## Acknowledgement
https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html <br />
https://pytorch.org/tutorials/beginner/transformer_tutorial.html <br />
https://pytorch.org/tutorials/intermediate/char_rnn_classification_tutorial.html <br />
https://pytorch.org/tutorials/intermediate/char_rnn_generation_tutorial.html <br />
https://github.com/lkulowski/LSTM_encoder_decoder <br />
https://docs.python.org/3/library/re.html <br />
https://stackexchange.com/ <br />
https://stackoverflow.com/ <br />
https://discuss.pytorch.org/ <br />

## README
- Download the model from https://drive.google.com/drive/folders/10DTWr96vJ0yXWy3kDos5ql4owqgcI_2D?usp=sharing
- Update Path Variables appropriately

# Import Libraries

In [None]:
# Model Libraries
import torch
from torch import Tensor
import torch.nn.functional as F
from torch.utils.data import DataLoader
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Scoring Libraries
import nltk
!pip install sacrebleu
from sacrebleu import sentence_chrf
!pip install rouge
from rouge import Rouge
nltk.download('punkt')

# Supporting
import io
from io import open
import numpy as np
import pandas as pd
import random
import math
from timeit import default_timer as timer
from tqdm.auto import tqdm as time_bar
from collections import Counter
import json
import dill
import re
import csv
from copy import deepcopy as cpy
import time

# Path Variables

In [None]:
# Path variables for Kaggle
train_data_path = r"/kaggle/input/phase-2/train_data2.json"
prediction_folder_path = r"."
val_data_path = r"/kaggle/input/phase-2/val_data2.json"
model_path = r"/kaggle/input/phase-2-final-model/model_final"

# Modules

## Constants

In [None]:
# Languages are Bengali, Gujarati, Hindi, Kannada, Malyalam, Tamil and Telugu

LANGUAGE_PAIRS = ["English-Bengali", "English-Gujarati", "English-Hindi", "English-Kannada", "English-Malayalam", "English-Tamil", "English-Telgu"]
IN_LANGUAGE_CODES = ["bn", "gu", "hi", "kn", "ml", "ta", "te"]

# Train:Test Split Ratio
TEST_TRAIN_SPLIT = 0.99

# Encoding type of the json
ENC_TYPE = "UTF-8"

# Reference
# Wikipedia Unicode Blocks

# Punctuation in various languages (which usually is not part of simple words)
EN_PUNCT = ['~', '`', '!', '@', '#', '$', '%', '^', '&', '*', '(', ')', '-', '_', '+', '=', '{', '}', '[', '],', '|', '\\', ':', '\"', ';', '\'', '<', '>', '?', ',', '.', '/']
EN_NUM = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '0']
UNI_SYM = ['\u2013', '\u2014', '\u2015', '\u2017', '\u2018', '\u2019', '\u201A', '\u201B', '\u201C', '\u201D', '\u201E', '\u2020', '\u2021', '\u2022', '\u2026', '\u2030', '\u2032', '\u2033', '\u2039', '\u203A', '\u203C', '\u203E', '\u2044', '\u204A']
BE_PUNCT = ['\u09F7']
BE_NUM = ['\u09E6', '\u09E7', '\u09E8', '\u09E9', '\u09EA', '\u09EB', '\u09EC', '\u09ED', '\u09EE', '\u09EF']
GU_PUNCT = []
GU_NUM = ['\u0AE6', '\u0AE7', '\u0AE8', '\u0AE9', '\u0AEA', '\u0AEB', '\u0AEC', '\u0AED', '\u0AEE', '\u0AEF']
HI_PUNCT = ['\u0964','\u0965']
HI_NUM = ['\u0966', '\u0967', '\u0968', '\u0969', '\u096A', '\u096B', '\u096C', '\u096D', '\u096E', '\u096F']
KA_PUNCT = []
KA_NUM = ['\u0CE6', '\u0CE7', '\u0CE8', '\u0CE9', '\u0CEA', '\u0CEB', '\u0CEC', '\u0CED', '\u0CEE', '\u0CEF']
MA_PUNCT = []
MA_NUM = ['\u0D66', '\u0D67', '\u0D68', '\u0D69', '\u0D6A', '\u0D6B', '\u0D6C', '\u0D6D', '\u0D6E', '\u0D6F']
TA_PUNCT = []
TA_NUM = ['\u0BE6', '\u0BE7', '\u0BE8', '\u0BE9', '\u0BEA', '\u0BEB', '\u0BEC', '\u0BED', '\u0BEE', '\u0BEF', '\u0BF0', '\u0BF1', '\u0BF1']
TE_PUNCT = []
TE_NUM = ['\u0C66', '\u0C67', '\u0C68', '\u0C69', '\u0C6A', '\u0C6B', '\u0C6C', '\u0C6D', '\u0C6E', '\u0C6F']

