In [None]:
pip install torch wandb pandas tqdm

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import random
import wandb
import editdistance
import numpy as np
import os
from tqdm import tqdm
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)
LANG = 'te'
data_path = f'/content/drive/MyDrive/dakshina_dataset_v1.0/{LANG}/lexicons/'

def read_data(filepath, max_len=40):
    pairs = []
    # Open the file with UTF-8 encoding to properly read Unicode characters
    with open(filepath, encoding='utf8') as f:
        for line in f:
            # Remove leading/trailing whitespace and split by tab
            parts = line.strip().split('\t')
            # Skip lines that don't contain both source and target text
            if len(parts) < 2:
                continue
            devanagiri, latin = parts[0], parts[1]

            # We are training a Latin → Devanagiri transliteration model,
            # so set Latin as the source and Devanagiri as the target
            source, target = latin, devanagiri

            # Only keep pairs where both source and target are within the allowed max length
            if len(source) <= max_len and len(target) <= max_len:
                pairs.append((source, target))

    # Return the list of filtered (source, target) pairs
    return pairs

def make_vocab(sequences):
    # Initialize the vocabulary with special tokens
    vocab = {'<pad>': 0, '<sos>': 1, '<eos>': 2}
    idx = 3  # Starting index for regular characters

    # Loop through all sequences to build the vocabulary
    for seq in sequences:
        for ch in seq:
            # Add each unique character to the vocabulary
            if ch not in vocab:
                vocab[ch] = idx
                idx += 1

    # Create reverse mapping from index to character
    idx2char = {i: c for c, i in vocab.items()}

    # Return both the character-to-index and index-to-character dictionaries
    return vocab, idx2char

def encode_word(word, vocab):
    # Convert a word into a list of indices using the vocabulary
    # Add <sos> token at the beginning and <eos> token at the end
    return [vocab['<sos>']] + [vocab[ch] for ch in word] + [vocab['<eos>']]

def pad_seq(seq, max_len, pad_idx=0):
    # Pad the sequence with <pad> tokens (default index 0) to reach max_len
    return seq + [pad_idx] * (max_len - len(seq))

class TransliterationDataset(Dataset):
    def __init__(self, pairs, source_vocab, target_vocab):
        # Save padding indices for both source and target vocabularies
        self.source_pad = source_vocab['<pad>']
        self.target_pad = target_vocab['<pad>']
        self.data = []

        # Convert each (source, target) word pair into sequences of token indices
        for source, target in pairs:
            source_t = encode_word(source, source_vocab)
            target_t = encode_word(target, target_vocab)
            self.data.append((source_t, target_t))

        # Determine the maximum lengths of source and target sequences
        self.source_max = max(len(x[0]) for x in self.data)
        self.target_max = max(len(x[1]) for x in self.data)

    def __len__(self):
        # Return total number of samples in the dataset
        return len(self.data)

    def __getitem__(self, idx):
        # Fetch a source-target pair and pad both to their respective max lengths
        source, target = self.data[idx]
        source = pad_seq(source, self.source_max, self.source_pad)
        target = pad_seq(target, self.target_max, self.target_pad)
        return torch.tensor(source), torch.tensor(target)

