In [1]:
%%bash
cat > main.py << 'EOF'
import argparse
import random
import time
import os
import csv
from typing import List, Tuple, Dict

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import wandb

# ─── CharVocab & Dataset ─────────────────────────────────────────────────────
class CharVocab:
    def __init__(self, filepaths: List[str]):
        self.rom_char2idx: Dict[str,int] = {}
        self.dev_char2idx: Dict[str,int] = {}
        self.rom_idx2char: Dict[int,str] = {}
        self.dev_idx2char: Dict[int,str] = {}
        self._build_vocab(filepaths)

    def _build_vocab(self, filepaths: List[str]):
        rom_chars = set()
        dev_chars = set()
        for fp in filepaths:
            with open(fp, "r", encoding="utf-8") as f:
                reader = csv.reader(f, delimiter="\t")
                for row in reader:
                    if len(row) < 2:
                        continue
                    devanagari = row[0].strip()
                    roman      = row[1].strip()
                    rom_chars.update(list(roman))
                    dev_chars.update(list(devanagari))

        PAD, SOS, EOS = "<pad>", "<sos>", "<eos>"

        all_rom = [PAD, SOS, EOS] + sorted(rom_chars)
        for i, ch in enumerate(all_rom):
            self.rom_char2idx[ch] = i
            self.rom_idx2char[i] = ch

        all_dev = [PAD, SOS, EOS] + sorted(dev_chars)
        for i, ch in enumerate(all_dev):
            self.dev_char2idx[ch] = i
            self.dev_idx2char[i] = ch

        self.rom_pad_idx = self.rom_char2idx[PAD]
        self.rom_sos_idx = self.rom_char2idx[SOS]
        self.rom_eos_idx = self.rom_char2idx[EOS]

        self.dev_pad_idx = self.dev_char2idx[PAD]
        self.dev_sos_idx = self.dev_char2idx[SOS]
        self.dev_eos_idx = self.dev_char2idx[EOS]

    @property
    def rom_vocab_size(self) -> int:
        return len(self.rom_char2idx)

    @property
    def dev_vocab_size(self) -> int:
        return len(self.dev_char2idx)

    def roman_to_indices(self, s: str) -> List[int]:
        return [self.rom_sos_idx] + [self.rom_char2idx[ch] for ch in s] + [self.rom_eos_idx]

    def dev_to_indices(self, s: str) -> List[int]:
        return [self.dev_sos_idx] + [self.dev_char2idx[ch] for ch in s] + [self.dev_eos_idx]

    def indices_to_dev(self, idxs: List[int]) -> str:
        chars = []
        for i in idxs:
            if i in (self.dev_sos_idx, self.dev_eos_idx, self.dev_pad_idx):
                continue
            chars.append(self.dev_idx2char[i])
        return "".join(chars)


def read_tsv(path: str) -> List[Tuple[str, str]]:
    """
    Expects each line of the TSV to be:
      Devanagari_word    Roman_word    <something_to_ignore>
    We only need (Roman, Devanagari) for training.
    """
    pairs = []
    with open(path, "r", encoding="utf-8") as f:
        reader = csv.reader(f, delimiter="\t")
        for row in reader:
            # If there are fewer than 2 columns, skip
            if len(row) < 2:
                continue

            # Unpack: Dev is first column, Roman is second, ignore anything else
            devana = row[0].strip()
            roman  = row[1].strip()

            if not roman or not devana:
                continue
            # Append (roman, devana)—this matches our CharVocab convention
            pairs.append((roman, devana))
    return pairs


class TransliterationDataset(Dataset):
    def __init__(self, filepath, vocab):
        super().__init__()
        self.pairs = read_tsv(filepath)
        self.vocab = vocab

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        roman, devanagari = self.pairs[idx]
        roman_idxs = self.vocab.roman_to_indices(roman)
        dev_idxs = self.vocab.dev_to_indices(devanagari)
        return torch.tensor(roman_idxs, dtype=torch.long), torch.tensor(dev_idxs, dtype=torch.long)

    @staticmethod
    def collate_fn(batch):
        roman_seqs, dev_seqs = zip(*batch)
        max_rom_len = max(len(x) for x in roman_seqs)
        max_dev_len = max(len(x) for x in dev_seqs)
        rom_padded = []
        dev_padded = []
        for r, d in zip(roman_seqs, dev_seqs):
            pad_r = torch.cat([r, r.new_full((max_rom_len - len(r),), r.new_tensor(0))])
            pad_d = torch.cat([d, d.new_full((max_dev_len - len(d),), d.new_tensor(0))])
            rom_padded.append(pad_r)
            dev_padded.append(pad_d)
        return torch.stack(rom_padded), torch.stack(dev_padded)


