# DA6401 - Assignment 3 - MM21B051

## Step 1: Loading and preprocessing the data

th train test and dev tsv files were stored inside a folder named lexicons in kaggle during training and testing. please modify your data_dir according to your storage location.

In [None]:
import pandas as pd
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Embedding, LSTM, GRU, Dense, Attention, Layer
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
from sklearn.model_selection import train_test_split

# Define the path to your data directory on Kaggle
data_dir = '/kaggle/input/lexicons/'

train_df = pd.read_csv(f'{data_dir}hi.translit.sampled.train.tsv', sep='\t', header=None, names=['native', 'romanized', 'attestation'])
dev_df = pd.read_csv(f'{data_dir}hi.translit.sampled.dev.tsv', sep='\t', header=None, names=['native', 'romanized', 'attestation'])
test_df = pd.read_csv(f'{data_dir}hi.translit.sampled.test.tsv', sep='\t', header=None, names=['native', 'romanized', 'attestation'])

print("Train data shape:", train_df.shape)
print("Dev data shape:", dev_df.shape)
print("Test data shape:", test_df.shape)

In [None]:
def preprocess_data(df):
    input_texts = df['romanized'].tolist()
    target_texts = ['<start> ' + text + ' <end>' for text in df['native'].tolist()]
    return input_texts, target_texts

input_texts_train, target_texts_train = preprocess_data(train_df)
input_texts_train_raw, target_texts_train_raw = preprocess_data(train_df)

input_texts_dev, target_texts_dev = preprocess_data(dev_df)
input_texts_test, target_texts_test = preprocess_data(test_df)

print("Example input (train):", input_texts_train[0])
print("Example target (train):", target_texts_train[0])

Removing nan values in the train set

In [None]:
# Identify valid indices where input is a non-NaN string
valid_train_indices = [i for i, text in enumerate(input_texts_train_raw) if isinstance(text, str) and not pd.isna(text)]

# Filter both input and target lists using these valid indices
input_texts_train = [input_texts_train_raw[i] for i in valid_train_indices]
target_texts_train = [target_texts_train_raw[i] for i in valid_train_indices]

print("Example input (train - filtered):", input_texts_train[0])
print("Example target (train - filtered):", target_texts_train[0])

# Check again if any non-string or nan values remain in input_texts_train
non_string_inputs = [(i, text) for i, text in enumerate(input_texts_train) if not isinstance(text, str) or pd.isna(text)]
if non_string_inputs:
    print("Found remaining non-string or nan values in input_texts_train:")
    for index, value in non_string_inputs:
        print(f"Index: {index}, Type: {type(value)}, Value: {value}")
else:
    print("No non-string or nan values found in input_texts_train after filtering.")

In [None]:
# Tokenizer for input (romanized)
input_tokenizer = Tokenizer(char_level=True)
input_tokenizer.fit_on_texts(input_texts_train)
input_vocab_size = len(input_tokenizer.word_index) + 1
input_sequences_train = input_tokenizer.texts_to_sequences(input_texts_train)
input_sequences_dev = input_tokenizer.texts_to_sequences(input_texts_dev)
input_sequences_test = input_tokenizer.texts_to_sequences(input_texts_test)

# Tokenizer for target (native)
target_tokenizer = Tokenizer(char_level=True)
target_tokenizer.fit_on_texts(target_texts_train)
target_vocab_size = len(target_tokenizer.word_index) + 1
target_sequences_train = target_tokenizer.texts_to_sequences(target_texts_train)
target_sequences_dev = target_tokenizer.texts_to_sequences(target_texts_dev)
target_sequences_test = target_tokenizer.texts_to_sequences(target_texts_test)

print("Input vocabulary size:", input_vocab_size)
print("Target vocabulary size:", target_vocab_size)

In [None]:
max_input_len = max(len(seq) for seq in input_sequences_train)
max_target_len = max(len(seq) for seq in target_sequences_train)

encoder_input_data_train = pad_sequences(input_sequences_train, maxlen=max_input_len, padding='post')
decoder_input_data_train = pad_sequences(target_sequences_train, maxlen=max_target_len, padding='post')
decoder_target_data_train = np.zeros_like(decoder_input_data_train)
for i, seq in enumerate(target_sequences_train):
    for j in range(1, len(seq)):
        decoder_target_data_train[i, j - 1] = seq[j] # Target is shifted by one

encoder_input_data_dev = pad_sequences(input_sequences_dev, maxlen=max_input_len, padding='post')
decoder_input_data_dev = pad_sequences(target_sequences_dev, maxlen=max_target_len, padding='post')
decoder_target_data_dev = np.zeros_like(decoder_input_data_dev)
for i, seq in enumerate(target_sequences_dev):
    for j in range(1, len(seq)):
        decoder_target_data_dev[i, j - 1] = seq[j]

encoder_input_data_test = pad_sequences(input_sequences_test, maxlen=max_input_len, padding='post')
decoder_input_data_test = pad_sequences(target_sequences_test, maxlen=max_target_len, padding='post')
decoder_target_data_test = np.zeros_like(decoder_input_data_test)
for i, seq in enumerate(target_sequences_test):
    for j in range(1, len(seq)):
        decoder_target_data_test[i, j - 1] = seq[j]

print("Padded input data (train) shape:", encoder_input_data_train.shape)
print("Padded target input data (train) shape:", decoder_input_data_train.shape)
print("Padded target output data (train) shape:", decoder_target_data_train.shape)

## Step 2: Defining model class nad testing a sample

In [None]:
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Embedding, SimpleRNN, LSTM, GRU, Dense

