In [1]:
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.optim as optim
import torch.nn.functional as F
import random
import numpy as np
import os
from bAbI_data_utils import bAbI_data_load, data_loader, pad_to_batch, pad_to_fact
DATA_PATH = os.environ['DATA_PATH']

USE_CUDA = torch.cuda.is_available()

## 데이터 로드 

bAbI dataset(https://research.fb.com/downloads/babi/)

In [2]:
train_data, word2index = bAbI_data_load(DATA_PATH+"bAbI/en-10k/qa1_single-supporting-fact_train.txt")

Start to data loading...


In [3]:
train_loader = data_loader(train_data,batch_size=32,shuffle=True)

In [4]:
for batch in train_loader:
    break

In [5]:
len(batch)

32

## 모델링 

In [6]:
class DMN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, start_index, dropout_p=0.1):
        super(DMN, self).__init__()
        
        self.hidden_size = hidden_size
        self.embed = nn.Embedding(input_size, hidden_size, padding_idx=0)
        self.input_gru = nn.GRU(hidden_size, hidden_size, batch_first=True)
        self.question_gru = nn.GRU(hidden_size, hidden_size, batch_first=True)
        
        self.gate = nn.Sequential(
                            nn.Linear(hidden_size * 4, hidden_size),
                            nn.Tanh(),
                            nn.Linear(hidden_size, 1),
                            nn.Sigmoid()
                        )
        
        self.attention_grucell =  nn.GRUCell(hidden_size, hidden_size)
        self.memory_grucell = nn.GRUCell(hidden_size, hidden_size)
        self.answer_grucell = nn.GRUCell(hidden_size * 2, hidden_size)
        self.answer_fc = nn.Linear(hidden_size, output_size)
        
        self.start_index = start_index
        self.dropout = nn.Dropout(dropout_p)
        
    def init_hidden(self, batch_size):
        hidden = Variable(torch.zeros(1, batch_size, self.hidden_size))
        return hidden.cuda() if USE_CUDA else hidden
    
    def init_start_decode(self,batch_size):
        start = Variable(torch.LongTensor([[self.start_index] * batch_size])).transpose(0, 1)
        return start.cuda() if USE_CUDA else start
        
    def init_weight(self):
        nn.init.xavier_uniform(self.embed.state_dict()['weight'])
        
        for name, param in self.input_gru.state_dict().items():
            if 'weight' in name: nn.init.xavier_normal(param)
        for name, param in self.question_gru.state_dict().items():
            if 'weight' in name: nn.init.xavier_normal(param)
        for name, param in self.gate.state_dict().items():
            if 'weight' in name: nn.init.xavier_normal(param)
        for name, param in self.attention_grucell.state_dict().items():
            if 'weight' in name: nn.init.xavier_normal(param)
        for name, param in self.memory_grucell.state_dict().items():
            if 'weight' in name: nn.init.xavier_normal(param)
        for name, param in self.answer_grucell.state_dict().items():
            if 'weight' in name: nn.init.xavier_normal(param)
        
        nn.init.xavier_normal(self.answer_fc.state_dict()['weight'])
        self.answer_fc.bias.data.fill_(0)
        
    def forward(self, facts, fact_masks, questions, question_masks, num_decode, episodes=3):
        """
        facts : (B,T_C,T_I) / LongTensor in List # batch_size, num_of_facts, length_of_each_fact(padded)
        fact_masks : (B,T_C,T_I) / ByteTensor in List # batch_size, num_of_facts, length_of_each_fact(padded)
        questions : (B,T_Q) / LongTensor # batch_size, question_length
        question_masks : (B,T_Q) / ByteTensor # batch_size, question_length
        """
        batch_size = len(facts)
        # Input Module
        C = [] # encoded facts
        for fact, fact_mask in zip(facts, fact_masks):
            embeds = self.embed(fact)
            embeds = self.dropout(embeds)
            hidden = self.init_hidden(fact.size(0))
            outputs, hidden = self.input_gru(embeds, hidden)
            real_hidden = []

            for i, o in enumerate(outputs): # B,T,D
                real_length = fact_mask[i].data.tolist().count(0) 
                real_hidden.append(o[real_length - 1])

            C.append(torch.cat(real_hidden).view(fact.size(0), -1).unsqueeze(0))
        
        encoded_facts = torch.cat(C) # B,T_C,D
        
        # Question Module
        embeds = self.embed(questions)
        embeds = self.dropout(embeds)
        hidden = self.init_hidden(batch_size)
        outputs, hidden = self.question_gru(embeds, hidden)
        
        if isinstance(question_masks, torch.autograd.variable.Variable):
            real_question = []
            for i, o in enumerate(outputs): # B,T,D
                real_length = question_masks[i].data.tolist().count(0) 
                real_question.append(o[real_length - 1])
            encoded_question = torch.cat(real_question).view(questions.size(0), -1) # B,D
        else: # for inference mode
            encoded_question = hidden.squeeze(0) # B,D
            
        # Episodic Memory Module
        memory = encoded_question
        T_C = encoded_facts.size(1)
        for i in range(episodes):
            hidden = self.init_hidden(batch_size).squeeze(0) # B,D
            for t in range(T_C):
                z = torch.cat([
                                    encoded_facts.transpose(0, 1)[t] * encoded_question, # B,D , element-wise product
                                    encoded_facts.transpose(0, 1)[t] * memory, # B,D , element-wise product
                                    torch.abs(encoded_facts.transpose(0,1)[t] - encoded_question), # B,D
                                    torch.abs(encoded_facts.transpose(0,1)[t] - memory) # B,D
                                ], 1)
                g_t = self.gate(z) # B,1 scalar
                hidden = g_t * self.attention_grucell(encoded_facts.transpose(0, 1)[t], hidden) + (1 - g_t) * hidden
                
            e = hidden
            memory = self.memory_grucell(e, memory)
        
        # Answer Module
        answer_hidden = memory
        start_decode = self.init_start_decode(batch_size)
        y_t_1 = self.embed(start_decode).squeeze(1) # B,D
        
        decodes = []
        for t in range(num_decode):
            answer_hidden = self.answer_grucell(torch.cat([y_t_1, encoded_question], 1), answer_hidden)
            decodes.append(self.answer_fc(answer_hidden),1)
        return torch.cat(decodes, 1).view(batch_size * num_decode, -1)