class Attention(nn.Module):
    def __init__(self, hid_dimensions):
        super().__init__()
        # Linear layer to compute attention scores from hidden and encoder outputs
        self.attn = nn.Linear(hid_dimensions * 2, hid_dimensions)

        # Learnable vector used to reduce the attention scores to a scalar
        self.v = nn.Parameter(torch.rand(hid_dimensions))

        # Initialize vector weights uniformly
        stdv = 1. / (hid_dimensions ** 0.5)
        self.v.data.uniform_(-stdv, stdv)

        self.hid_dimensions = hid_dimensions

    def forward(self, hidden, encoder_outputs):
        # hidden: decoder hidden state
        # encoder_outputs: all encoder outputs for the input sequence

        batch_size = encoder_outputs.size(0)
        src_len = encoder_outputs.size(1)

        # If hidden state has multiple layers, take the last one
        if hidden.dim() == 3:
            hidden = hidden[-1]
        elif hidden.dim() != 2:
            raise ValueError(f"Expected hidden to be 2D or 3D, got shape {hidden.shape}")

        # Repeat hidden state to match the number of encoder outputs
        hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)

        # Concatenate hidden and encoder outputs, then pass through a non-linear layer
        energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim=2)))

        # Compute raw attention scores using the learnable vector `v`
        energy = energy @ self.v

        # Normalize scores into a probability distribution (attention weights)
        attn_weights = torch.softmax(energy, dim=1).unsqueeze(2)

        # Compute weighted sum of encoder outputs (context vector)
        context = torch.sum(attn_weights * encoder_outputs, dim=1)

        # Return both the context vector and the attention weights
        return context, attn_weights.squeeze(2)

class translit_Decoder(nn.Module):
    def __init__(self, output_dimensions, emb_dimensions, hid_dimensions, num_layers, dropout, cell='lstm'):
        super().__init__()

        # Embedding layer to convert token indices into dense vectors
        self.embedding = nn.Embedding(output_dimensions, emb_dimensions)

        # Attention module to focus on relevant parts of the encoder output
        self.attention = Attention(hid_dimensions)

        # Choose RNN type based on user-specified cell type
        rnn_cls = {'rnn': nn.RNN, 'gru': nn.GRU, 'lstm': nn.LSTM}[cell.lower()]

        # RNN layer to process embedded inputs and context
        self.rnn = rnn_cls(
            emb_dimensions, hid_dimensions, num_layers,
            dropout=dropout if num_layers > 1 else 0,
            batch_first=True
        )

        # Final fully connected layer to map combined context + RNN output to vocabulary logits
        self.fc_out = nn.Linear(hid_dimensions * 2, output_dimensions)

        # Store the type of RNN cell
        self.cell = cell.lower()

        # Apply dropout to the embeddings
        self.dropout = nn.Dropout(dropout)

    def forward(self, input, hidden, cell, encoder_outputs):
        # Add time dimension to input (batch_size → batch_size x 1)
        input = input.unsqueeze(1)

        # Convert input token index to embedding and apply dropout
        embedded = self.dropout(self.embedding(input))

        # Pass through the RNN (handle LSTM and others differently)
        if self.cell == 'lstm':
            output, (hidden, cell) = self.rnn(embedded, (hidden, cell))
        else:
            output, hidden = self.rnn(embedded, hidden)
            cell = None  # Non-LSTM cells don't return a separate cell state

        # Use attention mechanism to compute context vector from encoder outputs
        context, attn_weights = self.attention(hidden, encoder_outputs)

        # Remove time dimension from RNN output
        rnn_output = output.squeeze(1)

        # Combine RNN output and context for final prediction
        combined = torch.cat((rnn_output, context), dim=1)

        # Compute the predicted output token scores
        prediction = self.fc_out(combined)

        # Return prediction, updated hidden/cell states, and attention weights
        return prediction, hidden, cell, attn_weights


class translit_Encoder(nn.Module):
    def __init__(self, input_dimensions, emb_dimensions, hid_dimensions, num_layers, dropout, cell='lstm'):
        super().__init__()

        # Embedding layer to convert input indices into dense vectors
        self.embedding = nn.Embedding(input_dimensions, emb_dimensions)

        # Choose RNN type based on cell argument
        rnn_cls = {'rnn': nn.RNN, 'gru': nn.GRU, 'lstm': nn.LSTM}[cell.lower()]

        # RNN layer to process the embedded input sequence
        self.rnn = rnn_cls(
            emb_dimensions, hid_dimensions, num_layers,
            dropout=dropout if num_layers > 1 else 0,
            batch_first=True
        )

        # Store attention module and cell type
        self.attention = Attention(hid_dimensions)
        self.cell = cell.lower()

        # Dropout layer for regularization
        self.dropout = nn.Dropout(dropout)

    def forward(self, source):
        # Convert input token indices into embeddings and apply dropout
        embedded = self.dropout(self.embedding(source))

        # Pass embedded input through RNN
        if self.cell == 'lstm':
            outputs, (hidden, cell) = self.rnn(embedded)
        else:
            outputs, hidden = self.rnn(embedded)
            cell = None

        # Compute context using attention (optional, can be ignored in basic encoder usage)
        context = self.attention(hidden, outputs)

        # Return the full sequence of encoder outputs, last hidden state, and cell state (if any)
        return outputs, hidden, cell

