# Question 2

In [None]:
import torch
import torch.nn as nn
import random

class Encoder(nn.Module):
    """
    Encoder part of the seq2seq model.
    It takes a sequence of input characters and produces a context vector.
    """
    def __init__(self, input_vocab_size, embedding_dim, hidden_dim, n_layers,
                 cell_type='LSTM', dropout_p=0.1, bidirectional=False):
        """
        Initializes the Encoder.
        """
        super(Encoder, self).__init__()

        self.input_vocab_size = input_vocab_size
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.n_layers = n_layers
        self.cell_type = cell_type.upper()
        self.bidirectional = bidirectional
        self.num_directions = 2 if bidirectional else 1

        self.embedding = nn.Embedding(input_vocab_size, embedding_dim)
        self.dropout = nn.Dropout(dropout_p)

        rnn_dropout = dropout_p if n_layers > 1 else 0
        if self.cell_type == 'RNN':
            self.rnn = nn.RNN(embedding_dim, hidden_dim, n_layers,
                              dropout=rnn_dropout, batch_first=True, bidirectional=self.bidirectional)
        elif self.cell_type == 'GRU':
            self.rnn = nn.GRU(embedding_dim, hidden_dim, n_layers,
                              dropout=rnn_dropout, batch_first=True, bidirectional=self.bidirectional)
        elif self.cell_type == 'LSTM':
            self.rnn = nn.LSTM(embedding_dim, hidden_dim, n_layers,
                               dropout=rnn_dropout, batch_first=True, bidirectional=self.bidirectional)
        else:
            raise ValueError("Unsupported cell type. Choose from 'RNN', 'GRU', 'LSTM'.")

    def forward(self, input_seq):
        """
        Forward pass of the encoder.
        """
        embedded = self.embedding(input_seq)
        embedded = self.dropout(embedded)

        outputs, hidden = self.rnn(embedded)

        return outputs, hidden

class Decoder(nn.Module):
    """
    Decoder part of the seq2seq model.
    It takes the encoder's context vector and generates an output sequence.
    """
    def __init__(self, output_vocab_size, embedding_dim, hidden_dim, n_layers,
                 cell_type='LSTM', dropout_p=0.1, encoder_bidirectional=False):
        """
        Initializes the Decoder.
        """
        super(Decoder, self).__init__()

        self.output_vocab_size = output_vocab_size
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.n_layers = n_layers
        self.cell_type = cell_type.upper()
        self.encoder_bidirectional = encoder_bidirectional


        self.embedding = nn.Embedding(output_vocab_size, embedding_dim)
        self.dropout = nn.Dropout(dropout_p)

        rnn_dropout = dropout_p if n_layers > 1 else 0
        if self.cell_type == 'RNN':
            self.rnn = nn.RNN(embedding_dim, hidden_dim, n_layers,
                              dropout=rnn_dropout, batch_first=True)
        elif self.cell_type == 'GRU':
            self.rnn = nn.GRU(embedding_dim, hidden_dim, n_layers,
                              dropout=rnn_dropout, batch_first=True)
        elif self.cell_type == 'LSTM':
            self.rnn = nn.LSTM(embedding_dim, hidden_dim, n_layers,
                               dropout=rnn_dropout, batch_first=True)
        else:
            raise ValueError("Unsupported cell type. Choose from 'RNN', 'GRU', 'LSTM'.")

        self.fc_out = nn.Linear(hidden_dim, output_vocab_size)

    def forward(self, input_char, hidden_state):
        """
        Forward pass for a single decoding step.
        """
        input_char = input_char.unsqueeze(1)

        embedded = self.embedding(input_char)
        embedded = self.dropout(embedded)

        rnn_output, new_hidden_state = self.rnn(embedded, hidden_state)

        rnn_output_squeezed = rnn_output.squeeze(1)

        prediction = self.fc_out(rnn_output_squeezed)

        return prediction, new_hidden_state

class Seq2Seq(nn.Module):
    """
    The main Seq2Seq model that combines Encoder and Decoder.
    """
    def __init__(self, encoder, decoder, device, target_sos_idx):
        """
        Initializes the Seq2Seq model.
        """
        super(Seq2Seq, self).__init__()

        self.encoder = encoder
        self.decoder = decoder
        self.device = device
        self.target_sos_idx = target_sos_idx

        if self.encoder.bidirectional:
            encoder_hidden_dim_actual = self.encoder.hidden_dim * self.encoder.num_directions
            decoder_hidden_dim_expected = self.decoder.hidden_dim

            self.fc_hidden = nn.Linear(encoder_hidden_dim_actual, decoder_hidden_dim_expected)
            if self.encoder.cell_type == 'LSTM':
                self.fc_cell = nn.Linear(encoder_hidden_dim_actual, decoder_hidden_dim_expected)


    def _adapt_encoder_hidden(self, encoder_hidden):
        """
        Adapts the encoder's final hidden state to be suitable as the decoder's initial hidden state.
        Handles bidirectional encoders by combining forward and backward states.
        """
        if not self.encoder.bidirectional:
            return encoder_hidden

        if self.encoder.cell_type == 'LSTM':
            h_n, c_n = encoder_hidden

            h_n = h_n.view(self.encoder.n_layers, self.encoder.num_directions, h_n.size(1), self.encoder.hidden_dim)
            c_n = c_n.view(self.encoder.n_layers, self.encoder.num_directions, c_n.size(1), self.encoder.hidden_dim)

            h_n_cat = torch.cat((h_n[:, 0, :, :], h_n[:, 1, :, :]), dim=2)
            c_n_cat = torch.cat((c_n[:, 0, :, :], c_n[:, 1, :, :]), dim=2)

            adapted_h = self.fc_hidden(h_n_cat)
            adapted_c = self.fc_cell(c_n_cat)
            return (adapted_h, adapted_c)

        else:
            h_n = encoder_hidden
            h_n = h_n.view(self.encoder.n_layers, self.encoder.num_directions, h_n.size(1), self.encoder.hidden_dim)
            h_n_cat = torch.cat((h_n[:, 0, :, :], h_n[:, 1, :, :]), dim=2)
            adapted_h = self.fc_hidden(h_n_cat)
            return adapted_h


    def forward(self, source_seq, target_seq, teacher_forcing_ratio=0.5):
        """
        Forward pass of the Seq2Seq model.
        """
        batch_size = source_seq.shape[0]
        target_len = target_seq.shape[1]
        target_vocab_size = self.decoder.output_vocab_size

        outputs = torch.zeros(batch_size, target_len, target_vocab_size).to(self.device)

        _, encoder_final_hidden = self.encoder(source_seq)

        decoder_hidden = self._adapt_encoder_hidden(encoder_final_hidden)

        decoder_input = target_seq[:, 0]

        for t in range(target_len -1):
            decoder_output_logits, decoder_hidden = self.decoder(decoder_input, decoder_hidden)

            outputs[:, t+1] = decoder_output_logits

            teacher_force_this_step = random.random() < teacher_forcing_ratio

            top1_predicted_token = decoder_output_logits.argmax(1)

            decoder_input = target_seq[:, t+1] if teacher_force_this_step else top1_predicted_token

        return outputs

    def predict(self, source_seq, max_output_len=50):
        """
        Generate a sequence of characters given a source sequence during inference.
        No teacher forcing is used.
        """
        self.eval()
        batch_size = source_seq.shape[0]
        if batch_size != 1:
            raise ValueError("Predict function currently supports batch_size=1 for simplicity.")

        with torch.no_grad():
            _, encoder_final_hidden = self.encoder(source_seq)
            decoder_hidden = self._adapt_encoder_hidden(encoder_final_hidden)

            decoder_input = torch.tensor([self.target_sos_idx], device=self.device)

            predicted_indices = []

            for _ in range(max_output_len):
                decoder_output_logits, decoder_hidden = self.decoder(decoder_input, decoder_hidden)

                top1_predicted_token = decoder_output_logits.argmax(1)
                predicted_idx = top1_predicted_token.item()
                predicted_indices.append(predicted_idx)

                decoder_input = top1_predicted_token

        self.train()
        return predicted_indices


if __name__ == '__main__':
    INPUT_VOCAB_SIZE = 50
    OUTPUT_VOCAB_SIZE = 60
    TARGET_SOS_IDX = 0
    TARGET_EOS_IDX = 1

    EMBEDDING_DIM = 128
    HIDDEN_DIM_ENCODER = 256
    HIDDEN_DIM_DECODER = 256
    N_LAYERS_ENCODER = 2
    N_LAYERS_DECODER = 2
    CELL_TYPE = 'LSTM'
    DROPOUT = 0.3
    ENCODER_BIDIRECTIONAL = True

    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    encoder = Encoder(INPUT_VOCAB_SIZE, EMBEDDING_DIM, HIDDEN_DIM_ENCODER, N_LAYERS_ENCODER,
                      CELL_TYPE, DROPOUT, bidirectional=ENCODER_BIDIRECTIONAL).to(DEVICE)

    decoder = Decoder(OUTPUT_VOCAB_SIZE, EMBEDDING_DIM, HIDDEN_DIM_DECODER, N_LAYERS_DECODER,
                      CELL_TYPE, DROPOUT, encoder_bidirectional=ENCODER_BIDIRECTIONAL).to(DEVICE)

    model = Seq2Seq(encoder, decoder, DEVICE, TARGET_SOS_IDX).to(DEVICE)

    print(f"Model initialized on {DEVICE}")
    print(f"Encoder cell type: {encoder.cell_type}, Layers: {encoder.n_layers}, Hidden: {encoder.hidden_dim}, Bidirectional: {encoder.bidirectional}")
    print(f"Decoder cell type: {decoder.cell_type}, Layers: {decoder.n_layers}, Hidden: {decoder.hidden_dim}")

    BATCH_SIZE = 4
    SOURCE_SEQ_LEN = 10
    TARGET_SEQ_LEN = 12

    dummy_source_seq = torch.randint(0, INPUT_VOCAB_SIZE, (BATCH_SIZE, SOURCE_SEQ_LEN)).to(DEVICE)

    dummy_target_seq = torch.randint(1, OUTPUT_VOCAB_SIZE, (BATCH_SIZE, TARGET_SEQ_LEN)).to(DEVICE)
    dummy_target_seq[:, 0] = TARGET_SOS_IDX

    print(f"\nDummy source shape: {dummy_source_seq.shape}")
    print(f"Dummy target shape: {dummy_target_seq.shape}")

    model.train()
    output_logits = model(dummy_source_seq, dummy_target_seq, teacher_forcing_ratio=0.5)
    print(f"\nOutput logits shape from forward pass: {output_logits.shape}")

    model.eval()
    dummy_single_source_seq = torch.randint(0, INPUT_VOCAB_SIZE, (1, SOURCE_SEQ_LEN)).to(DEVICE)
    predicted_sequence_indices = model.predict(dummy_single_source_seq, max_output_len=TARGET_SEQ_LEN)
    print(f"\nPredicted sequence indices for a single source: {predicted_sequence_indices}")
    print(f"Length of predicted sequence: {len(predicted_sequence_indices)}")

    def count_parameters(m):
        return sum(p.numel() for p in m.parameters() if p.requires_grad)

    print(f'\nThe model has {count_parameters(model):,} trainable parameters')

    print("\n--- Example of changing parameters ---")
    encoder_gru_unidir = Encoder(INPUT_VOCAB_SIZE, embedding_dim=64, hidden_dim=128, n_layers=1,
                                 cell_type='GRU', dropout_p=0.1, bidirectional=False).to(DEVICE)
    decoder_gru_unidir = Decoder(OUTPUT_VOCAB_SIZE, embedding_dim=64, hidden_dim=128, n_layers=1,
                                 cell_type='GRU', dropout_p=0.1, encoder_bidirectional=False).to(DEVICE)
    model_gru_unidir = Seq2Seq(encoder_gru_unidir, decoder_gru_unidir, DEVICE, TARGET_SOS_IDX).to(DEVICE)
    print(f"GRU Unidirectional Model initialized on {DEVICE}")
    print(f"Encoder cell type: {encoder_gru_unidir.cell_type}, Layers: {encoder_gru_unidir.n_layers}, Hidden: {encoder_gru_unidir.hidden_dim}, Bidirectional: {encoder_gru_unidir.bidirectional}")
    print(f"Decoder cell type: {decoder_gru_unidir.cell_type}, Layers: {decoder_gru_unidir.n_layers}, Hidden: {decoder_gru_unidir.hidden_dim}")
    print(f'The GRU Unidir model has {count_parameters(model_gru_unidir):,} trainable parameters')


# Q3 - Vanilla Model

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
import random
import numpy as np
import wandb
import os
from collections import Counter
from tqdm import tqdm
import heapq

# --- Constants ---
SOS_TOKEN = "<sos>"
EOS_TOKEN = "<eos>"
PAD_TOKEN = "<pad>"
UNK_TOKEN = "<unk>"

# --- Model Definition (Encoder, Decoder, Seq2Seq) ---
class Encoder(nn.Module):
    def __init__(self, input_vocab_size, embedding_dim, hidden_dim, n_layers,
                 cell_type='LSTM', dropout_p=0.1, bidirectional=False, pad_idx=0):
        super(Encoder, self).__init__()
        self.input_vocab_size = input_vocab_size
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.n_layers = n_layers
        self.cell_type = cell_type.upper()
        self.bidirectional = bidirectional
        self.num_directions = 2 if bidirectional else 1

        self.embedding = nn.Embedding(input_vocab_size, embedding_dim, padding_idx=pad_idx)
        self.dropout = nn.Dropout(dropout_p)

        rnn_dropout = dropout_p if n_layers > 1 else 0
        if self.cell_type == 'RNN':
            self.rnn = nn.RNN(embedding_dim, hidden_dim, n_layers,
                              dropout=rnn_dropout, batch_first=True, bidirectional=self.bidirectional)
        elif self.cell_type == 'GRU':
            self.rnn = nn.GRU(embedding_dim, hidden_dim, n_layers,
                              dropout=rnn_dropout, batch_first=True, bidirectional=self.bidirectional)
        elif self.cell_type == 'LSTM':
            self.rnn = nn.LSTM(embedding_dim, hidden_dim, n_layers,
                               dropout=rnn_dropout, batch_first=True, bidirectional=self.bidirectional)
        else:
            raise ValueError("Unsupported cell type. Choose from 'RNN', 'GRU', 'LSTM'.")

    def forward(self, input_seq, input_lengths):
        embedded = self.embedding(input_seq)
        embedded = self.dropout(embedded)

        packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, input_lengths.cpu(), batch_first=True, enforce_sorted=False)
        packed_outputs, hidden = self.rnn(packed_embedded)
        return None, hidden

class Decoder(nn.Module):
    def __init__(self, output_vocab_size, embedding_dim,
                 decoder_hidden_dim, n_layers, cell_type='LSTM', dropout_p=0.1, pad_idx=0):
        super(Decoder, self).__init__()
        self.output_vocab_size = output_vocab_size
        self.embedding_dim = embedding_dim
        self.decoder_hidden_dim = decoder_hidden_dim
        self.n_layers = n_layers
        self.cell_type = cell_type.upper()

        self.embedding = nn.Embedding(output_vocab_size, embedding_dim, padding_idx=pad_idx)
        self.dropout = nn.Dropout(dropout_p)

        rnn_input_dim = embedding_dim

        rnn_dropout = dropout_p if n_layers > 1 else 0
        if self.cell_type == 'RNN':
            self.rnn = nn.RNN(rnn_input_dim, decoder_hidden_dim, n_layers,
                              dropout=rnn_dropout, batch_first=True)
        elif self.cell_type == 'GRU':
            self.rnn = nn.GRU(rnn_input_dim, decoder_hidden_dim, n_layers,
                              dropout=rnn_dropout, batch_first=True)
        elif self.cell_type == 'LSTM':
            self.rnn = nn.LSTM(rnn_input_dim, decoder_hidden_dim, n_layers,
                               dropout=rnn_dropout, batch_first=True)
        else:
            raise ValueError("Unsupported cell type. Choose from 'RNN', 'GRU', 'LSTM'.")

        self.fc_out = nn.Linear(decoder_hidden_dim, output_vocab_size)

    def forward(self, input_char, prev_decoder_hidden):
        input_char = input_char.unsqueeze(1)
        embedded = self.embedding(input_char)
        embedded = self.dropout(embedded)

        rnn_output, current_decoder_hidden = self.rnn(embedded, prev_decoder_hidden)

        rnn_output_squeezed = rnn_output.squeeze(1)
        prediction_logits = self.fc_out(rnn_output_squeezed)
        return prediction_logits, current_decoder_hidden

