In [6]:
import numpy as np
import torch
import math
from torch import nn
import torch.nn.functional as F

def get_device():
    return torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

def scaled_dot_product(q, k, v, mask=None):
    d_k = q.size()[-1]
    scaled = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(d_k)
    if mask is not None:
        scaled = scaled.permute(1, 0, 2, 3) + mask
        scaled = scaled.permute(1, 0, 2, 3)
    attention = F.softmax(scaled, dim=-1)
    values = torch.matmul(attention, v)
    return values, attention

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_sequence_length):
        super().__init__()
        self.max_sequence_length = max_sequence_length
        self.d_model = d_model

    def forward(self):
        even_i = torch.arange(0, self.d_model, 2).float()
        denominator = torch.pow(10000, even_i/self.d_model)
        position = (torch.arange(self.max_sequence_length)
                          .reshape(self.max_sequence_length, 1))
        even_PE = torch.sin(position / denominator)
        odd_PE = torch.cos(position / denominator)
        stacked = torch.stack([even_PE, odd_PE], dim=2)
        PE = torch.flatten(stacked, start_dim=1, end_dim=2)
        return PE

class SentenceEmbedding(nn.Module):
    "For a given sentence, create an embedding"
    def __init__(self, max_sequence_length, d_model, language_to_index, START_TOKEN, END_TOKEN, PADDING_TOKEN):
        super().__init__()
        self.vocab_size = len(language_to_index)
        self.max_sequence_length = max_sequence_length
        self.embedding = nn.Embedding(num_embeddings=self.vocab_size,embedding_dim= d_model)
        self.language_to_index = language_to_index
        self.position_encoder = PositionalEncoding(d_model, max_sequence_length)
        self.dropout = nn.Dropout(p=0.1)
        self.START_TOKEN = START_TOKEN
        self.END_TOKEN = END_TOKEN
        self.PADDING_TOKEN = PADDING_TOKEN
    
    def batch_tokenize(self, batch, start_token, end_token):

        def tokenize(sentence, start_token, end_token):
            sentence_word_indicies = [self.language_to_index[token] for token in list(sentence)]
            if start_token:
                sentence_word_indicies.insert(0, self.language_to_index[self.START_TOKEN])
            if end_token:
                sentence_word_indicies.append(self.language_to_index[self.END_TOKEN])
            for _ in range(len(sentence_word_indicies), self.max_sequence_length):
                sentence_word_indicies.append(self.language_to_index[self.PADDING_TOKEN])
            return torch.tensor(sentence_word_indicies)

        tokenized = []
        for sentence_num in range(len(batch)):
           tokenized.append( tokenize(batch[sentence_num], start_token, end_token) )
        tokenized = torch.stack(tokenized)
        return tokenized.to(get_device())
    
    def forward(self, x, start_token, end_token): # sentence
        x = self.batch_tokenize(x, start_token, end_token)
        x = self.embedding(x)
        pos = self.position_encoder().to(get_device())
        x = self.dropout(x + pos)
        return x


class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        self.qkv_layer = nn.Linear(d_model , 3 * d_model)
        self.linear_layer = nn.Linear(d_model, d_model)
    
    def forward(self, x, mask):
        batch_size, sequence_length, d_model = x.size()
        qkv = self.qkv_layer(x)
        qkv = qkv.reshape(batch_size, sequence_length, self.num_heads, 3 * self.head_dim)
        qkv = qkv.permute(0, 2, 1, 3)
        q, k, v = qkv.chunk(3, dim=-1)
        values, attention = scaled_dot_product(q, k, v, mask)
        values = values.permute(0, 2, 1, 3).reshape(batch_size, sequence_length, self.num_heads * self.head_dim)
        out = self.linear_layer(values)
        return out


