Seq2Seq model

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
from torch.nn.utils.rnn import pad_sequence
import torch.optim as optim
import wandb
import pandas as pd
import os

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class Encoder(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers, rnn_type='LSTM',
                 dropout=0.2, bidirectional=False,use_attention=False):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.is_bidirectional = bidirectional
        self.rnn_type = rnn_type
        self.num_directions = 2 if bidirectional else 1
        self.hidden_dim = hidden_dim
        self.use_attention = use_attention

        rnn_cls = {'RNN': nn.RNN, 'LSTM': nn.LSTM, 'GRU': nn.GRU}[rnn_type]
        self.rnn = rnn_cls(
            input_size=embed_dim,
            hidden_size=hidden_dim // self.num_directions,
            num_layers=num_layers,
            dropout=dropout,
            batch_first=True,
            bidirectional=bidirectional
        )

    def forward(self, x):
        embedded = self.embedding(x)
        outputs, hidden = self.rnn(embedded)
        if self.use_attention:
            return outputs,hidden
        else:
            return hidden


class Decoder(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers, rnn_type='LSTM',
                 dropout=0.2,use_attention=False):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.rnn_type = rnn_type
        self.use_attention = use_attention

        rnn_cls = {'RNN': nn.RNN, 'LSTM': nn.LSTM, 'GRU': nn.GRU}[rnn_type]
        self.rnn = rnn_cls(
            input_size=embed_dim + hidden_dim if use_attention else embed_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            dropout=dropout,
            batch_first=True
        )
        self.attn = None
        if use_attention:
            self.attn = nn.Sequential(
                nn.Linear(hidden_dim * 2, hidden_dim),
                nn.Tanh(),
                nn.Linear(hidden_dim, 1)
            )
        self.fc_out = nn.Linear(hidden_dim, vocab_size)

    def forward(self, input_token, hidden_state, encoder_outputs=None,src_mask=None):
        embedded = self.embedding(input_token.unsqueeze(1))  # (B, 1, E)

        if self.use_attention and encoder_outputs is not None:
            # encoder_outputs: (B, T, H), hidden_state[0][-1]: (B, H)
            if self.rnn_type == 'LSTM':
                query = hidden_state[0][-1].unsqueeze(1)  # (B, 1, H)
            else:
                query = hidden_state[-1].unsqueeze(1)

            # Repeat query across time steps
            query = query.expand(-1, encoder_outputs.size(1), -1)  # (B, T, H)

            # Concatenate and compute attention weights
            energy = self.attn(torch.cat((encoder_outputs, query), dim=2))  # (B, T, 1)
            energy = energy.squeeze(2)
            if src_mask is not None:
                energy = energy.masked_fill(~src_mask, float('-inf'))
            attn_weights = F.softmax(energy, dim=1)  # (B, T)

            # Compute context vector
            context = torch.bmm(attn_weights.unsqueeze(1), encoder_outputs)  # (B, 1, H)
            rnn_input = torch.cat((embedded, context), dim=2)  # (B, 1, E + H)
        else:
            rnn_input = embedded

        rnn_output, hidden = self.rnn(rnn_input, hidden_state)  # rnn_output: (B, 1, H)
        logits = self.fc_out(rnn_output.squeeze(1))  # (B, V)
        if self.use_attention:
            return logits, hidden, attn_weights  # Return attention weights
        else:
            return logits, hidden
        

class TransliterationModel(nn.Module):
    def __init__(self, input_vocab_size, output_vocab_size, embed_dim, hidden_dim,
                 enc_layers, dec_layers, rnn_type='LSTM', dropout=0.2, bidirectional=False,use_attention=False):
        super().__init__()
        self.encoder = Encoder(input_vocab_size, embed_dim, hidden_dim,
                                    enc_layers, rnn_type, dropout, bidirectional,use_attention)
        self.decoder = Decoder(output_vocab_size, embed_dim, hidden_dim,
                                     dec_layers, rnn_type, dropout,use_attention)
        self.rnn_type = rnn_type
        self.bidirectional = bidirectional
        self.hidden_dim = hidden_dim
        self.enc_layers = enc_layers
        self.dec_layers = dec_layers
        self.use_attention = use_attention

    def forward(self, src, tgt, teacher_forcing_ratio=0.5):
        batch_size, tgt_len = tgt.shape
        vocab_size = self.decoder.fc_out.out_features
        outputs = torch.zeros(batch_size, tgt_len, vocab_size, device=src.device)
        src_mask = (src != 0)
        if self.use_attention:
            enc_outputs, enc_hidden = self.encoder(src)
        else:
            enc_hidden = self.encoder(src)

        def merge_bidir_states(state):
            return torch.cat([state[::2], state[1::2]], dim=2)

        def pad_layers(state, target_layers):
            if state.size(0) == target_layers:
                return state
            pad = torch.zeros(target_layers - state.size(0), *state.shape[1:], device=state.device)
            return torch.cat([state, pad], dim=0)

        if self.rnn_type == 'LSTM':
            h, c = enc_hidden
            if self.bidirectional:
                h, c = merge_bidir_states(h), merge_bidir_states(c)
            h, c = pad_layers(h, self.dec_layers), pad_layers(c, self.dec_layers)
            dec_hidden = (h, c)
        else:
            h = enc_hidden
            if self.bidirectional:
                h = merge_bidir_states(h)
            h = pad_layers(h, self.dec_layers)
            dec_hidden = h

        dec_input = tgt[:, 0]  # Start token
        for t in range(1, tgt_len):
            if self.use_attention:
                output, dec_hidden, attn_weights = self.decoder(dec_input, dec_hidden, enc_outputs, src_mask)
                if t == 1:  # Only collect attention weights for visualization once
                    all_attn_weights = attn_weights.unsqueeze(1)  # (B, 1, src_len)
                else:
                    all_attn_weights = torch.cat((all_attn_weights, attn_weights.unsqueeze(1)), dim=1)
            else:
                output, dec_hidden = self.decoder(dec_input, dec_hidden)

            outputs[:, t] = output
            top1 = output.argmax(1)
            teacher_force = random.random() < teacher_forcing_ratio
            dec_input = tgt[:, t] if teacher_force else top1

        if self.use_attention:
            return outputs, all_attn_weights  # Shape: (B, tgt_len-1, src_len)
        else:
            return outputs

