# STATIC GRAPH

In [29]:
import torch
from torch import nn
import torch.nn.functional as F
from torch import optim
from torch.autograd import Variable


## DATA

In [2]:
import data_utils
metadata, idx_q, idx_a = data_utils.load_data('../data/')

In [3]:
# add special symbol
i2w = metadata['idx2w'] + ['GO']
w2i = metadata['w2idx']
w2i['GO'] = len(i2w)-1

## Parameters

In [14]:
B = 1000
L = len(idx_q[0])
vocab_size = len(i2w)
enc_hdim = 250
dec_hdim = enc_hdim

## Graph

In [167]:
class Model(nn.Module):
    def __init__(self, batch_size, vocab_size, hidden_size, input_size):
        super(Model, self).__init__()
        
        self.batch_size = batch_size
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.input_size  = input_size
        
        self.encode = nn.LSTMCell(hidden_size, hidden_size)
        self.decode = nn.LSTMCell(hidden_size, hidden_size)
        
        self.embed = nn.Embedding(vocab_size, hidden_size)
        self.project = nn.Linear(hidden_size, vocab_size)
        
        self.printsize = True
        
    def initial_state(self):
        state = torch.zeros([self.batch_size, self.hidden_size])
        return Variable(state)
    
    def psize(self, name, tensor):
        if self.printsize:
            print(name, tensor.size())
    
    def forward(self, enc_inputs, dec_outputs):
        self.psize('enc_inputs', enc_inputs)
        self.psize('dec_outputs', dec_outputs)

        enc_embeddings = self.embed(enc_inputs)                                 #BxLXH <- BxL
        dec_embeddings = self.embed(dec_outputs)                                #BXLXH <- BxL
        
        self.psize('enc_embeddings', enc_embeddings)
        self.psize('dec_embeddings', dec_embeddings)
        
        hidden = cell_state = self.initial_state()                              #BxH
        for i in range(enc_embeddings.size()[1]):
            hidden, cell_state = self.encode(enc_embeddings[:,i], (hidden, cell_state))
        
        self.psize('hidden', hidden)
        self.psize('cell_state', cell_state)
        
        predicted_outputs = []
        first_input = torch.LongTensor([w2i['GO']] * self.batch_size).view(self.batch_size, self.input_size)
        first_input = Variable(first_input).squeeze(1)
        self.psize('first_input', first_input)
        first_input = self.embed(first_input)
        self.psize('first_input', first_input)
        hidden, cell_state = self.decode(first_input, (hidden, cell_state))
        predicted_outputs.append(hidden)
        for i in range(dec_embeddings.size()[1] - 1):
            hidden, cell_state = self.decode(dec_embeddings[:,i+1], (hidden, cell_state))
            predicted_outputs.append(hidden)
            
        predicted_outputs = torch.stack(predicted_outputs)
        predicted_outputs = predicted_outputs.view(
                            self.batch_size * dec_outputs.size()[1], 
                            self.hidden_size)
        self.psize('predicted_outputs', predicted_outputs)
        outputs = self.project(predicted_outputs)
        self.psize('outputs', outputs)
        outputs = outputs.view(self.batch_size, dec_outputs.size()[1], self.vocab_size)
        self.psize('outputs', outputs)

        outputs = F.log_softmax(outputs).max(2)[1].squeeze(2)
        self.psize('outputs', outputs)

        return outputs
        

# TRAINING

