In [1]:
from transformer import Transformer
from transformer import create_padding_mask
from transformer import create_causal_mask
from transformer import combine_masks
import torch
import torch.nn as nn

In [2]:
device = torch.device('cpu')

In [3]:
START_TOKEN = '<SOS>'
PADDING_TOKEN = '<PAD>'
END_TOKEN = '<EOS>'
UNKNOWN_TOKEN = '<UNK>'

In [4]:
ta_vocab = [PADDING_TOKEN, START_TOKEN, ' ', '!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/', 
            '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', '<', '=', '>', '?', 'ˌ', 
            'ஃ', 'அ', 'ஆ', 'இ', 'ஈ', 'உ', 'ஊ', 'எ', 'ஏ', 'ஐ', 'ஒ', 'ஓ', 'ஔ', 'க்', 'க', 'கா', 'கி', 'கீ', 'கு', 'கூ', 'கெ', 
            'கே', 'கை', 'கொ', 'கோ', 'கௌ', 'ங்', 'ங', 'ஙா', 'ஙி', 'ஙீ', 'ஙு', 'ஙூ', 'ஙெ', 'ஙே', 'ஙை', 'ஙொ', 'ஙோ', 'ஙௌ', 'ச்', 
            'ச', 'சா', 'சி', 'சீ', 'சு', 'சூ', 'செ', 'சே', 'சை', 'சொ', 'சோ', 'சௌ',
            'ஞ்', 'ஞ', 'ஞா', 'ஞி', 'ஞீ', 'ஞு', 'ஞூ', 'ஞெ', 'ஞே', 'ஞை', 'ஞொ', 'ஞோ', 'ஞௌ',
            'ட்', 'ட', 'டா', 'டி', 'டீ', 'டு', 'டூ', 'டெ', 'டே', 'டை', 'டொ', 'டோ', 'டௌ',
            'ண்', 'ண', 'ணா', 'ணி', 'ணீ', 'ணு', 'ணூ', 'ணெ', 'ணே', 'ணை', 'ணொ', 'ணோ', 'ணௌ',
            'த்', 'த', 'தா', 'தி', 'தீ', 'து', 'தூ', 'தெ', 'தே', 'தை', 'தொ', 'தோ', 'தௌ',
            'ந்', 'ந', 'நா', 'நி', 'நீ', 'நு', 'நூ', 'நெ', 'நே', 'நை', 'நொ', 'நோ', 'நௌ',
            'ப்', 'ப', 'பா', 'பி', 'பீ', 'பு', 'பூ', 'பெ', 'பே', 'பை', 'பொ', 'போ', 'பௌ',
            'ம்', 'ம', 'மா', 'மி', 'மீ', 'மு', 'மூ', 'மெ', 'மே', 'மை', 'மொ', 'மோ', 'மௌ',
            'ய்', 'ய', 'யா', 'யி', 'யீ', 'யு', 'யூ', 'யெ', 'யே', 'யை', 'யொ', 'யோ', 'யௌ',
            'ர்', 'ர', 'ரா', 'ரி', 'ரீ', 'ரு', 'ரூ', 'ரெ', 'ரே', 'ரை', 'ரொ', 'ரோ', 'ரௌ',
            'ல்', 'ல', 'லா', 'லி', 'லீ', 'லு', 'லூ', 'லெ', 'லே', 'லை', 'லொ', 'லோ', 'லௌ',
            'வ்', 'வ', 'வா', 'வி', 'வீ', 'வு', 'வூ', 'வெ', 'வே', 'வை', 'வொ', 'வோ', 'வௌ',
            'ழ்', 'ழ', 'ழா', 'ழி', 'ழீ', 'ழு', 'ழூ', 'ழெ', 'ழே', 'ழை', 'ழொ', 'ழோ', 'ழௌ',
            'ள்', 'ள', 'ளா', 'ளி', 'ளீ', 'ளு', 'ளூ', 'ளெ', 'ளே', 'ளை', 'ளொ', 'ளோ', 'ளௌ',
            'ற்', 'ற', 'றா', 'றி', 'றீ', 'று', 'றூ', 'றெ', 'றே', 'றை', 'றொ', 'றோ', 'றௌ',
            'ன்', 'ன', 'னா', 'னி', 'னீ', 'னு', 'னூ', 'னெ', 'னே', 'னை',
            'ஶ்', 'ஶ', 'ஶா', 'ஶி', 'ஶீ', 'ஶு', 'ஶூ', 'ஶெ', 'ஶே', 'ஶை', 'ஶொ', 'ஶோ', 'ஶௌ',
            'ஜ்', 'ஜ', 'ஜா', 'ஜி', 'ஜீ', 'ஜு', 'ஜூ', 'ஜெ', 'ஜே', 'ஜை', 'ஜொ', 'ஜோ', 'ஜௌ',
            'ஷ்', 'ஷ', 'ஷா', 'ஷி', 'ஷீ', 'ஷு', 'ஷூ', 'ஷெ', 'ஷே', 'ஷை', 'ஷொ', 'ஷோ', 'ஷௌ',
            'ஸ்', 'ஸ', 'ஸா', 'ஸி', 'ஸீ', 'ஸு', 'ஸூ', 'ஸெ', 'ஸே', 'ஸை', 'ஸொ', 'ஸோ', 'ஸௌ',
            'ஹ்', 'ஹ', 'ஹா', 'ஹி', 'ஹீ', 'ஹு', 'ஹூ', 'ஹெ', 'ஹே', 'ஹை', 'ஹொ', 'ஹோ', 'ஹௌ',
            'க்ஷ்', 'க்ஷ', 'க்ஷா', 'க்ஷ', 'க்ஷீ', 'க்ஷு', 'க்ஷூ', 'க்ஷெ', 'க்ஷே', 'க்ஷை', 'க்ஷொ', 'க்ஷோ', 'க்ஷௌ', 
            '்', 'ா', 'ி', 'ீ', 'ு', 'ூ', 'ெ', 'ே', 'ை', 'ொ', 'ோ', 'ௌ',END_TOKEN]