def read_pairs(file_path):
    with open(file_path, encoding='utf-8') as f:
        return [(line.split('\t')[1], line.split('\t')[0]) for line in f.read().strip().split('\n') if '\t' in line]

def build_vocab_and_prepare_batch(seqs, device):
    special_tokens = {'<pad>': 0, '<sos>': 1, '<eos>': 2, '<unk>': 3}
    
    # Build character sets
    unique_chars_latin = sorted(set(ch for seq in seqs for ch in seq[0]))
    unique_chars_dev = sorted(set(ch for seq in seqs for ch in seq[1]))

    # Build vocabularies
    src_vocab = {ch: idx + len(special_tokens) for idx, ch in enumerate(unique_chars_latin)}
    tgt_vocab = {ch: idx + len(special_tokens) for idx, ch in enumerate(unique_chars_dev)}
    src_vocab.update(special_tokens)
    tgt_vocab.update(special_tokens)

    idx2src = {idx: ch for ch, idx in src_vocab.items()}
    idx2tgt = {idx: ch for ch, idx in tgt_vocab.items()}

    def encode_text(seq, vocab):
        return [vocab.get(ch, vocab['<unk>']) for ch in seq]

    def create_batch(pairs):
        src = [torch.tensor(encode_text(x, src_vocab) + [src_vocab['<eos>']]) for x, _ in pairs]
        tgt = [torch.tensor([tgt_vocab['<sos>']] + encode_text(y, tgt_vocab) + [tgt_vocab['<eos>']]) for _, y in pairs]
        src = pad_sequence(src, batch_first=True, padding_value=src_vocab['<pad>'])
        tgt = pad_sequence(tgt, batch_first=True, padding_value=tgt_vocab['<pad>'])
        return src.to(device), tgt.to(device)

    return src_vocab, idx2src, tgt_vocab, idx2tgt, create_batch, unique_chars_latin, unique_chars_dev

def compute_word_level_accuracy(preds, targets, vocab):
    sos, eos, pad = vocab['<sos>'], vocab['<eos>'], vocab['<pad>']
    preds = preds.tolist()
    targets = targets.tolist()
    correct = 0
    for p, t in zip(preds, targets):
        p = [x for x in p if x != pad and x != eos]
        t = [x for x in t if x != pad and x != eos]
        if p == t:
            correct += 1
    return correct / len(preds) * 100

