# STATIC GRAPH

In [55]:
import torch
from torch import nn
import torch.nn.functional as F
from torch import optim
from torch.autograd import Variable


## DATA

In [56]:
import data_utils
metadata, idx_q, idx_a = data_utils.load_data('../data/')

In [57]:
# add special symbol
i2w = metadata['idx2w'] + ['GO']
w2i = metadata['w2idx']
w2i['GO'] = len(i2w)-1

## Parameters

In [282]:
batch_size = 128
L = len(idx_q[0])
vocab_size = len(i2w)
hidden_size = 256

In [84]:
class Config:
    pass

config = Config()
config.printsize = True

In [169]:
len(idx_q)

267518


## Graph

In [444]:
def initial_state(batch_size, hidden_size):
    state = torch.zeros([batch_size, hidden_size])
    return Variable(state.cuda())

def psize(name, variable):
    if config.printsize:
        print(name, variable.size(), type(variable.data))
        
class Encoder(nn.Module):
    def __init__(self, vocab_size, hidden_size):
        super(Encoder, self).__init__()
        
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
                
        self.embed = nn.Embedding(vocab_size, hidden_size)
        self.encode = nn.LSTMCell(hidden_size, hidden_size)
            
    def forward(self, enc_inputs, hidden, batch_size):
        input_length = enc_inputs.size()[0]
        psize('enc_inputs', enc_inputs)
        enc_embeddings = self.embed(enc_inputs)
        psize('enc_embeddings', enc_embeddings)
        enc_embeddings = enc_embeddings.view(input_length, 
                                            batch_size, 
                                            hidden_size)            #LxBxH       
                
        psize('enc_embeddings', enc_embeddings)        
        hidden, cell_state = hidden
        for i in range(enc_embeddings.size()[0]):
            hidden, cell_state = self.encode(enc_embeddings[i], (hidden, cell_state))
            
        return hidden, cell_state
    
class Decoder(nn.Module):
    def __init__(self, vocab_size, hidden_size):
        super(Decoder, self).__init__()
        
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        
        self.decode = nn.LSTMCell(hidden_size, hidden_size)
        self.embed = nn.Embedding(vocab_size, hidden_size)
        self.project = nn.Linear(hidden_size, vocab_size)
        
    def forward(self, outputs, hidden, batch_size):
        length = outputs.size()[0]
        psize('hidden', hidden[0]), psize('hidden', hidden[1])
        predicted_outputs = []
    
        dec_embeddings = self.embed(outputs).view(length,
                                                 batch_size,
                                                 hidden_size)           #LxBxH
    
        GO = torch.LongTensor([w2i['GO']] * batch_size).cuda()            
        GO = Variable(GO)
        psize('GO', GO)
        GO_emb = self.embed(GO)
        psize('GO_emd', GO_emb)
        
        hidden, cell_state = self.decode(GO_emb, hidden)
        predicted_outputs.append(hidden)
        for i in range(length - 1):
            import random
            if random.random() > 0:
                dec_input = dec_embeddings[i+1]
            else:
                dec_input = hidden
            hidden, cell_state = self.decode(dec_input, (hidden, cell_state))
            predicted_outputs.append(hidden)
            
        predicted_outputs = torch.stack(predicted_outputs).squeeze(1)
        psize('predicted_outputs', predicted_outputs)
              
        predicted_outputs = self.project(predicted_outputs.view(length*batch_size, hidden_size))
        psize('predicted_outputs', predicted_outputs)
        predicted_outputs = predicted_outputs.view(length, batch_size, vocab_size)
        psize('predicted_outputs', predicted_outputs)

        return predicted_outputs
    
    def predict(self, outputs, hidden, batch_size):
        length = outputs.size()[0]
        psize('hidden', hidden[0]), psize('hidden', hidden[1])
        predicted_outputs = []
    
        dec_embeddings = self.embed(outputs).view(length,
                                                 batch_size,
                                                 hidden_size)           #LxBxH
    
        GO = torch.LongTensor([w2i['GO']] * batch_size).cuda()            
        GO = Variable(GO)
        psize('GO', GO)
        GO_emb = self.embed(GO)
        psize('GO_emd', GO_emb)
        
        hidden, cell_state = self.decode(GO_emb, hidden)
        predicted_outputs.append(hidden)
        for i in range(length - 1):
            dec_input = hidden
            hidden, cell_state = self.decode(dec_input, (hidden, cell_state))
            predicted_outputs.append(hidden)
            
        predicted_outputs = torch.stack(predicted_outputs).squeeze(1)
        psize('predicted_outputs', predicted_outputs)
        
        predicted_outputs = self.project(predicted_outputs.view(length*batch_size, 
                                                                hidden_size))
        psize('predicted_outputs', predicted_outputs)
        predicted_outputs = predicted_outputs.view(length, batch_size, vocab_size)

        return predicted_outputs
    

