# Neural Machine Translation with Attention: Thai to English

In [48]:
import torch, torchdata, torchtext
import torch.nn as nn
import torch.nn.functional as F
import random, math, time

import datasets
from datasets import load_dataset

In [49]:
print("Torch version:", torch.__version__)
print("Torchtext version:", torchtext.__version__)

Torch version: 2.1.1+cpu
Torchtext version: 0.16.1+cpu


In [50]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cpu


In [51]:
SEED = 1234
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

## 1. ETL: Loading the dataset

In [52]:
dataset = datasets.load_dataset('Tsunnami/who-en-th')

In [53]:
print(dataset)

DatasetDict({
    train: Dataset({
        features: ['en', 'th'],
        num_rows: 538
    })
})


In [54]:
SRC_LANGUAGE = 'en'
TRG_LANGUAGE = 'th'

train = dataset['train']

In [55]:
train

Dataset({
    features: ['en', 'th'],
    num_rows: 538
})

## 2. EDA - simple investigation

In [56]:
sample = next(iter(train))
sample

{'en': 'Tobacco and nicotine industry tactics addict youth for life',
 'th': 'กลยุทธ์ของอุตสาหกรรมยาสูบและนิโคตินทำให้เยาวชนเสพติดไปชั่วชีวิต'}

In [57]:
# Total size of the dataset
total_size = len(train)

In [58]:
# Split the data
# 80% for training, 10% for validation, 
# and 10% for testing to maximize training data for a small dataset.
train_size = int(0.8 * total_size)
val_size = int(0.1 * total_size)
test_size = total_size - train_size - val_size  # Remaining for test

In [59]:
# First split: Split into 80% train and 20% (val + test)
train, temp = train.train_test_split(test_size=0.2, seed=999).values()

# Second split: Split the remaining 20% into 50% val and 50% test (10% each)
val, test = temp.train_test_split(test_size=0.5, seed=999).values()

print(f"Train size: {len(train)}")
print(f"Validation size: {len(val)}")
print(f"Test size: {len(test)}")

Train size: 430
Validation size: 54
Test size: 54


## 3. Preprocessing 

### Tokenizing

In [60]:
# Place-holders
token_transform = {}
vocab_transform = {}

In [61]:
from torchtext.data.utils import get_tokenizer

# Tokenizer for English (using spaCy)
token_transform[SRC_LANGUAGE] = get_tokenizer('spacy', language='en_core_web_sm')

# Tokenizer for Thai (using PyThaiNLP for Thai segmentation)
from pythainlp import word_tokenize
token_transform[TRG_LANGUAGE] = word_tokenize  # PyThaiNLP tokenizer for Thai

In [62]:
# Tokenization of the source language (English)
english_sentence = sample['en']
print("Sentence (English): ", english_sentence)
tokenized_english = token_transform[SRC_LANGUAGE](english_sentence)
print("Tokenization (English): ", tokenized_english)

Sentence (English):  Tobacco and nicotine industry tactics addict youth for life
Tokenization (English):  ['Tobacco', 'and', 'nicotine', 'industry', 'tactics', 'addict', 'youth', 'for', 'life']


In [63]:
# Tokenization of the target language (Thai)
thai_sentence = sample['th']
print("Sentence (Thai): ", thai_sentence)
tokenized_thai = token_transform[TRG_LANGUAGE](thai_sentence)
print("Tokenization (Thai): ", tokenized_thai)

Sentence (Thai):  กลยุทธ์ของอุตสาหกรรมยาสูบและนิโคตินทำให้เยาวชนเสพติดไปชั่วชีวิต
Tokenization (Thai):  ['กลยุทธ์', 'ของ', 'อุตสาหกรรม', 'ยาสูบ', 'และ', 'นิโคติน', 'ทำให้', 'เยาวชน', 'เสพติด', 'ไป', 'ชั่วชีวิต']


A function to tokenize our input.

In [64]:
# Helper function to yield list of tokens
# Here data can be `train` or `val` or `test`
def yield_tokens(data, language):
    language_index = {SRC_LANGUAGE: 'en', TRG_LANGUAGE: 'th'}

    for data_sample in data:
        yield token_transform[language](data_sample[language_index[language]])  # either 'en' or 'th'

