In [41]:
%%bash
cat > main.py << 'EOF'
import argparse
import random
import time
import os
import csv
from typing import List, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import wandb

# At startup, *before* creating any models/loaders:
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = False

# ─── CharVocab & Dataset ─────────────────────────────────────────────────────
class CharVocab:
    def __init__(self, filepaths: List[str]):
        self.rom_char2idx: Dict[str,int] = {}
        self.dev_char2idx: Dict[str,int] = {}
        self.rom_idx2char: Dict[int,str] = {}
        self.dev_idx2char: Dict[int,str] = {}
        self._build_vocab(filepaths)

    def _build_vocab(self, filepaths: List[str]):
        rom_chars = set()
        dev_chars = set()
        for fp in filepaths:
            with open(fp, "r", encoding="utf-8") as f:
                reader = csv.reader(f, delimiter="\t")
                for row in reader:
                    if len(row) < 2:
                        continue
                    devanagari = row[0].strip()
                    roman      = row[1].strip()
                    rom_chars.update(list(roman))
                    dev_chars.update(list(devanagari))

        PAD, SOS, EOS = "<pad>", "<sos>", "<eos>"

        all_rom = [PAD, SOS, EOS] + sorted(rom_chars)
        for i, ch in enumerate(all_rom):
            self.rom_char2idx[ch] = i
            self.rom_idx2char[i] = ch

        all_dev = [PAD, SOS, EOS] + sorted(dev_chars)
        for i, ch in enumerate(all_dev):
            self.dev_char2idx[ch] = i
            self.dev_idx2char[i] = ch

        self.rom_pad_idx = self.rom_char2idx[PAD]
        self.rom_sos_idx = self.rom_char2idx[SOS]
        self.rom_eos_idx = self.rom_char2idx[EOS]

        self.dev_pad_idx = self.dev_char2idx[PAD]
        self.dev_sos_idx = self.dev_char2idx[SOS]
        self.dev_eos_idx = self.dev_char2idx[EOS]

    # ─── Add these two properties ────────────────────────────────────────────
    @property
    def rom_vocab_size(self) -> int:
        return len(self.rom_char2idx)

    @property
    def dev_vocab_size(self) -> int:
        return len(self.dev_char2idx)
    # ───────────────────────────────────────────────────────────────────────────

    def roman_to_indices(self, s: str) -> List[int]:
        return [self.rom_sos_idx] + [self.rom_char2idx[ch] for ch in s] + [self.rom_eos_idx]

    def dev_to_indices(self, s: str) -> List[int]:
        return [self.dev_sos_idx] + [self.dev_char2idx[ch] for ch in s] + [self.dev_eos_idx]

    def indices_to_dev(self, idxs: List[int]) -> str:
        chars = []
        for i in idxs:
            if i in (self.dev_sos_idx, self.dev_eos_idx, self.dev_pad_idx):
                continue
            chars.append(self.dev_idx2char[i])
        return "".join(chars)



def read_tsv(path: str) -> List[Tuple[str, str]]:
    """
    Expects each line of the TSV to be:
      Devanagari_word    Roman_word    <something_to_ignore>
    We only need (Roman, Devanagari) for training.
    """
    pairs = []
    with open(path, "r", encoding="utf-8") as f:
        reader = csv.reader(f, delimiter="\t")
        for row in reader:
            # If there are fewer than 2 columns, skip
            if len(row) < 2:
                continue

            # Unpack: Dev is first column, Roman is second, ignore anything else
            devana = row[0].strip()
            roman  = row[1].strip()

            if not roman or not devana:
                continue
            # Append (roman, devana)—this matches our CharVocab convention
            pairs.append((roman, devana))
    return pairs



class TransliterationDataset(Dataset):
    def __init__(self, filepath, vocab):
        super().__init__()
        self.pairs = read_tsv(filepath)
        self.vocab = vocab

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        roman, devanagari = self.pairs[idx]
        roman_idxs = self.vocab.roman_to_indices(roman)
        dev_idxs = self.vocab.dev_to_indices(devanagari)
        return torch.tensor(roman_idxs, dtype=torch.long), torch.tensor(dev_idxs, dtype=torch.long)

    @staticmethod
    def collate_fn(batch):
        roman_seqs, dev_seqs = zip(*batch)
        max_rom_len = max(len(x) for x in roman_seqs)
        max_dev_len = max(len(x) for x in dev_seqs)
        rom_padded = []
        dev_padded = []
        for r, d in zip(roman_seqs, dev_seqs):
            pad_r = torch.cat([r, r.new_full((max_rom_len - len(r),), r.new_tensor(0))])
            pad_d = torch.cat([d, d.new_full((max_dev_len - len(d),), d.new_tensor(0))])
            rom_padded.append(pad_r)
            dev_padded.append(pad_d)
        return torch.stack(rom_padded), torch.stack(dev_padded)