class translit_Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device):
        super().__init__()
        self.encoder = encoder        # Encoder processes the input sequence
        self.decoder = decoder        # Decoder generates the output sequence
        self.device = device          # Device on which computation is performed (CPU/GPU)

    def forward(self, source, target, teacher_forcing_ratio=0.5):
        batch_size = source.size(0)
        target_len = target.size(1)
        output_dimensions = self.decoder.fc_out.out_features

        # Initialize tensor to store decoder predictions for each time step
        outputs = torch.zeros(batch_size, target_len, output_dimensions).to(self.device)

        # Initialize tensor to keep track of attention weights over time
        attn_weights_all = torch.zeros(batch_size, target_len, source.size(1)).to(self.device)

        # Run the encoder on the source sequence to get hidden states
        encoder_outputs, hidden, cell = self.encoder(source)

        # Set initial decoder input to the <sos> token
        input = target[:, 0]

        # Loop over each time step in the target sequence
        for t in range(1, target_len):
            # Get decoder output and updated hidden states
            output, hidden, cell, attn_weights = self.decoder(input, hidden, cell, encoder_outputs)

            # Store the current output prediction
            outputs[:, t] = output

            # Save attention weights for this time step
            attn_weights_all[:, t] = attn_weights

            # Decide whether to use ground truth or model prediction for next input
            teacher_force = random.random() < teacher_forcing_ratio
            top1 = output.argmax(1)
            input = target[:, t] if teacher_force else top1

        # Return the full sequence of predictions and attention weights
        return outputs, attn_weights_all


def strip_after_eos(seq, eos_idx):
    # Convert tensor to list if needed
    if isinstance(seq, torch.Tensor):
        seq = seq.cpu().numpy().tolist()
    # Trim the sequence at the first <eos> token
    if eos_idx in seq:
        return seq[:seq.index(eos_idx)]
    return seq

def calculate_word_accuracy(preds, targets, pad_idx=0, eos_idx=None):
    correct = 0
    for pred, target in zip(preds, targets):
        # Remove padding and stop at <eos> for fair comparison
        pred = strip_after_eos(pred, eos_idx) if eos_idx else pred
        target = strip_after_eos(target, eos_idx) if eos_idx else target
        pred = [p for p in pred if p != pad_idx]
        target = [t for t in target if t != pad_idx]
        # Count if full predicted word matches target
        correct += int(pred == target)
    return correct / max(len(preds), 1)


def calculate_cer(preds, targets, pad_idx=0, eos_idx=None):
    cer = 0
    total = 0
    for pred, target in zip(preds, targets):
        # Clean sequences by removing padding and trimming after <eos>
        pred = strip_after_eos(pred, eos_idx) if eos_idx else pred
        target = strip_after_eos(target, eos_idx) if eos_idx else target
        pred = [p for p in pred if p != pad_idx]
        target = [t for t in target if t != pad_idx]
        # Accumulate edit distance and total characters
        cer += editdistance.eval(pred, target)
        total += max(len(target), 1)
    return cer / total if total > 0 else float('inf')