Before we tokenize, let's define some special symbols so our neural network understand the embeddings of these symbols, namely the unknown, the padding, the start of sentence, and end of sentence.

In [65]:
# Define special symbols and indices
UNK_IDX, PAD_IDX, SOS_IDX, EOS_IDX = 0, 1, 2, 3
# Make sure the tokens are in order of their indices to properly insert them in vocab
special_symbols = ['<unk>', '<pad>', '<sos>', '<eos>']

### Text to integers (Numericalization)

Next we gonna create function (torchtext called vocabs) that turn these tokens into integers.  Here we use built in factory function <code>build_vocab_from_iterator</code> which accepts iterator that yield list or iterator of tokens.

In [66]:
from torchtext.vocab import build_vocab_from_iterator

# Build vocabulary for both source (English) and target (Thai) languages
for ln in [SRC_LANGUAGE, TRG_LANGUAGE]:
    # Create torchtext's Vocab object
    vocab_transform[ln] = build_vocab_from_iterator(
        yield_tokens(train, ln),
        min_freq=2,  # If a token appears less than twice, it will be treated as UNK
        specials=special_symbols,
        special_first=True  # Insert special symbols at the beginning
    )

# Set UNK_IDX as the default index. This index is returned when the token is not found.
# If not set, it throws RuntimeError when the queried token is not found in the Vocabulary.
for ln in [SRC_LANGUAGE, TRG_LANGUAGE]:
    vocab_transform[ln].set_default_index(UNK_IDX)

In [67]:
# Check some examples in the vocabulary
print("Example tokenization (English):", vocab_transform[SRC_LANGUAGE](['here', 'is', 'a', 'unknownword', 'a']))

Example tokenization (English): [0, 17, 12, 0, 12]


In [68]:
# Reverse vocabulary lookup: get the token by its index
mapping = vocab_transform[SRC_LANGUAGE].get_itos()

# For example, get the word corresponding to index 22
print("Mapping index 22 to word:", mapping[22])

Mapping index 22 to word: The


In [69]:
# Looking up the special token for unknown words
print("Special token for <unk>:", mapping[0])  # All unknown words will map to <unk>

Special token for <unk>: <unk>


In [70]:
# Looking up other special tokens
print("Special token for <pad>:", mapping[1])
print("Special token for <sos>:", mapping[2])
print("Special token for <eos>:", mapping[3])

Special token for <pad>: <pad>
Special token for <sos>: <sos>
Special token for <eos>: <eos>


In [71]:
# Check how many unique words are in the English vocabulary
print(f"Vocabulary size: {len(mapping)}")

Vocabulary size: 1135


## 4. Preparing the dataloader

One thing we change here is the <code>collate_fn</code> which now also returns the length of sentence.  This is required for <code>packed_padded_sequence</code>

In [72]:
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader

BATCH_SIZE = 16 # Better for small datasets and faster updates

# Helper function to apply sequential operations
def sequential_transforms(*transforms):
    def func(txt_input):
        for transform in transforms:
            txt_input = transform(txt_input)
        return txt_input
    return func

# Function to add BOS/EOS and create tensor for input sequence indices
def tensor_transform(token_ids):
    return torch.cat((torch.tensor([SOS_IDX]), 
                      torch.tensor(token_ids), 
                      torch.tensor([EOS_IDX])))

# src and trg language text transforms to convert raw strings into tensor indices
text_transform = {}
for ln in [SRC_LANGUAGE, TRG_LANGUAGE]:
    text_transform[ln] = sequential_transforms(
        token_transform[ln], # Tokenization
        vocab_transform[ln], # Numericalization
        tensor_transform     # Add BOS/EOS and create tensor
    )
