<a href="https://colab.research.google.com/github/Rohit-Potnuru/Machine-Translation-EN-TE/blob/main/Training_Machine_Translation_EN_TE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


# Helper Functions

In [2]:
import time
def format_time(seconds):
    hours = seconds // 3600
    minutes = (seconds % 3600) // 60
    seconds = seconds % 60

    if hours > 0:
        return f"{int(hours):02d}::{int(minutes):02d}::{int(seconds):02d}"
    elif minutes > 0:
        return f"{int(minutes):02d}::{int(seconds):02d}"
    else:
        return f"{seconds:.3f} sec"

In [3]:
def num_parameters(n_parameter):
  scales = {1e12: "T",
            1e9: "B",
            1e6: "M",
            1e3: "K"
           }
  for scale, val in scales.items():
    if(n_parameter > scale):
      return f'{n_parameter/scale}{val} parameters'

# Preprocessing

In [5]:
!pip install memory-profiler
%load_ext memory_profiler



## English and Telugu Vocabulary

In [6]:
# Start the timer
start_time = time.time()

In [7]:
START_TOKEN = '<START>'
PADDING_TOKEN = '<PADDING>'
END_TOKEN = '<END>'
telugu_vocabulary = [START_TOKEN,
    ' ', '!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/',
    '0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
    ':', '<', '=', '>', '?', '@', 'ˌ',
    '[', '\\', ']', '^', '_', '`','{', '|', '}', '~',

    'అ', 'ఆ', 'ఇ', 'ఈ', 'ఉ', 'ఊ', 'ఋ', 'ౠ', 'ఎ', 'ఏ', 'ఐ', 'ఒ','ఓ', 'ఔ',
    'ా', 'ి', 'ీ', 'ు', 'ూ', 'ృ', 'ౄ', 'ె', 'ే',  'ై',  'ొ', 'ో', 'ౌ', '్',  'ఁ', 'ం', 'ః','ఀ',

    'క', 'ఖ', 'గ', 'ఘ', 'ఙ',
    'చ', 'ఛ', 'జ', 'ఝ','ఞ',
    'ట', 'ఠ', 'డ', 'ఢ', 'ణ',
    'త', 'థ', 'ద', 'ధ', 'న',
    'ప', 'ఫ', 'బ', 'భ', 'మ',
    'య', 'ర', 'ల', 'వ', 'ళ', 'శ', 'ష', 'స', 'హ', 'ఱ',

   # 0,   1,   2,   3,   4,   5,   6,   7,   8,  9
    '౦', '౧', '౨', '౩', '౪', '౫', '౬', '౭', '౮', '౯',
    PADDING_TOKEN, END_TOKEN
]

english_vocabulary = [START_TOKEN,
    ' ', '!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/',
    '0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
    ':', '<', '=', '>', '?', '@', 'ˌ',
    '[', '\\', ']', '^', '_', '`', '{', '|', '}', '~',

    'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l',
    'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x',
    'y', 'z',
    PADDING_TOKEN, END_TOKEN]

print(f'total number of characters for Telugu Vocbulary is {len(telugu_vocabulary)}')
print(f'total number of characters for English Vocbulary is {len(english_vocabulary)}')

total number of characters for Telugu Vocbulary is 123
total number of characters for English Vocbulary is 72


In [8]:
def generate_index_key_map(keys):
  key_to_index = { ch: i for i, ch in enumerate(keys)}
  index_to_key = { i: ch for ch, i in key_to_index.items()}
  return key_to_index, index_to_key

telugu_to_index, index_to_telugu = generate_index_key_map(telugu_vocabulary)
english_to_index, index_to_english = generate_index_key_map(english_vocabulary)

In [9]:
def tokenize(sentence, key_to_index, max_sentence_length, start = True, end = True):
  sentence_indicies = []
  if start:
    sentence_indicies += [key_to_index[START_TOKEN]]
  sentence_indicies += [key_to_index[ch] for ch in list(sentence)]
  if end:
    sentence_indicies.append(key_to_index[END_TOKEN])
  # Adding PADDING_TOKEN so that sentence_indicies array length is equivalent to max_sentence_length
  if max_sentence_length is not None:
    sentence_indicies += [key_to_index[PADDING_TOKEN]] * (max_sentence_length - len(sentence_indicies))
  return sentence_indicies

