## ENCODER DECODER NETWORK

AND TEACHER FORCING

**References:**

Tutorials Given in Competition Document : [Competetion Link](https://docs.google.com/document/d/1p74wG-bECCgbpyq5x_x2QJrf5RSf9FnMLGSAiyUkHLo/edit)

PyTorch NMT Tutorial : [Pytorch NMT](https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html)

Github Page : To understand batch Processing in PyTorch [Github Pengyuchen](https://github.com/pengyuchen/PyTorch-Batch-Seq2seq)

Referred Few Stackoverflow Links for few Regex examples and for some bugs.

The whole code is divided into two sections:
a)  Functions containing all required procedures b) Execution : Using the function . Expand or Collapse to view each sections and subsections.

Observations :
1.   Using the default learning models work better in Adam.
2.   Training in epochs of 20 20 to avoid failure of timeouts.
3.   Saving the models is not working. Due to randomness everywhere. Language Word2index and index2word gets mapped to different word everytime. So all randomness need to be removed for saving and reusing the models.


NOTE : Change the directory location with respect to google drive location where the data is stored and EXPAND/COLLAPSE Section for the code.

No package other than the specified packages are imported


In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
location = r"/content/drive/My Drive/Files/"                  
INDIC_NLP_LIB_HOME = location + "indic_nlp_library"
INDIC_NLP_RESOURCES = location + "indic_nlp_resources"
data_location        = location + 'NMT/'                   
model_location       = location + 'NMT/NMT_GRUATTN/' 
weekly_data_location = location + 'NMT/Weekly Data/'

### LIBRARIES  -
This subsection contains importing various libraries. Download or clone the indic nlp library and resources to your drive. And change the location accordingly.
Also google colab does not have morfessor and uses old version of nltk. So needed to update/install those two packages.

In [3]:
import sys
sys.path.append(r'{}'.format(INDIC_NLP_LIB_HOME))
from indicnlp import common
common.set_resources_path(INDIC_NLP_RESOURCES)
from indicnlp import loader
loader.load()

In [4]:
!pip install Morfessor
import csv
import re
import string
import spacy
import tqdm.notebook as tq
nlpen = spacy.load("en_core_web_sm")
import random
import pickle
from indicnlp.tokenize import sentence_tokenize
from indicnlp.tokenize import indic_tokenize
from indicnlp.transliterate.unicode_transliterate import UnicodeIndicTransliterator
from indicnlp.transliterate.unicode_transliterate import ItransTransliterator
from indicnlp.normalize.indic_normalize import IndicNormalizerFactory

Collecting Morfessor
  Downloading https://files.pythonhosted.org/packages/39/e6/7afea30be2ee4d29ce9de0fa53acbb033163615f849515c0b1956ad074ee/Morfessor-2.0.6-py3-none-any.whl
Installing collected packages: Morfessor
Successfully installed Morfessor-2.0.6


In [5]:
import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

!pip install -U nltk
import nltk
import sys
nltk.download('wordnet')
from nltk.translate.bleu_score import sentence_bleu
from nltk.translate.meteor_score import single_meteor_score
import numpy as np