# Function to collate data samples into batch tensors
def collate_batch(batch):
    src_batch, src_len_batch, trg_batch = [], [], []
    
    for sample in batch:
        # Make sure to access the correct fields
        src_sample = sample['en'] 
        trg_sample = sample['th']
        
        # Apply text transformations (tokenization, vocab conversion, tensor creation)
        processed_src_text = text_transform[SRC_LANGUAGE](src_sample.rstrip("\n"))
        processed_trg_text = text_transform[TRG_LANGUAGE](trg_sample.rstrip("\n"))
        
        src_batch.append(processed_src_text)
        trg_batch.append(processed_trg_text)
        src_len_batch.append(processed_src_text.size(0))

    # Pad sequences in the batch
    src_batch = pad_sequence(src_batch, padding_value=PAD_IDX, batch_first=True)
    trg_batch = pad_sequence(trg_batch, padding_value=PAD_IDX, batch_first=True)
    
    # Return the batch along with source sequence lengths
    return src_batch, torch.tensor(src_len_batch, dtype=torch.int64), trg_batch

Create data loaders for training, validation, and test sets

In [73]:
BATCH_SIZE = 16

train_loader = DataLoader(train, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
valid_loader = DataLoader(val, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_batch)
test_loader  = DataLoader(test, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_batch)

In [74]:
# Example: Check the shape of a batch from the training data loader
for en, _, th in train_loader:
    break

In [75]:
print("English shape: ", en.shape)  # (batch_size, seq_len)
print("Thai shape: ", th.shape)   # (batch_size, seq_len)

English shape:  torch.Size([16, 37])
Thai shape:  torch.Size([16, 38])


## 5. Design the model

### Seq2Seq

In [76]:
class Seq2SeqPackedAttention(nn.Module):
    def __init__(self, encoder, decoder, src_pad_idx, device):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_pad_idx = src_pad_idx
        self.device  = device
        
    def create_mask(self, src):
        #src: [src len, batch_size]
        mask = (src == self.src_pad_idx).permute(1, 0)  #permute so that it's the same shape as attention
        #mask: [batch_size, src len] #(0, 0, 0, 0, 0, 1, 1)
        return mask
        
    def forward(self, src, src_len, trg, teacher_forcing_ratio = 0.5):
        #src: [src len, batch_size]
        #trg: [trg len, batch_size]
        
        #initialize something
        batch_size = src.shape[1]
        trg_len    = trg.shape[0]
        trg_vocab_size = self.decoder.output_dim
        
        outputs    = torch.zeros(trg_len, batch_size, trg_vocab_size).to(self.device)
        attentions = torch.zeros(trg_len, batch_size, src.shape[0]).to(self.device)
        
        #send our src text into encoder
        encoder_outputs, hidden = self.encoder(src, src_len)
        #encoder_outputs refer to all hidden states (last layer)
        #hidden refer to the last hidden state (of each layer, of each direction)
        
        input_ = trg[0, :]
        
        mask   = self.create_mask(src) #(0, 0, 0, 0, 0, 1, 1)
        
        #for each of the input of the trg text
        for t in range(1, trg_len):
            #send them to the decoder
            output, hidden, attention = self.decoder(input_, hidden, encoder_outputs, mask)
            #output: [batch_size, output_dim] ==> predictions
            #hidden: [batch_size, hid_dim]
            #attention: [batch_size, src len]
            
            #append the output to a list
            outputs[t] = output
            attentions[t] = attention
            
            teacher_force = random.random() < teacher_forcing_ratio
            top1          = output.argmax(1)  #autoregressive
            
            input_ = trg[t] if teacher_force else top1
            
        return outputs, attentions

### Encoder

In [77]:
class Encoder(nn.Module):
    def __init__(self, input_dim, emb_dim, hid_dim, dropout):
        super().__init__()
        self.embedding = nn.Embedding(input_dim, emb_dim)
        self.rnn       = nn.GRU(emb_dim, hid_dim, bidirectional=True)
        self.fc        = nn.Linear(hid_dim * 2, hid_dim)
        self.dropout   = nn.Dropout(dropout)
        
    def forward(self, src, src_len):
        #embedding
        embedded = self.dropout(self.embedding(src))
        #packed
        packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, src_len.to('cpu'), enforce_sorted=False)
        #rnn
        packed_outputs, hidden = self.rnn(packed_embedded)
        #unpacked
        outputs, _ = nn.utils.rnn.pad_packed_sequence(packed_outputs)
        #-1, -2 hidden state
        hidden = torch.tanh(self.fc(torch.cat((hidden[-2, :, :], hidden[-1, :, :]), dim = 1)))
        
        #outputs: [src len, batch_size, hid dim * 2]
        #hidden:  [batch_size, hid_dim]
        
        return outputs, hidden
        

