# Attention Mechanism Best Model Train Test

**Submitted By:**
Joyojyoti Acharya - CS21M024,
Vrushab Karia - CS21M075

### Importing the Necessary Packages

In [1]:
#using tensorflow 1.13.2
!pip install tensorflow==1.13.2
import tensorflow as tf
import pandas as pd
import numpy as np
from tensorflow import keras
tf.test.gpu_device_name()

### Importing Dataset

In [12]:
#Using target Language as Hindi
target_language = "hi"
DATAPATH = "/kaggle/input/dakshina/dakshina_dataset_v1.0/{}/lexicons/{}.translit.sampled.{}.tsv"

#Defining training, validation and test path and reading the data from dataset.

#Training
train_path = DATAPATH.format(target_language, target_language, "train")
train_data = pd.read_csv(train_path, sep = '\t', header = None)

#Validation
dev_path = DATAPATH.format(target_language, target_language, "dev")
dev_data = pd.read_csv(dev_path, sep = '\t', header = None)

#Test
test_path = DATAPATH.format(target_language, target_language, "test")
test_data = pd.read_csv(test_path, sep = '\t', header = None)

### Spliting the dataset into wordwise and characterwise

In [13]:
#All unique characters
input_characters = set()
target_characters = set()
input_characters.add(' ')
target_characters.add(' ')

#Training Data
train_input = [str(w) for w in train_data[1]]
train_target = ["\t" + str(w) + "\n" for w in train_data[0]]
for word in train_input:
    for char in word:
        input_characters.add(char)
for word in train_target:
    for char in word:
        target_characters.add(char)

#Validation Data
dev_input = [str(w) for w in dev_data[1]]
dev_target = ["\t" + str(w) + "\n" for w in dev_data[0]]
for word in dev_input:
    for char in word:
        input_characters.add(char)
for word in dev_target:
    for char in word:
        target_characters.add(char)

#Test Data
test_input = [str(w) for w in test_data[1]]
test_target = ["\t" + str(w) + "\n" for w in test_data[0]]

for word in test_input:
    for char in word:
        input_characters.add(char) 
for word in test_target:
    for char in word:
        target_characters.add(char)
        
#Sorting the characters
input_characters = list(input_characters)
target_characters = list(target_characters)
input_characters.sort()
target_characters.sort()

### Fetching character and maximum sequence length

In [14]:
num_encoder_tokens = len(input_characters)
num_decoder_tokens = len(target_characters)
max_encoder_seq_length = max(max([len(text) for text in train_input]),max([len(text) for text in dev_input]))
max_encoder_seq_length = max(max_encoder_seq_length,max([len(text) for text in test_input]))
                             
max_decoder_seq_length = max(max([len(text) for text in train_target]),max([len(text) for text in dev_target]))
max_decoder_seq_length = max(max_decoder_seq_length,max([len(text) for text in test_target]))
                             
print("Number of Training samples:", len(train_input))
print("Number of Validation samples:", len(dev_input))
print("Number of Test samples:", len(test_input))
                             
print("Number of unique input tokens:", num_encoder_tokens)
print("Number of unique output tokens:", num_decoder_tokens)
print("Max sequence length for inputs:", max_encoder_seq_length)
print("Max sequence length for outputs:", max_decoder_seq_length)

### Dictionary Indexing and Inverse Dictionary Indexing for the unique Characters

In [15]:
input_token_index = dict([(char, i) for i, char in enumerate(input_characters)])
inverse_input_token_index = dict([(i, char) for i, char in enumerate(input_characters)])
target_token_index = dict([(char, i) for i, char in enumerate(target_characters)])
inverse_target_token_index = dict([(i, char) for i, char in enumerate(target_characters)])

### Training Encoder-Decoder One Hot Data Preparation

In [16]:
train_encoder_input_data = np.zeros((len(train_input), max_encoder_seq_length), dtype="float32")
train_decoder_input_data = np.zeros((len(train_input), max_decoder_seq_length), dtype="float32")
train_decoder_target_data = np.zeros((len(train_input), max_decoder_seq_length, num_decoder_tokens), dtype="float32")
for i, (input_text, target_text) in enumerate(zip(train_input, train_target)):
    for t, char in enumerate(input_text):
        train_encoder_input_data[i, t] = input_token_index[char]
    train_encoder_input_data[i, t + 1 :] = input_token_index[' ']
    for t, char in enumerate(target_text):
        train_decoder_input_data[i, t] = target_token_index[char]
        if t > 0:
            train_decoder_target_data[i, t - 1, target_token_index[char]] = 1.0
    train_decoder_input_data[i, t + 1 :] = target_token_index[' ']
    train_decoder_target_data[i, t:, target_token_index[' ']] =  1.0

