<a href="https://colab.research.google.com/github/SaiRajesh228/DA6401_Assignment3/blob/main/withAttention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence, pad_sequence
from torch.utils.data import Dataset, DataLoader
import random
import os
import pickle
import json
import pandas as pd
from tqdm.auto import tqdm
import csv
import wandb

# Log in to Weights & Biases with the provided key
wandb.login(key='32f6049439fd96afecb91b2853dcb24d77f2f9d3')

# For reproducibility
def set_random_seeds(seed=42):
    """Set random seeds for reproducibility across libraries"""
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# Character vocabulary class
class CharacterVocabulary:
    """Character-level vocabulary for transliteration tasks"""
    def __init__(self, token_list=None, special_tokens=['<pad>','<bos>','<eos>','<unk>']):
        self.special_tokens = special_tokens
        self.idx_to_char = list(special_tokens) + (token_list or [])
        self.char_to_idx = {ch:i for i,ch in enumerate(self.idx_to_char)}

    @classmethod
    def create_from_texts(cls, text_list):
        """Build vocabulary from a list of text samples"""
        unique_chars = sorted({char for text in text_list for char in text})
        return cls(token_list=unique_chars)

    @classmethod
    def create_from_file(cls, file_path, src_col='src', tgt_col='tgt', is_csv=True):
        """Build vocabulary from a data file (CSV or TSV)"""
        if is_csv:
            df = pd.read_csv(file_path, header=None, names=[src_col, tgt_col])
            texts = df[src_col].dropna().tolist() + df[tgt_col].dropna().tolist()
        else:
            texts = []
            with open(file_path, encoding='utf-8') as f:
                for line in f:
                    parts = line.strip().split('\t')
                    if len(parts) >= 2:
                        texts.extend([parts[0], parts[1]])

        return cls.create_from_texts(texts)

    def save(self, path):
        """Save vocabulary to JSON file"""
        with open(path, 'w', encoding='utf-8') as f:
            json.dump(self.idx_to_char, f, ensure_ascii=False)

    @classmethod
    def load(cls, path):
        """Load vocabulary from JSON file"""
        with open(path, encoding='utf-8') as f:
            idx_to_char = json.load(f)

        vocab = cls(token_list=[])
        vocab.idx_to_char = idx_to_char
        vocab.char_to_idx = {c:i for i,c in enumerate(idx_to_char)}
        return vocab

    def tokenize(self, text, add_bos=False, add_eos=False):
        """Convert text to a sequence of indices"""
        indices = []
        if add_bos: indices.append(self.char_to_idx['<bos>'])
        for c in text:
            indices.append(self.char_to_idx.get(c, self.char_to_idx['<unk>']))
        if add_eos: indices.append(self.char_to_idx['<eos>'])
        return indices

    def detokenize(self, indices, remove_special=True, join=True):
        """Convert a sequence of indices back to text"""
        if hasattr(indices, 'tolist'):
            indices = indices.tolist()

        chars = [self.idx_to_char[i] for i in indices if i < len(self.idx_to_char)]

        if remove_special:
            chars = [c for c in chars if c not in self.special_tokens]

        return ''.join(chars) if join else chars

    def batch_detokenize(self, batch_indices, remove_special=True):
        """Decode a batch of index sequences"""
        return [self.detokenize(seq, remove_special=remove_special) for seq in batch_indices]

    def get_statistics(self):
        """Get vocabulary statistics"""
        return {
            'total_size': len(self.idx_to_char),
            'special_tokens': len(self.special_tokens),
            'character_count': len(self.idx_to_char) - len(self.special_tokens)
        }

    def __len__(self):
        return len(self.idx_to_char)

    @property
    def pad_id(self): return self.char_to_idx['<pad>']

    @property
    def bos_id(self): return self.char_to_idx['<bos>']

    @property
    def eos_id(self): return self.char_to_idx['<eos>']

    @property
    def unk_id(self): return self.char_to_idx['<unk>']

    @property
    def vocab_size(self): return len(self.idx_to_char)

