This code is inspired from the work at https://www.kaggle.com/code/priyankdl/machine-translation-seq-2-seq-bahdanau-attention

In [None]:
!pip install torch==2.0.1 torchtext==0.15.2

Collecting torch==2.0.1
  Downloading torch-2.0.1-cp310-cp310-manylinux1_x86_64.whl.metadata (24 kB)
Collecting torchtext==0.15.2
  Downloading torchtext-0.15.2-cp310-cp310-manylinux1_x86_64.whl.metadata (7.4 kB)
Collecting nvidia-cuda-nvrtc-cu11==11.7.99 (from torch==2.0.1)
  Downloading nvidia_cuda_nvrtc_cu11-11.7.99-2-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu11==11.7.99 (from torch==2.0.1)
  Downloading nvidia_cuda_runtime_cu11-11.7.99-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cuda-cupti-cu11==11.7.101 (from torch==2.0.1)
  Downloading nvidia_cuda_cupti_cu11-11.7.101-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu11==8.5.0.96 (from torch==2.0.1)
  Downloading nvidia_cudnn_cu11-8.5.0.96-2-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu11==11.10.3.66 (from torch==2.0.1)
  Downloading nvidia_cublas_cu11-11.10.3.66-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Co

In [None]:
!pip install portalocker>=2.0.0

In [None]:
!python -m spacy download en_core_web_sm
!python -m spacy download de_core_news_sm