### Validation Encoder-Decoder One Hot Data Preparation

In [17]:
dev_encoder_input_data = np.zeros((len(dev_input), max_encoder_seq_length), dtype="float32")
dev_decoder_input_data = np.zeros((len(dev_input), max_decoder_seq_length), dtype="float32")
dev_decoder_target_data = np.zeros((len(dev_input), max_decoder_seq_length, num_decoder_tokens), dtype="float32")
for i, (input_text, target_text) in enumerate(zip(dev_input, dev_target)):
    for t, char in enumerate(input_text):
        dev_encoder_input_data[i, t] = input_token_index[char]
    dev_encoder_input_data[i, t + 1 :] = input_token_index[' ']
    for t, char in enumerate(target_text):
        dev_decoder_input_data[i, t] = target_token_index[char]
        if t > 0:
            dev_decoder_target_data[i, t - 1, target_token_index[char]] = 1.0
    dev_decoder_input_data[i, t + 1 :] = target_token_index[' ']
    dev_decoder_target_data[i, t:, target_token_index[' '] ] = 1.0

### TEST DATA SETUP

In [18]:
test_encoder_input_data = np.zeros((len(test_input), max_encoder_seq_length), dtype="float32")
for i, input_word in enumerate(test_input):
    for t, char in enumerate(input_word):
        test_encoder_input_data[i, t] = input_token_index[char]
    test_encoder_input_data[i, t + 1 :] = input_token_index[' ']

### Defining Attention Class

In [19]:
#importing packages
import tensorflow as tf
from tensorflow.python.keras.layers import Layer
from tensorflow.python.keras import backend as K

#AttentionLayer Class
class AttentionLayer(Layer):
    def __init__(self, **args):
        super(AttentionLayer, self).__init__(**args)
    
    #build function
    def build(self, input_shape):
        
        #random initialization of w_a
        self.W_a = self.add_weight(name='W_a',
                                   shape = tf.TensorShape((input_shape[0][2], input_shape[0][2])),
                                   initializer = 'uniform',
                                   trainable = True)

        #random initialization of u_a
        self.U_a = self.add_weight(name = 'U_a',
                                   shape = tf.TensorShape((input_shape[1][2], input_shape[0][2])),
                                   initializer = 'uniform',
                                   trainable = True)

        #random initialization of v_a
        self.V_a = self.add_weight(name = 'V_a',
                                   shape = tf.TensorShape((input_shape[0][2], 1)),
                                   initializer = 'uniform',
                                   trainable = True)

        super(AttentionLayer, self).build(input_shape)  # Be sure to call this at the end

    #call function
    def call(self, inputs):
       
        """
        inputs: [encoder_output_sequence, decoder_output_sequence]
        """
        encoder_out_seq, decoder_out_seq = inputs
        
        #energy_step function
        def energy_step(inputs, states):
           
            """ Step function for computing energy for a single decoder state
            inputs: (batchsize * 1 * de_in_dim)
            states: (batchsize * 1 * de_latent_dim)
            """

            """ Some parameters required for shaping tensors"""
            en_seq_len, en_hidden = encoder_out_seq.shape[1], encoder_out_seq.shape[2]
            de_hidden = inputs.shape[-1]

            """ Computing S.Wa where S=[s0, s1, ..., si]"""
            W_a_dot_s = K.dot(encoder_out_seq, self.W_a)

            """ Computing hj.Ua """
            U_a_dot_h = K.expand_dims(K.dot(inputs, self.U_a), 1)

            """ tanh(S.Wa + hj.Ua) """
            Ws_plus_Uh = K.tanh(W_a_dot_s + U_a_dot_h)

            """ softmax(va.tanh(S.Wa + hj.Ua)) """
            e_i = K.squeeze(K.dot(Ws_plus_Uh, self.V_a), axis=-1)
            e_i = K.softmax(e_i)
            
            return e_i, [e_i]

        #context_step function
        def context_step(inputs, states):
            """ Step function for computing ci using ei """

            c_i = K.sum(encoder_out_seq * K.expand_dims(inputs, -1), axis=1)
            return c_i, [c_i]

        fake_state_c = K.sum(encoder_out_seq, axis=1)
        fake_state_e = K.sum(encoder_out_seq, axis=2) 

        """ Computing energy outputs """
        last_out, e_outputs, _ = K.rnn(energy_step, decoder_out_seq, [fake_state_e],)

        """ Computing context vectors """
        last_out, c_outputs, _ = K.rnn(context_step, e_outputs, [fake_state_c],)

        return c_outputs, e_outputs

