In [1]:
from transformer import Transformer
from transformer import create_padding_mask
from transformer import create_causal_mask
from transformer import combine_masks
import torch
import torch.nn as nn

In [2]:
device = torch.device('cuda')

In [3]:
START_TOKEN = '<SOS>'
PADDING_TOKEN = '<PAD>'
END_TOKEN = '<EOS>'
UNKNOWN_TOKEN = '<UNK>'

In [4]:
ta_vocab = [PADDING_TOKEN, START_TOKEN, ' ', '!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/', 
            '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', '<', '=', '>', '?', 'ˌ', 
            'ஃ', 'அ', 'ஆ', 'இ', 'ஈ', 'உ', 'ஊ', 'எ', 'ஏ', 'ஐ', 'ஒ', 'ஓ', 'ஔ', 'க்', 'க', 'கா', 'கி', 'கீ', 'கு', 'கூ', 'கெ', 
            'கே', 'கை', 'கொ', 'கோ', 'கௌ', 'ங்', 'ங', 'ஙா', 'ஙி', 'ஙீ', 'ஙு', 'ஙூ', 'ஙெ', 'ஙே', 'ஙை', 'ஙொ', 'ஙோ', 'ஙௌ', 'ச்', 
            'ச', 'சா', 'சி', 'சீ', 'சு', 'சூ', 'செ', 'சே', 'சை', 'சொ', 'சோ', 'சௌ',
            'ஞ்', 'ஞ', 'ஞா', 'ஞி', 'ஞீ', 'ஞு', 'ஞூ', 'ஞெ', 'ஞே', 'ஞை', 'ஞொ', 'ஞோ', 'ஞௌ',
            'ட்', 'ட', 'டா', 'டி', 'டீ', 'டு', 'டூ', 'டெ', 'டே', 'டை', 'டொ', 'டோ', 'டௌ',
            'ண்', 'ண', 'ணா', 'ணி', 'ணீ', 'ணு', 'ணூ', 'ணெ', 'ணே', 'ணை', 'ணொ', 'ணோ', 'ணௌ',
            'த்', 'த', 'தா', 'தி', 'தீ', 'து', 'தூ', 'தெ', 'தே', 'தை', 'தொ', 'தோ', 'தௌ',
            'ந்', 'ந', 'நா', 'நி', 'நீ', 'நு', 'நூ', 'நெ', 'நே', 'நை', 'நொ', 'நோ', 'நௌ',
            'ப்', 'ப', 'பா', 'பி', 'பீ', 'பு', 'பூ', 'பெ', 'பே', 'பை', 'பொ', 'போ', 'பௌ',
            'ம்', 'ம', 'மா', 'மி', 'மீ', 'மு', 'மூ', 'மெ', 'மே', 'மை', 'மொ', 'மோ', 'மௌ',
            'ய்', 'ய', 'யா', 'யி', 'யீ', 'யு', 'யூ', 'யெ', 'யே', 'யை', 'யொ', 'யோ', 'யௌ',
            'ர்', 'ர', 'ரா', 'ரி', 'ரீ', 'ரு', 'ரூ', 'ரெ', 'ரே', 'ரை', 'ரொ', 'ரோ', 'ரௌ',
            'ல்', 'ல', 'லா', 'லி', 'லீ', 'லு', 'லூ', 'லெ', 'லே', 'லை', 'லொ', 'லோ', 'லௌ',
            'வ்', 'வ', 'வா', 'வி', 'வீ', 'வு', 'வூ', 'வெ', 'வே', 'வை', 'வொ', 'வோ', 'வௌ',
            'ழ்', 'ழ', 'ழா', 'ழி', 'ழீ', 'ழு', 'ழூ', 'ழெ', 'ழே', 'ழை', 'ழொ', 'ழோ', 'ழௌ',
            'ள்', 'ள', 'ளா', 'ளி', 'ளீ', 'ளு', 'ளூ', 'ளெ', 'ளே', 'ளை', 'ளொ', 'ளோ', 'ளௌ',
            'ற்', 'ற', 'றா', 'றி', 'றீ', 'று', 'றூ', 'றெ', 'றே', 'றை', 'றொ', 'றோ', 'றௌ',
            'ன்', 'ன', 'னா', 'னி', 'னீ', 'னு', 'னூ', 'னெ', 'னே', 'னை',
            'ஶ்', 'ஶ', 'ஶா', 'ஶி', 'ஶீ', 'ஶு', 'ஶூ', 'ஶெ', 'ஶே', 'ஶை', 'ஶொ', 'ஶோ', 'ஶௌ',
            'ஜ்', 'ஜ', 'ஜா', 'ஜி', 'ஜீ', 'ஜு', 'ஜூ', 'ஜெ', 'ஜே', 'ஜை', 'ஜொ', 'ஜோ', 'ஜௌ',
            'ஷ்', 'ஷ', 'ஷா', 'ஷி', 'ஷீ', 'ஷு', 'ஷூ', 'ஷெ', 'ஷே', 'ஷை', 'ஷொ', 'ஷோ', 'ஷௌ',
            'ஸ்', 'ஸ', 'ஸா', 'ஸி', 'ஸீ', 'ஸு', 'ஸூ', 'ஸெ', 'ஸே', 'ஸை', 'ஸொ', 'ஸோ', 'ஸௌ',
            'ஹ்', 'ஹ', 'ஹா', 'ஹி', 'ஹீ', 'ஹு', 'ஹூ', 'ஹெ', 'ஹே', 'ஹை', 'ஹொ', 'ஹோ', 'ஹௌ',
            'க்ஷ்', 'க்ஷ', 'க்ஷா', 'க்ஷ', 'க்ஷீ', 'க்ஷு', 'க்ஷூ', 'க்ஷெ', 'க்ஷே', 'க்ஷை', 'க்ஷொ', 'க்ஷோ', 'க்ஷௌ', 
            '்', 'ா', 'ி', 'ீ', 'ு', 'ூ', 'ெ', 'ே', 'ை', 'ொ', 'ோ', 'ௌ',END_TOKEN]