Collecting en-core-web-sm==3.7.1
  Downloading https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.1/en_core_web_sm-3.7.1-py3-none-any.whl (12.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m12.8/12.8 MB[0m [31m52.5 MB/s[0m eta [36m0:00:00[0m
[38;5;2m✔ Download and installation successful[0m
You can now load the package via spacy.load('en_core_web_sm')
[38;5;3m⚠ Restart to reload dependencies[0m
If you are in a Jupyter or Colab notebook, you may need to restart Python in
order to load all the package's dependencies. You can do this by selecting the
'Restart kernel' or 'Restart runtime' option.
Collecting de-core-news-sm==3.7.0
  Downloading https://github.com/explosion/spacy-models/releases/download/de_core_news_sm-3.7.0/de_core_news_sm-3.7.0-py3-none-any.whl (14.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m14.6/14.6 MB[0m [31m70.2 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: de-core-ne

In [None]:
import torch
print(torch.__version__)

2.0.1+cu117


In [None]:
import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader

import torchtext
from torchtext.vocab import build_vocab_from_iterator
from torchtext.datasets import multi30k,Multi30k
from torchtext.data.utils import get_tokenizer
from torch.nn.utils.rnn import pad_sequence
from functools import partial


In [None]:
multi30k.URL["train"]="https://raw.githubusercontent.com/neychev/small_DL_repo/master/datasets/Multi30k/training.tar.gz"
multi30k.URL["test"]="https://raw.githubusercontent.com/neychev/small_DL_repo/master/datasets/Multi30k/validation.tar.gz"

In [None]:
train_dataset=Multi30k(split='train',language_pair=('de','en'))

de_tokenizer=get_tokenizer('spacy','de_core_news_sm')
en_tokenizer=get_tokenizer('spacy','en_core_web_sm')

en_vocab_iter=[]
de_vocab_iter=[]

for en,de in train_dataset:
  en_vocab_iter.append(en_tokenizer(en))
  de_vocab_iter.append(de_tokenizer(de))

en_vocab=build_vocab_from_iterator(
    en_vocab_iter,
    specials=["<unk>","<pad>""<sos>","<eos>"],
    special_first=True,
    min_freq=1
)

en_vocab.set_default_index(en_vocab["<unk>"])
print("English Vocab Size=",en_vocab.__len__())

de_vocab=build_vocab_from_iterator(
    de_vocab_iter,
    specials=["<unk>","<pad>""<sos>","<eos>"],
    special_first=True,
    min_freq=1
)

de_vocab.set_default_index(en_vocab["<unk>"])
print("German Vocab Size=",de_vocab.__len__())



English Vocab Size= 18543
German Vocab Size= 11397


In [None]:
def text_pipeline(text,idx):  #0 for de and 1 for en
  if (idx==0):
    return de_vocab.lookup_indices(de_tokenizer(text))

  else:
    return en_vocab.lookup_indices(en_tokenizer(text))

In [None]:
def collate_fn(batch):
  input=[]
  gt=[]
  for de,en in batch:
    de_tokenized=text_pipeline(de,0)
    en_tokenized=text_pipeline(en,1)

    de_tokenized.append(de_vocab["<eos>"])
    en_tokenized.append(en_vocab["<eos>"])

    en_tokenized.insert(0,en_vocab["<sos>"])

    input.append(torch.tensor(de_tokenized))
    gt.append(torch.tensor(en_tokenized,dtype=torch.long))

  input=pad_sequence(input,padding_value=de_vocab["<pad>"],batch_first=True)
  gt=pad_sequence(gt,padding_value=en_vocab["<pad>"],batch_first=True)

  return input,gt


In [None]:
batch_size=16
embed_size=300
hidden_size=512

In [None]:
class Encoder(nn.Module):
  def __init__(self,input_size,embed_size,hidden_size):
    super().__init__()

    self.embed=nn.Embedding(input_size,embed_size)
    self.gru=nn.GRU(embed_size,hidden_size,batch_first=True)

  def forward(self,x):
    x=self.embed(x)
    outputs,hidden=self.gru(x)

    return outputs,hidden

class Decoder(nn.Module):
  def __init__(self,output_size,embed_size,hidden_size,context_size):
    super().__init__()
    self.lin=nn.Linear(hidden_size,output_size)
    self.gru=nn.GRU(embed_size+context_size,hidden_size,batch_first=True)

  def forward(self,x,prev):
    output,hidden=self.gru(x,prev)
    output=self.lin(output)
    return output,hidden


class Bahdanau(nn.Module):
  def __init__(self,input_size,embed_size,encoder_hidden_size,decoder_hidden_size,new_hidden):
    super().__init__()
    self.embed=nn.Embedding(input_size,embed_size)
    self.lin1=nn.Linear(decoder_hidden_size,new_hidden)
    self.lin2=nn.Linear(encoder_hidden_size,new_hidden)
    self.score=nn.Linear(new_hidden,1)
    self.soft=nn.Softmax(dim=2)


  def forward(self,x,prev_hidden,encoder_hidden):
    x=self.embed(x)
    prev_hidden=prev_hidden.permute(1,0,2)
    prev_hidden=self.lin1(prev_hidden)
    encoder_hidden=self.lin2(encoder_hidden)
    encoder_hidden2=encoder_hidden+prev_hidden

    y=torch.tanh(encoder_hidden2)
    y=self.score(y)
    y=y.permute(0,2,1)
    y=self.soft(y)
    to_return=torch.bmm(y,encoder_hidden)

    x=torch.cat((x,to_return),dim=-1)
    return x

In [None]:
encoder=Encoder(de_vocab.__len__(),300,512)
decoder=Decoder(en_vocab.__len__(),300,512,400)
bahdanau=Bahdanau(en_vocab.__len__(),300,512,512,400)

In [None]:
en_optimiser=optim.Adam(encoder.parameters(),lr=0.001)
de_optimiser=optim.Adam(decoder.parameters(),lr=0.001)
b_optimiser=optim.Adam(bahdanau.parameters(),lr=0.001)

In [None]:
loss_function=nn.CrossEntropyLoss()

In [None]:
def train_one_epoch():

  train_dataset=Multi30k(split='train',language_pair=('de','en'))
  train_dataloader=DataLoader(train_dataset,batch_size=batch_size,shuffle=True,collate_fn=collate_fn)

  total_loss=0



  for i,(input,gt) in enumerate(train_dataloader):
    encoder_outputs,encoder_hidden=encoder(input)

    decoder_hidden=encoder_hidden

    for i in range(gt.shape[1]-1):

      to_input=gt[:,i:i+1]
      # print("to_input shape:",to_input.shape)
      new_x=bahdanau(to_input,decoder_hidden,encoder_outputs)
      # print("new_x shape:",new_x.shape)
      decoder_output,decoder_hidden=decoder(new_x,decoder_hidden)
      # print("decoder output shape: ",decoder_output.shape)
      # print("decoder_hidden shape:",decoder_hidden.shape)
      if (i==0):
        y_hats=decoder_output
      else:
        y_hats=torch.cat((y_hats,decoder_output),dim=1)

    # print("yhats shape:",y_hats.shape)
    to_compare=gt[:,1:]
    # print("to_compare shape:",to_compare.shape)
    to_compare=to_compare.reshape(to_compare.shape[0]*to_compare.shape[1])
    # print("to_compare shape2:",to_compare.shape)
    y_hats=y_hats.reshape(y_hats.shape[0]*y_hats.shape[1],en_vocab.__len__())
    # print("Y hats shape 2:",y_hats.shape)
    loss=loss_function(y_hats,to_compare)

    total_loss+=loss.item()

    print("Train loss on batch ",i+1," : ",total_loss/(i+1))

    en_optimiser.zero_grad()
    de_optimiser.zero_grad()
    b_optimiser.zero_grad()

    loss.backward()

    en_optimiser.step()
    de_optimiser.step()
    b_optimiser.step()

In [None]:
train_one_epoch()

Train loss on batch  16  :  0.6127297282218933
Train loss on batch  22  :  0.7747320045124401
Train loss on batch  16  :  1.3872261941432953
Train loss on batch  20  :  1.2702928900718689
Train loss on batch  22  :  1.2849308902567083
Train loss on batch  25  :  1.2284870052337646
Train loss on batch  20  :  1.6494329571723938
Train loss on batch  25  :  1.3972454261779785
Train loss on batch  19  :  1.9460628283651251
Train loss on batch  19  :  2.0510202708997225
Train loss on batch  20  :  2.0304929077625276
Train loss on batch  25  :  1.6868005657196046
Train loss on batch  22  :  1.9797257347540422
Train loss on batch  22  :  2.045749306678772
Train loss on batch  23  :  2.0275291670923647
Train loss on batch  27  :  1.7743705687699494
Train loss on batch  27  :  1.8149370087517633
Train loss on batch  23  :  2.1877728441487188
Train loss on batch  20  :  2.5777811288833616
Train loss on batch  20  :  2.65565710067749
Train loss on batch  25  :  2.175804085731506
Train loss on bat

KeyboardInterrupt: 