### Wordwise Inference Mechanism for Attention Approach

In [20]:
#importing packages
import numpy as np
from tensorflow import keras
from random import sample

#sigmoid function
def sigmoid(i):
    return [1/(1 + np.exp(-z)) for z in i]

# Attention_Inference Function
def attention_inference(model, dev_encoder_input_data, test_input, test_target, num_decoder_tokens, max_decoder_seq_length, target_token_index, inverse_target_token_index, latent_dim, model_name):
    
    #Function for Sequence Prediction
    def decode_sequence_prediction(input_sequence):
        # Encode the input as state vectors.
        encoder_outputs = encoder_model.predict(input_sequence)
        encoder_output, states_value = encoder_outputs[0], encoder_outputs[1:]
        
        # Generate empty target sequence of length 1.
        target_sequence = np.zeros((1, 1))

        # Populate the first character of target sequence with the start character.
        target_sequence[0, 0] = target_token_index["\t"]
        
        flag = True
        output_sequence = ""

        while flag:
            output = decoder_model.predict([target_sequence] + states_value + [encoder_output])
            output_tokens, states_value, attention_weights = output[0], output[1:-1], output[-1]

            # Sample a token/character
            sampled_token_index = np.argmax(output_tokens[0, -1, :])
            sampled_character = inverse_target_token_index[sampled_token_index]
            output_sequence += sampled_character

            if sampled_character == "\n" or len(output_sequence) > max_decoder_seq_length:
                flag = False

            target_sequence = np.zeros((1, 1))
            target_sequence[0, 0] = sampled_token_index

        return output_sequence
    
    print(model.summary())

    # Encoder Model
    encoder_inputs = model.input[0]

    if model_name == "RNN" or model_name == "GRU":
        encoder_outputs, state = model.layers[4].output
        encoder_model = keras.Model(encoder_inputs, [encoder_outputs] + [state])
    
    elif model_name == "LSTM":
        encoder_outputs, state_h_enc, state_c_enc = model.layers[4].output
        encoder_model = keras.Model(encoder_inputs, [encoder_outputs] + [state_h_enc, state_c_enc])
    
    else:
        print("Wrong Choice of Model")
        return

    #Decoder Model
    decoder_inputs = model.input[1]  # input_2
    decoder_outputs = model.layers[3](decoder_inputs)

    if model_name == "RNN" or model_name == "GRU":
        state = keras.Input(shape = (latent_dim, ))
        decoder_states_inputs = [state]
        decoder_outputs, state = model.layers[5](decoder_outputs, initial_state = decoder_states_inputs)
        decoder_states = [state]

    elif model_name == "LSTM":
        state_h_dec, state_c_dec = keras.Input(shape = (latent_dim, )), keras.Input(shape = (latent_dim, ))
        decoder_states_inputs = [state_h_dec, state_c_dec]
        decoder_outputs, state_h_dec, state_c_dec = model.layers[5](decoder_outputs, initial_state = decoder_states_inputs)
        decoder_states = [state_h_dec, state_c_dec]
        
    else:
        print("Wrong Choice of Model")
        
    attention_inputs = keras.Input(shape = (None, latent_dim, ))
    attention_output, attention_scores = model.layers[6]([attention_inputs, decoder_outputs])
    concatenated_decoder_input = model.layers[7]([decoder_outputs, attention_output])

    # Decoder Dense layer
    decoder_dense = model.layers[8]
    decoder_outputs = decoder_dense(concatenated_decoder_input)

    # Final decoder model
    decoder_model = keras.Model([decoder_inputs] + decoder_states_inputs + [attention_inputs], [decoder_outputs] + decoder_states + [attention_scores])

    #count the correct predictions
    correct_count, test_size = 0, len(test_input)
    
    #File to Laod the Prediction
    attention_prediction = open("/kaggle/working/predictions_attention.csv", "w", encoding='utf-8')
    attention_prediction.write("Input Sentence,Original Target Sentence,Predicted Output Sentence\n")
    for i in range(test_size):
        # Take one sequence (part of the training set)
        if i%50==0:
            print("Testing at: ",i)
        input_sequence = dev_encoder_input_data[i : i + 1]
        decoded_word = decode_sequence_prediction(input_sequence)
        original_word = test_target[i][1:]
        attention_prediction.write(test_input[i] + "," + decoded_word[:-1] + "," + original_word[:-1] + "\n")
        if(original_word == decoded_word):
            correct_count += 1
            
    return correct_count / test_size


