In [1]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
import wandb
from tqdm.notebook import tqdm
import random
import matplotlib.pyplot as plt
import copy
from torch.cuda.amp import autocast, GradScaler
import matplotlib.font_manager as fm

# Set random seeds for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cuda


In [2]:
!wget -qq https://ektype.in/fontshost/Anek_Devanagari.zip
!unzip Anek_Devanagari.zip "static/AnekDevanagari/*"
!cp -r static/AnekDevanagari /usr/share/fonts/truetype

Archive:  Anek_Devanagari.zip
  inflating: static/AnekDevanagari/AnekDevanagari-Thin.ttf  
  inflating: static/AnekDevanagari/AnekDevanagari-ExtraLight.ttf  
  inflating: static/AnekDevanagari/AnekDevanagari-Light.ttf  
  inflating: static/AnekDevanagari/AnekDevanagari-Regular.ttf  
  inflating: static/AnekDevanagari/AnekDevanagari-Medium.ttf  
  inflating: static/AnekDevanagari/AnekDevanagari-SemiBold.ttf  
  inflating: static/AnekDevanagari/AnekDevanagari-Bold.ttf  
  inflating: static/AnekDevanagari/AnekDevanagari-ExtraBold.ttf  


In [3]:
font_path = '/usr/share/fonts/truetype/AnekDevanagari/AnekDevanagari-Regular.ttf'
font_prop = fm.FontProperties(fname=font_path)

In [4]:
!wget https://storage.googleapis.com/gresearch/dakshina/dakshina_dataset_v1.0.tar
!tar -xf dakshina_dataset_v1.0.tar

# Data paths
train_path = "dakshina_dataset_v1.0/hi/lexicons/hi.translit.sampled.train.tsv"
val_path = "dakshina_dataset_v1.0/hi/lexicons/hi.translit.sampled.dev.tsv"
test_path = "dakshina_dataset_v1.0/hi/lexicons/hi.translit.sampled.test.tsv"

# Load data
train_df = pd.read_csv(train_path, delimiter='\t', names=['hi', 'en', '_'])
val_df = pd.read_csv(val_path, delimiter='\t', names=['hi', 'en', '_'])
test_df = pd.read_csv(test_path, delimiter='\t', names=['hi', 'en', '_'])

print(f"Train samples: {len(train_df)}")
print(f"Validation samples: {len(val_df)}")
print(f"Test samples: {len(test_df)}")

# Check max sequence lengths
src_max_len = max([len(str(text)) for text in train_df['en']])
tgt_max_len = max([len(str(text)) for text in train_df['hi']])
print(f"Max source sequence length: {src_max_len}")
print(f"Max target sequence length: {tgt_max_len}")

--2025-05-21 10:35:45--  https://storage.googleapis.com/gresearch/dakshina/dakshina_dataset_v1.0.tar
Resolving storage.googleapis.com (storage.googleapis.com)... 142.251.2.207, 74.125.137.207, 142.250.101.207, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|142.251.2.207|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2008340480 (1.9G) [application/x-tar]
Saving to: ‘dakshina_dataset_v1.0.tar’


2025-05-21 10:35:55 (194 MB/s) - ‘dakshina_dataset_v1.0.tar’ saved [2008340480/2008340480]

Train samples: 44204
Validation samples: 4358
Test samples: 4502
Max source sequence length: 20
Max target sequence length: 19


In [5]:
def create_vocab(texts, special_tokens=True):
    chars = set()
    for text in texts:
        for char in str(text):
            chars.add(char)

    # Create vocabulary dictionary
    if special_tokens:
        vocab = {'<PAD>': 0, '<SOS>': 1, '<EOS>': 2, '<UNK>': 3}
    else:
        vocab = {}

    for i, char in enumerate(sorted(list(chars))):
        vocab[char] = i + 4

    return vocab

def text_to_indices(text, vocab):
    indices = [vocab['<SOS>']]
    for char in str(text):
        if char in vocab:
            indices.append(vocab[char])
        elif char.lower() in vocab:
            indices.append(vocab[char.lower()])
        else:
            indices.append(vocab['<UNK>'])
    indices.append(vocab['<EOS>'])
    return indices

# Create vocabularies
src_vocab = create_vocab(train_df['en'])
tgt_vocab = create_vocab(train_df['hi'])