def build_vanilla_rnn_seq2seq(
    input_vocab_size,
    target_vocab_size,
    embedding_dim,
    encoder_units,
    decoder_units,
    cell_type='rnn',
    num_encoder_layers=1,
    num_decoder_layers=1,
    max_input_len=None,  # Needed for the input layer shape
    max_target_len=None  # Not strictly needed here but good for context
):
    """
    Builds a vanilla RNN-based sequence-to-sequence model.

    Args:
        input_vocab_size (int): Size of the input character vocabulary.
        target_vocab_size (int): Size of the target character vocabulary.
        embedding_dim (int): Dimension of the character embeddings.
        encoder_units (int): Number of hidden units in the encoder RNN(s).
        decoder_units (int): Number of hidden units in the decoder RNN(s).
        cell_type (str): Type of RNN cell to use ('rnn', 'lstm', 'gru'). Defaults to 'rnn'.
        num_encoder_layers (int): Number of layers in the encoder RNN. Defaults to 1.
        num_decoder_layers (int): Number of layers in the decoder RNN. Defaults to 1.
        max_input_len (int, optional): Maximum length of the input sequences. Defaults to None.
        max_target_len (int, optional): Maximum length of the target sequences. Defaults to None.

    Returns:
        tf.keras.Model: The built sequence-to-sequence model.
    """

    # ---------------- Encoder ----------------
    encoder_inputs = Input(shape=(max_input_len,))
    encoder_embedding = Embedding(input_vocab_size, embedding_dim)(encoder_inputs)
    encoder_outputs = encoder_embedding
    encoder_states = []

    for i in range(num_encoder_layers):
        if cell_type == 'rnn':
            encoder_rnn = SimpleRNN(encoder_units, return_state=True, return_sequences=False, name=f'encoder_rnn_{i}')
            encoder_outputs, state_h = encoder_rnn(encoder_outputs)
            encoder_states.append(state_h)
        elif cell_type == 'lstm':
            encoder_lstm = LSTM(encoder_units, return_state=True, return_sequences=False, name=f'encoder_lstm_{i}')
            encoder_outputs, state_h, state_c = encoder_lstm(encoder_outputs)
            encoder_states.extend([state_h, state_c])
        elif cell_type == 'gru':
            encoder_gru = GRU(encoder_units, return_state=True, return_sequences=False, name=f'encoder_gru_{i}')
            encoder_outputs, state_h = encoder_gru(encoder_outputs)
            encoder_states.append(state_h)
        else:
            raise ValueError("Invalid cell_type. Choose from 'rnn', 'lstm', or 'gru'.")

    # The final encoder state(s) will be used to initialize the decoder

    # ---------------- Decoder ----------------
    decoder_inputs = Input(shape=(None,))  # Length of target sequence is variable
    decoder_embedding = Embedding(target_vocab_size, embedding_dim)(decoder_inputs)
    decoder_outputs = decoder_embedding

    decoder_states = encoder_states[:num_decoder_layers * (2 if cell_type == 'lstm' else 1)] # Initialize with encoder's final states

    for i in range(num_decoder_layers):
        if cell_type == 'rnn':
            decoder_rnn = SimpleRNN(decoder_units, return_sequences=True, return_state=True, name=f'decoder_rnn_{i}')
            decoder_outputs, state_h = decoder_rnn(decoder_outputs, initial_state=decoder_states[i])
            decoder_states[i] = state_h
        elif cell_type == 'lstm':
            decoder_lstm = LSTM(decoder_units, return_sequences=True, return_state=True, name=f'decoder_lstm_{i}')
            decoder_outputs, state_h, state_c = decoder_lstm(decoder_outputs, initial_state=decoder_states[2*i:2*i+2])
            decoder_states[2*i] = state_h
            decoder_states[2*i+1] = state_c
        elif cell_type == 'gru':
            decoder_gru = GRU(decoder_units, return_sequences=True, return_state=True, name=f'decoder_gru_{i}')
            decoder_outputs, state_h = decoder_gru(decoder_outputs, initial_state=decoder_states[i])
            decoder_states[i] = state_h
        else:
            raise ValueError("Invalid cell_type. Choose from 'rnn', 'lstm', or 'gru'.")

    # Output layer
    decoder_dense = Dense(target_vocab_size, activation='softmax')(decoder_outputs)

    # Define the model
    model = Model([encoder_inputs, decoder_inputs], decoder_dense)
    return model

# ---------------- Instantiate and Test the Model ----------------
# Assuming you have the vocabulary sizes and max lengths from your preprocessing
embedding_dim = 64
encoder_units = 128
decoder_units = 128
num_encoder_layers = 1
num_decoder_layers = 1
cell_type = 'lstm'  # You can change this to 'rnn' or 'gru'

model = build_vanilla_rnn_seq2seq(
    input_vocab_size=input_vocab_size,
    target_vocab_size=target_vocab_size,
    embedding_dim=embedding_dim,
    encoder_units=encoder_units,
    decoder_units=decoder_units,
    cell_type=cell_type,
    num_encoder_layers=num_encoder_layers,
    num_decoder_layers=num_decoder_layers,
    max_input_len=max_input_len,
    max_target_len=max_target_len
)

model.summary()

# Step 3: Wandb and Sweep across hyper-parameters
please replace the key below with your wandb key while running

In [21]:
import wandb
!pip install wandb -q
wandb.login(key = "39ded67ba2685c0f85010b40d27298a712244e64")


