In [1]:
import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import Dataset, DataLoader, random_split
from io import open
import unicodedata
import re

In [2]:
if torch.cuda.is_available():
    device=torch.device(type='cuda', index=0)
else:
    device=torch.device(type='cpu', index=0)

print(f'Using device {device}')

Using device cuda:0


In [3]:
#unicode 2 ascii, remove non-letter characters, trim
def normalizeString(s): 
    sres=""
    for ch in unicodedata.normalize('NFD', s): 
        #Return the normal form form ('NFD') for the Unicode string s.
        if unicodedata.category(ch) != 'Mn':
            # The function in the first part returns the general 
            # category assigned to the character ch as string. 
            # "Mn' refers to Mark, Nonspacing
            sres+=ch
    sres = re.sub(r"([.!?])", r" \1", sres) 
    # inserts a space before any occurrence of ".", "!", or "?" in the string sres. 
    sres = re.sub(r"[^a-zA-Z!?]+", r" ", sres) 
    # this line of code replaces any sequence of characters in sres 
    # that are not letters (a-z or A-Z) or the punctuation marks 
    # "!" or "?" with a single space character.
    return sres.strip()

#create list of pairs (list of lists) (no filtering)
def createNormalizedPairs():
    initpairs=[]
    for pair in data:
        s1,s2=pair.split('\t')
        s1=normalizeString(s1.lower().strip())
        s2=normalizeString(s2.lower().strip())
        initpairs.append([s1,s2])
    #print(len(initpairs))
    return initpairs

#filter pairs
max_length = 10
def filterPairs(initpairs):
    #filtering conditions in addition to max_length
    eng_prefixes = (
        "i am ", "i m ",
        "he is", "he s ",
        "she is", "she s ",
        "you are", "you re ",
        "we are", "we re ",
        "they are", "they re "
    )

    pairs=[]
    for pair in initpairs:
        if len(pair[0].split(" ")) < max_length and len(pair[1].split(" ")) < max_length and pair[0].lower().startswith(eng_prefixes):
            pairs.append(pair)

    print("Number of pairs after filtering:", len(pairs))
    return pairs #list of lists

In [4]:
class Vocab:
    def __init__(self, name):
        self.name=name
        self.word2index={'SOS':0, 'EOS':1}
        self.index2word={0:'SOS', 1:'EOS'}
        self.word2count={}
        self.nwords=2
    
    def buildVocab(self,s):
        for word in s.split(" "):
            if word not in self.word2index:
                self.word2index[word]=self.nwords
                self.index2word[self.nwords]=word
                self.word2count[word]=1
                self.nwords+=1
            else:
                self.word2count[word]+=1

In [5]:
class Encoder(nn.Module):
    def __init__(self, input_size, embed_size, hidden_size, dropout_p=0.1):
        super().__init__()
        self.e=nn.Embedding(input_size, embed_size)
        self.dropout=nn.Dropout(dropout_p)
        self.gru=nn.GRU(embed_size,hidden_size, batch_first=True)
    
    def forward(self,x):
        x=self.e(x)
        x=self.dropout(x)
        outputs, hidden=self.gru(x)
        return outputs, hidden

In [6]:
class Decoder(nn.Module):
    def __init__(self,output_size,embed_size,hidden_size):
        super().__init__()
        self.e=nn.Embedding(output_size,embed_size)
        self.relu=nn.ReLU()
        self.gru=nn.GRU(embed_size, hidden_size, batch_first=True)
        self.lin=nn.Linear(hidden_size,output_size)
        self.lsoftmax=nn.LogSoftmax(dim=-1)
    
    def forward(self,x,prev_hidden):
        x=self.e(x)
        x=self.relu(x)
        output,hidden=self.gru(x,prev_hidden)
        y=self.lin(output)
        y=self.lsoftmax(y)
        return y, hidden

In [7]:
def get_input_ids(sentence,langobj):
    input_ids=[]
    for word in sentence.split(" "):
        input_ids.append(langobj.word2index[word])
    
    if langobj.name=='fre': #translation-direction sensitive
        input_ids.append(langobj.word2index['EOS'])
    else:
        input_ids.insert(0,langobj.word2index['SOS'])
        input_ids.append(langobj.word2index['EOS'])
    return torch.tensor(input_ids)

In [8]:
class CustomDataset(Dataset):
    def __init__(self):
        super().__init__()
    
    def __len__(self):
        return length
    
    def __getitem__(self,idx):
        t=pairs[idx][0] #translation-direction sensitive
        s=pairs[idx][1] #translation-direction sensitive
        s_input_ids=torch.zeros(max_length+1, dtype=torch.int64)
        t_input_ids=torch.zeros(max_length+2, dtype=torch.int64)
        s_input_ids[:len(s.split(" "))+1]=get_input_ids(s,fre) #translation-direction sensitive
        t_input_ids[:len(t.split(" "))+2]=get_input_ids(t,eng) #translation-direction sensitive
        
        return s_input_ids, t_input_ids

In [9]:
def train_one_epoch():
    encoder.train()
    decoder.train()
    track_loss=0
    
    for i, (s_ids,t_ids) in enumerate(train_dataloader):
        s_ids=s_ids.to(device)
        t_ids=t_ids.to(device)
        encoder_outputs, encoder_hidden=encoder(s_ids)
        decoder_hidden=encoder_hidden
        yhats, decoder_hidden = decoder(t_ids[:,0:-1],decoder_hidden)
                    
        gt=t_ids[:,1:]
        
        yhats_reshaped=yhats.view(-1,yhats.shape[-1])
        
        gt=gt.reshape(-1)
        
        
        loss=loss_fn(yhats_reshaped,gt)
        track_loss+=loss.item()
        
        opte.zero_grad()
        optd.zero_grad()
        
        loss.backward()
        
        opte.step()
        optd.step()
        
    return track_loss/len(train_dataloader)