# Combination of all above
TOK_PUNCT = EN_PUNCT + EN_NUM + UNI_SYM + BE_PUNCT + BE_NUM + GU_PUNCT + GU_NUM + HI_PUNCT + HI_NUM + KA_PUNCT + KA_NUM + MA_PUNCT + MA_NUM + TA_PUNCT + TA_NUM + TE_PUNCT + TE_NUM

# Range on Unicode for each Language - 0 to 6 denote indian languages while 8 denotes english
UNI_BEG = ['\u0980', '\u0A80', '\u0900', '\u0C80', '\u0D00', '\u0B80', '\u0C00', '\u0000']
UNI_END = ['\u09FF', '\u0AFF', '\u097F', '\u0CFF', '\u0D7F', '\u0BFF', '\u0C7F', '\u007F']
LANG_IX = ["Bengali", "Gujarati", "Hindi", "Kannada", "Malayalam", "Tamil", "Telgu", "English"]

# Language Names
SRC_LANGUAGE = "Indian"
TGT_LANGUAGE = "English"

# Vocabulary Size
IP_VOC_SIZE = 25000 * 7
OP_VOC_SIZE = 10000000 # Use Full Vocab

# Define special symbols and indices
UNK_IDX, PAD_IDX, BOS_IDX, EOS_IDX = 0, 1, 2, 3
special_symbols = ['unk', 'pad', 'sos', 'eos']

# Length Filtering
MIN_LENGTH = 1
MAX_LENGTH_TGT = None
MAX_LENGTH_SRC = None
MAX_PERCENTILE = 99

# Post-Processing Constant
THRESH = 50

## Tokeniser
- Numbers, Space and Punctuation
- Of all languages

In [None]:
# References
# https://stackoverflow.com/questions/30933216/split-by-regex-without-resulting-empty-strings-in-python#:~:text=%3E%3E%3E%20re.findall(r%27%5CS%2B%27%2C%20%27%20a%20b%20%20%20c%20%20de%20%20%27)

class Tokeniser:
    def __init__(self):
        pass
    
    def tok_sen(self, sen):
        '''To tokenise a single input (precisely not a sentence)'''
        
        # To lower
        sen = sen.lower()
        
        # Add space before and after punctuations and numbers
        for ch in TOK_PUNCT:
            if ch in sen:
                sen = sen.replace(ch,' '+ ch + ' ')
        
        # Split text at empty characters
        tok_sen = []
        tok_sen = tok_sen + re.findall('\S+',sen)
        
        return tok_sen
    
    def tok_sen_u(self, sen):
        '''Unlowered - To tokenise a single input (precisely not a sentence)'''
        
        # Add space before and after punctuations and numbers
        for ch in TOK_PUNCT:
            if ch in sen:
                sen = sen.replace(ch,' '+ ch + ' ')
        
        # Split text at empty characters
        tok_sen = []
        tok_sen = tok_sen + re.findall('\S+',sen)
        
        return tok_sen
    
    def tok_data(self, data):
        '''To tokenise data - list of sentences'''
        
        # Tokenise every sentence in the data
        tok_data = []
        for sen in data:
            tok_data.append(self.tok_sen(sen))
            
        return tok_data
    
    def tok_data_u(self, data):
        '''Unlowered - To tokenise data - list of sentences'''
        
        # Tokenise every sentence in the data
        tok_data = []
        for sen in data:
            tok_data.append(self.tok_sen_u(sen))
            
        return tok_data

In [None]:
# Creating a Global tokeniser object

tokeniser = Tokeniser()

## Data Reader
- For reading all the data from json

In [None]:
class Datareader:
    '''
    To read json
    '''
    
    def __init__(self, file_path):
        self.file_path = file_path
        self.json_data = None
        self.ids = []
        self.source_data = []
        self.target_data = []
    
    def load_json(self):
        with open(self.file_path, 'r') as file:
            self.json_data = json.load(file)

    def json_to_list(self, is_train):
        for language_pair, language_data in self.json_data.items():
                for data_type, data_entries in language_data.items():
                    for entry_id, entry_data in data_entries.items():
                        self.ids.append(entry_id)
                        self.source_data.append(entry_data["source"])
                        if is_train==True:
                            self.target_data.append(entry_data["target"])
                        
    def get_data(self, is_train):
        self.load_json()
        self.json_to_list(is_train)
        print("Data Loaded")
        if is_train:
            return self.ids, self.source_data, self.target_data
        else:
            return self.ids, self.source_data

