In [16]:
import nltk
nltk.download('punkt')
nltk.download('punkt_tab')
from nltk.tokenize import word_tokenize
from datasets import load_dataset, DatasetDict
from collections import Counter
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
import time

[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\surya\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package punkt_tab to
[nltk_data]     C:\Users\surya\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


In [17]:
torch.__version__

'2.5.1'

In [18]:
SRC_LANGUAGE = 'en'
TRG_LANGUAGE = 'hi'

In [19]:
SEED = 1234
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

In [20]:
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('cpu')
print(device)

cpu


In [21]:
# Load dataset
dataset = load_dataset("pary/hind_encorp", trust_remote_code=True, split="train")

# Shuffle dataset before splitting
dataset = dataset.shuffle(seed=42)

# Define split sizes
train_size = int(0.8 * len(dataset))
val_size = int(0.1 * len(dataset))
test_size = len(dataset) - train_size - val_size

# Split dataset
train_dataset = dataset.select(range(train_size))
val_dataset = dataset.select(range(train_size, train_size + val_size))
test_dataset = dataset.select(range(train_size + val_size, len(dataset)))

# Store in a DatasetDict
split_datasets = DatasetDict({
    "train": train_dataset,
    "val": val_dataset,
    "test": test_dataset
})

print(f"Train size: {len(train_dataset)}, Validation size: {len(val_dataset)}, Test size: {len(test_dataset)}")

Train size: 219108, Validation size: 27388, Test size: 27389


In [22]:
# maxValue = 0

# for idxa, value in enumerate(split_datasets['train']):
#     if len(value['translation'][SRC_LANGUAGE]) >= maxValue:
#         maxValue = len(value['translation'][SRC_LANGUAGE])
#     if idxa == 5:
#         break
        
# print(maxValue)
# print(len(split_datasets['train'][22]['translation']['en']))

In [23]:
# from indicnlp.tokenize import indic_tokenize
def english_tokenizer(text):
    return [token.lower() for token in word_tokenize(text)]  # Tokenizes English text into words

def hindi_tokenizer(text):
    return [token for token in word_tokenize(text)]
    # return indic_tokenize.trivial_tokenize(text, lang='hi')# Tokenizes Hindi text into words (works well for simple Hindi text)


In [24]:
PAD_IDX = 1
SOS_IDX = 2
EOS_IDX = 3
UNK_IDX = 0
SPECIAL_TOKENS = ['<unk>', '<pad>', '<sos>', '<eos>']

# Tokenize the dataset and add special tokens
def tokenize_addSplTokens(sample):
    sample['translation']['en'] = english_tokenizer(sample['translation']['en'])
    sample['translation']['hi'] = hindi_tokenizer(sample['translation']['hi'])
    return sample

# Apply tokenization
tokenized_datasets = split_datasets.map(tokenize_addSplTokens).remove_columns(['id', 'source', 'alignment_type', 'alignment_quality'])

In [25]:
tokenized_datasets['train'][1]

{'translation': {'en': ['enable', "'adview", "'", 'element'],
  'hi': ["'adview", "'", 'तत्व', 'सक्षम', 'करें']}}

In [26]:
def build_vocab(dataset, lang, min_freq=2):
    counter = Counter()
    for line in dataset:
        for word in line['translation'][lang]:
            counter.update(word)  # Counting word occurrences

    vocab = {word: idx + 4 for idx, (word, freq) in enumerate(counter.items())}

    # Add special tokens at the beginning
    vocab["<pad>"] = PAD_IDX
    vocab["<sos>"] = SOS_IDX
    vocab["<eos>"] = EOS_IDX
    vocab["<unk>"] = UNK_IDX

    return vocab

# Build vocab for English and Hindi using the train dataset
# en_vocab = build_vocab(tokenized_datasets["train"], lang="en")
# hi_vocab = build_vocab(tokenized_datasets["train"], lang="hi")
en_vocab = build_vocab(tokenized_datasets["train"], lang="en", min_freq=2)
hi_vocab = build_vocab(tokenized_datasets["train"], lang="hi", min_freq=2)

print(f"English vocab size: {len(en_vocab)}, Hindi vocab size: {len(hi_vocab)}")


English vocab size: 253, Hindi vocab size: 544


In [27]:
print(tokenized_datasets['train'][0]['translation']['en'])
print(tokenized_datasets['train'][0]['translation']['hi'])

print(en_vocab)
print(hi_vocab)

['ramayana', 'effect', 'on', 'various', 'cultures', 'and', 'civilization-', '(', 'from', 'of', 'p.d.f', ')']
['विवध', 'संस्कृतियों', 'एवं', 'सभ्यताओं', 'पर', 'रामायण', 'का', 'प्रभाव', '-', '(', 'पी.डी.एफ़', '.', 'संरूप', 'में', ')']
{'r': 4, 'a': 5, 'm': 6, 'y': 7, 'n': 8, 'e': 9, 'f': 10, 'c': 11, 't': 12, 'o': 13, 'v': 14, 'i': 15, 'u': 16, 's': 17, 'l': 18, 'd': 19, 'z': 20, '-': 21, '(': 22, 'p': 23, '.': 24, ')': 25, 'b': 26, "'": 27, 'w': 28, 'h': 29, 'j': 30, 'g': 31, 'x': 32, '1': 33, '9': 34, '2': 35, '4': 36, ',': 37, '5': 38, '3': 39, 'k': 40, 'q': 41, ':': 42, '8': 43, '7': 44, '0': 45, '/': 46, '<': 47, '>': 48, '“': 49, '”': 50, '?': 51, '%': 52, '{': 53, '}': 54, '6': 55, ';': 56, '…': 57, '_': 58, '£': 59, '[': 60, ']': 61, '&': 62, '=': 63, '!': 64, '|': 65, 'í': 66, '♫': 67, '$': 68, 'ô': 69, 'é': 70, '☻': 71, '☺': 72, 'ğ': 73, '*': 74, 'ü': 75, '+': 76, 'ć': 77, 'á': 78, 'š': 79, 'ý': 80, 'ñ': 81, '‘': 82, '¡': 83, 'ú': 84, '\\': 85, 'ö': 86, 'ā': 87, '@': 88, 'č': 8

In [28]:
# Converts tokenized text into numericalized format
def numericalize_text(text, vocab):
    return [vocab.get(word, UNK_IDX) for word in text]


# Adds <sos> and <eos> tokens and converts to tensor
def tensor_transform(token_ids):
    return torch.cat((torch.tensor([SOS_IDX]), 
                      torch.tensor(token_ids), 
                      torch.tensor([EOS_IDX])))

# Combines all transformations into one function
def sequential_transforms(*transforms):
    def func(txt_input):
        for transform in transforms:
            txt_input = transform(txt_input)
        return txt_input
    return func

# Define text transformation pipeline
text_transform = {
    "en": sequential_transforms(lambda x: numericalize_text(x, en_vocab), tensor_transform),
    "hi": sequential_transforms(lambda x: numericalize_text(x, hi_vocab), tensor_transform)
}


In [29]:
def encode_sample(sample):
    sample['translation']['en'] = text_transform["en"](sample['translation']['en'])
    sample['translation']['hi'] = text_transform["hi"](sample['translation']['hi'])
    return sample

numericalized_datasets = tokenized_datasets.map(encode_sample)

Map:   0%|          | 0/219108 [00:00<?, ? examples/s]

Map:   0%|          | 0/27388 [00:00<?, ? examples/s]

Map:   0%|          | 0/27389 [00:00<?, ? examples/s]

In [37]:
print(numericalized_datasets["train"][34])  # Check numericalized format

{'translation': {'en': [2, 0, 42, 3], 'hi': [2, 0, 0, 24, 73, 32, 80, 3]}}


In [38]:
BATCH_SIZE = 64
# function to collate data samples into batch tesors
def collate_batch(batch):
    src_batch, src_len_batch, trg_batch = [], [], []
    for src_sample, trg_sample in batch:
        # Process and truncate source text to a max length of 1000
        processed_text = text_transform[SRC_LANGUAGE](src_sample.rstrip("\n"))[:2000]
        src_batch.append(processed_text)
        src_len_batch.append(processed_text.size(0))  # Store length of source text

        # Process and truncate target text to a max length of 1000
        target_text = text_transform[TRG_LANGUAGE](trg_sample.rstrip("\n"))[:2000]
        trg_batch.append(target_text)

    # Pad the sequences to ensure they are all the same length (max length in the batch)
    src_batch = pad_sequence(src_batch, padding_value=PAD_IDX, batch_first=True)
    trg_batch = pad_sequence(trg_batch, padding_value=PAD_IDX, batch_first=True)

    return src_batch, torch.tensor(src_len_batch, dtype=torch.int64), trg_batch

In [39]:
# Create DataLoaders
train_loader = DataLoader(split_datasets['train'], batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
val_loader = DataLoader(split_datasets['val'], batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_batch)
test_loader = DataLoader(split_datasets['test'], batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_batch)

In [41]:
for en, _, hi in train_loader:
    break

ValueError: too many values to unpack (expected 2)

In [42]:
print("English shape: ", en.shape)  # (batch_size, seq len)
print("Hindi shape: ", hi.shape)   # (batch_size, seq len)

NameError: name 'en' is not defined

In [44]:
class EncoderLayer(nn.Module):
    def __init__(self, hid_dim, n_heads, pf_dim, dropout, device):
        super().__init__()
        self.self_attn_layer_norm = nn.LayerNorm(hid_dim)
        self.ff_layer_norm        = nn.LayerNorm(hid_dim)
        self.self_attention       = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, device)
        self.feedforward          = PositionwiseFeedforwardLayer(hid_dim, pf_dim, dropout)
        self.dropout              = nn.Dropout(dropout)
        
    def forward(self, src, src_mask):
        #src = [batch size, src len, hid dim]
        #src_mask = [batch size, 1, 1, src len]   #if the token is padding, it will be 1, otherwise 0
        _src, _ = self.self_attention(src, src, src, src_mask)
        src     = self.self_attn_layer_norm(src + self.dropout(_src))
        #src: [batch_size, src len, hid dim]
        
        _src    = self.feedforward(src)
        src     = self.ff_layer_norm(src + self.dropout(_src))
        #src: [batch_size, src len, hid dim]
        
        return src

In [45]:
class Encoder(nn.Module):
    def __init__(self, input_dim, hid_dim, n_layers, n_heads, pf_dim, dropout, device, max_length = 100):
        super().__init__()
        self.device = device
        self.tok_embedding = nn.Embedding(input_dim, hid_dim)
        self.pos_embedding = nn.Embedding(max_length, hid_dim)
        self.layers        = nn.ModuleList([EncoderLayer(hid_dim, n_heads, pf_dim, dropout, device)
                                           for _ in range(n_layers)])
        self.dropout       = nn.Dropout(dropout)
        self.scale         = torch.sqrt(torch.FloatTensor([hid_dim])).to(self.device)
        
    def forward(self, src, src_mask):
        
        #src = [batch size, src len]
        #src_mask = [batch size, 1, 1, src len]
        
        batch_size = src.shape[0]
        src_len    = src.shape[1]
        
        pos        = torch.arange(0, src_len).unsqueeze(0).repeat(batch_size, 1).to(self.device)
        #pos: [batch_size, src_len]
        
        src        = self.dropout((self.tok_embedding(src) * self.scale) + self.pos_embedding(pos))
        #src: [batch_size, src_len, hid_dim]
        
        for layer in self.layers:
            src = layer(src, src_mask)
        #src: [batch_size, src_len, hid_dim]
        
        return src
            

In [46]:
class MultiHeadAttentionLayer(nn.Module):
    def __init__(self, hid_dim, n_heads, dropout, device):
        super().__init__()
        assert hid_dim % n_heads == 0
        self.hid_dim  = hid_dim
        self.n_heads  = n_heads
        self.head_dim = hid_dim // n_heads
        
        self.fc_q     = nn.Linear(hid_dim, hid_dim)
        self.fc_k     = nn.Linear(hid_dim, hid_dim)
        self.fc_v     = nn.Linear(hid_dim, hid_dim)
        
        self.fc_o     = nn.Linear(hid_dim, hid_dim)
        
        self.dropout  = nn.Dropout(dropout)
        
        self.scale    = torch.sqrt(torch.FloatTensor([self.head_dim])).to(device)
                
    def forward(self, query, key, value, mask = None):
        #src, src, src, src_mask
        #query = [batch size, query len, hid dim]
        #key = [batch size, key len, hid dim]
        #value = [batch size, value len, hid dim]
        
        batch_size = query.shape[0]
        
        Q = self.fc_q(query)
        K = self.fc_k(key)
        V = self.fc_v(value)
        #Q=K=V: [batch_size, src len, hid_dim]
        
        Q = Q.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        K = K.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        V = V.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        #Q = [batch_size, n heads, query len, head_dim]
        
        energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale
        #Q = [batch_size, n heads, query len, head_dim] @ K = [batch_size, n heads, head_dim, key len]
        #energy = [batch_size, n heads, query len, key len]
        
        #for making attention to padding to 0
        if mask is not None:
            energy = energy.masked_fill(mask == 0, -1e10)
            
        attention = torch.softmax(energy, dim = -1)
        #attention = [batch_size, n heads, query len, key len]
        
        x = torch.matmul(self.dropout(attention), V)
        #[batch_size, n heads, query len, key len] @ [batch_size, n heads, value len, head_dim]
        #x = [batch_size, n heads, query len, head dim]
        
        x = x.permute(0, 2, 1, 3).contiguous()  #we can perform .view
        #x = [batch_size, query len, n heads, head dim]
        
        x = x.view(batch_size, -1, self.hid_dim)
        #x = [batch_size, query len, hid dim]
        
        x = self.fc_o(x)
        #x = [batch_size, query len, hid dim]
        
        return x, attention
        

In [47]:
class PositionwiseFeedforwardLayer(nn.Module):
    def __init__(self, hid_dim, pf_dim, dropout):
        super().__init__()
        self.fc1 = nn.Linear(hid_dim, pf_dim)
        self.fc2 = nn.Linear(pf_dim, hid_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        #x = [batch size, src len, hid dim]
        x = self.dropout(torch.relu(self.fc1(x)))
        x = self.fc2(x)
        
        return x

In [48]:
class DecoderLayer(nn.Module):
    def __init__(self, hid_dim, n_heads, pf_dim, dropout, device):
        super().__init__()
        self.self_attn_layer_norm = nn.LayerNorm(hid_dim)
        self.enc_attn_layer_norm  = nn.LayerNorm(hid_dim)
        self.ff_layer_norm        = nn.LayerNorm(hid_dim)
        self.self_attention       = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, device)
        self.encoder_attention    = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, device)
        self.feedforward          = PositionwiseFeedforwardLayer(hid_dim, pf_dim, dropout)
        self.dropout              = nn.Dropout(dropout)
        
    def forward(self, trg, enc_src, trg_mask, src_mask):
        
        #trg = [batch size, trg len, hid dim]
        #enc_src = [batch size, src len, hid dim]
        #trg_mask = [batch size, 1, trg len, trg len]
        #src_mask = [batch size, 1, 1, src len]
        
        _trg, _ = self.self_attention(trg, trg, trg, trg_mask)
        trg     = self.self_attn_layer_norm(trg + self.dropout(_trg))
        #trg = [batch_size, trg len, hid dim]
        
        _trg, attention = self.encoder_attention(trg, enc_src, enc_src, src_mask)
        trg             = self.enc_attn_layer_norm(trg + self.dropout(_trg))
        #trg = [batch_size, trg len, hid dim]
        #attention = [batch_size, n heads, trg len, src len]
        
        _trg = self.feedforward(trg)
        trg  = self.ff_layer_norm(trg + self.dropout(_trg))
        #trg = [batch_size, trg len, hid dim]
        
        return trg, attention

In [49]:
class Decoder(nn.Module):
    def __init__(self, output_dim, hid_dim, n_layers, n_heads, 
                 pf_dim, dropout, device,max_length = 100):
        super().__init__()
        self.device = device
        self.tok_embedding = nn.Embedding(output_dim, hid_dim)
        self.pos_embedding = nn.Embedding(max_length, hid_dim)
        self.layers        = nn.ModuleList([DecoderLayer(hid_dim, n_heads, pf_dim, dropout, device)
                                            for _ in range(n_layers)])
        self.fc_out        = nn.Linear(hid_dim, output_dim)
        self.dropout       = nn.Dropout(dropout)
        self.scale         = torch.sqrt(torch.FloatTensor([hid_dim])).to(device)
        
    def forward(self, trg, enc_src, trg_mask, src_mask):
        
        #trg = [batch size, trg len]
        #enc_src = [batch size, src len, hid dim]
        #trg_mask = [batch size, 1, trg len, trg len]
        #src_mask = [batch size, 1, 1, src len]
        
        batch_size = trg.shape[0]
        trg_len    = trg.shape[1]
        
        pos = torch.arange(0, trg_len).unsqueeze(0).repeat(batch_size, 1).to(self.device)
        #pos: [batch_size, trg len]
        
        trg = self.dropout((self.tok_embedding(trg) * self.scale) + self.pos_embedding(pos))
        #trg: [batch_size, trg len, hid dim]
        
        for layer in self.layers:
            trg, attention = layer(trg, enc_src, trg_mask, src_mask)
            
        #trg: [batch_size, trg len, hid dim]
        #attention: [batch_size, n heads, trg len, src len]
        
        output = self.fc_out(trg)
        #output = [batch_size, trg len, output_dim]
        
        return output, attention

In [75]:
class Seq2SeqTransformer(nn.Module):
    def __init__(self, encoder, decoder, src_pad_idx, trg_pad_idx, device):
        super().__init__()
        
        self.encoder = encoder
        self.decoder = decoder
        self.src_pad_idx = src_pad_idx
        self.trg_pad_idx = trg_pad_idx
        self.device = device
        
    def make_src_mask(self, src):
        
        #src = [batch size, src len]
        
        src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)
        #src_mask = [batch size, 1, 1, src len]

        return src_mask
    
    def make_trg_mask(self, trg):
        
        #trg = [batch size, trg len]
        
        trg_pad_mask = (trg != self.trg_pad_idx).unsqueeze(1).unsqueeze(2)
        #trg_pad_mask = [batch size, 1, 1, trg len]
        
        trg_len = trg.shape[1]
        
        trg_sub_mask = torch.tril(torch.ones((trg_len, trg_len), device = self.device)).bool()
        #trg_sub_mask = [trg len, trg len]
            
        trg_mask = trg_pad_mask & trg_sub_mask
        #trg_mask = [batch size, 1, trg len, trg len]
        
        return trg_mask

    def forward(self, src, trg):
        
        #src = [batch size, src len]
        #trg = [batch size, trg len]
                
        src_mask = self.make_src_mask(src)
        trg_mask = self.make_trg_mask(trg)
        
        #src_mask = [batch size, 1, 1, src len]
        #trg_mask = [batch size, 1, trg len, trg len]
        
        enc_src = self.encoder(src, src_mask)
        #enc_src = [batch size, src len, hid dim]
                
        output, attention = self.decoder(trg, enc_src, trg_mask, src_mask)
        
        #output = [batch size, trg len, output dim]
        #attention = [batch size, n heads, trg len, src len]
        
        return output, attention