# ─── Encoder & Decoder ────────────────────────────────────────────────────────
class Encoder(nn.Module):
    def __init__(self, input_vocab_size, embed_dim, hidden_size, num_layers, cell_type, dropout):
        super().__init__()
        self.embed = nn.Embedding(input_vocab_size, embed_dim, padding_idx=0)
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.cell_type = cell_type.upper()
        if self.cell_type == "RNN":
            self.rnn = nn.RNN(
                embed_dim, hidden_size, num_layers=num_layers,
                batch_first=True, dropout=dropout if num_layers > 1 else 0.0,
            )
        elif self.cell_type == "GRU":
            self.rnn = nn.GRU(
                embed_dim, hidden_size, num_layers=num_layers,
                batch_first=True, dropout=dropout if num_layers > 1 else 0.0,
            )
        elif self.cell_type == "LSTM":
            self.rnn = nn.LSTM(
                embed_dim, hidden_size, num_layers=num_layers,
                batch_first=True, dropout=dropout if num_layers > 1 else 0.0,
            )
        else:
            raise ValueError(f"Unsupported cell type: {cell_type}")

    def forward(self, x):
        emb = self.embed(x)
        if self.cell_type == "LSTM":
            outputs, (h_n, c_n) = self.rnn(emb)
            return outputs, (h_n, c_n)
        else:
            outputs, h_n = self.rnn(emb)
            return outputs, h_n


class Decoder(nn.Module):
    def __init__(self, output_vocab_size, embed_dim, hidden_size, num_layers, cell_type, dropout):
        super().__init__()
        self.embed = nn.Embedding(output_vocab_size, embed_dim, padding_idx=0)
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.cell_type = cell_type.upper()
        if self.cell_type == "RNN":
            self.rnn = nn.RNN(
                embed_dim, hidden_size, num_layers=num_layers,
                batch_first=True, dropout=dropout if num_layers > 1 else 0.0
            )
        elif self.cell_type == "GRU":
            self.rnn = nn.GRU(
                embed_dim, hidden_size, num_layers=num_layers,
                batch_first=True, dropout=dropout if num_layers > 1 else 0.0
            )
        elif self.cell_type == "LSTM":
            self.rnn = nn.LSTM(
                embed_dim, hidden_size, num_layers=num_layers,
                batch_first=True, dropout=dropout if num_layers > 1 else 0.0
            )
        else:
            raise ValueError(f"Unsupported cell type: {cell_type}")

        self.out = nn.Linear(hidden_size, output_vocab_size)

    def forward(self, tgt_seq, hidden, cell=None, teacher_forcing_ratio=0.0):
        B, T = tgt_seq.size()
        outputs = torch.zeros(B, T, self.out.out_features, device=tgt_seq.device)
        input_step = tgt_seq[:, 0]
        if self.cell_type == "LSTM":
            # hidden is h_n, and `cell` is c_n
            h = hidden       # h_n: (num_layers, B, hidden_size)
            c = cell         # c_n: (num_layers, B, hidden_size)
        else:
            # hidden is the single tensor from RNN/GRU; no cell
            h = hidden       # (num_layers, B, hidden_size)
            c = None

        for t in range(1, T):
            emb_t = self.embed(input_step).unsqueeze(1)
            if self.cell_type == "LSTM":
                out_step, (h, c) = self.rnn(emb_t, (h, c))
            else:
                out_step, h = self.rnn(emb_t, h)
            logits = self.out(out_step.squeeze(1))
            outputs[:, t, :] = logits

            teacher_force = (torch.rand(1).item() < teacher_forcing_ratio)
            top1 = logits.argmax(dim=1)
            next_input = tgt_seq[:, t] if teacher_force else top1
            input_step = next_input.view(-1)

        return outputs

    def beam_search_decode(self, encoder_hidden, encoder_cell, max_len, dev_sos_idx, dev_eos_idx, beam_size):
        hidden, cell = encoder_hidden, encoder_cell
        Hyp = lambda seq, h, c, scr: (seq, h, c, scr)
        live = [( [dev_sos_idx], hidden, cell, 0.0 )]
        completed = []

        for _ in range(max_len):
            new_hyps = []
            for (seq, h, c, score) in live:
                last_token = seq[-1]
                if last_token == dev_eos_idx:
                    completed.append((seq, h, c, score))
                    continue

                inp = torch.tensor([last_token], dtype=torch.long, device=h.device).unsqueeze(0)
                emb_t = self.embed(inp)

                if self.cell_type == "LSTM":
                    out_t, (h2, c2) = self.rnn(emb_t, (h, c))
                else:
                    out_t, h2 = self.rnn(emb_t, h)
                    c2 = None

                logits = self.out(out_t.squeeze(1))
                log_probs = F.log_softmax(logits, dim=1)

                topk_logprobs, topk_indices = torch.topk(log_probs, k=beam_size, dim=1)
                topk_logprobs = topk_logprobs.squeeze(0).tolist()
                topk_indices = topk_indices.squeeze(0).tolist()
                for lp, idx in zip(topk_logprobs, topk_indices):
                    new_hyps.append(( seq + [idx], h2, c2, score + lp ))

            new_hyps = sorted(new_hyps, key=lambda x: x[3], reverse=True)[:beam_size]
            live = new_hyps
            if all((hyp[0][-1] == dev_eos_idx) for hyp in live):
                completed.extend(live)
                break

        if not completed:
            completed = live
        best = max(completed, key=lambda x: x[3])
        return best[0]


