In [None]:
# -*- coding: utf-8 -*-
# import statements
import numpy as np
import pandas as pd
from AttentionModels import *
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow_addons.optimizers import RectifiedAdam, Lookahead, AdamW
from random import randint
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.font_manager import FontProperties
from IPython.display import display, HTML

In [None]:
### BEST PARAMETERS ###
LAYER_TYPE = "lstm"
NUM_RECURRENT_UNITS = 256
ENC_EMBED_DIM = 256
DEC_EMBED_DIM = 256
ATTN_DIM = 256
DROPOUT = 0.2
NUM_ENCODER_RECURRENT_LAYERS = 1
NUM_DECODER_RECURRENT_LAYERS = 3
OPTIMIZER = "adamw"
LR = 0.01
WEIGHT_DECAY = 0.0001
BATCH_SIZE = 128

In [None]:
# directory paths
train_dir = "./dakshina_dataset_v1.0/hi/lexicons/hi.translit.sampled.train.tsv"
dev_dir = "./dakshina_dataset_v1.0/hi/lexicons/hi.translit.sampled.dev.tsv"
test_dir = "./dakshina_dataset_v1.0/hi/lexicons/hi.translit.sampled.test.tsv"

In [None]:
# define a function to compute word-level accuracy
# after stripping of all pad-tokens
def compute_word_accuracy(y_true, y_pred, tokens):
    # count to keep track of correct predictions
    # and complete set of predictions and targets
    count, S_y, S_t = 0, [], []

    for t, y in zip(y_true, y_pred):
        # s_t and s_y are the target and prediction
        s_y, s_t = '', ''
        for i in y:
            c = tokens[int(i)]
            # if we encounter stop-token, stop forming the word
            if c == '>':
                break 
            # else add the character to the string
            s_y += c
        # strip all unnecessary characters and append to set of all predictions
        s_y = s_y.strip()
        S_y.append(s_y)
        for i in t:
            c = tokens[int(i)]
            # if we encounter stop-token, stop forming the word
            if c == '>':
                break 
            # else add the character to the string
            s_t += c
        # strip all unnecessary characters and append to set of all predictions
        s_t = s_t.strip()
        S_t.append(s_t)
        # check if the target word == predicted word
        count += int(s_t == s_y)

    # create a dataframe from all the targets and predictions
    df = pd.DataFrame(list(zip(S_t, S_y)), columns=['Target', 'Prediction']) 
    # to compute accuracy, divide by total number of items in the dataset
    # return both accuracy and dataframe
    return count/len(y_true), df

In [None]:
# a function to read the data into a pd dataframe
def load_data(path):
    data = pd.read_csv(path, 
                       sep='\t',
                       encoding="utf8",
                       names=["hi","en","_"], 
                       skip_blank_lines=True)
                           
    data = data[data['hi'].notna()]
    data = data[data['en'].notna()]
    data = data[['hi','en']]
    return data

In [None]:
# a function to preprocess the data
def pre_process(data, max_eng_len, max_hin_len, eng_token_map, hin_token_map):
    x = data['en'].values 
    # add start and end tokens to the hindi word
    y = '<' + data['hi'].values + '>'
    
    # a is the encoder input
    a = np.zeros((len(x), max_eng_len))
    # b is the decoder input (has start-token and end-token)
    b = np.zeros((len(y), max_hin_len))
    # c is the decoder output, which leads the decoder input by one step
    # as it does not have start token in the beginning
    c = np.zeros((len(y), max_hin_len))
    
    # replace the characters by numbers so that the model can process them
    # use a inverted_index to map the characters to integers
    # these integers are just the index when the vocabulary characters are sorted
    for i, (xx, yy) in enumerate(zip(x, y)):
        for j, ch in enumerate(xx):
            a[i, j] = eng_token_map[ch]
        for j, ch in enumerate(yy):
            b[i, j] = hin_token_map[ch]
            if j > 0:
                c[i, j-1] = hin_token_map[ch]
    return a, b, c

In [None]:
# load the train, validation and test data
train = load_data(train_dir)
dev = load_data(dev_dir)
test = load_data(test_dir)

# add start and end tokens to the hindi word
# now generate the english and hindi vocabulary
x = train['en'].values
y = '<' + train['hi'].values + '>'