# Training

In [76]:
def initialize_weights(m):
    if hasattr(m, 'weight') and m.weight.dim() > 1:
        nn.init.xavier_uniform_(m.weight.data)

In [77]:
# Initialize Encoder and Decoder
INPUT_DIM = len(en_vocab)
OUTPUT_DIM = len(hi_vocab)
HID_DIM = 256
ENC_LAYERS = 3
DEC_LAYERS = 3
ENC_HEADS = 8
DEC_HEADS = 8
ENC_PF_DIM = 512
DEC_PF_DIM = 512
ENC_DROPOUT = 0.1
DEC_DROPOUT = 0.1

enc = Encoder(INPUT_DIM, 
              HID_DIM, 
              ENC_LAYERS, 
              ENC_HEADS, 
              ENC_PF_DIM, 
              ENC_DROPOUT, 
              device)

dec = Decoder(OUTPUT_DIM, 
              HID_DIM, 
              DEC_LAYERS, 
              DEC_HEADS, 
              DEC_PF_DIM, 
              DEC_DROPOUT, 
              device)
# Padding index for source and target sequences
SRC_PAD_IDX = PAD_IDX
TRG_PAD_IDX = PAD_IDX

In [None]:
input_dim   = len(en_vocab)
output_dim  = len(hi_vocab)
hid_dim = 256
enc_layers = 3
dec_layers = 3
enc_heads = 8
dec_heads = 8
enc_pf_dim = 512
dec_pf_dim = 512
enc_dropout = 0.1
dec_dropout = 0.1

