In [43]:
from transformer import Transformer # this is the transformer.py file
import torch
import numpy as np

In [64]:
english_file = '../data/dev.en'
sanskrit_file = '../data/dev.sn'

# Generated this by filtering Appendix code

START_TOKEN = '<START>'
PADDING_TOKEN = '<PADDING>'
END_TOKEN = '<END>'

sanskrit_vocabulary = [
    START_TOKEN, ' ', '!', '"', "'", '(', ')', ',', '-', '.', '?', ':', ';',
    '0', '1', '2', '3', '4', '5', '6', '7', '8', '9',

    # Independent vowels
    'अ', 'आ', 'इ', 'ई', 'उ', 'ऊ', 'ऋ', 'ॠ', 'ऌ', 'ॡ', 'ए', 'ऐ', 'ओ', 'औ',
    
    # Consonants
    'क', 'ख', 'ग', 'घ', 'ङ',
    'च', 'छ', 'ज', 'झ', 'ञ',
    'ट', 'ठ', 'ड', 'ढ', 'ण',
    'त', 'थ', 'द', 'ध', 'न',
    'प', 'फ', 'ब', 'भ', 'म',
    'य', 'र', 'ल', 'व',
    'श', 'ष', 'स', 'ह',

    # Vowel signs
    'ा', 'ि', 'ी', 'ु', 'ू', 'ृ', 'ॄ', 'े', 'ै', 'ो', 'ौ',

    # Other signs
    'ं', 'ः', 'ँ', '्',  # virama, anusvara, visarga, chandrabindu
    '।', '॥',  # danda marks

    # Special tokens
    PADDING_TOKEN, END_TOKEN
]

english_vocabulary = [START_TOKEN, ' ', '!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/', 
                        '0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
                        ':', ';', '<', '=', '>', '?', '@', 
                        'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 
                        'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 
                        'Y', 'Z',
                        '_',
                        'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l',
                        'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 
                        'y', 'z', 
                        '{', '|', '}', '~', '·', 'º',
                        
                        # Extended Latin characters found in the data
                        'á', 'â', 'ã', 'ä', 'å', 'ç', 'é', 'î', 'ñ', 'ú', 'ü', 'ă', 'ć', 'ę', 'ı', 'ļ', 'ł', 'ņ',
                        'Ś', 'ś', 'Ş', 'ş', 'Š', 'š', 'ţ', 'ſ', 'ș', 'ț', 'ə',
                        
                        # IAST transliteration characters
                        'ā', 'ī', 'ū', 'ṛ', 'ṝ', 'ḷ', 'ḹ', 'ṅ', 'ṭ', 'ḍ', 'ṇ', 'ṣ',
                        'Ā', 'Ī', 'Ū', 'Ṛ', 'Ṝ', 'Ḷ', 'Ḹ', 'Ṅ', 'Ṭ', 'Ḍ', 'Ṇ', 'Ṣ','о','ả',
                        
                        # Vietnamese characters
                        'ả', 'ặ', 'ị',
                        
                        # Cyrillic (found in data)
                        'О',
                        
                        # Devanagari characters found in English text
                        'ं', 'उ', 'ए', 'क', 'च', 'त', 'द', 'ध', 'न', 'भ', 'म', 'र', 'ल', 'व', 'श', 'स', 
                        'ा', 'ि', 'ु', 'ै', 'ो', '्', '।', '॥',
                        
                        # Special punctuation (en dash, em dash, curved quotes)
                        '–', '—', '\u2018', '\u201c', '\u201d',
                        
                        PADDING_TOKEN, END_TOKEN]

In [65]:
index_to_sanskrit = {k:v for k,v in enumerate(sanskrit_vocabulary)}
sanskrit_to_index = {v:k for k,v in enumerate(sanskrit_vocabulary)}
index_to_english = {k:v for k,v in enumerate(english_vocabulary)}
english_to_index = {v:k for k,v in enumerate(english_vocabulary)}

In [66]:
with open(english_file, 'r') as file:
    english_sentences = file.readlines()
with open(sanskrit_file, 'r') as file:
    sanskrit_sentences = file.readlines()

# Limit Number of sentences
TOTAL_SENTENCES = 200000
english_sentences = english_sentences[:TOTAL_SENTENCES]
sanskrit_sentences = sanskrit_sentences[:TOTAL_SENTENCES]
english_sentences = [sentence.rstrip('\n').lower() for sentence in english_sentences]
sanskrit_sentences = [sentence.rstrip('\n') for sentence in sanskrit_sentences]

In [67]:
english_sentences[:10]