## Language
- To store all the data for a given language
- Vocabulary Wrapper

In [None]:
class Lang:
    def __init__(self, name):
        self.tokeniser = Tokeniser()
        self.name = name
        self.word2index = {"sos": 2,"unk": 0,"pad": 1,"eos": 3}
        self.index2word = {2: "sos",0: "unk",1: "pad",3: "eos"}
        self.n_words = 4  # Count Sof Base Tokens
        self.counter = Counter()

    def addSentence(self, sentence):
        sentence_tok = self.tokeniser.tok_sen(sentence)
        for word in sentence_tok:
            self.addWord(word)
            
    def lookup_tokens(self, list_of_ix):
        ans = []
        for ix in list_of_ix:
            try:
                word = self.index2word[ix]
            except:
                word = "unk"
            ans.append(word)
        return ans
            
    def addWord(self, word):
        self.counter.update([word])
        
    def trim_voc(self, VOC_SIZE):
        for word, count in self.counter.most_common(VOC_SIZE):
            self.addWord_final(word)

    def addWord_final(self, word):
        try:
            self.word2index[word]
        except:
            self.word2index[word] = self.n_words
            self.index2word[self.n_words] = word
            self.n_words += 1
        

## Language Checker 
- For filtering the wrongly labeled sentence in the training data
- Takes majority vote of the characters of each lanaguage

In [None]:
class Language_Checker:
    def __init__(self):
        pass
    
    def get_lang_char(self, ch):
        '''
        Checks the character belongs to which language 
        Returns Language out of 8 if found 
        Else returns None
        '''
        uni_ch = ord(ch)
        for i in range(len(LANG_IX)):
            l_ix = i
            uni_beg_l = ord(UNI_BEG[l_ix])
            uni_end_l = ord(UNI_END[l_ix])
            
            if uni_beg_l <= uni_ch and uni_ch <= uni_end_l:
                 return LANG_IX[l_ix]
        
        return None
    
    def get_score_word(self, word):
        '''
        Returns scores of being to the 8 languages
        '''
        scores = np.zeros(8)
        for ch in word:
            lang = self.get_lang_char(ch)
            if lang!=None and (lang!="English" or (ch in TOK_PUNCT) == False):
                ix = LANG_IX.index(lang)
                scores[ix] += 1
    
        return cpy(scores)
    
    def get_lang_sen(self, sen):
        '''
        Given a tokensied sentence returns its language out of the 8
        '''
        scores = np.zeros(8)
        for word in sen:
            word_score = self.get_score_word(word)
            scores += word_score
        
        pred_ix = np.argmax(scores)
        pred_lang = LANG_IX[pred_ix]
        
        return pred_lang
        
    
    def check_pair(self, sen1, lang1, sen2, lang2):
        '''
        Given a pair of tokensied sentence and expected labels returns true or false for being a mis-match
        '''
        pred_lang1 = self.get_lang_sen(sen1)
        pred_lang2 = self.get_lang_sen(sen2)
        if pred_lang1==lang1 and pred_lang2==lang2:
            return True
        else:
            return False

## To read the data

In [None]:
def readLangs(lang1, lang2):
    lc = Language_Checker()

    # Load Data
    print("Reading Data...")
    data_reader = Datareader(train_data_path)
    ids, source_data, target_data = data_reader.get_data(True)
    
    # Random Shuffle
    print("Data Randomly Shuffling...")
    data = list(zip(ids, source_data, target_data))
    random.shuffle(data)
    ids, source_data, target_data = zip(*data)
    N = len(source_data)

    # Val-Train Split
    val_source_data = source_data[int(TEST_TRAIN_SPLIT*N):]
    source_data = source_data[:int(TEST_TRAIN_SPLIT*N)]
    val_target_data = target_data[int(TEST_TRAIN_SPLIT*N):]
    target_data = target_data[:int(TEST_TRAIN_SPLIT*N)]
    
    # Filter and make pairs
    print("Checking Correctness...")
    pairs = [["unk","unk"]]
    for i in time_bar(range(len(target_data))):
#         if lc.get_lang_sen(target_data[i])=="English":
            pairs.append([source_data[i], target_data[i]])
    
    # Max-Min Lengths
    print("Calculating the threshold lengths...")
    set_lengths(pairs)
    
    
    # Sample
    print("Sample Pair")
    print(pairs[0])
    
    # Convert to languages
    input_lang = Lang(lang1)
    output_lang = Lang(lang2)

    return input_lang, output_lang, pairs, val_source_data, val_target_data

## Length Filters

