In [213]:
import torch.nn as nn
import torch
from itertools import chain
from collections import Counter
from torchtext.vocab import vocab
from torch.utils.data import Dataset,random_split,DataLoader
import random
from torch.nn.utils.rnn import pad_sequence
from functools import partial
import os
import json
from tqdm import tqdm
from torchtext.data.metrics import bleu_score

In [214]:
num_epochs=5
learning_rate=0.001
batch_size=64

In [215]:
def preprocess(data):
    data=data.replace("\u202f"," ").replace("\xa0"," ").replace("\u2009"," ")
    no_space=lambda char,prev_char: char in ",.!?" and prev_char !=' '
    out=[' '+char if i>0 and no_space(char,data[i-1]) else char
            for i,char in enumerate(data)]
    out=''.join(out)
    out=["\t".join(sentence.split("\t")[:2]) for sentence in out.split('\n')]
    out='\n'.join(out)
    return out

In [216]:
def build_vocab(list_tokens):
    tokens=sorted(chain.from_iterable((list_tokens)))
    token_freq=Counter(tokens)
    vocabulary=vocab(token_freq,specials=['<unk>','<pad>'])
    vocabulary.set_default_index(0)
    return vocabulary
    

In [217]:
def tokenizer(text):
    return [token for token in f"<sos> {text} <eos>".split(" ") if token]

In [218]:
def separate_src_tgt(data,max_samples=None):
    src=[]
    tgt=[]
    for i,text in enumerate(data):
        if max_samples and i> max_samples:break
        parts= text.split('\t')
        if len(parts)==2:
            src.append(tokenizer(parts[0]))
            tgt.append(tokenizer(parts[1]))
    return src,tgt
    

In [219]:
class CustomDataset(Dataset):
    def __init__(self,dataset) -> None:
        super().__init__()
        self.src_data=dataset[0]
        self.tgt_data=dataset[1]
    def __len__(self):
        return len(self.src_data)
    def __getitem__(self, index):
        return self.src_data[index],self.tgt_data[index]

In [220]:
def data_process_pipeline(data,eng_vocab=None,fra_vocab=None):
    src,tgt=separate_src_tgt(data)
    if eng_vocab is None:
        eng_vocab=build_vocab(src)
        fra_vocab=build_vocab(tgt)
    src_idx=[eng_vocab.forward(sent) for sent in src]
    tgt_idx=[fra_vocab.forward(sent) for sent in tgt]
    train_dataset=CustomDataset((src_idx,tgt_idx))
    return train_dataset,eng_vocab,fra_vocab

In [221]:
def train_test_split(dataset,train_percent):
    train_size=int(len(dataset)*train_percent)
    train_data=dataset[:train_size]
    test_data=dataset[train_size:]
    return train_data,test_data
    

In [222]:
def load_data(data_path,train_size):
    with open(data_path,'r',encoding="utf-8") as fp:
        data=fp.read()
    clean_data=preprocess(data)
    sent_list=[sent for sent in clean_data.split("\n") if len(sent)>0]
    sorted_sent_list=sorted(sent_list,key=lambda x: len(x.split('\t')[0].split(' ')))
    train_data,test_data=train_test_split(sorted_sent_list,train_size)
    return train_data,test_data

In [223]:
raw_train_data,raw_val_data=load_data("fra.txt",0.8)

In [224]:
train_data,eng_vocab,fra_vocab=data_process_pipeline(raw_train_data)
val_data,_,_=data_process_pipeline(raw_val_data,)


In [225]:
def collate_fn(train_data,src_pad_val,tgt_pad_val):
    src_data=[torch.LongTensor(src[0]) for src in train_data]
    tgt_data=[torch.LongTensor(tgt[1]) for tgt in train_data]
    src_tensor=pad_sequence(src_data,padding_value=src_pad_val)
    tgt_tensor=pad_sequence(tgt_data,padding_value=tgt_pad_val)
    return src_tensor,tgt_tensor

train_loader = DataLoader(train_data, batch_size=batch_size, collate_fn=partial(collate_fn,src_pad_val=eng_vocab['<pad>'],tgt_pad_val=fra_vocab['<pad>']))
val_loader=DataLoader(val_data,batch_size=batch_size,collate_fn=partial(collate_fn,src_pad_val=eng_vocab['<pad>'],tgt_pad_val=fra_vocab['<pad>']))

In [226]:

input_size_encoder=len(eng_vocab)
input_size_decoder=len(fra_vocab)
output_size_decoder=len(fra_vocab)
encoder_embedding_size=300
decoder_embedding_size=300
hidden_size=1024
encoder_n_layers=2
decoder_n_layers=2
encoder_dropout=0.5
decoder_dropout=0.5
teacher_forcing_ratio=0.5
device="cuda" if torch.cuda.is_available() else "cpu"

