In [1]:
from transformer import Transformer
import torch
import torch.nn as nn
import numpy as np
import math

In [2]:
english_file = 'data/english.txt'
kannada_file = 'data/kannada.txt'

START_TOKEN = '<start>'
PADDING_TOKEN = '<padding>'
END_TOKEN = '<end>'

kannada_vocabulary = [START_TOKEN, ' ', '!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/', 
                      '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', '<', '=', '>', '?', 'ˌ', 
                      'ँ', 'ఆ', 'ఇ', 'ా', 'ి', 'ీ', 'ు', 'ూ', 
                      'ಅ', 'ಆ', 'ಇ', 'ಈ', 'ಉ', 'ಊ', 'ಋ', 'ೠ', 'ಌ', 'ಎ', 'ಏ', 'ಐ', 'ಒ', 'ಓ', 'ಔ', 
                      'ಕ', 'ಖ', 'ಗ', 'ಘ', 'ಙ', 
                      'ಚ', 'ಛ', 'ಜ', 'ಝ', 'ಞ', 
                      'ಟ', 'ಠ', 'ಡ', 'ಢ', 'ಣ', 
                      'ತ', 'ಥ', 'ದ', 'ಧ', 'ನ', 
                      'ಪ', 'ಫ', 'ಬ', 'ಭ', 'ಮ', 
                      'ಯ', 'ರ', 'ಱ', 'ಲ', 'ಳ', 'ವ', 'ಶ', 'ಷ', 'ಸ', 'ಹ', 
                      '಼', 'ಽ', 'ಾ', 'ಿ', 'ೀ', 'ು', 'ೂ', 'ೃ', 'ೄ', 'ೆ', 'ೇ', 'ೈ', 'ೊ', 'ೋ', 'ೌ', '್', 'ೕ', 'ೖ', 'ೞ', 'ೣ', 'ಂ', 'ಃ', 
                      '೦', '೧', '೨', '೩', '೪', '೫', '೬', '೭', '೮', '೯', PADDING_TOKEN, END_TOKEN]

english_vocabulary = [START_TOKEN, ' ', '!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/', 
                        '0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
                        ':', '<', '=', '>', '?', '@', 
                        'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 
                        'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 
                        'Y', 'Z',
                        '[', '\\', ']', '^', '_', '`', 
                        'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l',
                        'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 
                        'y', 'z', 
                        '{', '|', '}', '~', PADDING_TOKEN, END_TOKEN]

In [3]:
index_to_kannada = {k:v for k,v in enumerate(kannada_vocabulary)}
kannada_to_index = {v:k for k,v in enumerate(kannada_vocabulary)}
index_to_english = {k:v for k,v in enumerate(english_vocabulary)}
english_to_index = {v:k for k,v in enumerate(english_vocabulary)}
english_to_index 

{'<start>': 0,
 ' ': 1,
 '!': 2,
 '"': 3,
 '#': 4,
 '$': 5,
 '%': 6,
 '&': 7,
 "'": 8,
 '(': 9,
 ')': 10,
 '*': 11,
 '+': 12,
 ',': 13,
 '-': 14,
 '.': 15,
 '/': 16,
 '0': 17,
 '1': 18,
 '2': 19,
 '3': 20,
 '4': 21,
 '5': 22,
 '6': 23,
 '7': 24,
 '8': 25,
 '9': 26,
 ':': 27,
 '<': 28,
 '=': 29,
 '>': 30,
 '?': 31,
 '@': 32,
 'A': 33,
 'B': 34,
 'C': 35,
 'D': 36,
 'E': 37,
 'F': 38,
 'G': 39,
 'H': 40,
 'I': 41,
 'J': 42,
 'K': 43,
 'L': 44,
 'M': 45,
 'N': 46,
 'O': 47,
 'P': 48,
 'Q': 49,
 'R': 50,
 'S': 51,
 'T': 52,
 'U': 53,
 'V': 54,
 'W': 55,
 'X': 56,
 'Y': 57,
 'Z': 58,
 '[': 59,
 '\\': 60,
 ']': 61,
 '^': 62,
 '_': 63,
 '`': 64,
 'a': 65,
 'b': 66,
 'c': 67,
 'd': 68,
 'e': 69,
 'f': 70,
 'g': 71,
 'h': 72,
 'i': 73,
 'j': 74,
 'k': 75,
 'l': 76,
 'm': 77,
 'n': 78,
 'o': 79,
 'p': 80,
 'q': 81,
 'r': 82,
 's': 83,
 't': 84,
 'u': 85,
 'v': 86,
 'w': 87,
 'x': 88,
 'y': 89,
 'z': 90,
 '{': 91,
 '|': 92,
 '}': 93,
 '~': 94,
 '<padding>': 95,
 '<end>': 96}