def replace_token(ch, replaceTokens):
  if ch in replaceTokens:
    return replaceTokens[ch]
  return ch


def generate_encoder_decoder(key_to_index, index_to_key):
  replaceTokens = {START_TOKEN : "", PADDING_TOKEN : "", END_TOKEN : ""}

  encoder = lambda s, max_sentence_length = None, start = True, end = True: tokenize(s.lower(), key_to_index, max_sentence_length, start, end)
  decoder = lambda enc_list, replaceTokens = replaceTokens: ''.join([replace_token(index_to_key[i], replaceTokens) for i in enc_list])

  return encoder, decoder

telugu_sen_encoder, telugu_sen_decoder = generate_encoder_decoder(telugu_to_index, index_to_telugu)
english_sen_encoder, english_sen_decoder = generate_encoder_decoder(english_to_index, index_to_english)

In [10]:
# def generate_encoder_decoder(key_to_index, index_to_key, enable_sentence_list = False):
#   replaceTokens = {START_TOKEN : "", PADDING_TOKEN : "", END_TOKEN : ""}

#   if (enable_sentence_list):
#     encoder = lambda list_s, max_sentence_length: \
#                                 [tokenize(s.lower(), index_to_key, max_sentence_length) for s in list_s]
#     decoder = lambda enc_sen_list, replaceTokens = replaceTokens: \
#                                 [''.join([replace_token(index_to_key[i], replaceTokens) for i in enc_list]) for enc_list in enc_sen_list]
#   else:
#     encoder = lambda s, max_sentence_length: [tokenize(s.lower(), index_to_key, max_sentence_length) for key in s.lower()]
#     decoder = lambda enc_list, replaceTokens = replaceTokens: ''.join([replace_token(index_to_key[i], replaceTokens) for i in enc_list])
#   return encoder, decoder

# def encoder(list_sen, max_sentence_length = None):
#   sen_flag = [False]
#   if(isinstance(list_sen, str)):
#     list_sen = [list_sen]
#     sen_flag[0] = True

#   enc_list_sen = []
#   for s in list_sen:
#     enc_list_sen.append(tokenize(s.lower(), key_to_index, max_sentence_length))

#   if(sen_flag[0]):
#     return enc_list_sen[0]
#   return enc_list_sen

# def decoder(enc_list, replaceTokens = replaceTokens):
#   sen_flag = [False]
#   if(isinstance(enc_list[0], int)):
#     enc_list = [enc_list]
#     sen_flag[0] = True
#   dec_list = []
#   for enc_sen in enc_list:
#     dec_sen = ""
#     for i in enc_sen:
#       dec_sen += replace_token(index_to_key[i], replaceTokens)
#     dec_list.append(dec_sen)

#   if(sen_flag[0]):
#     return dec_list[0]
#   return dec_list

In [11]:
print(english_sen_encoder('hi i am don deenu'))
print(english_sen_decoder(english_sen_encoder('hi I am Don Seenu')))

[0, 51, 52, 1, 52, 1, 44, 56, 1, 47, 58, 57, 1, 47, 48, 48, 57, 64, 71]
hi i am don seenu


In [12]:
print(f'time execution: {format_time(time.time() - start_time)}')

time execution: 0.593 sec


## Reading the Dataset

In [13]:
# Start the timer
start_time = time.time()

In [14]:
file_path = '/content/drive/MyDrive/Colab Notebooks/Machine Translation EN-TE/Dataset/filtered'
english_file = f'{file_path}/train.en'
telugu_file = f'{file_path}/train.te'

In [None]:
with open(english_file) as file:
  english_sentences = [line.rstrip('\n').lower() for line in file.readlines()]

with open(telugu_file) as file:
  telugu_sentences = [line.rstrip('\n') for line in file.readlines()]

assert len(telugu_sentences) == len(english_sentences), f"English and Telugu sentences count are not same"
n = len(english_sentences)
print(f'Total number of sentences: {n}')

Total number of sentences: 3756786


In [None]:
# n = len(telugu_sentences)
# te_len = [len(x) for x in telugu_sentences]
# te_per = [np.percentile(te_len, i) for i in range(101)]