['when şītā, having a husband although seeming as if she had none, was putting on the ascetic guise, the people got into a wrath and exclaimed, “o dasaratha, fie on you!"',
 'aggrieved at the uproar that arose there in consequence, the lord of earth banished from his heart all regard for life, virtue, and fame. and sighing hot, that descendant of ikşvāku spoke to that wife of his, saying, o kaikeyi, sītā deserves not to go in a kuća dress.',
 'tender, and youthful, and worthy of happiness, she is by no means capable of living in the forest. my spiritual guide has spoken the truth.',
 'whom has this one injured that, being the daughter of the foremost of kings, she like a female ascetic, wearing a meagre garb in the presence of all, will (repair to the woods and) remain there like a beggar destitute of everything?',
 "let janaka's daughter leave off her ascetic guise. this is not the promise that i had made to you before. let the princess go to the forest in comfort, furnished with all 

In [68]:
sanskrit_sentences[:10]

['तस्यां चीरं वसानायां नाथवत्यामनाथवत्। प्रचुक्रोश जनः सर्वो धिक् त्वां दशरथं त्विति ॥',
 'तेन तत्र प्रणादेन दुःखितः स महीपतिः। चिच्छेद जीविते श्रद्धां धर्मे यशसि चात्मनः॥ स निःश्वस्योष्णमैक्ष्वाकस्तां भार्यामिदमब्रवीत्। कैकेयि कुशचीरेण न सीता गन्तुमर्हति॥',
 'सुकुमारी च बाला च सततं च सुखोचिता। नेयं वनस्य योग्येति सत्यमाह गुरुर्मम ॥',
 'इयं हि कस्यापि करोति किंचित् तपस्विनी राजवरस्य पुत्री। या चीरमासाद्य वनस्य मध्ये जाता विसंज्ञा श्रमणीव काचित्॥',
 'चीराण्यपास्याज्जनकस्य कन्या नेयं प्रतिज्ञा मम दत्तपूर्वा। यथासुखं गच्छतु राजपुत्री वनं समग्रा सह सर्वरत्नैः॥',
 'अजीवनाहेण मया नृशंसा कृता प्रतिज्ञा नियमेन तावत्। त्वया हि बाल्यात् प्रतिपन्नमेतत् तन्मा दहेद् वेणुमिवात्मपुष्पम्॥',
 'रामेण यदि ते पापे किंचित्कृतमशोभनम्। अपकारः क इह ते वैदेह्या दर्शितोऽधमे॥',
 'मृगीवोत्फुल्लनयना मृदुशीला मनस्विनी। अपकारं कमिव ते करोति जनकात्मजा॥',
 'ननु पर्याप्तमेवं ते पापे रामविवासनम्। किमेभिः कृपणैर्भूयः पातकैरपि ते कृतैः॥',
 'प्रतिज्ञातं मया तावत् त्वयोक्तं देवि शृण्वता। रामं यदभिषेकाय त्वमिहागतमब्रवीः॥']

In [69]:
import numpy as np
PERCENTILE = 97
print( f"{PERCENTILE}th percentile length sanskrit: {np.percentile([len(x) for x in sanskrit_sentences], PERCENTILE)}" )
print( f"{PERCENTILE}th percentile length English: {np.percentile([len(x) for x in english_sentences], PERCENTILE)}" )


97th percentile length sanskrit: 219.0
97th percentile length English: 388.0


In [70]:
max_sequence_length = 200

def is_valid_tokens(sentence, vocab):
    for token in list(set(sentence)):
        if token not in vocab:
            return False
    return True

def is_valid_length(sentence, max_sequence_length):
    return len(list(sentence)) < (max_sequence_length - 1) # need to re-add the end token so leaving 1 space

valid_sentence_indicies = []
for index in range(len(sanskrit_sentences)):
    sanskrit_sentence, english_sentence = sanskrit_sentences[index], english_sentences[index]
    if is_valid_length(sanskrit_sentence, max_sequence_length) \
      and is_valid_length(english_sentence, max_sequence_length) \
      and is_valid_tokens(sanskrit_sentence, sanskrit_vocabulary):
        valid_sentence_indicies.append(index)

print(f"Number of sentences: {len(sanskrit_sentences)}")
print(f"Number of valid sentences: {len(valid_sentence_indicies)}")

Number of sentences: 6148
Number of valid sentences: 3832


In [71]:
sanskrit_sentences = [sanskrit_sentences[i] for i in valid_sentence_indicies]
english_sentences = [english_sentences[i] for i in valid_sentence_indicies]

In [72]:
sanskrit_sentences[:3]

['तस्यां चीरं वसानायां नाथवत्यामनाथवत्। प्रचुक्रोश जनः सर्वो धिक् त्वां दशरथं त्विति ॥',
 'सुकुमारी च बाला च सततं च सुखोचिता। नेयं वनस्य योग्येति सत्यमाह गुरुर्मम ॥',
 'चीराण्यपास्याज्जनकस्य कन्या नेयं प्रतिज्ञा मम दत्तपूर्वा। यथासुखं गच्छतु राजपुत्री वनं समग्रा सह सर्वरत्नैः॥']

In [73]:
import torch

d_model = 512
batch_size = 30
ffn_hidden = 2048
num_heads = 8
drop_prob = 0.1
num_layers = 1
max_sequence_length = 200
kn_vocab_size = len(sanskrit_vocabulary)

transformer = Transformer(d_model, 
                          ffn_hidden,
                          num_heads, 
                          drop_prob, 
                          num_layers, 
                          max_sequence_length,
                          kn_vocab_size,
                          english_to_index,
                          sanskrit_to_index,
                          START_TOKEN, 
                          END_TOKEN, 
                          PADDING_TOKEN)

In [74]:
transformer

Transformer(
  (encoder): Encoder(
    (sentence_embedding): SentenceEmbedding(
      (embedding): Embedding(182, 512)
      (position_encoder): PositionalEncoding()
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (layers): SequentialEncoder(
      (0): EncoderLayer(
        (attention): MultiHeadAttention(
          (qkv_layer): Linear(in_features=512, out_features=1536, bias=True)
          (linear_layer): Linear(in_features=512, out_features=512, bias=True)
        )
        (norm1): LayerNormalization()
        (dropout1): Dropout(p=0.1, inplace=False)
        (ffn): PositionwiseFeedForward(
          (linear1): Linear(in_features=512, out_features=2048, bias=True)
          (linear2): Linear(in_features=2048, out_features=512, bias=True)
          (relu): ReLU()
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (norm2): LayerNormalization()
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (decoder): Decoder(
    (sentence_embedding)

In [75]:
from torch.utils.data import Dataset, DataLoader

class TextDataset(Dataset):

    def __init__(self, english_sentences, sanskrit_sentences):
        self.english_sentences = english_sentences
        self.sanskrit_sentences = sanskrit_sentences

    def __len__(self):
        return len(self.english_sentences)

    def __getitem__(self, idx):
        return self.english_sentences[idx], self.sanskrit_sentences[idx]

In [76]:
dataset = TextDataset(english_sentences, sanskrit_sentences)

In [77]:
len(dataset)

3832

In [78]:
dataset[1]

('tender, and youthful, and worthy of happiness, she is by no means capable of living in the forest. my spiritual guide has spoken the truth.',
 'सुकुमारी च बाला च सततं च सुखोचिता। नेयं वनस्य योग्येति सत्यमाह गुरुर्मम ॥')

In [79]:
train_loader = DataLoader(dataset, batch_size)
iterator = iter(train_loader)

In [80]:
for batch_num, batch in enumerate(iterator):
    print(batch)
    if batch_num > 3:
        break

[('when şītā, having a husband although seeming as if she had none, was putting on the ascetic guise, the people got into a wrath and exclaimed, “o dasaratha, fie on you!"', 'tender, and youthful, and worthy of happiness, she is by no means capable of living in the forest. my spiritual guide has spoken the truth.', "let janaka's daughter leave off her ascetic guise. this is not the promise that i had made to you before. let the princess go to the forest in comfort, furnished with all sorts of gems.", "of eyes expanded like those of a doe, endued with a mild temperament, and virtuous, what harm has janaka's daughter done you.", 'surely, o nefarious one, the banishment of ráma is enough for you. why then do you bend your mind to perpetrate these atrocious sins?', 'o noble dame, having heard you asking for the banishment of rāma, who had at first been intended by me for being installed, and who came here afterwards, i had promised you (his exile alone.)', "but since, going beyond that pro

In [81]:
from torch import nn

criterian = nn.CrossEntropyLoss(ignore_index=sanskrit_to_index[PADDING_TOKEN],
                                reduction='none')

# When computing the loss, we are ignoring cases when the label is the padding token
for params in transformer.parameters():
    if params.dim() > 1:
        nn.init.xavier_uniform_(params)

optim = torch.optim.Adam(transformer.parameters(), lr=1e-4)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [82]:
NEG_INFTY = -1e9

def create_masks(eng_batch, kn_batch):
    num_sentences = len(eng_batch)
    look_ahead_mask = torch.full([max_sequence_length, max_sequence_length] , True)
    look_ahead_mask = torch.triu(look_ahead_mask, diagonal=1)
    encoder_padding_mask = torch.full([num_sentences, max_sequence_length, max_sequence_length] , False)
    decoder_padding_mask_self_attention = torch.full([num_sentences, max_sequence_length, max_sequence_length] , False)
    decoder_padding_mask_cross_attention = torch.full([num_sentences, max_sequence_length, max_sequence_length] , False)

    for idx in range(num_sentences):
      eng_sentence_length, kn_sentence_length = len(eng_batch[idx]), len(kn_batch[idx])
      eng_chars_to_padding_mask = np.arange(eng_sentence_length + 1, max_sequence_length)
      kn_chars_to_padding_mask = np.arange(kn_sentence_length + 1, max_sequence_length)
      encoder_padding_mask[idx, :, eng_chars_to_padding_mask] = True
      encoder_padding_mask[idx, eng_chars_to_padding_mask, :] = True
      decoder_padding_mask_self_attention[idx, :, kn_chars_to_padding_mask] = True
      decoder_padding_mask_self_attention[idx, kn_chars_to_padding_mask, :] = True
      decoder_padding_mask_cross_attention[idx, :, eng_chars_to_padding_mask] = True
      decoder_padding_mask_cross_attention[idx, kn_chars_to_padding_mask, :] = True

    encoder_self_attention_mask = torch.where(encoder_padding_mask, NEG_INFTY, 0)
    decoder_self_attention_mask =  torch.where(look_ahead_mask + decoder_padding_mask_self_attention, NEG_INFTY, 0)
    decoder_cross_attention_mask = torch.where(decoder_padding_mask_cross_attention, NEG_INFTY, 0)
    return encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask

Modify mask such that the padding tokens cannot look ahead.
In Encoder, tokens before it should be -1e9 while tokens after it should be -inf.
 

Note the target mask starts with 2 rows of non masked items: https://github.com/SamLynnEvans/Transformer/blob/master/Beam.py#L55


In [83]:
transformer.train()
transformer.to(device)
total_loss = 0
num_epochs = 10

for epoch in range(num_epochs):
    print(f"Epoch {epoch}")
    iterator = iter(train_loader)
    for batch_num, batch in enumerate(iterator):
        transformer.train()
        eng_batch, kn_batch = batch
        encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask = create_masks(eng_batch, kn_batch)
        optim.zero_grad()
        kn_predictions = transformer(eng_batch,
                                     kn_batch,
                                     encoder_self_attention_mask.to(device), 
                                     decoder_self_attention_mask.to(device), 
                                     decoder_cross_attention_mask.to(device),
                                     enc_start_token=False,
                                     enc_end_token=False,
                                     dec_start_token=True,
                                     dec_end_token=True)
        labels = transformer.decoder.sentence_embedding.batch_tokenize(kn_batch, start_token=False, end_token=True)
        loss = criterian(
            kn_predictions.view(-1, kn_vocab_size).to(device),
            labels.view(-1).to(device)
        ).to(device)
        valid_indicies = torch.where(labels.view(-1) == sanskrit_to_index[PADDING_TOKEN], False, True)
        loss = loss.sum() / valid_indicies.sum()
        loss.backward()
        optim.step()
        #train_losses.append(loss.item())
        if batch_num % 100 == 0:
            print(f"Iteration {batch_num} : {loss.item()}")
            print(f"English: {eng_batch[0]}")
            print(f"sanskrit Translation: {kn_batch[0]}")
            kn_sentence_predicted = torch.argmax(kn_predictions[0], axis=1)
            predicted_sentence = ""
            for idx in kn_sentence_predicted:
              if idx == sanskrit_to_index[END_TOKEN]:
                break
              predicted_sentence += index_to_sanskrit[idx.item()]
            print(f"sanskrit Prediction: {predicted_sentence}")


            transformer.eval()
            kn_sentence = ("",)
            eng_sentence = ("Do work don't expect result",)
            for word_counter in range(max_sequence_length):
                encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask= create_masks(eng_sentence, kn_sentence)
                predictions = transformer(eng_sentence,
                                          kn_sentence,
                                          encoder_self_attention_mask.to(device), 
                                          decoder_self_attention_mask.to(device), 
                                          decoder_cross_attention_mask.to(device),
                                          enc_start_token=False,
                                          enc_end_token=False,
                                          dec_start_token=True,
                                          dec_end_token=False)
                next_token_prob_distribution = predictions[0][word_counter] # not actual probs
                next_token_index = torch.argmax(next_token_prob_distribution).item()
                next_token = index_to_sanskrit[next_token_index]
                kn_sentence = (kn_sentence[0] + next_token, )
                if next_token == END_TOKEN:
                  break
            
            print(f"Evaluation translation (do work don't expect result) : {kn_sentence}")
            print("-------------------------------------------")

Epoch 0
Iteration 0 : 5.540466785430908
English: when şītā, having a husband although seeming as if she had none, was putting on the ascetic guise, the people got into a wrath and exclaimed, “o dasaratha, fie on you!"
sanskrit Translation: तस्यां चीरं वसानायां नाथवत्यामनाथवत्। प्रचुक्रोश जनः सर्वो धिक् त्वां दशरथं त्विति ॥
sanskrit Prediction: लल(लललघहॡहलएलललछूटभटटॡभएएएिललऔलहहललललॡहललललललभभलललॡऌ<START><START>लऌऌललललेे2हऐलललललललभल2ूएललभलभ2भलॡलॡॡॡलॡभएएएएएलॡभलल लेेभटटभहहएएलययललेेभभहेकेकभ-ल-ऐलभहहहलहललललहलललहल ूलललऐहललहेललललललललेेललेेलललललललटलललहललेलेहेेेेे
Evaluation translation (do work don't expect result) : ('ससससस्तततततततततततसससतततााााााासससाााााााताततततततततााततततततततााााातततततातत                      त                     ययेेेतेेेेेतततत---      ततततततततततत         ाााा    ााेेेेेेेालललेेेेलबेेेेेेेेेेेे',)
-------------------------------------------
Iteration 100 : 3.363590717315674
English: king yudhishthira immediately struck him with six arrows. he then cut off the bow and the sta

In [88]:
torch.save({
    'epoch': 9,  # last epoch you finished
    'model_state_dict': transformer.state_dict(),
    'optimizer_state_dict': optim.state_dict(),
    'loss': loss.item() if 'loss' in locals() else 0  # use 0 if loss is not defined
}, 'checkpoint.pth')

print("Checkpoint saved.")


Checkpoint saved.


In [89]:
import torch
import os

# Load model and optimizer state
checkpoint_path = "checkpoint.pth"
start_epoch = 0

if os.path.exists(checkpoint_path):
    checkpoint = torch.load(checkpoint_path, map_location=device)
    transformer.load_state_dict(checkpoint['model_state_dict'])
    optim.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch'] + 1
    print(f"Checkpoint loaded. Resuming training from epoch {start_epoch}")
else:
    print("No checkpoint found. Starting from scratch.")

# Continue training from saved state
transformer.train()
transformer.to(device)
total_loss = 0
num_epochs = 20  # set this to any number higher than start_epoch

for epoch in range(start_epoch, num_epochs):
    print(f"Epoch {epoch}")
    iterator = iter(train_loader)
    for batch_num, batch in enumerate(iterator):
        transformer.train()
        eng_batch, kn_batch = batch
        encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask = create_masks(eng_batch, kn_batch)
        optim.zero_grad()
        kn_predictions = transformer(eng_batch,
                                     kn_batch,
                                     encoder_self_attention_mask.to(device), 
                                     decoder_self_attention_mask.to(device), 
                                     decoder_cross_attention_mask.to(device),
                                     enc_start_token=False,
                                     enc_end_token=False,
                                     dec_start_token=True,
                                     dec_end_token=True)
        labels = transformer.decoder.sentence_embedding.batch_tokenize(kn_batch, start_token=False, end_token=True)
        loss = criterian(
            kn_predictions.view(-1, kn_vocab_size).to(device),
            labels.view(-1).to(device)
        ).to(device)
        valid_indicies = torch.where(labels.view(-1) == sanskrit_to_index[PADDING_TOKEN], False, True)
        loss = loss.sum() / valid_indicies.sum()
        loss.backward()
        optim.step()

        if batch_num % 100 == 0:
            print(f"Iteration {batch_num} : {loss.item()}")
            print(f"English: {eng_batch[0]}")
            print(f"sanskrit Translation: {kn_batch[0]}")
            kn_sentence_predicted = torch.argmax(kn_predictions[0], axis=1)
            predicted_sentence = ""
            for idx in kn_sentence_predicted:
                if idx == sanskrit_to_index[END_TOKEN]:
                    break
                predicted_sentence += index_to_sanskrit[idx.item()]
            print(f"sanskrit Prediction: {predicted_sentence}")

            transformer.eval()
            kn_sentence = ("",)
            eng_sentence = ("Do work don't expect result",)
            for word_counter in range(max_sequence_length):
                encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask= create_masks(eng_sentence, kn_sentence)
                predictions = transformer(eng_sentence,
                                          kn_sentence,
                                          encoder_self_attention_mask.to(device), 
                                          decoder_self_attention_mask.to(device), 
                                          decoder_cross_attention_mask.to(device),
                                          enc_start_token=False,
                                          enc_end_token=False,
                                          dec_start_token=True,
                                          dec_end_token=False)
                next_token_prob_distribution = predictions[0][word_counter]
                next_token_index = torch.argmax(next_token_prob_distribution).item()
                next_token = index_to_sanskrit[next_token_index]
                kn_sentence = (kn_sentence[0] + next_token, )
                if next_token == END_TOKEN:
                    break

            print(f"Evaluation translation (do work don't expect result) : {kn_sentence}")
            print("-------------------------------------------")

    # ✅ Save checkpoint after each epoch
    torch.save({
        'epoch': epoch,
        'model_state_dict': transformer.state_dict(),
        'optimizer_state_dict': optim.state_dict(),
        'loss': loss.item()
    }, checkpoint_path)


Checkpoint loaded. Resuming training from epoch 10
Epoch 10
Iteration 0 : 2.645613670349121
English: when şītā, having a husband although seeming as if she had none, was putting on the ascetic guise, the people got into a wrath and exclaimed, “o dasaratha, fie on you!"
sanskrit Translation: तस्यां चीरं वसानायां नाथवत्यामनाथवत्। प्रचुक्रोश जनः सर्वो धिक् त्वां दशरथं त्विति ॥
sanskrit Prediction: तत्मा  उ व् सा्न्ं   प्राा्र  ्् ाा्। प्य् र्र् ्प्् प््या तात्रप्यान स्््ा प्यात्तप
Evaluation translation (do work don't expect result) : ('ततस त तरवरररारात्त तररातस वररररररः। तररः। तः। तरारा तरात त त त त तरारारः॥<END>',)
-------------------------------------------
Iteration 100 : 2.632965326309204
English: king yudhishthira immediately struck him with six arrows. he then cut off the bow and the standard of his antagonist with two razorfaced arrows.
sanskrit Translation: तं विव्याधशुगैः षड्भिर्धर्मराजस्त्वरनिव॥ कार्मुकं चास्य चिच्छेद क्षुराभ्यां ध्वजमेव च।
sanskrit Prediction: तत साराया ा्रार 

In [96]:
torch.save({
    'epoch': 19,
    'model_state_dict': transformer.state_dict(),
    'optimizer_state_dict': optim.state_dict(),
    'loss': loss.item()
}, 'checkpoint.pth')


In [97]:
import torch
import os

checkpoint_path = "checkpoint.pth"
start_epoch = 0
total_epochs = 50  # You want to train up to epoch 49

if os.path.exists(checkpoint_path):
    checkpoint = torch.load(checkpoint_path, map_location=device)
    transformer.load_state_dict(checkpoint['model_state_dict'])
    optim.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch'] + 1
    print(f"Loaded checkpoint from epoch {checkpoint['epoch']}, resuming from epoch {start_epoch}")
else:
    print("No checkpoint found. Starting from epoch 0.")

# === Resume training from epoch 20 up to 49 ===
transformer.train()
transformer.to(device)

for epoch in range(start_epoch, total_epochs):
    print(f"Epoch {epoch}")
    iterator = iter(train_loader)
    for batch_num, batch in enumerate(iterator):
        transformer.train()
        eng_batch, kn_batch = batch
        encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask = create_masks(eng_batch, kn_batch)
        optim.zero_grad()
        kn_predictions = transformer(eng_batch,
                                     kn_batch,
                                     encoder_self_attention_mask.to(device), 
                                     decoder_self_attention_mask.to(device), 
                                     decoder_cross_attention_mask.to(device),
                                     enc_start_token=False,
                                     enc_end_token=False,
                                     dec_start_token=True,
                                     dec_end_token=True)
        labels = transformer.decoder.sentence_embedding.batch_tokenize(kn_batch, start_token=False, end_token=True)
        loss = criterian(
            kn_predictions.view(-1, kn_vocab_size).to(device),
            labels.view(-1).to(device)
        ).to(device)
        valid_indicies = torch.where(labels.view(-1) == sanskrit_to_index[PADDING_TOKEN], False, True)
        loss = loss.sum() / valid_indicies.sum()
        loss.backward()
        optim.step()

        if batch_num % 100 == 0:
            print(f"Iteration {batch_num} : {loss.item()}")
            print(f"English: {eng_batch[0]}")
            print(f"Sanskrit Target: {kn_batch[0]}")
            kn_sentence_predicted = torch.argmax(kn_predictions[0], axis=1)
            predicted_sentence = ""
            for idx in kn_sentence_predicted:
                if idx == sanskrit_to_index[END_TOKEN]:
                    break
                predicted_sentence += index_to_sanskrit[idx.item()]
            print(f"Sanskrit Prediction: {predicted_sentence}")

            # Evaluation example
            transformer.eval()
            kn_sentence = ("",)
            eng_sentence = ("Do work don't expect result",)
            for word_counter in range(max_sequence_length):
                encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask = create_masks(eng_sentence, kn_sentence)
                predictions = transformer(eng_sentence,
                                          kn_sentence,
                                          encoder_self_attention_mask.to(device), 
                                          decoder_self_attention_mask.to(device), 
                                          decoder_cross_attention_mask.to(device),
                                          enc_start_token=False,
                                          enc_end_token=False,
                                          dec_start_token=True,
                                          dec_end_token=False)
                next_token_prob_distribution = predictions[0][word_counter]
                next_token_index = torch.argmax(next_token_prob_distribution).item()
                next_token = index_to_sanskrit[next_token_index]
                kn_sentence = (kn_sentence[0] + next_token, )
                if next_token == END_TOKEN:
                    break
            print(f"Eval (Do work don't expect result): {kn_sentence}")
            print("--------------------------------------------------")

    # ✅ Save checkpoint after each epoch
    torch.save({
        'epoch': epoch,
        'model_state_dict': transformer.state_dict(),
        'optimizer_state_dict': optim.state_dict(),
        'loss': loss.item()
    }, checkpoint_path)


Loaded checkpoint from epoch 19, resuming from epoch 20
Epoch 20
Iteration 0 : 2.4425156116485596
English: when şītā, having a husband although seeming as if she had none, was putting on the ascetic guise, the people got into a wrath and exclaimed, “o dasaratha, fie on you!"
Sanskrit Target: तस्यां चीरं वसानायां नाथवत्यामनाथवत्। प्रचुक्रोश जनः सर्वो धिक् त्वां दशरथं त्विति ॥
Sanskrit Prediction: त््तान प र् सि्न् ां पिराा्रा ि्नाा्। सुरा्त्षा ्पन् प््वा सरत्षपेरा  पु््ा प्॥ाति प
Eval (Do work don't expect result): ('तस रज उवच तरामहान तर्य रवर्षरमहरः पररः। दर्य दरमहराजराजव दर्य तसर्य रव<END>',)
--------------------------------------------------
Iteration 100 : 2.4286437034606934
English: king yudhishthira immediately struck him with six arrows. he then cut off the bow and the standard of his antagonist with two razorfaced arrows.
Sanskrit Target: तं विव्याधशुगैः षड्भिर्धर्मराजस्त्वरनिव॥ कार्मुकं चास्य चिच्छेद क्षुराभ्यां ध्वजमेव च।
Sanskrit Prediction: तत रि ायानि् ्र तु्रूत्वि्मााता्त्

## Inference

In [98]:
transformer.eval()
def translate(eng_sentence):
  eng_sentence = (eng_sentence,)
  kn_sentence = ("",)
  for word_counter in range(max_sequence_length):
    encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask= create_masks(eng_sentence, kn_sentence)
    predictions = transformer(eng_sentence,
                              kn_sentence,
                              encoder_self_attention_mask.to(device), 
                              decoder_self_attention_mask.to(device), 
                              decoder_cross_attention_mask.to(device),
                              enc_start_token=False,
                              enc_end_token=False,
                              dec_start_token=True,
                              dec_end_token=False)
    next_token_prob_distribution = predictions[0][word_counter]
    next_token_index = torch.argmax(next_token_prob_distribution).item()
    next_token = index_to_sanskrit[next_token_index]
    kn_sentence = (kn_sentence[0] + next_token, )
    if next_token == END_TOKEN:
      break
  return kn_sentence[0]

In [111]:
#Epoch of 20
translation = translate("Your right is to perform your duty only, but never to its fruits.")
print(translation)
#कर्मण्येवाधिकारस्ते मा फलेषु कदाचन


ततो मा महार्थानां तु प्रत्ये तव त्वं त्वं त्वं त्वं त्वं तेन तेन महान्॥<END>


In [99]:
#Epoch of 50
translation = translate("""Your right is to perform your duty only, but never to its fruits." This quote emphasizes focusing on action rather than the outcome""")
print(translation)
#कर्मण्येवाधिकारस्ते मा फलेषु कदाचन

ततो युद्धा महाराज न पुन्त्रामि प्रति। अन्त्र समान्यां प्राप्य संग्राणां प्रवीत्॥<END>


In [95]:
#After 20 epoches
translation = translate("""When Şītā, having a husband although seeming as if she had none, was putting on the ascetic guise, the people got into a wrath and exclaimed, “O Dasaratha, fie on you!" """)
print(translation)

तत्राण्या सम्या काल्वा प्राण्या प्राण्डवान्। प्राज्च प्राण्या प्राण्या प्राज्यान्ति प्राज्ट्ट्राणा प्राज्या सराज्यान्या प्यान्यान्यानित्या स्यारि प्या स्या प्यास्यामानितिते॥<END>


In [100]:
#After 50 epoches
translation = translate("""When Şītā, having a husband although seeming as if she had none, was putting on the ascetic guise, the people got into a wrath and exclaimed, “O Dasaratha, fie on you!" """)
print(translation)

ततः स तु सुमुख्यान् सर्वे प्राप्तथा। प्राप्ता न विद्धा न त्वं प्राप्ता न च पुनः॥<END>


In [93]:
#After 20 epoches
translation = translate("i am here")
print(translation)

स<END>


In [101]:
#After 50 epoches
translation = translate("i am here")
print(translation)

यथ<END>


In [102]:
translation = translate("why did they do this?")
print(translation)


यथे वय उवा वय तस्वव क्षव प्रवय वयस्वय त्वय तस्वव वय त्वयस्वस्वय वय वय॥<END>


In [103]:
translation = translate("i am well.")
print(translation)


यस उवा वच<END>


In [108]:
translation = translate(" do you enjoy this kingdom rid of your thorn")
translation

'ततः स्थेवा तु ते ते तव त्वं ते तव ते तव ते॥<END>'

In [119]:
translation = translate("your exertions are to no purpose")
print(translation)


यथा तं यवच ते वय त्वं तं पररिच प्रववच। ते वं तं वच ते वय त्वं संय वचेव तं तव॥<END>


In [None]:
# import pickle
# import os

# # Create models directory if it doesn't exist
# models_dir = '../models'
# os.makedirs(models_dir, exist_ok=True)

# # Create comprehensive vocabulary data structure
# vocabulary_data = {
#     'sanskrit_vocabulary': sanskrit_vocabulary,
#     'english_vocabulary': english_vocabulary,
#     'sanskrit_to_index': sanskrit_to_index,
#     'index_to_sanskrit': index_to_sanskrit,
#     'english_to_index': english_to_index,
#     'index_to_english': index_to_english,
#     'special_tokens': {
#         'START_TOKEN': START_TOKEN,
#         'END_TOKEN': END_TOKEN,
#         'PADDING_TOKEN': PADDING_TOKEN
#     },
#     'vocab_sizes': {
#         'sanskrit_vocab_size': len(sanskrit_vocabulary),
#         'english_vocab_size': len(english_vocabulary)
#     },
#     'model_params': {
#         'max_sequence_length': max_sequence_length,
#         'd_model': d_model,
#         'num_heads': num_heads,
#         'num_layers': num_layers,
#         'ffn_hidden': ffn_hidden,
#         'drop_prob': drop_prob
#     }
# }

# # Save to pickle file
# pkl_file_path = os.path.join(models_dir, 'sanskrit_vocabulary.pkl')
# with open(pkl_file_path, 'wb') as f:
#     pickle.dump(vocabulary_data, f)

# print(f"✅ Sanskrit vocabulary data exported to: {pkl_file_path}")
# print(f"📊 Contents saved:")
# print(f"   - Sanskrit vocabulary: {len(sanskrit_vocabulary)} tokens")
# print(f"   - English vocabulary: {len(english_vocabulary)} tokens")
# print(f"   - Index mappings for both languages")
# print(f"   - Special tokens: {START_TOKEN}, {END_TOKEN}, {PADDING_TOKEN}")
# print(f"   - Model parameters")

# # Verify the file was created
# if os.path.exists(pkl_file_path):
#     file_size = os.path.getsize(pkl_file_path)
#     print(f"📁 File size: {file_size:,} bytes")
# else:
#     print("❌ Error: File was not created successfully")

✅ Sanskrit vocabulary data exported to: ../models/sanskrit_vocabulary.pkl
📊 Contents saved:
   - Sanskrit vocabulary: 89 tokens
   - English vocabulary: 183 tokens
   - Index mappings for both languages
   - Special tokens: <START>, <END>, <PADDING>
   - Model parameters
📁 File size: 3,910 bytes


In [None]:
# import pickle
# import os

# # Export Sanskrit vocabulary and mappings to pickle file
# vocabulary_data = {
#     'sanskrit_vocabulary': sanskrit_vocabulary,
#     'english_vocabulary': english_vocabulary,
#     'sanskrit_to_index': sanskrit_to_index,
#     'index_to_sanskrit': index_to_sanskrit,
#     'english_to_index': english_to_index,
#     'index_to_english': index_to_english,
#     'vocab_size': len(sanskrit_vocabulary),
#     'special_tokens': {
#         'START_TOKEN': START_TOKEN,
#         'PADDING_TOKEN': PADDING_TOKEN,
#         'END_TOKEN': END_TOKEN
#     }
# }

# # Create output directory if it doesn't exist
# os.makedirs('../models', exist_ok=True)

# # Save vocabulary to pickle file
# vocab_file_path = '../models/sanskrit_vocabulary.pkl'
# with open(vocab_file_path, 'wb') as f:
#     pickle.dump(vocabulary_data, f)

# print(f"Sanskrit vocabulary exported to: {vocab_file_path}")
# print(f"Vocabulary size: {len(sanskrit_vocabulary)} characters")
# print(f"Contains: {len(vocabulary_data)} data fields")
# print("\nExported data includes:")
# for key in vocabulary_data.keys():
#     if isinstance(vocabulary_data[key], dict):
#         print(f"  - {key}: {len(vocabulary_data[key])} items")
#     elif isinstance(vocabulary_data[key], list):
#         print(f"  - {key}: {len(vocabulary_data[key])} items")
#     else:
#         print(f"  - {key}: {vocabulary_data[key]}")

Sanskrit vocabulary exported to: ../models/sanskrit_vocabulary.pkl
Vocabulary size: 89 characters
Contains: 8 data fields

Exported data includes:
  - sanskrit_vocabulary: 89 items
  - english_vocabulary: 183 items
  - sanskrit_to_index: 89 items
  - index_to_sanskrit: 89 items
  - english_to_index: 182 items
  - index_to_english: 183 items
  - vocab_size: 89
  - special_tokens: 3 items