# Data processing
class TransliterationDataset(Dataset):
    """Dataset class for transliteration tasks"""

    def __init__(self, file_path, source_vocab, target_vocab, dataset_type='dakshina'):
        self.examples = []
        self.dataset_type = dataset_type

        if dataset_type == 'dakshina':
            for src, tgt in self._read_tsv_file(file_path):
                src_ids = source_vocab.tokenize(src, add_bos=True, add_eos=True)
                tgt_ids = target_vocab.tokenize(tgt, add_bos=True, add_eos=True)
                self.examples.append((
                    torch.tensor(src_ids, dtype=torch.long),
                    torch.tensor(tgt_ids, dtype=torch.long)
                ))
        else:
            raise ValueError(f"Unsupported dataset type: {dataset_type}")

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        return self.examples[idx]

    def _read_tsv_file(self, path):
        """Read a tab-separated file with source and target text"""
        with open(path, encoding='utf-8') as f:
            for line in f:
                parts = line.strip().split('\t')
                if len(parts) >= 2:
                    yield parts[1], parts[0]  # Dakshina format: target, source

def create_batches(batch, src_vocab, tgt_vocab):
    """Custom collate function for variable-length sequences"""
    srcs, tgts = zip(*batch)
    src_padded = pad_sequence(srcs, batch_first=True, padding_value=src_vocab.pad_id)
    tgt_padded = pad_sequence(tgts, batch_first=True, padding_value=tgt_vocab.pad_id)
    src_lengths = torch.tensor([len(s) for s in srcs], dtype=torch.long)
    return src_padded, src_lengths, tgt_padded

def load_data(
        language='te',
        dataset_type='dakshina',
        dataset_path=None,
        batch_size=64,
        device='cpu',
        worker_count=2,
        prefetch_factor=4,
        persistent_workers=True,
        cache_dir='./cache',
        use_cached_vocab=True
    ):
    """Load transliteration datasets and vocabulary"""
    if dataset_path is None:
        dataset_path = os.path.join(
            '/content/dakshina_dataset_v1.0',
            language, 'lexicons'
        )

    # Create cache directory if it doesn't exist
    if use_cached_vocab:
        os.makedirs(cache_dir, exist_ok=True)
        vocab_cache_path = os.path.join(cache_dir, f"{language}_{dataset_type}_vocab.pkl")

    # Try to load cached vocabularies
    if use_cached_vocab and os.path.exists(vocab_cache_path):
        print(f"Loading cached vocabularies from {vocab_cache_path}")
        with open(vocab_cache_path, 'rb') as f:
            src_vocab, tgt_vocab = pickle.load(f)
    else:
        # Build vocabularies from data
        all_src, all_tgt = [], []

        for split in ['train', 'dev']:
            file_path = os.path.join(dataset_path, f"{language}.translit.sampled.{split}.tsv")
            with open(file_path, encoding='utf-8') as f:
                for line in f:
                    parts = line.strip().split('\t')
                    if len(parts) >= 2:
                        all_src.append(parts[1])  # Dakshina format has target, source
                        all_tgt.append(parts[0])

        # Build vocabularies
        src_vocab = CharacterVocabulary.create_from_texts(all_src)
        tgt_vocab = CharacterVocabulary.create_from_texts(all_tgt)

        # Cache vocabularies
        if use_cached_vocab:
            with open(vocab_cache_path, 'wb') as f:
                pickle.dump((src_vocab, tgt_vocab), f)

    # DataLoader configuration
    loader_config = dict(
        batch_size=batch_size,
        num_workers=worker_count,
        prefetch_factor=prefetch_factor,
        persistent_workers=persistent_workers and worker_count > 0,
        pin_memory=(device == 'cuda')
    )

    # Create data loaders for each split
    data_loaders = {}

    splits = {'train': 'train', 'dev': 'dev', 'test': 'test'}
    for split_name, file_split in splits.items():
        file_path = os.path.join(dataset_path, f"{language}.translit.sampled.{file_split}.tsv")
        dataset = TransliterationDataset(file_path, src_vocab, tgt_vocab, dataset_type='dakshina')
        data_loaders[split_name] = DataLoader(
            dataset,
            shuffle=(split_name == 'train'),
            collate_fn=lambda b: create_batches(b, src_vocab, tgt_vocab),
            **loader_config
        )

    return data_loaders, src_vocab, tgt_vocab