class Seq2Seq(nn.Module):
    def __init__(
        self,
        encoder: Encoder,
        decoder: Decoder,
        device: torch.device,
    ):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device

    def forward(
        self,
        src: torch.Tensor,
        tgt: torch.Tensor,
        teacher_forcing_ratio: float = 0.5,
    ) -> torch.Tensor:
        """
        src: (B, T_src)
        tgt: (B, T_tgt)   # including <sos> … <eos>
        Returns:
          logits: (B, T_tgt, V_out)
        """
        B, T_src = src.size()
        # Prepare a tensor to hold all decoder logits
        outputs = torch.zeros(
            B, tgt.size(1), self.decoder.out.out_features, device=self.device
        )

        # 1) Run the encoder
        if self.encoder.cell_type == "LSTM":
            enc_outputs, (h_n, c_n) = self.encoder(src)
        else:
            enc_outputs, h_n = self.encoder(src)
            c_n = None

        # 2) Transform encoder's hidden‐state to match decoder layers
        enc_layers = self.encoder.num_layers
        dec_layers = self.decoder.num_layers
        hidden_size = self.encoder.hidden_size

        if self.encoder.cell_type == "LSTM":
            # h_n, c_n each have shape (enc_layers, B, hidden_size)
            if enc_layers >= dec_layers:
                # Take the top-most `dec_layers` layers from encoder
                h_dec = h_n[-dec_layers:]              # shape: (dec_layers, B, H)
                c_dec = c_n[-dec_layers:]              # shape: (dec_layers, B, H)
            else:
                # enc_layers < dec_layers → prepend zeros to match dec_layers
                num_missing = dec_layers - enc_layers   # how many extra layers decoder wants
                zeros_h = torch.zeros(
                    num_missing, B, hidden_size, device=self.device
                )
                zeros_c = torch.zeros(
                    num_missing, B, hidden_size, device=self.device
                )
                # concatenate: (num_missing, B, H) + (enc_layers, B, H) → (dec_layers, B, H)
                h_dec = torch.cat([zeros_h, h_n], dim=0)
                c_dec = torch.cat([zeros_c, c_n], dim=0)
            dec_hidden = h_dec
            dec_cell = c_dec

        else:
            # RNN or GRU case: h_n has shape (enc_layers, B, hidden_size)
            if enc_layers >= dec_layers:
                h_dec = h_n[-dec_layers:]   # just take the top dec_layers
            else:
                num_missing = dec_layers - enc_layers
                zeros_h = torch.zeros(
                    num_missing, B, hidden_size, device=self.device
                )
                h_dec = torch.cat([zeros_h, h_n], dim=0)
            dec_hidden = h_dec
            dec_cell = None

        # 3) Run the decoder (training mode, with teacher_forcing_ratio)
        logits = self.decoder(
            tgt_seq=tgt,
            hidden=dec_hidden,
            cell=dec_cell,
            teacher_forcing_ratio=teacher_forcing_ratio,
        )

        return logits

    @torch.no_grad()
    def predict(
        self,
        src: torch.Tensor,
        max_len: int,
        dev_sos_idx: int,
        dev_eos_idx: int,
        beam_size: int = 1,
    ) -> List[List[int]]:
        """
        Greedy (beam_size=1) or beam search decoding for a batch of size=1. Returns list of decoded sequences.
        """
        self.eval()
        # Assume batch_size=1
        if self.encoder.cell_type == "LSTM":
            _, (h_n, c_n) = self.encoder(src)
        else:
            _, h_n = self.encoder(src)
            c_n = None

        # Now transform encoder’s (h_n, c_n) → (dec_hidden, dec_cell) exactly as above:
        enc_layers = self.encoder.num_layers
        dec_layers = self.decoder.num_layers
        hidden_size = self.encoder.hidden_size
        B = 1  # we only support batch=1 in predict()

        if self.encoder.cell_type == "LSTM":
            if enc_layers >= dec_layers:
                h_dec = h_n[-dec_layers:]
                c_dec = c_n[-dec_layers:]
            else:
                num_missing = dec_layers - enc_layers
                zeros_h = torch.zeros(num_missing, B, hidden_size, device=self.device)
                zeros_c = torch.zeros(num_missing, B, hidden_size, device=self.device)
                h_dec = torch.cat([zeros_h, h_n], dim=0)
                c_dec = torch.cat([zeros_c, c_n], dim=0)
            hidden_state = h_dec
            cell_state = c_dec
        else:
            if enc_layers >= dec_layers:
                h_dec = h_n[-dec_layers:]
            else:
                num_missing = dec_layers - enc_layers
                zeros_h = torch.zeros(num_missing, B, hidden_size, device=self.device)
                h_dec = torch.cat([zeros_h, h_n], dim=0)
            hidden_state = h_dec
            cell_state = None

        # 4) Now do greedy or beam search decode using (hidden_state, cell_state)
        if beam_size == 1:
            seq = [dev_sos_idx]
            hidden_ = hidden_state
            cell_ = cell_state
            for _ in range(max_len):
                last_token = torch.tensor([seq[-1]], dtype=torch.long, device=self.device).unsqueeze(0)
                emb = self.decoder.embed(last_token)  # (1,1,embed_dim)

                if self.decoder.cell_type == "LSTM":
                    out, (h_next, c_next) = self.decoder.rnn(emb, (hidden_, cell_))
                    hidden_, cell_ = h_next, c_next
                else:
                    out, h_next = self.decoder.rnn(emb, hidden_)
                    hidden_, cell_ = h_next, None

                logits = self.decoder.out(out.squeeze(1))  # (1, V_out)
                next_token = logits.argmax(dim=1).item()
                seq.append(next_token)
                if next_token == dev_eos_idx:
                    break
            return [seq]

        else:
            best_seq = self.decoder.beam_search_decode(
                encoder_hidden=hidden_state,
                encoder_cell=cell_state,
                max_len=max_len,
                dev_sos_idx=dev_sos_idx,
                dev_eos_idx=dev_eos_idx,
                beam_size=beam_size,
            )
            return [best_seq]