SRC_PAD_IDX = PAD_IDX
TRG_PAD_IDX = PAD_IDX

enc = Encoder(input_dim, 
              hid_dim, 
              enc_layers, 
              enc_heads, 
              enc_pf_dim, 
              enc_dropout, 
              device)

dec = Decoder(output_dim, 
              hid_dim, 
              dec_layers, 
              dec_heads, 
              dec_pf_dim, 
              dec_dropout, 
              device)

model = Seq2SeqTransformer(enc, dec, SRC_PAD_IDX, TRG_PAD_IDX, device).to(device)
model.apply(initialize_weights)

In [None]:
#we can print the complexity by the number of parameters
def count_parameters(model):
    params = [p.numel() for p in model.parameters() if p.requires_grad]
    for item in params:
        print(f'{item:>6}')
    print(f'______\n{sum(params):>6}')
    
count_parameters(model)

In [80]:
import torch.optim as optim

lr = 0.0005

#training hyperparameters
optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss(ignore_index = TRG_PAD_IDX) #combine softmax with cross entropy

In [81]:
def train(model, loader, optimizer, criterion, clip, loader_length):
    
    model.train()
    
    epoch_loss = 0
    
    for src, src_len, trg in loader:
        src = src.to(device)
        trg = trg.to(device)
        
        optimizer.zero_grad()
        
        #trg[:, :-1] remove the eos, e.g., "<sos> I love sushi" since teaching forcing, the input does not need to have eos
        try:
            output, _ = model(src, trg[:,:-1])
        except:
            continue
                
        #output = [batch size, trg len - 1, output dim]
        #trg    = [batch size, trg len]
            
        output_dim = output.shape[-1]
            
        output = output.reshape(-1, output_dim)
        
        trg = trg[:,1:].reshape(-1) #trg[:, 1:] remove the sos, e.g., "i love sushi <eos>" since in teaching forcing, the output does not have sos
        #output = [batch size * trg len - 1, output dim]
        #trg    = [batch size * trg len - 1]
            
        loss = criterion(output, trg)
        
        loss.backward()
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        
        optimizer.step()
        
        epoch_loss += loss.item()
        
    return epoch_loss / loader_length