In [None]:
def set_lengths(pairs):
    '''
    Calculate the lengths at the percentiles and set the variables
    '''
    src_len = []
    tgt_len = []
    
    for p in pairs:
        sen_src, sen_tgt = cpy(p[0]), cpy(p[1])
        sen_src_tok = tokeniser.tok_sen(sen_src)
        sen_tgt_tok = tokeniser.tok_sen(sen_tgt)
        
        src_len.append(len(sen_src_tok))
        tgt_len.append(len(sen_tgt_tok))
        
    src_len = np.array(src_len)
    tgt_len = np.array(tgt_len)
    
    global MAX_LENGTH_SRC
    global MAX_LENGTH_TGT
    
    MAX_LENGTH_SRC = int(np.percentile(src_len, MAX_PERCENTILE))
    MAX_LENGTH_TGT = int(np.percentile(tgt_len, MAX_PERCENTILE))
    
    MAX_LENGTH_TGT = max(MAX_LENGTH_SRC, MAX_LENGTH_TGT)
    MAX_LENGTH_SRC = max(MAX_LENGTH_SRC, MAX_LENGTH_TGT)
    
    print("MAX_LEN_SRC =", MAX_LENGTH_SRC)
    print("MAX_LEN_TGT =", MAX_LENGTH_TGT)

def filterPair(p):
    '''
    Check does the pair satisfy the length criteria
    '''
    if  len(tokeniser.tok_sen(p[0])) < MAX_LENGTH_SRC and \
        len(tokeniser.tok_sen(p[1])) < MAX_LENGTH_TGT and \
        len(tokeniser.tok_sen(p[0])) >= MIN_LENGTH and \
        len(tokeniser.tok_sen(p[1])) >= MIN_LENGTH:
        return True
    else:
        return False

def filterPairs(pairs):
    '''
    Filter all the pairs
    '''
    return [pair for pair in pairs if filterPair(pair)]

In [None]:
def prepareData(lang1, lang2):
    # Read the data
    input_lang, output_lang, pairs, val_source_data, val_target_data = readLangs(lang1, lang2)
    print("Read %s sentence pairs" % len(pairs))
    
    # Filter the data based on length
    pairs = filterPairs(pairs)
    print("Trimmed to %s sentence pairs" % len(pairs))
    
    # Initialise the Language Objects
    for pair in pairs:
        input_lang.addSentence(pair[0])
        output_lang.addSentence(pair[1])
        
    # Reduce the vocabulary size
    input_lang.trim_voc(IP_VOC_SIZE)
    output_lang.trim_voc(OP_VOC_SIZE)
    
    print("Counted words:")
    print(input_lang.name, input_lang.n_words)
    print(output_lang.name, output_lang.n_words)
    return input_lang, output_lang, pairs, val_source_data, val_target_data

## Utility functions

### Pre-processing Related

In [None]:
# Replacing the words with index in vocabulary
def indexesFromSentence(lang, sentence):
    ans = []
    for word in tokeniser.tok_sen(sentence):
        try:
            id = lang.word2index[word]
        except:
            id = lang.word2index["unk"]
        ans.append(id)
    return ans

# Convert Sentence to indexes
vocab_transform = {}
vocab_transform[SRC_LANGUAGE] = lambda y: indexesFromSentence(src_lang, y)
vocab_transform[TGT_LANGUAGE] = lambda y: indexesFromSentence(tgt_lang, y)

# Add sos and eos, and then convert to tensor
def tensor_transform(token_ids: list[int]):
    return torch.cat((torch.tensor([BOS_IDX]),
                      torch.tensor(token_ids),
                      torch.tensor([EOS_IDX])))

# helper function to club together sequential operations
def sequential_transforms(*transforms):
    def func(txt_input):
        for transform in transforms:
            txt_input = transform(txt_input)
        return txt_input
    return func

# ``src`` and ``tgt`` language text transforms to convert raw strings into tensors indices
text_transform = {}
for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:
    text_transform[ln] = sequential_transforms(vocab_transform[ln], #Tokenisation and Numericalization
                                               tensor_transform) # Add BOS/EOS and create tensor

# function to collate data samples into batch tensors
def collate_fn(batch):
    src_batch, tgt_batch = [], []
    for src_sample, tgt_sample in batch:
        src_batch.append(text_transform[SRC_LANGUAGE](src_sample.rstrip("\n")))
        tgt_batch.append(text_transform[TGT_LANGUAGE](tgt_sample.rstrip("\n")))

    src_batch = torch.nn.utils.rnn.pad_sequence(src_batch, padding_value=PAD_IDX)
    tgt_batch = torch.nn.utils.rnn.pad_sequence(tgt_batch, padding_value=PAD_IDX)
    return src_batch, tgt_batch