# get the set of all unique characters, i.e. the vocabulary
eng_tokens = set()
hin_tokens = set()
for xx, yy in zip(x,y):
    for ch in xx:
        eng_tokens.add(ch)
    for ch in yy:
        hin_tokens.add(ch)

# sort the characters and create a inverted_index 
# to map the characters to their index in the vocabulary
eng_tokens = sorted(list(eng_tokens))
hin_tokens = sorted(list(hin_tokens))
eng_token_map = dict([(ch, i+1) for i, ch in enumerate(eng_tokens)])
hin_token_map = dict([(ch, i+1) for i, ch in enumerate(hin_tokens)])
eng_tokens.insert(0, ' ')
hin_tokens.insert(0, ' ')
eng_token_map[' '] = 0
hin_token_map[' '] = 0
max_eng_len = max([len(xx) for xx in x])
max_hin_len = max([len(yy) for yy in y])

# get the training encoder input, decoder input and decoder target
trainxe, trainxd, trainy = pre_process(train, 
                                       max_eng_len, 
                                       max_hin_len, 
                                       eng_token_map, 
                                       hin_token_map)

# get the validation encoder input, decoder input and decoder target
valxe, valxd, valy = pre_process(dev, 
                                 max_eng_len, 
                                 max_hin_len, 
                                 eng_token_map, 
                                 hin_token_map)

# get the test encoder input, decoder input and decoder target
# ignore the decoder target and only use it to check the metrics at the end
testxe, testxd, testy = pre_process(test,
                                    max_eng_len, 
                                    max_hin_len, 
                                    eng_token_map, 
                                    hin_token_map)

In [None]:
# Since we have custom objects, we can't save the model so easily
# Therefore, we have to re-train the model with the test parameters again
# create the encoder with the best hyperparameters
encoder = Encoder(input_dim=int(trainxe.max())+1,
                  embed_dim=ENC_EMBED_DIM,
                  cell_hidden_dim=NUM_RECURRENT_UNITS,
                  dropout=DROPOUT,
                  k=NUM_ENCODER_RECURRENT_LAYERS, 
                  cell_type=LAYER_TYPE)

# create the decoder with the best hyperparameters
decoder = AttentionDecoder(input_dim=int(trainxd.max())+1, 
                           output_dim=int(trainy.max())+1, 
                           embed_dim=DEC_EMBED_DIM,
                           attn_dim=ATTN_DIM,
                           cell_hidden_dim=NUM_RECURRENT_UNITS,
                           dropout=DROPOUT,
                           k=NUM_DECODER_RECURRENT_LAYERS,
                           cell_type=LAYER_TYPE)

# create the transliteration model with the created encoder and decoder
model = TransliterationModel(encoder=encoder, 
                             decoder=decoder, 
                             tgt_max_len=max_hin_len)

# instantiate and use the best optimizer
optimizer = {
    "ranger": Lookahead(RectifiedAdam(learning_rate=LR, weight_decay=WEIGHT_DECAY, amsgrad=True)),
    "adamw": AdamW(learning_rate=LR, weight_decay=WEIGHT_DECAY, amsgrad=True),
}[OPTIMIZER]

# define early stopping to terminate the run if the validation accuracy drops
# continously for 4 times
early_stop = EarlyStopping(monitor="val_accuracy",
                           patience=4,
                           restore_best_weights=True,
                           min_delta=1e-3)
                           
# compile the model and fit it to the data
model.compile(optimizer=optimizer, 
              loss="sparse_categorical_crossentropy", 
              metrics=["accuracy"])

model.fit([trainxe, trainxd], 
          trainy, 
          epochs=25, 
          callbacks=[early_stop],
          batch_size=BATCH_SIZE,
          validation_data=([valxe, valxd], valy), 
          shuffle=True)

In [None]:
# create a tf dataset from the test data to work with batches easily
test_dataset = tf.data.Dataset.from_tensor_slices((testxe, testxd)).batch(BATCH_SIZE)
attention_weights, test_pred = [], []

# get the predictions and attention weights for each input batch
for xe, xd in test_dataset:
    p, a = model([xe, xd[:, 0]])
    test_pred.append(p.numpy())
    attention_weights.append(a.numpy())

