In [1]:
import torch.nn as nn
import torch
from itertools import chain
from collections import Counter
from torchtext.vocab import vocab
from torch.utils.data import Dataset,random_split,DataLoader
import random
from torch.nn.utils.rnn import pad_sequence
from functools import partial
import os
import json
from tqdm import tqdm
from torchtext.data.metrics import bleu_score
from torch.utils.tensorboard import SummaryWriter

In [2]:
num_epochs=10
learning_rate=0.003
batch_size=1024

In [3]:
def preprocess(data):
    data=data.replace("\u202f"," ").replace("\xa0"," ").replace("\u2009"," ")
    no_space=lambda char,prev_char: char in ",.!?" and prev_char !=' '
    out=[' '+char if i>0 and no_space(char,data[i-1]) else char
            for i,char in enumerate(data)]
    out=''.join(out)
    out=["\t".join(sentence.split("\t")[:2]) for sentence in out.split('\n')]
    out='\n'.join(out)
    return out

In [4]:
def build_vocab(list_tokens):
    tokens=sorted(chain.from_iterable((list_tokens)))
    token_freq=Counter(tokens)
    vocabulary=vocab(token_freq,specials=['<unk>','<pad>'])
    vocabulary.set_default_index(0)
    return vocabulary
    

In [5]:
def tokenizer(text):
    return [token for token in f"<sos> {text} <eos>".split(" ") if token]

In [6]:
def separate_src_tgt(data,max_samples=None):
    src=[]
    tgt=[]
    for i,text in enumerate(data):
        if max_samples and i> max_samples:break
        parts= text.split('\t')
        if len(parts)==2:
            src.append(tokenizer(parts[0]))
            tgt.append(tokenizer(parts[1]))
    return src,tgt
    

In [7]:
class CustomDataset(Dataset):
    def __init__(self,dataset) -> None:
        super().__init__()
        self.src_data=dataset[0]
        self.tgt_data=dataset[1]
    def __len__(self):
        return len(self.src_data)
    def __getitem__(self, index):
        return self.src_data[index],self.tgt_data[index]

In [8]:
def data_process_pipeline(data,eng_vocab=None,fra_vocab=None):
    src,tgt=separate_src_tgt(data)
    if eng_vocab is None:
        eng_vocab=build_vocab(src)
        fra_vocab=build_vocab(tgt)
    src_idx=[eng_vocab.forward(sent) for sent in src]
    tgt_idx=[fra_vocab.forward(sent) for sent in tgt]
    train_dataset=CustomDataset((src_idx,tgt_idx))
    return train_dataset,eng_vocab,fra_vocab

In [9]:
def train_test_split(dataset,train_percent):
    train_size=int(len(dataset)*train_percent)
    train_data=dataset[:train_size]
    test_data=dataset[train_size:]
    return train_data,test_data
    

In [10]:
def load_data(data_path,train_size):
    with open(data_path,'r',encoding="utf-8") as fp:
        data=fp.read()
    clean_data=preprocess(data)
    sent_list=[sent for sent in clean_data.split("\n") if len(sent)>0]
    sorted_sent_list=sorted(sent_list,key=lambda x: len(x.split('\t')[0].split(' ')))
    train_data,test_data=train_test_split(sorted_sent_list,train_size)
    return train_data,test_data

In [11]:
raw_train_data,raw_val_data=load_data("data/fra.txt",0.8)

In [12]:
train_data,eng_vocab,fra_vocab=data_process_pipeline(raw_train_data)
val_data,_,_=data_process_pipeline(raw_val_data,)


In [13]:
def collate_fn(train_data,src_pad_val,tgt_pad_val):
    src_data=[torch.LongTensor(src[0]) for src in train_data]
    tgt_data=[torch.LongTensor(tgt[1]) for tgt in train_data]
    src_tensor=pad_sequence(src_data,padding_value=src_pad_val)
    tgt_tensor=pad_sequence(tgt_data,padding_value=tgt_pad_val)
    return src_tensor,tgt_tensor

train_loader = DataLoader(train_data, batch_size=batch_size, num_workers=8,collate_fn=partial(collate_fn,src_pad_val=eng_vocab['<pad>'],tgt_pad_val=fra_vocab['<pad>']))
val_loader=DataLoader(val_data,batch_size=batch_size,num_workers=8,collate_fn=partial(collate_fn,src_pad_val=eng_vocab['<pad>'],tgt_pad_val=fra_vocab['<pad>']))