### Attention  

General Attention

In [78]:
class GeneralAttention(nn.Module):
    def __init__(self, hid_dim):
        super().__init__()
        self.W = nn.Linear(hid_dim, hid_dim)  # for decoder input
        self.U = nn.Linear(hid_dim * 2, hid_dim)  # for encoder outputs
        self.v = nn.Linear(hid_dim, 1, bias=False)  # for final score
    
    def forward(self, hidden, encoder_outputs, mask):
        batch_size = encoder_outputs.shape[1]
        src_len = encoder_outputs.shape[0]
        
        hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)  # [batch_size, src_len, hid_dim]
        encoder_outputs = encoder_outputs.permute(1, 0, 2)  # [batch_size, src_len, hid_dim * 2]
        
        energy = self.v(torch.tanh(self.W(hidden) + self.U(encoder_outputs))).squeeze(2)  # [batch_size, src_len]
        
        energy = energy.masked_fill(mask, -1e10)
        
        return F.softmax(energy, dim=1)

Multiplicative Attention

In [79]:
class MultiplicativeAttention(nn.Module):
    def __init__(self, hid_dim):
        super().__init__()
        self.W = nn.Linear(hid_dim, hid_dim, bias=False)  # for decoder input
        self.U = nn.Linear(hid_dim * 2, hid_dim, bias=False)  # for encoder outputs
    
    def forward(self, hidden, encoder_outputs, mask):
        batch_size = encoder_outputs.shape[1]
        src_len = encoder_outputs.shape[0]
        
        hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)  # [batch_size, src_len, hid_dim]
        encoder_outputs = encoder_outputs.permute(1, 0, 2)  # [batch_size, src_len, hid_dim * 2]
        
        # Scaling dot product (multiplicative attention)
        energy = torch.bmm(hidden, encoder_outputs.permute(0, 2, 1))  # [batch_size, src_len, src_len]
        energy = energy.squeeze(1)  # [batch_size, src_len]
        
        energy = energy.masked_fill(mask, -1e10)
        
        return F.softmax(energy, dim=1)

Additive Attention

In [80]:
class AdditiveAttention(nn.Module):
    def __init__(self, hid_dim):
        super().__init__()
        self.v = nn.Linear(hid_dim, 1, bias=False)
        self.W = nn.Linear(hid_dim, hid_dim)  # for decoder input
        self.U = nn.Linear(hid_dim * 2, hid_dim)  # for encoder outputs
    
    def forward(self, hidden, encoder_outputs, mask):
        batch_size = encoder_outputs.shape[1]
        src_len = encoder_outputs.shape[0]
        
        hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)  # [batch_size, src_len, hid_dim]
        encoder_outputs = encoder_outputs.permute(1, 0, 2)  # [batch_size, src_len, hid_dim * 2]
        
        energy = self.v(torch.tanh(self.W(hidden) + self.U(encoder_outputs))).squeeze(2)  # [batch_size, src_len]
        
        energy = energy.masked_fill(mask, -1e10)
        
        return F.softmax(energy, dim=1)

### Decoder

