<a href="https://colab.research.google.com/github/D4deben/DA6401_Assignment3/blob/main/DL_Assignment4_PartB.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Question 2

In [None]:
import torch
import torch.nn as nn
import random

class Encoder(nn.Module):
    """
    Encoder part of the seq2seq model.
    It takes a sequence of input characters and produces a context vector.
    """
    def __init__(self, input_vocab_size, embedding_dim, hidden_dim, n_layers,
                 cell_type='LSTM', dropout_p=0.1, bidirectional=False):
        """
        Initializes the Encoder.
        """
        super(Encoder, self).__init__()

        self.input_vocab_size = input_vocab_size
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.n_layers = n_layers
        self.cell_type = cell_type.upper()
        self.bidirectional = bidirectional
        self.num_directions = 2 if bidirectional else 1

        self.embedding = nn.Embedding(input_vocab_size, embedding_dim)
        self.dropout = nn.Dropout(dropout_p)

        rnn_dropout = dropout_p if n_layers > 1 else 0
        if self.cell_type == 'RNN':
            self.rnn = nn.RNN(embedding_dim, hidden_dim, n_layers,
                              dropout=rnn_dropout, batch_first=True, bidirectional=self.bidirectional)
        elif self.cell_type == 'GRU':
            self.rnn = nn.GRU(embedding_dim, hidden_dim, n_layers,
                              dropout=rnn_dropout, batch_first=True, bidirectional=self.bidirectional)
        elif self.cell_type == 'LSTM':
            self.rnn = nn.LSTM(embedding_dim, hidden_dim, n_layers,
                               dropout=rnn_dropout, batch_first=True, bidirectional=self.bidirectional)
        else:
            raise ValueError("Unsupported cell type. Choose from 'RNN', 'GRU', 'LSTM'.")

    def forward(self, input_seq):
        """
        Forward pass of the encoder.
        """
        embedded = self.embedding(input_seq)
        embedded = self.dropout(embedded)

        outputs, hidden = self.rnn(embedded)

        return outputs, hidden

class Decoder(nn.Module):
    """
    Decoder part of the seq2seq model.
    It takes the encoder's context vector and generates an output sequence.
    """
    def __init__(self, output_vocab_size, embedding_dim, hidden_dim, n_layers,
                 cell_type='LSTM', dropout_p=0.1, encoder_bidirectional=False):
        """
        Initializes the Decoder.
        """
        super(Decoder, self).__init__()

        self.output_vocab_size = output_vocab_size
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.n_layers = n_layers
        self.cell_type = cell_type.upper()
        self.encoder_bidirectional = encoder_bidirectional


        self.embedding = nn.Embedding(output_vocab_size, embedding_dim)
        self.dropout = nn.Dropout(dropout_p)

        rnn_dropout = dropout_p if n_layers > 1 else 0
        if self.cell_type == 'RNN':
            self.rnn = nn.RNN(embedding_dim, hidden_dim, n_layers,
                              dropout=rnn_dropout, batch_first=True)
        elif self.cell_type == 'GRU':
            self.rnn = nn.GRU(embedding_dim, hidden_dim, n_layers,
                              dropout=rnn_dropout, batch_first=True)
        elif self.cell_type == 'LSTM':
            self.rnn = nn.LSTM(embedding_dim, hidden_dim, n_layers,
                               dropout=rnn_dropout, batch_first=True)
        else:
            raise ValueError("Unsupported cell type. Choose from 'RNN', 'GRU', 'LSTM'.")

        self.fc_out = nn.Linear(hidden_dim, output_vocab_size)

    def forward(self, input_char, hidden_state):
        """
        Forward pass for a single decoding step.
        """
        input_char = input_char.unsqueeze(1)

        embedded = self.embedding(input_char)
        embedded = self.dropout(embedded)

        rnn_output, new_hidden_state = self.rnn(embedded, hidden_state)

        rnn_output_squeezed = rnn_output.squeeze(1)

        prediction = self.fc_out(rnn_output_squeezed)

        return prediction, new_hidden_state

class Seq2Seq(nn.Module):
    """
    The main Seq2Seq model that combines Encoder and Decoder.
    """
    def __init__(self, encoder, decoder, device, target_sos_idx):
        """
        Initializes the Seq2Seq model.
        """
        super(Seq2Seq, self).__init__()

        self.encoder = encoder
        self.decoder = decoder
        self.device = device
        self.target_sos_idx = target_sos_idx

        if self.encoder.bidirectional:
            encoder_hidden_dim_actual = self.encoder.hidden_dim * self.encoder.num_directions
            decoder_hidden_dim_expected = self.decoder.hidden_dim

            self.fc_hidden = nn.Linear(encoder_hidden_dim_actual, decoder_hidden_dim_expected)
            if self.encoder.cell_type == 'LSTM':
                self.fc_cell = nn.Linear(encoder_hidden_dim_actual, decoder_hidden_dim_expected)


    def _adapt_encoder_hidden(self, encoder_hidden):
        """
        Adapts the encoder's final hidden state to be suitable as the decoder's initial hidden state.
        Handles bidirectional encoders by combining forward and backward states.
        """
        if not self.encoder.bidirectional:
            return encoder_hidden

        if self.encoder.cell_type == 'LSTM':
            h_n, c_n = encoder_hidden

            h_n = h_n.view(self.encoder.n_layers, self.encoder.num_directions, h_n.size(1), self.encoder.hidden_dim)
            c_n = c_n.view(self.encoder.n_layers, self.encoder.num_directions, c_n.size(1), self.encoder.hidden_dim)

            h_n_cat = torch.cat((h_n[:, 0, :, :], h_n[:, 1, :, :]), dim=2)
            c_n_cat = torch.cat((c_n[:, 0, :, :], c_n[:, 1, :, :]), dim=2)

            adapted_h = self.fc_hidden(h_n_cat)
            adapted_c = self.fc_cell(c_n_cat)
            return (adapted_h, adapted_c)

        else:
            h_n = encoder_hidden
            h_n = h_n.view(self.encoder.n_layers, self.encoder.num_directions, h_n.size(1), self.encoder.hidden_dim)
            h_n_cat = torch.cat((h_n[:, 0, :, :], h_n[:, 1, :, :]), dim=2)
            adapted_h = self.fc_hidden(h_n_cat)
            return adapted_h


    def forward(self, source_seq, target_seq, teacher_forcing_ratio=0.5):
        """
        Forward pass of the Seq2Seq model.
        """
        batch_size = source_seq.shape[0]
        target_len = target_seq.shape[1]
        target_vocab_size = self.decoder.output_vocab_size

        outputs = torch.zeros(batch_size, target_len, target_vocab_size).to(self.device)

        _, encoder_final_hidden = self.encoder(source_seq)

        decoder_hidden = self._adapt_encoder_hidden(encoder_final_hidden)

        decoder_input = target_seq[:, 0]

        for t in range(target_len -1):
            decoder_output_logits, decoder_hidden = self.decoder(decoder_input, decoder_hidden)

            outputs[:, t+1] = decoder_output_logits

            teacher_force_this_step = random.random() < teacher_forcing_ratio

            top1_predicted_token = decoder_output_logits.argmax(1)

            decoder_input = target_seq[:, t+1] if teacher_force_this_step else top1_predicted_token

        return outputs

    def predict(self, source_seq, max_output_len=50):
        """
        Generate a sequence of characters given a source sequence during inference.
        No teacher forcing is used.
        """
        self.eval()
        batch_size = source_seq.shape[0]
        if batch_size != 1:
            raise ValueError("Predict function currently supports batch_size=1 for simplicity.")

        with torch.no_grad():
            _, encoder_final_hidden = self.encoder(source_seq)
            decoder_hidden = self._adapt_encoder_hidden(encoder_final_hidden)

            decoder_input = torch.tensor([self.target_sos_idx], device=self.device)

            predicted_indices = []

            for _ in range(max_output_len):
                decoder_output_logits, decoder_hidden = self.decoder(decoder_input, decoder_hidden)

                top1_predicted_token = decoder_output_logits.argmax(1)
                predicted_idx = top1_predicted_token.item()
                predicted_indices.append(predicted_idx)

                decoder_input = top1_predicted_token

        self.train()
        return predicted_indices


if __name__ == '__main__':
    INPUT_VOCAB_SIZE = 50
    OUTPUT_VOCAB_SIZE = 60
    TARGET_SOS_IDX = 0
    TARGET_EOS_IDX = 1

    EMBEDDING_DIM = 128
    HIDDEN_DIM_ENCODER = 256
    HIDDEN_DIM_DECODER = 256
    N_LAYERS_ENCODER = 2
    N_LAYERS_DECODER = 2
    CELL_TYPE = 'LSTM'
    DROPOUT = 0.3
    ENCODER_BIDIRECTIONAL = True

    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    encoder = Encoder(INPUT_VOCAB_SIZE, EMBEDDING_DIM, HIDDEN_DIM_ENCODER, N_LAYERS_ENCODER,
                      CELL_TYPE, DROPOUT, bidirectional=ENCODER_BIDIRECTIONAL).to(DEVICE)

    decoder = Decoder(OUTPUT_VOCAB_SIZE, EMBEDDING_DIM, HIDDEN_DIM_DECODER, N_LAYERS_DECODER,
                      CELL_TYPE, DROPOUT, encoder_bidirectional=ENCODER_BIDIRECTIONAL).to(DEVICE)

    model = Seq2Seq(encoder, decoder, DEVICE, TARGET_SOS_IDX).to(DEVICE)

    print(f"Model initialized on {DEVICE}")
    print(f"Encoder cell type: {encoder.cell_type}, Layers: {encoder.n_layers}, Hidden: {encoder.hidden_dim}, Bidirectional: {encoder.bidirectional}")
    print(f"Decoder cell type: {decoder.cell_type}, Layers: {decoder.n_layers}, Hidden: {decoder.hidden_dim}")

    BATCH_SIZE = 4
    SOURCE_SEQ_LEN = 10
    TARGET_SEQ_LEN = 12

    dummy_source_seq = torch.randint(0, INPUT_VOCAB_SIZE, (BATCH_SIZE, SOURCE_SEQ_LEN)).to(DEVICE)

    dummy_target_seq = torch.randint(1, OUTPUT_VOCAB_SIZE, (BATCH_SIZE, TARGET_SEQ_LEN)).to(DEVICE)
    dummy_target_seq[:, 0] = TARGET_SOS_IDX

    print(f"\nDummy source shape: {dummy_source_seq.shape}")
    print(f"Dummy target shape: {dummy_target_seq.shape}")

    model.train()
    output_logits = model(dummy_source_seq, dummy_target_seq, teacher_forcing_ratio=0.5)
    print(f"\nOutput logits shape from forward pass: {output_logits.shape}")

    model.eval()
    dummy_single_source_seq = torch.randint(0, INPUT_VOCAB_SIZE, (1, SOURCE_SEQ_LEN)).to(DEVICE)
    predicted_sequence_indices = model.predict(dummy_single_source_seq, max_output_len=TARGET_SEQ_LEN)
    print(f"\nPredicted sequence indices for a single source: {predicted_sequence_indices}")
    print(f"Length of predicted sequence: {len(predicted_sequence_indices)}")

    def count_parameters(m):
        return sum(p.numel() for p in m.parameters() if p.requires_grad)

    print(f'\nThe model has {count_parameters(model):,} trainable parameters')

    print("\n--- Example of changing parameters ---")
    encoder_gru_unidir = Encoder(INPUT_VOCAB_SIZE, embedding_dim=64, hidden_dim=128, n_layers=1,
                                 cell_type='GRU', dropout_p=0.1, bidirectional=False).to(DEVICE)
    decoder_gru_unidir = Decoder(OUTPUT_VOCAB_SIZE, embedding_dim=64, hidden_dim=128, n_layers=1,
                                 cell_type='GRU', dropout_p=0.1, encoder_bidirectional=False).to(DEVICE)
    model_gru_unidir = Seq2Seq(encoder_gru_unidir, decoder_gru_unidir, DEVICE, TARGET_SOS_IDX).to(DEVICE)
    print(f"GRU Unidirectional Model initialized on {DEVICE}")
    print(f"Encoder cell type: {encoder_gru_unidir.cell_type}, Layers: {encoder_gru_unidir.n_layers}, Hidden: {encoder_gru_unidir.hidden_dim}, Bidirectional: {encoder_gru_unidir.bidirectional}")
    print(f"Decoder cell type: {decoder_gru_unidir.cell_type}, Layers: {decoder_gru_unidir.n_layers}, Hidden: {decoder_gru_unidir.hidden_dim}")
    print(f'The GRU Unidir model has {count_parameters(model_gru_unidir):,} trainable parameters')


# Q3 - Vanilla Model

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
import random
import numpy as np
import wandb
import os
from collections import Counter
from tqdm import tqdm
import heapq

# --- Constants ---
SOS_TOKEN = "<sos>"
EOS_TOKEN = "<eos>"
PAD_TOKEN = "<pad>"
UNK_TOKEN = "<unk>"

# --- Model Definition (Encoder, Decoder, Seq2Seq) ---
class Encoder(nn.Module):
    def __init__(self, input_vocab_size, embedding_dim, hidden_dim, n_layers,
                 cell_type='LSTM', dropout_p=0.1, bidirectional=False, pad_idx=0):
        super(Encoder, self).__init__()
        self.input_vocab_size = input_vocab_size
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.n_layers = n_layers
        self.cell_type = cell_type.upper()
        self.bidirectional = bidirectional
        self.num_directions = 2 if bidirectional else 1

        self.embedding = nn.Embedding(input_vocab_size, embedding_dim, padding_idx=pad_idx)
        self.dropout = nn.Dropout(dropout_p)

        rnn_dropout = dropout_p if n_layers > 1 else 0
        if self.cell_type == 'RNN':
            self.rnn = nn.RNN(embedding_dim, hidden_dim, n_layers,
                              dropout=rnn_dropout, batch_first=True, bidirectional=self.bidirectional)
        elif self.cell_type == 'GRU':
            self.rnn = nn.GRU(embedding_dim, hidden_dim, n_layers,
                              dropout=rnn_dropout, batch_first=True, bidirectional=self.bidirectional)
        elif self.cell_type == 'LSTM':
            self.rnn = nn.LSTM(embedding_dim, hidden_dim, n_layers,
                               dropout=rnn_dropout, batch_first=True, bidirectional=self.bidirectional)
        else:
            raise ValueError("Unsupported cell type. Choose from 'RNN', 'GRU', 'LSTM'.")

    def forward(self, input_seq, input_lengths):
        embedded = self.embedding(input_seq)
        embedded = self.dropout(embedded)

        packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, input_lengths.cpu(), batch_first=True, enforce_sorted=False)
        packed_outputs, hidden = self.rnn(packed_embedded)
        return None, hidden

class Decoder(nn.Module):
    def __init__(self, output_vocab_size, embedding_dim,
                 decoder_hidden_dim, n_layers, cell_type='LSTM', dropout_p=0.1, pad_idx=0):
        super(Decoder, self).__init__()
        self.output_vocab_size = output_vocab_size
        self.embedding_dim = embedding_dim
        self.decoder_hidden_dim = decoder_hidden_dim
        self.n_layers = n_layers
        self.cell_type = cell_type.upper()

        self.embedding = nn.Embedding(output_vocab_size, embedding_dim, padding_idx=pad_idx)
        self.dropout = nn.Dropout(dropout_p)

        rnn_input_dim = embedding_dim

        rnn_dropout = dropout_p if n_layers > 1 else 0
        if self.cell_type == 'RNN':
            self.rnn = nn.RNN(rnn_input_dim, decoder_hidden_dim, n_layers,
                              dropout=rnn_dropout, batch_first=True)
        elif self.cell_type == 'GRU':
            self.rnn = nn.GRU(rnn_input_dim, decoder_hidden_dim, n_layers,
                              dropout=rnn_dropout, batch_first=True)
        elif self.cell_type == 'LSTM':
            self.rnn = nn.LSTM(rnn_input_dim, decoder_hidden_dim, n_layers,
                               dropout=rnn_dropout, batch_first=True)
        else:
            raise ValueError("Unsupported cell type. Choose from 'RNN', 'GRU', 'LSTM'.")

        self.fc_out = nn.Linear(decoder_hidden_dim, output_vocab_size)

    def forward(self, input_char, prev_decoder_hidden):
        input_char = input_char.unsqueeze(1)
        embedded = self.embedding(input_char)
        embedded = self.dropout(embedded)

        rnn_output, current_decoder_hidden = self.rnn(embedded, prev_decoder_hidden)

        rnn_output_squeezed = rnn_output.squeeze(1)
        prediction_logits = self.fc_out(rnn_output_squeezed)
        return prediction_logits, current_decoder_hidden

class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device, target_sos_idx):
        super(Seq2Seq, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device
        self.target_sos_idx = target_sos_idx

        encoder_effective_output_dim_per_layer = self.encoder.hidden_dim * self.encoder.num_directions
        decoder_rnn_expected_hidden_dim = self.decoder.decoder_hidden_dim

        self.needs_dim_adaptation = encoder_effective_output_dim_per_layer != decoder_rnn_expected_hidden_dim
        self.fc_adapt_hidden = None
        self.fc_adapt_cell = None

        if self.needs_dim_adaptation:
            self.fc_adapt_hidden = nn.Linear(encoder_effective_output_dim_per_layer, decoder_rnn_expected_hidden_dim)
            if self.encoder.cell_type == 'LSTM':
                self.fc_adapt_cell = nn.Linear(encoder_effective_output_dim_per_layer, decoder_rnn_expected_hidden_dim)

    def _adapt_encoder_hidden_for_decoder(self, encoder_final_hidden_state):
        is_lstm = self.encoder.cell_type == 'LSTM'
        if is_lstm:
            h_from_enc, c_from_enc = encoder_final_hidden_state
        else:
            h_from_enc = encoder_final_hidden_state
            c_from_enc = None

        batch_size = h_from_enc.size(1)

        h_processed = h_from_enc.view(self.encoder.n_layers, self.encoder.num_directions,
                                     batch_size, self.encoder.hidden_dim)
        if self.encoder.bidirectional:
            h_processed = torch.cat([h_processed[:, 0, :, :], h_processed[:, 1, :, :]], dim=2)
        else:
            h_processed = h_processed.squeeze(1)

        c_processed = None
        if is_lstm and c_from_enc is not None:
            c_processed = c_from_enc.view(self.encoder.n_layers, self.encoder.num_directions,
                                         batch_size, self.encoder.hidden_dim)
            if self.encoder.bidirectional:
                c_processed = torch.cat([c_processed[:, 0, :, :], c_processed[:, 1, :, :]], dim=2)
            else:
                c_processed = c_processed.squeeze(1)

        if self.needs_dim_adaptation:
            h_processed = self.fc_adapt_hidden(h_processed)
            if is_lstm and c_processed is not None and self.fc_adapt_cell:
                c_processed = self.fc_adapt_cell(c_processed)

        final_h = torch.zeros(self.decoder.n_layers, batch_size, self.decoder.decoder_hidden_dim, device=self.device)
        final_c = torch.zeros(self.decoder.n_layers, batch_size, self.decoder.decoder_hidden_dim, device=self.device) if is_lstm else None

        if self.encoder.n_layers == self.decoder.n_layers:
            final_h = h_processed
            if is_lstm: final_c = c_processed
        elif self.encoder.n_layers > self.decoder.n_layers:
            final_h = h_processed[-self.decoder.n_layers:, :, :]
            if is_lstm and c_processed is not None:
                final_c = c_processed[-self.decoder.n_layers:, :, :]
        else:
            final_h[:self.encoder.n_layers, :, :] = h_processed
            if is_lstm and c_processed is not None:
                final_c[:self.encoder.n_layers, :, :] = c_processed

            if self.encoder.n_layers > 0:
                last_h_layer_to_repeat = h_processed[self.encoder.n_layers-1, :, :]
                for i in range(self.encoder.n_layers, self.decoder.n_layers):
                    final_h[i, :, :] = last_h_layer_to_repeat
                    if is_lstm and c_processed is not None:
                        last_c_layer_to_repeat = c_processed[self.encoder.n_layers-1, :, :]
                        final_c[i, :, :] = last_c_layer_to_repeat

        return (final_h, final_c) if is_lstm else final_h

    def forward(self, source_seq, source_lengths, target_seq, teacher_forcing_ratio=0.5):
        batch_size = source_seq.shape[0]
        target_len = target_seq.shape[1]
        target_vocab_size = self.decoder.output_vocab_size

        outputs_logits = torch.zeros(batch_size, target_len, target_vocab_size).to(self.device)

        _, encoder_final_hidden = self.encoder(source_seq, source_lengths)
        decoder_hidden = self._adapt_encoder_hidden_for_decoder(encoder_final_hidden)

        decoder_input = target_seq[:, 0]

        for t in range(target_len - 1):
            decoder_output_logits, decoder_hidden = \
                self.decoder(decoder_input, decoder_hidden)

            outputs_logits[:, t+1] = decoder_output_logits

            teacher_force_this_step = random.random() < teacher_forcing_ratio
            top1_predicted_token = decoder_output_logits.argmax(1)

            decoder_input = target_seq[:, t+1] if teacher_force_this_step else top1_predicted_token

        return outputs_logits

    def predict_greedy(self, source_seq, source_lengths, max_output_len=50, target_eos_idx=None):
        if source_seq.dim() == 1:
            source_seq = source_seq.unsqueeze(0)
            source_lengths = torch.tensor([source_lengths if isinstance(source_lengths, int) else len(source_lengths)], device=self.device)

        with torch.no_grad():
            _, encoder_final_hidden = self.encoder(source_seq, source_lengths)
            decoder_hidden = self._adapt_encoder_hidden_for_decoder(encoder_final_hidden)

            decoder_input = torch.tensor([self.target_sos_idx], device=self.device)
            predicted_indices = []

            for _ in range(max_output_len):
                decoder_output_logits, decoder_hidden = \
                    self.decoder(decoder_input, decoder_hidden)

                top1_predicted_token = decoder_output_logits.argmax(1)
                predicted_idx = top1_predicted_token.item()

                if target_eos_idx is not None and predicted_idx == target_eos_idx:
                    break
                predicted_indices.append(predicted_idx)
                decoder_input = top1_predicted_token
        return predicted_indices

    def predict_beam_search(self, source_seq, source_lengths, max_output_len=50, target_eos_idx=None, beam_width=3):
        if source_seq.dim() == 1:
            source_seq = source_seq.unsqueeze(0)
            source_lengths = torch.tensor([source_lengths if isinstance(source_lengths, int) else len(source_lengths)], device=self.device)

        batch_size = source_seq.shape[0]
        if batch_size != 1:
            raise ValueError("Beam search predict function currently supports batch_size=1.")

        with torch.no_grad():
            _, encoder_final_hidden = self.encoder(source_seq, source_lengths)
            decoder_hidden_init = self._adapt_encoder_hidden_for_decoder(encoder_final_hidden)

            beams = [(0.0, [self.target_sos_idx], decoder_hidden_init)]
            completed_sequences = []

            for _ in range(max_output_len):
                new_beams = []
                if len(completed_sequences) >= beam_width and all(b[1][-1] == target_eos_idx for b in beams if b[1]):
                    break

                for log_prob_beam, seq_beam, hidden_beam in beams:
                    if not seq_beam or seq_beam[-1] == target_eos_idx:
                        completed_sequences.append((log_prob_beam / len(seq_beam) if len(seq_beam) > 0 else -float('inf'), seq_beam))
                        continue

                    decoder_input = torch.tensor([seq_beam[-1]], device=self.device)

                    decoder_output_logits, next_hidden_beam = \
                        self.decoder(decoder_input, hidden_beam)

                    log_probs_next_token = F.log_softmax(decoder_output_logits, dim=1)
                    topk_log_probs, topk_indices = torch.topk(log_probs_next_token, beam_width, dim=1)

                    for k in range(beam_width):
                        next_token_idx = topk_indices[0, k].item()
                        token_log_prob = topk_log_probs[0, k].item()

                        new_seq = seq_beam + [next_token_idx]
                        new_log_prob = log_prob_beam + token_log_prob
                        new_beams.append((new_log_prob, new_seq, next_hidden_beam))

                if not new_beams: break

                new_beams.sort(key=lambda x: x[0], reverse=True)
                beams = new_beams[:beam_width]

            for log_prob_beam, seq_beam, _ in beams:
                if not seq_beam or seq_beam[-1] != target_eos_idx :
                    completed_sequences.append((log_prob_beam / len(seq_beam) if len(seq_beam) > 0 else -float('inf'), seq_beam))

            if not completed_sequences:
                return [target_eos_idx] if target_eos_idx is not None else []

            completed_sequences.sort(key=lambda x: x[0], reverse=True)

            best_sequence_indices = completed_sequences[0][1]
            return best_sequence_indices[1:] if best_sequence_indices and best_sequence_indices[0] == self.target_sos_idx else best_sequence_indices

# --- Data Loading and Preprocessing ---
class Vocabulary:
    def __init__(self, name):
        self.name = name
        self.char2index = {PAD_TOKEN: 0, SOS_TOKEN: 1, EOS_TOKEN: 2, UNK_TOKEN: 3}
        self.index2char = {0: PAD_TOKEN, 1: SOS_TOKEN, 2: EOS_TOKEN, 3: UNK_TOKEN}
        self.char_counts = Counter()
        self.n_chars = 4
        self.pad_idx = self.char2index[PAD_TOKEN]
        self.sos_idx = self.char2index[SOS_TOKEN]
        self.eos_idx = self.char2index[EOS_TOKEN]

    def add_sequence(self, sequence):
        for char in list(sequence):
            self.char_counts[char] += 1

    def build_vocab(self, min_freq=1):
        sorted_chars = sorted(self.char_counts.keys(), key=lambda char: (-self.char_counts[char], char))
        for char in sorted_chars:
            if self.char_counts[char] >= min_freq:
                if char not in self.char2index:
                    self.char2index[char] = self.n_chars
                    self.index2char[self.n_chars] = char
                    self.n_chars += 1

    def sequence_to_indices(self, sequence, add_eos=False, add_sos=False):
        indices = []
        if add_sos:
            indices.append(self.sos_idx)
        for char in list(sequence):
            indices.append(self.char2index.get(char, self.char2index[UNK_TOKEN]))
        if add_eos:
            indices.append(self.eos_idx)
        return indices

    def indices_to_sequence(self, indices):
        chars = []
        for index_val in indices:
            char = self.index2char.get(index_val, UNK_TOKEN)
            if index_val == self.eos_idx:
                break
            if index_val != self.sos_idx and index_val != self.pad_idx:
                chars.append(char)
        return "".join(chars)

class TransliterationDataset(Dataset):
    def __init__(self, file_path, source_vocab, target_vocab, max_len=None):
        self.pairs = []
        self.source_vocab = source_vocab
        self.target_vocab = target_vocab

        if not os.path.exists(file_path):
            print(f"ERROR: Data file not found during Dataset init: {file_path}")
            return

        print(f"Loading data from: {file_path}")
        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                for line_num, line in enumerate(f):
                    parts = line.strip().split('\t')
                    if len(parts) >= 2:
                        target_sequence, source_sequence = parts[0], parts[1]

                        if max_len and (len(source_sequence) > max_len or len(target_sequence) > max_len):
                            continue
                        if not source_sequence or not target_sequence:
                            continue
                        self.pairs.append((source_sequence, target_sequence))
            print(f"Loaded {len(self.pairs)} pairs from {file_path}.")
        except Exception as e:
            print(f"ERROR: Could not read or process file {file_path}. Error: {e}")

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        if idx >= len(self.pairs):
            raise IndexError("Index out of bounds for dataset")
        source_str, target_str = self.pairs[idx]
        source_indices = self.source_vocab.sequence_to_indices(source_str, add_eos=True)
        target_indices = self.target_vocab.sequence_to_indices(target_str, add_sos=True, add_eos=True)
        return torch.tensor(source_indices, dtype=torch.long), \
               torch.tensor(target_indices, dtype=torch.long)

def collate_fn(batch, pad_idx_source, pad_idx_target):
    batch = [item for item in batch if item is not None and item[0] is not None and item[1] is not None]
    if not batch:
        return None, None, None

    source_seqs, target_seqs = zip(*batch)

    valid_indices = [i for i, s in enumerate(source_seqs) if len(s) > 0]
    if not valid_indices: return None, None, None

    source_seqs = [source_seqs[i] for i in valid_indices]
    target_seqs = [target_seqs[i] for i in valid_indices]
    source_lengths = torch.tensor([len(s) for s in source_seqs], dtype=torch.long)

    padded_sources = pad_sequence(source_seqs, batch_first=True, padding_value=pad_idx_source)
    padded_targets = pad_sequence(target_seqs, batch_first=True, padding_value=pad_idx_target)
    return padded_sources, source_lengths, padded_targets

# --- Training and Evaluation Functions ---
def train_epoch(model, dataloader, optimizer, criterion, device, clip_value, teacher_forcing_ratio, target_vocab):
    model.train()
    epoch_loss = 0
    total_correct_train = 0
    total_samples_train = 0

    if len(dataloader) == 0:
        print("Warning: Training dataloader is empty.")
        return 0.0, 0.0

    for batch_data in tqdm(dataloader, desc="Training", leave=False):
        if batch_data[0] is None: continue
        sources, source_lengths, targets = batch_data

        if sources is None or source_lengths is None or targets is None or source_lengths.numel() == 0 or sources.shape[0] == 0:
            print("Warning: Empty or invalid batch data after collate_fn in training. Skipping.")
            continue

        sources, targets, source_lengths = sources.to(device), targets.to(device), source_lengths.to(device)

        optimizer.zero_grad()
        outputs_logits = model(sources, source_lengths, targets, teacher_forcing_ratio)

        output_dim = outputs_logits.shape[-1]
        flat_outputs = outputs_logits[:, 1:].reshape(-1, output_dim)
        flat_targets = targets[:, 1:].reshape(-1)

        loss = criterion(flat_outputs, flat_targets)
        if torch.isnan(loss) or torch.isinf(loss):
            print("Warning: NaN or Inf loss detected in training. Skipping batch.")
            continue
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip_value)
        optimizer.step()
        epoch_loss += loss.item()

        predictions_indices_train = outputs_logits.argmax(dim=2)
        for i in range(targets.shape[0]):
            pred_str_train = target_vocab.indices_to_sequence(predictions_indices_train[i, 1:].tolist())
            true_str_train = target_vocab.indices_to_sequence(targets[i, 1:].tolist())
            if pred_str_train == true_str_train:
                total_correct_train += 1
            total_samples_train += 1

    avg_epoch_loss = epoch_loss / len(dataloader) if len(dataloader) > 0 else 0.0
    train_accuracy = total_correct_train / total_samples_train if total_samples_train > 0 else 0.0
    return avg_epoch_loss, train_accuracy

def evaluate(model, dataloader, criterion, device, target_vocab, beam_width=1, target_eos_idx=None):
    model.eval()
    epoch_loss = 0
    total_correct = 0
    total_samples = 0

    if len(dataloader) == 0:
        print("WARNING: Validation dataloader is empty. Returning 0 loss and 0 accuracy.")
        return 0.0, 0.0

    print_debug_once = True

    with torch.no_grad():
        for batch_idx, batch_data in enumerate(tqdm(dataloader, desc="Evaluating", leave=False)):
            if batch_data[0] is None: continue
            sources, source_lengths, targets = batch_data

            if sources is None or source_lengths is None or targets is None or source_lengths.numel() == 0 or sources.shape[0] == 0:
                print("Warning: Empty or invalid batch data in eval after collate_fn. Skipping.")
                continue

            sources, targets, source_lengths = sources.to(device), targets.to(device), source_lengths.to(device)

            outputs_for_loss = model(sources, source_lengths, targets, teacher_forcing_ratio=0.0)
            output_dim = outputs_for_loss.shape[-1]
            flat_outputs_for_loss = outputs_for_loss[:, 1:].reshape(-1, output_dim)
            flat_targets_for_loss = targets[:, 1:].reshape(-1)

            loss = criterion(flat_outputs_for_loss, flat_targets_for_loss)
            if torch.isnan(loss) or torch.isinf(loss):
                print("Warning: NaN or Inf loss detected in evaluation. Skipping batch for loss accumulation.")
            else:
                epoch_loss += loss.item()

            for i in range(sources.shape[0]):
                src_single = sources[i:i+1]
                src_len_single = source_lengths[i:i+1]

                if beam_width > 1 and hasattr(model, 'predict_beam_search'):
                    predicted_indices = model.predict_beam_search(src_single, src_len_single,
                                                                  max_output_len=targets.size(1),
                                                                  target_eos_idx=target_eos_idx,
                                                                  beam_width=beam_width)
                elif hasattr(model, 'predict_greedy'):
                    predicted_indices = model.predict_greedy(src_single, src_len_single,
                                                             max_output_len=targets.size(1),
                                                             target_eos_idx=target_eos_idx)
                else:
                    predicted_indices = outputs_for_loss[i:i+1].argmax(dim=2)[0, 1:].tolist()

                pred_str = target_vocab.indices_to_sequence(predicted_indices)
                true_seq_indices = targets[i, 1:].tolist()
                true_str = target_vocab.indices_to_sequence(true_seq_indices)

                if pred_str == true_str:
                    total_correct += 1
                total_samples += 1

                if print_debug_once and batch_idx == 0 and i < 1 :
                    print(f"\n--- Evaluation Debug Sample {i} (Batch {batch_idx}) ---")
                    print(f"    Source (indices): {src_single[0, :15].tolist()}")
                    print(f"    Predicted Indices: {predicted_indices[:15]}")
                    print(f"    Predicted String: '{pred_str}'")
                    print(f"    True Indices: {true_seq_indices[:15]}")
                    print(f"    True String: '{true_str}'")
                    print(f"    Match: {pred_str == true_str}")

            if print_debug_once and batch_idx == 0:
                print_debug_once = False

    avg_epoch_loss = epoch_loss / len(dataloader) if len(dataloader) > 0 else 0.0
    accuracy = total_correct / total_samples if total_samples > 0 else 0.0

    print(f"Evaluation - Total Correct: {total_correct}, Total Samples: {total_samples}, Calculated Accuracy: {accuracy:.4f}, Avg Loss: {avg_epoch_loss:.4f}")
    return avg_epoch_loss, accuracy

# --- Main Training Function for W&B ---
def train_model():
    with wandb.init() as run:
        config = wandb.config

        run_name = f"{config.cell_type}_emb{config.embedding_dim}_hid{config.hidden_dim}" \
                   f"_encL{config.encoder_layers}_decL{config.decoder_layers}" \
                   f"_do{config.dropout_p:.2f}_lr{config.learning_rate:.2e}" \
                   f"_encBi{str(config.encoder_bidirectional)[0]}" \
                   f"_beam{config.get('beam_width_eval', 1)}"

        if hasattr(run, 'name'): run.name = run_name

        DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        current_run_name_for_log = run.name if run.name else run.id
        print(f"Using device: {DEVICE}. Run: {current_run_name_for_log} (ID: {run.id})")
        print(f"Config: {config}")

        BASE_DATA_DIR = "/kaggle/input/dakshina-dl-a3/dakshina_dataset_v1.0/hi/"
        DATA_DIR = os.path.join(BASE_DATA_DIR, "lexicons/")

        if not os.path.exists(DATA_DIR):
            print(f"ERROR: Lexicons directory not found: {DATA_DIR}")
            return 1

        train_file = os.path.join(DATA_DIR, "hi.translit.sampled.train.tsv")
        dev_file = os.path.join(DATA_DIR, "hi.translit.sampled.dev.tsv")

        if not os.path.exists(train_file) or not os.path.exists(dev_file):
            print(f"ERROR: Train or Dev file not found. Train: {train_file}, Dev: {dev_file}")
            return 1

        source_vocab = Vocabulary("latin")
        target_vocab = Vocabulary("devanagari")

        temp_train_dataset_for_vocab = TransliterationDataset(train_file, source_vocab, target_vocab, max_len=config.max_seq_len)
        if not temp_train_dataset_for_vocab.pairs:
            print(f"ERROR: No data loaded for vocab building from {train_file}.")
            return 1
        for src_str, tgt_str in temp_train_dataset_for_vocab.pairs:
            source_vocab.add_sequence(src_str)
            target_vocab.add_sequence(tgt_str)
        source_vocab.build_vocab(min_freq=config.vocab_min_freq)
        target_vocab.build_vocab(min_freq=config.vocab_min_freq)

        print(f"Source Vocab: {source_vocab.n_chars} chars. Target Vocab: {target_vocab.n_chars} chars.")

        train_dataset = TransliterationDataset(train_file, source_vocab, target_vocab, max_len=config.max_seq_len)
        dev_dataset = TransliterationDataset(dev_file, source_vocab, target_vocab, max_len=config.max_seq_len)

        if len(train_dataset) == 0 or len(dev_dataset) == 0:
            print(f"ERROR: Train/Dev dataset empty. Train: {len(train_dataset)}, Dev: {len(dev_dataset)}")
            return 1

        num_loader_workers = 0
        if DEVICE.type == 'cuda' and os.cpu_count() and os.cpu_count() > 1:
            num_loader_workers = min(4, os.cpu_count() // 2)

        train_dataloader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True,
                                      collate_fn=lambda b: collate_fn(b, source_vocab.pad_idx, target_vocab.pad_idx),
                                      num_workers=num_loader_workers, pin_memory=True if DEVICE.type == 'cuda' else False, drop_last=True)
        dev_dataloader = DataLoader(dev_dataset, batch_size=config.batch_size, shuffle=False,
                                     collate_fn=lambda b: collate_fn(b, source_vocab.pad_idx, target_vocab.pad_idx),
                                     num_workers=num_loader_workers, pin_memory=True if DEVICE.type == 'cuda' else False, drop_last=False)

        if len(train_dataloader) == 0 or len(dev_dataloader) == 0:
            print(f"ERROR: Train/Dev Dataloader empty. Train: {len(train_dataloader)}, Dev: {len(dev_dataloader)}")
            return 1

        encoder = Encoder(source_vocab.n_chars, config.embedding_dim, config.hidden_dim,
                          config.encoder_layers, config.cell_type, config.dropout_p,
                          config.encoder_bidirectional, pad_idx=source_vocab.pad_idx).to(DEVICE)
        decoder = Decoder(target_vocab.n_chars, config.embedding_dim,
                          config.hidden_dim,
                          config.decoder_layers, config.cell_type, config.dropout_p,
                          pad_idx=target_vocab.pad_idx).to(DEVICE)
        model = Seq2Seq(encoder, decoder, DEVICE, target_vocab.sos_idx).to(DEVICE)

        print(f"Model parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
        wandb.watch(model, log="all", log_freq=100)

        optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=config.get('lr_scheduler_patience', 3), factor=0.3, verbose=True)
        criterion = nn.CrossEntropyLoss(ignore_index=target_vocab.pad_idx)

        best_val_accuracy = -1.0
        epochs_no_improve = 0
        max_epochs_no_improve = config.get('max_epochs_no_improve', 7)

        for epoch in range(config.epochs):
            train_loss, train_accuracy = train_epoch(model, train_dataloader, optimizer, criterion, DEVICE,
                                                     config.clip_value, config.teacher_forcing_ratio, target_vocab)

            val_loss, val_accuracy = evaluate(model, dev_dataloader, criterion, DEVICE, target_vocab,
                                              beam_width=config.get('beam_width_eval', 1),
                                              target_eos_idx=target_vocab.eos_idx)

            scheduler.step(val_accuracy)

            print(f"Epoch {epoch+1}/{config.epochs} | Train Loss: {train_loss:.4f} | Train Acc: {train_accuracy:.4f} | Val Loss: {val_loss:.4f} | Val Acc: {val_accuracy:.4f}")
            log_dict = {
                "epoch": epoch + 1,
                "train_loss": train_loss if not np.isnan(train_loss) else 0.0,
                "train_accuracy": train_accuracy if not np.isnan(train_accuracy) else 0.0,
                "val_loss": val_loss if not np.isnan(val_loss) else 0.0,
                "val_accuracy": val_accuracy if not np.isnan(val_accuracy) else 0.0,
                "learning_rate": optimizer.param_groups[0]['lr']
            }
            wandb.log(log_dict)

            current_val_acc = val_accuracy if not np.isnan(val_accuracy) else -1.0
            if current_val_acc > best_val_accuracy:
                best_val_accuracy = current_val_acc
                epochs_no_improve = 0
            else:
                epochs_no_improve += 1

            if config.early_stopping and epochs_no_improve >= max_epochs_no_improve:
                print(f"Early stopping triggered at epoch {epoch+1} after {epochs_no_improve} epochs with no improvement.")
                break

        wandb.summary["best_val_accuracy"] = best_val_accuracy if not np.isnan(best_val_accuracy) else 0.0
        print(f"Finished run. Best Validation Accuracy: {best_val_accuracy:.4f}")
        return 0

# --- W&B Sweep Configuration ---
sweep_config = {
    'method': 'bayes',
    'metric': {'name': 'val_accuracy', 'goal': 'maximize'},
    'parameters': {
        'embedding_dim': {'values': [128, 256, 300]},
        'hidden_dim': {'values': [256, 512]},
        'encoder_layers': {'values': [1, 2]},
        'decoder_layers': {'values': [1, 2]},
        'cell_type': {'values': ['GRU', 'LSTM']},
        'dropout_p': {'values': [0.2, 0.3, 0.4, 0.5]},
        'encoder_bidirectional': {'values': [True, False]},
        'learning_rate': {'distribution': 'log_uniform_values', 'min': 1e-4, 'max': 3e-3},
        'batch_size': {'values': [32, 64, 128]},
        'epochs': {'value': 15},
        'clip_value': {'value': 1.0},
        'teacher_forcing_ratio': {'distribution': 'uniform', 'min': 0.4, 'max': 0.6},
        'vocab_min_freq': {'value': 1},
        'max_seq_len': {'value': 50},
        'early_stopping': {'value': True},
        'max_epochs_no_improve': {'values': [5, 7]},
        'lr_scheduler_patience': {'values': [2, 3]},
        'beam_width_eval': {'values': [1, 3]}
    }
}

# --- Main Execution Block ---
if __name__ == '__main__':
    try:
        if 'KAGGLE_KERNEL_RUN_TYPE' in os.environ:
            print("Detected Kaggle environment. Ensure WANDB_API_KEY is set as a secret or environment variable.")
            try:
                from kaggle_secrets import UserSecretsClient
                user_secrets = UserSecretsClient()
                wandb_api_key = user_secrets.get_secret("WANDB_API_KEY")
                wandb.login(key=wandb_api_key)
                print("W&B login with Kaggle secret successful.")
            except Exception as e_secret:
                print(f"Could not login with Kaggle secret ({e_secret}). Attempting default login.")
                if "WANDB_API_KEY" in os.environ:
                    wandb.login()
                    print("W&B login using environment variable.")
                else:
                    print("WANDB_API_KEY not found. Please set it up or login manually if prompted.")
                    wandb.login()
        else:
            wandb.login()
        print("W&B login process attempted/completed.")
    except Exception as e:
        print(f"W&B login failed: {e}. Please ensure your W&B API key is correctly configured.")
        exit()

    SWEEP_PROJECT_NAME = "DL_A3"

    print("Initializing sweep...")
    try:
        sweep_id = wandb.sweep(sweep_config, project=SWEEP_PROJECT_NAME)
        print(f"Sweep ID: {sweep_id}")
        print(f"To run agents, execute: wandb agent YOUR_WANDB_USERNAME/{SWEEP_PROJECT_NAME}/{sweep_id}")

        print("Starting a W&B agent (adjust count as needed)...")
        wandb.agent(sweep_id, function=train_model, count=10)
    except Exception as e:
        print(f"Could not initialize sweep or run agent. Error: {e}")

    print("Sweep agent finished or stopped. Check W&B dashboard for results.")

# Q4 train and testing

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
import random
import numpy as np
import wandb
import os
from collections import Counter
from tqdm import tqdm
import heapq

# --- Constants for Special Tokens ---
SOS_TOKEN = "<sos>"  # Start-of-sequence token
EOS_TOKEN = "<eos>"  # End-of-sequence token
PAD_TOKEN = "<pad>"  # Padding token
UNK_TOKEN = "<unk>"  # Unknown token

# --- Model Definitions (Non-Attention Version) ---

class Encoder(nn.Module):
    """
    The Encoder processes the input sequence and produces a context vector (final hidden state).
    It uses a recurrent neural network (RNN, GRU, or LSTM).
    """
    def __init__(self, input_vocab_size, embedding_dim, hidden_dim, n_layers,
                 cell_type='LSTM', dropout_p=0.1, bidirectional=False, pad_idx=0):
        super(Encoder, self).__init__()
        self.input_vocab_size = input_vocab_size
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.n_layers = n_layers
        self.cell_type = cell_type.upper()
        self.bidirectional = bidirectional
        self.num_directions = 2 if bidirectional else 1

        # Embedding layer converts input tokens (indices) into dense vectors
        self.embedding = nn.Embedding(input_vocab_size, embedding_dim, padding_idx=pad_idx)
        # Dropout layer for regularization
        self.dropout = nn.Dropout(dropout_p)

        # Apply dropout to RNN layers if n_layers > 1
        rnn_dropout = dropout_p if n_layers > 1 else 0
        if self.cell_type == 'RNN':
            self.rnn = nn.RNN(embedding_dim, hidden_dim, n_layers,
                              dropout=rnn_dropout, batch_first=True, bidirectional=self.bidirectional)
        elif self.cell_type == 'GRU':
            self.rnn = nn.GRU(embedding_dim, hidden_dim, n_layers,
                              dropout=rnn_dropout, batch_first=True, bidirectional=self.bidirectional)
        elif self.cell_type == 'LSTM':
            self.rnn = nn.LSTM(embedding_dim, hidden_dim, n_layers,
                               dropout=rnn_dropout, batch_first=True, bidirectional=self.bidirectional)
        else:
            raise ValueError("Unsupported cell type. Choose from 'RNN', 'GRU', 'LSTM'.")

    def forward(self, input_seq, input_lengths):
        """
        Forward pass for the encoder.

        Args:
            input_seq (torch.Tensor): Padded input sequences of shape (batch_size, seq_len).
            input_lengths (torch.Tensor): Lengths of the original sequences in the batch of shape (batch_size,).

        Returns:
            tuple: A tuple (None, hidden_state).
                   The first element is None because this encoder does not output full sequences for attention.
                   The second element is the final hidden state of the encoder, which serves as the context.
                   For LSTM, `hidden_state` is a tuple (h, c).
        """
        embedded = self.embedding(input_seq)
        embedded = self.dropout(embedded)

        # Pack the padded sequences to handle variable-length inputs efficiently
        packed_embedded = nn.utils.rnn.pack_padded_sequence(
            embedded, input_lengths.cpu(), batch_first=True, enforce_sorted=False
        )
        # Pass packed sequences through the RNN
        # For a non-attention decoder, we primarily need the final hidden state.
        _, hidden = self.rnn(packed_embedded)
        return None, hidden # Return None for encoder_outputs as they are not used by the simple decoder

class Decoder(nn.Module):
    """
    The Decoder generates the output sequence one token at a time, conditioned on the
    encoder's final hidden state and previously generated tokens.
    This version does NOT use attention.
    """
    def __init__(self, output_vocab_size, embedding_dim,
                 decoder_hidden_dim, n_layers, cell_type='LSTM', dropout_p=0.1, pad_idx=0):
        super(Decoder, self).__init__()
        self.output_vocab_size = output_vocab_size
        self.embedding_dim = embedding_dim
        self.decoder_hidden_dim = decoder_hidden_dim
        self.n_layers = n_layers
        self.cell_type = cell_type.upper()

        self.embedding = nn.Embedding(output_vocab_size, embedding_dim, padding_idx=pad_idx)
        self.dropout = nn.Dropout(dropout_p)

        # Input to decoder RNN is just the embedding of the previous token
        rnn_input_dim = embedding_dim

        # Apply dropout to RNN layers if n_layers > 1
        rnn_dropout = dropout_p if n_layers > 1 else 0
        if self.cell_type == 'RNN':
            self.rnn = nn.RNN(rnn_input_dim, decoder_hidden_dim, n_layers,
                              dropout=rnn_dropout, batch_first=True)
        elif self.cell_type == 'GRU':
            self.rnn = nn.GRU(rnn_input_dim, decoder_hidden_dim, n_layers,
                              dropout=rnn_dropout, batch_first=True)
        elif self.cell_type == 'LSTM':
            self.rnn = nn.LSTM(rnn_input_dim, decoder_hidden_dim, n_layers,
                               dropout=rnn_dropout, batch_first=True)
        else:
            raise ValueError("Unsupported cell type. Choose from 'RNN', 'GRU', 'LSTM'.")

        # Output linear layer to project decoder's hidden state to vocabulary size
        self.fc_out = nn.Linear(decoder_hidden_dim, output_vocab_size)

    def forward(self, input_char, prev_decoder_hidden):
        """
        Forward pass for the decoder.

        Args:
            input_char (torch.Tensor): A single token (or batch of single tokens) to be embedded,
                                       shape (batch_size,).
            prev_decoder_hidden (torch.Tensor or tuple): The previous hidden state (and cell state for LSTM)
                                                         of the decoder RNN.

        Returns:
            tuple: A tuple (prediction_logits, current_decoder_hidden).
                   `prediction_logits` are the raw scores for each token in the vocabulary.
                   `current_decoder_hidden` is the updated hidden state after processing the input.
        """
        # Add a sequence dimension for RNN input (batch_size, 1, embedding_dim)
        input_char = input_char.unsqueeze(1)
        embedded = self.embedding(input_char)
        embedded = self.dropout(embedded)

        # Pass through the RNN layer
        rnn_output, current_decoder_hidden = self.rnn(embedded, prev_decoder_hidden)

        # Squeeze the sequence dimension for the linear layer
        rnn_output_squeezed = rnn_output.squeeze(1)
        # Project to vocabulary size to get logits
        prediction_logits = self.fc_out(rnn_output_squeezed)
        return prediction_logits, current_decoder_hidden

class Seq2Seq(nn.Module):
    """
    The main Sequence-to-Sequence model that connects the Encoder and Decoder.
    This architecture does NOT use an attention mechanism.
    """
    def __init__(self, encoder, decoder, device, target_sos_idx):
        super(Seq2Seq, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device
        self.target_sos_idx = target_sos_idx

        # Calculate effective dimensions for hidden state adaptation
        encoder_effective_output_dim_per_layer = self.encoder.hidden_dim * self.encoder.num_directions
        decoder_rnn_expected_hidden_dim = self.decoder.decoder_hidden_dim

        # Check if hidden state dimensions need to be adapted from encoder to decoder
        self.needs_dim_adaptation = encoder_effective_output_dim_per_layer != decoder_rnn_expected_hidden_dim
        self.fc_adapt_hidden = None
        self.fc_adapt_cell = None

        # Create linear layers for hidden state adaptation if necessary
        if self.needs_dim_adaptation:
            self.fc_adapt_hidden = nn.Linear(encoder_effective_output_dim_per_layer, decoder_rnn_expected_hidden_dim)
            if self.encoder.cell_type == 'LSTM':
                self.fc_adapt_cell = nn.Linear(encoder_effective_output_dim_per_layer, decoder_rnn_expected_hidden_dim)

    def _adapt_encoder_hidden_for_decoder(self, encoder_final_hidden_state):
        """
        Adapts the encoder's final hidden state(s) to match the decoder's expected hidden state dimensions and layer count.
        Handles bidirectionality by concatenating forward and backward hidden states.
        If decoder has more layers than encoder, the last encoder layer's state is repeated.
        """
        is_lstm = self.encoder.cell_type == 'LSTM'
        if is_lstm:
            h_from_enc, c_from_enc = encoder_final_hidden_state
        else:
            h_from_enc = encoder_final_hidden_state
            c_from_enc = None

        batch_size = h_from_enc.size(1)

        # Reshape encoder hidden state to (n_layers, num_directions, batch_size, hidden_dim)
        h_processed = h_from_enc.view(self.encoder.n_layers, self.encoder.num_directions,
                                     batch_size, self.encoder.hidden_dim)
        if self.encoder.bidirectional:
            # Concatenate forward and backward hidden states across the hidden dimension
            h_processed = torch.cat([h_processed[:, 0, :, :], h_processed[:, 1, :, :]], dim=2)
        else:
            # Remove the single direction dimension
            h_processed = h_processed.squeeze(1)

        c_processed = None
        if is_lstm and c_from_enc is not None:
            c_processed = c_from_enc.view(self.encoder.n_layers, self.encoder.num_directions,
                                         batch_size, self.encoder.hidden_dim)
            if self.encoder.bidirectional:
                c_processed = torch.cat([c_processed[:, 0, :, :], c_processed[:, 1, :, :]], dim=2)
            else:
                c_processed = c_processed.squeeze(1)

        # Apply linear transformation if hidden dimensions mismatch
        if self.needs_dim_adaptation:
            h_processed = self.fc_adapt_hidden(h_processed)
            if is_lstm and c_processed is not None and self.fc_adapt_cell:
                c_processed = self.fc_adapt_cell(c_processed)

        # Initialize decoder's hidden state(s)
        final_h = torch.zeros(self.decoder.n_layers, batch_size, self.decoder.decoder_hidden_dim, device=self.device)
        final_c = torch.zeros(self.decoder.n_layers, batch_size, self.decoder.decoder_hidden_dim, device=self.device) if is_lstm else None

        # Copy or repeat encoder hidden states to match decoder's layer count
        if self.encoder.n_layers == self.decoder.n_layers:
            final_h = h_processed
            if is_lstm: final_c = c_processed
        elif self.encoder.n_layers > self.decoder.n_layers:
            # Use the last N encoder layers where N is decoder_layers
            final_h = h_processed[-self.decoder.n_layers:, :, :]
            if is_lstm and c_processed is not None:
                final_c = c_processed[-self.decoder.n_layers:, :, :]
        else:
            # Copy encoder states to the first M decoder layers, and repeat the last encoder layer state for the rest
            final_h[:self.encoder.n_layers, :, :] = h_processed
            if is_lstm and c_processed is not None:
                final_c[:self.encoder.n_layers, :, :] = c_processed

            if self.encoder.n_layers > 0:
                last_h_layer_to_repeat = h_processed[self.encoder.n_layers-1, :, :]
                for i in range(self.encoder.n_layers, self.decoder.n_layers):
                    final_h[i, :, :] = last_h_layer_to_repeat
                    if is_lstm and c_processed is not None:
                        last_c_layer_to_repeat = c_processed[self.encoder.n_layers-1, :, :]
                        final_c[i, :, :] = last_c_layer_to_repeat

        return (final_h, final_c) if is_lstm else final_h

    def forward(self, source_seq, source_lengths, target_seq, teacher_forcing_ratio=0.5):
        """
        Forward pass for the Seq2Seq model during training.

        Args:
            source_seq (torch.Tensor): Padded source sequences.
            source_lengths (torch.Tensor): Lengths of source sequences.
            target_seq (torch.Tensor): Padded target sequences (including SOS token).
            teacher_forcing_ratio (float): Probability of using actual target token as next input.

        Returns:
            torch.Tensor: Logits for the predicted target sequence.
        """
        batch_size = source_seq.shape[0]
        target_len = target_seq.shape[1]
        target_vocab_size = self.decoder.output_vocab_size

        # Tensor to store decoder outputs
        outputs_logits = torch.zeros(batch_size, target_len, target_vocab_size).to(self.device)

        # Encode the source sequence to get the initial hidden state for the decoder
        # Encoder returns None for encoder_outputs as they are not used by the simple decoder
        _, encoder_final_hidden = self.encoder(source_seq, source_lengths)
        # Adapt encoder hidden state to decoder's expected shape
        decoder_hidden = self._adapt_encoder_hidden_for_decoder(encoder_final_hidden)

        # First input to the decoder is the <sos> token
        decoder_input = target_seq[:, 0]

        # Iterate through the target sequence, predicting one token at a time
        for t in range(target_len - 1): # Exclude the last token as we predict up to target_len-1
            # Pass input and previous hidden state to the decoder
            decoder_output_logits, decoder_hidden = \
                self.decoder(decoder_input, decoder_hidden)

            # Store the current step's predictions
            outputs_logits[:, t+1] = decoder_output_logits

            # Decide whether to use teacher forcing
            teacher_force_this_step = random.random() < teacher_forcing_ratio
            # Get the top predicted token for the next input if not teacher forcing
            top1_predicted_token = decoder_output_logits.argmax(1)

            # Use actual target token (teacher forcing) or predicted token for the next step
            decoder_input = target_seq[:, t+1] if teacher_force_this_step else top1_predicted_token

        return outputs_logits

    def predict_greedy(self, source_seq, source_lengths, max_output_len=50, target_eos_idx=None):
        """
        Generates a sequence using greedy decoding.

        Args:
            source_seq (torch.Tensor): Input source sequence (can be a single example or a batch of 1).
            source_lengths (torch.Tensor): Length of the source sequence.
            max_output_len (int): Maximum length of the sequence to generate.
            target_eos_idx (int, optional): Index of the EOS token in the target vocabulary.

        Returns:
            list: List of predicted token indices.
        """
        self.eval() # Set model to evaluation mode
        # Ensure input sequence is batched (even if batch size is 1)
        if source_seq.dim() == 1:
            source_seq = source_seq.unsqueeze(0)
            source_lengths = torch.tensor([source_lengths if isinstance(source_lengths, int) else len(source_lengths)], device=self.device)

        with torch.no_grad(): # Disable gradient calculations
            # Encode the source sequence
            _, encoder_final_hidden = self.encoder(source_seq, source_lengths)
            # Adapt encoder hidden state for decoder initialization
            decoder_hidden = self._adapt_encoder_hidden_for_decoder(encoder_final_hidden)

            # First input to the decoder is the <sos> token
            decoder_input = torch.tensor([self.target_sos_idx], device=self.device)
            predicted_indices = []

            for _ in range(max_output_len):
                # Get decoder output and updated hidden state
                decoder_output_logits, decoder_hidden = \
                    self.decoder(decoder_input, decoder_hidden)

                # Greedily select the token with the highest probability
                top1_predicted_token = decoder_output_logits.argmax(1)
                predicted_idx = top1_predicted_token.item()

                # Stop if EOS token is predicted
                if target_eos_idx is not None and predicted_idx == target_eos_idx:
                    break

                predicted_indices.append(predicted_idx)
                # Use the predicted token as the input for the next step
                decoder_input = top1_predicted_token
        return predicted_indices

    def predict_beam_search(self, source_seq, source_lengths, max_output_len=50, target_eos_idx=None, beam_width=3):
        """
        Generates a sequence using beam search decoding.

        Args:
            source_seq (torch.Tensor): Input source sequence (must be a single example).
            source_lengths (torch.Tensor): Length of the source sequence.
            max_output_len (int): Maximum length of the sequence to generate.
            target_eos_idx (int, optional): Index of the EOS token in the target vocabulary.
            beam_width (int): The number of top sequences to keep at each step.

        Returns:
            list: List of predicted token indices for the best sequence found.
        """
        self.eval() # Set model to evaluation mode
        # Ensure input sequence is batched (even if batch size is 1)
        if source_seq.dim() == 1:
            source_seq = source_seq.unsqueeze(0)
            source_lengths = torch.tensor([source_lengths if isinstance(source_lengths, int) else len(source_lengths)], device=self.device)

        if source_seq.shape[0] != 1:
            raise ValueError("Beam search predict function currently supports batch_size=1.")

        with torch.no_grad(): # Disable gradient calculations
            # Encode the source sequence
            _, encoder_final_hidden = self.encoder(source_seq, source_lengths)
            # Adapt encoder hidden state for decoder initialization
            decoder_hidden_init = self._adapt_encoder_hidden_for_decoder(encoder_final_hidden)

            # Beams are stored as (log_probability, sequence_of_indices, decoder_hidden_state)
            # Use a min-heap to keep track of the top `beam_width` sequences (smallest log_prob at top)
            beams = [(0.0, [self.target_sos_idx], decoder_hidden_init)]
            completed_sequences = [] # Stores sequences that have predicted EOS

            for _ in range(max_output_len):
                new_beams = []
                all_current_beams_ended = True # Flag to check if all beams have terminated

                # Process each beam in the current set of beams
                for log_prob_beam, seq_beam, hidden_beam in beams:
                    # If this beam has already ended, move it to completed sequences and skip
                    if not seq_beam or seq_beam[-1] == target_eos_idx:
                        # Normalize log probability by length to counteract bias towards shorter sequences
                        completed_sequences.append((log_prob_beam / len(seq_beam) if len(seq_beam) > 0 else -float('inf'), seq_beam))
                        continue # Skip to the next beam

                    all_current_beams_ended = False # At least one beam is still active

                    # Get the last token from the current beam sequence as decoder input
                    decoder_input = torch.tensor([seq_beam[-1]], device=self.device)
                    # Pass through the decoder to get next token logits and hidden state
                    decoder_output_logits, next_hidden_beam = \
                        self.decoder(decoder_input, hidden_beam)

                    # Convert logits to log probabilities and get top K candidates
                    log_probs_next_token = F.log_softmax(decoder_output_logits, dim=1)
                    topk_log_probs, topk_indices = torch.topk(log_probs_next_token, beam_width, dim=1)

                    # Expand each current beam into `beam_width` new beams
                    for k in range(beam_width):
                        next_token_idx = topk_indices[0, k].item()
                        token_log_prob = topk_log_probs[0, k].item()

                        new_seq = seq_beam + [next_token_idx]
                        new_log_prob = log_prob_beam + token_log_prob # Accumulate log probability
                        new_beams.append((new_log_prob, new_seq, next_hidden_beam))

                # If no new beams were generated (e.g., all beams ended) or all current beams ended, stop
                if not new_beams or all_current_beams_ended: break

                # Sort all new candidate beams by log probability and select the top `beam_width`
                new_beams.sort(key=lambda x: x[0], reverse=True)
                beams = new_beams[:beam_width]

            # After generation loop, add any remaining active beams to completed sequences
            for log_prob_beam, seq_beam, _ in beams:
                 if not seq_beam or seq_beam[-1] != target_eos_idx:
                    completed_sequences.append((log_prob_beam / len(seq_beam) if len(seq_beam) > 0 else -float('inf'), seq_beam))

            # If no sequences were completed (should not happen with good parameters), return default
            if not completed_sequences:
                return [target_eos_idx] if target_eos_idx is not None else []

            # Sort all completed sequences by normalized log probability and return the best one
            completed_sequences.sort(key=lambda x: x[0], reverse=True)

            best_sequence_indices = completed_sequences[0][1]
            # Remove the SOS token if it's at the beginning of the best sequence
            return best_sequence_indices[1:] if best_sequence_indices and best_sequence_indices[0] == self.target_sos_idx else best_sequence_indices

# --- Data Loading and Preprocessing ---

class Vocabulary:
    """
    Manages the mapping between characters and their numerical indices.
    Includes special tokens for padding, start-of-sequence, end-of-sequence, and unknown characters.
    """
    def __init__(self, name):
        self.name = name
        self.char2index = {PAD_TOKEN: 0, SOS_TOKEN: 1, EOS_TOKEN: 2, UNK_TOKEN: 3}
        self.index2char = {0: PAD_TOKEN, 1: SOS_TOKEN, 2: EOS_TOKEN, 3: UNK_TOKEN}
        self.char_counts = Counter()
        self.n_chars = 4 # Initialize count with special tokens
        self.pad_idx = self.char2index[PAD_TOKEN]
        self.sos_idx = self.char2index[SOS_TOKEN]
        self.eos_idx = self.char2index[EOS_TOKEN]

    def add_sequence(self, sequence):
        """Adds characters from a sequence to the character counts."""
        for char in list(sequence): self.char_counts[char] += 1

    def build_vocab(self, min_freq=1):
        """
        Builds the vocabulary mapping based on character counts and a minimum frequency.
        Characters appearing less than `min_freq` will be treated as UNK_TOKEN.
        """
        sorted_chars = sorted(self.char_counts.keys(), key=lambda char: (-self.char_counts[char], char))
        for char in sorted_chars:
            if self.char_counts[char] >= min_freq and char not in self.char2index:
                self.char2index[char] = self.n_chars
                self.index2char[self.n_chars] = char
                self.n_chars += 1

    def sequence_to_indices(self, sequence, add_eos=False, add_sos=False):
        """Converts a character sequence into a list of numerical indices."""
        indices = [self.sos_idx] if add_sos else []
        for char in list(sequence): indices.append(self.char2index.get(char, self.char2index[UNK_TOKEN]))
        if add_eos: indices.append(self.eos_idx)
        return indices

    def indices_to_sequence(self, indices):
        """Converts a list of numerical indices back into a character sequence."""
        chars = []
        for index_val in indices:
            if index_val == self.eos_idx: break # Stop at EOS token
            if index_val not in [self.sos_idx, self.pad_idx]: # Ignore SOS and PAD tokens in output
                chars.append(self.index2char.get(index_val, UNK_TOKEN))
        return "".join(chars)

class TransliterationDataset(Dataset):
    """
    A PyTorch Dataset for loading and preparing transliteration pairs.
    Reads data from a TSV file and converts text sequences to token indices.
    """
    def __init__(self, file_path, source_vocab, target_vocab, max_len=50):
        self.pairs = []
        self.source_vocab = source_vocab
        self.target_vocab = target_vocab

        if not os.path.exists(file_path):
            print(f"ERROR: Data file not found: {file_path}")
            return

        print(f"Loading data from: {file_path}")
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                parts = line.strip().split('\t')
                if len(parts) >= 2:
                    target, source = parts[0], parts[1]
                    # Skip empty sequences or sequences exceeding max_len
                    if not source or not target or \
                       (max_len and (len(source) > max_len or len(target) > max_len)):
                        continue
                    self.pairs.append((source, target))
        print(f"Loaded {len(self.pairs)} pairs from {file_path}.")

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        """Returns a source-target pair as Tensors of indices."""
        source_str, target_str = self.pairs[idx]
        source_indices = self.source_vocab.sequence_to_indices(source_str, add_eos=True)
        target_indices = self.target_vocab.sequence_to_indices(target_str, add_sos=True, add_eos=True)
        return torch.tensor(source_indices, dtype=torch.long), torch.tensor(target_indices, dtype=torch.long)

def collate_fn(batch, pad_idx_source, pad_idx_target):
    """
    Custom collate function for DataLoader to handle variable-length sequences.
    Pads sequences within a batch to the maximum length of that batch.
    """
    # Filter out any None items if __getitem__ could return them
    batch = [item for item in batch if item is not None and len(item[0]) > 0 and len(item[1]) > 0]
    if not batch: return None, None, None # Return None if batch is empty after filtering

    source_seqs, target_seqs = zip(*batch)
    source_lengths = torch.tensor([len(s) for s in source_seqs], dtype=torch.long)

    # Pad sequences to the length of the longest sequence in the batch
    padded_sources = pad_sequence(source_seqs, batch_first=True, padding_value=pad_idx_source)
    padded_targets = pad_sequence(target_seqs, batch_first=True, padding_value=pad_idx_target)

    return padded_sources, source_lengths, padded_targets

# --- Training and Evaluation Functions ---

def _train_one_epoch(model, dataloader, optimizer, criterion, device, clip_value, teacher_forcing_ratio, target_vocab):
    """
    Trains the model for a single epoch.

    Args:
        model (nn.Module): The Seq2Seq model.
        dataloader (DataLoader): DataLoader for the training set.
        optimizer (optim.Optimizer): Optimizer for model parameters.
        criterion (nn.Module): Loss function (e.g., CrossEntropyLoss).
        device (torch.device): Device to run the model on (CPU or GPU).
        clip_value (float): Gradient clipping value.
        teacher_forcing_ratio (float): Probability of using teacher forcing.
        target_vocab (Vocabulary): Vocabulary for the target language.

    Returns:
        tuple: Average epoch loss and training accuracy.
    """
    model.train() # Set model to training mode
    epoch_loss = 0
    total_correct_train = 0
    total_samples_train = 0

    if len(dataloader) == 0:
        print("Warning: Training dataloader is empty.")
        return 0.0, 0.0

    for batch_data in tqdm(dataloader, desc="Training", leave=False):
        # Skip invalid batches (e.g., from drop_last=True or filtering)
        if batch_data[0] is None: continue
        sources, source_lengths, targets = batch_data
        if sources is None or sources.shape[0] == 0: continue # Skip if sources tensor is empty

        sources, targets, source_lengths = sources.to(device), targets.to(device), source_lengths.to(device)
        optimizer.zero_grad() # Clear gradients

        # Forward pass: model predicts logits for the target sequence
        outputs_logits = model(sources, source_lengths, targets, teacher_forcing_ratio)

        # Reshape outputs and targets for CrossEntropyLoss
        # We ignore the first token (SOS) in the target for loss calculation
        output_dim = outputs_logits.shape[-1]
        flat_outputs = outputs_logits[:, 1:].reshape(-1, output_dim)
        flat_targets = targets[:, 1:].reshape(-1) # Also remove SOS from targets

        loss = criterion(flat_outputs, flat_targets)

        # Handle potential NaN/Inf loss (e.g., due to bad gradients or initial parameters)
        if not (torch.isnan(loss) or torch.isinf(loss)):
            loss.backward() # Backpropagation
            torch.nn.utils.clip_grad_norm_(model.parameters(), clip_value) # Clip gradients to prevent exploding gradients
            optimizer.step() # Update model parameters
            epoch_loss += loss.item()

        # Calculate training accuracy (exact sequence match)
        predictions_indices_train = outputs_logits.argmax(dim=2) # Get predicted token indices
        for i in range(targets.shape[0]):
            # Convert predicted and true indices to strings for comparison
            pred_str_train = target_vocab.indices_to_sequence(predictions_indices_train[i, 1:].tolist())
            true_str_train = target_vocab.indices_to_sequence(targets[i, 1:].tolist())
            if pred_str_train == true_str_train:
                total_correct_train += 1
            total_samples_train += 1

    avg_epoch_loss = epoch_loss / len(dataloader) if len(dataloader) > 0 else 0.0
    train_accuracy = total_correct_train / total_samples_train if total_samples_train > 0 else 0.0
    return avg_epoch_loss, train_accuracy

def _evaluate_one_epoch(model, dataloader, criterion, device, target_vocab, beam_width=1, is_test_set=False):
    """
    Evaluates the model's performance on a given dataset (validation or test).

    Args:
        model (nn.Module): The Seq2Seq model.
        dataloader (DataLoader): DataLoader for the evaluation set.
        criterion (nn.Module): Loss function.
        device (torch.device): Device to run the model on.
        target_vocab (Vocabulary): Vocabulary for the target language.
        beam_width (int): Beam width for decoding (1 for greedy, >1 for beam search).
        is_test_set (bool): Flag to indicate if this is the final test evaluation (for logging and samples).

    Returns:
        tuple: Average epoch loss and evaluation accuracy.
    """
    model.eval() # Set model to evaluation mode
    epoch_loss = 0
    total_correct, total_samples = 0, 0

    if len(dataloader) == 0:
        print("WARNING: Evaluation dataloader is empty.")
        return 0.0, 0.0

    desc_prefix = "Testing" if is_test_set else "Validating"

    with torch.no_grad(): # Disable gradient calculations during evaluation
        for batch_idx, batch_data in enumerate(tqdm(dataloader, desc=desc_prefix, leave=False)):
            if batch_data[0] is None: continue
            sources, source_lengths, targets = batch_data
            if sources is None or sources.shape[0] == 0: continue

            sources, targets, source_lengths = sources.to(device), targets.to(device), source_lengths.to(device)

            # Forward pass for loss calculation (using teacher forcing=0.0)
            outputs_for_loss = model(sources, source_lengths, targets, teacher_forcing_ratio=0.0)
            output_dim = outputs_for_loss.shape[-1]
            flat_outputs_for_loss = outputs_for_loss[:, 1:].reshape(-1, output_dim)
            flat_targets_for_loss = targets[:, 1:].reshape(-1)

            loss = criterion(flat_outputs_for_loss, flat_targets_for_loss)
            epoch_loss += loss.item() if not (torch.isnan(loss) or torch.isinf(loss)) else 0

            # Generate predictions for accuracy calculation for each item in batch
            for i in range(sources.shape[0]):
                src_single, src_len_single = sources[i:i+1], source_lengths[i:i+1]

                # Choose decoding strategy (beam search or greedy)
                if beam_width > 1 and hasattr(model, 'predict_beam_search'):
                    predicted_indices = model.predict_beam_search(src_single, src_len_single,
                                                                  max_output_len=targets.size(1) + 5, # Allow slightly longer output
                                                                  target_eos_idx=target_vocab.eos_idx,
                                                                  beam_width=beam_width)
                else: # Default to greedy if beam_width is 1 or method not found
                     predicted_indices = model.predict_greedy(src_single, src_len_single,
                                                             max_output_len=targets.size(1) + 5,
                                                             target_eos_idx=target_vocab.eos_idx)

                # Convert predicted and true indices to strings for comparison
                pred_str = target_vocab.indices_to_sequence(predicted_indices)
                true_str = target_vocab.indices_to_sequence(targets[i, 1:].tolist()) # Exclude SOS from true target

                # Check for exact match
                if pred_str == true_str: total_correct += 1
                total_samples += 1

                # Print debug samples for the test set
                if is_test_set and batch_idx == 0 and i < 3: # Print first 3 samples of the first batch
                    print(f"  Test Example {i} - Source: '{source_vocab.indices_to_sequence(src_single[0].tolist())}'")
                    print(f"    Pred: '{pred_str}', True: '{true_str}', Match: {pred_str == true_str}")

    avg_epoch_loss = epoch_loss / len(dataloader) if len(dataloader) > 0 else 0.0
    accuracy = total_correct / total_samples if total_samples > 0 else 0.0

    print(f"{desc_prefix} - Total Correct: {total_correct}, Total Samples: {total_samples}, Accuracy: {accuracy:.4f}, Avg Loss: {avg_epoch_loss:.4f}")
    return avg_epoch_loss, accuracy

# --- Function to Train and Save the Best Model ---

def train_and_save_best_model(config, model_save_path, device):
    """
    Performs a dedicated training run using the best hyperparameters found from a sweep.
    Saves the model checkpoint with the best validation accuracy.

    Args:
        config (dict): Dictionary of hyperparameters for this specific training run.
        model_save_path (str): Path to save the best model's state_dict.
        device (torch.device): Device to run the training on.

    Returns:
        bool: True if training was successful and a model was saved, False otherwise.
    """
    # Create a unique run name for this dedicated training
    run_name_train_best = f"TRAIN_BEST_{config['cell_type']}_emb{config['embedding_dim']}_hid{config['hidden_dim']}"

    # Initialize W&B run for this dedicated training
    with wandb.init(project="DL_A3", name=run_name_train_best, config=config, job_type="training_best_model", reinit=True) as run:
        cfg = wandb.config # Access hyperparameters via wandb.config
        print(f"Starting dedicated training for best model with config: {cfg}")

        # --- Data Loading and Vocabulary Building ---
        BASE_DATA_DIR = "/kaggle/input/dakshina-dl-a3/dakshina_dataset_v1.0/hi/"
        DATA_DIR = os.path.join(BASE_DATA_DIR, "lexicons/")
        train_file = os.path.join(DATA_DIR, "hi.translit.sampled.train.tsv")
        dev_file = os.path.join(DATA_DIR, "hi.translit.sampled.dev.tsv")

        source_vocab = Vocabulary("latin")
        target_vocab = Vocabulary("devanagari")

        # Build vocabulary from the training data
        temp_train_ds_vocab = TransliterationDataset(train_file, source_vocab, target_vocab, max_len=cfg.max_seq_len)
        if not temp_train_ds_vocab.pairs:
            print(f"ERROR: No training data loaded for vocabulary building from {train_file}.")
            return False # Indicate failure
        for src, tgt in temp_train_ds_vocab.pairs:
            source_vocab.add_sequence(src)
            target_vocab.add_sequence(tgt)
        source_vocab.build_vocab(min_freq=cfg.vocab_min_freq)
        target_vocab.build_vocab(min_freq=cfg.vocab_min_freq)
        print(f"Vocabs built. Source: {source_vocab.n_chars} unique chars, Target: {target_vocab.n_chars} unique chars.")

        # Create actual datasets for training and validation
        train_dataset = TransliterationDataset(train_file, source_vocab, target_vocab, max_len=cfg.max_seq_len)
        dev_dataset = TransliterationDataset(dev_file, source_vocab, target_vocab, max_len=cfg.max_seq_len)
        if not train_dataset.pairs or not dev_dataset.pairs:
            print(f"ERROR: Train or Dev dataset is empty after filtering. Train: {len(train_dataset.pairs)}, Dev: {len(dev_dataset.pairs)}")
            return False

        # --- DataLoader Setup ---
        # Adjust number of workers based on CPU availability and device
        num_w = 0 if device.type == 'cpu' else min(4, os.cpu_count() // 2 if os.cpu_count() else 0)
        train_loader = DataLoader(train_dataset, batch_size=cfg.batch_size_train, shuffle=True,
                                  collate_fn=lambda b: collate_fn(b, source_vocab.pad_idx, target_vocab.pad_idx),
                                  num_workers=num_w, pin_memory=True if device.type == 'cuda' else False, drop_last=True)
        dev_loader = DataLoader(dev_dataset, batch_size=cfg.batch_size_train, shuffle=False,
                                collate_fn=lambda b: collate_fn(b, source_vocab.pad_idx, target_vocab.pad_idx),
                                num_workers=num_w, pin_memory=True if device.type == 'cuda' else False)
        if not train_loader or not dev_loader:
            print(f"ERROR: Train or Dev DataLoader is empty. Train: {len(train_loader)}, Dev: {len(dev_loader)}")
            return False

        # --- Model, Optimizer, Loss Function Setup ---
        encoder = Encoder(source_vocab.n_chars, cfg.embedding_dim, cfg.hidden_dim,
                          cfg.encoder_layers, cfg.cell_type, cfg.dropout_p,
                          cfg.encoder_bidirectional, pad_idx=source_vocab.pad_idx).to(device)
        decoder = Decoder(target_vocab.n_chars, cfg.embedding_dim, cfg.hidden_dim,
                          cfg.decoder_layers, cfg.cell_type, cfg.dropout_p,
                          pad_idx=target_vocab.pad_idx).to(device)
        model = Seq2Seq(encoder, decoder, device, target_vocab.sos_idx).to(device)
        print(f"Best model parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
        wandb.watch(model, log="all", log_freq=100) # Log model weights and gradients to W&B

        optimizer = optim.Adam(model.parameters(), lr=cfg.learning_rate_train)
        # Scheduler to reduce learning rate if validation accuracy plateaus
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=cfg.lr_scheduler_patience_train, factor=0.3, verbose=True)
        criterion = nn.CrossEntropyLoss(ignore_index=target_vocab.pad_idx) # Ignore padding tokens in loss calculation

        # --- Training Loop ---
        best_val_acc_this_training = -1.0
        epochs_no_improve = 0 # Counter for early stopping

        for epoch in range(cfg.epochs_train):
            train_loss, train_acc = _train_one_epoch(model, train_loader, optimizer, criterion, device,
                                                 cfg.clip_value_train, cfg.teacher_forcing_train, target_vocab)
            # Evaluate on validation set using greedy decoding during training for simplicity
            val_loss, val_acc = _evaluate_one_epoch(model, dev_loader, criterion, device, target_vocab,
                                                beam_width=1)

            scheduler.step(val_acc) # Update learning rate based on validation accuracy

            print(f"Epoch {epoch+1}/{cfg.epochs_train} | Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} | Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}")
            # Log metrics to W&B
            wandb.log({"epoch_train_best": epoch + 1, "train_loss_best": train_loss, "train_acc_best": train_acc,
                       "val_loss_best": val_loss, "val_acc_best": val_acc, "lr_best": optimizer.param_groups[0]['lr']})

            # Early stopping logic
            if val_acc > best_val_acc_this_training:
                best_val_acc_this_training = val_acc
                epochs_no_improve = 0
                torch.save(model.state_dict(), model_save_path) # Save best model
                print(f"Saved new best model to {model_save_path} (Val Acc: {best_val_acc_this_training:.4f})")
            else:
                epochs_no_improve += 1

            if epochs_no_improve >= cfg.max_epochs_no_improve_train:
                print(f"Early stopping for best model training at epoch {epoch+1}.")
                break

        # Save the final model state if no improvement happened over the initial checkpoint (though this shouldn't be the best)
        if not os.path.exists(model_save_path):
            torch.save(model.state_dict(), model_save_path)
            print(f"Saved final model state (no improvement over initial) to {model_save_path}")

        wandb.summary["final_best_val_accuracy_during_training"] = best_val_acc_this_training
        print(f"Finished training best model. Best Val Acc: {best_val_acc_this_training:.4f}")
    return True # Indicate successful training

# --- Main Execution Block ---

if __name__ == '__main__':
    # Determine the device (GPU if available, else CPU)
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {DEVICE}")

    # --- W&B Login ---
    try:
        # Handle W&B login for Kaggle environments
        if 'KAGGLE_KERNEL_RUN_TYPE' in os.environ:
            print("Detected Kaggle environment. Ensuring WANDB_API_KEY is set.")
            if "WANDB_API_KEY" in os.environ:
                wandb.login()
            else:
                # Attempt to get API key from Kaggle secrets
                from kaggle_secrets import UserSecretsClient
                user_secrets = UserSecretsClient()
                wandb_api_key = user_secrets.get_secret("WANDB_API_KEY")
                wandb.login(key=wandb_api_key)
            print("W&B login for Kaggle successful.")
        else:
            wandb.login() # Regular W&B login for local environments
        print("W&B login process attempted/completed.")
    except Exception as e:
        print(f"W&B login failed: {e}. Please ensure your W&B API key is correctly configured.")
        exit() # Exit if W&B login fails

    # --- Best Hyperparameters (obtained from a previous sweep) ---
    # These hyperparameters are chosen based on the provided sweep results (e.g., from the best run shown in comments).
    BEST_HYPERPARAMETERS = {
        'embedding_dim': 300,
        'hidden_dim': 512,
        'encoder_layers': 2,
        'decoder_layers': 1,
        'cell_type': 'LSTM',
        'dropout_p': 0.5,
        'encoder_bidirectional': True,
        # Training specific parameters for this dedicated run:
        'learning_rate_train': 0.000376, # Learning rate from the best sweep run
        'batch_size_train': 128,         # Batch size from the best sweep run
        'epochs_train': 20,              # Max epochs for this dedicated training
        'clip_value_train': 1.0,         # Gradient clipping value
        'teacher_forcing_train': 0.5,    # Teacher forcing ratio for this training
        'max_epochs_no_improve_train': 7, # Early stopping patience for this training run
        'lr_scheduler_patience_train': 3, # Patience for LR scheduler
        # Parameters needed for dataset/vocabulary consistency:
        'vocab_min_freq': 1,
        'max_seq_len': 50,
        # Parameters for final test evaluation:
        'eval_batch_size': 128,          # Batch size for evaluation on test set
        'beam_width_eval': 3             # Beam width for final test evaluation
    }
    MODEL_SAVE_PATH = "/kaggle/working/best_model_for_testing.pt" # Path to save the best model checkpoint

    print(f"Best hyperparameters selected for dedicated training and testing: {BEST_HYPERPARAMETERS}")

    # --- Phase 1: Train and Save the Best Model ---
    print("\n" + "="*80)
    print("--- Phase 1: Training and Saving Best Model Configuration ---".center(80))
    print("="*80 + "\n")

    training_successful = train_and_save_best_model(BEST_HYPERPARAMETERS, MODEL_SAVE_PATH, DEVICE)

    if not training_successful or not os.path.exists(MODEL_SAVE_PATH):
        print("ERROR: Failed to train and save the best model. Exiting before test evaluation.")
        exit()
    print(f"\nBest model trained and saved to {MODEL_SAVE_PATH}")

    # --- Phase 2: Load and Evaluate on Test Set ---
    print("\n" + "="*80)
    print("--- Phase 2: Loading and Evaluating Best Model on Test Set ---".center(80))
    print("="*80 + "\n")

    # --- Data Setup for Test Evaluation ---
    # Need to rebuild vocab using training data to ensure consistency before processing test data
    BASE_DATA_DIR = "/kaggle/input/dakshina-dl-a3/dakshina_dataset_v1.0/hi/"
    DATA_DIR = os.path.join(BASE_DATA_DIR, "lexicons/")
    train_file_for_vocab = os.path.join(DATA_DIR, "hi.translit.sampled.train.tsv")
    test_file = os.path.join(DATA_DIR, "hi.translit.sampled.test.tsv")

    source_vocab_test = Vocabulary("latin_test")
    target_vocab_test = Vocabulary("devanagari_test")

    # Build vocabulary using the training dataset (crucial for consistent tokenization)
    temp_train_ds_test_vocab = TransliterationDataset(train_file_for_vocab, source_vocab_test, target_vocab_test,
                                                max_len=BEST_HYPERPARAMETERS['max_seq_len'])
    if not temp_train_ds_test_vocab.pairs:
        print(f"ERROR: No data loaded for vocabulary building from {train_file_for_vocab}.")
        exit()
    for src, tgt in temp_train_ds_test_vocab.pairs:
        source_vocab_test.add_sequence(src)
        target_vocab_test.add_sequence(tgt)
    source_vocab_test.build_vocab(min_freq=BEST_HYPERPARAMETERS['vocab_min_freq'])
    target_vocab_test.build_vocab(min_freq=BEST_HYPERPARAMETERS['vocab_min_freq'])
    print(f"Test Vocab built from training data. Source: {source_vocab_test.n_chars} chars. Target: {target_vocab_test.n_chars} chars.")

    # Create the test dataset
    test_dataset = TransliterationDataset(test_file, source_vocab_test, target_vocab_test,
                                          max_len=BEST_HYPERPARAMETERS['max_seq_len'])
    if not test_dataset.pairs:
        print(f"ERROR: Test dataset is empty from {test_file}.")
        exit()

    # Setup test DataLoader
    num_w_test = 0 if DEVICE.type == 'cpu' else min(4, os.cpu_count() // 2 if os.cpu_count() else 0)
    test_dataloader = DataLoader(test_dataset, batch_size=BEST_HYPERPARAMETERS['eval_batch_size'], shuffle=False,
                                 collate_fn=lambda b: collate_fn(b, source_vocab_test.pad_idx, target_vocab_test.pad_idx),
                                 num_workers=num_w_test, pin_memory=True if DEVICE.type == 'cuda' else False)
    if not test_dataloader:
        print(f"ERROR: Test DataLoader is empty.")
        exit()

    # --- Initialize Model for Testing and Load Weights ---
    # Instantiate the model architecture with the same hyperparameters as the trained model
    encoder_test = Encoder(source_vocab_test.n_chars, BEST_HYPERPARAMETERS['embedding_dim'], BEST_HYPERPARAMETERS['hidden_dim'],
                           BEST_HYPERPARAMETERS['encoder_layers'], BEST_HYPERPARAMETERS['cell_type'], BEST_HYPERPARAMETERS['dropout_p'],
                           BEST_HYPERPARAMETERS['encoder_bidirectional'], pad_idx=source_vocab_test.pad_idx).to(DEVICE)
    decoder_test = Decoder(target_vocab_test.n_chars, BEST_HYPERPARAMETERS['embedding_dim'], BEST_HYPERPARAMETERS['hidden_dim'],
                           BEST_HYPERPARAMETERS['decoder_layers'], BEST_HYPERPARAMETERS['cell_type'], BEST_HYPERPARAMETERS['dropout_p'],
                           pad_idx=target_vocab_test.pad_idx).to(DEVICE)
    model_test = Seq2Seq(encoder_test, decoder_test, DEVICE, target_vocab_test.sos_idx).to(DEVICE)

    print(f"Loading weights into test model from: {MODEL_SAVE_PATH}")
    model_test.load_state_dict(torch.load(MODEL_SAVE_PATH, map_location=DEVICE)) # Load saved weights
    print("Weights loaded successfully for test evaluation.")

    # --- Evaluate on Test Set ---
    criterion_test = nn.CrossEntropyLoss(ignore_index=target_vocab_test.pad_idx) # Use the same loss criterion
    test_loss, test_accuracy = _evaluate_one_epoch(model_test, test_dataloader, criterion_test, DEVICE,
                                                   target_vocab_test,
                                                   beam_width=BEST_HYPERPARAMETERS.get('beam_width_eval', 1), # Use the eval beam width
                                                   is_test_set=True) # Flag for printing test samples

    print(f"\n" + "="*80)
    print(f"--- FINAL TEST SET PERFORMANCE ---".center(80))
    print(f"="*80)
    print(f"  Model Checkpoint: '{MODEL_SAVE_PATH}'")
    print(f"  Hyperparameters used for training and testing: {BEST_HYPERPARAMETERS}")
    print(f"  Test Loss (avg per batch): {test_loss:.4f}")
    print(f"  Test Accuracy (Exact Match): {test_accuracy:.4f} ({test_accuracy*100:.2f}%)")
    print("="*80 + "\n")

    # --- Log Final Test Results to W&B ---
    try:
        run_name_final_test = f"FINAL_TEST_EVAL_{BEST_HYPERPARAMETERS['cell_type']}_beam{BEST_HYPERPARAMETERS['beam_width_eval']}"
        # Start a new W&B run to log only the final test results
        with wandb.init(project="DL_A3", name=run_name_final_test, config=BEST_HYPERPARAMETERS, job_type="final_evaluation", reinit=True) as test_run:
            test_run.summary["final_test_accuracy"] = test_accuracy
            test_run.summary["final_test_loss"] = test_loss
            test_run.summary["model_checkpoint_used"] = MODEL_SAVE_PATH
            print("Final test results logged to W&B.")
    except Exception as e:
        print(f"Could not log final test results to W&B: {e}")

# Q5 Attention based model

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
import random
import numpy as np
import wandb
import os
from collections import Counter
from tqdm import tqdm
import heapq

# --- Constants for Special Tokens ---
SOS_TOKEN = "<sos>"  # Start-of-sequence token
EOS_TOKEN = "<eos>"  # End-of-sequence token
PAD_TOKEN = "<pad>"  # Padding token
UNK_TOKEN = "<unk>"  # Unknown token

# --- Model Definitions (Non-Attention Version) ---

class Encoder(nn.Module):
    """
    The Encoder processes the input sequence and produces a context vector (final hidden state).
    It uses a recurrent neural network (RNN, GRU, or LSTM).
    """
    def __init__(self, input_vocab_size, embedding_dim, hidden_dim, n_layers,
                 cell_type='LSTM', dropout_p=0.1, bidirectional=False, pad_idx=0):
        super(Encoder, self).__init__()
        self.input_vocab_size = input_vocab_size
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.n_layers = n_layers
        self.cell_type = cell_type.upper()
        self.bidirectional = bidirectional
        self.num_directions = 2 if bidirectional else 1

        self.embedding = nn.Embedding(input_vocab_size, embedding_dim, padding_idx=pad_idx)
        self.dropout = nn.Dropout(dropout_p)

        rnn_dropout = dropout_p if n_layers > 1 else 0
        if self.cell_type == 'RNN':
            self.rnn = nn.RNN(embedding_dim, hidden_dim, n_layers,
                              dropout=rnn_dropout, batch_first=True, bidirectional=self.bidirectional)
        elif self.cell_type == 'GRU':
            self.rnn = nn.GRU(embedding_dim, hidden_dim, n_layers,
                              dropout=rnn_dropout, batch_first=True, bidirectional=self.bidirectional)
        elif self.cell_type == 'LSTM':
            self.rnn = nn.LSTM(embedding_dim, hidden_dim, n_layers,
                               dropout=rnn_dropout, batch_first=True, bidirectional=self.bidirectional)
        else:
            raise ValueError("Unsupported cell type. Choose from 'RNN', 'GRU', 'LSTM'.")

    def forward(self, input_seq, input_lengths):
        """
        Forward pass for the encoder.

        Args:
            input_seq (torch.Tensor): Padded input sequences of shape (batch_size, seq_len).
            input_lengths (torch.Tensor): Lengths of the original sequences in the batch of shape (batch_size,).

        Returns:
            tuple: A tuple (None, hidden_state).
                   The first element is None because this encoder does not output full sequences for attention.
                   The second element is the final hidden state of the encoder, which serves as the context.
                   For LSTM, `hidden_state` is a tuple (h, c).
        """
        embedded = self.embedding(input_seq)
        embedded = self.dropout(embedded)

        # Pack the padded sequences to handle variable-length inputs efficiently
        packed_embedded = nn.utils.rnn.pack_padded_sequence(
            embedded, input_lengths.cpu(), batch_first=True, enforce_sorted=False
        )
        # Pass packed sequences through the RNN
        _, hidden = self.rnn(packed_embedded)
        return None, hidden

class Decoder(nn.Module):
    """
    The Decoder generates the output sequence one token at a time, conditioned on the
    encoder's final hidden state and previously generated tokens.
    This version does NOT use attention.
    """
    def __init__(self, output_vocab_size, embedding_dim,
                 decoder_hidden_dim, n_layers, cell_type='LSTM', dropout_p=0.1, pad_idx=0):
        super(Decoder, self).__init__()
        self.output_vocab_size = output_vocab_size
        self.embedding_dim = embedding_dim
        self.decoder_hidden_dim = decoder_hidden_dim
        self.n_layers = n_layers
        self.cell_type = cell_type.upper()

        self.embedding = nn.Embedding(output_vocab_size, embedding_dim, padding_idx=pad_idx)
        self.dropout = nn.Dropout(dropout_p)

        rnn_input_dim = embedding_dim

        rnn_dropout = dropout_p if n_layers > 1 else 0
        if self.cell_type == 'RNN':
            self.rnn = nn.RNN(rnn_input_dim, decoder_hidden_dim, n_layers,
                              dropout=rnn_dropout, batch_first=True)
        elif self.cell_type == 'GRU':
            self.rnn = nn.GRU(rnn_input_dim, decoder_hidden_dim, n_layers,
                              dropout=rnn_dropout, batch_first=True)
        elif self.cell_type == 'LSTM':
            self.rnn = nn.LSTM(rnn_input_dim, decoder_hidden_dim, n_layers,
                               dropout=rnn_dropout, batch_first=True)
        else:
            raise ValueError("Unsupported cell type. Choose from 'RNN', 'GRU', 'LSTM'.")

        # Output linear layer to project decoder's hidden state to vocabulary size
        self.fc_out = nn.Linear(decoder_hidden_dim, output_vocab_size)

    def forward(self, input_char, prev_decoder_hidden):
        """
        Forward pass for the decoder.

        Args:
            input_char (torch.Tensor): A single token (or batch of single tokens) to be embedded,
                                       shape (batch_size,).
            prev_decoder_hidden (torch.Tensor or tuple): The previous hidden state (and cell state for LSTM)
                                                         of the decoder RNN.

        Returns:
            tuple: A tuple (prediction_logits, current_decoder_hidden).
                   `prediction_logits` are the raw scores for each token in the vocabulary.
                   `current_decoder_hidden` is the updated hidden state after processing the input.
        """
        # Add a sequence dimension for RNN input (batch_size, 1, embedding_dim)
        input_char = input_char.unsqueeze(1)
        embedded = self.embedding(input_char)
        embedded = self.dropout(embedded)

        # Pass through the RNN layer
        rnn_output, current_decoder_hidden = self.rnn(embedded, prev_decoder_hidden)

        # Squeeze the sequence dimension for the linear layer
        rnn_output_squeezed = rnn_output.squeeze(1)
        # Project to vocabulary size to get logits
        prediction_logits = self.fc_out(rnn_output_squeezed)
        return prediction_logits, current_decoder_hidden

class Seq2Seq(nn.Module):
    """
    The main Sequence-to-Sequence model that connects the Encoder and Decoder.
    This architecture does NOT use an attention mechanism.
    """
    def __init__(self, encoder, decoder, device, target_sos_idx):
        super(Seq2Seq, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device
        self.target_sos_idx = target_sos_idx

        # Calculate effective dimensions for hidden state adaptation
        encoder_effective_output_dim_per_layer = self.encoder.hidden_dim * self.encoder.num_directions
        decoder_rnn_expected_hidden_dim = self.decoder.decoder_hidden_dim

        # Check if hidden state dimensions need to be adapted from encoder to decoder
        self.needs_dim_adaptation = encoder_effective_output_dim_per_layer != decoder_rnn_expected_hidden_dim
        self.fc_adapt_hidden = None
        self.fc_adapt_cell = None

        # Create linear layers for hidden state adaptation if necessary
        if self.needs_dim_adaptation:
            self.fc_adapt_hidden = nn.Linear(encoder_effective_output_dim_per_layer, decoder_rnn_expected_hidden_dim)
            if self.encoder.cell_type == 'LSTM':
                self.fc_adapt_cell = nn.Linear(encoder_effective_output_dim_per_layer, decoder_rnn_expected_hidden_dim)

    def _adapt_encoder_hidden_for_decoder(self, encoder_final_hidden_state):
        """
        Adapts the encoder's final hidden state(s) to match the decoder's expected hidden state dimensions and layer count.
        Handles bidirectionality by concatenating forward and backward hidden states.
        If decoder has more layers than encoder, the last encoder layer's state is repeated.
        """
        is_lstm = self.encoder.cell_type == 'LSTM'
        if is_lstm:
            h_from_enc, c_from_enc = encoder_final_hidden_state
        else:
            h_from_enc = encoder_final_hidden_state
            c_from_enc = None

        batch_size = h_from_enc.size(1)

        # Reshape encoder hidden state
        h_processed = h_from_enc.view(self.encoder.n_layers, self.encoder.num_directions,
                                     batch_size, self.encoder.hidden_dim)
        if self.encoder.bidirectional:
            h_processed = torch.cat([h_processed[:, 0, :, :], h_processed[:, 1, :, :]], dim=2)
        else:
            h_processed = h_processed.squeeze(1)

        c_processed = None
        if is_lstm and c_from_enc is not None:
            c_processed = c_from_enc.view(self.encoder.n_layers, self.encoder.num_directions,
                                         batch_size, self.encoder.hidden_dim)
            if self.encoder.bidirectional:
                c_processed = torch.cat([c_processed[:, 0, :, :], c_processed[:, 1, :, :]], dim=2)
            else:
                c_processed = c_processed.squeeze(1)

        # Apply linear transformation if hidden dimensions mismatch
        if self.needs_dim_adaptation:
            h_processed = self.fc_adapt_hidden(h_processed)
            if is_lstm and c_processed is not None and self.fc_adapt_cell:
                c_processed = self.fc_adapt_cell(c_processed)

        # Initialize decoder's hidden state(s)
        final_h = torch.zeros(self.decoder.n_layers, batch_size, self.decoder.decoder_hidden_dim, device=self.device)
        final_c = torch.zeros(self.decoder.n_layers, batch_size, self.decoder.decoder_hidden_dim, device=self.device) if is_lstm else None

        # Copy or repeat encoder hidden states to match decoder's layer count
        if self.encoder.n_layers == self.decoder.n_layers:
            final_h = h_processed
            if is_lstm: final_c = c_processed
        elif self.encoder.n_layers > self.decoder.n_layers:
            final_h = h_processed[-self.decoder.n_layers:, :, :]
            if is_lstm and c_processed is not None:
                final_c = c_processed[-self.decoder.n_layers:, :, :]
        else:
            final_h[:self.encoder.n_layers, :, :] = h_processed
            if is_lstm and c_processed is not None:
                final_c[:self.encoder.n_layers, :, :] = c_processed

            if self.encoder.n_layers > 0:
                last_h_layer_to_repeat = h_processed[self.encoder.n_layers-1, :, :]
                for i in range(self.encoder.n_layers, self.decoder.n_layers):
                    final_h[i, :, :] = last_h_layer_to_repeat
                    if is_lstm and c_processed is not None:
                        last_c_layer_to_repeat = c_processed[self.encoder.n_layers-1, :, :]
                        final_c[i, :, :] = last_c_layer_to_repeat

        return (final_h, final_c) if is_lstm else final_h

    def forward(self, source_seq, source_lengths, target_seq, teacher_forcing_ratio=0.5):
        """
        Forward pass for the Seq2Seq model during training.

        Args:
            source_seq (torch.Tensor): Padded source sequences.
            source_lengths (torch.Tensor): Lengths of source sequences.
            target_seq (torch.Tensor): Padded target sequences (including SOS token).
            teacher_forcing_ratio (float): Probability of using actual target token as next input.

        Returns:
            torch.Tensor: Logits for the predicted target sequence.
        """
        batch_size = source_seq.shape[0]
        target_len = target_seq.shape[1]
        target_vocab_size = self.decoder.output_vocab_size
        outputs_logits = torch.zeros(batch_size, target_len, target_vocab_size).to(self.device)

        # Encode the source sequence to get the initial hidden state for the decoder
        _, encoder_final_hidden = self.encoder(source_seq, source_lengths)
        # Adapt encoder hidden state to decoder's expected shape
        decoder_hidden = self._adapt_encoder_hidden_for_decoder(encoder_final_hidden)

        # First input to the decoder is the <sos> token
        decoder_input = target_seq[:, 0]

        # Iterate through the target sequence, predicting one token at a time
        for t in range(target_len - 1):
            decoder_output_logits, decoder_hidden = self.decoder(decoder_input, decoder_hidden)
            outputs_logits[:, t+1] = decoder_output_logits

            # Decide whether to use teacher forcing
            teacher_force_this_step = random.random() < teacher_forcing_ratio
            top1_predicted_token = decoder_output_logits.argmax(1)

            # Use actual target token (teacher forcing) or predicted token for the next step
            decoder_input = target_seq[:, t+1] if teacher_force_this_step else top1_predicted_token
        return outputs_logits

    def predict_greedy(self, source_seq, source_lengths, max_output_len=50, target_eos_idx=None):
        """
        Generates a sequence using greedy decoding.

        Args:
            source_seq (torch.Tensor): Input source sequence (can be a single example or a batch of 1).
            source_lengths (torch.Tensor): Length of the source sequence.
            max_output_len (int): Maximum length of the sequence to generate.
            target_eos_idx (int, optional): Index of the EOS token in the target vocabulary.

        Returns:
            list: List of predicted token indices.
        """
        self.eval()
        if source_seq.dim() == 1:
            source_seq = source_seq.unsqueeze(0)
            source_lengths = torch.tensor([source_lengths if isinstance(source_lengths, int) else len(source_lengths)], device=self.device)

        with torch.no_grad():
            _, encoder_final_hidden = self.encoder(source_seq, source_lengths)
            decoder_hidden = self._adapt_encoder_hidden_for_decoder(encoder_final_hidden)
            decoder_input = torch.tensor([self.target_sos_idx], device=self.device)
            predicted_indices = []
            for _ in range(max_output_len):
                decoder_output_logits, decoder_hidden = self.decoder(decoder_input, decoder_hidden)
                top1_predicted_token = decoder_output_logits.argmax(1)
                predicted_idx = top1_predicted_token.item()
                if target_eos_idx is not None and predicted_idx == target_eos_idx:
                    break
                predicted_indices.append(predicted_idx)
                decoder_input = top1_predicted_token
        return predicted_indices

    def predict_beam_search(self, source_seq, source_lengths, max_output_len=50, target_eos_idx=None, beam_width=3):
        """
        Generates a sequence using beam search decoding.

        Args:
            source_seq (torch.Tensor): Input source sequence (must be a single example).
            source_lengths (torch.Tensor): Length of the source sequence.
            max_output_len (int): Maximum length of the sequence to generate.
            target_eos_idx (int, optional): Index of the EOS token.
            beam_width (int): The number of top sequences to keep at each step.

        Returns:
            list: List of predicted token indices for the best sequence found.
        """
        self.eval()
        if source_seq.dim() == 1:
            source_seq = source_seq.unsqueeze(0)
            source_lengths = torch.tensor([source_lengths if isinstance(source_lengths, int) else len(source_lengths)], device=self.device)

        if source_seq.shape[0] != 1:
            raise ValueError("Beam search predict function currently supports batch_size=1.")

        with torch.no_grad():
            _, encoder_final_hidden = self.encoder(source_seq, source_lengths)
            decoder_hidden_init = self._adapt_encoder_hidden_for_decoder(encoder_final_hidden)

            # Beams are stored as (cumulative_log_probability, sequence_of_indices, decoder_hidden_state)
            beams = [(0.0, [self.target_sos_idx], decoder_hidden_init)]
            completed_sequences = []

            for _ in range(max_output_len):
                new_beams = []
                all_current_beams_ended = True
                for log_prob_beam, seq_beam, hidden_beam in beams:
                    # If this beam has already ended, move it to completed sequences and skip expansion
                    if not seq_beam or seq_beam[-1] == target_eos_idx:
                        # Normalize log probability by length to counteract bias towards shorter sequences
                        completed_sequences.append((log_prob_beam / len(seq_beam) if len(seq_beam) > 0 else -float('inf'), seq_beam))
                        continue
                    all_current_beams_ended = False

                    decoder_input = torch.tensor([seq_beam[-1]], device=self.device)
                    decoder_output_logits, next_hidden_beam = self.decoder(decoder_input, hidden_beam)

                    # Get top K next tokens (log probabilities)
                    log_probs_next_token = F.log_softmax(decoder_output_logits, dim=1)
                    topk_log_probs, topk_indices = torch.topk(log_probs_next_token, beam_width, dim=1)

                    # Expand each current beam into `beam_width` new beams
                    for k in range(beam_width):
                        next_token_idx = topk_indices[0, k].item()
                        token_log_prob = topk_log_probs[0, k].item()

                        new_seq = seq_beam + [next_token_idx]
                        new_log_prob = log_prob_beam + token_log_prob
                        new_beams.append((new_log_prob, new_seq, next_hidden_beam))

                if not new_beams or all_current_beams_ended:
                    break # No new beams to explore or all current beams ended

                # Sort all new candidate beams by their log probability and keep top `beam_width`
                new_beams.sort(key=lambda x: x[0], reverse=True)
                beams = new_beams[:beam_width]

            # Add any remaining active beams to completed sequences
            for log_prob_beam, seq_beam, _ in beams:
                completed_sequences.append((log_prob_beam / len(seq_beam) if len(seq_beam) > 0 else -float('inf'), seq_beam))

            if not completed_sequences:
                # Fallback if somehow no sequences were completed or started
                return [target_eos_idx] if target_eos_idx is not None else []
            completed_sequences.sort(key=lambda x: x[0], reverse=True)

            best_sequence_indices = completed_sequences[0][1]
            # Remove SOS from the beginning if it was added
            return best_sequence_indices[1:] if best_sequence_indices and best_sequence_indices[0] == self.target_sos_idx else best_sequence_indices

# --- Data Loading and Preprocessing ---

class Vocabulary:
    """
    Manages the mapping between characters and their numerical indices.
    Includes special tokens for padding, start-of-sequence, end-of-sequence, and unknown characters.
    """
    def __init__(self, name):
        self.name = name
        self.char2index = {PAD_TOKEN: 0, SOS_TOKEN: 1, EOS_TOKEN: 2, UNK_TOKEN: 3}
        self.index2char = {0: PAD_TOKEN, 1: SOS_TOKEN, 2: EOS_TOKEN, 3: UNK_TOKEN}
        self.char_counts = Counter()
        self.n_chars = 4
        self.pad_idx = self.char2index[PAD_TOKEN]
        self.sos_idx = self.char2index[SOS_TOKEN]
        self.eos_idx = self.char2index[EOS_TOKEN]

    def add_sequence(self, sequence):
        """Adds characters from a sequence to the character counts."""
        for char in list(sequence):
            self.char_counts[char] += 1

    def build_vocab(self, min_freq=1):
        """
        Builds the vocabulary mapping based on character counts and a minimum frequency.
        Characters appearing less than `min_freq` will be treated as UNK_TOKEN.
        """
        sorted_chars = sorted(self.char_counts.keys(), key=lambda char: (-self.char_counts[char], char))
        for char in sorted_chars:
            if self.char_counts[char] >= min_freq and char not in self.char2index:
                self.char2index[char] = self.n_chars
                self.index2char[self.n_chars] = char
                self.n_chars += 1

    def sequence_to_indices(self, sequence, add_eos=False, add_sos=False):
        """Converts a character sequence into a list of numerical indices."""
        indices = []
        if add_sos:
            indices.append(self.sos_idx)
        for char in list(sequence):
            indices.append(self.char2index.get(char, self.char2index[UNK_TOKEN]))
        if add_eos:
            indices.append(self.eos_idx)
        return indices

    def indices_to_sequence(self, indices):
        """Converts a list of numerical indices back into a character sequence."""
        chars = []
        for index_val in indices:
            if index_val == self.eos_idx:
                break # Stop at EOS token
            if index_val != self.sos_idx and index_val != self.pad_idx:
                 chars.append(self.index2char.get(index_val, UNK_TOKEN))
        return "".join(chars)


class TransliterationDataset(Dataset):
    """
    A PyTorch Dataset for loading and preparing transliteration pairs.
    Reads data from a TSV file and converts text sequences to token indices.
    """
    def __init__(self, file_path, source_vocab, target_vocab, max_len=50):
        self.pairs = []
        self.source_vocab = source_vocab
        self.target_vocab = target_vocab

        if not os.path.exists(file_path):
            print(f"ERROR: Data file not found during Dataset init: {file_path}")
            return

        print(f"Loading data from: {file_path}")
        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                for line_num, line in enumerate(f):
                    parts = line.strip().split('\t')
                    if len(parts) >= 2:
                        target_sequence, source_sequence = parts[0], parts[1]

                        if max_len and (len(source_sequence) > max_len or len(target_sequence) > max_len):
                            continue
                        if not source_sequence or not target_sequence:
                            continue
                        self.pairs.append((source_sequence, target_sequence))
            print(f"Loaded {len(self.pairs)} pairs from {file_path}.")
        except Exception as e:
            print(f"ERROR: Could not read or process file {file_path}. Error: {e}")

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        """Returns a source-target pair as Tensors of indices."""
        if idx >= len(self.pairs):
            raise IndexError("Index out of bounds for dataset")
        source_str, target_str = self.pairs[idx]
        source_indices = self.source_vocab.sequence_to_indices(source_str, add_eos=True)
        target_indices = self.target_vocab.sequence_to_indices(target_str, add_sos=True, add_eos=True)
        return torch.tensor(source_indices, dtype=torch.long), \
               torch.tensor(target_indices, dtype=torch.long)

def collate_fn(batch, pad_idx_source, pad_idx_target):
    """
    Custom collate function for DataLoader to handle variable-length sequences.
    Pads sequences within a batch to the maximum length of that batch.
    """
    # Filter out any None or empty items
    batch = [item for item in batch if item is not None and item[0] is not None and item[1] is not None]
    if not batch:
        return None, None, None

    source_seqs, target_seqs = zip(*batch)

    # Filter out any empty sequences after zipping
    valid_indices = [i for i, s in enumerate(source_seqs) if len(s) > 0]
    if not valid_indices: return None, None, None

    source_seqs = [source_seqs[i] for i in valid_indices]
    target_seqs = [target_seqs[i] for i in valid_indices]
    source_lengths = torch.tensor([len(s) for s in source_seqs], dtype=torch.long)

    # Pad sequences to the length of the longest sequence in the batch
    padded_sources = pad_sequence(source_seqs, batch_first=True, padding_value=pad_idx_source)
    padded_targets = pad_sequence(target_seqs, batch_first=True, padding_value=pad_idx_target)
    return padded_sources, source_lengths, padded_targets

# --- Training and Evaluation Functions ---

def _train_one_epoch(model, dataloader, optimizer, criterion, device, clip_value, teacher_forcing_ratio, target_vocab):
    """
    Trains the model for a single epoch.

    Args:
        model (nn.Module): The Seq2Seq model.
        dataloader (DataLoader): DataLoader for the training set.
        optimizer (optim.Optimizer): Optimizer for model parameters.
        criterion (nn.Module): Loss function (e.g., CrossEntropyLoss).
        device (torch.device): Device to run the model on.
        clip_value (float): Gradient clipping value.
        teacher_forcing_ratio (float): Probability of using teacher forcing.
        target_vocab (Vocabulary): Vocabulary for the target language.

    Returns:
        tuple: Average epoch loss and training accuracy.
    """
    model.train() # Set model to training mode
    epoch_loss = 0
    total_correct_train = 0
    total_samples_train = 0

    if len(dataloader) == 0:
        print("Warning: Training dataloader is empty.")
        return 0.0, 0.0

    for batch_data in tqdm(dataloader, desc="Training", leave=False):
        if batch_data[0] is None: continue
        sources, source_lengths, targets = batch_data

        if sources is None or source_lengths is None or targets is None or source_lengths.numel() == 0 or sources.shape[0] == 0:
            print("Warning: Empty or invalid batch data after collate_fn in training. Skipping.")
            continue

        sources, targets, source_lengths = sources.to(device), targets.to(device), source_lengths.to(device)
        optimizer.zero_grad() # Clear gradients

        # Forward pass: model predicts logits for the target sequence
        outputs_logits = model(sources, source_lengths, targets, teacher_forcing_ratio)

        # Reshape outputs and targets for CrossEntropyLoss
        # We ignore the first token (SOS) in the target for loss calculation
        output_dim = outputs_logits.shape[-1]
        flat_outputs = outputs_logits[:, 1:].reshape(-1, output_dim)
        flat_targets = targets[:, 1:].reshape(-1)

        loss = criterion(flat_outputs, flat_targets)

        # Handle potential NaN/Inf loss (e.g., due to bad gradients)
        if torch.isnan(loss) or torch.isinf(loss):
            print("Warning: NaN or Inf loss detected in training. Skipping batch.")
            continue
        loss.backward() # Backpropagation
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip_value) # Clip gradients
        optimizer.step() # Update model parameters
        epoch_loss += loss.item()

        # Calculate training accuracy (exact sequence match)
        predictions_indices_train = outputs_logits.argmax(dim=2)
        for i in range(targets.shape[0]):
            pred_str_train = target_vocab.indices_to_sequence(predictions_indices_train[i, 1:].tolist())
            true_str_train = target_vocab.indices_to_sequence(targets[i, 1:].tolist())
            if pred_str_train == true_str_train:
                total_correct_train += 1
            total_samples_train += 1

    avg_epoch_loss = epoch_loss / len(dataloader) if len(dataloader) > 0 else 0.0
    train_accuracy = total_correct_train / total_samples_train if total_samples_train > 0 else 0.0
    return avg_epoch_loss, train_accuracy


def _evaluate_one_epoch(model, dataloader, criterion, device, source_vocab, target_vocab, beam_width=1, is_test_set=False):
    """
    Evaluates the model's performance on a given dataset (validation or test).

    Args:
        model (nn.Module): The Seq2Seq model.
        dataloader (DataLoader): DataLoader for the evaluation set.
        criterion (nn.Module): Loss function.
        device (torch.device): Device to run the model on.
        source_vocab (Vocabulary): Vocabulary for the source language (used for debug prints).
        target_vocab (Vocabulary): Vocabulary for the target language.
        beam_width (int): Beam width for decoding (1 for greedy, >1 for beam search).
        is_test_set (bool): Flag to indicate if this is the final test evaluation (for logging and samples).

    Returns:
        tuple: Average epoch loss and evaluation accuracy.
    """
    model.eval() # Set model to evaluation mode
    epoch_loss = 0
    total_correct, total_samples = 0, 0
    if len(dataloader) == 0:
        print("WARNING: Evaluation dataloader is empty. Returning 0 loss and 0 accuracy.")
        return 0.0, 0.0
    desc_prefix = "Testing" if is_test_set else "Validating"

    with torch.no_grad(): # Disable gradient calculations during evaluation
        for batch_idx, batch_data in enumerate(tqdm(dataloader, desc=desc_prefix, leave=False)):
            if batch_data[0] is None: continue
            sources, source_lengths, targets = batch_data

            if sources is None or source_lengths is None or targets is None or source_lengths.numel() == 0 or sources.shape[0] == 0:
                print("Warning: Empty or invalid batch data in eval after collate_fn. Skipping.")
                continue

            sources, targets, source_lengths = sources.to(device), targets.to(device), source_lengths.to(device)

            # Forward pass for loss calculation (using teacher forcing=0.0)
            outputs_for_loss = model(sources, source_lengths, targets, teacher_forcing_ratio=0.0)
            output_dim = outputs_for_loss.shape[-1]
            flat_outputs_for_loss = outputs_for_loss[:, 1:].reshape(-1, output_dim)
            flat_targets_for_loss = targets[:, 1:].reshape(-1)
            loss = criterion(flat_outputs_for_loss, flat_targets_for_loss)
            epoch_loss += loss.item() if not (torch.isnan(loss) or torch.isinf(loss)) else 0

            # Generate predictions for accuracy calculation for each item in batch
            for i in range(sources.shape[0]):
                src_single, src_len_single = sources[i:i+1], source_lengths[i:i+1]

                # Choose decoding strategy (beam search or greedy)
                if beam_width > 1 and hasattr(model, 'predict_beam_search'):
                    predicted_indices = model.predict_beam_search(src_single, src_len_single,
                                                                  max_output_len=targets.size(1) + 5,
                                                                  target_eos_idx=target_vocab.eos_idx,
                                                                  beam_width=beam_width)
                elif hasattr(model, 'predict_greedy'):
                     predicted_indices = model.predict_greedy(src_single, src_len_single,
                                                             max_output_len=targets.size(1) + 5,
                                                             target_eos_idx=target_vocab.eos_idx)
                else: # Fallback to argmax on outputs_for_loss if predict methods are not available
                    predicted_indices = outputs_for_loss[i:i+1].argmax(dim=2)[0, 1:].tolist()

                # Convert predicted and true indices to strings for comparison
                pred_str = target_vocab.indices_to_sequence(predicted_indices)
                true_str = target_vocab.indices_to_sequence(targets[i, 1:].tolist())

                # Check for exact match
                if pred_str == true_str: total_correct += 1
                total_samples += 1

                # Print a single debug sample from the first batch
                if is_test_set and batch_idx == 0 and i < 3:
                    print(f"  Test Example {i} - Source: '{source_vocab.indices_to_sequence(src_single[0].tolist())}'")
                    print(f"    Pred: '{pred_str}', True: '{true_str}', Match: {pred_str == true_str}")

    avg_epoch_loss = epoch_loss / len(dataloader) if len(dataloader) > 0 else 0.0
    accuracy = total_correct / total_samples if total_samples > 0 else 0.0
    print(f"{desc_prefix} - Total Correct: {total_correct}, Total Samples: {total_samples}, Accuracy: {accuracy:.4f}, Avg Loss: {avg_epoch_loss:.4f}")
    return avg_epoch_loss, accuracy

# --- Function to Train and Save the Best Model ---

def train_and_save_best_model(config_params, model_save_path, device):
    """
    Performs a dedicated training run using the best hyperparameters found from a sweep.
    Saves the model checkpoint with the best validation accuracy.

    Args:
        config_params (dict): Dictionary of hyperparameters for this specific training run.
        model_save_path (str): Path to save the best model's state_dict.
        device (torch.device): Device to run the training on.

    Returns:
        bool: True if training was successful and a model was saved, False otherwise.
    """
    run_name_train_best = f"TRAIN_BEST_{config_params['cell_type']}_emb{config_params['embedding_dim']}_hid{config_params['hidden_dim']}"

    with wandb.init(project="DL_A3", name=run_name_train_best, config=config_params, job_type="training_best_model_final", reinit=True) as run:
        cfg = wandb.config
        print(f"Starting dedicated training for best model with config: {cfg}")

        # --- Data Loading and Vocabulary Building ---
        BASE_DATA_DIR = "/kaggle/input/dakshina-dl-a3/dakshina_dataset_v1.0/hi/"
        DATA_DIR = os.path.join(BASE_DATA_DIR, "lexicons/")
        train_file = os.path.join(DATA_DIR, "hi.translit.sampled.train.tsv")
        dev_file = os.path.join(DATA_DIR, "hi.translit.sampled.dev.tsv")

        source_vocab = Vocabulary("latin")
        target_vocab = Vocabulary("devanagari")

        # Build vocabulary from the training data
        temp_train_ds_vocab = TransliterationDataset(train_file, source_vocab, target_vocab, max_len=cfg.max_seq_len)
        if not temp_train_ds_vocab.pairs:
            print(f"ERROR: No training data loaded for vocabulary building from {train_file}.")
            return False
        for src, tgt in temp_train_ds_vocab.pairs:
            source_vocab.add_sequence(src)
            target_vocab.add_sequence(tgt)
        source_vocab.build_vocab(min_freq=cfg.vocab_min_freq)
        target_vocab.build_vocab(min_freq=cfg.vocab_min_freq)
        print(f"Vocabs built. Source: {source_vocab.n_chars}, Target: {target_vocab.n_chars}")

        # Create actual datasets for training and validation
        train_dataset = TransliterationDataset(train_file, source_vocab, target_vocab, max_len=cfg.max_seq_len)
        dev_dataset = TransliterationDataset(dev_file, source_vocab, target_vocab, max_len=cfg.max_seq_len)
        if not train_dataset.pairs or not dev_dataset.pairs:
            print(f"ERROR: Train or Dev dataset is empty after filtering. Train: {len(train_dataset.pairs)}, Dev: {len(dev_dataset.pairs)}")
            return False

        # --- DataLoader Setup ---
        num_w = 0 if device.type == 'cpu' else min(4, os.cpu_count() // 2 if os.cpu_count() else 0)
        train_loader = DataLoader(train_dataset, batch_size=cfg.batch_size_train, shuffle=True,
                                  collate_fn=lambda b: collate_fn(b, source_vocab.pad_idx, target_vocab.pad_idx),
                                  num_workers=num_w, pin_memory=True if device.type == 'cuda' else False, drop_last=True)
        dev_loader = DataLoader(dev_dataset, batch_size=cfg.batch_size_train, shuffle=False,
                                collate_fn=lambda b: collate_fn(b, source_vocab.pad_idx, target_vocab.pad_idx),
                                num_workers=num_w, pin_memory=True if device.type == 'cuda' else False)
        if not train_loader or not dev_loader:
            print(f"ERROR: Train or Dev DataLoader is empty. Train: {len(train_loader)}, Dev: {len(dev_loader)}")
            return False

        # --- Model, Optimizer, Loss Function Setup ---
        encoder = Encoder(source_vocab.n_chars, cfg.embedding_dim, cfg.hidden_dim,
                          cfg.encoder_layers, cfg.cell_type, cfg.dropout_p,
                          cfg.encoder_bidirectional, pad_idx=source_vocab.pad_idx).to(device)
        decoder = Decoder(target_vocab.n_chars, cfg.embedding_dim, cfg.hidden_dim,
                          cfg.decoder_layers, cfg.cell_type, cfg.dropout_p,
                          pad_idx=target_vocab.pad_idx).to(device)
        model = Seq2Seq(encoder, decoder, device, target_vocab.sos_idx).to(device)
        print(f"Best model parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
        wandb.watch(model, log="all", log_freq=100)

        optimizer = optim.Adam(model.parameters(), lr=cfg.learning_rate_train)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=cfg.lr_scheduler_patience_train, factor=0.3, verbose=True)
        criterion = nn.CrossEntropyLoss(ignore_index=target_vocab.pad_idx)

        best_val_acc_this_training = -1.0
        epochs_no_improve = 0

        # --- Training Loop ---
        for epoch in range(cfg.epochs_train):
            train_loss, train_acc = _train_one_epoch(model, train_loader, optimizer, criterion, device,
                                                 cfg.clip_value_train, cfg.teacher_forcing_train, target_vocab)

            val_loss, val_acc = _evaluate_one_epoch(model, dev_loader, criterion, device,
                                                    source_vocab, target_vocab,
                                                    beam_width=1) # Use greedy for val during this training

            scheduler.step(val_acc)
            print(f"Epoch {epoch+1}/{cfg.epochs_train} | Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} | Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}")
            wandb.log({"epoch_train_best": epoch + 1, "train_loss_best": train_loss, "train_acc_best": train_acc,
                       "val_loss_best": val_loss, "val_acc_best": val_acc, "lr_best": optimizer.param_groups[0]['lr']})

            # Early stopping logic
            current_val_acc = val_acc if not np.isnan(val_acc) else -1.0
            if current_val_acc > best_val_acc_this_training:
                best_val_acc_this_training = current_val_acc
                epochs_no_improve = 0
                torch.save(model.state_dict(), model_save_path) # Save best model
                print(f"Saved new best model to {model_save_path} (Val Acc: {best_val_acc_this_training:.4f})")
            else:
                epochs_no_improve += 1

            if epochs_no_improve >= cfg.max_epochs_no_improve_train:
                print(f"Early stopping for best model training at epoch {epoch+1}.")
                break

        # Save the final model state if no improvement happened over the initial checkpoint
        if not os.path.exists(model_save_path):
            torch.save(model.state_dict(), model_save_path)
            print(f"Saved final model state (no improvement over initial) to {model_save_path}")

        wandb.summary["final_best_val_accuracy_during_training"] = best_val_acc_this_training
        print(f"Finished training best model. Best Val Acc: {best_val_acc_this_training:.4f}")
    return True

# --- Main Execution Block for Training Best Model and Testing ---
if __name__ == '__main__':
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {DEVICE}")

    # --- W&B Login ---
    try:
        if 'KAGGLE_KERNEL_RUN_TYPE' in os.environ:
            print("Detected Kaggle environment. Ensuring WANDB_API_KEY is set.")
            if "WANDB_API_KEY" in os.environ:
                wandb.login()
                print("W&B login using environment variable.")
            else:
                from kaggle_secrets import UserSecretsClient
                user_secrets = UserSecretsClient()
                wandb_api_key = user_secrets.get_secret("WANDB_API_KEY")
                wandb.login(key=wandb_api_key)
                print("W&B login for Kaggle successful.")
        else:
            wandb.login()
        print("W&B login process attempted/completed.")
    except Exception as e:
        print(f"W&B login failed: {e}. Ensure API key is configured.")
        exit()

    # --- Best Hyperparameters (obtained from a previous sweep) ---
    # These hyperparameters are chosen based on the provided sweep results.
    BEST_HYPERPARAMETERS_FOR_TRAINING = {
        'embedding_dim': 300,
        'hidden_dim': 512,
        'encoder_layers': 2,
        'decoder_layers': 1,
        'cell_type': 'LSTM',
        'dropout_p': 0.5,
        'encoder_bidirectional': True,
        'learning_rate_train': 0.000376,
        'batch_size_train': 128,
        'epochs_train': 20,
        'clip_value_train': 1.0,
        'teacher_forcing_train': 0.5,
        'max_epochs_no_improve_train': 7,
        'lr_scheduler_patience_train': 3,
        'vocab_min_freq': 1,
        'max_seq_len': 50,
        'eval_batch_size': 128,
        'beam_width_eval': 3
    }
    MODEL_SAVE_PATH = "/kaggle/working/best_model_for_testing.pt"

    print(f"Best hyperparameters selected for training: {BEST_HYPERPARAMETERS_FOR_TRAINING}")

    # --- Phase 1: Train and Save the Best Model ---
    print("\n" + "="*80)
    print("--- Phase 1: Training and Saving Best Model Configuration ---".center(80))
    print("="*80 + "\n")

    training_successful = train_and_save_best_model(BEST_HYPERPARAMETERS_FOR_TRAINING, MODEL_SAVE_PATH, DEVICE)

    if not training_successful or not os.path.exists(MODEL_SAVE_PATH):
        print("ERROR: Failed to train and save the best model. Exiting before test evaluation.")
        exit()
    print(f"Best model trained and saved to {MODEL_SAVE_PATH}")

    # --- Phase 2: Load and Evaluate on Test Set ---
    print("\n" + "="*80)
    print("--- Phase 2: Loading and Evaluating Best Model on Test Set ---".center(80))
    print("="*80 + "\n")

    BASE_DATA_DIR = "/kaggle/input/dakshina-dl-a3/dakshina_dataset_v1.0/hi/"
    DATA_DIR = os.path.join(BASE_DATA_DIR, "lexicons/")
    train_file_for_vocab = os.path.join(DATA_DIR, "hi.translit.sampled.train.tsv")
    test_file = os.path.join(DATA_DIR, "hi.translit.sampled.test.tsv")

    source_vocab_test = Vocabulary("latin_test")
    target_vocab_test = Vocabulary("devanagari_test")

    # Build vocabulary using the training dataset (crucial for consistent tokenization)
    temp_train_ds_test_vocab = TransliterationDataset(train_file_for_vocab, source_vocab_test, target_vocab_test,
                                                max_len=BEST_HYPERPARAMETERS_FOR_TRAINING['max_seq_len'])
    if not temp_train_ds_test_vocab.pairs:
        print(f"ERROR: No data loaded for vocabulary building from {train_file_for_vocab}.")
        exit()
    for src, tgt in temp_train_ds_test_vocab.pairs:
        source_vocab_test.add_sequence(src)
        target_vocab_test.add_sequence(tgt)
    source_vocab_test.build_vocab(min_freq=BEST_HYPERPARAMETERS_FOR_TRAINING['vocab_min_freq'])
    target_vocab_test.build_vocab(min_freq=BEST_HYPERPARAMETERS_FOR_TRAINING['vocab_min_freq'])
    print(f"Test Vocab built from training data. Source: {source_vocab_test.n_chars} chars. Target: {target_vocab_test.n_chars} chars.")

    # Create the test dataset
    test_dataset = TransliterationDataset(test_file, source_vocab_test, target_vocab_test,
                                          max_len=BEST_HYPERPARAMETERS_FOR_TRAINING['max_seq_len'])
    if not test_dataset.pairs:
        print(f"ERROR: Test dataset is empty from {test_file}.")
        exit()

    # Setup test DataLoader
    num_w_test = 0 if DEVICE.type == 'cpu' else min(4, os.cpu_count() // 2 if os.cpu_count() else 0)
    test_dataloader = DataLoader(test_dataset, batch_size=BEST_HYPERPARAMETERS_FOR_TRAINING['eval_batch_size'], shuffle=False,
                                 collate_fn=lambda b: collate_fn(b, source_vocab_test.pad_idx, target_vocab_test.pad_idx),
                                 num_workers=num_w_test, pin_memory=True if DEVICE.type == 'cuda' else False)
    if not test_dataloader or len(test_dataloader) == 0:
        print(f"ERROR: Test Dataloader is empty. Dataset size: {len(test_dataset)}, Batch size: {BEST_HYPERPARAMETERS_FOR_TRAINING['eval_batch_size']}")
        exit()

    # --- Initialize Model for Testing and Load Weights ---
    encoder_test = Encoder(source_vocab_test.n_chars, BEST_HYPERPARAMETERS_FOR_TRAINING['embedding_dim'], BEST_HYPERPARAMETERS_FOR_TRAINING['hidden_dim'],
                           BEST_HYPERPARAMETERS_FOR_TRAINING['encoder_layers'], BEST_HYPERPARAMETERS_FOR_TRAINING['cell_type'], BEST_HYPERPARAMETERS_FOR_TRAINING['dropout_p'],
                           BEST_HYPERPARAMETERS_FOR_TRAINING['encoder_bidirectional'], pad_idx=source_vocab_test.pad_idx).to(DEVICE)
    decoder_test = Decoder(target_vocab_test.n_chars, BEST_HYPERPARAMETERS_FOR_TRAINING['embedding_dim'], BEST_HYPERPARAMETERS_FOR_TRAINING['hidden_dim'],
                           BEST_HYPERPARAMETERS_FOR_TRAINING['decoder_layers'], BEST_HYPERPARAMETERS_FOR_TRAINING['cell_type'], BEST_HYPERPARAMETERS_FOR_TRAINING['dropout_p'],
                           pad_idx=target_vocab_test.pad_idx).to(DEVICE)
    model_test = Seq2Seq(encoder_test, decoder_test, DEVICE, target_vocab_test.sos_idx).to(DEVICE)

    print(f"Loading weights into test model from: {MODEL_SAVE_PATH}")
    model_test.load_state_dict(torch.load(MODEL_SAVE_PATH, map_location=DEVICE))
    print("Weights loaded successfully for test evaluation.")

    # --- Evaluate on Test Set ---
    criterion_test = nn.CrossEntropyLoss(ignore_index=target_vocab_test.pad_idx)
    test_loss, test_accuracy = _evaluate_one_epoch(model_test, test_dataloader, criterion_test, DEVICE,
                                                   source_vocab_test, target_vocab_test,
                                                   beam_width=BEST_HYPERPARAMETERS_FOR_TRAINING.get('beam_width_eval', 1),
                                                   is_test_set=True)

    print(f"\n--- FINAL TEST SET PERFORMANCE ---")
    print(f"Model Checkpoint: '{MODEL_SAVE_PATH}'")
    # Print only relevant hyperparameters for the loaded model, not training-specific ones
    eval_config_to_print = {k: v for k, v in BEST_HYPERPARAMETERS_FOR_TRAINING.items() if not k.endswith('_train')}
    print(f"Hyperparameters: {eval_config_to_print}")
    print(f"  Test Loss (avg per batch): {test_loss:.4f}")
    print(f"  Test Accuracy (Exact Match): {test_accuracy:.4f} ({test_accuracy*100:.2f}%)")

    # --- Log Test Results to W&B ---
    try:
        run_name_final_test = f"TEST_BEST_MODEL_{BEST_HYPERPARAMETERS_FOR_TRAINING['cell_type']}"
        with wandb.init(project="DL_A3", name=run_name_final_test, config=BEST_HYPERPARAMETERS_FOR_TRAINING, job_type="final_test_evaluation", reinit=True) as test_run:
            test_run.summary["final_test_accuracy"] = test_accuracy
            test_run.summary["final_test_loss"] = test_loss
            test_run.summary["model_checkpoint_used"] = MODEL_SAVE_PATH
            print("Final test results logged to W&B.")
    except Exception as e:
        print(f"Could not log final test results to W&B: {e}")

# testing

Q 5 testing attention

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
import random
import numpy as np
import wandb
import os
from collections import Counter
from tqdm import tqdm
import heapq
from sklearn.metrics import confusion_matrix

# --- Constants for Special Tokens ---
SOS_TOKEN = "<sos>"  # Start-of-sequence token
EOS_TOKEN = "<eos>"  # End-of-sequence token
PAD_TOKEN = "<pad>"  # Padding token
UNK_TOKEN = "<unk>"  # Unknown token

# --- Attention Mechanism ---

class Attention(nn.Module):
    """Computes alignment scores between decoder's hidden state and encoder's outputs."""
    def __init__(self, encoder_hidden_dim_eff, decoder_hidden_dim):
        super(Attention, self).__init__()
        self.attn_W = nn.Linear(encoder_hidden_dim_eff + decoder_hidden_dim, decoder_hidden_dim)
        self.attn_v = nn.Linear(decoder_hidden_dim, 1, bias=False)

    def forward(self, decoder_hidden_top_layer, encoder_outputs):
        """Calculates attention weights (probabilities)."""
        src_len = encoder_outputs.size(1)
        repeated_decoder_hidden = decoder_hidden_top_layer.unsqueeze(1).repeat(1, src_len, 1)
        energy_input = torch.cat((repeated_decoder_hidden, encoder_outputs), dim=2)
        energy = torch.tanh(self.attn_W(energy_input))
        attention_scores = self.attn_v(energy).squeeze(2)
        return F.softmax(attention_scores, dim=1)

# --- Model Definition (Encoder, DecoderWithAttention, Seq2SeqWithAttention) ---

class Encoder(nn.Module):
    """Encodes input sequence into a context vector and output states."""
    def __init__(self, input_vocab_size, embedding_dim, hidden_dim, n_layers,
                 cell_type='LSTM', dropout_p=0.1, bidirectional=False, pad_idx=0):
        super(Encoder, self).__init__()
        self.input_vocab_size = input_vocab_size
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.n_layers = n_layers
        self.cell_type = cell_type.upper()
        self.bidirectional = bidirectional
        self.num_directions = 2 if bidirectional else 1

        self.embedding = nn.Embedding(input_vocab_size, embedding_dim, padding_idx=pad_idx)
        self.dropout = nn.Dropout(dropout_p)

        rnn_dropout = dropout_p if n_layers > 1 else 0
        if self.cell_type == 'RNN':
            self.rnn = nn.RNN(embedding_dim, hidden_dim, n_layers,
                              dropout=rnn_dropout, batch_first=True, bidirectional=self.bidirectional)
        elif self.cell_type == 'GRU':
            self.rnn = nn.GRU(embedding_dim, hidden_dim, n_layers,
                              dropout=rnn_dropout, batch_first=True, bidirectional=self.bidirectional)
        elif self.cell_type == 'LSTM':
            self.rnn = nn.LSTM(embedding_dim, hidden_dim, n_layers,
                               dropout=rnn_dropout, batch_first=True, bidirectional=self.bidirectional)
        else:
            raise ValueError("Unsupported cell type. Choose from 'RNN', 'GRU', 'LSTM'.")

    def forward(self, input_seq, input_lengths):
        """Processes input sequence, returns all encoder outputs and final hidden state."""
        embedded = self.embedding(input_seq)
        embedded = self.dropout(embedded)

        packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, input_lengths.cpu(), batch_first=True, enforce_sorted=False)
        packed_outputs, hidden = self.rnn(packed_embedded)
        encoder_outputs, _ = nn.utils.rnn.pad_packed_sequence(packed_outputs, batch_first=True)
        return encoder_outputs, hidden

class DecoderWithAttention(nn.Module):
    """Generates output sequence one token at a time using an attention mechanism."""
    def __init__(self, output_vocab_size, embedding_dim, encoder_hidden_dim_eff,
                 decoder_hidden_dim, n_layers, cell_type='LSTM', dropout_p=0.1, pad_idx=0):
        super(DecoderWithAttention, self).__init__()
        self.output_vocab_size = output_vocab_size
        self.embedding_dim = embedding_dim
        self.encoder_hidden_dim_eff = encoder_hidden_dim_eff
        self.decoder_hidden_dim = decoder_hidden_dim
        self.n_layers = n_layers
        self.cell_type = cell_type.upper()

        self.attention = Attention(encoder_hidden_dim_eff, decoder_hidden_dim)
        self.embedding = nn.Embedding(output_vocab_size, embedding_dim, padding_idx=pad_idx)
        self.dropout = nn.Dropout(dropout_p)

        rnn_input_dim = embedding_dim + encoder_hidden_dim_eff

        rnn_dropout = dropout_p if n_layers > 1 else 0
        if self.cell_type == 'RNN':
            self.rnn = nn.RNN(rnn_input_dim, decoder_hidden_dim, n_layers,
                              dropout=rnn_dropout, batch_first=True)
        elif self.cell_type == 'GRU':
            self.rnn = nn.GRU(rnn_input_dim, decoder_hidden_dim, n_layers,
                              dropout=rnn_dropout, batch_first=True)
        elif self.cell_type == 'LSTM':
            self.rnn = nn.LSTM(rnn_input_dim, decoder_hidden_dim, n_layers,
                               dropout=rnn_dropout, batch_first=True)
        else:
            raise ValueError("Unsupported cell type. Choose from 'RNN', 'GRU', 'LSTM'.")

        self.fc_out = nn.Linear(decoder_hidden_dim, output_vocab_size)

    def forward(self, input_char, prev_decoder_hidden, encoder_outputs):
        """Decodes one step, applying attention and returning logits, new hidden state, and attention weights."""
        input_char = input_char.unsqueeze(1)
        embedded = self.embedding(input_char)
        embedded = self.dropout(embedded)

        if self.cell_type == 'LSTM':
            attention_query_hidden = prev_decoder_hidden[0][-1, :, :]
        else:
            attention_query_hidden = prev_decoder_hidden[-1, :, :]

        attention_weights = self.attention(attention_query_hidden, encoder_outputs)
        context_vector = torch.bmm(attention_weights.unsqueeze(1), encoder_outputs)
        rnn_input = torch.cat((embedded, context_vector), dim=2)
        rnn_output, current_decoder_hidden = self.rnn(rnn_input, prev_decoder_hidden)

        rnn_output_squeezed = rnn_output.squeeze(1)
        prediction_logits = self.fc_out(rnn_output_squeezed)

        return prediction_logits, current_decoder_hidden, attention_weights


class Seq2SeqWithAttention(nn.Module):
    """Full Attention-based Sequence-to-Sequence model."""
    def __init__(self, encoder, decoder, device, target_sos_idx):
        super(Seq2SeqWithAttention, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device
        self.target_sos_idx = target_sos_idx

        encoder_effective_final_hidden_dim = self.encoder.hidden_dim * self.encoder.num_directions
        decoder_rnn_hidden_dim = self.decoder.decoder_hidden_dim

        self.fc_adapt_hidden = None
        self.fc_adapt_cell = None

        if encoder_effective_final_hidden_dim != decoder_rnn_hidden_dim:
            self.fc_adapt_hidden = nn.Linear(encoder_effective_final_hidden_dim, decoder_rnn_hidden_dim)
            if self.encoder.cell_type == 'LSTM':
                self.fc_adapt_cell = nn.Linear(encoder_effective_final_hidden_dim, decoder_rnn_hidden_dim)

    def _adapt_encoder_hidden_for_decoder(self, encoder_final_hidden_state):
        """Adapts encoder's final hidden state to decoder's initial hidden state dimensions and layers."""
        is_lstm = self.encoder.cell_type == 'LSTM'
        if is_lstm:
            h_from_enc, c_from_enc = encoder_final_hidden_state
        else:
            h_from_enc = encoder_final_hidden_state
            c_from_enc = None

        batch_size = h_from_enc.size(1)

        h_processed = h_from_enc.view(self.encoder.n_layers, self.encoder.num_directions,
                                      batch_size, self.encoder.hidden_dim)
        if self.encoder.bidirectional:
            h_processed = torch.cat([h_processed[:, 0, :, :], h_processed[:, 1, :, :]], dim=2)
        else:
            h_processed = h_processed.squeeze(1)

        c_processed = None
        if is_lstm and c_from_enc is not None:
            c_processed = c_from_enc.view(self.encoder.n_layers, self.encoder.num_directions,
                                          batch_size, self.encoder.hidden_dim)
            if self.encoder.bidirectional:
                c_processed = torch.cat([c_processed[:, 0, :, :], c_processed[:, 1, :, :]], dim=2)
            else:
                c_processed = c_processed.squeeze(1)

        if self.fc_adapt_hidden:
            h_processed = self.fc_adapt_hidden(h_processed)
            if is_lstm and c_processed is not None and self.fc_adapt_cell:
                c_processed = self.fc_adapt_cell(c_processed)

        final_h = torch.zeros(self.decoder.n_layers, batch_size, self.decoder.decoder_hidden_dim, device=self.device)
        final_c = torch.zeros(self.decoder.n_layers, batch_size, self.decoder.decoder_hidden_dim, device=self.device) if is_lstm else None

        num_layers_to_copy = min(self.encoder.n_layers, self.decoder.n_layers)

        final_h[:num_layers_to_copy, :, :] = h_processed[:num_layers_to_copy, :, :]
        if is_lstm and c_processed is not None:
            final_c[:num_layers_to_copy, :, :] = c_processed[:num_layers_to_copy, :, :]

        if self.decoder.n_layers > self.encoder.n_layers and self.encoder.n_layers > 0:
            last_h_layer_to_repeat = h_processed[self.encoder.n_layers-1, :, :]
            for i in range(self.encoder.n_layers, self.decoder.n_layers):
                final_h[i, :, :] = last_h_layer_to_repeat
                if is_lstm and c_processed is not None:
                    last_c_layer_to_repeat = c_processed[self.encoder.n_layers-1, :, :]
                    final_c[i, :, :] = last_c_layer_to_repeat

        return (final_h, final_c) if is_lstm else final_h

    def forward(self, source_seq, source_lengths, target_seq, teacher_forcing_ratio=0.5):
        """Performs a forward pass during training with teacher forcing."""
        batch_size = source_seq.shape[0]
        target_len = target_seq.shape[1]
        target_vocab_size = self.decoder.output_vocab_size
        outputs_logits = torch.zeros(batch_size, target_len, target_vocab_size).to(self.device)

        encoder_outputs, encoder_final_hidden = self.encoder(source_seq, source_lengths)
        decoder_hidden = self._adapt_encoder_hidden_for_decoder(encoder_final_hidden)
        decoder_input = target_seq[:, 0]

        for t in range(target_len - 1):
            decoder_output_logits, decoder_hidden, _ = \
                self.decoder(decoder_input, decoder_hidden, encoder_outputs)
            outputs_logits[:, t+1] = decoder_output_logits

            teacher_force_this_step = random.random() < teacher_forcing_ratio
            top1_predicted_token = decoder_output_logits.argmax(1)

            decoder_input = target_seq[:, t+1] if teacher_force_this_step else top1_predicted_token
        return outputs_logits

    def predict_greedy(self, source_seq, source_lengths, max_output_len=50, target_eos_idx=None):
        """Generates sequence using greedy decoding."""
        self.eval()
        if source_seq.dim() == 1:
            source_seq = source_seq.unsqueeze(0)
            source_lengths = torch.tensor([source_lengths if isinstance(source_lengths, int) else len(source_lengths)], device=self.device)

        with torch.no_grad():
            encoder_outputs, encoder_final_hidden = self.encoder(source_seq, source_lengths)
            decoder_hidden = self._adapt_encoder_hidden_for_decoder(encoder_final_hidden)
            decoder_input = torch.tensor([self.target_sos_idx], device=self.device)
            predicted_indices = []
            for _ in range(max_output_len):
                decoder_output_logits, decoder_hidden, _ = \
                    self.decoder(decoder_input, decoder_hidden, encoder_outputs)
                top1_predicted_token = decoder_output_logits.argmax(1)
                predicted_idx = top1_predicted_token.item()
                if target_eos_idx is not None and predicted_idx == target_eos_idx: break
                predicted_indices.append(predicted_idx)
                decoder_input = top1_predicted_token
        return predicted_indices

    def predict_beam_search(self, source_seq, source_lengths, max_output_len=50, target_eos_idx=None, beam_width=3):
        """Generates sequence using beam search decoding."""
        self.eval()
        if source_seq.dim() == 1:
            source_seq = source_seq.unsqueeze(0)
            source_lengths = torch.tensor([source_lengths if isinstance(source_lengths, int) else len(source_lengths)], device=self.device)

        if source_seq.shape[0] != 1: raise ValueError("Beam search predict function currently supports batch_size=1.")

        with torch.no_grad():
            encoder_outputs, encoder_final_hidden = self.encoder(source_seq, source_lengths)
            decoder_hidden_init = self._adapt_encoder_hidden_for_decoder(encoder_final_hidden)

            beams = [(0.0, [self.target_sos_idx], decoder_hidden_init)]
            completed_sequences = []

            for _ in range(max_output_len):
                new_beams = []
                all_current_beams_ended = True
                for log_prob_beam, seq_beam, hidden_beam in beams:
                    if not seq_beam or seq_beam[-1] == target_eos_idx:
                        completed_sequences.append((log_prob_beam / len(seq_beam) if len(seq_beam) > 0 else -float('inf'), seq_beam))
                        continue
                    all_current_beams_ended = False

                    decoder_input = torch.tensor([seq_beam[-1]], device=self.device)
                    decoder_output_logits, next_hidden_beam, _ = \
                        self.decoder(decoder_input, hidden_beam, encoder_outputs)

                    log_probs_next_token = F.log_softmax(decoder_output_logits, dim=1)
                    topk_log_probs, topk_indices = torch.topk(log_probs_next_token, beam_width, dim=1)

                    for k in range(beam_width):
                        next_token_idx = topk_indices[0, k].item()
                        token_log_prob = topk_log_probs[0, k].item()
                        new_seq = seq_beam + [next_token_idx]
                        new_log_prob = log_prob_beam + token_log_prob
                        new_beams.append((new_log_prob, new_seq, next_hidden_beam))

                if not new_beams or all_current_beams_ended : break

                new_beams.sort(key=lambda x: x[0], reverse=True)
                beams = new_beams[:beam_width]

            for log_prob_beam, seq_beam, _ in beams:
                if not seq_beam or seq_beam[-1] != target_eos_idx:
                    completed_sequences.append((log_prob_beam / len(seq_beam) if len(seq_beam) > 0 else -float('inf'), seq_beam))

            if not completed_sequences: return [target_eos_idx] if target_eos_idx is not None else []
            completed_sequences.sort(key=lambda x: x[0], reverse=True)
            best_sequence_indices = completed_sequences[0][1]
            return best_sequence_indices[1:] if best_sequence_indices and best_sequence_indices[0] == self.target_sos_idx else best_sequence_indices

# --- Data Loading and Preprocessing ---
class Vocabulary:
    """Manages char-to-index and index-to-char mappings."""
    def __init__(self, name):
        self.name = name
        self.char2index = {PAD_TOKEN: 0, SOS_TOKEN: 1, EOS_TOKEN: 2, UNK_TOKEN: 3}
        self.index2char = {0: PAD_TOKEN, 1: SOS_TOKEN, 2: EOS_TOKEN, 3: UNK_TOKEN}
        self.char_counts = Counter()
        self.n_chars = 4
        self.pad_idx = self.char2index[PAD_TOKEN]
        self.sos_idx = self.char2index[SOS_TOKEN]
        self.eos_idx = self.char2index[EOS_TOKEN]

    def add_sequence(self, sequence):
        """Adds characters from a sequence to the character counts."""
        for char in list(sequence): self.char_counts[char] += 1

    def build_vocab(self, min_freq=1):
        """Builds vocabulary based on char counts and min frequency."""
        sorted_chars = sorted(self.char_counts.keys(), key=lambda char: (-self.char_counts[char], char))
        for char in sorted_chars:
            if self.char_counts[char] >= min_freq and char not in self.char2index:
                self.char2index[char] = self.n_chars
                self.index2char[self.n_chars] = char
                self.n_chars += 1

    def sequence_to_indices(self, sequence, add_eos=False, add_sos=False):
        """Converts char sequence to numerical indices."""
        indices = [self.sos_idx] if add_sos else []
        for char in list(sequence): indices.append(self.char2index.get(char, self.char2index[UNK_TOKEN]))
        if add_eos: indices.append(self.eos_idx)
        return indices

    def indices_to_sequence(self, indices):
        """Converts numerical indices back to char sequence."""
        chars = []
        for index_val in indices:
            if index_val == self.eos_idx: break
            if index_val not in [self.sos_idx, self.pad_idx]: chars.append(self.index2char.get(index_val, UNK_TOKEN))
        return "".join(chars)

class TransliterationDataset(Dataset):
    """Loads and preprocesses transliteration pairs from a TSV file."""
    def __init__(self, file_path, source_vocab, target_vocab, max_len=50):
        self.pairs = []
        self.source_vocab = source_vocab
        self.target_vocab = target_vocab
        if not os.path.exists(file_path):
            print(f"ERROR: Data file not found: {file_path}")
            return
        print(f"Loading data from: {file_path}")
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                parts = line.strip().split('\t')
                if len(parts) >= 2:
                    target, source = parts[0], parts[1]
                    if not source or not target or \
                       (max_len and (len(source) > max_len or len(target) > max_len)):
                        continue
                    self.pairs.append((source, target))
        print(f"Loaded {len(self.pairs)} pairs from {file_path}.")

    def __len__(self): return len(self.pairs)
    def __getitem__(self, idx):
        source_str, target_str = self.pairs[idx]
        source_indices = self.source_vocab.sequence_to_indices(source_str, add_eos=True)
        target_indices = self.target_vocab.sequence_to_indices(target_str, add_sos=True, add_eos=True)
        return torch.tensor(source_indices, dtype=torch.long), torch.tensor(target_indices, dtype=torch.long)

def collate_fn(batch, pad_idx_source, pad_idx_target):
    """Pads sequences in a batch for DataLoader."""
    batch = [item for item in batch if item is not None and len(item[0]) > 0 and len(item[1]) > 0]
    if not batch: return None, None, None
    source_seqs, target_seqs = zip(*batch)
    source_lengths = torch.tensor([len(s) for s in source_seqs], dtype=torch.long)
    padded_sources = pad_sequence(source_seqs, batch_first=True, padding_value=pad_idx_source)
    padded_targets = pad_sequence(target_seqs, batch_first=True, padding_value=pad_idx_target)
    return padded_sources, source_lengths, padded_targets

# --- Training and Evaluation Functions ---
def _train_one_epoch(model, dataloader, optimizer, criterion, device, clip_value, teacher_forcing_ratio, target_vocab):
    """Trains the model for one epoch."""
    model.train()
    epoch_loss = 0
    total_correct_train = 0
    total_samples_train = 0
    if len(dataloader) == 0: return 0.0, 0.0

    for batch_data in tqdm(dataloader, desc="Training", leave=False):
        if batch_data[0] is None: continue
        sources, source_lengths, targets = batch_data
        if sources is None or sources.shape[0] == 0: continue

        sources, targets, source_lengths = sources.to(device), targets.to(device), source_lengths.to(device)
        optimizer.zero_grad()
        outputs_logits = model(sources, source_lengths, targets, teacher_forcing_ratio)

        output_dim = outputs_logits.shape[-1]
        flat_outputs = outputs_logits[:, 1:].reshape(-1, output_dim)
        flat_targets = targets[:, 1:].reshape(-1)

        loss = criterion(flat_outputs, flat_targets)
        if not (torch.isnan(loss) or torch.isinf(loss)):
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), clip_value)
            optimizer.step()
            epoch_loss += loss.item()

        predictions_indices_train = outputs_logits.argmax(dim=2)
        for i in range(targets.shape[0]):
            pred_str_train = target_vocab.indices_to_sequence(predictions_indices_train[i, 1:].tolist())
            true_str_train = target_vocab.indices_to_sequence(targets[i, 1:].tolist())
            if pred_str_train == true_str_train: total_correct_train += 1
            total_samples_train += 1

    avg_epoch_loss = epoch_loss / len(dataloader) if len(dataloader) > 0 else 0.0
    train_accuracy = total_correct_train / total_samples_train if total_samples_train > 0 else 0.0
    return avg_epoch_loss, train_accuracy

def _evaluate_one_epoch(model, dataloader, criterion, device, source_vocab, target_vocab, beam_width=1, is_test_set=False):
    """Evaluates the model on a dataset, collecting predictions and character-level data."""
    model.eval()
    epoch_loss = 0
    total_correct, total_samples = 0, 0
    all_predictions_for_file_eval = []
    all_true_chars_flat_eval = []
    all_pred_chars_flat_eval = []

    if len(dataloader) == 0: return 0.0, 0.0, [], [], []
    desc_prefix = "Testing" if is_test_set else "Validating"

    with torch.no_grad():
        for batch_idx, batch_data in enumerate(tqdm(dataloader, desc=desc_prefix, leave=False)):
            if batch_data[0] is None: continue
            sources, source_lengths, targets = batch_data
            if sources is None or sources.shape[0] == 0: continue
            sources, targets, source_lengths = sources.to(device), targets.to(device), source_lengths.to(device)

            outputs_for_loss = model(sources, source_lengths, targets, teacher_forcing_ratio=0.0)
            output_dim = outputs_for_loss.shape[-1]
            flat_outputs_for_loss = outputs_for_loss[:, 1:].reshape(-1, output_dim)
            flat_targets_for_loss = targets[:, 1:].reshape(-1)
            loss = criterion(flat_outputs_for_loss, flat_targets_for_loss)
            epoch_loss += loss.item() if not (torch.isnan(loss) or torch.isinf(loss)) else 0

            for i in range(sources.shape[0]):
                src_single_indices = sources[i].tolist()
                src_single_tensor, src_len_single = sources[i:i+1], source_lengths[i:i+1]
                original_source_str = source_vocab.indices_to_sequence(src_single_indices)

                if beam_width > 1 and hasattr(model, 'predict_beam_search'):
                    predicted_indices = model.predict_beam_search(src_single_tensor, src_len_single,
                                                                  max_output_len=targets.size(1) + 5,
                                                                  target_eos_idx=target_vocab.eos_idx,
                                                                  beam_width=beam_width)
                else:
                    predicted_indices = model.predict_greedy(src_single_tensor, src_len_single,
                                                             max_output_len=targets.size(1) + 5,
                                                             target_eos_idx=target_vocab.eos_idx)

                pred_str = target_vocab.indices_to_sequence(predicted_indices)
                true_str = target_vocab.indices_to_sequence(targets[i, 1:].tolist())

                if is_test_set:
                    all_predictions_for_file_eval.append((original_source_str, pred_str, true_str))
                    true_chars_cm = list(true_str)
                    pred_chars_cm = list(pred_str)
                    min_len_cm = min(len(true_chars_cm), len(pred_chars_cm))
                    all_true_chars_flat_eval.extend(true_chars_cm[:min_len_cm])
                    all_pred_chars_flat_eval.extend(pred_chars_cm[:min_len_cm])

                if pred_str == true_str: total_correct += 1
                total_samples += 1

                if is_test_set and batch_idx == 0 and i < 3:
                    print(f"  Test Example {i} - Source: '{original_source_str}'")
                    print(f"    Pred: '{pred_str}', True: '{true_str}', Match: {pred_str == true_str}")

    avg_epoch_loss = epoch_loss / len(dataloader) if len(dataloader) > 0 else 0.0
    accuracy = total_correct / total_samples if total_samples > 0 else 0.0
    print(f"{desc_prefix} - Total Correct: {total_correct}, Total Samples: {total_samples}, Accuracy: {accuracy:.4f}, Avg Loss: {avg_epoch_loss:.4f}")

    if is_test_set:
        return avg_epoch_loss, accuracy, all_predictions_for_file_eval, all_true_chars_flat_eval, all_pred_chars_flat_eval
    else:
        return avg_epoch_loss, accuracy


# --- Function to Train and Save the Best Attention Model ---
def train_and_save_attention_model(config_params, model_save_path, device):
    """Trains and saves the best attention model using provided hyperparameters."""
    run_name_train_best_attn = f"TRAIN_BEST_ATTN_{config_params['cell_type']}_emb{config_params['embedding_dim']}_hid{config_params['hidden_dim']}"

    with wandb.init(project="DL_A3_Attention_Training", name=run_name_train_best_attn, config=config_params, job_type="training_best_attention_model", reinit=True) as run:
        cfg = wandb.config
        print(f"Starting dedicated training for BEST ATTENTION model with config: {cfg}")

        BASE_DATA_DIR = "/kaggle/input/dakshina-dl-a3/dakshina_dataset_v1.0/hi/"
        DATA_DIR = os.path.join(BASE_DATA_DIR, "lexicons/")
        train_file = os.path.join(DATA_DIR, "hi.translit.sampled.train.tsv")
        dev_file = os.path.join(DATA_DIR, "hi.translit.sampled.dev.tsv")

        source_vocab = Vocabulary("latin")
        target_vocab = Vocabulary("devanagari")
        temp_train_ds_vocab = TransliterationDataset(train_file, source_vocab, target_vocab, max_len=cfg.max_seq_len)
        if not temp_train_ds_vocab.pairs: return False
        for src, tgt in temp_train_ds_vocab.pairs:
            source_vocab.add_sequence(src)
            target_vocab.add_sequence(tgt)
        source_vocab.build_vocab(min_freq=cfg.vocab_min_freq)
        target_vocab.build_vocab(min_freq=cfg.vocab_min_freq)

        train_dataset = TransliterationDataset(train_file, source_vocab, target_vocab, max_len=cfg.max_seq_len)
        dev_dataset = TransliterationDataset(dev_file, source_vocab, target_vocab, max_len=cfg.max_seq_len)
        if not train_dataset.pairs or not dev_dataset.pairs: return False

        num_w = 0 if device.type == 'cpu' else min(4, os.cpu_count()//2 if os.cpu_count() else 0)
        train_loader = DataLoader(train_dataset, batch_size=cfg.batch_size_train, shuffle=True,
                                  collate_fn=lambda b: collate_fn(b, source_vocab.pad_idx, target_vocab.pad_idx),
                                  num_workers=num_w, pin_memory=True if device.type == 'cuda' else False, drop_last=True)
        dev_loader = DataLoader(dev_dataset, batch_size=cfg.batch_size_train, shuffle=False,
                                collate_fn=lambda b: collate_fn(b, source_vocab.pad_idx, target_vocab.pad_idx),
                                num_workers=num_w, pin_memory=True if device.type == 'cuda' else False)
        if not train_loader or not dev_loader: return False

        encoder_hidden_dim_eff = cfg.hidden_dim * (2 if cfg.encoder_bidirectional else 1)
        encoder = Encoder(source_vocab.n_chars, cfg.embedding_dim, cfg.hidden_dim,
                          cfg.encoder_layers, cfg.cell_type, cfg.dropout_p,
                          cfg.encoder_bidirectional, pad_idx=source_vocab.pad_idx).to(device)
        decoder = DecoderWithAttention(target_vocab.n_chars, cfg.embedding_dim, encoder_hidden_dim_eff,
                                       cfg.hidden_dim, cfg.decoder_layers, cfg.cell_type,
                                       cfg.dropout_p, pad_idx=target_vocab.pad_idx).to(device)
        model = Seq2SeqWithAttention(encoder, decoder, device, target_vocab.sos_idx).to(device)
        print(f"Best Attention Model parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
        wandb.watch(model, log="all", log_freq=100)

        optimizer = optim.Adam(model.parameters(), lr=cfg.learning_rate_train)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=cfg.lr_scheduler_patience_train, factor=0.3, verbose=True)
        criterion = nn.CrossEntropyLoss(ignore_index=target_vocab.pad_idx)

        best_val_acc_this_training = -1.0
        epochs_no_improve = 0

        for epoch in range(cfg.epochs_train):
            train_loss, train_acc = _train_one_epoch(model, train_loader, optimizer, criterion, device,
                                                     cfg.clip_value_train, cfg.teacher_forcing_train, target_vocab)
            val_loss, val_acc = _evaluate_one_epoch(model, dev_loader, criterion, device,
                                                    source_vocab, target_vocab, beam_width=1)

            scheduler.step(val_acc)
            print(f"Epoch {epoch+1}/{cfg.epochs_train} | Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} | Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}")
            wandb.log({"epoch_train_best_attn": epoch + 1, "train_loss_best_attn": train_loss, "train_acc_best_attn": train_acc,
                       "val_loss_best_attn": val_loss, "val_acc_best_attn": val_acc, "lr_best_attn": optimizer.param_groups[0]['lr']})

            if val_acc > best_val_acc_this_training:
                best_val_acc_this_training = val_acc
                epochs_no_improve = 0
                torch.save(model.state_dict(), model_save_path)
                print(f"Saved new best attention model to {model_save_path} (Val Acc: {best_val_acc_this_training:.4f})")
            else:
                epochs_no_improve += 1

            if epochs_no_improve >= cfg.max_epochs_no_improve_train:
                print(f"Early stopping for best attention model training at epoch {epoch+1}.")
                break

        if not os.path.exists(model_save_path):
            torch.save(model.state_dict(), model_save_path)
            print(f"Saved final attention model state to {model_save_path}")

        wandb.summary["final_best_val_accuracy_during_attn_training"] = best_val_acc_this_training
        print(f"Finished training best attention model. Best Val Acc: {best_val_acc_this_training:.4f}")
    return True


def display_sample_predictions_markdown(predictions_data, num_samples=15):
    """Generates a Markdown table for displaying a sample of predictions."""
    print(f"\n--- Displaying {min(num_samples, len(predictions_data))} Sample Predictions ---")
    markdown_table = "| Input (Latin) | True Output (Devanagari) | Model Prediction (Devanagari) | Correct? |\n"
    markdown_table += "|---|---|---|---|\n"
    for i, (source_str, predicted_str, true_target_str) in enumerate(predictions_data[:num_samples]):
        is_correct = "✅ Yes" if predicted_str == true_target_str else "❌ No"
        markdown_table += f"| {source_str} | {true_target_str} | {predicted_str} | {is_correct} |\n"
    print(markdown_table)
    return markdown_table

def save_all_predictions_to_file(predictions_data, output_file_path):
    """Saves all predictions to a TSV file."""
    print(f"\n--- Saving all predictions to file: {output_file_path} ---")
    os.makedirs(os.path.dirname(output_file_path), exist_ok=True)
    count = 0
    with open(output_file_path, 'w', encoding='utf-8') as f_out:
        for source_str, predicted_str, _ in predictions_data:
            f_out.write(f"{source_str}\t{predicted_str}\n")
            count +=1
    print(f"Generated {count} predictions and saved to {output_file_path}")


# --- Main Execution Block for Q5 ---
if __name__ == '__main__':
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {DEVICE}")

    # --- W&B Login ---
    try:
        if 'KAGGLE_KERNEL_RUN_TYPE' in os.environ:
            print("Detected Kaggle environment. Ensure WANDB_API_KEY is set.")
            if "WANDB_API_KEY" in os.environ: wandb.login()
            else:
                from kaggle_secrets import UserSecretsClient
                user_secrets = UserSecretsClient()
                wandb_api_key = user_secrets.get_secret("WANDB_API_KEY")
                wandb.login(key=wandb_api_key)
            print("W&B login for Kaggle successful.")
        else:
            wandb.login()
        print("W&B login process attempted/completed.")
    except Exception as e:
        print(f"W&B login failed: {e}. Ensure API key is configured.")
        exit()

    # --- Best Attention Model Hyperparameters (UPDATE THIS SECTION) ---
    # These should come from a NEW W&B sweep specifically for the Attention Model.
    # The values below are placeholders and need to be replaced with your findings.
    BEST_ATTN_HYPERPARAMETERS = {
        'embedding_dim': 256,
        'hidden_dim': 512,
        'encoder_layers': 1,
        'decoder_layers': 1,
        'cell_type': 'LSTM',
        'dropout_p': 0.3,
        'encoder_bidirectional': True,
        'learning_rate_train': 0.001,
        'batch_size_train': 64,
        'epochs_train': 20,
        'clip_value_train': 1.0,
        'teacher_forcing_train': 0.5,
        'max_epochs_no_improve_train': 5,
        'lr_scheduler_patience_train': 2,
        'vocab_min_freq': 1,
        'max_seq_len': 50,
        'eval_batch_size': 64,
        'beam_width_eval': 3
    }
    ATTN_MODEL_SAVE_PATH = "/kaggle/working/best_attention_model_q5.pt"
    PREDICTIONS_ATTN_OUTPUT_DIR = "/kaggle/working/predictions_attention"
    PREDICTIONS_ATTN_FILE_PATH = os.path.join(PREDICTIONS_ATTN_OUTPUT_DIR, "predictions_attention.tsv")

    print(f"Target hyperparameters for best attention model: {BEST_ATTN_HYPERPARAMETERS}")
    print("IMPORTANT: The above BEST_ATTN_HYPERPARAMETERS are placeholders if you haven't run an attention-specific sweep.")

    # --- 1. Train and Save the Best Attention Model ---
    print("\n--- Phase 1: Training and Saving Best Attention Model Configuration ---")
    training_successful = train_and_save_attention_model(BEST_ATTN_HYPERPARAMETERS, ATTN_MODEL_SAVE_PATH, DEVICE)

    if not training_successful or not os.path.exists(ATTN_MODEL_SAVE_PATH):
        print(f"ERROR: Failed to train/save the attention model, or checkpoint not found at {ATTN_MODEL_SAVE_PATH}. Exiting.")
        exit()
    print(f"Best attention model trained and saved to {ATTN_MODEL_SAVE_PATH}")

    # --- 2. Load and Evaluate on Test Set ---
    print("\n--- Phase 2: Loading and Evaluating Best Attention Model on Test Set ---")

    BASE_DATA_DIR = "/kaggle/input/dakshina-dl-a3/dakshina_dataset_v1.0/hi/"
    DATA_DIR = os.path.join(BASE_DATA_DIR, "lexicons/")
    train_file_for_vocab = os.path.join(DATA_DIR, "hi.translit.sampled.train.tsv")
    test_file = os.path.join(DATA_DIR, "hi.translit.sampled.test.tsv")

    source_vocab_test_attn = Vocabulary("latin_test_attn")
    target_vocab_test_attn = Vocabulary("devanagari_test_attn")

    temp_train_ds_test_vocab_attn = TransliterationDataset(train_file_for_vocab, source_vocab_test_attn, target_vocab_test_attn,
                                                max_len=BEST_ATTN_HYPERPARAMETERS['max_seq_len'])
    if not temp_train_ds_test_vocab_attn.pairs: exit()
    for src, tgt in temp_train_ds_test_vocab_attn.pairs:
        source_vocab_test_attn.add_sequence(src)
        target_vocab_test_attn.add_sequence(tgt)
    source_vocab_test_attn.build_vocab(min_freq=BEST_ATTN_HYPERPARAMETERS['vocab_min_freq'])
    target_vocab_test_attn.build_vocab(min_freq=BEST_ATTN_HYPERPARAMETERS['vocab_min_freq'])

    test_dataset_attn = TransliterationDataset(test_file, source_vocab_test_attn, target_vocab_test_attn,
                                                  max_len=BEST_ATTN_HYPERPARAMETERS['max_seq_len'])
    if not test_dataset_attn.pairs: exit()

    num_w_test_attn = 0 if DEVICE.type == 'cpu' else min(4, os.cpu_count()//2 if os.cpu_count() else 0)
    test_dataloader_attn = DataLoader(test_dataset_attn, batch_size=BEST_ATTN_HYPERPARAMETERS['eval_batch_size'], shuffle=False,
                                 collate_fn=lambda b: collate_fn(b, source_vocab_test_attn.pad_idx, target_vocab_test_attn.pad_idx),
                                 num_workers=num_w_test_attn, pin_memory=True if DEVICE.type == 'cuda' else False)
    if not test_dataloader_attn or len(test_dataloader_attn) == 0: exit()

    encoder_hidden_dim_eff_test = BEST_ATTN_HYPERPARAMETERS['hidden_dim'] * (2 if BEST_ATTN_HYPERPARAMETERS['encoder_bidirectional'] else 1)
    encoder_test_attn = Encoder(source_vocab_test_attn.n_chars, BEST_ATTN_HYPERPARAMETERS['embedding_dim'], BEST_ATTN_HYPERPARAMETERS['hidden_dim'],
                                BEST_ATTN_HYPERPARAMETERS['encoder_layers'], BEST_ATTN_HYPERPARAMETERS['cell_type'], BEST_ATTN_HYPERPARAMETERS['dropout_p'],
                                BEST_ATTN_HYPERPARAMETERS['encoder_bidirectional'], pad_idx=source_vocab_test_attn.pad_idx).to(DEVICE)
    decoder_test_attn = DecoderWithAttention(target_vocab_test_attn.n_chars, BEST_ATTN_HYPERPARAMETERS['embedding_dim'],
                                             encoder_hidden_dim_eff_test, BEST_ATTN_HYPERPARAMETERS['hidden_dim'],
                                             BEST_ATTN_HYPERPARAMETERS['decoder_layers'], BEST_ATTN_HYPERPARAMETERS['cell_type'], BEST_ATTN_HYPERPARAMETERS['dropout_p'],
                                             pad_idx=target_vocab_test_attn.pad_idx).to(DEVICE)
    model_test_attn = Seq2SeqWithAttention(encoder_test_attn, decoder_test_attn, DEVICE, target_vocab_test_attn.sos_idx).to(DEVICE)

    print(f"Loading weights into attention model from: {ATTN_MODEL_SAVE_PATH}")
    model_test_attn.load_state_dict(torch.load(ATTN_MODEL_SAVE_PATH, map_location=DEVICE))
    print("Attention model weights loaded successfully for test evaluation.")

    criterion_test_attn = nn.CrossEntropyLoss(ignore_index=target_vocab_test_attn.pad_idx)

    # (5b) Evaluate Attention Model on Test Set
    test_loss_attn, test_accuracy_attn, all_test_predictions_data_attn, \
    all_true_chars_cm_attn, all_pred_chars_cm_attn = _evaluate_one_epoch(
                                                            model_test_attn, test_dataloader_attn, criterion_test_attn, DEVICE,
                                                            source_vocab_test_attn, target_vocab_test_attn,
                                                            beam_width=BEST_ATTN_HYPERPARAMETERS.get('beam_width_eval', 1),
                                                            is_test_set=True)

    print(f"\n--- (5b) FINAL TEST SET PERFORMANCE (ATTENTION MODEL) ---")
    print(f"Model Checkpoint: '{ATTN_MODEL_SAVE_PATH}'")
    print(f"Hyperparameters: {BEST_ATTN_HYPERPARAMETERS}")
    print(f"  Test Loss (avg per batch): {test_loss_attn:.4f}")
    print(f"  Test Accuracy (Exact Match): {test_accuracy_attn:.4f} ({test_accuracy_attn*100:.2f}%)")

    # (5b) Save all predictions to file
    if not os.path.exists(PREDICTIONS_ATTN_OUTPUT_DIR):
        os.makedirs(PREDICTIONS_ATTN_OUTPUT_DIR)
    save_all_predictions_to_file(all_test_predictions_data_attn, PREDICTIONS_ATTN_FILE_PATH)
    print(f"All attention model test predictions saved to: {PREDICTIONS_ATTN_FILE_PATH}")
    print(f"Please upload the file '{os.path.basename(PREDICTIONS_ATTN_FILE_PATH)}' to a folder named 'predictions_attention' in your GitHub project.")

    # --- Log Test Results and Artifacts to W&B ---
    try:
        run_name_final_test_attn = f"FINAL_TEST_ATTN_{BEST_ATTN_HYPERPARAMETERS['cell_type']}"
        with wandb.init(project="DL_A3_Attention_Test", name=run_name_final_test_attn, config=BEST_ATTN_HYPERPARAMETERS, job_type="final_attention_evaluation", reinit=True) as test_run_attn:
            test_run_attn.summary["final_test_accuracy_attention"] = test_accuracy_attn
            test_run_attn.summary["final_test_loss_attention"] = test_loss_attn
            test_run_attn.summary["model_checkpoint_used_attention"] = ATTN_MODEL_SAVE_PATH

            # Log sample predictions table
            wandb_table_cols_attn = ["Input (Latin)", "True Output (Devanagari)", "Model Prediction (Devanagari)", "Correct?"]
            wandb_table_data_attn = []
            for src, pred, true_tgt in all_test_predictions_data_attn[:50]: # Log first 50 samples
                wandb_table_data_attn.append([src, true_tgt, pred, "Yes" if pred == true_tgt else "No"])
            test_run_attn.log({"test_sample_predictions_attention": wandb.Table(columns=wandb_table_cols_attn, data=wandb_table_data_attn)})

            # Log predictions file as an artifact
            if os.path.exists(PREDICTIONS_ATTN_FILE_PATH):
                predictions_artifact_attn = wandb.Artifact("test_predictions_attention_file", type="predictions")
                predictions_artifact_attn.add_file(PREDICTIONS_ATTN_FILE_PATH)
                test_run_attn.log_artifact(predictions_artifact_attn)

            # Log character-level confusion matrix
            if all_true_chars_cm_attn and all_pred_chars_cm_attn:
                # Get unique characters that are actually in both true and pred lists and in vocab
                cm_labels_indices_attn = sorted(list(set(target_vocab_test_attn.char2index[c] for c in all_true_chars_cm_attn + all_pred_chars_cm_attn
                                                          if c not in [SOS_TOKEN, EOS_TOKEN, PAD_TOKEN, UNK_TOKEN] and c in target_vocab_test_attn.char2index)))
                cm_labels_chars_attn = [target_vocab_test_attn.index2char[i] for i in cm_labels_indices_attn]

                if cm_labels_chars_attn: # Only plot if there are valid characters to plot
                    # Filter true/pred chars to only include those in cm_labels_chars_attn to avoid errors with unknown/special tokens
                    filtered_true_chars_attn = [c for c in all_true_chars_cm_attn if c in cm_labels_chars_attn]
                    filtered_pred_chars_attn = [c for c in all_pred_chars_cm_attn if c in cm_labels_chars_attn]

                    min_len_for_cm_attn = min(len(filtered_true_chars_attn), len(filtered_pred_chars_attn))
                    if min_len_for_cm_attn > 0:
                        test_run_attn.log({"char_confusion_matrix_attention": wandb.plot.confusion_matrix(
                                                        preds=filtered_pred_chars_attn[:min_len_for_cm_attn],
                                                        y_true=filtered_true_chars_attn[:min_len_for_cm_attn],
                                                        class_names=cm_labels_chars_attn
                                                    )})
                        print("Attention model character-level confusion matrix logged to W&B.")
            print("Attention model final test results and artifacts logged to W&B.")
    except Exception as e:
        print(f"Could not log attention model final test results to W&B: {e}")

Q 5c table

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
import random
import numpy as np
import wandb
import os
from collections import Counter
from tqdm import tqdm
import heapq
import matplotlib
matplotlib.use('Agg') # Use a non-interactive backend for Kaggle, crucial for saving plots
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker

# --- Constants for Special Tokens ---
SOS_TOKEN = "<sos>"  # Start-of-sequence token
EOS_TOKEN = "<eos>"  # End-of-sequence token
PAD_TOKEN = "<pad>"  # Padding token
UNK_TOKEN = "<unk>"  # Unknown token

# --- Attention Mechanism ---

class Attention(nn.Module):
    """
    Implements a general (Luong-style) attention mechanism to compute alignment scores
    between the decoder's current hidden state and all encoder's output states.
    """
    def __init__(self, encoder_hidden_dim_eff, decoder_hidden_dim):
        super(Attention, self).__init__()
        self.attn_W = nn.Linear(encoder_hidden_dim_eff + decoder_hidden_dim, decoder_hidden_dim)
        self.attn_v = nn.Linear(decoder_hidden_dim, 1, bias=False)

    def forward(self, decoder_hidden_top_layer, encoder_outputs):
        """
        Calculates attention weights.

        Args:
            decoder_hidden_top_layer (torch.Tensor): The top layer's hidden state of the decoder RNN (batch_size, dec_hidden_dim).
            encoder_outputs (torch.Tensor): All hidden states from the encoder (batch_size, src_len, enc_hidden_dim).

        Returns:
            torch.Tensor: Attention weights (batch_size, src_len) representing alignment probabilities.
        """
        src_len = encoder_outputs.size(1)
        repeated_decoder_hidden = decoder_hidden_top_layer.unsqueeze(1).repeat(1, src_len, 1)

        energy_input = torch.cat((repeated_decoder_hidden, encoder_outputs), dim=2)
        energy = torch.tanh(self.attn_W(energy_input))

        attention_scores = self.attn_v(energy).squeeze(2)
        return F.softmax(attention_scores, dim=1)

# --- Model Definition (Encoder, DecoderWithAttention, Seq2SeqWithAttention) ---

class Encoder(nn.Module):
    """
    The Encoder processes the input sequence, producing a context vector (final hidden state)
    and a sequence of output states for the attention mechanism.
    """
    def __init__(self, input_vocab_size, embedding_dim, hidden_dim, n_layers,
                 cell_type='LSTM', dropout_p=0.1, bidirectional=False, pad_idx=0):
        super(Encoder, self).__init__()
        self.input_vocab_size = input_vocab_size
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.n_layers = n_layers
        self.cell_type = cell_type.upper()
        self.bidirectional = bidirectional
        self.num_directions = 2 if bidirectional else 1

        self.embedding = nn.Embedding(input_vocab_size, embedding_dim, padding_idx=pad_idx)
        self.dropout = nn.Dropout(dropout_p)

        rnn_dropout = dropout_p if n_layers > 1 else 0
        if self.cell_type == 'RNN':
            self.rnn = nn.RNN(embedding_dim, hidden_dim, n_layers,
                              dropout=rnn_dropout, batch_first=True, bidirectional=self.bidirectional)
        elif self.cell_type == 'GRU':
            self.rnn = nn.GRU(embedding_dim, hidden_dim, n_layers,
                              dropout=rnn_dropout, batch_first=True, bidirectional=self.bidirectional)
        elif self.cell_type == 'LSTM':
            self.rnn = nn.LSTM(embedding_dim, hidden_dim, n_layers,
                               dropout=rnn_dropout, batch_first=True, bidirectional=self.bidirectional)
        else:
            raise ValueError("Unsupported cell type. Choose from 'RNN', 'GRU', 'LSTM'.")

    def forward(self, input_seq, input_lengths):
        """
        Forward pass for the encoder.

        Args:
            input_seq (torch.Tensor): Padded input sequences (batch_size, seq_len).
            input_lengths (torch.Tensor): Lengths of the original sequences (batch_size,).

        Returns:
            tuple: (`encoder_outputs`, `hidden_state`).
                   `encoder_outputs`: All hidden states from the last layer (batch_size, src_len, hidden_dim * num_directions).
                   `hidden_state`: The final hidden state (and cell state for LSTM) from all layers.
        """
        embedded = self.embedding(input_seq)
        embedded = self.dropout(embedded)

        packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, input_lengths.cpu(), batch_first=True, enforce_sorted=False)
        packed_outputs, hidden = self.rnn(packed_embedded)
        encoder_outputs, _ = nn.utils.rnn.pad_packed_sequence(packed_outputs, batch_first=True)
        return encoder_outputs, hidden

class DecoderWithAttention(nn.Module):
    """
    The Decoder generates the output sequence one token at a time,
    incorporating an attention mechanism over the encoder's output states.
    """
    def __init__(self, output_vocab_size, embedding_dim, encoder_hidden_dim_eff,
                 decoder_hidden_dim, n_layers, cell_type='LSTM', dropout_p=0.1, pad_idx=0):
        super(DecoderWithAttention, self).__init__()
        self.output_vocab_size = output_vocab_size
        self.embedding_dim = embedding_dim
        self.encoder_hidden_dim_eff = encoder_hidden_dim_eff
        self.decoder_hidden_dim = decoder_hidden_dim
        self.n_layers = n_layers
        self.cell_type = cell_type.upper()

        self.attention = Attention(encoder_hidden_dim_eff, decoder_hidden_dim)
        self.embedding = nn.Embedding(output_vocab_size, embedding_dim, padding_idx=pad_idx)
        self.dropout = nn.Dropout(dropout_p)

        # RNN input is concatenation of token embedding and context vector
        rnn_input_dim = embedding_dim + encoder_hidden_dim_eff

        rnn_dropout = dropout_p if n_layers > 1 else 0
        if self.cell_type == 'RNN':
            self.rnn = nn.RNN(rnn_input_dim, decoder_hidden_dim, n_layers,
                              dropout=rnn_dropout, batch_first=True)
        elif self.cell_type == 'GRU':
            self.rnn = nn.GRU(rnn_input_dim, decoder_hidden_dim, n_layers,
                              dropout=rnn_dropout, batch_first=True)
        elif self.cell_type == 'LSTM':
            self.rnn = nn.LSTM(rnn_input_dim, decoder_hidden_dim, n_layers,
                               dropout=rnn_dropout, batch_first=True)
        else:
            raise ValueError("Unsupported cell type. Choose from 'RNN', 'GRU', 'LSTM'.")

        self.fc_out = nn.Linear(decoder_hidden_dim, output_vocab_size)

    def forward(self, input_char, prev_decoder_hidden, encoder_outputs):
        """
        Forward pass for the attention-based decoder.

        Args:
            input_char (torch.Tensor): Current target token (batch_size,).
            prev_decoder_hidden (torch.Tensor or tuple): Previous hidden state(s) of the decoder RNN.
            encoder_outputs (torch.Tensor): All hidden states from the encoder (batch_size, src_len, enc_hidden_dim_eff).

        Returns:
            tuple: (`prediction_logits`, `current_decoder_hidden`, `attention_weights`).
                   `prediction_logits`: Raw scores for each token in the vocabulary.
                   `current_decoder_hidden`: Updated hidden state of the decoder RNN.
                   `attention_weights`: Attention probabilities (batch_size, src_len).
        """
        input_char = input_char.unsqueeze(1)
        embedded = self.embedding(input_char)
        embedded = self.dropout(embedded)

        # Get the hidden state for attention query (top layer's hidden state)
        if self.cell_type == 'LSTM':
            attention_query_hidden = prev_decoder_hidden[0][-1, :, :]
        else:
            attention_query_hidden = prev_decoder_hidden[-1, :, :]

        # Calculate attention weights and context vector
        attention_weights = self.attention(attention_query_hidden, encoder_outputs)
        context_vector = torch.bmm(attention_weights.unsqueeze(1), encoder_outputs)

        # Concatenate embedded input and context vector to form RNN input
        rnn_input = torch.cat((embedded, context_vector), dim=2)

        # Pass through decoder RNN
        rnn_output, current_decoder_hidden = self.rnn(rnn_input, prev_decoder_hidden)

        # Project RNN output to vocabulary size
        rnn_output_squeezed = rnn_output.squeeze(1)
        prediction_logits = self.fc_out(rnn_output_squeezed)

        return prediction_logits, current_decoder_hidden, attention_weights


class Seq2SeqWithAttention(nn.Module):
    """
    The main Sequence-to-Sequence model integrating the Encoder and Attention-based Decoder.
    """
    def __init__(self, encoder, decoder, device, target_sos_idx):
        super(Seq2SeqWithAttention, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device
        self.target_sos_idx = target_sos_idx

        # Layers to adapt encoder's final hidden state to decoder's initial hidden state
        encoder_effective_final_hidden_dim = self.encoder.hidden_dim * self.encoder.num_directions
        decoder_rnn_hidden_dim = self.decoder.decoder_hidden_dim

        self.fc_adapt_hidden = None
        self.fc_adapt_cell = None
        if encoder_effective_final_hidden_dim != decoder_rnn_hidden_dim:
            self.fc_adapt_hidden = nn.Linear(encoder_effective_final_hidden_dim, decoder_rnn_hidden_dim)
            if self.encoder.cell_type == 'LSTM':
                self.fc_adapt_cell = nn.Linear(encoder_effective_final_hidden_dim, decoder_rnn_hidden_dim)

    def _adapt_encoder_hidden_for_decoder(self, encoder_final_hidden_state):
        """
        Adapts the encoder's final hidden state(s) to match the decoder's expected dimensions and layer count.
        Handles bidirectionality and differing layer counts.
        """
        is_lstm = self.encoder.cell_type == 'LSTM'
        if is_lstm:
            h_from_enc, c_from_enc = encoder_final_hidden_state
        else:
            h_from_enc = encoder_final_hidden_state
            c_from_enc = None

        batch_size = h_from_enc.size(1)

        # Reshape and concatenate bidirectional states if applicable
        h_processed = h_from_enc.view(self.encoder.n_layers, self.encoder.num_directions,
                                      batch_size, self.encoder.hidden_dim)
        if self.encoder.bidirectional:
            h_processed = torch.cat([h_processed[:, 0, :, :], h_processed[:, 1, :, :]], dim=2)
        else:
            h_processed = h_processed.squeeze(1)

        c_processed = None
        if is_lstm and c_from_enc is not None:
            c_processed = c_from_enc.view(self.encoder.n_layers, self.encoder.num_directions,
                                          batch_size, self.encoder.hidden_dim)
            if self.encoder.bidirectional:
                c_processed = torch.cat([c_processed[:, 0, :, :], c_processed[:, 1, :, :]], dim=2)
            else:
                c_processed = c_processed.squeeze(1)

        # Apply linear transformation for dimension adaptation if needed
        if self.fc_adapt_hidden:
            h_processed = self.fc_adapt_hidden(h_processed)
            if is_lstm and c_processed is not None and self.fc_adapt_cell:
                c_processed = self.fc_adapt_cell(c_processed)

        # Adapt number of layers for the decoder's RNN
        final_h = torch.zeros(self.decoder.n_layers, batch_size, self.decoder.decoder_hidden_dim, device=self.device)
        final_c = torch.zeros(self.decoder.n_layers, batch_size, self.decoder.decoder_hidden_dim, device=self.device) if is_lstm else None

        num_layers_to_copy = min(self.encoder.n_layers, self.decoder.n_layers)

        final_h[:num_layers_to_copy, :, :] = h_processed[:num_layers_to_copy, :, :]
        if is_lstm and c_processed is not None:
            final_c[:num_layers_to_copy, :, :] = c_processed[:num_layers_to_copy, :, :]

        # If decoder has more layers than encoder, repeat the last encoder layer's state
        if self.decoder.n_layers > self.encoder.n_layers and self.encoder.n_layers > 0:
            last_h_layer_to_repeat = h_processed[self.encoder.n_layers-1, :, :]
            for i in range(self.encoder.n_layers, self.decoder.n_layers):
                final_h[i, :, :] = last_h_layer_to_repeat
                if is_lstm and c_processed is not None:
                    last_c_layer_to_repeat = c_processed[self.encoder.n_layers-1, :, :]
                    final_c[i, :, :] = last_c_layer_to_repeat

        return (final_h, final_c) if is_lstm else final_h

    def forward(self, source_seq, source_lengths, target_seq, teacher_forcing_ratio=0.5):
        """
        Forward pass for the Seq2SeqWithAttention model during training.

        Args:
            source_seq (torch.Tensor): Padded source sequences.
            source_lengths (torch.Tensor): Lengths of source sequences.
            target_seq (torch.Tensor): Padded target sequences (including SOS token).
            teacher_forcing_ratio (float): Probability of using actual target token as next input.

        Returns:
            torch.Tensor: Logits for the predicted target sequence.
        """
        batch_size = source_seq.shape[0]
        target_len = target_seq.shape[1]
        target_vocab_size = self.decoder.output_vocab_size
        outputs_logits = torch.zeros(batch_size, target_len, target_vocab_size).to(self.device)

        # Encode the source sequence to get all encoder outputs and final hidden state
        encoder_outputs, encoder_final_hidden = self.encoder(source_seq, source_lengths)
        # Adapt encoder's final hidden state to initialize the decoder's hidden state
        decoder_hidden = self._adapt_encoder_hidden_for_decoder(encoder_final_hidden)
        decoder_input = target_seq[:, 0]

        # Iterate through the target sequence, predicting one token at a time
        for t in range(target_len - 1):
            decoder_output_logits, decoder_hidden, _ = \
                self.decoder(decoder_input, decoder_hidden, encoder_outputs)
            outputs_logits[:, t+1] = decoder_output_logits

            # Apply teacher forcing: use true target or predicted token for next input
            teacher_force_this_step = random.random() < teacher_forcing_ratio
            top1_predicted_token = decoder_output_logits.argmax(1)

            decoder_input = target_seq[:, t+1] if teacher_force_this_step else top1_predicted_token
        return outputs_logits

    def predict_with_attention(self, source_seq, source_lengths, max_output_len=50, target_eos_idx=None):
        """
        Generates a sequence using greedy decoding and records attention weights.

        Args:
            source_seq (torch.Tensor): Input source sequence (batch_size=1).
            source_lengths (torch.Tensor): Length of the source sequence.
            max_output_len (int): Maximum length of the sequence to generate.
            target_eos_idx (int, optional): Index of the EOS token.

        Returns:
            tuple: Predicted token indices and attention matrices (numpy array).
        """
        self.eval()
        if source_seq.dim() == 1:
            source_seq = source_seq.unsqueeze(0)
            source_lengths = torch.tensor([source_lengths if isinstance(source_lengths, int) else len(source_lengths)], device=self.device)

        with torch.no_grad():
            encoder_outputs, encoder_final_hidden = self.encoder(source_seq, source_lengths)
            decoder_hidden = self._adapt_encoder_hidden_for_decoder(encoder_final_hidden)
            decoder_input = torch.tensor([self.target_sos_idx], device=self.device)

            predicted_indices = []
            attention_matrices = []

            for _ in range(max_output_len):
                decoder_output_logits, decoder_hidden, attention_weights = \
                    self.decoder(decoder_input, decoder_hidden, encoder_outputs)

                attention_matrices.append(attention_weights.squeeze(0).cpu().numpy())

                top1_predicted_token = decoder_output_logits.argmax(1)
                predicted_idx = top1_predicted_token.item()

                predicted_indices.append(predicted_idx)
                if target_eos_idx is not None and predicted_idx == target_eos_idx:
                    break

                if len(predicted_indices) >= max_output_len: break
                decoder_input = top1_predicted_token
        return predicted_indices, np.array(attention_matrices) if attention_matrices else None

# --- Data Loading and Preprocessing ---

class Vocabulary:
    """
    Manages the mapping between characters and their numerical indices.
    Includes special tokens for padding, start-of-sequence, end-of-sequence, and unknown characters.
    """
    def __init__(self, name):
        self.name = name
        self.char2index = {PAD_TOKEN: 0, SOS_TOKEN: 1, EOS_TOKEN: 2, UNK_TOKEN: 3}
        self.index2char = {0: PAD_TOKEN, 1: SOS_TOKEN, 2: EOS_TOKEN, 3: UNK_TOKEN}
        self.char_counts = Counter()
        self.n_chars = 4
        self.pad_idx = self.char2index[PAD_TOKEN]
        self.sos_idx = self.char2index[SOS_TOKEN]
        self.eos_idx = self.char2index[EOS_TOKEN]

    def add_sequence(self, sequence):
        """Adds characters from a sequence to the character counts."""
        for char in list(sequence):
            self.char_counts[char] += 1

    def build_vocab(self, min_freq=1):
        """
        Builds the vocabulary mapping based on character counts and a minimum frequency.
        Characters appearing less than `min_freq` will be treated as UNK_TOKEN.
        """
        sorted_chars = sorted(self.char_counts.keys(), key=lambda char: (-self.char_counts[char], char))
        for char in sorted_chars:
            if self.char_counts[char] >= min_freq and char not in self.char2index:
                self.char2index[char] = self.n_chars
                self.index2char[self.n_chars] = char
                self.n_chars += 1

    def sequence_to_indices(self, sequence, add_eos=False, add_sos=False, max_len=None):
        """Converts a character sequence into a list of numerical indices."""
        indices = []
        if add_sos:
            indices.append(self.sos_idx)

        seq_to_process = list(sequence)
        if max_len:
            effective_max_len = max_len
            if add_sos: effective_max_len -=1
            if add_eos: effective_max_len -=1
            seq_to_process = seq_to_process[:max(0, effective_max_len)] # Ensure non-negative slice

        for char in seq_to_process:
            indices.append(self.char2index.get(char, self.char2index[UNK_TOKEN]))

        if add_eos: indices.append(self.eos_idx)
        return indices

    def indices_to_sequence(self, indices, strip_special=True):
        """Converts a list of numerical indices back into a character sequence."""
        chars = []
        for index_val in indices:
            if strip_special and index_val == self.eos_idx: break
            if strip_special and index_val in [self.sos_idx, self.pad_idx]: continue
            chars.append(self.index2char.get(index_val, UNK_TOKEN))
        return "".join(chars)

class TransliterationDataset(Dataset):
    """
    A PyTorch Dataset for loading and preparing transliteration pairs.
    Reads data from a TSV file and converts text sequences to token indices.
    """
    def __init__(self, file_path, source_vocab, target_vocab, max_len=50):
        self.pairs = []
        self.source_vocab = source_vocab
        self.target_vocab = target_vocab
        self.max_len = max_len

        if not os.path.exists(file_path):
            print(f"ERROR: Data file not found: {file_path}")
            return

        print(f"Loading data from: {file_path}")
        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                for line_num, line in enumerate(f):
                    parts = line.strip().split('\t')
                    if len(parts) >= 2:
                        target, source = parts[0], parts[1]
                        if not source or not target or \
                           (self.max_len and (len(source) > self.max_len or len(target) > self.max_len)):
                            continue
                        self.pairs.append((source, target))
            print(f"Loaded {len(self.pairs)} pairs from {file_path}.")
        except Exception as e:
            print(f"ERROR: Could not read or process file {file_path}. Error: {e}")

    def __len__(self): return len(self.pairs)
    def __getitem__(self, idx):
        """Returns a source-target pair as Tensors of indices."""
        source_str, target_str = self.pairs[idx]
        source_indices = self.source_vocab.sequence_to_indices(source_str, add_eos=True, max_len=self.max_len)
        target_indices = self.target_vocab.sequence_to_indices(target_str, add_sos=True, add_eos=True, max_len=self.max_len)
        return torch.tensor(source_indices, dtype=torch.long), torch.tensor(target_indices, dtype=torch.long)

def collate_fn(batch, pad_idx_source, pad_idx_target):
    """
    Custom collate function for DataLoader to handle variable-length sequences.
    Pads sequences within a batch to the maximum length of that batch.
    """
    batch = [item for item in batch if item is not None and len(item[0]) > 0 and len(item[1]) > 0]
    if not batch: return None, None, None
    source_seqs, target_seqs = zip(*batch)

    source_lengths = torch.tensor([len(s) for s in source_seqs], dtype=torch.long)
    padded_sources = pad_sequence(source_seqs, batch_first=True, padding_value=pad_idx_source)
    padded_targets = pad_sequence(target_seqs, batch_first=True, padding_value=pad_idx_target)
    return padded_sources, source_lengths, padded_targets

# --- Training and Evaluation Functions ---

def _train_one_epoch(model, dataloader, optimizer, criterion, device, clip_value, teacher_forcing_ratio, target_vocab):
    """
    Trains the model for a single epoch.

    Args:
        model (nn.Module): The Seq2Seq model.
        dataloader (DataLoader): DataLoader for the training set.
        optimizer (optim.Optimizer): Optimizer for model parameters.
        criterion (nn.Module): Loss function (e.g., CrossEntropyLoss).
        device (torch.device): Device to run the model on.
        clip_value (float): Gradient clipping value.
        teacher_forcing_ratio (float): Probability of using teacher forcing.
        target_vocab (Vocabulary): Vocabulary for the target language.

    Returns:
        tuple: Average epoch loss and training accuracy (0.0 for this simplified version).
    """
    model.train()
    epoch_loss = 0
    if len(dataloader) == 0: return 0.0, 0.0

    for batch_data in tqdm(dataloader, desc="Training", leave=False):
        if batch_data[0] is None: continue
        sources, source_lengths, targets = batch_data
        if sources is None or sources.shape[0] == 0: continue

        sources, targets, source_lengths = sources.to(device), targets.to(device), source_lengths.to(device)
        optimizer.zero_grad()
        outputs_logits = model(sources, source_lengths, targets, teacher_forcing_ratio)

        output_dim = outputs_logits.shape[-1]
        flat_outputs = outputs_logits[:, 1:].reshape(-1, output_dim)
        flat_targets = targets[:, 1:].reshape(-1)

        loss = criterion(flat_outputs, flat_targets)
        if not (torch.isnan(loss) or torch.isinf(loss)):
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), clip_value)
            optimizer.step()
            epoch_loss += loss.item()
    # For simplicity in this heatmap script, training accuracy is not computed here
    return epoch_loss / len(dataloader) if len(dataloader) > 0 else 0.0, 0.0


def _evaluate_one_epoch(model, dataloader, criterion, device, source_vocab, target_vocab, beam_width=1, is_test_set=False):
    """
    Evaluates the model's performance on a given dataset (validation or test).

    Args:
        model (nn.Module): The Seq2Seq model.
        dataloader (DataLoader): DataLoader for the evaluation set.
        criterion (nn.Module): Loss function.
        device (torch.device): Device to run the model on.
        source_vocab (Vocabulary): Vocabulary for the source language (used for debug prints).
        target_vocab (Vocabulary): Vocabulary for the target language.
        beam_width (int): Beam width for decoding (1 for greedy, >1 for beam search).
        is_test_set (bool): Flag to indicate if this is the final test evaluation (for debug prints).

    Returns:
        tuple: Average epoch loss and evaluation accuracy.
    """
    model.eval()
    epoch_loss = 0
    total_correct = 0
    total_samples = 0
    if len(dataloader) == 0: return 0.0, 0.0
    desc_prefix = "Testing" if is_test_set else "Validating"

    with torch.no_grad():
        for batch_data in tqdm(dataloader, desc=desc_prefix, leave=False):
            if batch_data[0] is None: continue
            sources, source_lengths, targets = batch_data
            if sources is None or sources.shape[0] == 0: continue

            sources, targets, source_lengths = sources.to(device), targets.to(device), source_lengths.to(device)
            outputs = model(sources, source_lengths, targets, teacher_forcing_ratio=0.0)
            output_dim = outputs.shape[-1]
            flat_outputs = outputs[:, 1:].reshape(-1, output_dim)
            flat_targets = targets[:, 1:].reshape(-1)
            loss = criterion(flat_outputs, flat_targets)
            epoch_loss += loss.item() if not (torch.isnan(loss) or torch.isinf(loss)) else 0

            # For accuracy, we need to generate predictions (greedy or beam search)
            for i in range(targets.shape[0]):
                src_single, src_len_single = sources[i:i+1], source_lengths[i:i+1]

                # Use greedy for attention visualization
                if hasattr(model, 'predict_with_attention') and beam_width == 1:
                    predicted_indices, _ = model.predict_with_attention(
                        src_single, src_len_single,
                        max_output_len=targets.size(1) + 5,
                        target_eos_idx=target_vocab.eos_idx
                    )
                # Use beam search if specified and available
                elif hasattr(model, 'predict_beam_search') and beam_width > 1:
                    predicted_indices = model.predict_beam_search(
                        src_single, src_len_single,
                        max_output_len=targets.size(1) + 5,
                        target_eos_idx=target_vocab.eos_idx,
                        beam_width=beam_width
                    )
                else: # Fallback to argmax from teacher-forced outputs if no specific prediction method
                    predicted_indices = outputs[i:i+1].argmax(dim=2)[0, 1:].tolist()

                pred_str = target_vocab.indices_to_sequence(predicted_indices, strip_special=True)
                true_str = target_vocab.indices_to_sequence(targets[i, 1:].tolist(), strip_special=True)

                if pred_str == true_str: total_correct += 1
                total_samples += 1

                if is_test_set and batch_idx == 0 and i < 3:
                    print(f"  Test Example {i} - Source: '{source_vocab.indices_to_sequence(src_single[0].tolist(), strip_special=True)}'")
                    print(f"    Pred: '{pred_str}', True: '{true_str}', Match: {pred_str == true_str}")

    accuracy = total_correct / total_samples if total_samples > 0 else 0.0
    return epoch_loss / len(dataloader) if len(dataloader) > 0 else 0.0, accuracy

# --- Function to Train and Save the Best Attention Model ---

def train_and_save_attention_model(config_params, model_save_path, device):
    """
    Performs a dedicated training run using the best hyperparameters for the attention model.
    Saves the model checkpoint with the best validation accuracy.

    Args:
        config_params (dict): Dictionary of hyperparameters for this specific training run.
        model_save_path (str): Path to save the best model's state_dict.
        device (torch.device): Device to run the training on.

    Returns:
        bool: True if training was successful and a model was saved, False otherwise.
    """
    run_name_train_best_attn = f"TRAIN_BEST_ATTN_{config_params['cell_type']}_emb{config_params['embedding_dim']}_hid{config_params['hidden_dim']}"

    with wandb.init(project="DL_A3_Attention_Training", name=run_name_train_best_attn, config=config_params, job_type="training_best_attention_model", reinit=True) as run:
        cfg = wandb.config
        print(f"Starting dedicated training for BEST ATTENTION model with config: {cfg}")

        # --- Data Loading and Vocabulary Building ---
        BASE_DATA_DIR = "/kaggle/input/dakshina-dl-a3/dakshina_dataset_v1.0/hi/"
        DATA_DIR = os.path.join(BASE_DATA_DIR, "lexicons/")
        train_file = os.path.join(DATA_DIR, "hi.translit.sampled.train.tsv")
        dev_file = os.path.join(DATA_DIR, "hi.translit.sampled.dev.tsv")

        source_vocab = Vocabulary("latin")
        target_vocab = Vocabulary("devanagari")
        temp_train_ds_vocab = TransliterationDataset(train_file, source_vocab, target_vocab, max_len=cfg.max_seq_len)
        if not temp_train_ds_vocab.pairs:
            print(f"ERROR: No data for vocab from {train_file}")
            return False
        for src, tgt in temp_train_ds_vocab.pairs:
            source_vocab.add_sequence(src)
            target_vocab.add_sequence(tgt)
        source_vocab.build_vocab(min_freq=cfg.vocab_min_freq)
        target_vocab.build_vocab(min_freq=cfg.vocab_min_freq)

        train_dataset = TransliterationDataset(train_file, source_vocab, target_vocab, max_len=cfg.max_seq_len)
        dev_dataset = TransliterationDataset(dev_file, source_vocab, target_vocab, max_len=cfg.max_seq_len)
        if not train_dataset.pairs or not dev_dataset.pairs:
            print(f"ERROR: Train or Dev dataset empty. Train: {len(train_dataset)}, Dev: {len(dev_dataset)}")
            return False

        # --- DataLoader Setup ---
        num_w = 0 if device.type == 'cpu' else min(4, os.cpu_count()//2 if os.cpu_count() else 0)
        train_loader = DataLoader(train_dataset, batch_size=cfg.batch_size_train, shuffle=True,
                                  collate_fn=lambda b: collate_fn(b, source_vocab.pad_idx, target_vocab.pad_idx),
                                  num_workers=num_w, pin_memory=True if device.type == 'cuda' else False, drop_last=True)
        dev_loader = DataLoader(dev_dataset, batch_size=cfg.batch_size_train, shuffle=False,
                                collate_fn=lambda b: collate_fn(b, source_vocab.pad_idx, target_vocab.pad_idx),
                                num_workers=num_w, pin_memory=True if device.type == 'cuda' else False)
        if not train_loader or len(train_loader)==0 or not dev_loader or len(dev_loader)==0:
            print(f"ERROR: Train/Dev Dataloader empty. Train: {len(train_loader)}, Dev: {len(dev_loader)}")
            return False

        # --- Model, Optimizer, Loss Function Setup ---
        encoder_hidden_dim_eff = cfg.hidden_dim * (2 if cfg.encoder_bidirectional else 1)
        encoder = Encoder(source_vocab.n_chars, cfg.embedding_dim, cfg.hidden_dim,
                          cfg.encoder_layers, cfg.cell_type, cfg.dropout_p,
                          cfg.encoder_bidirectional, pad_idx=source_vocab.pad_idx).to(device)
        decoder = DecoderWithAttention(target_vocab.n_chars, cfg.embedding_dim, encoder_hidden_dim_eff,
                                       cfg.hidden_dim, cfg.decoder_layers, cfg.cell_type,
                                       cfg.dropout_p, pad_idx=target_vocab.pad_idx).to(device)
        model = Seq2SeqWithAttention(encoder, decoder, device, target_vocab.sos_idx).to(device)
        print(f"Best Attention Model parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
        wandb.watch(model, log="all", log_freq=100)

        optimizer = optim.Adam(model.parameters(), lr=cfg.learning_rate_train)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=cfg.lr_scheduler_patience_train, factor=0.3, verbose=True)
        criterion = nn.CrossEntropyLoss(ignore_index=target_vocab.pad_idx)

        best_val_acc_this_training = -1.0
        epochs_no_improve = 0
        max_epochs_no_improve_val = cfg.get('max_epochs_no_improve_train', 7)

        for epoch in range(cfg.epochs_train):
            train_loss, _ = _train_one_epoch(model, train_loader, optimizer, criterion, device,
                                             cfg.clip_value_train, cfg.teacher_forcing_train, target_vocab)
            val_loss, val_acc = _evaluate_one_epoch(model, dev_loader, criterion, device,
                                                    source_vocab, target_vocab, beam_width=1, is_test_set=False)

            scheduler.step(val_acc)
            print(f"Epoch {epoch+1}/{cfg.epochs_train} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}")
            wandb.log({"epoch_train_best_attn": epoch + 1, "train_loss_best_attn": train_loss,
                       "val_loss_best_attn": val_loss, "val_acc_best_attn": val_acc, "lr_best_attn": optimizer.param_groups[0]['lr']})

            # Early stopping logic
            if val_acc > best_val_acc_this_training:
                best_val_acc_this_training = val_acc
                epochs_no_improve = 0
                torch.save(model.state_dict(), model_save_path)
                print(f"Saved new best attention model to {model_save_path} (Val Acc: {best_val_acc_this_training:.4f})")
            else:
                epochs_no_improve += 1

            if epochs_no_improve >= max_epochs_no_improve_val:
                print(f"Early stopping for best attention model training at epoch {epoch+1}.")
                break

        # Save the final model state if no improvement happened over the initial checkpoint
        if not os.path.exists(model_save_path):
            torch.save(model.state_dict(), model_save_path)
            print(f"Saved final attention model state to {model_save_path}")

        wandb.summary["final_best_val_accuracy_during_attn_training"] = best_val_acc_this_training
        print(f"Finished training best attention model. Best Val Acc during this training: {best_val_acc_this_training:.4f}")
    return True


def plot_attention_heatmap(source_chars, predicted_chars, attention_matrix, file_path="attention_heatmap.png", title="Attention Heatmap"):
    """
    Plots an attention heatmap and saves it to a file.

    Args:
        source_chars (list): List of characters from the source sequence.
        predicted_chars (list): List of characters from the predicted target sequence.
        attention_matrix (np.ndarray): 2D NumPy array of attention weights (rows=target, cols=source).
        file_path (str): Path to save the generated heatmap image.
        title (str): Title for the heatmap plot.
    """
    if attention_matrix is None or not isinstance(attention_matrix, np.ndarray) or attention_matrix.ndim != 2 or attention_matrix.shape[0] == 0 or attention_matrix.shape[1] == 0:
        print(f"Warning: Invalid attention matrix for '{''.join(source_chars)}' -> '{''.join(predicted_chars)}'. Skipping plot.")
        return

    plot_pred_len = len(predicted_chars)
    plot_src_len = len(source_chars)

    if plot_pred_len == 0 or plot_src_len == 0:
        print(f"Warning: Empty source or predicted characters for heatmap. Source: {plot_src_len}, Pred: {plot_pred_len}. Skipping.")
        return

    current_attention_matrix = attention_matrix[:plot_pred_len, :plot_src_len]

    if current_attention_matrix.shape[0] != plot_pred_len or current_attention_matrix.shape[1] != plot_src_len :
        print(f"Warning: Attention matrix shape ({attention_matrix.shape}) "
              f"could not be perfectly aligned with char lists (Pred:{plot_pred_len}, Src:{plot_src_len}). "
              f"Using shape {current_attention_matrix.shape} for plot. Input: '{''.join(source_chars)}'")
        if not (current_attention_matrix.shape[0] > 0 and current_attention_matrix.shape[1] > 0):
             print("  Skipping plot due to zero dimension after alignment.")
             return

    fig, ax = plt.subplots(figsize=(max(6, plot_src_len*0.7), max(4, plot_pred_len*0.7)))
    cax = ax.matshow(current_attention_matrix, cmap='viridis')
    fig.colorbar(cax)

    try: # Attempt to set Devanagari font
        ax.set_yticklabels([''] + predicted_chars, fontfamily='Arial Unicode MS', fontsize=10)
    except:
        print("Warning: Arial Unicode MS font not found. Using default sans-serif for Devanagari labels.")
        ax.set_yticklabels([''] + predicted_chars, fontfamily='sans-serif', fontsize=10)
    ax.set_xticklabels([''] + source_chars, rotation=90, fontfamily='sans-serif', fontsize=10)

    ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
    ax.yaxis.set_major_locator(ticker.MultipleLocator(1))

    plt.xlabel("Source (Latin)")
    plt.ylabel("Prediction (Devanagari)")
    plt.title(title, fontsize=12)
    plt.tight_layout()

    try:
        plt.savefig(file_path)
        print(f"Saved attention heatmap to {file_path}")
    except Exception as e:
        print(f"Error saving heatmap: {e}")
    plt.close(fig)

# --- Main Execution Block for Heatmap Generation ---
if __name__ == '__main__':
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {DEVICE}")

    # --- W&B Login ---
    try:
        if 'KAGGLE_KERNEL_RUN_TYPE' in os.environ:
            print("Detected Kaggle environment. Ensuring WANDB_API_KEY is set.")
            if "WANDB_API_KEY" in os.environ: wandb.login()
            else:
                from kaggle_secrets import UserSecretsClient
                user_secrets = UserSecretsClient()
                wandb_api_key = user_secrets.get_secret("WANDB_API_KEY")
                wandb.login(key=wandb_api_key)
            print("W&B login for Kaggle successful.")
        else:
            wandb.login()
        print("W&B login process attempted/completed.")
    except Exception as e:
        print(f"W&B login failed: {e}. Ensure API key is configured.")
        wandb = None # Disable wandb if login fails

    # --- Best Attention Model Hyperparameters (UPDATE THIS SECTION) ---
    # These hyperparameters should come from your W&B sweep for the Attention Model.
    BEST_ATTN_HYPERPARAMETERS = {
        'embedding_dim': 256,
        'hidden_dim': 512,
        'encoder_layers': 1,
        'decoder_layers': 1,
        'cell_type': 'LSTM',
        'dropout_p': 0.3,
        'encoder_bidirectional': True,
        'learning_rate_train': 0.0008,
        'batch_size_train': 64,
        'epochs_train': 20,
        'clip_value_train': 1.0,
        'teacher_forcing_train': 0.5,
        'max_epochs_no_improve_train': 7,
        'lr_scheduler_patience_train': 3,
        'vocab_min_freq': 1,
        'max_seq_len': 50,
        'eval_batch_size': 64,
        # beam_width_eval is not directly used by predict_with_attention, which is greedy for this heatmap purpose
    }
    ATTN_MODEL_SAVE_PATH = "/kaggle/working/best_attention_model_q5.pt"
    HEATMAP_OUTPUT_DIR = "/kaggle/working/attention_heatmaps_q5d"

    print(f"Target hyperparameters for best attention model: {BEST_ATTN_HYPERPARAMETERS}")

    # --- Phase 1: Train and Save the Best Attention Model (if checkpoint doesn't exist) ---
    if not os.path.exists(ATTN_MODEL_SAVE_PATH):
        print(f"\n--- Attention Model Checkpoint NOT FOUND at {ATTN_MODEL_SAVE_PATH} ---")
        print("--- Attempting to TRAIN AND SAVE the best attention model using BEST_ATTN_HYPERPARAMETERS ---")
        training_successful = train_and_save_attention_model(BEST_ATTN_HYPERPARAMETERS, ATTN_MODEL_SAVE_PATH, DEVICE)
        if not training_successful or not os.path.exists(ATTN_MODEL_SAVE_PATH):
            print("ERROR: Failed to train and save the best attention model. Exiting.")
            exit()
        print(f"Best attention model trained and saved to {ATTN_MODEL_SAVE_PATH}")
    else:
        print(f"Found existing attention model checkpoint at {ATTN_MODEL_SAVE_PATH}. Will use this for heatmaps.")

    # --- Phase 2: Load Model and Prepare for Heatmap Generation ---
    print("\n--- Preparing for Heatmap Generation ---")

    BASE_DATA_DIR = "/kaggle/input/dakshina-dl-a3/dakshina_dataset_v1.0/hi/"
    DATA_DIR = os.path.join(BASE_DATA_DIR, "lexicons/")
    train_file_for_vocab = os.path.join(DATA_DIR, "hi.translit.sampled.train.tsv")
    test_file = os.path.join(DATA_DIR, "hi.translit.sampled.test.tsv")

    source_vocab = Vocabulary("latin_attn_heatmap")
    target_vocab = Vocabulary("devanagari_attn_heatmap")

    temp_train_ds_attn_vocab = TransliterationDataset(train_file_for_vocab, source_vocab, target_vocab,
                                                max_len=BEST_ATTN_HYPERPARAMETERS['max_seq_len'])
    if not temp_train_ds_attn_vocab.pairs: exit()
    for src, tgt in temp_train_ds_attn_vocab.pairs:
        source_vocab.add_sequence(src)
        target_vocab.add_sequence(tgt)
    source_vocab.build_vocab(min_freq=BEST_ATTN_HYPERPARAMETERS['vocab_min_freq'])
    target_vocab.build_vocab(min_freq=BEST_ATTN_HYPERPARAMETERS['vocab_min_freq'])

    test_dataset_for_heatmaps = TransliterationDataset(test_file, source_vocab, target_vocab,
                                                  max_len=BEST_ATTN_HYPERPARAMETERS['max_seq_len'])
    if not test_dataset_for_heatmaps.pairs:
        print("Test dataset for heatmaps is empty. Exiting.")
        exit()

    encoder_hidden_dim_eff_test = BEST_ATTN_HYPERPARAMETERS['hidden_dim'] * (2 if BEST_ATTN_HYPERPARAMETERS['encoder_bidirectional'] else 1)
    encoder = Encoder(source_vocab.n_chars, BEST_ATTN_HYPERPARAMETERS['embedding_dim'], BEST_ATTN_HYPERPARAMETERS['hidden_dim'],
                      BEST_ATTN_HYPERPARAMETERS['encoder_layers'], BEST_ATTN_HYPERPARAMETERS['cell_type'], BEST_ATTN_HYPERPARAMETERS['dropout_p'],
                      BEST_ATTN_HYPERPARAMETERS['encoder_bidirectional'], pad_idx=source_vocab.pad_idx).to(DEVICE)
    decoder = DecoderWithAttention(target_vocab.n_chars, BEST_ATTN_HYPERPARAMETERS['embedding_dim'],
                                   encoder_hidden_dim_eff_test, BEST_ATTN_HYPERPARAMETERS['hidden_dim'],
                                   BEST_ATTN_HYPERPARAMETERS['decoder_layers'], BEST_ATTN_HYPERPARAMETERS['cell_type'], BEST_ATTN_HYPERPARAMETERS['dropout_p'],
                                   pad_idx=target_vocab.pad_idx).to(DEVICE)
    model = Seq2SeqWithAttention(encoder, decoder, DEVICE, target_vocab.sos_idx).to(DEVICE)

    print(f"Loading weights into attention model from: {ATTN_MODEL_SAVE_PATH}")
    model.load_state_dict(torch.load(ATTN_MODEL_SAVE_PATH, map_location=DEVICE))
    print("Attention model weights loaded successfully.")
    model.eval() # Ensure model is in evaluation mode

    # --- Generate and Plot Attention Heatmaps ---
    if not os.path.exists(HEATMAP_OUTPUT_DIR):
        os.makedirs(HEATMAP_OUTPUT_DIR)

    num_heatmap_samples = 10
    actual_num_samples_heatmap = min(num_heatmap_samples, len(test_dataset_for_heatmaps.pairs))
    sample_indices_heatmap = []
    if len(test_dataset_for_heatmaps.pairs) > 0:
        sample_indices_heatmap = random.sample(range(len(test_dataset_for_heatmaps.pairs)), actual_num_samples_heatmap)

    generated_heatmap_files = []

    print(f"\n--- Generating {len(sample_indices_heatmap)} Attention Heatmaps ---")
    for i, data_idx in enumerate(sample_indices_heatmap):
        source_str, true_target_str = test_dataset_for_heatmaps.pairs[data_idx]

        # Prepare input for predict_with_attention
        source_indices_for_pred = source_vocab.sequence_to_indices(source_str, add_eos=True, max_len=BEST_ATTN_HYPERPARAMETERS['max_seq_len'])
        source_tensor_hm = torch.tensor(source_indices_for_pred, dtype=torch.long).unsqueeze(0).to(DEVICE)
        source_length_hm = torch.tensor([len(source_indices_for_pred)], dtype=torch.long).to(DEVICE)

        predicted_indices_hm, attention_matrix = model.predict_with_attention(
            source_tensor_hm,
            source_length_hm,
            max_output_len=len(true_target_str) + 10,
            target_eos_idx=target_vocab.eos_idx
        )

        # Convert indices to characters for plotting
        source_chars_for_plot = list(source_str) # Use original source string characters
        predicted_chars_for_plot = list(target_vocab.indices_to_sequence(predicted_indices_hm, strip_special=True))

        if attention_matrix is not None and len(predicted_chars_for_plot) > 0 and len(source_chars_for_plot) > 0:
            # Slice attention_matrix to match the lengths of displayed characters
            valid_pred_len = min(len(predicted_chars_for_plot), attention_matrix.shape[0])
            valid_src_len = min(len(source_chars_for_plot), attention_matrix.shape[1])

            if valid_pred_len > 0 and valid_src_len > 0:
                plot_attn_matrix = attention_matrix[:valid_pred_len, :valid_src_len]

                heatmap_file_path = os.path.join(HEATMAP_OUTPUT_DIR, f"attn_heatmap_{i+1}_{source_str[:15].replace(' ','_').replace('/','')}.png")
                plot_attention_heatmap(source_chars_for_plot[:valid_src_len],
                                       predicted_chars_for_plot[:valid_pred_len],
                                       plot_attn_matrix,
                                       heatmap_file_path,
                                       title=f"Input: {source_str} -> Pred: {''.join(predicted_chars_for_plot)}")
                generated_heatmap_files.append(heatmap_file_path)
            else:
                print(f"Warning: Not enough data in attention matrix or char lists for '{source_str}'. Skipping heatmap.")
        else:
            print(f"Warning: Could not generate attention matrix or empty prediction/source for '{source_str}'. Skipping heatmap.")

    print(f"\n--- Markdown for Displaying Attention Heatmaps (Saved in '{HEATMAP_OUTPUT_DIR}') ---")
    if generated_heatmap_files:
        print("You can use the following Markdown in your report (adjust paths if needed):")
        md_grid_rows = []
        num_full_rows = len(generated_heatmap_files) // 3
        for r_idx in range(num_full_rows):
            md_row_headers_str = " | ".join([f"Heatmap {r_idx*3+j+1}" for j in range(3)])
            md_row_images_str = " | ".join([f"![Attention {r_idx*3+j+1}]({os.path.relpath(generated_heatmap_files[r_idx*3+j], '/kaggle/working/')})" for j in range(3)])
            if r_idx == 0:
                md_grid_rows.append(f"| {md_row_headers_str} |")
                md_grid_rows.append(f"|---|---|---|")
            md_grid_rows.append(f"| {md_row_images_str} |")

        remaining_idx_start = num_full_rows * 3
        if remaining_idx_start < len(generated_heatmap_files):
            md_row_headers_list = []
            md_row_images_list = []
            for j, file_idx in enumerate(range(remaining_idx_start, len(generated_heatmap_files))):
                md_row_headers_list.append(f"Heatmap {file_idx+1}")
                md_row_images_list.append(f"![Attention {file_idx+1}]({os.path.relpath(generated_heatmap_files[file_idx], '/kaggle/working/')})")

            # Pad to 3 columns if necessary
            while len(md_row_headers_list) < 3: md_row_headers_list.append(" ")
            while len(md_row_images_list) < 3: md_row_images_list.append(" ")

            if num_full_rows == 0:
                md_grid_rows.append(f"| {md_row_headers_list[0]} | {md_row_headers_list[1]} | {md_row_headers_list[2]} |")
                md_grid_rows.append(f"|---|---|---|")
            md_grid_rows.append(f"| {md_row_images_list[0]} | {md_row_images_list[1]} | {md_row_images_list[2]} |")

        print("\n".join(md_grid_rows))
        print("\n(Note: Paths are relative to '/kaggle/working/'. Adjust if your report is viewed elsewhere.)")
    else:
        print("No heatmaps were generated successfully.")

    # --- Log Heatmaps to W&B ---
    if wandb and wandb.run is None:
        wandb.init(project="DL_A3", name=f"Q5d_Attention_Heatmaps", config=BEST_ATTN_HYPERPARAMETERS, job_type="q5d_heatmap_generation", reinit=True)

    if wandb and wandb.run and generated_heatmap_files:
        for i, f_path in enumerate(generated_heatmap_files):
            if os.path.exists(f_path):
                wandb.log({f"q5d_attention_heatmap_sample_{i+1}": wandb.Image(f_path, caption=f"Heatmap for sample {i+1}")})
        print("Attention heatmaps logged to W&B.")
        if wandb.run: wandb.finish()

    print("Script finished.")

# Q5 d heat map

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
import random
import numpy as np
import wandb
import os
from collections import Counter
from tqdm import tqdm
import heapq
import matplotlib
matplotlib.use('Agg') # Use a non-interactive backend for Kaggle, crucial for saving plots
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker

# --- Constants for Special Tokens ---
SOS_TOKEN = "<sos>"  # Start-of-sequence token
EOS_TOKEN = "<eos>"  # End-of-sequence token
PAD_TOKEN = "<pad>"  # Padding token
UNK_TOKEN = "<unk>"  # Unknown token

# --- Attention Mechanism ---

class Attention(nn.Module):
    """
    Implements a general (Luong-style) attention mechanism to compute alignment scores
    between the decoder's current hidden state and all encoder's output states.
    """
    def __init__(self, encoder_hidden_dim_eff, decoder_hidden_dim):
        super(Attention, self).__init__()
        self.attn_W = nn.Linear(encoder_hidden_dim_eff + decoder_hidden_dim, decoder_hidden_dim)
        self.attn_v = nn.Linear(decoder_hidden_dim, 1, bias=False)

    def forward(self, decoder_hidden_top_layer, encoder_outputs):
        """
        Calculates attention weights.

        Args:
            decoder_hidden_top_layer (torch.Tensor): The top layer's hidden state of the decoder RNN (batch_size, dec_hidden_dim).
            encoder_outputs (torch.Tensor): All hidden states from the encoder (batch_size, src_len, enc_hidden_dim).

        Returns:
            torch.Tensor: Attention weights (batch_size, src_len) representing alignment probabilities.
        """
        src_len = encoder_outputs.size(1)
        repeated_decoder_hidden = decoder_hidden_top_layer.unsqueeze(1).repeat(1, src_len, 1)

        energy_input = torch.cat((repeated_decoder_hidden, encoder_outputs), dim=2)
        energy = torch.tanh(self.attn_W(energy_input))

        attention_scores = self.attn_v(energy).squeeze(2)
        return F.softmax(attention_scores, dim=1)

# --- Model Definition (Encoder, DecoderWithAttention, Seq2SeqWithAttention) ---

class Encoder(nn.Module):
    """
    The Encoder processes the input sequence, producing a context vector (final hidden state)
    and a sequence of output states for the attention mechanism.
    """
    def __init__(self, input_vocab_size, embedding_dim, hidden_dim, n_layers,
                 cell_type='LSTM', dropout_p=0.1, bidirectional=False, pad_idx=0):
        super(Encoder, self).__init__()
        self.input_vocab_size = input_vocab_size
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.n_layers = n_layers
        self.cell_type = cell_type.upper()
        self.bidirectional = bidirectional
        self.num_directions = 2 if bidirectional else 1

        self.embedding = nn.Embedding(input_vocab_size, embedding_dim, padding_idx=pad_idx)
        self.dropout = nn.Dropout(dropout_p)

        rnn_dropout = dropout_p if n_layers > 1 else 0
        if self.cell_type == 'RNN':
            self.rnn = nn.RNN(embedding_dim, hidden_dim, n_layers,
                              dropout=rnn_dropout, batch_first=True, bidirectional=self.bidirectional)
        elif self.cell_type == 'GRU':
            self.rnn = nn.GRU(embedding_dim, hidden_dim, n_layers,
                              dropout=rnn_dropout, batch_first=True, bidirectional=self.bidirectional)
        elif self.cell_type == 'LSTM':
            self.rnn = nn.LSTM(embedding_dim, hidden_dim, n_layers,
                               dropout=rnn_dropout, batch_first=True, bidirectional=self.bidirectional)
        else:
            raise ValueError("Unsupported cell type. Choose from 'RNN', 'GRU', 'LSTM'.")

    def forward(self, input_seq, input_lengths):
        """
        Forward pass for the encoder.

        Args:
            input_seq (torch.Tensor): Padded input sequences (batch_size, seq_len).
            input_lengths (torch.Tensor): Lengths of the original sequences (batch_size,).

        Returns:
            tuple: (`encoder_outputs`, `hidden_state`).
                   `encoder_outputs`: All hidden states from the last layer (batch_size, src_len, hidden_dim * num_directions).
                   `hidden_state`: The final hidden state (and cell state for LSTM) from all layers.
        """
        embedded = self.embedding(input_seq)
        embedded = self.dropout(embedded)

        packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, input_lengths.cpu(), batch_first=True, enforce_sorted=False)
        packed_outputs, hidden = self.rnn(packed_embedded)
        encoder_outputs, _ = nn.utils.rnn.pad_packed_sequence(packed_outputs, batch_first=True)
        return encoder_outputs, hidden

class DecoderWithAttention(nn.Module):
    """
    The Decoder generates the output sequence one token at a time,
    incorporating an attention mechanism over the encoder's output states.
    """
    def __init__(self, output_vocab_size, embedding_dim, encoder_hidden_dim_eff,
                 decoder_hidden_dim, n_layers, cell_type='LSTM', dropout_p=0.1, pad_idx=0):
        super(DecoderWithAttention, self).__init__()
        self.output_vocab_size = output_vocab_size
        self.embedding_dim = embedding_dim
        self.encoder_hidden_dim_eff = encoder_hidden_dim_eff
        self.decoder_hidden_dim = decoder_hidden_dim
        self.n_layers = n_layers
        self.cell_type = cell_type.upper()

        self.attention = Attention(encoder_hidden_dim_eff, decoder_hidden_dim)
        self.embedding = nn.Embedding(output_vocab_size, embedding_dim, padding_idx=pad_idx)
        self.dropout = nn.Dropout(dropout_p)

        # RNN input is concatenation of token embedding and context vector
        rnn_input_dim = embedding_dim + encoder_hidden_dim_eff

        rnn_dropout = dropout_p if n_layers > 1 else 0
        if self.cell_type == 'RNN':
            self.rnn = nn.RNN(rnn_input_dim, decoder_hidden_dim, n_layers,
                              dropout=rnn_dropout, batch_first=True)
        elif self.cell_type == 'GRU':
            self.rnn = nn.GRU(rnn_input_dim, decoder_hidden_dim, n_layers,
                              dropout=rnn_dropout, batch_first=True)
        elif self.cell_type == 'LSTM':
            self.rnn = nn.LSTM(rnn_input_dim, decoder_hidden_dim, n_layers,
                               dropout=rnn_dropout, batch_first=True)
        else:
            raise ValueError("Unsupported cell type. Choose from 'RNN', 'GRU', 'LSTM'.")

        self.fc_out = nn.Linear(decoder_hidden_dim, output_vocab_size)

    def forward(self, input_char, prev_decoder_hidden, encoder_outputs):
        """
        Forward pass for the attention-based decoder.

        Args:
            input_char (torch.Tensor): Current target token (batch_size,).
            prev_decoder_hidden (torch.Tensor or tuple): Previous hidden state(s) of the decoder RNN.
            encoder_outputs (torch.Tensor): All hidden states from the encoder (batch_size, src_len, enc_hidden_dim_eff).

        Returns:
            tuple: (`prediction_logits`, `current_decoder_hidden`, `attention_weights`).
                   `prediction_logits`: Raw scores for each token in the vocabulary.
                   `current_decoder_hidden`: Updated hidden state of the decoder RNN.
                   `attention_weights`: Attention probabilities (batch_size, src_len).
        """
        input_char = input_char.unsqueeze(1)
        embedded = self.embedding(input_char)
        embedded = self.dropout(embedded)

        # Get the hidden state for attention query (top layer's hidden state)
        if self.cell_type == 'LSTM':
            attention_query_hidden = prev_decoder_hidden[0][-1, :, :]
        else:
            attention_query_hidden = prev_decoder_hidden[-1, :, :]

        # Calculate attention weights and context vector
        attention_weights = self.attention(attention_query_hidden, encoder_outputs)
        context_vector = torch.bmm(attention_weights.unsqueeze(1), encoder_outputs)

        # Concatenate embedded input and context vector to form RNN input
        rnn_input = torch.cat((embedded, context_vector), dim=2)

        # Pass through decoder RNN
        rnn_output, current_decoder_hidden = self.rnn(rnn_input, prev_decoder_hidden)

        # Project RNN output to vocabulary size
        rnn_output_squeezed = rnn_output.squeeze(1)
        prediction_logits = self.fc_out(rnn_output_squeezed)

        return prediction_logits, current_decoder_hidden, attention_weights


class Seq2SeqWithAttention(nn.Module):
    """
    The main Sequence-to-Sequence model integrating the Encoder and Attention-based Decoder.
    """
    def __init__(self, encoder, decoder, device, target_sos_idx):
        super(Seq2SeqWithAttention, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device
        self.target_sos_idx = target_sos_idx

        # Layers to adapt encoder's final hidden state to decoder's initial hidden state
        encoder_effective_final_hidden_dim = self.encoder.hidden_dim * self.encoder.num_directions
        decoder_rnn_hidden_dim = self.decoder.decoder_hidden_dim

        self.fc_adapt_hidden = None
        self.fc_adapt_cell = None
        if encoder_effective_final_hidden_dim != decoder_rnn_hidden_dim:
            self.fc_adapt_hidden = nn.Linear(encoder_effective_final_hidden_dim, decoder_rnn_hidden_dim)
            if self.encoder.cell_type == 'LSTM':
                self.fc_adapt_cell = nn.Linear(encoder_effective_final_hidden_dim, decoder_rnn_hidden_dim)

    def _adapt_encoder_hidden_for_decoder(self, encoder_final_hidden_state):
        """
        Adapts the encoder's final hidden state(s) to match the decoder's expected dimensions and layer count.
        Handles bidirectionality and differing layer counts.
        """
        is_lstm = self.encoder.cell_type == 'LSTM'
        if is_lstm:
            h_from_enc, c_from_enc = encoder_final_hidden_state
        else:
            h_from_enc = encoder_final_hidden_state
            c_from_enc = None

        batch_size = h_from_enc.size(1)

        # Reshape and concatenate bidirectional states if applicable
        h_processed = h_from_enc.view(self.encoder.n_layers, self.encoder.num_directions,
                                      batch_size, self.encoder.hidden_dim)
        if self.encoder.bidirectional:
            h_processed = torch.cat([h_processed[:, 0, :, :], h_processed[:, 1, :, :]], dim=2)
        else:
            h_processed = h_processed.squeeze(1)

        c_processed = None
        if is_lstm and c_from_enc is not None:
            c_processed = c_from_enc.view(self.encoder.n_layers, self.encoder.num_directions,
                                          batch_size, self.encoder.hidden_dim)
            if self.encoder.bidirectional:
                c_processed = torch.cat([c_processed[:, 0, :, :], c_processed[:, 1, :, :]], dim=2)
            else:
                c_processed = c_processed.squeeze(1)

        # Apply linear transformation for dimension adaptation if needed
        if self.fc_adapt_hidden:
            h_processed = self.fc_adapt_hidden(h_processed)
            if is_lstm and c_processed is not None and self.fc_adapt_cell:
                c_processed = self.fc_adapt_cell(c_processed)

        # Adapt number of layers for the decoder's RNN
        final_h = torch.zeros(self.decoder.n_layers, batch_size, self.decoder.decoder_hidden_dim, device=self.device)
        final_c = torch.zeros(self.decoder.n_layers, batch_size, self.decoder.decoder_hidden_dim, device=self.device) if is_lstm else None

        num_layers_to_copy = min(self.encoder.n_layers, self.decoder.n_layers)

        final_h[:num_layers_to_copy, :, :] = h_processed[:num_layers_to_copy, :, :]
        if is_lstm and c_processed is not None:
            final_c[:num_layers_to_copy, :, :] = c_processed[:num_layers_to_copy, :, :]

        # If decoder has more layers than encoder, repeat the last encoder layer's state
        if self.decoder.n_layers > self.encoder.n_layers and self.encoder.n_layers > 0:
            last_h_layer_to_repeat = h_processed[self.encoder.n_layers-1, :, :]
            for i in range(self.encoder.n_layers, self.decoder.n_layers):
                final_h[i, :, :] = last_h_layer_to_repeat
                if is_lstm and c_processed is not None:
                    last_c_layer_to_repeat = c_processed[self.encoder.n_layers-1, :, :]
                    final_c[i, :, :] = last_c_layer_to_repeat

        return (final_h, final_c) if is_lstm else final_h

    def forward(self, source_seq, source_lengths, target_seq, teacher_forcing_ratio=0.5):
        """
        Forward pass for the Seq2SeqWithAttention model during training.

        Args:
            source_seq (torch.Tensor): Padded source sequences.
            source_lengths (torch.Tensor): Lengths of source sequences.
            target_seq (torch.Tensor): Padded target sequences (including SOS token).
            teacher_forcing_ratio (float): Probability of using actual target token as next input.

        Returns:
            torch.Tensor: Logits for the predicted target sequence.
        """
        batch_size = source_seq.shape[0]
        target_len = target_seq.shape[1]
        target_vocab_size = self.decoder.output_vocab_size
        outputs_logits = torch.zeros(batch_size, target_len, target_vocab_size).to(self.device)

        # Encode the source sequence to get all encoder outputs and final hidden state
        encoder_outputs, encoder_final_hidden = self.encoder(source_seq, source_lengths)
        # Adapt encoder's final hidden state to initialize the decoder's hidden state
        decoder_hidden = self._adapt_encoder_hidden_for_decoder(encoder_final_hidden)
        decoder_input = target_seq[:, 0]

        # Iterate through the target sequence, predicting one token at a time
        for t in range(target_len - 1):
            decoder_output_logits, decoder_hidden, _ = \
                self.decoder(decoder_input, decoder_hidden, encoder_outputs)
            outputs_logits[:, t+1] = decoder_output_logits

            # Apply teacher forcing: use true target or predicted token for next input
            teacher_force_this_step = random.random() < teacher_forcing_ratio
            top1_predicted_token = decoder_output_logits.argmax(1)

            decoder_input = target_seq[:, t+1] if teacher_force_this_step else top1_predicted_token
        return outputs_logits

    def predict_with_attention(self, source_seq, source_lengths, max_output_len=50, target_eos_idx=None):
        """
        Generates a sequence using greedy decoding and records attention weights.

        Args:
            source_seq (torch.Tensor): Input source sequence (batch_size=1).
            source_lengths (torch.Tensor): Length of the source sequence.
            max_output_len (int): Maximum length of the sequence to generate.
            target_eos_idx (int, optional): Index of the EOS token.

        Returns:
            tuple: Predicted token indices and attention matrices (numpy array).
        """
        self.eval()
        if source_seq.dim() == 1:
            source_seq = source_seq.unsqueeze(0)
            source_lengths = torch.tensor([source_lengths if isinstance(source_lengths, int) else len(source_lengths)], device=self.device)

        with torch.no_grad():
            encoder_outputs, encoder_final_hidden = self.encoder(source_seq, source_lengths)
            decoder_hidden = self._adapt_encoder_hidden_for_decoder(encoder_final_hidden)
            decoder_input = torch.tensor([self.target_sos_idx], device=self.device)

            predicted_indices = []
            attention_matrices = []

            for _ in range(max_output_len):
                decoder_output_logits, decoder_hidden, attention_weights = \
                    self.decoder(decoder_input, decoder_hidden, encoder_outputs)

                attention_matrices.append(attention_weights.squeeze(0).cpu().numpy())

                top1_predicted_token = decoder_output_logits.argmax(1)
                predicted_idx = top1_predicted_token.item()

                predicted_indices.append(predicted_idx)
                if target_eos_idx is not None and predicted_idx == target_eos_idx:
                    break

                if len(predicted_indices) >= max_output_len: break
                decoder_input = top1_predicted_token
        return predicted_indices, np.array(attention_matrices) if attention_matrices else None

# --- Data Loading and Preprocessing ---

class Vocabulary:
    """
    Manages the mapping between characters and their numerical indices.
    Includes special tokens for padding, start-of-sequence, end-of-sequence, and unknown characters.
    """
    def __init__(self, name):
        self.name = name
        self.char2index = {PAD_TOKEN: 0, SOS_TOKEN: 1, EOS_TOKEN: 2, UNK_TOKEN: 3}
        self.index2char = {0: PAD_TOKEN, 1: SOS_TOKEN, 2: EOS_TOKEN, 3: UNK_TOKEN}
        self.char_counts = Counter()
        self.n_chars = 4
        self.pad_idx = self.char2index[PAD_TOKEN]
        self.sos_idx = self.char2index[SOS_TOKEN]
        self.eos_idx = self.char2index[EOS_TOKEN]

    def add_sequence(self, sequence):
        """Adds characters from a sequence to the character counts."""
        for char in list(sequence):
            self.char_counts[char] += 1

    def build_vocab(self, min_freq=1):
        """
        Builds the vocabulary mapping based on character counts and a minimum frequency.
        Characters appearing less than `min_freq` will be treated as UNK_TOKEN.
        """
        sorted_chars = sorted(self.char_counts.keys(), key=lambda char: (-self.char_counts[char], char))
        for char in sorted_chars:
            if self.char_counts[char] >= min_freq and char not in self.char2index:
                self.char2index[char] = self.n_chars
                self.index2char[self.n_chars] = char
                self.n_chars += 1

    def sequence_to_indices(self, sequence, add_eos=False, add_sos=False, max_len=None):
        """Converts a character sequence into a list of numerical indices."""
        indices = []
        if add_sos:
            indices.append(self.sos_idx)

        seq_to_process = list(sequence)
        if max_len:
            effective_max_len = max_len
            if add_sos: effective_max_len -=1
            if add_eos: effective_max_len -=1
            seq_to_process = seq_to_process[:max(0, effective_max_len)] # Ensure non-negative slice

        for char in seq_to_process:
            indices.append(self.char2index.get(char, self.char2index[UNK_TOKEN]))

        if add_eos: indices.append(self.eos_idx)
        return indices

    def indices_to_sequence(self, indices, strip_special=True):
        """Converts a list of numerical indices back into a character sequence."""
        chars = []
        for index_val in indices:
            if strip_special and index_val == self.eos_idx: break
            if strip_special and index_val in [self.sos_idx, self.pad_idx]: continue
            chars.append(self.index2char.get(index_val, UNK_TOKEN))
        return "".join(chars)

class TransliterationDataset(Dataset):
    """
    A PyTorch Dataset for loading and preparing transliteration pairs.
    Reads data from a TSV file and converts text sequences to token indices.
    """
    def __init__(self, file_path, source_vocab, target_vocab, max_len=50):
        self.pairs = []
        self.source_vocab = source_vocab
        self.target_vocab = target_vocab
        self.max_len = max_len

        if not os.path.exists(file_path):
            print(f"ERROR: Data file not found: {file_path}")
            return

        print(f"Loading data from: {file_path}")
        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                for line_num, line in enumerate(f):
                    parts = line.strip().split('\t')
                    if len(parts) >= 2:
                        target, source = parts[0], parts[1]
                        if not source or not target or \
                           (self.max_len and (len(source) > self.max_len or len(target) > self.max_len)):
                            continue
                        self.pairs.append((source, target))
            print(f"Loaded {len(self.pairs)} pairs from {file_path}.")
        except Exception as e:
            print(f"ERROR: Could not read or process file {file_path}. Error: {e}")

    def __len__(self): return len(self.pairs)
    def __getitem__(self, idx):
        """Returns a source-target pair as Tensors of indices."""
        source_str, target_str = self.pairs[idx]
        source_indices = self.source_vocab.sequence_to_indices(source_str, add_eos=True, max_len=self.max_len)
        target_indices = self.target_vocab.sequence_to_indices(target_str, add_sos=True, add_eos=True, max_len=self.max_len)
        return torch.tensor(source_indices, dtype=torch.long), torch.tensor(target_indices, dtype=torch.long)

def collate_fn(batch, pad_idx_source, pad_idx_target):
    """
    Custom collate function for DataLoader to handle variable-length sequences.
    Pads sequences within a batch to the maximum length of that batch.
    """
    batch = [item for item in batch if item is not None and len(item[0]) > 0 and len(item[1]) > 0]
    if not batch: return None, None, None
    source_seqs, target_seqs = zip(*batch)

    source_lengths = torch.tensor([len(s) for s in source_seqs], dtype=torch.long)
    padded_sources = pad_sequence(source_seqs, batch_first=True, padding_value=pad_idx_source)
    padded_targets = pad_sequence(target_seqs, batch_first=True, padding_value=pad_idx_target)
    return padded_sources, source_lengths, padded_targets

# --- Training and Evaluation Functions ---

def _train_one_epoch(model, dataloader, optimizer, criterion, device, clip_value, teacher_forcing_ratio, target_vocab):
    """
    Trains the model for a single epoch.

    Args:
        model (nn.Module): The Seq2Seq model.
        dataloader (DataLoader): DataLoader for the training set.
        optimizer (optim.Optimizer): Optimizer for model parameters.
        criterion (nn.Module): Loss function (e.g., CrossEntropyLoss).
        device (torch.device): Device to run the model on.
        clip_value (float): Gradient clipping value.
        teacher_forcing_ratio (float): Probability of using teacher forcing.
        target_vocab (Vocabulary): Vocabulary for the target language.

    Returns:
        tuple: Average epoch loss and training accuracy (0.0 for this simplified version).
    """
    model.train()
    epoch_loss = 0
    if len(dataloader) == 0: return 0.0, 0.0

    for batch_data in tqdm(dataloader, desc="Training", leave=False):
        if batch_data[0] is None: continue
        sources, source_lengths, targets = batch_data
        if sources is None or sources.shape[0] == 0: continue

        sources, targets, source_lengths = sources.to(device), targets.to(device), source_lengths.to(device)
        optimizer.zero_grad()
        outputs_logits = model(sources, source_lengths, targets, teacher_forcing_ratio)

        output_dim = outputs_logits.shape[-1]
        flat_outputs = outputs_logits[:, 1:].reshape(-1, output_dim)
        flat_targets = targets[:, 1:].reshape(-1)

        loss = criterion(flat_outputs, flat_targets)
        if not (torch.isnan(loss) or torch.isinf(loss)):
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), clip_value)
            optimizer.step()
            epoch_loss += loss.item()
    # For simplicity in this heatmap script, training accuracy is not computed here
    return epoch_loss / len(dataloader) if len(dataloader) > 0 else 0.0, 0.0


def _evaluate_one_epoch(model, dataloader, criterion, device, source_vocab, target_vocab, beam_width=1, is_test_set=False):
    """
    Evaluates the model's performance on a given dataset (validation or test).

    Args:
        model (nn.Module): The Seq2Seq model.
        dataloader (DataLoader): DataLoader for the evaluation set.
        criterion (nn.Module): Loss function.
        device (torch.device): Device to run the model on.
        source_vocab (Vocabulary): Vocabulary for the source language (used for debug prints).
        target_vocab (Vocabulary): Vocabulary for the target language.
        beam_width (int): Beam width for decoding (1 for greedy, >1 for beam search).
        is_test_set (bool): Flag to indicate if this is the final test evaluation (for debug prints).

    Returns:
        tuple: Average epoch loss and evaluation accuracy.
    """
    model.eval()
    epoch_loss = 0
    total_correct = 0
    total_samples = 0
    if len(dataloader) == 0: return 0.0, 0.0
    desc_prefix = "Testing" if is_test_set else "Validating"

    with torch.no_grad():
        for batch_data in tqdm(dataloader, desc=desc_prefix, leave=False):
            if batch_data[0] is None: continue
            sources, source_lengths, targets = batch_data
            if sources is None or sources.shape[0] == 0: continue

            sources, targets, source_lengths = sources.to(device), targets.to(device), source_lengths.to(device)
            outputs = model(sources, source_lengths, targets, teacher_forcing_ratio=0.0)
            output_dim = outputs.shape[-1]
            flat_outputs = outputs[:, 1:].reshape(-1, output_dim)
            flat_targets = targets[:, 1:].reshape(-1)
            loss = criterion(flat_outputs, flat_targets)
            epoch_loss += loss.item() if not (torch.isnan(loss) or torch.isinf(loss)) else 0

            # For accuracy, we need to generate predictions (greedy or beam search)
            for i in range(targets.shape[0]):
                src_single, src_len_single = sources[i:i+1], source_lengths[i:i+1]

                # Use greedy for attention visualization
                if hasattr(model, 'predict_with_attention') and beam_width == 1:
                    predicted_indices, _ = model.predict_with_attention(
                        src_single, src_len_single,
                        max_output_len=targets.size(1) + 5,
                        target_eos_idx=target_vocab.eos_idx
                    )
                # Use beam search if specified and available
                elif hasattr(model, 'predict_beam_search') and beam_width > 1:
                    predicted_indices = model.predict_beam_search(
                        src_single, src_len_single,
                        max_output_len=targets.size(1) + 5,
                        target_eos_idx=target_vocab.eos_idx,
                        beam_width=beam_width
                    )
                else: # Fallback to argmax from teacher-forced outputs if no specific prediction method
                    predicted_indices = outputs[i:i+1].argmax(dim=2)[0, 1:].tolist()

                pred_str = target_vocab.indices_to_sequence(predicted_indices, strip_special=True)
                true_str = target_vocab.indices_to_sequence(targets[i, 1:].tolist(), strip_special=True)

                if pred_str == true_str: total_correct += 1
                total_samples += 1

                if is_test_set and batch_idx == 0 and i < 3:
                    print(f"  Test Example {i} - Source: '{source_vocab.indices_to_sequence(src_single[0].tolist(), strip_special=True)}'")
                    print(f"    Pred: '{pred_str}', True: '{true_str}', Match: {pred_str == true_str}")

    accuracy = total_correct / total_samples if total_samples > 0 else 0.0
    return epoch_loss / len(dataloader) if len(dataloader) > 0 else 0.0, accuracy

# --- Function to Train and Save the Best Attention Model ---

def train_and_save_attention_model(config_params, model_save_path, device):
    """
    Performs a dedicated training run using the best hyperparameters for the attention model.
    Saves the model checkpoint with the best validation accuracy.

    Args:
        config_params (dict): Dictionary of hyperparameters for this specific training run.
        model_save_path (str): Path to save the best model's state_dict.
        device (torch.device): Device to run the training on.

    Returns:
        bool: True if training was successful and a model was saved, False otherwise.
    """
    run_name_train_best_attn = f"TRAIN_BEST_ATTN_{config_params['cell_type']}_emb{config_params['embedding_dim']}_hid{config_params['hidden_dim']}"

    with wandb.init(project="DL_A3_Attention_Training", name=run_name_train_best_attn, config=config_params, job_type="training_best_attention_model", reinit=True) as run:
        cfg = wandb.config
        print(f"Starting dedicated training for BEST ATTENTION model with config: {cfg}")

        # --- Data Loading and Vocabulary Building ---
        BASE_DATA_DIR = "/kaggle/input/dakshina-dl-a3/dakshina_dataset_v1.0/hi/"
        DATA_DIR = os.path.join(BASE_DATA_DIR, "lexicons/")
        train_file = os.path.join(DATA_DIR, "hi.translit.sampled.train.tsv")
        dev_file = os.path.join(DATA_DIR, "hi.translit.sampled.dev.tsv")

        source_vocab = Vocabulary("latin")
        target_vocab = Vocabulary("devanagari")
        temp_train_ds_vocab = TransliterationDataset(train_file, source_vocab, target_vocab, max_len=cfg.max_seq_len)
        if not temp_train_ds_vocab.pairs:
            print(f"ERROR: No data for vocab from {train_file}")
            return False
        for src, tgt in temp_train_ds_vocab.pairs:
            source_vocab.add_sequence(src)
            target_vocab.add_sequence(tgt)
        source_vocab.build_vocab(min_freq=cfg.vocab_min_freq)
        target_vocab.build_vocab(min_freq=cfg.vocab_min_freq)

        train_dataset = TransliterationDataset(train_file, source_vocab, target_vocab, max_len=cfg.max_seq_len)
        dev_dataset = TransliterationDataset(dev_file, source_vocab, target_vocab, max_len=cfg.max_seq_len)
        if not train_dataset.pairs or not dev_dataset.pairs:
            print(f"ERROR: Train or Dev dataset empty. Train: {len(train_dataset)}, Dev: {len(dev_dataset)}")
            return False

        # --- DataLoader Setup ---
        num_w = 0 if device.type == 'cpu' else min(4, os.cpu_count()//2 if os.cpu_count() else 0)
        train_loader = DataLoader(train_dataset, batch_size=cfg.batch_size_train, shuffle=True,
                                  collate_fn=lambda b: collate_fn(b, source_vocab.pad_idx, target_vocab.pad_idx),
                                  num_workers=num_w, pin_memory=True if device.type == 'cuda' else False, drop_last=True)
        dev_loader = DataLoader(dev_dataset, batch_size=cfg.batch_size_train, shuffle=False,
                                collate_fn=lambda b: collate_fn(b, source_vocab.pad_idx, target_vocab.pad_idx),
                                num_workers=num_w, pin_memory=True if device.type == 'cuda' else False)
        if not train_loader or len(train_loader)==0 or not dev_loader or len(dev_loader)==0:
            print(f"ERROR: Train/Dev Dataloader empty. Train: {len(train_loader)}, Dev: {len(dev_loader)}")
            return False

        # --- Model, Optimizer, Loss Function Setup ---
        encoder_hidden_dim_eff = cfg.hidden_dim * (2 if cfg.encoder_bidirectional else 1)
        encoder = Encoder(source_vocab.n_chars, cfg.embedding_dim, cfg.hidden_dim,
                          cfg.encoder_layers, cfg.cell_type, cfg.dropout_p,
                          cfg.encoder_bidirectional, pad_idx=source_vocab.pad_idx).to(device)
        decoder = DecoderWithAttention(target_vocab.n_chars, cfg.embedding_dim, encoder_hidden_dim_eff,
                                       cfg.hidden_dim, cfg.decoder_layers, cfg.cell_type,
                                       cfg.dropout_p, pad_idx=target_vocab.pad_idx).to(device)
        model = Seq2SeqWithAttention(encoder, decoder, device, target_vocab.sos_idx).to(device)
        print(f"Best Attention Model parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
        wandb.watch(model, log="all", log_freq=100)

        optimizer = optim.Adam(model.parameters(), lr=cfg.learning_rate_train)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=cfg.lr_scheduler_patience_train, factor=0.3, verbose=True)
        criterion = nn.CrossEntropyLoss(ignore_index=target_vocab.pad_idx)

        best_val_acc_this_training = -1.0
        epochs_no_improve = 0
        max_epochs_no_improve_val = cfg.get('max_epochs_no_improve_train', 7)

        for epoch in range(cfg.epochs_train):
            train_loss, _ = _train_one_epoch(model, train_loader, optimizer, criterion, device,
                                             cfg.clip_value_train, cfg.teacher_forcing_train, target_vocab)
            val_loss, val_acc = _evaluate_one_epoch(model, dev_loader, criterion, device,
                                                    source_vocab, target_vocab, beam_width=1, is_test_set=False)

            scheduler.step(val_acc)
            print(f"Epoch {epoch+1}/{cfg.epochs_train} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}")
            wandb.log({"epoch_train_best_attn": epoch + 1, "train_loss_best_attn": train_loss,
                       "val_loss_best_attn": val_loss, "val_acc_best_attn": val_acc, "lr_best_attn": optimizer.param_groups[0]['lr']})

            # Early stopping logic
            if val_acc > best_val_acc_this_training:
                best_val_acc_this_training = val_acc
                epochs_no_improve = 0
                torch.save(model.state_dict(), model_save_path)
                print(f"Saved new best attention model to {model_save_path} (Val Acc: {best_val_acc_this_training:.4f})")
            else:
                epochs_no_improve += 1

            if epochs_no_improve >= max_epochs_no_improve_val:
                print(f"Early stopping for best attention model training at epoch {epoch+1}.")
                break

        # Save the final model state if no improvement happened over the initial checkpoint
        if not os.path.exists(model_save_path):
            torch.save(model.state_dict(), model_save_path)
            print(f"Saved final attention model state to {model_save_path}")

        wandb.summary["final_best_val_accuracy_during_attn_training"] = best_val_acc_this_training
        print(f"Finished training best attention model. Best Val Acc during this training: {best_val_acc_this_training:.4f}")
    return True


def plot_attention_heatmap(source_chars, predicted_chars, attention_matrix, file_path="attention_heatmap.png", title="Attention Heatmap"):
    """
    Plots an attention heatmap and saves it to a file.

    Args:
        source_chars (list): List of characters from the source sequence.
        predicted_chars (list): List of characters from the predicted target sequence.
        attention_matrix (np.ndarray): 2D NumPy array of attention weights (rows=target, cols=source).
        file_path (str): Path to save the generated heatmap image.
        title (str): Title for the heatmap plot.
    """
    if attention_matrix is None or not isinstance(attention_matrix, np.ndarray) or attention_matrix.ndim != 2 or attention_matrix.shape[0] == 0 or attention_matrix.shape[1] == 0:
        print(f"Warning: Invalid attention matrix for '{''.join(source_chars)}' -> '{''.join(predicted_chars)}'. Skipping plot.")
        return

    plot_pred_len = len(predicted_chars)
    plot_src_len = len(source_chars)

    if plot_pred_len == 0 or plot_src_len == 0:
        print(f"Warning: Empty source or predicted characters for heatmap. Source: {plot_src_len}, Pred: {plot_pred_len}. Skipping.")
        return

    current_attention_matrix = attention_matrix[:plot_pred_len, :plot_src_len]

    if current_attention_matrix.shape[0] != plot_pred_len or current_attention_matrix.shape[1] != plot_src_len :
        print(f"Warning: Attention matrix shape ({attention_matrix.shape}) "
              f"could not be perfectly aligned with char lists (Pred:{plot_pred_len}, Src:{plot_src_len}). "
              f"Using shape {current_attention_matrix.shape} for plot. Input: '{''.join(source_chars)}'")
        if not (current_attention_matrix.shape[0] > 0 and current_attention_matrix.shape[1] > 0):
             print("  Skipping plot due to zero dimension after alignment.")
             return

    fig, ax = plt.subplots(figsize=(max(6, plot_src_len*0.7), max(4, plot_pred_len*0.7)))
    cax = ax.matshow(current_attention_matrix, cmap='viridis')
    fig.colorbar(cax)

    try: # Attempt to set Devanagari font
        ax.set_yticklabels([''] + predicted_chars, fontfamily='Arial Unicode MS', fontsize=10)
    except:
        print("Warning: Arial Unicode MS font not found. Using default sans-serif for Devanagari labels.")
        ax.set_yticklabels([''] + predicted_chars, fontfamily='sans-serif', fontsize=10)
    ax.set_xticklabels([''] + source_chars, rotation=90, fontfamily='sans-serif', fontsize=10)

    ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
    ax.yaxis.set_major_locator(ticker.MultipleLocator(1))

    plt.xlabel("Source (Latin)")
    plt.ylabel("Prediction (Devanagari)")
    plt.title(title, fontsize=12)
    plt.tight_layout()

    try:
        plt.savefig(file_path)
        print(f"Saved attention heatmap to {file_path}")
    except Exception as e:
        print(f"Error saving heatmap: {e}")
    plt.close(fig)

# --- Main Execution Block for Heatmap Generation ---
if __name__ == '__main__':
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {DEVICE}")

    # --- W&B Login ---
    try:
        if 'KAGGLE_KERNEL_RUN_TYPE' in os.environ:
            print("Detected Kaggle environment. Ensuring WANDB_API_KEY is set.")
            if "WANDB_API_KEY" in os.environ: wandb.login()
            else:
                from kaggle_secrets import UserSecretsClient
                user_secrets = UserSecretsClient()
                wandb_api_key = user_secrets.get_secret("WANDB_API_KEY")
                wandb.login(key=wandb_api_key)
            print("W&B login for Kaggle successful.")
        else:
            wandb.login()
        print("W&B login process attempted/completed.")
    except Exception as e:
        print(f"W&B login failed: {e}. Ensure API key is configured.")
        wandb = None # Disable wandb if login fails

    # --- Best Attention Model Hyperparameters (UPDATE THIS SECTION) ---
    # These hyperparameters should come from your W&B sweep for the Attention Model.
    BEST_ATTN_HYPERPARAMETERS = {
        'embedding_dim': 256,
        'hidden_dim': 512,
        'encoder_layers': 1,
        'decoder_layers': 1,
        'cell_type': 'LSTM',
        'dropout_p': 0.3,
        'encoder_bidirectional': True,
        'learning_rate_train': 0.0008,
        'batch_size_train': 64,
        'epochs_train': 20,
        'clip_value_train': 1.0,
        'teacher_forcing_train': 0.5,
        'max_epochs_no_improve_train': 7,
        'lr_scheduler_patience_train': 3,
        'vocab_min_freq': 1,
        'max_seq_len': 50,
        'eval_batch_size': 64,
        # beam_width_eval is not directly used by predict_with_attention, which is greedy for this heatmap purpose
    }
    ATTN_MODEL_SAVE_PATH = "/kaggle/working/best_attention_model_q5.pt"
    HEATMAP_OUTPUT_DIR = "/kaggle/working/attention_heatmaps_q5d"

    print(f"Target hyperparameters for best attention model: {BEST_ATTN_HYPERPARAMETERS}")

    # --- Phase 1: Train and Save the Best Attention Model (if checkpoint doesn't exist) ---
    if not os.path.exists(ATTN_MODEL_SAVE_PATH):
        print(f"\n--- Attention Model Checkpoint NOT FOUND at {ATTN_MODEL_SAVE_PATH} ---")
        print("--- Attempting to TRAIN AND SAVE the best attention model using BEST_ATTN_HYPERPARAMETERS ---")
        training_successful = train_and_save_attention_model(BEST_ATTN_HYPERPARAMETERS, ATTN_MODEL_SAVE_PATH, DEVICE)
        if not training_successful or not os.path.exists(ATTN_MODEL_SAVE_PATH):
            print("ERROR: Failed to train and save the best attention model. Exiting.")
            exit()
        print(f"Best attention model trained and saved to {ATTN_MODEL_SAVE_PATH}")
    else:
        print(f"Found existing attention model checkpoint at {ATTN_MODEL_SAVE_PATH}. Will use this for heatmaps.")

    # --- Phase 2: Load Model and Prepare for Heatmap Generation ---
    print("\n--- Preparing for Heatmap Generation ---")

    BASE_DATA_DIR = "/kaggle/input/dakshina-dl-a3/dakshina_dataset_v1.0/hi/"
    DATA_DIR = os.path.join(BASE_DATA_DIR, "lexicons/")
    train_file_for_vocab = os.path.join(DATA_DIR, "hi.translit.sampled.train.tsv")
    test_file = os.path.join(DATA_DIR, "hi.translit.sampled.test.tsv")

    source_vocab = Vocabulary("latin_attn_heatmap")
    target_vocab = Vocabulary("devanagari_attn_heatmap")

    temp_train_ds_attn_vocab = TransliterationDataset(train_file_for_vocab, source_vocab, target_vocab,
                                                max_len=BEST_ATTN_HYPERPARAMETERS['max_seq_len'])
    if not temp_train_ds_attn_vocab.pairs: exit()
    for src, tgt in temp_train_ds_attn_vocab.pairs:
        source_vocab.add_sequence(src)
        target_vocab.add_sequence(tgt)
    source_vocab.build_vocab(min_freq=BEST_ATTN_HYPERPARAMETERS['vocab_min_freq'])
    target_vocab.build_vocab(min_freq=BEST_ATTN_HYPERPARAMETERS['vocab_min_freq'])

    test_dataset_for_heatmaps = TransliterationDataset(test_file, source_vocab, target_vocab,
                                                  max_len=BEST_ATTN_HYPERPARAMETERS['max_seq_len'])
    if not test_dataset_for_heatmaps.pairs:
        print("Test dataset for heatmaps is empty. Exiting.")
        exit()

    encoder_hidden_dim_eff_test = BEST_ATTN_HYPERPARAMETERS['hidden_dim'] * (2 if BEST_ATTN_HYPERPARAMETERS['encoder_bidirectional'] else 1)
    encoder = Encoder(source_vocab.n_chars, BEST_ATTN_HYPERPARAMETERS['embedding_dim'], BEST_ATTN_HYPERPARAMETERS['hidden_dim'],
                      BEST_ATTN_HYPERPARAMETERS['encoder_layers'], BEST_ATTN_HYPERPARAMETERS['cell_type'], BEST_ATTN_HYPERPARAMETERS['dropout_p'],
                      BEST_ATTN_HYPERPARAMETERS['encoder_bidirectional'], pad_idx=source_vocab.pad_idx).to(DEVICE)
    decoder = DecoderWithAttention(target_vocab.n_chars, BEST_ATTN_HYPERPARAMETERS['embedding_dim'],
                                   encoder_hidden_dim_eff_test, BEST_ATTN_HYPERPARAMETERS['hidden_dim'],
                                   BEST_ATTN_HYPERPARAMETERS['decoder_layers'], BEST_ATTN_HYPERPARAMETERS['cell_type'], BEST_ATTN_HYPERPARAMETERS['dropout_p'],
                                   pad_idx=target_vocab.pad_idx).to(DEVICE)
    model = Seq2SeqWithAttention(encoder, decoder, DEVICE, target_vocab.sos_idx).to(DEVICE)

    print(f"Loading weights into attention model from: {ATTN_MODEL_SAVE_PATH}")
    model.load_state_dict(torch.load(ATTN_MODEL_SAVE_PATH, map_location=DEVICE))
    print("Attention model weights loaded successfully.")
    model.eval() # Ensure model is in evaluation mode

    # --- Generate and Plot Attention Heatmaps ---
    if not os.path.exists(HEATMAP_OUTPUT_DIR):
        os.makedirs(HEATMAP_OUTPUT_DIR)

    num_heatmap_samples = 10
    actual_num_samples_heatmap = min(num_heatmap_samples, len(test_dataset_for_heatmaps.pairs))
    sample_indices_heatmap = []
    if len(test_dataset_for_heatmaps.pairs) > 0:
        sample_indices_heatmap = random.sample(range(len(test_dataset_for_heatmaps.pairs)), actual_num_samples_heatmap)

    generated_heatmap_files = []

    print(f"\n--- Generating {len(sample_indices_heatmap)} Attention Heatmaps ---")
    for i, data_idx in enumerate(sample_indices_heatmap):
        source_str, true_target_str = test_dataset_for_heatmaps.pairs[data_idx]

        # Prepare input for predict_with_attention
        source_indices_for_pred = source_vocab.sequence_to_indices(source_str, add_eos=True, max_len=BEST_ATTN_HYPERPARAMETERS['max_seq_len'])
        source_tensor_hm = torch.tensor(source_indices_for_pred, dtype=torch.long).unsqueeze(0).to(DEVICE)
        source_length_hm = torch.tensor([len(source_indices_for_pred)], dtype=torch.long).to(DEVICE)

        predicted_indices_hm, attention_matrix = model.predict_with_attention(
            source_tensor_hm,
            source_length_hm,
            max_output_len=len(true_target_str) + 10,
            target_eos_idx=target_vocab.eos_idx
        )

        # Convert indices to characters for plotting
        source_chars_for_plot = list(source_str) # Use original source string characters
        predicted_chars_for_plot = list(target_vocab.indices_to_sequence(predicted_indices_hm, strip_special=True))

        if attention_matrix is not None and len(predicted_chars_for_plot) > 0 and len(source_chars_for_plot) > 0:
            # Slice attention_matrix to match the lengths of displayed characters
            valid_pred_len = min(len(predicted_chars_for_plot), attention_matrix.shape[0])
            valid_src_len = min(len(source_chars_for_plot), attention_matrix.shape[1])

            if valid_pred_len > 0 and valid_src_len > 0:
                plot_attn_matrix = attention_matrix[:valid_pred_len, :valid_src_len]

                heatmap_file_path = os.path.join(HEATMAP_OUTPUT_DIR, f"attn_heatmap_{i+1}_{source_str[:15].replace(' ','_').replace('/','')}.png")
                plot_attention_heatmap(source_chars_for_plot[:valid_src_len],
                                       predicted_chars_for_plot[:valid_pred_len],
                                       plot_attn_matrix,
                                       heatmap_file_path,
                                       title=f"Input: {source_str} -> Pred: {''.join(predicted_chars_for_plot)}")
                generated_heatmap_files.append(heatmap_file_path)
            else:
                print(f"Warning: Not enough data in attention matrix or char lists for '{source_str}'. Skipping heatmap.")
        else:
            print(f"Warning: Could not generate attention matrix or empty prediction/source for '{source_str}'. Skipping heatmap.")

    print(f"\n--- Markdown for Displaying Attention Heatmaps (Saved in '{HEATMAP_OUTPUT_DIR}') ---")
    if generated_heatmap_files:
        print("You can use the following Markdown in your report (adjust paths if needed):")
        md_grid_rows = []
        num_full_rows = len(generated_heatmap_files) // 3
        for r_idx in range(num_full_rows):
            md_row_headers_str = " | ".join([f"Heatmap {r_idx*3+j+1}" for j in range(3)])
            md_row_images_str = " | ".join([f"![Attention {r_idx*3+j+1}]({os.path.relpath(generated_heatmap_files[r_idx*3+j], '/kaggle/working/')})" for j in range(3)])
            if r_idx == 0:
                md_grid_rows.append(f"| {md_row_headers_str} |")
                md_grid_rows.append(f"|---|---|---|")
            md_grid_rows.append(f"| {md_row_images_str} |")

        remaining_idx_start = num_full_rows * 3
        if remaining_idx_start < len(generated_heatmap_files):
            md_row_headers_list = []
            md_row_images_list = []
            for j, file_idx in enumerate(range(remaining_idx_start, len(generated_heatmap_files))):
                md_row_headers_list.append(f"Heatmap {file_idx+1}")
                md_row_images_list.append(f"![Attention {file_idx+1}]({os.path.relpath(generated_heatmap_files[file_idx], '/kaggle/working/')})")

            # Pad to 3 columns if necessary
            while len(md_row_headers_list) < 3: md_row_headers_list.append(" ")
            while len(md_row_images_list) < 3: md_row_images_list.append(" ")

            if num_full_rows == 0:
                md_grid_rows.append(f"| {md_row_headers_list[0]} | {md_row_headers_list[1]} | {md_row_headers_list[2]} |")
                md_grid_rows.append(f"|---|---|---|")
            md_grid_rows.append(f"| {md_row_images_list[0]} | {md_row_images_list[1]} | {md_row_images_list[2]} |")

        print("\n".join(md_grid_rows))
        print("\n(Note: Paths are relative to '/kaggle/working/'. Adjust if your report is viewed elsewhere.)")
    else:
        print("No heatmaps were generated successfully.")

    # --- Log Heatmaps to W&B ---
    if wandb and wandb.run is None:
        wandb.init(project="DL_A3", name=f"Q5d_Attention_Heatmaps", config=BEST_ATTN_HYPERPARAMETERS, job_type="q5d_heatmap_generation", reinit=True)

    if wandb and wandb.run and generated_heatmap_files:
        for i, f_path in enumerate(generated_heatmap_files):
            if os.path.exists(f_path):
                wandb.log({f"q5d_attention_heatmap_sample_{i+1}": wandb.Image(f_path, caption=f"Heatmap for sample {i+1}")})
        print("Attention heatmaps logged to W&B.")
        if wandb.run: wandb.finish()

    print("Script finished.")

# Q6

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
import random
import numpy as np
import wandb
import os
from collections import Counter
from tqdm import tqdm
import heapq
import json
import html

# --- Constants for Special Tokens ---
SOS_TOKEN = "<sos>"  # Start-of-sequence token
EOS_TOKEN = "<eos>"  # End-of-sequence token
PAD_TOKEN = "<pad>"  # Padding token
UNK_TOKEN = "<unk>"  # Unknown token

# --- Attention Mechanism ---

class Attention(nn.Module):
    """
    Implements a general (Luong-style) attention mechanism to compute alignment scores
    between the decoder's current hidden state and all encoder's output states.
    """
    def __init__(self, encoder_hidden_dim_eff, decoder_hidden_dim):
        super(Attention, self).__init__()
        self.attn_W = nn.Linear(encoder_hidden_dim_eff + decoder_hidden_dim, decoder_hidden_dim)
        self.attn_v = nn.Linear(decoder_hidden_dim, 1, bias=False)

    def forward(self, decoder_hidden_top_layer, encoder_outputs):
        """
        Calculates attention weights.

        Args:
            decoder_hidden_top_layer (torch.Tensor): The top layer's hidden state of the decoder RNN (batch_size, dec_hidden_dim).
            encoder_outputs (torch.Tensor): All hidden states from the encoder (batch_size, src_len, enc_hidden_dim).

        Returns:
            torch.Tensor: Attention weights (batch_size, src_len) representing alignment probabilities.
        """
        src_len = encoder_outputs.size(1)
        repeated_decoder_hidden = decoder_hidden_top_layer.unsqueeze(1).repeat(1, src_len, 1)

        energy_input = torch.cat((repeated_decoder_hidden, encoder_outputs), dim=2)
        energy = torch.tanh(self.attn_W(energy_input))

        attention_scores = self.attn_v(energy).squeeze(2)
        return F.softmax(attention_scores, dim=1)

# --- Model Definition (Encoder, DecoderWithAttention, Seq2SeqWithAttention) ---

class Encoder(nn.Module):
    """
    The Encoder processes the input sequence, producing a context vector (final hidden state)
    and a sequence of output states for the attention mechanism.
    """
    def __init__(self, input_vocab_size, embedding_dim, hidden_dim, n_layers,
                 cell_type='LSTM', dropout_p=0.1, bidirectional=False, pad_idx=0):
        super(Encoder, self).__init__()
        self.input_vocab_size = input_vocab_size
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.n_layers = n_layers
        self.cell_type = cell_type.upper()
        self.bidirectional = bidirectional
        self.num_directions = 2 if bidirectional else 1

        self.embedding = nn.Embedding(input_vocab_size, embedding_dim, padding_idx=pad_idx)
        self.dropout = nn.Dropout(dropout_p)

        rnn_dropout = dropout_p if n_layers > 1 else 0
        if self.cell_type == 'RNN':
            self.rnn = nn.RNN(embedding_dim, hidden_dim, n_layers,
                              dropout=rnn_dropout, batch_first=True, bidirectional=self.bidirectional)
        elif self.cell_type == 'GRU':
            self.rnn = nn.GRU(embedding_dim, hidden_dim, n_layers,
                              dropout=rnn_dropout, batch_first=True, bidirectional=self.bidirectional)
        elif self.cell_type == 'LSTM':
            self.rnn = nn.LSTM(embedding_dim, hidden_dim, n_layers,
                               dropout=rnn_dropout, batch_first=True, bidirectional=self.bidirectional)
        else:
            raise ValueError("Unsupported cell type. Choose from 'RNN', 'GRU', 'LSTM'.")

    def forward(self, input_seq, input_lengths):
        """
        Forward pass for the encoder.

        Args:
            input_seq (torch.Tensor): Padded input sequences (batch_size, seq_len).
            input_lengths (torch.Tensor): Lengths of the original sequences (batch_size,).

        Returns:
            tuple: (`encoder_outputs`, `hidden_state`).
                   `encoder_outputs`: All hidden states from the last layer (batch_size, src_len, hidden_dim * num_directions).
                   `hidden_state`: The final hidden state (and cell state for LSTM) from all layers.
        """
        embedded = self.embedding(input_seq)
        embedded = self.dropout(embedded)

        packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, input_lengths.cpu(), batch_first=True, enforce_sorted=False)
        packed_outputs, hidden = self.rnn(packed_embedded)
        encoder_outputs, _ = nn.utils.rnn.pad_packed_sequence(packed_outputs, batch_first=True)
        return encoder_outputs, hidden

class DecoderWithAttention(nn.Module):
    """
    The Decoder generates the output sequence one token at a time,
    incorporating an attention mechanism over the encoder's output states.
    """
    def __init__(self, output_vocab_size, embedding_dim, encoder_hidden_dim_eff,
                 decoder_hidden_dim, n_layers, cell_type='LSTM', dropout_p=0.1, pad_idx=0):
        super(DecoderWithAttention, self).__init__()
        self.output_vocab_size = output_vocab_size
        self.embedding_dim = embedding_dim
        self.encoder_hidden_dim_eff = encoder_hidden_dim_eff
        self.decoder_hidden_dim = decoder_hidden_dim
        self.n_layers = n_layers
        self.cell_type = cell_type.upper()

        self.attention = Attention(encoder_hidden_dim_eff, decoder_hidden_dim)
        self.embedding = nn.Embedding(output_vocab_size, embedding_dim, padding_idx=pad_idx)
        self.dropout = nn.Dropout(dropout_p)

        # RNN input is concatenation of token embedding and context vector
        rnn_input_dim = embedding_dim + encoder_hidden_dim_eff

        rnn_dropout = dropout_p if n_layers > 1 else 0
        if self.cell_type == 'RNN':
            self.rnn = nn.RNN(rnn_input_dim, decoder_hidden_dim, n_layers,
                              dropout=rnn_dropout, batch_first=True)
        elif self.cell_type == 'GRU':
            self.rnn = nn.GRU(rnn_input_dim, decoder_hidden_dim, n_layers,
                              dropout=rnn_dropout, batch_first=True)
        elif self.cell_type == 'LSTM':
            self.rnn = nn.LSTM(rnn_input_dim, decoder_hidden_dim, n_layers,
                               dropout=rnn_dropout, batch_first=True)
        else:
            raise ValueError("Unsupported cell type. Choose from 'RNN', 'GRU', 'LSTM'.")

        self.fc_out = nn.Linear(decoder_hidden_dim, output_vocab_size)

    def forward(self, input_char, prev_decoder_hidden, encoder_outputs):
        """
        Forward pass for the attention-based decoder.

        Args:
            input_char (torch.Tensor): Current target token (batch_size,).
            prev_decoder_hidden (torch.Tensor or tuple): Previous hidden state(s) of the decoder RNN.
            encoder_outputs (torch.Tensor): All hidden states from the encoder (batch_size, src_len, enc_hidden_dim_eff).

        Returns:
            tuple: (`prediction_logits`, `current_decoder_hidden`, `attention_weights`).
                   `prediction_logits`: Raw scores for each token in the vocabulary.
                   `current_decoder_hidden`: Updated hidden state of the decoder RNN.
                   `attention_weights`: Attention probabilities (batch_size, src_len).
        """
        input_char = input_char.unsqueeze(1)
        embedded = self.embedding(input_char)
        embedded = self.dropout(embedded)

        # Get the hidden state for attention query (top layer's hidden state)
        if self.cell_type == 'LSTM':
            attention_query_hidden = prev_decoder_hidden[0][-1, :, :]
        else:
            attention_query_hidden = prev_decoder_hidden[-1, :, :]

        # Calculate attention weights and context vector
        attention_weights = self.attention(attention_query_hidden, encoder_outputs)
        context_vector = torch.bmm(attention_weights.unsqueeze(1), encoder_outputs)

        # Concatenate embedded input and context vector to form RNN input
        rnn_input = torch.cat((embedded, context_vector), dim=2)

        # Pass through decoder RNN
        rnn_output, current_decoder_hidden = self.rnn(rnn_input, prev_decoder_hidden)

        # Project RNN output to vocabulary size
        rnn_output_squeezed = rnn_output.squeeze(1)
        prediction_logits = self.fc_out(rnn_output_squeezed)

        return prediction_logits, current_decoder_hidden, attention_weights


class Seq2SeqWithAttention(nn.Module):
    """
    The main Sequence-to-Sequence model integrating the Encoder and Attention-based Decoder.
    """
    def __init__(self, encoder, decoder, device, target_sos_idx):
        super(Seq2SeqWithAttention, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device
        self.target_sos_idx = target_sos_idx

        # Layers to adapt encoder's final hidden state to decoder's initial hidden state
        encoder_effective_final_hidden_dim = self.encoder.hidden_dim * self.encoder.num_directions
        decoder_rnn_hidden_dim = self.decoder.decoder_hidden_dim

        self.fc_adapt_hidden = None
        self.fc_adapt_cell = None
        if encoder_effective_final_hidden_dim != decoder_rnn_hidden_dim:
            self.fc_adapt_hidden = nn.Linear(encoder_effective_final_hidden_dim, decoder_rnn_hidden_dim)
            if self.encoder.cell_type == 'LSTM':
                self.fc_adapt_cell = nn.Linear(encoder_effective_final_hidden_dim, decoder_rnn_hidden_dim)

    def _adapt_encoder_hidden_for_decoder(self, encoder_final_hidden_state):
        """
        Adapts the encoder's final hidden state(s) to match the decoder's expected dimensions and layer count.
        Handles bidirectionality and differing layer counts.
        """
        is_lstm = self.encoder.cell_type == 'LSTM'
        if is_lstm:
            h_from_enc, c_from_enc = encoder_final_hidden_state
        else:
            h_from_enc = encoder_final_hidden_state
            c_from_enc = None

        batch_size = h_from_enc.size(1)

        # Reshape and concatenate bidirectional states if applicable
        h_processed = h_from_enc.view(self.encoder.n_layers, self.encoder.num_directions,
                                      batch_size, self.encoder.hidden_dim)
        if self.encoder.bidirectional:
            h_processed = torch.cat([h_processed[:, 0, :, :], h_processed[:, 1, :, :]], dim=2)
        else:
            h_processed = h_processed.squeeze(1)

        c_processed = None
        if is_lstm and c_from_enc is not None:
            c_processed = c_from_enc.view(self.encoder.n_layers, self.encoder.num_directions,
                                          batch_size, self.encoder.hidden_dim)
            if self.encoder.bidirectional:
                c_processed = torch.cat([c_processed[:, 0, :, :], c_processed[:, 1, :, :]], dim=2)
            else:
                c_processed = c_processed.squeeze(1)

        # Apply linear transformation for dimension adaptation if needed
        if self.fc_adapt_hidden:
            h_processed = self.fc_adapt_hidden(h_processed)
            if is_lstm and c_processed is not None and self.fc_adapt_cell:
                c_processed = self.fc_adapt_cell(c_processed)

        # Adapt number of layers for the decoder's RNN
        final_h = torch.zeros(self.decoder.n_layers, batch_size, self.decoder.decoder_hidden_dim, device=self.device)
        final_c = torch.zeros(self.decoder.n_layers, batch_size, self.decoder.decoder_hidden_dim, device=self.device) if is_lstm else None

        num_layers_to_copy = min(self.encoder.n_layers, self.decoder.n_layers)

        final_h[:num_layers_to_copy, :, :] = h_processed[:num_layers_to_copy, :, :]
        if is_lstm and c_processed is not None:
            final_c[:num_layers_to_copy, :, :] = c_processed[:num_layers_to_copy, :, :]

        # If decoder has more layers than encoder, repeat the last encoder layer's state
        if self.decoder.n_layers > self.encoder.n_layers and self.encoder.n_layers > 0:
            last_h_layer_to_repeat = h_processed[self.encoder.n_layers-1, :, :]
            for i in range(self.encoder.n_layers, self.decoder.n_layers):
                final_h[i, :, :] = last_h_layer_to_repeat
                if is_lstm and c_processed is not None:
                    last_c_layer_to_repeat = c_processed[self.encoder.n_layers-1, :, :]
                    final_c[i, :, :] = last_c_layer_to_repeat

        return (final_h, final_c) if is_lstm else final_h

    def forward(self, source_seq, source_lengths, target_seq, teacher_forcing_ratio=0.5):
        """
        Forward pass for the Seq2SeqWithAttention model during training.

        Args:
            source_seq (torch.Tensor): Padded source sequences.
            source_lengths (torch.Tensor): Lengths of source sequences.
            target_seq (torch.Tensor): Padded target sequences (including SOS token).
            teacher_forcing_ratio (float): Probability of using actual target token as next input.

        Returns:
            torch.Tensor: Logits for the predicted target sequence.
        """
        batch_size = source_seq.shape[0]
        target_len = target_seq.shape[1]
        target_vocab_size = self.decoder.output_vocab_size
        outputs_logits = torch.zeros(batch_size, target_len, target_vocab_size).to(self.device)

        # Encode the source sequence to get all encoder outputs and final hidden state
        encoder_outputs, encoder_final_hidden = self.encoder(source_seq, source_lengths)
        # Adapt encoder's final hidden state to initialize the decoder's hidden state
        decoder_hidden = self._adapt_encoder_hidden_for_decoder(encoder_final_hidden)
        decoder_input = target_seq[:, 0]

        # Iterate through the target sequence, predicting one token at a time
        for t in range(target_len - 1):
            decoder_output_logits, decoder_hidden, _ = \
                self.decoder(decoder_input, decoder_hidden, encoder_outputs)
            outputs_logits[:, t+1] = decoder_output_logits

            # Apply teacher forcing: use true target or predicted token for next input
            teacher_force_this_step = random.random() < teacher_forcing_ratio
            top1_predicted_token = decoder_output_logits.argmax(1)

            decoder_input = target_seq[:, t+1] if teacher_force_this_step else top1_predicted_token
        return outputs_logits

    def predict_with_attention(self, source_seq, source_lengths, max_output_len=50, target_eos_idx=None):
        """
        Generates a sequence using greedy decoding and records attention weights.

        Args:
            source_seq (torch.Tensor): Input source sequence (batch_size=1).
            source_lengths (torch.Tensor): Length of the source sequence.
            max_output_len (int): Maximum length of the sequence to generate.
            target_eos_idx (int, optional): Index of the EOS token.

        Returns:
            tuple: Predicted token indices and attention matrices (numpy array).
        """
        self.eval()
        if source_seq.dim() == 1:
            source_seq = source_seq.unsqueeze(0)
            source_lengths = torch.tensor([source_lengths if isinstance(source_lengths, int) else len(source_lengths)], device=self.device)

        with torch.no_grad():
            encoder_outputs, encoder_final_hidden = self.encoder(source_seq, source_lengths)
            decoder_hidden = self._adapt_encoder_hidden_for_decoder(encoder_final_hidden)
            decoder_input = torch.tensor([self.target_sos_idx], device=self.device)

            predicted_indices = []
            attention_matrices = []

            for _ in range(max_output_len):
                decoder_output_logits, decoder_hidden, attention_weights = \
                    self.decoder(decoder_input, decoder_hidden, encoder_outputs)

                attention_matrices.append(attention_weights.squeeze(0).cpu().numpy())

                top1_predicted_token = decoder_output_logits.argmax(1)
                predicted_idx = top1_predicted_token.item()

                predicted_indices.append(predicted_idx)
                if target_eos_idx is not None and predicted_idx == target_eos_idx:
                    break

                if len(predicted_indices) >= max_output_len: break
                decoder_input = top1_predicted_token
        return predicted_indices, np.array(attention_matrices) if attention_matrices else None

# --- Data Loading and Preprocessing ---

class Vocabulary:
    """
    Manages the mapping between characters and their numerical indices.
    Includes special tokens for padding, start-of-sequence, end-of-sequence, and unknown characters.
    """
    def __init__(self, name):
        self.name = name
        self.char2index = {PAD_TOKEN: 0, SOS_TOKEN: 1, EOS_TOKEN: 2, UNK_TOKEN: 3}
        self.index2char = {0: PAD_TOKEN, 1: SOS_TOKEN, 2: EOS_TOKEN, 3: UNK_TOKEN}
        self.char_counts = Counter()
        self.n_chars = 4
        self.pad_idx = self.char2index[PAD_TOKEN]
        self.sos_idx = self.char2index[SOS_TOKEN]
        self.eos_idx = self.char2index[EOS_TOKEN]

    def add_sequence(self, sequence):
        """Adds characters from a sequence to the character counts."""
        for char in list(sequence):
            self.char_counts[char] += 1

    def build_vocab(self, min_freq=1):
        """
        Builds the vocabulary mapping based on character counts and a minimum frequency.
        Characters appearing less than `min_freq` will be treated as UNK_TOKEN.
        """
        sorted_chars = sorted(self.char_counts.keys(), key=lambda char: (-self.char_counts[char], char))
        for char in sorted_chars:
            if self.char_counts[char] >= min_freq and char not in self.char2index:
                self.char2index[char] = self.n_chars
                self.index2char[self.n_chars] = char
                self.n_chars += 1

    def sequence_to_indices(self, sequence, add_eos=False, add_sos=False, max_len=None):
        """Converts a character sequence into a list of numerical indices."""
        indices = []
        if add_sos:
            indices.append(self.sos_idx)

        seq_to_process = list(sequence)
        if max_len:
            effective_max_len = max_len
            if add_sos: effective_max_len -=1
            if add_eos: effective_max_len -=1
            seq_to_process = seq_to_process[:max(0, effective_max_len)] # Ensure non-negative slice

        for char in seq_to_process:
            indices.append(self.char2index.get(char, self.char2index[UNK_TOKEN]))

        if add_eos: indices.append(self.eos_idx)
        return indices

    def indices_to_sequence(self, indices, strip_special=True):
        """Converts a list of numerical indices back into a character sequence."""
        chars = []
        for index_val in indices:
            if strip_special and index_val == self.eos_idx: break
            if strip_special and index_val in [self.sos_idx, self.pad_idx]: continue
            chars.append(self.index2char.get(index_val, UNK_TOKEN))
        return "".join(chars)

class TransliterationDataset(Dataset):
    """
    A PyTorch Dataset for loading and preparing transliteration pairs.
    Reads data from a TSV file and converts text sequences to token indices.
    """
    def __init__(self, file_path, source_vocab, target_vocab, max_len=50):
        self.pairs = []
        self.source_vocab = source_vocab
        self.target_vocab = target_vocab
        self.max_len = max_len

        if not os.path.exists(file_path):
            print(f"ERROR: Data file not found: {file_path}")
            return

        print(f"Loading data from: {file_path}")
        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                for line_num, line in enumerate(f):
                    parts = line.strip().split('\t')
                    if len(parts) >= 2:
                        target, source = parts[0], parts[1]
                        if not source or not target or \
                           (self.max_len and (len(source) > self.max_len or len(target) > self.max_len)):
                            continue
                        self.pairs.append((source, target))
            print(f"Loaded {len(self.pairs)} pairs from {file_path}.")
        except Exception as e:
            print(f"ERROR: Could not read or process file {file_path}. Error: {e}")

    def __len__(self): return len(self.pairs)
    def __getitem__(self, idx):
        """Returns a source-target pair as Tensors of indices."""
        source_str, target_str = self.pairs[idx]
        source_indices = self.source_vocab.sequence_to_indices(source_str, add_eos=True, max_len=self.max_len)
        target_indices = self.target_vocab.sequence_to_indices(target_str, add_sos=True, add_eos=True, max_len=self.max_len)
        return torch.tensor(source_indices, dtype=torch.long), torch.tensor(target_indices, dtype=torch.long)

def collate_fn(batch, pad_idx_source, pad_idx_target):
    """
    Custom collate function for DataLoader to handle variable-length sequences.
    Pads sequences within a batch to the maximum length of that batch.
    """
    batch = [item for item in batch if item is not None and len(item[0]) > 0 and len(item[1]) > 0]
    if not batch: return None, None, None
    source_seqs, target_seqs = zip(*batch)

    source_lengths = torch.tensor([len(s) for s in source_seqs], dtype=torch.long)
    padded_sources = pad_sequence(source_seqs, batch_first=True, padding_value=pad_idx_source)
    padded_targets = pad_sequence(target_seqs, batch_first=True, padding_value=pad_idx_target)
    return padded_sources, source_lengths, padded_targets

# --- Training and Evaluation Functions ---

def _train_one_epoch(model, dataloader, optimizer, criterion, device, clip_value, teacher_forcing_ratio, target_vocab):
    """
    Trains the model for a single epoch.

    Args:
        model (nn.Module): The Seq2Seq model.
        dataloader (DataLoader): DataLoader for the training set.
        optimizer (optim.Optimizer): Optimizer for model parameters.
        criterion (nn.Module): Loss function (e.g., CrossEntropyLoss).
        device (torch.device): Device to run the model on.
        clip_value (float): Gradient clipping value.
        teacher_forcing_ratio (float): Probability of using teacher forcing.
        target_vocab (Vocabulary): Vocabulary for the target language.

    Returns:
        tuple: Average epoch loss and training accuracy (0.0 for this simplified version).
    """
    model.train()
    epoch_loss = 0
    if len(dataloader) == 0: return 0.0, 0.0

    for batch_data in tqdm(dataloader, desc="Training", leave=False):
        if batch_data[0] is None: continue
        sources, source_lengths, targets = batch_data
        if sources is None or sources.shape[0] == 0: continue

        sources, targets, source_lengths = sources.to(device), targets.to(device), source_lengths.to(device)
        optimizer.zero_grad()
        outputs_logits = model(sources, source_lengths, targets, teacher_forcing_ratio)

        output_dim = outputs_logits.shape[-1]
        flat_outputs = outputs_logits[:, 1:].reshape(-1, output_dim)
        flat_targets = targets[:, 1:].reshape(-1)

        loss = criterion(flat_outputs, flat_targets)
        if not (torch.isnan(loss) or torch.isinf(loss)):
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), clip_value)
            optimizer.step()
            epoch_loss += loss.item()
    return epoch_loss / len(dataloader) if len(dataloader) > 0 else 0.0, 0.0 # Train accuracy is not computed here


def _evaluate_one_epoch(model, dataloader, criterion, device, source_vocab, target_vocab, beam_width=1, is_test_set=False):
    """
    Evaluates the model's performance on a given dataset (validation or test).

    Args:
        model (nn.Module): The Seq2Seq model.
        dataloader (DataLoader): DataLoader for the evaluation set.
        criterion (nn.Module): Loss function.
        device (torch.device): Device to run the model on.
        source_vocab (Vocabulary): Vocabulary for the source language (used for debug prints).
        target_vocab (Vocabulary): Vocabulary for the target language.
        beam_width (int): Beam width for decoding (1 for greedy, >1 for beam search).
        is_test_set (bool): Flag to indicate if this is the final test evaluation (for debug prints).

    Returns:
        tuple: Average epoch loss and evaluation accuracy.
    """
    model.eval()
    epoch_loss = 0
    total_correct = 0
    total_samples = 0
    if len(dataloader) == 0: return 0.0, 0.0
    desc_prefix = "Testing" if is_test_set else "Validating"

    with torch.no_grad():
        for batch_data in tqdm(dataloader, desc=desc_prefix, leave=False):
            if batch_data[0] is None: continue
            sources, source_lengths, targets = batch_data
            if sources is None or sources.shape[0] == 0: continue

            sources, targets, source_lengths = sources.to(device), targets.to(device), source_lengths.to(device)
            outputs = model(sources, source_lengths, targets, teacher_forcing_ratio=0.0)
            output_dim = outputs.shape[-1]
            flat_outputs = outputs[:, 1:].reshape(-1, output_dim)
            flat_targets = targets[:, 1:].reshape(-1)
            loss = criterion(flat_outputs, flat_targets)
            epoch_loss += loss.item() if not (torch.isnan(loss) or torch.isinf(loss)) else 0

            # For accuracy, we need to generate predictions (greedy or beam search)
            for i in range(targets.shape[0]):
                src_single, src_len_single = sources[i:i+1], source_lengths[i:i+1]

                # Check if predict_with_attention exists for attention models, otherwise use predict_greedy
                if hasattr(model, 'predict_with_attention') and beam_width == 1: # Use greedy for attention visualization
                    predicted_indices, _ = model.predict_with_attention(
                        src_single, src_len_single,
                        max_output_len=targets.size(1) + 5,
                        target_eos_idx=target_vocab.eos_idx
                    )
                elif hasattr(model, 'predict_beam_search') and beam_width > 1:
                    predicted_indices = model.predict_beam_search(
                        src_single, src_len_single,
                        max_output_len=targets.size(1) + 5,
                        target_eos_idx=target_vocab.eos_idx,
                        beam_width=beam_width
                    )
                else: # Fallback for non-attention or unexpected cases
                    # This branch should not be reached if Seq2Seq or Seq2SeqWithAttention is used correctly.
                    # This part could be model.predict_greedy if implemented directly in Seq2Seq base class
                    # For current `Seq2SeqWithAttention`, this is not directly available, but for `Seq2Seq` it is.
                    # To be safe, we'll assume the model has *some* prediction method for accuracy calculation.
                    # Or, as a last resort, extract from `outputs` (which is teacher-forced for loss)
                    predicted_indices = outputs[i:i+1].argmax(dim=2)[0, 1:].tolist()

                pred_str = target_vocab.indices_to_sequence(predicted_indices, strip_special=True)
                true_str = target_vocab.indices_to_sequence(targets[i, 1:].tolist(), strip_special=True)

                if pred_str == true_str: total_correct += 1
                total_samples += 1

                if is_test_set and batch_idx == 0 and i < 3:
                    print(f"  Test Example {i} - Source: '{source_vocab.indices_to_sequence(src_single[0].tolist(), strip_special=True)}'")
                    print(f"    Pred: '{pred_str}', True: '{true_str}', Match: {pred_str == true_str}")

    accuracy = total_correct / total_samples if total_samples > 0 else 0.0
    return epoch_loss / len(dataloader) if len(dataloader) > 0 else 0.0, accuracy

# --- Function to Train and Save the Best Attention Model ---

def train_and_save_attention_model(config_params, model_save_path, device):
    """
    Performs a dedicated training run using the best hyperparameters for the attention model.
    Saves the model checkpoint with the best validation accuracy.

    Args:
        config_params (dict): Dictionary of hyperparameters for this specific training run.
        model_save_path (str): Path to save the best model's state_dict.
        device (torch.device): Device to run the training on.

    Returns:
        bool: True if training was successful and a model was saved, False otherwise.
    """
    run_name_train_best_attn = f"TRAIN_BEST_ATTN_{config_params['cell_type']}_emb{config_params['embedding_dim']}_hid{config_params['hidden_dim']}"

    with wandb.init(project="DL_A3_Attention_Training", name=run_name_train_best_attn, config=config_params, job_type="training_best_attention_model", reinit=True) as run:
        cfg = wandb.config
        print(f"Starting dedicated training for BEST ATTENTION model with config: {cfg}")

        # --- Data Loading and Vocabulary Building ---
        BASE_DATA_DIR = "/kaggle/input/dakshina-dl-a3/dakshina_dataset_v1.0/hi/"
        DATA_DIR = os.path.join(BASE_DATA_DIR, "lexicons/")
        train_file = os.path.join(DATA_DIR, "hi.translit.sampled.train.tsv")
        dev_file = os.path.join(DATA_DIR, "hi.translit.sampled.dev.tsv")

        source_vocab = Vocabulary("latin")
        target_vocab = Vocabulary("devanagari")
        temp_train_ds_vocab = TransliterationDataset(train_file, source_vocab, target_vocab, max_len=cfg.max_seq_len)
        if not temp_train_ds_vocab.pairs:
            print(f"ERROR: No data for vocab from {train_file}")
            return False
        for src, tgt in temp_train_ds_vocab.pairs:
            source_vocab.add_sequence(src)
            target_vocab.add_sequence(tgt)
        source_vocab.build_vocab(min_freq=cfg.vocab_min_freq)
        target_vocab.build_vocab(min_freq=cfg.vocab_min_freq)

        train_dataset = TransliterationDataset(train_file, source_vocab, target_vocab, max_len=cfg.max_seq_len)
        dev_dataset = TransliterationDataset(dev_file, source_vocab, target_vocab, max_len=cfg.max_seq_len)
        if not train_dataset.pairs or not dev_dataset.pairs:
            print(f"ERROR: Train or Dev dataset empty. Train: {len(train_dataset)}, Dev: {len(dev_dataset)}")
            return False

        # --- DataLoader Setup ---
        num_w = 0 if device.type == 'cpu' else min(4, os.cpu_count()//2 if os.cpu_count() else 0)
        train_loader = DataLoader(train_dataset, batch_size=cfg.batch_size_train, shuffle=True,
                                  collate_fn=lambda b: collate_fn(b, source_vocab.pad_idx, target_vocab.pad_idx),
                                  num_workers=num_w, pin_memory=True if device.type == 'cuda' else False, drop_last=True)
        dev_loader = DataLoader(dev_dataset, batch_size=cfg.batch_size_train, shuffle=False,
                                collate_fn=lambda b: collate_fn(b, source_vocab.pad_idx, target_vocab.pad_idx),
                                num_workers=num_w, pin_memory=True if device.type == 'cuda' else False)
        if not train_loader or len(train_loader)==0 or not dev_loader or len(dev_loader)==0:
            print(f"ERROR: Train/Dev Dataloader empty. Train: {len(train_loader)}, Dev: {len(dev_loader)}")
            return False

        # --- Model, Optimizer, Loss Function Setup ---
        encoder_hidden_dim_eff = cfg.hidden_dim * (2 if cfg.encoder_bidirectional else 1)
        encoder = Encoder(source_vocab.n_chars, cfg.embedding_dim, cfg.hidden_dim,
                          cfg.encoder_layers, cfg.cell_type, cfg.dropout_p,
                          cfg.encoder_bidirectional, pad_idx=source_vocab.pad_idx).to(device)
        decoder = DecoderWithAttention(target_vocab.n_chars, cfg.embedding_dim, encoder_hidden_dim_eff,
                                       cfg.hidden_dim, cfg.decoder_layers, cfg.cell_type,
                                       cfg.dropout_p, pad_idx=target_vocab.pad_idx).to(device)
        model = Seq2SeqWithAttention(encoder, decoder, device, target_vocab.sos_idx).to(device)
        print(f"Best Attention Model parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
        wandb.watch(model, log="all", log_freq=100)

        optimizer = optim.Adam(model.parameters(), lr=cfg.learning_rate_train)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=cfg.lr_scheduler_patience_train, factor=0.3, verbose=True)
        criterion = nn.CrossEntropyLoss(ignore_index=target_vocab.pad_idx)

        best_val_acc_this_training = -1.0
        epochs_no_improve = 0
        max_epochs_no_improve_val = cfg.get('max_epochs_no_improve_train', 7)

        for epoch in range(cfg.epochs_train):
            train_loss, _ = _train_one_epoch(model, train_loader, optimizer, criterion, device,
                                             cfg.clip_value_train, cfg.teacher_forcing_train, target_vocab)
            val_loss, val_acc = _evaluate_one_epoch(model, dev_loader, criterion, device,
                                                    source_vocab, target_vocab, beam_width=1, is_test_set=False)

            scheduler.step(val_acc)
            print(f"Epoch {epoch+1}/{cfg.epochs_train} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}")
            wandb.log({"epoch_train_best_attn": epoch + 1, "train_loss_best_attn": train_loss,
                       "val_loss_best_attn": val_loss, "val_acc_best_attn": val_acc, "lr_best_attn": optimizer.param_groups[0]['lr']})

            # Early stopping logic
            if val_acc > best_val_acc_this_training:
                best_val_acc_this_training = val_acc
                epochs_no_improve = 0
                torch.save(model.state_dict(), model_save_path)
                print(f"Saved new best attention model to {model_save_path} (Val Acc: {best_val_acc_this_training:.4f})")
            else:
                epochs_no_improve += 1

            if epochs_no_improve >= max_epochs_no_improve_val:
                print(f"Early stopping for best attention model training at epoch {epoch+1}.")
                break

        # Save the final model state if no improvement happened over the initial checkpoint
        if not os.path.exists(model_save_path):
            torch.save(model.state_dict(), model_save_path)
            print(f"Saved final attention model state to {model_save_path}")

        wandb.summary["final_best_val_accuracy_during_attn_training"] = best_val_acc_this_training
        print(f"Finished training best attention model. Best Val Acc during this training: {best_val_acc_this_training:.4f}")
    return True


# --- Function to Generate Interactive HTML for W&B ---

def generate_interactive_attention_html(examples_data):
    """
    Generates an HTML string for interactive attention visualization.
    The HTML includes JavaScript to display source and target characters,
    and highlight source characters based on attention weights when hovering over target characters.

    Args:
        examples_data (list): A list of dictionaries, each containing:
                              - 'name': Name of the example (e.g., "Source -> Target").
                              - 'sourceChars': List of source characters.
                              - 'targetChars': List of predicted target characters.
                              - 'attentionMatrix': 2D list of attention weights (rows=target, cols=source).

    Returns:
        str: The complete HTML content as a string.
    """
    # Embed Python data as JSON string directly into JavaScript
    escaped_examples_data_str = json.dumps(examples_data)

    html_content = f"""
<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <title>Interactive Attention Visualization</title>
    <script src="https://cdn.tailwindcss.com"></script>
    <style>
        body {{ font-family: 'Inter', sans-serif; display: flex; flex-direction: column; align-items: center; padding: 20px; background-color: #f9fafb; }}
        .container {{ background-color: white; padding: 2rem; border-radius: 0.75rem; box-shadow: 0 10px 15px -3px rgba(0,0,0,0.1), 0 4px 6px -2px rgba(0,0,0,0.05); width: 100%; max-width: 900px; }}
        .word-display {{ display: flex; flex-wrap: wrap; margin-bottom: 1rem; padding: 0.5rem; border-radius: 0.5rem; background-color: #f3f4f6; min-height: 40px; align-items: center; justify-content: flex-start; }}
        .char-box {{ display: inline-flex; align-items: center; justify-content: center; padding: 0.3rem 0.6rem; margin: 0.15rem; border: 1px solid #d1d5db; border-radius: 0.375rem; font-size: 1.25rem; min-width: 30px; height: 40px; text-align: center; transition: background-color 0.1s ease-in-out; }}
        .input-char {{ background-color: #ffffff; }}
        .output-char {{ background-color: #eff6ff; cursor: pointer; }}
        .output-char:hover {{ background-color: #dbeafe; }}
        .devanagari {{ font-family: 'Arial Unicode MS', 'Noto Sans Devanagari', sans-serif; }}
        .highlight-base {{ background-color: #6ee7b7; }} /* Tailwind green-300 base */
    </style>
</head>
<body>
    <div class="container">
        <h1 class="text-2xl font-semibold text-gray-800 mb-4">Interactive Attention Connectivity</h1>
        <div class="controls">
            <label for="exampleSelectViz" class="mr-2 text-gray-700">Select Example:</label>
            <select id="exampleSelectViz"></select>
        </div>
        <div class="text-lg font-medium text-gray-700 mb-1">Input (Latin):</div>
        <div id="inputSequenceViz" class="word-display"></div>
        <div class="text-lg font-medium text-gray-700 mt-4 mb-1">Predicted Output (Devanagari):</div>
        <div id="outputSequenceViz" class="word-display devanagari"></div>
        <p class="text-sm text-gray-600 mt-2">Hover over a Devanagari character to see its attention to Latin characters (highlighted green).</p>
    </div>
    <script>
        const examplesData = {escaped_examples_data_str};

        const inputDiv = document.getElementById('inputSequenceViz');
        const outputDiv = document.getElementById('outputSequenceViz');
        const selectEl = document.getElementById('exampleSelectViz');

        function displayExample(index) {{
            const example = examplesData[index];
            inputDiv.innerHTML = '';
            outputDiv.innerHTML = '';
            const inputSpans = [];

            example.sourceChars.forEach(char => {{
                const span = document.createElement('span');
                span.className = 'char-box input-char';
                span.textContent = char;
                inputDiv.appendChild(span);
                inputSpans.push(span);
            }});

            example.targetChars.forEach((char, targetIdx) => {{
                const span = document.createElement('span');
                span.className = 'char-box output-char devanagari';
                span.textContent = char;
                span.addEventListener('mouseover', () => {{
                    if (targetIdx < example.attentionMatrix.length) {{
                        const weights = example.attentionMatrix[targetIdx];
                        inputSpans.forEach((inputSpan, inputIdx) => {{
                            if (inputIdx < weights.length) {{
                                const weight = weights[inputIdx];
                                inputSpan.style.backgroundColor = `rgba(52, 211, 153, ${{weight * 0.8 + 0.2}})`; // Tailwind green-400 with opacity
                            }}
                        }});
                    }}
                }});
                span.addEventListener('mouseout', () => {{
                    inputSpans.forEach(s => s.style.backgroundColor = '#ffffff');
                }});
                outputDiv.appendChild(span);
            }});
        }}

        examplesData.forEach((ex, i) => {{
            const option = document.createElement('option');
            option.value = i;
            option.textContent = ex.name || `Example ${{i + 1}}`;
            selectEl.appendChild(option);
        }});
        selectEl.addEventListener('change', (e) => displayExample(e.target.value));
        if (examplesData.length > 0) {{
            displayExample(0); // Load first example
        }}
    </script>
</body>
</html>
    """
    return html_content

# --- Main Execution Block for Attention Model Training & Visualization ---

if __name__ == '__main__':
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {DEVICE}")

    # --- W&B Login ---
    try:
        if 'KAGGLE_KERNEL_RUN_TYPE' in os.environ:
            print("Detected Kaggle environment. Ensuring WANDB_API_KEY is set.")
            if "WANDB_API_KEY" in os.environ:
                wandb.login()
                print("W&B login using environment variable.")
            else:
                from kaggle_secrets import UserSecretsClient
                user_secrets = UserSecretsClient()
                wandb_api_key = user_secrets.get_secret("WANDB_API_KEY")
                wandb.login(key=wandb_api_key)
                print("W&B login for Kaggle successful.")
        else:
            wandb.login()
        print("W&B login process attempted/completed.")
    except Exception as e:
        print(f"W&B login failed: {e}. Ensure API key is configured.")
        wandb = None # Disable wandb if login fails

    # --- Best Attention Model Hyperparameters (UPDATE THIS SECTION) ---
    # These hyperparameters should come from your W&B sweep for the Attention Model.
    BEST_ATTN_HYPERPARAMETERS = {
        'embedding_dim': 256,
        'hidden_dim': 512,
        'encoder_layers': 1,
        'decoder_layers': 1,
        'cell_type': 'LSTM',
        'dropout_p': 0.3,
        'encoder_bidirectional': True,
        'learning_rate_train': 0.0008,
        'batch_size_train': 64,
        'epochs_train': 20,
        'clip_value_train': 1.0,
        'teacher_forcing_train': 0.5,
        'max_epochs_no_improve_train': 7,
        'lr_scheduler_patience_train': 3,
        'vocab_min_freq': 1,
        'max_seq_len': 50,
        'eval_batch_size': 64,
    }
    ATTN_MODEL_SAVE_PATH = "/kaggle/working/best_attention_model_q5.pt"

    print(f"Target hyperparameters for best attention model: {BEST_ATTN_HYPERPARAMETERS}")

    # --- Phase 1: Train and Save the Best Attention Model (if checkpoint doesn't exist) ---
    if not os.path.exists(ATTN_MODEL_SAVE_PATH):
        print(f"\n--- Attention Model Checkpoint NOT FOUND at {ATTN_MODEL_SAVE_PATH} ---")
        print("--- Attempting to TRAIN AND SAVE the best attention model using BEST_ATTN_HYPERPARAMETERS ---")
        training_successful = train_and_save_attention_model(BEST_ATTN_HYPERPARAMETERS, ATTN_MODEL_SAVE_PATH, DEVICE)
        if not training_successful or not os.path.exists(ATTN_MODEL_SAVE_PATH):
            print("ERROR: Failed to train and save the best attention model. Exiting.")
            exit()
        print(f"Best attention model trained and saved to {ATTN_MODEL_SAVE_PATH}")
    else:
        print(f"Found existing attention model checkpoint at {ATTN_MODEL_SAVE_PATH}. Will use this for visualization.")

    # --- Phase 2: Load Model and Prepare for Interactive Visualization ---
    print("\n--- Preparing for Interactive Attention Visualization ---")

    BASE_DATA_DIR = "/kaggle/input/dakshina-dl-a3/dakshina_dataset_v1.0/hi/"
    DATA_DIR = os.path.join(BASE_DATA_DIR, "lexicons/")
    train_file_for_vocab = os.path.join(DATA_DIR, "hi.translit.sampled.train.tsv")
    test_file = os.path.join(DATA_DIR, "hi.translit.sampled.test.tsv")

    source_vocab = Vocabulary("latin_attn_viz")
    target_vocab = Vocabulary("devanagari_attn_viz")

    temp_train_ds_attn_vocab = TransliterationDataset(train_file_for_vocab, source_vocab, target_vocab,
                                                max_len=BEST_ATTN_HYPERPARAMETERS['max_seq_len'])
    if not temp_train_ds_attn_vocab.pairs: exit()
    for src, tgt in temp_train_ds_attn_vocab.pairs:
        source_vocab.add_sequence(src)
        target_vocab.add_sequence(tgt)
    source_vocab.build_vocab(min_freq=BEST_ATTN_HYPERPARAMETERS['vocab_min_freq'])
    target_vocab.build_vocab(min_freq=BEST_ATTN_HYPERPARAMETERS['vocab_min_freq'])

    test_dataset_for_viz = TransliterationDataset(test_file, source_vocab, target_vocab,
                                                  max_len=BEST_ATTN_HYPERPARAMETERS['max_seq_len'])
    if not test_dataset_for_viz.pairs:
        print("Test dataset for visualization is empty. Exiting.")
        exit()

    encoder_hidden_dim_eff_test = BEST_ATTN_HYPERPARAMETERS['hidden_dim'] * (2 if BEST_ATTN_HYPERPARAMETERS['encoder_bidirectional'] else 1)
    encoder = Encoder(source_vocab.n_chars, BEST_ATTN_HYPERPARAMETERS['embedding_dim'], BEST_ATTN_HYPERPARAMETERS['hidden_dim'],
                      BEST_ATTN_HYPERPARAMETERS['encoder_layers'], BEST_ATTN_HYPERPARAMETERS['cell_type'], BEST_ATTN_HYPERPARAMETERS['dropout_p'],
                      BEST_ATTN_HYPERPARAMETERS['encoder_bidirectional'], pad_idx=source_vocab.pad_idx).to(DEVICE)
    decoder = DecoderWithAttention(target_vocab.n_chars, BEST_ATTN_HYPERPARAMETERS['embedding_dim'],
                                   encoder_hidden_dim_eff_test, BEST_ATTN_HYPERPARAMETERS['hidden_dim'],
                                   BEST_ATTN_HYPERPARAMETERS['decoder_layers'], BEST_ATTN_HYPERPARAMETERS['cell_type'], BEST_ATTN_HYPERPARAMETERS['dropout_p'],
                                   pad_idx=target_vocab.pad_idx).to(DEVICE)
    model = Seq2SeqWithAttention(encoder, decoder, DEVICE, target_vocab.sos_idx).to(DEVICE)

    print(f"Loading weights into attention model from: {ATTN_MODEL_SAVE_PATH}")
    model.load_state_dict(torch.load(ATTN_MODEL_SAVE_PATH, map_location=DEVICE))
    print("Attention model weights loaded successfully.")
    model.eval() # Set model to evaluation mode

    # --- Generate Data for Interactive Visualization ---
    num_viz_samples = 10
    actual_num_viz_samples = min(num_viz_samples, len(test_dataset_for_viz.pairs))
    sample_indices_viz = []
    if len(test_dataset_for_viz.pairs) > 0:
        sample_indices_viz = random.sample(range(len(test_dataset_for_viz.pairs)), actual_num_viz_samples)

    examples_for_html = []
    print(f"\n--- Generating Data for {len(sample_indices_viz)} Interactive Visualizations ---")

    for i, data_idx in enumerate(sample_indices_viz):
        source_str, true_target_str = test_dataset_for_viz.pairs[data_idx]

        source_indices_for_pred = source_vocab.sequence_to_indices(source_str, add_eos=True, max_len=BEST_ATTN_HYPERPARAMETERS['max_seq_len'])
        source_tensor = torch.tensor(source_indices_for_pred, dtype=torch.long).unsqueeze(0).to(DEVICE)
        source_length = torch.tensor([len(source_indices_for_pred)], dtype=torch.long).to(DEVICE)

        # Get predicted indices and attention matrix
        predicted_indices, attention_matrix_np = model.predict_with_attention(
            source_tensor,
            source_length,
            max_output_len=len(true_target_str) + 10, # Generate slightly longer than true target
            target_eos_idx=target_vocab.eos_idx
        )

        # Get characters for display (strip all special tokens for clean display)
        display_source_chars = list(source_vocab.indices_to_sequence(source_indices_for_pred, strip_special=True))
        display_target_chars = list(target_vocab.indices_to_sequence(predicted_indices, strip_special=True))

        # Align attention matrix with the displayed characters
        final_attention_matrix_for_js = []
        if attention_matrix_np is not None and \
           len(display_target_chars) > 0 and \
           len(display_source_chars) > 0 and \
           attention_matrix_np.shape[0] >= len(display_target_chars) and \
           attention_matrix_np.shape[1] >= len(display_source_chars) :

            # Slice attention matrix to correspond to the displayed non-special tokens
            num_rows_to_take = len(display_target_chars)
            num_cols_to_take = len(display_source_chars)
            final_attention_matrix_for_js = attention_matrix_np[:num_rows_to_take, :num_cols_to_take].tolist()
        else:
            print(f"  Skipping example '{source_str}' due to attention matrix dimension mismatch or empty strings.")
            print(f"  Attn shape: {attention_matrix_np.shape if attention_matrix_np is not None else 'None'}, "
                  f"display_target_len: {len(display_target_chars)}, display_source_len: {len(display_source_chars)}")
            continue # Skip this example if dimensions don't align

        examples_for_html.append({
            "name": f"Ex {i+1}: {source_str} -> {true_target_str}",
            "sourceChars": display_source_chars,
            "targetChars": display_target_chars,
            "attentionMatrix": final_attention_matrix_for_js
        })

    # --- Generate and Log Interactive HTML to W&B ---
    if examples_for_html:
        interactive_html_content = generate_interactive_attention_html(examples_for_html)

        # Ensure a W&B run is active for logging
        if wandb and wandb.run is None:
            wandb.init(project="DL_A3", name="Q6_Interactive_Attention_Viz", config=BEST_ATTN_HYPERPARAMETERS, job_type="q6_visualization", reinit=True)

        if wandb and wandb.run:
            wandb.log({"interactive_attention_visualization": wandb.Html(interactive_html_content)})
            print("\nLogged interactive attention visualization to W&B.")

            # Save HTML locally for inspection
            with open("/kaggle/working/interactive_attention_viz.html", "w", encoding="utf-8") as f_html:
                f_html.write(interactive_html_content)
            print("Saved interactive HTML to /kaggle/working/interactive_attention_viz.html")

            # Log the generated HTML file as an artifact
            html_artifact = wandb.Artifact("interactive_attention_html", type="visualization")
            html_artifact.add_file("/kaggle/working/interactive_attention_viz.html")
            wandb.log_artifact(html_artifact)
            print("Logged HTML file as a W&B artifact.")

        else:
            print("W&B run not active. Could not log interactive HTML.")
            # Save HTML locally if W&B is not active
            with open("/kaggle/working/interactive_attention_viz.html", "w", encoding="utf-8") as f_html:
                f_html.write(interactive_html_content)
            print("W&B not active. Saved interactive HTML locally to /kaggle/working/interactive_attention_viz.html")
    else:
        print("No valid examples with attention data were generated for the interactive HTML.")

    # Finish the W&B run if one was started
    if wandb and wandb.run: wandb.finish()
    print("Script finished.")