# TRAINING

In [442]:
from pprint import pprint
from tqdm import tqdm
def train_epochs(epochs, encoder, decoder, eoptim, doptim, criterion, print_every=1):
    model.train()
    losses = []
    config.printsize = True

    for epoch in tqdm(range(epochs+1)):
        loss = train(encoder, decoder, eoptim, doptim, criterion, idx_q[:30000], idx_a[:30000],
                    print_every=print_every*100)    
        if epoch % print_every == 0:
            losses.append(loss)
            print('{} - loss: {}'.format(epoch, loss))

        
def train(encoder, decoder, eoptim, doptim, criterion, question_ids, answer_ids, print_every=100):
    input_length = len(question_ids[0])
    for batch_index in range(len(idx_q)):
        l,r = batch_index * batch_size, (batch_index + 1) * batch_size
        
        question_id, answer_id = question_ids[l:r], answer_ids[l:r]
        _batch_size = len(question_id)
        if _batch_size != batch_size:
            print('breaking because batch sizes do not match')
            break

        data = Variable(torch.from_numpy(question_id).long().cuda().t())
        target = Variable(torch.from_numpy(answer_id).long().cuda().t())

        eoptim.zero_grad(), doptim.zero_grad()    
        initial_hidden = initial_state(batch_size, hidden_size).cuda(), initial_state(batch_size, hidden_size).cuda()
        
        encoder_output = encoder(data, initial_hidden, _batch_size)
        decoder_output = decoder(target, encoder_output, _batch_size)
        loss = 0
        for i in range(input_length):
            logits = F.log_softmax(decoder_output[i])
            loss += criterion(logits, target[i])    
            
        loss.backward()
        eoptim.step(), doptim.step()
        config.printsize = False
        
        if batch_index % print_every == 0:
            print('\t{} - loss: {}'.format(batch_index, loss.data[0]))
        
    return loss.data[0]

In [407]:
encoder = Encoder(vocab_size, hidden_size)
decoder = Decoder(vocab_size, hidden_size)

encoder.cuda()
decoder.cuda()

criterion = nn.NLLLoss()

eoptim = optim.SGD(encoder.parameters(), lr=0.1, momentum=0.1)
doptim = optim.SGD(decoder.parameters(), lr=0.1, momentum=0.1)

In [434]:
train_epochs(10, encoder, decoder, eoptim, doptim, criterion,)

  0%|          | 0/11 [00:00<?, ?it/s]

enc_inputs torch.Size([21, 128]) <class 'torch.cuda.LongTensor'>
enc_embeddings torch.Size([21, 128, 256]) <class 'torch.cuda.FloatTensor'>
enc_embeddings torch.Size([21, 128, 256]) <class 'torch.cuda.FloatTensor'>
hidden torch.Size([128, 256]) <class 'torch.cuda.FloatTensor'>
hidden torch.Size([128, 256]) <class 'torch.cuda.FloatTensor'>
GO torch.Size([128]) <class 'torch.cuda.LongTensor'>
GO_emd torch.Size([128, 256]) <class 'torch.cuda.FloatTensor'>
predicted_outputs torch.Size([21, 128, 256]) <class 'torch.cuda.FloatTensor'>
predicted_outputs torch.Size([2688, 6005]) <class 'torch.cuda.FloatTensor'>
predicted_outputs torch.Size([21, 128, 6005]) <class 'torch.cuda.FloatTensor'>
	0 - loss: 1.3411242961883545
	100 - loss: 1.3380681276321411
	200 - loss: 1.3729528188705444


  9%|▉         | 1/11 [00:21<03:33, 21.34s/it]

breaking because batch sizes do not match
0 - loss: 1.544586420059204
	0 - loss: 1.3162250518798828
	100 - loss: 1.3365718126296997
	200 - loss: 1.1995617151260376


 18%|█▊        | 2/11 [00:43<03:13, 21.54s/it]

breaking because batch sizes do not match
1 - loss: 1.3850281238555908
	0 - loss: 1.209272861480713
	100 - loss: 1.1428279876708984
	200 - loss: 1.070444107055664


 27%|██▋       | 3/11 [01:05<02:53, 21.68s/it]

breaking because batch sizes do not match
2 - loss: 1.2881885766983032
	0 - loss: 1.199554204940796
	100 - loss: 1.1904761791229248
	200 - loss: 1.1714420318603516


 36%|███▋      | 4/11 [01:27<02:32, 21.78s/it]

breaking because batch sizes do not match
3 - loss: 1.2804590463638306
	0 - loss: 1.099460244178772
	100 - loss: 1.082923173904419
	200 - loss: 1.1439708471298218


 45%|████▌     | 5/11 [01:49<02:11, 21.88s/it]

