In [25]:
from __future__ import unicode_literals, print_function, division
from io import open
import unicodedata
import re
import random

import torch
from torch import nn
from torch.nn import functional as F

import numpy as np
import matplotlib.pyplot as plt

from torch.utils.data import DataLoader,Dataset,TensorDataset,RandomSampler
import time

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
# device = torch.device("cpu")

In [26]:
SOS_token = 0
EOS_token = 1
class Language:
    def __init__(self,name):
        self.wordcount = {}
        self.word2index = {}
        self.name = name
        self.index2word = {0: "SOS", 1: "EOS"}
        self.num_words = 2

    def addWord(self,word):
        if word not in self.wordcount.keys():
            self.num_words += 1
            self.word2index[word] = self.num_words
            self.index2word[self.num_words] = word
            self.wordcount[word] = 1
        else:
            self.wordcount[word] += 1

    def addSentence(self,sentence):
        for word in sentence.split(" "):
            self.addWord(word)


In [27]:
def unicodeToASCII(unicode):
    return ''.join(
        c for c in unicodedata.normalize('NFD',unicode)
        if unicodedata.category(c) != 'Mn'
    )
def normaliseText(text):
    # regex to remove punctuations
    text = re.sub(r"([.!?])",r" \1",text)
    text = unicodeToASCII(text.lower().strip())
    text = re.sub(r"[^a-zA-Z!?]+", r" ", text)
    return text.lower().strip()

In [28]:
MAX_LENGTH = 10

eng_prefixes = (
    "i am ", "i m ",
    "he is", "he s ",
    "she is", "she s ",
    "you are", "you re ",
    "we are", "we re ",
    "they are", "they re "
)

def filterPair(p):
    return len(p[0].split(' ')) < MAX_LENGTH and \
        len(p[1].split(' ')) < MAX_LENGTH and \
        p[1].startswith(eng_prefixes)


def filterPairs(pairs):
    return [pair for pair in pairs if filterPair(pair)]

In [29]:
def readData(reverse=False):
    lines = open("fra.txt","r").read().splitlines()
    pairs = [[normaliseText(s) for s in l.split('\t')] for l in lines]
    print(f"{len(pairs)}")
    # pairs = filterPairs(pairs)
    
    if reverse:
        pairs = [list(reversed(p)) for p in pairs]
        french = Language("French")
        english = Language("English")
    else:
        english = Language("English")
        french = Language("French")
    return english,french,pairs


In [30]:
english,french,pairs = readData(False)
print(f"No. of Sentences: {len(pairs)}")
for pair in pairs:
    english.addSentence(pair[0])
    french.addSentence(pair[1])
print(f"No of english words: {english.num_words}")
print(f"No of french words: {french.num_words}")

232736
No. of Sentences: 232736
No of english words: 16033
No of french words: 25684