# Create reverse mappings for visualization
idx2src = {idx: char for char, idx in src_vocab.items()}
idx2tgt = {idx: char for char, idx in tgt_vocab.items()}

print(f"Source vocabulary size: {len(src_vocab)}")
print(f"Target vocabulary size: {len(tgt_vocab)}")

Source vocabulary size: 30
Target vocabulary size: 67


In [6]:
class TransliterationDataset(Dataset):
    def __init__(self, dataframe, src_vocab, tgt_vocab):
        self.dataframe = dataframe
        self.src_vocab = src_vocab
        self.tgt_vocab = tgt_vocab

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        src_text = self.dataframe.iloc[idx]['en']
        tgt_text = self.dataframe.iloc[idx]['hi']

        src_indices = text_to_indices(src_text, self.src_vocab)
        tgt_indices = text_to_indices(tgt_text, self.tgt_vocab)

        return torch.tensor(src_indices), torch.tensor(tgt_indices)

def collate_fn(batch):
    src_batch, tgt_batch = [], []
    for src, tgt in batch:
        # Safety check for index bounds
        src = torch.clamp(src, 0, len(src_vocab)-1)
        tgt = torch.clamp(tgt, 0, len(tgt_vocab)-1)

        # Pad or truncate to max lengths
        src = src[:20]  # Max source length is 20
        tgt = tgt[:19]  # Max target length is 19

        src_batch.append(src)
        tgt_batch.append(tgt)

    src_batch = pad_sequence(src_batch, batch_first=True, padding_value=src_vocab['<PAD>'])
    tgt_batch = pad_sequence(tgt_batch, batch_first=True, padding_value=tgt_vocab['<PAD>'])

    return src_batch, tgt_batch

In [7]:
class Encoder(nn.Module):
    def __init__(self, input_vocab_size, embedding_dim, hidden_dim, num_layers, dropout, cell_type):
        super().__init__()
        self.embedding = nn.Embedding(input_vocab_size, embedding_dim)
        self.cell_type = cell_type.lower()

        if self.cell_type == "lstm":
            self.rnn = nn.LSTM(embedding_dim, hidden_dim, num_layers=num_layers,
                              dropout=dropout if num_layers > 1 else 0, batch_first=True)
        elif self.cell_type == "gru":
            self.rnn = nn.GRU(embedding_dim, hidden_dim, num_layers=num_layers,
                             dropout=dropout if num_layers > 1 else 0, batch_first=True)
        else:  # rnn
            self.rnn = nn.RNN(embedding_dim, hidden_dim, num_layers=num_layers,
                             dropout=dropout if num_layers > 1 else 0, batch_first=True)

        self.dropout = nn.Dropout(dropout)

        # Apply weight initialization
        self._init_weights()

    def _init_weights(self):
        for name, param in self.named_parameters():
            if 'weight_ih' in name:
                nn.init.xavier_uniform_(param.data)
            elif 'weight_hh' in name:
                nn.init.orthogonal_(param.data)
            elif 'bias' in name:
                param.data.fill_(0)

    def forward(self, src):
        # src: [batch_size, src_len]
        embedded = self.dropout(self.embedding(src))  # [batch_size, src_len, emb_dim]

        if self.cell_type == "lstm":
            outputs, (hidden, cell) = self.rnn(embedded)
            return outputs, hidden, cell
        else:
            outputs, hidden = self.rnn(embedded)
            return outputs, hidden, None

