# Q5 Solution

In [1]:
from __future__ import annotations

import argparse
import math
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Tuple, Optional, Any
import warnings
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler


# ──────────────────── 2.5. Embedding-method modules ────────────────────
class OneHotEmbedding(nn.Module):
    """Convert token ids → explicit one-hot → linear projection to embedding_dim."""
    def __init__(self, vocab_size: int, embedding_dim: int, padding_idx: int):
        super().__init__()
        self.vocab_size = vocab_size
        self.padding_idx = padding_idx
        # Project a one-hot vector of length vocab_size → embedding_dim
        self.projection = nn.Linear(vocab_size, embedding_dim, bias=False)

    def forward(self, token_ids: torch.LongTensor) -> torch.Tensor:
        # token_ids: (B, T)
        one_hot = F.one_hot(token_ids, num_classes=self.vocab_size).float()  # (B,T,V)
        # zero out pad positions if desired
        one_hot[token_ids == self.padding_idx] = 0.0
        # project → (B, T, D)
        return self.projection(one_hot)
    
class CharCNNEmbedding(nn.Module):
    """Convert token ids → one-hot → conv1d over time → (B,T,embedding_dim)."""
    def __init__(
        self,
        vocab_size: int,
        embedding_dim: int,
        padding_idx: int,
        kernel_size: int = 3,
        num_filters: Optional[int] = None
    ):
        super().__init__()
        self.vocab_size = vocab_size
        self.padding_idx = padding_idx
        self.num_filters = num_filters or embedding_dim
        # Convolution: in_channels=vocab_size, out_channels=num_filters
        self.conv1d = nn.Conv1d(
            in_channels=vocab_size,
            out_channels=self.num_filters,
            kernel_size=kernel_size,
            padding=kernel_size // 2,
            bias=False
        )
        # Optionally project filters → embedding_dim
        if self.num_filters != embedding_dim:
            self.projection = nn.Linear(self.num_filters, embedding_dim, bias=False)
        else:
            self.projection = None

    def forward(self, token_ids: torch.LongTensor) -> torch.Tensor:
        # token_ids: (B, T)
        one_hot = F.one_hot(token_ids, num_classes=self.vocab_size).float()  # (B,T,V)
        x = one_hot.permute(0, 2, 1)  # (B, V, T)
        x = self.conv1d(x)           # (B, F, T)
        x = x.permute(0, 2, 1)       # (B, T, F)
        if self.projection:
            x = self.projection(x)   # (B, T, D)
        return x                     # (B, T, embedding_dim)
    
class SVDPPMIEmbedding(nn.Module):
    """Build PPMI→SVD char embeddings, then (if needed) project to embedding_dim."""
    def __init__(
        self,
        token_seqs: List[List[int]],
        vocab_size: int,
        embedding_dim: int,
        padding_idx: int,
        window: int = 2
    ):
        super().__init__()
        # 1) Build co-occurrence counts
        cooc = np.zeros((vocab_size, vocab_size), dtype=np.float64)
        for seq in token_seqs:
            for i, u in enumerate(seq):
                if u == padding_idx: continue
                for j in range(max(0, i - window), min(len(seq), i + window + 1)):
                    if i == j: continue
                    v = seq[j]
                    if v == padding_idx: continue
                    cooc[u, v] += 1

        # 2) Compute PPMI matrix
        total = cooc.sum()
        row_sums = cooc.sum(axis=1, keepdims=True)
        col_sums = cooc.sum(axis=0, keepdims=True)
        with np.errstate(divide="ignore", invalid="ignore"):
            pmi = np.log((cooc * total) / (row_sums * col_sums))
        pmi[np.isnan(pmi)] = 0.0
        pmi[pmi < 0] = 0.0

        # 3) Truncated SVD
        U, S, _ = np.linalg.svd(pmi, full_matrices=False)
        # D0 is the actual SVD dimension we get (≤ vocab_size)
        D0 = min(embedding_dim, U.shape[1])
        U = U[:, :D0]            # (vocab_size, D0)
        S = S[:D0]               # (D0,)
        emb_matrix = U * np.sqrt(S)  # (vocab_size, D0)

        # Register static SVD weights
        self.register_buffer("weight", torch.from_numpy(emb_matrix).float())

        # 4) If the SVD rank D0 is smaller than requested embedding_dim, add a projection
        if D0 < embedding_dim:
            # Warn the user clearly
            warnings.warn(
                f"SVD/PPMI yielded only {D0} dimensions (≤ vocab size), "
                f"but embedding_size={embedding_dim} was requested. "
                "Adding a learnable linear projection to expand from "
                f"{D0} → {embedding_dim} dimensions.",
                UserWarning
            )
            self.expander = nn.Linear(D0, embedding_dim, bias=False)
        else:
            self.expander = None

    def forward(self, token_ids: torch.LongTensor) -> torch.Tensor:
        # Lookup static SVD embeddings → (B, T, D0)
        x = F.embedding(token_ids, self.weight, padding_idx=self.weight.new_zeros(1).long())
        # If we have an expander, project to the full embedding_dim
        if self.expander:
            x = self.expander(x)
        return x  # (B, T, embedding_dim)
    

# map a string to the corresponding nn.RNN module
_RNN_MAP: Dict[str, nn.Module] = {
    "RNN": nn.RNN,
    "LSTM": nn.LSTM,
    "GRU": nn.GRU,
}

# ─────────────────────── 1. Vocabulary helpers ───────────────────────
SPECIAL_TOKENS = {"<pad>": 0, "<sos>": 1, "<eos>": 2}


# ─────────────────────── 1. Vocabulary helpers ───────────────────────
SPECIAL_TOKENS = {"<pad>": 0, "<sos>": 1, "<eos>": 2}


class CharVocabulary:
    """Character-level vocabulary that handles <pad>, <sos>, and <eos>."""
    def __init__(self, characters: List[str]):
        unique_chars = sorted(set(characters))
        # string→index and index→string maps
        self.stoi: Dict[str, int] = {
            **SPECIAL_TOKENS,
            **{ch: idx + len(SPECIAL_TOKENS) for idx, ch in enumerate(unique_chars)}
        }
        self.itos: Dict[int, str] = {idx: ch for ch, idx in self.stoi.items()}

    def encode(self, text: str, *, add_sos: bool = False, add_eos: bool = True) -> List[int]:
        """Convert a string to list of token ids (with optional <sos>/<eos>)."""
        ids = [self.stoi[ch] for ch in text]
        if add_eos:
            ids.append(self.stoi["<eos>"])
        if add_sos:
            ids.insert(0, self.stoi["<sos>"])
        return ids

    def decode(self, ids: List[int]) -> str:
        """Convert ids back to string (stop at <eos>)."""
        chars: List[str] = []
        for idx in ids:
            if idx == self.stoi["<eos>"]:
                break
            chars.append(self.itos.get(idx, ""))
        return "".join(chars)

    @property
    def size(self) -> int:
        """Total number of tokens in vocabulary."""
        return len(self.stoi)

# ─────────────────────────── Helper: Align Hidden State ──────────────────────────
def _align_hidden_state(hidden_state, target_num_layers: int):
    """
    Adjust the encoder's final hidden state to match the decoder's expected
    number of layers. Works for both LSTM (tuple of (h,c)) and GRU/RNN (single tensor).
    
    Strategies:
    - If encoder_layers == decoder_layers: return hidden_state unchanged.
    - If encoder_layers  > decoder_layers: take the **last** `target_num_layers` layers.
    - If encoder_layers  < decoder_layers: **repeat** the final layer's state
      so that the total number of layers equals `target_num_layers`.
    """
    def _repeat_last_layer(tensor, repeat_count: int):
        # tensor shape: (enc_layers, batch_size, hidden_dim)
        last_layer = tensor[-1:]                             # shape: (1, B, H)
        repeated   = last_layer.expand(repeat_count, -1, -1) # shape: (repeat_count, B, H)
        return torch.cat([tensor, repeated], dim=0)          # new shape: (enc_layers+repeat_count, B, H)

    if isinstance(hidden_state, tuple):
        h, c = hidden_state
        enc_layers, batch_size, hid_dim = h.shape

        if enc_layers == target_num_layers:
            return h, c
        
        # Warn whenever we need to truncate or repeat
        warnings.warn(
            f"Encoder has {enc_layers} layers but decoder expects {target_num_layers}. "
            f"{'Truncating' if enc_layers>target_num_layers else 'Repeating last layer'} hidden state.",
            UserWarning
        )

        if enc_layers > target_num_layers:
            # keep only the last `target_num_layers` layers
            return h[-target_num_layers:], c[-target_num_layers:]
        else:
            # repeat the final layer's state to pad up to target_num_layers
            to_repeat = target_num_layers - enc_layers
            return _repeat_last_layer(h, to_repeat), _repeat_last_layer(c, to_repeat)
    else:
        h = hidden_state
        enc_layers, batch_size, hid_dim = h.shape

        if enc_layers == target_num_layers:
            return h
        
        # Same warning for single‐tensor hidden states
        warnings.warn(
            f"Encoder has {enc_layers} layers but decoder expects {target_num_layers}. "
            f"{'Truncating' if enc_layers>target_num_layers else 'Repeating last layer'} hidden state.",
            UserWarning
        )

        if enc_layers > target_num_layers:
            return h[-target_num_layers:]
        else:
            to_repeat = target_num_layers - enc_layers
            return _repeat_last_layer(h, to_repeat)
        

