In [111]:
from transformer import Transformer # this is the transformer.py file
import torch
import numpy as np

# Check CUDA availability
print("CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("CUDA version:", torch.version.cuda)
    print("Number of GPUs:", torch.cuda.device_count())
    print("Current GPU:", torch.cuda.current_device())
    print("GPU name:", torch.cuda.get_device_name(0))
    print("GPU memory:", f"{torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
else:
    print("CUDA not available - will use CPU")

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

CUDA available: False
CUDA not available - will use CPU
Using device: cpu


In [112]:
# GPU Memory Management Utilities
def clear_gpu_memory():
    """Clear GPU memory cache"""
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        print("GPU memory cache cleared")
    else:
        print("CUDA not available")

def print_gpu_memory():
    """Print current GPU memory usage"""
    if torch.cuda.is_available():
        allocated = torch.cuda.memory_allocated() / 1024**3
        cached = torch.cuda.memory_reserved() / 1024**3
        total = torch.cuda.get_device_properties(0).total_memory / 1024**3
        print(f"GPU Memory - Allocated: {allocated:.2f} GB, Cached: {cached:.2f} GB, Total: {total:.1f} GB")
    else:
        print("CUDA not available")

# Clear any existing GPU memory
clear_gpu_memory()
print_gpu_memory()

CUDA not available
CUDA not available


In [113]:
english_file = '../data/dev.en'
sanskrit_file = '../data/dev.sn'

START_TOKEN = ''
PADDING_TOKEN = ''
END_TOKEN = ''

sanskrit_vocabulary = [
    START_TOKEN, ' ', '!', '"', "'", '(', ')', ',', '-', '.', '?', ':', ';',
    '0', '1', '2', '3', '4', '5', '6', '7', '8', '9',

    # Independent vowels
    'अ', 'आ', 'इ', 'ई', 'उ', 'ऊ', 'ऋ', 'ॠ', 'ऌ', 'ॡ', 'ए', 'ऐ', 'ओ', 'औ',
    
    # Consonants
    'क', 'ख', 'ग', 'घ', 'ङ',
    'च', 'छ', 'ज', 'झ', 'ञ',
    'ट', 'ठ', 'ड', 'ढ', 'ण',
    'त', 'थ', 'द', 'ध', 'न',
    'प', 'फ', 'ब', 'भ', 'म',
    'य', 'र', 'ल', 'व',
    'श', 'ष', 'स', 'ह',

    # Vowel signs
    'ा', 'ि', 'ी', 'ु', 'ू', 'ृ', 'ॄ', 'े', 'ै', 'ो', 'ौ',

    # Other signs
    'ं', 'ः', 'ँ', '्',  # virama, anusvara, visarga, chandrabindu
    '।', '॥',  # danda marks

    # Special tokens
    PADDING_TOKEN, END_TOKEN
]

english_vocabulary = [START_TOKEN, ' ', '!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/', 
                        '0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
                        ':', ';', '<', '=', '>', '?', '@', 
                        'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 
                        'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 
                        'Y', 'Z',
                        '_',
                        'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l',
                        'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 
                        'y', 'z', 
                        '{', '|', '}', '~', '·', 'º',
                        
                        # Extended Latin characters found in the data
                        'á', 'â', 'ã', 'ä', 'å', 'ç', 'é', 'î', 'ñ', 'ú', 'ü', 'ă', 'ć', 'ę', 'ı', 'ļ', 'ł', 'ņ',
                        'Ś', 'ś', 'Ş', 'ş', 'Š', 'š', 'ţ', 'ſ', 'ș', 'ț', 'ə',
                        
                        # IAST transliteration characters
                        'ā', 'ī', 'ū', 'ṛ', 'ṝ', 'ḷ', 'ḹ', 'ṅ', 'ṭ', 'ḍ', 'ṇ', 'ṣ',
                        'Ā', 'Ī', 'Ū', 'Ṛ', 'Ṝ', 'Ḷ', 'Ḹ', 'Ṅ', 'Ṭ', 'Ḍ', 'Ṇ', 'Ṣ',
                        
                        # Vietnamese characters
                        'ả', 'ặ', 'ị',
                        
                        # Cyrillic (found in data)
                        'О',
                        
                        # Devanagari characters found in English text
                        'ं', 'उ', 'ए', 'क', 'च', 'त', 'द', 'ध', 'न', 'भ', 'म', 'र', 'ल', 'व', 'श', 'स', 
                        'ा', 'ि', 'ु', 'ै', 'ो', '्', '।', '॥',
                        
                        # Special punctuation (en dash, em dash, curved quotes)
                        '–', '—', '\u2018', '\u201c', '\u201d',
                        
                        PADDING_TOKEN, END_TOKEN]

In [114]:
index_to_sanskrit = {k:v for k,v in enumerate(sanskrit_vocabulary)}
sanskrit_to_index = {v:k for k,v in enumerate(sanskrit_vocabulary)}
index_to_english = {k:v for k,v in enumerate(english_vocabulary)}
english_to_index = {v:k for k,v in enumerate(english_vocabulary)}

In [115]:
with open(english_file, 'r') as file:
    english_sentences = file.readlines()
with open(sanskrit_file, 'r') as file:
    sanskrit_sentences = file.readlines()

# Limit Number of sentences
TOTAL_SENTENCES = 200000
english_sentences = english_sentences[:TOTAL_SENTENCES]
sanskrit_sentences = sanskrit_sentences[:TOTAL_SENTENCES]
english_sentences = [sentence.rstrip('\n').lower() for sentence in english_sentences]
sanskrit_sentences = [sentence.rstrip('\n') for sentence in sanskrit_sentences]

In [116]:
english_sentences[:10]

['when şītā, having a husband although seeming as if she had none, was putting on the ascetic guise, the people got into a wrath and exclaimed, “o dasaratha, fie on you!"',
 'aggrieved at the uproar that arose there in consequence, the lord of earth banished from his heart all regard for life, virtue, and fame. and sighing hot, that descendant of ikşvāku spoke to that wife of his, saying, o kaikeyi, sītā deserves not to go in a kuća dress.',
 'tender, and youthful, and worthy of happiness, she is by no means capable of living in the forest. my spiritual guide has spoken the truth.',
 'whom has this one injured that, being the daughter of the foremost of kings, she like a female ascetic, wearing a meagre garb in the presence of all, will (repair to the woods and) remain there like a beggar destitute of everything?',
 "let janaka's daughter leave off her ascetic guise. this is not the promise that i had made to you before. let the princess go to the forest in comfort, furnished with all 

In [117]:
sanskrit_sentences[:10]

['तस्यां चीरं वसानायां नाथवत्यामनाथवत्। प्रचुक्रोश जनः सर्वो धिक् त्वां दशरथं त्विति ॥',
 'तेन तत्र प्रणादेन दुःखितः स महीपतिः। चिच्छेद जीविते श्रद्धां धर्मे यशसि चात्मनः॥ स निःश्वस्योष्णमैक्ष्वाकस्तां भार्यामिदमब्रवीत्। कैकेयि कुशचीरेण न सीता गन्तुमर्हति॥',
 'सुकुमारी च बाला च सततं च सुखोचिता। नेयं वनस्य योग्येति सत्यमाह गुरुर्मम ॥',
 'इयं हि कस्यापि करोति किंचित् तपस्विनी राजवरस्य पुत्री। या चीरमासाद्य वनस्य मध्ये जाता विसंज्ञा श्रमणीव काचित्॥',
 'चीराण्यपास्याज्जनकस्य कन्या नेयं प्रतिज्ञा मम दत्तपूर्वा। यथासुखं गच्छतु राजपुत्री वनं समग्रा सह सर्वरत्नैः॥',
 'अजीवनाहेण मया नृशंसा कृता प्रतिज्ञा नियमेन तावत्। त्वया हि बाल्यात् प्रतिपन्नमेतत् तन्मा दहेद् वेणुमिवात्मपुष्पम्॥',
 'रामेण यदि ते पापे किंचित्कृतमशोभनम्। अपकारः क इह ते वैदेह्या दर्शितोऽधमे॥',
 'मृगीवोत्फुल्लनयना मृदुशीला मनस्विनी। अपकारं कमिव ते करोति जनकात्मजा॥',
 'ननु पर्याप्तमेवं ते पापे रामविवासनम्। किमेभिः कृपणैर्भूयः पातकैरपि ते कृतैः॥',
 'प्रतिज्ञातं मया तावत् त्वयोक्तं देवि शृण्वता। रामं यदभिषेकाय त्वमिहागतमब्रवीः॥']

In [118]:
import numpy as np
PERCENTILE = 97
print( f"{PERCENTILE}th percentile length sanskrit: {np.percentile([len(x) for x in sanskrit_sentences], PERCENTILE)}" )
print( f"{PERCENTILE}th percentile length English: {np.percentile([len(x) for x in english_sentences], PERCENTILE)}" )


97th percentile length sanskrit: 219.0
97th percentile length English: 388.0


In [119]:
max_sequence_length = 200

def is_valid_tokens(sentence, vocab):
    for token in list(set(sentence)):
        if token not in vocab:
            return False
    return True

def is_valid_length(sentence, max_sequence_length):
    return len(list(sentence)) < (max_sequence_length - 1) # need to re-add the end token so leaving 1 space

valid_sentence_indicies = []
for index in range(len(sanskrit_sentences)):
    sanskrit_sentence, english_sentence = sanskrit_sentences[index], english_sentences[index]
    if is_valid_length(sanskrit_sentence, max_sequence_length) \
      and is_valid_length(english_sentence, max_sequence_length) \
      and is_valid_tokens(sanskrit_sentence, sanskrit_vocabulary):
        valid_sentence_indicies.append(index)

print(f"Number of sentences: {len(sanskrit_sentences)}")
print(f"Number of valid sentences: {len(valid_sentence_indicies)}")

Number of sentences: 6148
Number of valid sentences: 3832


In [120]:
sanskrit_sentences = [sanskrit_sentences[i] for i in valid_sentence_indicies]
english_sentences = [english_sentences[i] for i in valid_sentence_indicies]

In [121]:
sanskrit_sentences[:3]

['तस्यां चीरं वसानायां नाथवत्यामनाथवत्। प्रचुक्रोश जनः सर्वो धिक् त्वां दशरथं त्विति ॥',
 'सुकुमारी च बाला च सततं च सुखोचिता। नेयं वनस्य योग्येति सत्यमाह गुरुर्मम ॥',
 'चीराण्यपास्याज्जनकस्य कन्या नेयं प्रतिज्ञा मम दत्तपूर्वा। यथासुखं गच्छतु राजपुत्री वनं समग्रा सह सर्वरत्नैः॥']

In [122]:
# Add missing character to English vocabulary
if 'о' not in english_vocabulary:
    english_vocabulary.append('о')  # Cyrillic small letter O
    print("Added missing Cyrillic character 'о' to English vocabulary")

# Recreate the vocabularies
index_to_english = {k:v for k,v in enumerate(english_vocabulary)}
english_to_index = {v:k for k,v in enumerate(english_vocabulary)}

import torch

d_model = 512
batch_size = 30
ffn_hidden = 2048
num_heads = 8
drop_prob = 0.1
num_layers = 1
max_sequence_length = 200
kn_vocab_size = len(sanskrit_vocabulary)
en_vocab_size = len(english_vocabulary)  # Updated English vocab size

transformer = Transformer(d_model, 
                          ffn_hidden,
                          num_heads, 
                          drop_prob, 
                          num_layers, 
                          max_sequence_length,
                          kn_vocab_size,
                          english_to_index,
                          sanskrit_to_index,
                          START_TOKEN, 
                          END_TOKEN, 
                          PADDING_TOKEN)

print(f"English vocabulary size: {en_vocab_size}")
print(f"Sanskrit vocabulary size: {kn_vocab_size}")

Added missing Cyrillic character 'о' to English vocabulary
English vocabulary size: 182
Sanskrit vocabulary size: 89


In [123]:
transformer

Transformer(
  (encoder): Encoder(
    (sentence_embedding): SentenceEmbedding(
      (embedding): Embedding(180, 512)
      (position_encoder): PositionalEncoding()
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (layers): SequentialEncoder(
      (0): EncoderLayer(
        (attention): MultiHeadAttention(
          (qkv_layer): Linear(in_features=512, out_features=1536, bias=True)
          (linear_layer): Linear(in_features=512, out_features=512, bias=True)
        )
        (norm1): LayerNormalization()
        (dropout1): Dropout(p=0.1, inplace=False)
        (ffn): PositionwiseFeedForward(
          (linear1): Linear(in_features=512, out_features=2048, bias=True)
          (linear2): Linear(in_features=2048, out_features=512, bias=True)
          (relu): ReLU()
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (norm2): LayerNormalization()
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (decoder): Decoder(
    (sentence_embedding)

In [124]:
from torch.utils.data import Dataset, DataLoader

class TextDataset(Dataset):

    def __init__(self, english_sentences, sanskrit_sentences):
        self.english_sentences = english_sentences
        self.sanskrit_sentences = sanskrit_sentences

    def __len__(self):
        return len(self.english_sentences)

    def __getitem__(self, idx):
        return self.english_sentences[idx], self.sanskrit_sentences[idx]

In [125]:
dataset = TextDataset(english_sentences, sanskrit_sentences)

In [126]:
len(dataset)

3832

In [127]:
dataset[1]

('tender, and youthful, and worthy of happiness, she is by no means capable of living in the forest. my spiritual guide has spoken the truth.',
 'सुकुमारी च बाला च सततं च सुखोचिता। नेयं वनस्य योग्येति सत्यमाह गुरुर्मम ॥')

In [128]:
train_loader = DataLoader(dataset, batch_size)
iterator = iter(train_loader)

In [129]:
for batch_num, batch in enumerate(iterator):
    print(batch)
    if batch_num > 3:
        break

[('when şītā, having a husband although seeming as if she had none, was putting on the ascetic guise, the people got into a wrath and exclaimed, “o dasaratha, fie on you!"', 'tender, and youthful, and worthy of happiness, she is by no means capable of living in the forest. my spiritual guide has spoken the truth.', "let janaka's daughter leave off her ascetic guise. this is not the promise that i had made to you before. let the princess go to the forest in comfort, furnished with all sorts of gems.", "of eyes expanded like those of a doe, endued with a mild temperament, and virtuous, what harm has janaka's daughter done you.", 'surely, o nefarious one, the banishment of ráma is enough for you. why then do you bend your mind to perpetrate these atrocious sins?', 'o noble dame, having heard you asking for the banishment of rāma, who had at first been intended by me for being installed, and who came here afterwards, i had promised you (his exile alone.)', "but since, going beyond that pro

In [130]:
from torch import nn

# Check for GPU availability and set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
else:
    print("CUDA not available, using CPU")

criterian = nn.CrossEntropyLoss(ignore_index=sanskrit_to_index[PADDING_TOKEN],
                                reduction='none')

# When computing the loss, we are ignoring cases when the label is the padding token
for params in transformer.parameters():
    if params.dim() > 1:
        nn.init.xavier_uniform_(params)

optim = torch.optim.Adam(transformer.parameters(), lr=1e-4)

print(f"Transformer parameters: {sum(p.numel() for p in transformer.parameters()):,}")

Using device: cpu
CUDA not available, using CPU
Transformer parameters: 7,538,777


In [131]:
NEG_INFTY = -1e9

def create_masks(eng_batch, kn_batch):
    num_sentences = len(eng_batch)
    look_ahead_mask = torch.full([max_sequence_length, max_sequence_length] , True)
    look_ahead_mask = torch.triu(look_ahead_mask, diagonal=1)
    encoder_padding_mask = torch.full([num_sentences, max_sequence_length, max_sequence_length] , False)
    decoder_padding_mask_self_attention = torch.full([num_sentences, max_sequence_length, max_sequence_length] , False)
    decoder_padding_mask_cross_attention = torch.full([num_sentences, max_sequence_length, max_sequence_length] , False)

    for idx in range(num_sentences):
      eng_sentence_length, kn_sentence_length = len(eng_batch[idx]), len(kn_batch[idx])
      eng_chars_to_padding_mask = np.arange(eng_sentence_length + 1, max_sequence_length)
      kn_chars_to_padding_mask = np.arange(kn_sentence_length + 1, max_sequence_length)
      encoder_padding_mask[idx, :, eng_chars_to_padding_mask] = True
      encoder_padding_mask[idx, eng_chars_to_padding_mask, :] = True
      decoder_padding_mask_self_attention[idx, :, kn_chars_to_padding_mask] = True
      decoder_padding_mask_self_attention[idx, kn_chars_to_padding_mask, :] = True
      decoder_padding_mask_cross_attention[idx, :, eng_chars_to_padding_mask] = True
      decoder_padding_mask_cross_attention[idx, kn_chars_to_padding_mask, :] = True

    encoder_self_attention_mask = torch.where(encoder_padding_mask, NEG_INFTY, 0)
    decoder_self_attention_mask =  torch.where(look_ahead_mask + decoder_padding_mask_self_attention, NEG_INFTY, 0)
    decoder_cross_attention_mask = torch.where(decoder_padding_mask_cross_attention, NEG_INFTY, 0)
    return encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask

Modify mask such that the padding tokens cannot look ahead.
In Encoder, tokens before it should be -1e9 while tokens after it should be -inf.
 

Note the target mask starts with 2 rows of non masked items: https://github.com/SamLynnEvans/Transformer/blob/master/Beam.py#L55


In [132]:
# Build complete vocabularies from the actual data
print("Building vocabularies from actual data...")

# Read the original data again
with open(english_file, 'r') as file:
    all_english_sentences = file.readlines()
with open(sanskrit_file, 'r') as file:
    all_sanskrit_sentences = file.readlines()

# Limit and clean
all_english_sentences = all_english_sentences[:TOTAL_SENTENCES]
all_sanskrit_sentences = all_sanskrit_sentences[:TOTAL_SENTENCES]
all_english_sentences = [sentence.rstrip('\n').lower() for sentence in all_english_sentences]
all_sanskrit_sentences = [sentence.rstrip('\n') for sentence in all_sanskrit_sentences]

print(f"Total sentences: {len(all_english_sentences)}")

# Build vocabularies from ALL characters in the data
def build_vocabulary_from_data(sentences, base_vocab):
    """Build vocabulary including ALL characters found in the data"""
    vocab_set = set(base_vocab)
    
    for sentence in sentences:
        for char in sentence:
            vocab_set.add(char)
    
    return sorted(list(vocab_set))

# Base vocabularies
base_english_vocab = [START_TOKEN, ' ', '!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/', 
                      '0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
                      ':', ';', '<', '=', '>', '?', '@', 
                      'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 
                      'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 
                      'Y', 'Z', '_',
                      'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l',
                      'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 
                      'y', 'z', '{', '|', '}', '~', PADDING_TOKEN, END_TOKEN]

base_sanskrit_vocab = [START_TOKEN, ' ', '!', '"', "'", '(', ')', ',', '-', '.', '?', ':', ';',
                       '0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
                       # Devanagari characters
                       'अ', 'आ', 'इ', 'ई', 'उ', 'ऊ', 'ऋ', 'ॠ', 'ऌ', 'ॡ', 'ए', 'ऐ', 'ओ', 'औ',
                       'क', 'ख', 'ग', 'घ', 'ङ', 'च', 'छ', 'ज', 'झ', 'ञ', 'ट', 'ठ', 'ड', 'ढ', 'ण',
                       'त', 'थ', 'द', 'ध', 'न', 'प', 'फ', 'ब', 'भ', 'म', 'य', 'र', 'ल', 'व',
                       'श', 'ष', 'स', 'ह', 'ा', 'ि', 'ी', 'ु', 'ू', 'ृ', 'ॄ', 'े', 'ै', 'ो', 'ौ',
                       'ं', 'ः', 'ँ', '्', '।', '॥', PADDING_TOKEN, END_TOKEN]

# Build complete vocabularies
english_vocabulary = build_vocabulary_from_data(all_english_sentences, base_english_vocab)
sanskrit_vocabulary = build_vocabulary_from_data(all_sanskrit_sentences, base_sanskrit_vocab)

print(f"English vocabulary size: {len(english_vocabulary)}")
print(f"Sanskrit vocabulary size: {len(sanskrit_vocabulary)}")

# Create the mappings
index_to_english = {k:v for k,v in enumerate(english_vocabulary)}
english_to_index = {v:k for k,v in enumerate(english_vocabulary)}
index_to_sanskrit = {k:v for k,v in enumerate(sanskrit_vocabulary)}
sanskrit_to_index = {v:k for k,v in enumerate(sanskrit_vocabulary)}

print("✓ Complete vocabularies built from actual data")

# Final test to ensure the vocabulary issue is resolved
train_loader = DataLoader(dataset, batch_size)
iterator = iter(train_loader)
batch = next(iterator)
eng_batch, kn_batch = batch

print("Testing transformer forward pass with real data...")
try:
    # Create masks
    encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask = create_masks(eng_batch, kn_batch)
    
    # Test forward pass
    transformer.eval()
    with torch.no_grad():
        kn_predictions = transformer(eng_batch,
                                     kn_batch,
                                     encoder_self_attention_mask.to(device), 
                                     decoder_self_attention_mask.to(device), 
                                     decoder_cross_attention_mask.to(device),
                                     enc_start_token=False,
                                     enc_end_token=False,
                                     dec_start_token=True,
                                     dec_end_token=True)
    
    print(f"✓ Forward pass successful!")
    print(f"Output shape: {kn_predictions.shape}")
    print(f"First sentence: '{eng_batch[0][:50]}...'")
    
except Exception as e:
    print(f"✗ Error: {e}")
    import traceback
    traceback.print_exc()

Building vocabularies from actual data...
Total sentences: 6148
English vocabulary size: 157
Sanskrit vocabulary size: 114
✓ Complete vocabularies built from actual data
Testing transformer forward pass with real data...
✗ Error: index out of range in self
Total sentences: 6148
English vocabulary size: 157
Sanskrit vocabulary size: 114
✓ Complete vocabularies built from actual data
Testing transformer forward pass with real data...
✗ Error: index out of range in self


Traceback (most recent call last):
  File "/var/folders/4j/n0rl6b3s0657t_7z21qxxz6w0000gn/T/ipykernel_29778/2414974385.py", line 78, in <module>
    kn_predictions = transformer(eng_batch,
                                 kn_batch,
    ...<5 lines>...
                                 dec_start_token=True,
                                 dec_end_token=True)
  File "/Users/ankitpokhrel/Downloads/All projects/ML_Projects/sanskrit_to_english/myenv/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/Users/ankitpokhrel/Downloads/All projects/ML_Projects/sanskrit_to_english/myenv/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
  File "/Users/ankitpokhrel/Downloads/All projects/ML_Projects/sanskrit_to_english/transformer/transformer.py", line 301, in forward
    x = self.encoder(x, en

In [133]:
# Filter sentences using the complete vocabularies
valid_sentence_indicies = []
for index in range(len(all_english_sentences)):
    sanskrit_sentence, english_sentence = all_sanskrit_sentences[index], all_english_sentences[index]
    if (is_valid_length(sanskrit_sentence, max_sequence_length) and 
        is_valid_length(english_sentence, max_sequence_length) and 
        is_valid_tokens(sanskrit_sentence, sanskrit_vocabulary) and
        is_valid_tokens(english_sentence, english_vocabulary)):
        valid_sentence_indicies.append(index)

print(f"Valid sentences: {len(valid_sentence_indicies)} out of {len(all_english_sentences)}")

# Create the filtered datasets
english_sentences = [all_english_sentences[i] for i in valid_sentence_indicies]
sanskrit_sentences = [all_sanskrit_sentences[i] for i in valid_sentence_indicies]

# Recreate the dataset - NOTE: Now Sanskrit is input, English is output
dataset = TextDataset(sanskrit_sentences, english_sentences)

# Set device for GPU usage
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Create the transformer: Sanskrit input -> English output
transformer = Transformer(d_model, 
                          ffn_hidden,
                          num_heads, 
                          drop_prob, 
                          num_layers, 
                          max_sequence_length,
                          len(english_vocabulary),  # English is now the output vocabulary
                          sanskrit_to_index,        # Sanskrit is now the input (encoder)
                          english_to_index,         # English is now the output (decoder)
                          START_TOKEN, 
                          END_TOKEN, 
                          PADDING_TOKEN)

# Move transformer to GPU
transformer = transformer.to(device)

# Reinitialize everything
criterian = nn.CrossEntropyLoss(ignore_index=english_to_index[PADDING_TOKEN], reduction='none')

for params in transformer.parameters():
    if params.dim() > 1:
        nn.init.xavier_uniform_(params)

optim = torch.optim.Adam(transformer.parameters(), lr=1e-4)

print(f"✓ Sanskrit-to-English transformer created!")
print(f"Device: {device}")
print(f"Dataset size: {len(dataset)}")
print(f"Sanskrit vocab (input): {len(sanskrit_vocabulary)}")
print(f"English vocab (output): {len(english_vocabulary)}")
print(f"Transformer params: {sum(p.numel() for p in transformer.parameters()):,}")

if torch.cuda.is_available():
    print(f"GPU Memory allocated: {torch.cuda.memory_allocated(device) / 1024**3:.2f} GB")
    print(f"GPU Memory cached: {torch.cuda.memory_reserved(device) / 1024**3:.2f} GB")

Valid sentences: 4633 out of 6148

✓ Sanskrit-to-English transformer created!
Device: cpu
Dataset size: 4633
Sanskrit vocab (input): 114
English vocab (output): 157
Transformer params: 7,575,709
✓ Sanskrit-to-English transformer created!
Device: cpu
Dataset size: 4633
Sanskrit vocab (input): 114
English vocab (output): 157
Transformer params: 7,575,709


In [134]:
# Test the Sanskrit-to-English transformer
train_loader = DataLoader(dataset, batch_size)
iterator = iter(train_loader)
batch = next(iterator)
skt_batch, eng_batch = batch  # Sanskrit input, English output

print("Testing Sanskrit-to-English transformer...")
try:
    # Test forward pass
    encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask = create_masks(skt_batch, eng_batch)
    
    transformer.eval()
    with torch.no_grad():
        eng_predictions = transformer(skt_batch,
                                      eng_batch,
                                      encoder_self_attention_mask.to(device), 
                                      decoder_self_attention_mask.to(device), 
                                      decoder_cross_attention_mask.to(device),
                                      enc_start_token=False,
                                      enc_end_token=False,
                                      dec_start_token=True,
                                      dec_end_token=True)
    
    print(f"✅ SUCCESS! Forward pass completed")
    print(f"Output shape: {eng_predictions.shape}")
    print(f"Sample Sanskrit input: '{skt_batch[0][:50]}...'")
    print(f"Sample English target: '{eng_batch[0][:50]}...'")
    print("\n🎉 Sanskrit-to-English transformer is working!")
    print("You can now proceed with training.")
    
except Exception as e:
    print(f"❌ Error: {e}")
    import traceback
    traceback.print_exc()

Testing Sanskrit-to-English transformer...
✅ SUCCESS! Forward pass completed
Output shape: torch.Size([30, 200, 157])
Sample Sanskrit input: 'तस्यां चीरं वसानायां नाथवत्यामनाथवत्। प्रचुक्रोश ज...'
Sample English target: 'when şītā, having a husband although seeming as if...'

🎉 Sanskrit-to-English transformer is working!
You can now proceed with training.
✅ SUCCESS! Forward pass completed
Output shape: torch.Size([30, 200, 157])
Sample Sanskrit input: 'तस्यां चीरं वसानायां नाथवत्यामनाथवत्। प्रचुक्रोश ज...'
Sample English target: 'when şītā, having a husband although seeming as if...'

🎉 Sanskrit-to-English transformer is working!
You can now proceed with training.


In [135]:
# Training Sanskrit-to-English transformer on GPU
transformer.train()
# Ensure transformer is on the correct device
transformer = transformer.to(device)
total_loss = 0
num_epochs = 10

# Use the updated vocabulary size
eng_vocab_size = len(english_vocabulary)

print(f"Starting Sanskrit-to-English training...")
print(f"Device: {device}")
print(f"Dataset size: {len(dataset)}")
print(f"Sanskrit vocab size: {len(sanskrit_vocabulary)}")
print(f"English vocab size: {len(english_vocabulary)}")

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Initial GPU Memory: {torch.cuda.memory_allocated(device) / 1024**3:.2f} GB")

for epoch in range(num_epochs):
    print(f"\nEpoch {epoch}")
    iterator = iter(train_loader)
    epoch_loss = 0
    num_batches = 0
    
    for batch_num, batch in enumerate(iterator):
        transformer.train()
        skt_batch, eng_batch = batch  # Sanskrit input, English output
        encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask = create_masks(skt_batch, eng_batch)
        
        # Move masks to device
        encoder_self_attention_mask = encoder_self_attention_mask.to(device)
        decoder_self_attention_mask = decoder_self_attention_mask.to(device)
        decoder_cross_attention_mask = decoder_cross_attention_mask.to(device)
        
        optim.zero_grad()
        
        eng_predictions = transformer(skt_batch,
                                      eng_batch,
                                      encoder_self_attention_mask, 
                                      decoder_self_attention_mask, 
                                      decoder_cross_attention_mask,
                                      enc_start_token=False,
                                      enc_end_token=False,
                                      dec_start_token=True,
                                      dec_end_token=True)
        
        labels = transformer.decoder.sentence_embedding.batch_tokenize(eng_batch, start_token=False, end_token=True)
        loss = criterian(
            eng_predictions.view(-1, eng_vocab_size),
            labels.view(-1)
        )
        valid_indicies = torch.where(labels.view(-1) == english_to_index[PADDING_TOKEN], False, True)
        loss = loss.sum() / valid_indicies.sum()
        
        loss.backward()
        optim.step()
        
        epoch_loss += loss.item()
        num_batches += 1
        
        if batch_num % 100 == 0:
            print(f"Iteration {batch_num} : {loss.item():.4f}")
            print(f"Sanskrit: {skt_batch[0][:80]}...")
            print(f"English Translation: {eng_batch[0][:80]}...")
            
            # Predicted translation
            eng_sentence_predicted = torch.argmax(eng_predictions[0], axis=1)
            predicted_sentence = ""
            for idx in eng_sentence_predicted:
              if idx == english_to_index[END_TOKEN]:
                break
              predicted_sentence += index_to_english[idx.item()]
            print(f"English Prediction: {predicted_sentence[:80]}...")

            # Evaluation on a sample Sanskrit sentence
            transformer.eval()
            with torch.no_grad():
                eng_sentence = ("",)
                skt_sentence = ("नमस्ते",)  # "Hello" in Sanskrit
                for word_counter in range(max_sequence_length):
                    encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask = create_masks(skt_sentence, eng_sentence)
                    encoder_self_attention_mask = encoder_self_attention_mask.to(device)
                    decoder_self_attention_mask = decoder_self_attention_mask.to(device)
                    decoder_cross_attention_mask = decoder_cross_attention_mask.to(device)
                    
                    predictions = transformer(skt_sentence,
                                              eng_sentence,
                                              encoder_self_attention_mask, 
                                              decoder_self_attention_mask, 
                                              decoder_cross_attention_mask,
                                              enc_start_token=False,
                                              enc_end_token=False,
                                              dec_start_token=True,
                                              dec_end_token=False)
                    next_token_prob_distribution = predictions[0][word_counter]
                    next_token_index = torch.argmax(next_token_prob_distribution).item()
                    next_token = index_to_english[next_token_index]
                    eng_sentence = (eng_sentence[0] + next_token, )
                    if next_token == END_TOKEN:
                      break
            
            print(f"Evaluation translation (नमस्ते): {eng_sentence[0]}")
            
            # GPU memory info
            if torch.cuda.is_available():
                print(f"GPU Memory: {torch.cuda.memory_allocated(device) / 1024**3:.2f} GB")
            print("-------------------------------------------")
    
    avg_loss = epoch_loss / num_batches
    print(f"Epoch {epoch} completed. Average loss: {avg_loss:.4f}")
    
    if torch.cuda.is_available():
        print(f"GPU Memory after epoch: {torch.cuda.memory_allocated(device) / 1024**3:.2f} GB")

Starting Sanskrit-to-English training...
Device: cpu
Dataset size: 4633
Sanskrit vocab size: 114
English vocab size: 157

Epoch 0
Iteration 0 : 5.5575
Sanskrit: तस्यां चीरं वसानायां नाथवत्यामनाथवत्। प्रचुक्रोश जनः सर्वो धिक् त्वां दशरथं त्वि...
English Translation: when şītā, having a husband although seeming as if she had none, was putting on ...
English Prediction: dş00åभNकककककककककFसPī॥॥3ससo003šकc00ş00şककdšकFककककbšīकककīoīcकcbककcīīīīīīīचभभभşdīचभ...
Iteration 0 : 5.5575
Sanskrit: तस्यां चीरं वसानायां नाथवत्यामनाथवत्। प्रचुक्रोश जनः सर्वो धिक् त्वां दशरथं त्वि...
English Translation: when şītā, having a husband although seeming as if she had none, was putting on ...
English Prediction: dş00åभNकककककककककFसPī॥॥3ससo003šकc00ş00şककdšकFककककbšīकककīoīcकcbककcīīīīīīīचभभभşdīचभ...
Evaluation translation (नमस्ते):          mmmmaaaaPPPoooooo  ooo  oooommmmmmmoooooooooooooooooooooooooooooooooo oooo    ooooo   PP     oooooooooooooooooooo  oo  oooooooooooooooooooooooooooooooooooooooooīī  ooooooooooo    o

KeyboardInterrupt: 

## Inference

In [None]:
transformer.eval()
def translate_sanskrit_to_english(skt_sentence):
    """Translate Sanskrit sentence to English using GPU if available"""
    transformer.eval()
    with torch.no_grad():
        skt_sentence = (skt_sentence,)
        eng_sentence = ("",)
        for word_counter in range(max_sequence_length):
            encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask = create_masks(skt_sentence, eng_sentence)
            
            # Move masks to device
            encoder_self_attention_mask = encoder_self_attention_mask.to(device)
            decoder_self_attention_mask = decoder_self_attention_mask.to(device)
            decoder_cross_attention_mask = decoder_cross_attention_mask.to(device)
            
            predictions = transformer(skt_sentence,
                                      eng_sentence,
                                      encoder_self_attention_mask, 
                                      decoder_self_attention_mask, 
                                      decoder_cross_attention_mask,
                                      enc_start_token=False,
                                      enc_end_token=False,
                                      dec_start_token=True,
                                      dec_end_token=False)
            next_token_prob_distribution = predictions[0][word_counter]
            next_token_index = torch.argmax(next_token_prob_distribution).item()
            next_token = index_to_english[next_token_index]
            eng_sentence = (eng_sentence[0] + next_token, )
            if next_token == END_TOKEN:
                break
        return eng_sentence[0]

In [None]:
translation = translate("what should we do when the day starts?")
print(translation)
#ದಿನ ಪ್ರಾರಂಭವಾದಾಗ ನಾವು ಏನು ಮಾಡಬೇಕು?

# Test Sanskrit-to-English translations
translation = translate_sanskrit_to_english("नमस्ते")
print(f"Sanskrit: नमस्ते")
print(f"English: {translation}")
print()

ಇದರ ಬಗ್ಗೆ ಏನು ಮಾಡಬೇಕು?<END>


In [None]:
translation = translate_sanskrit_to_english("सत्यं किम्?")
print(f"Sanskrit: सत्यं ಕಿಮ್?")
print(f"English: {translation}")
print()

ಹೇಗೆ ಇದು ಹೇಗೆ ಹೇಗೆ?<END>


In [None]:
translation = translate_sanskrit_to_english("जगत् एकः महान् स्थानः अस्ति")
print(f"Sanskrit: जगत् एकः महान् स्थानः अस्ति")
print(f"English: {translation}")
print()

ಇದರಿಂದ ಮೂಲಕ ಸಂಬಂಧಿಸಿದ ಮೇಲೆ ಮಾಡಿದ್ದಾರೆ<END>


In [None]:
translation = translate_sanskrit_to_english("मम नाम अजयः")
print(f"Sanskrit: मम नाम ಅಜಯ್")
print(f"English: {translation}")
print()

ನಾನು ಕುಟುಂಬದ ಹೆಸರು<END>


In [None]:
translation = translate_sanskrit_to_english("अहं एतं गन्धं न सहे")
print(f"Sanskrit: अहं एतं गन्धं न सहे")
print(f"English: {translation}")
print()

ನಾನು ಅಂತರ ಸಂಗತಿ ನಾನು ಕೊಡುವುದಿಲ್ಲ<END>


In [None]:
translation = translate_sanskrit_to_english("इदं सर्वोत्तमं वस्तु")
print(f"Sanskrit: इदं सर्वोत्तमं वस्तु")
print(f"English: {translation}")
print()

ಇದು ಅತ್ಯಂತ ಹೊರತಾಗಿದೆ<END>


In [None]:
translation = translate_sanskrit_to_english("अहं अत्र अस्मि")
print(f"Sanskrit: अहं अत्र अस्मि")
print(f"English: {translation}")
print()

ನಾನು ಕೇಳಿದ್ದೇನೆ.<END>


## Insights for Sanskrit-to-English Translation

- **Character-level tokenization**: While we're using character-level tokenization, word-based or BPE-based tokenization might yield better results for Sanskrit, especially considering its complex morphology.
- **Dataset quality**: Ensure the training set has diverse Sanskrit texts (classical literature, modern usage, different domains) paired with accurate English translations.
- **Model capacity**: Consider increasing the number of encoder/decoder layers for better translation quality. Currently set to 1 layer each for faster training.
- **Sanskrit script complexity**: Devanagari script has conjuncts and complex character combinations that might benefit from specialized preprocessing.
- **Evaluation metrics**: Consider implementing BLEU score or other translation quality metrics for better evaluation during training.
- **Fine-tuning**: Start with a smaller learning rate and consider learning rate scheduling for better convergence.

## Conclusion

This transformer model is now configured for **Sanskrit-to-English translation**. The model architecture includes:

- **Input**: Sanskrit text (Devanagari script)
- **Output**: English text
- **Vocabulary**: Complete character-level vocabularies built from the actual dataset
- **Model**: Single-layer encoder-decoder transformer (can be increased for better performance)

The model should learn to translate Sanskrit sentences to English. For production use, consider:
1. Increasing model layers and capacity
2. Using larger datasets
3. Implementing proper evaluation metrics
4. Adding beam search for better decoding
5. Experimenting with subword tokenization (BPE/SentencePiece)