# en_len = [len(x) for x in english_sentences]
# en_per = [np.percentile(en_len, i) for i in range(101)]

# # for idx, (e,t) in enumerate(zip(te_per, en_per)):
# #   print(idx, e, t)

# seq_len = 200
# print(len([x for x in en_len if x > seq_len]))
# print(len([x for x in te_len if x > seq_len]))

In [None]:
import numpy as np
PERCENTILE = 99
print( f"{PERCENTILE}th percentile length Telugu: {np.percentile([len(list(x)) for x in telugu_sentences], PERCENTILE)}" )
print( f"{PERCENTILE}th percentile length English: {np.percentile([len(list(x)) for x in english_sentences], PERCENTILE)}" )

99th percentile length Telugu: 193.0
99th percentile length English: 210.0


In [None]:
max_sequence_length = 200

In [None]:
#Filtering sentences based on vocabulary characters
def is_valid_tokens(sentence, vocab):
    for token in list(set(sentence)):
        if token not in vocab:
            return False
    return True

def is_valid_length(sentence, max_sequence_length):
    return len(list(sentence)) < (max_sequence_length - 1) # need to re-add the end token so leaving 1 space

en_voc = set(english_vocabulary)
te_voc = set(telugu_vocabulary)
valid_sentence_indicies = []

for i in range(len(english_sentences)):
  en_sen, te_sen = english_sentences[i], telugu_sentences[i]
  if (is_valid_length(en_sen, max_sequence_length) \
      and is_valid_length(te_sen, max_sequence_length) \
      and is_valid_tokens(en_sen, en_voc)) \
      and is_valid_tokens(te_sen, te_voc):
        valid_sentence_indicies.append(i)

english_sentences = [english_sentences[i] for i in valid_sentence_indicies]
telugu_sentences = [telugu_sentences[i] for i in valid_sentence_indicies]

print(f'Total number of VALID EN-TE sentence pairs : {len(english_sentences)}')
if (n - len(english_sentences)) > 0:
  print(f'Total number of INVALID EN-TE sentence pairs : {n - len(english_sentences)}')

Total number of VALID EN-TE sentence pairs : 3699089
Total number of INVALID EN-TE sentence pairs : 57697


In [None]:
print(f'Some Examples: ')
for en_sen, te_sen in zip(english_sentences[:4], telugu_sentences[:4]):
  print(f'{en_sen} -> {te_sen}')

Some Examples: 
have you heard about foie gras? -> ఇక ఫ్రూట్ ఫ్లైస్ గురించి మీరు విన్నారా?
i never thought of acting in films. -> సూర్య సినిమాల్లో నటించాలని ఎప్పుడూ అనుకోలేదు.
a case has been registered under sections 302 and 376, ipc. -> నిందితులపై సెక్షన్ 376 మరియు 302ల కింద కేసు నమోదు చేశాం.
of this, 10 people succumbed to the injuries. -> అందులో 10 మంది తీవ్రంగా గాయపడ్డారు.


In [None]:
print(f'time execution: {format_time(time.time() - start_time)}')

time execution: 57.053 sec


## Dataset Generation

In [None]:
# Start the timer
start_time = time.time()

In [None]:
# enc_english_sentences = english_sen_encoder(english_sentences, max_sequence_length)
# del english_sentences
# enc_telugu_sentences = telugu_sen_encoder(telugu_sentences, max_sequence_length)
# del telugu_sentences

In [None]:
def split_dataset(lang1_sentences, lang2_sentences, ratios):
  assert sum(ratios) == 100, f'split is not perfect'
  train_ratio, val_ratio, test_ratio = ratios
  n = len(english_sentences)
  n1, n2 = int(n * 0.01 * train_ratio), int(n * 0.01 * (train_ratio + val_ratio))

  dataset = {}
  dataset['train'] = (lang1_sentences[:n1], lang2_sentences[:n1])
  dataset['val'] = (lang1_sentences[n1:n2], lang2_sentences[n1:n2])
  dataset['test'] = (lang1_sentences[n2:], lang2_sentences[n2:])
  return dataset

en_te_dataset = split_dataset(english_sentences, telugu_sentences, [85, 7, 8])
del english_sentences
del telugu_sentences

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