In [5]:
en_vocab = [PADDING_TOKEN, START_TOKEN, ' ', '!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/', 
                        '0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
                        ':', '<', '=', '>', '?', '@',
                        '[', '\\', ']', '^', '_', '`', 
                        'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l',
                        'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 
                        'y', 'z', '{', '|', '}', '~', END_TOKEN]

In [6]:
index_to_tamil = {k:v for k,v in enumerate(ta_vocab)}
tamil_to_index = {v:k for k,v in enumerate(ta_vocab)}
index_to_english = {k:v for k,v in enumerate(en_vocab)}
english_to_index = {v:k for k,v in enumerate(en_vocab)}

In [7]:
with open('en-ta/English.txt', 'r') as file:
    en_sentences = file.readlines()
with open('en-ta/Tamil.txt', 'r') as file:
    ta_sentences = file.readlines()

TOTAL_SENTENCES = 200000
en_sentences = en_sentences[:TOTAL_SENTENCES]
ta_sentences = ta_sentences[:TOTAL_SENTENCES]
en_sentences = [sentence.rstrip('\n').lower() for sentence in en_sentences]
ta_sentences = [sentence.rstrip('\n') for sentence in ta_sentences]

In [8]:
def is_valid_token(sentence, vocab):
    return all(token in vocab for token in sentence)

def find_invalid_tokens(sentence, vocab):
    return [token for token in set(sentence) if token not in vocab]

def is_valid_length(sentence, max_sequence_length):
    return len(sentence) <= max_sequence_length

invalid_tokens_list = []
valid_sentence_indices = []
invalid_sentence_indices = []