In [81]:
class Decoder(nn.Module):
    def __init__(self, output_dim, emb_dim, hid_dim, dropout, attention):
        super().__init__()
        self.output_dim = output_dim
        self.attention  = attention
        self.embedding  = nn.Embedding(output_dim, emb_dim)
        self.rnn        = nn.GRU((hid_dim * 2) + emb_dim, hid_dim)
        self.fc         = nn.Linear((hid_dim * 2) + hid_dim + emb_dim, output_dim)
        self.dropout    = nn.Dropout(dropout)
        
    def forward(self, input, hidden, encoder_outputs, mask):
        #input: [batch_size]
        #hidden: [batch_size, hid_dim]
        #encoder_ouputs: [src len, batch_size, hid_dim * 2]
        #mask: [batch_size, src len]
                
        #embed our input
        input    = input.unsqueeze(0)
        embedded = self.dropout(self.embedding(input))
        #embedded = [1, batch_size, emb_dim]
        
        #calculate the attention
        a = self.attention(hidden, encoder_outputs, mask)
        #a = [batch_size, src len]
        a = a.unsqueeze(1)
        #a = [batch_size, 1, src len]
        encoder_outputs = encoder_outputs.permute(1, 0, 2)
        #encoder_ouputs: [batch_size, src len, hid_dim * 2]
        weighted = torch.bmm(a, encoder_outputs)
        #weighted: [batch_size, 1, hid_dim * 2]
        weighted = weighted.permute(1, 0, 2)
        #weighted: [1, batch_size, hid_dim * 2]
        
        #send the input to decoder rnn
            #concatenate (embed, weighted encoder_outputs)
            #[1, batch_size, emb_dim]; [1, batch_size, hid_dim * 2]
        rnn_input = torch.cat((embedded, weighted), dim = 2)
        #rnn_input: [1, batch_size, emb_dim + hid_dim * 2]
        output, hidden = self.rnn(rnn_input, hidden.unsqueeze(0))
            
        #send the output of the decoder rnn to fc layer to predict the word
            #prediction = fc(concatenate (output, weighted, embed))
        embedded = embedded.squeeze(0)
        output   = output.squeeze(0)
        weighted = weighted.squeeze(0)
        prediction = self.fc(torch.cat((embedded, output, weighted), dim = 1))
        #prediction: [batch_size, output_dim]
            
        return prediction, hidden.squeeze(0), a.squeeze(1)

## 6. Training

We use a simplified version of the weight initialization scheme used in the paper. Here, we will initialize all biases to zero and all weights from $\mathcal{N}(0, 0.01)$.

In [82]:
def initialize_weights(m):
    for name, param in m.named_parameters():
        if 'weight' in name:
            nn.init.normal_(param.data, mean=0, std=0.01)
        else:
            nn.init.constant_(param.data, 0)

In [83]:
input_dim   = len(vocab_transform[SRC_LANGUAGE])
output_dim  = len(vocab_transform[TRG_LANGUAGE])
emb_dim     = 256  
hid_dim     = 512  
dropout     = 0.5
SRC_PAD_IDX = PAD_IDX

attn_general = GeneralAttention(hid_dim)
attn_multiplicative = MultiplicativeAttention(hid_dim)
attn_additive = AdditiveAttention(hid_dim)

enc  = Encoder(input_dim, emb_dim, hid_dim, dropout)

dec_general = Decoder(output_dim, emb_dim, hid_dim, dropout, attn_general)
dec_multiplicative = Decoder(output_dim, emb_dim, hid_dim, dropout, attn_multiplicative)
dec_additive = Decoder(output_dim, emb_dim, hid_dim, dropout, attn_additive)

model_general = Seq2SeqPackedAttention(enc, dec_general, SRC_PAD_IDX, device).to(device)
model_multiplicative = Seq2SeqPackedAttention(enc, dec_multiplicative, SRC_PAD_IDX, device).to(device)
model_additive = Seq2SeqPackedAttention(enc, dec_additive, SRC_PAD_IDX, device).to(device)

model_general.apply(initialize_weights)
model_multiplicative.apply(initialize_weights)
model_additive.apply(initialize_weights)


Seq2SeqPackedAttention(
  (encoder): Encoder(
    (embedding): Embedding(1135, 256)
    (rnn): GRU(256, 512, bidirectional=True)
    (fc): Linear(in_features=1024, out_features=512, bias=True)
    (dropout): Dropout(p=0.5, inplace=False)
  )
  (decoder): Decoder(
    (attention): AdditiveAttention(
      (v): Linear(in_features=512, out_features=1, bias=False)
      (W): Linear(in_features=512, out_features=512, bias=True)
      (U): Linear(in_features=1024, out_features=512, bias=True)
    )
    (embedding): Embedding(1089, 256)
    (rnn): GRU(1280, 512)
    (fc): Linear(in_features=1792, out_features=1089, bias=True)
    (dropout): Dropout(p=0.5, inplace=False)
  )
)

In [84]:
#we can print the complexity by the number of parameters
def count_parameters(model):
    params = [p.numel() for p in model.parameters() if p.requires_grad]
    for item in params:
        print(f'{item:>6}')
    print(f'______\n{sum(params):>6}')
    
count_parameters(model_general)
count_parameters(model_multiplicative)
count_parameters(model_additive)

