# 12. 	Dynamic Memory Networks for Question Answering

* http://web.stanford.edu/class/cs224n/lectures/cs224n-2017-lecture16-DMN-QA.pdf
* https://arxiv.org/abs/1506.07285

In [1]:
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.optim as optim
import torch.nn.functional as F
import nltk
import random
import numpy as np
from collections import Counter, OrderedDict
import nltk
from copy import deepcopy
import os
import re
import unicodedata
flatten = lambda l: [item for sublist in l for item in sublist]

from torch.nn.utils.rnn import PackedSequence,pack_padded_sequence

In [2]:
USE_CUDA = torch.cuda.is_available()

FloatTensor = torch.cuda.FloatTensor if USE_CUDA else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if USE_CUDA else torch.LongTensor
ByteTensor = torch.cuda.ByteTensor if USE_CUDA else torch.ByteTensor

In [3]:
def getBatch(batch_size,train_data):
    random.shuffle(train_data)
    sindex=0
    eindex=batch_size
    while eindex < len(train_data):
        batch = train_data[sindex:eindex]
        temp = eindex
        eindex = eindex+batch_size
        sindex = temp
        yield batch
    
    if eindex >= len(train_data):
        batch = train_data[sindex:]
        yield batch

In [4]:
def pad_to_batch(batch,x_to_ix,y_to_ix):
    
    sorted_batch =  sorted(batch, key=lambda b:b[0].size(1),reverse=True) # sort by len
    x,y = list(zip(*sorted_batch))
    max_x = max([s.size(1) for s in x])
    max_y = max([s.size(1) for s in y])
    x_p,y_p=[],[]
    for i in range(len(batch)):
        if x[i].size(1)<max_x:
            x_p.append(torch.cat([x[i],Variable(LongTensor([x_to_ix['<PAD>']]*(max_x-x[i].size(1)))).view(1,-1)],1))
        else:
            x_p.append(x[i])
        if y[i].size(1)<max_y:
            y_p.append(torch.cat([y[i],Variable(LongTensor([y_to_ix['<PAD>']]*(max_y-y[i].size(1)))).view(1,-1)],1))
        else:
            y_p.append(y[i])
        
    input_var = torch.cat(x_p)
    target_var = torch.cat(y_p)
    input_len = [list(map(lambda s: s ==0, t.data)).count(False) for t in input_var]
    target_len = [list(map(lambda s: s ==0, t.data)).count(False) for t in target_var]
    
    return input_var, target_var, input_len, target_len

In [5]:
def prepare_sequence(seq, to_index):
    idxs = list(map(lambda w: to_index[w] if w in to_index.keys() else to_index["<UNK>"], seq))
    return Variable(LongTensor(idxs))

### Data load and Preprocessing 

In [6]:
DIR_PATH='../dataset/corpus/bAbI/en-10k/'
flist = os.listdir(DIR_PATH)

In [46]:
data=[]
for f in flist:
    #if f.endswith('train.txt'):
        fname = DIR_PATH+f
        #print(fname)
        temp = open(fname,'r',encoding='utf-8').readlines()
        temp = [t[:-1] for t in temp]
        data.extend(temp)

In [7]:
data = open('../dataset/corpus/bAbI/en-10k/qa1_single-supporting-fact_train.txt').readlines()
data = [d[:-1] for d in data]

In [8]:
count=0
for d in data:
    if '?' in d:
        count+=1
print(count)

10000


In [54]:
train=[]
support=[]
qa=[]

In [55]:
for d in data:

    index = d.split(' ')[0]
    if index=='1':
        support=[]
        qa=[]
        
    if '?' in d:
        temp = d.split('\t')
        q = temp[0].strip().replace('?','').split(' ')[1:]+['?']
        a = temp[1].split()+['</s>']
        #f = [int(t)-1 for t in temp[2].split()]
        stemp = deepcopy(support)

        train.append([stemp,q,a])
    else:
        tokens = d.replace('.','').split(' ')[1:]+['</s>']
        support.append(tokens)

In [56]:
support ,q,a = list(zip(*train))

In [57]:
support[0]

[['Mary', 'moved', 'to', 'the', 'bathroom', '</s>'],
 ['John', 'went', 'to', 'the', 'hallway', '</s>']]

In [58]:
Counter([len(s) for s in support])

Counter({2: 2000, 4: 2000, 6: 2000, 8: 2000, 10: 2000})

In [59]:
vocab = list(set(flatten(flatten(support))+flatten(q)+flatten(a)))