In [82]:
def evaluate(model, loader, criterion, loader_length):
    
    model.eval()
    
    epoch_loss = 0
    
    with torch.no_grad():
    
        for src, src_len, trg in loader:
        
            src = src.to(device)
            trg = trg.to(device)

            try:
                output, _ = model(src, trg[:,:-1])
            except:
                continue
            
            #output = [batch size, trg len - 1, output dim]
            #trg = [batch size, trg len]
            
            output_dim = output.shape[-1]
            
            output = output.contiguous().view(-1, output_dim)
            trg = trg[:,1:].contiguous().view(-1)
            
            #output = [batch size * trg len - 1, output dim]
            #trg = [batch size * trg len - 1]
            
            loss = criterion(output, trg)

            epoch_loss += loss.item()
        
    return epoch_loss / loader_length

In [83]:
train_loader_length = len(train_loader)
val_loader_length   = len(val_loader)
test_loader_length  = len(test_loader)

In [84]:
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [85]:
import math

In [None]:
best_val_loss = float('inf')
num_epochs = 2
clip       = 1

save_path = f'models/{model.__class__.__name__}.pt'

train_losses = []
val_losses = []

for epoch in range(num_epochs):
    
    start_time = time.time()

    train_loss = train(model, train_loader, optimizer, criterion, clip, train_loader_length)
    val_loss = evaluate(model, val_loader, criterion, val_loader_length)
    
    #for plotting
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    
    end_time = time.time()
    
    epoch_mins, epoch_secs = epoch_time(start_time, end_time)

    
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), save_path)
    
    print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')
    print(f'\t Val. Loss: {val_loss:.3f} |  Val. PPL: {math.exp(val_loss):7.3f}')
    
    #lower perplexity is better

