In [44]:
import numpy as np
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from collections import Counter, defaultdict

In [2]:
def load_data(in_file, max_example=None, relabeling=True):
    """
        load CNN / Daily Mail data from {train | dev | test}.txt
        relabeling: relabel the entities by their first occurence if it is True.
    """

    documents = []
    questions = []
    answers = []
    num_examples = 0
    with open(in_file, 'r') as f:
        while True:
            line = f.readline()
            if not line:
                break
            question = line.strip().lower()
            answer = f.readline().strip()
            document = f.readline().strip().lower()

            if relabeling:
                q_words = question.split(' ')
                d_words = document.split(' ')
                assert answer in d_words

                entity_dict = {}
                entity_id = 0
                for word in d_words + q_words:
                    if (word.startswith('@entity')) and (word not in entity_dict):
                        entity_dict[word] = '@entity' + str(entity_id)
                        entity_id += 1

                q_words = [entity_dict[w] if w in entity_dict else w for w in q_words]
                d_words = [entity_dict[w] if w in entity_dict else w for w in d_words]
                answer = entity_dict[answer]

                question = ' '.join(q_words)
                document = ' '.join(d_words)

            questions.append(question)
            answers.append(answer)
            documents.append(document)
            num_examples += 1

            f.readline()
            if (max_example is not None) and (num_examples >= max_example):
                break
                
    print('#Examples: %d' % len(documents))
    return (documents, questions, answers)

In [3]:
fin_train = 'data/cnn/train.txt'
fin_dev = 'data/cnn/dev.txt'

print('*' * 10 + ' Train Loading')
train_d, train_q, train_a = load_data(fin_train, 100, relabeling=True)
print('*' * 10 + ' Dev Loading')
dev_d, dev_q, dev_a = load_data(fin_dev, 100, relabeling=True)

********** Train Loading
#Examples: 100
********** Dev Loading
#Examples: 100


In [4]:
print(train_d[0])
print(train_q[0])
print(train_a[0])

days after two @entity0 journalists were killed in northern @entity1 , authorities rounded up dozens of suspects and a group linked to @entity2 claimed responsibility for the deaths . at least 30 suspects were seized in desert camps near the town of @entity3 and taken to the local @entity0 army base for questioning , three officials in @entity1 said . the officials did not want to be named because they are not authorized to talk to the media . @entity4 ( @entity4 ) has allegedly claimed responsibility for the killings , according to @entity5 news agency in @entity6 . @entity4 operates in northern @entity7 and the group 's statements have shown up before on the @entity8 outlet . @entity9 journalists @entity10 and @entity11 were abducted in front of the home of a member of the @entity12 rebels ' @entity13 of a @entity14 on saturday , @entity9 reported . they were found dead the same day . their bodies arrived in @entity15 on tuesday . @entity3 was one of the strongholds of the @entity16 

In [5]:
def build_dict(sentences, max_words=50000):
    """
        Build a dictionary for the words in `sentences`.
        Only the max_words ones are kept and the remaining will be mapped to <UNK>.
    """
    word_count = Counter()
    for sent in sentences:
        for w in sent.split(' '):
            word_count[w] += 1

    ls = word_count.most_common(max_words)
    print('#Words: %d -> %d' % (len(word_count), len(ls)))
    for key in ls[:5]:
        print(key)
    print('...')
    for key in ls[-5:]:
        print(key)

    # leave 0 to UNK
    # leave 1 to delimiter |||
    return {w[0]: index + 2 for (index, w) in enumerate(ls)}

In [6]:
print('Build dictionary..')
word_dict = build_dict(train_d + train_q)
entity_markers = list(set([w for w in word_dict.keys()
                          if w.startswith('@entity')] + train_a))
entity_markers = ['<unk_entity>'] + entity_markers
entity_dict = {w: index for (index, w) in enumerate(entity_markers)}
print('Entity markers: %d' % len(entity_dict))
num_labels = len(entity_dict)

Build dictionary..
#Words: 8178 -> 8178
('the', 4399)
(',', 4044)
('.', 3356)
('"', 2181)
('to', 1992)
...
('spout', 1)
('tide', 1)
('envoy', 1)
('heroic', 1)
('quirkier', 1)
Entity markers: 225