In [227]:
class Encoder(nn.Module):
    def __init__(self,vocab_size,embed_size,hidden_size,num_layers,drop_prob):
        super().__init__()
        self.embed_layer=nn.Embedding(vocab_size,embed_size)
        self.rnn=nn.LSTM(embed_size,hidden_size,num_layers,dropout=drop_prob)
    def forward(self,X):
        embeddings=self.embed_layer(X)
        # print(f"embeddings shape : {embeddings.shape}")
        outputs,(hidden,cell)=self.rnn(embeddings)
        return hidden,cell

In [228]:
class Decoder(nn.Module):
    def __init__(self,vocab_size,embed_size,hidden_size,num_layers,drop_prob,output_size) -> None:
        super().__init__()
        self.embed=nn.Embedding(vocab_size,embed_size)
        self.rnn=nn.LSTM(embed_size,hidden_size,num_layers=num_layers,dropout=drop_prob)
        self.fc=nn.Linear(hidden_size,output_size)
    def forward(self,X,hidden,cell):
        X=X.unsqueeze(0)
        embeddings=self.embed(X)
        outputs,(hidden,cell)=self.rnn(embeddings,(hidden,cell))
        predictions=self.fc(outputs)
        predictions=predictions.squeeze(0)
        return predictions   

In [229]:
class Seq2Seq(nn.Module):
    def __init__(self,encoder,decoder) -> None:
        super().__init__()
        self.encoder=encoder
        self.decoder=decoder
    def forward(self,src,tgt,tgt_vocab_size,teacher_force_ratio):
        
        batch_size=src.shape[1]
        target_len=tgt.shape[0]
        outputs=torch.zeros((target_len,batch_size,tgt_vocab_size))
        hidden,cell=self.encoder(src)
        x=tgt[0]
        for t in range(1,target_len):
            output=self.decoder(x,hidden,cell)
            outputs[t]=output
            teacher_force=random.random()<teacher_force_ratio
            top1=output.argmax(1)
            x=tgt[t] if teacher_force else top1
            
        return outputs