## 트레이닝 

In [11]:
HIDDEN_SIZE = 80
BATCH_SIZE = 16
LR = 0.001
EPOCH = 50
NUM_EPISODE = 3
EARLY_STOPPING = False

In [12]:
model = DMN(len(word2index), HIDDEN_SIZE, len(word2index),word2index['<s>'])
model.init_weight()
if USE_CUDA:
    model = model.cuda()

loss_function = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.Adam(model.parameters(), lr=LR)

In [13]:
model.train()
for epoch in range(EPOCH):
    losses = []
    if EARLY_STOPPING: 
        break
        
    for i,batch in enumerate(data_loader(train_data,BATCH_SIZE,True)):
        facts, fact_masks, questions, question_masks, answers = pad_to_batch(batch, word2index)
        
        model.zero_grad()
        pred = model(facts, fact_masks, questions, question_masks, answers.size(1), NUM_EPISODE)
        loss = loss_function(pred, answers.view(-1))
        losses.append(loss.data.tolist()[0])
        
        loss.backward()
        optimizer.step()
        
        if i % 100 == 0:
            print("[%d/%d] mean_loss : %0.2f" %(epoch, EPOCH, np.mean(losses)))
            
            if np.mean(losses) < 0.01:
                EARLY_STOPPING = True
                print("Early Stopping!")
                break
            losses = []