# ─── Attention Module ─────────────────────────────────────────────────────────
class BahdanauAttention(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        # For computing alignment scores
        self.W1 = nn.Linear(hidden_size, hidden_size, bias=False)
        self.W2 = nn.Linear(hidden_size, hidden_size, bias=False)
        self.V = nn.Linear(hidden_size, 1, bias=False)

    def forward(self, hidden, encoder_outputs):
        # hidden: (num_layers, batch_size, hidden_size) --> take top layer
        # encoder_outputs: (batch_size, seq_len, hidden_size)
        # Extract the top-layer hidden state for computing scores:
        h = hidden[-1]  # (B, H)
        seq_len = encoder_outputs.size(1)

        # Expand h to (B, seq_len, H) so we can add it to encoder_outputs-transformed
        h_expanded = h.unsqueeze(1).repeat(1, seq_len, 1)  # (B, T, H)

        # alignment: score = V^T * tanh(W1(enc_out) + W2(h_expanded))
        score = self.V(torch.tanh(self.W1(encoder_outputs) + self.W2(h_expanded)))  # (B, T, 1)

        # Compute attention weights over all encoder time steps:
        attention_weights = F.softmax(score.squeeze(2), dim=1)  # (B, T)

        # Weighted sum of encoder_outputs, giving a context vector of shape (B, H)
        context = torch.bmm(attention_weights.unsqueeze(1), encoder_outputs)  # (B, 1, H)
        context = context.squeeze(1)  # (B, H)

        return context, attention_weights


# ─── Encoder ───────────────────────────────────────────────────────────────────
class Encoder(nn.Module):
    def __init__(self, input_vocab_size, embed_dim, hidden_size, num_layers, cell_type, dropout):
        super().__init__()
        self.embed = nn.Embedding(input_vocab_size, embed_dim, padding_idx=0)
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.cell_type = cell_type.upper()

        if self.cell_type == "RNN":
            self.rnn = nn.RNN(
                embed_dim, hidden_size, num_layers=num_layers,
                batch_first=True, dropout=dropout if num_layers > 1 else 0.0,
            )
        elif self.cell_type == "GRU":
            self.rnn = nn.GRU(
                embed_dim, hidden_size, num_layers=num_layers,
                batch_first=True, dropout=dropout if num_layers > 1 else 0.0,
            )
        elif self.cell_type == "LSTM":
            self.rnn = nn.LSTM(
                embed_dim, hidden_size, num_layers=num_layers,
                batch_first=True, dropout=dropout if num_layers > 1 else 0.0,
            )
        else:
            raise ValueError(f"Unsupported cell type: {cell_type}")

    def forward(self, x):
        # x: (batch_size, seq_len)
        emb = self.embed(x)  # (B, T, E)
        if self.cell_type == "LSTM":
            outputs, (h_n, c_n) = self.rnn(emb)
            return outputs, (h_n, c_n)
        else:
            outputs, h_n = self.rnn(emb)
            return outputs, h_n


# ─── Decoder with Attention ────────────────────────────────────────────────────
class Decoder(nn.Module):
    def __init__(self, output_vocab_size, embed_dim, hidden_size, num_layers, cell_type, dropout):
        super().__init__()
        self.embed = nn.Embedding(output_vocab_size, embed_dim, padding_idx=0)
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.cell_type = cell_type.upper()

        # The RNN input dimension is now (embedding + context_vector)
        self.input_dim = embed_dim + hidden_size

        if self.cell_type == "RNN":
            self.rnn = nn.RNN(
                self.input_dim, hidden_size, num_layers=num_layers,
                batch_first=True, dropout=dropout if num_layers > 1 else 0.0
            )
        elif self.cell_type == "GRU":
            self.rnn = nn.GRU(
                self.input_dim, hidden_size, num_layers=num_layers,
                batch_first=True, dropout=dropout if num_layers > 1 else 0.0
            )
        elif self.cell_type == "LSTM":
            self.rnn = nn.LSTM(
                self.input_dim, hidden_size, num_layers=num_layers,
                batch_first=True, dropout=dropout if num_layers > 1 else 0.0
            )
        else:
            raise ValueError(f"Unsupported cell type: {cell_type}")

        # Bahdanau attention module
        self.attention = BahdanauAttention(hidden_size)

        # Final linear layer to project RNN output to vocabulary size
        self.out = nn.Linear(hidden_size, output_vocab_size)

    def forward(self, tgt_seq, hidden, cell=None, encoder_outputs=None, teacher_forcing_ratio=0.0):
        """
        tgt_seq: (batch_size, T_tgt)         -- includes <sos> ... <eos>
        hidden:   (num_layers, batch_size, hidden_size)
        cell:     (num_layers, batch_size, hidden_size)  # only for LSTM
        encoder_outputs: (batch_size, T_src, hidden_size)
        """
        B, T = tgt_seq.size()
        outputs = torch.zeros(B, T, self.out.out_features, device=tgt_seq.device)

        # The first input token to the decoder is always <sos>
        input_step = tgt_seq[:, 0]  # (B,)

        # Initialize hidden and cell states
        if self.cell_type == "LSTM":
            h = hidden  # (num_layers, B, H)
            c = cell
        else:
            h = hidden
            c = None

        for t in range(1, T):
            emb_t = self.embed(input_step)  # (B, E)

            # 1) Compute context vector via attention:
            context, attn_weights = self.attention(h, encoder_outputs)  # context: (B, H)

            # 2) Concatenate embedding and context vector, then feed to RNN:
            rnn_input = torch.cat([emb_t, context], dim=1).unsqueeze(1)  # (B, 1, E+H)

            if self.cell_type == "LSTM":
                out_t, (h, c) = self.rnn(rnn_input, (h, c))  # out_t: (B, 1, H)
            else:
                out_t, h = self.rnn(rnn_input, h)  # out_t: (B, 1, H)
                c = None

            logits = self.out(out_t.squeeze(1))  # (B, V)
            outputs[:, t, :] = logits

            # Decide next input (teacher forcing or predicted token)
            teacher_force = (torch.rand(1).item() < teacher_forcing_ratio)
            top1 = logits.argmax(dim=1)  # (B,)
            next_input = tgt_seq[:, t] if teacher_force else top1
            input_step = next_input

        return outputs

    def beam_search_decode(
        self,
        encoder_hidden,
        encoder_cell,
        encoder_outputs,
        max_len,
        dev_sos_idx,
        dev_eos_idx,
        beam_size
    ):
        """
        Simplified batch_size=1 beam search using attention.
        encoder_hidden: hidden state to initialize decoder
        encoder_cell:   cell state for LSTM decoder (None if not LSTM)
        encoder_outputs: (1, T_src, H)
        """
        hidden, cell = encoder_hidden, encoder_cell
        live = [([dev_sos_idx], hidden, cell, 0.0)]
        completed = []

        for _ in range(max_len):
            new_hyps = []
            for (seq, h, c, score) in live:
                last_token = seq[-1]
                if last_token == dev_eos_idx:
                    completed.append((seq, h, c, score))
                    continue

                inp = torch.tensor([last_token], dtype=torch.long, device=encoder_outputs.device).unsqueeze(0)
                emb_t = self.embed(inp).squeeze(1)  # (1, E)

                # Compute attention‐based context for this step:
                context, attn_weights = self.attention(h, encoder_outputs)  # (1, H)

                rnn_input = torch.cat([emb_t, context], dim=1).unsqueeze(1)  # (1, 1, E+H)

                if self.cell_type == "LSTM":
                    out_t, (h2, c2) = self.rnn(rnn_input, (h, c))  # (1, 1, H)
                else:
                    out_t, h2 = self.rnn(rnn_input, h)
                    c2 = None

                logits = self.out(out_t.squeeze(1))  # (1, V)
                log_probs = F.log_softmax(logits, dim=1)  # (1, V)

                topk_logprobs, topk_indices = torch.topk(log_probs, k=beam_size, dim=1)
                topk_logprobs = topk_logprobs.squeeze(0).tolist()
                topk_indices = topk_indices.squeeze(0).tolist()

                for lp, idx in zip(topk_logprobs, topk_indices):
                    new_hyps.append((seq + [idx], h2, c2, score + lp))

            # Retain only top k hypotheses
            new_hyps = sorted(new_hyps, key=lambda x: x[3], reverse=True)[:beam_size]
            live = new_hyps

            # If all live hypotheses have already ended with <eos>, stop
            if all(hyp[0][-1] == dev_eos_idx for hyp in live):
                completed.extend(live)
                break

        if not completed:
            completed = live

        best = max(completed, key=lambda x: x[3])
        return best[0]

# ─── Seq2Seq ────────────────────────────────────────────────────────────────────
class Seq2Seq(nn.Module):
    def __init__(self, encoder: Encoder, decoder: Decoder, device: torch.device):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device

    def forward(self, src: torch.Tensor, tgt: torch.Tensor, teacher_forcing_ratio: float = 0.5) -> torch.Tensor:
        """
        src: (batch_size, T_src)
        tgt: (batch_size, T_tgt)      # includes <sos> ... <eos>
        Returns: logits over vocabulary at each decoder timestep
                 shape = (batch_size, T_tgt, V_out)
        """
        B, T_src = src.size()
        outputs = torch.zeros(B, tgt.size(1), self.decoder.out.out_features, device=self.device)

        # 1) Run the encoder
        if self.encoder.cell_type == "LSTM":
            enc_outputs, (h_n, c_n) = self.encoder(src)
        else:
            enc_outputs, h_n = self.encoder(src)
            c_n = None

        # 2) Prepare initial decoder hidden & cell states from encoder’s final states
        enc_layers = self.encoder.num_layers
        dec_layers = self.decoder.num_layers
        hidden_size = self.encoder.hidden_size

        if self.encoder.cell_type == "LSTM":
            if enc_layers >= dec_layers:
                h_dec = h_n[-dec_layers:]
                c_dec = c_n[-dec_layers:]
            else:
                num_missing = dec_layers - enc_layers
                zeros_h = torch.zeros(num_missing, B, hidden_size, device=self.device)
                zeros_c = torch.zeros(num_missing, B, hidden_size, device=self.device)
                h_dec = torch.cat([zeros_h, h_n], dim=0)
                c_dec = torch.cat([zeros_c, c_n], dim=0)
            dec_hidden = h_dec
            dec_cell = c_dec
        else:
            if enc_layers >= dec_layers:
                h_dec = h_n[-dec_layers:]
            else:
                num_missing = dec_layers - enc_layers
                zeros_h = torch.zeros(num_missing, B, hidden_size, device=self.device)
                h_dec = torch.cat([zeros_h, h_n], dim=0)
            dec_hidden = h_dec
            dec_cell = None

        # 3) Decode with attention
        logits = self.decoder(
            tgt_seq=tgt,
            hidden=dec_hidden,
            cell=dec_cell,
            encoder_outputs=enc_outputs,
            teacher_forcing_ratio=teacher_forcing_ratio,
        )
        return logits

    @torch.no_grad()
    def predict(self, src: torch.Tensor, max_len: int, dev_sos_idx: int, dev_eos_idx: int, beam_size: int = 1):
        """
        Run inference (batch_size=1), either greedily (beam_size=1) or with beam search.
        Returns a list of token indices (the decoded sequence).
        """
        self.eval()
        B = 1

        # 1) Encode the source
        if self.encoder.cell_type == "LSTM":
            enc_outputs, (h_n, c_n) = self.encoder(src)
        else:
            enc_outputs, h_n = self.encoder(src)
            c_n = None

        # Prepare decoder initial states
        enc_layers = self.encoder.num_layers
        dec_layers = self.decoder.num_layers
        hidden_size = self.encoder.hidden_size

        if self.encoder.cell_type == "LSTM":
            if enc_layers >= dec_layers:
                h_dec = h_n[-dec_layers:]
                c_dec = c_n[-dec_layers:]
            else:
                num_missing = dec_layers - enc_layers
                zeros_h = torch.zeros(num_missing, B, hidden_size, device=self.device)
                zeros_c = torch.zeros(num_missing, B, hidden_size, device=self.device)
                h_dec = torch.cat([zeros_h, h_n], dim=0)
                c_dec = torch.cat([zeros_c, c_n], dim=0)
            hidden_state = h_dec
            cell_state = c_dec
        else:
            if enc_layers >= dec_layers:
                h_dec = h_n[-dec_layers:]
            else:
                num_missing = dec_layers - enc_layers
                zeros_h = torch.zeros(num_missing, B, hidden_size, device=self.device)
                h_dec = torch.cat([zeros_h, h_n], dim=0)
            hidden_state = h_dec
            cell_state = None

        if beam_size == 1:
            # GREEDY DECODING
            seq = [dev_sos_idx]
            hidden_, cell_ = hidden_state, cell_state

            for _ in range(max_len):
                last_token = torch.tensor([seq[-1]], dtype=torch.long, device=self.device).unsqueeze(0)  # (1,1)
                emb = self.decoder.embed(last_token).squeeze(1)  # (1, E)

                # Compute attention context
                context, attn_weights = self.decoder.attention(hidden_, enc_outputs)  # (1, H)

                rnn_input = torch.cat([emb, context], dim=1).unsqueeze(1)  # (1, 1, E+H)

                if self.decoder.cell_type == "LSTM":
                    out, (h_next, c_next) = self.decoder.rnn(rnn_input, (hidden_, cell_))
                    hidden_, cell_ = h_next, c_next
                else:
                    out, h_next = self.decoder.rnn(rnn_input, hidden_)
                    hidden_, cell_ = h_next, None

                logits = self.decoder.out(out.squeeze(1))  # (1, V)
                next_token = logits.argmax(dim=1).item()
                seq.append(next_token)
                if next_token == dev_eos_idx:
                    break

            return seq
        else:
            # BEAM SEARCH DECODING
            best_seq = self.decoder.beam_search_decode(
                encoder_hidden=hidden_state,
                encoder_cell=cell_state,
                encoder_outputs=enc_outputs,
                max_len=max_len,
                dev_sos_idx=dev_sos_idx,
                dev_eos_idx=dev_eos_idx,
                beam_size=beam_size,
            )
            return best_seq

# ─── train / evaluate ───────────────────────────────────────────────────────
def char_accuracy(
    logits: torch.Tensor,
    target: torch.Tensor,
    pad_idx: int,
    sos_idx: int,
    eos_idx: int
) -> Tuple[int, int]:
    """
    Compute (num_correct_chars, num_valid_chars) ignoring <pad>, <sos>, and <eos>.
    logits: (B, T, V)    – raw decoder logits
    target: (B, T)       – ground‐truth indices (including <sos>, <eos>, <pad>)
    pad_idx, sos_idx, eos_idx: indices for the special tokens

    Returns:
       (correct_count, valid_count)
    """
    with torch.no_grad():
        B, T, V = logits.size()
        pred = logits.argmax(dim=2)   # (B, T)

        # Mask out positions where target is pad, sos, or eos
        ignore_mask = (
            (target != pad_idx) &
            (target != sos_idx) &
            (target != eos_idx)
        )  # (B, T)

        correct = ((pred == target) & ignore_mask).sum().item()
        valid   = ignore_mask.sum().item()
        return correct, valid


def train_one_epoch(
    model: Seq2Seq,
    iterator: DataLoader,
    optimizer: torch.optim.Optimizer,
    criterion: nn.CrossEntropyLoss,
    pad_idx: int,
    sos_idx: int,
    eos_idx: int,
    device: torch.device,
    teacher_forcing_ratio: float,
) -> Tuple[float, float, float]:
    """
    Runs one epoch of training, returning:
      (train_loss, train_char_acc, train_word_acc)

    – train_char_acc is computed as (total_correct_chars / total_valid_chars)
    – train_word_acc is (num_exact_word_matches / total_words)
    """
    model.train()
    epoch_loss = 0.0

    # For micro‐average char‐accuracy:
    total_correct_chars = 0
    total_valid_chars   = 0

    # For word‐accuracy (we’ll count how many sequences are exactly correct)
    total_exact_words = 0
    total_words       = 0

    for src, tgt in iterator:
        src = src.to(device, non_blocking=True)  # (B, T_src)
        tgt = tgt.to(device, non_blocking=True)  # (B, T_tgt)
        B, T_tgt = tgt.size()

        optimizer.zero_grad()
        # Forward pass
        output_logits = model(src, tgt, teacher_forcing_ratio=teacher_forcing_ratio)
        # output_logits: (B, T_tgt, V_out)

        # 1) Compute token‐level (character) loss
        V = output_logits.size(-1)
        loss = criterion(output_logits.view(-1, V), tgt.view(-1))

        # 2) Compute char‐level correct/valid counts (ignore special tokens)
        correct, valid = char_accuracy(
            output_logits, tgt,
            pad_idx, sos_idx, eos_idx
        )
        total_correct_chars += correct
        total_valid_chars   += valid

        # 3) Compute word‐level accuracy for this batch
        #    We already have output_logits. Let’s get preds:
        with torch.no_grad():
            pred_inds = output_logits.argmax(dim=2)  # (B, T_tgt)

            # For each example, ignore index 0 (<sos>). Then check that
            #   for every position t where tgt[b,t] != pad_idx,
            #   pred_inds[b,t] == tgt[b,t].
            pred_trim = pred_inds[:, 1:]
            tgt_trim  = tgt[:,      1:]
            mask_trim = (tgt_trim != pad_idx)  # which positions to verify

            # A word is correct if for all non‐pad positions, pred == tgt
            match_trim    = (pred_trim == tgt_trim) | (~mask_trim)
            exact_matches = match_trim.all(dim=1)  # (B,)
            total_exact_words += exact_matches.sum().item()
            total_words       += B

        # 4) Backward + step
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
        optimizer.step()

        epoch_loss += loss.item()

    # Final epoch‐level metrics:
    train_loss = epoch_loss / len(iterator)
    train_char_acc = total_correct_chars / max(total_valid_chars, 1)
    train_word_acc = total_exact_words   / max(total_words, 1)

    return train_loss, train_char_acc, train_word_acc


def evaluate(
    model: Seq2Seq,
    iterator: DataLoader,
    criterion: nn.CrossEntropyLoss,
    pad_idx: int,
    sos_idx: int,
    eos_idx: int,
    device: torch.device,
    beam_size: int,
    max_dev_len: int,
) -> Tuple[float, float, float]:
    """
    One validation pass (no teacher forcing). Returns:
      (val_loss, val_char_acc, val_word_acc)

    – val_char_acc is (total_correct_chars / total_valid_chars)
    – val_word_acc is (num_exact_word_matches / total_words)
    """
    model.eval()
    epoch_loss = 0.0

    total_correct_chars = 0
    total_valid_chars   = 0
    total_exact_words   = 0
    total_words         = 0

    with torch.no_grad():
        for src, tgt in iterator:
            src = src.to(device, non_blocking=True)
            tgt = tgt.to(device, non_blocking=True)   # (B, T_tgt)
            B, T_tgt = tgt.size()

            # 1) Forward pass (teacher_forcing_ratio=0.0)
            logits = model(src, tgt, teacher_forcing_ratio=0.0)  # (B, T_tgt, V_out)
            V = logits.size(-1)
            loss = criterion(logits.view(-1, V), tgt.view(-1))

            # 2) Character‐level correct/valid counts (ignore special tokens)
            correct, valid = char_accuracy(
                logits, tgt,
                pad_idx, sos_idx, eos_idx
            )
            total_correct_chars += correct
            total_valid_chars   += valid

            # 3) Word‐level accuracy (greedy decode)
            pred_inds = logits.argmax(dim=2)  # (B, T_tgt)
            pred_trim = pred_inds[:, 1:]
            tgt_trim  = tgt[:,      1:]
            mask_trim = (tgt_trim != pad_idx)

            match_trim    = (pred_trim == tgt_trim) | (~mask_trim)
            exact_matches = match_trim.all(dim=1)  # (B,)
            total_exact_words += exact_matches.sum().item()
            total_words       += B

            epoch_loss += loss.item()

    val_loss = epoch_loss / len(iterator)
    val_char_acc = total_correct_chars / max(total_valid_chars, 1)
    val_word_acc = total_exact_words / max(total_words, 1)

    return val_loss, val_char_acc, val_word_acc


# ─── main() with parse_known_args ─────────────────────────────────────────────
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--data_dir", type=str, default="/content")
    parser.add_argument("--train_file", type=str, default="bn.translit.sampled.train.tsv")
    parser.add_argument("--dev_file", type=str, default="bn.translit.sampled.dev.tsv")
    parser.add_argument("--test_file", type=str, default="bn.translit.sampled.test.tsv")
    parser.add_argument("--emb_size", type=int, default=64)
    parser.add_argument("--hidden_size", type=int, default=64)
    parser.add_argument("--enc_layers", type=int, default=1)
    parser.add_argument("--dec_layers", type=int, default=1)
    parser.add_argument("--cell_type", type=str, default="RNN", choices=["RNN", "GRU", "LSTM"])
    parser.add_argument("--dropout", type=float, default=0.2)
    parser.add_argument("--lr", type=float, default=1e-3)
    parser.add_argument("--batch_size", type=int, default=128)
    parser.add_argument("--epochs", type=int, default=10)
    parser.add_argument("--tf_ratio", type=float, default=0.5)
    parser.add_argument("--beam_size", type=int, default=1)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--max_dev_len", type=int, default=32)
    parser.add_argument("--project_name", type=str, default="Attention_RNN")
    parser.add_argument("--run_name", type=str, default=None)

    # ── use parse_known_args to ignore Colab’s extra "-f …json"
    args, _ = parser.parse_known_args()

    random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    train_path = os.path.join(args.data_dir, args.train_file)
    dev_path   = os.path.join(args.data_dir, args.dev_file)
    test_path  = os.path.join(args.data_dir, args.test_file)
    vocab = CharVocab([train_path, dev_path, test_path])

    # Extract special‐token indices for convenience
    pad_idx = vocab.dev_pad_idx
    sos_idx = vocab.dev_sos_idx
    eos_idx = vocab.dev_eos_idx

    train_ds = TransliterationDataset(train_path, vocab)
    dev_ds   = TransliterationDataset(dev_path, vocab)
    test_ds  = TransliterationDataset(test_path, vocab)

    train_loader = DataLoader(
        train_ds,
        batch_size=args.batch_size,
        shuffle=True,
        collate_fn=TransliterationDataset.collate_fn,
        num_workers=2,
        pin_memory=True,
    )
    dev_loader = DataLoader(
        dev_ds,
        batch_size=args.batch_size,
        shuffle=False,
        collate_fn=TransliterationDataset.collate_fn,
        num_workers=2,
        pin_memory=True,
    )
    test_loader = DataLoader(
        test_ds,
        batch_size=1,
        shuffle=False,
        collate_fn=TransliterationDataset.collate_fn,
        num_workers=2,
        pin_memory=True,
    )

    if args.run_name is None:
        args.run_name = (
            f"{args.cell_type}"
            f"_emb{args.emb_size}"
            f"_hid{args.hidden_size}"
            f"_layers{args.enc_layers}x{args.dec_layers}"
            f"_lr{args.lr}"
            f"_tf{args.tf_ratio}"
        )

    wandb.init(
        project=args.project_name,
        name=args.run_name,
        config={
            "emb_size": args.emb_size,
            "hidden_size": args.hidden_size,
            "enc_layers": args.enc_layers,
            "dec_layers": args.dec_layers,
            "cell_type": args.cell_type,
            "dropout": args.dropout,
            "lr": args.lr,
            "batch_size": args.batch_size,
            "epochs": args.epochs,
            "tf_ratio": args.tf_ratio,
            "beam_size": args.beam_size,
        },
    )
    config = wandb.config

    encoder = Encoder(
        input_vocab_size=vocab.rom_vocab_size,
        embed_dim=config.emb_size,
        hidden_size=config.hidden_size,
        num_layers=config.enc_layers,
        cell_type=config.cell_type,
        dropout=config.dropout,
    )
    decoder = Decoder(
        output_vocab_size=vocab.dev_vocab_size,
        embed_dim=config.emb_size,
        hidden_size=config.hidden_size,
        num_layers=config.dec_layers,
        cell_type=config.cell_type,
        dropout=config.dropout,
    )
    model = Seq2Seq(encoder, decoder, device).to(device)
    criterion = nn.CrossEntropyLoss(ignore_index=vocab.dev_pad_idx)
    optimizer = torch.optim.Adam(model.parameters(), lr=config.lr, weight_decay=1e-5)

    best_val_acc = 0.0
    for epoch in range(1, config.epochs + 1):
        start_time = time.time()

        # 1) Train
        train_loss, train_char_acc, train_word_acc = train_one_epoch(
            model=model,
            iterator=train_loader,
            optimizer=optimizer,
            criterion=criterion,
            pad_idx=pad_idx,
            sos_idx=sos_idx,
            eos_idx=eos_idx,
            device=device,
            teacher_forcing_ratio=config.tf_ratio,
        )

        # 2) Evaluate
        val_loss, val_char_acc, val_word_acc = evaluate(
            model=model,
            iterator=dev_loader,
            criterion=criterion,
            pad_idx=pad_idx,
            sos_idx=sos_idx,
            eos_idx=eos_idx,
            device=device,
            beam_size=config.beam_size,
            max_dev_len=args.max_dev_len,
        )

        elapsed = time.time() - start_time

        # 3) Save best model based on val_char_acc
        if val_char_acc > best_val_acc:
            best_val_acc = val_char_acc
            torch.save(model.state_dict(), "best_model.pt")

        # 4) Log to W&B
        wandb.log(
            {
                "epoch": epoch,
                "train_loss": train_loss,
                "train_char_acc": train_char_acc,
                "train_word_acc": train_word_acc,
                "val_loss": val_loss,
                "val_char_acc": val_char_acc,
                "val_word_acc": val_word_acc,
                "epoch_time_sec": elapsed,
            }
        )

        # 5) Print progress
        print(
            f"Epoch {epoch:02d} | "
            f"Train Loss: {train_loss:.4f}, "
            f"Train Char Acc: {train_char_acc:.4f}, "
            f"Train Word Acc: {train_word_acc:.4f} | "
            f"Val Loss: {val_loss:.4f}, "
            f"Val Char Acc: {val_char_acc:.4f}, "
            f"Val Word Acc: {val_word_acc:.4f} | "
            f"Time: {elapsed:.1f}s"
        )

    wandb.finish()


if __name__ == "__main__":
    main()

EOF

In [2]:
%%bash
cat > sweep.yaml << 'EOF'
program: main.py
method: bayes
project: Attention_RNN
entity: mrsagarbiswas-iit-madras
metric:
  name: val_char_acc
  goal: maximize
parameters:
  emb_size:
    values: [128, 256]
  hidden_size:
    values: [128, 256]
  enc_layers:
    values: [3, 5]
  epochs:
    values: [15]
  dec_layers:
    values: [1, 2, 3]
  cell_type:
    values: ["RNN", "GRU", "LSTM"]
  dropout:
    values: [0.2, 0.3]
  lr:
    values: [1e-3, 1e-4]
  batch_size:
    values: [32, 64, 128]
  tf_ratio:
    values: [0.3, 0.5, 0.7]
  beam_size:
    values: [3, 5]
EOF


In [None]:
%%bash
export WANDB_API_KEY="Your-API-Key"
wandb login
SWEEP_ID=$(wandb sweep sweep.yaml)
echo "Sweep ID = $SWEEP_ID"

Sweep ID = 


wandb: Currently logged in as: mrsagarbiswas (mrsagarbiswas-iit-madras) to https://api.wandb.ai. Use `wandb login --relogin` to force relogin
wandb: Creating sweep from: sweep.yaml
wandb: Creating sweep with ID: 7hka6agi
wandb: View sweep at: https://wandb.ai/mrsagarbiswas-iit-madras/Attention_RNN/sweeps/7hka6agi
wandb: Run sweep agent with: wandb agent mrsagarbiswas-iit-madras/Attention_RNN/7hka6agi


In [None]:
!wandb agent mrsagarbiswas-iit-madras/Attention_RNN/7hka6agi --count 15

[34m[1mwandb[0m: Login to W&B to use the sweep agent feature
[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize?ref=models
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit: 
[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mmrsagarbiswas[0m ([33mmrsagarbiswas-iit-madras[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Starting wandb agent 🕵️
2025-05-19 19:30:19,723 - wandb.wandb_agent - INFO - Running runs: []
2025-05-19 19:30:19,969 - wandb.wandb_agent - INFO - Agent received command: run
2025-05-19 19:30:19,969 - wandb.wandb_agent - INFO - Agent starting run with config:
	batch_s