In [7]:
def gen_embeddings(word_dict, dim, in_file=None):
    """
        Generate an initial embedding matrix for `word_dict`.
        If an embedding file is not given or a word is not in the embedding file,
        a randomly initialized vector will be used.
    """

    num_words = len(word_dict) + 2
    embeddings = np.random.uniform(size=(num_words, dim))
    print('Embeddings: %d x %d' % (num_words, dim))

    if in_file is not None:
        print('Loading embedding file: %s' % in_file)
        pre_trained = 0
        for line in open(in_file).readlines():
            sp = line.split()
            assert len(sp) == dim + 1 # word + embeddings ..
            if sp[0] in word_dict:
                pre_trained += 1
                embeddings[word_dict[sp[0]]] = [float(x) for x in sp[1:]]
        print('Pre-trained: %d (%.2f%%)' %
              (pre_trained, pre_trained * 100.0 / num_words))
    return embeddings

In [10]:
embedding_size = 50
embeddings = gen_embeddings(word_dict, embedding_size, 'data/glove.6B/glove.6B.{}d.txt'.format(embedding_size))

Embeddings: 8180 x 50
Loading embedding file: data/glove.6B/glove.6B.50d.txt
Pre-trained: 7887 (96.42%)


In [67]:
class Net(nn.Module):

    def __init__(self, word_dict, embeddings, embedding_dim, hidden_dim):
        super(Net, self).__init__()
        
        self.word_dict = word_dict
        self.embeddings = embeddings
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim // 2 * 2

        self.d_gru = nn.GRU(embedding_dim, self.hidden_dim // 2,
                            num_layers=1, bidirectional=True)
        self.q_gru = nn.GRU(embedding_dim, self.hidden_dim // 2,
                            num_layers=1, bidirectional=True)

    def init_hidden(self):
        # Variable(num_layers*num_directions, minibatch_size, hidden_dim)
        return Variable(torch.randn(2, 1, self.hidden_dim // 2))
    
    def forward(self, d, q):
        d_words = d.split()
        q_words = q.split()
        d_idx = [self.word_dict[dw] for dw in d_words]
        q_idx = [self.word_dict[qw] for qw in q_words]
        d_emb = [self.embeddings[i] for i in d_idx] # !bug: max_words not in word_dict
        q_emb = [self.embeddings[i] for i in q_idx]
        d_emb = Variable(torch.FloatTensor(d_emb), requires_grad=True)
        q_emb = Variable(torch.FloatTensor(q_emb), requires_grad=True)
        
        d_hidden = self.init_hidden()
        q_hidden = self.init_hidden()
        d_gru_out, d_hidden = self.d_gru(d_emb.view(len(d_words), 1, -1), # (seq_len, batch, input_size)
                                         d_hidden)
        q_gru_out, q_hidden = self.q_gru(q_emb.view(len(q_words), 1, -1), # (seq_len, batch, input_size)
                                         q_hidden)
        
        d_gru_out = d_gru_out.view(len(d_words), self.hidden_dim)
        q_gru_out = q_gru_out.view(len(q_words), self.hidden_dim)
        sim = torch.mm(d_gru_out, q_hidden.view(self.hidden_dim,1))
        prob = F.softmax(sim, dim=0)
        
        prob_uniq = defaultdict(float)
        for i,p in zip(q_idx, prob):
            prob_uniq[i] += p
        return prob_uniq

In [71]:
hidden_dim = 5
net = Net(word_dict, embeddings, embedding_size, hidden_dim)
for idx, d,q,a in zip(range(len(train_d)), train_d, train_q, train_a): # !test
    prob = net(d, q)
    for k,v in prob.items():
        print('idx: {}'.format(k))
        print(v)
    break

idx: 2208
Variable containing:
1.00000e-03 *
  3.0187
[torch.FloatTensor of size 1]

idx: 2
Variable containing:
1.00000e-03 *
  4.2678
[torch.FloatTensor of size 1]

idx: 452
Variable containing:
1.00000e-03 *
  4.5723
[torch.FloatTensor of size 1]

idx: 6
Variable containing:
1.00000e-03 *
  3.4829
[torch.FloatTensor of size 1]

idx: 1512
Variable containing:
1.00000e-03 *
  5.9110
[torch.FloatTensor of size 1]

idx: 1596
Variable containing:
1.00000e-03 *
  1.9943
[torch.FloatTensor of size 1]

idx: 491
Variable containing:
1.00000e-03 *
  1.9940
[torch.FloatTensor of size 1]

idx: 2946
Variable containing:
1.00000e-03 *
  2.6689
[torch.FloatTensor of size 1]

idx: 14
Variable containing:
1.00000e-03 *
  5.5144
[torch.FloatTensor of size 1]

idx: 104
Variable containing:
1.00000e-03 *
  1.1711
[torch.FloatTensor of size 1]

idx: 59
Variable containing:
1.00000e-03 *
  7.4605
[torch.FloatTensor of size 1]

idx: 188
Variable containing:
1.00000e-03 *
  1.5242
[torch.FloatTensor of siz