# Model Components
class RNNEncoder(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers=1,
                 rnn_type='LSTM', dropout=0.0, bidirectional=False):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.bidirectional = bidirectional
        self.rnn_type = rnn_type
        self.num_layers = num_layers
        self.hidden_dim = hidden_dim

        # Output size will be doubled if bidirectional
        self.output_dim = hidden_dim * 2 if bidirectional else hidden_dim

        rnn_classes = {'LSTM': nn.LSTM, 'GRU': nn.GRU, 'RNN': nn.RNN}
        if rnn_type not in rnn_classes:
            raise ValueError(f"Unsupported RNN type: {rnn_type}")

        self.rnn = rnn_classes[rnn_type](embedding_dim,
                                       hidden_dim,
                                       num_layers=num_layers,
                                       dropout=dropout if num_layers > 1 else 0.0,
                                       batch_first=True,
                                       bidirectional=bidirectional)

    def forward(self, inputs, lengths):
        # inputs: [batch_size, seq_len], lengths: [batch_size]
        embedded = self.embedding(inputs)  # [batch_size, seq_len, embedding_dim]
        packed_input = pack_padded_sequence(embedded, lengths.cpu(), batch_first=True, enforce_sorted=False)
        packed_output, hidden_states = self.rnn(packed_input)
        output, _ = pad_packed_sequence(packed_output, batch_first=True)  # [batch_size, seq_len, hidden_dim*dirs]

        # Process hidden state based on RNN type and bidirectionality
        if self.bidirectional:
            if self.rnn_type == 'LSTM':
                # For LSTM we have both hidden and cell states
                h_n, c_n = hidden_states
                # Combine forward and backward states by averaging
                h_n = torch.add(h_n[0:self.num_layers], h_n[self.num_layers:]) / 2
                c_n = torch.add(c_n[0:self.num_layers], c_n[self.num_layers:]) / 2
                hidden_states = (h_n, c_n)
            else:
                # For GRU/RNN we only have hidden state
                hidden_states = torch.add(hidden_states[0:self.num_layers], hidden_states[self.num_layers:]) / 2

        return output, hidden_states

# Bahdanau Attention mechanism - the core of the attention model
class BahdanauAttention(nn.Module):
    """
    Bahdanau attention mechanism (additive attention)

    This attention mechanism allows the decoder to focus on different parts
    of the encoder's outputs at each decoding step, enabling the model to
    better capture alignments between source and target sequences.
    """
    def __init__(self, encoder_dim, decoder_dim):
        super().__init__()
        # Linear layer to process concatenated encoder and decoder states
        self.attention_layer = nn.Linear(encoder_dim + decoder_dim, decoder_dim)
        # Vector to convert processed states to attention scores
        self.v_layer = nn.Linear(decoder_dim, 1, bias=False)

    def forward(self, hidden, encoder_outputs, mask):
        """
        Calculate attention weights

        Args:
            hidden: [batch_size, decoder_dim] - Current decoder hidden state
            encoder_outputs: [batch_size, src_len, encoder_dim] - All encoder outputs
            mask: [batch_size, src_len] - Source padding mask (1 for real tokens, 0 for padding)

        Returns:
            attention_weights: [batch_size, src_len] - Attention weights for each position
        """
        batch_size, src_len, _ = encoder_outputs.size()

        # Repeat decoder hidden state for each encoder position
        hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)  # [batch_size, src_len, decoder_dim]

        # Attention energy calculation
        # Concatenate encoder and decoder representations
        combined = torch.cat((hidden, encoder_outputs), dim=2)  # [batch_size, src_len, encoder_dim+decoder_dim]

        # Process through attention layer and apply tanh activation
        energy = torch.tanh(self.attention_layer(combined))  # [batch_size, src_len, decoder_dim]

        # Convert to attention scores
        attention_scores = self.v_layer(energy).squeeze(2)  # [batch_size, src_len]

        # Apply mask to ignore padding positions - set scores at pad positions to negative infinity
        attention_scores = attention_scores.masked_fill(~mask, -1e10)

        # Apply softmax to get attention weights that sum to 1
        attention_weights = torch.softmax(attention_scores, dim=1)  # [batch_size, src_len]

        return attention_weights