# ─── train / evaluate ────────────────────────────────────────────────────────
def char_accuracy(logits, target, pad_idx):
    with torch.no_grad():
        pred = logits.argmax(dim=2)
        mask = (target != pad_idx)
        correct = (pred == target) & mask
        total = mask.sum().item()
        return correct.sum().item() / max(total, 1)


def train_one_epoch(
    model: Seq2Seq,
    iterator: DataLoader,
    optimizer: torch.optim.Optimizer,
    criterion: nn.CrossEntropyLoss,
    pad_idx: int,
    device: torch.device,
    teacher_forcing_ratio: float,
) -> Tuple[float, float, float]:
    """
    Runs one epoch of training, returning (train_loss, train_char_acc, train_word_acc).
    """
    model.train()
    epoch_loss = 0.0
    epoch_char_acc = 0.0
    epoch_word_acc = 0.0
    total_batches = 0

    for src, tgt in iterator:
        src = src.to(device, non_blocking=True)  # (B, T_src)
        tgt = tgt.to(device, non_blocking=True)  # (B, T_tgt)
        B, T_tgt = tgt.size()

        optimizer.zero_grad()
        # Forward pass
        output_logits = model(src, tgt, teacher_forcing_ratio=teacher_forcing_ratio)
        # output_logits: (B, T_tgt, V_out)

        # 1) Compute token‐level (character) loss & accuracy
        V = output_logits.size(-1)
        loss = criterion(output_logits.view(-1, V), tgt.view(-1))

        # char‐level accuracy (ignore pad)
        with torch.no_grad():
            pred_inds = output_logits.argmax(dim=2)      # (B, T_tgt)
            char_mask = (tgt != pad_idx)                  # (B, T_tgt)
            char_correct = ((pred_inds == tgt) & char_mask).sum().item()
            char_total = char_mask.sum().item()
            batch_char_acc = char_correct / max(char_total, 1)

        # 2) Compute word‐level accuracy: count how many sequences match exactly (ignoring pad)
        with torch.no_grad():
            # For each example b, we want (pred_inds[b, t] == tgt[b, t]) for ALL t where tgt[b, t] != pad_idx.
            # We can OR with (~mask) on both sides:
            #   (pred_inds == tgt) | (~char_mask)  → True for any pad position
            # Then check .all(dim=1).
            match_or_pad = (pred_inds == tgt) | (~char_mask)    # (B, T_tgt) boolean
            exact_matches = match_or_pad.all(dim=1)             # (B,) boolean: True if all positions match or are pad
            batch_word_acc = exact_matches.sum().item() / B

        # Backward + step
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
        optimizer.step()

        epoch_loss += loss.item()
        epoch_char_acc += batch_char_acc
        epoch_word_acc += batch_word_acc
        total_batches += 1

    return (
        epoch_loss / total_batches,
        epoch_char_acc / total_batches,
        epoch_word_acc / total_batches,
    )