In [153]:
def train(epochs, model, question_ids, answer_ids):
    model.train()
    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.1)
    for epoch in range(epochs+1):
        avg_loss = 0
        for batch in range(len(question_ids)//B):
            l, r = batch * B, (batch + 1) * B
            data = Variable(torch.from_numpy(question_ids[l:r]).long())
            target = Variable(torch.from_numpy(answer_ids[l:r]).long())
            
            optimizer.zero_grad()
            logits = model(data, target)
            model.psize('data', data)
            model.psize('target', target)
            model.psize('logits', logits)
            loss = F.nll_loss(logits, target)
            loss.backward()
            optimizer.step()
            
            avg_loss += loss.data[0]
            model.printsize = False
            
            if batch % 20:
                print('{}.{} - {}'.format(epoch, batch, avg_loss/20))
                avg_loss = 0

In [168]:
model = Model(B, vocab_size, enc_hdim, 1)

In [169]:
train(10, model, idx_q, idx_a)

enc_inputs torch.Size([1000, 20])
dec_outputs torch.Size([1000, 20])
enc_embeddings torch.Size([1000, 20, 250])
dec_embeddings torch.Size([1000, 20, 250])
hidden torch.Size([1000, 250])
cell_state torch.Size([1000, 250])
first_input torch.Size([1000])
first_input torch.Size([1000, 250])
predicted_outputs torch.Size([20000, 250])
outputs torch.Size([20000, 8003])
outputs torch.Size([1000, 20, 8003])
outputs torch.Size([1000, 20])
data torch.Size([1000, 20])
target torch.Size([1000, 20])
logits torch.Size([1000, 20])


KeyError: <class 'torch.LongTensor'>

### Training parameters

In [16]:
num_epochs = 10

## Begin

In [17]:
for i in range(num_epochs):
    avg_loss = 0.
    for j in range(len(idx_q)//B):
        _, loss_v = sess.run([train_op, loss], feed_dict = {
            inputs : idx_q[j*B:(j+1)*B],
            targets : idx_a[j*B:(j+1)*B]
        })
        avg_loss += loss_v
        if j and j%20==0:
            print('{}.{} : {}'.format(i,j,avg_loss/20))
            avg_loss = 0.

0.20 : 4.374236321449279
0.40 : 2.9668627977371216
0.60 : 2.504377770423889
0.80 : 1.9982671082019805
0.100 : 1.444276547431946
0.120 : 1.019878402352333
1.20 : 0.7292188316583633
1.40 : 0.5727434158325195
1.60 : 0.49078361988067626
1.80 : 0.4375475347042084
1.100 : 0.4037304982542992
1.120 : 0.3748603016138077
2.20 : 0.35517664849758146
2.40 : 0.32739430814981463
2.60 : 0.3153625696897507
2.80 : 0.3068651154637337
2.100 : 0.30491813123226164
2.120 : 0.3002783715724945
3.20 : 0.3121029630303383
3.40 : 0.2952332764863968
3.60 : 0.29264847487211226
3.80 : 0.2903579115867615
3.100 : 0.2914220854640007
3.120 : 0.29028379172086716
4.20 : 0.3058256059885025
4.40 : 0.2903655216097832
4.60 : 0.28862053602933885
4.80 : 0.2873972564935684
4.100 : 0.2889208883047104
4.120 : 0.2882304400205612
5.20 : 0.3042555972933769
5.40 : 0.2887479826807976
5.60 : 0.2871220126748085
5.80 : 0.2859190031886101
5.100 : 0.28765359073877333
5.120 : 0.28710009157657623
6.20 : 0.3032954305410385
6.40 : 0.287793479859

In [19]:
j = 7 
pred_v = sess.run(prediction, feed_dict = {
            inputs : idx_q[j*B:(j+1)*B],
            targets : idx_a[j*B:(j+1)*B]
        })

In [20]:
pred_v[0], idx_a[j*B:(j+1)*B][0]

(array([  4,  45, 754,  11, 694,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0]),
 array([1989,   45,  754,   11,  694,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0], dtype=int32))

In [21]:
def arr2sent(arr):
    return ' '.join([i2w[item] for item in arr])

In [26]:
arr2sent(pred_v[11])

'i as if her supporters just accept shes a crook _ _ _ _ _ _ _ _ _ _'

In [27]:
arr2sent( idx_a[j*B:(j+1)*B][11])

'its as if her supporters just accept shes a crook _ _ _ _ _ _ _ _ _ _'