In [60]:
word2index={'<PAD>':0,'</s>':1}
for vo in vocab:
    word2index[vo]=len(word2index)
index2word = {v:k for k,v in word2index.items()}

In [63]:
for t in train:
    for i,sup in enumerate(t[0]):
        t[0][i] = prepare_sequence(sup,word2index).view(1,-1)
    
    t[1] = prepare_sequence(t[1],word2index).view(1,-1)
    t[2] = prepare_sequence(t[2],word2index).view(1,-1)

In [72]:
support = torch.cat(train[12][0])

In [73]:
support

Variable containing:
  5  16  18   8   7  21
 12  13  18   8   9  21
  2  13  18   8  17  21
 12  13  18   8  14  21
 12  13  18   8   9  21
  2  13  18   8  14  21
[torch.LongTensor of size 6x6]

### Modeling 

In [123]:
class DMN(nn.Module):
    def __init__(self, input_size,hidden_size,output_size):
        super(DMN, self).__init__()
        
        self.hidden_size = hidden_size
        self.embed = nn.Embedding(input_size, hidden_size, padding_idx=0, sparse=True)
        self.input_gru = nn.GRU(hidden_size, hidden_size,batch_first=True)
        self.question_gru = nn.GRU(hidden_size, hidden_size,batch_first=True)
        
        self.gate = nn.Sequential(
                            nn.Linear(hidden_size*4,hidden_size),
                            nn.Tanh(),
                            nn.Linear(hidden_size,1),
                            nn.Sigmoid()
                        )
        
        self.attention_gru =  nn.GRUCell(hidden_size, hidden_size)
        self.memory_gru = nn.GRU(hidden_size,hidden_size,batch_first=True)
        self.answer_gru = nn.GRU(hidden_size, hidden_size,batch_first=True)
        self.answer_fc = nn.Linear(hidden_size,output_size)
        
    def init_hidden(self,inputs):
        hidden = Variable(torch.zeros(1,inputs.size(0),self.hidden_size))
        return hidden.cuda() if USE_CUDA else hidden

    def memory_update(self,Cs,Gs):
        pass

    def forward(self,inputs,questions):
        
        C=[]
        for input_ in inputs:
            embeds = self.embed(input_)
            hidden = self.init_hidden(inputs)
            outputs,hidden = self.input_gru(embeds,hidden)
            C.append(hidden.squeeze(0)) # B x T x D
        
        embeds = self.embed(questions)
        hidden = self.init_hidden(questions)
        outputs, hidden = self.question_gru(embeds,hidden)
        
        
        
        return hidden.squeeze(0)


In [124]:
model = DMN(len(word2index),50,len(word2index))

In [125]:
model(support,question)

Variable containing:

Columns 0 to 9 
 0.4173  0.4438  0.7172 -0.5022 -0.1528 -0.2227 -0.4276 -0.2200 -0.3188 -0.3380
 0.4938  0.2621  0.7536 -0.5704 -0.4606 -0.3523 -0.2763  0.0045 -0.4024 -0.2828
 0.5056  0.4136  0.7094 -0.4888 -0.1466 -0.4241 -0.1926  0.1614 -0.3194 -0.2285
 0.5455  0.2938  0.8005 -0.4930  0.0572 -0.3729 -0.3775 -0.0722 -0.0260 -0.3010
 0.4938  0.2621  0.7536 -0.5704 -0.4606 -0.3523 -0.2763  0.0045 -0.4024 -0.2828
 0.5440  0.2921  0.7982 -0.5086  0.0824 -0.3721 -0.3756 -0.0814  0.0139 -0.2941

Columns 10 to 19 
 0.1222 -0.2928  0.2636 -0.2284  0.1185  0.1745  0.4569  0.0064 -0.4305 -0.0360
 0.0670 -0.1415  0.0685 -0.3162  0.1461 -0.1845  0.3521  0.0121 -0.3087 -0.0470
 0.2300 -0.2021  0.0709 -0.3735 -0.0517  0.1506  0.3462 -0.0615 -0.3862 -0.1503
 0.1159  0.0838  0.2641 -0.3576  0.0637  0.4873  0.4478 -0.2282 -0.3949 -0.0969
 0.0670 -0.1415  0.0685 -0.3162  0.1461 -0.1845  0.3521  0.0121 -0.3087 -0.0470
 0.1305  0.0611  0.2616 -0.3400  0.0831  0.5096  0.4586 -0.2562