def run_training():
    # Initialize wandb config
    wandb.init()
    cfg = wandb.config
    wandb.run.name = (
    f"es_{cfg.embedding_size}_hs_{cfg.hidden_size}_"
    f"enc_{cfg.enc_layers}_dec_{cfg.dec_layers}_"
    f"rnn_{cfg.rnn_type}_dropout_{cfg.dropout_rate}_"
    f"bidirectional_{cfg.is_bidirectional}_"
    f"lr_{cfg.learning_rate}_bs_{cfg.batch_size}_"
    f"epochs_{cfg.epochs}_tfp_{cfg.teacher_forcing_prob}_"
    f"beam_size_{cfg.beam_size}"
    )


    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Load and prepare data
    train_path = "dakshina_dataset_v1.0/hi/lexicons/hi.translit.sampled.train.tsv"
    dev_path = "dakshina_dataset_v1.0/hi/lexicons/hi.translit.sampled.dev.tsv"
    train_set = read_pairs(train_path)
    dev_set = read_pairs(dev_path)

    src_vocab, idx2src, tgt_vocab, idx2tgt, create_batch, _, _ = build_vocab_and_prepare_batch(train_set, device)

    # Initialize model, optimizer, criterion
    model = TransliterationModel(
        len(src_vocab), len(tgt_vocab), cfg.embedding_size, cfg.hidden_size,
        cfg.enc_layers, cfg.dec_layers, cfg.rnn_type, cfg.dropout_rate,
        cfg.is_bidirectional,cfg.use_attention
    ).to(device)
    optimizer = optim.Adam(model.parameters(), lr=cfg.learning_rate)
    criterion = nn.CrossEntropyLoss(ignore_index=tgt_vocab['<pad>'])

    # Training loop
    for epoch in range(cfg.epochs):
        model.train()
        total_loss, total_acc = 0, 0
        random.shuffle(train_set)

        for i in range(0, len(train_set), cfg.batch_size):
            batch = train_set[i:i+cfg.batch_size]
            src, tgt = create_batch(batch)

            optimizer.zero_grad()
            if cfg.use_attention:
                outputs, attn_weights = model(src, tgt, cfg.teacher_forcing_prob)
            else:
                outputs = model(src, tgt, cfg.teacher_forcing_prob)

            loss = criterion(outputs[:, 1:].reshape(-1, outputs.size(-1)), tgt[:, 1:].reshape(-1))
            loss.backward()
            optimizer.step()

            preds = outputs.argmax(-1)
            acc = compute_word_level_accuracy(preds[:, 1:], tgt[:, 1:], tgt_vocab)

            total_loss += loss.item()
            total_acc += acc

        avg_train_loss = total_loss / (len(train_set) // cfg.batch_size)
        avg_train_acc = total_acc / (len(train_set) // cfg.batch_size)

        # Validation
        model.eval()
        dev_loss, dev_acc = 0, 0
        with torch.no_grad():
            for i in range(0, len(dev_set), cfg.batch_size):
                batch = dev_set[i:i+cfg.batch_size]
                src, tgt = create_batch(batch)
                if cfg.use_attention:
                    outputs, attn_weights = model(src, tgt, 0)
                else:
                    outputs = model(src, tgt, 0,)
                loss = criterion(outputs[:, 1:].reshape(-1, outputs.size(-1)), tgt[:, 1:].reshape(-1))

                preds = outputs.argmax(-1)
                acc = compute_word_level_accuracy(preds[:, 1:], tgt[:, 1:], tgt_vocab)

                dev_loss += loss.item()
                dev_acc += acc

        avg_dev_loss = dev_loss / (len(dev_set) // cfg.batch_size)
        avg_dev_acc = dev_acc / (len(dev_set) // cfg.batch_size)

        # Logging
        wandb.log({
            "Epoch": epoch + 1,
            "Train Loss": avg_train_loss,
            "Train Accuracy": avg_train_acc,
            "Validation Loss": avg_dev_loss,
            "Validation Accuracy": avg_dev_acc,
        })

        print(f"Epoch {epoch+1}/{cfg.epochs} | Train Loss: {avg_train_loss:.4f}, Train Acc: {avg_train_acc:.2f}% | Val Loss: {avg_dev_loss:.4f}, Val Acc: {avg_dev_acc:.2f}%")

    wandb.finish()
    return model


In [2]:
def model_eval(cfg,model_path,project_name,csv_file_name):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    wandb.finish()
    wandb.init(
        project=project_name,
        name = 'best_model_test_eval',
        resume="never",
        reinit=True,
        config=cfg
    )
    # Load and prepare data
    train_path = "dakshina_dataset_v1.0/hi/lexicons/hi.translit.sampled.train.tsv"
    train_set = read_pairs(train_path)
    test_path = "dakshina_dataset_v1.0/hi/lexicons/hi.translit.sampled.test.tsv"
    test_set = read_pairs(test_path)

    src_vocab, idx2src, tgt_vocab, idx2tgt, create_batch, _, _ = build_vocab_and_prepare_batch(train_set, device)

    # Initialize model, optimizer, criterion
    model = TransliterationModel(
        len(src_vocab), len(tgt_vocab), cfg['embedding_size'], cfg['hidden_size'],
        cfg['enc_layers'], cfg['dec_layers'], cfg['rnn_type'], cfg['dropout_rate'],
        cfg['is_bidirectional'],cfg['use_attention']
    ).to(device)
    if not os.path.exists(model_path):
        print("❌ No saved model found, starting training.")
        optimizer = optim.Adam(model.parameters(), lr=cfg['learning_rate'])
        criterion = nn.CrossEntropyLoss(ignore_index=tgt_vocab['<pad>'])
        best_acc = 0.0
        # Training loop
        for epoch in range(cfg['epochs']):
            model.train()
            total_loss, total_acc = 0, 0
            random.shuffle(train_set)

            for i in range(0, len(train_set), cfg['batch_size']):
                batch = train_set[i:i+cfg['batch_size']]
                src, tgt = create_batch(batch)

                optimizer.zero_grad()
                if cfg['use_attention']:
                    outputs, attn_weights = model(src, tgt, cfg['teacher_forcing_prob'])
                else:
                    outputs = model(src, tgt, cfg['teacher_forcing_prob'])

                loss = criterion(outputs[:, 1:].reshape(-1, outputs.size(-1)), tgt[:, 1:].reshape(-1))
                loss.backward()
                optimizer.step()

                preds = outputs.argmax(-1)
                acc = compute_word_level_accuracy(preds[:, 1:], tgt[:, 1:], tgt_vocab)

                total_loss += loss.item()
                total_acc += acc

            avg_train_loss = total_loss / (len(train_set) // cfg['batch_size'])
            avg_train_acc = total_acc / (len(train_set) // cfg['batch_size'])

            print(f"Epoch {epoch+1}/{cfg['epochs']} | Train Loss: {avg_train_loss:.4f}, Train Acc: {avg_train_acc:.2f}%")
            wandb.log({"Train Loss": avg_train_loss, "Train Accuracy": avg_train_acc})

            # Save the best model
            if avg_train_acc > best_acc:
                best_acc = avg_train_acc
                torch.save(model.state_dict(), model_path)
                print(f"💾 Saved new best model at epoch {epoch + 1} with accuracy {best_acc:.2f}%")
        print(f"Best model saved with accuracy: {best_acc:.2f}%")

    #test the model
    if os.path.exists(model_path):
        model.load_state_dict(torch.load(model_path))
        print("✅ Loaded saved model from disk.")
    model.eval()
    predictions = []

    with torch.no_grad():
        for i in range(0, len(test_set), cfg['batch_size']):
            batch = test_set[i:i + cfg['batch_size']]
            src, tgt = create_batch(batch)
            if cfg['use_attention']:
                outputs, attn_weights = model(src, tgt, 0)
            else:
                outputs = model(src, tgt, 0)
            preds = outputs.argmax(-1)

            for j in range(src.size(0)):
                input_seq = ''.join([idx2src.get(idx.item(), '') for idx in src[j] if idx.item() not in [src_vocab['<pad>'], src_vocab['<eos>']]])
                target_seq = ''.join([idx2tgt.get(idx.item(), '') for idx in tgt[j][1:] if idx.item() not in [tgt_vocab['<pad>'], tgt_vocab['<eos>']]])
                pred_seq = ''.join([idx2tgt.get(idx.item(), '') for idx in preds[j][1:] if idx.item() not in [tgt_vocab['<pad>'], tgt_vocab['<eos>']]])
                is_correct = target_seq == pred_seq
                predictions.append({'Input': input_seq, 'Target': target_seq, 'Predicted': pred_seq , 'Is_Correct': "True✅" if is_correct else "False❌"})
    predictions = pd.DataFrame(predictions)
    overall_acc = (predictions.Is_Correct == "True✅").mean()
    wandb.log({"Test Accuracy": overall_acc})
    table = wandb.Table(dataframe=predictions)
    wandb.log({f"{csv_file_name}_table": table})
    # finish run
    wandb.finish()
    predictions.to_csv(f'{csv_file_name}.csv', index=False)
    print(f"Saved {len(predictions)} rows, eval accuracy = {overall_acc:.2f}")


Vanilla model parameter tuning

In [None]:
sweep_config = {
    'method': 'bayes',
    'metric': {'name': 'Validation Accuracy', 'goal': 'maximize'},
    'parameters': {
        'embedding_size': {'values': [128, 256]},
        'hidden_size': {'values': [128, 256]},
        'enc_layers': {'values': [2, 3]},
        'dec_layers': {'values': [2, 3]},
        'rnn_type': {'values': ['GRU', 'LSTM','RNN']},
        'dropout_rate': {'values': [0.2, 0.3]},
        'batch_size': {'values': [32, 64]},
        'epochs': {
            'values': [5, 10]},
        'is_bidirectional': {'values': [False, True]},
        'learning_rate': {'values': [0.001, 0.0001]},
        'optimizer': {'values': ['adam', 'nadam']},
        'teacher_forcing_prob': {'values': [0.2, 0.5, 0.7]},
        'beam_size': {'values': [1,3,5]},
        'use_attention': {'values': [False]},
    }
}

sweep_id = wandb.sweep(sweep_config, project="dakshina_transliteration")
wandb.agent(sweep_id, function=run_training, count=50)


Vanilla model evaluation

In [None]:
parameters = {
        'embedding_size':256,
        'hidden_size': 256,
        'enc_layers': 3,
        'dec_layers': 3,
        'rnn_type': 'GRU',
        'dropout_rate': 0.3,
        'batch_size': 64,
        'epochs':10,
        'is_bidirectional':False,
        'learning_rate': 0.001,
        'optimizer': 'nadam',
        'teacher_forcing_prob':0.7,
        'beam_size': 5,
        'use_attention': False,
    }
model_eval(parameters,"best_vanilla_model.pt","transliteration_evaluation","predictions_vanilla")

Attention Model parameter tuning

In [None]:
sweep_config = {
    'method': 'bayes',
    'metric': {'name': 'Validation Accuracy', 'goal': 'maximize'},
    'parameters': {
        'embedding_size': {'values': [128, 256]},
        'hidden_size': {'values': [128, 256]},
        'enc_layers': {'values': [2, 3]},
        'dec_layers': {'values': [2, 3]},
        'rnn_type': {'values': ['GRU', 'LSTM','RNN']},
        'dropout_rate': {'values': [0.2, 0.3]},
        'batch_size': {'values': [32, 64]},
        'epochs': {
            'values': [5, 10]},
        'is_bidirectional': {'values': [False, True]},
        'learning_rate': {'values': [0.001, 0.0001]},
        'optimizer': {'values': ['adam', 'nadam']},
        'teacher_forcing_prob': {'values': [0.2, 0.5, 0.7]},
        'beam_size': {'values': [1,3,5]},
        'use_attention': {'values': [True]},
    }
}

sweep_id = wandb.sweep(sweep_config, project="dakshina_transliteration_attention")
wandb.agent(sweep_id, function=run_training, count=50)


Attention model parameter tuning

In [None]:
parameters = {
        'embedding_size':256,
        'hidden_size': 256,
        'enc_layers': 2,
        'dec_layers': 3,
        'rnn_type': 'LSTM',
        'dropout_rate': 0.3,
        'batch_size': 64,
        'epochs':10,
        'is_bidirectional':True,
        'learning_rate': 0.001,
        'optimizer': 'adam',
        'teacher_forcing_prob':0.7,
        'beam_size': 3,
        'use_attention': True,
    }
model_eval(parameters,"best_attention_model.pt","transliteration_attention_evaluation","predictions_attention")

Comparing attention model with vanilla model

In [None]:
wandb.finish()
wandb.init(
    project="Comparison_vanilla_attention",
    name = 'Compare_attention_vanilla',
    resume="never",
    reinit=True,
)
vanilla_predictions = pd.read_csv("predictions_vanilla.csv")
attention_predictions = pd.read_csv("predictions_attention.csv")
comparison_df = vanilla_predictions.copy()
comparison_df['Vanilla_prediction'] = comparison_df['Predicted']
comparison_df['Attention_prediction'] = attention_predictions['Predicted']
comparison_df['Attention_Is_Correct'] = attention_predictions['Is_Correct']
filtered_df = comparison_df[
    (comparison_df['Is_Correct'] == "False❌") &
    (comparison_df['Attention_Is_Correct'] == "True✅")
][['Input', 'Target', 'Vanilla_prediction', 'Attention_prediction']]
filtered_df.reset_index(drop=True, inplace=True)
filtered_df.to_csv("comparison.csv", index=False)
table = wandb.Table(dataframe=filtered_df)
wandb.log({"Comparison_table": table})
# finish run
wandb.finish()

In [None]:
import matplotlib.pyplot as plt
import matplotlib.font_manager as fm
import seaborn as sns
from collections import Counter
import pandas as pd

# Load the dataset
df = pd.read_csv("comparison.csv")

# Find corrected cases
corrected_cases = df[
    (df["Vanilla_prediction"] != df["Target"]) &
    (df["Attention_prediction"] == df["Target"])
]

# Character-level differences
char_diffs = []
for _, row in corrected_cases.iterrows():
    vanilla = row["Vanilla_prediction"]
    correct = row["Target"]
    for v_char, c_char in zip(vanilla, correct):
        if v_char != c_char:
            char_diffs.append((v_char, c_char))

# Count top 15
from collections import Counter
most_common_corrections = Counter(char_diffs).most_common(15)
labels = [f"{v}→{c}" for v, c in dict(most_common_corrections).keys()]
counts = list(dict(most_common_corrections).values())

# ✅ Use font from local file
font_path = r"font\NotoSansDevanagari-Regular.ttf"
devanagari_font = fm.FontProperties(fname=font_path)

# Plot with explicit font
plt.figure(figsize=(12, 6))
sns.barplot(x=labels, y=counts, palette="mako")

plt.title("Most Common Character-Level Corrections", fontproperties=devanagari_font)
plt.xlabel("Character Correction", fontproperties=devanagari_font)
plt.ylabel("Frequency", fontproperties=devanagari_font)
plt.xticks(rotation=45, fontproperties=devanagari_font)
plt.yticks(fontproperties=devanagari_font)
plt.tight_layout()
plt.show()


In [9]:
import os
import random
from io import BytesIO
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.font_manager as fm
import wandb

# Assumes TransliterationModel, build_vocab_and_prepare_batch, compute_word_level_accuracy, and read_pairs are defined elsewhere

def plot_attention(attn_weights, input_tokens, output_tokens, input_word, output_word, idx):
    # Load Devanagari-compatible font
    font_path = r"font\NotoSansDevanagari-Regular.ttf"
    devanagari_font = fm.FontProperties(fname=font_path, size=12)  # Increase font size here

    # Define special tokens to exclude
    special_tokens = {'<pad>', '<sos>', '<eos>'}
    filtered_input_tokens = [tok for tok in input_tokens if tok not in special_tokens]
    filtered_output_tokens = [tok for tok in output_tokens if tok not in special_tokens]
    attn_weights_filtered = attn_weights[:len(filtered_output_tokens), :len(filtered_input_tokens)]

    # Create heatmap
    plt.figure(figsize=(8, 6))  # Slightly larger figure
    ax = sns.heatmap(attn_weights_filtered,
                     xticklabels=filtered_input_tokens,
                     yticklabels=filtered_output_tokens,
                     cmap='viridis',
                     cbar_kws={"shrink": 0.7})

    # Set font sizes for labels and title
    ax.set_xlabel("Input Sequence (characters)", fontsize=16)
    ax.set_ylabel("Predicted Output (characters)", fontproperties=devanagari_font, fontsize=16)
    ax.set_title(f"Heatmap {idx}: '{input_word}' - '{output_word}'", fontproperties=devanagari_font, fontsize=16)

    # Set tick label font sizes
    ax.tick_params(axis='x', labelsize=16)
    ax.tick_params(axis='y', labelsize=16)

    # Apply Devanagari font to y-tick labels
    for label in ax.get_yticklabels():
        label.set_fontproperties(devanagari_font)

    plt.xticks(rotation=90)
    plt.yticks(rotation=0)

    plt.tight_layout()

    # Convert to wandb.Image
    buf = BytesIO()
    plt.savefig(buf, format='png')
    plt.close()
    buf.seek(0)
    image = Image.open(buf)

    return wandb.Image(image, caption=f"{idx}: {input_word} → {output_word}")



def attention_heatmaps(cfg, model_path):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    wandb.finish()
    wandb.init(
        project="transliteration_attention_heatmap",
        name='best_attention_model_test_eval',
        resume="never",
        reinit=True,
        config=cfg
    )

    # Load and prepare data
    train_path = "dakshina_dataset_v1.0/hi/lexicons/hi.translit.sampled.train.tsv"
    train_set = read_pairs(train_path)
    test_path = "dakshina_dataset_v1.0/hi/lexicons/hi.translit.sampled.test.tsv"
    test_set = read_pairs(test_path)

    src_vocab, idx2src, tgt_vocab, idx2tgt, create_batch, _, _ = build_vocab_and_prepare_batch(train_set, device)

    # Initialize model
    model = TransliterationModel(
        len(src_vocab), len(tgt_vocab), cfg['embedding_size'], cfg['hidden_size'],
        cfg['enc_layers'], cfg['dec_layers'], cfg['rnn_type'], cfg['dropout_rate'],
        cfg['is_bidirectional'], cfg['use_attention']
    ).to(device)

    if not os.path.exists(model_path):
        print("❌ No saved model found, starting training.")
        optimizer = optim.Adam(model.parameters(), lr=cfg['learning_rate'])
        criterion = nn.CrossEntropyLoss(ignore_index=tgt_vocab['<pad>'])
        best_acc = 0.0

        for epoch in range(cfg['epochs']):
            model.train()
            total_loss, total_acc = 0, 0
            random.shuffle(train_set)

            for i in range(0, len(train_set), cfg['batch_size']):
                batch = train_set[i:i + cfg['batch_size']]
                src, tgt = create_batch(batch)

                optimizer.zero_grad()
                outputs, attn_weights = model(src, tgt, cfg['teacher_forcing_prob'])

                loss = criterion(outputs[:, 1:].reshape(-1, outputs.size(-1)), tgt[:, 1:].reshape(-1))
                loss.backward()
                optimizer.step()

                preds = outputs.argmax(-1)
                acc = compute_word_level_accuracy(preds[:, 1:], tgt[:, 1:], tgt_vocab)

                total_loss += loss.item()
                total_acc += acc

            avg_train_loss = total_loss / (len(train_set) // cfg['batch_size'])
            avg_train_acc = total_acc / (len(train_set) // cfg['batch_size'])

            print(f"Epoch {epoch+1}/{cfg['epochs']} | Train Loss: {avg_train_loss:.4f}, Train Acc: {avg_train_acc:.2f}%")

            if avg_train_acc > best_acc:
                best_acc = avg_train_acc
                torch.save(model.state_dict(), model_path)
                print(f"💾 Saved new best model at epoch {epoch + 1} with accuracy {best_acc:.2f}%")

        print(f"Best model saved with accuracy: {best_acc:.2f}%")

    if os.path.exists(model_path):
        model.load_state_dict(torch.load(model_path))
        print("✅ Loaded saved model from disk.")
    model.eval()

    num_plots = 0
    max_plots = 10
    images = []

    with torch.no_grad():
        for i in range(0, len(test_set), cfg['batch_size']):
            batch = test_set[i:i + cfg['batch_size']]
            src, tgt = create_batch(batch)

            outputs, attn_weights = model(src, tgt, teacher_forcing_ratio=0.0)
            preds = outputs.argmax(-1)

            for j in range(src.size(0)):
                input_seq = ''.join([idx2src.get(idx.item(), '') for idx in src[j] if idx.item() not in [src_vocab['<pad>'], src_vocab['<eos>']]])
                target_seq = ''.join([idx2tgt.get(idx.item(), '') for idx in tgt[j][1:] if idx.item() not in [tgt_vocab['<pad>'], tgt_vocab['<eos>']]])
                pred_seq = ''.join([idx2tgt.get(idx.item(), '') for idx in preds[j][1:] if idx.item() not in [tgt_vocab['<pad>'], tgt_vocab['<eos>']]])

                if num_plots < max_plots:
                    input_tokens = [idx2src.get(idx.item(), '') for idx in src[j] if idx.item() not in [src_vocab['<pad>'], src_vocab['<eos>']]]
                    output_tokens = [idx2tgt.get(idx.item(), '') for idx in preds[j][1:] if idx.item() not in [tgt_vocab['<pad>'], tgt_vocab['<eos>']]]
                    attn_matrix = attn_weights[j].cpu().numpy()

                    wandb_img = plot_attention(attn_matrix, input_tokens, output_tokens, input_seq, pred_seq, num_plots + 1)
                    images.append(wandb_img)
                    num_plots += 1

                if num_plots >= max_plots:
                    break
            if num_plots >= max_plots:
                break

    if images:
        wandb.log({"attention_heatmaps": images})
        wandb.finish()


In [None]:
parameters = {
        'embedding_size':256,
        'hidden_size': 256,
        'enc_layers': 2,
        'dec_layers': 3,
        'rnn_type': 'LSTM',
        'dropout_rate': 0.3,
        'batch_size': 64,
        'epochs':10,
        'is_bidirectional':True,
        'learning_rate': 0.001,
        'optimizer': 'adam',
        'teacher_forcing_prob':0.7,
        'beam_size': 3,
        'use_attention': True,
    }
attention_heatmaps(parameters,"best_attention_model.pt")

Interactive plot

In [20]:
def generate_sentence_html(words_data):
    """
    words_data is a list of dictionaries, one per word.
    Each dictionary should have:
      - 'input_chars': list of input characters (Latin)
      - 'output_chars': list of output characters (native script)
      - 'attention_weights': list of lists (each inner list the attention weights for corresponding output character)
    """
    html_template = f"""
    <html>
    <head>
      <meta charset="UTF-8">
      <style>
        body {{
          font-family: Arial, sans-serif;
          padding: 20px;
        }}
        .top-attn-boxes {{
          display: flex;
          gap: 20px;
          margin-bottom: 20px;
        }}
        .attn-box {{
          border: 1px solid #ccc;
          padding: 10px;
          width: 150px;
          height: 50px;
          text-align: center;
          font-size: 18px;
          background-color: #f9f9f9;
          border-radius: 8px;
          display: flex;
          align-items: center;
          justify-content: center;
        }}
        .sentence {{
          line-height: 2;
          font-size: 28px;
        }}
        .word {{
          margin-right: 20px; /* proper spacing between words */
          display: inline-block;
        }}
        .output-char {{
          display: inline-block;
          margin: 0 3px;
          padding: 8px 5px;
          cursor: pointer;
          border-bottom: 1px dotted #555;
          transition: background-color 0.2s;
        }}
        .output-char:hover {{
          background-color: #eef;
        }}
      </style>
    </head>
    <body>
      <div class="top-attn-boxes">
        <div class="attn-box" id="top1">—</div>
        <div class="attn-box" id="top2">—</div>
        <div class="attn-box" id="top3">—</div>
      </div>
      <div class="sentence">
    """
    # For each word in the sentence
    for word in words_data:
        input_chars = word["input_chars"]
        output_chars = word["output_chars"]
        attention_weights = word["attention_weights"]
        word_html = '<span class="word">'
        for i, out_char in enumerate(output_chars):
            # Prepare data: list of dicts with char and weight for this output char.
            data = [
                {"char": input_chars[j], "weight": round(w, 3)}
                for j, w in enumerate(attention_weights[i])
            ]
            # Encode the data into a HTML-friendly format.
            data_str = str(data).replace("'", "&quot;")
            # If an output character is empty, we show a placeholder (like ␣)
            display_char = out_char if out_char else "␣"
            word_html += f'<span class="output-char" data-attn="{data_str}">{display_char}</span>'
        word_html += '</span>'  # Close the word span.
        html_template += word_html

    html_template += """
      </div>
      <script>
        function showTop3(attnData) {
          // Create a shallow copy to sort so we don't modify the original array
          let sortedData = attnData.slice().sort((a, b) => b.weight - a.weight);
          const top = sortedData.slice(0, 3);
          for (let i = 0; i < 3; i++) {
            const el = document.getElementById("top" + (i + 1));
            if (i < top.length) {
              el.innerText = top[i].char + " : " + top[i].weight.toFixed(2);
            } else {
              el.innerText = "—";
            }
          }
        }

        // Attach hover event to each character.
        document.querySelectorAll(".output-char").forEach(span => {
          span.addEventListener("mouseenter", () => {
            // Replace HTML entity quotes with actual quotes and parse the JSON data.
            const attn = JSON.parse(span.dataset.attn.replace(/&quot;/g, '"'));
            showTop3(attn);
          });
        });
      </script>
    </body>
    </html>
    """
    return html_template


In [21]:
import torch
import torch.nn as nn
import torch.optim as optim
import os
import random
import wandb

def Interactive_plot(cfg, model_path):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    wandb.finish()
    wandb.init(
        project="transliteration_attention_Interactive_plot",
        name='connectivity_plot',
        resume="never",
        reinit=True,
        config=cfg
    )

    # Load training and test datasets
    train_path = "dakshina_dataset_v1.0/hi/lexicons/hi.translit.sampled.train.tsv"
    train_set = read_pairs(train_path)
    test_path = "dakshina_dataset_v1.0/hi/lexicons/hi.translit.sampled.test.tsv"
    test_set = read_pairs(test_path)

    # Prepare vocabulary and batch creation
    src_vocab, idx2src, tgt_vocab, idx2tgt, create_batch, tensor_to_words, _ = build_vocab_and_prepare_batch(train_set, device)

    # Initialize the model
    model = TransliterationModel(
        len(src_vocab), len(tgt_vocab), cfg['embedding_size'], cfg['hidden_size'],
        cfg['enc_layers'], cfg['dec_layers'], cfg['rnn_type'], cfg['dropout_rate'],
        cfg['is_bidirectional'], cfg['use_attention']
    ).to(device)

    # If model not trained yet, train it
    if not os.path.exists(model_path):
        print("❌ No saved model found, starting training.")
        optimizer = optim.Adam(model.parameters(), lr=cfg['learning_rate'])
        criterion = nn.CrossEntropyLoss(ignore_index=tgt_vocab['<pad>'])
        best_acc = 0.0

        for epoch in range(cfg['epochs']):
            model.train()
            total_loss, total_acc = 0, 0
            random.shuffle(train_set)

            for i in range(0, len(train_set), cfg['batch_size']):
                batch = train_set[i:i + cfg['batch_size']]
                src, tgt = create_batch(batch)

                optimizer.zero_grad()
                outputs, attn_weights = model(src, tgt, cfg['teacher_forcing_prob'])

                loss = criterion(outputs[:, 1:].reshape(-1, outputs.size(-1)), tgt[:, 1:].reshape(-1))
                loss.backward()
                optimizer.step()

                preds = outputs.argmax(-1)
                acc = compute_word_level_accuracy(preds[:, 1:], tgt[:, 1:], tgt_vocab)

                total_loss += loss.item()
                total_acc += acc

            avg_train_loss = total_loss / (len(train_set) // cfg['batch_size'])
            avg_train_acc = total_acc / (len(train_set) // cfg['batch_size'])

            print(f"Epoch {epoch+1}/{cfg['epochs']} | Train Loss: {avg_train_loss:.4f}, Train Acc: {avg_train_acc:.2f}%")

            if avg_train_acc > best_acc:
                best_acc = avg_train_acc
                torch.save(model.state_dict(), model_path)
                print(f"💾 Saved new best model at epoch {epoch + 1} with accuracy {best_acc:.2f}%")

        print(f"Best model saved with accuracy: {best_acc:.2f}%")

    # Load and evaluate saved model
    if os.path.exists(model_path):
        model.load_state_dict(torch.load(model_path, map_location=device))
        print("✅ Loaded saved model from disk.")

    model.eval()
    num_plots = 0
    max_plots = 10
    word_data=[]
    random.shuffle(test_set)
    with torch.no_grad():
        for i in range(0, len(test_set), cfg['batch_size']):
            batch = test_set[i:i + cfg['batch_size']]
            src, tgt = create_batch(batch)

            outputs, attn_weights = model(src, tgt, teacher_forcing_ratio=0.0)
            preds = outputs.argmax(-1)

            for j in range(src.size(0)):
                input_seq = ''.join([idx2src.get(idx.item(), '') for idx in src[j] if idx.item() not in [src_vocab['<pad>'], src_vocab['<eos>']]])
                target_seq = ''.join([idx2tgt.get(idx.item(), '') for idx in tgt[j][1:] if idx.item() not in [tgt_vocab['<pad>'], tgt_vocab['<eos>']]])
                pred_seq = ''.join([idx2tgt.get(idx.item(), '') for idx in preds[j][1:] if idx.item() not in [tgt_vocab['<pad>'], tgt_vocab['<eos>']]])

                if num_plots < max_plots:
                    input_tokens = [idx2src.get(idx.item(), '') for idx in src[j] if idx.item() not in [src_vocab['<pad>'], src_vocab['<eos>']]]
                    output_tokens = [idx2tgt.get(idx.item(), '') for idx in preds[j][1:] if idx.item() not in [tgt_vocab['<pad>'], tgt_vocab['<eos>']]]
                    attn_matrix = attn_weights[j].cpu().numpy()
                        # Define special tokens to exclude
                    special_tokens = {'<pad>', '<sos>', '<eos>'}
                    filtered_input_tokens = [tok for tok in input_tokens if tok not in special_tokens]
                    filtered_output_tokens = [tok for tok in output_tokens if tok not in special_tokens]
                    attn_matrix_filtered = attn_matrix[:len(filtered_output_tokens), :len(filtered_input_tokens)]
                    num_plots += 1
                    pred_dict = {
                        "input_chars": filtered_input_tokens,
                        "output_chars": filtered_output_tokens,
                        "attention_weights": attn_matrix_filtered.tolist(),
                    }
                    word_data.append(pred_dict)
                else:
                    break
            if num_plots >= max_plots:
                break
    # Generate HTML
    html_str = generate_sentence_html(word_data)

    # Save the HTML file with UTF-8 encoding
    with open("sentence_attention.html", "w", encoding="utf-8") as f:
        f.write(html_str)

    wandb.log({"sentence_attention_viz": wandb.Html("sentence_attention.html", inject=False)})
    wandb.finish()

    print("HTML file generated and logged to wandb (if wandb is configured).")



In [22]:
parameters = {
        'embedding_size':256,
        'hidden_size': 256,
        'enc_layers': 2,
        'dec_layers': 3,
        'rnn_type': 'LSTM',
        'dropout_rate': 0.3,
        'batch_size': 64,
        'epochs':10,
        'is_bidirectional':True,
        'learning_rate': 0.001,
        'optimizer': 'adam',
        'teacher_forcing_prob':0.7,
        'beam_size': 3,
        'use_attention': True,
    }
Interactive_plot(parameters,"best_attention_model.pt")



✅ Loaded saved model from disk.


HTML file generated and logged to wandb (if wandb is configured).