### Main Block to Train the Best Model

In [21]:
#importing packages 
import numpy as np
from tensorflow import keras

#parameters in the main block to train the model
hidden_layer_size=384
learning_rate=0.001
optimizer='adam'
batch_size=512
model_name = "LSTM"
embedding_size = 512
dropout = 0.3
epochs = 20
        
#Encoder Model
encoder_inputs = keras.Input(shape = (None, ))
encoder_outputs = keras.layers.Embedding(input_dim = num_encoder_tokens, output_dim = embedding_size, input_length = max_encoder_seq_length)(encoder_inputs)

# Encoder Model Choice
if model_name == "RNN":
    encoder_outputs, state = keras.layers.SimpleRNN(hidden_layer_size, dropout = dropout, return_state = True, return_sequences = True)(encoder_outputs)
    encoder_states = [state]
elif model_name == "LSTM":
    encoder_outputs, state_h, state_c = keras.layers.LSTM(hidden_layer_size, dropout = dropout, return_state = True, return_sequences = True)(encoder_outputs)
    encoder_states = [state_h,state_c]
elif model_name == "GRU":
    encoder_outputs, state = keras.layers.GRU(hidden_layer_size, dropout = dropout, return_state = True, return_sequences = True)(encoder_outputs)
    encoder_states = [state]
else:
    print("Wrong Choice of Model")

# Decoder Model
decoder_inputs = keras.Input(shape=(None, ))
decoder_outputs = keras.layers.Embedding(input_dim = num_decoder_tokens, output_dim = embedding_size, input_length = max_decoder_seq_length)(decoder_inputs)

# We will test on only one layer of encoder and only one layer of decoder model

if model_name == "RNN":
    decoder = keras.layers.SimpleRNN(hidden_layer_size, dropout = dropout, return_sequences = True, return_state = True)
    decoder_outputs, state = decoder(decoder_outputs, initial_state = encoder_states)
    decoder_states = [state]
elif model_name == "LSTM":
    decoder = keras.layers.LSTM(hidden_layer_size, dropout = dropout, return_sequences = True, return_state = True)
    decoder_outputs, state_h, state_c = decoder(decoder_outputs, initial_state = encoder_states)
    decoder_states = [state_h, state_c]
elif model_name == "GRU":
    decoder = keras.layers.GRU(hidden_layer_size, dropout = dropout, return_sequences = True, return_state = True)
    decoder_outputs, state = decoder(decoder_outputs, initial_state = encoder_states)
    decoder_states = [state]
else:
    print("Wrong Model Choice")
        
# Adding Attention Layer
attention = AttentionLayer()
attention_output, _ = attention([encoder_outputs, decoder_outputs])
concatenated_decoder_input = keras.layers.Concatenate(axis = -1)([decoder_outputs, attention_output])

#Decoder Dense Layer
decoder_dense = keras.layers.Dense(num_decoder_tokens, activation = "softmax")
decoder_outputs = decoder_dense(concatenated_decoder_input)

#Runnable Model
model = keras.Model([encoder_inputs, decoder_inputs], decoder_outputs)
model.summary()
    
#Different Optimizers
if optimizer == 'adam':
    model.compile(optimizer="adam", loss="categorical_crossentropy", metrics=["accuracy"])
elif optimizer == 'nadam':
    model.compile(optimizer="nadam", loss="categorical_crossentropy", metrics=["accuracy"])
elif optimizer == 'rmsprop':
    model.compile(optimizer="rmsprop", loss="categorical_crossentropy", metrics=["accuracy"])
else:
    print("Wrong optimizer Choice...")
        