# ───────────────────────── 2. Dataset class ─────────────────────────
class DakshinaLexicon(Dataset):
    """Loads a Dakshina *lexicon* TSV and encodes (source, target) pairs.

    TSV columns: native_word, romanized_word, count
    We treat romanized_word as source and native_word as target.
    """
    def __init__(
        self,
        tsv_path: str | Path,
        source_vocab: Optional[CharVocabulary] = None,
        target_vocab: Optional[CharVocabulary] = None,
        *,
        build_vocabs: bool = False,
        use_attestations: bool = False
    ):
        # Read TSV – three columns, ensure correct dtypes
        dataframe = pd.read_csv(
            tsv_path, sep="\t", header=None,
            names=["target_native", "source_roman", "count"],
            dtype={"target_native": str, "source_roman": str, "count": int}
        ).dropna()

        # Optionally keep annotator counts for sampling or weighting
        self.example_counts: Optional[List[int]] = (
            dataframe["count"].tolist() if use_attestations else None
        )

        # Keep only the (src, tgt) pairs
        self.word_pairs: List[Tuple[str, str]] = list(zip(
            dataframe["source_roman"], dataframe["target_native"]
        ))

        # Build new or reuse provided vocabularies
        if build_vocabs:
            assert source_vocab is None and target_vocab is None, (
                "Cannot pass existing vocabs when build_vocabs=True"
            )
            # collect all chars
            all_src_chars = [ch for src, _ in self.word_pairs for ch in src]
            all_tgt_chars = [ch for _, tgt in self.word_pairs for ch in tgt]
            source_vocab = CharVocabulary(all_src_chars)
            target_vocab = CharVocabulary(all_tgt_chars)

        assert source_vocab is not None and target_vocab is not None, (
            "Must provide or build vocabularies"
        )
        self.src_vocab, self.tgt_vocab = source_vocab, target_vocab

        # Encode all pairs once for efficiency
        self.encoded_pairs: List[Tuple[List[int], List[int]]] = [
            (self.src_vocab.encode(src),
             self.tgt_vocab.encode(tgt, add_sos=True))
            for src, tgt in self.word_pairs
        ]
        # Also keep just the source sequences for SVD/PPMI embedding
        self.encoded_sources: List[List[int]] = [src for src, _ in self.encoded_pairs]

        # Padding token id
        self.pad_id: int = self.src_vocab.stoi["<pad>"]

    def __len__(self) -> int:
        return len(self.encoded_pairs)

    def __getitem__(self, index: int) -> Tuple[List[int], List[int]]:
        return self.encoded_pairs[index]

def collate_batch(
    batch: List[Tuple[List[int], List[int]]],
    pad_id: int
) -> Tuple[torch.LongTensor, torch.LongTensor, torch.LongTensor]:
    """Pad source and target sequences to uniform length within a batch."""
    src_seqs, tgt_seqs = zip(*batch)
    src_lengths = torch.tensor([len(s) for s in src_seqs], dtype=torch.long)
    tgt_lengths = torch.tensor([len(t) for t in tgt_seqs], dtype=torch.long)

    max_src_len = src_lengths.max().item()
    max_tgt_len = tgt_lengths.max().item()

    padded_sources = torch.full((len(batch), max_src_len), pad_id, dtype=torch.long)
    padded_targets = torch.full((len(batch), max_tgt_len), pad_id, dtype=torch.long)

    for i, (s, t) in enumerate(zip(src_seqs, tgt_seqs)):
        padded_sources[i, : len(s)] = torch.tensor(s, dtype=torch.long)
        padded_targets[i, : len(t)] = torch.tensor(t, dtype=torch.long)

    return padded_sources, src_lengths, padded_targets

# ───────────────────────── 7. Training & evaluation ─────────────────────────
def train_epoch(
    model: Seq2Seq,
    loader: DataLoader,
    optimizer: torch.optim.Optimizer,
    loss_fn: nn.CrossEntropyLoss,
    device: str,
    teacher_forcing: float
) -> float:
    """Run one epoch of training; return average token-level cross-entropy loss."""
    model.train()
    total_loss, total_tokens = 0.0, 0
    for src, src_len, tgt in loader:
        src, src_len, tgt = src.to(device), src_len.to(device), tgt.to(device)
        optimizer.zero_grad()
        logits = model(src, src_len, tgt, teacher_forcing_ratio=teacher_forcing)
        # shift tgt so we predict t when truth is at t
        gold = tgt[:, 1:].contiguous()
        preds = logits[:, 1:].contiguous()
        loss = loss_fn(preds.view(-1, preds.size(-1)), gold.view(-1))
        loss.backward()
        optimizer.step()

        n_valid = (gold != loss_fn.ignore_index).sum().item()
        total_loss += loss.item() * n_valid
        total_tokens += n_valid

    return total_loss / total_tokens


def eval_epoch(
    model: Seq2Seq,
    loader: DataLoader,
    loss_fn: nn.CrossEntropyLoss,
    device: str
) -> float:
    """Run one epoch of evaluation; return average token-level cross-entropy loss."""
    model.eval()
    total_loss, total_tokens = 0.0, 0
    with torch.no_grad():
        for src, src_len, tgt in loader:
            src, src_len, tgt = src.to(device), src_len.to(device), tgt.to(device)
            logits = model(src, src_len, tgt, teacher_forcing_ratio=0.0)
            gold = tgt[:, 1:].contiguous()
            preds = logits[:, 1:].contiguous()
            loss = loss_fn(preds.view(-1, preds.size(-1)), gold.view(-1))
            n_valid = (gold != loss_fn.ignore_index).sum().item()
            total_loss += loss.item() * n_valid
            total_tokens += n_valid

    return total_loss / total_tokens