def calculate_accuracy(preds, targets, pad_idx=0, eos_idx=None):
    correct = 0
    total = 0
    for pred, target in zip(preds, targets):
        # Convert tensors to lists if necessary
        if isinstance(pred, torch.Tensor):
            pred = pred.cpu().tolist()
        if isinstance(target, torch.Tensor):
            target = target.cpu().tolist()
        # Strip <eos> tokens if specified
        if eos_idx is not None:
            pred = strip_after_eos(pred, eos_idx)
            target = strip_after_eos(target, eos_idx)
        # Compare tokens one by one, ignoring padding
        for p_token, t_token in zip(pred, target):
            if t_token == pad_idx:
                continue
            if p_token == t_token:
                correct += 1
            total += 1
    return correct / total if total > 0 else 0.0

cpu


In [None]:
!pip install dash
!pip install plotly

In [None]:
wandb.login()

In [None]:
import numpy as np
import wandb
import os

run = wandb.init(project="attention-viz-2", entity="sai-sakunthala-indian-institute-of-technology-madras", name="evaluate_test")
artifact = run.use_artifact('sai-sakunthala-indian-institute-of-technology-madras/dakshina-seq2seq-3/best_model:v48', type='model')
artifact_dir = artifact.download()
run.finish()

# Read data and create vocabularies
test_pairs = read_data(data_path + f"{LANG}.translit.sampled.test.tsv", max_len=30)
train_pairs = read_data(data_path + f"{LANG}.translit.sampled.train.tsv", max_len=30)
source_vocab, idx2char_src = make_vocab([x[0] for x in train_pairs])
target_vocab, idx2char_tgt = make_vocab([x[1] for x in train_pairs])

# Model parameters (must match training)
input_dimensions = len(source_vocab)
output_dimensions = len(target_vocab)
emb_dimensions = 256
hid_dimensions = 256 * 2
num_layers = 2
dropout = 0.2
cell = 'lstm'
batch_size = 64
max_len = 30

# Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Initialize model
encoder = translit_Encoder(input_dimensions, emb_dimensions, hid_dimensions, num_layers, dropout, cell).to(device)
decoder = translit_Decoder(output_dimensions, emb_dimensions, hid_dimensions, num_layers, dropout, cell).to(device)
model = translit_Seq2Seq(encoder, decoder, device).to(device)

# Load model weights
state_dict = torch.load(f"{artifact_dir}/best_model.pt", map_location=device)
model.load_state_dict(state_dict)
model.eval()

# Create test dataset and loader
test_translit = TransliterationDataset(test_pairs, source_vocab, target_vocab)
test_loader = DataLoader(test_translit, batch_size=batch_size, shuffle=False, drop_last=True)

all_src, all_preds, all_tgts, all_attn_weights = [], [], [], []
correct = 0
total = 0
selected_examples = []

def predict(model, src, max_len=30):

    encoder_outputs, encoder_hidden, encoder_cell = model.encoder(src)
    input = torch.tensor([target_vocab['<sos>']] * src.size(0)).to(device)
    outputs = []
    attn_weights_list = []

    for t in range(max_len):
        output, encoder_hidden, encoder_cell, attn_weights = model.decoder(input, encoder_hidden, encoder_cell, encoder_outputs)
        input = output.argmax(1)
        outputs.append(input)
        attn_weights_list.append(attn_weights)

        if (input == target_vocab.get('<eos>', -1)).all():
            break

    outputs = torch.stack(outputs, dim=1)  # (batch_size, max_len)
    attn_weights_all = torch.stack(attn_weights_list, dim=1)  # (batch_size, max_len, src_len)
    return outputs, attn_weights_all