#Model fitting with train and validation data characterwise
model.fit(
    [train_encoder_input_data, train_decoder_input_data],
    train_decoder_target_data,
    batch_size = batch_size,
    epochs = epochs,
    validation_data = ([dev_encoder_input_data, dev_decoder_input_data], dev_decoder_target_data)
)

#Wordwise Validation Data and Accuracy on the model
validation_accuracy = attention_inference(model,dev_encoder_input_data, dev_input, dev_target, num_decoder_tokens, max_decoder_seq_length, target_token_index, inverse_target_token_index, hidden_layer_size, model_name)
print("Attention_Wordwise_Val_Accuracy: ", validation_accuracy)

In [22]:
model.save("Attention_best_model")

### Test Accuracy

In [23]:
#Wordwise Validation Data and Accuracy on the model
test_accuracy = attention_inference(model,test_encoder_input_data, test_input, test_target, num_decoder_tokens, max_decoder_seq_length, target_token_index, inverse_target_token_index, hidden_layer_size, model_name)
print("Attention_Wordwise_Test_Accuracy: ", test_accuracy)

## Visualization of HeatMaps

### Wordwise Attention Inference Model with Heatmaps

In [24]:
#importing packages
import numpy as np
from tensorflow import keras
from random import sample


#sigmoid function
def sigmoid(i):
    return [1/(1 + np.exp(-z)) for z in i]

# Attention_Inference Function
def attention_inference_heatmaps(model, dev_encoder_input_data, test_input, test_target, num_decoder_tokens, max_decoder_seq_length, target_token_index, inverse_target_token_index, latent_dim, model_name):
    
    #Function for Sequence Prediction
    def decode_sequence_prediction(input_sequence):
        # Encode the input as state vectors.
        encoder_outputs = encoder_model.predict(input_sequence)
        encoder_output, states_value = encoder_outputs[0], encoder_outputs[1:]
        
        # Generate empty target sequence of length 1.
        target_sequence = np.zeros((1, 1))

        # Populate the first character of target sequence with the start character.
        target_sequence[0, 0] = target_token_index["\t"]
        
        #Heatmaps and Visualization Data
        heatmap = []
        visualization = []
        flag = True
        output_sequence = ""
        

        while flag:
            output = decoder_model.predict([target_sequence] + states_value + [encoder_output])
            output_tokens, states_value, attention_weights = output[0], output[1:-1], output[-1]

            # Sample a token/character
            sampled_token_index = np.argmax(output_tokens[0, -1, :])
            sampled_character = inverse_target_token_index[sampled_token_index]
            output_sequence += sampled_character

            if sampled_character == "\n" or len(output_sequence) > max_decoder_seq_length:
                flag = False

            target_sequence = np.zeros((1, 1))
            target_sequence[0, 0] = sampled_token_index
            heatmap.append((sampled_character, attention_weights))
            visualization.append((sampled_character, states_value[0]))

        return output_sequence, heatmap, visualization
    
    print(model.summary())

    # Encoder Model
    encoder_inputs = model.input[0]

    if model_name == "RNN" or model_name == "GRU":
        encoder_outputs, state = model.layers[4].output
        encoder_model = keras.Model(encoder_inputs, [encoder_outputs] + [state])
    
    elif model_name == "LSTM":
        encoder_outputs, state_h_enc, state_c_enc = model.layers[4].output
        encoder_model = keras.Model(encoder_inputs, [encoder_outputs] + [state_h_enc, state_c_enc])
    
    else:
        print("Wrong Choice of Model")
        return

    #Decoder Model
    decoder_inputs = model.input[1]  # input_2
    decoder_outputs = model.layers[3](decoder_inputs)

    if model_name == "RNN" or model_name == "GRU":
        state = keras.Input(shape = (latent_dim, ))
        decoder_states_inputs = [state]
        decoder_outputs, state = model.layers[5](decoder_outputs, initial_state = decoder_states_inputs)
        decoder_states = [state]

    elif model_name == "LSTM":
        state_h_dec, state_c_dec = keras.Input(shape = (latent_dim, )), keras.Input(shape = (latent_dim, ))
        decoder_states_inputs = [state_h_dec, state_c_dec]
        decoder_outputs, state_h_dec, state_c_dec = model.layers[5](decoder_outputs, initial_state = decoder_states_inputs)
        decoder_states = [state_h_dec, state_c_dec]
        
    else:
        print("Wrong Choice of Model")
        
    attention_inputs = keras.Input(shape = (None, latent_dim, ))
    attention_output, attention_scores = model.layers[6]([attention_inputs, decoder_outputs])
    concatenated_decoder_input = model.layers[7]([decoder_outputs, attention_output])

    # Decoder Dense layer
    decoder_dense = model.layers[8]
    decoder_outputs = decoder_dense(concatenated_decoder_input)

    # Final decoder model
    decoder_model = keras.Model([decoder_inputs] + decoder_states_inputs + [attention_inputs], [decoder_outputs] + decoder_states + [attention_scores])

    #count the correct predictions
    correct_count, test_size = 0, len(test_input)
    
    #Heatmaps and Visualizations
    visualisations = sample(range(test_size), 10)
    heatmaps = []
    
    #File to Laod the Prediction
    attention_prediction = open("/kaggle/working/predictions_attention.csv", "w", encoding='utf-8')
    attention_prediction.write("Input Sentence,Original Target Sentence,Predicted Output Sentence\n")
    for i in range(test_size):
        # Take one sequence (part of the training set)
        if i%50==0:
            print("Testing at: ",i)
        input_sequence = dev_encoder_input_data[i : i + 1]
        decoded_word, heatmap, visualization = decode_sequence_prediction(input_sequence)
        original_word = test_target[i][1:]
        attention_prediction.write(test_input[i] + "," + decoded_word[:-1] + "," + original_word[:-1] + "\n")
        if(original_word == decoded_word):
            correct_count += 1
            
        #Heatmaps and Visualizations
        if i in visualisations:
            
            # Connectivity Visualization - Q6
            with open("connectivity_visualization.txt", "a", encoding='utf-8') as file:

                actual_word = test_input[i]

                """writing to the output file for connectivity"""
                
                file.write(actual_word)
                file.write("\t")
                file.write(str(len(heatmap)))
                file.write("\n")

                for t in range(len(heatmap)):
                    dec_char = heatmap[t][0]
                    dec_char_prob = heatmap[t][1].reshape(-1)
                
                    if t != len(heatmap) - 1: 
                        file.write(dec_char)
                    else:
                        file.write("<e>")
                    
                    file.write("\t")

                    for p in range(len(actual_word)):
                        file.write(str(dec_char_prob[p]))
                        file.write("\t")

                    file.write("\n")

                file.write("Next\n")
            
            # Heatmap Plot Data
            heatmaps.append((test_input[i], heatmap))
            
    return correct_count / test_size, heatmaps