for index, (ta_sentence, en_sentence) in enumerate(zip(ta_sentences, en_sentences)):
    invalid_ta_tokens = find_invalid_tokens(ta_sentence, ta_vocab)
    invalid_en_tokens = find_invalid_tokens(en_sentence, en_vocab)

    if is_valid_length(ta_sentence, 250) and is_valid_length(en_sentence, 250):
        if is_valid_token(ta_sentence, ta_vocab) and is_valid_token(en_sentence, en_vocab):
            valid_sentence_indices.append(index)
        else:
            invalid_tokens_list.append((invalid_ta_tokens, invalid_en_tokens))
            invalid_sentence_indices.append(index)
            
print(f"Number of sentences: {len(ta_sentences)}")
print(f"Number of valid sentences: {len(valid_sentence_indices)}")
         
ta_sentences = [ta_sentences[i] for i in valid_sentence_indices]
en_sentences = [en_sentences[i] for i in valid_sentence_indices]



Number of sentences: 200000
Number of valid sentences: 172749


In [9]:
def tokenize_sentence(sentence):
    return list(sentence)

def tokens_to_indices(tokens, vocab_to_index):
    return [vocab_to_index[token] for token in tokens]

def add_special_tokens(indices, sos_token_index, eos_token_index):
    return [sos_token_index] + indices + [eos_token_index]

from torch.nn.utils.rnn import pad_sequence

def pad_sequences(batch, padding_value):
    return pad_sequence(batch, batch_first=True, padding_value=padding_value)



In [10]:
from torch.utils.data import Dataset

class TranslationDataset(Dataset):
    def __init__(self, source_sentences, target_sentences, 
                 source_vocab_to_index, target_vocab_to_index,
                 max_length=250):
        self.source_sentences = source_sentences
        self.target_sentences = target_sentences
        self.source_vocab_to_index = source_vocab_to_index
        self.target_vocab_to_index = target_vocab_to_index
        self.max_length = max_length
        
        self.source_sos = source_vocab_to_index['<SOS>']
        self.source_eos = source_vocab_to_index['<EOS>']
        self.source_pad = source_vocab_to_index['<PAD>']
        
        self.target_sos = target_vocab_to_index['<SOS>']
        self.target_eos = target_vocab_to_index['<EOS>']
        self.target_pad = target_vocab_to_index['<PAD>']
        
    def __len__(self):
        return len(self.source_sentences)
    
    def __getitem__(self, idx):
        # Tokenize sentences
        src_tokens = tokenize_sentence(self.source_sentences[idx])
        tgt_tokens = tokenize_sentence(self.target_sentences[idx])
        
        # Convert tokens to indices
        src_indices = tokens_to_indices(src_tokens, self.source_vocab_to_index)
        tgt_indices = tokens_to_indices(tgt_tokens, self.target_vocab_to_index)
        
        # Add special tokens
        src_indices = add_special_tokens(src_indices, self.source_sos, self.source_eos)
        tgt_indices = add_special_tokens(tgt_indices, self.target_sos, self.target_eos)
        
        # Convert to tensors
        src_tensor = torch.tensor(src_indices, dtype=torch.long)
        tgt_tensor = torch.tensor(tgt_indices, dtype=torch.long)
        
        return src_tensor, tgt_tensor


In [11]:
def collate_fn(batch):
    src_batch, tgt_batch = zip(*batch)
    src_batch = pad_sequence(src_batch, batch_first=True, padding_value=english_to_index['<PAD>'])
    tgt_batch = pad_sequence(tgt_batch, batch_first=True, padding_value=tamil_to_index['<PAD>'])
    return src_batch, tgt_batch

In [12]:
from torch.utils.data import DataLoader

dataset = TranslationDataset(
    source_sentences=en_sentences,
    target_sentences=ta_sentences,
    source_vocab_to_index=english_to_index,
    target_vocab_to_index=tamil_to_index
)

batch_size = 1

dataloader = DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=collate_fn
)