class AttentionDecoder(nn.Module):
    """
    RNN decoder with Bahdanau attention mechanism

    This decoder attends to the encoder outputs at each decoding step, allowing
    it to focus on relevant parts of the input sequence when generating the output.
    """
    def __init__(self, vocab_size, embedding_dim, encoder_hidden_dim, decoder_hidden_dim,
                 num_layers=1, rnn_type="LSTM", dropout=0.0):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.rnn_type = rnn_type

        # Initialize the attention mechanism
        self.attention = BahdanauAttention(encoder_hidden_dim, decoder_hidden_dim)

        # Input to RNN is embedding + context vector from attention
        rnn_input_dim = embedding_dim + encoder_hidden_dim

        # Input to output layer combines RNN hidden state, context vector from attention, and embedding
        # This concatenation is important for the model to have direct access to both
        # the current context and the embedding of the previous token
        fc_input_dim = decoder_hidden_dim + encoder_hidden_dim + embedding_dim

        # Select the appropriate RNN type
        rnn_classes = {"LSTM": nn.LSTM, "GRU": nn.GRU, "RNN": nn.RNN}
        if rnn_type not in rnn_classes:
            raise ValueError(f"Unsupported RNN type: {rnn_type}")

        self.rnn = rnn_classes[rnn_type](rnn_input_dim, decoder_hidden_dim,
                                        num_layers=num_layers,
                                        dropout=dropout if num_layers > 1 else 0.0,
                                        batch_first=True)

        # Final output layer that predicts the next token
        self.output_layer = nn.Linear(fc_input_dim, vocab_size)

    def forward(self, input_token, hidden, encoder_outputs, mask):
        """
        Forward pass of the attention decoder

        Args:
            input_token : [batch_size] - Current input token
            hidden : tuple(tensor) or tensor - Initial RNN state
            encoder_outputs : [batch_size, src_len, encoder_hidden_dim] - Encoder outputs
            mask : [batch_size, src_len] - Source padding mask

        Returns:
            logits : [batch_size, vocab_size] - Prediction logits
            hidden : Updated RNN state
            attention_weights : [batch_size, src_len] - Attention weights
        """
        # Embed the current token
        embedded = self.embedding(input_token).unsqueeze(1)  # [batch_size, 1, embedding_dim]

        # Extract the hidden state for attention (depends on RNN type)
        if self.rnn_type == 'LSTM':
            # For LSTM, use the hidden state (not cell state)
            attn_hidden = hidden[0][-1]
        else:
            # For GRU/RNN
            attn_hidden = hidden[-1]

        # Calculate attention weights over encoder outputs
        attention_weights = self.attention(attn_hidden, encoder_outputs, mask)  # [batch_size, src_len]

        # Create context vector by applying attention weights to encoder outputs
        # This is a key step: we use the attention weights to create a weighted sum
        # of encoder outputs, focusing on relevant parts of the input sequence
        context = torch.bmm(attention_weights.unsqueeze(1), encoder_outputs)  # [batch_size, 1, encoder_hidden_dim]

        # Combine embedding and context as input to the RNN
        # By concatenating these, the RNN can use both the previous token and the relevant
        # context from the encoder when generating the next token
        rnn_input = torch.cat((embedded, context), dim=2)  # [batch_size, 1, embedding_dim + encoder_hidden_dim]

        # Run through RNN
        output, hidden = self.rnn(rnn_input, hidden)  # [batch_size, 1, decoder_hidden_dim]
        output = output.squeeze(1)  # [batch_size, decoder_hidden_dim]
        embedded = embedded.squeeze(1)  # [batch_size, embedding_dim]
        context = context.squeeze(1)  # [batch_size, encoder_hidden_dim]

        # Generate prediction logits by combining all available information
        # This allows the model to consider the RNN state, context vector, and current input token
        logits = self.output_layer(torch.cat((output, context, embedded), dim=1))

        return logits, hidden, attention_weights