def evaluate(
    model: Seq2Seq,
    iterator: DataLoader,
    criterion: nn.CrossEntropyLoss,
    pad_idx: int,
    device: torch.device,
    beam_size: int,
    max_dev_len: int,
) -> Tuple[float, float, float]:
    """
    One validation pass (no teacher forcing). Returns (val_loss, val_char_acc, val_word_acc).
    """
    model.eval()
    epoch_loss = 0.0
    epoch_char_acc = 0.0
    epoch_word_acc = 0.0
    total_batches = 0

    with torch.no_grad():
        for src, tgt in iterator:
            src = src.to(device, non_blocking=True)
            tgt = tgt.to(device, non_blocking=True)     # (B, T_tgt)
            B, T_tgt = tgt.size()

            # 1) Compute loss by feeding gold tgt through decoder with teacher_forcing=0
            logits = model(src, tgt, teacher_forcing_ratio=0.0)  # (B, T_tgt, V_out)
            V = logits.size(-1)
            loss = criterion(logits.view(-1, V), tgt.view(-1))

            # 2) Character‐level accuracy (token‐level)
            pred_inds = logits.argmax(dim=2)           # (B, T_tgt)
            char_mask = (tgt != pad_idx)               # (B, T_tgt)
            char_correct = ((pred_inds == tgt) & char_mask).sum().item()
            char_total = char_mask.sum().item()
            batch_char_acc = char_correct / max(char_total, 1)

            # 3) Word‐level accuracy on this batch using simple greedy decode
            # (Note: you could also use beam_search here if you want word-acc under beam search.
            #  For simplicity, we’ll just use the greedy `pred_inds` we already have.)
            match_or_pad = (pred_inds == tgt) | (~char_mask)  # True if match or tgt is pad
            exact_matches = match_or_pad.all(dim=1)           # (B,) boolean
            batch_word_acc = exact_matches.sum().item() / B

            epoch_loss += loss.item()
            epoch_char_acc += batch_char_acc
            epoch_word_acc += batch_word_acc
            total_batches += 1

    return (
        epoch_loss / total_batches,
        epoch_char_acc / total_batches,
        epoch_word_acc / total_batches,
    )