Collecting nltk
[?25l  Downloading https://files.pythonhosted.org/packages/5e/37/9532ddd4b1bbb619333d5708aaad9bf1742f051a664c3c6fa6632a105fd8/nltk-3.6.2-py3-none-any.whl (1.5MB)
[K     |▎                               | 10kB 24.4MB/s eta 0:00:01[K     |▌                               | 20kB 30.4MB/s eta 0:00:01[K     |▊                               | 30kB 35.8MB/s eta 0:00:01[K     |█                               | 40kB 28.3MB/s eta 0:00:01[K     |█▏                              | 51kB 28.9MB/s eta 0:00:01[K     |█▍                              | 61kB 31.2MB/s eta 0:00:01[K     |█▋                              | 71kB 21.5MB/s eta 0:00:01[K     |█▉                              | 81kB 22.5MB/s eta 0:00:01[K     |██                              | 92kB 21.6MB/s eta 0:00:01[K     |██▎                             | 102kB 23.0MB/s eta 0:00:01[K     |██▌                             | 112kB 23.0MB/s eta 0:00:01[K     |██▊                             | 122kB 23.0MB/s e

[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Unzipping corpora/wordnet.zip.


In [6]:
def read_csv(location, file_type):
    cFile = open(location) 
    cReader = csv.reader(cFile, delimiter=',')
    header = next(cReader)
    if( file_type == 'train'):
        df = {}
        df['hindi'] = []
        df['english'] = []
        for t in cReader:
            df['hindi'].append(t[1])
            df['english'].append(t[2])
    elif( file_type == 'weekly' ):
        df = {}
        df['hindi'] = []
        for t in cReader:
            df['hindi'].append(t[2])
    return df

In [7]:
def train_test_split(dataset, test_split_percentage):

    total_len   = len(dataset)
    total_index = list(range(total_len))
    test_index = list( total_index[: int(test_split_percentage*total_len)] )
    train_index  = list( total_index[int(test_split_percentage*total_len) : ] )
    #np.random.shuffle(test_index)
    #np.random.shuffle(train_index)
    index = { 'train' : train_index, 'test' : test_index}
    train_df = [ dataset[i] for i in train_index ]
    test_df  = [ dataset[i] for i in test_index ]
    return index, train_df, test_df

### TEXT PROCESSING
This subsection contains processing of english and hindi sentences.
Since processing the 1 Lakh text pairs takes a lot of time. Instead of doing same thing again and again. I have stored the processed texts and token using pickle. 

In [8]:
english_nums = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
hindi_nums =   ['०', '१', '२', '३', '४', '५', '६', '७', '८', '९']

def clean_string( instr ):
    instr = instr.lower()
    instr = instr.replace(u'[', ' ')
    instr = instr.replace(u']', ' ')
    instr = instr.replace(u'{', ' ')
    instr = instr.replace(u'}', ' ')
    instr = instr.replace(u'(', ' ')
    instr = instr.replace(u')', ' ')
    instr = instr.replace(u'...', ' ')
    instr = instr.replace(u'..', ' ')
    instr = instr.replace(u'-', ' ')
    instr = instr.replace(u',', ' ')
    instr = instr.replace(u'"', ' ')
    instr = re.sub(' +',' ', instr)
    return instr
  
def preprocess_hindi( instr ):
    factory    = IndicNormalizerFactory()
    normalizer = factory.get_normalizer("hi",remove_nuktas=True)
    instr      = normalizer.normalize(instr)

    instr      = clean_string( instr )
    #instr = instr.replace(u'॥', '')
    for nums in hindi_nums:
        instr    = instr.replace(nums, nums + ' ')

    instr      = ItransTransliterator.from_itrans( instr , 'hi')  
    instr      = re.sub(' +',' ', instr)
    instr      = ItransTransliterator.from_itrans( instr , 'hi')
    instr      = instr.strip() #sentence_tokenize.sentence_split(instr, lang='hi')
    
    return instr

def preprocess_english( instr ):
    instr = clean_string(instr)

    instr = instr.replace("’", "'")
    instr = instr.replace("n\'t", " not")
    instr = instr.replace("'re" , " are")
    instr = instr.replace("'ve" , " have")
    instr = instr.replace("'s"  , " is")
    instr = instr.replace("'ll" , " will")
    instr = instr.replace("'m" , " am")
    #instr = re.sub(r'[^\w\s\\d]' , " " , instr)
    #instr = re.sub(r'[\d]' , ' ' , instr)

    for nums in english_nums:
        instr    = instr.replace(nums, nums + ' ')
    instr = re.sub(' +',' ', instr)
    instr = instr.strip()

    return instr

def get_hindi_tokens(sentence):
    return indic_tokenize.trivial_tokenize(sentence)

def get_english_tokens(sentence):
    tokens = []
    tokstr = nlpen(sentence)
    for token in tokstr:
        tokens.append(token.text)
    return tokens

In [9]:
# Load_From_file =
#   -1   : Process the texts and store/dump the files into the location
#    0   : Process the texts and do not store the files
#    1   : Directly load the processed text from the location

def process_pairs(df, load_from_file = 0, location = ''):
    if( load_from_file == 1):
        with open(location + r'pairs.pickle', 'rb') as handle:
            pairs = pickle.load(handle)
        with open(location + r'pairs_tokens.pickle', 'rb') as handle:
            pairs_tokens = pickle.load(handle)
        return pairs, pairs_tokens
    else:
        pairs = []
        pairs_tokens = []
        for i in tq.tqdm( range( len(df['hindi']) )):
            hinsen  = df['hindi'][i]
            hsent   = preprocess_hindi( hinsen )
            htokens = get_hindi_tokens(hsent)

            engsen  = df['english'][i]
            esent   = preprocess_english( engsen )
            etokens = get_english_tokens(esent)

            pairs.append( [hsent, esent] )
            pairs_tokens.append( [htokens, etokens] )

        if( load_from_file == -1):
            with open(location + r'pairs.pickle', 'wb') as handle:
                pickle.dump(pairs, handle, protocol=pickle.HIGHEST_PROTOCOL)
            with open(location + r'pairs_tokens.pickle', 'wb') as handle:
                pickle.dump(pairs_tokens, handle, protocol=pickle.HIGHEST_PROTOCOL)

        return pairs, pairs_tokens

### LANGUAGE
This subsection contains the class 'Laguage' which stores all the token and its equivalent index. This subsection also contains functions to convert a sentence to a tensor.

This subsection is referred from pytorch tutorial on NMT.

In [10]:
START_TOKEN = 0
END_TOKEN = 1
PAD_TOKEN = 2
UNK_TOKEN = 3

class Language:
    def __init__(self, name):
        self.name = name
        self.word2index = {}
        self.word2count = {}
        self.index2word = {}
        self.num_words = 4
        self.word2index['START_TOKEN'] = START_TOKEN
        self.word2index['END_TOKEN']   = END_TOKEN
        self.word2index['PAD_TOKEN']   = PAD_TOKEN
        self.word2index['UNK_TOKEN']   = UNK_TOKEN
        self.index2word[START_TOKEN] = 'START_TOKEN'
        self.index2word[END_TOKEN] = 'END_TOKEN'
        self.index2word[PAD_TOKEN] = 'PAD_TOKEN'
        self.index2word[UNK_TOKEN] = 'UNK_TOKEN'

    def addWord(self, word):
        if word in self.word2count:
            self.word2count[word] = self.word2count[word] + 1
        else:
            self.word2count[word] = 1
            #self.word2index[word] = self.num_words
            #self.index2word[self.num_words] = word
            self.num_words = self.num_words + 1
    
    def addSentence(self, sentence_tokens):
        for word in sentence_tokens:
            self.addWord(word)
    
    def filter_words(self):
        self.num_words = 4
        for word in self.word2count:
            if( self.word2count[word] != 1):
                self.word2index[word] = self.num_words
                self.index2word[self.num_words] = word
                self.num_words = self.num_words + 1


def generate_language( pairs_tokens ):
    hindi   = Language('hindi')
    english = Language('english')
    for i in tq.tqdm( range(len(pairs_tokens)) ):
        hindi.addSentence(pairs_tokens[i][0])
        english.addSentence(pairs_tokens[i][1])
    hindi.filter_words()
    english.filter_words()
    return hindi, english

PROCESS TEXT TO TENSOR

In [11]:
def indexesFromSentence(lang, tokens, max_length):
    indexes = []
    indexes.append(START_TOKEN)
    for word in tokens:
        if word in lang.word2index.keys():
            indexes.append( lang.word2index[word] )
        else:
            indexes.append( lang.word2index['UNK_TOKEN'] )
    indexes = indexes[0:max_length-1]
    indexes.append(END_TOKEN)
    indexes.extend( [PAD_TOKEN]*( max_length - len(indexes)))
    return indexes

def tensorFromSentence(lang, sentence, max_length):
    indexes = indexesFromSentence(lang, sentence, max_length)
    return torch.tensor(indexes, dtype=torch.long, device=device)

def tensorsFromPair(pairs, input_lang, output_lang, max_length):
    res_pairs = []
    for pair in pairs:
        input_tensor  = tensorFromSentence(input_lang, pair[0], max_length)
        target_tensor = tensorFromSentence(output_lang, pair[1], max_length)
        res_pairs.append( (input_tensor, target_tensor) )
    return res_pairs

### NEURAL MACHINE TRANSLATOR
This subjection contains 3 main classes Encoder , Decoder and an seq2seq which merge the two encoder and decoder.
It also contains a function to train, use and evaluate the seq2seq model.


ENCODER and DECODER

In [12]:
class Encoder(nn.Module):
    def __init__(self, input_size, embed_size, hidden_size):
        super(Encoder, self).__init__()
        self.embedding = nn.Embedding(input_size, embed_size)
        self.rnn = nn.GRU(embed_size, hidden_size, bidirectional = True)
        self.fc = nn.Linear(hidden_size * 2, hidden_size)
        #self.fc_cell = nn.Linear(hidden_size * 2, hidden_size)

    def forward(self, input):
        # input.shape :    [Sentence Length, Batch Size]
        # embedded.shape : [Sentence Length, Batch Size, Embedding Dimension]
        # output.shape :   [Sentence Length, Batch Size, Hidden Size]
        # hidden.shape :   [Layers = 2*2 , Batch Size, Hidden Size]
        # cell.shape   :   [Layers = 2*2 , Batch Size, Hidden Size]

        embedded = self.embedding(input)
        output, hidden = self.rnn(embedded)
        hidden = torch.cat( (hidden[-2,:,:], hidden[-1,:,:]), dim=1)
        hidden = self.fc(hidden)
        #hidden = torch.tanh(hidden)

        return output, hidden

In [13]:
class Attention(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.attn = nn.Linear((hidden_size * 2) + hidden_size, hidden_size)
        self.v = nn.Linear(hidden_size, 1, bias = False)
        
    def forward(self, hidden, encoder_outputs):
        # hidden.shape          :   [Batch Size, Hidden Size]
        # encoder_outputs.shape :   [Sen Len, Batch Size, Hidden_size*2]
        
        # After Ajusting
        # hidden.shape          :   [Batch Size, Sen Length, Hidden Size]
        # encoder_outputs.shape :   [Batch Size, Sen Length, Hidden_size*2]

        src_len = encoder_outputs.shape[0]
        hidden = hidden.unsqueeze(1)
        hidden = hidden.repeat(1, src_len, 1)
        encoder_outputs = encoder_outputs.permute(1, 0, 2)
        comb        = torch.cat((hidden, encoder_outputs), dim = 2)
        energy      = torch.tanh( self.attn(comb) )
        attention   = self.v(energy).squeeze(2)
        attention   = F.softmax(attention, dim=1)
        attention   = attention.unsqueeze(1)
        weights     = torch.bmm(attention, encoder_outputs)
        weights     = weights.permute(1,0,2)
        return weights

In [14]:
class DecoderAttn(nn.Module):
    def __init__(self, output_size, embed_size, hidden_size):
        super(DecoderAttn, self).__init__()
        self.embedding = nn.Embedding(output_size, embed_size)
        self.rnn   = nn.GRU((hidden_size*2)+embed_size, hidden_size)
        self.dense  = nn.Linear(hidden_size*3 + embed_size, output_size)
        self.softmax = nn.LogSoftmax(dim=1)

        self.attention = Attention(hidden_size)

    def forward(self, target, hidden, encoder_outputs):
        # target.shape :   [Batch Size]
        # target.shape :   [1, Batch Size] after unsqueezing
        # embed.shape  :   [1, Batch Size, Embedding Size]
        # output.shape :   [1, Batch Size, Hidden Size] before squeezing
        # hidden.shape :   [Batch Size, Hidden Size]
        # preds.shape  :   [Batch Size, Output_Vocabulary_Size]

        target = target.unsqueeze(0)
        embed  = self.embedding(target)
        weights = self.attention(hidden, encoder_outputs)
        rinput   = torch.cat((embed, weights), dim = 2)
        hidden = hidden.unsqueeze(0)
        output, hidden = self.rnn(rinput, hidden)
        dense_input = torch.cat((output, weights, embed), dim=2)
        preds = self.dense(dense_input[0])
        #preds = F.relu(preds)
        preds = self.softmax(preds)
        return preds, hidden.squeeze(0)

In [15]:
class seq2seq(nn.Module):
    def __init__(self, input_size, output_size, embed_size, hidden_size, max_length):
        super(seq2seq, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.embed_size = embed_size
        self.max_length = max_length

        self.encoder = Encoder(input_size, embed_size, hidden_size).to(device)
        self.decoder = DecoderAttn(output_size, embed_size, hidden_size).to(device)

    def forward(self, src, target , teacher_forcing = 0.5):
        # If teacher forcing is set to 0.5, it will use true outputs half the time for
        # next input to decoder and use the predicted output as input
        # If teacher forcing is 0, it will always use previous output as input to decoder.

        # src.shape    = [Input Sentence Length, Batch Size]
        # target.shape = [Output Sentence Length, Batch Size]
        # decoder_output.shape = [ Output Sentence Length, Batch Size, ]
        # Encode the Source Sentence; Decode the tokens one by one.

        batch_size, target_vocab_size = src.shape[1], self.output_size
        outputs = torch.zeros(self.max_length, batch_size, target_vocab_size).to(device)
        encoder_outputs, hidden = self.encoder(src)
        dinput = src[0,:]
        for index in range(1, self.max_length):
            output, hidden = self.decoder(dinput, hidden, encoder_outputs)
            if random.random() < teacher_forcing:
                dinput = target[index]  
            else:
                dinput = output.argmax(1)
            outputs[index] = output

        return outputs

In [16]:
# Set model in training mode to activate dropouts
# Transpose the text tokens to adjust to pytorch
# Forward Pass on Encoder-Decoder
# Optimize network
def train( model, opt, lossfn, train_loader, r_epoch, save_model=0):
    model.train()
    history = []
    num_batches = len(train_loader)
    tf_ratio = 0.5

    for epoch in range(r_epoch[0], r_epoch[1]):
        epoch_loss = 0
        if((epoch+1) % 5 == 0):
            tf_ratio = tf_ratio - 0.1

        for inS, outS in tq.tqdm( train_loader ):
            opt.zero_grad()
            loss = 0

            inS =  inS.transpose(0, 1)
            outS = outS.transpose(0, 1)
            predoutS = model(inS, target = outS, teacher_forcing=tf_ratio)
            outS     = outS[1:].reshape(-1)       # Reshape outputs
            predoutS = predoutS[1:].reshape(-1, predoutS.shape[-1])

            loss = lossfn(predoutS, outS)         # Compute Loss
            loss.backward()                       # Propagate Loss To the Netowork
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1)  # Gradient Clipping
            opt.step()                            # Update the weights
            epoch_loss = epoch_loss + loss.item()
            

        print(' Epoch : ', epoch , '   loss  : ', epoch_loss / num_batches )
        history.append(epoch_loss / num_batches)

        if( save_model == 1):
            if( (epoch+1)%5 == 0):
                torch.save(model.state_dict(), model_location + 'gruattn_dict_' + str(epoch) )
            if( (epoch+1)%20 == 0):
                torch.save(model, model_location + 'gruattn_' +  str(epoch) )

    return history

In [17]:
# Set model to evaluation model to disable dropout layer
# get Tensor from Sentence and adjust it to size [Sequence Length, Max Length = 1]
def make_sentence(tokens):
    str = ''
    for x in tokens:
        if x is 'UNK_TOKEN':
            str = str + ' ' + '<UNK>'
        elif x not in ['START_TOKEN', 'END_TOKEN', 'PAD_TOKEN']:
            str = str + ' ' + x
    return re.sub('(?<=\d)+ (?=\d)+', '', str)[1:]

def translate(model, sentence, input_lang, output_lang, max_length):
    model.eval()
    with torch.no_grad():
        input = tensorFromSentence( input_lang, sentence, max_length= max_length)
        input = torch.transpose( input.unsqueeze(0) , 0 , 1)
        output = model(input, target=None, teacher_forcing = 0)
        dec_words = []
        for x in output.squeeze():
            i = x.argmax(0)
            dec_words.append( output_lang.index2word[ i.item() ] )
            if(i.item() == END_TOKEN ):
                break
    return make_sentence( dec_words )


### PERFORMANCE EVALUATION
Evaluation Script Modified to give Bleu and Meteor Score

In [18]:
def get_bleu_score(model, pairs, input_lang, output_lang, max_length):
    total_num = len(pairs)
    total_bleu_scores = 0
    total_meteor_scores = 0
    
    for i in tq.tqdm( range(total_num) ):
        output    = translate(model, pairs[i][0], input_lang, output_lang, max_length)
        original  = make_sentence(pairs[i][1])
        total_bleu_scores   += sentence_bleu([output.split(" ")], original.split(" "))
        total_meteor_scores += single_meteor_score(output, original)

    bleu_result = total_bleu_scores/total_num
    meteor_result = total_meteor_scores/total_num
    
    print()
    print("BLEU score: ",bleu_result)
    print("METEOR score: ",meteor_result)

# EXECUTION
Executing the whole process.


1.   Read the training data
2.   Process all sentences( english and hindi)
3.   Generate Language ( word2index and index2word)
4.   Prepare tensors for all tokens.
5.   Create the seq2seq model and train the model
6.   Evaluate the performance
7.   Use the model for weekly translation



READ AND PROCESS FILE

In [19]:
MAX_LENGTH = 32
batch_size = 256


print('Reading Training Data ... ', end = '')
df = read_csv(data_location + 'train.csv', 'train')
print('Done')

print('Processing Strings ... ', end = '')
pairs, tokens = process_pairs(df, load_from_file=1, location = data_location + 'DataPairs/')
print('Done')

print('Splitting Dataset ... ', end = '')
index, train_tokens, test_tokens = train_test_split(tokens,  0.2)
print('Done')

print('Preparing Language Word2vectors and inverse ... ', end = '')
# Generate Langauge Input and Output
hindi, english = generate_language(tokens)
print('Done, Hindi Token Count : ', hindi.num_words, '  English Token Count : ', english.num_words)

print('Preparing Tensors ... ', end = '')
# Get Tensors for tokens and create Dataloaders
train_tensors = tensorsFromPair(train_tokens, hindi, english, MAX_LENGTH)
test_tensors = tensorsFromPair(test_tokens, hindi, english, MAX_LENGTH)
print('Done')

print('Preparing Dataloaders ... ', end = '')
train_loader = torch.utils.data.DataLoader(train_tensors, batch_size=batch_size, shuffle=True)
test_loader  = torch.utils.data.DataLoader(test_tensors, batch_size=batch_size, shuffle=True)
pretrainloader = torch.utils.data.DataLoader(train_tensors[0:batch_size], batch_size=batch_size, shuffle=True)
print('Done')

Reading Training Data ... Done
Processing Strings ... Done
Splitting Dataset ... Done
Preparing Language Word2vectors and inverse ... 

HBox(children=(FloatProgress(value=0.0, max=102322.0), HTML(value='')))


Done, Hindi Token Count :  21104   English Token Count :  18988
Preparing Tensors ... Done
Preparing Dataloaders ... Done


TRAIN MODEL

In [21]:
# Model Parameters
print('Initialising Parameters')
hidden_size = 512
input_vocab_size = hindi.num_words + 1
output_vocab_size = english.num_words + 1
embedding_dim = 300
epochs = 20
pretrain_epoch = 0
#save_losses
Losses = []

#Generate Model, optimizer, lossfn
print('Creating Models ... ', end = ' ')
model = seq2seq(input_vocab_size, output_vocab_size , embedding_dim, hidden_size, MAX_LENGTH)
optimizer = optim.Adam( model.parameters())
lossfn = nn.NLLLoss(ignore_index=PAD_TOKEN)
print('Done')

#load_model weights if available
load_model = 1
if(load_model==1):
    print('Loading Pretrained Weights .. :')
    model.load_state_dict( torch.load(model_location + 'gruattn_dict_39'))
model.eval() 

Initialising Parameters
Creating Models ...  Done
Loading Pretrained Weights .. :


seq2seq(
  (encoder): Encoder(
    (embedding): Embedding(21105, 300)
    (rnn): GRU(300, 512, bidirectional=True)
    (fc): Linear(in_features=1024, out_features=512, bias=True)
  )
  (decoder): DecoderAttn(
    (embedding): Embedding(18989, 300)
    (rnn): GRU(1324, 512)
    (dense): Linear(in_features=1836, out_features=18989, bias=True)
    (softmax): LogSoftmax(dim=1)
    (attention): Attention(
      (attn): Linear(in_features=1536, out_features=512, bias=True)
      (v): Linear(in_features=512, out_features=1, bias=False)
    )
  )
)

In [None]:
# Pretrain to Overfit Model on single batch
train(model, optimizer, lossfn, pretrainloader, (0,200))


 Epoch :  21    loss  :  4.4923481941223145


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  22    loss  :  4.315109729766846


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  23    loss  :  4.192707061767578


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  24    loss  :  4.22307014465332


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  25    loss  :  4.148548603057861


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  26    loss  :  4.013594150543213


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  27    loss  :  3.904195785522461


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  28    loss  :  3.8517346382141113


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  29    loss  :  3.7031123638153076


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  30    loss  :  3.5979599952697754


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  31    loss  :  3.5061099529266357


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  32    loss  :  3.3658814430236816


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  33    loss  :  3.227588415145874


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  34    loss  :  3.108192205429077


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  35    loss  :  2.989473581314087


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  36    loss  :  2.9054555892944336


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  37    loss  :  2.782717704772949


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  38    loss  :  2.6248650550842285


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  39    loss  :  2.5911011695861816


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  40    loss  :  2.4286556243896484


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  41    loss  :  2.2922093868255615


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  42    loss  :  2.158637762069702


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  43    loss  :  2.102001905441284


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  44    loss  :  2.0473146438598633


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  45    loss  :  1.8954404592514038


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  46    loss  :  1.7875083684921265


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  47    loss  :  1.7885924577713013


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  48    loss  :  1.696744680404663


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  49    loss  :  1.5549472570419312


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  50    loss  :  1.4763239622116089


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  51    loss  :  1.3945285081863403


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  52    loss  :  1.464187741279602


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  53    loss  :  1.370223045349121


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  54    loss  :  1.2373545169830322


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  55    loss  :  1.4181526899337769


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  56    loss  :  1.2203702926635742


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  57    loss  :  1.3587162494659424


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  58    loss  :  1.2117406129837036


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  59    loss  :  1.3462430238723755


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  60    loss  :  1.2742024660110474


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  61    loss  :  1.0425728559494019


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  62    loss  :  0.994838297367096


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  63    loss  :  1.048134207725525


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  64    loss  :  0.9572029709815979


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  65    loss  :  0.9624354243278503


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  66    loss  :  0.9495874643325806


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  67    loss  :  0.8780005574226379


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  68    loss  :  0.8131539225578308


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  69    loss  :  0.8641667366027832


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  70    loss  :  0.7942917346954346


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  71    loss  :  0.849055826663971


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  72    loss  :  0.8468276262283325


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  73    loss  :  0.7074281573295593


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  74    loss  :  0.6649332642555237


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  75    loss  :  0.6715803742408752


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  76    loss  :  0.6364166140556335


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  77    loss  :  0.692819356918335


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  78    loss  :  0.6329879760742188


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  79    loss  :  0.6328212022781372


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  80    loss  :  0.5866047143936157


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  81    loss  :  0.5908501744270325


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  82    loss  :  0.5586667060852051


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  83    loss  :  0.6056114435195923


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  84    loss  :  0.6242931485176086


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  85    loss  :  0.4890003204345703


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  86    loss  :  0.5065665245056152


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  87    loss  :  0.4413069784641266


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  88    loss  :  0.4644148051738739


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  89    loss  :  0.42235955595970154


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  90    loss  :  0.4836525022983551


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  91    loss  :  0.4429602324962616


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  92    loss  :  0.4264688491821289


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  93    loss  :  0.3949178159236908


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  94    loss  :  0.40409475564956665


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  95    loss  :  0.35054299235343933


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  96    loss  :  0.3880467116832733


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  97    loss  :  0.30869174003601074


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  98    loss  :  0.34084129333496094


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  99    loss  :  0.2964801788330078


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  100    loss  :  0.335616797208786


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  101    loss  :  0.2888067066669464


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  102    loss  :  0.31755098700523376


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  103    loss  :  0.32373976707458496


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  104    loss  :  0.2828691303730011


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  105    loss  :  0.25718721747398376


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  106    loss  :  0.2636008858680725


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  107    loss  :  0.24309884011745453


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  108    loss  :  0.21844856441020966


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  109    loss  :  0.23450873792171478


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  110    loss  :  0.20218731462955475


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  111    loss  :  0.1983618438243866


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  112    loss  :  0.1838233470916748


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  113    loss  :  0.17346717417240143


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  114    loss  :  0.18399757146835327


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  115    loss  :  0.16645342111587524


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  116    loss  :  0.15033318102359772


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  117    loss  :  0.15069593489170074


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  118    loss  :  0.14430131018161774


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  119    loss  :  0.13557611405849457


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  120    loss  :  0.13188707828521729


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  121    loss  :  0.12576478719711304


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  122    loss  :  0.11524034291505814


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  123    loss  :  0.10989689081907272


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  124    loss  :  0.104647696018219


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  125    loss  :  0.09635228663682938


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  126    loss  :  0.091677725315094


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  127    loss  :  0.08817350119352341


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  128    loss  :  0.08333997428417206


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  129    loss  :  0.0818260982632637


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  130    loss  :  0.07733173668384552


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  131    loss  :  0.07284438610076904


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  132    loss  :  0.06807900965213776


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  133    loss  :  0.06466206908226013


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  134    loss  :  0.058770183473825455


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  135    loss  :  0.05854014679789543


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  136    loss  :  0.05658678337931633


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  137    loss  :  0.05206621065735817


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  138    loss  :  0.050384435802698135


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  139    loss  :  0.04574684798717499


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  140    loss  :  0.043905433267354965


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  141    loss  :  0.04179270192980766


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  142    loss  :  0.037626560777425766


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  143    loss  :  0.038949061185121536


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  144    loss  :  0.03609991818666458


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  145    loss  :  0.0351262167096138


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  146    loss  :  0.03249825909733772


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  147    loss  :  0.027943218126893044


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  148    loss  :  0.029193934053182602


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  149    loss  :  0.02740940824151039


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  150    loss  :  0.024275045841932297


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  151    loss  :  0.022574564442038536


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  152    loss  :  0.020735347643494606


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  153    loss  :  0.021070245653390884


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  154    loss  :  0.019285883754491806


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  155    loss  :  0.018850266933441162


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  156    loss  :  0.017842713743448257


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  157    loss  :  0.017124265432357788


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  158    loss  :  0.01679646596312523


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  159    loss  :  0.015677275136113167


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  160    loss  :  0.017407724633812904


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  161    loss  :  0.015913961455225945


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  162    loss  :  0.013977693393826485


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  163    loss  :  0.01575889252126217


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  164    loss  :  0.012438065372407436


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  165    loss  :  0.011586836539208889


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  166    loss  :  0.011761379428207874


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  167    loss  :  0.011530132032930851


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  168    loss  :  0.010492634028196335


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  169    loss  :  0.012285822071135044


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  170    loss  :  0.01134910061955452


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  171    loss  :  0.00986152421683073


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  172    loss  :  0.013377047143876553


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  173    loss  :  0.01326777134090662


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  174    loss  :  0.009412718936800957


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  175    loss  :  0.009097730740904808


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  176    loss  :  0.009273458272218704


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  177    loss  :  0.008486878126859665


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  178    loss  :  0.009845603257417679


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  179    loss  :  0.009131411090493202


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  180    loss  :  0.008241208270192146


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  181    loss  :  0.008914864622056484


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  182    loss  :  0.009108040481805801


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  183    loss  :  0.008841301314532757


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  184    loss  :  0.008330042473971844


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  185    loss  :  0.008251533843576908


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  186    loss  :  0.007786578964442015


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  187    loss  :  0.007017624098807573


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  188    loss  :  0.008318580687046051


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  189    loss  :  0.00719123100861907


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  190    loss  :  0.007481854408979416


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  191    loss  :  0.00736029539257288


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  192    loss  :  0.007389924023300409


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  193    loss  :  0.007896550931036472


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  194    loss  :  0.00790481548756361


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  195    loss  :  0.007229872513562441


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  196    loss  :  0.00945352204144001


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  197    loss  :  0.008480006828904152


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  198    loss  :  0.007577402051538229


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


 Epoch :  199    loss  :  0.009074542671442032


[9.881617546081543,
 8.673712730407715,
 7.365865230560303,
 6.418414115905762,
 5.95231819152832,
 5.70348596572876,
 5.56167459487915,
 5.473485469818115,
 5.238452911376953,
 5.263691425323486,
 5.132071018218994,
 5.080777168273926,
 5.009632110595703,
 4.886823654174805,
 4.85363245010376,
 4.731806755065918,
 4.771327972412109,
 4.608567714691162,
 4.543424606323242,
 4.5701189041137695,
 4.395766735076904,
 4.4923481941223145,
 4.315109729766846,
 4.192707061767578,
 4.22307014465332,
 4.148548603057861,
 4.013594150543213,
 3.904195785522461,
 3.8517346382141113,
 3.7031123638153076,
 3.5979599952697754,
 3.5061099529266357,
 3.3658814430236816,
 3.227588415145874,
 3.108192205429077,
 2.989473581314087,
 2.9054555892944336,
 2.782717704772949,
 2.6248650550842285,
 2.5911011695861816,
 2.4286556243896484,
 2.2922093868255615,
 2.158637762069702,
 2.102001905441284,
 2.0473146438598633,
 1.8954404592514038,
 1.7875083684921265,
 1.7885924577713013,
 1.696744680404663,
 1.554947

In [None]:
# Final Train on all Training Data Set, # Append the losses
pretrain_epoch = 0
epochs = 40
history = train(model, optimizer, lossfn, train_loader, (pretrain_epoch , pretrain_epoch + epochs), save_model = 1)
Losses.extend(history)
Losses

HBox(children=(FloatProgress(value=0.0, max=320.0), HTML(value='')))


 Epoch :  0    loss  :  5.306207177042961


HBox(children=(FloatProgress(value=0.0, max=320.0), HTML(value='')))


 Epoch :  1    loss  :  3.798873773962259


HBox(children=(FloatProgress(value=0.0, max=320.0), HTML(value='')))


 Epoch :  2    loss  :  3.059540618956089


HBox(children=(FloatProgress(value=0.0, max=320.0), HTML(value='')))


 Epoch :  3    loss  :  2.6531201243400573


HBox(children=(FloatProgress(value=0.0, max=320.0), HTML(value='')))


 Epoch :  4    loss  :  2.4992314126342534


HBox(children=(FloatProgress(value=0.0, max=320.0), HTML(value='')))


 Epoch :  5    loss  :  2.2506532415747644


HBox(children=(FloatProgress(value=0.0, max=320.0), HTML(value='')))


 Epoch :  6    loss  :  2.016348884999752


HBox(children=(FloatProgress(value=0.0, max=320.0), HTML(value='')))


 Epoch :  7    loss  :  1.8162709075957537


HBox(children=(FloatProgress(value=0.0, max=320.0), HTML(value='')))


 Epoch :  8    loss  :  1.6412615414708853


HBox(children=(FloatProgress(value=0.0, max=320.0), HTML(value='')))


 Epoch :  9    loss  :  1.6089615866541862


HBox(children=(FloatProgress(value=0.0, max=320.0), HTML(value='')))


 Epoch :  10    loss  :  1.5047208599746227


HBox(children=(FloatProgress(value=0.0, max=320.0), HTML(value='')))


 Epoch :  11    loss  :  1.407477853819728


HBox(children=(FloatProgress(value=0.0, max=320.0), HTML(value='')))


 Epoch :  12    loss  :  1.305465718358755


HBox(children=(FloatProgress(value=0.0, max=320.0), HTML(value='')))


 Epoch :  13    loss  :  1.217481330037117


HBox(children=(FloatProgress(value=0.0, max=320.0), HTML(value='')))


 Epoch :  14    loss  :  1.2765130223706365


HBox(children=(FloatProgress(value=0.0, max=320.0), HTML(value='')))


 Epoch :  15    loss  :  1.2628283154219389


HBox(children=(FloatProgress(value=0.0, max=320.0), HTML(value='')))


 Epoch :  16    loss  :  1.2313409056514502


HBox(children=(FloatProgress(value=0.0, max=320.0), HTML(value='')))


 Epoch :  17    loss  :  1.1776728732511401


HBox(children=(FloatProgress(value=0.0, max=320.0), HTML(value='')))


 Epoch :  18    loss  :  1.1321511428803206


HBox(children=(FloatProgress(value=0.0, max=320.0), HTML(value='')))


 Epoch :  19    loss  :  1.2369312556460499


HBox(children=(FloatProgress(value=0.0, max=320.0), HTML(value='')))


 Epoch :  20    loss  :  1.2499164000153542


HBox(children=(FloatProgress(value=0.0, max=320.0), HTML(value='')))


 Epoch :  21    loss  :  1.21764185577631


HBox(children=(FloatProgress(value=0.0, max=320.0), HTML(value='')))


 Epoch :  22    loss  :  1.1904155423864722


HBox(children=(FloatProgress(value=0.0, max=320.0), HTML(value='')))


 Epoch :  23    loss  :  1.1611024629324675


HBox(children=(FloatProgress(value=0.0, max=320.0), HTML(value='')))


 Epoch :  24    loss  :  1.2524032736197115


HBox(children=(FloatProgress(value=0.0, max=320.0), HTML(value='')))


 Epoch :  25    loss  :  1.2650340868160128


HBox(children=(FloatProgress(value=0.0, max=320.0), HTML(value='')))


 Epoch :  26    loss  :  1.245679411664605


HBox(children=(FloatProgress(value=0.0, max=320.0), HTML(value='')))


 Epoch :  27    loss  :  1.2161254012957214


HBox(children=(FloatProgress(value=0.0, max=320.0), HTML(value='')))


 Epoch :  28    loss  :  1.1931087624281644


HBox(children=(FloatProgress(value=0.0, max=320.0), HTML(value='')))


 Epoch :  29    loss  :  1.1677618868649007


HBox(children=(FloatProgress(value=0.0, max=320.0), HTML(value='')))


 Epoch :  30    loss  :  1.1548232071101665


HBox(children=(FloatProgress(value=0.0, max=320.0), HTML(value='')))


 Epoch :  31    loss  :  1.141481806896627


HBox(children=(FloatProgress(value=0.0, max=320.0), HTML(value='')))


 Epoch :  32    loss  :  1.1371271597221493


HBox(children=(FloatProgress(value=0.0, max=320.0), HTML(value='')))


 Epoch :  33    loss  :  1.1207850560545922


HBox(children=(FloatProgress(value=0.0, max=320.0), HTML(value='')))


 Epoch :  34    loss  :  1.1091734379529954


HBox(children=(FloatProgress(value=0.0, max=320.0), HTML(value='')))


 Epoch :  35    loss  :  1.0991636136546732


HBox(children=(FloatProgress(value=0.0, max=320.0), HTML(value='')))


 Epoch :  36    loss  :  1.0955843701958656


HBox(children=(FloatProgress(value=0.0, max=320.0), HTML(value='')))


 Epoch :  37    loss  :  1.085945650190115


HBox(children=(FloatProgress(value=0.0, max=320.0), HTML(value='')))


 Epoch :  38    loss  :  1.0795336263254285


HBox(children=(FloatProgress(value=0.0, max=320.0), HTML(value='')))


 Epoch :  39    loss  :  1.0680181326344609


[5.306207177042961,
 3.798873773962259,
 3.059540618956089,
 2.6531201243400573,
 2.4992314126342534,
 2.2506532415747644,
 2.016348884999752,
 1.8162709075957537,
 1.6412615414708853,
 1.6089615866541862,
 1.5047208599746227,
 1.407477853819728,
 1.305465718358755,
 1.217481330037117,
 1.2765130223706365,
 1.2628283154219389,
 1.2313409056514502,
 1.1776728732511401,
 1.1321511428803206,
 1.2369312556460499,
 1.2499164000153542,
 1.21764185577631,
 1.1904155423864722,
 1.1611024629324675,
 1.2524032736197115,
 1.2650340868160128,
 1.245679411664605,
 1.2161254012957214,
 1.1931087624281644,
 1.1677618868649007,
 1.1548232071101665,
 1.141481806896627,
 1.1371271597221493,
 1.1207850560545922,
 1.1091734379529954,
 1.0991636136546732,
 1.0955843701958656,
 1.085945650190115,
 1.0795336263254285,
 1.0680181326344609]

In [None]:
#save Model and its dictionary
torch.save(model.state_dict(), model_location + 'bilstm_np_dict_' + str(epochs) )
torch.save(model, model_location + 'bilstm_np_' + str(epochs) )
torch.save(model.encoder.state_dict(), model_location + 'bilstm_enc_dict_' + str(epochs) )
torch.save(model.encoder, model_location + 'bilstm_enc_' + str(epochs) )
torch.save(model.decoder.state_dict(), model_location + 'bilstm_dec_dict_' + str(epochs) )
torch.save(model.decoder, model_location + 'bilstm_dec_' + str(epochs) )

### USE MODEL

In [22]:
get_bleu_score(model, test_tokens, hindi, english, MAX_LENGTH)

HBox(children=(FloatProgress(value=0.0, max=20464.0), HTML(value='')))

The hypothesis contains 0 counts of 2-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 3-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 4-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()




BLEU score:  0.0336906359919707
METEOR score:  0.31819703302161806


USE MODEL FOR WEEKLY TRANSLATION

In [None]:
print('Reading Weekly Data ... ', end = '')
week = read_csv(weekly_data_location + 'Week4/hindistatements.csv', file_type='weekly')    # Load weekly data 
print('Done')

print('Process Weekly Hindi Data ... ', end = '')
week_processed = []
for x in  week['hindi']:
    t = get_hindi_tokens(preprocess_hindi(x))
    week_processed.append(t)
print('Done')

print('Trasnlating all the sentences ... ', end = '')
translated_texts = []
for i in tq.tqdm( range(len(week_processed)) ):
  translated_texts.append( translate(model, week_processed[i], hindi, english, MAX_LENGTH) ) 
print('Done')

print('Storing translated Sentences ... ', end = '')
with open(weekly_data_location + 'Week4/bigruattn20.txt', 'w') as f:
    for item in translated_texts:
        f.write("%s\n" % item)
print('Done')

Reading Weekly Data ... Done
Process Weekly Hindi Data ... Done
Trasnlating all the sentences ... 

HBox(children=(FloatProgress(value=0.0, max=5000.0), HTML(value='')))


Done
Storing translated Sentences ... Done


In [None]:
#translate(model, week_processed[i], hindi, english, MAX_LENGTH)


#torch.save( tmodel.state_dict(), model_location + 'gru_dict_100')
#torch.save(model, location+ 'gru_enc_dec')

#tmodel = torch.load(model_location+ 'gru_100')
#tmodel.eval()

#tq.tqdm._instances.clear()