In [5]:
en_vocab = [PADDING_TOKEN, START_TOKEN, ' ', '!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/', 
                        '0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
                        ':', '<', '=', '>', '?', '@',
                        '[', '\\', ']', '^', '_', '`', 
                        'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l',
                        'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 
                        'y', 'z', '{', '|', '}', '~', END_TOKEN]

In [6]:
index_to_tamil = {k:v for k,v in enumerate(ta_vocab)}
tamil_to_index = {v:k for k,v in enumerate(ta_vocab)}
index_to_english = {k:v for k,v in enumerate(en_vocab)}
english_to_index = {v:k for k,v in enumerate(en_vocab)}

In [7]:
with open('en-ta/English.txt', 'r') as file:
    en_sentences = file.readlines()
with open('en-ta/Tamil.txt', 'r') as file:
    ta_sentences = file.readlines()

TOTAL_SENTENCES = 200000
en_sentences = en_sentences[:TOTAL_SENTENCES]
ta_sentences = ta_sentences[:TOTAL_SENTENCES]
en_sentences = [sentence.rstrip('\n').lower() for sentence in en_sentences]
ta_sentences = [sentence.rstrip('\n') for sentence in ta_sentences]

In [8]:
def is_valid_token(sentence, vocab):
    return all(token in vocab for token in sentence)

def find_invalid_tokens(sentence, vocab):
    return [token for token in set(sentence) if token not in vocab]

def is_valid_length(sentence, max_sequence_length):
    return len(sentence) <= max_sequence_length

invalid_tokens_list = []
valid_sentence_indices = []
invalid_sentence_indices = []

for index, (ta_sentence, en_sentence) in enumerate(zip(ta_sentences, en_sentences)):
    invalid_ta_tokens = find_invalid_tokens(ta_sentence, ta_vocab)
    invalid_en_tokens = find_invalid_tokens(en_sentence, en_vocab)

    if is_valid_length(ta_sentence, 250) and is_valid_length(en_sentence, 250):
        if is_valid_token(ta_sentence, ta_vocab) and is_valid_token(en_sentence, en_vocab):
            valid_sentence_indices.append(index)
        else:
            invalid_tokens_list.append((invalid_ta_tokens, invalid_en_tokens))
            invalid_sentence_indices.append(index)
            
print(f"Number of sentences: {len(ta_sentences)}")
print(f"Number of valid sentences: {len(valid_sentence_indices)}")
         
ta_sentences = [ta_sentences[i] for i in valid_sentence_indices]
en_sentences = [en_sentences[i] for i in valid_sentence_indices]



Number of sentences: 200000
Number of valid sentences: 172749


In [9]:
def tokenize_sentence(sentence):
    return list(sentence)

def tokens_to_indices(tokens, vocab_to_index):
    return [vocab_to_index[token] for token in tokens]

def add_special_tokens(indices, sos_token_index, eos_token_index):
    return [sos_token_index] + indices + [eos_token_index]

from torch.nn.utils.rnn import pad_sequence

def pad_sequences(batch, padding_value):
    return pad_sequence(batch, batch_first=True, padding_value=padding_value)

In [10]:
from torch.utils.data import Dataset