## Finding Heatmaps

In [25]:
#Wordwise Validation Data and Accuracy on the model
test_accuracy, heatmaps = attention_inference_heatmaps(model,test_encoder_input_data, test_input, test_target, num_decoder_tokens, max_decoder_seq_length, target_token_index, inverse_target_token_index, hidden_layer_size, model_name)
print("Attention_Wordwise_Test_Accuracy: ", test_accuracy)

### Heatmap Plots Function

In [26]:
import matplotlib.pyplot as plt
from matplotlib.font_manager import FontProperties
#HeatMap Plots function
def attention_heatmap_plot(input_sequence, heatmap):
    matrix = []
    decoder_inputs = []

    for data in heatmap:
        attn, idx = data[1], data[0], 
        matrix.append(attn.reshape(-1)[:len(input_sequence)])
        decoder_inputs.append(idx)
    

    figure, axis = plt.subplots()
    axis.imshow(np.array(matrix))
    
    #We have used "nirmala.ttf" for the "HINDI" font.
    axis.set_yticklabels([inp if inp != '\n' else "<e>" for inp in decoder_inputs], fontproperties = FontProperties(fname = "/kaggle/input/nirmala/nirmala.ttf"))
    axis.set_xticklabels([char for char in input_sequence])

    axis.set_xticks(np.arange(np.array(matrix).shape[1]))
    axis.set_yticks(np.arange(np.array(matrix).shape[0]))

    axis.tick_params(labelsize = 15)

    return figure

### Wandb Log of the sample HeatMaps

In [28]:
import wandb
wandb.init(project="CS6910-Assignment-3", entity="cs21m024_cs21m075")
wandb.log({"Q5D": [wandb.Image(attention_heatmap_plot(image[0],image[1])) for image in heatmaps]})
wandb.finish()