In [8]:
class AttentionDecoder(nn.Module):
    def __init__(self, output_vocab_size, embedding_dim, hidden_dim, num_layers, dropout, cell_type):
        super().__init__()
        self.output_vocab_size = output_vocab_size
        self.hidden_dim = hidden_dim
        self.embedding = nn.Embedding(output_vocab_size, embedding_dim)
        self.cell_type = cell_type.lower()

        # Attention mechanism
        self.attention = nn.Linear(hidden_dim * 2, hidden_dim)
        self.attention_combine = nn.Linear(hidden_dim + embedding_dim, embedding_dim)

        # RNN layer
        if cell_type == "lstm":
            self.rnn = nn.LSTM(embedding_dim, hidden_dim, num_layers=num_layers,
                              dropout=dropout if num_layers > 1 else 0, batch_first=True)
        elif cell_type == "gru":
            self.rnn = nn.GRU(embedding_dim, hidden_dim, num_layers=num_layers,
                             dropout=dropout if num_layers > 1 else 0, batch_first=True)
        else:  # rnn
            self.rnn = nn.RNN(embedding_dim, hidden_dim, num_layers=num_layers,
                             dropout=dropout if num_layers > 1 else 0, batch_first=True)

        self.fc_out = nn.Linear(hidden_dim, output_vocab_size)
        self.dropout = nn.Dropout(dropout)

        # Apply weight initialization
        self._init_weights()

    def _init_weights(self):
        for name, param in self.named_parameters():
            if 'weight_ih' in name:
                nn.init.xavier_uniform_(param.data)
            elif 'weight_hh' in name:
                nn.init.orthogonal_(param.data)
            elif 'bias' in name:
                param.data.fill_(0)

    def forward(self, input, hidden, cell, encoder_outputs):
        # input: [batch_size]
        # hidden: [num_layers, batch_size, hidden_dim]
        # encoder_outputs: [batch_size, src_len, hidden_dim]

        input = input.unsqueeze(1)  # [batch_size, 1]
        embedded = self.dropout(self.embedding(input))  # [batch_size, 1, emb_dim]

        # Calculate attention weights
        batch_size = encoder_outputs.size(0)
        src_len = encoder_outputs.size(1)

        # Use the last layer of hidden state for attention
        attn_hidden = hidden[-1].unsqueeze(1)  # [batch_size, 1, hidden_dim]

        # Repeat for each encoder output
        attn_hidden = attn_hidden.repeat(1, src_len, 1)  # [batch_size, src_len, hidden_dim]

        # Concatenate encoder outputs and hidden state
        energy = torch.cat((encoder_outputs, attn_hidden), dim=2)  # [batch_size, src_len, 2*hidden_dim]
        energy = self.attention(energy)  # [batch_size, src_len, hidden_dim]
        energy = torch.tanh(energy)

        # Calculate attention weights
        attn_weights = torch.sum(energy, dim=2)  # [batch_size, src_len]
        attn_weights = F.softmax(attn_weights, dim=1).unsqueeze(1)  # [batch_size, 1, src_len]

        # Apply attention weights to encoder outputs
        context = torch.bmm(attn_weights, encoder_outputs)  # [batch_size, 1, hidden_dim]

        # Combine embedded input and context vector
        rnn_input = torch.cat((embedded, context), dim=2)  # [batch_size, 1, emb_dim + hidden_dim]
        rnn_input = self.attention_combine(rnn_input)  # [batch_size, 1, emb_dim]

        # Pass through RNN
        if self.cell_type == "lstm":
            output, (hidden, cell) = self.rnn(rnn_input, (hidden, cell))
        else:
            output, hidden = self.rnn(rnn_input, hidden)
            cell = None

        # Generate output
        prediction = self.fc_out(output.squeeze(1))  # [batch_size, output_vocab_size]

        return prediction, hidden, cell, attn_weights

