In [1]:
from io import open
import unicodedata
import string
import re
import random
import pandas as ps
import numpy as np
import os
import torch
import random
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
from tqdm import tqdm_notebook
device = torch.device("cuda")
device

device(type='cuda')

In [2]:
class Util:
    def __init__(self, FILE_PATH):
        self.file = open(FILE_PATH, encoding="utf8")
        # Word index for input/target
        self.word2index_input = {}
        self.word2index_target = {}
        # Word count for input/target
        self.word2count_input = {}
        self.word2count_target = {}
        # Word for the index 
        self.index2word_input = {}
        self.index2word_target = {}
        # Create training data
        self.training_data = []
        self.read_file()
        self.count_words()
        self.create_dictionaries()
        # Final dataset structure
        self.training_set = []
        self.create_training_set()
    
    def read_file(self):
        for line in self.file.readlines():
            line = line.split('\t')[0:2] # split 
            line = [word.translate(str.maketrans('', '', string.punctuation)) for word in line]
            line = [word.lower() for word in line]
            self.training_data.append((line[0], line[1]))
    
    def count_words(self):
        for line, translation in self.training_data:
            for word in line.split(' '):
                if word in self.word2count_input:
                    self.word2count_input[word] = self.word2count_input[word] + 1
                else:
                    self.word2count_input[word] = 1
            for word in translation.split(' '):
                if word in self.word2count_target:
                    self.word2count_target[word] = self.word2count_target[word] + 1
                else:
                    self.word2count_target[word] = 1
    
    def create_dictionaries(self):
        self.word2index_input['SOS'] = 0
        self.word2index_input['EOS'] = 1
        self.word2index_input['unk'] = 2
        self.index2word_input[0] = 'SOS'
        self.index2word_input[1] = 'EOS'
        self.index2word_input[2] = 'unk'
        
        self.word2index_target['SOS'] = 0
        self.word2index_target['EOS'] = 1
        self.word2index_target['unk'] = 2
        self.index2word_target[0] = 'SOS'
        self.index2word_target[1] = 'EOS'
        self.index2word_target[2] = 'unk'
        
        for line, translation in self.training_data:
            for word in line.split(' '):
                if self.word2count_input[word] > 4:
                    if word not in self.word2index_input:
                        self.word2index_input[word] = len(self.word2index_input)
                        self.index2word_input[len(self.word2index_input) - 1] = word
            for word in translation.split(' '):
                if self.word2count_target[word] > 4:
                    if word not in self.word2index_target:
                        self.word2index_target[word] = len(self.word2index_target)
                        self.index2word_target[len(self.word2index_target) - 1] = word
    
    def create_training_set(self):
        for line, translation in self.training_data:
            inp = []
            for word in line.split(' '):
                if word in self.word2index_input:
                    inp.append(self.word2index_input[word])
                else:
                    inp.append(self.word2index_input['unk'])
            target = []
            for word in translation.split(' '):
                if word in self.word2index_target:
                    target.append(self.word2index_target[word])
                else:
                    target.append(self.word2index_target['unk'])
            # append EOS to the sentences
            inp.append(1)
            target.append(1)
            self.training_set.append((torch.tensor(inp), torch.tensor(target)))
            
util = Util(os.path.abspath('ron.txt'))

In [3]:
print(util.training_set[0])

(tensor([2, 1]), tensor([3, 1]))


In [4]:
class Encoder(nn.Module):
    def __init__(self, vocab_size, emb_hidden, hidden_size):
        super(Encoder, self).__init__()
        self.vocab_size = vocab_size
        self.emb_hidden = emb_hidden
        self.hidden_size = hidden_size
        
        self.embedding = nn.Embedding(self.vocab_size, self.emb_hidden)
        self.dense = nn.Linear(self.emb_hidden, self.hidden_size)
        self.lstm = nn.LSTM(self.hidden_size, self.hidden_size)
    
    def init_hidden(self):
        return (torch.zeros(1, 1, self.hidden_size).to(device), torch.zeros(1, 1, self.hidden_size).to(device))
    
    def forward(self, x, hidden):
        x = self.embedding(x).view(1, 1, -1)
        x = self.dense(x)
        x = F.relu(x)
        x, hidden = self.lstm(x, hidden)
        
        return x, hidden