[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mmm21b051[0m ([33mmm21b051-iitmaana[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [None]:
sweep_config = {
    'program': 'your_kaggle_notebook.ipynb',  # Replace with your notebook name
    'method': 'bayes',  # Or 'grid', 'random'
    'metric': {
        'name': 'val_accuracy',
        'goal': 'maximize'
    },
    'parameters': {
        'embedding_dim': {
            'values': [16, 32, 64, 128, 256]
        },
        'num_encoder_layers': {
            'values': [1, 2]
        },
        'num_decoder_layers': {
            'values': [1, 2]
        },
        'decoder_units': {
            'values': [64, 128, 256]
        },
        'cell_type': {
            'values': ['gru']
        },
        'dropout_rate': {
            'values': [0.0, 0.2]
        },
        'optimizer': {
            'values': ['adam', 'rmsprop']
        },
        'learning_rate': {
            'values': [0.001, 0.0001]
        },
        'batch_size': {
            'values': [64, 128]
        }
    }
}

In [None]:
import wandb
from wandb.integration.keras import WandbCallback
from wandb.integration.keras import WandbMetricsLogger, WandbModelCheckpoint
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Embedding, SimpleRNN, LSTM, GRU, Dense, Dropout, Layer
from tensorflow.keras.optimizers import Adam, RMSprop
from tensorflow.keras.callbacks import EarlyStopping
tf.config.run_functions_eagerly(True)

class ExpandDims(Layer):
    def __init__(self, axis, **kwargs):
        super(ExpandDims, self).__init__(**kwargs)
        self.axis = axis

    def call(self, inputs):
        return tf.expand_dims(inputs, axis=self.axis)

    def get_config(self):
        config = super().get_config()
        config.update({'axis': self.axis})
        return config

def build_and_train_model(config=None):
    with wandb.init(config=config):
        config = wandb.config

        # ---------------- Encoder ----------------
        encoder_inputs = Input(shape=(max_input_len,))
        encoder_embedding = Embedding(input_vocab_size, config.embedding_dim)(encoder_inputs)
        encoder_outputs = encoder_embedding
        encoder_states = []
        encoder_units = config.decoder_units
        encoder_final_state = None

        for i in range(config.num_encoder_layers):
            if config.cell_type == 'rnn':
                encoder_rnn = SimpleRNN(encoder_units, return_state=True, return_sequences=True, name=f'encoder_rnn_{i}')
            elif config.cell_type == 'lstm':
                encoder_rnn = LSTM(encoder_units, return_state=True, return_sequences=True, name=f'encoder_lstm_{i}')
            elif config.cell_type == 'gru':
                encoder_rnn = GRU(encoder_units, return_state=True, return_sequences=True, name=f'encoder_gru_{i}')

            encoder_outputs, *state = encoder_rnn(encoder_outputs)
            if config.dropout_rate > 0:
                encoder_outputs = Dropout(config.dropout_rate)(encoder_outputs)
            encoder_states.append(state)
            encoder_final_state = state

        if config.cell_type == 'gru':
            encoder_final_state = encoder_final_state[0]

        # ---------------- Decoder ----------------
        decoder_inputs = Input(shape=(None,))
        decoder_embedding = Embedding(target_vocab_size, config.embedding_dim)(decoder_inputs)
        decoder_outputs = decoder_embedding
        # Wrap tf.expand_dims in a Keras Layer
        expand_dims_layer = ExpandDims(axis=0)
        decoder_initial_state = expand_dims_layer(encoder_final_state)
        decoder_all_states = []

        for i in range(config.num_decoder_layers):
            if config.cell_type == 'rnn':
                decoder_rnn = SimpleRNN(config.decoder_units, return_sequences=True, return_state=True, name=f'decoder_rnn_{i}')
            elif config.cell_type == 'lstm':
                decoder_rnn = LSTM(config.decoder_units, return_sequences=True, return_state=True, name=f'decoder_lstm_{i}')
            elif config.cell_type == 'gru':
                decoder_rnn = GRU(config.decoder_units, return_sequences=True, return_state=True, name=f'decoder_gru_{i}')

            decoder_outputs, *state = decoder_rnn(decoder_outputs, initial_state=decoder_initial_state if i == 0 else decoder_all_states[-1])
            if config.dropout_rate > 0:
                decoder_outputs = Dropout(config.dropout_rate)(decoder_outputs)
            decoder_all_states.append(state[0] if config.cell_type in ['lstm', 'gru'] else state)
            if i > 0:
                decoder_initial_state = decoder_all_states[-1]

        decoder_dense = Dense(target_vocab_size, activation='softmax')(decoder_outputs)
        model = Model([encoder_inputs, decoder_inputs], decoder_dense)

        if config.optimizer == 'adam':
            optimizer = Adam(learning_rate=config.learning_rate)
        elif config.optimizer == 'rmsprop':
            optimizer = RMSprop(learning_rate=config.learning_rate)

        model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy'])

        # Prepare target data for training (teacher forcing)
        decoder_input_train = decoder_input_data_train[:, :-1]
        decoder_target_train = decoder_target_data_train[:, 1:]
        decoder_target_train_one_hot = tf.one_hot(decoder_target_train, depth=target_vocab_size)

        # Prepare validation data similarly
        decoder_input_dev = decoder_input_data_dev[:, :-1]
        decoder_target_dev = decoder_target_data_dev[:, 1:]
        decoder_target_dev_one_hot = tf.one_hot(decoder_target_dev, depth=target_vocab_size)

        # Print shapes before fitting
        print("Encoder input data shape:", encoder_input_data_train.shape)
        print("Decoder input train shape:", decoder_input_train.shape)
        print("Decoder target train one-hot shape:", decoder_target_train_one_hot.shape)
        print("Decoder input dev shape:", decoder_input_dev.shape)
        print("Decoder target dev one-hot shape:", decoder_target_dev_one_hot.shape)


        callbacks = [
            WandbMetricsLogger(log_freq='epoch'),
            WandbModelCheckpoint(
                filepath="best_model_{epoch:02d}-{val_accuracy:.4f}.h5",
                monitor='val_accuracy',
                save_best_only=True,
                save_weights_only=False
            ),
        ]

        history = model.fit(
            [encoder_input_data_train, decoder_input_train],
            decoder_target_train_one_hot,
            batch_size=config.batch_size,
            epochs=10,
            validation_data=(
                [encoder_input_data_dev, decoder_input_dev],
                decoder_target_dev_one_hot
            ),
            callbacks=callbacks
        )
        return history

In [None]:
# # Initialize WandB and run the sweep
sweep_id = wandb.sweep(sweep_config, project="dakshina-transliteration")
wandb.agent(sweep_id, build_and_train_model, count=50) # Adjust the 'count' for the number of runs


## Step 4: Training again and testing on the model with best configs from sweep

In [None]:
import wandb
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Embedding, SimpleRNN, LSTM, GRU, Dense, Dropout, Layer
from tensorflow.keras.optimizers import Adam, RMSprop
from tensorflow.keras.callbacks import EarlyStopping
import numpy as np
import os

tf.config.run_functions_eagerly(True)

input_vocab_size = len(input_tokenizer.word_index) + 1
target_vocab_size = len(target_tokenizer.word_index) + 1

# --- Best configuration from your sweep ---
best_config = {
    'batch_size': 64,
    'cell_type': 'lstm',
    'decoder_units': 128,
    'dropout_rate': 0.2,
    'embedding_dim': 64,
    'learning_rate': 0.001,
    'num_decoder_layers': 2,
    'num_encoder_layers': 1,
    'optimizer': 'adam'
}

class ExpandDims(Layer):
    def __init__(self, axis, **kwargs):
        super(ExpandDims, self).__init__(**kwargs)
        self.axis = axis

    def call(self, inputs):
        return tf.expand_dims(inputs, axis=self.axis)

    def get_config(self):
        config = super().get_config()
        config.update({'axis': self.axis})
        return config

def build_and_train_model(config):
    # ---------------- Encoder ----------------
    encoder_inputs = Input(shape=(max_input_len,))
    encoder_embedding = Embedding(input_vocab_size, config['embedding_dim'])(encoder_inputs)
    encoder_outputs = encoder_embedding
    encoder_states = []
    encoder_units = config['decoder_units']
    encoder_final_state = None

    for i in range(config['num_encoder_layers']):
        if config['cell_type'] == 'rnn':
            encoder_rnn = SimpleRNN(encoder_units, return_state=True, return_sequences=True, name=f'encoder_rnn_{i}')
        elif config['cell_type'] == 'lstm':
            encoder_rnn = LSTM(encoder_units, return_state=True, return_sequences=True, name=f'encoder_lstm_{i}')
        elif config['cell_type'] == 'gru':
            encoder_rnn = GRU(encoder_units, return_state=True, return_sequences=True, name=f'encoder_gru_{i}')

        encoder_outputs, *state = encoder_rnn(encoder_outputs)
        if config['dropout_rate'] > 0:
            encoder_outputs = Dropout(config['dropout_rate'])(encoder_outputs)
        encoder_states.append(state)
        encoder_final_state = state

    if config['cell_type'] == 'gru':
        encoder_final_state = encoder_final_state[0]
    elif config['cell_type'] == 'lstm':
        encoder_final_state = encoder_final_state  # Keep both h and c states

    # ---------------- Decoder ----------------
    decoder_inputs = Input(shape=(None,))
    decoder_embedding = Embedding(target_vocab_size, config['embedding_dim'])(decoder_inputs)
    decoder_outputs = decoder_embedding
    decoder_initial_state = encoder_final_state
    decoder_all_states = []

    for i in range(config['num_decoder_layers']):
        if config['cell_type'] == 'rnn':
            decoder_rnn = SimpleRNN(config['decoder_units'], return_sequences=True, return_state=True, name=f'decoder_rnn_{i}')
        elif config['cell_type'] == 'lstm':
            decoder_rnn = LSTM(config['decoder_units'], return_sequences=True, return_state=True, name=f'decoder_lstm_{i}')
        elif config['cell_type'] == 'gru':
            decoder_rnn = GRU(config['decoder_units'], return_sequences=True, return_state=True, name=f'decoder_gru_{i}')

        decoder_outputs, *state = decoder_rnn(decoder_outputs, initial_state=decoder_initial_state if i == 0 else decoder_all_states[-1])
        if config['dropout_rate'] > 0:
            decoder_outputs = Dropout(config['dropout_rate'])(decoder_outputs)
        decoder_all_states.append(state)
        if i > 0:
            decoder_initial_state = state

    decoder_dense = Dense(target_vocab_size, activation='softmax')(decoder_outputs)
    model = Model([encoder_inputs, decoder_inputs], decoder_dense)

    optimizer_name = config['optimizer'].lower()
    if optimizer_name == 'adam':
        optimizer = Adam(learning_rate=config['learning_rate'])
    elif optimizer_name == 'rmsprop':
        optimizer = RMSprop(learning_rate=config['learning_rate'])

    model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy'])
    return model

# Prepare target data for training (teacher forcing)
decoder_input_train = decoder_input_data_train[:, :-1]
decoder_target_train = decoder_target_data_train[:, 1:]
decoder_target_train_one_hot = tf.one_hot(decoder_target_train, depth=target_vocab_size)

# Prepare validation data similarly
decoder_input_dev = decoder_input_data_dev[:, :-1]
decoder_target_dev = decoder_target_data_dev[:, 1:]
decoder_target_dev_one_hot = tf.one_hot(decoder_target_dev, depth=target_vocab_size)

decoder_input_test = decoder_input_data_test[:, :-1]
decoder_target_test = decoder_target_data_test[:, 1:]
decoder_target_test_one_hot = tf.one_hot(decoder_target_test, depth=target_vocab_size)


# --- Build and train the best model ---
tf.config.list_physical_devices('GPU') # Ensure GPUs are visible
with tf.device('/GPU:0'): # Train on the first GPU (you can adjust this)
    best_model = build_and_train_model(best_config)

    history = best_model.fit(
            [encoder_input_data_train, decoder_input_train],
            decoder_target_train_one_hot,
            batch_size=best_config['batch_size'],
            epochs=10,
            validation_data=(
                [encoder_input_data_dev, decoder_input_dev],
                decoder_target_dev_one_hot
            ),
            callbacks=[EarlyStopping(monitor='val_accuracy', patience=5, restore_best_weights=True)] # Added EarlyStopping
        )
    print("Best model trained.")

decoder_target_test_one_hot = tf.one_hot(decoder_target_test, depth=target_vocab_size)

# --- Evaluate the best model on the test set ---
loss, accuracy = best_model.evaluate(
    [encoder_input_data_test, decoder_input_test[:, :-1]],
    decoder_target_test_one_hot[:, 1:],
    batch_size=best_config['batch_size'],
    verbose=0
)
print(f"\nTest Loss: {loss:.4f}")
print(f"Test Accuracy: {accuracy:.4f}")

# --- Prediction Function (Adjusted for the trained model) ---
def translate_sentence(input_sequence, model, input_tokenizer, target_tokenizer, max_target_len, config):
    encoder_inputs = model.input[0]
    encoder_embedding_layer = model.get_layer(index=1) # Assuming embedding is the second layer
    encoder_output_layer = [layer for layer in model.layers if 'encoder' in layer.name and isinstance(layer, (SimpleRNN, LSTM, GRU))][-1] # Get the last encoder RNN layer

    encoder_model = tf.keras.Model(encoder_inputs, encoder_output_layer.output) # Output sequences

    decoder_inputs = model.input[1]
    decoder_embedding_layer = [layer for layer in model.layers if 'embedding' in layer.name and layer.input is decoder_inputs][0]
    decoder_rnn_layers = [layer for layer in model.layers if 'decoder' in layer.name and isinstance(layer, (SimpleRNN, LSTM, GRU))]
    decoder_dense_layer = [layer for layer in model.layers if isinstance(layer, Dense) and layer.output_shape[-1] == target_vocab_size][0]

    # Reconstruct the decoder model for prediction
    decoder_state_input = [tf.keras.layers.Input(shape=(config['decoder_units'],), name=f'decoder_in_{i}') for i in range(len(decoder_rnn_layers) * (1 if config['cell_type'] in ['rnn', 'gru'] else 2))]
    decoder_embedding_out = decoder_embedding_layer(decoder_inputs)
    decoder_outputs = decoder_embedding_out
    decoder_states = decoder_state_input

    next_decoder_states = []
    for i, rnn_layer in enumerate(decoder_rnn_layers):
        if config['cell_type'] == 'lstm':
            decoder_outputs, state_h, state_c = rnn_layer(decoder_outputs, initial_state=[decoder_states[i*2], decoder_states[i*2+1]])
            next_decoder_states.extend([state_h, state_c])
        else:
            decoder_outputs, state = rnn_layer(decoder_outputs, initial_state=[decoder_states[i]])
            next_decoder_states.append(state)

    decoder_outputs = decoder_dense_layer(decoder_outputs)
    decoder_model = tf.keras.Model([decoder_inputs] + decoder_state_input, [decoder_outputs] + next_decoder_states)


    # Prediction process
    input_seq = tf.expand_dims(input_sequence, 0)
    encoder_out = encoder_model.predict(input_seq)

    decoder_input = tf.expand_dims([target_tokenizer.word_index['<start>']], 0)
    decoder_states_val = [tf.zeros((1, config['decoder_units'])) for _ in range(len(decoder_rnn_layers) * (1 if config['cell_type'] in ['rnn', 'gru'] else 2))] # Initialize with zeros

    if config['cell_type'] == 'lstm':
        # Initialize LSTM states with encoder output (assuming compatible shape)
        for i in range(config['num_decoder_layers']):
            decoder_states_val[i*2] = encoder_out[1][0] # h state
            decoder_states_val[i*2+1] = encoder_out[2][0] # c state
    else: # GRU or RNN
        for i in range(config['num_decoder_layers']):
            decoder_states_val[i] = encoder_out[1][0]


    decoded_sentence = []
    for _ in range(max_target_len):
        decoder_outputs, *states = decoder_model.predict([decoder_input] + decoder_states_val)
        predicted_id = np.argmax(decoder_outputs[0, -1, :])
        predicted_char = target_tokenizer.index_word.get(predicted_id)

        if predicted_char == '<end>':
            break
        if predicted_char:
            decoded_sentence.append(predicted_char)

        decoder_input = tf.expand_dims([predicted_id], 0)
        decoder_states_val = list(states)

    return ' '.join(decoded_sentence)

# --- Generate and Save All Test Set Predictions ---
predictions_folder = "predictions_vanilla"
os.makedirs(predictions_folder, exist_ok=True)
all_predictions = []

for i in range(len(encoder_input_data_test)):
    input_sequence = encoder_input_data_test[i]
    predicted_translation = translate_sentence(
        input_sequence,
        best_model,
        input_tokenizer,
        target_tokenizer,
        max_target_len,
        best_config
    )
    all_predictions.append(predicted_translation)
    filename = os.path.join(predictions_folder, f"prediction_{i}.txt")
    with open(filename, "w") as f:
        f.write(predicted_translation)

print(f"\nAll test set predictions saved to '{predictions_folder}'.")

# --- Sample Inputs and Predictions (Creative Grid) ---
num_samples_to_show = 5
sample_indices = np.random.choice(len(encoder_input_data_test), size=num_samples_to_show, replace=False)

print("\n--- Sample Test Inputs and Predictions ---")
print("-------------------------------------------------------------------------")
print(f"{'Input':<20} | {'Predicted':<20} | {'Reference':<20}")
print("-------------------------------------------------------------------------")

for index in sample_indices:
    input_sequence = encoder_input_data_test[index]
    predicted_translation = translate_sentence(
        input_sequence,
        best_model,
        input_tokenizer,
        target_tokenizer,
        max_target_len,
        best_config
    )

    reference_tokens = [target_tokenizer.index_word.get(i, '<unk>') for i in decoder_target_test[index] if i != 0]
    reference_translation = ' '.join(reference_tokens).replace('<start>', '').replace('<end>', '').strip()

    input_tokens = [input_tokenizer.index_word.get(i, '<unk>') for i in input_sequence if i != 0]
    input_text = ' '.join(input_tokens)

    print(f"{input_text[:20]:<20} | {predicted_translation[:20]:<20} | {reference_translation[:20]:<20}")

print("-------------------------------------------------------------------------")

# Attention based

## step 1: define the attention based model and sweep across hyper parameters

In [None]:
import wandb
from wandb.integration.keras import WandbCallback
from wandb.integration.keras import WandbMetricsLogger, WandbModelCheckpoint
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Embedding, LSTM, GRU, Dense, Dropout, Attention, Concatenate, Layer
from tensorflow.keras.optimizers import Adam, RMSprop

class ReshapeInitialState(Layer):
    def __init__(self, target_shape, **kwargs):
        super().__init__(**kwargs)
        self.target_shape = target_shape

    def call(self, inputs):
        return tf.reshape(inputs, self.target_shape)

    def get_config(self):
        config = super().get_config()
        config.update({'target_shape': self.target_shape})
        return config

def build_and_train_model(config=None):
    with wandb.init(config=config):
        config = wandb.config

        # ---------------- Encoder ----------------
        encoder_inputs = Input(shape=(max_input_len,))
        encoder_embedding = Embedding(input_vocab_size, config.embedding_dim)(encoder_inputs)
        encoder_outputs = encoder_embedding # Initialize encoder_outputs

        if config.cell_type == 'lstm':
            encoder_outputs, encoder_state_h, encoder_state_c = LSTM(config.encoder_units, return_sequences=True, return_state=True, name='encoder_lstm_0')(encoder_outputs)
            encoder_states = [encoder_state_h, encoder_state_c]
        elif config.cell_type == 'gru':
            encoder_outputs, encoder_state = GRU(config.encoder_units, return_sequences=True, return_state=True, name='encoder_gru_0')(encoder_outputs)
            encoder_states = [encoder_state]
        else:  # SimpleRNN
            encoder_outputs, encoder_state = SimpleRNN(config.encoder_units, return_sequences=True, return_state=True, name='encoder_rnn_0')(encoder_embedding)
            encoder_states = [encoder_state]

        if config.dropout_rate > 0:
            encoder_outputs = Dropout(config.dropout_rate)(encoder_outputs)

        # ---------------- Decoder ----------------
        decoder_inputs = Input(shape=(None,))
        decoder_embedding = Embedding(target_vocab_size, config.embedding_dim)(decoder_inputs)
        decoder_outputs = decoder_embedding # Initialize decoder_outputs

        if config.cell_type == 'lstm':
            decoder_lstm = LSTM(config.decoder_units, return_sequences=True, return_state=True, name='decoder_lstm_0')
        elif config.cell_type == 'gru':
            decoder_gru = GRU(config.decoder_units, return_sequences=True, return_state=True, name='decoder_gru_0')
        else:  # SimpleRNN
            decoder_rnn = SimpleRNN(config.decoder_units, return_sequences=True, return_state=True, name='decoder_rnn_0')

        # Initial decoder state is the final encoder state
        decoder_initial_state = encoder_states
        if config.cell_type == 'gru':
            decoder_initial_state = [ReshapeInitialState((-1, config.decoder_units))(state) for state in decoder_initial_state] # Ensure correct shape for GRU

        # Apply the decoder RNN
        if config.cell_type == 'lstm':
            decoder_outputs, _, _ = decoder_lstm(decoder_outputs, initial_state=decoder_initial_state)
        else:
            decoder_outputs, _ = (decoder_gru if config.cell_type == 'gru' else decoder_rnn)(decoder_outputs, initial_state=decoder_initial_state)

        if config.dropout_rate > 0:
            decoder_outputs = Dropout(config.dropout_rate)(decoder_outputs)

        # ---------------- Attention Mechanism ----------------
        attention = Attention()([decoder_outputs, encoder_outputs])
        context_vector = attention

        # Concatenate the context vector with the decoder output
        decoder_concat_input = Concatenate(axis=-1)([decoder_outputs, context_vector])

        # Final dense layer
        decoder_dense = Dense(target_vocab_size, activation='softmax')(decoder_concat_input)

        # Define the model
        model = Model([encoder_inputs, decoder_inputs], decoder_dense)

        if config.optimizer == 'adam':
            optimizer = Adam(learning_rate=config.learning_rate)
        elif config.optimizer == 'rmsprop':
            optimizer = RMSprop(learning_rate=config.learning_rate)

        model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy'])

        # Prepare target data for training (teacher forcing)
        decoder_input_train = decoder_input_data_train[:, :-1]
        decoder_target_train = decoder_target_data_train[:, 1:]
        decoder_target_train_one_hot = tf.one_hot(decoder_target_train, depth=target_vocab_size, axis=-1) # Use axis=-1

        decoder_input_dev = decoder_input_data_dev[:, :-1]
        decoder_target_dev = decoder_target_data_dev[:, 1:]
        decoder_target_dev_one_hot = tf.one_hot(decoder_target_dev, depth=target_vocab_size, axis=-1) # Use axis=-1

        # Print shapes before fitting
        print("Encoder input data shape:", encoder_input_data_train.shape)
        print("Decoder input train shape:", decoder_input_train.shape)
        print("Decoder target train one-hot shape:", decoder_target_train_one_hot.shape)
        print("Decoder input dev shape:", decoder_input_data_dev.shape)
        print("Decoder target dev one-hot shape:", decoder_target_dev_one_hot.shape)

        callbacks = [
            WandbMetricsLogger(log_freq='epoch'),
            WandbModelCheckpoint(
                filepath="attention_best_model_{epoch:02d}-{val_accuracy:.4f}.h5",
                monitor='val_accuracy',
                save_best_only=True,
                save_weights_only=False
            ),
        ]

        # Train the model
        history = model.fit(
            [encoder_input_data_train, decoder_input_train],
            decoder_target_train_one_hot,
            batch_size=config.batch_size,
            epochs=config.epochs,
            validation_data=(
                [encoder_input_data_dev, decoder_input_dev],
                decoder_target_dev_one_hot
            ),
            callbacks=callbacks
        )
        return history

if __name__ == '__main__':
    # Define your sweep configuration for the attention model
    attention_sweep_config = {
        'method': 'bayes',
        'metric': {
            'name': 'val_accuracy',
            'goal': 'maximize'
        },
        'parameters': {
            'embedding_dim': {'values': [128, 256]},
            'encoder_units': {'values': [128, 256]},
            'decoder_units': {'values': [128, 256]},
            'dropout_rate': {'values': [0.0, 0.2]},
            'cell_type': {'values': ['lstm']}, # Let's focus on LSTM and GRU for attention
            'optimizer': {'values': ['adam', 'rmsprop']},
            'learning_rate': {'values': [1e-3, 1e-4]},
            'batch_size': {'values': [64, 128]},
            'epochs': {'value': 10} # Set a fixed number of epochs for this sweep
        }
    }

    # Initialize the sweep for the attention model
    attention_sweep_id = wandb.sweep(attention_sweep_config, project="dakshina-transliteration-attention", entity="mm21b051-iitmaana")

    # Run the WandB agent to start the sweep
    wandb.agent(attention_sweep_id, function=build_and_train_model,count=20)


## Testing best attention based model and generating heat maps

In [None]:
import wandb
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Embedding, LSTM, GRU, Dense, Dropout, Attention, Concatenate, Layer
from tensorflow.keras.optimizers import Adam, RMSprop
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# Define the ReshapeInitialState layer
class ReshapeInitialState(Layer):
    def __init__(self, target_shape, **kwargs):
        super().__init__(**kwargs)
        self.target_shape = target_shape

    def call(self, inputs):
        return tf.reshape(inputs, self.target_shape)

    def get_config(self):
        config = super().get_config()
        config.update({'target_shape': self.target_shape})
        return config

def build_and_train_model(config, encoder_input_data_train, decoder_input_data_train, decoder_target_data_train,
                          encoder_input_data_dev, decoder_input_data_dev, decoder_target_data_dev):
    with tf.device('/GPU:0'): # Ensure training happens on GPU
        encoder_inputs = Input(shape=(max_input_len,))
        encoder_embedding = Embedding(input_vocab_size, config['embedding_dim'])(encoder_inputs)

        if config['cell_type'] == 'lstm':
            encoder_outputs, encoder_state_h, encoder_state_c = LSTM(config['encoder_units'], return_sequences=True, return_state=True, name='encoder_lstm_0')(encoder_embedding)
            encoder_states = [encoder_state_h, encoder_state_c]
        elif config['cell_type'] == 'gru':
            encoder_outputs, encoder_state = GRU(config['encoder_units'], return_sequences=True, return_state=True, name='encoder_gru_0')(encoder_embedding)
            encoder_states = [encoder_state]
        else:
            encoder_outputs, encoder_state = SimpleRNN(config['encoder_units'], return_sequences=True, return_state=True, name='encoder_rnn_0')(encoder_embedding)
            encoder_states = [encoder_state]

        if config['dropout_rate'] > 0:
            encoder_outputs = Dropout(config['dropout_rate'])(encoder_outputs)

        decoder_inputs = Input(shape=(None,))
        decoder_embedding = Embedding(target_vocab_size, config['embedding_dim'])(decoder_inputs)

        if config['cell_type'] == 'lstm':
            decoder_lstm = LSTM(config['decoder_units'], return_sequences=True, return_state=True, name='decoder_lstm_0')
        elif config['cell_type'] == 'gru':
            decoder_gru = GRU(config['decoder_units'], return_sequences=True, return_state=True, name='decoder_gru_0')
        else:
            decoder_rnn = SimpleRNN(config['decoder_units'], return_sequences=True, return_state=True, name='encoder_rnn_0')(decoder_embedding)

        decoder_initial_state = encoder_states
        if config['cell_type'] == 'gru':
            decoder_initial_state = [ReshapeInitialState((-1, config['decoder_units']))(state) for state in decoder_initial_state]

        decoder_outputs = decoder_embedding

        if config['cell_type'] == 'lstm':
            decoder_outputs, _, _ = decoder_lstm(decoder_outputs, initial_state=decoder_initial_state)
        else:
            decoder_outputs, _ = (decoder_gru if config['cell_type'] == 'gru' else decoder_rnn)(decoder_outputs, initial_state=decoder_initial_state)

        if config['dropout_rate'] > 0:
            decoder_outputs = Dropout(config['dropout_rate'])(decoder_outputs)

        attention = Attention(name='attention_layer')([decoder_outputs, encoder_outputs])
        context_vector = attention

        decoder_concat_input = Concatenate(axis=-1)([decoder_outputs, context_vector])
        decoder_dense = Dense(target_vocab_size, activation='softmax')(decoder_concat_input)

        model = Model([encoder_inputs, decoder_inputs], decoder_dense)

        if config['optimizer'] == 'adam':
            optimizer = Adam(learning_rate=config['learning_rate'])
        elif config['optimizer'] == 'rmsprop':
            optimizer = RMSprop(learning_rate=config['learning_rate'])

        model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy'])

        # Prepare target data for training (teacher forcing)
        decoder_input_train = decoder_input_data_train[:, :-1]
        decoder_target_train = decoder_target_data_train[:, 1:]
        decoder_target_train_one_hot = tf.one_hot(decoder_target_train, depth=target_vocab_size, axis=-1)

        decoder_input_dev = decoder_input_data_dev[:, :-1]
        decoder_target_dev = decoder_target_data_dev[:, 1:]
        decoder_target_dev_one_hot = tf.one_hot(decoder_target_dev, depth=target_vocab_size, axis=-1)

        print("Encoder input data shape:", encoder_input_data_train.shape)
        print("Decoder input train shape:", decoder_input_train.shape)
        print("Decoder target train one-hot shape:", decoder_target_train_one_hot.shape)
        print("Decoder input dev shape:", decoder_input_data_dev.shape)
        print("Decoder target dev one-hot shape:", decoder_target_dev_one_hot.shape)

        history = model.fit(
            [encoder_input_data_train, decoder_input_train],
            decoder_target_train_one_hot,
            batch_size=config['batch_size'],
            epochs=10, # You can adjust the number of epochs as needed
            validation_data=(
                [encoder_input_data_dev, decoder_input_dev],
                decoder_target_dev_one_hot
            ),
            verbose=1
        )
        return model, history

def decode_sequence(input_seq, encoder_model, decoder_model, target_vocab_inv, max_decoder_seq_length):
    with tf.device('/GPU:0'): # Ensure inference happens on GPU
        # Encode the input sequence to get the encoder state vector.
        encoder_outputs = encoder_model.predict(input_seq)

        # Generate empty target sequence of length 1.
        target_seq = np.zeros((1, 1))
        # Populate the first character of target sequence with the start character.
        target_seq[0, 0] = target_vocab['<start>']

        # Sampling loop for a batch of sequences
        # (to simplify, here we assume a batch of size 1).
        stop_condition = False
        decoded_sentence = ''
        attention_weights_list = []  # Store attention weights

        while not stop_condition:
            output_tokens, h, c, attention_weights = decoder_model.predict([target_seq, encoder_outputs])
            attention_weights_list.append(attention_weights) #append attention weights

            # Sample a token.
            sampled_token_index = np.argmax(output_tokens[0, -1, :])
            sampled_char = target_vocab_inv[sampled_token_index]
            decoded_sentence += sampled_char

            # Exit condition: either hit max length
            # or find stop character.
            if (sampled_char == '<end>' or
               len(decoded_sentence) > max_decoder_seq_length):
                stop_condition = True

            # Update the target sequence (of length 1).
            target_seq = np.zeros((1, 1))
            target_seq[0, 0] = sampled_token_index

        return decoded_sentence, attention_weights_list

def visualize_attention(input_text, predicted_text, attention_weights_list, input_tokens, target_tokens):
    """
    Visualizes attention weights for a given input-output pair.

    Args:
        input_text (str): The original input text.
        predicted_text (str): The predicted output text.
        attention_weights_list (list of numpy arrays): Attention weights for each decoder step.
        input_tokens (list): List of input tokens
        target_tokens (list): List of target tokens
    """
    # Remove <start> and <end> from target tokens and predicted text
    target_tokens = target_tokens[1:-1]
    predicted_text = predicted_text.replace("<start>", "").replace("<end>", "")
    attention_weights_list = attention_weights_list[:len(predicted_text)] # Truncate attention weights to the length of the predicted text.

    # Ensure the predicted text and attention weights list have the same length.
    min_length = min(len(predicted_text), len(attention_weights_list))
    predicted_text = predicted_text[:min_length]
    attention_weights_list = attention_weights_list[:min_length]

    plt.figure(figsize=(10, 10))
    attention_matrix = np.array([aw[0][0] for aw in attention_weights_list])
    # Use predicted text and input text for labels
    sns.heatmap(attention_matrix, xticklabels=input_tokens, yticklabels=list(predicted_text), cmap='viridis')
    plt.xlabel('Input Tokens')
    plt.ylabel('Predicted Tokens')
    plt.title(f'Attention Heatmap for Input: "{input_text}"\nPredicted: "{predicted_text}"')
    plt.show()

if __name__ == '__main__':
  
    # 1. Define the best parameters
    best_params = {
        'batch_size': 64,
        'cell_type': 'lstm',
        'decoder_units': 256,
        'dropout_rate': 0.2,
        'embedding_dim': 128,
        'encoder_units': 256,
        'learning_rate': 0.001,
        'optimizer': 'adam'
    }
    reverse_input_char_index = {v: k for k, v in input_tokenizer.word_index.items()}
    target_vocab = {v: k for k, v in target_tokenizer.word_index.items()}

    target_start_token_index = target_tokenizer.word_index.get('<start>', None)
    target_end_token_index = target_tokenizer.word_index.get('<end>', None)

    target_vocab_inv = {index: char for char, index in target_vocab.items()}


    # 2. Build and train the model with the best parameters
    model, history = build_and_train_model(best_params, encoder_input_data_train, decoder_input_data_train, decoder_target_data_train,
                                          encoder_input_data_dev, decoder_input_data_dev, decoder_target_data_dev)

    # 3. Prepare data for testing
    decoder_input_test = decoder_input_data_test[:, :-1] #shape (num_samples, max_decoder_len - 1)
    decoder_target_test = decoder_target_data_test[:, 1:]  # shape (num_samples, max_decoder_len - 1)
    decoder_target_test_one_hot = tf.one_hot(decoder_target_test, depth=target_vocab_size, axis=-1) #shape (num_samples, max_decoder_len - 1, target_vocab_size)

    # 4. Make predictions on the test set
    with tf.device('/GPU:0'): # Ensure inference happens on GPU
        encoder_model = Model(model.input[0], model.layers[6].output) # Encoder model
        decoder_model = Model(
            [model.input[1], model.layers[6].output],  # decoder_inputs, encoder_outputs
            [model.layers[10].output, model.layers[14].output[0], model.layers[14].output[1], model.layers[11].output]) # decoder_dense, decoder_state_h, decoder_state_c, attention_weights

        predictions = []
        attention_weights = []
        input_texts = []
        target_texts = []
        input_tokens_list = []
        target_tokens_list = []

        for i in range(len(encoder_input_data_test)):
            input_seq = encoder_input_data_test[i:i+1]
            decoded_sentence, attention_weights_list = decode_sequence(input_seq, encoder_model, decoder_model, target_vocab_inv, max_decoder_len)
            predictions.append(decoded_sentence)
            attention_weights.append(attention_weights_list)

            #get input and target text for visualization
            input_text = ''.join([target_vocab_inv[idx] for idx in encoder_input_data_test[i] if idx != 0 and idx != target_vocab['<start>'] and idx != target_vocab['<end>']])
            target_text = ''.join([target_vocab_inv[idx] for idx in decoder_target_data_test[i] if idx != 0 and idx != target_vocab['<start>'] and idx != target_vocab['<end>']])
            input_texts.append(input_text)
            target_texts.append(target_text)

            #get input and target tokens
            input_tokens = [target_vocab_inv[idx] for idx in encoder_input_data_test[i] if idx != 0]
            target_tokens = [target_vocab_inv[idx] for idx in decoder_target_data_test[i] if idx != 0]
            input_tokens_list.append(input_tokens)
            target_tokens_list.append(target_tokens)


        # 5. Save predictions
        np.save('test_predictions.npy', predictions)

        # 6. Print sample predictions
        print("\nSample Predictions on Test Data:")
        for i in range(5):  # Print the first 5 predictions
            print(f"Input:  {input_texts[i]}")
            print(f"Target: {target_texts[i]}")
            print(f"Predicted: {predictions[i]}")

        # 7.  Visualize Attention Heatmaps for 9 inputs
        print("\nGenerating Attention Heatmaps...")
        plt.figure(figsize=(15, 15))
        for i in range(9):
            plt.subplot(3, 3, i + 1)
            visualize_attention(input_texts[i], predictions[i], attention_weights[i], input_tokens_list[i], target_tokens_list[i])
        plt.tight_layout()
        plt.show()


# Logging the interactive attention map

In [None]:
html_content="""
<!DOCTYPE html>
<html lang="en">
<head>
  <meta charset="UTF-8" />
  <title>Attention Visualization</title>
  <style>
    body {
      font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
      background: #f9f9fb;
      color: #333;
      padding: 2em;
      line-height: 1.6;
    }

    h2 {
      color: #444;
      margin-bottom: 0.5em;
    }

    .card {
      background: #fff;
      border-radius: 12px;
      box-shadow: 0 4px 12px rgba(0, 0, 0, 0.08);
      padding: 1.5em;
      margin-bottom: 2em;
      transition: transform 0.2s ease;
    }

    .card:hover {
      transform: translateY(-4px);
    }

    .token-container {
      margin-top: 1em;
      margin-bottom: 1.2em;
    }

    .label {
      font-weight: bold;
      margin-bottom: 0.3em;
      display: block;
    }

    .token {
      display: inline-block;
      margin: 0.1em 0.2em;
      padding: 0.4em 0.6em;
      border-radius: 6px;
      transition: background-color 0.3s ease;
    }

    .source-token {
      background-color: #f1f1f6;
      font-weight: 500;
    }

    .target-token {
      background-color: #dcefff;
      cursor: pointer;
      font-weight: bold;
    }

    .target-token:hover {
      background-color: #b3e5fc;
    }
  </style>
</head>
<body>

  <h1>Attention Visualization</h1>

  <div id="visualizations"></div>

  <script>
    const dataset = [
      {
        source_tokens: ['<sos>', 'k', 'a', 'n', 'd', 'a', 'l', 'a', '<eos>'],
        target_tokens: ['&#x0915;', '&#x093E;', '&#x0902;', '&#x0921;', '&#x0932;', '&#x093E;'],
        attention: [
          [0.0, 0.899, 0.043, 0.052, 0.005, 0.0, 0.0, 0.0, 0.0],
          [0.0, 0.0, 0.0, 0.573, 0.426, 0.0, 0.0, 0.0, 0.0],
          [0.0, 0.0, 0.0, 0.152, 0.576, 0.033, 0.238, 0.0, 0.0],
          [0.0, 0.0, 0.0, 0.004, 0.229, 0.168, 0.526, 0.056, 0.017],
          [0.0, 0.0, 0.0, 0.0, 0.006, 0.051, 0.863, 0.040, 0.040],
          [0.0, 0.0, 0.0, 0.0, 0.0, 0.001, 0.009, 0.438, 0.551]
        ]
      },
      {
        source_tokens: ['<sos>', 'a', 'n', 'u', 's', 'u', 'y', 'a', '<eos>'],
        target_tokens: ['&#x0905;', '&#x0928;', '&#x0941;', '&#x0938;', '&#x0942;', '&#x092F;', '&#x093E;'],
        attention: [
          [0.0, 0.77, 0.112, 0.100, 0.017, 0.0, 0.0, 0.0, 0.0],
          [0.0, 0.0, 0.596, 0.264, 0.140, 0.0, 0.0, 0.0, 0.0],
          [0.0, 0.0, 0.048, 0.554, 0.363, 0.027, 0.008, 0.0, 0.0],
          [0.0, 0.0, 0.0, 0.009, 0.125, 0.612, 0.240, 0.003, 0.012],
          [0.0, 0.0, 0.0, 0.001, 0.005, 0.497, 0.417, 0.009, 0.071],
          [0.0, 0.0, 0.0, 0.0, 0.0, 0.005, 0.193, 0.507, 0.293],
          [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.002, 0.189, 0.808]
        ]
      }
    ];

    const visualizations = document.getElementById('visualizations');

    dataset.forEach((sample, idx) => {
      const card = document.createElement('div');
      card.className = 'card';

      const heading = document.createElement('h2');
      heading.textContent = `Sample ${idx + 1}`;
      card.appendChild(heading);

      const sourceContainer = document.createElement('div');
      sourceContainer.className = 'token-container';
      const srcLabel = document.createElement('span');
      srcLabel.className = 'label';
      srcLabel.textContent = 'Source Tokens:';
      sourceContainer.appendChild(srcLabel);

      const sourceSpans = [];
      sample.source_tokens.forEach((tok, sIdx) => {
        const span = document.createElement('span');
        span.textContent = tok;
        span.className = 'token source-token';
        span.dataset.index = sIdx;
        sourceContainer.appendChild(span);
        sourceSpans.push(span);
      });

      const targetContainer = document.createElement('div');
      targetContainer.className = 'token-container';
      const tgtLabel = document.createElement('span');
      tgtLabel.className = 'label';
      tgtLabel.textContent = 'Target Tokens (hover to see attention):';
      targetContainer.appendChild(tgtLabel);

      sample.target_tokens.forEach((tok, tIdx) => {
        const span = document.createElement('span');
        span.className = 'token target-token';
        span.innerHTML = tok;

        span.addEventListener('mouseover', () => {
          sourceSpans.forEach((srcSpan, sIdx) => {
            const weight = sample.attention[tIdx][sIdx];
            srcSpan.style.backgroundColor = `rgba(255, 215, 0, ${weight})`;
          });
        });

        span.addEventListener('mouseout', () => {
          sourceSpans.forEach(s => {
            s.style.backgroundColor = '#f1f1f6';
          });
        });

        targetContainer.appendChild(span);
      });

      card.appendChild(sourceContainer);
      card.appendChild(targetContainer);
      visualizations.appendChild(card);
    });
  </script>
</body>
</html>

"""
wandb_project = "dakshina-transliteration"
with wandb.init(project=wandb_project, name="attention_final_new2") as run:
    wandb.log({"attention_final_new2": wandb.Html(html_content, inject=False)})