breaking because batch sizes do not match
4 - loss: 1.1584521532058716
	0 - loss: 1.008824348449707
	100 - loss: 1.0420303344726562
	200 - loss: 1.0742990970611572


 55%|█████▍    | 6/11 [02:11<01:49, 21.91s/it]

breaking because batch sizes do not match
5 - loss: 1.2309650182724
	0 - loss: 1.077671766281128
	100 - loss: 1.1084177494049072
	200 - loss: 1.0507218837738037


 64%|██████▎   | 7/11 [02:33<01:27, 21.90s/it]

breaking because batch sizes do not match
6 - loss: 1.37307608127594
	0 - loss: 1.2246034145355225
	100 - loss: 1.0969810485839844
	200 - loss: 1.0633491277694702


 73%|███████▎  | 8/11 [02:55<01:05, 21.91s/it]

breaking because batch sizes do not match
7 - loss: 1.16168212890625
	0 - loss: 1.0008502006530762
	100 - loss: 0.998125433921814
	200 - loss: 0.9589577913284302


 82%|████████▏ | 9/11 [03:17<00:43, 21.93s/it]

breaking because batch sizes do not match
8 - loss: 1.0832334756851196
	0 - loss: 0.9744068384170532
	100 - loss: 1.0879734754562378
	200 - loss: 1.1067527532577515


 91%|█████████ | 10/11 [03:39<00:21, 21.93s/it]

breaking because batch sizes do not match
9 - loss: 1.1511907577514648
	0 - loss: 1.045896291732788
	100 - loss: 0.9164615869522095
	200 - loss: 0.9757237434387207


100%|██████████| 11/11 [04:01<00:00, 21.96s/it]

breaking because batch sizes do not match
10 - loss: 1.0804495811462402





In [433]:
torch.save(encoder.state_dict(), 'graph.pytorch.encoder.pth')
torch.save(decoder.state_dict(), 'graph.pytorch.decoder.pth')

## Test

In [446]:
encoder_test = Encoder(vocab_size, hidden_size)
decoder_test = Decoder(vocab_size, hidden_size)
encoder_test.cuda()
decoder_test.cuda()
encoder_test.load_state_dict(torch.load('graph.pytorch.encoder.pth'))
decoder_test.load_state_dict(torch.load('graph.pytorch.decoder.pth'))

In [447]:
batch = 0
l, r = batch * B, (batch + 1) * B
test_q, test_a = idx_q[0], idx_a[0]

encoder_test.eval()
decoder_test.eval()

test_q = Variable(torch.from_numpy(test_q).long().cuda())
test_a = Variable(torch.from_numpy(test_a).long().cuda())

config.printsize = True
batch_size = 1
hidden = initial_state(batch_size, hidden_size).cuda(), initial_state(batch_size, hidden_size).cuda()
predictions = decoder_test.predict(test_a, encoder_test(test_q, hidden, 1), 1)
predictions = F.log_softmax(predictions).max(1)[1].squeeze(1)


enc_inputs torch.Size([21]) <class 'torch.cuda.LongTensor'>
enc_embeddings torch.Size([21, 256]) <class 'torch.cuda.FloatTensor'>
enc_embeddings torch.Size([21, 1, 256]) <class 'torch.cuda.FloatTensor'>
hidden torch.Size([1, 256]) <class 'torch.cuda.FloatTensor'>
hidden torch.Size([1, 256]) <class 'torch.cuda.FloatTensor'>
GO torch.Size([1]) <class 'torch.cuda.LongTensor'>
GO_emd torch.Size([1, 256]) <class 'torch.cuda.FloatTensor'>
predicted_outputs torch.Size([21, 256]) <class 'torch.cuda.FloatTensor'>
predicted_outputs torch.Size([21, 6005]) <class 'torch.cuda.FloatTensor'>


In [216]:
def arr2sent(arr):
    return ' '.join([i2w[item] for item in arr])

In [448]:
print(predictions)
print(arr2sent(predictions.cpu().data.numpy()))
print(arr2sent(test_a.cpu().data.numpy()))

Variable containing:
    0     0     0  ...      0     0     0
    0     0     0  ...      0     0     0
    0     0     0  ...      0     0     0
       ...          ⋱          ...       
    0     0     0  ...      0     0     0
    0     0     0  ...      0     0     0
    0     0     0  ...      0     0     0
[torch.cuda.LongTensor of size 21x6005 (GPU 0)]



TypeError: only integer scalar arrays can be converted to a scalar index

In [394]:
x = torch.Tensor([[1,2, 2, 3],[1,2,3,4]])
_ = Variable
print(F.log_softmax(_(x)).sum(0))

Variable containing:
-6.0667 -4.0667 -3.0667 -1.0667
[torch.FloatTensor of size 1x4]