In [31]:
pairs = pairs[:len(pairs)//2]

In [32]:
class EncoderLSTM(nn.Module):
    def __init__(self,input_size,hidden_size,dropout_p=0.2):
        super().__init__()
        self.embedding = nn.Embedding(input_size,hidden_size)
        self.lstm = nn.LSTM(hidden_size,hidden_size)
        self.dropout = nn.Dropout(dropout_p)
    def forward(self,X):
        embed = self.dropout(self.embedding(X))
        output, final_hidden, final_cell = self.lstm(embed)
        return output, final_hidden

In [33]:
class DecoderLSTM(nn.Module):
    def __init__(self,hidden_size,output_size,dropout_p=0.2):
        super().__init__()
        self.lstm = nn.LSTM(hidden_size,hidden_size,batch_first=True)
        self.embedding = nn.Embedding(output_size,hidden_size)
        self.ff = nn.Linear(hidden_size,output_size)

    def forward_step(self, inputs, hidden,target_tensor=None):
        out = self.embedding(inputs)
        out = nn.ReLU(out)

        output, final_hidden,final_cell = self.lstm(out, hidden)
        output = self.ff(output)
        return output, final_hidden

    def forward(self, encoder_output,encoder_hidden):
        batch_size = encoder_output.shape(0)
        decoder_input = torch.ones(batch_size,1,dtype=torch.long,device=device).fill_(SOS_token)
        decoder_outputs = []
        decoder_hidden = encoder_hidden
        decoder_output, decoder_hidden  = self.forward_step(decoder_input, decoder_hidden)
        decoder_outputs.append(decoder_outputs)
        _, topi = decoder_output.topk(1) # this will return the value of the highest probabilty and the index on which that probability is present
        decoder_input = topi.squeeze(-1).detach()
        while topi.item() != EOS_token:
            decoder_output, decoder_hidden  = self.forward_step(decoder_input, decoder_hidden)
            decoder_outputs.append(decoder_outputs)
            _,topi = decoder_output.topk(1)
            decoder_input = topi.squeeze(-1).detach()
        decoder_outputs = torch.cat(decoder_outputs, dim=1)
        decoder_outputs = F.log_softmax(decoder_outputs, dim=-1)
        return decoder_outputs, decoder_hidden         

In [34]:
def sentenceToIndex(sentence:str,lang:Language):
    indices = [SOS_token]
    indices.extend([lang.word2index[normaliseText(word)] for word in sentence.split(" ")])
    indices.append(EOS_token)
    return indices

def tensorFromSentence(sentence:str,lang:Language):
    indexes = [SOS_token]
    indexes.extend(sentenceToIndex(sentence,lang))
    indexes.append(EOS_token)
    tensor_indexes = torch.tensor(indexes,dtype=torch.long,device=device).view(1,-1)
    return tensor_indexes

def tensorFromPair(pairs:list,lang1:Language,lang2:Language):
    index_1 = tensorFromSentence(pairs[0],lang1)
    index_2 = tensorFromSentence(pairs[1],lang2)
    return (index_1,index_2)

def getData(batch_size):
    # english,french,pairs = readData()
    n = len(pairs)
    input_ids = np.zeros((n,n),dtype=np.int32)
    output_ids = np.zeros((n,n),dtype=np.int32)

    for idx, (input_sent, output_sent,_) in enumerate(pairs):
        # print(input_sent,english.name)
        # print(output_sent,french.name)
        input_index = sentenceToIndex(input_sent,english)
        output_index = sentenceToIndex(output_sent,french)
        input_ids[idx,:len(input_index)] = input_index
        output_ids[idx,:len(output_index)] = output_index
    train_data = TensorDataset(torch.tensor(input_ids,dtype=torch.long,device=device),torch.tensor(output_sent,dtype=torch.long,device=device))
    train_sampler = RandomSampler(train_data)
    train_dataloader = DataLoader(train_data,sampler=train_sampler,batch_size=batch_size,device=device)
    return english,french,train_dataloader
    # return train_dataloader

In [35]:
def train_epoch(train_data, encoder, decoder, encoder_opt, decoder_opt, criterion):
    total_loss=0
    for data in train_data:
        input_tensor,target_tensor = data

        encoder_opt.zero_grad()
        decoder_opt.zero_grad()


        encoder_output, final_hidden = encoder(input_tensor)
        decoder_output, decoder_hidden = decoder(encoder_output,final_hidden)
        loss = criterion(decoder_output.view(-1,decoder_output.size(-1)),target_tensor.view(-1))
        loss.backward()

        encoder_opt.step()
        decoder_opt.step()

        total_loss+=loss.item()

    return total_loss/len(train_data)

def training(train_dataloader,encoder,decoder,epochs,lr=1e-3,print_interval=100):
    start = time.time()
    losses = list()
    encoder_opt = torch.optim.Adam(encoder.parameters(),lr=lr)
    decoder_opt = torch.optim.Adam(decoder.parameters(),lr=lr)
    criterion = nn.NLLLoss()
    
    for epoch in epochs:
        loss = train_epoch(train_dataloader,encoder,decoder,encoder_opt,decoder_opt,criterion)
        losses.append(loss)
        if epoch%print_interval == 0:
            print(f"epoch: {epoch:.3d}: loss = {loss:.7d} \tavg_loss = {sum(losses)/epoch}({time.time()-start} secs)")
    return losses
        

In [36]:
hidden_size = 128
batch_size = 32
epochs = 10

english,french,train_dataloader = getData(batch_size)
print(f"No. of Sentences: {len(pairs)}")
for pair in pairs:
    english.addSentence(pair[0])
    french.addSentence(pair[1])
print(f"No of english words: {english.num_words}")
print(f"No of french words: {french.num_words}")

encoder = EncoderLSTM(english.num_words,hidden_size)
decoder = DecoderLSTM(hidden_size,french.num_words)

losses = training(train_dataloader,encoder,decoder,epochs,print_interval=1)

: 