class TranslationDataset(Dataset):
    def __init__(self, source_sentences, target_sentences, 
                 source_vocab_to_index, target_vocab_to_index,
                 max_length=250):
        self.source_sentences = source_sentences
        self.target_sentences = target_sentences
        self.source_vocab_to_index = source_vocab_to_index
        self.target_vocab_to_index = target_vocab_to_index
        self.max_length = max_length
        
        self.source_sos = source_vocab_to_index['<SOS>']
        self.source_eos = source_vocab_to_index['<EOS>']
        self.source_pad = source_vocab_to_index['<PAD>']
        
        self.target_sos = target_vocab_to_index['<SOS>']
        self.target_eos = target_vocab_to_index['<EOS>']
        self.target_pad = target_vocab_to_index['<PAD>']
        
    def __len__(self):
        return len(self.source_sentences)
    
    def __getitem__(self, idx):
        # Tokenize sentences
        src_tokens = tokenize_sentence(self.source_sentences[idx])
        tgt_tokens = tokenize_sentence(self.target_sentences[idx])
        
        # Convert tokens to indices
        src_indices = tokens_to_indices(src_tokens, self.source_vocab_to_index)
        tgt_indices = tokens_to_indices(tgt_tokens, self.target_vocab_to_index)
        
        # Add special tokens
        src_indices = add_special_tokens(src_indices, self.source_sos, self.source_eos)
        tgt_indices = add_special_tokens(tgt_indices, self.target_sos, self.target_eos)
        
        # Convert to tensors
        src_tensor = torch.tensor(src_indices, dtype=torch.long)
        tgt_tensor = torch.tensor(tgt_indices, dtype=torch.long)
        
        return src_tensor, tgt_tensor


In [11]:
def collate_fn(batch):
    src_batch, tgt_batch = zip(*batch)
    src_batch = pad_sequence(src_batch, batch_first=True, padding_value=english_to_index['<PAD>'])
    tgt_batch = pad_sequence(tgt_batch, batch_first=True, padding_value=tamil_to_index['<PAD>'])
    return src_batch, tgt_batch

In [12]:
from torch.utils.data import DataLoader

dataset = TranslationDataset(
    source_sentences=en_sentences,
    target_sentences=ta_sentences,
    source_vocab_to_index=english_to_index,
    target_vocab_to_index=tamil_to_index
)

batch_size = 16

dataloader = DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=collate_fn
)

In [15]:
model = Transformer(
    num_layers=6,
    d_model=512,
    dff=2048,
    dropout=0.1,
    heads=8,
    src_vocab_size=len(en_vocab),
    tgt_vocab_size=len(ta_vocab),
    max_len=252
).to(device)


In [16]:
loss_fn = nn.CrossEntropyLoss(ignore_index=tamil_to_index['<PAD>'])
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

In [19]:
num_epochs = 24

for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for src_batch, tgt_batch in dataloader:
        # Move tensors to device (GPU or CPU)
        src_batch = src_batch.to(device)
        tgt_batch = tgt_batch.to(device)
        
        # Create source padding mask
        src_padding_mask = create_padding_mask(src_batch, pad_token=english_to_index['<PAD>'])
        src_padding_mask = src_padding_mask.to(device)  # Move mask to device
        
        # Create target padding mask
        tgt_padding_mask = create_padding_mask(tgt_batch[:, :-1], pad_token=tamil_to_index['<PAD>'])
        tgt_padding_mask = tgt_padding_mask.to(device)
        
        # Create causal mask for the target sequence
        seq_len = tgt_batch[:, :-1].size(1)
        causal_mask = create_causal_mask(seq_len)
        causal_mask = causal_mask.to(device)
        
        # Combine target padding mask and causal mask
        combined_mask = combine_masks(tgt_padding_mask, causal_mask)
        combined_mask = combined_mask.to(device)
        
        # Forward pass through the model
        output = model(
            src=src_batch,
            tgt=tgt_batch[:, :-1],  # Input to the decoder (excluding the last token)
            src_padding_mask=src_padding_mask,
            tgt_padding_mask=None,  # Not used directly in the model as per your code
            combined_mask=combined_mask
        )
        
        # Compute loss using target tokens shifted by one position
        target_output = tgt_batch[:, 1:]  # Exclude the first token (<SOS>)
        loss = loss_fn(output.reshape(-1, output.size(-1)), target_output.reshape(-1))
        
        # Backpropagation and optimization steps
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print(f"LOSS: {loss}")
        total_loss += loss.item()

    avg_loss = total_loss / len(dataloader)
    print(f'Epoch {epoch+1}, Loss: {avg_loss}')



LOSS: 51.008331298828125
LOSS: 41.636539459228516
LOSS: 35.714263916015625


KeyboardInterrupt: 