290560
393216
786432
  1536
  1536
393216
786432
  1536
  1536
524288
   512
262144
   512
524288
   512
   512
278784
1966080
786432
  1536
  1536
1951488
  1089
______
8955713
290560
393216
786432
  1536
  1536
393216
786432
  1536
  1536
524288
   512
262144
524288
278784
1966080
786432
  1536
  1536
1951488
  1089
______
8954177
290560
393216
786432
  1536
  1536
393216
786432
  1536
  1536
524288
   512
   512
262144
   512
524288
   512
278784
1966080
786432
  1536
  1536
1951488
  1089
______
8955713


In [85]:
import torch.optim as optim

lr = 0.001

# training hyperparameters
optimizer_general = optim.Adam(model_general.parameters(), lr=lr)
optimizer_multiplicative = optim.Adam(model_multiplicative.parameters(), lr=lr)
optimizer_additive = optim.Adam(model_additive.parameters(), lr=lr)

criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX) # combine softmax with cross entropy

In [86]:
def train(model, loader, optimizer, criterion, clip, loader_length):
    
    model.train()
    epoch_loss = 0
    
    for src, src_length, trg in loader:
        
        src = src.to(device)
        trg = trg.to(device)
        
        optimizer.zero_grad()
        
        output, attentions = model(src, src_length, trg)
        
        #trg    = [trg len, batch size]
        #output = [trg len, batch size, output dim]
        output_dim = output.shape[-1]
        
        #the loss function only works on 2d inputs with 1d targets thus we need to flatten each of them
        output = output[1:].view(-1, output_dim)
        trg    = trg[1:].view(-1)
        #trg    = [(trg len - 1) * batch size]
        #output = [(trg len - 1) * batch size, output dim]
        
        loss = criterion(output, trg)
        
        loss.backward()
        
        #clip the gradients to prevent them from exploding (a common issue in RNNs)
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        
        optimizer.step()
        
        epoch_loss += loss.item()
        
    return epoch_loss / loader_length

In [87]:
def evaluate(model, loader, criterion, loader_length):
        
    #turn off dropout (and batch norm if used)
    model.eval()
    
    epoch_loss = 0
    
    with torch.no_grad():
    
        for src, src_length, trg in loader:
        
            src = src.to(device)
            trg = trg.to(device)

            output, attentions = model(src, src_length, trg, 0) #turn off teacher forcing

            #trg    = [trg len, batch size]
            #output = [trg len, batch size, output dim]

            output_dim = output.shape[-1]
            
            output = output[1:].view(-1, output_dim)
            trg = trg[1:].view(-1)

            #trg    = [(trg len - 1) * batch size]
            #output = [(trg len - 1) * batch size, output dim]

            loss = criterion(output, trg)
            
            epoch_loss += loss.item()
        
    return epoch_loss / loader_length

### Putting everything together

In [88]:
train_loader_length = len(list(iter(train_loader)))
val_loader_length   = len(list(iter(valid_loader)))
test_loader_length  = len(list(iter(test_loader)))

In [89]:
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [None]:
best_valid_loss_general = float('inf')
best_valid_loss_multiplicative = float('inf')
best_valid_loss_additive = float('inf')

num_epochs = 10
clip       = 1

save_path_general = f'models/{model_general.__class__.__name__}.pt'
save_path_multiplicative = f'models/{model_multiplicative.__class__.__name__}.pt'
save_path_additive = f'models/{model_additive.__class__.__name__}.pt'

train_losses_general = []
valid_losses_general = []

train_losses_multiplicative = []
valid_losses_multiplicative = []

train_losses_additive = []
valid_losses_additive = []