with torch.no_grad():
    for src, tgt in tqdm(test_loader):
        src, tgt = src.to(device), tgt.to(device)
        preds, attn_weights_batch = predict(model, src)

        # Convert to numpy arrays for processing
        src_np = src.cpu().numpy()
        preds_np = preds.cpu().numpy()
        tgt_np = tgt.cpu().numpy()
        attn_weights_np = attn_weights_batch.cpu().numpy()

        for i in range(len(src_np)):
            # Get source, prediction, target, and attention weights
            s = src_np[i]
            p = preds_np[i]
            t = tgt_np[i]
            attn = attn_weights_np[i]

            # Store sequences and attention weights for all examples
            all_src.append(s)
            all_preds.append(p)
            all_tgts.append(t)
            all_attn_weights.append(attn)

            # Collect up to 12 examples for heatmaps
            if len(selected_examples) < 12:
                selected_examples.append((s, p, t, attn))

            # Process prediction: remove padding and everything after EOS
            p_processed = []
            for token in p:
                if token == target_vocab.get('<eos>', -1):
                    break
                if token not in [target_vocab.get('<pad>', -1), target_vocab.get('<sos>', -1)]:
                    p_processed.append(token)

            # Process target: remove padding and everything after EOS
            t_processed = []
            for token in t:
                if token == target_vocab.get('<eos>', -1):
                    break
                if token not in [target_vocab.get('<pad>', -1), target_vocab.get('<sos>', -1)]:
                    t_processed.append(token)

            # Compare the processed sequences
            if p_processed == t_processed:
                correct += 1
            total += 1

# Calculate accuracy
accuracy = correct / total if total > 0 else 0
print(f"Test Accuracy: {accuracy:.4f}")
print(f"Correct: {correct}, Total: {total}")

In [None]:
# Initialize a list to store correctly predicted samples
correctly_predicted_samples = []

with torch.no_grad():
    for src, tgt in tqdm(test_loader):
        src, tgt = src.to(device), tgt.to(device)
        preds, attn_weights_batch = predict(model, src)

        # Convert to numpy arrays for processing
        src_np = src.cpu().numpy()
        preds_np = preds.cpu().numpy()
        tgt_np = tgt.cpu().numpy()
        attn_weights_np = attn_weights_batch.cpu().numpy()

        for i in range(len(src_np)):
            # Get source, prediction, target, and attention weights
            s = src_np[i]
            p = preds_np[i]
            t = tgt_np[i]
            attn = attn_weights_np[i]

            # Process prediction: remove padding and everything after EOS
            p_processed = []
            for token in p:
                if token == target_vocab.get('<eos>', -1):
                    break
                if token not in [target_vocab.get('<pad>', -1), target_vocab.get('<sos>', -1)]:
                    p_processed.append(token)

            # Process target: remove padding and everything after EOS
            t_processed = []
            for token in t:
                if token == target_vocab.get('<eos>', -1):
                    break
                if token not in [target_vocab.get('<pad>', -1), target_vocab.get('<sos>', -1)]:
                    t_processed.append(token)

            # Compare the processed sequences
            if p_processed == t_processed:
                correctly_predicted_samples.append((s, p, t, attn))  # Store the correct prediction

                # Stop if we have collected 4 correct predictions
                if len(correctly_predicted_samples) >= 10:
                    break

        # Break the outer loop if we have enough correct predictions
        if len(correctly_predicted_samples) >= 10:
            break

  0%|          | 0/89 [00:03<?, ?it/s]


In [None]:
visualization_data = []
selected_examples = correctly_predicted_samples
for i in range(len(selected_examples)):
    s, p, t, attn = selected_examples[i]

    # Convert indices to characters
    src_word = ''.join([idx2char_src.get(idx, '?') for idx in s if idx not in [source_vocab.get('<pad>', -1), source_vocab.get('<sos>', -1), source_vocab.get('<eos>', -1)]])
    pred_word = ''.join([idx2char_tgt.get(idx, '?') for idx in p if idx not in [target_vocab.get('<pad>', -1), target_vocab.get('<sos>', -1), target_vocab.get('<eos>', -1)]])
    ref_word = ''.join([idx2char_tgt.get(idx, '?') for idx in t if idx not in [target_vocab.get('<pad>', -1), target_vocab.get('<sos>', -1), target_vocab.get('<eos>', -1)]])

    # Find the source characters with maximum attention for each target character
    max_attention_chars = []
    for attn_weights in attn:
        max_index = attn_weights.argmax()  # Get the index of the max attention weight
        max_attention_char = idx2char_src.get(max_index, '?')  # Get the corresponding source character
        max_attention_chars.append(max_attention_char)
    # Store the visualization data
    visualization_data.append({
        "src_word": src_word,
        "pred_word": pred_word,
        "ref_word": ref_word,
        "max_attention_chars": max_attention_chars
    })