In [None]:
print("Sample English vocab:", list(en_vocab.items())[:20])
print("Sample Hindi vocab:", list(hi_vocab.items())[:20])

In [None]:
for example in tokenized_datasets["train"]:
    print("Tokenized English:", example["translation"]["en"])
    print("Tokenized Hindi:", example["translation"]["hi"])
    break


In [None]:
import matplotlib.pyplot as plt
fig = plt.figure(figsize=(5, 3))
ax = fig.add_subplot(1, 1, 1)
ax.plot(train_losses, label = 'train loss')
ax.plot(val_losses, label = 'valid loss')
plt.legend()
ax.set_xlabel('updates')
ax.set_ylabel('loss')

In [None]:
model.load_state_dict(torch.load(save_path))
test_loss = evaluate(model, test_loader, criterion, test_loader_length)

print(f'| Test Loss: {test_loss:.3f} | Test PPL: {math.exp(test_loss):7.3f} |')

# Test on some random data

In [91]:
sample = next(iter(train_dataset))

In [None]:
sample['translation']['en']

In [None]:
sample['translation']['hi']

In [None]:
src_text = text_transform[SRC_LANGUAGE](sample['translation']['en']).to(device)
src_text

In [None]:
trg_text = text_transform[TRG_LANGUAGE](sample['translation']['hi']).to(device)
trg_text