In [2]:
# solution_5_model.py
# -*- coding: utf-8 -*-
"""
Attention‐augmented Seq2Seq for Dakshina Hindi transliteration (Q5).

This module defines:
  - Seq2SeqAttentionConfig: hyper‐parameters & model options
  - DotProductAttention: simple global dot‐product attention
  - EncoderWithOutputs: returns per‐time‐step features + final hidden
  - DecoderWithAttention: attends + decodes one token at a time
  - Seq2SeqAttention: end‐to‐end training / inference wrapper
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import List, Optional, Tuple, Any

import torch
import torch.nn as nn
import torch.nn.functional as F



# ───────────────────────── 1. Configuration ─────────────────────────
@dataclass
class Seq2SeqAttentionConfig:
    source_vocab_size: int
    target_vocab_size: int

    embedding_dim: int = 256
    hidden_dim:    int = 512

    encoder_layers: int = 1
    decoder_layers: int = 1

    cell_type: str = "LSTM"       # "RNN" | "GRU" | "LSTM"
    dropout:   float = 0.1

    pad_index: int = 0
    sos_index: int = 1
    eos_index: int = 2

    embedding_method: str = "learned"  # "learned" | "onehot" | "char_cnn" | "svd_ppmi"
    svd_sources: Optional[List[List[int]]] = None

    attention_dim: Optional[int] = None

    def __post_init__(self):
        assert self.cell_type in {"RNN", "GRU", "LSTM"}, "cell_type must be RNN, GRU or LSTM"
        assert self.embedding_method in {"learned", "onehot", "char_cnn", "svd_ppmi"}
        if self.embedding_method == "svd_ppmi":
            assert self.svd_sources is not None, "svd_ppmi requires svd_sources"
        if self.attention_dim is None:
            self.attention_dim = self.hidden_dim


# ───────────────────────── 2. Encoder ───────────────────────────────
class EncoderWithOutputs(nn.Module):
    """
    Encoder that returns both:
      - outputs      : (B, T_src, hidden_dim)
      - hidden_state : final hidden state(s) for decoder init
    """
    def __init__(self, cfg: Seq2SeqAttentionConfig):
        super().__init__()
        self.cfg = cfg

        # character embedding
        if cfg.embedding_method == "learned":
            self.embedding = nn.Embedding(cfg.source_vocab_size,
                                          cfg.embedding_dim,
                                          padding_idx=cfg.pad_index)
        elif cfg.embedding_method == "onehot":
            self.embedding = OneHotEmbedding(
                vocab_size=cfg.source_vocab_size,
                embedding_dim=cfg.embedding_dim,
                padding_idx=cfg.pad_index
            )
        elif cfg.embedding_method == "char_cnn":
            self.embedding = CharCNNEmbedding(
                vocab_size=cfg.source_vocab_size,
                embedding_dim=cfg.embedding_dim,
                padding_idx=cfg.pad_index
            )
        else:  # "svd_ppmi"
            self.embedding = SVDPPMIEmbedding(
                token_seqs=cfg.svd_sources,
                vocab_size=cfg.source_vocab_size,
                embedding_dim=cfg.embedding_dim,
                padding_idx=cfg.pad_index
            )

        # RNN stack
        RNNClass = _RNN_MAP[cfg.cell_type]
        self.rnn = RNNClass(
            input_size=cfg.embedding_dim,
            hidden_size=cfg.hidden_dim,
            num_layers=cfg.encoder_layers,
            batch_first=True,
            dropout=cfg.dropout if cfg.encoder_layers > 1 else 0.0
        )

    def forward(
        self,
        src: torch.LongTensor,
        src_lengths: torch.LongTensor
    ) -> Tuple[torch.Tensor, Any]:
        # embed: (B, T_src, D_emb)
        embedded = self.embedding(src)

        # pack for RNN
        packed = nn.utils.rnn.pack_padded_sequence(
            embedded, src_lengths.cpu(), batch_first=True, enforce_sorted=False
        )
        packed_outputs, hidden_state = self.rnn(packed)

        # unpack to (B, T_src, hidden_dim)
        outputs, _ = nn.utils.rnn.pad_packed_sequence(
            packed_outputs,
            batch_first=True,
            padding_value=self.cfg.pad_index
        )
        return outputs, hidden_state


# ───────────────────────── 3. Attention ─────────────────────────────
class DotProductAttention(nn.Module):
    """
    Global dot-product attention:
      score_{t,i} = h_dec(t) · h_enc(i)
    """
    def __init__(self):
        super().__init__()

    def forward(
        self,
        decoder_hidden: torch.Tensor,   # (B, hidden_dim)
        encoder_outputs: torch.Tensor,  # (B, T_src, hidden_dim)
        mask: torch.Tensor              # (B, T_src)
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # scores: (B, T_src)
        scores = torch.bmm(encoder_outputs, decoder_hidden.unsqueeze(2)).squeeze(2)
        scores = scores.masked_fill(mask == 0, float("-inf"))
        alignments = F.softmax(scores, dim=1)                 # (B, T_src)
        context    = torch.bmm(alignments.unsqueeze(1),      # (B,1,T_src)
                               encoder_outputs
                              ).squeeze(1)                   # (B, hidden_dim)
        return context, alignments


# ───────────────────────── 4. Decoder ───────────────────────────────
class DecoderWithAttention(nn.Module):
    """
    Decoder RNN that at each step:
      1) embeds input
      2) attends over encoder_outputs
      3) feeds [embed; context] to RNN
      4) projects to vocab logits
    """
    def __init__(self, cfg: Seq2SeqAttentionConfig):
        super().__init__()
        self.cfg = cfg

        self.embedding = nn.Embedding(cfg.target_vocab_size,
                                      cfg.embedding_dim,
                                      padding_idx=cfg.pad_index)
        self.attention = DotProductAttention()

        rnn_input_size = cfg.embedding_dim + cfg.hidden_dim
        RNNClass = _RNN_MAP[cfg.cell_type]
        self.rnn = RNNClass(
            input_size=rnn_input_size,
            hidden_size=cfg.hidden_dim,
            num_layers=cfg.decoder_layers,
            batch_first=True,
            dropout=cfg.dropout if cfg.decoder_layers > 1 else 0.0
        )
        self.output_projection = nn.Linear(cfg.hidden_dim,
                                           cfg.target_vocab_size)

    def forward(
        self,
        input_token: torch.LongTensor,    # (B,1)
        last_hidden: Any,
        encoder_outputs: torch.Tensor,    # (B, T_src, hidden_dim)
        encoder_mask: torch.Tensor        # (B, T_src)
    ) -> Tuple[torch.Tensor, Any, torch.Tensor]:
        # embed: (B,1,D_emb)
        emb = self.embedding(input_token)

        # extract last layer of hidden
        if isinstance(last_hidden, tuple):
            dec_h = last_hidden[0][-1]  # (B, hidden_dim)
        else:
            dec_h = last_hidden[-1]     # (B, hidden_dim)

        # attention
        context, alignments = self.attention(dec_h,
                                             encoder_outputs,
                                             encoder_mask)
        # concat: (B,1, D_emb+hidden_dim)
        rnn_in = torch.cat([emb, context.unsqueeze(1)], dim=2)
        output, new_hidden = self.rnn(rnn_in, last_hidden)  # (B,1,hidden_dim)
        logits = self.output_projection(output)              # (B,1,V)
        return logits, new_hidden, alignments


# ──────────────────── 5. Seq2SeqAttention ─────────────────────────
class Seq2SeqAttention(nn.Module):
    """
    Full encoder-decoder with attention.
    Supports forward, greedy_decode, and beam_search_decode.
    """
    def __init__(self, cfg: Seq2SeqAttentionConfig):
        super().__init__()
        self.cfg     = cfg
        self.encoder = EncoderWithOutputs(cfg)
        self.decoder = DecoderWithAttention(cfg)

    def forward(
        self,
        src: torch.LongTensor,
        src_lengths: torch.LongTensor,
        tgt: torch.LongTensor,
        *,
        teacher_forcing_ratio: float = 0.5
    ) -> torch.Tensor:
        B, T_tgt = tgt.size()
        device   = src.device
        enc_outputs, enc_hidden = self.encoder(src, src_lengths)
        enc_mask  = (src != self.cfg.pad_index).to(device)
        dec_hidden = _align_hidden_state(enc_hidden,
                                        self.cfg.decoder_layers)

        logits_all = torch.zeros(B, T_tgt,
                                 self.cfg.target_vocab_size,
                                 device=device)
        dec_input  = tgt[:, 0].unsqueeze(1)  # (B,1)

        for t in range(1, T_tgt):
            step_logits, dec_hidden, _ = self.decoder(
                dec_input, dec_hidden,
                enc_outputs, enc_mask
            )
            logits_all[:, t] = step_logits.squeeze(1)
            if torch.rand(1).item() < teacher_forcing_ratio:
                dec_input = tgt[:, t].unsqueeze(1)
            else:
                dec_input = step_logits.argmax(-1)

        return logits_all

    def greedy_decode(
        self,
        src: torch.LongTensor,
        src_lengths: torch.LongTensor,
        *,
        max_len: int = 50
    ) -> torch.LongTensor:
        B = src.size(0)
        device = src.device
        enc_outputs, enc_hidden = self.encoder(src, src_lengths)
        enc_mask = (src != self.cfg.pad_index).to(device)
        dec_hidden = _align_hidden_state(enc_hidden,
                                        self.cfg.decoder_layers)

        dec_input = torch.full((B,1),
                               self.cfg.sos_index,
                               device=device,
                               dtype=torch.long)
        generated = []
        for _ in range(max_len):
            logits, dec_hidden, _ = self.decoder(
                dec_input, dec_hidden,
                enc_outputs, enc_mask
            )
            dec_input = logits.argmax(-1)  # (B,1)
            generated.append(dec_input)
        return torch.cat(generated, dim=1)  # (B, max_len)

    def beam_search_decode(
        self,
        src: torch.LongTensor,
        src_lengths: torch.LongTensor,
        *,
        beam_size: int = 5,
        max_len:   int = 50
    ) -> torch.LongTensor:
        """
        Beam-search decoding (batch_size=1 only).
        Returns best seq (1, L) without leading <sos>.
        """
        B = src.size(0)
        assert B == 1, "beam_search_decode only supports batch_size=1"

        device = src.device
        enc_outputs, enc_hidden = self.encoder(src, src_lengths)
        enc_mask  = (src != self.cfg.pad_index).to(device)
        dec_hidden = _align_hidden_state(enc_hidden,
                                        self.cfg.decoder_layers)

        # beams: list of (token_list, score, hidden_state)
        beams = [([self.cfg.sos_index], 0.0, dec_hidden)]
        completed = []

        for _ in range(max_len):
            all_candidates = []
            for seq, score, hidden in beams:
                last = seq[-1]
                if last == self.cfg.eos_index:
                    completed.append((seq, score))
                    continue

                inp = torch.tensor([[last]],
                                   device=device)
                logits, new_hidden, _ = self.decoder(
                    inp, hidden,
                    enc_outputs, enc_mask
                )
                # logits: (1,1,V) → (V,)
                log_probs = F.log_softmax(logits.squeeze(1), dim=-1)[0]

                topk_vals, topk_idx = log_probs.topk(beam_size)
                for lp, idx in zip(topk_vals.tolist(),
                                   topk_idx.tolist()):
                    # detach hidden state for this candidate
                    if isinstance(new_hidden, tuple):
                        h, c = new_hidden
                        nh = (h.detach().clone(),
                              c.detach().clone())
                    else:
                        nh = new_hidden.detach().clone()
                    all_candidates.append((
                        seq + [idx],
                        score + lp,
                        nh
                    ))

            if not all_candidates:
                break

            # keep top beam_size
            beams = sorted(all_candidates,
                           key=lambda x: x[1],
                           reverse=True)[:beam_size]
            # if all beams ended, stop early
            if all(b[-1] == self.cfg.eos_index for b, _, _ in beams):
                completed.extend((b, s) for b, s, _ in beams)
                break

        if not completed:
            completed = [(b, s) for b, s, _ in beams]

        # pick best
        best_seq, _ = max(completed, key=lambda x: x[1])
        # drop leading <sos>
        if best_seq and best_seq[0] == self.cfg.sos_index:
            best_seq = best_seq[1:]
        return torch.tensor(best_seq,
                            dtype=torch.long,
                            device=device).unsqueeze(0)


In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Q5: W&B sweep driver for attention‐augmented Seq2Seq
────────────────────────────────────────────────────────────────
Reuses the attention Seq2Seq from solution_5_model.py.
A YAML under ./configs/ specifies the sweep space.

Usage:
    # Sweep:
    python solution_5.py \
      --mode sweep \
      --sweep_config sweep_attention.yaml \
      --wandb_project DA6401_Intro_to_DeepLearning_Assignment_3 \
      --wandb_run_tag solution_5 \
      --gpu_ids 0 2 3 \
      --train_tsv ./lexicons/hi.translit.sampled.train.tsv \
      --dev_tsv   ./lexicons/hi.translit.sampled.dev.tsv \
      --test_tsv  ./lexicons/hi.translit.sampled.test.tsv \
      --sweep_count 75

    # Single debug run:
    python solution_5.py \
      --mode single \
      --wandb_project transliteration \
      --wandb_run_tag attention_debug \
      --train_tsv ... \
      --dev_tsv   ... \
      --test_tsv  ...
"""
from __future__ import annotations
import argparse
import math
import os
import yaml
from pathlib import Path
from typing import Any, Dict