In [14]:

input_size_encoder=len(eng_vocab)
input_size_decoder=len(fra_vocab)
output_size_decoder=len(fra_vocab)
encoder_embedding_size=300
decoder_embedding_size=300
hidden_size=1024
encoder_n_layers=2
decoder_n_layers=2
encoder_dropout=0.5
decoder_dropout=0.5
teacher_forcing_ratio=0.5
device="cuda" if torch.cuda.is_available() else "cpu"

In [15]:
class Encoder(nn.Module):
    def __init__(self,vocab_size,embed_size,hidden_size,num_layers,drop_prob):
        super().__init__()
        self.embed_layer=nn.Embedding(vocab_size,embed_size)
        self.rnn=nn.LSTM(embed_size,hidden_size,num_layers,dropout=drop_prob)
    def forward(self,X):
        embeddings=self.embed_layer(X)
        # print(f"embeddings shape : {embeddings.shape}")
        outputs,(hidden,cell)=self.rnn(embeddings)
        return hidden,cell

In [16]:
class Decoder(nn.Module):
    def __init__(self,vocab_size,embed_size,hidden_size,num_layers,drop_prob,output_size) -> None:
        super().__init__()
        self.embed=nn.Embedding(vocab_size,embed_size)
        self.rnn=nn.LSTM(embed_size,hidden_size,num_layers=num_layers,dropout=drop_prob)
        self.fc=nn.Linear(hidden_size,output_size)
    def forward(self,X,hidden,cell):
        X=X.unsqueeze(0)
        embeddings=self.embed(X)
        outputs,(hidden,cell)=self.rnn(embeddings,(hidden,cell))
        predictions=self.fc(outputs)
        predictions=predictions.squeeze(0)
        return predictions   

In [17]:
class Seq2Seq(nn.Module):
    def __init__(self,encoder,decoder) -> None:
        super().__init__()
        self.encoder=encoder
        self.decoder=decoder
    def forward(self,src,tgt,tgt_vocab_size,teacher_force_ratio):
        
        batch_size=src.shape[1]
        target_len=tgt.shape[0]
        outputs=torch.zeros((target_len,batch_size,tgt_vocab_size))
        hidden,cell=self.encoder(src)
        x=tgt[0]
        for t in range(1,target_len):
            output=self.decoder(x,hidden,cell)
            outputs[t]=output
            teacher_force=random.random()<teacher_force_ratio
            top1=output.argmax(1)
            x=tgt[t] if teacher_force else top1
            
        return outputs