In [4]:
with open(english_file, 'r') as file:
    english_sentences = file.readlines()
with open(kannada_file, 'r') as file:
    kannada_sentences = file.readlines()

In [5]:
# Limit Number of sentences
TOTAL_SENTENCES = 100000
english_sentences = english_sentences[:TOTAL_SENTENCES]
kannada_sentences = kannada_sentences[:TOTAL_SENTENCES]

In [6]:
english_sentences = [sentence.rstrip('\n') for sentence in english_sentences]
kannada_sentences = [sentence.rstrip('\n') for sentence in kannada_sentences]

In [7]:
max_sequence_length = 200

def is_valid_tokens(sentence, vocab):
    for token in list(set(sentence)):
        if token not in vocab:
            return False
    return True

def is_valid_length(sentence, max_sequence_length):
    return len(list(sentence)) < (max_sequence_length - 1) # need to re-add the end token so leaving 1 space

valid_sentence_indicies = []
for index in range(len(kannada_sentences)):
    kannada_sentence, english_sentence = kannada_sentences[index], english_sentences[index]
    if is_valid_length(kannada_sentence, max_sequence_length) \
      and is_valid_length(english_sentence, max_sequence_length) \
      and is_valid_tokens(kannada_sentence, kannada_vocabulary):
        valid_sentence_indicies.append(index)

# print(f"Number of sentences: {len(kannada_sentences)}")
# print(f"Number of valid sentences: {len(valid_sentence_indicies)}")

In [8]:
kannada_sentences = [kannada_sentences[i] for i in valid_sentence_indicies]
english_sentences = [english_sentences[i] for i in valid_sentence_indicies]

In [9]:
from torch.utils.data import Dataset, DataLoader

class TextDataset(Dataset):

    def __init__(self, english_sentences, kannada_sentences):
        self.english_sentences = english_sentences
        self.kannada_sentences = kannada_sentences

    def __len__(self):
        return len(self.english_sentences)

    def __getitem__(self, idx):
        return self.english_sentences[idx], self.kannada_sentences[idx]

In [10]:
dataset = TextDataset(english_sentences, kannada_sentences)

In [11]:
NEG_INFTY = -1e9

def create_masks(eng_batch, kn_batch, max_sequence_length):
    num_sentences = len(eng_batch)
    look_ahead_mask = torch.full([max_sequence_length, max_sequence_length] , True)
    # print(f"look_ahead_mask {look_ahead_mask.size()}: {look_ahead_mask}")
    look_ahead_mask = torch.triu(look_ahead_mask, diagonal=1)
    # print(f"look_ahead_mask {look_ahead_mask.size()}: {look_ahead_mask}")
    encoder_padding_mask = torch.full([num_sentences, max_sequence_length, max_sequence_length] , False)
    # print(f"encoder_padding_mask {encoder_padding_mask.size()}: {encoder_padding_mask[0]}")
    decoder_padding_mask_self_attention = torch.full([num_sentences, max_sequence_length, max_sequence_length] , False)
    decoder_padding_mask_cross_attention = torch.full([num_sentences, max_sequence_length, max_sequence_length] , False)

    for idx in range(num_sentences):
      eng_sentence_length, kn_sentence_length = len(eng_batch[idx]), len(kn_batch[idx])
      # print(f"eng_sentence_length {eng_sentence_length}, kn_sentence_length {kn_sentence_length}")
      eng_chars_to_padding_mask = np.arange(eng_sentence_length + 1, max_sequence_length)
      kn_chars_to_padding_mask = np.arange(kn_sentence_length + 1, max_sequence_length)
      encoder_padding_mask[idx, :, eng_chars_to_padding_mask] = True
      encoder_padding_mask[idx, eng_chars_to_padding_mask, :] = True
      decoder_padding_mask_self_attention[idx, :, kn_chars_to_padding_mask] = True
      decoder_padding_mask_self_attention[idx, kn_chars_to_padding_mask, :] = True
      decoder_padding_mask_cross_attention[idx, :, eng_chars_to_padding_mask] = True
      decoder_padding_mask_cross_attention[idx, kn_chars_to_padding_mask, :] = True

    encoder_self_attention_mask = torch.where(encoder_padding_mask, NEG_INFTY, 0)
    decoder_self_attention_mask =  torch.where(look_ahead_mask + decoder_padding_mask_self_attention, NEG_INFTY, 0)
    decoder_cross_attention_mask = torch.where(decoder_padding_mask_cross_attention, NEG_INFTY, 0)
    # print(f"encoder_self_attention_mask {encoder_self_attention_mask.size()}: {encoder_self_attention_mask[0, :10, :10]}")
    # print(f"decoder_self_attention_mask {decoder_self_attention_mask.size()}: {decoder_self_attention_mask[0, :10, :10]}")
    # print(f"decoder_cross_attention_mask {decoder_cross_attention_mask.size()}: {decoder_cross_attention_mask[0, :10, :10]}")
    return encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask

In [12]:
d_model = 512
batch_size = 30
ffn_hidden = 2048
num_heads = 8
drop_prob = 0.1
num_layers = 1
max_sequence_length = 200
kn_vocab_size = len(kannada_vocabulary)

transformer = Transformer(d_model, 
                          ffn_hidden, 
                          num_heads, 
                          drop_prob, 
                          num_layers, 
                          max_sequence_length, 
                          kn_vocab_size,
                          english_to_index,
                          kannada_to_index,
                          START_TOKEN,
                          END_TOKEN,
                          PADDING_TOKEN)

# kannada_to_index

In [13]:
import torch
criterion = nn.CrossEntropyLoss(ignore_index=kannada_to_index[PADDING_TOKEN], reduction='none')

for params in transformer.parameters():
    if params.dim() > 1:
        nn.init.xavier_uniform_(params)

optimizer = torch.optim.Adam(transformer.parameters(), lr=0.0001)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [14]:
transformer
train_loader = DataLoader(dataset, batch_size)
iterator = iter(train_loader)

In [15]:
transformer.train()
transformer.to(device)
total_loss = 0
num_epochs = 10

for epoch in range(num_epochs):
    print(f"Epoch {epoch}")
    iterator = iter(train_loader)
    for batch_num, batch in enumerate(iterator):
        transformer.train()
        eng_batch, kn_batch = batch
        encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask = create_masks(eng_batch, kn_batch, max_sequence_length)
        optimizer.zero_grad()
        kn_predictions = transformer(eng_batch,
                                     kn_batch,
                                     encoder_self_attention_mask.to(device), 
                                     decoder_self_attention_mask.to(device), 
                                     decoder_cross_attention_mask.to(device),
                                     enc_start_token=False,
                                     enc_end_token=False,
                                     dec_start_token=True,
                                     dec_end_token=True)
        labels = transformer.decoder.sentence_embedding.batch_tokenize(kn_batch, start_token=False, end_token=True)
        loss = criterion(
            kn_predictions.view(-1, kn_vocab_size).to(device),
            labels.view(-1).to(device)
        ).to(device)
        valid_indicies = torch.where(labels.view(-1) == kannada_to_index[PADDING_TOKEN], False, True)
        loss = loss.sum() / valid_indicies.sum()
        loss.backward()
        optimizer.step()
        #train_losses.append(loss.item())
        if batch_num % 100 == 0:
            print(f"Iteration {batch_num} : {loss.item()}")
            print(f"English: {eng_batch[0]}")
            print(f"Kannada Translation: {kn_batch[0]}")
            kn_sentence_predicted = torch.argmax(kn_predictions[0], axis=1)
            predicted_sentence = ""
            for idx in kn_sentence_predicted:
              if idx == kannada_to_index[END_TOKEN]:
                break
              predicted_sentence += index_to_kannada[idx.item()]
            print(f"Kannada Prediction: {predicted_sentence}")


            transformer.eval()
            kn_sentence = ("",)
            eng_sentence = ("should we go to the mall?",)
            for word_counter in range(max_sequence_length):
                encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask= create_masks(eng_sentence, kn_sentence, max_sequence_length)
                predictions = transformer(eng_sentence,
                                          kn_sentence,
                                          encoder_self_attention_mask.to(device), 
                                          decoder_self_attention_mask.to(device), 
                                          decoder_cross_attention_mask.to(device),
                                          enc_start_token=False,
                                          enc_end_token=False,
                                          dec_start_token=True,
                                          dec_end_token=False)
                next_token_prob_distribution = predictions[0][word_counter] # not actual probs
                next_token_index = torch.argmax(next_token_prob_distribution).item()
                next_token = index_to_kannada[next_token_index]
                kn_sentence = (kn_sentence[0] + next_token, )
                if next_token == END_TOKEN:
                  break
            
            print(f"Evaluation translation (should we go to the mall?) : {kn_sentence}")
            print("-------------------------------------------")