In [230]:
class Trainer:
    def __init__(
        self,
        model,
        num_epochs,
        batch_size,
        criterion,
        optimizer,
        learning_rate,
        output_size_decoder,
        teacher_forcing_ratio,
        device,
        print_stats,
        tgt_vocab,
        pred_pipeline,
    ):
        self.num_epochs = num_epochs
        self.batch_size = batch_size
        self.output_size_decoder=output_size_decoder
        self.criterion = criterion
        self.optimizer = optimizer
        self.learning_rate = learning_rate
        self.teacher_forcing_ratio = teacher_forcing_ratio
        self.device = device
        self.print_stats=print_stats
        self.tgt_vocab=tgt_vocab
        self.pred_pipeline=pred_pipeline
        self.loss={"train":[],'val':[]}
        self.bleu={"train":[],'val':[]}
        self.model=model.to(device)
    def train(self, train_loader, val_loader, save_path):
        for epoch in tqdm(range(self.num_epochs)):
            self._train_epoch(train_loader)
            self._val_epoch(val_loader)
            if self.print_stats:
                print(f"train loss : {self.loss['train'][-1]}\t val loss: {self.loss['val'][-1]}")
                print(f"train bleu score : {self.bleu['train'][-1]} \t val bleu score : {self.bleu['val'][-1]}")
                test_txt="In this story an old man sets out to ask an Indian king to dig some well in his village when their water runs dry"
                expected_translation="Dans cette histoire, un vieil homme entreprend de demander à un roi indien de creuser un puits dans son village lorsque leur eau sera à sec."
                predicted_translation=self.transcribe(test_txt,self.pred_pipeline,50,189,self.tgt_vocab,188)
                print(f"expected translation : \t {expected_translation}")
                print(f"predicted trainslation : \t{predicted_translation}")
                print(f"bleu_score : {bleu_score([predicted_translation[0].split()],[[expected_translation.split()]],max_n=2,weights=[0.5,0.5])}")
                print("\n-----------------------------------\n")
        self.save_model(os.path.join(save_path,"model.pt"))
        self.save_loss(os.path.join(save_path,"loss.json"))
    def _train_epoch(self,train_loader):
        i=0
        epoch_loss=0
        epoch_bleu=0
        for batch_src,batch_tgt in train_loader:
            i+=1
            source=batch_src.to(self.device)
            target=batch_tgt.to(self.device)
            outputs=self.model(source,target,self.output_size_decoder,self.teacher_forcing_ratio)
            preds=outputs.argmax(-1)
            outputs=outputs[1:].reshape(-1,outputs.shape[-1])
            batch_tgt=batch_tgt[1:].reshape(-1)
            loss=self.criterion(outputs,batch_tgt)
            epoch_loss+=loss.item()
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            pred_tokens=[self.tgt_vocab.lookup_tokens(tokens) for tokens in preds[1:].t().tolist()]
            target_tokens=[[self.tgt_vocab.lookup_tokens(tokens)] for tokens in target.t().tolist()]
            bleu=bleu_score(pred_tokens,target_tokens,max_n=2,weights=[0.5,0.5])
            epoch_bleu+=bleu
            if i>=5:
                break
        self.loss['train'].append(epoch_loss/i)
        self.bleu['train'].append(epoch_bleu/i)
                
            
    def _val_epoch(self,val_loader):
        i=0
        epoch_bleu=0
        epoch_loss=0
        with torch.no_grad():
            for batch_src,batch_tgt in val_loader:
                i+=1
                source=batch_src.to(self.device)
                target=batch_tgt.to(self.device)
                outputs=self.model(source,target,self.output_size_decoder,self.teacher_forcing_ratio)
                preds=outputs.argmax(-1)
                outputs=outputs[1:].reshape(-1,outputs.shape[-1])
                batch_tgt=batch_tgt[1:].reshape(-1)
                loss=self.criterion(outputs,batch_tgt)
                pred_tokens=[self.tgt_vocab.lookup_tokens(tokens) for tokens in preds[1:].t().tolist()]
                target_tokens=[[self.tgt_vocab.lookup_tokens(tokens)] for tokens in target.t().tolist()]
                bleu=bleu_score(pred_tokens,target_tokens,max_n=2,weights=[0.5,0.5])
                epoch_bleu+=bleu
                epoch_loss+=loss.item()
                if i>=5:
                    break
        self.loss['val'].append(epoch_loss/i)
        self.bleu['val'].append(epoch_bleu/i)
    def save_model(self,save_path):
        torch.save(self.model,save_path)
    def save_loss(self,save_path):
        with open(save_path,"w") as fp:
            json.dump(self.loss,fp)
    
    def transcribe(self,inputs,pipeline,max_tokens,start_token,tgt_vocab,end_token):
        with torch.no_grad():
            inputs=pipeline(inputs)
            inputs=inputs.to(self.device)
            hidden,cell=self.model.encoder(inputs)
            output_tokens=[]
            i=0
            for i in range(inputs.shape[1]):
                x=torch.LongTensor([start_token]).to(self.device)
                current_output=[]
                while True:
                    predictions=self.model.decoder(x,hidden,cell)
                    predictions=predictions.argmax(1)
                    pred_tokens=tgt_vocab.lookup_token(predictions.item())
                    current_output.append(pred_tokens)
                    x=predictions
                    if end_token==predictions or len(current_output)>=max_tokens:
                        break
                output_tokens.append(" ".join(current_output))
        return output_tokens

In [231]:
encoder=Encoder(len(eng_vocab),embed_size=encoder_embedding_size,hidden_size=hidden_size,num_layers=encoder_n_layers,drop_prob=encoder_dropout)
decoder=Decoder(len(fra_vocab),embed_size=decoder_embedding_size,hidden_size=hidden_size,num_layers=decoder_n_layers,drop_prob=decoder_dropout,output_size=len(fra_vocab))
model=Seq2Seq(encoder=encoder,decoder=decoder)

In [232]:
criterion=nn.CrossEntropyLoss(ignore_index=eng_vocab['<pad>'])
optimizer=torch.optim.Adam(model.parameters(),lr=learning_rate)

In [233]:
def predict_pipeline(txt):
    if isinstance(txt,str):
        txt=[txt]
    sent_tokens=[tokenizer(tokens) for tokens in txt]
    int_tokens=[eng_vocab.forward(tokens) for tokens in sent_tokens]
    src_tensor=[torch.LongTensor(token_list) for token_list in int_tokens]
    src=pad_sequence(src_tensor,padding_value=eng_vocab['<pad>'])
    return src

In [234]:
trainer=Trainer(model=model,num_epochs=num_epochs,batch_size=batch_size,criterion=criterion,optimizer=optimizer,learning_rate=learning_rate,output_size_decoder=len(fra_vocab),teacher_forcing_ratio=teacher_forcing_ratio,device=device,print_stats=True,tgt_vocab=fra_vocab,pred_pipeline=predict_pipeline)

In [235]:
trainer.train(train_loader,val_loader,"results")

  0%|          | 0/5 [00:02<?, ?it/s]


KeyboardInterrupt: 