# ─── main() with parse_known_args ─────────────────────────────────────────────
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--data_dir", type=str, default="/content")
    parser.add_argument("--train_file", type=str, default="bn.translit.sampled.train.tsv")
    parser.add_argument("--dev_file", type=str, default="bn.translit.sampled.dev.tsv")
    parser.add_argument("--test_file", type=str, default="bn.translit.sampled.test.tsv")
    parser.add_argument("--emb_size", type=int, default=64)
    parser.add_argument("--hidden_size", type=int, default=64)
    parser.add_argument("--enc_layers", type=int, default=1)
    parser.add_argument("--dec_layers", type=int, default=1)
    parser.add_argument("--cell_type", type=str, default="RNN", choices=["RNN", "GRU", "LSTM"])
    parser.add_argument("--dropout", type=float, default=0.2)
    parser.add_argument("--lr", type=float, default=1e-3)
    parser.add_argument("--batch_size", type=int, default=128)
    parser.add_argument("--epochs", type=int, default=10)
    parser.add_argument("--tf_ratio", type=float, default=0.5)
    parser.add_argument("--beam_size", type=int, default=1)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--max_dev_len", type=int, default=32)
    parser.add_argument("--project_name", type=str, default="Vanilla_RNN")
    parser.add_argument("--run_name", type=str, default=None)

    # ── use parse_known_args to ignore Colab’s extra "-f …json"
    args, _ = parser.parse_known_args()

    random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    train_path = os.path.join(args.data_dir, args.train_file)
    dev_path   = os.path.join(args.data_dir, args.dev_file)
    test_path  = os.path.join(args.data_dir, args.test_file)
    vocab = CharVocab([train_path, dev_path, test_path])

    train_ds = TransliterationDataset(train_path, vocab)
    dev_ds   = TransliterationDataset(dev_path, vocab)
    test_ds  = TransliterationDataset(test_path, vocab)

    train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, collate_fn=TransliterationDataset.collate_fn, num_workers=2, pin_memory=True)
    dev_loader   = DataLoader(dev_ds,   batch_size=args.batch_size, shuffle=False, collate_fn=TransliterationDataset.collate_fn, num_workers=2, pin_memory=True)
    test_loader  = DataLoader(test_ds,  batch_size=1, shuffle=False, collate_fn=TransliterationDataset.collate_fn, num_workers=2, pin_memory=True)

    wandb.init(
        project=args.project_name,
        name=args.run_name,
        config={
            "emb_size": args.emb_size,
            "hidden_size": args.hidden_size,
            "enc_layers": args.enc_layers,
            "dec_layers": args.dec_layers,
            "cell_type": args.cell_type,
            "dropout": args.dropout,
            "lr": args.lr,
            "batch_size": args.batch_size,
            "epochs": args.epochs,
            "tf_ratio": args.tf_ratio,
            "beam_size": args.beam_size,
        },
    )
    config = wandb.config

    encoder = Encoder(
        input_vocab_size=vocab.rom_vocab_size,
        embed_dim=config.emb_size,
        hidden_size=config.hidden_size,
        num_layers=config.enc_layers,
        cell_type=config.cell_type,
        dropout=config.dropout,
    )
    decoder = Decoder(
        output_vocab_size=vocab.dev_vocab_size,
        embed_dim=config.emb_size,
        hidden_size=config.hidden_size,
        num_layers=config.dec_layers,
        cell_type=config.cell_type,
        dropout=config.dropout,
    )
    model = Seq2Seq(encoder, decoder, device).to(device)
    criterion = nn.CrossEntropyLoss(ignore_index=vocab.dev_pad_idx)
    optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)

    best_val_acc = 0.0
    for epoch in range(1, config.epochs + 1):
        start_time = time.time()

        # 1) Unpack three values from train_one_epoch:
        train_loss, train_char_acc, train_word_acc = train_one_epoch(
            model=model,
            iterator=train_loader,
            optimizer=optimizer,
            criterion=criterion,
            pad_idx=vocab.dev_pad_idx,
            device=device,
            teacher_forcing_ratio=config.tf_ratio,
        )

        # 2) Unpack three values from evaluate:
        val_loss, val_char_acc, val_word_acc = evaluate(
            model=model,
            iterator=dev_loader,
            criterion=criterion,
            pad_idx=vocab.dev_pad_idx,
            device=device,
            beam_size=config.beam_size,
            max_dev_len=args.max_dev_len,
        )

        elapsed = time.time() - start_time

        # 3) Decide “best” model based on whichever metric you prefer (e.g. val_char_acc)
        if val_char_acc > best_val_acc:
            best_val_acc = val_char_acc
            torch.save(model.state_dict(), "best_model.pt")

        # 4) Log all four accuracy metrics to W&B
        wandb.log(
            {
                "epoch": epoch,
                "train_loss": train_loss,
                "train_char_acc": train_char_acc,
                "train_word_acc": train_word_acc,   # newly added
                "val_loss": val_loss,
                "val_char_acc": val_char_acc,
                "val_word_acc": val_word_acc,       # newly added
                "epoch_time_sec": elapsed,
            }
        )

        # 5) Print so you see them in Colab output
        print(
            f"Epoch {epoch:02d} | "
            f"Train Loss: {train_loss:.4f}, "
            f"Train Char Acc: {train_char_acc:.4f}, "
            f"Train Word Acc: {train_word_acc:.4f} | "
            f"Val Loss: {val_loss:.4f}, "
            f"Val Char Acc: {val_char_acc:.4f}, "
            f"Val Word Acc: {val_word_acc:.4f} | "
            f"Time: {elapsed:.1f}s"
        )

    wandb.finish()