In [5]:
class Decoder(nn.Module):
    def __init__(self, vocab_size, emb_hidden, hidden_size):
        super(Decoder, self).__init__()
        self.vocab_size = vocab_size
        self.emb_hidden = emb_hidden
        self.hidden_size = hidden_size
        
        self.embedding = nn.Embedding(self.vocab_size, self.emb_hidden)
        self.dense = nn.Linear(self.emb_hidden, self.hidden_size)
        self.attn = nn.Linear(3 * hidden_size, 1)
        self.lstm = nn.LSTM(self.hidden_size, self.hidden_size)
        
        self.input_combine = nn.Linear(2 * hidden_size, hidden_size)
        self.last = nn.Linear(self.hidden_size, self.vocab_size)
    
        
    def init_hidden(self):
        return (torch.zeros(1, 1, self.hidden_size).to(device), torch.zeros(1, 1, self.hidden_size).to(device))
    
    
    def forward(self, x, decoder_hidden, encoder_outputs):
        x = self.embedding(x).view(1, 1, -1)
        x = self.dense(x)
        x = F.relu(x)
        attention_values = []
        for i in range(len(encoder_outputs)):
            # concatinate encoder_output at i with the decoder hidden state (tuple cause lstm has 2) -> 3 * hidden_size 
            enc_dec_concat = torch.cat((encoder_outputs[i].view(1, 1, -1), torch.cat((decoder_hidden[0], decoder_hidden[1]), 2)), 2)
            attn_value = self.attn(enc_dec_concat)
            attention_values.append(attn_value)
        alphas = torch.cat(attention_values, 1)
        alphas_norm = F.softmax(alphas, dim=1)
        # Weight multiplication for each encoder output to denote it's importance 
        c = torch.bmm(alphas_norm.view(1, 1, -1), encoder_outputs.view(1, -1, self.hidden_size))
        
        x = torch.cat((x.view(1, 1, -1), c.view(1, 1, -1)), 2)
        x = self.input_combine(x)
        x = F.relu(x)
        
        out, decoder_hidden = self.lstm(x, decoder_hidden)
        out = self.last(out[0])
        
        out = F.log_softmax(out, dim=1)
        
        return out, decoder_hidden
        
        

In [6]:
encoder = Encoder(len(util.word2index_input), 50, 256).to(device)
decoder = Decoder(len(util.word2index_target), 50, 256).to(device)
encoder_optimizer = optim.Adam(encoder.parameters())
decoder_optimizer = optim.Adam(decoder.parameters())

criterion = nn.NLLLoss()

In [7]:
def learn(input_tensor, translation_tensor):
    hidden_enc = encoder.init_hidden()
    hidden_dec = decoder.init_hidden()
    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()
    encoder_outputs = torch.zeros((50, 1, 256)).to(device)# max 10 words per sequence
    teacher_enforcing = random.random() > 0.5 # teacher enforcing
    loss = 0
    
    # Encoding part
    for i in range(len(input_tensor)):
        encoder_out, hidden_enc = encoder(input_tensor[i].to(device), hidden_enc)
        encoder_outputs[i] = encoder_out

    out = torch.tensor(util.word2index_input['SOS']).to(device)
    
    if teacher_enforcing:
        for i in range(len(translation_tensor)):
            translation_out, hidden_dec = decoder(out.to(device), hidden_dec, encoder_outputs)
            out = torch.tensor(translation_tensor[i].to(device))
            loss += criterion(translation_out, translation_tensor[i].unsqueeze(0).to(device))
            if out.item() == 'EOS':
                break
    else:
        for i in range(len(translation_tensor)):
            translation_out, hidden_dec = decoder(out.to(device), hidden_dec, encoder_outputs)
            topv, topi = translation_out.topk(1)
            out = topi.detach().long().cuda()
            loss += criterion(translation_out, translation_tensor[i].unsqueeze(0).to(device))
            if out.item() == 'EOS':
                break
            
    loss.backward()

    encoder_optimizer.step()
    decoder_optimizer.step()
        
    return loss / len(translation_tensor)

In [8]:
def validate(input_tensor, target_tensor):
    correct = 0
    total = 0
    
    hidden_enc = encoder.init_hidden()
    encoder_outputs = torch.zeros((50, 1, 256)).to(device) # max 50 words per sequence
    
    with torch.no_grad():
        for i, word in enumerate(input_tensor):
            out, hidden_enc = encoder(word.to(device), hidden_enc)
            encoder_outputs[i] = out

        hidden_dec = hidden_enc
        decoder_input = torch.tensor([util.word2index_input['SOS']]).to(device)
        
        for i, word in enumerate(target_tensor):
            total += 1
            out, hidden_dec = decoder(decoder_input.to(device), hidden_dec, encoder_outputs)
            topv, topi = out.topk(1)
            decoder_input = topi.detach().long().cuda()
            if decoder_input.item() == target_tensor[i].item():
                correct += 1
            if decoder_input.item() == 0:
                if target_tensor[i] == 0:
                    correct +=1
                break
        
    
    return correct, total

In [10]:
def train():
    for i in range(10):
        loss = 0
        for sentence, translation in tqdm_notebook(util.training_set):
            loss += learn(sentence, translation)
        
        
        print("Loss for this bad boi at 10 epochs is : {}".format(loss // len(util.training_set)))
        correct = 0
        total = 0
        for sentence, translation in tqdm_notebook(util.training_set):
            correct_it, total_it = validate(sentence, translation)
            correct += correct_it
            total += total_it
        print("Accuracy for this bad boi is : {}".format((correct / total) * 100))
            
train()

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  after removing the cwd from sys.path.


HBox(children=(FloatProgress(value=0.0, max=10449.0), HTML(value='')))

KeyboardInterrupt: 