In [99]:
import torch
import numpy as np
import torch.nn as nn

In [100]:
english_file = r'D:\Python Projects\Transformers\train.en\train.en'
hindi_file = r'D:\Python Projects\Transformers\train.hi\train.hi'

In [101]:
start_token = ''
end_token = ''
padding_token = ''

In [102]:
hindi_vocabulary = [start_token, ' ', '!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/', 
                    '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', '<', '=', '>', '?', '@', 
                    'अ', 'आ', 'इ', 'ई', 'उ', 'ऊ', 'ऋ', 'ॠ', 'ऌ', 'ॡ', 'ए', 'ऐ', 'ओ', 'औ', 
                    'क', 'ख', 'ग', 'घ', 'ङ', 
                    'च', 'छ', 'ज', 'झ', 'ञ', 
                    'ट', 'ठ', 'ड', 'ढ', 'ण', 
                    'त', 'थ', 'द', 'ध', 'न', 
                    'प', 'फ', 'ब', 'भ', 'म', 
                    'य', 'र', 'ल', 'ळ', 'व', 'श', 'ष', 'स', 'ह', 
                    '़', 'ऽ', 'ा', 'ि', 'ी', 'ु', 'ू', 'ृ', 'ॄ', 'ॅ', 'ॆ', 'े', 'ै', 'ॉ', 'ॊ', 'ो', 'ौ', '्', 'ं', 'ः', 
                    '०', '१', '२', '३', '४', '५', '६', '७', '८', '९', 
                    padding_token, end_token]

hindi_vocabulary.append('\u200d')

hindi_vocabulary.append('।')


In [103]:
english_vocabulary = [start_token, ' ', '!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/', 
                      '0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
                      ':', '<', '=', '>', '?', '@', 
                      'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 
                      'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 
                      'Y', 'Z',
                      '[', '\\', ']', '^', '_', '`', 
                      'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l',
                      'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 
                      'y', 'z', 
                      '{', '|', '}', '~', padding_token, end_token]


In [None]:
text = "नमस्ते"
list(text)

In [105]:
index_to_hindi = {k:v for k,v in enumerate(hindi_vocabulary)}
index_to_english = {k:v for k,v in enumerate(english_vocabulary)}

hindi_to_index  = {v:k for k,v in enumerate(hindi_vocabulary)}
english_to_index = {v:k for k,v in enumerate(english_vocabulary)}

In [106]:
with open(english_file, 'r') as file:
    english_sentences = file.readlines()

with open(hindi_file, 'r', encoding = 'utf-8') as file:
    hindi_sentences = file.readlines()


In [107]:
total_sentences = 100000
english_sentences = english_sentences[:total_sentences]
hindi_sentences = hindi_sentences[:total_sentences]

english_sentences = [sentence.rstrip('\n') for sentence in english_sentences]
hindi_sentences = [sentence.rstrip('\n') for sentence in hindi_sentences]

In [None]:
hindi_sentences[:3]

In [None]:
max(len(x) for x in hindi_sentences), max(len(x) for x in english_sentences)

In [None]:
PERCENTILE = 97
print( f"{PERCENTILE}th percentile length Hindi: {np.percentile([len(x) for x in hindi_sentences], PERCENTILE)}" )
print( f"{PERCENTILE}th percentile length English: {np.percentile([len(x) for x in english_sentences], PERCENTILE)}")

In [111]:
max_seq_len = 256

def is_valid_tokens(sentence, vocab):
    for token in list(set(sentence)):
        if token not in vocab:
            return False
    return True

def is_valid_length(sentence, max_sequence_length):
    return len(list(sentence)) < (max_sequence_length - 1)

In [None]:
valid_sentence_indicies = []
for index in range(len(hindi_sentences)):
    hindi_sentence, english_sentence = hindi_sentences[index], english_sentences[index]
    if is_valid_length(hindi_sentence, max_seq_len) \
      and is_valid_length(english_sentence, max_seq_len) :
        valid_sentence_indicies.append(index)

print(f"Number of sentences: {len(hindi_sentences)}")
print(f"Number of valid sentences: {len(valid_sentence_indicies)}")

In [113]:
hindi_sentences = [hindi_sentences[i] for i in valid_sentence_indicies]
english_sentences = [english_sentences[i] for i in valid_sentence_indicies]

In [114]:
from torch.utils.data import Dataset, DataLoader

class TextDataset(Dataset):

    def __init__(self, english_sentences, hindi_sentences):
        self.english_sentences = english_sentences
        self.hindi_sentences = hindi_sentences

    def __len__(self):
        return len(self.english_sentences)

    def __getitem__(self, idx):
        return self.english_sentences[idx], self.hindi_sentences[idx]

In [None]:
dataset = TextDataset(english_sentences, hindi_sentences)
len(dataset)

In [None]:


batch_size = 3 
train_loader = DataLoader(dataset, batch_size)
iterator = iter(train_loader)
     

for batch_num, batch in enumerate(iterator):
    print(batch)
    if batch_num > 3:
        break
     