[0/50] mean_loss : 3.20
[0/50] mean_loss : 1.21
[0/50] mean_loss : 0.91
[0/50] mean_loss : 0.90
[0/50] mean_loss : 0.91
[0/50] mean_loss : 0.91
[0/50] mean_loss : 0.91
[1/50] mean_loss : 0.86
[1/50] mean_loss : 0.90
[1/50] mean_loss : 0.91
[1/50] mean_loss : 0.90
[1/50] mean_loss : 0.90
[1/50] mean_loss : 0.87
[1/50] mean_loss : 0.83
[2/50] mean_loss : 0.85
[2/50] mean_loss : 0.77
[2/50] mean_loss : 0.68
[2/50] mean_loss : 0.64
[2/50] mean_loss : 0.59
[2/50] mean_loss : 0.56
[2/50] mean_loss : 0.57
[3/50] mean_loss : 0.47
[3/50] mean_loss : 0.54
[3/50] mean_loss : 0.53
[3/50] mean_loss : 0.53
[3/50] mean_loss : 0.55
[3/50] mean_loss : 0.51
[3/50] mean_loss : 0.53
[4/50] mean_loss : 0.52
[4/50] mean_loss : 0.50
[4/50] mean_loss : 0.52
[4/50] mean_loss : 0.52
[4/50] mean_loss : 0.51
[4/50] mean_loss : 0.52
[4/50] mean_loss : 0.51
[5/50] mean_loss : 0.41
[5/50] mean_loss : 0.50
[5/50] mean_loss : 0.50
[5/50] mean_loss : 0.51
[5/50] mean_loss : 0.51
[5/50] mean_loss : 0.50
[5/50] mean_loss

## 테스트 

In [14]:
test_data, word2index = bAbI_data_load(DATA_PATH+"bAbI/en-10k/qa1_single-supporting-fact_train.txt",word2index)

Start to data loading...


### 정량적 테스트 : Accruacy 

In [21]:
accuracy=0
model.eval()
for t in test_data:
    fact, fact_mask = pad_to_fact(t[0], word2index)
    question = t[1]
    question_mask = Variable(torch.ByteTensor([0] * t[1].size(1))).unsqueeze(0)
    answer = t[2].squeeze(0)
    
    model.zero_grad()
    pred = model([fact], [fact_mask], question, question_mask, answer.size(0), NUM_EPISODE)
    if pred.max(1)[1].data.tolist() == answer.data.tolist():
        accuracy += 1

print(accuracy/len(test_data) * 100)

100.0


### 정성적 테스트 

In [24]:
index2word = {v:k for k,v in word2index.items()}

In [25]:
t = random.choice(test_data)
fact, fact_mask = pad_to_fact(t[0],word2index)
question = t[1]
question_mask = Variable(torch.ByteTensor([0] * t[1].size(1))).unsqueeze(0)
answer = t[2].squeeze(0)

model.zero_grad()
pred = model([fact], [fact_mask], question, question_mask, answer.size(0), NUM_EPISODE)

print("Facts : ")
print('\n'.join([' '.join(list(map(lambda x: index2word[x],f))) for f in fact.data.tolist()]))
print("")
print("Question : ",' '.join(list(map(lambda x: index2word[x], question.data.tolist()[0]))))
print("")
print("Answer : ",' '.join(list(map(lambda x: index2word[x], answer.data.tolist()))))
print("Prediction : ",' '.join(list(map(lambda x: index2word[x], pred.max(1)[1].data.tolist()))))

Facts : 
John travelled to the bedroom </s> <pad>
Mary travelled to the garden </s> <pad>
Sandra went back to the office </s>
John journeyed to the office </s> <pad>
Mary journeyed to the kitchen </s> <pad>
John went back to the kitchen </s>
Daniel went back to the hallway </s>
Sandra went back to the bedroom </s>

Question :  Where is Daniel ?

Answer :  hallway </s>
Prediction :  hallway </s>