import torch
import wandb
from torch.utils.data import DataLoader


def get_configs(project_root: str | Path, config_filename: str) -> Dict[str, Any]:
    """Load a YAML sweep configuration from ./configs/."""
    cfg_path = Path(project_root) / "configs" / config_filename
    with open(cfg_path, "r", encoding="utf-8") as handle:
        return yaml.safe_load(handle)


def compute_sequence_accuracy(
    model: Seq2SeqAttention,
    dataset: DakshinaLexicon,
    device: str,
    beam_size: int = 1
) -> float:
    """
    Exact‐match accuracy over the dataset using beam-search.
    Strips leading <sos> from predictions before decoding.
    """
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for src_ids, tgt_ids in dataset:
            src_tensor = torch.tensor([src_ids], device=device)
            src_len    = torch.tensor([len(src_ids)], device=device)

            pred = model.beam_search_decode(
                src_tensor, src_len,
                beam_size=beam_size,
                max_len=len(tgt_ids)
            )[0]  # (L,)

            pred_list = pred.tolist()
            # drop leading <sos>
            sos_idx = dataset.tgt_vocab.stoi["<sos>"]
            if pred_list and pred_list[0] == sos_idx:
                pred_list = pred_list[1:]

            pred_str = dataset.tgt_vocab.decode(pred_list)
            gold_str = dataset.tgt_vocab.decode(tgt_ids[1:])

            correct += int(pred_str == gold_str)
            total   += 1

    return correct / total if total > 0 else 0.0


def run_single_training(sweep_config: Dict[str, Any], static_args: argparse.Namespace) -> None:
    """
    Train + evaluate once using hyperparams in sweep_config
    and fixed filepaths and tags from CLI.
    """
    # pin GPUs
    os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(str(g) for g in static_args.gpu_ids)
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # Data
    train_ds = DakshinaLexicon(
        static_args.train_tsv,
        build_vocabs=True,
        use_attestations=sweep_config.get("use_attestations", False)
    )
    src_vocab, tgt_vocab = train_ds.src_vocab, train_ds.tgt_vocab
    dev_ds  = DakshinaLexicon(static_args.dev_tsv,  src_vocab, tgt_vocab)
    test_ds = DakshinaLexicon(static_args.test_tsv, src_vocab, tgt_vocab)

    collate_fn = lambda batch: collate_batch(batch,
                                             pad_id=src_vocab.stoi["<pad>"])
    train_loader = DataLoader(train_ds,
                              batch_size=sweep_config["batch_size"],
                              shuffle=True,
                              collate_fn=collate_fn)
    dev_loader   = DataLoader(dev_ds,
                              batch_size=sweep_config["batch_size"],
                              shuffle=False,
                              collate_fn=collate_fn)
    test_loader  = DataLoader(test_ds,
                              batch_size=sweep_config["batch_size"],
                              shuffle=False,
                              collate_fn=collate_fn)

    # Model + optimizer + loss
    extra = {}
    if sweep_config["embedding_method"] == "svd_ppmi":
        extra["svd_sources"] = train_ds.encoded_sources

    cfg = Seq2SeqAttentionConfig(
        source_vocab_size=src_vocab.size,
        target_vocab_size=tgt_vocab.size,
        embedding_dim=sweep_config["embedding_size"],
        hidden_dim=sweep_config["hidden_size"],
        encoder_layers=sweep_config["encoder_layers"],
        decoder_layers=sweep_config["decoder_layers"],
        cell_type=sweep_config["cell"],
        dropout=sweep_config["dropout"],
        pad_index=src_vocab.stoi["<pad>"],
        sos_index=tgt_vocab.stoi["<sos>"],
        eos_index=tgt_vocab.stoi["<eos>"],
        embedding_method=sweep_config["embedding_method"],
        **extra
    )
    model     = Seq2SeqAttention(cfg).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=sweep_config["learning_rate"])
    loss_fn   = torch.nn.CrossEntropyLoss(ignore_index=cfg.pad_index)

    # build a run name from hyperparams
    run_name = (
        f"emb:{sweep_config['embedding_method']}|es:{sweep_config['embedding_size']}|"
        f"cell:{sweep_config['cell']}|hs:{sweep_config['hidden_size']}|"
        f"enc:{sweep_config['encoder_layers']}|dec:{sweep_config['decoder_layers']}|"
        f"dr:{sweep_config['dropout']}|lr:{sweep_config['learning_rate']:.1e}|"
        f"bsz:{sweep_config['batch_size']}|tf:{sweep_config['teacher_forcing']}|"
        f"ep:{sweep_config['epochs']}|beam:{sweep_config.get('beam_size',1)}|"
        f"att:{sweep_config.get('use_attestations',False)}"
    )
    wandb.run.name = run_name
    wandb.run.tags = [static_args.wandb_run_tag]

    # Train / validate
    for epoch in range(1, sweep_config["epochs"] + 1):
        train_loss = train_epoch(
            model, train_loader, optimizer, loss_fn,
            device, sweep_config["teacher_forcing"]
        )
        dev_loss = eval_epoch(model, dev_loader, loss_fn, device)

        wandb.log({
            "Q5_epoch":       epoch,
            "Q5_train_loss":  train_loss,
            "Q5_train_ppl":   math.exp(train_loss),
            "Q5_dev_loss":    dev_loss,
            "Q5_dev_ppl":     math.exp(dev_loss),
        })

    # final dev accuracy via beam-search
    beam = sweep_config.get("beam_size", 1)
    dev_acc = compute_sequence_accuracy(model, dev_ds,
                                        device, beam_size=beam)
    wandb.log({"Q5_dev_accuracy": dev_acc})


def main():
    parser = argparse.ArgumentParser(
        description="W&B sweep driver for Q5 attention model."
    )
    parser.add_argument("--mode",         choices=["sweep","single"], required=True)
    parser.add_argument("--sweep_config", type=str, default="sweep_attention.yaml")
    parser.add_argument("--wandb_project",type=str, required=True)
    parser.add_argument("--wandb_run_tag",type=str, default="attention")
    parser.add_argument("--gpu_ids",      nargs="+", type=int, default=[0])
    parser.add_argument("--train_tsv",    type=str, required=True)
    parser.add_argument("--dev_tsv",      type=str, required=True)
    parser.add_argument("--test_tsv",     type=str, required=True)
    parser.add_argument("--sweep_count",  type=int, default=30)
    args = parser.parse_args()

    project_root = Path.cwd().resolve().parent
    sweep_yaml   = get_configs(project_root, args.sweep_config)

    if args.mode == "sweep":
        # ensure metric, method, etc.
        sweep_yaml.setdefault("method", "bayes")
        sweep_yaml.setdefault("metric", {"name":"dev_perplexity","goal":"minimize"})
        sweep_yaml.setdefault("parameters", sweep_yaml.get("parameters",{}))
        sweep_yaml["run_cap"]  = args.sweep_count

        sweep_id = wandb.sweep(sweep=sweep_yaml,
                               project=args.wandb_project)
        print(f"Registered sweep: {sweep_id}")

        def _agent():
            with wandb.init(project=args.wandb_project) as run:
                run_single_training(dict(run.config), args)

        wandb.agent(sweep_id, function=_agent,
                    count=args.sweep_count)

    else:
        # single debug run
        with wandb.init(
            project=args.wandb_project,
            config=sweep_yaml.get("parameters", {})
        ) as run:
            run.config.update({
                "epochs":            3,
                "batch_size":       64,
                "embedding_method":"learned",
                "embedding_size":   64,
                "hidden_size":     128,
                "encoder_layers":    1,
                "decoder_layers":    1,
                "cell":            "LSTM",
                "dropout":          0.1,
                "learning_rate":   1e-3,
                "teacher_forcing":  0.5,
                "use_attestations":False,
                "beam_size":        1,
            }, allow_val_change=True)
            run_single_training(dict(run.config), args)

import sys

# Simulate the exact CLI you’d type:
sys.argv = [
    "solution_5.py",
    "--mode", "sweep",
    "--sweep_config", "sweep_attention.yaml",
    "--wandb_project", "transliteration",
    "--wandb_run_tag", "solution_5",
    "--gpu_ids", "0", "2", "3",
    "--train_tsv", "../lexicons/hi.translit.sampled.train.tsv",
    "--dev_tsv",   "../lexicons/hi.translit.sampled.dev.tsv",
    "--test_tsv",  "../lexicons/hi.translit.sampled.test.tsv",
    "--sweep_count", "75",
]

main()