class LayerNormalization(nn.Module):
    def __init__(self, parameters_shape, eps=1e-5):
        super().__init__()
        self.parameters_shape=parameters_shape
        self.eps=eps
        self.gamma = nn.Parameter(torch.ones(parameters_shape))
        self.beta =  nn.Parameter(torch.zeros(parameters_shape))

    def forward(self, inputs):
        dims = [-(i + 1) for i in range(len(self.parameters_shape))]
        mean = inputs.mean(dim=dims, keepdim=True)
        var = ((inputs - mean) ** 2).mean(dim=dims, keepdim=True)
        std = (var + self.eps).sqrt()
        y = (inputs - mean) / std
        out = self.gamma * y + self.beta
        return out

  
class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_model, hidden, drop_prob=0.1):
        super(PositionwiseFeedForward, self).__init__()
        self.linear1 = nn.Linear(d_model, hidden)
        self.linear2 = nn.Linear(hidden, d_model)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p=drop_prob)

    def forward(self, x):
        x = self.linear1(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.linear2(x)
        return x


class EncoderLayer(nn.Module):
    def __init__(self, d_model, ffn_hidden, num_heads, drop_prob):
        super(EncoderLayer, self).__init__()
        self.attention = MultiHeadAttention(d_model=d_model, num_heads=num_heads)
        self.norm1 = LayerNormalization(parameters_shape=[d_model])
        self.dropout1 = nn.Dropout(p=drop_prob)
        self.ffn = PositionwiseFeedForward(d_model=d_model, hidden=ffn_hidden, drop_prob=drop_prob)
        self.norm2 = LayerNormalization(parameters_shape=[d_model])
        self.dropout2 = nn.Dropout(p=drop_prob)

    def forward(self, x, self_attention_mask):
        residual_x = x.clone()
        x = self.attention(x, mask=self_attention_mask)
        x = self.dropout1(x)
        x = self.norm1(x + residual_x)
        residual_x = x.clone()
        x = self.ffn(x)
        x = self.dropout2(x)
        x = self.norm2(x + residual_x)
        return x
    
class SequentialEncoder(nn.Sequential):
    def forward(self, *inputs):
        x, self_attention_mask  = inputs
        for module in self._modules.values():
            x = module(x, self_attention_mask)
        return x

class Encoder(nn.Module):
    def __init__(self, 
                 d_model, 
                 ffn_hidden, 
                 num_heads, 
                 drop_prob, 
                 num_layers,
                 max_sequence_length,
                 language_to_index,
                 START_TOKEN,
                 END_TOKEN, 
                 PADDING_TOKEN):
        super().__init__()
        self.sentence_embedding = SentenceEmbedding(max_sequence_length, d_model, language_to_index, START_TOKEN, END_TOKEN, PADDING_TOKEN)
        self.layers = SequentialEncoder(*[EncoderLayer(d_model, ffn_hidden, num_heads, drop_prob)
                                      for _ in range(num_layers)])

    def forward(self, x, self_attention_mask, start_token, end_token):
        x = self.sentence_embedding(x, start_token, end_token)
        x = self.layers(x, self_attention_mask)
        return x


class MultiHeadCrossAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        self.kv_layer = nn.Linear(d_model , 2 * d_model)
        self.q_layer = nn.Linear(d_model , d_model)
        self.linear_layer = nn.Linear(d_model, d_model)
    
    def forward(self, x, y, mask):
        batch_size, sequence_length, d_model = x.size() # in practice, this is the same for both languages...so we can technically combine with normal attention
        kv = self.kv_layer(x)
        q = self.q_layer(y)
        kv = kv.reshape(batch_size, sequence_length, self.num_heads, 2 * self.head_dim)
        q = q.reshape(batch_size, sequence_length, self.num_heads, self.head_dim)
        kv = kv.permute(0, 2, 1, 3)
        q = q.permute(0, 2, 1, 3)
        k, v = kv.chunk(2, dim=-1)
        values, attention = scaled_dot_product(q, k, v, mask) # We don't need the mask for cross attention, removing in outer function!
        values = values.permute(0, 2, 1, 3).reshape(batch_size, sequence_length, d_model)
        out = self.linear_layer(values)
        return out


class DecoderLayer(nn.Module):
    def __init__(self, d_model, ffn_hidden, num_heads, drop_prob):
        super(DecoderLayer, self).__init__()
        self.self_attention = MultiHeadAttention(d_model=d_model, num_heads=num_heads)
        self.layer_norm1 = LayerNormalization(parameters_shape=[d_model])
        self.dropout1 = nn.Dropout(p=drop_prob)

        self.encoder_decoder_attention = MultiHeadCrossAttention(d_model=d_model, num_heads=num_heads)
        self.layer_norm2 = LayerNormalization(parameters_shape=[d_model])
        self.dropout2 = nn.Dropout(p=drop_prob)

        self.ffn = PositionwiseFeedForward(d_model=d_model, hidden=ffn_hidden, drop_prob=drop_prob)
        self.layer_norm3 = LayerNormalization(parameters_shape=[d_model])
        self.dropout3 = nn.Dropout(p=drop_prob)

    def forward(self, x, y, self_attention_mask, cross_attention_mask):
        _y = y.clone()
        y = self.self_attention(y, mask=self_attention_mask)
        y = self.dropout1(y)
        y = self.layer_norm1(y + _y)

        _y = y.clone()
        y = self.encoder_decoder_attention(x, y, mask=cross_attention_mask)
        y = self.dropout2(y)
        y = self.layer_norm2(y + _y)

        _y = y.clone()
        y = self.ffn(y)
        y = self.dropout3(y)
        y = self.layer_norm3(y + _y)
        return y


class SequentialDecoder(nn.Sequential):
    def forward(self, *inputs):
        x, y, self_attention_mask, cross_attention_mask = inputs
        for module in self._modules.values():
            y = module(x, y, self_attention_mask, cross_attention_mask)
        return y

class Decoder(nn.Module):
    def __init__(self, 
                 d_model, 
                 ffn_hidden, 
                 num_heads, 
                 drop_prob, 
                 num_layers,
                 max_sequence_length,
                 language_to_index,
                 START_TOKEN,
                 END_TOKEN, 
                 PADDING_TOKEN):
        super().__init__()
        self.sentence_embedding = SentenceEmbedding(max_sequence_length, d_model, language_to_index, START_TOKEN, END_TOKEN, PADDING_TOKEN)
        self.layers = SequentialDecoder(*[DecoderLayer(d_model, ffn_hidden, num_heads, drop_prob) for _ in range(num_layers)])

    def forward(self, x, y, self_attention_mask, cross_attention_mask, start_token, end_token):
        y = self.sentence_embedding(y, start_token, end_token)
        y = self.layers(x, y, self_attention_mask, cross_attention_mask)
        return y


class Transformer(nn.Module):
    def __init__(self, 
                d_model, 
                ffn_hidden, 
                num_heads, 
                drop_prob, 
                num_layers,
                max_sequence_length, 
                ta_vocab_size,
                eng_to_ind,
                tam_to_ind,
                START_TOKEN, 
                END_TOKEN, 
                PADDING_TOKEN
                ):
        super().__init__()
        self.encoder = Encoder(d_model, ffn_hidden, num_heads, drop_prob, num_layers, max_sequence_length, eng_to_ind, START_TOKEN, END_TOKEN, PADDING_TOKEN)
        self.decoder = Decoder(d_model, ffn_hidden, num_heads, drop_prob, num_layers, max_sequence_length, tam_to_ind, START_TOKEN, END_TOKEN, PADDING_TOKEN)
        self.linear = nn.Linear(d_model, ta_vocab_size)
        self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

    def forward(self, 
                x, 
                y, 
                encoder_self_attention_mask=None, 
                decoder_self_attention_mask=None, 
                decoder_cross_attention_mask=None,
                enc_start_token=False,
                enc_end_token=False,
                dec_start_token=False, # We should make this true
                dec_end_token=False): # x, y are batch of sentences
        x = self.encoder(x, encoder_self_attention_mask, start_token=enc_start_token, end_token=enc_end_token)
        out = self.decoder(x, y, decoder_self_attention_mask, decoder_cross_attention_mask, start_token=dec_start_token, end_token=dec_end_token)
        out = self.linear(out)
        return out
    

import torch
import numpy
from torch.utils.data import Dataset, DataLoader
import numpy as np
import torch.nn as nn

engfile='train.en'
tamfile='train.ta'
START_TOKEN='<START>'
PADDING_TOKEN='<PAD>'
END_TOKEN='<END>'
max_seq_len=300

with open(engfile,encoding='utf8') as file:
    engsentences=file.readlines()
with open(tamfile,encoding='utf8') as file:
    tamsentences=file.readlines()

from torch import nn



engsentences=engsentences[:100]
engsentences=[sentence.rstrip('\n') for sentence in engsentences]
tamsentences=tamsentences[:100]
tamsentences=[sentence.rstrip('\n') for sentence in tamsentences]

tamvocab=sorted(list(set((''.join([str(item) for item in tamsentences])))))
tamvocab.insert(0,START_TOKEN)
tamvocab.append(PADDING_TOKEN)
tamvocab.append(END_TOKEN)
tamvocab.append('`')
tamvocab=list(set(tamvocab))
print(tamvocab)

engvocab=sorted(list(set((''.join([str(item) for item in engsentences])))))
engvocab.insert(0,START_TOKEN)
engvocab.append(PADDING_TOKEN)
engvocab.append(END_TOKEN)


ind_to_eng={t:v for t,v in enumerate(engvocab)}
eng_to_ind={v:t for t,v in enumerate(engvocab)}
ind_to_tam={t:v for t,v in enumerate(tamvocab)}
tam_to_ind={v:t for t,v in enumerate(tamvocab)}

criterian = nn.CrossEntropyLoss(ignore_index=tam_to_ind[PADDING_TOKEN],
                                reduction='none')

def is_valid_tokens(sentence, vocab):
    for token in list(set(sentence)):
        if token not in vocab:
            return False
    return True

def is_valid_length(sentence, max_seq_len):
    return len(list(sentence)) < (max_seq_len - 1)

valid_sentence_indicies = []
for index in range(len(tamsentences)):
    tamil_sentence, english_sentence = tamsentences[index], engsentences[index]
    if is_valid_length(tamil_sentence, max_seq_len) \
      and is_valid_length(english_sentence, max_seq_len) \
      and is_valid_tokens(tamil_sentence, tamvocab):
        valid_sentence_indicies.append(index)

tamsentences = [tamsentences[i] for i in valid_sentence_indicies]
engsentences = [engsentences[i] for i in valid_sentence_indicies]


d_model = 512
batch_size = 30
ffn_hidden = 2048
num_heads = 8
drop_prob = 0.1
num_layers = 1
max_sequence_length = 300
ta_vocab_size = len(tamvocab)
print(f" ta_vocab = {ta_vocab_size} tam_to_ind ={len(tam_to_ind)}")

transformer = Transformer(d_model, 
                          ffn_hidden,
                          num_heads, 
                          drop_prob, 
                          num_layers, 
                          max_sequence_length,
                          ta_vocab_size,
                          eng_to_ind,
                          tam_to_ind,
                          START_TOKEN, 
                          END_TOKEN, 
                          PADDING_TOKEN)

# When computing the loss, we are ignoring cases when the label is the padding token
for params in transformer.parameters():
    if params.dim() > 1:
        nn.init.xavier_uniform_(params)

optim = torch.optim.Adam(transformer.parameters(), lr=1e-4)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

class TextDataset(Dataset):
    def __init__(self,engsentences,tamsentences):
        self.engsentences=engsentences
        self.tamsentences=tamsentences
    
    def __len__(self):
        return len(self.engsentences)
    
    def __getitem__(self,i):
        return self.engsentences[i],self.tamsentences[i]

dataset=TextDataset(engsentences,tamsentences)

#print(dataset[50])
batch_size=3
train_loader=DataLoader(dataset,batch_size)
iterator=iter(train_loader)

for batch_num,batch in enumerate(iterator):
    print(batch)
    if batch_num>3:
        break

def tokenize(sentence,language_to_index,start_token=True,end_token=True):
    sentence_word_indices=[language_to_index[token] for token in list(sentence)]
    if start_token:
        sentence_word_indices.insert(0,language_to_index[START_TOKEN])
    if end_token:
        sentence_word_indices.append(language_to_index[END_TOKEN])
    for _ in range(len(sentence_word_indices),max_seq_len):
        sentence_word_indices.append(language_to_index[PADDING_TOKEN])
    return torch.tensor(sentence_word_indices)

eng_tokenized,tam_tokenized=[],[]
for sentence_num in range(batch_size):
    eng_sentence,ta_sentence=batch[0][sentence_num],batch[1][sentence_num]
    eng_tokenized.append(tokenize(eng_sentence,eng_to_ind,start_token=False,end_token=False))
    tam_tokenized.append(tokenize(ta_sentence,tam_to_ind,start_token=True,end_token=True))
eng_tokenized=torch.stack(eng_tokenized)
tam_tokenized=torch.stack(tam_tokenized)

neginf=-1e9
def create_masks(eng_batch,tam_batch):
    num_sentences=len(eng_batch)
    look_ahead_mask=torch.full([max_seq_len,max_seq_len],True)
    look_ahead_mask=torch.triu(look_ahead_mask,diagonal=1)
    encoder_padding_mask = torch.full([num_sentences, max_seq_len, max_seq_len] , False)
    decoder_padding_mask_self_attention = torch.full([num_sentences, max_seq_len, max_seq_len] , False)
    decoder_padding_mask_cross_attention = torch.full([num_sentences, max_seq_len, max_seq_len] , False)

    for i in range(num_sentences):
        eng_sentence_length,tam_sentence_length=len(eng_batch[i]),len(tam_batch[i])
        eng_chars_to_padding_mask=np.arange(eng_sentence_length+1,max_seq_len)
        tam_chars_to_padding_mask=np.arange(tam_sentence_length+1,max_seq_len)
        encoder_padding_mask[i, :, eng_chars_to_padding_mask] = True
        encoder_padding_mask[i, eng_chars_to_padding_mask, :] = True
        decoder_padding_mask_self_attention[i, :, tam_chars_to_padding_mask] = True
        decoder_padding_mask_self_attention[i, tam_chars_to_padding_mask, :] = True
        decoder_padding_mask_cross_attention[i, :, eng_chars_to_padding_mask] = True
        decoder_padding_mask_cross_attention[i, tam_chars_to_padding_mask, :] = True

    encoder_padding_mask[i, :, eng_chars_to_padding_mask] = True
    encoder_padding_mask[i, eng_chars_to_padding_mask, :] = True
    decoder_padding_mask_self_attention[i, :, tam_chars_to_padding_mask] = True
    decoder_padding_mask_self_attention[i, tam_chars_to_padding_mask, :] = True
    decoder_padding_mask_cross_attention[i, :, eng_chars_to_padding_mask] = True
    decoder_padding_mask_cross_attention[i, tam_chars_to_padding_mask, :] = True

    encoder_self_attention_mask = torch.where(encoder_padding_mask, neginf, 0)
    decoder_self_attention_mask =  torch.where(look_ahead_mask + decoder_padding_mask_self_attention, neginf, 0)
    decoder_cross_attention_mask = torch.where(decoder_padding_mask_cross_attention, neginf, 0)
    return encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask

class sentence_embedding(nn.Module):
    def __init__(self,max_seq_len,d_model,language_to_index,START_TOKEN,END_TOKEN,PADDING_TOKEN):
        super().__init__()
        self.vocab_size=len(language_to_index)
        self.max_seq_len=max_seq_len
        self.embedding=nn.Embedding(self.vocab_size,d_model)
        self.language_to_index=language_to_index
        self.position_encoder=PositionalEncoding(d_model,max_sequence_length=max_seq_len)
        self.dropout=nn.Dropout(p=0.1)
        self.START_TOKEN=START_TOKEN
        self.END_TOKEN=END_TOKEN
        self.PADDING_TOKEN=PADDING_TOKEN
    
    def batch_tokenize(self,batch,start_token=True,end_token=True):
        def tokenize(sentence,start_token=True,end_token=True):
            sentence_word_indices=[self.language_to_index[token] for token in list(sentence)]
            if start_token:
                sentence_word_indices.insert(0,self.language_to_index[START_TOKEN])
            if end_token:
                sentence_word_indices.append(self.language_to_index[END_TOKEN])
            for _ in range(len(sentence_word_indices),max_seq_len):
                sentence_word_indices.append(self.language_to_index[PADDING_TOKEN])
            return torch.tensor(sentence_word_indices)
        
        tokenized = []
        for sentence_num in range(len(batch)):
           tokenized.append( tokenize(batch[sentence_num], start_token, end_token) )
        tokenized = torch.stack(tokenized)
        return tokenized.to(get_device())
    
    def forward(self, x, end_token=True): # sentence
        x = self.batch_tokenize(x ,end_token)
        x = self.embedding(x)
        pos = self.position_encoder().to(get_device())
        x = self.dropout(x + pos)
        return x

device=get_device()



transformer.to(device)
total_loss = 0
num_epochs = 400


for epoch in range(num_epochs):
    print(f"Epoch {epoch}")
    iterator = iter(train_loader)
    for batch_num, batch in enumerate(iterator):
        transformer.train()
        eng_batch, ta_batch = batch
        encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask = create_masks(eng_batch, ta_batch)
        optim.zero_grad()
        ta_predictions = transformer(eng_batch,
                                     ta_batch,
                                     encoder_self_attention_mask.to(device), 
                                     decoder_self_attention_mask.to(device), 
                                     decoder_cross_attention_mask.to(device),
                                     enc_start_token=False,
                                     enc_end_token=False,
                                     dec_start_token=True,
                                     dec_end_token=True)
        labels = transformer.decoder.sentence_embedding.batch_tokenize(ta_batch, start_token=False, end_token=True)
        loss = criterian(
            ta_predictions.view(-1, ta_vocab_size).to(device),
            labels.view(-1).to(device)
        ).to(device)
        valid_indicies = torch.where(labels.view(-1) == tam_to_ind[PADDING_TOKEN], False, True)
        loss = loss.sum() / valid_indicies.sum()
        loss.backward()
        optim.step()
        #train_losses.append(loss.item())
        if batch_num % 100 == 0:
            print(f"Iteration {batch_num} : {loss.item()}")
            print(f"English: {eng_batch[0]}")
            print(f"Tamil Translation: {ta_batch[0]}")
            ta_sentence_predicted = torch.argmax(ta_predictions[0], axis=1)
            predicted_sentence = ""
            for idx in ta_sentence_predicted:
              if idx == tam_to_ind[END_TOKEN]:
                break
              predicted_sentence += ind_to_tam[idx.item()]
            print(f"Tamil Prediction: {predicted_sentence}")


            transformer.eval()
            ta_sentence = ("",)
            eng_sentence = ("What is your name?",)
            for word_counter in range(max_sequence_length):
                encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask= create_masks(eng_sentence, ta_sentence)
                predictions = transformer(eng_sentence,
                                          ta_sentence,
                                          encoder_self_attention_mask.to(device), 
                                          decoder_self_attention_mask.to(device), 
                                          decoder_cross_attention_mask.to(device),
                                          enc_start_token=False,
                                          enc_end_token=False,
                                          dec_start_token=True,
                                          dec_end_token=False)
                next_token_prob_distribution = predictions[0][word_counter] # not actual probs
                next_token_index = torch.argmax(next_token_prob_distribution).item()
                next_token = ind_to_tam[next_token_index]
                ta_sentence = (ta_sentence[0] + next_token, )
                if next_token == END_TOKEN:
                  break
            
            print(f"Evaluation translation (What is your name?) : {ta_sentence}")
            print("-------------------------------------------")
        
    FILEA="model1save.pth"
    torch.save(transformer,FILEA)

FILEA="finalmodel.pth"
torch.save(transformer,FILEA)


def translate(eng_sentence):
    transformer.eval()
    ta_sentence = ("",)
    eng_sentence=(eng_sentence,)
    for word_counter in range(max_sequence_length):
        encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask= create_masks(eng_sentence, ta_sentence)
        predictions = transformer(eng_sentence,
                                    ta_sentence,
                                    encoder_self_attention_mask.to(device), 
                                    decoder_self_attention_mask.to(device), 
                                    decoder_cross_attention_mask.to(device),
                                    enc_start_token=False,
                                    enc_end_token=False,
                                    dec_start_token=True,
                                    dec_end_token=False)
        next_token_prob_distribution = predictions[0][word_counter] # not actual probs
        next_token_index = torch.argmax(next_token_prob_distribution).item()
        next_token = ind_to_tam[next_token_index]
        ta_sentence = (ta_sentence[0] + next_token, )
        if next_token == END_TOKEN:
            break

    print(f"Evaluation translation (What is your name?) : {ta_sentence}")
    print("-------------------------------------------")

translate('what is your name')

['ஹ', 'ொ', 'C', '9', 'L', 'ஃ', 'அ', 'ி', 'i', 'ஏ', '7', 'ஓ', '—', '?', 'ன', 'ஞ', 'ா', ':', 'o', 'த', '<END>', '0', "'", 'ெ', 'ர', 'ஸ', '\u2060', 'க', 'ீ', 'ஷ', 'ல', 'ஆ', '-', 'F', 'ே', 'A', 'ட', 'T', 'உ', '்', '6', 'ஊ', '‘', 'ு', '5', '<START>', 'e', 'ந', 'ஜ', '3', '2', 'ண', 'ப', 'a', 'E', 'n', 'l', '4', 'ள', '“', '\u200b', 'm', ',', 'ம', 'ழ', 'ோ', 'N', 'k', '1', 'ங', 't', ')', ' ', '"', 'ை', 'ூ', 'ஒ', 'R', 'g', '”', 'ஈ', '.', 'இ', 'ஐ', 'ய', '`', 'எ', '’', '<PAD>', 's', 'ச', 'u', '(', 'வ', 'ற']
 ta_vocab = 95 tam_to_ind =95
[('Some 14 months later, the second calf is born.', '"Senior advocate Kapil Sibal, who was appearing for Chidambaram, said the condition was not justified for a member of Parliament and ""he would not run away anywhere."""', 'This photo was taken then.'), ('சுமார் 14 மாதங்கள் கழித்து, இரண்டாம் கன்றை ஈனுகிறது.', '‘காா்த்தி சிதம்பரம் எம். பி. யாக உள்ளதால் எங்கும் தப்பிவிட மாட்டாா்’ என்று அவரது சாா்பில் ஆஜரான மூத்த வழக்குரைஞா் கபில் சிபல் வாதாடினாா்.', 'அதன்போது எடுக்க

In [7]:
def translate(eng_sentence):
    transformer.eval()
    ta_sentence = ("",)
    eng_sentence=(eng_sentence,)
    for word_counter in range(max_sequence_length):
        encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask= create_masks(eng_sentence, ta_sentence)
        predictions = transformer(eng_sentence,
                                    ta_sentence,
                                    encoder_self_attention_mask.to(device), 
                                    decoder_self_attention_mask.to(device), 
                                    decoder_cross_attention_mask.to(device),
                                    enc_start_token=False,
                                    enc_end_token=False,
                                    dec_start_token=True,
                                    dec_end_token=False)
        next_token_prob_distribution = predictions[0][word_counter] # not actual probs
        next_token_index = torch.argmax(next_token_prob_distribution).item()
        next_token = ind_to_tam[next_token_index]
        ta_sentence = (ta_sentence[0] + next_token, )
        if next_token == END_TOKEN:
            break

    print(f"Evaluation translation ({eng_sentence}) : {ta_sentence}")
    print("-------------------------------------------")

translate('where are you?')
translate('Thanks')

Evaluation translation (('where are you?',)) : ('ஸ்மோக்பிக் கொண்கான்ட கோாா - 100 கிராம்<END>',)
-------------------------------------------
Evaluation translation (('Thanks',)) : ('திலே உயிக்கொய்டேலே<END>',)
-------------------------------------------


In [8]:
translate('Some 14 months later, the second calf is born.')

Evaluation translation (('Some 14 months later, the second calf is born.',)) : ('சுமார் 14 மாதங்கள் கழித்து, இரண்டாம் கன்றை ஈனுகிறது.<END>',)
-------------------------------------------


In [10]:
translate('What is joy about?')

Evaluation translation (('What is joy about?',)) : ('அப்படி என்ன ஆனந்தம்?<END>',)
-------------------------------------------


In [13]:
translate('I don\'t know')

Evaluation translation (("I don't know",)) : ('எனக்கு தெரியவில்லை<END>',)
-------------------------------------------


In [16]:
translate('Thank you')

Evaluation translation (('Thank you',)) : ('நன்றி நன்றி நன்றி<END>',)
-------------------------------------------


translate(')

In [5]:
translate('Some 14 months later, the second calf is born.')

Evaluation translation (('Some 14 months later, the second calf is born.',)) : ('இதன் பித்த் தங் க் க கு கள்த் ப் பட் க்த்த் ஈடியியொட்ட்ட்.<END>',)
-------------------------------------------


In [14]:
def save_checkpoint(state,filename="savetest.pth.tar"):
    torch.save(state,filename)

checkpoint={'state_dict' : transformer.state_dict(), 'optimizer' : optim.state_dict()}
save_checkpoint(checkpoint)



In [30]:
translate('what is joy about?')

Evaluation translation (('what is joy about?',)) : ('என்ன பயன் என்ன?<END>',)
-------------------------------------------


In [31]:
translate('Happy birthday')

Evaluation translation (('Happy birthday',)) : ('குழந்தை மகிழ்ச்சி<END>',)
-------------------------------------------


In [33]:
translate('What is your name?')

Evaluation translation (('What is your name?',)) : ('உங்கள் என்ன பெயர்?<END>',)
-------------------------------------------


In [35]:
translate('Why are you sad')

Evaluation translation (('Why are you sad',)) : ('ஏன் என்னை ஏன் இப்படி இருக்கிறீர்கள்?<END>',)
-------------------------------------------


In [58]:
translate('Leave me alone')

Evaluation translation (('Leave me alone',)) : ('என்னை விட்டு விட்டு விடுவோம்<END>',)
-------------------------------------------


In [70]:
translate('That was good!')

Evaluation translation (('That was good!',)) : ('அது நல்லாம்!<END>',)
-------------------------------------------


In [53]:
translate('How are you?')

Evaluation translation (('How are you?',)) : ('எப்படி எப்படி இருக்கிறீர்கள்?<END>',)
-------------------------------------------