In [9]:
class AttentionSeq2Seq(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.encoder = Encoder(
            config['input_vocab_size'],
            config['embedding_dim'],
            config['hidden_dim'],
            config['num_encoding_layers'],
            config['dropout'],
            config['cell_type']
        )
        self.decoder = AttentionDecoder(
            config['output_vocab_size'],
            config['embedding_dim'],
            config['hidden_dim'],
            config['num_decoding_layers'],
            config['dropout'],
            config['cell_type']
        )
        self.device = config.get('device', device)
        self.teacher_forcing_ratio = config.get('teacher_forcing_ratio', 0.5)
        self.cell_type = config['cell_type'].lower()
        self.config = config

    def forward(self, src, trg, teacher_forcing=1):
        batch_size = trg.shape[0]
        trg_len = trg.shape[1]
        trg_vocab_size = self.decoder.output_vocab_size

        # Tensor to store decoder outputs
        outputs = torch.zeros(batch_size, trg_len, trg_vocab_size).to(self.device)

        # Tensor to store attention weights
        attentions = torch.zeros(batch_size, trg_len, src.shape[1]).to(self.device)

        # Encode source sequence
        encoder_outputs, hidden, cell = self.encoder(src)

        # Adjust hidden state dimensions if needed
        enc_layers = self.config['num_encoding_layers']
        dec_layers = self.config['num_decoding_layers']
        hidden_size = self.config['hidden_dim']

        if enc_layers != dec_layers:
            if self.cell_type != 'lstm':
                # Case 1: Encoder has more layers - take only what we need
                if enc_layers > dec_layers:
                    hidden = hidden[:dec_layers]
                # Case 2: Decoder has more layers - pad with zeros
                else:
                    padding = torch.zeros(dec_layers - enc_layers, batch_size, hidden_size).to(self.device)
                    hidden = torch.cat([hidden, padding], dim=0)
            else:  # LSTM case
                if enc_layers > dec_layers:
                    hidden = hidden[:dec_layers]
                    cell = cell[:dec_layers]
                else:
                    padding = torch.zeros(dec_layers - enc_layers, batch_size, hidden_size).to(self.device)
                    hidden = torch.cat([hidden, padding], dim=0)
                    cell = torch.cat([cell, padding], dim=0)

        # First input to decoder is <SOS> token
        input = trg[:, 0]

        for t in range(1, trg_len):
            # Get decoder output
            output, hidden, cell, attn_weights = self.decoder(input, hidden, cell, encoder_outputs)

            # Store prediction and attention weights
            outputs[:, t, :] = output
            attentions[:, t, :] = attn_weights.squeeze(1)

            # Teacher forcing
            teacher_force = random.random() < self.teacher_forcing_ratio * teacher_forcing

            # Get highest predicted token
            top1 = output.argmax(1)

            # Next input is either ground truth or predicted token
            input = trg[:, t] if teacher_force else top1

        return outputs, attentions

In [10]:
def evaluate(model, dataloader, criterion):
    model.eval()
    epoch_loss = 0

    # For exact match accuracy
    exact_match_correct = 0
    exact_match_total = 0

    # For character-level accuracy
    char_correct = 0
    char_total = 0

    with torch.no_grad():
        for src, trg in tqdm(dataloader, desc="Evaluating", leave=False):
            src, trg = src.to(device), trg.to(device)

            output, _ = model(src, trg, 0)  # Turn off teacher forcing

            # For loss calculation
            output_dim = output.shape[-1]
            output_flat = output[:, 1:].reshape(-1, output_dim)
            trg_flat = trg[:, 1:].reshape(-1)

            loss = criterion(output_flat, trg_flat)
            epoch_loss += loss.item()

            # Get predictions
            predictions = output.argmax(dim=2)

            # Calculate exact match and character-level accuracy
            for i in range(len(predictions)):
                pred_seq = predictions[i, 1:].cpu().numpy()  # Skip <SOS>
                target_seq = trg[i, 1:].cpu().numpy()  # Skip <SOS>

                # Get valid sequence (remove padding)
                valid_length = (target_seq != tgt_vocab['<PAD>']).sum()
                pred_clean = pred_seq[:valid_length]
                target_clean = target_seq[:valid_length]

                # Check exact match
                if np.array_equal(pred_clean, target_clean):
                    exact_match_correct += 1
                exact_match_total += 1

                # Calculate character-level accuracy
                for j in range(valid_length):
                    if pred_seq[j] == target_seq[j]:
                        char_correct += 1
                    char_total += 1

    # Calculate metrics
    exact_match_accuracy = exact_match_correct / exact_match_total if exact_match_total > 0 else 0
    char_accuracy = char_correct / char_total if char_total > 0 else 0

    return {
        'loss': epoch_loss / len(dataloader),
        'exact_match_accuracy': exact_match_accuracy,
        'char_accuracy': char_accuracy
    }

In [11]:
def translate_sentence(model, sentence, src_vocab, tgt_vocab, idx2tgt, max_len=50):
    model.eval()

    # Convert to indices and add <SOS> and <EOS>
    indices = text_to_indices(sentence, src_vocab)
    src_tensor = torch.LongTensor(indices).unsqueeze(0).to(device)

    # Get encoder outputs
    with torch.no_grad():
        encoder_outputs, hidden, cell = model.encoder(src_tensor)

    # Adjust hidden state dimensions if needed
    enc_layers = model.config['num_encoding_layers']
    dec_layers = model.config['num_decoding_layers']
    hidden_size = model.config['hidden_dim']

    if enc_layers != dec_layers:
        batch_size = 1  # Since we're translating one sentence
        if model.cell_type != 'lstm':
            if enc_layers > dec_layers:
                hidden = hidden[:dec_layers]
            else:
                padding = torch.zeros(dec_layers - enc_layers, batch_size, hidden_size).to(device)
                hidden = torch.cat([hidden, padding], dim=0)
        else:  # LSTM case
            if enc_layers > dec_layers:
                hidden = hidden[:dec_layers]
                cell = cell[:dec_layers]
            else:
                padding = torch.zeros(dec_layers - enc_layers, batch_size, hidden_size).to(device)
                hidden = torch.cat([hidden, padding], dim=0)
                cell = torch.cat([cell, padding], dim=0)

    # Start with <SOS> token
    trg_idx = [tgt_vocab['<SOS>']]
    attentions = []

    for _ in range(max_len):
        trg_tensor = torch.LongTensor([trg_idx[-1]]).to(device)

        with torch.no_grad():
            output, hidden, cell, attn_weights = model.decoder(trg_tensor, hidden, cell, encoder_outputs)

        # Store attention weights
        attentions.append(attn_weights.squeeze().cpu().numpy())

        pred_token = output.argmax(1).item()

        # Stop if <EOS> token
        if pred_token == tgt_vocab['<EOS>']:
            break

        trg_idx.append(pred_token)

    # Convert indices to characters
    trg_tokens = [idx2tgt[i] for i in trg_idx if i not in [tgt_vocab['<SOS>'], tgt_vocab['<EOS>'], tgt_vocab['<PAD>'], tgt_vocab['<UNK>']]]

    return ''.join(trg_tokens), attentions

In [13]:
def compute_gradient_connectivity(model, src_text, src_vocab, tgt_vocab, idx2tgt, max_len=50):
    """
    Compute gradient-based connectivity between input and output characters.

    Args:
        model: The Seq2Seq model
        src_text: Source text (Latin characters)
        src_vocab: Source vocabulary mapping
        tgt_vocab: Target vocabulary mapping
        idx2tgt: Mapping from indices to target characters
        max_len: Maximum generation length

    Returns:
        translation: Generated translation
        connectivity: Matrix of gradient magnitudes [tgt_len, src_len]
    """
    # Step 2.1: Set model to train mode to enable gradient computation
    model.train()

    # Step 2.2: Convert source text to indices
    src_indices = text_to_indices(src_text, src_vocab)
    src_tensor = torch.LongTensor(src_indices).unsqueeze(0).to(device)

    # Step 2.3: Get embeddings with gradient tracking
    # This is the input we'll compute gradients with respect to
    src_emb = model.encoder.embedding(src_tensor)
    src_emb.retain_grad()  # Keep gradients for this tensor

    # Step 2.4: Get encoder outputs
    if model.cell_type == "lstm":
        encoder_outputs, (hidden, cell) = model.encoder.rnn(src_emb)
    else:  # GRU or RNN
        encoder_outputs, hidden = model.encoder.rnn(src_emb)
        cell = None

    # Step 2.5: Adjust hidden state dimensions if needed
    enc_layers = model.config['num_encoding_layers']
    dec_layers = model.config['num_decoding_layers']
    hidden_size = model.config['hidden_dim']
    batch_size = 1

    if enc_layers != dec_layers:
        if model.cell_type != 'lstm':
            if enc_layers > dec_layers:
                hidden = hidden[:dec_layers]
            else:
                padding = torch.zeros(dec_layers - enc_layers, batch_size, hidden_size).to(device)
                hidden = torch.cat([hidden, padding], dim=0)
        else:  # LSTM case
            if enc_layers > dec_layers:
                hidden = hidden[:dec_layers]
                cell = cell[:dec_layers]
            else:
                padding = torch.zeros(dec_layers - enc_layers, batch_size, hidden_size).to(device)
                hidden = torch.cat([hidden, padding], dim=0)
                cell = torch.cat([cell, padding], dim=0)

    # Step 2.6: Start decoding with <SOS> token
    trg_idx = [tgt_vocab['<SOS>']]
    gradient_list = []

    # Step 2.7: Generate translation and compute gradients
    for _ in range(max_len):
        # Clear previous gradients
        model.zero_grad()

        # Get current decoder input
        trg_tensor = torch.LongTensor([trg_idx[-1]]).to(device)

        # Forward pass through decoder
        if hasattr(model, 'decoder') and hasattr(model.decoder, 'attention'):
            # For attention model
            output, hidden, cell, _ = model.decoder(trg_tensor, hidden, cell, encoder_outputs)
        else:
            # For vanilla model
            if model.cell_type == "lstm":
                output, hidden, cell = model.decoder(trg_tensor, hidden, cell)
            else:
                output, hidden = model.decoder(trg_tensor, hidden)
                cell = None

        # Get predicted token
        pred_token = output.argmax(1).item()

        # Step 2.8: Compute gradients with respect to the predicted token
        # This implements the formula: ||∂(h^L_t̃)_k / ∂x_t||^2
        output[0, pred_token].backward(retain_graph=True)

        # Step 2.9: Get gradients for embedding
        if src_emb.grad is not None:
            # Sum across embedding dimension and square (for magnitude)
            grad_magnitude = src_emb.grad.pow(2).sum(dim=2).squeeze(0).detach().cpu().numpy()
            gradient_list.append(grad_magnitude)
        else:
            # Fallback if no gradients
            gradient_list.append(np.ones(len(src_indices)) / len(src_indices))

        # Step 2.10: Reset gradients for next iteration
        if src_emb.grad is not None:
            src_emb.grad.zero_()

        # Stop if <EOS> token
        if pred_token == tgt_vocab['<EOS>']:
            break

        trg_idx.append(pred_token)

    # Step 2.11: Convert indices to characters
    trg_tokens = [idx2tgt[i] for i in trg_idx if i not in [tgt_vocab['<SOS>'], tgt_vocab['<EOS>'], tgt_vocab['<PAD>']]]
    translation = ''.join(trg_tokens)

    # Step 2.12: Create connectivity matrix
    connectivity = np.zeros((len(trg_tokens), len(src_text)))
    for i, grad in enumerate(gradient_list[:len(trg_tokens)]):
        if i < len(trg_tokens):
            connectivity[i, :len(src_text)] = grad[:len(src_text)]
            # Normalize each row
            if np.sum(connectivity[i]) > 0:
                connectivity[i] = connectivity[i] / np.max(connectivity[i])

    return translation, connectivity


In [14]:
# Install Bokeh if needed
!pip install bokeh -q

# Import necessary libraries
import numpy as np
import torch
import pandas as pd
from bokeh.plotting import figure, show, output_notebook
from bokeh.models import ColumnDataSource, HoverTool, LinearColorMapper, ColorBar
from bokeh.layouts import column, row, gridplot
from bokeh.palettes import Viridis256, Reds256
from bokeh.io import output_file, save
from bokeh.models import CustomJS, TapTool

In [38]:
def create_bokeh_character_boxes_plot(model, src_text, src_vocab, tgt_vocab, idx2tgt):
    """
    Create an interactive plot with character boxes:
    - Output characters in boxes on top row
    - Input characters in boxes on bottom row
    - Hovering over output boxes highlights input boxes based on connectivity
    - No tooltips displayed
    """
    output_notebook()

    # Step 1: Compute connectivity
    translation, connectivity = compute_gradient_connectivity(model, src_text, src_vocab, tgt_vocab, idx2tgt)

    # Step 2: Prepare data for character boxes
    src_chars = list(src_text)
    tgt_chars = list(translation)

    # Create data sources for output and input characters
    output_data = {
        'x': list(range(len(tgt_chars))),
        'y': [0] * len(tgt_chars),
        'char': tgt_chars,
        'index': list(range(len(tgt_chars)))
    }
    output_source = ColumnDataSource(data=output_data)

    input_data = {
        'x': list(range(len(src_chars))),
        'y': [0] * len(src_chars),
        'char': src_chars,
        'color': ['#e6e6e6'] * len(src_chars),  # Light gray default
        'alpha': [1.0] * len(src_chars)
    }
    input_source = ColumnDataSource(data=input_data)

    # Step 3: Create the figures for output and input boxes
    # Output character boxes (top)
    output_plot = figure(
        title="Output Characters (Devanagari)",
        x_range=(-0.5, len(tgt_chars) - 0.5),
        y_range=(-0.5, 0.5),
        width=600, height=100,
        tools="hover",
        toolbar_location=None
    )

    # Input character boxes (bottom)
    input_plot = figure(
        title="Input Characters (Latin)",
        x_range=(-0.5, len(src_chars) - 0.5),
        y_range=(-0.5, 0.5),
        width=600, height=100,
        tools="",
        toolbar_location=None
    )

    # Step 4: Add character boxes as rectangles
    # Output boxes
    output_rect = output_plot.rect(
        x='x', y='y', width=0.9, height=0.9,
        source=output_source,
        fill_color="#64b5f6",  # Light blue
        line_color="black",
        line_width=2
    )

    # Output text
    output_text = output_plot.text(
        x='x', y='y', text='char',
        source=output_source,
        text_align="center",
        text_baseline="middle",
        text_font_size="16px"
    )

    # Input boxes
    input_rect = input_plot.rect(
        x='x', y='y', width=0.9, height=0.9,
        source=input_source,
        fill_color='color',
        line_color="black",
        line_width=2
    )

    # Input text
    input_text = input_plot.text(
        x='x', y='y', text='char',
        source=input_source,
        text_align="center",
        text_baseline="middle",
        text_font_size="16px"
    )

    # Step 5: Add hover tool for output boxes with no tooltips
    hover_tool = HoverTool(
        renderers=[output_rect],
        tooltips=None,  # Set tooltips to None to hide them
        callback=CustomJS(args=dict(
            input_source=input_source,
            connectivity=connectivity.tolist()
        ), code="""
            // Get the index of the hovered output character
            const index = cb_data.index.indices[0];
            if (index !== undefined) {
                // Get connectivity values for this output character
                const conn_row = connectivity[index];

                // Update input box colors based on connectivity
                const colors = input_source.data['color'];

                for (let i = 0; i < colors.length; i++) {
                    // Get connectivity value
                    const weight = conn_row[i];

                    // Create color based on weight (white to red gradient)
                    const r = 255;
                    const g = Math.max(0, Math.round(255 * (1 - weight)));
                    const b = Math.max(0, Math.round(255 * (1 - weight)));

                    colors[i] = `rgb(${r}, ${g}, ${b})`;
                }

                // Notify of data change
                input_source.change.emit();
            }
        """)
    )
    output_plot.add_tools(hover_tool)

    # Step 6: Remove grid lines and axes
    output_plot.grid.grid_line_color = None
    output_plot.axis.visible = False
    input_plot.grid.grid_line_color = None
    input_plot.axis.visible = False

    # Step 7: Create layout
    layout = column(
        output_plot,
        input_plot,
        sizing_mode="stretch_width"
    )

    return layout

In [19]:
config = {'input_vocab_size': 30, 'output_vocab_size': 67, 'embedding_dim': 128, 'hidden_dim': 512, 'num_encoding_layers': 2, 'num_decoding_layers': 3, 'dropout': 0.3, 'cell_type': 'lstm', 'teacher_forcing_ratio': 0.9, 'learning_rate': 0.0006293179845087059, 'batch_size': 128}

In [20]:
model = AttentionSeq2Seq(config).to(device)

In [21]:
model.load_state_dict(torch.load("/content/best_attention_model.pt"))

<All keys matched successfully>

In [39]:
interactive_viz = create_bokeh_character_boxes_plot(model, "ankganit", src_vocab, tgt_vocab, idx2tgt)
show(interactive_viz)

In [50]:
import wandb
from bokeh.io import save, output_file

def log_interactive_plot_to_wandb(model, src_text, src_vocab, tgt_vocab, idx2tgt, run=None):
    """
    Create an interactive connectivity plot and log it to Weights & Biases.

    Args:
        model: Your transliteration model
        src_text: Source text (Latin characters)
        src_vocab: Source vocabulary mapping
        tgt_vocab: Target vocabulary mapping
        idx2tgt: Mapping from indices to target characters
        run: Optional wandb run object (if None, uses current run)
    """
    # Create the interactive plot
    interactive_plot = create_bokeh_character_boxes_plot(model, src_text, src_vocab, tgt_vocab, idx2tgt)

    # Save the plot to an HTML file
    output_file(f"connectivity_plot_{src_text}.html")
    save(interactive_plot)

    # Initialize W&B if not already initialized
    if run is None and wandb.run is None:
        wandb.init(project="seq2seq-attention-transliteration")

    # Log the HTML file to W&B
    if run:
        run.log({f"connectivity_{src_text}": wandb.Html(f"connectivity_plot_{src_text}.html")})
    else:
        wandb.log({f"connectivity_{src_text}": wandb.Html(f"connectivity_plot_{src_text}.html")})

    wandb.finish()

    print(f"Interactive plot for '{src_text}' saved to W&B")

In [51]:
log_interactive_plot_to_wandb(model, "ankur", src_vocab, tgt_vocab, idx2tgt)

Interactive plot for 'ankur' saved to W&B
