In [1]:
import math
import random

import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import editdistance
from transer import Dataset
from IPython.display import clear_output
import matplotlib.pyplot as plt
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
data = Dataset('train.csv')
words_vocab = data.words_vocab
trans_vocab = data.trans_vocab

In [3]:
words, trans_in, trans_out, words_lens, trans_lens = data.get_batch(32)

In [4]:
words.size(), trans_in.size(), trans_out.size(), words_lens.size(), trans_lens.size()

(torch.Size([32, 13]),
 torch.Size([32, 13]),
 torch.Size([32, 13]),
 torch.Size([32]),
 torch.Size([32]))

In [5]:
gru_cell = nn.GRUCell(32,64)
gru      = nn.GRU(32,64)

In [6]:
emb = torch.rand(12,128,32)

In [7]:
outputs, hidden = gru(emb)
print(outputs.size(), hidden.size())

torch.Size([12, 128, 64]) torch.Size([1, 128, 64])


In [8]:
emb_t = emb[0,:,:]
hidden = torch.zeros(128,64)
print(emb_t.size(), hidden.size())
hidden = gru_cell(emb_t, hidden)

torch.Size([128, 32]) torch.Size([128, 64])


In [9]:
emb     = torch.rand(12,128,32)
hidden  = torch.zeros(128,64)
outputs = []

for i in range(emb.size(0)):
    #emb_t  = emb[i,:,:]
    hidden = gru_cell(emb[i], hidden)
    outputs.append(hidden)
outputs = torch.stack(outputs)
    
    

In [10]:
hidden.size(), outputs.size()

(torch.Size([128, 64]), torch.Size([12, 128, 64]))

In [11]:
class Encoder(nn.Module):
    def __init__(self, vocab_size, emb_size, hidden_size, pad_idx):
        super(Encoder, self).__init__()
        self.emb_size = emb_size
        self.hidden_size = hidden_size
        
        self.embedding = nn.Embedding(vocab_size, emb_size, padding_idx = pad_idx)
        self.gru_cell  = nn.GRUCell(emb_size, hidden_size)
        
    
    def forward(self, source):
        batch_size = source.size(0)
        seq_len = source.size(1)
        embedded = self.embedding(source)
        #embedded = embedded.transpose(0,1)
        hidden = torch.zeros(batch_size, self.hidden_size).to(source.device)
        outputs = []

        for i in range(seq_len):
            #emb_t  = emb[i,:,:]
            hidden = self.gru_cell(embedded[:,i,:], hidden)
            #outputs.append(hidden)
        #outputs = torch.stack(outputs)
        
        #print(embedded.size())
        return hidden
class Decoder(nn.Module):
    def __init__(self, vocab_size, emb_size, hidden_size, pad_idx):
        super(Decoder, self).__init__()
        
        self.emb_size = emb_size
        self.hidden_size = hidden_size
        
        self.embedding  = nn.Embedding(vocab_size, emb_size, padding_idx = pad_idx)
        self.gru_cell   = nn.GRUCell(emb_size, hidden_size)
        self.linear_out = nn.Linear(hidden_size, vocab_size)
    def forward(self, target, hidden):   
        batch_size = target.size(0) 
        embedded = self.embedding(target)
        embedded = embedded.squeeze(1)
        
        hidden = self.gru_cell(embedded, hidden)
        
        logit = self.linear_out(hidden)
        
        #print(embedded.size(), hidden.size())
        return logit, hidden

In [12]:
encoder = Encoder(len(words_vocab), 32, 64, pad_idx = words_vocab.pad_idx)
decoder = Decoder(len(trans_vocab), 32, 64, pad_idx = trans_vocab.pad_idx)

criterion = nn.CrossEntropyLoss()
encoder_optimizer = optim.Adam(encoder.parameters())
decoder_optimizer = optim.Adam(decoder.parameters())

batch_size = 32
num_epochs = 5
losses = []


In [13]:
def plot(epoch, batch_idx, losses):
    clear_output(True)
    plt.figure(figsize=(20,5))
    plt.subplot(131)
    plt.title('epoch %s. | batch: %s | loss: %s' % (epoch, batch_idx, losses[-1]))
    plt.plot(losses)
    plt.show()
    

In [14]:
for epoch in range(num_epochs):
    for batch_idx in range(len(data)//batch_size):
        words, trans_in, trans_out, words_len, trans_len = data.get_batch(32)
        hidden = encoder(words)

        logits = []

        for t in range(trans_in.size(1)):
            logit, hidden = decoder(trans_in[:,t].unsqueeze(1), hidden)
            logits.append(logit)
            #print(logit.size())
        logits = torch.stack(logits, 1)
        logits = logits.view(-1, len(trans_vocab))
        trans_out = trans_out.view(-1)

        mask = trans_out!=trans_vocab.pad_idx
        loss = criterion(logits[mask], trans_out[mask])

        encoder_optimizer.zero_grad()
        decoder_optimizer.zero_grad()

        loss.backward()

        encoder_optimizer.step()
        decoder_optimizer.step()

        losses.append(loss.item())
    
    plot(epoch, batch_idx, losses)



KeyboardInterrupt: 