In [18]:
class Trainer:
    def __init__(
        self,
        model,
        num_epochs,
        batch_size,
        criterion,
        optimizer,
        learning_rate,
        output_size_decoder,
        teacher_forcing_ratio,
        device,
        print_stats,
        tgt_vocab,
        pred_pipeline,
    ):
        self.num_epochs = num_epochs
        self.batch_size = batch_size
        self.output_size_decoder=output_size_decoder
        self.criterion = criterion
        self.optimizer = optimizer
        self.learning_rate = learning_rate
        self.teacher_forcing_ratio = teacher_forcing_ratio
        self.device = device
        self.print_stats=print_stats
        self.tgt_vocab=tgt_vocab
        self.pred_pipeline=pred_pipeline
        self.loss={"train":[],'val':[]}
        self.bleu={"train":[],'val':[]}
        self.model=model.to(device)
        self.writer=SummaryWriter()
    def train(self, train_loader, val_loader, save_path):
        for epoch in range(self.num_epochs):
            self._train_epoch(train_loader)
            self._val_epoch(val_loader)
            if self.print_stats:
                print(f"\n###########{epoch}#############\n")
                print(f"train loss : {self.loss['train'][-1]}\t val loss: {self.loss['val'][-1]}")
                print(f"train bleu score : {self.bleu['train'][-1]} \t val bleu score : {self.bleu['val'][-1]}")
                test_txt="In this story an old man sets out to ask an Indian king to dig some well in his village when their water runs dry"
                expected_translation="Dans cette histoire, un vieil homme entreprend de demander à un roi indien de creuser un puits dans son village lorsque leur eau sera à sec."
                predicted_translation=self.transcribe(test_txt,self.pred_pipeline,50,189,self.tgt_vocab,188)
                blue_score=bleu_score([predicted_translation[0].split()],[[expected_translation.split()]],max_n=2,weights=[0.5,0.5])
                print(f"expected translation : \t {expected_translation}")
                print(f"predicted trainslation : \t{predicted_translation}")
                print(f"bleu_score : {blue_score}")
                print("\n-----------------------------------\n")
                if epoch==0:
                    with open(os.path.join(save_path,"sample_bleu.json"),"w+") as fp :
                        json.dump({"src text":test_txt,
                         "results":[]},fp)
                with open(os.path.join(save_path,"sample_bleu.json"),"r+") as fp :
                    result={"expected tranlation ":expected_translation,
                          "predicted translation":predicted_translation,
                          "bleu score ":blue_score,
                          "epoch":epoch
                          }
                    prev_data=json.load(fp)
                    prev_data["results"].append(result)
                    fp.seek(0)
                    json.dump(prev_data,fp)
            self.writer.add_scalar("train loss ",self.loss['train'][-1],epoch)
            self.writer.add_scalar("val loss ",self.loss['val'][-1],epoch)
            self.writer.add_scalar("train bleu ",self.bleu['train'][-1],epoch)
            self.writer.add_scalar("val bleu ",self.bleu['val'][-1],epoch)
        self.save_model(os.path.join(save_path,"model.pt"))
        self.save_loss(os.path.join(save_path,"loss.json"))
        self.save_loss(os.path.join(save_path,"bleu.json"))
    def _train_epoch(self,train_loader):
        i=0
        epoch_loss=0
        epoch_bleu=0
        for batch_src,batch_tgt in tqdm(train_loader):
            i+=1
            source=batch_src.to(self.device)
            target=batch_tgt.to(self.device)
            outputs=self.model(source,target,self.output_size_decoder,self.teacher_forcing_ratio)
            preds=outputs.argmax(-1)
            outputs=outputs[1:].reshape(-1,outputs.shape[-1])
            batch_tgt=batch_tgt[1:].reshape(-1)
            loss=self.criterion(outputs,batch_tgt)
            epoch_loss+=loss.item()
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            pred_tokens=[self.tgt_vocab.lookup_tokens(tokens) for tokens in preds[1:].t().tolist()]
            target_tokens=[[self.tgt_vocab.lookup_tokens(tokens)] for tokens in target.t().tolist()]
            bleu=bleu_score(pred_tokens,target_tokens,max_n=2,weights=[0.5,0.5])
            epoch_bleu+=bleu
        self.loss['train'].append(epoch_loss/i)
        self.bleu['train'].append(epoch_bleu/i)
        
            
    def _val_epoch(self,val_loader):
        i=0
        epoch_bleu=0
        epoch_loss=0
        with torch.no_grad():
            for batch_src,batch_tgt in tqdm(val_loader):
                i+=1
                source=batch_src.to(self.device)
                target=batch_tgt.to(self.device)
                outputs=self.model(source,target,self.output_size_decoder,self.teacher_forcing_ratio)
                preds=outputs.argmax(-1)
                outputs=outputs[1:].reshape(-1,outputs.shape[-1])
                batch_tgt=batch_tgt[1:].reshape(-1)
                loss=self.criterion(outputs,batch_tgt)
                pred_tokens=[self.tgt_vocab.lookup_tokens(tokens) for tokens in preds[1:].t().tolist()]
                target_tokens=[[self.tgt_vocab.lookup_tokens(tokens)] for tokens in target.t().tolist()]
                bleu=bleu_score(pred_tokens,target_tokens,max_n=2,weights=[0.5,0.5])
                epoch_bleu+=bleu
                epoch_loss+=loss.item()
        self.loss['val'].append(epoch_loss/i)
        self.bleu['val'].append(epoch_bleu/i)
    def save_model(self,save_path):
        torch.save(self.model,save_path)
    def save_loss(self,save_path):
        with open(save_path,"w") as fp:
            json.dump(self.loss,fp)
    def save_bleu_score(self,save_path):
        with open(save_path,'w') as fp:
            json.dump(self.bleu,fp)
    
    def transcribe(self,inputs,pipeline,max_tokens,start_token,tgt_vocab,end_token):
        with torch.no_grad():
            inputs=pipeline(inputs)
            inputs=inputs.to(self.device)
            hidden,cell=self.model.encoder(inputs)
            output_tokens=[]
            i=0
            for i in range(inputs.shape[1]):
                x=torch.LongTensor([start_token]).to(self.device)
                current_output=[]
                while True:
                    predictions=self.model.decoder(x,hidden,cell)
                    predictions=predictions.argmax(1)
                    pred_tokens=tgt_vocab.lookup_token(predictions.item())
                    current_output.append(pred_tokens)
                    x=predictions
                    if end_token==predictions or len(current_output)>=max_tokens:
                        break
                output_tokens.append(" ".join(current_output))
        return output_tokens