#### Q5 Solution - Part A

In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
solution_5a.py: Tune beam size for your pretrained attention‐based Hindi transliteration model.

Usage example:
    python solution_5a.py \
        --train_tsv ./lexicons/hi.translit.sampled.train.tsv \
        --dev_tsv   ./lexicons/hi.translit.sampled.dev.tsv \
        --test_tsv  ./lexicons/hi.translit.sampled.test.tsv \
        --gpu_ids   0 1

This script:
  1. Loads the Dakshina Hindi train/dev/test splits.
  2. Trains a Seq2SeqAttention model with your best attention hyperparameters.
  3. Runs beam‐search decoding on the dev set for various beam sizes.
  4. Reports exact‐match accuracy for each beam size.
"""

from __future__ import annotations
import argparse
import math
import os
from pathlib import Path
from typing import Tuple

import torch
from torch.utils.data import DataLoader, WeightedRandomSampler
from tqdm import tqdm

# ──────────────────────────── best hyperparameters ────────────────────────────
best = {
    "batch_size":        128,
    "cell_type":       "GRU",
    "encoder_layers":     1,
    "decoder_layers":     2,
    "hidden_size":      512,
    "embedding_method":"svd_ppmi",
    "embedding_size":   64,
    "dropout":          0.2,
    "learning_rate":   0.0006899910999897612,
    "teacher_forcing":  0.5,
    "use_attestations": True,
    "epochs":           10,
}
# ──────────────────────────────────────────────────────────────────────────────

def parse_args():
    p = argparse.ArgumentParser(
        description="Tune beam size for your best attention‐based model"
    )
    p.add_argument(
        "--train_tsv",  type=str, required=True,
        help="Path to Hindi training lexicon TSV"
    )
    p.add_argument(
        "--dev_tsv",    type=str, required=True,
        help="Path to Hindi development lexicon TSV"
    )
    p.add_argument(
        "--test_tsv",   type=str, required=True,
        help="Path to Hindi test lexicon TSV"
    )
    p.add_argument(
        "--gpu_ids", type=int, nargs="+", default=[0],
        help="CUDA device IDs to use (e.g. 0 1)."
    )
    return p.parse_args()

def build_data_loaders(
    train_path: str,
    dev_path:   str,
    test_path:  str,
    batch_size: int,
    use_attest: bool
) -> Tuple[DataLoader, DataLoader, DataLoader, DakshinaLexicon, DakshinaLexicon]:
    """
    Builds train/dev/test datasets and loaders, re-using vocab built on train.
    If use_attest=True, uses a WeightedRandomSampler on train counts.
    """
    # build train ds + vocabs
    train_ds = DakshinaLexicon(
        train_path,
        build_vocabs=True,
        use_attestations=use_attest
    )
    src_vocab, tgt_vocab = train_ds.src_vocab, train_ds.tgt_vocab

    # dev/test reuse same vocabs
    dev_ds  = DakshinaLexicon(dev_path,  src_vocab, tgt_vocab)
    test_ds = DakshinaLexicon(test_path, src_vocab, tgt_vocab)

    # collate_fn for padding
    pad_id = src_vocab.stoi["<pad>"]
    collate_fn = lambda batch: collate_batch(batch, pad_id=pad_id)

    # train loader: optionally weighted by attestations
    if use_attest:
        sampler = WeightedRandomSampler(
            weights=train_ds.example_counts,
            num_samples=len(train_ds),
            replacement=True
        )
        train_loader = DataLoader(
            train_ds, batch_size=batch_size,
            sampler=sampler, collate_fn=collate_fn
        )
    else:
        train_loader = DataLoader(
            train_ds, batch_size=batch_size,
            shuffle=True, collate_fn=collate_fn
        )

    # dev/test loaders
    dev_loader  = DataLoader(
        dev_ds, batch_size=batch_size,
        shuffle=False, collate_fn=collate_fn
    )
    test_loader = DataLoader(
        test_ds, batch_size=batch_size,
        shuffle=False, collate_fn=collate_fn
    )

    return train_loader, dev_loader, test_loader, train_ds, dev_ds

def train_model(
    train_loader: DataLoader,
    dev_loader:   DataLoader,
    cfg:          Seq2SeqAttentionConfig,
    device:       torch.device,
    epochs:       int,
    lr:           float,
    teacher_forcing: float
) -> Seq2SeqAttention:
    """
    Trains the Seq2SeqAttention model for `epochs` epochs.
    Returns the trained model.
    """
    model = Seq2SeqAttention(cfg).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    loss_fn   = torch.nn.CrossEntropyLoss(ignore_index=cfg.pad_index)

    for epoch in range(1, epochs + 1):
        train_loss = train_epoch(
            model, train_loader,
            optimizer, loss_fn,
            device, teacher_forcing
        )
        dev_loss   = eval_epoch(
            model, dev_loader,
            loss_fn, device
        )
        print(f"Epoch {epoch:02d} | train ppl={math.exp(train_loss):.2f} | dev ppl={math.exp(dev_loss):.2f}")
    return model

def evaluate_beam_exact_match(
    model:       Seq2SeqAttention,
    dataset:     DakshinaLexicon,
    beam_size:   int,
    device:      torch.device
) -> float:
    """
    Runs beam_search_decode on each example in `dataset` (batch_size=1),
    computes exact‐match rate vs. gold target.
    """
    model.eval()
    total, correct = 0, 0
    with torch.no_grad():
        for src_ids, tgt_ids in tqdm(dataset, desc=f"beam={beam_size}", leave=False):
            total += 1
            # prepare tensors
            src_len = len(src_ids)
            src_tensor = torch.tensor([src_ids], device=device)
            len_tensor = torch.tensor([src_len], device=device)

            # beam search decode
            pred_ids = model.beam_search_decode(
                src_tensor, len_tensor,
                beam_size=beam_size,
                max_len=max(src_len * 2, 50)
            )[0].tolist()

            # drop leading <sos> if present
            if pred_ids and pred_ids[0] == dataset.tgt_vocab.stoi["<sos>"]:
                pred_ids = pred_ids[1:]

            # decode strings
            pred_str = dataset.tgt_vocab.decode(pred_ids)
            gold_str = dataset.tgt_vocab.decode(tgt_ids[1:])

            if pred_str == gold_str:
                correct += 1

    return correct / total if total > 0 else 0.0

def main():
    args = parse_args()

    # pin GPUs
    os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, args.gpu_ids))
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # build data loaders
    train_loader, dev_loader, _, train_ds, dev_ds = build_data_loaders(
        train_path=args.train_tsv,
        dev_path=args.dev_tsv,
        test_path=args.test_tsv,    # unused here
        batch_size=best["batch_size"],
        use_attest=best["use_attestations"]
    )

    # construct model config
    extra_cfg = {}
    if best["embedding_method"] == "svd_ppmi":
        extra_cfg["svd_sources"] = train_ds.encoded_sources

    cfg = Seq2SeqAttentionConfig(
        source_vocab_size=train_ds.src_vocab.size,
        target_vocab_size=train_ds.tgt_vocab.size,
        embedding_dim=best["embedding_size"],
        hidden_dim=best["hidden_size"],
        encoder_layers=best["encoder_layers"],
        decoder_layers=best["decoder_layers"],
        cell_type=best["cell_type"],
        dropout=best["dropout"],
        pad_index=train_ds.src_vocab.stoi["<pad>"],
        sos_index=train_ds.tgt_vocab.stoi["<sos>"],
        eos_index=train_ds.tgt_vocab.stoi["<eos>"],
        embedding_method=best["embedding_method"],
        **extra_cfg,
    )

    # train the model
    model = train_model(
        train_loader, dev_loader,
        cfg=cfg,
        device=device,
        epochs=best["epochs"],
        lr=best["learning_rate"],
        teacher_forcing=best["teacher_forcing"]
    )

    # tune beam size on dev set
    print("\nTuning beam size on dev set (exact‐match rate):")
    for beam in [1, 2, 3, 5, 8, 10]:
        acc = evaluate_beam_exact_match(
            model, dev_ds,
            beam_size=beam,
            device=device
        )
        print(f"  beam_size={beam:2d} → dev accuracy = {acc * 100:5.2f}%")

import sys

# Simulate exactly the CLI invocation:
sys.argv = [
    "solution_5a.py",
    "--train_tsv", "../lexicons/hi.translit.sampled.train.tsv",
    "--dev_tsv",   "../lexicons/hi.translit.sampled.dev.tsv",
    "--test_tsv",  "../lexicons/hi.translit.sampled.test.tsv",
    "--gpu_ids",   "3",
]


main()


#### Q5 Solution - Part B

In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Q5.b: Train (with early stopping) & evaluate best attention‐augmented Seq2Seq model on the Dakshina Hindi test set.

This script:
  1. Loads train/dev/test lexicons and builds/uses the same vocabularies.
  2. Constructs the Seq2SeqAttention model with your best hyperparameters.
  3. If no checkpoint exists:
       - Trains with early stopping on dev cross‐entropy loss (patience=3).
       - Saves the best‐so‐far model to --checkpoint.
     Otherwise loads the existing checkpoint.
  4. Runs beam-search decoding on the test set, computes exact‐match accuracy.
  5. Saves all (source, gold, prediction) triples under `predictions_attention/` as TSV and CSV.
  6. Samples 20 predictions, builds a colored table figure, saves it, and logs it to W&B.
  7. Selects 10 random test examples, computes their greedy attention heatmaps, and
     plots them in a 3×4 grid, saving & logging the figure to W&B.

Usage example:

    python solution_5b.py \
      --train_tsv ./lexicons/hi.translit.sampled.train.tsv \
      --dev_tsv   ./lexicons/hi.translit.sampled.dev.tsv \
      --test_tsv  ./lexicons/hi.translit.sampled.test.tsv \
      --checkpoint ./checkpoints/best_attention.pt \
      --output_dir predictions_attention \
      --gpu_ids 3 \
      --wandb_project DA6401_Intro_to_DeepLearning_Assignment_3 \
      --wandb_run_name solution_5b_run \
      --wandb_run_tag solution_5b
"""
from __future__ import annotations
import argparse
import os
import math
from pathlib import Path
import random