# Example of how to print or log the visualization data
for data in visualization_data:
    print(f"Source: {data['src_word']}, Predicted: {data['pred_word']}, Reference: {data['ref_word']}, Max Attention: {''.join(data['max_attention_chars'])}")

In [None]:
import wandb
from PIL import Image, ImageDraw, ImageFont
import json
import base64
import os

# Initialize W&B
run = wandb.init(project="attention-viz-2", entity="sai-sakunthala-indian-institute-of-technology-madras", name="interactive_image_visualization")

visualization_data = [
    {"src_word":"amkamlo","pred_word":"అంకంలో","ref_word":"అంకంలో","max_attention_chars":"&lt;sos&gt;amitthhh"},
    {"src_word":"ankamlo","pred_word":"అంకంలో","ref_word":"అంకాలో","max_attention_chars":"&lt;sos&gt;amitthhhh"},
    {"src_word":"ankamloo","pred_word":"అంకంలో","ref_word":"అంకంలో","max_attention_chars":"&lt;sos&gt;amitnbbb"},
    {"src_word":"amkitamai","pred_word":"అంకితమై","ref_word":"అంకితమై","max_attention_chars":"&lt;sos&gt;amkinbccc"},
    {"src_word":"ankitamai","pred_word":"అంకితమై","ref_word":"అంకితమై","max_attention_chars":"&lt;sos&gt;aakknbccc"},
    {"src_word":"ankela","pred_word":"అంకెల","ref_word":"అంకెల","max_attention_chars":"&lt;sos&gt;aakittt"},
    {"src_word":"ankelanu","pred_word":"అంకెలను","ref_word":"అంకెలను","max_attention_chars":"&lt;sos&gt;aakithbb"},
    {"src_word":"angeekarinchaka","pred_word":"అంగీకరించక","ref_word":"అంగీకరించక","max_attention_chars":"aamktnvcduses"},
    {"src_word":"amgiikarimchaadu","pred_word":"అంగీకరించాడు","ref_word":"అంగీకరించాడు","max_attention_chars":"&lt;sos&gt;amktnvcduseyll"},
    {"src_word":"angeekarinchaadu","pred_word":"అంగీకరించాడు","ref_word":"అంగీకరించాడు","max_attention_chars":"aamktnvcduseyll"},
]

# Preprocess: Remove <sos> tags and HTML-escaped chars (if required)
for d in visualization_data:
    d['max_attention_chars'] = d['max_attention_chars'].replace('&lt;sos&gt;', '').replace('&lt;', '<').replace('&gt;', '>')

# Combine logic with line break after 25 characters per line
max_chars_per_line = 25
font_path = "/content/NotoSansTelugu-VariableFont.ttf"
font_size = 40
line_height = font_size + 10

# lines: list of dicts with char and max_attention chars line wise
lines = []
current_line_chars = []
current_line_attention = []

for d in visualization_data:
    tgt_word = d['pred_word']
    attention_word = d['max_attention_chars']

    for i, ch in enumerate(tgt_word):
        att_char = attention_word[i] if i < len(attention_word) else ''
        current_line_chars.append(ch)
        current_line_attention.append(att_char)

        if len(current_line_chars) == max_chars_per_line:
            lines.append((current_line_chars, current_line_attention))
            current_line_chars = []
            current_line_attention = []

# Add the last line if any chars remain
if current_line_chars:
    lines.append((current_line_chars, current_line_attention))