if __name__ == "__main__":
    main()
EOF

In [42]:
%%bash
cat > sweep.yaml << 'EOF'
program: main.py
method: random
project: Vanilla_RNN
entity: mrsagarbiswas-iit-madras
metric:
  name: val_char_acc
  goal: maximize
parameters:
  emb_size:
    values: [16, 32, 64, 256]
  hidden_size:
    values: [16, 32, 64, 256]
  enc_layers:
    values: [1, 2, 3]
  epochs:
    values: [5, 10, 15]
  dec_layers:
    values: [1, 2, 3]
  cell_type:
    values: ["RNN", "GRU", "LSTM"]
  dropout:
    values: [0.2, 0.3]
  lr:
    values: [1e-3, 1e-4]
  batch_size:
    values: [32, 64, 128]
  tf_ratio:
    values: [0.3, 0.5, 0.7]
  beam_size:
    values: [1, 3]
EOF


In [43]:
%%bash
wandb login

wandb: Currently logged in as: mrsagarbiswas (mrsagarbiswas-iit-madras) to https://api.wandb.ai. Use `wandb login --relogin` to force relogin


In [44]:
%%bash
SWEEP_ID=$(wandb sweep sweep.yaml)
echo "Sweep ID = $SWEEP_ID"

Sweep ID = 


wandb: Creating sweep from: sweep.yaml
wandb: Creating sweep with ID: wpzjutzc
wandb: View sweep at: https://wandb.ai/mrsagarbiswas-iit-madras/Vanilla_RNN/sweeps/wpzjutzc
wandb: Run sweep agent with: wandb agent mrsagarbiswas-iit-madras/Vanilla_RNN/wpzjutzc


In [None]:
!wandb agent mrsagarbiswas-iit-madras/Vanilla_RNN/wpzjutzc --count 30

[34m[1mwandb[0m: Starting wandb agent 🕵️
2025-05-18 20:01:04,306 - wandb.wandb_agent - INFO - Running runs: []
2025-05-18 20:01:04,648 - wandb.wandb_agent - INFO - Agent received command: run
2025-05-18 20:01:04,648 - wandb.wandb_agent - INFO - Agent starting run with config:
	batch_size: 128
	beam_size: 3
	cell_type: RNN
	dec_layers: 2
	dropout: 0.2
	emb_size: 16
	enc_layers: 2
	epochs: 15
	hidden_size: 32
	lr: 0.0001
	tf_ratio: 0.7
2025-05-18 20:01:04,650 - wandb.wandb_agent - INFO - About to run command: /usr/bin/env python main.py --batch_size=128 --beam_size=3 --cell_type=RNN --dec_layers=2 --dropout=0.2 --emb_size=16 --enc_layers=2 --epochs=15 --hidden_size=32 --lr=0.0001 --tf_ratio=0.7
[34m[1mwandb[0m: Currently logged in as: [33mmrsagarbiswas[0m ([33mmrsagarbiswas-iit-madras[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
2025-05-18 20:01:09,661 - wandb.wandb_agent - INFO - Running runs: ['dise9ble']
[34m[1mwandb[0m: Track