### Time related

In [None]:
def asMinutes(s):
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)

def timeSince(since, percent):
    now = time.time()
    s = now - since
    es = s / (percent)
    rs = es - s
    return '%s (- %s)' % (asMinutes(s), asMinutes(rs))

## Scorer
- To calculate validation scores while training

### CHRF Calculator

In [None]:
def calculate_chrf_score(ref_data, hypo_data):
    
    # Load the ids and reference sentences csv
    ref_data_tok = []
    for row1 in ref_data:
        row = cpy(row1)
        if len(row) >= 2:
            ref_data_tok.append((row[0], nltk.word_tokenize(row[1])))

    # Load the ids and hypothesis sentences csv
    hypo_data_tok = []
    for row1 in hypo_data:
        row = cpy(row1)
        if len(row) >= 2:
            hypo_data_tok.append((row[0], nltk.word_tokenize(row[1])))

    chrF_score = []
    totalscore = 0.0
    
    # Calculate the chrf++ score for each pair
    for (answer_id, hyp), (_, ref) in zip(ref_data_tok, hypo_data_tok):
        hyp_text = ' '.join(hyp)
        ref_text = ' '.join(ref)
        scores = sentence_chrf(hyp_text, [ref_text], word_order=2)  # word order = 2 represent chrf++
        chrF_score.append(scores.score)
        totalscore = totalscore + scores.score

    # Average the score over all pairs
    return (totalscore/len(chrF_score))

### ROUGE Calculator

In [None]:
def calculate_rouge_scores(ref_data, hypo_data):
    
    # Load the ids and reference sentences csv
    ref_data_tok = []
    for row1 in ref_data:
        row = cpy(row1)
        if len(row) >= 2:
            ref_data_tok.append((row[0], nltk.word_tokenize(row[1])))

    # Load the ids and hypothesis sentences csv
    hypo_data_tok = []
    for row1 in hypo_data:
        row = cpy(row1)
        if len(row) >= 2:
            hypo_data_tok.append((row[0], nltk.word_tokenize(row[1])))

    # Error Handling
    if len(ref_data_tok) != len(hypo_data_tok):
        raise ValueError(f"Number of data points in ref({len(ref_data)}) and hypo must be the same ({len(ref_data)})")

    rouge_scorer = Rouge()
    rouge_l_scores = []

    # Calculate rouge score for each pair
    totalscore = 0.0
    total_itrs =0
    for (answer_id, hyp), (ref_id, ref) in zip(ref_data_tok, hypo_data_tok):
        total_itrs = total_itrs+1
        
        # Error Handling
        if answer_id != ref_id:
            print(f"Warning: Answer ID ({answer_id}) and Ref ID ({ref_id}) do not match. Skipping this iteration.")
            continue        
        hyp_text = ' '.join(hyp)
        ref_text = ' '.join(ref)
        
        # Calculate
        try:
            scores = rouge_scorer.get_scores(hyp_text, ref_text)
            rouge_l_scores.append(scores[0]['rouge-l']['f'])
            totalscore = totalscore + scores[0]['rouge-l']['f']
        except:
            pass

    # Average out the rouge score
    return (totalscore/total_itrs)

### BLEU Calculator

In [None]:
def calculate_bleu_score(ref_data, hypo_data):

    # Load the ids and reference sentences csv
    truth = []
    for ix, line in ref_data:
        line = line.strip()     # To remove leading and trailing spaces
        line = nltk.word_tokenize(line)
        truth.append([line])

    # Load the ids and hypothesis sentences csv
    submission_answer = []
    for ix, line in hypo_data:
        line = line.strip()
        line = nltk.word_tokenize(line)
        submission_answer.append(line)

    # Calculates the bleu score for the corpus
    score = nltk.translate.bleu_score.corpus_bleu(truth, submission_answer)
    return score

### Scorer
- Calculate the scores on the validation set
- chrf++, ROUGE, BLEU


