In [1]:
import json
import math
import torch
import numpy as np
import torch.nn as nn
from transformer import Transformer

In [2]:
with open('English_vocab.json', 'r') as f:
    eng_vocab = json.load(f)

with open('Hindi_vocab.json', 'r') as f:
    hn_vocab = json.load(f)

In [3]:
len(eng_vocab), len(hn_vocab)

(68193, 83630)

In [4]:
START_TOKEN = '<SOS>'
END_TOKEN = '<EOS>'
PADDING_TOKEN = '<POS>'

In [5]:
# eng_vocab[START_TOKEN] = len(eng_vocab)
# eng_vocab[END_TOKEN] = len(eng_vocab)+1
# eng_vocab[PADDING_TOKEN] = len(eng_vocab)+1


# hn_vocab[START_TOKEN] = len(hn_vocab)
# hn_vocab[END_TOKEN] = len(hn_vocab)+1
# hn_vocab[PADDING_TOKEN] = len(hn_vocab)+1

In [6]:
index_to_eng = {}
index_to_hn = {}

for k,v in eng_vocab.items():
    index_to_eng[v] = k

for k,v in hn_vocab.items():
    index_to_hn[v] = k

In [7]:
eng_vocab[PADDING_TOKEN], hn_vocab[PADDING_TOKEN]

(68192, 83629)

In [8]:
with open('english_sentences.txt', 'r') as f:
    english_sentences = f.readlines()

with open('hindi_sentences.txt', 'r') as f:
    hindi_sentences = f.readlines()

In [9]:
english_sentences = [sen.rstrip('\n') for sen in english_sentences]
hindi_sentences = [sen.rstrip('\n') for sen in hindi_sentences]


In [10]:
from torch.utils.data import DataLoader,Dataset

class TextData(Dataset):
    def __init__(self,english_sentences, hindi_sentences) -> None:
        super().__init__()
        self.english_sentences = english_sentences
        self.hindi_sentences = hindi_sentences
    
    def __len__(self):
        return len(self.english_sentences)
    
    def __getitem__(self, index) :
        return self.english_sentences[index],self.hindi_sentences[index]

In [11]:
dataset = TextData(english_sentences, hindi_sentences)

In [12]:
batch_size = 30

In [13]:
train_loader = DataLoader(dataset, batch_size)
iterator = iter(train_loader)

In [14]:
for batch_num, batch in enumerate(iterator):
    print(batch)
    break