In [14]:
model = Transformer(
    num_layers=6,
    d_model=512,
    dff=2048,
    dropout=0.1,
    heads=8,
    src_vocab_size=len(en_vocab),
    tgt_vocab_size=len(ta_vocab),
    max_len=252
).to(device)


In [15]:
MODEL_PATH = 'best_model_epoch_24.ptrom'

In [16]:
model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
print(f'Model loaded from {MODEL_PATH}')

Model loaded from best_model_epoch_24.ptrom


  model.load_state_dict(torch.load(MODEL_PATH, map_location=device))


In [17]:
def translate(sentence, model, 
              english_to_index, index_to_tamil, 
              max_length=250):
    model.eval()
    tokens = tokenize_sentence(sentence.lower())
    indices = tokens_to_indices(tokens, english_to_index)
    indices = add_special_tokens(indices, 
                                 english_to_index[START_TOKEN], 
                                 english_to_index[END_TOKEN])
    src_tensor = torch.tensor(indices, dtype=torch.long).unsqueeze(0).to(device)
    src_padding_mask = create_padding_mask(src_tensor, pad_token=english_to_index[PADDING_TOKEN]).to(device)
    tgt_indices = [tamil_to_index[START_TOKEN]]
    tgt_tensor = torch.tensor(tgt_indices, dtype=torch.long).unsqueeze(0).to(device)
    
    for _ in range(max_length):
        tgt_padding_mask = create_padding_mask(tgt_tensor, pad_token=tamil_to_index[PADDING_TOKEN]).to(device)
        causal_mask = create_causal_mask(tgt_tensor.size(1)).to(device)
        combined_mask = combine_masks(tgt_padding_mask, causal_mask)
        
        with torch.no_grad():
            output = model(src_tensor, tgt_tensor, 
                          src_padding_mask, 
                          tgt_padding_mask, 
                          combined_mask)
        
        next_token_logits = output[0, -1, :]
        _, next_token = torch.max(next_token_logits, dim=-1)
        next_token = next_token.item()
        tgt_indices.append(next_token)
        tgt_tensor = torch.tensor(tgt_indices, dtype=torch.long).unsqueeze(0).to(device)
        
        if next_token == tamil_to_index[END_TOKEN]:
            break
    
    translated_tokens = [index_to_tamil[idx] for idx in tgt_indices[1:] if idx != tamil_to_index[END_TOKEN]]
    translated_sentence = ''.join(translated_tokens)
    
    return translated_sentence

In [18]:
import heapq