In [96]:
src_text = src_text.reshape(1, -1)  #because batch_size is 1

In [97]:
trg_text = trg_text.reshape(1, -1)

In [None]:
src_text.shape, trg_text.shape

In [99]:
text_length = torch.tensor([src_text.size(0)]).to(dtype=torch.int64)

In [None]:
model.load_state_dict(torch.load(save_path))

model.eval()
with torch.no_grad():
    output, attentions = model(src_text, trg_text) #turn off teacher forcing

In [None]:
output.shape #batch_size, trg_len, trg_output_dim

Since batch size is 1, we just take off that dimension

In [102]:
output = output.squeeze(0)

In [None]:
output.shape

We shall remove the first token since it's zeroes anyway

In [None]:
output = output[1:]
output.shape #trg_len, trg_output_dim

Then we just take the top token with highest probabilities

In [None]:
output_max = output.argmax(-2) #returns max indices

In [None]:
output_max

Get the mapping of the target language

In [124]:
hi_vocab['<unk>'] = 0
hi_vocab['<pad>'] = 1
hi_vocab['<sos>'] = 2
hi_vocab['<eos>'] = 3

In [125]:
mapping = list(hi_vocab.keys())

In [None]:
for token in output_max:
    print(mapping[token.item()])

# Attention

Let's display the attentions to understand how the source text links with the generated text

In [None]:
attentions.shape

Since there are 8 heads, we can look at just 1 head for sake of simplicity.

In [None]:
attention = attentions[0, 0, :, :]
attention.shape

In [None]:
src_tokens = ['<sos>'] + token_transform[SRC_LANGUAGE](sample[0]) + ['<eos>']
src_tokens

In [None]:
trg_tokens = ['<sos>'] + [mapping[token.item()] for token in output_max]
trg_tokens

In [None]:
import matplotlib.ticker as ticker

def display_attention(sentence, translation, attention):
    
    fig = plt.figure(figsize=(10,10))
    ax = fig.add_subplot(111)
    
    attention = attention.squeeze(1).cpu().detach().numpy()
    
    cax = ax.matshow(attention, cmap='bone')
   
    ax.tick_params(labelsize=10)
    
    y_ticks =  [''] + translation
    x_ticks =  [''] + sentence 
     
    ax.set_xticklabels(x_ticks, rotation=45)
    ax.set_yticklabels(y_ticks)

    ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
    ax.yaxis.set_major_locator(ticker.MultipleLocator(1))

    plt.show()
    plt.close()

In [None]:
display_attention(src_tokens, trg_tokens, attention)