In [None]:
class Scorer:
      
    def __init__(self):
        self.bleu = None
        self.chrf_pp = None
        self.rouge = None
        self.ref_data = None
        self.hypo_data = None
        
    def calc_scores(self):
        '''
        For calcuating the socres
        '''
        
        self.bleu = calculate_bleu_score(self.ref_data, self.hypo_data)
        self.chrf_pp = calculate_chrf_score(self.ref_data, self.hypo_data)
        self.rouge = calculate_rouge_scores(self.ref_data, self.hypo_data)
        
    def print_score(self):
        '''
        For printing the scores
        '''
        
        self.calc_scores()
        
        print(f"Bleu Score: {self.bleu}")
        print(f"CHRF++ Score: {self.chrf_pp/100}")
        print(f"Rouge Score: {self.rouge}")
        
    def translate_data(self, transformer, val_source_data, val_target_data):
        '''
        Translate the source sentences
        '''
        
        self.ref_data = []
        self.hypo_data = []
        
        for id, sen in time_bar(enumerate(val_source_data), total=len(val_source_data)):
            if len(val_target_data[id])>0:
                self.ref_data.append([id, val_target_data[id]])
                translated_sen = translate(transformer, sen)
                self.hypo_data.append([id, translated_sen])
            
    def score(self, transformer, val_source_data, val_target_data):
        '''
        Score the data provided
        '''
        
        # Calculate ref and hypo
        self.translate_data(transformer, val_source_data, val_target_data)
        
        # Print Scores
        self.print_score()
        
        # Clear Memory
        self.ref_data = None
        self.hypo_data = None
        
        return self.bleu, self.chrf_pp, self.rouge
        
# Scorer Global Object
scorer = Scorer()

## Post Processing

In [None]:
def create_replace_dict(pairs, THRESH):
    '''
    Creating a dictionary to store which words are to be capitalised in post-processing
    '''
    
    target_data = []
    for pair in pairs:
        target_data.append(cpy(pair[1]))
    
    # Tokenise Data
    data_u_tok = tokeniser.tok_data_u(target_data)
    data_tok = tokeniser.tok_data(target_data)
    
    # Get frequencies of words in both forms
    counter_u = Counter()
    for sen in data_u_tok:
        counter_u.update(sen)
    counter = Counter()
    for sen in data_tok:
        counter.update(sen)
        
    # Fill the replace_dict if faction of word occuring in non-lower form is more than thresh
    replace_dict = {}
    for word in counter_u.elements():
        if (word in TOK_PUNCT)==False:
            
            if word.islower()==False:
                count_form = counter_u[word]
                low_word = word.lower()
                count_low = counter[low_word]
                
                if count_form / (count_low) >= THRESH / 100:
                    replace_dict[low_word] = word
    
    return replace_dict

def post_process(text):
    '''
    Perform post-processing of machine generated translations
    '''
    
    # Tokenise the sentence
    tokens = tokeniser.tok_sen(text)
    new_tokens = []
    prev_token = None
    
    for token in tokens:
        # Other capitalisations
        try:
            new_token = replace_dict[token]
        except:
            new_token = token
        
        # Start of Sentence
        if prev_token==None:
            new_token = token.capitalize()
            
        # Joining Numbers and puncts
        if (token in TOK_PUNCT)==True:
            
            # Prev was a number or punct
            if (len(new_tokens)!=0) and (prev_token!=None) and ((prev_token in TOK_PUNCT)==True):
                new_tokens[-1] = new_tokens[-1] + token
                continue
            
        # Update Variable
        new_tokens.append(new_token)
        prev_token = token
        
    new_sentence = " ".join(new_tokens)
    return new_sentence

# Model

## Positional Encoding
- Provides the representation of the postion in which the word occurs in the sentence

In [None]:
class PositionalEncoding(torch.nn.Module):
    '''
    Positional Information to embeddings
    '''
    
    def __init__(self,
                 emb_size: int,
                 dropout: float,
                 maxlen: int = 5000):
        super(PositionalEncoding, self).__init__()
        den = torch.exp(- torch.arange(0, emb_size, 2)* math.log(10000) / emb_size)
        pos = torch.arange(0, maxlen).reshape(maxlen, 1)
        pos_embedding = torch.zeros((maxlen, emb_size))
        pos_embedding[:, 0::2] = torch.sin(pos * den)
        pos_embedding[:, 1::2] = torch.cos(pos * den)
        pos_embedding = pos_embedding.unsqueeze(-2)

        self.dropout = torch.nn.Dropout(dropout)
        self.register_buffer('pos_embedding', pos_embedding)

    def forward(self, token_embedding: Tensor):
        return self.dropout(token_embedding + self.pos_embedding[:token_embedding.size(0), :])

## Embedding
- Provides embedding given index of vector

In [None]:
class TokenEmbedding(torch.nn.Module):
    '''
    Simple Embedding Layer
    '''
    
    def __init__(self, vocab_size: int, emb_size):
        super(TokenEmbedding, self).__init__()
        self.embedding = torch.nn.Embedding(vocab_size, emb_size)
        self.emb_size = emb_size

    def forward(self, tokens: Tensor):
        return self.embedding(tokens.long()) * math.sqrt(self.emb_size)

## Transformer
- The actual model