for epoch in range(num_epochs):
    
    start_time = time.time()

    train_loss_general = train(model_general, train_loader, optimizer_general, criterion, clip, train_loader_length)
    valid_loss_general = evaluate(model_general, valid_loader, criterion, val_loader_length)
    train_losses_general.append(train_loss_general)
    valid_losses_general.append(valid_loss_general)
    
    train_loss_multiplicative = train(model_multiplicative, train_loader, optimizer_multiplicative, criterion, clip, train_loader_length)
    valid_loss_multiplicative = evaluate(model_multiplicative, valid_loader, criterion, val_loader_length)
    train_losses_multiplicative.append(train_loss_multiplicative)
    valid_losses_multiplicative.append(valid_loss_multiplicative)
    
    train_loss_additive = train(model_additive, train_loader, optimizer_additive, criterion, clip, train_loader_length)
    valid_loss_additive = evaluate(model_additive, valid_loader, criterion, val_loader_length)
    train_losses_additive.append(train_loss_additive)
    valid_losses_additive.append(valid_loss_additive)
    
    end_time = time.time()
    epoch_mins, epoch_secs = epoch_time(start_time, end_time)

    if valid_loss_general < best_valid_loss_general:
        best_valid_loss_general = valid_loss_general
        torch.save(model_general.state_dict(), save_path_general)
        
    if valid_loss_multiplicative < best_valid_loss_multiplicative:
        best_valid_loss_multiplicative = valid_loss_multiplicative
        torch.save(model_multiplicative.state_dict(), save_path_multiplicative)
        
    if valid_loss_additive < best_valid_loss_additive:
        best_valid_loss_additive = valid_loss_additive
        torch.save(model_additive.state_dict(), save_path_additive)
    
    print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tGeneral Model - Train Loss: {train_loss_general:.3f} | Train PPL: {math.exp(train_loss_general):7.3f}')
    print(f'\tGeneral Model - Val. Loss: {valid_loss_general:.3f} |  Val. PPL: {math.exp(valid_loss_general):7.3f}')
    
    print(f'\tMultiplicative Model - Train Loss: {train_loss_multiplicative:.3f} | Train PPL: {math.exp(train_loss_multiplicative):7.3f}')
    print(f'\tMultiplicative Model - Val. Loss: {valid_loss_multiplicative:.3f} |  Val. PPL: {math.exp(valid_loss_multiplicative):7.3f}')
    
    print(f'\tAdditive Model - Train Loss: {train_loss_additive:.3f} | Train PPL: {math.exp(train_loss_additive):7.3f}')
    print(f'\tAdditive Model - Val. Loss: {valid_loss_additive:.3f} |  Val. PPL: {math.exp(valid_loss_additive):7.3f}')

Plotting Training and Validation Losses for Each Attention Type

In [None]:
import matplotlib.pyplot as plt

# Create a figure for plotting
fig, axs = plt.subplots(3, 1, figsize=(10, 15))

# Plot for General Attention Model
axs[0].plot(train_losses_general, label='train loss')
axs[0].plot(valid_losses_general, label='valid loss')
axs[0].set_title('General Attention')
axs[0].legend()
axs[0].set_xlabel('Epochs')
axs[0].set_ylabel('Loss')

# Plot for Multiplicative Attention Model
axs[1].plot(train_losses_multiplicative, label='train loss')
axs[1].plot(valid_losses_multiplicative, label='valid loss')
axs[1].set_title('Multiplicative Attention')
axs[1].legend()
axs[1].set_xlabel('Epochs')
axs[1].set_ylabel('Loss')

# Plot for Additive Attention Model
axs[2].plot(train_losses_additive, label='train loss')
axs[2].plot(valid_losses_additive, label='valid loss')
axs[2].set_title('Additive Attention')
axs[2].legend()
axs[2].set_xlabel('Epochs')
axs[2].set_ylabel('Loss')

plt.tight_layout()
plt.show()

Evaluate and Test the Best Model for Each Attention Type

In [None]:
# For General Attention Model
model_general.load_state_dict(torch.load(save_path_general))
test_loss_general = evaluate(model_general, test_loader, criterion, test_loader_length)
print(f'General Model - Test Loss: {test_loss_general:.3f} | Test PPL: {math.exp(test_loss_general):7.3f}')

# For Multiplicative Attention Model
model_multiplicative.load_state_dict(torch.load(save_path_multiplicative))
test_loss_multiplicative = evaluate(model_multiplicative, test_loader, criterion, test_loader_length)
print(f'Multiplicative Model - Test Loss: {test_loss_multiplicative:.3f} | Test PPL: {math.exp(test_loss_multiplicative):7.3f}')