In [10]:
def ids2Sentence(ids,vocab):
    sentence=""
    for id in ids.squeeze():
        if id==0:
            continue
        word=vocab.index2word[id.item()]
        sentence+=word + " "
        if id==1:  
            break
    return sentence

In [11]:
#eval loop (written assuming batch_size=1)
def eval_one_epoch(e,n_epochs):
    encoder.eval()
    decoder.eval()
    track_loss=0
    with torch.no_grad():
        for i, (s_ids,t_ids) in enumerate(test_dataloader):
            s_ids=s_ids.to(device)
            t_ids=t_ids.to(device)
            encoder_outputs, encoder_hidden=encoder(s_ids)
            decoder_hidden=encoder_hidden #n_dim=3
            input_ids=t_ids[:,0]
            yhats=[]
            if e+1==n_epochs:
                pred_sentence=""
            for j in range(1,max_length+2): #j starts from 1
                probs, decoder_hidden = decoder(input_ids.unsqueeze(1),decoder_hidden)
                yhats.append(probs)
                _,input_ids=torch.topk(probs,1,dim=-1)
                input_ids=input_ids.squeeze(1,2) #still a tensor
                if e+1==n_epochs:
                    word=eng.index2word[input_ids.item()] #batch_size=1
                    pred_sentence+=word + " "
                if input_ids.item() == 1: #batch_size=1
                    break
                                
            if e+1==n_epochs:
                src_sentence=ids2Sentence(s_ids,fre) #translation-direction sensitive
                gt_sentence=ids2Sentence(t_ids[:,1:],eng) #translation-direction sensitive

                print("\n-----------------------------------")
                print("Source Sentence:",src_sentence)
                print("GT Sentence:",gt_sentence)
                print("Predicted Sentence:",pred_sentence)
            
            yhats_cat=torch.cat(yhats,dim=1)
            yhats_reshaped=yhats_cat.view(-1,yhats_cat.shape[-1])
            gt=t_ids[:,1:j+1]
            gt=gt.view(-1)
            

            loss=loss_fn(yhats_reshaped,gt)
            track_loss+=loss.item()
            
        if e+1==n_epochs:    
            print("-----------------------------------")
        return track_loss/len(test_dataloader)

In [13]:
#driver code

#read data
data=open("/kaggle/input/eng-fra/eng-fra.txt").read().strip().split('\n')
print("Total number of pairs:",len(data))

#create pairs (create + normalize)
initpairs=createNormalizedPairs() #list of lists. Each inner list is a pair

#filter pairs
pairs=filterPairs(initpairs)
length=len(pairs)

#create Vocab objects for each language
eng=Vocab('eng')
fre=Vocab('fre')

#build the vocab
for pair in pairs:
    eng.buildVocab(pair[0])
    fre.buildVocab(pair[1])

#print vocab size
print("English Vocab Length:",eng.nwords)
print("French Vocab Length:",fre.nwords)    
    
dataset=CustomDataset()
train_dataset,test_dataset=random_split(dataset,[0.99,0.01])

batch_size=32
train_dataloader=DataLoader(dataset=train_dataset,batch_size=batch_size, shuffle=False)
test_dataloader=DataLoader(dataset=test_dataset,batch_size=1, shuffle=False)

    
embed_size=300
hidden_size=512

encoder=Encoder(fre.nwords,embed_size,hidden_size).to(device) #translation-direction sensitive
decoder=Decoder(eng.nwords,embed_size,hidden_size).to(device) #translation-direction sensitive

loss_fn=nn.NLLLoss(ignore_index=0).to(device)
lr=0.001
opte=optim.Adam(params=encoder.parameters(), lr=lr, weight_decay=0.001)
optd=optim.Adam(params=decoder.parameters(), lr=lr, weight_decay=0.001)

n_epochs=80

for e in range(n_epochs):
    print("Epoch=",e+1, sep="", end=", ")
    print("Train Loss=", round(train_one_epoch(),4), sep="", end=", ")
    print("Eval Loss=",round(eval_one_epoch(e,n_epochs),4), sep="")

Total number of pairs: 135842
Number of pairs after filtering: 11445
English Vocab Length: 2991
French Vocab Length: 4601
Epoch=1, Train Loss=3.1455, Eval Loss=3.4373
Epoch=2, Train Loss=2.3968, Eval Loss=3.1169
Epoch=3, Train Loss=2.1571, Eval Loss=2.8849
Epoch=4, Train Loss=2.0365, Eval Loss=2.7919
Epoch=5, Train Loss=1.9634, Eval Loss=2.7005
Epoch=6, Train Loss=1.914, Eval Loss=2.6888
Epoch=7, Train Loss=1.8664, Eval Loss=2.5019
Epoch=8, Train Loss=1.8219, Eval Loss=2.3599
Epoch=9, Train Loss=1.7805, Eval Loss=2.3376
Epoch=10, Train Loss=1.7352, Eval Loss=2.3319
Epoch=11, Train Loss=1.6939, Eval Loss=2.3716
Epoch=12, Train Loss=1.6522, Eval Loss=2.3173
Epoch=13, Train Loss=1.6191, Eval Loss=2.1812
Epoch=14, Train Loss=1.5781, Eval Loss=2.2088
Epoch=15, Train Loss=1.5509, Eval Loss=2.1708
Epoch=16, Train Loss=1.5253, Eval Loss=2.1575
Epoch=17, Train Loss=1.4973, Eval Loss=2.1922
Epoch=18, Train Loss=1.4721, Eval Loss=2.1058
Epoch=19, Train Loss=1.4519, Eval Loss=2.1896
Epoch=20, Trai