# concatenate all predictions into a single list
attention_weights = np.concatenate(attention_weights, axis=0)
test_pred = np.concatenate(test_pred, axis=0)

# obtain the test word-level accuracy and complete set of predictions
test_word_accuracy, df = compute_word_accuracy(testy.tolist(), 
                                               test_pred.tolist(), 
                                               hin_tokens)

# save the predictions as a csv file
df.insert(loc=0, column="data", value=test['en'])
df.to_csv("predictions_attention/predictions.csv", encoding="utf-8")
print(f"Test_word_accuracy: {test_word_accuracy:.4f}\n")

# Plot Attention Heatmaps

In [None]:
# set required font for devanagiri characters
font_prop = FontProperties(fname="VesperLibre-Regular.ttf")
attn_maps, xs, zs = [], [], []
df = df.values.tolist()

correct = 0
# sample 9 random english words and their corresponding hindi transliterations and attention weights
for _ in range(9):
    i = randint(0, len(df)-1)
    x, y, z = df[i]
    correct += int(y == z)
    aw = attention_weights[i]
    mp = aw[:len(z)][:, :len(x)]
    attn_maps.append(mp)
    xs.append(x)
    zs.append(z)

# check how many of that sample are correct
print(f"correct: {correct}/9")
plt.close('all')
# plot those attention weights as a heatmap using sns
fig, axes = plt.subplots(3, 3, figsize=(15, 15), constrained_layout=True)
plt.suptitle('Attention Heatmaps', fontsize='xx-large')

for x, z, mp, ax in zip(xs, zs, attn_maps, axes.flat):
    g = sns.heatmap(mp, linewidth=0.5, ax=ax)
    # set necessary fonts and ticks for neat images
    g.set_xticklabels(list(x), fontproperties=font_prop, fontsize='xx-large')
    g.set_yticklabels(list(z), fontproperties=font_prop, rotation=45, fontsize='xx-large')
    g.set_xlabel(f'{x}', fontproperties=font_prop, fontsize="xx-large")
    g.set_ylabel(f'{z}', fontproperties=font_prop, fontsize="xx-large")
    g.tick_params(labelsize=15)
    g.xaxis.tick_top()
    g.xaxis.set_label_position('top')
    g.set_aspect("equal")
    g.set_frame_on(False)

# Visualize the connectivity

In [None]:
## MOSTLY CODE TAKEN FROM BLOG GIVEN IN THE QUEUSTION ##
def get_clr(value):
    colors = ['#85c2e1', '#89c4e2', '#95cae5', '#99cce6', '#a1d0e8'
              '#b2d9ec', '#baddee', '#c2e1f0', '#eff7fb', '#f9e8e8',
              '#f9e8e8', '#f9d4d4', '#f9bdbd', '#f8a8a8', '#f68f8f',
              '#f47676', '#f45f5f', '#f34343', '#f33b3b', '#f42e2e']
    value = min(int((value * 100) / 5), len(colors)-1)
    return colors[value]

def cstr(s, color='black'):
    if s == ' ':
        return f"<text style=color:#000;padding-left:10px;background-color:{s}> </text>"
    return f"<text style=color:#000;background-color:{color}>{s} </text>"

# helps print colors in html document
def print_color(t):
    display(HTML(''.join([cstr(t_i, color=c_i) for (t_i, c_i) in t])))
 
# for each character being decoded, highlight the input sequence characters according to the attention weights
def visualize(input_word, output_word, attn_map, idx):
    print(f"Highlighting connectivity for: {output_word[idx]}")
    text_colours = [(c, get_clr(a)) for (c, a) in zip(input_word, attn_map[idx])]
    print_color(text_colours)
    print()

In [None]:
correct = 0
# sample and print the connectivity for 5 random samples
for _ in range(5):
    i = randint(0, len(df)-1)
    (x, y, z), mp = df[i], attention_weights[i]
    # to check how many of those sample are correct
    correct += int(y == z)
    # plot the visualization
    print(f"visualization for {x} --> {z}")
    for idx in range(len(z)):
        mp = mp[:len(z)][:, :len(x)]
        visualize(x, z, mp, idx)
    print("-"*50)
print(f"correct: {correct}/5")