In [None]:
class Seq2SeqTransformer(torch.nn.Module):
    '''
    Transformer Model
    '''
    
    def __init__(self, num_encoder_layers, num_decoder_layers, emb_size, nhead, src_vocab_size, tgt_vocab_size, dim_feedforward,dropout = 0.1):
        super(Seq2SeqTransformer, self).__init__()
        self.transformer = torch.nn.Transformer(d_model=emb_size, nhead=nhead, num_encoder_layers=num_encoder_layers, num_decoder_layers=num_decoder_layers, dim_feedforward=dim_feedforward, dropout=dropout)
        self.generator = torch.nn.Linear(emb_size, tgt_vocab_size)
        self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size)
        self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size)
        self.positional_encoding = PositionalEncoding(emb_size, dropout=dropout)

    def forward(self, src, trg, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask, memory_key_padding_mask):
        src_emb = self.positional_encoding(self.src_tok_emb(src))
        tgt_emb = self.positional_encoding(self.tgt_tok_emb(trg))
        outs = self.transformer(src_emb, tgt_emb, src_mask, tgt_mask, None, src_padding_mask, tgt_padding_mask, memory_key_padding_mask)
        return self.generator(outs)

    def encode(self, src, src_mask):
        return self.transformer.encoder(self.positional_encoding(self.src_tok_emb(src)), src_mask)

    def decode(self, tgt, memory, tgt_mask):
        return self.transformer.decoder(self.positional_encoding(self.tgt_tok_emb(tgt)), memory, tgt_mask)

## Masks 
- Preventing looking to future words
- Ignoring the padding tokens

In [None]:
def generate_square_subsequent_mask(sz):
    '''
    Strictly Lower traingular mask
    '''
    
    mask = (torch.triu(torch.ones((sz, sz), device=DEVICE)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask

def create_mask(src, tgt):
    '''
    Creating mask for source and target languages
    '''
    
    src_seq_len = src.shape[0]
    tgt_seq_len = tgt.shape[0]

    tgt_mask = generate_square_subsequent_mask(tgt_seq_len)
    src_mask = torch.zeros((src_seq_len, src_seq_len),device=DEVICE).type(torch.bool)

    src_padding_mask = (src == PAD_IDX).transpose(0, 1)
    tgt_padding_mask = (tgt == PAD_IDX).transpose(0, 1)
    return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask

## Evaluation Functions

### Greedy Decoding

In [None]:
def greedy_decode(model, src, src_mask, max_len, start_symbol):
    '''
    Using the most-probable word 
    '''
    
    src = src.to(DEVICE)
    src_mask = src_mask.to(DEVICE)

    memory = model.encode(src, src_mask)
    ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(DEVICE)
    for i in range(max_len-1):
        memory = memory.to(DEVICE)
        tgt_mask = (generate_square_subsequent_mask(ys.size(0))
                    .type(torch.bool)).to(DEVICE)
        out = model.decode(ys, memory, tgt_mask)
        out = out.transpose(0, 1)
        prob = model.generator(out[:, -1])
        _, next_word = torch.max(prob, dim=1)
        next_word = next_word.item()

        ys = torch.cat([ys,
                        torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=0)
        if next_word == EOS_IDX:
            break
            
    return ys

### Translating Wrapper
- For performing translations

In [None]:
def translate(model, src_sentence):
    model.eval()
    
    # Sentence to tensor
    src = text_transform[SRC_LANGUAGE](src_sentence).view(-1, 1)
    num_tokens = src.shape[0]
    
    # Masks
    src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool)
    
    # Inference
    tgt_tokens = greedy_decode(model,  src, src_mask, max_len=num_tokens + 5, start_symbol=BOS_IDX).flatten()
    
    # Post Processing
    predicted_sen = " ".join(tgt_lang.lookup_tokens(list(tgt_tokens.cpu().numpy()))).replace("sos", "").replace("eos", "")
    post_proc_sen = post_process(predicted_sen)
    model.train()
    
    return post_proc_sen

### Randomly translating during training

In [None]:
def evaluateRandomly(transformer, pairs, n=5):
    '''
    For monitoring visual performance of model during training
    '''
    
    for i in range(n):
        pair = random.choice(pairs)
        print('Source:', pair[0])
        print('Target:', pair[1])
        output_sentence = translate(transformer, pair[0])
        print('Predicted:', output_sentence)
        print('')

## Training

### Taking an epoch

In [None]:
def train_epoch(model, optimizer):
    model.train()
    losses = 0
    train_iter = pairs
    train_dataloader = DataLoader(train_iter, batch_size=BATCH_SIZE, collate_fn=collate_fn)

    for src, tgt in time_bar(train_dataloader):
        src = src.to(DEVICE)
        tgt = tgt.to(DEVICE)

        tgt_input = tgt[:-1, :]

        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)

        logits = model(src, tgt_input, src_mask, tgt_mask,src_padding_mask, tgt_padding_mask, src_padding_mask)

        optimizer.zero_grad()

        tgt_out = tgt[1:, :]
        loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
        loss.backward()

        optimizer.step()
        losses += loss.item()

    return losses / len(list(train_dataloader))

### Full Training
- Model is saved if it has the best BLEU score till now

In [None]:
def train(transformer, optimizer, loss_fn, n_epochs, pairs, val_source_data, val_target_data, learning_rate=0.001, loss_every=100, score_every=100):
    start = time.time()
    print_loss_total = 0
    
    b_max = 0
    
    for epoch in range(1, n_epochs + 1):
        loss = train_epoch(transformer, optimizer)
        print_loss_total += loss

        # Printing the loss and randomly evaluating
        if epoch % loss_every == 0:
            
            print_loss_avg = print_loss_total / print_every
            print_loss_total = 0
            print('%s (%d %d%%) %.4f' % (timeSince(start, epoch / n_epochs),
                                        epoch, epoch / n_epochs * 100, print_loss_avg))
            evaluateRandomly(transformer, pairs)
            
        # Calculating BLEU Scores on the validation Data
        if epoch % score_every == 0:
            b,c,r = scorer.score(transformer, val_source_data, val_target_data)
            
            # Saving the model if BLEU score is best till now
            if b_max < b:
                b_max = b
                file = open("transformer_max", 'wb')
                dill.dump(transformer, file)
                file.close()

# Training

### Reading Data

In [None]:
src_lang, tgt_lang, pairs, val_source_data, val_target_data = prepareData(SRC_LANGUAGE, TGT_LANGUAGE)

### Post-Processing Data

In [None]:
replace_dict = create_replace_dict(pairs, THRESH)

### Creating New Model

In [None]:
EMB_SIZE = 512
NHEAD = 8
FFN_HID_DIM = 512
NUM_ENCODER_LAYERS = 3
NUM_DECODER_LAYERS = 3
SRC_VOCAB_SIZE = len(src_lang.word2index)
TGT_VOCAB_SIZE = len(tgt_lang.word2index)

transformer = Seq2SeqTransformer(NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS, EMB_SIZE,
                                 NHEAD, SRC_VOCAB_SIZE, TGT_VOCAB_SIZE, FFN_HID_DIM)
transformer = transformer.to(DEVICE)

### Loading Already-Trained Model 

In [None]:
file = open(model_path, 'rb')
m = dill.load(file)
file.close()

transformer = m.transformer
src_lang = m.src_lang
tgt_lang = m.tgt_lang
SRC_VOCAB_SIZE = len(src_lang.word2index)
TGT_VOCAB_SIZE = len(tgt_lang.word2index)

### Training the Model

In [None]:
BATCH_SIZE = 100

loss_fn = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX)