class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device, target_sos_idx):
        super(Seq2Seq, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device
        self.target_sos_idx = target_sos_idx

        encoder_effective_output_dim_per_layer = self.encoder.hidden_dim * self.encoder.num_directions
        decoder_rnn_expected_hidden_dim = self.decoder.decoder_hidden_dim

        self.needs_dim_adaptation = encoder_effective_output_dim_per_layer != decoder_rnn_expected_hidden_dim
        self.fc_adapt_hidden = None
        self.fc_adapt_cell = None

        if self.needs_dim_adaptation:
            self.fc_adapt_hidden = nn.Linear(encoder_effective_output_dim_per_layer, decoder_rnn_expected_hidden_dim)
            if self.encoder.cell_type == 'LSTM':
                self.fc_adapt_cell = nn.Linear(encoder_effective_output_dim_per_layer, decoder_rnn_expected_hidden_dim)

    def _adapt_encoder_hidden_for_decoder(self, encoder_final_hidden_state):
        is_lstm = self.encoder.cell_type == 'LSTM'
        if is_lstm:
            h_from_enc, c_from_enc = encoder_final_hidden_state
        else:
            h_from_enc = encoder_final_hidden_state
            c_from_enc = None

        batch_size = h_from_enc.size(1)

        h_processed = h_from_enc.view(self.encoder.n_layers, self.encoder.num_directions,
                                     batch_size, self.encoder.hidden_dim)
        if self.encoder.bidirectional:
            h_processed = torch.cat([h_processed[:, 0, :, :], h_processed[:, 1, :, :]], dim=2)
        else:
            h_processed = h_processed.squeeze(1)

        c_processed = None
        if is_lstm and c_from_enc is not None:
            c_processed = c_from_enc.view(self.encoder.n_layers, self.encoder.num_directions,
                                         batch_size, self.encoder.hidden_dim)
            if self.encoder.bidirectional:
                c_processed = torch.cat([c_processed[:, 0, :, :], c_processed[:, 1, :, :]], dim=2)
            else:
                c_processed = c_processed.squeeze(1)

        if self.needs_dim_adaptation:
            h_processed = self.fc_adapt_hidden(h_processed)
            if is_lstm and c_processed is not None and self.fc_adapt_cell:
                c_processed = self.fc_adapt_cell(c_processed)

        final_h = torch.zeros(self.decoder.n_layers, batch_size, self.decoder.decoder_hidden_dim, device=self.device)
        final_c = torch.zeros(self.decoder.n_layers, batch_size, self.decoder.decoder_hidden_dim, device=self.device) if is_lstm else None

        if self.encoder.n_layers == self.decoder.n_layers:
            final_h = h_processed
            if is_lstm: final_c = c_processed
        elif self.encoder.n_layers > self.decoder.n_layers:
            final_h = h_processed[-self.decoder.n_layers:, :, :]
            if is_lstm and c_processed is not None:
                final_c = c_processed[-self.decoder.n_layers:, :, :]
        else:
            final_h[:self.encoder.n_layers, :, :] = h_processed
            if is_lstm and c_processed is not None:
                final_c[:self.encoder.n_layers, :, :] = c_processed

            if self.encoder.n_layers > 0:
                last_h_layer_to_repeat = h_processed[self.encoder.n_layers-1, :, :]
                for i in range(self.encoder.n_layers, self.decoder.n_layers):
                    final_h[i, :, :] = last_h_layer_to_repeat
                    if is_lstm and c_processed is not None:
                        last_c_layer_to_repeat = c_processed[self.encoder.n_layers-1, :, :]
                        final_c[i, :, :] = last_c_layer_to_repeat

        return (final_h, final_c) if is_lstm else final_h

    def forward(self, source_seq, source_lengths, target_seq, teacher_forcing_ratio=0.5):
        batch_size = source_seq.shape[0]
        target_len = target_seq.shape[1]
        target_vocab_size = self.decoder.output_vocab_size

        outputs_logits = torch.zeros(batch_size, target_len, target_vocab_size).to(self.device)

        _, encoder_final_hidden = self.encoder(source_seq, source_lengths)
        decoder_hidden = self._adapt_encoder_hidden_for_decoder(encoder_final_hidden)

        decoder_input = target_seq[:, 0]

        for t in range(target_len - 1):
            decoder_output_logits, decoder_hidden = \
                self.decoder(decoder_input, decoder_hidden)

            outputs_logits[:, t+1] = decoder_output_logits

            teacher_force_this_step = random.random() < teacher_forcing_ratio
            top1_predicted_token = decoder_output_logits.argmax(1)

            decoder_input = target_seq[:, t+1] if teacher_force_this_step else top1_predicted_token

        return outputs_logits

    def predict_greedy(self, source_seq, source_lengths, max_output_len=50, target_eos_idx=None):
        if source_seq.dim() == 1:
            source_seq = source_seq.unsqueeze(0)
            source_lengths = torch.tensor([source_lengths if isinstance(source_lengths, int) else len(source_lengths)], device=self.device)

        with torch.no_grad():
            _, encoder_final_hidden = self.encoder(source_seq, source_lengths)
            decoder_hidden = self._adapt_encoder_hidden_for_decoder(encoder_final_hidden)

            decoder_input = torch.tensor([self.target_sos_idx], device=self.device)
            predicted_indices = []

            for _ in range(max_output_len):
                decoder_output_logits, decoder_hidden = \
                    self.decoder(decoder_input, decoder_hidden)

                top1_predicted_token = decoder_output_logits.argmax(1)
                predicted_idx = top1_predicted_token.item()

                if target_eos_idx is not None and predicted_idx == target_eos_idx:
                    break
                predicted_indices.append(predicted_idx)
                decoder_input = top1_predicted_token
        return predicted_indices

    def predict_beam_search(self, source_seq, source_lengths, max_output_len=50, target_eos_idx=None, beam_width=3):
        if source_seq.dim() == 1:
            source_seq = source_seq.unsqueeze(0)
            source_lengths = torch.tensor([source_lengths if isinstance(source_lengths, int) else len(source_lengths)], device=self.device)

        batch_size = source_seq.shape[0]
        if batch_size != 1:
            raise ValueError("Beam search predict function currently supports batch_size=1.")

        with torch.no_grad():
            _, encoder_final_hidden = self.encoder(source_seq, source_lengths)
            decoder_hidden_init = self._adapt_encoder_hidden_for_decoder(encoder_final_hidden)

            beams = [(0.0, [self.target_sos_idx], decoder_hidden_init)]
            completed_sequences = []

            for _ in range(max_output_len):
                new_beams = []
                if len(completed_sequences) >= beam_width and all(b[1][-1] == target_eos_idx for b in beams if b[1]):
                    break

                for log_prob_beam, seq_beam, hidden_beam in beams:
                    if not seq_beam or seq_beam[-1] == target_eos_idx:
                        completed_sequences.append((log_prob_beam / len(seq_beam) if len(seq_beam) > 0 else -float('inf'), seq_beam))
                        continue

                    decoder_input = torch.tensor([seq_beam[-1]], device=self.device)

                    decoder_output_logits, next_hidden_beam = \
                        self.decoder(decoder_input, hidden_beam)

                    log_probs_next_token = F.log_softmax(decoder_output_logits, dim=1)
                    topk_log_probs, topk_indices = torch.topk(log_probs_next_token, beam_width, dim=1)

                    for k in range(beam_width):
                        next_token_idx = topk_indices[0, k].item()
                        token_log_prob = topk_log_probs[0, k].item()

                        new_seq = seq_beam + [next_token_idx]
                        new_log_prob = log_prob_beam + token_log_prob
                        new_beams.append((new_log_prob, new_seq, next_hidden_beam))

                if not new_beams: break

                new_beams.sort(key=lambda x: x[0], reverse=True)
                beams = new_beams[:beam_width]

            for log_prob_beam, seq_beam, _ in beams:
                if not seq_beam or seq_beam[-1] != target_eos_idx :
                    completed_sequences.append((log_prob_beam / len(seq_beam) if len(seq_beam) > 0 else -float('inf'), seq_beam))

            if not completed_sequences:
                return [target_eos_idx] if target_eos_idx is not None else []

            completed_sequences.sort(key=lambda x: x[0], reverse=True)

            best_sequence_indices = completed_sequences[0][1]
            return best_sequence_indices[1:] if best_sequence_indices and best_sequence_indices[0] == self.target_sos_idx else best_sequence_indices

# --- Data Loading and Preprocessing ---
class Vocabulary:
    def __init__(self, name):
        self.name = name
        self.char2index = {PAD_TOKEN: 0, SOS_TOKEN: 1, EOS_TOKEN: 2, UNK_TOKEN: 3}
        self.index2char = {0: PAD_TOKEN, 1: SOS_TOKEN, 2: EOS_TOKEN, 3: UNK_TOKEN}
        self.char_counts = Counter()
        self.n_chars = 4
        self.pad_idx = self.char2index[PAD_TOKEN]
        self.sos_idx = self.char2index[SOS_TOKEN]
        self.eos_idx = self.char2index[EOS_TOKEN]

    def add_sequence(self, sequence):
        for char in list(sequence):
            self.char_counts[char] += 1

    def build_vocab(self, min_freq=1):
        sorted_chars = sorted(self.char_counts.keys(), key=lambda char: (-self.char_counts[char], char))
        for char in sorted_chars:
            if self.char_counts[char] >= min_freq:
                if char not in self.char2index:
                    self.char2index[char] = self.n_chars
                    self.index2char[self.n_chars] = char
                    self.n_chars += 1

    def sequence_to_indices(self, sequence, add_eos=False, add_sos=False):
        indices = []
        if add_sos:
            indices.append(self.sos_idx)
        for char in list(sequence):
            indices.append(self.char2index.get(char, self.char2index[UNK_TOKEN]))
        if add_eos:
            indices.append(self.eos_idx)
        return indices

    def indices_to_sequence(self, indices):
        chars = []
        for index_val in indices:
            char = self.index2char.get(index_val, UNK_TOKEN)
            if index_val == self.eos_idx:
                break
            if index_val != self.sos_idx and index_val != self.pad_idx:
                chars.append(char)
        return "".join(chars)

class TransliterationDataset(Dataset):
    def __init__(self, file_path, source_vocab, target_vocab, max_len=None):
        self.pairs = []
        self.source_vocab = source_vocab
        self.target_vocab = target_vocab

        if not os.path.exists(file_path):
            print(f"ERROR: Data file not found during Dataset init: {file_path}")
            return

        print(f"Loading data from: {file_path}")
        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                for line_num, line in enumerate(f):
                    parts = line.strip().split('\t')
                    if len(parts) >= 2:
                        target_sequence, source_sequence = parts[0], parts[1]

                        if max_len and (len(source_sequence) > max_len or len(target_sequence) > max_len):
                            continue
                        if not source_sequence or not target_sequence:
                            continue
                        self.pairs.append((source_sequence, target_sequence))
            print(f"Loaded {len(self.pairs)} pairs from {file_path}.")
        except Exception as e:
            print(f"ERROR: Could not read or process file {file_path}. Error: {e}")

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        if idx >= len(self.pairs):
            raise IndexError("Index out of bounds for dataset")
        source_str, target_str = self.pairs[idx]
        source_indices = self.source_vocab.sequence_to_indices(source_str, add_eos=True)
        target_indices = self.target_vocab.sequence_to_indices(target_str, add_sos=True, add_eos=True)
        return torch.tensor(source_indices, dtype=torch.long), \
               torch.tensor(target_indices, dtype=torch.long)

def collate_fn(batch, pad_idx_source, pad_idx_target):
    batch = [item for item in batch if item is not None and item[0] is not None and item[1] is not None]
    if not batch:
        return None, None, None

    source_seqs, target_seqs = zip(*batch)

    valid_indices = [i for i, s in enumerate(source_seqs) if len(s) > 0]
    if not valid_indices: return None, None, None

    source_seqs = [source_seqs[i] for i in valid_indices]
    target_seqs = [target_seqs[i] for i in valid_indices]
    source_lengths = torch.tensor([len(s) for s in source_seqs], dtype=torch.long)

    padded_sources = pad_sequence(source_seqs, batch_first=True, padding_value=pad_idx_source)
    padded_targets = pad_sequence(target_seqs, batch_first=True, padding_value=pad_idx_target)
    return padded_sources, source_lengths, padded_targets

# --- Training and Evaluation Functions ---
def train_epoch(model, dataloader, optimizer, criterion, device, clip_value, teacher_forcing_ratio, target_vocab):
    model.train()
    epoch_loss = 0
    total_correct_train = 0
    total_samples_train = 0

    if len(dataloader) == 0:
        print("Warning: Training dataloader is empty.")
        return 0.0, 0.0

    for batch_data in tqdm(dataloader, desc="Training", leave=False):
        if batch_data[0] is None: continue
        sources, source_lengths, targets = batch_data

        if sources is None or source_lengths is None or targets is None or source_lengths.numel() == 0 or sources.shape[0] == 0:
            print("Warning: Empty or invalid batch data after collate_fn in training. Skipping.")
            continue

        sources, targets, source_lengths = sources.to(device), targets.to(device), source_lengths.to(device)

        optimizer.zero_grad()
        outputs_logits = model(sources, source_lengths, targets, teacher_forcing_ratio)

        output_dim = outputs_logits.shape[-1]
        flat_outputs = outputs_logits[:, 1:].reshape(-1, output_dim)
        flat_targets = targets[:, 1:].reshape(-1)

        loss = criterion(flat_outputs, flat_targets)
        if torch.isnan(loss) or torch.isinf(loss):
            print("Warning: NaN or Inf loss detected in training. Skipping batch.")
            continue
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip_value)
        optimizer.step()
        epoch_loss += loss.item()

        predictions_indices_train = outputs_logits.argmax(dim=2)
        for i in range(targets.shape[0]):
            pred_str_train = target_vocab.indices_to_sequence(predictions_indices_train[i, 1:].tolist())
            true_str_train = target_vocab.indices_to_sequence(targets[i, 1:].tolist())
            if pred_str_train == true_str_train:
                total_correct_train += 1
            total_samples_train += 1

    avg_epoch_loss = epoch_loss / len(dataloader) if len(dataloader) > 0 else 0.0
    train_accuracy = total_correct_train / total_samples_train if total_samples_train > 0 else 0.0
    return avg_epoch_loss, train_accuracy

def evaluate(model, dataloader, criterion, device, target_vocab, beam_width=1, target_eos_idx=None):
    model.eval()
    epoch_loss = 0
    total_correct = 0
    total_samples = 0

    if len(dataloader) == 0:
        print("WARNING: Validation dataloader is empty. Returning 0 loss and 0 accuracy.")
        return 0.0, 0.0

    print_debug_once = True

    with torch.no_grad():
        for batch_idx, batch_data in enumerate(tqdm(dataloader, desc="Evaluating", leave=False)):
            if batch_data[0] is None: continue
            sources, source_lengths, targets = batch_data

            if sources is None or source_lengths is None or targets is None or source_lengths.numel() == 0 or sources.shape[0] == 0:
                print("Warning: Empty or invalid batch data in eval after collate_fn. Skipping.")
                continue

            sources, targets, source_lengths = sources.to(device), targets.to(device), source_lengths.to(device)

            outputs_for_loss = model(sources, source_lengths, targets, teacher_forcing_ratio=0.0)
            output_dim = outputs_for_loss.shape[-1]
            flat_outputs_for_loss = outputs_for_loss[:, 1:].reshape(-1, output_dim)
            flat_targets_for_loss = targets[:, 1:].reshape(-1)

            loss = criterion(flat_outputs_for_loss, flat_targets_for_loss)
            if torch.isnan(loss) or torch.isinf(loss):
                print("Warning: NaN or Inf loss detected in evaluation. Skipping batch for loss accumulation.")
            else:
                epoch_loss += loss.item()

            for i in range(sources.shape[0]):
                src_single = sources[i:i+1]
                src_len_single = source_lengths[i:i+1]

                if beam_width > 1 and hasattr(model, 'predict_beam_search'):
                    predicted_indices = model.predict_beam_search(src_single, src_len_single,
                                                                  max_output_len=targets.size(1),
                                                                  target_eos_idx=target_eos_idx,
                                                                  beam_width=beam_width)
                elif hasattr(model, 'predict_greedy'):
                    predicted_indices = model.predict_greedy(src_single, src_len_single,
                                                             max_output_len=targets.size(1),
                                                             target_eos_idx=target_eos_idx)
                else:
                    predicted_indices = outputs_for_loss[i:i+1].argmax(dim=2)[0, 1:].tolist()

                pred_str = target_vocab.indices_to_sequence(predicted_indices)
                true_seq_indices = targets[i, 1:].tolist()
                true_str = target_vocab.indices_to_sequence(true_seq_indices)

                if pred_str == true_str:
                    total_correct += 1
                total_samples += 1

                if print_debug_once and batch_idx == 0 and i < 1 :
                    print(f"\n--- Evaluation Debug Sample {i} (Batch {batch_idx}) ---")
                    print(f"    Source (indices): {src_single[0, :15].tolist()}")
                    print(f"    Predicted Indices: {predicted_indices[:15]}")
                    print(f"    Predicted String: '{pred_str}'")
                    print(f"    True Indices: {true_seq_indices[:15]}")
                    print(f"    True String: '{true_str}'")
                    print(f"    Match: {pred_str == true_str}")

            if print_debug_once and batch_idx == 0:
                print_debug_once = False

    avg_epoch_loss = epoch_loss / len(dataloader) if len(dataloader) > 0 else 0.0
    accuracy = total_correct / total_samples if total_samples > 0 else 0.0

    print(f"Evaluation - Total Correct: {total_correct}, Total Samples: {total_samples}, Calculated Accuracy: {accuracy:.4f}, Avg Loss: {avg_epoch_loss:.4f}")
    return avg_epoch_loss, accuracy