import torch
from torch.utils.data import DataLoader
import pandas as pd
import wandb
import matplotlib.pyplot as plt
import matplotlib.font_manager as fm

# ────────────────────────────────
# Download & register Devanagari font
# ────────────────────────────────
font_dir = Path.cwd() / "fonts"
font_dir.mkdir(exist_ok=True)
font_path = font_dir / "Hind-Regular.ttf"

if not font_path.exists():
    print("Font file not found. Please ensure 'Hind-Regular.ttf' is in the 'fonts' directory which"
    "can be downloaded from Google Fonts. https://www.cufonfonts.com/font/noto-sans-devanagari")
else:
    fm.fontManager.addfont(str(font_path))
    plt.rcParams["font.family"] = "Hind"
    plt.style.use("seaborn-v0_8-pastel")
    print("Font loaded and matplotlib configured.")

# ─────────────────────────────────────────────────────────────────────
# Replace these with best attention‐model hyperparameters:
best = {
    "batch_size":        128,
    "beam_size":         3,
    "cell_type":       "GRU",
    "decoder_layers":     2,
    "dropout":          0.2,
    "embedding_method":"svd_ppmi",
    "embedding_size":   64,
    "encoder_layers":     1,
    "hidden_size":      512,
    "learning_rate":   0.0006899910999897612,
    "teacher_forcing":  0.5,
    "use_attestations": True,
    # early stopping patience
    "patience":          3,
    # number of training epochs to try
    "epochs":           25,
}
# ─────────────────────────────────────────────────────────────────────

def parse_args():
    p = argparse.ArgumentParser(
        description="Q5.b: Train (with early stopping) & evaluate best attention model"
    )
    p.add_argument("--train_tsv",     type=str, required=True, help="Path to train lexicon TSV")
    p.add_argument("--dev_tsv",       type=str, required=True, help="Path to dev lexicon TSV")
    p.add_argument("--test_tsv",      type=str, required=True, help="Path to test lexicon TSV")
    p.add_argument("--checkpoint",    type=str, required=True, help="Path to save/load model checkpoint")
    p.add_argument("--output_dir",    type=str, default="predictions_attention",
                   help="Directory to write predictions.tsv/csv")
    p.add_argument("--gpu_ids",       type=int, nargs="+", default=[0], help="CUDA device IDs")
    p.add_argument("--wandb_project", type=str, default=None, help="W&B project name")
    p.add_argument("--wandb_run_name",type=str, default=None, help="W&B run name")
    p.add_argument("--wandb_run_tag", type=str, default="solution_5b", help="W&B run tag")
    return p.parse_args()

def get_attention_heatmap(
    model: Seq2SeqAttention,
    src_ids: list[int],
    src_lens: torch.Tensor,
    src_vocab: DakshinaLexicon.src_vocab.__class__,
    tgt_vocab: DakshinaLexicon.tgt_vocab.__class__,
    device: torch.device,
    max_len: int = 50
) -> tuple[list[int], list[list[float]]]:
    """
    Run greedy decode step-by-step, collecting the attention weights at each step.
    Returns (predicted_ids, attention_weights_matrix) where matrix[t][i] is
    the attention weight at decoder time t for encoder position i.
    """
    model.eval()
    # prepare tensors
    src_tensor = torch.tensor([src_ids], dtype=torch.long, device=device)
    enc_outputs, enc_hidden = model.encoder(src_tensor, src_lens)
    enc_mask = (src_tensor != model.cfg.pad_index).to(device)
    dec_hidden = _align_hidden_state(enc_hidden, model.cfg.decoder_layers)
    # start with <sos>
    dec_input = torch.full((1,1), model.cfg.sos_index, dtype=torch.long, device=device)

    predicted_ids: list[int] = []
    attention_weights: list[list[float]] = []

    for _ in range(max_len):
        logits, dec_hidden, align = model.decoder(
            dec_input, dec_hidden, enc_outputs, enc_mask
        )
        # align: (1, T_src) → record
        alignment = align.squeeze(0).tolist()
        attention_weights.append(alignment)
        # pick argmax
        dec_input = logits.argmax(-1)  # (1,1)
        next_id = dec_input.item()
        if next_id == model.cfg.eos_index:
            break
        predicted_ids.append(next_id)

    return predicted_ids, attention_weights