class Seq2SeqWithAttention(nn.Module):
    """
    Sequence-to-sequence model with attention mechanism

    This model consists of an encoder that processes the input sequence and
    an attention decoder that generates the output sequence while attending
    to relevant parts of the input.
    """
    def __init__(self, encoder, decoder, pad_idx, device='cpu'):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.pad_idx = pad_idx
        self.device = device

    def forward(self, src, src_lengths, tgt, teacher_forcing_ratio=0.5):
        """
        Forward pass with teacher forcing

        Args:
            src : [batch_size, src_len] - Source sequence
            src_lengths : [batch_size] - Lengths of each source sequence
            tgt : [batch_size, tgt_len] - Target sequence
            teacher_forcing_ratio : float - Probability of using teacher forcing

        Returns:
            outputs : [batch_size, tgt_len-1, vocab_size] - Decoder outputs
        """
        # Encode the source sequence
        encoder_outputs, hidden = self.encoder(src, src_lengths)

        # Create mask for attention (1 for real tokens, 0 for padding)
        mask = (src != self.pad_idx)

        batch_size, target_length = tgt.size()
        output_dim = self.decoder.output_layer.out_features

        # Initialize tensor to store decoder outputs
        outputs = torch.zeros(batch_size, target_length-1, output_dim, device=self.device)

        # First input to the decoder is the <bos> token
        decoder_input = tgt[:, 0]

        # Teacher forcing is applied with probability teacher_forcing_ratio
        for t in range(1, target_length):
            # Pass through decoder with attention
            decoder_output, hidden, _ = self.decoder(decoder_input, hidden, encoder_outputs, mask)

            # Store the output
            outputs[:, t-1] = decoder_output

            # Decide whether to use teacher forcing
            use_teacher_forcing = random.random() < teacher_forcing_ratio

            if use_teacher_forcing:
                # Teacher forcing: use ground-truth as next input
                decoder_input = tgt[:, t]
            else:
                # No teacher forcing: use model's prediction as next input
                decoder_input = decoder_output.argmax(1)

        return outputs

    def generate(self, src, src_lengths, tgt_vocab, max_len=50):
        """
        Generate a translation using greedy decoding

        Args:
            src : [batch_size, src_len] - Source sequence
            src_lengths : [batch_size] - Lengths of each source sequence
            tgt_vocab : CharacterVocabulary - Target vocabulary
            max_len : int - Maximum length of generated sequence

        Returns:
            generated_tokens : [batch_size, seq_len] - Generated sequences
        """
        # Encode the source sequence
        encoder_outputs, hidden = self.encoder(src, src_lengths)

        # Create mask for attention
        mask = (src != self.pad_idx)

        batch_size = src.size(0)

        # First input is the <bos> token
        decoder_input = torch.full((batch_size,), tgt_vocab.bos_id, device=self.device, dtype=torch.long)

        # List to store generated tokens and attention weights
        generated_tokens = []
        attention_weights_list = []

        # Generate tokens one by one
        for _ in range(max_len):
            # Get decoder output and attention weights
            decoder_output, hidden, attn_weights = self.decoder(decoder_input, hidden, encoder_outputs, mask)

            # Get the most likely token
            next_token = decoder_output.argmax(1)

            # Add to our generated tokens and attention weights
            generated_tokens.append(next_token.unsqueeze(1))
            attention_weights_list.append(attn_weights.unsqueeze(1))

            # Update the decoder input for the next step
            decoder_input = next_token

            # Stop if all sequences have generated the <eos> token
            if (next_token == tgt_vocab.eos_id).all():
                break

        # Concatenate all tokens and attention weights
        return torch.cat(generated_tokens, dim=1)

# Training and evaluation utilities
def calculate_accuracy(model, data_loader, tgt_vocab, src_vocab, device):
    """Calculate accuracy and collect prediction details"""
    model.eval()
    correct = total = 0

    # Lists to store detailed results
    correct_sources = []
    correct_targets = []
    correct_predictions = []

    incorrect_sources = []
    incorrect_targets = []
    incorrect_predictions = []

    with torch.no_grad():
        for src, src_lengths, tgt in data_loader:
            src, src_lengths, tgt = (x.to(device) for x in (src, src_lengths, tgt))
            predictions = model.generate(src, src_lengths, tgt_vocab, max_len=tgt.size(1))

            # Process each example in the batch
            for idx in range(src.size(0)):
                # Convert indices to strings
                predicted_text = tgt_vocab.detokenize(predictions[idx].cpu().tolist())
                target_text = tgt_vocab.detokenize(tgt[idx, 1:].cpu().tolist())  # Skip <bos>
                source_text = src_vocab.detokenize(src[idx].cpu().tolist())

                # Check if prediction matches target
                is_correct = (predicted_text == target_text)
                correct += is_correct

                # Store detailed results
                if is_correct:
                    correct_sources.append(source_text)
                    correct_targets.append(target_text)
                    correct_predictions.append(predicted_text)
                else:
                    incorrect_sources.append(source_text)
                    incorrect_targets.append(target_text)
                    incorrect_predictions.append(predicted_text)

            total += src.size(0)

    accuracy = correct / total if total else 0.0
    return (
        accuracy,
        (correct_sources, correct_targets, correct_predictions),
        (incorrect_sources, incorrect_targets, incorrect_predictions)
    )