Epoch 0
Iteration 0 : 5.497995853424072
English: Hes a scientist.
Kannada Translation: ಇವರು ಸಂಶೋಧಕ ಸ್ವಭಾವದವರು.
Kannada Prediction: 6666ిిిిిిಟిిిి6ిೊిిిిిిೊಮಡಡಡಮಮೊೖಭ6ೂిೂೖಟ171ಭಗಭಡ,ಮಮೊಮಮಮಮೊೃೊೊಟೖೖೖೖೊೊೊೃೃೃೃಮೊೊಌೊಌి111ೊೖೖೖిಢిిೊಟಡೊೖಒಙಙ0ೊೖ7ನನಡಡನನನ7ಳ7ಟಟನೖ1ೖೖಳಳಗ1ಟಡೖೖೖಟಚೊ=ೖಒೖೖಳೊೊೊ7ిిి/ಒిిిిిిిిಡ==ಒ7ಒిಳిೃನನ=ಳ=ిೊಒిిನిిೃ777ి77ೊೊೊಡೖನನನನిನನನಗ೯೯ి
Evaluation translation (should we go to the mall?) : ('6 6 6     ೆೆೆೆೆೆೆೆ     ೆೆೊುು                           ು              ್್          ್್್್್್         00ನನನನನನನನನನಳಳಳಳಳಳಳಳಳಳಳಳ   ಳಳಳಳಳಳಳಳಳಳಳಳಳ  ್್್್್್್್್000್ವವ77 ಳ  ನನನನನನನ್ ದನನನನನ      ೆೆೆೆನನನನನನನನನನೆೆೆೆ',)
-------------------------------------------
Iteration 100 : 3.508317470550537
English: She ate it.
Kannada Translation: ಅವಳು ಅವನಿಗೆ ಊಟ ಹಾಕಿದಳೂ.
Kannada Prediction: ಅನ್           ್ು 
Evaluation translation (should we go to the mall?) : ('ದದ್                       ು<end>',)
-------------------------------------------
Iteration 200 : 3.253875970840454
English: Caste and religion were unknown.
Kannada Translat

In [17]:
state = {
    'epoch': epoch,
    'transformer_state_dict': transformer.state_dict(),
    'transformer_optimizer_state_dict': optimizer.state_dict()
}
torch.save(state, 'checkpoint_final_' + str(epoch) + '.pth')

In [18]:
checkpoint = torch.load('checkpoint_final_9.pth')

  checkpoint = torch.load('checkpoint_final_9.pth')


In [19]:
d_model = 512
batch_size = 30
ffn_hidden = 2048
num_heads = 8
drop_prob = 0.1
num_layers = 1
max_sequence_length = 200
kn_vocab_size = len(kannada_vocabulary)

transformer = Transformer(d_model, 
                          ffn_hidden, 
                          num_heads, 
                          drop_prob, 
                          num_layers, 
                          max_sequence_length, 
                          kn_vocab_size,
                          english_to_index,
                          kannada_to_index,
                          START_TOKEN,
                          END_TOKEN,
                          PADDING_TOKEN)

In [21]:
transformer.load_state_dict(checkpoint['transformer_state_dict'])
optimizer.load_state_dict(checkpoint['transformer_optimizer_state_dict'])

# Restore the last epoch
start_epoch = checkpoint['epoch']

In [32]:

transformer.eval()
def translate(eng_sentence):
  eng_sentence = (eng_sentence,)
  kn_sentence = ("",)
  for word_counter in range(max_sequence_length):
    encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask= create_masks(eng_sentence, kn_sentence, max_sequence_length)
    predictions = transformer(eng_sentence,
                              kn_sentence,
                              encoder_self_attention_mask.to(device), 
                              decoder_self_attention_mask.to(device), 
                              decoder_cross_attention_mask.to(device),
                              enc_start_token=False,
                              enc_end_token=False,
                              dec_start_token=True,
                              dec_end_token=False)
    next_token_prob_distribution = predictions[0][word_counter]
    next_token_index = torch.argmax(next_token_prob_distribution).item()
    next_token = index_to_kannada[next_token_index]
    kn_sentence = (kn_sentence[0] + next_token, )
    if next_token == END_TOKEN:
      break
  return kn_sentence[0]

In [33]:
translation = translate("what should we do when the day starts?")
print(translation)

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument index in method wrapper_CUDA__index_select)