In [19]:
encoder=Encoder(len(eng_vocab),embed_size=encoder_embedding_size,hidden_size=hidden_size,num_layers=encoder_n_layers,drop_prob=encoder_dropout)
decoder=Decoder(len(fra_vocab),embed_size=decoder_embedding_size,hidden_size=hidden_size,num_layers=decoder_n_layers,drop_prob=decoder_dropout,output_size=len(fra_vocab))
model=Seq2Seq(encoder=encoder,decoder=decoder)

In [20]:
criterion=nn.CrossEntropyLoss(ignore_index=eng_vocab['<pad>'])
optimizer=torch.optim.Adam(model.parameters(),lr=learning_rate)

In [21]:
def predict_pipeline(txt):
    if isinstance(txt,str):
        txt=[txt]
    sent_tokens=[tokenizer(tokens) for tokens in txt]
    int_tokens=[eng_vocab.forward(tokens) for tokens in sent_tokens]
    src_tensor=[torch.LongTensor(token_list) for token_list in int_tokens]
    src=pad_sequence(src_tensor,padding_value=eng_vocab['<pad>'])
    return src

In [22]:
trainer=Trainer(model=model,num_epochs=num_epochs,batch_size=batch_size,criterion=criterion,optimizer=optimizer,learning_rate=learning_rate,output_size_decoder=len(fra_vocab),teacher_forcing_ratio=teacher_forcing_ratio,device=device,print_stats=True,tgt_vocab=fra_vocab,pred_pipeline=predict_pipeline)

In [23]:
trainer.train(train_loader,val_loader,"results")

100%|████████████████████████████████████████████████████████████████████████████████| 171/171 [17:50<00:00,  6.26s/it]
100%|██████████████████████████████████████████████████████████████████████████████████| 43/43 [01:48<00:00,  2.53s/it]



###########0#############

train loss : 5.429914100825438	 val loss: 13.663817250451377
train bleu score : 0.09613293382107199 	 val bleu score : 0.0
expected translation : 	 Dans cette histoire, un vieil homme entreprend de demander à un roi indien de creuser un puits dans son village lorsque leur eau sera à sec.
predicted trainslation : 	['« temps ? <eos>']
bleu_score : 0.0

-----------------------------------



100%|████████████████████████████████████████████████████████████████████████████████| 171/171 [18:13<00:00,  6.39s/it]
100%|██████████████████████████████████████████████████████████████████████████████████| 43/43 [01:48<00:00,  2.53s/it]



###########1#############

train loss : 4.468157688776652	 val loss: 13.564084651858307
train bleu score : 0.13301422644803457 	 val bleu score : 0.0
expected translation : 	 Dans cette histoire, un vieil homme entreprend de demander à un roi indien de creuser un puits dans son village lorsque leur eau sera à sec.
predicted trainslation : 	['La de la . <eos>']
bleu_score : 0.0

-----------------------------------



100%|████████████████████████████████████████████████████████████████████████████████| 171/171 [18:13<00:00,  6.39s/it]
100%|██████████████████████████████████████████████████████████████████████████████████| 43/43 [01:49<00:00,  2.54s/it]



###########2#############

train loss : 3.623633338693987	 val loss: 13.8826162205186
train bleu score : 0.1672759608698524 	 val bleu score : 3.1363901035217715e-05
expected translation : 	 Dans cette histoire, un vieil homme entreprend de demander à un roi indien de creuser un puits dans son village lorsque leur eau sera à sec.
predicted trainslation : 	['La voiture ? <eos>']
bleu_score : 0.0

-----------------------------------



100%|████████████████████████████████████████████████████████████████████████████████| 171/171 [18:08<00:00,  6.37s/it]
100%|██████████████████████████████████████████████████████████████████████████████████| 43/43 [01:48<00:00,  2.51s/it]



###########3#############

train loss : 2.996352057707937	 val loss: 14.292387585307276
train bleu score : 0.1977042283981259 	 val bleu score : 2.359135772438463e-05
expected translation : 	 Dans cette histoire, un vieil homme entreprend de demander à un roi indien de creuser un puits dans son village lorsque leur eau sera à sec.
predicted trainslation : 	["Dans quel ordinateur , c'est un jour dans ce bâtiment a un jour dans ce bâtiment de ce soir . <eos>"]
bleu_score : 0.0