# For Additive Attention Model
model_additive.load_state_dict(torch.load(save_path_additive))
test_loss_additive = evaluate(model_additive, test_loader, criterion, test_loader_length)
print(f'Additive Model - Test Loss: {test_loss_additive:.3f} | Test PPL: {math.exp(test_loss_additive):7.3f}')

The BLEU score shows how good a translation is by comparing it to a reference translation. A higher BLEU score means the translation is better and more similar to the reference.

In [None]:
from nltk.translate.bleu_score import corpus_bleu

def calculate_bleu_score(reference, hypothesis):
    return corpus_bleu(reference, hypothesis)

# Evaluate BLEU score for each attention mechanism after translation
def evaluate_bleu_score(models, test_loader):
    all_hypotheses = {model.__class__.__name__: [] for model in models}
    all_references = []

    for src, src_length, trg in test_loader:
        src = src.to(device)
        trg = trg.to(device)

        for model in models:
            model.eval()

            output, _ = model(src, src_length, trg, 0)  # No teacher forcing during eval
            output_dim = output.shape[-1]
            output = output.argmax(dim=-1).cpu().numpy()

            # Collect predictions for BLEU score calculation
            for batch_idx, pred in enumerate(output):
                hypothesis = [vocab_transform[TRG_LANGUAGE].lookup_tokens(pred)]
                reference = [vocab_transform[TRG_LANGUAGE].lookup_tokens(trg[batch_idx].cpu().numpy())]
                all_hypotheses[model.__class__.__name__].append(hypothesis)
                all_references.append(reference)

    bleu_scores = {}
    for model_name, hypotheses in all_hypotheses.items():
        bleu_scores[model_name] = calculate_bleu_score(all_references, hypotheses)
    
    return bleu_scores

# Example of using BLEU scores
bleu_scores = evaluate_bleu_score([model_general, model_multiplicative, model_additive], test_loader)
print(bleu_scores)


## 7. Test randomly

In [None]:
sample[0]

In [None]:
src_text = text_transform[SRC_LANGUAGE](sample[0]).to(device)
src_text

In [None]:
trg_text = text_transform[TRG_LANGUAGE](sample[1]).to(device)
trg_text

In [None]:
src_text = src_text.reshape(-1, 1)  #because batch_size is 1

In [None]:
trg_text = trg_text.reshape(-1, 1)

In [None]:
src_text.shape, trg_text.shape

In [None]:
text_length = torch.tensor([src_text.size(0)]).to(dtype=torch.int64)

In [None]:
model_general.load_state_dict(torch.load(save_path_general))

model_general.eval()
with torch.no_grad():
    output, attentions = model_general(src_text, text_length, trg_text, 0) #turn off teacher forcing

In [None]:
output.shape #trg_len, batch_size, trg_output_dim

In [None]:
output = output.squeeze(1)
output.shape

In [None]:
output = output[1:]
output.shape #trg_len, trg_output_dim

In [None]:
output_max = output.argmax(1) #returns max indices
output_max

Get the mapping of the target language

In [None]:
mapping = vocab_transform[TRG_LANGUAGE].get_itos()

In [None]:
for token in output_max:
    print(mapping[token.item()])

## 8. Attention

In [None]:
attentions.shape

In [None]:
src_tokens = ['<sos>'] + token_transform[SRC_LANGUAGE](sample[0]) + ['<eos>']
src_tokens

In [None]:
trg_tokens = ['<sos>'] + [mapping[token.item()] for token in output_max]
trg_tokens

Display Attention Maps for Each Attention Type

In [None]:
import matplotlib.ticker as ticker

def display_attention(sentence, translation, attention, attention_type):

    fig = plt.figure(figsize=(10, 10))
    ax = fig.add_subplot(111)

    attention = attention.squeeze(1).cpu().detach().numpy()

    cax = ax.matshow(attention, cmap='bone')
    
    ax.tick_params(labelsize=10)
    
    y_ticks = [''] + translation
    x_ticks = [''] + sentence
    
    ax.set_xticklabels(x_ticks, rotation=45)
    ax.set_yticklabels(y_ticks)
    
    ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
    ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
    
    ax.set_title(f'{attention_type} Attention')
    
    plt.show()
    plt.close()

In [None]:
display_attention(src_tokens, trg_tokens, attentions)