Permalink
Switch branches/tags
Nothing to show
Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
111 lines (92 sloc) 3.87 KB
import torch
from torch.autograd import Variable
import random
from model import DMN
from copy import deepcopy
flatten = lambda l: [item for sublist in l for item in sublist]
random.seed(1024)
USE_CUDA = torch.cuda.is_available()
gpus = [0]
torch.cuda.set_device(gpus[0])
FloatTensor = torch.cuda.FloatTensor if USE_CUDA else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if USE_CUDA else torch.LongTensor
ByteTensor = torch.cuda.ByteTensor if USE_CUDA else torch.ByteTensor
def bAbI_data_load(path):
try:
data = open(path).readlines()
except:
print("Such a file does not exist at %s".format(path))
return None
data = [d[:-1] for d in data]
data_p = []
fact = []
qa = []
try:
for d in data:
index = d.split(' ')[0]
if index == '1':
fact = []
qa = []
if '?' in d:
temp = d.split('\t')
q = temp[0].strip().replace('?', '').split(' ')[1:] + ['?']
a = temp[1].split() + ['</s>']
stemp = deepcopy(fact)
data_p.append([stemp, q, a])
else:
tokens = d.replace('.', '').split(' ')[1:] + ['</s>']
fact.append(tokens)
except:
print("Please check the data is right")
return None
return data_p
def prepare_sequence(seq, to_index):
idxs = list(map(lambda w: to_index[w] if to_index.get(w) is not None else to_index["<UNK>"], seq))
return Variable(LongTensor(idxs))
def prepare_word2index(train_data):
fact, q, a = list(zip(*train_data))
vocab = list(set(flatten(flatten(fact)) + flatten(q) + flatten(a)))
word2index = {'<PAD>': 0, '<UNK>': 1, '<s>': 2, '</s>': 3}
for vo in vocab:
if word2index.get(vo) is None:
word2index[vo] = len(word2index)
index2word = {v: k for k, v in word2index.items()}
return word2index, index2word
def pad_to_fact(fact, x_to_ix): # this is for inference
max_x = max([s.size(1) for s in fact])
x_p = []
for i in range(len(fact)):
if fact[i].size(1) < max_x:
x_p.append(
torch.cat([fact[i], Variable(LongTensor([x_to_ix['<PAD>']] * (max_x - fact[i].size(1)))).view(1, -1)],
1))
else:
x_p.append(fact[i])
fact = torch.cat(x_p)
fact_mask = torch.cat(
[Variable(ByteTensor(tuple(map(lambda s: s == 0, t.data))), volatile=False) for t in fact]).view(fact.size(0),
-1)
return fact, fact_mask
if __name__ == '__main__':
train_data = bAbI_data_load('dataset/babi/en-10k/qa5_three-arg-relations_train.txt')
word2index, index2word = prepare_word2index(train_data)
test_data = bAbI_data_load('dataset/babi/en-10k/qa5_three-arg-relations_test.txt')
t = random.choice(test_data)
for i, fact in enumerate(t[0]):
t[0][i] = prepare_sequence(fact, word2index).view(1, -1)
t[1] = prepare_sequence(t[1], word2index).view(1, -1)
t[2] = prepare_sequence(t[2], word2index).view(1, -1)
fact, fact_mask = pad_to_fact(t[0], word2index)
question = t[1]
question_mask = Variable(ByteTensor([0] * t[1].size(1)), volatile=False).unsqueeze(0)
answer = t[2].squeeze(0)
dmn_model = torch.load('dmn_qa')
dmn_model.zero_grad()
pred = dmn_model([fact], [fact_mask], question, question_mask, answer.size(0), 3)
print("Facts : ")
print('\n'.join([' '.join(list(map(lambda x: index2word[x], f))) for f in fact.data.tolist()]))
print("")
print("Question : ", ' '.join(list(map(lambda x: index2word[x], question.data.tolist()[0]))))
print("")
print("Answer : ", ' '.join(list(map(lambda x: index2word[x], answer.data.tolist()))))
print("Prediction : ", ' '.join(list(map(lambda x: index2word[x], pred.max(1)[1].data.tolist()))))