In [1]:
import wandb
wandb.login(key='13b86763ab8ddf529c91c7dce385c6cb04b5253e')

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mma23m015[0m ([33miitm-ma23m015[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [3]:

try:
    import torch, wandb, pandas as pd, numpy as np
except ImportError:
    !pip install -q torch==2.2.1+cpu torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
    !pip install -q wandb pandas numpy
    import torch, wandb, pandas as pd, numpy as np

import random, os, math, json, shutil
from pathlib import Path
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence


DATA_DIR = Path("/kaggle/input/tamil-translit")
TRAIN_F  = DATA_DIR/"ta.translit.sampled.train.tsv"
DEV_F    = DATA_DIR/"ta.translit.sampled.dev.tsv"
TEST_F   = DATA_DIR/"ta.translit.sampled.test.tsv"

def read_pairs(path):
    df = pd.read_csv(path, sep="\t", header=None, names=["target","source","freq"])
    df = df.dropna(subset=["source","target"])
    return [(s, t) for t,s in zip(df.target.astype(str), df.source.astype(str))]

train_pairs, dev_pairs, test_pairs = map(read_pairs, [TRAIN_F, DEV_F, TEST_F])


class CharVocab:
    def __init__(self, seqs):
        self.char2idx = {'<pad>':0,'<sos>':1,'<eos>':2,'<unk>':3}
        self.idx2char = ['<pad>','<sos>','<eos>','<unk>']
        for ch in sorted(set("".join(seqs))):
            self.char2idx[ch] = len(self.idx2char)
            self.idx2char.append(ch)
    def encode(self, txt):   return [self.char2idx.get(c,3) for c in txt]
    def decode(self, idxs):
        out=[]; 
        for i in idxs:
            if i==2: break
            if i not in (0,1): out.append(self.idx2char[i])
        return "".join(out)
    def __len__(self): return len(self.idx2char)

src_vocab = CharVocab([s for s,_ in train_pairs])
tgt_vocab = CharVocab([t for _,t in train_pairs])


class TransliterationDS(Dataset):
    def __init__(self, pairs): self.pairs = pairs
    def __len__(self): return len(self.pairs)
    def __getitem__(self, idx):
        s,t = self.pairs[idx]
        src = torch.tensor(src_vocab.encode(s), dtype=torch.long)
        tgt = torch.tensor([1]+tgt_vocab.encode(t)+[2], dtype=torch.long)
        return src, tgt
def collate_fn(batch):
    src, tgt = zip(*batch)
    src = pad_sequence(src, batch_first=True, padding_value=0)
    tgt = pad_sequence(tgt, batch_first=True, padding_value=0)
    return src, tgt

train_ds, dev_ds, test_ds = map(TransliterationDS, [train_pairs, dev_pairs, test_pairs])


class Attention(nn.Module):
    def __init__(self, hid, attn): super().__init__()
    def __init__(self, hid_dim, attn_dim):
        super().__init__()
        self.W_enc = nn.Linear(hid_dim, attn_dim, bias=False)
        self.W_dec = nn.Linear(hid_dim, attn_dim, bias=False)
        self.v     = nn.Linear(attn_dim, 1,      bias=False)
    def forward(self, enc_out, dec_h):
        T = enc_out.size(1)
        dec = self.W_dec(dec_h).unsqueeze(1).expand(-1,T,-1)
        e   = torch.tanh(self.W_enc(enc_out) + dec)
        scores = self.v(e).squeeze(-1)             # [B,T]
        alpha  = torch.softmax(scores, dim=1)
        ctx    = torch.bmm(alpha.unsqueeze(1), enc_out).squeeze(1)  # [B,H]
        return ctx, alpha

class Encoder(nn.Module):
    def __init__(self, vocab, emb, hid, drop, cell):
        super().__init__()
        self.emb = nn.Embedding(vocab, emb, padding_idx=0)
        rnn_cls = {'GRU':nn.GRU,'LSTM':nn.LSTM}[cell]
        self.rnn = rnn_cls(emb, hid, batch_first=True)
        self.drop = nn.Dropout(drop)
        self.cell = cell
    def forward(self, src):
        x = self.drop(self.emb(src))
        out, hid = self.rnn(x)
        return out, hid

class Decoder(nn.Module):
    def __init__(self, vocab, emb, hid, drop, cell, attn_dim):
        super().__init__()
        self.emb  = nn.Embedding(vocab, emb, padding_idx=0)
        self.attn = Attention(hid, attn_dim)
        rnn_cls   = {'GRU':nn.GRU,'LSTM':nn.LSTM}[cell]
        self.rnn  = rnn_cls(emb+hid, hid, batch_first=True)
        self.fc   = nn.Linear(hid*2, vocab)
        self.drop = nn.Dropout(drop)
        self.cell = cell
    def forward(self, tok, hid, enc_out):
        emb = self.drop(self.emb(tok).unsqueeze(1))  # [B,1,E]
        dec_h = hid[0] if self.cell=='LSTM' else hid
        dec_h = dec_h[-1]                            # [B,H]
        ctx, alpha = self.attn(enc_out, dec_h)       # [B,H]
        rnn_in = torch.cat([emb, ctx.unsqueeze(1)], dim=2)
        out, hid = self.rnn(rnn_in, hid)
        out = out.squeeze(1)
        pred = self.fc(torch.cat([out, ctx], dim=1))
        return pred, hid, alpha

class Seq2SeqAttn(nn.Module):
    def __init__(self, cfg, device):
        super().__init__()
        self.device = device
        self.encoder = Encoder(len(src_vocab), cfg.emb_dim, cfg.hid_dim,
                               cfg.dropout, cfg.cell)
        self.decoder = Decoder(len(tgt_vocab), cfg.emb_dim, cfg.hid_dim,
                               cfg.dropout, cfg.cell, cfg.attn_dim)
        self.tgt_size = len(tgt_vocab)
        self.cell = cfg.cell
    def forward(self, src, tgt, teacher=0.5):
        B,T = tgt.shape
        outs = torch.zeros(B,T,self.tgt_size, device=self.device)
        enc_out, hid = self.encoder(src)
        tok = tgt[:,0]
        for t in range(1,T):
            pred, hid, _ = self.decoder(tok, hid, enc_out)
            outs[:,t] = pred
            tok = tgt[:,t] if random.random() < teacher else pred.argmax(1)
        return outs
    @torch.no_grad()
    def greedy(self, src, max_len=50):
        self.eval()
        enc_out, hid = self.encoder(src)
        tok = torch.full((src.size(0),), tgt_vocab.char2idx['<sos>'],
                         dtype=torch.long, device=src.device)
        preds=[]
        for _ in range(max_len):
            pred, hid, _ = self.decoder(tok, hid, enc_out)
            tok = pred.argmax(1)
            preds.append(tok)
        preds = torch.stack(preds,1)
        texts=[tgt_vocab.decode(row.cpu().numpy()) for row in preds]
        return texts


def train_epoch(model, loader, opt, crit, device):
    model.train(); tot=0
    for src,tgt in loader:
        src,tgt = src.to(device), tgt.to(device)
        opt.zero_grad()
        out = model(src,tgt,teacher=0.5)
        loss = crit(out[:,1:].reshape(-1, model.tgt_size),
                    tgt[:,1:].reshape(-1))
        loss.backward(); opt.step(); tot+=loss.item()
    return tot/len(loader)

@torch.no_grad()
def word_acc(model, loader, device):
    model.eval(); ok=tot=0
    for src,tgt in loader:
        src,tgt = src.to(device), tgt.to(device)
        out = model(src,tgt,teacher=0)
        pred = out.argmax(2)
        for p,t in zip(pred.cpu(), tgt.cpu()):
            if tgt_vocab.decode(p.numpy()) == tgt_vocab.decode(t.numpy()):
                ok+=1
            tot+=1
    return ok/tot

def eval_loss(model, loader, crit, device):
    model.eval(); tot=0
    for src,tgt in loader:
        src,tgt = src.to(device), tgt.to(device)
        out = model(src,tgt,teacher=0)
        loss = crit(out[:,1:].reshape(-1, model.tgt_size),
                    tgt[:,1:].reshape(-1))
        tot+=loss.item()
    return tot/len(loader)


sweep_cfg = {
    "method":"bayes",
    "name":"attn-sweep",
    "metric":{"name":"validation_accuracy","goal":"maximize"},
    "parameters":{
        "emb_dim":{"values":[64,128,256]},
        "hid_dim":{"values":[64,128,256]},
        "attn_dim":{"values":[128,256]},
        "dropout":{"values":[0.1,0.3]},
        "cell":{"values":["GRU","LSTM"]}
    }
}
sweep_id = wandb.sweep(sweep_cfg, project="MA23M015_DL_Assignment3")

def sweep_run():
    with wandb.init() as run:
        cfg = wandb.config
        device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
        dl_train=DataLoader(train_ds, batch_size=64, shuffle=True, collate_fn=collate_fn)
        dl_dev  =DataLoader(dev_ds,   batch_size=64, shuffle=False,collate_fn=collate_fn)
        dl_test =DataLoader(test_ds,  batch_size=64, shuffle=False,collate_fn=collate_fn)

        class C: pass
        C.emb_dim=cfg.emb_dim; C.hid_dim=cfg.hid_dim; C.dropout=cfg.dropout
        C.attn_dim=cfg.attn_dim; C.cell=cfg.cell
        model=Seq2SeqAttn(C, device).to(device)
        opt=torch.optim.Adam(model.parameters())
        crit=nn.CrossEntropyLoss(ignore_index=0)

        best=0
        for ep in range(1,11):
            tr_loss=train_epoch(model, dl_train, opt, crit, device)
            tr_acc =word_acc(model, dl_train, device)
            val_loss=eval_loss(model, dl_dev, crit, device)
            val_acc =word_acc(model, dl_dev, device)
            wandb.log({"epoch": ep, "train_loss": tr_loss, "train_acc": tr_acc,
                       "validation_loss": val_loss, "validation_accuracy": val_acc})
            print(f"epoch:{ep:02d} train_loss:{tr_loss:.3f} train_acc:{tr_acc:.3f} "
                  f"val_loss:{val_loss:.3f} val_acc{val_acc:.3f}")
            if val_acc > best:
                best = val_acc
                torch.save(model.state_dict(), "best_model.pt")

        # Final test evaluation using the best model
        model.load_state_dict(torch.load("best_model.pt"))
        test_acc = word_acc(model, dl_test, device)
        wandb.log({"test_accuracy": test_acc})

        print(f"BEST val_acc {best*100:.2f}%  →  test_acc {test_acc*100:.2f}%")

        # save predictions
        preds=[]
        for src,_ in dl_test:
            src=src.to(device)
            txt=model.greedy(src)
            preds.extend(txt)
        out_dir=Path("predictions_attention"); out_dir.mkdir(exist_ok=True)
        pd.DataFrame({"source":[s for s,_ in test_pairs],
                      "prediction":preds}
                    ).to_csv(out_dir/"test_predictions.tsv",sep="\t",index=False)
        wandb.save(str(out_dir/"test_predictions.tsv"))


wandb.agent(sweep_id, function=sweep_run, count=2)
wandb.finish()


Create sweep with ID: bhtzulgd
Sweep URL: https://wandb.ai/iitm-ma23m015/MA23M015_DL_Assignment3/sweeps/bhtzulgd


[34m[1mwandb[0m: Agent Starting Run: i127tr8q with config:
[34m[1mwandb[0m: 	attn_dim: 256
[34m[1mwandb[0m: 	cell: GRU
[34m[1mwandb[0m: 	dropout: 0.1
[34m[1mwandb[0m: 	emb_dim: 64
[34m[1mwandb[0m: 	hid_dim: 128


epoch:01 train_loss:0.948 train_acc:0.427 val_loss:0.876 val_acc0.411
epoch:02 train_loss:0.467 train_acc:0.545 val_loss:0.738 val_acc0.511
epoch:03 train_loss:0.399 train_acc:0.568 val_loss:0.730 val_acc0.506
epoch:04 train_loss:0.360 train_acc:0.616 val_loss:0.716 val_acc0.540
epoch:05 train_loss:0.328 train_acc:0.639 val_loss:0.678 val_acc0.560
epoch:06 train_loss:0.304 train_acc:0.670 val_loss:0.690 val_acc0.564
epoch:07 train_loss:0.289 train_acc:0.684 val_loss:0.694 val_acc0.566
epoch:08 train_loss:0.274 train_acc:0.683 val_loss:0.730 val_acc0.559
epoch:09 train_loss:0.262 train_acc:0.713 val_loss:0.718 val_acc0.572
epoch:10 train_loss:0.251 train_acc:0.726 val_loss:0.699 val_acc0.576
BEST val_acc 57.61%  →  test_acc 55.32%


0,1
epoch,▁▂▃▃▄▅▆▆▇█
test_accuracy,▁
train_acc,▁▄▄▅▆▇▇▇██
train_loss,█▃▂▂▂▂▁▁▁▁
validation_accuracy,▁▅▅▆▇██▇██
validation_loss,█▃▃▂▁▁▂▃▂▂

0,1
epoch,10.0
test_accuracy,0.55318
train_acc,0.72579
train_loss,0.2514
validation_accuracy,0.57609
validation_loss,0.69933


[34m[1mwandb[0m: Agent Starting Run: 7re6hdui with config:
[34m[1mwandb[0m: 	attn_dim: 256
[34m[1mwandb[0m: 	cell: LSTM
[34m[1mwandb[0m: 	dropout: 0.3
[34m[1mwandb[0m: 	emb_dim: 256
[34m[1mwandb[0m: 	hid_dim: 64


epoch:01 train_loss:1.297 train_acc:0.296 val_loss:1.110 val_acc0.282
epoch:02 train_loss:0.652 train_acc:0.400 val_loss:0.941 val_acc0.366
epoch:03 train_loss:0.542 train_acc:0.468 val_loss:0.823 val_acc0.449
epoch:04 train_loss:0.488 train_acc:0.506 val_loss:0.801 val_acc0.475
epoch:05 train_loss:0.454 train_acc:0.524 val_loss:0.760 val_acc0.491
epoch:06 train_loss:0.426 train_acc:0.549 val_loss:0.731 val_acc0.505
epoch:07 train_loss:0.408 train_acc:0.570 val_loss:0.732 val_acc0.516
epoch:08 train_loss:0.394 train_acc:0.576 val_loss:0.737 val_acc0.516
epoch:09 train_loss:0.381 train_acc:0.592 val_loss:0.738 val_acc0.519
epoch:10 train_loss:0.368 train_acc:0.601 val_loss:0.709 val_acc0.533
BEST val_acc 53.27%  →  test_acc 51.50%


0,1
epoch,▁▂▃▃▄▅▆▆▇█
test_accuracy,▁
train_acc,▁▃▅▆▆▇▇▇██
train_loss,█▃▂▂▂▁▁▁▁▁
validation_accuracy,▁▃▆▆▇▇████
validation_loss,█▅▃▃▂▁▁▁▁▁

0,1
epoch,10.0
test_accuracy,0.51501
train_acc,0.60086
train_loss,0.36826
validation_accuracy,0.53274
validation_loss,0.70939


In [6]:

import random, torch, numpy as np, pandas as pd, seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.font_manager as fm
from pathlib import Path
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from io import BytesIO
from PIL import Image
import wandb


SEED = 42
torch.manual_seed(SEED)
random.seed(SEED)
np.random.seed(SEED)


font_path = Path("/kaggle/input/notosans-tamil/NotoSansTamil-VariableFont_wdth,wght.ttf")
if font_path.exists():
    fm.fontManager.addfont(str(font_path))
    plt.rcParams["font.family"] = "Noto Sans Tamil"



DATA_DIR = Path("/kaggle/input/tamil-translit")
TRAIN_F  = DATA_DIR / "ta.translit.sampled.train.tsv"
DEV_F    = DATA_DIR / "ta.translit.sampled.dev.tsv"
TEST_F   = DATA_DIR / "ta.translit.sampled.test.tsv"


def read_pairs(path):
    df = pd.read_csv(path, sep="\t", header=None,
                     names=["target", "source", "freq"])
    df = df.dropna(subset=["source", "target"])
    return [(s, t) for t, s in zip(df.target.astype(str),
                                   df.source.astype(str))]

train_pairs, dev_pairs, test_pairs = map(read_pairs,
                                         [TRAIN_F, DEV_F, TEST_F])

class CharVocab:
    def __init__(self, seqs):
        self.char2idx = {"<pad>":0, "<sos>":1, "<eos>":2, "<unk>":3}
        self.idx2char = ["<pad>", "<sos>", "<eos>", "<unk>"]
        for ch in sorted(set("".join(seqs))):
            self.char2idx[ch] = len(self.idx2char)
            self.idx2char.append(ch)
    def encode(self, txt):   return [self.char2idx.get(c, 3) for c in txt]
    def decode(self, idxs):
        out = []
        for i in idxs:
            if i == 2: break
            if i not in (0, 1): out.append(self.idx2char[i])
        return "".join(out)
    def __len__(self): return len(self.idx2char)

src_vocab = CharVocab([s for s, _ in train_pairs])
tgt_vocab = CharVocab([t for _, t in train_pairs])

class TransliterationDS(Dataset):
    def __init__(self, pairs): self.pairs = pairs
    def __len__(self): return len(self.pairs)
    def __getitem__(self, idx):
        s, t = self.pairs[idx]
        src = torch.tensor(src_vocab.encode(s), dtype=torch.long)
        tgt = torch.tensor([1] + tgt_vocab.encode(t) + [2], dtype=torch.long)
        return src, tgt, s, t   # keep the raw strings for later

def collate_fn(batch):
    src, tgt, s_str, t_str = zip(*batch)
    src = pad_sequence(src, batch_first=True, padding_value=0)
    tgt = pad_sequence(tgt, batch_first=True, padding_value=0)
    return src, tgt, s_str, t_str


class Attention(nn.Module):
    def __init__(self, hid_dim, attn_dim):
        super().__init__()
        self.W_enc = nn.Linear(hid_dim, attn_dim, bias=False)
        self.W_dec = nn.Linear(hid_dim, attn_dim, bias=False)
        self.v     = nn.Linear(attn_dim, 1,      bias=False)
    def forward(self, enc_out, dec_h):
        # enc_out : [B,T,H]   dec_h : [B,H]
        T     = enc_out.size(1)
        dec   = self.W_dec(dec_h).unsqueeze(1).expand(-1, T, -1)
        e     = torch.tanh(self.W_enc(enc_out) + dec)   # [B,T,A]
        score = self.v(e).squeeze(-1)                   # [B,T]
        alpha = torch.softmax(score, dim=1)             # attention weights
        ctx   = torch.bmm(alpha.unsqueeze(1), enc_out).squeeze(1)  # [B,H]
        return ctx, alpha

class Encoder(nn.Module):
    def __init__(self, vocab, emb, hid, drop, cell):
        super().__init__()
        self.emb  = nn.Embedding(vocab, emb, padding_idx=0)
        rnn_cls   = {"GRU": nn.GRU, "LSTM": nn.LSTM}[cell]
        self.rnn  = rnn_cls(emb, hid, batch_first=True)
        self.drop = nn.Dropout(drop)
        self.cell = cell
    def forward(self, src):
        x = self.drop(self.emb(src))
        out, hid = self.rnn(x)
        return out, hid

class Decoder(nn.Module):
    def __init__(self, vocab, emb, hid, drop, cell, attn_dim):
        super().__init__()
        self.emb  = nn.Embedding(vocab, emb, padding_idx=0)
        self.attn = Attention(hid, attn_dim)
        rnn_cls   = {"GRU": nn.GRU, "LSTM": nn.LSTM}[cell]
        self.rnn  = rnn_cls(emb + hid, hid, batch_first=True)
        self.fc   = nn.Linear(hid * 2, vocab)
        self.drop = nn.Dropout(drop)
        self.cell = cell
    def forward(self, tok, hid, enc_out):
        emb  = self.drop(self.emb(tok).unsqueeze(1))        # [B,1,E]
        dec_h = hid[0] if self.cell == "LSTM" else hid
        dec_h = dec_h[-1]                                   # [B,H]
        ctx, alpha = self.attn(enc_out, dec_h)              # [B,H]
        rnn_in = torch.cat([emb, ctx.unsqueeze(1)], dim=2)  # [B,1,E+H]
        out, hid = self.rnn(rnn_in, hid)
        out = out.squeeze(1)                                # [B,H]
        pred = self.fc(torch.cat([out, ctx], dim=1))        # [B,V]
        return pred, hid, alpha

class Seq2SeqAttn(nn.Module):
    def __init__(self, cfg, device):
        super().__init__()
        self.device  = device
        self.encoder = Encoder(len(src_vocab), cfg.emb_dim,
                               cfg.hid_dim, cfg.dropout, cfg.cell)
        self.decoder = Decoder(len(tgt_vocab), cfg.emb_dim,
                               cfg.hid_dim, cfg.dropout, cfg.cell,
                               cfg.attn_dim)
        self.tgt_size = len(tgt_vocab)
        self.cell     = cfg.cell

    def forward(self, src, tgt, teacher=0.5):
        B, T = tgt.shape
        outs = torch.zeros(B, T, self.tgt_size, device=self.device)
        enc_out, hid = self.encoder(src)
        tok = tgt[:, 0]      # <sos>
        for t in range(1, T):
            pred, hid, _ = self.decoder(tok, hid, enc_out)
            outs[:, t] = pred
            use_gt = random.random() < teacher
            tok = tgt[:, t] if use_gt else pred.argmax(1)
        return outs

    @torch.no_grad()
    def greedy_with_attention(self, src, max_len=50):
        self.eval()
        enc_out, hid = self.encoder(src)
        tok = torch.full((src.size(0),),
                         tgt_vocab.char2idx["<sos>"],
                         dtype=torch.long, device=src.device)
        preds, attentions = [], []
        for _ in range(max_len):
            pred, hid, alpha = self.decoder(tok, hid, enc_out)
            tok = pred.argmax(1)
            preds.append(tok)
            attentions.append(alpha)
            if (tok == tgt_vocab.char2idx["<eos>"]).all(): break
        preds      = torch.stack(preds, 1)          # [B,T_tgt]
        attentions = torch.stack(attentions, 1)     # [B,T_tgt,T_src]
        texts = [tgt_vocab.decode(row.cpu().numpy()) for row in preds]
        return texts, attentions

def train_epoch(model, loader, opt, crit, device):
    model.train()
    tot = 0
    for src, tgt, _, _ in loader:
        src, tgt = src.to(device), tgt.to(device)
        opt.zero_grad()
        out = model(src, tgt, teacher=0.5)
        loss = crit(out[:, 1:].reshape(-1, model.tgt_size),
                    tgt[:, 1:].reshape(-1))
        loss.backward()
        opt.step()
        tot += loss.item()
    return tot / len(loader)

@torch.no_grad()
def word_acc(model, loader, device):
    model.eval()
    ok = tot = 0
    for src, tgt, _, _ in loader:
        src, tgt = src.to(device), tgt.to(device)
        out = model(src, tgt, teacher=0)
        pred = out.argmax(2)
        for p, t in zip(pred.cpu(), tgt.cpu()):
            if tgt_vocab.decode(p.numpy()) == tgt_vocab.decode(t.numpy()):
                ok += 1
            tot += 1
    return ok / tot

@torch.no_grad()
def eval_loss(model, loader, crit, device):
    model.eval()
    tot = 0
    for src, tgt, _, _ in loader:
        src, tgt = src.to(device), tgt.to(device)
        out = model(src, tgt, teacher=0)
        loss = crit(out[:, 1:].reshape(-1, model.tgt_size),
                    tgt[:, 1:].reshape(-1))
        tot += loss.item()
    return tot / len(loader)


class Config:
    def __init__(self):
        self.emb_dim  = 128
        self.hid_dim  = 256
        self.attn_dim = 256
        self.dropout  = 0.3
        self.cell     = "LSTM"


wandb.init(project="MA23M015_DL_Assignment3",
           name="attention-visualization-lstm",
           config={
               "emb_dim": 128,
               "hid_dim": 256,
               "attn_dim": 256,
               "dropout": 0.3,
               "cell": "LSTM",
               "seed": SEED
           })


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"[INFO] ‑ Using device: {device}")

train_ds = TransliterationDS(train_pairs)
dev_ds   = TransliterationDS(dev_pairs)
test_ds  = TransliterationDS(test_pairs)

train_loader = DataLoader(train_ds, batch_size=64, shuffle=True,
                          collate_fn=collate_fn)
dev_loader   = DataLoader(dev_ds,   batch_size=64, shuffle=False,
                          collate_fn=collate_fn)
test_loader  = DataLoader(test_ds,  batch_size=1,  shuffle=False,
                          collate_fn=collate_fn)


cfg   = Config()
model = Seq2SeqAttn(cfg, device).to(device)
opt   = torch.optim.Adam(model.parameters())
crit  = nn.CrossEntropyLoss(ignore_index=0)


EPOCHS        = 10
best_val_acc  = 0.0
MODEL_PATH    = "transliteration_model.pt"


for epoch in range(1, EPOCHS + 1):
    tr_loss = train_epoch(model, train_loader, opt, crit, device)
    tr_acc  = word_acc(model, train_loader, device)
    val_loss = eval_loss(model, dev_loader, crit, device)
    val_acc  = word_acc(model, dev_loader, device)

    wandb.log({
        "epoch": epoch,
        "train_loss": tr_loss,
        "train_acc":  tr_acc,
        "val_loss": val_loss,
        "val_acc":  val_acc
    })

    print(f"Epoch {epoch:02d}: "
          f"Train‑Loss {tr_loss:.4f}  Train‑Acc {tr_acc:.4f} | "
          f"Val‑Loss {val_loss:.4f}  Val‑Acc {val_acc:.4f}")

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), MODEL_PATH)
        print(f"[INFO] ‑ New best model saved (Val‑Acc = {val_acc:.4f})")


model.load_state_dict(torch.load(MODEL_PATH))
model.eval()

import random
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from PIL import Image
from io import BytesIO
import wandb

figs  = []
count = 0

sample_ids = random.sample(range(len(test_ds)), 9)

for idx in sample_ids:
    src_tok, _, src_str, _ = test_ds[idx]
    src_tensor = src_tok.unsqueeze(0).to(device)
    pred_texts, attn = model.greedy_with_attention(src_tensor)
    pred_str    = pred_texts[0]

    attn_mat = attn.squeeze(0).cpu().numpy()[:len(pred_str), :len(src_str)]

    # Plot single heatmap 
    fig, ax = plt.subplots(figsize=(4, 4))
    sns.heatmap(attn_mat,
                ax=ax, cmap="viridis",
                cbar=False,
                xticklabels=list(src_str),
                yticklabels=list(pred_str))
    ax.set_xlabel("Source")
    ax.set_ylabel("Predicted")
    ax.set_title(f"{src_str} -> {pred_str}")
    plt.tight_layout()

  
    figs.append(fig)
    count += 1

grid_fig, grid_axes = plt.subplots(3, 3, figsize=(18, 18))

for i, (ax, single_fig) in enumerate(zip(grid_axes.ravel(), figs)):
    single_fig.canvas.draw()
    img = np.asarray(single_fig.canvas.buffer_rgba())
    ax.imshow(img)
    ax.axis("off")
    ax.set_title(f"Sample {i+1}", fontsize=14)

plt.tight_layout()

buf = BytesIO()
grid_fig.savefig(buf, format="png", dpi=150)
buf.seek(0)
wandb.log({"attention_grid": wandb.Image(Image.open(buf))})


plt.close(grid_fig)
for f in figs:
    plt.close(f)



[INFO] ‑ Using device: cuda
Epoch 01: Train‑Loss 0.8536  Train‑Acc 0.5113 | Val‑Loss 0.7628  Val‑Acc 0.4831
[INFO] ‑ New best model saved (Val‑Acc = 0.4831)
Epoch 02: Train‑Loss 0.4225  Train‑Acc 0.5843 | Val‑Loss 0.7359  Val‑Acc 0.5260
[INFO] ‑ New best model saved (Val‑Acc = 0.5260)
Epoch 03: Train‑Loss 0.3544  Train‑Acc 0.6310 | Val‑Loss 0.6799  Val‑Acc 0.5660
[INFO] ‑ New best model saved (Val‑Acc = 0.5660)
Epoch 04: Train‑Loss 0.3139  Train‑Acc 0.6763 | Val‑Loss 0.6959  Val‑Acc 0.5827
[INFO] ‑ New best model saved (Val‑Acc = 0.5827)
Epoch 05: Train‑Loss 0.2846  Train‑Acc 0.6957 | Val‑Loss 0.6944  Val‑Acc 0.5726
Epoch 06: Train‑Loss 0.2568  Train‑Acc 0.7282 | Val‑Loss 0.7299  Val‑Acc 0.5784
Epoch 07: Train‑Loss 0.2382  Train‑Acc 0.7507 | Val‑Loss 0.7368  Val‑Acc 0.5869
[INFO] ‑ New best model saved (Val‑Acc = 0.5869)
Epoch 08: Train‑Loss 0.2171  Train‑Acc 0.7739 | Val‑Loss 0.7101  Val‑Acc 0.5871
[INFO] ‑ New best model saved (Val‑Acc = 0.5871)
Epoch 09: Train‑Loss 0.2021  Train‑Acc