optimizer = torch.optim.Adam(transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)

train(transformer, optimizer, loss_fn, 10, pairs, val_source_data, val_target_data, learning_rate=0.001,
               loss_every=1, score_every=1)

### Running on Validation Data

In [None]:
def validate(transformer): 
    transformer.eval()
    
    # Load Data
    print("Reading Data...")
    dl_lp = Dataloader(val_data_path, None)
    ids, source_data = dl_lp.get_data(False)
    
    N = len(source_data)
    
    # Translate for Validation Data
    translations = []
    for ix, sen in time_bar(enumerate(source_data), total=len(source_data)):
        translated_sen = translate(transformer, sen)
        translations.append({"ID": ids[ix], "Translation": translated_sen})

    translations_df = pd.DataFrame(translations)
    translations_df.to_csv("translations.csv", index=False)
    
    translations_df.to_csv(f"answer.csv", index=False, quotechar='"', quoting=csv.QUOTE_NONNUMERIC, sep='\t')
    
    transformer.train()

In [None]:
# Validation

validate(transformer)

### Saving the Model
Model consists of 3 parts
- Transformer
- Source Language
- Target Language

In [None]:
class model:
    def __init__(self, transformer, src_lang, tgt_lang):
        self.transformer = transformer
        self.src_lang = src_lang
        self.tgt_lang = tgt_lang
        
m = model(transformer, src_lang, tgt_lang)

file = open("model_final", 'wb')
dill.dump(m, file)
file.close()

### Experimentation
- Now your chance to play with the model

In [None]:
sen = "ये सभी 14 साल से महाराष्ट्र में पढ़ाई कर रहे हैं।"
translated_sen = translate(transformer, sen)
print(translated_sen)