[('however paes who was partnering australias paul hanley could only go as far as the quarterfinals where they lost to bhupathi and knowles', 'whosoever desires the reward of the world with allah is the reward of the world and of the everlasting life allah is the hearer the seer', 'the value of insects in the biosphere is enormous because they outnumber all other living groups in measure of species richness', 'mithali to anchor indian team against australia in odis', 'after the assent of the honble president on 8thseptember 2016 the 101thconstitutional amendment act 2016 came into existence', 'the court has fixed a hearing for february 12', 'please select the position where the track should be split', 'as per police armys 22rr special operation group sog of police and the central reserve police force crpf cordoned the village and launched search operation in the area', 'jharkhand chief minister hemant soren', 'arvind kumar sho of the sector 5556 police station said a case has been regi

In [15]:
d_model = 512
max_seq_len = 100 
ffn = 2048 
drop_prob = 0.1
num_head = 8
head_dim = 64
n_layers = 6

transformer = Transformer(d_model, 
                          max_seq_len, 
                          ffn, 
                          drop_prob, 
                          eng_vocab, 
                          hn_vocab,
                          num_head, 
                          head_dim, 
                          n_layers, 
                          True, 
                          True,
                          True)
transformer

Transformer(
  (encoder): Encoder(
    (sentence_embedding): SentenceEmbedding(
      (embedding): Embedding(68193, 512)
      (postion): PostionEncoding()
    )
    (enocder_layers): SequentialEncoder(
      (0): EncoderLayers(
        (selfAttention): MultiheadAttention(
          (q_layer): Linear(in_features=512, out_features=512, bias=True)
          (k_layer): Linear(in_features=512, out_features=512, bias=True)
          (v_layer): Linear(in_features=512, out_features=512, bias=True)
          (linear): Linear(in_features=512, out_features=512, bias=True)
        )
        (norm): LayerNormalize()
        (feedForward): PositionWiseFeedForward(
          (dropout): Dropout(p=0.1, inplace=False)
          (layer1): Linear(in_features=512, out_features=2048, bias=True)
          (layer2): Linear(in_features=2048, out_features=1024, bias=True)
          (layer3): Linear(in_features=1024, out_features=512, bias=True)
          (layer4): Linear(in_features=512, out_features=512, bias

In [16]:
criterian = nn.CrossEntropyLoss(ignore_index=hn_vocab[PADDING_TOKEN], reduction='mean')

In [17]:
for params in transformer.parameters():
    if(params.dim() > 1):
        nn.init.xavier_uniform_(params)

In [18]:
optim = torch.optim.Adam(transformer.parameters(), lr=1e-5)

In [19]:
NEG_INFTY = float('-inf')
NEG_INFTY = -1e9

def create_masks(eng_batch, hn_batch):
    num_sentences = len(eng_batch)
    max_sequence_length = 100
    look_ahead_mask = torch.full([max_sequence_length, max_sequence_length] , True)
    look_ahead_mask = torch.triu(look_ahead_mask, diagonal=1)
    encoder_padding_mask = torch.full([num_sentences, max_sequence_length, max_sequence_length] , False)
    decoder_padding_mask_self_attention = torch.full([num_sentences, max_sequence_length, max_sequence_length] , False)
    decoder_padding_mask_cross_attention = torch.full([num_sentences, max_sequence_length, max_sequence_length] , False)

    for idx in range(num_sentences):
      eng_sentence_length, hn_sentence_length = len(eng_batch[idx].split()), len(hn_batch[idx].split())
      eng_chars_to_padding_mask = np.arange(eng_sentence_length , max_sequence_length)
      hn_chars_to_padding_mask = np.arange(hn_sentence_length , max_sequence_length)
      encoder_padding_mask[idx, :, eng_chars_to_padding_mask] = True
      encoder_padding_mask[idx, eng_chars_to_padding_mask, :] = True
      decoder_padding_mask_self_attention[idx, :, hn_chars_to_padding_mask] = True
      decoder_padding_mask_self_attention[idx, hn_chars_to_padding_mask, :] = True
      decoder_padding_mask_cross_attention[idx, :, eng_chars_to_padding_mask] = True
      decoder_padding_mask_cross_attention[idx, hn_chars_to_padding_mask, :] = True

    encoder_self_attention_mask = torch.where(encoder_padding_mask, NEG_INFTY, 0)
    decoder_self_attention_mask =  torch.where(look_ahead_mask + decoder_padding_mask_self_attention, NEG_INFTY, 0)
    decoder_cross_attention_mask = torch.where(decoder_padding_mask_cross_attention, NEG_INFTY, 0)
    return encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask

In [20]:
transformer.train()
# transformer.to(device)
total_loss = 0
num_epochs = 10

for epoch in range(num_epochs):
    print(f"Epoch {epoch}")
    iterator = iter(train_loader)
    for batch_num, batch in enumerate(iterator):
        transformer.train()
        eng_batch, hn_batch = batch
        encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask = create_masks(eng_batch, hn_batch)
        optim.zero_grad()
        hn_predictions = transformer(eng_batch,
                                     hn_batch,
                                     encoder_self_attention_mask,
                                     decoder_cross_attention_mask,
                                     decoder_self_attention_mask, )
        
        labels = transformer.decoder.sentence_embedding.batchTokenize(hn_batch, start_token=False, end_token=True)
        # print("LABELS :- ", labels)
        loss = criterian(
            hn_predictions.view(-1, len(hn_vocab)),
            labels.view(-1)
        )
        valid_indicies = torch.where(labels.view(-1) == hn_vocab[PADDING_TOKEN], False, True)
        loss = loss.sum() / valid_indicies.sum()
        loss.backward()
        optim.step()
        #train_losses.append(loss.item())
        if batch_num % 100 == 0:
            print(f"Iteration {batch_num} : {loss.item()}")
            print(f"English: {eng_batch[0]}")
            print(f"Hindi Translation: {hn_batch[0]}")
            hn_sentence_predicted = torch.argmax(hn_predictions[0], axis=1)
            predicted_sentence = ""
            for idx in hn_sentence_predicted:
              if idx == hn_vocab[END_TOKEN]:
                break
              predicted_sentence += index_to_hn[idx.item()]+" "
            print(f"Hindi Prediction: {predicted_sentence}")


            transformer.eval()
            hn_sentence = ("",)
            eng_sentence = ("please select the position where the track should be split",)
            for word_counter in range(90):
                # print(word_counter)
                # if(word_counter == 99):
                #    print(eng_sentence, hn_sentence)
                #    print(len(hn_sentence[0].split()))
                encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask= create_masks(eng_sentence, hn_sentence)
                predictions = transformer(eng_sentence,
                                          hn_sentence,
                                          encoder_self_attention_mask, 
                                          decoder_cross_attention_mask,
                                          decoder_self_attention_mask, )
                next_token_prob_distribution = predictions[0][word_counter] # not actual probs
                next_token_index = torch.argmax(next_token_prob_distribution).item()
                next_token = index_to_hn[next_token_index]
                hn_sentence = (hn_sentence[0] + next_token+ ' ', )
                # hn_sentence += ' '
                if next_token == END_TOKEN:
                  print("Break Applied")
                  break
            
            print(f"Evaluation translation (please select the position where the track should be split) : {hn_sentence}")
            print("-------------------------------------------")

Epoch 0
Iteration 0 : 0.019243065267801285
English: however paes who was partnering australias paul hanley could only go as far as the quarterfinals where they lost to bhupathi and knowles
Hindi Translation: आस्ट्रेलिया के पाल हेनली के साथ जोड़ी बनाने वाले पेस मियामी में क्वार्टरफाइनल तक ही पहुंच सके क्योंकि इस दौर में उन्हें भूपति और नोल्स ने हराया था।
Hindi Prediction: मॉगों एपीआई मॉगों एपीआई मॉगों मॉगों एपीआई एपीआई मॉगों एपीआई एपीआई एपीआई एपीआई मॉगों मॉगों एपीआई एपीआई मॉगों एपीआई एपीआई एपीआई एपीआई मॉगों एपीआई एपीआई एपीआई एपीआई एपीआई एपीआई एपीआई एपीआई मॉगों ब्रित मॉगों एपीआई एपीआई मॉगों एपीआई मॉगों एपीआई एपीआई मॉगों एपीआई ब्रित मॉगों एपीआई एपीआई एपीआई एपीआई एपीआई मॉगों एपीआई मॉगों एपीआई मॉगों एपीआई ब्रित एबर्ट स्पोक एपीआई एपीआई एपीआई मॉगों मॉगों एपीआई एपीआई आंवावस्था एपीआई एपीआई एपीआई एपीआई एपीआई एपीआई एपीआई एपीआई मॉगों एपीआई ब्रित एपीआई एपीआई ब्रित एपीआई एपीआई मॉगों मॉगों ब्रित एपीआई एपीआई मॉगों ब्रित एपीआई एपीआई ब्रित मॉगों मॉगों एबर्ट एपीआई एपीआई एपीआई एपीआई 
Evaluation translation

68194