# --- Main Training Function for W&B ---
def train_model():
    with wandb.init() as run:
        config = wandb.config

        run_name = f"{config.cell_type}_emb{config.embedding_dim}_hid{config.hidden_dim}" \
                   f"_encL{config.encoder_layers}_decL{config.decoder_layers}" \
                   f"_do{config.dropout_p:.2f}_lr{config.learning_rate:.2e}" \
                   f"_encBi{str(config.encoder_bidirectional)[0]}" \
                   f"_beam{config.get('beam_width_eval', 1)}"

        if hasattr(run, 'name'): run.name = run_name

        DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        current_run_name_for_log = run.name if run.name else run.id
        print(f"Using device: {DEVICE}. Run: {current_run_name_for_log} (ID: {run.id})")
        print(f"Config: {config}")

        BASE_DATA_DIR = "/kaggle/input/dakshina-dl-a3/dakshina_dataset_v1.0/hi/"
        DATA_DIR = os.path.join(BASE_DATA_DIR, "lexicons/")

        if not os.path.exists(DATA_DIR):
            print(f"ERROR: Lexicons directory not found: {DATA_DIR}")
            return 1

        train_file = os.path.join(DATA_DIR, "hi.translit.sampled.train.tsv")
        dev_file = os.path.join(DATA_DIR, "hi.translit.sampled.dev.tsv")

        if not os.path.exists(train_file) or not os.path.exists(dev_file):
            print(f"ERROR: Train or Dev file not found. Train: {train_file}, Dev: {dev_file}")
            return 1

        source_vocab = Vocabulary("latin")
        target_vocab = Vocabulary("devanagari")

        temp_train_dataset_for_vocab = TransliterationDataset(train_file, source_vocab, target_vocab, max_len=config.max_seq_len)
        if not temp_train_dataset_for_vocab.pairs:
            print(f"ERROR: No data loaded for vocab building from {train_file}.")
            return 1
        for src_str, tgt_str in temp_train_dataset_for_vocab.pairs:
            source_vocab.add_sequence(src_str)
            target_vocab.add_sequence(tgt_str)
        source_vocab.build_vocab(min_freq=config.vocab_min_freq)
        target_vocab.build_vocab(min_freq=config.vocab_min_freq)

        print(f"Source Vocab: {source_vocab.n_chars} chars. Target Vocab: {target_vocab.n_chars} chars.")

        train_dataset = TransliterationDataset(train_file, source_vocab, target_vocab, max_len=config.max_seq_len)
        dev_dataset = TransliterationDataset(dev_file, source_vocab, target_vocab, max_len=config.max_seq_len)

        if len(train_dataset) == 0 or len(dev_dataset) == 0:
            print(f"ERROR: Train/Dev dataset empty. Train: {len(train_dataset)}, Dev: {len(dev_dataset)}")
            return 1

        num_loader_workers = 0
        if DEVICE.type == 'cuda' and os.cpu_count() and os.cpu_count() > 1:
            num_loader_workers = min(4, os.cpu_count() // 2)

        train_dataloader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True,
                                      collate_fn=lambda b: collate_fn(b, source_vocab.pad_idx, target_vocab.pad_idx),
                                      num_workers=num_loader_workers, pin_memory=True if DEVICE.type == 'cuda' else False, drop_last=True)
        dev_dataloader = DataLoader(dev_dataset, batch_size=config.batch_size, shuffle=False,
                                     collate_fn=lambda b: collate_fn(b, source_vocab.pad_idx, target_vocab.pad_idx),
                                     num_workers=num_loader_workers, pin_memory=True if DEVICE.type == 'cuda' else False, drop_last=False)

        if len(train_dataloader) == 0 or len(dev_dataloader) == 0:
            print(f"ERROR: Train/Dev Dataloader empty. Train: {len(train_dataloader)}, Dev: {len(dev_dataloader)}")
            return 1

        encoder = Encoder(source_vocab.n_chars, config.embedding_dim, config.hidden_dim,
                          config.encoder_layers, config.cell_type, config.dropout_p,
                          config.encoder_bidirectional, pad_idx=source_vocab.pad_idx).to(DEVICE)
        decoder = Decoder(target_vocab.n_chars, config.embedding_dim,
                          config.hidden_dim,
                          config.decoder_layers, config.cell_type, config.dropout_p,
                          pad_idx=target_vocab.pad_idx).to(DEVICE)
        model = Seq2Seq(encoder, decoder, DEVICE, target_vocab.sos_idx).to(DEVICE)

        print(f"Model parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
        wandb.watch(model, log="all", log_freq=100)

        optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=config.get('lr_scheduler_patience', 3), factor=0.3, verbose=True)
        criterion = nn.CrossEntropyLoss(ignore_index=target_vocab.pad_idx)

        best_val_accuracy = -1.0
        epochs_no_improve = 0
        max_epochs_no_improve = config.get('max_epochs_no_improve', 7)

        for epoch in range(config.epochs):
            train_loss, train_accuracy = train_epoch(model, train_dataloader, optimizer, criterion, DEVICE,
                                                     config.clip_value, config.teacher_forcing_ratio, target_vocab)

            val_loss, val_accuracy = evaluate(model, dev_dataloader, criterion, DEVICE, target_vocab,
                                              beam_width=config.get('beam_width_eval', 1),
                                              target_eos_idx=target_vocab.eos_idx)

            scheduler.step(val_accuracy)

            print(f"Epoch {epoch+1}/{config.epochs} | Train Loss: {train_loss:.4f} | Train Acc: {train_accuracy:.4f} | Val Loss: {val_loss:.4f} | Val Acc: {val_accuracy:.4f}")
            log_dict = {
                "epoch": epoch + 1,
                "train_loss": train_loss if not np.isnan(train_loss) else 0.0,
                "train_accuracy": train_accuracy if not np.isnan(train_accuracy) else 0.0,
                "val_loss": val_loss if not np.isnan(val_loss) else 0.0,
                "val_accuracy": val_accuracy if not np.isnan(val_accuracy) else 0.0,
                "learning_rate": optimizer.param_groups[0]['lr']
            }
            wandb.log(log_dict)

            current_val_acc = val_accuracy if not np.isnan(val_accuracy) else -1.0
            if current_val_acc > best_val_accuracy:
                best_val_accuracy = current_val_acc
                epochs_no_improve = 0
            else:
                epochs_no_improve += 1

            if config.early_stopping and epochs_no_improve >= max_epochs_no_improve:
                print(f"Early stopping triggered at epoch {epoch+1} after {epochs_no_improve} epochs with no improvement.")
                break

        wandb.summary["best_val_accuracy"] = best_val_accuracy if not np.isnan(best_val_accuracy) else 0.0
        print(f"Finished run. Best Validation Accuracy: {best_val_accuracy:.4f}")
        return 0

# --- W&B Sweep Configuration ---
sweep_config = {
    'method': 'bayes',
    'metric': {'name': 'val_accuracy', 'goal': 'maximize'},
    'parameters': {
        'embedding_dim': {'values': [128, 256, 300]},
        'hidden_dim': {'values': [256, 512]},
        'encoder_layers': {'values': [1, 2]},
        'decoder_layers': {'values': [1, 2]},
        'cell_type': {'values': ['GRU', 'LSTM']},
        'dropout_p': {'values': [0.2, 0.3, 0.4, 0.5]},
        'encoder_bidirectional': {'values': [True, False]},
        'learning_rate': {'distribution': 'log_uniform_values', 'min': 1e-4, 'max': 3e-3},
        'batch_size': {'values': [32, 64, 128]},
        'epochs': {'value': 15},
        'clip_value': {'value': 1.0},
        'teacher_forcing_ratio': {'distribution': 'uniform', 'min': 0.4, 'max': 0.6},
        'vocab_min_freq': {'value': 1},
        'max_seq_len': {'value': 50},
        'early_stopping': {'value': True},
        'max_epochs_no_improve': {'values': [5, 7]},
        'lr_scheduler_patience': {'values': [2, 3]},
        'beam_width_eval': {'values': [1, 3]}
    }
}

# --- Main Execution Block ---
if __name__ == '__main__':
    try:
        if 'KAGGLE_KERNEL_RUN_TYPE' in os.environ:
            print("Detected Kaggle environment. Ensure WANDB_API_KEY is set as a secret or environment variable.")
            try:
                from kaggle_secrets import UserSecretsClient
                user_secrets = UserSecretsClient()
                wandb_api_key = user_secrets.get_secret("WANDB_API_KEY")
                wandb.login(key=wandb_api_key)
                print("W&B login with Kaggle secret successful.")
            except Exception as e_secret:
                print(f"Could not login with Kaggle secret ({e_secret}). Attempting default login.")
                if "WANDB_API_KEY" in os.environ:
                    wandb.login()
                    print("W&B login using environment variable.")
                else:
                    print("WANDB_API_KEY not found. Please set it up or login manually if prompted.")
                    wandb.login()
        else:
            wandb.login()
        print("W&B login process attempted/completed.")
    except Exception as e:
        print(f"W&B login failed: {e}. Please ensure your W&B API key is correctly configured.")
        exit()

    SWEEP_PROJECT_NAME = "DL_A3"

    print("Initializing sweep...")
    try:
        sweep_id = wandb.sweep(sweep_config, project=SWEEP_PROJECT_NAME)
        print(f"Sweep ID: {sweep_id}")
        print(f"To run agents, execute: wandb agent YOUR_WANDB_USERNAME/{SWEEP_PROJECT_NAME}/{sweep_id}")

        print("Starting a W&B agent (adjust count as needed)...")
        wandb.agent(sweep_id, function=train_model, count=10)
    except Exception as e:
        print(f"Could not initialize sweep or run agent. Error: {e}")

    print("Sweep agent finished or stopped. Check W&B dashboard for results.")

# Q4 train and testing

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
import random
import numpy as np
import wandb
import os
from collections import Counter
from tqdm import tqdm
import heapq

# --- Constants for Special Tokens ---
SOS_TOKEN = "<sos>"  # Start-of-sequence token
EOS_TOKEN = "<eos>"  # End-of-sequence token
PAD_TOKEN = "<pad>"  # Padding token
UNK_TOKEN = "<unk>"  # Unknown token

# --- Model Definitions (Non-Attention Version) ---

class Encoder(nn.Module):
    """
    The Encoder processes the input sequence and produces a context vector (final hidden state).
    It uses a recurrent neural network (RNN, GRU, or LSTM).
    """
    def __init__(self, input_vocab_size, embedding_dim, hidden_dim, n_layers,
                 cell_type='LSTM', dropout_p=0.1, bidirectional=False, pad_idx=0):
        super(Encoder, self).__init__()
        self.input_vocab_size = input_vocab_size
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.n_layers = n_layers
        self.cell_type = cell_type.upper()
        self.bidirectional = bidirectional
        self.num_directions = 2 if bidirectional else 1

        # Embedding layer converts input tokens (indices) into dense vectors
        self.embedding = nn.Embedding(input_vocab_size, embedding_dim, padding_idx=pad_idx)
        # Dropout layer for regularization
        self.dropout = nn.Dropout(dropout_p)

        # Apply dropout to RNN layers if n_layers > 1
        rnn_dropout = dropout_p if n_layers > 1 else 0
        if self.cell_type == 'RNN':
            self.rnn = nn.RNN(embedding_dim, hidden_dim, n_layers,
                              dropout=rnn_dropout, batch_first=True, bidirectional=self.bidirectional)
        elif self.cell_type == 'GRU':
            self.rnn = nn.GRU(embedding_dim, hidden_dim, n_layers,
                              dropout=rnn_dropout, batch_first=True, bidirectional=self.bidirectional)
        elif self.cell_type == 'LSTM':
            self.rnn = nn.LSTM(embedding_dim, hidden_dim, n_layers,
                               dropout=rnn_dropout, batch_first=True, bidirectional=self.bidirectional)
        else:
            raise ValueError("Unsupported cell type. Choose from 'RNN', 'GRU', 'LSTM'.")

    def forward(self, input_seq, input_lengths):
        """
        Forward pass for the encoder.

        Args:
            input_seq (torch.Tensor): Padded input sequences of shape (batch_size, seq_len).
            input_lengths (torch.Tensor): Lengths of the original sequences in the batch of shape (batch_size,).

        Returns:
            tuple: A tuple (None, hidden_state).
                   The first element is None because this encoder does not output full sequences for attention.
                   The second element is the final hidden state of the encoder, which serves as the context.
                   For LSTM, `hidden_state` is a tuple (h, c).
        """
        embedded = self.embedding(input_seq)
        embedded = self.dropout(embedded)

        # Pack the padded sequences to handle variable-length inputs efficiently
        packed_embedded = nn.utils.rnn.pack_padded_sequence(
            embedded, input_lengths.cpu(), batch_first=True, enforce_sorted=False
        )
        # Pass packed sequences through the RNN
        # For a non-attention decoder, we primarily need the final hidden state.
        _, hidden = self.rnn(packed_embedded)
        return None, hidden # Return None for encoder_outputs as they are not used by the simple decoder

class Decoder(nn.Module):
    """
    The Decoder generates the output sequence one token at a time, conditioned on the
    encoder's final hidden state and previously generated tokens.
    This version does NOT use attention.
    """
    def __init__(self, output_vocab_size, embedding_dim,
                 decoder_hidden_dim, n_layers, cell_type='LSTM', dropout_p=0.1, pad_idx=0):
        super(Decoder, self).__init__()
        self.output_vocab_size = output_vocab_size
        self.embedding_dim = embedding_dim
        self.decoder_hidden_dim = decoder_hidden_dim
        self.n_layers = n_layers
        self.cell_type = cell_type.upper()

        self.embedding = nn.Embedding(output_vocab_size, embedding_dim, padding_idx=pad_idx)
        self.dropout = nn.Dropout(dropout_p)

        # Input to decoder RNN is just the embedding of the previous token
        rnn_input_dim = embedding_dim

        # Apply dropout to RNN layers if n_layers > 1
        rnn_dropout = dropout_p if n_layers > 1 else 0
        if self.cell_type == 'RNN':
            self.rnn = nn.RNN(rnn_input_dim, decoder_hidden_dim, n_layers,
                              dropout=rnn_dropout, batch_first=True)
        elif self.cell_type == 'GRU':
            self.rnn = nn.GRU(rnn_input_dim, decoder_hidden_dim, n_layers,
                              dropout=rnn_dropout, batch_first=True)
        elif self.cell_type == 'LSTM':
            self.rnn = nn.LSTM(rnn_input_dim, decoder_hidden_dim, n_layers,
                               dropout=rnn_dropout, batch_first=True)
        else:
            raise ValueError("Unsupported cell type. Choose from 'RNN', 'GRU', 'LSTM'.")

        # Output linear layer to project decoder's hidden state to vocabulary size
        self.fc_out = nn.Linear(decoder_hidden_dim, output_vocab_size)

    def forward(self, input_char, prev_decoder_hidden):
        """
        Forward pass for the decoder.

        Args:
            input_char (torch.Tensor): A single token (or batch of single tokens) to be embedded,
                                       shape (batch_size,).
            prev_decoder_hidden (torch.Tensor or tuple): The previous hidden state (and cell state for LSTM)
                                                         of the decoder RNN.

        Returns:
            tuple: A tuple (prediction_logits, current_decoder_hidden).
                   `prediction_logits` are the raw scores for each token in the vocabulary.
                   `current_decoder_hidden` is the updated hidden state after processing the input.
        """
        # Add a sequence dimension for RNN input (batch_size, 1, embedding_dim)
        input_char = input_char.unsqueeze(1)
        embedded = self.embedding(input_char)
        embedded = self.dropout(embedded)

        # Pass through the RNN layer
        rnn_output, current_decoder_hidden = self.rnn(embedded, prev_decoder_hidden)

        # Squeeze the sequence dimension for the linear layer
        rnn_output_squeezed = rnn_output.squeeze(1)
        # Project to vocabulary size to get logits
        prediction_logits = self.fc_out(rnn_output_squeezed)
        return prediction_logits, current_decoder_hidden