# Dataset Ref: https://pytorch.org/docs/stable/data.html#torch.utils.data.Dataset
class TextDataset(Dataset):
    def __init__(self, lang1_sentences,
                       lang2_sentences,
                       lang1_encoder,
                       lang2_encoder,
                       max_sentence_length):
        self.lang1_sentences = lang1_sentences
        self.lang2_sentences = lang2_sentences
        self.max_sentence_length = max_sentence_length

    def __len__(self):
        return len(self.lang2_sentences)

    def __getitem__(self, idx):
        return self.lang1_sentences[idx], self.lang2_sentences[idx]

In [None]:
for split in en_te_dataset.keys():
  en_te_dataset[split] = TextDataset(en_te_dataset[split][0],
                                      en_te_dataset[split][1],
                                      english_sen_encoder,
                                      telugu_sen_encoder,
                                      max_sequence_length
                                      )

In [None]:
len(en_te_dataset['train']), len(en_te_dataset['val']), len(en_te_dataset['test'])

(3144225, 258936, 295928)

In [None]:
# batch_size = 300000
# trainDataLoader = DataLoader(train_dataset, batch_size)
# iterator = iter(trainDataLoader)

In [None]:
# start_time = time.time()
# for batch_num, batch in enumerate(iterator):
#   (en_batch, te_batch), tar_te_batch = batch
#   print(batch_num)
# print(f'time execution: {format_time(time.time() - start_time)}')

In [None]:
print(f'time execution: {format_time(time.time() - start_time)}')

time execution: 2.739 sec


# Transformer

In [None]:
import torch
import io
import torch.nn as nn
import torch.nn.functional as F

## Transformer Model Architecture

In [None]:
class AttentionHead(nn.Module):
  def __init__(self, config, mask):
    super().__init__()
    assert isinstance(mask, bool), f'mask should be boolean, please provide a valid "mask" input'

    self.config = config
    self.mask = mask
    self.query = nn.Linear(config.n_embd, config.head_size, bias=False)
    self.key = nn.Linear(config.n_embd, config.head_size)
    self.value = nn.Linear(config.n_embd, config.head_size)
    self.register_buffer('tril', torch.tril(torch.ones(config.block_size, config.block_size)))
    self.dropout = nn.Dropout(config.dropout)

  def forward(self, q, k, v):
    B, T, C = q.shape
    q = self.query(q) # (B, T, head_size)
    k = self.key(k) # (B, T, head_size)
    w = q @ k.transpose(-2, -1) * k.shape[-1] ** (-0.5) # (B, T, head_size) @ (B, head_size, T) -> (B, T, T)
    if (self.mask):
      w = w.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
    w = F.softmax(w, dim = -1)
    w = self.dropout(w)

    v = self.value(v) # (B, T, head_size)
    out = w @ v # (B, T, T) @ (B, T, head_size) -> (B, T, head_size)
    return out # (B, T, head_size)

class MultiAttentionHead(nn.Module):
  def __init__(self, config, mask = False):
    super().__init__()
    assert config.n_embd % config.n_head == 0, f'"n_head" should be divisible of "n_embd", please provide valid hyperparameter inputs'
    assert config.n_embd == config.n_head * config.head_size, f'"n_embd" is not equal to "n_head * config.head_size", please provide valid hyperparameter inputs'

    self.config = config
    self.attn_heads = nn.ModuleList([AttentionHead(config, mask) for _ in range(config.n_head)])
    self.proj = nn.Linear(config.n_head * config.head_size, config.n_embd)
    self.dropout = nn.Dropout(config.dropout)

  def forward(self, q, k, v):
    out = torch.cat([h(q, k, v)for h in self.attn_heads], dim = -1) # (B, T, n_head * head_size)
    out = self.proj(out) # (B, T, n_embd)
    out = self.dropout(out)
    return out # (B, T, n_embd)

In [None]:
class FeedForward(nn.Module):
  def __init__(self, config):
    super().__init__()
    assert config.n_embd is not None, f'"n_embd" is not present, please provide valid inputs'
    assert config.dropout is not None, f'"dropout" is not present, please provide valid inputs'

    self.config = config
    self.net = nn.Sequential(
        nn.Linear(config.n_embd, 4*config.n_embd),
        nn.GELU(),
        nn.Linear(4 * config.n_embd, config.n_embd),
        nn.Dropout(config.dropout),
    )

  def forward(self, x):
    out = self.net(x)
    return out