In [117]:
def tokenize(sentence, language_to_index, start_token=False, end_token=False):
    sentence_word_indices = [
        language_to_index.get(token, language_to_index.get('<UNK>', -1)) 
        for token in list(sentence)
    ]
    if start_token:
        sentence_word_indices.insert(0, language_to_index.get('<START>', -1))
    if end_token:
        sentence_word_indices.append(language_to_index.get('<END>', -1))
    # Ensure fixed sequence length
    sentence_word_indices = sentence_word_indices[:max_seq_len]
    padded_sentence = sentence_word_indices + [language_to_index.get('<PAD>', -1)] * (max_seq_len - len(sentence_word_indices))
    return torch.tensor(padded_sentence)


In [118]:
for batch_num, batch in enumerate(iterator):
    if len(batch) < 2:
        print(f"Batch {batch_num} is malformed: {batch}")
        continue


In [None]:
batch_size = 3
eng_tokenized, hn_tokenized = [], []
for sentence_num in range(batch_size):
    eng_sentence, hn_sentence = batch[0][sentence_num], batch[1][sentence_num]
    eng_tokenized.append( tokenize(eng_sentence, english_to_index, start_token=False, end_token=False) )
    hn_tokenized.append( tokenize(hn_sentence, hindi_to_index, start_token=True, end_token=True) )
eng_tokenized = torch.stack(eng_tokenized)
hn_tokenized = torch.stack(hn_tokenized)

In [120]:
NEG_INFTY = -1e9

def create_masks(eng_batch, kn_batch):
    num_sentences = len(eng_batch)
    look_ahead_mask = torch.full([max_seq_len, max_seq_len], True)
    look_ahead_mask = torch.triu(look_ahead_mask, diagonal=1)
    encoder_padding_mask = torch.full([num_sentences, max_seq_len, max_seq_len], False)
    decoder_padding_mask_self_attention = torch.full([num_sentences, max_seq_len, max_seq_len], False)
    decoder_padding_mask_cross_attention = torch.full([num_sentences, max_seq_len, max_seq_len], False)

    for idx in range(num_sentences):
        eng_sentence_length, kn_sentence_length = len(eng_batch[idx]), len(kn_batch[idx])
        eng_chars_to_padding_mask = np.arange(eng_sentence_length + 1, max_seq_len)
        kn_chars_to_padding_mask = np.arange(kn_sentence_length + 1, max_seq_len)
        encoder_padding_mask[idx, :, eng_chars_to_padding_mask] = True
        encoder_padding_mask[idx, eng_chars_to_padding_mask, :] = True
        decoder_padding_mask_self_attention[idx, :, kn_chars_to_padding_mask] = True
        decoder_padding_mask_self_attention[idx, kn_chars_to_padding_mask, :] = True
        decoder_padding_mask_cross_attention[idx, :, eng_chars_to_padding_mask] = True
        decoder_padding_mask_cross_attention[idx, kn_chars_to_padding_mask, :] = True

    encoder_self_attention_mask = torch.where(encoder_padding_mask, NEG_INFTY, 0)
    decoder_self_attention_mask = torch.where(look_ahead_mask + decoder_padding_mask_self_attention, NEG_INFTY, 0)
    decoder_cross_attention_mask = torch.where(decoder_padding_mask_cross_attention, NEG_INFTY, 0)
    print(f"encoder_self_attention_mask {encoder_self_attention_mask.size()}: {encoder_self_attention_mask[0, :10, :10]}")
    print(f"decoder_self_attention_mask {decoder_self_attention_mask.size()}: {decoder_self_attention_mask[0, :10, :10]}")
    print(f"decoder_cross_attention_mask {decoder_cross_attention_mask.size()}: {decoder_cross_attention_mask[0, :10, :10]}")
    return encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask


In [None]:
class SentenceEmbedding(nn.Module):
    "For a given sentence, create an embedding"
    def __init__(self, max_sequence_length, d_model, language_to_index, START_TOKEN, END_TOKEN, PADDING_TOKEN):
        super().__init__()
        self.vocab_size = len(language_to_index)
        self.max_sequence_length = max_sequence_length
        self.embedding = nn.Embedding(self.vocab_size, d_model)
        self.language_to_index = language_to_index
        self.position_encoder = PositionalEncoding(d_model, max_sequence_length)
        self.dropout = nn.Dropout(p=0.1)
        self.START_TOKEN = START_TOKEN
        self.END_TOKEN = END_TOKEN
        self.PADDING_TOKEN = PADDING_TOKEN
    
    def batch_tokenize(self, batch, start_token=True, end_token=True):

        def tokenize(sentence, start_token=True, end_token=True):
            sentence_word_indicies = [self.language_to_index[token] for token in list(sentence)]
            if start_token:
                sentence_word_indicies.insert(0, self.language_to_index[self.START_TOKEN])
            if end_token:
                sentence_word_indicies.append(self.language_to_index[self.END_TOKEN])
            for _ in range(len(sentence_word_indicies), self.max_sequence_length):
                sentence_word_indicies.append(self.language_to_index[self.PADDING_TOKEN])
            return torch.tensor(sentence_word_indicies)

        tokenized = []
        for sentence_num in range(len(batch)):
           tokenized.append( tokenize(batch[sentence_num], start_token, end_token) )
        tokenized = torch.stack(tokenized)
        return tokenized.to(get_device())
    
    def forward(self, x, end_token=True): # sentence
        x = self.batch_tokenize(x ,end_token)
        x = self.embedding(x)
        pos = self.position_encoder().to(get_device())
        x = self.dropout(x + pos)
        return x