class Seq2Seq(nn.Module):
    """
    The main Sequence-to-Sequence model that connects the Encoder and Decoder.
    This architecture does NOT use an attention mechanism.
    """
    def __init__(self, encoder, decoder, device, target_sos_idx):
        super(Seq2Seq, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device
        self.target_sos_idx = target_sos_idx

        # Calculate effective dimensions for hidden state adaptation
        encoder_effective_output_dim_per_layer = self.encoder.hidden_dim * self.encoder.num_directions
        decoder_rnn_expected_hidden_dim = self.decoder.decoder_hidden_dim

        # Check if hidden state dimensions need to be adapted from encoder to decoder
        self.needs_dim_adaptation = encoder_effective_output_dim_per_layer != decoder_rnn_expected_hidden_dim
        self.fc_adapt_hidden = None
        self.fc_adapt_cell = None

        # Create linear layers for hidden state adaptation if necessary
        if self.needs_dim_adaptation:
            self.fc_adapt_hidden = nn.Linear(encoder_effective_output_dim_per_layer, decoder_rnn_expected_hidden_dim)
            if self.encoder.cell_type == 'LSTM':
                self.fc_adapt_cell = nn.Linear(encoder_effective_output_dim_per_layer, decoder_rnn_expected_hidden_dim)

    def _adapt_encoder_hidden_for_decoder(self, encoder_final_hidden_state):
        """
        Adapts the encoder's final hidden state(s) to match the decoder's expected hidden state dimensions and layer count.
        Handles bidirectionality by concatenating forward and backward hidden states.
        If decoder has more layers than encoder, the last encoder layer's state is repeated.
        """
        is_lstm = self.encoder.cell_type == 'LSTM'
        if is_lstm:
            h_from_enc, c_from_enc = encoder_final_hidden_state
        else:
            h_from_enc = encoder_final_hidden_state
            c_from_enc = None

        batch_size = h_from_enc.size(1)

        # Reshape encoder hidden state to (n_layers, num_directions, batch_size, hidden_dim)
        h_processed = h_from_enc.view(self.encoder.n_layers, self.encoder.num_directions,
                                     batch_size, self.encoder.hidden_dim)
        if self.encoder.bidirectional:
            # Concatenate forward and backward hidden states across the hidden dimension
            h_processed = torch.cat([h_processed[:, 0, :, :], h_processed[:, 1, :, :]], dim=2)
        else:
            # Remove the single direction dimension
            h_processed = h_processed.squeeze(1)

        c_processed = None
        if is_lstm and c_from_enc is not None:
            c_processed = c_from_enc.view(self.encoder.n_layers, self.encoder.num_directions,
                                         batch_size, self.encoder.hidden_dim)
            if self.encoder.bidirectional:
                c_processed = torch.cat([c_processed[:, 0, :, :], c_processed[:, 1, :, :]], dim=2)
            else:
                c_processed = c_processed.squeeze(1)

        # Apply linear transformation if hidden dimensions mismatch
        if self.needs_dim_adaptation:
            h_processed = self.fc_adapt_hidden(h_processed)
            if is_lstm and c_processed is not None and self.fc_adapt_cell:
                c_processed = self.fc_adapt_cell(c_processed)

        # Initialize decoder's hidden state(s)
        final_h = torch.zeros(self.decoder.n_layers, batch_size, self.decoder.decoder_hidden_dim, device=self.device)
        final_c = torch.zeros(self.decoder.n_layers, batch_size, self.decoder.decoder_hidden_dim, device=self.device) if is_lstm else None

        # Copy or repeat encoder hidden states to match decoder's layer count
        if self.encoder.n_layers == self.decoder.n_layers:
            final_h = h_processed
            if is_lstm: final_c = c_processed
        elif self.encoder.n_layers > self.decoder.n_layers:
            # Use the last N encoder layers where N is decoder_layers
            final_h = h_processed[-self.decoder.n_layers:, :, :]
            if is_lstm and c_processed is not None:
                final_c = c_processed[-self.decoder.n_layers:, :, :]
        else:
            # Copy encoder states to the first M decoder layers, and repeat the last encoder layer state for the rest
            final_h[:self.encoder.n_layers, :, :] = h_processed
            if is_lstm and c_processed is not None:
                final_c[:self.encoder.n_layers, :, :] = c_processed

            if self.encoder.n_layers > 0:
                last_h_layer_to_repeat = h_processed[self.encoder.n_layers-1, :, :]
                for i in range(self.encoder.n_layers, self.decoder.n_layers):
                    final_h[i, :, :] = last_h_layer_to_repeat
                    if is_lstm and c_processed is not None:
                        last_c_layer_to_repeat = c_processed[self.encoder.n_layers-1, :, :]
                        final_c[i, :, :] = last_c_layer_to_repeat

        return (final_h, final_c) if is_lstm else final_h

    def forward(self, source_seq, source_lengths, target_seq, teacher_forcing_ratio=0.5):
        """
        Forward pass for the Seq2Seq model during training.

        Args:
            source_seq (torch.Tensor): Padded source sequences.
            source_lengths (torch.Tensor): Lengths of source sequences.
            target_seq (torch.Tensor): Padded target sequences (including SOS token).
            teacher_forcing_ratio (float): Probability of using actual target token as next input.

        Returns:
            torch.Tensor: Logits for the predicted target sequence.
        """
        batch_size = source_seq.shape[0]
        target_len = target_seq.shape[1]
        target_vocab_size = self.decoder.output_vocab_size

        # Tensor to store decoder outputs
        outputs_logits = torch.zeros(batch_size, target_len, target_vocab_size).to(self.device)

        # Encode the source sequence to get the initial hidden state for the decoder
        # Encoder returns None for encoder_outputs as they are not used by the simple decoder
        _, encoder_final_hidden = self.encoder(source_seq, source_lengths)
        # Adapt encoder hidden state to decoder's expected shape
        decoder_hidden = self._adapt_encoder_hidden_for_decoder(encoder_final_hidden)

        # First input to the decoder is the <sos> token
        decoder_input = target_seq[:, 0]

        # Iterate through the target sequence, predicting one token at a time
        for t in range(target_len - 1): # Exclude the last token as we predict up to target_len-1
            # Pass input and previous hidden state to the decoder
            decoder_output_logits, decoder_hidden = \
                self.decoder(decoder_input, decoder_hidden)

            # Store the current step's predictions
            outputs_logits[:, t+1] = decoder_output_logits

            # Decide whether to use teacher forcing
            teacher_force_this_step = random.random() < teacher_forcing_ratio
            # Get the top predicted token for the next input if not teacher forcing
            top1_predicted_token = decoder_output_logits.argmax(1)

            # Use actual target token (teacher forcing) or predicted token for the next step
            decoder_input = target_seq[:, t+1] if teacher_force_this_step else top1_predicted_token

        return outputs_logits

    def predict_greedy(self, source_seq, source_lengths, max_output_len=50, target_eos_idx=None):
        """
        Generates a sequence using greedy decoding.

        Args:
            source_seq (torch.Tensor): Input source sequence (can be a single example or a batch of 1).
            source_lengths (torch.Tensor): Length of the source sequence.
            max_output_len (int): Maximum length of the sequence to generate.
            target_eos_idx (int, optional): Index of the EOS token in the target vocabulary.

        Returns:
            list: List of predicted token indices.
        """
        self.eval() # Set model to evaluation mode
        # Ensure input sequence is batched (even if batch size is 1)
        if source_seq.dim() == 1:
            source_seq = source_seq.unsqueeze(0)
            source_lengths = torch.tensor([source_lengths if isinstance(source_lengths, int) else len(source_lengths)], device=self.device)

        with torch.no_grad(): # Disable gradient calculations
            # Encode the source sequence
            _, encoder_final_hidden = self.encoder(source_seq, source_lengths)
            # Adapt encoder hidden state for decoder initialization
            decoder_hidden = self._adapt_encoder_hidden_for_decoder(encoder_final_hidden)

            # First input to the decoder is the <sos> token
            decoder_input = torch.tensor([self.target_sos_idx], device=self.device)
            predicted_indices = []

            for _ in range(max_output_len):
                # Get decoder output and updated hidden state
                decoder_output_logits, decoder_hidden = \
                    self.decoder(decoder_input, decoder_hidden)

                # Greedily select the token with the highest probability
                top1_predicted_token = decoder_output_logits.argmax(1)
                predicted_idx = top1_predicted_token.item()

                # Stop if EOS token is predicted
                if target_eos_idx is not None and predicted_idx == target_eos_idx:
                    break

                predicted_indices.append(predicted_idx)
                # Use the predicted token as the input for the next step
                decoder_input = top1_predicted_token
        return predicted_indices

    def predict_beam_search(self, source_seq, source_lengths, max_output_len=50, target_eos_idx=None, beam_width=3):
        """
        Generates a sequence using beam search decoding.

        Args:
            source_seq (torch.Tensor): Input source sequence (must be a single example).
            source_lengths (torch.Tensor): Length of the source sequence.
            max_output_len (int): Maximum length of the sequence to generate.
            target_eos_idx (int, optional): Index of the EOS token in the target vocabulary.
            beam_width (int): The number of top sequences to keep at each step.

        Returns:
            list: List of predicted token indices for the best sequence found.
        """
        self.eval() # Set model to evaluation mode
        # Ensure input sequence is batched (even if batch size is 1)
        if source_seq.dim() == 1:
            source_seq = source_seq.unsqueeze(0)
            source_lengths = torch.tensor([source_lengths if isinstance(source_lengths, int) else len(source_lengths)], device=self.device)

        if source_seq.shape[0] != 1:
            raise ValueError("Beam search predict function currently supports batch_size=1.")

        with torch.no_grad(): # Disable gradient calculations
            # Encode the source sequence
            _, encoder_final_hidden = self.encoder(source_seq, source_lengths)
            # Adapt encoder hidden state for decoder initialization
            decoder_hidden_init = self._adapt_encoder_hidden_for_decoder(encoder_final_hidden)

            # Beams are stored as (log_probability, sequence_of_indices, decoder_hidden_state)
            # Use a min-heap to keep track of the top `beam_width` sequences (smallest log_prob at top)
            beams = [(0.0, [self.target_sos_idx], decoder_hidden_init)]
            completed_sequences = [] # Stores sequences that have predicted EOS

            for _ in range(max_output_len):
                new_beams = []
                all_current_beams_ended = True # Flag to check if all beams have terminated

                # Process each beam in the current set of beams
                for log_prob_beam, seq_beam, hidden_beam in beams:
                    # If this beam has already ended, move it to completed sequences and skip
                    if not seq_beam or seq_beam[-1] == target_eos_idx:
                        # Normalize log probability by length to counteract bias towards shorter sequences
                        completed_sequences.append((log_prob_beam / len(seq_beam) if len(seq_beam) > 0 else -float('inf'), seq_beam))
                        continue # Skip to the next beam

                    all_current_beams_ended = False # At least one beam is still active

                    # Get the last token from the current beam sequence as decoder input
                    decoder_input = torch.tensor([seq_beam[-1]], device=self.device)
                    # Pass through the decoder to get next token logits and hidden state
                    decoder_output_logits, next_hidden_beam = \
                        self.decoder(decoder_input, hidden_beam)

                    # Convert logits to log probabilities and get top K candidates
                    log_probs_next_token = F.log_softmax(decoder_output_logits, dim=1)
                    topk_log_probs, topk_indices = torch.topk(log_probs_next_token, beam_width, dim=1)

                    # Expand each current beam into `beam_width` new beams
                    for k in range(beam_width):
                        next_token_idx = topk_indices[0, k].item()
                        token_log_prob = topk_log_probs[0, k].item()

                        new_seq = seq_beam + [next_token_idx]
                        new_log_prob = log_prob_beam + token_log_prob # Accumulate log probability
                        new_beams.append((new_log_prob, new_seq, next_hidden_beam))

                # If no new beams were generated (e.g., all beams ended) or all current beams ended, stop
                if not new_beams or all_current_beams_ended: break

                # Sort all new candidate beams by log probability and select the top `beam_width`
                new_beams.sort(key=lambda x: x[0], reverse=True)
                beams = new_beams[:beam_width]

            # After generation loop, add any remaining active beams to completed sequences
            for log_prob_beam, seq_beam, _ in beams:
                 if not seq_beam or seq_beam[-1] != target_eos_idx:
                    completed_sequences.append((log_prob_beam / len(seq_beam) if len(seq_beam) > 0 else -float('inf'), seq_beam))

            # If no sequences were completed (should not happen with good parameters), return default
            if not completed_sequences:
                return [target_eos_idx] if target_eos_idx is not None else []

            # Sort all completed sequences by normalized log probability and return the best one
            completed_sequences.sort(key=lambda x: x[0], reverse=True)

            best_sequence_indices = completed_sequences[0][1]
            # Remove the SOS token if it's at the beginning of the best sequence
            return best_sequence_indices[1:] if best_sequence_indices and best_sequence_indices[0] == self.target_sos_idx else best_sequence_indices

# --- Data Loading and Preprocessing ---

class Vocabulary:
    """
    Manages the mapping between characters and their numerical indices.
    Includes special tokens for padding, start-of-sequence, end-of-sequence, and unknown characters.
    """
    def __init__(self, name):
        self.name = name
        self.char2index = {PAD_TOKEN: 0, SOS_TOKEN: 1, EOS_TOKEN: 2, UNK_TOKEN: 3}
        self.index2char = {0: PAD_TOKEN, 1: SOS_TOKEN, 2: EOS_TOKEN, 3: UNK_TOKEN}
        self.char_counts = Counter()
        self.n_chars = 4 # Initialize count with special tokens
        self.pad_idx = self.char2index[PAD_TOKEN]
        self.sos_idx = self.char2index[SOS_TOKEN]
        self.eos_idx = self.char2index[EOS_TOKEN]

    def add_sequence(self, sequence):
        """Adds characters from a sequence to the character counts."""
        for char in list(sequence): self.char_counts[char] += 1

    def build_vocab(self, min_freq=1):
        """
        Builds the vocabulary mapping based on character counts and a minimum frequency.
        Characters appearing less than `min_freq` will be treated as UNK_TOKEN.
        """
        sorted_chars = sorted(self.char_counts.keys(), key=lambda char: (-self.char_counts[char], char))
        for char in sorted_chars:
            if self.char_counts[char] >= min_freq and char not in self.char2index:
                self.char2index[char] = self.n_chars
                self.index2char[self.n_chars] = char
                self.n_chars += 1

    def sequence_to_indices(self, sequence, add_eos=False, add_sos=False):
        """Converts a character sequence into a list of numerical indices."""
        indices = [self.sos_idx] if add_sos else []
        for char in list(sequence): indices.append(self.char2index.get(char, self.char2index[UNK_TOKEN]))
        if add_eos: indices.append(self.eos_idx)
        return indices

    def indices_to_sequence(self, indices):
        """Converts a list of numerical indices back into a character sequence."""
        chars = []
        for index_val in indices:
            if index_val == self.eos_idx: break # Stop at EOS token
            if index_val not in [self.sos_idx, self.pad_idx]: # Ignore SOS and PAD tokens in output
                chars.append(self.index2char.get(index_val, UNK_TOKEN))
        return "".join(chars)

class TransliterationDataset(Dataset):
    """
    A PyTorch Dataset for loading and preparing transliteration pairs.
    Reads data from a TSV file and converts text sequences to token indices.
    """
    def __init__(self, file_path, source_vocab, target_vocab, max_len=50):
        self.pairs = []
        self.source_vocab = source_vocab
        self.target_vocab = target_vocab

        if not os.path.exists(file_path):
            print(f"ERROR: Data file not found: {file_path}")
            return

        print(f"Loading data from: {file_path}")
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                parts = line.strip().split('\t')
                if len(parts) >= 2:
                    target, source = parts[0], parts[1]
                    # Skip empty sequences or sequences exceeding max_len
                    if not source or not target or \
                       (max_len and (len(source) > max_len or len(target) > max_len)):
                        continue
                    self.pairs.append((source, target))
        print(f"Loaded {len(self.pairs)} pairs from {file_path}.")

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        """Returns a source-target pair as Tensors of indices."""
        source_str, target_str = self.pairs[idx]
        source_indices = self.source_vocab.sequence_to_indices(source_str, add_eos=True)
        target_indices = self.target_vocab.sequence_to_indices(target_str, add_sos=True, add_eos=True)
        return torch.tensor(source_indices, dtype=torch.long), torch.tensor(target_indices, dtype=torch.long)

def collate_fn(batch, pad_idx_source, pad_idx_target):
    """
    Custom collate function for DataLoader to handle variable-length sequences.
    Pads sequences within a batch to the maximum length of that batch.
    """
    # Filter out any None items if __getitem__ could return them
    batch = [item for item in batch if item is not None and len(item[0]) > 0 and len(item[1]) > 0]
    if not batch: return None, None, None # Return None if batch is empty after filtering

    source_seqs, target_seqs = zip(*batch)
    source_lengths = torch.tensor([len(s) for s in source_seqs], dtype=torch.long)

    # Pad sequences to the length of the longest sequence in the batch
    padded_sources = pad_sequence(source_seqs, batch_first=True, padding_value=pad_idx_source)
    padded_targets = pad_sequence(target_seqs, batch_first=True, padding_value=pad_idx_target)

    return padded_sources, source_lengths, padded_targets

# --- Training and Evaluation Functions ---

def _train_one_epoch(model, dataloader, optimizer, criterion, device, clip_value, teacher_forcing_ratio, target_vocab):
    """
    Trains the model for a single epoch.

    Args:
        model (nn.Module): The Seq2Seq model.
        dataloader (DataLoader): DataLoader for the training set.
        optimizer (optim.Optimizer): Optimizer for model parameters.
        criterion (nn.Module): Loss function (e.g., CrossEntropyLoss).
        device (torch.device): Device to run the model on (CPU or GPU).
        clip_value (float): Gradient clipping value.
        teacher_forcing_ratio (float): Probability of using teacher forcing.
        target_vocab (Vocabulary): Vocabulary for the target language.

    Returns:
        tuple: Average epoch loss and training accuracy.
    """
    model.train() # Set model to training mode
    epoch_loss = 0
    total_correct_train = 0
    total_samples_train = 0

    if len(dataloader) == 0:
        print("Warning: Training dataloader is empty.")
        return 0.0, 0.0

    for batch_data in tqdm(dataloader, desc="Training", leave=False):
        # Skip invalid batches (e.g., from drop_last=True or filtering)
        if batch_data[0] is None: continue
        sources, source_lengths, targets = batch_data
        if sources is None or sources.shape[0] == 0: continue # Skip if sources tensor is empty

        sources, targets, source_lengths = sources.to(device), targets.to(device), source_lengths.to(device)
        optimizer.zero_grad() # Clear gradients

        # Forward pass: model predicts logits for the target sequence
        outputs_logits = model(sources, source_lengths, targets, teacher_forcing_ratio)

        # Reshape outputs and targets for CrossEntropyLoss
        # We ignore the first token (SOS) in the target for loss calculation
        output_dim = outputs_logits.shape[-1]
        flat_outputs = outputs_logits[:, 1:].reshape(-1, output_dim)
        flat_targets = targets[:, 1:].reshape(-1) # Also remove SOS from targets

        loss = criterion(flat_outputs, flat_targets)

        # Handle potential NaN/Inf loss (e.g., due to bad gradients or initial parameters)
        if not (torch.isnan(loss) or torch.isinf(loss)):
            loss.backward() # Backpropagation
            torch.nn.utils.clip_grad_norm_(model.parameters(), clip_value) # Clip gradients to prevent exploding gradients
            optimizer.step() # Update model parameters
            epoch_loss += loss.item()

        # Calculate training accuracy (exact sequence match)
        predictions_indices_train = outputs_logits.argmax(dim=2) # Get predicted token indices
        for i in range(targets.shape[0]):
            # Convert predicted and true indices to strings for comparison
            pred_str_train = target_vocab.indices_to_sequence(predictions_indices_train[i, 1:].tolist())
            true_str_train = target_vocab.indices_to_sequence(targets[i, 1:].tolist())
            if pred_str_train == true_str_train:
                total_correct_train += 1
            total_samples_train += 1

    avg_epoch_loss = epoch_loss / len(dataloader) if len(dataloader) > 0 else 0.0
    train_accuracy = total_correct_train / total_samples_train if total_samples_train > 0 else 0.0
    return avg_epoch_loss, train_accuracy

def _evaluate_one_epoch(model, dataloader, criterion, device, target_vocab, beam_width=1, is_test_set=False):
    """
    Evaluates the model's performance on a given dataset (validation or test).

    Args:
        model (nn.Module): The Seq2Seq model.
        dataloader (DataLoader): DataLoader for the evaluation set.
        criterion (nn.Module): Loss function.
        device (torch.device): Device to run the model on.
        target_vocab (Vocabulary): Vocabulary for the target language.
        beam_width (int): Beam width for decoding (1 for greedy, >1 for beam search).
        is_test_set (bool): Flag to indicate if this is the final test evaluation (for logging and samples).

    Returns:
        tuple: Average epoch loss and evaluation accuracy.
    """
    model.eval() # Set model to evaluation mode
    epoch_loss = 0
    total_correct, total_samples = 0, 0

    if len(dataloader) == 0:
        print("WARNING: Evaluation dataloader is empty.")
        return 0.0, 0.0

    desc_prefix = "Testing" if is_test_set else "Validating"

    with torch.no_grad(): # Disable gradient calculations during evaluation
        for batch_idx, batch_data in enumerate(tqdm(dataloader, desc=desc_prefix, leave=False)):
            if batch_data[0] is None: continue
            sources, source_lengths, targets = batch_data
            if sources is None or sources.shape[0] == 0: continue

            sources, targets, source_lengths = sources.to(device), targets.to(device), source_lengths.to(device)

            # Forward pass for loss calculation (using teacher forcing=0.0)
            outputs_for_loss = model(sources, source_lengths, targets, teacher_forcing_ratio=0.0)
            output_dim = outputs_for_loss.shape[-1]
            flat_outputs_for_loss = outputs_for_loss[:, 1:].reshape(-1, output_dim)
            flat_targets_for_loss = targets[:, 1:].reshape(-1)

            loss = criterion(flat_outputs_for_loss, flat_targets_for_loss)
            epoch_loss += loss.item() if not (torch.isnan(loss) or torch.isinf(loss)) else 0

            # Generate predictions for accuracy calculation for each item in batch
            for i in range(sources.shape[0]):
                src_single, src_len_single = sources[i:i+1], source_lengths[i:i+1]

                # Choose decoding strategy (beam search or greedy)
                if beam_width > 1 and hasattr(model, 'predict_beam_search'):
                    predicted_indices = model.predict_beam_search(src_single, src_len_single,
                                                                  max_output_len=targets.size(1) + 5, # Allow slightly longer output
                                                                  target_eos_idx=target_vocab.eos_idx,
                                                                  beam_width=beam_width)
                else: # Default to greedy if beam_width is 1 or method not found
                     predicted_indices = model.predict_greedy(src_single, src_len_single,
                                                             max_output_len=targets.size(1) + 5,
                                                             target_eos_idx=target_vocab.eos_idx)

                # Convert predicted and true indices to strings for comparison
                pred_str = target_vocab.indices_to_sequence(predicted_indices)
                true_str = target_vocab.indices_to_sequence(targets[i, 1:].tolist()) # Exclude SOS from true target

                # Check for exact match
                if pred_str == true_str: total_correct += 1
                total_samples += 1

                # Print debug samples for the test set
                if is_test_set and batch_idx == 0 and i < 3: # Print first 3 samples of the first batch
                    print(f"  Test Example {i} - Source: '{source_vocab.indices_to_sequence(src_single[0].tolist())}'")
                    print(f"    Pred: '{pred_str}', True: '{true_str}', Match: {pred_str == true_str}")

    avg_epoch_loss = epoch_loss / len(dataloader) if len(dataloader) > 0 else 0.0
    accuracy = total_correct / total_samples if total_samples > 0 else 0.0

    print(f"{desc_prefix} - Total Correct: {total_correct}, Total Samples: {total_samples}, Accuracy: {accuracy:.4f}, Avg Loss: {avg_epoch_loss:.4f}")
    return avg_epoch_loss, accuracy

# --- Function to Train and Save the Best Model ---

def train_and_save_best_model(config, model_save_path, device):
    """
    Performs a dedicated training run using the best hyperparameters found from a sweep.
    Saves the model checkpoint with the best validation accuracy.

    Args:
        config (dict): Dictionary of hyperparameters for this specific training run.
        model_save_path (str): Path to save the best model's state_dict.
        device (torch.device): Device to run the training on.

    Returns:
        bool: True if training was successful and a model was saved, False otherwise.
    """
    # Create a unique run name for this dedicated training
    run_name_train_best = f"TRAIN_BEST_{config['cell_type']}_emb{config['embedding_dim']}_hid{config['hidden_dim']}"

    # Initialize W&B run for this dedicated training
    with wandb.init(project="DL_A3", name=run_name_train_best, config=config, job_type="training_best_model", reinit=True) as run:
        cfg = wandb.config # Access hyperparameters via wandb.config
        print(f"Starting dedicated training for best model with config: {cfg}")

        # --- Data Loading and Vocabulary Building ---
        BASE_DATA_DIR = "/kaggle/input/dakshina-dl-a3/dakshina_dataset_v1.0/hi/"
        DATA_DIR = os.path.join(BASE_DATA_DIR, "lexicons/")
        train_file = os.path.join(DATA_DIR, "hi.translit.sampled.train.tsv")
        dev_file = os.path.join(DATA_DIR, "hi.translit.sampled.dev.tsv")

        source_vocab = Vocabulary("latin")
        target_vocab = Vocabulary("devanagari")

        # Build vocabulary from the training data
        temp_train_ds_vocab = TransliterationDataset(train_file, source_vocab, target_vocab, max_len=cfg.max_seq_len)
        if not temp_train_ds_vocab.pairs:
            print(f"ERROR: No training data loaded for vocabulary building from {train_file}.")
            return False # Indicate failure
        for src, tgt in temp_train_ds_vocab.pairs:
            source_vocab.add_sequence(src)
            target_vocab.add_sequence(tgt)
        source_vocab.build_vocab(min_freq=cfg.vocab_min_freq)
        target_vocab.build_vocab(min_freq=cfg.vocab_min_freq)
        print(f"Vocabs built. Source: {source_vocab.n_chars} unique chars, Target: {target_vocab.n_chars} unique chars.")

        # Create actual datasets for training and validation
        train_dataset = TransliterationDataset(train_file, source_vocab, target_vocab, max_len=cfg.max_seq_len)
        dev_dataset = TransliterationDataset(dev_file, source_vocab, target_vocab, max_len=cfg.max_seq_len)
        if not train_dataset.pairs or not dev_dataset.pairs:
            print(f"ERROR: Train or Dev dataset is empty after filtering. Train: {len(train_dataset.pairs)}, Dev: {len(dev_dataset.pairs)}")
            return False

        # --- DataLoader Setup ---
        # Adjust number of workers based on CPU availability and device
        num_w = 0 if device.type == 'cpu' else min(4, os.cpu_count() // 2 if os.cpu_count() else 0)
        train_loader = DataLoader(train_dataset, batch_size=cfg.batch_size_train, shuffle=True,
                                  collate_fn=lambda b: collate_fn(b, source_vocab.pad_idx, target_vocab.pad_idx),
                                  num_workers=num_w, pin_memory=True if device.type == 'cuda' else False, drop_last=True)
        dev_loader = DataLoader(dev_dataset, batch_size=cfg.batch_size_train, shuffle=False,
                                collate_fn=lambda b: collate_fn(b, source_vocab.pad_idx, target_vocab.pad_idx),
                                num_workers=num_w, pin_memory=True if device.type == 'cuda' else False)
        if not train_loader or not dev_loader:
            print(f"ERROR: Train or Dev DataLoader is empty. Train: {len(train_loader)}, Dev: {len(dev_loader)}")
            return False

        # --- Model, Optimizer, Loss Function Setup ---
        encoder = Encoder(source_vocab.n_chars, cfg.embedding_dim, cfg.hidden_dim,
                          cfg.encoder_layers, cfg.cell_type, cfg.dropout_p,
                          cfg.encoder_bidirectional, pad_idx=source_vocab.pad_idx).to(device)
        decoder = Decoder(target_vocab.n_chars, cfg.embedding_dim, cfg.hidden_dim,
                          cfg.decoder_layers, cfg.cell_type, cfg.dropout_p,
                          pad_idx=target_vocab.pad_idx).to(device)
        model = Seq2Seq(encoder, decoder, device, target_vocab.sos_idx).to(device)
        print(f"Best model parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
        wandb.watch(model, log="all", log_freq=100) # Log model weights and gradients to W&B

        optimizer = optim.Adam(model.parameters(), lr=cfg.learning_rate_train)
        # Scheduler to reduce learning rate if validation accuracy plateaus
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=cfg.lr_scheduler_patience_train, factor=0.3, verbose=True)
        criterion = nn.CrossEntropyLoss(ignore_index=target_vocab.pad_idx) # Ignore padding tokens in loss calculation

        # --- Training Loop ---
        best_val_acc_this_training = -1.0
        epochs_no_improve = 0 # Counter for early stopping

        for epoch in range(cfg.epochs_train):
            train_loss, train_acc = _train_one_epoch(model, train_loader, optimizer, criterion, device,
                                                 cfg.clip_value_train, cfg.teacher_forcing_train, target_vocab)
            # Evaluate on validation set using greedy decoding during training for simplicity
            val_loss, val_acc = _evaluate_one_epoch(model, dev_loader, criterion, device, target_vocab,
                                                beam_width=1)

            scheduler.step(val_acc) # Update learning rate based on validation accuracy

            print(f"Epoch {epoch+1}/{cfg.epochs_train} | Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} | Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}")
            # Log metrics to W&B
            wandb.log({"epoch_train_best": epoch + 1, "train_loss_best": train_loss, "train_acc_best": train_acc,
                       "val_loss_best": val_loss, "val_acc_best": val_acc, "lr_best": optimizer.param_groups[0]['lr']})

            # Early stopping logic
            if val_acc > best_val_acc_this_training:
                best_val_acc_this_training = val_acc
                epochs_no_improve = 0
                torch.save(model.state_dict(), model_save_path) # Save best model
                print(f"Saved new best model to {model_save_path} (Val Acc: {best_val_acc_this_training:.4f})")
            else:
                epochs_no_improve += 1

            if epochs_no_improve >= cfg.max_epochs_no_improve_train:
                print(f"Early stopping for best model training at epoch {epoch+1}.")
                break

        # Save the final model state if no improvement happened over the initial checkpoint (though this shouldn't be the best)
        if not os.path.exists(model_save_path):
            torch.save(model.state_dict(), model_save_path)
            print(f"Saved final model state (no improvement over initial) to {model_save_path}")

        wandb.summary["final_best_val_accuracy_during_training"] = best_val_acc_this_training
        print(f"Finished training best model. Best Val Acc: {best_val_acc_this_training:.4f}")
    return True # Indicate successful training

# --- Main Execution Block ---

if __name__ == '__main__':
    # Determine the device (GPU if available, else CPU)
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {DEVICE}")

    # --- W&B Login ---
    try:
        # Handle W&B login for Kaggle environments
        if 'KAGGLE_KERNEL_RUN_TYPE' in os.environ:
            print("Detected Kaggle environment. Ensuring WANDB_API_KEY is set.")
            if "WANDB_API_KEY" in os.environ:
                wandb.login()
            else:
                # Attempt to get API key from Kaggle secrets
                from kaggle_secrets import UserSecretsClient
                user_secrets = UserSecretsClient()
                wandb_api_key = user_secrets.get_secret("WANDB_API_KEY")
                wandb.login(key=wandb_api_key)
            print("W&B login for Kaggle successful.")
        else:
            wandb.login() # Regular W&B login for local environments
        print("W&B login process attempted/completed.")
    except Exception as e:
        print(f"W&B login failed: {e}. Please ensure your W&B API key is correctly configured.")
        exit() # Exit if W&B login fails

    # --- Best Hyperparameters (obtained from a previous sweep) ---
    # These hyperparameters are chosen based on the provided sweep results (e.g., from the best run shown in comments).
    BEST_HYPERPARAMETERS = {
        'embedding_dim': 300,
        'hidden_dim': 512,
        'encoder_layers': 2,
        'decoder_layers': 1,
        'cell_type': 'LSTM',
        'dropout_p': 0.5,
        'encoder_bidirectional': True,
        # Training specific parameters for this dedicated run:
        'learning_rate_train': 0.000376, # Learning rate from the best sweep run
        'batch_size_train': 128,         # Batch size from the best sweep run
        'epochs_train': 20,              # Max epochs for this dedicated training
        'clip_value_train': 1.0,         # Gradient clipping value
        'teacher_forcing_train': 0.5,    # Teacher forcing ratio for this training
        'max_epochs_no_improve_train': 7, # Early stopping patience for this training run
        'lr_scheduler_patience_train': 3, # Patience for LR scheduler
        # Parameters needed for dataset/vocabulary consistency:
        'vocab_min_freq': 1,
        'max_seq_len': 50,
        # Parameters for final test evaluation:
        'eval_batch_size': 128,          # Batch size for evaluation on test set
        'beam_width_eval': 3             # Beam width for final test evaluation
    }
    MODEL_SAVE_PATH = "/kaggle/working/best_model_for_testing.pt" # Path to save the best model checkpoint

    print(f"Best hyperparameters selected for dedicated training and testing: {BEST_HYPERPARAMETERS}")

    # --- Phase 1: Train and Save the Best Model ---
    print("\n" + "="*80)
    print("--- Phase 1: Training and Saving Best Model Configuration ---".center(80))
    print("="*80 + "\n")

    training_successful = train_and_save_best_model(BEST_HYPERPARAMETERS, MODEL_SAVE_PATH, DEVICE)

    if not training_successful or not os.path.exists(MODEL_SAVE_PATH):
        print("ERROR: Failed to train and save the best model. Exiting before test evaluation.")
        exit()
    print(f"\nBest model trained and saved to {MODEL_SAVE_PATH}")

    # --- Phase 2: Load and Evaluate on Test Set ---
    print("\n" + "="*80)
    print("--- Phase 2: Loading and Evaluating Best Model on Test Set ---".center(80))
    print("="*80 + "\n")

    # --- Data Setup for Test Evaluation ---
    # Need to rebuild vocab using training data to ensure consistency before processing test data
    BASE_DATA_DIR = "/kaggle/input/dakshina-dl-a3/dakshina_dataset_v1.0/hi/"
    DATA_DIR = os.path.join(BASE_DATA_DIR, "lexicons/")
    train_file_for_vocab = os.path.join(DATA_DIR, "hi.translit.sampled.train.tsv")
    test_file = os.path.join(DATA_DIR, "hi.translit.sampled.test.tsv")

    source_vocab_test = Vocabulary("latin_test")
    target_vocab_test = Vocabulary("devanagari_test")

    # Build vocabulary using the training dataset (crucial for consistent tokenization)
    temp_train_ds_test_vocab = TransliterationDataset(train_file_for_vocab, source_vocab_test, target_vocab_test,
                                                max_len=BEST_HYPERPARAMETERS['max_seq_len'])
    if not temp_train_ds_test_vocab.pairs:
        print(f"ERROR: No data loaded for vocabulary building from {train_file_for_vocab}.")
        exit()
    for src, tgt in temp_train_ds_test_vocab.pairs:
        source_vocab_test.add_sequence(src)
        target_vocab_test.add_sequence(tgt)
    source_vocab_test.build_vocab(min_freq=BEST_HYPERPARAMETERS['vocab_min_freq'])
    target_vocab_test.build_vocab(min_freq=BEST_HYPERPARAMETERS['vocab_min_freq'])
    print(f"Test Vocab built from training data. Source: {source_vocab_test.n_chars} chars. Target: {target_vocab_test.n_chars} chars.")

    # Create the test dataset
    test_dataset = TransliterationDataset(test_file, source_vocab_test, target_vocab_test,
                                          max_len=BEST_HYPERPARAMETERS['max_seq_len'])
    if not test_dataset.pairs:
        print(f"ERROR: Test dataset is empty from {test_file}.")
        exit()

    # Setup test DataLoader
    num_w_test = 0 if DEVICE.type == 'cpu' else min(4, os.cpu_count() // 2 if os.cpu_count() else 0)
    test_dataloader = DataLoader(test_dataset, batch_size=BEST_HYPERPARAMETERS['eval_batch_size'], shuffle=False,
                                 collate_fn=lambda b: collate_fn(b, source_vocab_test.pad_idx, target_vocab_test.pad_idx),
                                 num_workers=num_w_test, pin_memory=True if DEVICE.type == 'cuda' else False)
    if not test_dataloader:
        print(f"ERROR: Test DataLoader is empty.")
        exit()

    # --- Initialize Model for Testing and Load Weights ---
    # Instantiate the model architecture with the same hyperparameters as the trained model
    encoder_test = Encoder(source_vocab_test.n_chars, BEST_HYPERPARAMETERS['embedding_dim'], BEST_HYPERPARAMETERS['hidden_dim'],
                           BEST_HYPERPARAMETERS['encoder_layers'], BEST_HYPERPARAMETERS['cell_type'], BEST_HYPERPARAMETERS['dropout_p'],
                           BEST_HYPERPARAMETERS['encoder_bidirectional'], pad_idx=source_vocab_test.pad_idx).to(DEVICE)
    decoder_test = Decoder(target_vocab_test.n_chars, BEST_HYPERPARAMETERS['embedding_dim'], BEST_HYPERPARAMETERS['hidden_dim'],
                           BEST_HYPERPARAMETERS['decoder_layers'], BEST_HYPERPARAMETERS['cell_type'], BEST_HYPERPARAMETERS['dropout_p'],
                           pad_idx=target_vocab_test.pad_idx).to(DEVICE)
    model_test = Seq2Seq(encoder_test, decoder_test, DEVICE, target_vocab_test.sos_idx).to(DEVICE)

    print(f"Loading weights into test model from: {MODEL_SAVE_PATH}")
    model_test.load_state_dict(torch.load(MODEL_SAVE_PATH, map_location=DEVICE)) # Load saved weights
    print("Weights loaded successfully for test evaluation.")

    # --- Evaluate on Test Set ---
    criterion_test = nn.CrossEntropyLoss(ignore_index=target_vocab_test.pad_idx) # Use the same loss criterion
    test_loss, test_accuracy = _evaluate_one_epoch(model_test, test_dataloader, criterion_test, DEVICE,
                                                   target_vocab_test,
                                                   beam_width=BEST_HYPERPARAMETERS.get('beam_width_eval', 1), # Use the eval beam width
                                                   is_test_set=True) # Flag for printing test samples

    print(f"\n" + "="*80)
    print(f"--- FINAL TEST SET PERFORMANCE ---".center(80))
    print(f"="*80)
    print(f"  Model Checkpoint: '{MODEL_SAVE_PATH}'")
    print(f"  Hyperparameters used for training and testing: {BEST_HYPERPARAMETERS}")
    print(f"  Test Loss (avg per batch): {test_loss:.4f}")
    print(f"  Test Accuracy (Exact Match): {test_accuracy:.4f} ({test_accuracy*100:.2f}%)")
    print("="*80 + "\n")

    # --- Log Final Test Results to W&B ---
    try:
        run_name_final_test = f"FINAL_TEST_EVAL_{BEST_HYPERPARAMETERS['cell_type']}_beam{BEST_HYPERPARAMETERS['beam_width_eval']}"
        # Start a new W&B run to log only the final test results
        with wandb.init(project="DL_A3", name=run_name_final_test, config=BEST_HYPERPARAMETERS, job_type="final_evaluation", reinit=True) as test_run:
            test_run.summary["final_test_accuracy"] = test_accuracy
            test_run.summary["final_test_loss"] = test_loss
            test_run.summary["model_checkpoint_used"] = MODEL_SAVE_PATH
            print("Final test results logged to W&B.")
    except Exception as e:
        print(f"Could not log final test results to W&B: {e}")

# Q5 Attention based model

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
import random
import numpy as np
import wandb
import os
from collections import Counter
from tqdm import tqdm
import heapq

# --- Constants for Special Tokens ---
SOS_TOKEN = "<sos>"  # Start-of-sequence token
EOS_TOKEN = "<eos>"  # End-of-sequence token
PAD_TOKEN = "<pad>"  # Padding token
UNK_TOKEN = "<unk>"  # Unknown token

# --- Model Definitions (Non-Attention Version) ---

class Encoder(nn.Module):
    """
    The Encoder processes the input sequence and produces a context vector (final hidden state).
    It uses a recurrent neural network (RNN, GRU, or LSTM).
    """
    def __init__(self, input_vocab_size, embedding_dim, hidden_dim, n_layers,
                 cell_type='LSTM', dropout_p=0.1, bidirectional=False, pad_idx=0):
        super(Encoder, self).__init__()
        self.input_vocab_size = input_vocab_size
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.n_layers = n_layers
        self.cell_type = cell_type.upper()
        self.bidirectional = bidirectional
        self.num_directions = 2 if bidirectional else 1

        self.embedding = nn.Embedding(input_vocab_size, embedding_dim, padding_idx=pad_idx)
        self.dropout = nn.Dropout(dropout_p)

        rnn_dropout = dropout_p if n_layers > 1 else 0
        if self.cell_type == 'RNN':
            self.rnn = nn.RNN(embedding_dim, hidden_dim, n_layers,
                              dropout=rnn_dropout, batch_first=True, bidirectional=self.bidirectional)
        elif self.cell_type == 'GRU':
            self.rnn = nn.GRU(embedding_dim, hidden_dim, n_layers,
                              dropout=rnn_dropout, batch_first=True, bidirectional=self.bidirectional)
        elif self.cell_type == 'LSTM':
            self.rnn = nn.LSTM(embedding_dim, hidden_dim, n_layers,
                               dropout=rnn_dropout, batch_first=True, bidirectional=self.bidirectional)
        else:
            raise ValueError("Unsupported cell type. Choose from 'RNN', 'GRU', 'LSTM'.")

    def forward(self, input_seq, input_lengths):
        """
        Forward pass for the encoder.

        Args:
            input_seq (torch.Tensor): Padded input sequences of shape (batch_size, seq_len).
            input_lengths (torch.Tensor): Lengths of the original sequences in the batch of shape (batch_size,).

        Returns:
            tuple: A tuple (None, hidden_state).
                   The first element is None because this encoder does not output full sequences for attention.
                   The second element is the final hidden state of the encoder, which serves as the context.
                   For LSTM, `hidden_state` is a tuple (h, c).
        """
        embedded = self.embedding(input_seq)
        embedded = self.dropout(embedded)

        # Pack the padded sequences to handle variable-length inputs efficiently
        packed_embedded = nn.utils.rnn.pack_padded_sequence(
            embedded, input_lengths.cpu(), batch_first=True, enforce_sorted=False
        )
        # Pass packed sequences through the RNN
        _, hidden = self.rnn(packed_embedded)
        return None, hidden

class Decoder(nn.Module):
    """
    The Decoder generates the output sequence one token at a time, conditioned on the
    encoder's final hidden state and previously generated tokens.
    This version does NOT use attention.
    """
    def __init__(self, output_vocab_size, embedding_dim,
                 decoder_hidden_dim, n_layers, cell_type='LSTM', dropout_p=0.1, pad_idx=0):
        super(Decoder, self).__init__()
        self.output_vocab_size = output_vocab_size
        self.embedding_dim = embedding_dim
        self.decoder_hidden_dim = decoder_hidden_dim
        self.n_layers = n_layers
        self.cell_type = cell_type.upper()

        self.embedding = nn.Embedding(output_vocab_size, embedding_dim, padding_idx=pad_idx)
        self.dropout = nn.Dropout(dropout_p)

        rnn_input_dim = embedding_dim

        rnn_dropout = dropout_p if n_layers > 1 else 0
        if self.cell_type == 'RNN':
            self.rnn = nn.RNN(rnn_input_dim, decoder_hidden_dim, n_layers,
                              dropout=rnn_dropout, batch_first=True)
        elif self.cell_type == 'GRU':
            self.rnn = nn.GRU(rnn_input_dim, decoder_hidden_dim, n_layers,
                              dropout=rnn_dropout, batch_first=True)
        elif self.cell_type == 'LSTM':
            self.rnn = nn.LSTM(rnn_input_dim, decoder_hidden_dim, n_layers,
                               dropout=rnn_dropout, batch_first=True)
        else:
            raise ValueError("Unsupported cell type. Choose from 'RNN', 'GRU', 'LSTM'.")

        # Output linear layer to project decoder's hidden state to vocabulary size
        self.fc_out = nn.Linear(decoder_hidden_dim, output_vocab_size)

    def forward(self, input_char, prev_decoder_hidden):
        """
        Forward pass for the decoder.

        Args:
            input_char (torch.Tensor): A single token (or batch of single tokens) to be embedded,
                                       shape (batch_size,).
            prev_decoder_hidden (torch.Tensor or tuple): The previous hidden state (and cell state for LSTM)
                                                         of the decoder RNN.

        Returns:
            tuple: A tuple (prediction_logits, current_decoder_hidden).
                   `prediction_logits` are the raw scores for each token in the vocabulary.
                   `current_decoder_hidden` is the updated hidden state after processing the input.
        """
        # Add a sequence dimension for RNN input (batch_size, 1, embedding_dim)
        input_char = input_char.unsqueeze(1)
        embedded = self.embedding(input_char)
        embedded = self.dropout(embedded)

        # Pass through the RNN layer
        rnn_output, current_decoder_hidden = self.rnn(embedded, prev_decoder_hidden)

        # Squeeze the sequence dimension for the linear layer
        rnn_output_squeezed = rnn_output.squeeze(1)
        # Project to vocabulary size to get logits
        prediction_logits = self.fc_out(rnn_output_squeezed)
        return prediction_logits, current_decoder_hidden

class Seq2Seq(nn.Module):
    """
    The main Sequence-to-Sequence model that connects the Encoder and Decoder.
    This architecture does NOT use an attention mechanism.
    """
    def __init__(self, encoder, decoder, device, target_sos_idx):
        super(Seq2Seq, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device
        self.target_sos_idx = target_sos_idx

        # Calculate effective dimensions for hidden state adaptation
        encoder_effective_output_dim_per_layer = self.encoder.hidden_dim * self.encoder.num_directions
        decoder_rnn_expected_hidden_dim = self.decoder.decoder_hidden_dim

        # Check if hidden state dimensions need to be adapted from encoder to decoder
        self.needs_dim_adaptation = encoder_effective_output_dim_per_layer != decoder_rnn_expected_hidden_dim
        self.fc_adapt_hidden = None
        self.fc_adapt_cell = None

        # Create linear layers for hidden state adaptation if necessary
        if self.needs_dim_adaptation:
            self.fc_adapt_hidden = nn.Linear(encoder_effective_output_dim_per_layer, decoder_rnn_expected_hidden_dim)
            if self.encoder.cell_type == 'LSTM':
                self.fc_adapt_cell = nn.Linear(encoder_effective_output_dim_per_layer, decoder_rnn_expected_hidden_dim)

    def _adapt_encoder_hidden_for_decoder(self, encoder_final_hidden_state):
        """
        Adapts the encoder's final hidden state(s) to match the decoder's expected hidden state dimensions and layer count.
        Handles bidirectionality by concatenating forward and backward hidden states.
        If decoder has more layers than encoder, the last encoder layer's state is repeated.
        """
        is_lstm = self.encoder.cell_type == 'LSTM'
        if is_lstm:
            h_from_enc, c_from_enc = encoder_final_hidden_state
        else:
            h_from_enc = encoder_final_hidden_state
            c_from_enc = None

        batch_size = h_from_enc.size(1)

        # Reshape encoder hidden state
        h_processed = h_from_enc.view(self.encoder.n_layers, self.encoder.num_directions,
                                     batch_size, self.encoder.hidden_dim)
        if self.encoder.bidirectional:
            h_processed = torch.cat([h_processed[:, 0, :, :], h_processed[:, 1, :, :]], dim=2)
        else:
            h_processed = h_processed.squeeze(1)

        c_processed = None
        if is_lstm and c_from_enc is not None:
            c_processed = c_from_enc.view(self.encoder.n_layers, self.encoder.num_directions,
                                         batch_size, self.encoder.hidden_dim)
            if self.encoder.bidirectional:
                c_processed = torch.cat([c_processed[:, 0, :, :], c_processed[:, 1, :, :]], dim=2)
            else:
                c_processed = c_processed.squeeze(1)

        # Apply linear transformation if hidden dimensions mismatch
        if self.needs_dim_adaptation:
            h_processed = self.fc_adapt_hidden(h_processed)
            if is_lstm and c_processed is not None and self.fc_adapt_cell:
                c_processed = self.fc_adapt_cell(c_processed)

        # Initialize decoder's hidden state(s)
        final_h = torch.zeros(self.decoder.n_layers, batch_size, self.decoder.decoder_hidden_dim, device=self.device)
        final_c = torch.zeros(self.decoder.n_layers, batch_size, self.decoder.decoder_hidden_dim, device=self.device) if is_lstm else None

        # Copy or repeat encoder hidden states to match decoder's layer count
        if self.encoder.n_layers == self.decoder.n_layers:
            final_h = h_processed
            if is_lstm: final_c = c_processed
        elif self.encoder.n_layers > self.decoder.n_layers:
            final_h = h_processed[-self.decoder.n_layers:, :, :]
            if is_lstm and c_processed is not None:
                final_c = c_processed[-self.decoder.n_layers:, :, :]
        else:
            final_h[:self.encoder.n_layers, :, :] = h_processed
            if is_lstm and c_processed is not None:
                final_c[:self.encoder.n_layers, :, :] = c_processed

            if self.encoder.n_layers > 0:
                last_h_layer_to_repeat = h_processed[self.encoder.n_layers-1, :, :]
                for i in range(self.encoder.n_layers, self.decoder.n_layers):
                    final_h[i, :, :] = last_h_layer_to_repeat
                    if is_lstm and c_processed is not None:
                        last_c_layer_to_repeat = c_processed[self.encoder.n_layers-1, :, :]
                        final_c[i, :, :] = last_c_layer_to_repeat

        return (final_h, final_c) if is_lstm else final_h

    def forward(self, source_seq, source_lengths, target_seq, teacher_forcing_ratio=0.5):
        """
        Forward pass for the Seq2Seq model during training.

        Args:
            source_seq (torch.Tensor): Padded source sequences.
            source_lengths (torch.Tensor): Lengths of source sequences.
            target_seq (torch.Tensor): Padded target sequences (including SOS token).
            teacher_forcing_ratio (float): Probability of using actual target token as next input.

        Returns:
            torch.Tensor: Logits for the predicted target sequence.
        """
        batch_size = source_seq.shape[0]
        target_len = target_seq.shape[1]
        target_vocab_size = self.decoder.output_vocab_size
        outputs_logits = torch.zeros(batch_size, target_len, target_vocab_size).to(self.device)

        # Encode the source sequence to get the initial hidden state for the decoder
        _, encoder_final_hidden = self.encoder(source_seq, source_lengths)
        # Adapt encoder hidden state to decoder's expected shape
        decoder_hidden = self._adapt_encoder_hidden_for_decoder(encoder_final_hidden)

        # First input to the decoder is the <sos> token
        decoder_input = target_seq[:, 0]

        # Iterate through the target sequence, predicting one token at a time
        for t in range(target_len - 1):
            decoder_output_logits, decoder_hidden = self.decoder(decoder_input, decoder_hidden)
            outputs_logits[:, t+1] = decoder_output_logits

            # Decide whether to use teacher forcing
            teacher_force_this_step = random.random() < teacher_forcing_ratio
            top1_predicted_token = decoder_output_logits.argmax(1)

            # Use actual target token (teacher forcing) or predicted token for the next step
            decoder_input = target_seq[:, t+1] if teacher_force_this_step else top1_predicted_token
        return outputs_logits

    def predict_greedy(self, source_seq, source_lengths, max_output_len=50, target_eos_idx=None):
        """
        Generates a sequence using greedy decoding.

        Args:
            source_seq (torch.Tensor): Input source sequence (can be a single example or a batch of 1).
            source_lengths (torch.Tensor): Length of the source sequence.
            max_output_len (int): Maximum length of the sequence to generate.
            target_eos_idx (int, optional): Index of the EOS token in the target vocabulary.

        Returns:
            list: List of predicted token indices.
        """
        self.eval()
        if source_seq.dim() == 1:
            source_seq = source_seq.unsqueeze(0)
            source_lengths = torch.tensor([source_lengths if isinstance(source_lengths, int) else len(source_lengths)], device=self.device)

        with torch.no_grad():
            _, encoder_final_hidden = self.encoder(source_seq, source_lengths)
            decoder_hidden = self._adapt_encoder_hidden_for_decoder(encoder_final_hidden)
            decoder_input = torch.tensor([self.target_sos_idx], device=self.device)
            predicted_indices = []
            for _ in range(max_output_len):
                decoder_output_logits, decoder_hidden = self.decoder(decoder_input, decoder_hidden)
                top1_predicted_token = decoder_output_logits.argmax(1)
                predicted_idx = top1_predicted_token.item()
                if target_eos_idx is not None and predicted_idx == target_eos_idx:
                    break
                predicted_indices.append(predicted_idx)
                decoder_input = top1_predicted_token
        return predicted_indices

    def predict_beam_search(self, source_seq, source_lengths, max_output_len=50, target_eos_idx=None, beam_width=3):
        """
        Generates a sequence using beam search decoding.

        Args:
            source_seq (torch.Tensor): Input source sequence (must be a single example).
            source_lengths (torch.Tensor): Length of the source sequence.
            max_output_len (int): Maximum length of the sequence to generate.
            target_eos_idx (int, optional): Index of the EOS token.
            beam_width (int): The number of top sequences to keep at each step.

        Returns:
            list: List of predicted token indices for the best sequence found.
        """
        self.eval()
        if source_seq.dim() == 1:
            source_seq = source_seq.unsqueeze(0)
            source_lengths = torch.tensor([source_lengths if isinstance(source_lengths, int) else len(source_lengths)], device=self.device)

        if source_seq.shape[0] != 1:
            raise ValueError("Beam search predict function currently supports batch_size=1.")

        with torch.no_grad():
            _, encoder_final_hidden = self.encoder(source_seq, source_lengths)
            decoder_hidden_init = self._adapt_encoder_hidden_for_decoder(encoder_final_hidden)

            # Beams are stored as (cumulative_log_probability, sequence_of_indices, decoder_hidden_state)
            beams = [(0.0, [self.target_sos_idx], decoder_hidden_init)]
            completed_sequences = []

            for _ in range(max_output_len):
                new_beams = []
                all_current_beams_ended = True
                for log_prob_beam, seq_beam, hidden_beam in beams:
                    # If this beam has already ended, move it to completed sequences and skip expansion
                    if not seq_beam or seq_beam[-1] == target_eos_idx:
                        # Normalize log probability by length to counteract bias towards shorter sequences
                        completed_sequences.append((log_prob_beam / len(seq_beam) if len(seq_beam) > 0 else -float('inf'), seq_beam))
                        continue
                    all_current_beams_ended = False

                    decoder_input = torch.tensor([seq_beam[-1]], device=self.device)
                    decoder_output_logits, next_hidden_beam = self.decoder(decoder_input, hidden_beam)

                    # Get top K next tokens (log probabilities)
                    log_probs_next_token = F.log_softmax(decoder_output_logits, dim=1)
                    topk_log_probs, topk_indices = torch.topk(log_probs_next_token, beam_width, dim=1)

                    # Expand each current beam into `beam_width` new beams
                    for k in range(beam_width):
                        next_token_idx = topk_indices[0, k].item()
                        token_log_prob = topk_log_probs[0, k].item()

                        new_seq = seq_beam + [next_token_idx]
                        new_log_prob = log_prob_beam + token_log_prob
                        new_beams.append((new_log_prob, new_seq, next_hidden_beam))

                if not new_beams or all_current_beams_ended:
                    break # No new beams to explore or all current beams ended

                # Sort all new candidate beams by their log probability and keep top `beam_width`
                new_beams.sort(key=lambda x: x[0], reverse=True)
                beams = new_beams[:beam_width]

            # Add any remaining active beams to completed sequences
            for log_prob_beam, seq_beam, _ in beams:
                completed_sequences.append((log_prob_beam / len(seq_beam) if len(seq_beam) > 0 else -float('inf'), seq_beam))

            if not completed_sequences:
                # Fallback if somehow no sequences were completed or started
                return [target_eos_idx] if target_eos_idx is not None else []
            completed_sequences.sort(key=lambda x: x[0], reverse=True)

            best_sequence_indices = completed_sequences[0][1]
            # Remove SOS from the beginning if it was added
            return best_sequence_indices[1:] if best_sequence_indices and best_sequence_indices[0] == self.target_sos_idx else best_sequence_indices

# --- Data Loading and Preprocessing ---

class Vocabulary:
    """
    Manages the mapping between characters and their numerical indices.
    Includes special tokens for padding, start-of-sequence, end-of-sequence, and unknown characters.
    """
    def __init__(self, name):
        self.name = name
        self.char2index = {PAD_TOKEN: 0, SOS_TOKEN: 1, EOS_TOKEN: 2, UNK_TOKEN: 3}
        self.index2char = {0: PAD_TOKEN, 1: SOS_TOKEN, 2: EOS_TOKEN, 3: UNK_TOKEN}
        self.char_counts = Counter()
        self.n_chars = 4
        self.pad_idx = self.char2index[PAD_TOKEN]
        self.sos_idx = self.char2index[SOS_TOKEN]
        self.eos_idx = self.char2index[EOS_TOKEN]

    def add_sequence(self, sequence):
        """Adds characters from a sequence to the character counts."""
        for char in list(sequence):
            self.char_counts[char] += 1

    def build_vocab(self, min_freq=1):
        """
        Builds the vocabulary mapping based on character counts and a minimum frequency.
        Characters appearing less than `min_freq` will be treated as UNK_TOKEN.
        """
        sorted_chars = sorted(self.char_counts.keys(), key=lambda char: (-self.char_counts[char], char))
        for char in sorted_chars:
            if self.char_counts[char] >= min_freq and char not in self.char2index:
                self.char2index[char] = self.n_chars
                self.index2char[self.n_chars] = char
                self.n_chars += 1

    def sequence_to_indices(self, sequence, add_eos=False, add_sos=False):
        """Converts a character sequence into a list of numerical indices."""
        indices = []
        if add_sos:
            indices.append(self.sos_idx)
        for char in list(sequence):
            indices.append(self.char2index.get(char, self.char2index[UNK_TOKEN]))
        if add_eos:
            indices.append(self.eos_idx)
        return indices

    def indices_to_sequence(self, indices):
        """Converts a list of numerical indices back into a character sequence."""
        chars = []
        for index_val in indices:
            if index_val == self.eos_idx:
                break # Stop at EOS token
            if index_val != self.sos_idx and index_val != self.pad_idx:
                 chars.append(self.index2char.get(index_val, UNK_TOKEN))
        return "".join(chars)


class TransliterationDataset(Dataset):
    """
    A PyTorch Dataset for loading and preparing transliteration pairs.
    Reads data from a TSV file and converts text sequences to token indices.
    """
    def __init__(self, file_path, source_vocab, target_vocab, max_len=50):
        self.pairs = []
        self.source_vocab = source_vocab
        self.target_vocab = target_vocab

        if not os.path.exists(file_path):
            print(f"ERROR: Data file not found during Dataset init: {file_path}")
            return

        print(f"Loading data from: {file_path}")
        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                for line_num, line in enumerate(f):
                    parts = line.strip().split('\t')
                    if len(parts) >= 2:
                        target_sequence, source_sequence = parts[0], parts[1]

                        if max_len and (len(source_sequence) > max_len or len(target_sequence) > max_len):
                            continue
                        if not source_sequence or not target_sequence:
                            continue
                        self.pairs.append((source_sequence, target_sequence))
            print(f"Loaded {len(self.pairs)} pairs from {file_path}.")
        except Exception as e:
            print(f"ERROR: Could not read or process file {file_path}. Error: {e}")

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        """Returns a source-target pair as Tensors of indices."""
        if idx >= len(self.pairs):
            raise IndexError("Index out of bounds for dataset")
        source_str, target_str = self.pairs[idx]
        source_indices = self.source_vocab.sequence_to_indices(source_str, add_eos=True)
        target_indices = self.target_vocab.sequence_to_indices(target_str, add_sos=True, add_eos=True)
        return torch.tensor(source_indices, dtype=torch.long), \
               torch.tensor(target_indices, dtype=torch.long)

def collate_fn(batch, pad_idx_source, pad_idx_target):
    """
    Custom collate function for DataLoader to handle variable-length sequences.
    Pads sequences within a batch to the maximum length of that batch.
    """
    # Filter out any None or empty items
    batch = [item for item in batch if item is not None and item[0] is not None and item[1] is not None]
    if not batch:
        return None, None, None

    source_seqs, target_seqs = zip(*batch)

    # Filter out any empty sequences after zipping
    valid_indices = [i for i, s in enumerate(source_seqs) if len(s) > 0]
    if not valid_indices: return None, None, None

    source_seqs = [source_seqs[i] for i in valid_indices]
    target_seqs = [target_seqs[i] for i in valid_indices]
    source_lengths = torch.tensor([len(s) for s in source_seqs], dtype=torch.long)

    # Pad sequences to the length of the longest sequence in the batch
    padded_sources = pad_sequence(source_seqs, batch_first=True, padding_value=pad_idx_source)
    padded_targets = pad_sequence(target_seqs, batch_first=True, padding_value=pad_idx_target)
    return padded_sources, source_lengths, padded_targets

# --- Training and Evaluation Functions ---

def _train_one_epoch(model, dataloader, optimizer, criterion, device, clip_value, teacher_forcing_ratio, target_vocab):
    """
    Trains the model for a single epoch.

    Args:
        model (nn.Module): The Seq2Seq model.
        dataloader (DataLoader): DataLoader for the training set.
        optimizer (optim.Optimizer): Optimizer for model parameters.
        criterion (nn.Module): Loss function (e.g., CrossEntropyLoss).
        device (torch.device): Device to run the model on.
        clip_value (float): Gradient clipping value.
        teacher_forcing_ratio (float): Probability of using teacher forcing.
        target_vocab (Vocabulary): Vocabulary for the target language.

    Returns:
        tuple: Average epoch loss and training accuracy.
    """
    model.train() # Set model to training mode
    epoch_loss = 0
    total_correct_train = 0
    total_samples_train = 0

    if len(dataloader) == 0:
        print("Warning: Training dataloader is empty.")
        return 0.0, 0.0

    for batch_data in tqdm(dataloader, desc="Training", leave=False):
        if batch_data[0] is None: continue
        sources, source_lengths, targets = batch_data

        if sources is None or source_lengths is None or targets is None or source_lengths.numel() == 0 or sources.shape[0] == 0:
            print("Warning: Empty or invalid batch data after collate_fn in training. Skipping.")
            continue

        sources, targets, source_lengths = sources.to(device), targets.to(device), source_lengths.to(device)
        optimizer.zero_grad() # Clear gradients

        # Forward pass: model predicts logits for the target sequence
        outputs_logits = model(sources, source_lengths, targets, teacher_forcing_ratio)

        # Reshape outputs and targets for CrossEntropyLoss
        # We ignore the first token (SOS) in the target for loss calculation
        output_dim = outputs_logits.shape[-1]
        flat_outputs = outputs_logits[:, 1:].reshape(-1, output_dim)
        flat_targets = targets[:, 1:].reshape(-1)

        loss = criterion(flat_outputs, flat_targets)

        # Handle potential NaN/Inf loss (e.g., due to bad gradients)
        if torch.isnan(loss) or torch.isinf(loss):
            print("Warning: NaN or Inf loss detected in training. Skipping batch.")
            continue
        loss.backward() # Backpropagation
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip_value) # Clip gradients
        optimizer.step() # Update model parameters
        epoch_loss += loss.item()

        # Calculate training accuracy (exact sequence match)
        predictions_indices_train = outputs_logits.argmax(dim=2)
        for i in range(targets.shape[0]):
            pred_str_train = target_vocab.indices_to_sequence(predictions_indices_train[i, 1:].tolist())
            true_str_train = target_vocab.indices_to_sequence(targets[i, 1:].tolist())
            if pred_str_train == true_str_train:
                total_correct_train += 1
            total_samples_train += 1

    avg_epoch_loss = epoch_loss / len(dataloader) if len(dataloader) > 0 else 0.0
    train_accuracy = total_correct_train / total_samples_train if total_samples_train > 0 else 0.0
    return avg_epoch_loss, train_accuracy


def _evaluate_one_epoch(model, dataloader, criterion, device, source_vocab, target_vocab, beam_width=1, is_test_set=False):
    """
    Evaluates the model's performance on a given dataset (validation or test).

    Args:
        model (nn.Module): The Seq2Seq model.
        dataloader (DataLoader): DataLoader for the evaluation set.
        criterion (nn.Module): Loss function.
        device (torch.device): Device to run the model on.
        source_vocab (Vocabulary): Vocabulary for the source language (used for debug prints).
        target_vocab (Vocabulary): Vocabulary for the target language.
        beam_width (int): Beam width for decoding (1 for greedy, >1 for beam search).
        is_test_set (bool): Flag to indicate if this is the final test evaluation (for logging and samples).

    Returns:
        tuple: Average epoch loss and evaluation accuracy.
    """
    model.eval() # Set model to evaluation mode
    epoch_loss = 0
    total_correct, total_samples = 0, 0
    if len(dataloader) == 0:
        print("WARNING: Evaluation dataloader is empty. Returning 0 loss and 0 accuracy.")
        return 0.0, 0.0
    desc_prefix = "Testing" if is_test_set else "Validating"

    with torch.no_grad(): # Disable gradient calculations during evaluation
        for batch_idx, batch_data in enumerate(tqdm(dataloader, desc=desc_prefix, leave=False)):
            if batch_data[0] is None: continue
            sources, source_lengths, targets = batch_data

            if sources is None or source_lengths is None or targets is None or source_lengths.numel() == 0 or sources.shape[0] == 0:
                print("Warning: Empty or invalid batch data in eval after collate_fn. Skipping.")
                continue

            sources, targets, source_lengths = sources.to(device), targets.to(device), source_lengths.to(device)

            # Forward pass for loss calculation (using teacher forcing=0.0)
            outputs_for_loss = model(sources, source_lengths, targets, teacher_forcing_ratio=0.0)
            output_dim = outputs_for_loss.shape[-1]
            flat_outputs_for_loss = outputs_for_loss[:, 1:].reshape(-1, output_dim)
            flat_targets_for_loss = targets[:, 1:].reshape(-1)
            loss = criterion(flat_outputs_for_loss, flat_targets_for_loss)
            epoch_loss += loss.item() if not (torch.isnan(loss) or torch.isinf(loss)) else 0

            # Generate predictions for accuracy calculation for each item in batch
            for i in range(sources.shape[0]):
                src_single, src_len_single = sources[i:i+1], source_lengths[i:i+1]

                # Choose decoding strategy (beam search or greedy)
                if beam_width > 1 and hasattr(model, 'predict_beam_search'):
                    predicted_indices = model.predict_beam_search(src_single, src_len_single,
                                                                  max_output_len=targets.size(1) + 5,
                                                                  target_eos_idx=target_vocab.eos_idx,
                                                                  beam_width=beam_width)
                elif hasattr(model, 'predict_greedy'):
                     predicted_indices = model.predict_greedy(src_single, src_len_single,
                                                             max_output_len=targets.size(1) + 5,
                                                             target_eos_idx=target_vocab.eos_idx)
                else: # Fallback to argmax on outputs_for_loss if predict methods are not available
                    predicted_indices = outputs_for_loss[i:i+1].argmax(dim=2)[0, 1:].tolist()

                # Convert predicted and true indices to strings for comparison
                pred_str = target_vocab.indices_to_sequence(predicted_indices)
                true_str = target_vocab.indices_to_sequence(targets[i, 1:].tolist())

                # Check for exact match
                if pred_str == true_str: total_correct += 1
                total_samples += 1

                # Print a single debug sample from the first batch
                if is_test_set and batch_idx == 0 and i < 3:
                    print(f"  Test Example {i} - Source: '{source_vocab.indices_to_sequence(src_single[0].tolist())}'")
                    print(f"    Pred: '{pred_str}', True: '{true_str}', Match: {pred_str == true_str}")

    avg_epoch_loss = epoch_loss / len(dataloader) if len(dataloader) > 0 else 0.0
    accuracy = total_correct / total_samples if total_samples > 0 else 0.0
    print(f"{desc_prefix} - Total Correct: {total_correct}, Total Samples: {total_samples}, Accuracy: {accuracy:.4f}, Avg Loss: {avg_epoch_loss:.4f}")
    return avg_epoch_loss, accuracy

# --- Function to Train and Save the Best Model ---

def train_and_save_best_model(config_params, model_save_path, device):
    """
    Performs a dedicated training run using the best hyperparameters found from a sweep.
    Saves the model checkpoint with the best validation accuracy.

    Args:
        config_params (dict): Dictionary of hyperparameters for this specific training run.
        model_save_path (str): Path to save the best model's state_dict.
        device (torch.device): Device to run the training on.

    Returns:
        bool: True if training was successful and a model was saved, False otherwise.
    """
    run_name_train_best = f"TRAIN_BEST_{config_params['cell_type']}_emb{config_params['embedding_dim']}_hid{config_params['hidden_dim']}"

    with wandb.init(project="DL_A3", name=run_name_train_best, config=config_params, job_type="training_best_model_final", reinit=True) as run:
        cfg = wandb.config
        print(f"Starting dedicated training for best model with config: {cfg}")

        # --- Data Loading and Vocabulary Building ---
        BASE_DATA_DIR = "/kaggle/input/dakshina-dl-a3/dakshina_dataset_v1.0/hi/"
        DATA_DIR = os.path.join(BASE_DATA_DIR, "lexicons/")
        train_file = os.path.join(DATA_DIR, "hi.translit.sampled.train.tsv")
        dev_file = os.path.join(DATA_DIR, "hi.translit.sampled.dev.tsv")

        source_vocab = Vocabulary("latin")
        target_vocab = Vocabulary("devanagari")

        # Build vocabulary from the training data
        temp_train_ds_vocab = TransliterationDataset(train_file, source_vocab, target_vocab, max_len=cfg.max_seq_len)
        if not temp_train_ds_vocab.pairs:
            print(f"ERROR: No training data loaded for vocabulary building from {train_file}.")
            return False
        for src, tgt in temp_train_ds_vocab.pairs:
            source_vocab.add_sequence(src)
            target_vocab.add_sequence(tgt)
        source_vocab.build_vocab(min_freq=cfg.vocab_min_freq)
        target_vocab.build_vocab(min_freq=cfg.vocab_min_freq)
        print(f"Vocabs built. Source: {source_vocab.n_chars}, Target: {target_vocab.n_chars}")

        # Create actual datasets for training and validation
        train_dataset = TransliterationDataset(train_file, source_vocab, target_vocab, max_len=cfg.max_seq_len)
        dev_dataset = TransliterationDataset(dev_file, source_vocab, target_vocab, max_len=cfg.max_seq_len)
        if not train_dataset.pairs or not dev_dataset.pairs:
            print(f"ERROR: Train or Dev dataset is empty after filtering. Train: {len(train_dataset.pairs)}, Dev: {len(dev_dataset.pairs)}")
            return False

        # --- DataLoader Setup ---
        num_w = 0 if device.type == 'cpu' else min(4, os.cpu_count() // 2 if os.cpu_count() else 0)
        train_loader = DataLoader(train_dataset, batch_size=cfg.batch_size_train, shuffle=True,
                                  collate_fn=lambda b: collate_fn(b, source_vocab.pad_idx, target_vocab.pad_idx),
                                  num_workers=num_w, pin_memory=True if device.type == 'cuda' else False, drop_last=True)
        dev_loader = DataLoader(dev_dataset, batch_size=cfg.batch_size_train, shuffle=False,
                                collate_fn=lambda b: collate_fn(b, source_vocab.pad_idx, target_vocab.pad_idx),
                                num_workers=num_w, pin_memory=True if device.type == 'cuda' else False)
        if not train_loader or not dev_loader:
            print(f"ERROR: Train or Dev DataLoader is empty. Train: {len(train_loader)}, Dev: {len(dev_loader)}")
            return False

        # --- Model, Optimizer, Loss Function Setup ---
        encoder = Encoder(source_vocab.n_chars, cfg.embedding_dim, cfg.hidden_dim,
                          cfg.encoder_layers, cfg.cell_type, cfg.dropout_p,
                          cfg.encoder_bidirectional, pad_idx=source_vocab.pad_idx).to(device)
        decoder = Decoder(target_vocab.n_chars, cfg.embedding_dim, cfg.hidden_dim,
                          cfg.decoder_layers, cfg.cell_type, cfg.dropout_p,
                          pad_idx=target_vocab.pad_idx).to(device)
        model = Seq2Seq(encoder, decoder, device, target_vocab.sos_idx).to(device)
        print(f"Best model parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
        wandb.watch(model, log="all", log_freq=100)

        optimizer = optim.Adam(model.parameters(), lr=cfg.learning_rate_train)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=cfg.lr_scheduler_patience_train, factor=0.3, verbose=True)
        criterion = nn.CrossEntropyLoss(ignore_index=target_vocab.pad_idx)

        best_val_acc_this_training = -1.0
        epochs_no_improve = 0

        # --- Training Loop ---
        for epoch in range(cfg.epochs_train):
            train_loss, train_acc = _train_one_epoch(model, train_loader, optimizer, criterion, device,
                                                 cfg.clip_value_train, cfg.teacher_forcing_train, target_vocab)

            val_loss, val_acc = _evaluate_one_epoch(model, dev_loader, criterion, device,
                                                    source_vocab, target_vocab,
                                                    beam_width=1) # Use greedy for val during this training

            scheduler.step(val_acc)
            print(f"Epoch {epoch+1}/{cfg.epochs_train} | Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} | Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}")
            wandb.log({"epoch_train_best": epoch + 1, "train_loss_best": train_loss, "train_acc_best": train_acc,
                       "val_loss_best": val_loss, "val_acc_best": val_acc, "lr_best": optimizer.param_groups[0]['lr']})

            # Early stopping logic
            current_val_acc = val_acc if not np.isnan(val_acc) else -1.0
            if current_val_acc > best_val_acc_this_training:
                best_val_acc_this_training = current_val_acc
                epochs_no_improve = 0
                torch.save(model.state_dict(), model_save_path) # Save best model
                print(f"Saved new best model to {model_save_path} (Val Acc: {best_val_acc_this_training:.4f})")
            else:
                epochs_no_improve += 1

            if epochs_no_improve >= cfg.max_epochs_no_improve_train:
                print(f"Early stopping for best model training at epoch {epoch+1}.")
                break

        # Save the final model state if no improvement happened over the initial checkpoint
        if not os.path.exists(model_save_path):
            torch.save(model.state_dict(), model_save_path)
            print(f"Saved final model state (no improvement over initial) to {model_save_path}")

        wandb.summary["final_best_val_accuracy_during_training"] = best_val_acc_this_training
        print(f"Finished training best model. Best Val Acc: {best_val_acc_this_training:.4f}")
    return True

# --- Main Execution Block for Training Best Model and Testing ---
if __name__ == '__main__':
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {DEVICE}")

    # --- W&B Login ---
    try:
        if 'KAGGLE_KERNEL_RUN_TYPE' in os.environ:
            print("Detected Kaggle environment. Ensuring WANDB_API_KEY is set.")
            if "WANDB_API_KEY" in os.environ:
                wandb.login()
                print("W&B login using environment variable.")
            else:
                from kaggle_secrets import UserSecretsClient
                user_secrets = UserSecretsClient()
                wandb_api_key = user_secrets.get_secret("WANDB_API_KEY")
                wandb.login(key=wandb_api_key)
                print("W&B login for Kaggle successful.")
        else:
            wandb.login()
        print("W&B login process attempted/completed.")
    except Exception as e:
        print(f"W&B login failed: {e}. Ensure API key is configured.")
        exit()

    # --- Best Hyperparameters (obtained from a previous sweep) ---
    # These hyperparameters are chosen based on the provided sweep results.
    BEST_HYPERPARAMETERS_FOR_TRAINING = {
        'embedding_dim': 300,
        'hidden_dim': 512,
        'encoder_layers': 2,
        'decoder_layers': 1,
        'cell_type': 'LSTM',
        'dropout_p': 0.5,
        'encoder_bidirectional': True,
        'learning_rate_train': 0.000376,
        'batch_size_train': 128,
        'epochs_train': 20,
        'clip_value_train': 1.0,
        'teacher_forcing_train': 0.5,
        'max_epochs_no_improve_train': 7,
        'lr_scheduler_patience_train': 3,
        'vocab_min_freq': 1,
        'max_seq_len': 50,
        'eval_batch_size': 128,
        'beam_width_eval': 3
    }
    MODEL_SAVE_PATH = "/kaggle/working/best_model_for_testing.pt"

    print(f"Best hyperparameters selected for training: {BEST_HYPERPARAMETERS_FOR_TRAINING}")

    # --- Phase 1: Train and Save the Best Model ---
    print("\n" + "="*80)
    print("--- Phase 1: Training and Saving Best Model Configuration ---".center(80))
    print("="*80 + "\n")

    training_successful = train_and_save_best_model(BEST_HYPERPARAMETERS_FOR_TRAINING, MODEL_SAVE_PATH, DEVICE)

    if not training_successful or not os.path.exists(MODEL_SAVE_PATH):
        print("ERROR: Failed to train and save the best model. Exiting before test evaluation.")
        exit()
    print(f"Best model trained and saved to {MODEL_SAVE_PATH}")

    # --- Phase 2: Load and Evaluate on Test Set ---
    print("\n" + "="*80)
    print("--- Phase 2: Loading and Evaluating Best Model on Test Set ---".center(80))
    print("="*80 + "\n")

    BASE_DATA_DIR = "/kaggle/input/dakshina-dl-a3/dakshina_dataset_v1.0/hi/"
    DATA_DIR = os.path.join(BASE_DATA_DIR, "lexicons/")
    train_file_for_vocab = os.path.join(DATA_DIR, "hi.translit.sampled.train.tsv")
    test_file = os.path.join(DATA_DIR, "hi.translit.sampled.test.tsv")

    source_vocab_test = Vocabulary("latin_test")
    target_vocab_test = Vocabulary("devanagari_test")

    # Build vocabulary using the training dataset (crucial for consistent tokenization)
    temp_train_ds_test_vocab = TransliterationDataset(train_file_for_vocab, source_vocab_test, target_vocab_test,
                                                max_len=BEST_HYPERPARAMETERS_FOR_TRAINING['max_seq_len'])
    if not temp_train_ds_test_vocab.pairs:
        print(f"ERROR: No data loaded for vocabulary building from {train_file_for_vocab}.")
        exit()
    for src, tgt in temp_train_ds_test_vocab.pairs:
        source_vocab_test.add_sequence(src)
        target_vocab_test.add_sequence(tgt)
    source_vocab_test.build_vocab(min_freq=BEST_HYPERPARAMETERS_FOR_TRAINING['vocab_min_freq'])
    target_vocab_test.build_vocab(min_freq=BEST_HYPERPARAMETERS_FOR_TRAINING['vocab_min_freq'])
    print(f"Test Vocab built from training data. Source: {source_vocab_test.n_chars} chars. Target: {target_vocab_test.n_chars} chars.")

    # Create the test dataset
    test_dataset = TransliterationDataset(test_file, source_vocab_test, target_vocab_test,
                                          max_len=BEST_HYPERPARAMETERS_FOR_TRAINING['max_seq_len'])
    if not test_dataset.pairs:
        print(f"ERROR: Test dataset is empty from {test_file}.")
        exit()

    # Setup test DataLoader
    num_w_test = 0 if DEVICE.type == 'cpu' else min(4, os.cpu_count() // 2 if os.cpu_count() else 0)
    test_dataloader = DataLoader(test_dataset, batch_size=BEST_HYPERPARAMETERS_FOR_TRAINING['eval_batch_size'], shuffle=False,
                                 collate_fn=lambda b: collate_fn(b, source_vocab_test.pad_idx, target_vocab_test.pad_idx),
                                 num_workers=num_w_test, pin_memory=True if DEVICE.type == 'cuda' else False)
    if not test_dataloader or len(test_dataloader) == 0:
        print(f"ERROR: Test Dataloader is empty. Dataset size: {len(test_dataset)}, Batch size: {BEST_HYPERPARAMETERS_FOR_TRAINING['eval_batch_size']}")
        exit()

    # --- Initialize Model for Testing and Load Weights ---
    encoder_test = Encoder(source_vocab_test.n_chars, BEST_HYPERPARAMETERS_FOR_TRAINING['embedding_dim'], BEST_HYPERPARAMETERS_FOR_TRAINING['hidden_dim'],
                           BEST_HYPERPARAMETERS_FOR_TRAINING['encoder_layers'], BEST_HYPERPARAMETERS_FOR_TRAINING['cell_type'], BEST_HYPERPARAMETERS_FOR_TRAINING['dropout_p'],
                           BEST_HYPERPARAMETERS_FOR_TRAINING['encoder_bidirectional'], pad_idx=source_vocab_test.pad_idx).to(DEVICE)
    decoder_test = Decoder(target_vocab_test.n_chars, BEST_HYPERPARAMETERS_FOR_TRAINING['embedding_dim'], BEST_HYPERPARAMETERS_FOR_TRAINING['hidden_dim'],
                           BEST_HYPERPARAMETERS_FOR_TRAINING['decoder_layers'], BEST_HYPERPARAMETERS_FOR_TRAINING['cell_type'], BEST_HYPERPARAMETERS_FOR_TRAINING['dropout_p'],
                           pad_idx=target_vocab_test.pad_idx).to(DEVICE)
    model_test = Seq2Seq(encoder_test, decoder_test, DEVICE, target_vocab_test.sos_idx).to(DEVICE)

    print(f"Loading weights into test model from: {MODEL_SAVE_PATH}")
    model_test.load_state_dict(torch.load(MODEL_SAVE_PATH, map_location=DEVICE))
    print("Weights loaded successfully for test evaluation.")

    # --- Evaluate on Test Set ---
    criterion_test = nn.CrossEntropyLoss(ignore_index=target_vocab_test.pad_idx)
    test_loss, test_accuracy = _evaluate_one_epoch(model_test, test_dataloader, criterion_test, DEVICE,
                                                   source_vocab_test, target_vocab_test,
                                                   beam_width=BEST_HYPERPARAMETERS_FOR_TRAINING.get('beam_width_eval', 1),
                                                   is_test_set=True)

    print(f"\n--- FINAL TEST SET PERFORMANCE ---")
    print(f"Model Checkpoint: '{MODEL_SAVE_PATH}'")
    # Print only relevant hyperparameters for the loaded model, not training-specific ones
    eval_config_to_print = {k: v for k, v in BEST_HYPERPARAMETERS_FOR_TRAINING.items() if not k.endswith('_train')}
    print(f"Hyperparameters: {eval_config_to_print}")
    print(f"  Test Loss (avg per batch): {test_loss:.4f}")
    print(f"  Test Accuracy (Exact Match): {test_accuracy:.4f} ({test_accuracy*100:.2f}%)")

    # --- Log Test Results to W&B ---
    try:
        run_name_final_test = f"TEST_BEST_MODEL_{BEST_HYPERPARAMETERS_FOR_TRAINING['cell_type']}"
        with wandb.init(project="DL_A3", name=run_name_final_test, config=BEST_HYPERPARAMETERS_FOR_TRAINING, job_type="final_test_evaluation", reinit=True) as test_run:
            test_run.summary["final_test_accuracy"] = test_accuracy
            test_run.summary["final_test_loss"] = test_loss
            test_run.summary["model_checkpoint_used"] = MODEL_SAVE_PATH
            print("Final test results logged to W&B.")
    except Exception as e:
        print(f"Could not log final test results to W&B: {e}")