In [None]:
class EncoderBlock(nn.Module):
  def __init__(self, config):
    super().__init__()
    self.config = config
    self.ln1 = nn.LayerNorm(config.n_embd)
    self.attn_heads = MultiAttentionHead(config)
    self.ln2 = nn.LayerNorm(config.n_embd)
    self.ffd = FeedForward(config)

  def forward(self, x):
    out = self.ln1(x)
    x = self.attn_heads(out, out, out) + x
    out = self.ln2(x)
    out = self.ffd(out) + x
    return out

In [None]:
class SequentialEncoder(nn.Sequential):
  def forward(self, *inputs):
    x, = inputs
    for module in self._modules.values():
      x = module(x)
    return x

class Encoder(nn.Module):
  def __init__(self, config):
    super().__init__()
    self.config = config
    self.token_embedding_table = nn.Embedding(config.enc_vocab_size, config.n_embd)
    self.pos_embedding_table = nn.Embedding(config.enc_vocab_size, config.n_embd)
    self.encoderBlocks = SequentialEncoder(*[EncoderBlock(config) for _ in range(config.n_layer)])
    self.ln_f = nn.LayerNorm(config.n_embd)

  def forward(self, x):
    B, T = x.shape # (B, T)
    tok_emb = self.token_embedding_table(x) # (B, T, C) C===n_embd
    pos_emb = self.pos_embedding_table(x) # (B, T, C)
    out = tok_emb + pos_emb # (B, T, C)
    out = self.encoderBlocks(out) # (B, T, C)
    out = self.ln_f(out)
    return out # (B, T, C)

In [None]:
class DecoderBlock(nn.Module):
  def __init__(self, config):
    super().__init__()
    self.config = config
    self.ln1 = nn.LayerNorm(config.n_embd)
    self.mask_attn_heads = MultiAttentionHead(config,True)
    self.ln2 = nn.LayerNorm(config.n_embd)
    self.cross_attn_heads = MultiAttentionHead(config)
    self.ln3 = nn.LayerNorm(config.n_embd)
    self.ffd = FeedForward(config)
    self

  def forward(self, x, enc_out):
    out = self.ln1(x)
    x = self.mask_attn_heads(out, out, out) + x # (B, T, C)
    out = self.ln2(x)
    x = self.cross_attn_heads(enc_out, enc_out, out) + x # (B, T, C)
    out = self.ln3(x)
    out = self.ffd(out) + x # (B, T, C)
    return out # (B, T, C)

In [None]:
class SequentialDecoder(nn.Sequential):
  def forward(self, *inputs):
    x, enc_out = inputs
    for module in self._modules.values():
      x = module(x, enc_out)
    return x

class Decoder(nn.Module):
  def __init__(self, config):
    super().__init__()
    self.config = config
    self.token_embedding_table = nn.Embedding(config.dec_vocab_size, config.n_embd)
    self.pos_embedding_table = nn.Embedding(config.dec_vocab_size, config.n_embd)
    self.decoderBlocks = SequentialDecoder(*[DecoderBlock(config) for _ in range(config.n_layer)])
    self.ln_f = nn.LayerNorm(config.n_embd)
    self.lm_head = nn.Linear(config.n_embd, config.dec_vocab_size)

  def forward(self, x, enc_out):
    B, T = x.shape # (B, T)
    tok_emb = self.token_embedding_table(x) # (B, T, C) C===n_embd
    pos_emb = self.pos_embedding_table(x) # (B, T, C)
    out = tok_emb + pos_emb # (B, T, C)
    out = self.decoderBlocks(out, enc_out) # (B, T, C)
    out = self.ln_f(out)
    out = self.lm_head(out) # (B, T, dec_vocab_size)

    return out # (B, T, C)

In [None]:
class TokenEncoding(nn.Module):
  def __init__(self, config):
    super().__init__()
    self.config = config

  def forward(self, lang_sen_list, start = False, end = True):
    if isinstance(lang_sen_list, str):
      lang_sen_list = [lang_sen_list]

    return torch.tensor([self.config.encoder(sentence,
                         self.config.block_size,
                         start,
                         end) for sentence in lang_sen_list], dtype = torch.long)