-----------------------------------



100%|████████████████████████████████████████████████████████████████████████████████| 171/171 [17:44<00:00,  6.22s/it]
100%|██████████████████████████████████████████████████████████████████████████████████| 43/43 [01:48<00:00,  2.52s/it]



###########4#############

train loss : 2.4698233409234653	 val loss: 14.85745423339134
train bleu score : 0.2287104228665654 	 val bleu score : 9.393288329707728e-06
expected translation : 	 Dans cette histoire, un vieil homme entreprend de demander à un roi indien de creuser un puits dans son village lorsque leur eau sera à sec.
predicted trainslation : 	['À utiliser une histoire pour utiliser ce livre ou utiliser ce livre ? <eos>']
bleu_score : 0.0

-----------------------------------



100%|████████████████████████████████████████████████████████████████████████████████| 171/171 [17:45<00:00,  6.23s/it]
100%|██████████████████████████████████████████████████████████████████████████████████| 43/43 [01:48<00:00,  2.52s/it]



###########5#############

train loss : 2.065730781583061	 val loss: 15.621812953505405
train bleu score : 0.2575588264363604 	 val bleu score : 1.0183563886935215e-05
expected translation : 	 Dans cette histoire, un vieil homme entreprend de demander à un roi indien de creuser un puits dans son village lorsque leur eau sera à sec.
predicted trainslation : 	['Ce bus ? <eos>']
bleu_score : 0.0

-----------------------------------



100%|████████████████████████████████████████████████████████████████████████████████| 171/171 [17:50<00:00,  6.26s/it]
100%|██████████████████████████████████████████████████████████████████████████████████| 43/43 [01:48<00:00,  2.52s/it]



###########6#############

train loss : 1.77553441510563	 val loss: 16.46398774967637
train bleu score : 0.2811894808441532 	 val bleu score : 5.930077550331532e-06
expected translation : 	 Dans cette histoire, un vieil homme entreprend de demander à un roi indien de creuser un puits dans son village lorsque leur eau sera à sec.
predicted trainslation : 	['À quel point de à quel <eos>']
bleu_score : 0.0

-----------------------------------



100%|████████████████████████████████████████████████████████████████████████████████| 171/171 [17:50<00:00,  6.26s/it]
100%|██████████████████████████████████████████████████████████████████████████████████| 43/43 [01:48<00:00,  2.52s/it]



###########7#############

train loss : 1.5626653997521651	 val loss: 16.88708873127782
train bleu score : 0.3011173415359612 	 val bleu score : 8.043150106197592e-05
expected translation : 	 Dans cette histoire, un vieil homme entreprend de demander à un roi indien de creuser un puits dans son village lorsque leur eau sera à sec.
predicted trainslation : 	['Le bus a un <eos>']
bleu_score : 0.0

-----------------------------------



100%|████████████████████████████████████████████████████████████████████████████████| 171/171 [17:48<00:00,  6.25s/it]
100%|██████████████████████████████████████████████████████████████████████████████████| 43/43 [01:47<00:00,  2.51s/it]



###########8#############

train loss : 1.4190863597462748	 val loss: 17.311570988144986
train bleu score : 0.3157033586118601 	 val bleu score : 5.269549539664291e-05
expected translation : 	 Dans cette histoire, un vieil homme entreprend de demander à un roi indien de creuser un puits dans son village lorsque leur eau sera à sec.
predicted trainslation : 	['Le chien de leur oiseau . <eos>']
bleu_score : 0.0

-----------------------------------



100%|████████████████████████████████████████████████████████████████████████████████| 171/171 [17:50<00:00,  6.26s/it]
100%|██████████████████████████████████████████████████████████████████████████████████| 43/43 [01:47<00:00,  2.51s/it]



###########9#############

train loss : 1.320338190996159	 val loss: 17.702983856201172
train bleu score : 0.3266650063176182 	 val bleu score : 3.693877024823806e-05
expected translation : 	 Dans cette histoire, un vieil homme entreprend de demander à un roi indien de creuser un puits dans son village lorsque leur eau sera à sec.
predicted trainslation : 	['Ce faut un un un un un un un autre zone . <eos>']
bleu_score : 0.0

-----------------------------------