# Calculate image height
image_width = 1200
image_height = line_height * len(lines) + 20

# Create image and draw text line by line
font = ImageFont.truetype(font_path, font_size)
image = Image.new("RGB", (image_width, image_height), "white")
draw = ImageDraw.Draw(image)

char_positions = []
y_offset = 10

for line_idx, (chars, attns) in enumerate(lines):
    current_x = 10
    y = y_offset + line_idx * line_height

    for idx, ch in enumerate(chars):
        att_char = attns[idx] if idx < len(attns) else ''
        char_bbox = draw.textbbox((current_x, y), ch, font=font)
        char_width = char_bbox[2] - char_bbox[0]
        char_height = char_bbox[3] - char_bbox[1]

        draw.text((current_x, y), ch, font=font, fill="black")

        char_positions.append({
            "char": ch,
            "x_min": current_x,
            "x_max": current_x + char_width,
            "y_min": y,
            "y_max": y + char_height,
            "max_attention": att_char
        })

        current_x += char_width

# Save image and encode
image_path = "combined_telugu_sentence_multiline.png"
image.save(image_path)
with open(image_path, "rb") as f:
    image_base64 = base64.b64encode(f.read()).decode()

# Generate HTML
html_content = f"""
<!DOCTYPE html>
<html>
<head>
    <style>
        #canvas {{ position: relative; border: 1px solid #ccc; }}
        #tooltip {{
            position: absolute;
            background: rgba(0,0,0,0.8);
            color: white;
            padding: 5px 10px;
            border-radius: 4px;
            font-family: 'Noto Sans Telugu', Arial, sans-serif;
            font-size: 14px;
            pointer-events: none;
            display: none;
            z-index: 1000;
            white-space: nowrap;
        }}
    </style>
    <link href="https://fonts.googleapis.com/css2?family=Noto+Sans+Telugu&display=swap" rel="stylesheet">
</head>
<body>
    <canvas id="canvas" width="{image_width}" height="{image_height}"></canvas>
    <div id="tooltip"></div>
    <script>
        (function() {{
            const canvas = document.getElementById('canvas');
            const ctx = canvas.getContext('2d');
            const img = new Image();
            img.src = 'data:image/png;base64,{image_base64}';
            img.onload = function() {{
                ctx.drawImage(img, 0, 0);
            }};

            const regions = {json.dumps(char_positions)};
            const tooltip = document.getElementById('tooltip');

            function escapeHtml(text) {{
                if (!text) return '';
                return text.replace(/[&<>"']/g, function(m) {{
                    return {{'&':'&amp;','<':'&lt;','>':'&gt;','"':'&quot;',"'":'&#39;'}}[m];
                }});
            }}

            canvas.addEventListener('mousemove', function(e) {{
                const rect = canvas.getBoundingClientRect();
                const x = e.clientX - rect.left;
                const y = e.clientY - rect.top;

                let found = false;
                for (let region of regions) {{
                    if (x >= region.x_min && x <= region.x_max &&
                        y >= region.y_min && y <= region.y_max) {{
                            tooltip.style.display = 'block';
                            tooltip.style.left = (e.clientX + 10) + 'px';
                            tooltip.style.top = (e.clientY + 10) + 'px';

                            const maxAtt = region.max_attention ? escapeHtml(region.max_attention) : '(none)';
                            tooltip.innerHTML = 'Telugu Character: ' + escapeHtml(region.char) + '<br>Max Attention Char: ' + maxAtt;
                            found = true;
                            break;
                    }}
                }}

                if (!found) {{
                    tooltip.style.display = 'none';
                }}
            }});

            canvas.addEventListener('mouseout', function() {{
                tooltip.style.display = 'none';
            }});
        }})();
    </script>
</body>
</html>
"""

# Log to W&B
wandb.log({
    "combined_telugu_attention_visualization_multiline": wandb.Html(html_content)
})

try:
    os.remove(image_path)
except OSError:
    pass

run.finish()
