In [1]:
import numpy as np
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from collections import Counter

In [2]:
def load_data(in_file, max_example=None, relabeling=True):
    """
        load CNN / Daily Mail data from {train | dev | test}.txt
        relabeling: relabel the entities by their first occurence if it is True.
    """

    documents = []
    questions = []
    answers = []
    num_examples = 0
    with open(in_file, 'r') as f:
        while True:
            line = f.readline()
            if not line:
                break
            question = line.strip().lower()
            answer = f.readline().strip()
            document = f.readline().strip().lower()

            if relabeling:
                q_words = question.split(' ')
                d_words = document.split(' ')
                assert answer in d_words

                entity_dict = {}
                entity_id = 0
                for word in d_words + q_words:
                    if (word.startswith('@entity')) and (word not in entity_dict):
                        entity_dict[word] = '@entity' + str(entity_id)
                        entity_id += 1

                q_words = [entity_dict[w] if w in entity_dict else w for w in q_words]
                d_words = [entity_dict[w] if w in entity_dict else w for w in d_words]
                answer = entity_dict[answer]

                question = ' '.join(q_words)
                document = ' '.join(d_words)

            questions.append(question)
            answers.append(answer)
            documents.append(document)
            num_examples += 1

            f.readline()
            if (max_example is not None) and (num_examples >= max_example):
                break
                
    print('#Examples: %d' % len(documents))
    return (documents, questions, answers)

In [3]:
fin_train = 'data/cnn/train.txt'
fin_dev = 'data/cnn/dev.txt'

print('*' * 10 + ' Train Loading')
train_d, train_q, train_a = load_data(fin_train, 100, relabeling=True)
print('*' * 10 + ' Dev Loading')
dev_d, dev_q, dev_a = load_data(fin_dev, 100, relabeling=True)

********** Train Loading
#Examples: 100
********** Dev Loading
#Examples: 100


In [4]:
print(train_d[0])
print(train_q[0])
print(train_a[0])

days after two @entity0 journalists were killed in northern @entity1 , authorities rounded up dozens of suspects and a group linked to @entity2 claimed responsibility for the deaths . at least 30 suspects were seized in desert camps near the town of @entity3 and taken to the local @entity0 army base for questioning , three officials in @entity1 said . the officials did not want to be named because they are not authorized to talk to the media . @entity4 ( @entity4 ) has allegedly claimed responsibility for the killings , according to @entity5 news agency in @entity6 . @entity4 operates in northern @entity7 and the group 's statements have shown up before on the @entity8 outlet . @entity9 journalists @entity10 and @entity11 were abducted in front of the home of a member of the @entity12 rebels ' @entity13 of a @entity14 on saturday , @entity9 reported . they were found dead the same day . their bodies arrived in @entity15 on tuesday . @entity3 was one of the strongholds of the @entity16 

In [5]:
def build_dict(sentences, max_words=50000):
    """
        Build a dictionary for the words in `sentences`.
        Only the max_words ones are kept and the remaining will be mapped to <UNK>.
    """
    word_count = Counter()
    for sent in sentences:
        for w in sent.split(' '):
            word_count[w] += 1

    ls = word_count.most_common(max_words)
    print('#Words: %d -> %d' % (len(word_count), len(ls)))
    for key in ls[:5]:
        print(key)
    print('...')
    for key in ls[-5:]:
        print(key)

    # leave 0 to UNK
    # leave 1 to delimiter |||
    return {w[0]: index + 2 for (index, w) in enumerate(ls)}

In [6]:
print('Build dictionary..')
word_dict = build_dict(train_d + train_q)
entity_markers = list(set([w for w in word_dict.keys()
                          if w.startswith('@entity')] + train_a))
entity_markers = ['<unk_entity>'] + entity_markers
entity_dict = {w: index for (index, w) in enumerate(entity_markers)}
print('Entity markers: %d' % len(entity_dict))
num_labels = len(entity_dict)

Build dictionary..
#Words: 8178 -> 8178
('the', 4399)
(',', 4044)
('.', 3356)
('"', 2181)
('to', 1992)
...
('spout', 1)
('tide', 1)
('envoy', 1)
('heroic', 1)
('quirkier', 1)
Entity markers: 225