def main():
    args = parse_args()

    # ─── Pin GPUs & select device ────────────────────────────────────
    os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, args.gpu_ids))
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # ─── Init WandB if requested ─────────────────────────────────────
    use_wandb = args.wandb_project is not None
    if use_wandb:
        wandb.init(
            project=args.wandb_project,
            name=args.wandb_run_name,
            tags=[args.wandb_run_tag],
            config=best
        )

    # ─── Ensure output directories exist ──────────────────────────────
    output_path = Path(args.output_dir)
    output_path.mkdir(parents=True, exist_ok=True)
    ckpt_path = Path(args.checkpoint)
    ckpt_path.parent.mkdir(parents=True, exist_ok=True)

    # ─── Build vocab from train set ──────────────────────────────────
    train_ds = DakshinaLexicon(
        args.train_tsv,
        build_vocabs=True,
        use_attestations=best["use_attestations"]
    )
    src_vocab = train_ds.src_vocab
    tgt_vocab = train_ds.tgt_vocab

    # ─── Prepare dev/test datasets ───────────────────────────────────
    collate_fn = lambda batch: collate_batch(batch, pad_id=src_vocab.stoi["<pad>"])
    train_loader = DataLoader(train_ds, batch_size=best["batch_size"],
                              shuffle=True, collate_fn=collate_fn)
    dev_ds = DakshinaLexicon(args.dev_tsv, src_vocab, tgt_vocab)
    dev_loader = DataLoader(dev_ds, batch_size=best["batch_size"],
                            shuffle=False, collate_fn=collate_fn)
    test_ds = DakshinaLexicon(args.test_tsv, src_vocab, tgt_vocab)
    test_loader = DataLoader(test_ds, batch_size=best["batch_size"],
                             shuffle=False, collate_fn=collate_fn)

    # ─── Build the Seq2SeqAttention model ────────────────────────────
    extra = {}
    if best["embedding_method"] == "svd_ppmi":
        extra["svd_sources"] = train_ds.encoded_sources

    cfg = Seq2SeqAttentionConfig(
        source_vocab_size=src_vocab.size,
        target_vocab_size=tgt_vocab.size,
        embedding_dim=best["embedding_size"],
        hidden_dim=best["hidden_size"],
        encoder_layers=best["encoder_layers"],
        decoder_layers=best["decoder_layers"],
        cell_type=best["cell_type"],
        dropout=best["dropout"],
        pad_index=src_vocab.stoi["<pad>"],
        sos_index=tgt_vocab.stoi["<sos>"],
        eos_index=tgt_vocab.stoi["<eos>"],
        embedding_method=best["embedding_method"],
        **extra
    )
    model = Seq2SeqAttention(cfg).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=best["learning_rate"])
    loss_fn = torch.nn.CrossEntropyLoss(ignore_index=cfg.pad_index)

    # ─── Train with early stopping if checkpoint missing ─────────────
    if not ckpt_path.exists():
        print("No checkpoint found; starting training with early stopping...")
        best_dev_loss = float("inf")
        no_improve = 0

        for epoch in range(1, best["epochs"] + 1):
            train_loss = train_epoch(
                model, train_loader, optimizer, loss_fn,
                device, teacher_forcing=best["teacher_forcing"]
            )
            dev_loss = eval_epoch(model, dev_loader, loss_fn, device)
            train_ppl = math.exp(train_loss)
            dev_ppl   = math.exp(dev_loss)

            print(f"Epoch {epoch:02d} | "
                  f"train_loss={train_loss:.4f} ppl={train_ppl:.2f} | "
                  f"dev_loss={dev_loss:.4f} ppl={dev_ppl:.2f}")

            if use_wandb:
                wandb.log({
                    "Q5b_epoch":        epoch,
                    "Q5b_train_loss":   train_loss,
                    "Q5b_train_ppl":    train_ppl,
                    "Q5b_dev_loss":     dev_loss,
                    "Q5b_dev_ppl":      dev_ppl,
                })

            # early stopping on dev_loss
            if dev_loss < best_dev_loss:
                best_dev_loss = dev_loss
                no_improve = 0
                torch.save({"model_state_dict": model.state_dict()}, str(ckpt_path))
                print("  ↳ dev improved; checkpoint saved.")
            else:
                no_improve += 1
                print(f"  ↳ no improvement for {no_improve} epoch(s)")
                if no_improve >= best["patience"]:
                    print("Early stopping.")
                    break

        print("Training complete.\n")
    else:
        print(f"Found existing checkpoint at {ckpt_path}; skipping training.\n")

    # ─── Load the best checkpoint ─────────────────────────────────────
    ckpt = torch.load(str(ckpt_path), map_location=device)
    model.load_state_dict(ckpt["model_state_dict"])
    model.eval()

    # ─── Decode test set & compute exact‐match accuracy ──────────────
    total, correct = 0, 0
    predictions = []

    with torch.no_grad():
        for src_batch, src_lens, tgt_batch in test_loader:
            for i in range(src_batch.size(0)):
                total += 1
                # Trim to true length to fix mask/score mismatch
                length = src_lens[i].item()
                s = src_batch[i, :length].unsqueeze(0).to(device)
                l = torch.tensor([length], device=device)

                pred_ids = model.beam_search_decode(
                    s, l,
                    beam_size=best["beam_size"],
                    max_len=50
                )[0].tolist()

                # drop leading <sos> if present
                if pred_ids and pred_ids[0] == tgt_vocab.stoi["<sos>"]:
                    pred_ids = pred_ids[1:]

                src_str  = src_vocab.decode(s[0].tolist())
                gold_str = tgt_vocab.decode(tgt_batch[i].tolist()[1:])
                pred_str = tgt_vocab.decode(pred_ids)

                if pred_str == gold_str:
                    correct += 1
                predictions.append((src_str, gold_str, pred_str))

    accuracy = correct / total * 100
    print(f"\nTest exact‐match accuracy: {accuracy:.2f}% ({correct}/{total})\n")
    if use_wandb:
        wandb.log({"Q5b_test_accuracy": accuracy})

    # ─── Save all predictions ─────────────────────────────────────────
    df = pd.DataFrame(predictions, columns=["source", "target", "prediction"])
    df.to_csv(output_path / "predictions.tsv", sep="\t", index=False)
    df.to_csv(output_path / "predictions.csv", index=False)
    print(f"Saved predictions → {output_path/'predictions.tsv'}, {output_path/'predictions.csv'}\n")

    # ─── Sample 20 and build colored table figure ─────────────────────
    sample_df = df.sample(20, random_state=42)
    # green for correct, red for wrong
    colors = [
        ["#c8e6c9" if row.target == row.prediction else "#ffcdd2"
         for _ in sample_df.columns]
        for _, row in sample_df.iterrows()
    ]
    fig, ax = plt.subplots(figsize=(10, 6))
    ax.axis("off")
    tbl = ax.table(
        cellText=sample_df.values.tolist(),
        colLabels=sample_df.columns.tolist(),
        cellColours=colors,
        cellLoc="center",
        loc="center"
    )
    tbl.auto_set_font_size(False)
    tbl.set_fontsize(10)
    tbl.scale(1, 2)
    figure_path = output_path / "sample_predictions.png"
    fig.savefig(figure_path, bbox_inches="tight")
    plt.close(fig)

    print(f"Saved sample predictions figure to {figure_path}")
    if use_wandb:
        wandb.log({"Q5b_sample_table": wandb.Image(str(figure_path))})
    
    # ─── Attention heatmaps for 10 random test examples ──────────────
    random.seed(42)
    indices = random.sample(range(len(test_ds)), 10)
    heatmaps = []
    src_lists, pred_lists = [], []
    for idx in indices:
        src_ids, tgt_ids = test_ds[idx]
        _, attn = get_attention_heatmap(
            model, src_ids, torch.tensor([len(src_ids)], device=device),
            src_vocab, tgt_vocab, device
        )
        heatmaps.append(attn)
        src_lists.append(src_ids)
        # we drop <sos> from predictions
        pred_ids = [i for i in attn and []]  # placeholder

    # actually regenerate preds & char lists
    preds_and_attn = []
    for src_ids in src_lists:
        pred_ids, attn = get_attention_heatmap(
            model, src_ids, torch.tensor([len(src_ids)], device=device),
            src_vocab, tgt_vocab, device
        )
        # drop any leading <sos>
        if pred_ids and pred_ids[0] == tgt_vocab.stoi["<sos>"]:
            pred_ids = pred_ids[1:]
            attn = attn[1:]
        preds_and_attn.append((src_ids, pred_ids, attn))

    # plot in a 3×4 grid
    n = len(preds_and_attn)
    cols = 3
    rows = math.ceil(n/cols)
    fig, axes = plt.subplots(rows, cols, figsize=(cols*4, rows*3))
    axes = axes.flatten()
    for i, (src_ids, pred_ids, attn) in enumerate(preds_and_attn):
        ax = axes[i]
        im = ax.imshow(attn, aspect="auto", origin="lower")
        # x-axis: source chars
        sx = [src_vocab.itos[id] for id in src_ids]
        ax.set_xticks(range(len(sx)))
        ax.set_xticklabels(sx, rotation=90, fontsize=8)
        # y-axis: predicted chars
        py = [tgt_vocab.itos[id] for id in pred_ids]
        ax.set_yticks(range(len(py)))
        ax.set_yticklabels(py, fontsize=8)
        ax.set_xlabel("Source")
        ax.set_ylabel("Predicted")
        ax.set_title(f"Example {i+1}")
        fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
    # turn off any extra axes
    for ax in axes[n:]:
        ax.axis("off")
    plt.tight_layout()
    heatmap_path = output_path/"attention_heatmaps.png"
    fig.savefig(heatmap_path, bbox_inches="tight")
    plt.close(fig)
    print(f"Saved attention heatmaps to {heatmap_path}")
    if use_wandb:
        wandb.log({"Q5b_attention_heatmaps": wandb.Image(str(heatmap_path))})
        wandb.finish()

import sys

# Simulate the exact CLI invocation:
sys.argv = [
    "solution_5b.py",
    "--train_tsv",    "../lexicons/hi.translit.sampled.train.tsv",
    "--dev_tsv",      "../lexicons/hi.translit.sampled.dev.tsv",
    "--test_tsv",     "../lexicons/hi.translit.sampled.test.tsv",
    "--checkpoint",   "../checkpoints/best_attention_.pt",
    "--output_dir",   "../predictions_attention_",
    "--gpu_ids",      "3",
    "--wandb_project","transliteration",
    "--wandb_run_name","solution_5b_run",
    "--wandb_run_tag","solution_5b",
]


main()


# Q6 Solution

In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Q6: Interactive “Connectivity” visualization for your attention-augmented Seq2Seq model

This script lets you:
  1. Load your trained Seq2SeqAttention checkpoint.
  2. Sample N examples from the test set.
  3. For each example, run a greedy decode while recording the attention weights
     at each decoder timestep.
  4. Save a standalone HTML that displays, for each example:
       - The source characters along the x-axis.
       - The predicted output characters along the y-axis.
       - A Plotly heatmap of attention weights.
     You can hover over any cell to see the exact weight—and thus see “connectivity.”
  5. (New!) Log the final HTML to WandB so you can browse it in your project.

Usage (as a script):

    python solution_6.py \
      --checkpoint ./checkpoints/best_attention.pt \
      --train_tsv   ./lexicons/hi.translit.sampled.train.tsv \
      --test_tsv    ./lexicons/hi.translit.sampled.test.tsv \
      --n_examples  5 \
      --output_html connectivity.html \
      --wandb_project DA6401_Intro_to_DeepLearning_Assignment_3 \
      --wandb_run_name solution_6_run \
      --wandb_run_tag  solution_6