def save_predictions(src_list, tgt_list, pred_list, file_name):
    """Save prediction details to CSV file for analysis"""
    with open(file_name, mode='w', newline='', encoding='utf-8') as file:
        writer = csv.writer(file)
        writer.writerow(['Source', 'Target', 'Predicted'])
        for row in zip(src_list, tgt_list, pred_list):
            writer.writerow(row)

    return file_name

def train_model(
    model,
    data_loaders,
    src_vocab,
    tgt_vocab,
    device,
    config,
    save_path=None,
    log_to_wandb=True
):
    """Train a sequence-to-sequence model with attention"""
    criterion = nn.CrossEntropyLoss(ignore_index=tgt_vocab.pad_id)

    # Select optimizer based on config
    if config['optimizer'].lower() == 'adam':
        optimizer = optim.Adam(model.parameters(), lr=config['learning_rate'])
    elif config['optimizer'].lower() == 'nadam':
        optimizer = optim.NAdam(model.parameters(), lr=config['learning_rate'])
    else:
        optimizer = optim.Adam(model.parameters(), lr=config['learning_rate'])

    # Track best validation accuracy
    best_val_acc = 0.0

    # Main training loop
    for epoch in tqdm(range(1, config['epochs'] + 1), desc="Epochs", position=0):
        model.train()
        total_loss = 0.0

        # Training batches with progress bar
        train_loader = tqdm(data_loaders['train'], desc=f"Train {epoch}", leave=False, position=1)
        for src, src_lengths, tgt in train_loader:
            src, src_lengths, tgt = src.to(device), src_lengths.to(device), tgt.to(device)

            optimizer.zero_grad()
            output = model(src, src_lengths, tgt, teacher_forcing_ratio=config['teacher_forcing'])
            loss = criterion(output.reshape(-1, output.size(-1)), tgt[:,1:].reshape(-1))
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            total_loss += loss.item()

        train_loader.close()
        train_loss = total_loss / len(data_loaders['train'])

        # Validation loss
        val_loss = 0.0
        val_loader = tqdm(data_loaders['dev'], desc=f"Val {epoch}", leave=False, position=1)
        model.eval()
        with torch.no_grad():
            for src, src_lengths, tgt in val_loader:
                src, src_lengths, tgt = src.to(device), src_lengths.to(device), tgt.to(device)
                output = model(src, src_lengths, tgt, teacher_forcing_ratio=0.0)  # No teacher forcing in validation
                val_loss += criterion(output.reshape(-1, output.size(-1)),
                                    tgt[:,1:].reshape(-1)).item()
        val_loader.close()
        val_loss /= len(data_loaders['dev'])

        # Compute accuracy metrics
        train_results = calculate_accuracy(model, data_loaders['train'], tgt_vocab, src_vocab, device)
        train_acc = train_results[0]

        val_results = calculate_accuracy(model, data_loaders['dev'], tgt_vocab, src_vocab, device)
        val_acc = val_results[0]

        # Save model if it's the best so far
        if val_acc > best_val_acc and save_path:
            best_val_acc = val_acc
            torch.save(model.state_dict(), save_path)
            print(f"New best model saved with validation accuracy: {val_acc:.4f}")

            # Save prediction analysis for milestone epochs
            if epoch == config['epochs'] or epoch % 5 == 0:
                correct_data = val_results[1]
                incorrect_data = val_results[2]

                save_predictions(
                    correct_data[0], correct_data[1], correct_data[2],
                    f"correct_predictions_epoch_{epoch}.csv"
                )

                save_predictions(
                    incorrect_data[0], incorrect_data[1], incorrect_data[2],
                    f"incorrect_predictions_epoch_{epoch}.csv"
                )

        # Log metrics
        print(f"Epoch {epoch}/{config['epochs']}:")
        print(f"  Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
        print(f"  Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")

        if log_to_wandb:
            wandb.log({
                'epoch': epoch,
                'train_loss': train_loss,
                'validation_loss': val_loss,
                'train_accuracy': train_acc,
                'validation_accuracy': val_acc
            })

    # Final evaluation on test set
    test_results = calculate_accuracy(model, data_loaders['test'], tgt_vocab, src_vocab, device)
    test_acc = test_results[0]
    print(f"Final test accuracy: {test_acc:.4f}")

    if log_to_wandb:
        wandb.log({'test_accuracy': test_acc})

    # Save final prediction analysis
    correct_data = test_results[1]
    incorrect_data = test_results[2]

    save_predictions(
        correct_data[0], correct_data[1], correct_data[2],
        "correct_predictions_final.csv"
    )

    save_predictions(
        incorrect_data[0], incorrect_data[1], incorrect_data[2],
        "incorrect_predictions_final.csv"
    )

    return model, test_acc

# Hyperparameter sweep configuration
def get_sweep_config():
    """Define the hyperparameter sweep configuration for wandb"""
    sweep_config = {
        'method': 'bayes',  # Use Bayesian optimization
        'name': 'Transliteration_with_Attention',
        'metric': {'name': 'validation_accuracy', 'goal': 'maximize'},
        'parameters': {
            # Model architecture
            'embedding_dim': {'values': [128, 256, 512]},
            'hidden_dim': {'values': [128, 256, 512, 1024]},
            'num_layers': {'values': [1, 2, 3, 4]},
            'rnn_type': {'values': ['RNN', 'GRU', 'LSTM']},
            'bidirectional': {'values': [True, False]},

            # Training parameters
            'dropout': {'values': [0.0, 0.1, 0.2, 0.3, 0.5]},
            'learning_rate': {'values': [1e-4, 2e-4, 5e-4, 8e-4, 1e-3]},
            'batch_size': {'values': [32, 64, 128]},
            'epochs': {'values': [10, 15, 20]},
            'teacher_forcing': {'values': [0.3, 0.5, 0.7, 1.0]},
            'optimizer': {'values': ['Adam', 'NAdam']},
            'seed': {'values': [42, 43, 44, 45, 46]},
        }
    }
    return sweep_config

def run_sweep_objective():
    """Objective function for wandb sweep"""
    # Initialize wandb run and get config
    run = wandb.init()
    config = wandb.config

    # Convert to a normal dictionary for our function
    experiment_config = {
        'language': 'te',  # Telugu
        'rnn_type': config.rnn_type,
        'embedding_dim': config.embedding_dim,
        'hidden_dim': config.hidden_dim,
        'num_layers': config.num_layers,
        'dropout': config.dropout,
        'bidirectional': config.bidirectional,
        'batch_size': config.batch_size,
        'epochs': config.epochs,
        'learning_rate': config.learning_rate,
        'teacher_forcing': config.teacher_forcing,
        'optimizer': config.optimizer,
        'seed': config.seed
    }

    # Set seeds for reproducibility
    set_random_seeds(experiment_config['seed'])

    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # Download and extract the dataset if necessary
    if not os.path.exists('/content/dakshina_dataset_v1.0'):
        print("Downloading Dakshina dataset...")
        !wget "https://storage.googleapis.com/gresearch/dakshina/dakshina_dataset_v1.0.tar"
        !tar xopf dakshina_dataset_v1.0.tar

    # Create a unique run name based on config
    run_name = f"{experiment_config['rnn_type']}_{experiment_config['num_layers']}l_{experiment_config['embedding_dim']}e_{experiment_config['hidden_dim']}h_" \
               f"{'bid' if experiment_config['bidirectional'] else 'uni'}_{experiment_config['dropout']}d_" \
               f"{experiment_config['teacher_forcing']}tf_{experiment_config['optimizer']}"
    wandb.run.name = run_name

    # Load data
    print(f"Loading {experiment_config['language']} data...")
    data_loaders, src_vocab, tgt_vocab = load_data(
        language=experiment_config['language'],
        batch_size=experiment_config['batch_size'],
        device=device
    )

    # Create model components
    print("Building model with attention...")
    encoder = RNNEncoder(
        src_vocab.vocab_size,
        experiment_config['embedding_dim'],
        experiment_config['hidden_dim'],
        num_layers=experiment_config['num_layers'],
        rnn_type=experiment_config['rnn_type'],
        dropout=experiment_config['dropout'],
        bidirectional=experiment_config['bidirectional']
    ).to(device)

    # Calculate encoder output dimension (doubled if bidirectional)
    encoder_output_dim = experiment_config['hidden_dim'] * 2 if experiment_config['bidirectional'] else experiment_config['hidden_dim']

    decoder = AttentionDecoder(
        tgt_vocab.vocab_size,
        experiment_config['embedding_dim'],
        encoder_output_dim,
        experiment_config['hidden_dim'],
        num_layers=experiment_config['num_layers'],
        rnn_type=experiment_config['rnn_type'],
        dropout=experiment_config['dropout']
    ).to(device)

    model = Seq2SeqWithAttention(encoder, decoder, pad_idx=src_vocab.pad_id, device=device).to(device)

    # Train the model
    print("Training model...")
    model_save_path = f"model_{run_name}.pt"

    model, test_acc = train_model(
        model=model,
        data_loaders=data_loaders,
        src_vocab=src_vocab,
        tgt_vocab=tgt_vocab,
        device=device,
        config=experiment_config,
        save_path=model_save_path,
        log_to_wandb=True
    )

    # Wandb finish happens automatically when this function returns

def run_transliteration_experiment(config=None, use_wandb=True, run_sweep=False, sweep_count=20):
    """Run a transliteration experiment with attention"""

    if run_sweep:
        # Run a hyperparameter sweep
        sweep_config = get_sweep_config()
        sweep_id = wandb.sweep(sweep_config, project="DA6401_Assignment_3")
        wandb.agent(sweep_id, function=run_sweep_objective, count=sweep_count)
        return None, None, None

    # Run a single experiment with the given config
    if config is None:
        config = {
            'language': 'te',  # Telugu
            'rnn_type': 'LSTM',
            'embedding_dim': 256,
            'hidden_dim': 512,
            'num_layers': 2,
            'dropout': 0.3,
            'bidirectional': True,
            'batch_size': 64,
            'epochs': 10,
            'learning_rate': 0.001,
            'teacher_forcing': 0.5,
            'optimizer': 'Adam',
            'seed': 42
        }

    # Set seeds for reproducibility
    set_random_seeds(config['seed'])

    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # Initialize wandb
    if use_wandb:
        run_name = f"{config['rnn_type']}_{config['num_layers']}l_{config['embedding_dim']}e_{config['hidden_dim']}h_" \
                  f"{'bid' if config['bidirectional'] else 'uni'}_{config['dropout']}d_" \
                  f"{config['teacher_forcing']}tf_{config['optimizer']}"

        wandb.init(
            project="DA6401_Assignment_3",
            name=run_name,
            config=config
        )

    # Download and extract the dataset if necessary
    if not os.path.exists('/content/dakshina_dataset_v1.0'):
        print("Downloading Dakshina dataset...")
        !wget "https://storage.googleapis.com/gresearch/dakshina/dakshina_dataset_v1.0.tar"
        !tar xopf dakshina_dataset_v1.0.tar

    # Load data
    print(f"Loading {config['language']} data...")
    data_loaders, src_vocab, tgt_vocab = load_data(
        language=config['language'],
        batch_size=config['batch_size'],
        device=device
    )

    # Create model components
    print("Building model with attention...")
    encoder = RNNEncoder(
        src_vocab.vocab_size,
        config['embedding_dim'],
        config['hidden_dim'],
        num_layers=config['num_layers'],
        rnn_type=config['rnn_type'],
        dropout=config['dropout'],
        bidirectional=config['bidirectional']
    ).to(device)

    # Calculate encoder output dimension (doubled if bidirectional)
    encoder_output_dim = config['hidden_dim'] * 2 if config['bidirectional'] else config['hidden_dim']

    # Create attention decoder - explicit attention model
    decoder = AttentionDecoder(
        tgt_vocab.vocab_size,
        config['embedding_dim'],
        encoder_output_dim,
        config['hidden_dim'],
        num_layers=config['num_layers'],
        rnn_type=config['rnn_type'],
        dropout=config['dropout']
    ).to(device)

    model = Seq2SeqWithAttention(encoder, decoder, pad_idx=src_vocab.pad_id, device=device).to(device)

    # Train the model
    print("Training model with attention...")
    model_save_path = f"transliteration_model_attention_{config['language']}_{config['rnn_type']}.pt"

    model, test_acc = train_model(
        model=model,
        data_loaders=data_loaders,
        src_vocab=src_vocab,
        tgt_vocab=tgt_vocab,
        device=device,
        config=config,
        save_path=model_save_path,
        log_to_wandb=use_wandb
    )

    print(f"Training complete! Final test accuracy: {test_acc:.4f}")
    print(f"Model saved to {model_save_path}")

    # Finish the wandb run
    if use_wandb:
        wandb.finish()

    return model, src_vocab, tgt_vocab

# Run the experiment when executed directly
if __name__ == "__main__":
    # Run a single experiment with attention
    run_transliteration_experiment(use_wandb=True)

    # To run a hyperparameter sweep:
    # run_transliteration_experiment(use_wandb=True, run_sweep=True, sweep_count=20)