class TokenDecoding(nn.Module):
  def __init__(self, config):
    super().__init__()
    self.config = config

  def forward(self, enc_sen_list, replace_token = None):
    if replace_token is not None:
      return [self.config.decoder(enc_list, replace_token) for enc_list in enc_sen_list]
    return [self.config.decoder(enc_list) for enc_list in enc_sen_list]


class SentenceTokenize():
  def __init__(self, config):
    assert config.sen_encoder is not None
    assert config.sen_decoder is not None

    self.config = config
    self.encoder, self.decoder = generate_encoder_decoder(config.sen_encoder, config.sen_decoder)

  def tokenize(sentence, key_to_index, max_sentence_length, start = True, end = True):
    sentence_indicies = []
    if start:
      sentence_indicies += [key_to_index[START_TOKEN]]
    sentence_indicies += [key_to_index[ch] for ch in list(sentence)]
    if end:
      sentence_indicies.append(key_to_index[END_TOKEN])
    # Adding PADDING_TOKEN so that sentence_indicies array length is equivalent to max_sentence_length
    if max_sentence_length is not None:
      sentence_indicies += [key_to_index[PADDING_TOKEN]] * (max_sentence_length - len(sentence_indicies))
    return sentence_indicies

  def replace_token(ch, replaceTokens):
    if ch in replaceTokens:
      return replaceTokens[ch]
    return ch

  def generate_encoder_decoder(key_to_index, index_to_key):
    replaceTokens = {START_TOKEN : "", PADDING_TOKEN : "", END_TOKEN : ""}

    encoder = lambda s, max_sentence_length = None, start = True, end = True: tokenize(s.lower(), key_to_index, max_sentence_length, start, end)
    decoder = lambda enc_list, replaceTokens = replaceTokens: ''.join([replace_token(index_to_key[i], replaceTokens) for i in enc_list])

    return encoder, decoder

  def sen_encode(self, lang_sen_list, start = False, end = True):
    if isinstance(lang_sen_list, str):
      lang_sen_list = [lang_sen_list]

    return torch.tensor([self.encoder(sentence,
                         self.config.block_size,
                         start,
                         end) for sentence in lang_sen_list], dtype = torch.long)

  def sen_decode(self, enc_sen_list, replace_token = None):
    if replace_token is not None:
      return [self.decoder(enc_list, replace_token) for enc_list in enc_sen_list]
    return [self.decoder(enc_list) for enc_list in enc_sen_list]