def translate_beam_search(sentence, model, 
                          english_to_index, index_to_tamil, 
                          max_length=250, beam_width=3):
    """
    Translates an English sentence to Tamil using Beam Search with the trained Transformer model.
    
    Args:
        sentence (str): The English sentence to translate.
        model (Transformer): The trained Transformer model.
        english_to_index (dict): Mapping from English tokens to indices.
        index_to_tamil (dict): Mapping from Tamil indices to tokens.
        max_length (int): Maximum length of the generated Tamil sentence.
        beam_width (int): The number of beams to keep during decoding.
        
    Returns:
        str: The translated Tamil sentence.
    """
    model.eval()
    
    # Preprocess the input sentence
    tokens = tokenize_sentence(sentence.lower())
    indices = tokens_to_indices(tokens, english_to_index)
    indices = add_special_tokens(indices, 
                                 english_to_index[START_TOKEN], 
                                 english_to_index[END_TOKEN])
    src_tensor = torch.tensor(indices, dtype=torch.long).unsqueeze(0).to(device)  # Shape: [1, src_seq_len]
    
    # Create source padding mask
    src_padding_mask = create_padding_mask(src_tensor, pad_token=english_to_index[PADDING_TOKEN]).to(device)
    
    # Initialize the beam with the start token
    beams = [([tamil_to_index[START_TOKEN]], 0.0)]  # List of tuples: (sequence, cumulative log-prob)
    
    completed_beams = []
    
    for _ in range(max_length):
        new_beams = []
        for seq, score in beams:
            # If the last token is <EOS>, add the beam to completed_beams
            if seq[-1] == tamil_to_index[END_TOKEN]:
                completed_beams.append((seq, score))
                continue
            
            # Prepare target tensor
            tgt_tensor = torch.tensor(seq, dtype=torch.long).unsqueeze(0).to(device)  # Shape: [1, seq_len]
            
            # Create target padding mask
            tgt_padding_mask = create_padding_mask(tgt_tensor, pad_token=tamil_to_index[PADDING_TOKEN]).to(device)
            
            # Create causal mask for target
            causal_mask = create_causal_mask(tgt_tensor.size(1)).to(device)
            
            # Combine masks
            combined_mask = combine_masks(tgt_padding_mask, causal_mask)
            
            # Forward pass through the model
            with torch.no_grad():
                output = model(src_tensor, tgt_tensor, 
                              src_padding_mask, 
                              tgt_padding_mask, 
                              combined_mask)  # Shape: [1, seq_len, tgt_vocab_size]
            
            # Get the logits for the last token
            next_token_logits = output[0, -1, :]  # Shape: [tgt_vocab_size]
            
            # Compute log probabilities
            log_probs = nn.functional.log_softmax(next_token_logits, dim=-1)  # Shape: [tgt_vocab_size]
            
            # Get the top `beam_width` tokens and their log probabilities
            topk_log_probs, topk_indices = torch.topk(log_probs, beam_width)
            
            # Expand each beam with each of the top `beam_width` tokens
            for i in range(beam_width):
                next_token = topk_indices[i].item()
                next_log_prob = topk_log_probs[i].item()
                new_seq = seq + [next_token]
                new_score = score + next_log_prob
                new_beams.append((new_seq, new_score))
        
        # If no new beams are generated, break
        if not new_beams:
            break
        
        # Keep the top `beam_width` beams based on cumulative score
        beams = heapq.nlargest(beam_width, new_beams, key=lambda x: x[1])
        
        # If all beams are completed, stop early
        if len(completed_beams) >= beam_width:
            break
    
    # If no completed beams, use the current beams
    if not completed_beams:
        completed_beams = beams
    
    # Select the beam with the highest score
    best_beam = max(completed_beams, key=lambda x: x[1])
    tgt_indices = best_beam[0]
    
    # Convert indices to tokens, excluding <SOS> and <EOS>
    translated_tokens = [index_to_tamil[idx] for idx in tgt_indices[1:] if idx != tamil_to_index[END_TOKEN]]
    translated_sentence = ''.join(translated_tokens)
    
    return translated_sentence


In [19]:
english_sentence = "farmers in this region largely grow paddy and wheat."
tamil_translation = translate(english_sentence, model, 
                                english_to_index, index_to_tamil)
print(f"English: {english_sentence}")
print(f"Tamil: {tamil_translation}")

English: farmers in this region largely grow paddy and wheat.
Tamil: இதனால் பல பகுதிகளில் போக்குவரத்து பெரிதும் பாதிக்கப்பட்டுள்ளது.


In [26]:
test_sentences = ["How are you?"]

In [27]:
for english_sentence in test_sentences:
    tamil_translation_greedy = translate(english_sentence, model, 
                                        english_to_index, index_to_tamil)
    tamil_translation_beam = translate_beam_search(english_sentence, model, 
                                                    english_to_index, index_to_tamil, 
                                                    beam_width=3)
    print(f"English: {english_sentence}")
    print(f"Tamil (Greedy): {tamil_translation_greedy}")
    print(f"Tamil (Beam Search): {tamil_translation_beam}")
    print("-" * 50)

English: How are you?
Tamil (Greedy): யார் இருக்கா?
Tamil (Beam Search): எப்படி இருக்கிறாய்?
--------------------------------------------------