"""

from __future__ import annotations
import argparse
import os
import random
from pathlib import Path

import torch
import pandas as pd
import plotly.graph_objects as go
import wandb

# ──────────────────────────────────────────────────────────────────────────────
# Hyper-parameters for your best attention model (must match how you trained it)
best = {
    "batch_size":        128,
    "beam_size":         3,
    "cell_type":         "GRU",
    "decoder_layers":    3,
    "dropout":           0.2,
    "embedding_method":  "svd_ppmi",
    "embedding_size":    64,
    "encoder_layers":    1,
    "hidden_size":       512,
    "learning_rate":     0.0006899910999897612,
    "teacher_forcing":   0.5,
    "use_attestations":  True,
    # early stopping patience
    "patience":         3,
    # number of training epochs to try
    "epochs":          20,
}
# ──────────────────────────────────────────────────────────────────────────────

def parse_args():
    p = argparse.ArgumentParser(
        description="Q6: Visualize character-level attention connectivity interactively"
    )
    p.add_argument(
        "--checkpoint", type=str, required=True,
        help="Path to your trained Seq2SeqAttention .pt file"
    )
    p.add_argument(
        "--train_tsv", type=str, required=True,
        help="TSV of the TRAIN split (native\tr omanized\tcount) for building vocab"
    )
    p.add_argument(
        "--test_tsv", type=str, required=True,
        help="TSV of the TEST split (native\tr omanized\tcount)"
    )
    p.add_argument(
        "--n_examples", type=int, default=3,
        help="Number of random examples to visualize"
    )
    p.add_argument(
        "--output_html", type=str, default="connectivity.html",
        help="Path to save the standalone HTML with Plotly plots"
    )
    # WandB logging args
    p.add_argument(
        "--wandb_project", type=str, default=None,
        help="(optional) WandB project name to log this HTML"
    )
    p.add_argument(
        "--wandb_run_name", type=str, default=None,
        help="(optional) WandB run name"
    )
    p.add_argument(
        "--wandb_run_tag", type=str, default=None,
        help="(optional) WandB run tag"
    )
    return p.parse_args()

def load_model_and_vocab(
    ckpt_path: str,
    train_tsv: str,
    device: torch.device
) -> tuple[Seq2SeqAttention, DakshinaLexicon, DakshinaLexicon]:
    """
    Loads the checkpoint into a Seq2SeqAttention model,
    and builds the vocab from the TRAIN split so that src/tgt vocabs align.
    """
    # 1) Build the vocabulary from the true training split
    train_ds = DakshinaLexicon(
        train_tsv,
        build_vocabs=True,
        use_attestations=best["use_attestations"]
    )
    src_vocab = train_ds.src_vocab
    tgt_vocab = train_ds.tgt_vocab

    # 2) Peek at the checkpoint to detect how many decoder layers it actually has
    ckpt = torch.load(ckpt_path, map_location=device)
    state_dict = ckpt["model_state_dict"]

    # Collect all weight keys of the form "decoder.rnn.weight_ih_l{n}"
    layer_indices = []
    for key in state_dict:
        if key.startswith("decoder.rnn.weight_ih_l"):
            # e.g. "decoder.rnn.weight_ih_l0", "decoder.rnn.weight_ih_l1", ...
            idx_str = key.split("decoder.rnn.weight_ih_l", 1)[1]
            try:
                layer_indices.append(int(idx_str))
            except ValueError:
                continue

    if layer_indices:
        # max index + 1 = number of layers
        actual_decoder_layers = max(layer_indices) + 1
        print(f"  ↳ Checkpoint has {actual_decoder_layers} decoder layers (detected)")
    else:
        actual_decoder_layers = best["decoder_layers"]
        print(f"  ↳ No decoder.rnn.weight_ih_l* keys found; defaulting to {actual_decoder_layers}")

    # 3) Construct exactly the same config you used for training,
    #    except use the detected number of decoder layers:
    cfg = Seq2SeqAttentionConfig(
        source_vocab_size=src_vocab.size,
        target_vocab_size=tgt_vocab.size,
        embedding_dim=best["embedding_size"],
        hidden_dim=best["hidden_size"],
        encoder_layers=best["encoder_layers"],
        decoder_layers=actual_decoder_layers,
        cell_type=best["cell_type"],
        dropout=best["dropout"],
        pad_index=src_vocab.stoi["<pad>"],
        sos_index=tgt_vocab.stoi["<sos>"],
        eos_index=tgt_vocab.stoi["<eos>"],
        embedding_method=best["embedding_method"],
        **({"svd_sources": train_ds.encoded_sources}
           if best["embedding_method"] == "svd_ppmi" else {})
    )

    # 4) Instantiate and load
    model = Seq2SeqAttention(cfg).to(device)
    model.load_state_dict(state_dict)
    model.eval()

    return model, src_vocab, tgt_vocab

def record_attention(
    model: Seq2SeqAttention,
    src_ids: list[int],
    src_vocab,
    tgt_vocab,
    device: torch.device,
    max_len: int = 50
) -> tuple[list[str], list[str], list[list[float]]]:
    """
    Greedy-decode `src_ids` with the model, capturing attention weights at each step.
    Returns:
      (source_chars, predicted_chars, attention_matrix)
    where
      attention_matrix[t][i] = score attending to src position i when predicting
                              output char at position t
    """
    # Prepare tensors
    src_tensor = torch.tensor([src_ids], dtype=torch.long, device=device)
    src_len    = torch.tensor([len(src_ids)], device=device)

    # Encoder: get outputs and initial hidden
    enc_outputs, enc_hidden = model.encoder(src_tensor, src_len)
    enc_mask = (src_tensor != model.cfg.pad_index).to(device)
    dec_hidden = _align_hidden_state(enc_hidden, model.cfg.decoder_layers)

    # Initialize decoder input (<sos>)
    dec_input = torch.full((1,1), model.cfg.sos_index, dtype=torch.long, device=device)
    predicted_ids: list[int] = []
    attentions: list[list[float]] = []

    # Step through decoder, capture alignments
    for _ in range(max_len):
        logits, dec_hidden, alignments = model.decoder(
            dec_input, dec_hidden, enc_outputs, enc_mask
        )
        # record attention over source
        attn = alignments.squeeze(0).tolist()
        attentions.append(attn)

        # next token
        dec_input = logits.argmax(-1)  # shape (1,1)
        next_id = dec_input.item()
        if next_id == model.cfg.eos_index:
            break
        predicted_ids.append(next_id)

    # Convert indices back to characters
    source_chars    = [src_vocab.itos[i] for i in src_ids]
    predicted_chars = [tgt_vocab.itos[i] for i in predicted_ids]
    return source_chars, predicted_chars, attentions

def make_plotly_figure(
    source_chars:    list[str],
    predicted_chars: list[str],
    attentions:      list[list[float]],
    title:           str
) -> go.Figure:
    """
    Builds a Plotly heatmap figure with x=source_chars, y=predicted_chars,
    z=attentions matrix.  Hover text will show exact weight.
    """
    heatmap = go.Heatmap(
        z=attentions,
        x=source_chars,
        y=predicted_chars,
        colorscale="Blues",
        zmin=0, zmax=1,
        colorbar=dict(
            title=dict(
                text="Attention",
                side="right",
                font=dict(size=12)
            ),
            lenmode="fraction",
            len=0.6,
            tickfont=dict(size=10)
        ),
        hovertemplate=(
            "input: %{x}<br>"
            "output: %{y}<br>"
            "weight: %{z:.3f}"
            "<extra></extra>"
        )
    )

    fig = go.Figure(data=[heatmap])
    fig.update_layout(
        title=dict(
            text=title,
            x=0.5,           # center
            xanchor="center",
            yanchor="top",
            font=dict(size=16)
        ),
        margin=dict(l=80, r=50, t=100, b=80),
        width=600,
        height=450,
        xaxis=dict(
            title=dict(
                text="Source characters",
                font=dict(size=14)
            ),
            tickangle=-45,
            tickfont=dict(size=12),
            side="top",
            automargin=True
        ),
        yaxis=dict(
            title=dict(
                text="Predicted characters",
                font=dict(size=14)
            ),
            tickfont=dict(size=12),
            automargin=True
        )
    )
    return fig


def main():
    args = parse_args()

    # ─── WandB initialization (optional) ─────────────────────────
    use_wandb = args.wandb_project is not None
    if use_wandb:
        wandb.init(
            project=args.wandb_project,
            name=args.wandb_run_name,
            tags=[args.wandb_run_tag] if args.wandb_run_tag else None
        )

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # 1) Load model and vocab built from the true training split
    model, src_vocab, tgt_vocab = load_model_and_vocab(
        args.checkpoint, args.train_tsv, device
    )

    # 2) Load test set (for sampling)
    test_ds = DakshinaLexicon(args.test_tsv, src_vocab, tgt_vocab)

    # 3) Randomly pick N examples
    random.seed(42)
    selected_indices = random.sample(range(len(test_ds)), k=args.n_examples)

    # 4) For each example, record attention and make a Plotly figure
    figures: list[go.Figure] = []
    for idx_rank, idx in enumerate(selected_indices, start=1):
        src_ids, _ = test_ds[idx]
        source_chars, predicted_chars, attn_matrix = record_attention(
            model, src_ids, src_vocab, tgt_vocab, device
        )
        title = (
            f'Example {idx_rank}: “{"".join(source_chars)}” → '
            f'“{"".join(predicted_chars)}”'
        )
        fig = make_plotly_figure(source_chars, predicted_chars, attn_matrix, title)
        figures.append(fig)

    # 5) Assemble all figures into one standalone HTML
    html_snippets = [
        fig.to_html(full_html=False, include_plotlyjs="cdn")
        for fig in figures
    ]
    html_body = "\n<hr>\n".join(html_snippets)
    full_html = f"""\
<html>
  <head><meta charset="utf-8"/><title>Attention Connectivity</title></head>
  <body>
    <h1>Q6: Attention Connectivity Visualizations</h1>
    {html_body}
  </body>
</html>
"""

    # 6) Write out the HTML
    out_path = Path(args.output_html)
    out_path.write_text(full_html, encoding="utf-8")
    print(f"Wrote interactive connectivity HTML to {out_path}")

    # 7) Log to WandB (if requested)
    if use_wandb:
        # wandb.Html will render your HTML in the WandB UI
        wandb.log({
            "connectivity": wandb.Html(str(out_path))
        })
        wandb.finish()

import sys

# Simulate exactly the CLI invocation:
sys.argv = [
    "solution_6.py",
    "--checkpoint",    "../checkpoints/best_attention.pt",
    "--train_tsv",     "../lexicons/hi.translit.sampled.train.tsv",
    "--test_tsv",      "../lexicons/hi.translit.sampled.test.tsv",
    "--n_examples",    "5",
    "--output_html",   "../connectivity.html",
    "--wandb_project", "transliteration",
    "--wandb_run_name","solution_6_run",
    "--wandb_run_tag", "solution_6",
]

main()