In [None]:
class Transformer(nn.Module):
  def __init__(self, model_config, token_configs):
    super().__init__()
    assert len(token_configs) == 2, f'token_configs should two config classes, one for source language conversions, one for target language conversions'
    self.config = model_config
    self.encoder = Encoder(model_config)
    self.decoder = Decoder(model_config)

    self.token_configs = token_configs
    src_token_config, tar_token_config = token_configs
    self.src_tokenize = SentenceTokenize(src_token_config)
    self.tar_tokenize = SentenceTokenize(tar_token_config)

    # self.token_configs = token_configs
    # srclang_token_config, tarlang_token_config = token_configs
    # self.srclang_token_encoder = TokenEncoding(srclang_token_config)
    # self.srclang_token_decoder = TokenDecoding(srclang_token_config)
    # self.tarlang_token_encoder = TokenEncoding(tarlang_token_config)
    # self.tarlang_token_decoder = TokenDecoding(tarlang_token_config)

    self.initial()

  def initial(self):
    for params in self.parameters():
      if params.dim() > 1:
          nn.init.xavier_uniform_(params)

  def forward(self, enc_inp, dec_inp, targets = None):
    enc_inp = self.src_tokenize.sen_encode(enc_inp, False, False).to(self.config.device) # (B, T)
    enc_out = self.encoder(enc_inp) # (B, T, C) C===n_embd

    dec_inp = self.tar_tokenize.sen_encode(dec_inp, True, True).to(self.config.device) # (B, T)
    dec_inp = dec_inp.to(self.config.device)
    logits = self.decoder(dec_inp, enc_out) # (B, T, dec_vocab_size)

    if targets is None:
      loss = None
    else:
      targets = self.tar_tokenize.sen_encode(targets, False, True).to(self.config.device) # (B, T)
      B, T, V = logits.shape
      logits = logits.view(B * T, V) # (B * T, V)
      targets = targets.view(B * T) # (B * T)
      loss = F.cross_entropy(logits, targets)

    return logits, loss

  @torch.no_grad()
  def generate(self, enc_inp):
    self.eval()
    flag = False
    if isinstance(enc_inp, str):
      flag = True
      enc_inp = [enc_inp]

    B = len(enc_inp)
    outputs = []

    for batch_num in range(B):
      dec_inp = [""]
      for i in range(max_sequence_length - 1):
        logits, _ = self(enc_inp, dec_inp)
        next_token_index = torch.argmax(logits, dim = -1)
        next_token_index = next_token_index[:, i].view(-1, 1)
        next_token = self.tar_tokenize.sen_decode(next_token_index.tolist(),{})
        if next_token[0] in [START_TOKEN, END_TOKEN, PADDING_TOKEN]:
          break
        dec_inp = [sen + nxt_t for (sen, nxt_t) in zip(dec_inp, next_token)]
      outputs.append(dec_inp[0])
    if flag: return outputs[0]
    return outputs

  def save_model(self, filepath, losses = None, config = None):
    if config is None:
      config = self.config
    torch.save({
        "state_dict": self.state_dict(),
        "config": self.config,
        "token_configs": self.token_configs,
        "losses": losses
    }, filepath)

  def load_model(filepath, device = None):
    # Load the model from the buffer
    checkpoint = torch.load(filepath)
    config = checkpoint['config']
    token_configs = checkpoint['token_configs']
    if device is not None:
      config['device'] = device
    model = Transformer(config, token_configs)
    return model

## Hyper Parameters

In [None]:
# Training HyperParameters
learning_rate = 3e-4
max_iters = 10
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_interval = 1000
eval_iters = 200

In [None]:
class Config:
  def __init__(self, **kwargs):
     for key, value in kwargs.items():
      setattr(self, key, value)

model_config = Config(
    batch_size = 32,
    block_size = max_sequence_length,
    enc_vocab_size = len(english_vocabulary),
    dec_vocab_size = len(telugu_vocabulary),

    n_embd = 512,
    n_layer = 2,
    n_head = 8,
    head_size = 512//8,
    dropout = 0.2,

    learning_rate = learning_rate,
    max_iter = max_iters,
    device = device,
    eval_interval = eval_interval,
    eval_iters = eval_iters
)

en_token_config = Config(
    block_size = model_config.block_size,
    sen_encoder = english_to_index,
    sen_decoder = index_to_english
)

te_token_config = Config(
    block_size = model_config.block_size,
    sen_encoder = telugu_to_index,
    sen_decoder = index_to_telugu
)

## Training

In [None]:
def get_DataLoader(dataset, batch_size):
  dataLoader = DataLoader(dataset, batch_size, shuffle = True)
  iterator = iter(dataLoader)
  return dataLoader, iterator

trainDataLoader, trainIterator = get_DataLoader(en_te_dataset['train'], model_config.batch_size)
valDataLoader, valIterator = get_DataLoader(en_te_dataset['val'], batch_size = model_config.batch_size)
testDataLoader, testIterator = get_DataLoader(en_te_dataset['test'], 1)

In [None]:
transformer = Transformer(model_config, (en_token_config, te_token_config))
transformer.to(device)
optimizer = torch.optim.AdamW(transformer.parameters(), lr = learning_rate)

In [None]:
print(f'number of parameters of the model: {num_parameters(sum(p.numel() for p in transformer.parameters()))}')

number of parameters of the model: 14.974587M parameters


In [None]:
# def get_batch(split, batch_size):
#   assert split in en_te_dataset.keys(), f'Provide valid split which should be ["train", "val", "test"]'

#   enc_data = []
#   dec_data = []
#   ix = torch.randint(len(data), (batch_size,))
#   for i in ix:
#     enc_sen, dec_sen = en_te_dataset[split][i.item()]
#     enc_data.append(enc_sen)
#     dec_data.append(dec_sen)
#   return enc_data, dec_data

# @torch.no_grad()
# def estimate_loss(model, config):
#   out = {}
#   model.eval()
#   for split in ['train', 'val']:
#     losses = torch.zeros(eval_iters)
#     for k in range(config.eval_iters):
#       en_batch, te_batch = get_batch(split, config.batch_size)
#       logits, loss = transformer(en_batch, te_batch, te_batch)
#       losses[k] = (loss.item())
#     out[split] = losses.mean()
#   return out

In [None]:
start_time = time.time()
losses = {'train':[],
          'val': []}
for step in range(max_iters):
  out = {'train':[],
          'val': []}
  for batch_num, batch in enumerate(trainIterator):
    transformer.train()
    en_batch, te_batch = batch
    logits, loss = transformer(en_batch, te_batch, te_batch)
    out['train'].append(loss)
    optimizer.zero_grad(set_to_none = True)
    loss.backward()
    optimizer.step()

    transformer.eval()
    for val_batch in valDataLoader:
      en_val_batch, te_val_batch = val_batch
      break
    logits, loss = transformer(en_val_batch, te_val_batch, te_val_batch)
    out['val'].append(loss)

    if batch_num % model_config.eval_interval == 0:
      for split in losses.keys():
        losses[split].append(torch.tensor(out[split]).mean().item())
        out[split] = []

      print(f'Iteration {batch_num} Train loss: {losses["train"][-1]}, Val loss: {losses["val"][-1]}, time taken: {time.time() - start_time}')
      en_sentence = "I love you"
      te_sentence = "నేను నిన్ను ప్రేమిస్తున్నాను"
      print(f'English: {en_sentence}')
      print(f'Telugu Prediction: {transformer.generate(en_sentence)}')
      print(f'Telugu Target: {te_sentence}')
      print(f'------------------------------------------------------------------------------------------------------------------------')

  for split in losses.keys():
    losses[split].append(torch.tensor(out[split]).mean().item())

  print(f'***************************************************************************************************************************')
  print(f'Epoch {step} Train loss: {losses["train"][-1]}, Val loss: {losses["val"][-1]}, time taken: {time.time() - start_time}')
  en_sentence = "I love you"
  te_sentence = "నేను నిన్ను ప్రేమిస్తున్నాను"
  print(f'English: {en_sentence}')
  print(f'Telugu Prediction: {transformer.generate(en_sentence)}')
  print(f'Telugu Target: {te_sentence}')
  print(f'***************************************************************************************************************************')
print(f'time execution: {format_time(time.time() - start_time)}')

Iteration 0 Train loss: 3.8702831268310547, Val loss: 3.863093852996826, time taken: 1.3994348049163818
English: I love you
Telugu Prediction: 
Telugu Target: నేను నిన్ను ప్రేమిస్తున్నాను
------------------------------------------------------------------------------------------------------------------------


OutOfMemoryError: CUDA out of memory. Tried to allocate 20.00 MiB. GPU 0 has a total capacty of 14.75 GiB of which 5.06 MiB is free. Process 80915 has 14.74 GiB memory in use. Of the allocated memory 14.07 GiB is allocated by PyTorch, and 555.02 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [None]:
filepath = './save_model.pt'
transformer = Transformer(model_config, (en_token_config, te_token_config))
# transformer.save_model(filepath)
# transformer = Transformer.load_model(filepath)
lang1_inp = 'Rohit'
transformer.generate(lang1_inp)

'ూఝృదఘద దఇ౪ా ఫాఝా చదఉ"దఝదఝఇ"ఇషఇఇ\'ష"జ|షషఝష"ఝషఇఝఝఝ>౧ఇష"ఝఇ(ఇ౯ ."|జషూజ7ఙజష""ష"""""షఝజజఝష<ౄఝ"నఝూఝఝఝఝప"చ/చఝచచచఝ",\'"ఇచఐచఝఙచ,చ*చఝఙ`ఝఋచఐష|"\'చచఐ"ౠఇ\'ఐషఝఐచేచఐ\'`నఐఐ\'౭$`ృపఐచఐ౭ఝ*ఐ"9౭ఐష\'ఐౠఋఝఐచఉఐఐ\'ఐ౭ఐచఐ\'౭ఐ\'౭ఐషఐఐ\'ఐఐచఉఐ'