In [2]:
import wandb
wandb.login(key="give wandb api key")

[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mapoorvaprashanth[0m ([33mapoorvaprashanth-indian-institute-of-technology-madras[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

**Vanilla model for initial sweep**

> Run the below cell to sweep, sweep configuration can be chnaged based on observation

In [None]:
import wandb
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import random

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Char-level vocabulary
class CharVocab:
    """
    A character-level vocabulary class for encoding and decoding sequences of characters.

    Attributes:
    -----------
    char2idx : dict
        Mapping from characters to integer indices. Includes special tokens:
        '<pad>' (0), '<sos>' (1), '<eos>' (2), and '<unk>' (3).
        
    idx2char : dict
        Reverse mapping from indices to characters.

    pad_idx : int
        Index of the padding token ('<pad>').

    sos_idx : int
        Index of the start-of-sequence token ('<sos>').

    eos_idx : int
        Index of the end-of-sequence token ('<eos>').

    Methods:
    --------
    encode(word: str) -> List[int]
        Converts a string into a list of indices, including <sos> at the start and <eos> at the end.
        Unknown characters are mapped to the <unk> index.

    decode(ids: List[int]) -> str
        Converts a list of indices back into a string, ignoring <sos> and <pad>, and stopping at <eos>.

    __len__() -> int
        Returns the size of the vocabulary (i.e., number of unique tokens including special tokens).
    """
    def __init__(self, words):
        chars = sorted(set("".join(words)))
        self.char2idx = {'<pad>': 0, '<sos>': 1, '<eos>': 2, '<unk>': 3}
        for c in chars:
            self.char2idx[c] = len(self.char2idx)
        self.idx2char = {i: c for c, i in self.char2idx.items()}
        self.pad_idx = self.char2idx['<pad>']
        self.sos_idx = self.char2idx['<sos>']
        self.eos_idx = self.char2idx['<eos>']

    def encode(self, word):
        return [self.sos_idx] + [self.char2idx.get(c, self.char2idx['<unk>']) for c in word] + [self.eos_idx]

    def decode(self, ids):
        chars = []
        for idx in ids:
            if idx == self.eos_idx:
                break
            if idx not in (self.sos_idx, self.pad_idx):
                chars.append(self.idx2char.get(idx, ''))
        return ''.join(chars)

    def __len__(self):
        return len(self.char2idx)

# Load data
def read_file(path):
    with open(path, encoding='utf-8') as f:
        lines = f.read().strip().split('\n')
    return [(line.split('\t')[0], line.split('\t')[1]) for line in lines if len(line.split('\t')) >= 2]

train_pairs = read_file('/kaggle/input/malayalam/ml.translit.sampled.train.tsv')
dev_pairs = read_file('/kaggle/input/malayalam/ml.translit.sampled.dev.tsv')
test_pairs = read_file('/kaggle/input/malayalam/ml.translit.sampled.test.tsv')

src_vocab = CharVocab([src for src, _ in train_pairs])
tgt_vocab = CharVocab([tgt for _, tgt in train_pairs])

# Dataset
class TransliterationDataset(Dataset):
    """
This code defines a PyTorch dataset and dataloaders for a character-level transliteration task. It converts input-output string pairs into sequences of token indices using source and target vocabularies, and pads them for batch processing. The TransliterationDataset encodes each word pair, while the collate_fn ensures proper padding during batching. Dataloaders are created for training and validation with appropriate batch sizes.
"""
    def __init__(self, pairs, src_vocab, tgt_vocab):
        self.data = pairs
        self.src_vocab = src_vocab
        self.tgt_vocab = tgt_vocab

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        src, tgt = self.data[idx]
        return torch.tensor(self.src_vocab.encode(src)), torch.tensor(self.tgt_vocab.encode(tgt))

def collate_fn(batch):
    src_batch, tgt_batch = zip(*batch)
    src_pad = nn.utils.rnn.pad_sequence(src_batch, batch_first=True, padding_value=src_vocab.pad_idx)
    tgt_pad = nn.utils.rnn.pad_sequence(tgt_batch, batch_first=True, padding_value=tgt_vocab.pad_idx)
    return src_pad, tgt_pad

train_loader = DataLoader(TransliterationDataset(train_pairs, src_vocab, tgt_vocab),
                          batch_size=64, shuffle=True, collate_fn=collate_fn)
dev_loader = DataLoader(TransliterationDataset(dev_pairs, src_vocab, tgt_vocab),
                        batch_size=32, shuffle=False, collate_fn=collate_fn)

# Model
class Seq2Seq(nn.Module):
    """
    A sequence-to-sequence model for character-level transliteration using RNNs."""
    def __init__(self, config, input_vocab_size, output_vocab_size):
        super().__init__()
        self.embedding_dim = config.embed_size
        self.hidden_size = config.hidden_size
        self.num_enc_layers = config.enc_layers
        self.num_dec_layers = config.dec_layers
        self.cell_type = config.cell
        self.device = device
        self.dropout = nn.Dropout(config.dropout)
        self.max_len = 30

        self.encoder_embedding = nn.Embedding(input_vocab_size, self.embedding_dim)
        self.decoder_embedding = nn.Embedding(output_vocab_size, self.embedding_dim)

        RNN = {'RNN': nn.RNN, 'GRU': nn.GRU, 'LSTM': nn.LSTM}[self.cell_type]
        self.encoder = RNN(self.embedding_dim, self.hidden_size, num_layers=self.num_enc_layers,
                           batch_first=True, bidirectional=True)
        self.decoder = RNN(self.embedding_dim, self.hidden_size * 2, num_layers=self.num_dec_layers,
                           batch_first=True)

        self.fc = nn.Linear(self.hidden_size * 2, output_vocab_size)

        self.sos_idx = tgt_vocab.sos_idx
        self.eos_idx = tgt_vocab.eos_idx
        self.pad_idx = tgt_vocab.pad_idx

    def encode(self, src): 
        '''Encodes the input sequence using the encoder RNN.
        Args: '''
        embedded = self.dropout(self.encoder_embedding(src))
        outputs, h_n = self.encoder(embedded)
        if self.cell_type == 'LSTM':
            h, c = h_n
            h_cat = torch.cat((h[-2], h[-1]), dim=1).unsqueeze(0)
            c_cat = torch.cat((c[-2], c[-1]), dim=1).unsqueeze(0)
            return outputs, (h_cat, c_cat)
        else:
            h_cat = torch.cat((h_n[-2], h_n[-1]), dim=1).unsqueeze(0)
            return outputs, h_cat

    def decode_step(self, input_token, hidden):
        embedded = self.dropout(self.decoder_embedding(input_token))
        output, hidden = self.decoder(embedded, hidden)
        logits = self.fc(output.squeeze(1))
        return logits, hidden

    def forward(self, src, tgt, teacher_forcing_ratio=0.5):
        batch_size, tgt_len = tgt.shape
        _, hidden = self.encode(src)
        input_token = tgt[:, 0].unsqueeze(1)
        outputs = []

        for t in range(1, tgt_len):
            output, hidden = self.decode_step(input_token, hidden)
            outputs.append(output.unsqueeze(1))
            teacher_force = random.random() < teacher_forcing_ratio
            top1 = output.argmax(1).unsqueeze(1)
            input_token = tgt[:, t].unsqueeze(1) if teacher_force else top1
        return torch.cat(outputs, dim=1)

# Training
def train(model, loader, optimizer, criterion):
    model.train()
    total_loss = 0
    for src, tgt in loader:
        src, tgt = src.to(device), tgt.to(device)
        optimizer.zero_grad()
        output = model(src, tgt)
        loss = criterion(output.view(-1, output.size(-1)), tgt[:, 1:].reshape(-1))
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)

# Evaluation
def beam_decode(model, src, beam_size):
    """
    Beam search decoding for the Seq2Seq model.
    Args:
        model: The trained Seq2Seq model.
        src: The source sequence tensor.
        beam_size: The number of beams to keep during decoding.
        Returns:
        A list of predicted sequences.
    """
    model.eval()
    with torch.no_grad():
        _, hidden = model.encode(src)
        batch_size = src.size(0)
        final_outputs = []

        for b in range(batch_size):
            h_b = (hidden[0][:, b:b+1, :].contiguous(), hidden[1][:, b:b+1, :].contiguous()) if model.cell_type == 'LSTM' else hidden[:, b:b+1, :].contiguous()
            beams = [([model.sos_idx], 0.0, h_b)]
            for _ in range(model.max_len):
                new_beams = []
                for seq, score, h in beams:
                    if seq[-1] == model.eos_idx:
                        new_beams.append((seq, score, h))
                        continue
                    input_token = torch.tensor([[seq[-1]]], device=device)
                    out, h_new = model.decode_step(input_token, h)
                    log_probs = F.log_softmax(out, dim=1)
                    topk_probs, topk_idxs = torch.topk(log_probs, beam_size, dim=1)
                    for i in range(beam_size):
                        next_seq = seq + [topk_idxs[0][i].item()]
                        new_score = score + topk_probs[0][i].item()
                        new_beams.append((next_seq, new_score, h_new))
                beams = sorted(new_beams, key=lambda x: x[1], reverse=True)[:beam_size]
            final_outputs.append(beams[0][0])
        return final_outputs

def evaluate_beam(model, dataloader, beam_size):
    """
    Evaluate the model using beam search decoding.
    Args:
        model: The trained Seq2Seq model.
        dataloader: DataLoader for the evaluation dataset.
        beam_size: The number of beams to keep during decoding.
        Returns:
        Tuple of sequence-level and token-level accuracy.
    """
    model.eval()
    total_seq, correct_seq = 0, 0
    total_tokens, correct_tokens = 0, 0

    with torch.no_grad():
        for src, tgt in dataloader:
            src, tgt = src.to(device), tgt.to(device)
            preds = beam_decode(model, src, beam_size)
            for pred, true in zip(preds, tgt):
                pred_trimmed = [tok for tok in pred[1:] if tok != model.pad_idx and tok != model.eos_idx]
                true_trimmed = [tok.item() for tok in true[1:] if tok.item() != model.pad_idx and tok.item() != model.eos_idx]

                # Sequence-level accuracy
                if pred_trimmed == true_trimmed:
                    correct_seq += 1
                total_seq += 1

                # Token-level accuracy
                for p, t in zip(pred_trimmed, true_trimmed):
                    if p == t:
                        correct_tokens += 1
                total_tokens += len(true_trimmed)

    seq_accuracy = correct_seq / total_seq if total_seq > 0 else 0.0
    token_accuracy = correct_tokens / total_tokens if total_tokens > 0 else 0.0
    return seq_accuracy, token_accuracy


# W&B Sweep Training
def sweep_train(config=None):
    with wandb.init(config=config):
        config = wandb.config  # MUST be done before using config

        run_name = f"embed{config.embed_size}_hid{config.hidden_size}_enc{config.enc_layers}_dec{config.dec_layers}_{config.cell}_drop{config.dropout}_beam{config.beam_size}"
        wandb.run.name = run_name

        model = Seq2Seq(config, len(src_vocab), len(tgt_vocab)).to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
        criterion = nn.CrossEntropyLoss(ignore_index=tgt_vocab.pad_idx)

        for epoch in range(10):
            train_loss = train(model, train_loader, optimizer, criterion)
            acc, token_acc = evaluate_beam(model, dev_loader, beam_size=config.beam_size)

            wandb.log({
                'epoch': epoch + 1,
                'train_loss': train_loss,
                'val_accuracy': acc,
                'val_token_accuracy': token_acc
            })

            print(f"Epoch {epoch+1}: Train Loss = {train_loss:.4f}, Seq Acc = {acc:.4f}, Token Acc = {token_acc:.4f}")




# Sweep Config
sweep_config = {
    'method': 'bayes',
    'metric': {'name': 'val_accuracy', 'goal': 'maximize'},
    'parameters': {
        'embed_size': {'values': [128, 256]},
        'hidden_size': {'values': [64, 128]},
        'enc_layers': {'values': [2]},
        'dec_layers': {'values': [1]},
        'dropout': {'values': [0.2, 0.3]},
        'cell': {'values': ['LSTM']},
        'beam_size': {'values': [5]}
    }
}

sweep_id = wandb.sweep(sweep_config, project="A3_ce21b020")
wandb.agent(sweep_id, function=sweep_train, count=5)


Create sweep with ID: 200n7n5k
Sweep URL: https://wandb.ai/apoorvaprashanth-indian-institute-of-technology-madras/A3_ce21b020/sweeps/200n7n5k


[34m[1mwandb[0m: Agent Starting Run: 34h930t9 with config:
[34m[1mwandb[0m: 	beam_size: 5
[34m[1mwandb[0m: 	cell: LSTM
[34m[1mwandb[0m: 	dec_layers: 1
[34m[1mwandb[0m: 	dropout: 0.3
[34m[1mwandb[0m: 	embed_size: 256
[34m[1mwandb[0m: 	enc_layers: 2
[34m[1mwandb[0m: 	hidden_size: 64


Epoch 1: Train Loss = 1.6122, Seq Acc = 0.0893, Token Acc = 0.4890
Epoch 2: Train Loss = 0.9561, Seq Acc = 0.2204, Token Acc = 0.6336
Epoch 3: Train Loss = 0.7857, Seq Acc = 0.2569, Token Acc = 0.6603
Epoch 4: Train Loss = 0.7063, Seq Acc = 0.2725, Token Acc = 0.6683
Epoch 5: Train Loss = 0.6542, Seq Acc = 0.2882, Token Acc = 0.6960
Epoch 6: Train Loss = 0.6128, Seq Acc = 0.3056, Token Acc = 0.7082
Epoch 7: Train Loss = 0.5903, Seq Acc = 0.3175, Token Acc = 0.7176
Epoch 8: Train Loss = 0.5729, Seq Acc = 0.3065, Token Acc = 0.7091
Epoch 9: Train Loss = 0.5529, Seq Acc = 0.3269, Token Acc = 0.7223
Epoch 10: Train Loss = 0.5390, Seq Acc = 0.3283, Token Acc = 0.7225


0,1
epoch,▁▂▃▃▄▅▆▆▇█
train_loss,█▄▃▂▂▁▁▁▁▁
val_accuracy,▁▅▆▆▇▇█▇██
val_token_accuracy,▁▅▆▆▇█████

0,1
epoch,10.0
train_loss,0.53905
val_accuracy,0.32831
val_token_accuracy,0.72252


[34m[1mwandb[0m: Agent Starting Run: vd5e0t10 with config:
[34m[1mwandb[0m: 	beam_size: 5
[34m[1mwandb[0m: 	cell: LSTM
[34m[1mwandb[0m: 	dec_layers: 1
[34m[1mwandb[0m: 	dropout: 0.2
[34m[1mwandb[0m: 	embed_size: 128
[34m[1mwandb[0m: 	enc_layers: 2
[34m[1mwandb[0m: 	hidden_size: 64


Epoch 1: Train Loss = 1.5899, Seq Acc = 0.1377, Token Acc = 0.5648
Epoch 2: Train Loss = 0.9084, Seq Acc = 0.2351, Token Acc = 0.6516
Epoch 3: Train Loss = 0.7481, Seq Acc = 0.2650, Token Acc = 0.6665
Epoch 4: Train Loss = 0.6741, Seq Acc = 0.2854, Token Acc = 0.6790
Epoch 5: Train Loss = 0.6257, Seq Acc = 0.2913, Token Acc = 0.6932
Epoch 6: Train Loss = 0.5972, Seq Acc = 0.3210, Token Acc = 0.7129
Epoch 7: Train Loss = 0.5711, Seq Acc = 0.3171, Token Acc = 0.7089
Epoch 8: Train Loss = 0.5546, Seq Acc = 0.3221, Token Acc = 0.7171
Epoch 9: Train Loss = 0.5385, Seq Acc = 0.3264, Token Acc = 0.7200
Epoch 10: Train Loss = 0.5259, Seq Acc = 0.3251, Token Acc = 0.7188


0,1
epoch,▁▂▃▃▄▅▆▆▇█
train_loss,█▄▂▂▂▁▁▁▁▁
val_accuracy,▁▅▆▆▇█████
val_token_accuracy,▁▅▆▆▇█████

0,1
epoch,10.0
train_loss,0.52593
val_accuracy,0.32512
val_token_accuracy,0.71876


[34m[1mwandb[0m: Agent Starting Run: 0whhxb57 with config:
[34m[1mwandb[0m: 	beam_size: 5
[34m[1mwandb[0m: 	cell: LSTM
[34m[1mwandb[0m: 	dec_layers: 1
[34m[1mwandb[0m: 	dropout: 0.2
[34m[1mwandb[0m: 	embed_size: 128
[34m[1mwandb[0m: 	enc_layers: 2
[34m[1mwandb[0m: 	hidden_size: 64


Epoch 1: Train Loss = 1.6110, Seq Acc = 0.1048, Token Acc = 0.5255
Epoch 2: Train Loss = 0.9154, Seq Acc = 0.2260, Token Acc = 0.6362
Epoch 3: Train Loss = 0.7561, Seq Acc = 0.2638, Token Acc = 0.6708
Epoch 4: Train Loss = 0.6743, Seq Acc = 0.2902, Token Acc = 0.7009
Epoch 5: Train Loss = 0.6295, Seq Acc = 0.2980, Token Acc = 0.6974
Epoch 6: Train Loss = 0.5924, Seq Acc = 0.3092, Token Acc = 0.7098
Epoch 7: Train Loss = 0.5702, Seq Acc = 0.3170, Token Acc = 0.7115
Epoch 8: Train Loss = 0.5503, Seq Acc = 0.3221, Token Acc = 0.7184
Epoch 9: Train Loss = 0.5374, Seq Acc = 0.3281, Token Acc = 0.7222
Epoch 10: Train Loss = 0.5221, Seq Acc = 0.3327, Token Acc = 0.7212


0,1
epoch,▁▂▃▃▄▅▆▆▇█
train_loss,█▄▃▂▂▁▁▁▁▁
val_accuracy,▁▅▆▇▇▇████
val_token_accuracy,▁▅▆▇▇█████

0,1
epoch,10.0
train_loss,0.52205
val_accuracy,0.33274
val_token_accuracy,0.72118


[34m[1mwandb[0m: Agent Starting Run: kbknrf4b with config:
[34m[1mwandb[0m: 	beam_size: 5
[34m[1mwandb[0m: 	cell: LSTM
[34m[1mwandb[0m: 	dec_layers: 1
[34m[1mwandb[0m: 	dropout: 0.2
[34m[1mwandb[0m: 	embed_size: 128
[34m[1mwandb[0m: 	enc_layers: 2
[34m[1mwandb[0m: 	hidden_size: 128


Epoch 1: Train Loss = 1.2823, Seq Acc = 0.2523, Token Acc = 0.6787


[34m[1mwandb[0m: Ctrl + C detected. Stopping sweep.


**Added attention (Q5)**

In [None]:
import wandb
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import random

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Char-level vocabulary
class CharVocab:
    """
    A character-level vocabulary class for encoding and decoding sequences of characters.

    Attributes:
    -----------
    char2idx : dict
        Mapping from characters to integer indices. Includes special tokens:
        '<pad>' (0), '<sos>' (1), '<eos>' (2), and '<unk>' (3).
        
    idx2char : dict
        Reverse mapping from indices to characters.

    pad_idx : int
        Index of the padding token ('<pad>').

    sos_idx : int
        Index of the start-of-sequence token ('<sos>').

    eos_idx : int
        Index of the end-of-sequence token ('<eos>').

    Methods:
    --------
    encode(word: str) -> List[int]
        Converts a string into a list of indices, including <sos> at the start and <eos> at the end.
        Unknown characters are mapped to the <unk> index.

    decode(ids: List[int]) -> str
        Converts a list of indices back into a string, ignoring <sos> and <pad>, and stopping at <eos>.

    __len__() -> int
        Returns the size of the vocabulary (i.e., number of unique tokens including special tokens).
    """
    def __init__(self, words):
        chars = sorted(set("".join(words)))
        self.char2idx = {'<pad>': 0, '<sos>': 1, '<eos>': 2, '<unk>': 3}
        for c in chars:
            self.char2idx[c] = len(self.char2idx)
        self.idx2char = {i: c for c, i in self.char2idx.items()}
        self.pad_idx = self.char2idx['<pad>']
        self.sos_idx = self.char2idx['<sos>']
        self.eos_idx = self.char2idx['<eos>']

    def encode(self, word):
        return [self.sos_idx] + [self.char2idx.get(c, self.char2idx['<unk>']) for c in word] + [self.eos_idx]

    def decode(self, ids):
        chars = []
        for idx in ids:
            if idx == self.eos_idx:
                break
            if idx not in (self.sos_idx, self.pad_idx):
                chars.append(self.idx2char.get(idx, ''))
        return ''.join(chars)

    def __len__(self):
        return len(self.char2idx)

def read_file(path):
    with open(path, encoding='utf-8') as f:
        lines = f.read().strip().split('\n')
    return [(line.split('\t')[0], line.split('\t')[1]) for line in lines if len(line.split('\t')) >= 2]

train_pairs = read_file('/kaggle/input/malayalam/ml.translit.sampled.train.tsv')
dev_pairs = read_file('/kaggle/input/malayalam/ml.translit.sampled.dev.tsv')
test_pairs = read_file('/kaggle/input/malayalam/ml.translit.sampled.test.tsv')

src_vocab = CharVocab([src for src, _ in train_pairs])
tgt_vocab = CharVocab([tgt for _, tgt in train_pairs])

class TransliterationDataset(Dataset):
    """
This code defines a PyTorch dataset and dataloaders for a character-level transliteration task. It converts input-output string pairs into sequences of token indices using source and target vocabularies, and pads them for batch processing. The TransliterationDataset encodes each word pair, while the collate_fn ensures proper padding during batching. Dataloaders are created for training and validation with appropriate batch sizes.
"""
    def __init__(self, pairs, src_vocab, tgt_vocab):
        self.data = pairs
        self.src_vocab = src_vocab
        self.tgt_vocab = tgt_vocab

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        src, tgt = self.data[idx]
        return torch.tensor(self.src_vocab.encode(src)), torch.tensor(self.tgt_vocab.encode(tgt))

def collate_fn(batch):
    src_batch, tgt_batch = zip(*batch)
    src_pad = nn.utils.rnn.pad_sequence(src_batch, batch_first=True, padding_value=src_vocab.pad_idx)
    tgt_pad = nn.utils.rnn.pad_sequence(tgt_batch, batch_first=True, padding_value=tgt_vocab.pad_idx)
    return src_pad, tgt_pad

train_loader = DataLoader(TransliterationDataset(train_pairs, src_vocab, tgt_vocab),
                          batch_size=64, shuffle=True, collate_fn=collate_fn)
dev_loader = DataLoader(TransliterationDataset(dev_pairs, src_vocab, tgt_vocab),
                        batch_size=32, shuffle=False, collate_fn=collate_fn)

# Attention
class Attention(nn.Module):
    """
    A simple attention mechanism for the Seq2Seq model.
    Args:
        enc_hidden_dim (int): The hidden dimension of the encoder.
        dec_hidden_dim (int): The hidden dimension of the decoder.
    """
    def __init__(self, enc_hidden_dim, dec_hidden_dim):
        super().__init__()
        self.attn = nn.Linear(enc_hidden_dim + dec_hidden_dim, dec_hidden_dim)
        self.v = nn.Parameter(torch.rand(dec_hidden_dim))

    def forward(self, hidden, encoder_outputs, mask=None):
        src_len = encoder_outputs.shape[1]
        hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)
        energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim=2)))
        energy = energy.permute(0, 2, 1)
        v = self.v.repeat(encoder_outputs.size(0), 1).unsqueeze(1)
        attn_weights = torch.bmm(v, energy).squeeze(1)
        if mask is not None:
            attn_weights = attn_weights.masked_fill(mask == 0, -1e10)
        return F.softmax(attn_weights, dim=1)

# Seq2Seq with Attention
class Seq2Seq(nn.Module):
    def __init__(self, config, input_vocab_size, output_vocab_size):
        super().__init__()
        self.embedding_dim = config.embed_size
        self.hidden_size = config.hidden_size
        self.num_enc_layers = config.enc_layers
        self.num_dec_layers = config.dec_layers
        self.cell_type = config.cell
        self.device = device
        self.dropout = nn.Dropout(config.dropout)
        self.max_len = 30

        self.encoder_embedding = nn.Embedding(input_vocab_size, self.embedding_dim)
        self.decoder_embedding = nn.Embedding(output_vocab_size, self.embedding_dim)

        RNN = {'RNN': nn.RNN, 'GRU': nn.GRU, 'LSTM': nn.LSTM}[self.cell_type]
        self.encoder = RNN(self.embedding_dim, self.hidden_size, num_layers=self.num_enc_layers,
                           batch_first=True, bidirectional=True)
        self.decoder = RNN(self.embedding_dim + self.hidden_size * 2, self.hidden_size * 2,
                           num_layers=self.num_dec_layers, batch_first=True)

        self.attention = Attention(self.hidden_size * 2, self.hidden_size * 2)
        self.fc = nn.Linear(self.hidden_size * 4, output_vocab_size)

        self.sos_idx = tgt_vocab.sos_idx
        self.eos_idx = tgt_vocab.eos_idx
        self.pad_idx = tgt_vocab.pad_idx

    def encode(self, src):
        embedded = self.dropout(self.encoder_embedding(src))
        outputs, h_n = self.encoder(embedded)
        if self.cell_type == 'LSTM':
            h, c = h_n
            h_cat = torch.cat((h[-2], h[-1]), dim=1).unsqueeze(0)
            c_cat = torch.cat((c[-2], c[-1]), dim=1).unsqueeze(0)
            return outputs, (h_cat, c_cat)
        else:
            h_cat = torch.cat((h_n[-2], h_n[-1]), dim=1).unsqueeze(0)
            return outputs, h_cat

    def decode_step(self, input_token, hidden, encoder_outputs):
        embedded = self.dropout(self.decoder_embedding(input_token))
        if self.cell_type == 'LSTM':
            h_t = hidden[0][-1]
        else:
            h_t = hidden[-1]
        attn_weights = self.attention(h_t, encoder_outputs)
        context = torch.bmm(attn_weights.unsqueeze(1), encoder_outputs)
        rnn_input = torch.cat((embedded, context), dim=2)
        output, hidden = self.decoder(rnn_input, hidden)
        logits = self.fc(torch.cat((output.squeeze(1), context.squeeze(1)), dim=1))
        return logits, hidden

    def forward(self, src, tgt, teacher_forcing_ratio=0.5):
        batch_size, tgt_len = tgt.shape
        encoder_outputs, hidden = self.encode(src)
        input_token = tgt[:, 0].unsqueeze(1)
        outputs = []

        for t in range(1, tgt_len):
            output, hidden = self.decode_step(input_token, hidden, encoder_outputs)
            outputs.append(output.unsqueeze(1))
            teacher_force = random.random() < teacher_forcing_ratio
            top1 = output.argmax(1).unsqueeze(1)
            input_token = tgt[:, t].unsqueeze(1) if teacher_force else top1
        return torch.cat(outputs, dim=1)

# Training
def train(model, loader, optimizer, criterion):
    model.train()
    total_loss = 0
    for src, tgt in loader:
        src, tgt = src.to(device), tgt.to(device)
        optimizer.zero_grad()
        output = model(src, tgt)
        loss = criterion(output.view(-1, output.size(-1)), tgt[:, 1:].reshape(-1))
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)

# Beam Decoding
def beam_decode(model, src, beam_size):
    model.eval()
    with torch.no_grad():
        encoder_outputs, hidden = model.encode(src)
        batch_size = src.size(0)
        final_outputs = []

        for b in range(batch_size):
            h_b = (hidden[0][:, b:b+1, :].contiguous(), hidden[1][:, b:b+1, :].contiguous()) if model.cell_type == 'LSTM' else hidden[:, b:b+1, :].contiguous()
            enc_out_b = encoder_outputs[b:b+1]
            beams = [([model.sos_idx], 0.0, h_b)]
            for _ in range(model.max_len):
                new_beams = []
                for seq, score, h in beams:
                    if seq[-1] == model.eos_idx:
                        new_beams.append((seq, score, h))
                        continue
                    input_token = torch.tensor([[seq[-1]]], device=device)
                    out, h_new = model.decode_step(input_token, h, enc_out_b)
                    log_probs = F.log_softmax(out, dim=1)
                    topk_probs, topk_idxs = torch.topk(log_probs, beam_size, dim=1)
                    for i in range(beam_size):
                        next_seq = seq + [topk_idxs[0][i].item()]
                        new_score = score + topk_probs[0][i].item()
                        new_beams.append((next_seq, new_score, h_new))
                beams = sorted(new_beams, key=lambda x: x[1], reverse=True)[:beam_size]
            final_outputs.append(beams[0][0])
        return final_outputs

def evaluate_beam(model, dataloader, beam_size):
    model.eval()
    total_seq, correct_seq = 0, 0
    total_tokens, correct_tokens = 0, 0

    with torch.no_grad():
        for src, tgt in dataloader:
            src, tgt = src.to(device), tgt.to(device)
            preds = beam_decode(model, src, beam_size)
            for pred, true in zip(preds, tgt):
                pred_trimmed = [tok for tok in pred[1:] if tok != model.pad_idx and tok != model.eos_idx]
                true_trimmed = [tok.item() for tok in true[1:] if tok.item() != model.pad_idx and tok.item() != model.eos_idx]

                if pred_trimmed == true_trimmed:
                    correct_seq += 1
                total_seq += 1

                for p, t in zip(pred_trimmed, true_trimmed):
                    if p == t:
                        correct_tokens += 1
                total_tokens += len(true_trimmed)

    return correct_seq / total_seq, correct_tokens / total_tokens

# W&B Sweep Training
def sweep_train(config=None):
    with wandb.init(config=config):
        config = wandb.config
        run_name = f"embed{config.embed_size}_hid{config.hidden_size}_enc{config.enc_layers}_dec{config.dec_layers}_{config.cell}_drop{config.dropout}_beam{config.beam_size}"
        wandb.run.name = run_name

        model = Seq2Seq(config, len(src_vocab), len(tgt_vocab)).to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
        criterion = nn.CrossEntropyLoss(ignore_index=tgt_vocab.pad_idx)

        for epoch in range(10):
            train_loss = train(model, train_loader, optimizer, criterion)
            acc, token_acc = evaluate_beam(model, dev_loader, beam_size=config.beam_size)

            wandb.log({
                'epoch': epoch + 1,
                'train_loss': train_loss,
                'val_accuracy': acc,
                'val_token_accuracy': token_acc,
                'used_attention':config.used_attention
            })
            print(f"Epoch {epoch+1}: Train Loss = {train_loss:.4f}, Seq Acc = {acc:.4f}, Token Acc = {token_acc:.4f}")

# Sweep Config
sweep_config = {
    'method': 'bayes',
    'metric': {'name': 'val_accuracy', 'goal': 'maximize'},
    'parameters': {
        'embed_size': {'values': [128, 256]},
        'hidden_size': {'values': [128]},
        'enc_layers': {'values': [1, 2]},
        'dec_layers': {'values': [1]},
        'dropout': {'values': [0.25, 0.3, 0.35]},
        'cell': {'values': ['GRU','LSTM']},
        'beam_size': {'values': [3, 5]},
        'used_attention':{'values':[True]}
    }
}

sweep_id = wandb.sweep(sweep_config, project="A3_ce21b020")
wandb.agent(sweep_id, function=sweep_train, count=10)


Create sweep with ID: 97snhvbo
Sweep URL: https://wandb.ai/apoorvaprashanth-indian-institute-of-technology-madras/A3_ce21b020/sweeps/97snhvbo


[34m[1mwandb[0m: Agent Starting Run: fzpcg94g with config:
[34m[1mwandb[0m: 	beam_size: 3
[34m[1mwandb[0m: 	cell: GRU
[34m[1mwandb[0m: 	dec_layers: 1
[34m[1mwandb[0m: 	dropout: 0.25
[34m[1mwandb[0m: 	embed_size: 128
[34m[1mwandb[0m: 	enc_layers: 2
[34m[1mwandb[0m: 	hidden_size: 128
[34m[1mwandb[0m: 	used_attention: True


Epoch 1: Train Loss = 0.8944, Seq Acc = 0.3636, Token Acc = 0.7392
Epoch 2: Train Loss = 0.5061, Seq Acc = 0.3755, Token Acc = 0.7446
Epoch 3: Train Loss = 0.4734, Seq Acc = 0.3627, Token Acc = 0.7437
Epoch 4: Train Loss = 0.4442, Seq Acc = 0.3785, Token Acc = 0.7470
Epoch 5: Train Loss = 0.4329, Seq Acc = 0.3673, Token Acc = 0.7468
Epoch 6: Train Loss = 0.4186, Seq Acc = 0.3744, Token Acc = 0.7467
Epoch 7: Train Loss = 0.4152, Seq Acc = 0.3810, Token Acc = 0.7545
Epoch 8: Train Loss = 0.4038, Seq Acc = 0.3744, Token Acc = 0.7472
Epoch 9: Train Loss = 0.3994, Seq Acc = 0.3691, Token Acc = 0.7466
Epoch 10: Train Loss = 0.3927, Seq Acc = 0.3721, Token Acc = 0.7412


0,1
epoch,▁▂▃▃▄▅▆▆▇█
train_loss,█▃▂▂▂▁▁▁▁▁
val_accuracy,▁▆▁▇▃▅█▅▃▅
val_token_accuracy,▁▃▃▅▄▄█▅▄▂

0,1
epoch,10
train_loss,0.39267
used_attention,True
val_accuracy,0.3721
val_token_accuracy,0.74119


[34m[1mwandb[0m: Agent Starting Run: nt29evpt with config:
[34m[1mwandb[0m: 	beam_size: 5
[34m[1mwandb[0m: 	cell: LSTM
[34m[1mwandb[0m: 	dec_layers: 1
[34m[1mwandb[0m: 	dropout: 0.25
[34m[1mwandb[0m: 	embed_size: 256
[34m[1mwandb[0m: 	enc_layers: 1
[34m[1mwandb[0m: 	hidden_size: 128
[34m[1mwandb[0m: 	used_attention: True


Epoch 1: Train Loss = 1.0539, Seq Acc = 0.3593, Token Acc = 0.7437
Epoch 2: Train Loss = 0.5119, Seq Acc = 0.3742, Token Acc = 0.7482
Epoch 3: Train Loss = 0.4659, Seq Acc = 0.3712, Token Acc = 0.7464
Epoch 4: Train Loss = 0.4379, Seq Acc = 0.3758, Token Acc = 0.7487
Epoch 5: Train Loss = 0.4250, Seq Acc = 0.3769, Token Acc = 0.7497
Epoch 6: Train Loss = 0.4132, Seq Acc = 0.3691, Token Acc = 0.7459
Epoch 7: Train Loss = 0.4066, Seq Acc = 0.3739, Token Acc = 0.7551
Epoch 8: Train Loss = 0.3995, Seq Acc = 0.3686, Token Acc = 0.7498
Epoch 9: Train Loss = 0.3897, Seq Acc = 0.3698, Token Acc = 0.7464
Epoch 10: Train Loss = 0.3793, Seq Acc = 0.3760, Token Acc = 0.7472


0,1
epoch,▁▂▃▃▄▅▆▆▇█
train_loss,█▂▂▂▁▁▁▁▁▁
val_accuracy,▁▇▆██▅▇▅▅█
val_token_accuracy,▁▄▃▄▅▂█▅▃▃

0,1
epoch,10
train_loss,0.3793
used_attention,True
val_accuracy,0.376
val_token_accuracy,0.74725


[34m[1mwandb[0m: Agent Starting Run: f8zw99p4 with config:
[34m[1mwandb[0m: 	beam_size: 5
[34m[1mwandb[0m: 	cell: GRU
[34m[1mwandb[0m: 	dec_layers: 1
[34m[1mwandb[0m: 	dropout: 0.3
[34m[1mwandb[0m: 	embed_size: 128
[34m[1mwandb[0m: 	enc_layers: 2
[34m[1mwandb[0m: 	hidden_size: 128
[34m[1mwandb[0m: 	used_attention: True


Epoch 1: Train Loss = 0.8683, Seq Acc = 0.3762, Token Acc = 0.7537
Epoch 2: Train Loss = 0.5146, Seq Acc = 0.3709, Token Acc = 0.7410
Epoch 3: Train Loss = 0.4759, Seq Acc = 0.3700, Token Acc = 0.7439
Epoch 4: Train Loss = 0.4467, Seq Acc = 0.3815, Token Acc = 0.7559
Epoch 5: Train Loss = 0.4321, Seq Acc = 0.3684, Token Acc = 0.7420
Epoch 6: Train Loss = 0.4294, Seq Acc = 0.3794, Token Acc = 0.7541
Epoch 7: Train Loss = 0.4163, Seq Acc = 0.3717, Token Acc = 0.7470
Epoch 8: Train Loss = 0.4056, Seq Acc = 0.3694, Token Acc = 0.7487
Epoch 9: Train Loss = 0.4038, Seq Acc = 0.3748, Token Acc = 0.7559
Epoch 10: Train Loss = 0.3979, Seq Acc = 0.3664, Token Acc = 0.7450


0,1
epoch,▁▂▃▃▄▅▆▆▇█
train_loss,█▃▂▂▂▁▁▁▁▁
val_accuracy,▆▃▃█▂▇▃▂▅▁
val_token_accuracy,▇▁▂█▁▇▄▅█▃

0,1
epoch,10
train_loss,0.39791
used_attention,True
val_accuracy,0.36642
val_token_accuracy,0.745


[34m[1mwandb[0m: Agent Starting Run: c9iqdtvk with config:
[34m[1mwandb[0m: 	beam_size: 5
[34m[1mwandb[0m: 	cell: LSTM
[34m[1mwandb[0m: 	dec_layers: 1
[34m[1mwandb[0m: 	dropout: 0.25
[34m[1mwandb[0m: 	embed_size: 128
[34m[1mwandb[0m: 	enc_layers: 2
[34m[1mwandb[0m: 	hidden_size: 128
[34m[1mwandb[0m: 	used_attention: True


Epoch 1: Train Loss = 0.9719, Seq Acc = 0.3625, Token Acc = 0.7362
Epoch 2: Train Loss = 0.5056, Seq Acc = 0.3725, Token Acc = 0.7474
Epoch 3: Train Loss = 0.4535, Seq Acc = 0.3765, Token Acc = 0.7470
Epoch 4: Train Loss = 0.4334, Seq Acc = 0.3739, Token Acc = 0.7462
Epoch 5: Train Loss = 0.4218, Seq Acc = 0.3804, Token Acc = 0.7485
Epoch 6: Train Loss = 0.4048, Seq Acc = 0.3909, Token Acc = 0.7555
Epoch 7: Train Loss = 0.3883, Seq Acc = 0.3815, Token Acc = 0.7553
Epoch 8: Train Loss = 0.3819, Seq Acc = 0.3845, Token Acc = 0.7516
Epoch 9: Train Loss = 0.3772, Seq Acc = 0.3774, Token Acc = 0.7527
Epoch 10: Train Loss = 0.3743, Seq Acc = 0.3764, Token Acc = 0.7502


0,1
epoch,▁▂▃▃▄▅▆▆▇█
train_loss,█▃▂▂▂▁▁▁▁▁
val_accuracy,▁▃▄▄▅█▆▆▅▄
val_token_accuracy,▁▅▅▅▅██▇▇▆

0,1
epoch,10
train_loss,0.37429
used_attention,True
val_accuracy,0.37635
val_token_accuracy,0.75024


[34m[1mwandb[0m: Agent Starting Run: 3r5lxpg7 with config:
[34m[1mwandb[0m: 	beam_size: 3
[34m[1mwandb[0m: 	cell: GRU
[34m[1mwandb[0m: 	dec_layers: 1
[34m[1mwandb[0m: 	dropout: 0.35
[34m[1mwandb[0m: 	embed_size: 128
[34m[1mwandb[0m: 	enc_layers: 2
[34m[1mwandb[0m: 	hidden_size: 128
[34m[1mwandb[0m: 	used_attention: True


Epoch 1: Train Loss = 0.8944, Seq Acc = 0.3409, Token Acc = 0.7214
Epoch 2: Train Loss = 0.5244, Seq Acc = 0.3593, Token Acc = 0.7334
Epoch 3: Train Loss = 0.4739, Seq Acc = 0.3643, Token Acc = 0.7419
Epoch 4: Train Loss = 0.4524, Seq Acc = 0.3863, Token Acc = 0.7545
Epoch 5: Train Loss = 0.4396, Seq Acc = 0.3655, Token Acc = 0.7433
Epoch 6: Train Loss = 0.4236, Seq Acc = 0.3616, Token Acc = 0.7492
Epoch 7: Train Loss = 0.4153, Seq Acc = 0.3792, Token Acc = 0.7529
Epoch 8: Train Loss = 0.4065, Seq Acc = 0.3795, Token Acc = 0.7533
Epoch 9: Train Loss = 0.4054, Seq Acc = 0.3760, Token Acc = 0.7521
Epoch 10: Train Loss = 0.3966, Seq Acc = 0.3693, Token Acc = 0.7501


0,1
epoch,▁▂▃▃▄▅▆▆▇█
train_loss,█▃▂▂▂▁▁▁▁▁
val_accuracy,▁▄▅█▅▄▇▇▆▅
val_token_accuracy,▁▄▅█▆▇██▇▇

0,1
epoch,10
train_loss,0.39665
used_attention,True
val_accuracy,0.36926
val_token_accuracy,0.75011


[34m[1mwandb[0m: Agent Starting Run: mlk5ph1v with config:
[34m[1mwandb[0m: 	beam_size: 5
[34m[1mwandb[0m: 	cell: GRU
[34m[1mwandb[0m: 	dec_layers: 1
[34m[1mwandb[0m: 	dropout: 0.25
[34m[1mwandb[0m: 	embed_size: 128
[34m[1mwandb[0m: 	enc_layers: 1
[34m[1mwandb[0m: 	hidden_size: 128
[34m[1mwandb[0m: 	used_attention: True


Epoch 1: Train Loss = 0.8739, Seq Acc = 0.3625, Token Acc = 0.7429
Epoch 2: Train Loss = 0.5159, Seq Acc = 0.3645, Token Acc = 0.7406
Epoch 3: Train Loss = 0.4794, Seq Acc = 0.3707, Token Acc = 0.7508
Epoch 4: Train Loss = 0.4531, Seq Acc = 0.3737, Token Acc = 0.7519


**best model without attention**  Q4

In [3]:
## import wandb
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import random
import os
import pandas as pd
import matplotlib.pyplot as plt
from IPython.display import display, HTML
import zipfile

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Char-level vocabulary
class CharVocab:
     """
    A character-level vocabulary class for encoding and decoding sequences of characters.

    Attributes:
    -----------
    char2idx : dict
        Mapping from characters to integer indices. Includes special tokens:
        '<pad>' (0), '<sos>' (1), '<eos>' (2), and '<unk>' (3).
        
    idx2char : dict
        Reverse mapping from indices to characters.

    pad_idx : int
        Index of the padding token ('<pad>').

    sos_idx : int
        Index of the start-of-sequence token ('<sos>').

    eos_idx : int
        Index of the end-of-sequence token ('<eos>').

    Methods:
    --------
    encode(word: str) -> List[int]
        Converts a string into a list of indices, including <sos> at the start and <eos> at the end.
        Unknown characters are mapped to the <unk> index.

    decode(ids: List[int]) -> str
        Converts a list of indices back into a string, ignoring <sos> and <pad>, and stopping at <eos>.

    __len__() -> int
        Returns the size of the vocabulary (i.e., number of unique tokens including special tokens).
    """
    def __init__(self, words):
        chars = sorted(set("".join(words)))
        self.char2idx = {'<pad>': 0, '<sos>': 1, '<eos>': 2, '<unk>': 3}
        for c in chars:
            self.char2idx[c] = len(self.char2idx)
        self.idx2char = {i: c for c, i in self.char2idx.items()}
        self.pad_idx = self.char2idx['<pad>']
        self.sos_idx = self.char2idx['<sos>']
        self.eos_idx = self.char2idx['<eos>']

    def encode(self, word):
        return [self.sos_idx] + [self.char2idx.get(c, self.char2idx['<unk>']) for c in word] + [self.eos_idx]

    def decode(self, ids):
        chars = []
        for idx in ids:
            if idx == self.eos_idx:
                break
            if idx not in (self.sos_idx, self.pad_idx):
                chars.append(self.idx2char.get(idx, ''))
        return ''.join(chars)

    def __len__(self):
        return len(self.char2idx)

# Load data
def read_file(path):
    with open(path, encoding='utf-8') as f:
        lines = f.read().strip().split('\n')
    return [(line.split('\t')[0], line.split('\t')[1]) for line in lines if len(line.split('\t')) >= 2]

train_pairs = read_file('/kaggle/input/malayalam/ml.translit.sampled.train.tsv')
dev_pairs = read_file('/kaggle/input/malayalam/ml.translit.sampled.dev.tsv')
test_pairs = read_file('/kaggle/input/malayalam/ml.translit.sampled.test.tsv')

# Build vocabularies from ALL data to ensure full coverage
src_vocab = CharVocab([src for _, src in train_pairs])
tgt_vocab = CharVocab([tgt for tgt, _ in train_pairs])


# Dataset
class TransliterationDataset(Dataset):
    """
This code defines a PyTorch dataset and dataloaders for a character-level transliteration task. It converts input-output string pairs into sequences of token indices using source and target vocabularies, and pads them for batch processing. The TransliterationDataset encodes each word pair, while the collate_fn ensures proper padding during batching. Dataloaders are created for training and validation with appropriate batch sizes.
"""
    def __init__(self, pairs, src_vocab, tgt_vocab):
        self.data = pairs
        self.src_vocab = src_vocab
        self.tgt_vocab = tgt_vocab

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        tgt, src = self.data[idx]
        return torch.tensor(self.src_vocab.encode(src)), torch.tensor(self.tgt_vocab.encode(tgt))

def collate_fn(batch):
    src_batch, tgt_batch = zip(*batch)
    src_pad = nn.utils.rnn.pad_sequence(src_batch, batch_first=True, padding_value=src_vocab.pad_idx)
    tgt_pad = nn.utils.rnn.pad_sequence(tgt_batch, batch_first=True, padding_value=tgt_vocab.pad_idx)
    return src_pad, tgt_pad

train_loader = DataLoader(TransliterationDataset(train_pairs, src_vocab, tgt_vocab),
                          batch_size=64, shuffle=True, collate_fn=collate_fn)
dev_loader = DataLoader(TransliterationDataset(dev_pairs, src_vocab, tgt_vocab),
                        batch_size=32, shuffle=False, collate_fn=collate_fn)
test_loader = DataLoader(TransliterationDataset(test_pairs, src_vocab, tgt_vocab),
                        batch_size=32, shuffle=False, collate_fn=collate_fn)

# Model
class Seq2Seq(nn.Module):
    def __init__(self, config, input_vocab_size, output_vocab_size):
        super().__init__()
        self.embedding_dim = config.embed_size
        self.hidden_size = config.hidden_size
        self.num_enc_layers = config.enc_layers
        self.num_dec_layers = config.dec_layers
        self.cell_type = config.cell
        self.device = device
        self.dropout = nn.Dropout(config.dropout)
        self.max_len = 30

        self.encoder_embedding = nn.Embedding(input_vocab_size, self.embedding_dim)
        self.decoder_embedding = nn.Embedding(output_vocab_size, self.embedding_dim)

        RNN = {'RNN': nn.RNN, 'GRU': nn.GRU, 'LSTM': nn.LSTM}[self.cell_type]
        self.encoder = RNN(self.embedding_dim, self.hidden_size, num_layers=self.num_enc_layers,
                           batch_first=True, bidirectional=True)
        self.decoder = RNN(self.embedding_dim, self.hidden_size * 2, num_layers=self.num_dec_layers,
                           batch_first=True)

        self.fc = nn.Linear(self.hidden_size * 2, output_vocab_size)

        self.sos_idx = tgt_vocab.sos_idx
        self.eos_idx = tgt_vocab.eos_idx
        self.pad_idx = tgt_vocab.pad_idx

    def encode(self, src):
        embedded = self.dropout(self.encoder_embedding(src))
        outputs, h_n = self.encoder(embedded)
        if self.cell_type == 'LSTM':
            h, c = h_n
            h_cat = torch.cat((h[-2], h[-1]), dim=1).unsqueeze(0)
            c_cat = torch.cat((c[-2], c[-1]), dim=1).unsqueeze(0)
            return outputs, (h_cat, c_cat)
        else:
            h_cat = torch.cat((h_n[-2], h_n[-1]), dim=1).unsqueeze(0)
            return outputs, h_cat

    def decode_step(self, input_token, hidden):
        embedded = self.dropout(self.decoder_embedding(input_token))
        output, hidden = self.decoder(embedded, hidden)
        logits = self.fc(output.squeeze(1))
        return logits, hidden

    def forward(self, src, tgt, teacher_forcing_ratio=0.5):
        batch_size, tgt_len = tgt.shape
        _, hidden = self.encode(src)
        input_token = tgt[:, 0].unsqueeze(1)
        outputs = []

        for t in range(1, tgt_len):
            output, hidden = self.decode_step(input_token, hidden)
            outputs.append(output.unsqueeze(1))
            teacher_force = random.random() < teacher_forcing_ratio
            top1 = output.argmax(1).unsqueeze(1)
            input_token = tgt[:, t].unsqueeze(1) if teacher_force else top1
        return torch.cat(outputs, dim=1)

# Training
def train(model, loader, optimizer, criterion):
    model.train()
    total_loss = 0
    for src, tgt in loader:
        src, tgt = src.to(device), tgt.to(device)
        optimizer.zero_grad()
        output = model(src, tgt)
        loss = criterion(output.view(-1, output.size(-1)), tgt[:, 1:].reshape(-1))
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)

# Evaluation
def beam_decode(model, src, beam_size):
    model.eval()
    with torch.no_grad():
        _, hidden = model.encode(src)
        batch_size = src.size(0)
        final_outputs = []

        for b in range(batch_size):
            h_b = (hidden[0][:, b:b+1, :].contiguous(), hidden[1][:, b:b+1, :].contiguous()) if model.cell_type == 'LSTM' else hidden[:, b:b+1, :].contiguous()
            beams = [([model.sos_idx], 0.0, h_b)]
            for _ in range(model.max_len):
                new_beams = []
                for seq, score, h in beams:
                    if seq[-1] == model.eos_idx:
                        new_beams.append((seq, score, h))
                        continue
                    input_token = torch.tensor([[seq[-1]]], device=device)
                    out, h_new = model.decode_step(input_token, h)
                    log_probs = F.log_softmax(out, dim=1)
                    topk_probs, topk_idxs = torch.topk(log_probs, beam_size, dim=1)
                    for i in range(beam_size):
                        next_seq = seq + [topk_idxs[0][i].item()]
                        new_score = score + topk_probs[0][i].item()
                        new_beams.append((next_seq, new_score, h_new))
                beams = sorted(new_beams, key=lambda x: x[1], reverse=True)[:beam_size]
            final_outputs.append(beams[0][0])
        return final_outputs

def evaluate_beam(model, dataloader, beam_size):
    model.eval()
    total_seq, correct_seq = 0, 0
    total_tokens, correct_tokens = 0, 0
    all_predictions = []

    with torch.no_grad():
        for src, tgt in dataloader:
            src, tgt = src.to(device), tgt.to(device)
            preds = beam_decode(model, src, beam_size)
            
            # Decode each item in the batch individually
            for i, (pred, true) in enumerate(zip(preds, tgt)):
                # Process predicted sequence
                pred_trimmed = [tok for tok in pred[1:] if tok != model.pad_idx and tok != model.eos_idx]
                # Process true sequence
                true_trimmed = [tok.item() for tok in true[1:] if tok.item() != model.pad_idx and tok.item() != model.eos_idx]

                # Sequence-level accuracy
                if pred_trimmed == true_trimmed:
                    correct_seq += 1
                total_seq += 1

                # Token-level accuracy
                for p, t in zip(pred_trimmed, true_trimmed):
                    if p == t:
                        correct_tokens += 1
                total_tokens += len(true_trimmed)
                
                # Decode source, predicted and true words properly
                src_word = src_vocab.decode([x.item() for x in src[i] if x.item() not in (src_vocab.sos_idx, src_vocab.eos_idx, src_vocab.pad_idx)])
                pred_word = tgt_vocab.decode(pred)
                true_word = tgt_vocab.decode([x.item() for x in true if x.item() not in (tgt_vocab.pad_idx, tgt_vocab.eos_idx)])
                
                all_predictions.append((src_word, pred_word, true_word))

    seq_accuracy = correct_seq / total_seq if total_seq > 0 else 0.0
    token_accuracy = correct_tokens / total_tokens if total_tokens > 0 else 0.0
    return seq_accuracy, token_accuracy, all_predictions

def visualize_predictions(predictions, num_samples=10, log_to_wandb=False):
    # Create a DataFrame
    df = pd.DataFrame(predictions[:num_samples], columns=['Input', 'Predicted', 'True'])
    
    # Highlight differences between Predicted and True
    def highlight_diff(row):
        pred, true = row['Predicted'], row['True']
        diff = []
        for p, t in zip(pred, true):
            if p == t:
                diff.append(p)
            else:
                diff.append(f'<b style="color:red">{p}</b>')  # Highlight incorrect chars in red
        return ''.join(diff)
    
    # Add a "Difference" column
    df['Difference'] = df.apply(lambda row: highlight_diff(row), axis=1)
    
    # Color entire row green if correct, else pink
    def row_style(row):
        color = 'lightgreen' if row['Predicted'] == row['True'] else 'lightpink'
        return [f'background-color: {color}' for _ in row]
    
    # Apply styling
    styled_df = df.style.apply(row_style, axis=1).set_properties(**{'text-align': 'left'})
    
    # Display in Jupyter (HTML)
    display(HTML(styled_df.to_html(escape=False)))
    
    # Log to W&B (if enabled)
    if log_to_wandb:
        wandb.log({"predictions": wandb.Table(dataframe=df)})
    
    return styled_df

def save_predictions(predictions, filename):
    os.makedirs('predictions_vanilla', exist_ok=True)
    df = pd.DataFrame(predictions, columns=['English Input', 'Predicted Native', 'True Native'])
    df.to_csv(f'predictions_vanilla/{filename}', index=False)
    print(f"Saved to predictions_vanilla/{filename}")

# W&B Training
def sweep_train(config=None):
    with wandb.init(config=config):
        config = wandb.config
        run_name = f"embed{config.embed_size}_hid{config.hidden_size}_enc{config.enc_layers}_dec{config.dec_layers}_{config.cell}_drop{config.dropout}_beam{config.beam_size}"
        wandb.run.name = run_name

        model = Seq2Seq(config, len(src_vocab), len(tgt_vocab)).to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
        criterion = nn.CrossEntropyLoss(ignore_index=tgt_vocab.pad_idx)

        for epoch in range(20):
            train_loss = train(model, train_loader, optimizer, criterion)
            acc, token_acc, val_preds = evaluate_beam(model, dev_loader, beam_size=config.beam_size)
            test_acc, test_token_acc, test_preds = evaluate_beam(model, test_loader, beam_size=config.beam_size)

            wandb.log({
                'epoch': epoch + 1,
                'train_loss': train_loss,
                'val_accuracy': acc,
                'val_token_accuracy': token_acc,
                'test_accuracy': test_acc,
                'test_token_accuracy': test_token_acc
            })

            print(f"Epoch {epoch+1}: Train Loss = {train_loss:.4f}, Seq Acc = {acc:.4f}, Token Acc = {token_acc:.4f} Test Seq Acc = {test_acc:.4f}, Test Token Acc = {test_token_acc:.4f}")
        visualize_predictions(test_preds, num_samples=15, log_to_wandb=True)
        save_predictions(test_preds, f'test_predictions_{run_name}.csv')

# Sweep Config
sweep_config = {
    'method': 'bayes',
    'metric': {'name': 'val_accuracy', 'goal': 'maximize'},
    'parameters': {
        'embed_size': {'values': [256]},
        'hidden_size': {'values': [128]},
        'enc_layers': {'values': [2]},
        'dec_layers': {'values': [1]},
        'dropout': {'values': [0.35]},
        'cell': {'values': ['LSTM']},
        'beam_size': {'values': [5]}
    }
}

# Run sweep
sweep_id = wandb.sweep(sweep_config, project="A3_ce21b020")
wandb.agent(sweep_id, function=sweep_train, count=1)

# Create zip
def create_prediction_zip():
    with zipfile.ZipFile('predictions_vanilla.zip', 'w') as zipf:
        for root, dirs, files in os.walk('predictions_vanilla'):
            for file in files:
                zipf.write(os.path.join(root, file))
    print("Zip created: predictions_vanilla.zip")

create_prediction_zip()

Create sweep with ID: hcp01g1t
Sweep URL: https://wandb.ai/apoorvaprashanth-indian-institute-of-technology-madras/A3_ce21b020/sweeps/hcp01g1t


[34m[1mwandb[0m: Agent Starting Run: cfyh601e with config:
[34m[1mwandb[0m: 	beam_size: 5
[34m[1mwandb[0m: 	cell: LSTM
[34m[1mwandb[0m: 	dec_layers: 1
[34m[1mwandb[0m: 	dropout: 0.35
[34m[1mwandb[0m: 	embed_size: 256
[34m[1mwandb[0m: 	enc_layers: 2
[34m[1mwandb[0m: 	hidden_size: 128


Epoch 1: Train Loss = 1.4911, Seq Acc = 0.2413, Token Acc = 0.6442 Test Seq Acc = 0.2410, Test Token Acc = 0.6395
Epoch 2: Train Loss = 0.7016, Seq Acc = 0.3505, Token Acc = 0.7269 Test Seq Acc = 0.3606, Test Token Acc = 0.7331
Epoch 3: Train Loss = 0.5182, Seq Acc = 0.3960, Token Acc = 0.7471 Test Seq Acc = 0.3995, Test Token Acc = 0.7535
Epoch 4: Train Loss = 0.4284, Seq Acc = 0.4093, Token Acc = 0.7589 Test Seq Acc = 0.4075, Test Token Acc = 0.7639
Epoch 5: Train Loss = 0.3630, Seq Acc = 0.4294, Token Acc = 0.7761 Test Seq Acc = 0.4326, Test Token Acc = 0.7800
Epoch 6: Train Loss = 0.3174, Seq Acc = 0.4370, Token Acc = 0.7800 Test Seq Acc = 0.4285, Test Token Acc = 0.7800
Epoch 7: Train Loss = 0.2829, Seq Acc = 0.4364, Token Acc = 0.7796 Test Seq Acc = 0.4421, Test Token Acc = 0.7858
Epoch 8: Train Loss = 0.2562, Seq Acc = 0.4368, Token Acc = 0.7737 Test Seq Acc = 0.4401, Test Token Acc = 0.7835
Epoch 9: Train Loss = 0.2330, Seq Acc = 0.4455, Token Acc = 0.7816 Test Seq Acc = 0.4453

Unnamed: 0,Input,Predicted,True,Difference
0,amgathavavum,അംഗതാവും,അംഗത്വവും,അംഗതാവും
1,amgathvavum,അംഗത്വവും,അംഗത്വവും,അംഗത്വവും
2,angathwavum,അംഗത്വവും,അംഗത്വവും,അംഗത്വവും
3,amgabalam,അംഗബലം,അംഗബലം,അംഗബലം
4,angabalam,അംഗബലം,അംഗബലം,അംഗബലം
5,amgeekarikkuka,അംഗീകരിക്കുക,അംഗീകരിക്കുക,അംഗീകരിക്കുക
6,angeekarikkuka,അംഗീകരിക്കുക,അംഗീകരിക്കുക,അംഗീകരിക്കുക
7,ambaasadar,അംബാസഡർ,അംബാസഡർ,അംബാസഡർ
8,ambaassador,അൻബാസ്സർ,അംബാസഡർ,അൻബാസ്സ
9,ambassador,അൻബസ്സാർ,അംബാസഡർ,അൻബസ്സാ


Saved to predictions_vanilla/test_predictions_embed256_hid128_enc2_dec1_LSTM_drop0.35_beam5.csv


0,1
epoch,▁▁▂▂▂▃▃▄▄▄▅▅▅▆▆▇▇▇██
test_accuracy,▁▅▆▆▇▇█████████████▇
test_token_accuracy,▁▅▆▇▇▇█▇████████████
train_loss,█▄▃▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁
val_accuracy,▁▅▆▇▇▇▇▇███████▇████
val_token_accuracy,▁▅▆▇▇██▇████████████

0,1
epoch,20.0
test_accuracy,0.43922
test_token_accuracy,0.79014
train_loss,0.11281
val_accuracy,0.45098
val_token_accuracy,0.78757


Zip created: predictions_vanilla.zip


**Best model using attention**

In [None]:
import wandb
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import random
import os
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from IPython.display import display, HTML
import zipfile
import numpy as np

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Char-level vocabulary
class CharVocab:
     """
    A character-level vocabulary class for encoding and decoding sequences of characters.

    Attributes:
    -----------
    char2idx : dict
        Mapping from characters to integer indices. Includes special tokens:
        '<pad>' (0), '<sos>' (1), '<eos>' (2), and '<unk>' (3).
        
    idx2char : dict
        Reverse mapping from indices to characters.

    pad_idx : int
        Index of the padding token ('<pad>').

    sos_idx : int
        Index of the start-of-sequence token ('<sos>').

    eos_idx : int
        Index of the end-of-sequence token ('<eos>').

    Methods:
    --------
    encode(word: str) -> List[int]
        Converts a string into a list of indices, including <sos> at the start and <eos> at the end.
        Unknown characters are mapped to the <unk> index.

    decode(ids: List[int]) -> str
        Converts a list of indices back into a string, ignoring <sos> and <pad>, and stopping at <eos>.

    __len__() -> int
        Returns the size of the vocabulary (i.e., number of unique tokens including special tokens).
    """
    def __init__(self, words):
        chars = sorted(set("".join(words)))
        self.char2idx = {'<pad>': 0, '<sos>': 1, '<eos>': 2, '<unk>': 3}
        for c in chars:
            self.char2idx[c] = len(self.char2idx)
        self.idx2char = {i: c for c, i in self.char2idx.items()}
        self.pad_idx = self.char2idx['<pad>']
        self.sos_idx = self.char2idx['<sos>']
        self.eos_idx = self.char2idx['<eos>']

    def encode(self, word):
        return [self.sos_idx] + [self.char2idx.get(c, self.char2idx['<unk>']) for c in word] + [self.eos_idx]

    def decode(self, ids):
        chars = []
        for idx in ids:
            if idx == self.eos_idx:
                break
            if idx not in (self.sos_idx, self.pad_idx):
                chars.append(self.idx2char.get(idx, ''))
        return ''.join(chars)

    def __len__(self):
        return len(self.char2idx)

def read_file(path):
    with open(path, encoding='utf-8') as f:
        lines = f.read().strip().split('\n')
    return [(line.split('\t')[0], line.split('\t')[1]) for line in lines if len(line.split('\t')) >= 2]

train_pairs = read_file('/kaggle/input/malayalam/ml.translit.sampled.train.tsv')
dev_pairs = read_file('/kaggle/input/malayalam/ml.translit.sampled.dev.tsv')
test_pairs = read_file('/kaggle/input/malayalam/ml.translit.sampled.test.tsv')

src_vocab = CharVocab([src for _, src in train_pairs])
tgt_vocab = CharVocab([tgt for tgt, _ in train_pairs])

class TransliterationDataset(Dataset):
    def __init__(self, pairs, src_vocab, tgt_vocab):
        self.data = pairs
        self.src_vocab = src_vocab
        self.tgt_vocab = tgt_vocab

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        tgt, src = self.data[idx]
        return torch.tensor(self.src_vocab.encode(src)), torch.tensor(self.tgt_vocab.encode(tgt))

def collate_fn(batch):
    src_batch, tgt_batch = zip(*batch)
    src_pad = nn.utils.rnn.pad_sequence(src_batch, batch_first=True, padding_value=src_vocab.pad_idx)
    tgt_pad = nn.utils.rnn.pad_sequence(tgt_batch, batch_first=True, padding_value=tgt_vocab.pad_idx)
    return src_pad, tgt_pad

train_loader = DataLoader(TransliterationDataset(train_pairs, src_vocab, tgt_vocab),
                          batch_size=64, shuffle=True, collate_fn=collate_fn)
dev_loader = DataLoader(TransliterationDataset(dev_pairs, src_vocab, tgt_vocab),
                        batch_size=32, shuffle=False, collate_fn=collate_fn)
test_loader = DataLoader(TransliterationDataset(test_pairs, src_vocab, tgt_vocab),
                        batch_size=32, shuffle=False, collate_fn=collate_fn)

class Attention(nn.Module):
    def __init__(self, enc_hidden_dim, dec_hidden_dim):
        super().__init__()
        self.attn = nn.Linear(enc_hidden_dim + dec_hidden_dim, dec_hidden_dim)
        self.v = nn.Parameter(torch.rand(dec_hidden_dim))

    def forward(self, hidden, encoder_outputs, mask=None):
        src_len = encoder_outputs.shape[1]
        hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)
        energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim=2)))
        energy = energy.permute(0, 2, 1)
        v = self.v.repeat(encoder_outputs.size(0), 1).unsqueeze(1)
        attention = torch.bmm(v, energy).squeeze(1)
        if mask is not None:
            attention = attention.masked_fill(mask == 0, -1e10)
        return F.softmax(attention, dim=1)

class Seq2Seq(nn.Module):
    def __init__(self, config, input_vocab_size, output_vocab_size):
        super().__init__()
        self.embedding_dim = config.embed_size
        self.hidden_size = config.hidden_size
        self.num_enc_layers = config.enc_layers
        self.num_dec_layers = config.dec_layers
        self.cell_type = config.cell
        self.device = device
        self.dropout = nn.Dropout(config.dropout)
        self.max_len = 30
        self.attention_weights = []  # Store attention weights for visualization

        self.encoder_embedding = nn.Embedding(input_vocab_size, self.embedding_dim)
        self.decoder_embedding = nn.Embedding(output_vocab_size, self.embedding_dim)

        RNN = {'RNN': nn.RNN, 'GRU': nn.GRU, 'LSTM': nn.LSTM}[self.cell_type]
        self.encoder = RNN(self.embedding_dim, self.hidden_size, num_layers=self.num_enc_layers,
                           batch_first=True, bidirectional=True)
        self.decoder = RNN(self.embedding_dim + self.hidden_size * 2, self.hidden_size * 2,
                           num_layers=self.num_dec_layers, batch_first=True)

        self.attention = Attention(self.hidden_size * 2, self.hidden_size * 2)
        self.fc = nn.Linear(self.hidden_size * 4, output_vocab_size)

        self.sos_idx = tgt_vocab.sos_idx
        self.eos_idx = tgt_vocab.eos_idx
        self.pad_idx = tgt_vocab.pad_idx

    def encode(self, src):
        embedded = self.dropout(self.encoder_embedding(src))
        outputs, h_n = self.encoder(embedded)
        if self.cell_type == 'LSTM':
            h, c = h_n
            h_cat = torch.cat((h[-2], h[-1]), dim=1).unsqueeze(0)
            c_cat = torch.cat((c[-2], c[-1]), dim=1).unsqueeze(0)
            return outputs, (h_cat, c_cat)
        else:
            h_cat = torch.cat((h_n[-2], h_n[-1]), dim=1).unsqueeze(0)
            return outputs, h_cat

    def decode_step(self, input_token, hidden, encoder_outputs):
        embedded = self.dropout(self.decoder_embedding(input_token))
        if self.cell_type == 'LSTM':
            h_t = hidden[0][-1]
        else:
            h_t = hidden[-1]
            
        attn_weights = self.attention(h_t, encoder_outputs)
        self.attention_weights.append(attn_weights)  # Store for visualization
        context = torch.bmm(attn_weights.unsqueeze(1), encoder_outputs)
        rnn_input = torch.cat((embedded, context), dim=2)
        output, hidden = self.decoder(rnn_input, hidden)
        logits = self.fc(torch.cat((output.squeeze(1), context.squeeze(1)), dim=1))
        return logits, hidden, attn_weights

    def forward(self, src, tgt, teacher_forcing_ratio=0.5):
        self.attention_weights = []  # Reset attention weights storage
        batch_size, tgt_len = tgt.shape
        encoder_outputs, hidden = self.encode(src)
        input_token = tgt[:, 0].unsqueeze(1)
        outputs = []

        for t in range(1, tgt_len):
            output, hidden, attn_weights = self.decode_step(input_token, hidden, encoder_outputs)
            outputs.append(output.unsqueeze(1))
            teacher_force = random.random() < teacher_forcing_ratio
            top1 = output.argmax(1).unsqueeze(1)
            input_token = tgt[:, t].unsqueeze(1) if teacher_force else top1
        return torch.cat(outputs, dim=1)

def train(model, loader, optimizer, criterion):
    model.train()
    total_loss = 0
    for src, tgt in loader:
        src, tgt = src.to(device), tgt.to(device)
        optimizer.zero_grad()
        output = model(src, tgt)
        loss = criterion(output.view(-1, output.size(-1)), tgt[:, 1:].reshape(-1))
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)

def beam_decode(model, src, beam_size):
    model.eval()
    with torch.no_grad():
        encoder_outputs, hidden = model.encode(src)
        batch_size = src.size(0)
        final_outputs = []

        for b in range(batch_size):
            h_b = (hidden[0][:, b:b+1, :].contiguous(), hidden[1][:, b:b+1, :].contiguous()) if model.cell_type == 'LSTM' else hidden[:, b:b+1, :].contiguous()
            enc_out_b = encoder_outputs[b:b+1]
            beams = [([model.sos_idx], 0.0, h_b)]
            
            for _ in range(model.max_len):
                new_beams = []
                for seq, score, h in beams:
                    if seq[-1] == model.eos_idx:
                        new_beams.append((seq, score, h))
                        continue
                    input_token = torch.tensor([[seq[-1]]], device=device)
                    out, h_new, _ = model.decode_step(input_token, h, enc_out_b)
                    log_probs = F.log_softmax(out, dim=1)
                    topk_probs, topk_idxs = torch.topk(log_probs, beam_size, dim=1)
                    for i in range(beam_size):
                        next_seq = seq + [topk_idxs[0][i].item()]
                        new_score = score + topk_probs[0][i].item()
                        new_beams.append((next_seq, new_score, h_new))
                beams = sorted(new_beams, key=lambda x: x[1], reverse=True)[:beam_size]
            final_outputs.append(beams[0][0])
        return final_outputs

def plot_attention_heatmaps(model, test_samples, num_samples=9):
    model.eval()
    plt.figure(figsize=(15, 15))
    samples = test_samples[:num_samples]
    
    for i, (src, tgt) in enumerate(samples):
        plt.subplot(3, 3, i+1)
        with torch.no_grad():
            src_tensor = torch.tensor([src_vocab.encode(src)], device=device)
            tgt_tensor = torch.tensor([tgt_vocab.encode(tgt)], device=device)
            model(src_tensor, tgt_tensor)  # This populates attention_weights
            
            # Get attention weights and convert to numpy
            attn_weights = torch.cat(model.attention_weights).squeeze().cpu().numpy()
            
            # Create heatmap
            sns.heatmap(attn_weights, cmap="YlGnBu", 
                        xticklabels=list(src),
                        yticklabels=list(tgt_vocab.decode(tgt_tensor[0][1:-1])))
            plt.title(f"Input: {src}\nOutput: {tgt_vocab.decode(tgt_tensor[0][1:-1])}")
            plt.xlabel("Source Characters")
            plt.ylabel("Target Characters")
    
    plt.tight_layout()
    return plt

def evaluate_beam(model, dataloader, beam_size):
    model.eval()  # Set model to evaluation mode (disables dropout, etc.)
    total_seq, correct_seq = 0, 0  # Track full sequence accuracy
    total_tokens, correct_tokens = 0, 0  # Track token-level accuracy
    all_predictions = []  # Store predictions for analysis/visualization

    with torch.no_grad():  # Disable gradient computation for faster evaluation
        for src, tgt in dataloader:
            src, tgt = src.to(device), tgt.to(device)  # Move tensors to device (CPU/GPU)

            preds = beam_decode(model, src, beam_size)  # Perform beam search decoding

            for i, (pred, true) in enumerate(zip(preds, tgt)):
                # Remove special tokens (pad/eos) from predictions and targets
                pred_trimmed = [tok for tok in pred[1:] if tok != model.pad_idx and tok != model.eos_idx]
                true_trimmed = [tok.item() for tok in true[1:] if tok.item() != model.pad_idx and tok.item() != model.eos_idx]

                # Check if entire sequence matches exactly
                if pred_trimmed == true_trimmed:
                    correct_seq += 1
                total_seq += 1

                # Compare token-by-token for token-level accuracy
                for p, t in zip(pred_trimmed, true_trimmed):
                    if p == t:
                        correct_tokens += 1
                total_tokens += len(true_trimmed)

                # Decode input, prediction, and ground truth for storage
                src_word = src_vocab.decode([x.item() for x in src[i] if x.item() not in (src_vocab.sos_idx, src_vocab.eos_idx, src_vocab.pad_idx)])
                pred_word = tgt_vocab.decode(pred)
                true_word = tgt_vocab.decode([x.item() for x in true if x.item() not in (tgt_vocab.pad_idx, tgt_vocab.eos_idx)])

                # Store triplet: (input word, predicted word, true word)
                all_predictions.append((src_word, pred_word, true_word))

    # Compute accuracy scores
    seq_accuracy = correct_seq / total_seq if total_seq > 0 else 0.0
    token_accuracy = correct_tokens / total_tokens if total_tokens > 0 else 0.0

    return seq_accuracy, token_accuracy, all_predictions  # Return both accuracy metrics and all decoded predictions

def visualize_predictions(predictions, num_samples=10, log_to_wandb=False):
    """
    Visualizes a subset of model predictions by highlighting character-level differences 
    between the predicted and true output.

    Parameters:
    -----------
    predictions : list of tuples or lists
        Each element should be a (input, predicted, true) triple representing the input string,
        the model's prediction, and the ground truth string.

    num_samples : int, optional (default=10)
        The number of prediction samples to display.

    log_to_wandb : bool, optional (default=False)
        If True, logs the raw predictions to Weights & Biases using wandb.Table.

    Returns:
    --------
    styled_df

    Notes:
    ------
    - This function is useful for qualitative evaluation of sequence prediction models.
    - It uses HTML and pandas styling to visually compare predictions.
    """
    df = pd.DataFrame(predictions[:num_samples], columns=['Input', 'Predicted', 'True'])
    
    def highlight_diff(row):
        pred, true = row['Predicted'], row['True']
        diff = []
        for p, t in zip(pred, true):
            if p == t:
                diff.append(p)
            else:
                diff.append(f'<b style="color:red">{p}</b>')
        return ''.join(diff)
    
    df['Difference'] = df.apply(lambda row: highlight_diff(row), axis=1)
    
    def row_style(row):
        color = 'lightgreen' if row['Predicted'] == row['True'] else 'lightpink'
        return [f'background-color: {color}' for _ in row]
    
    styled_df = df.style.apply(row_style, axis=1).set_properties(**{'text-align': 'left'})
    display(HTML(styled_df.to_html(escape=False)))
    
    if log_to_wandb:
        wandb.log({"predictions": wandb.Table(dataframe=df)})
    
    return styled_df

def save_predictions(predictions, filename):
    os.makedirs('predictions_attention', exist_ok=True)
    df = pd.DataFrame(predictions, columns=['Input', 'Predicted', 'True'])
    df.to_csv(f'predictions_attention/{filename}', index=False)
    print(f"Saved to predictions_attention/{filename}")

def sweep_train(config=None):
    with wandb.init(config=config):
        config = wandb.config
        run_name = f"embed{config.embed_size}_hid{config.hidden_size}_enc{config.enc_layers}_dec{config.dec_layers}_{config.cell}_drop{config.dropout}_beam{config.beam_size}_attn"
        wandb.run.name = run_name

        model = Seq2Seq(config, len(src_vocab), len(tgt_vocab)).to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
        criterion = nn.CrossEntropyLoss(ignore_index=tgt_vocab.pad_idx)

        for epoch in range(20):  # Increased epochs for better training
            train_loss = train(model, train_loader, optimizer, criterion)
            acc, token_acc, val_preds = evaluate_beam(model, dev_loader, beam_size=config.beam_size)
            test_acc, test_token_acc, test_preds = evaluate_beam(model, test_loader, beam_size=config.beam_size)

            # Log attention heatmaps at the end of training
            if epoch == 19:
                # Get first 9 test samples properly
                test_samples = []
                for i in range(9):
                    src, tgt = test_loader.dataset[i]
                    src_word = src_vocab.decode([x.item() for x in src if x.item() not in (src_vocab.sos_idx, src_vocab.eos_idx, src_vocab.pad_idx)])
                    tgt_word = tgt_vocab.decode([x.item() for x in tgt if x.item() not in (tgt_vocab.pad_idx, tgt_vocab.eos_idx)])
                    test_samples.append((src_word, tgt_word))
                
                attention_plot = plot_attention_heatmaps(model, test_samples)
                wandb.log({"attention_heatmaps": wandb.Image(attention_plot)})
                plt.close()

            wandb.log({
                'epoch': epoch + 1,
                'train_loss': train_loss,
                'val_accuracy': acc,
                'val_token_accuracy': token_acc,
                'test_accuracy': test_acc,
                'test_token_accuracy': test_token_acc,
                'used_attention': True
            })

            print(f"Epoch {epoch+1}: Train Loss = {train_loss:.4f}, Val Acc = {acc:.4f}, Test Acc = {test_acc:.4f}")
        visualize_predictions(test_preds, num_samples=15, log_to_wandb=True)
        save_predictions(test_preds, f'test_predictions_{run_name}.csv')

# Sweep Config
sweep_config = {
    'method': 'bayes',
    'metric': {'name': 'val_accuracy', 'goal': 'maximize'},
    'parameters': {
        'embed_size': {'values': [256]},
        'hidden_size': {'values': [128]},
        'enc_layers': {'values': [2]},
        'dec_layers': {'values': [1]},
        'dropout': {'values': [0.30]},
        'cell': {'values': ['LSTM']},
        'beam_size': {'values': [5]}
    }
}

sweep_id = wandb.sweep(sweep_config, project="A3_attention_ce21b020")
wandb.agent(sweep_id, function=sweep_train, count=1)

def create_prediction_zip():
    with zipfile.ZipFile('predictions_attention.zip', 'w') as zipf:
        for root, dirs, files in os.walk('predictions_attention'):
            for file in files:
                zipf.write(os.path.join(root, file))
    print("Zip created: predictions_attention.zip")

create_prediction_zip()

Create sweep with ID: vgrgvz79
Sweep URL: https://wandb.ai/apoorvaprashanth-indian-institute-of-technology-madras/A3_attention_ce21b020/sweeps/vgrgvz79


[34m[1mwandb[0m: Agent Starting Run: n8rd5zr1 with config:
[34m[1mwandb[0m: 	beam_size: 5
[34m[1mwandb[0m: 	cell: LSTM
[34m[1mwandb[0m: 	dec_layers: 1
[34m[1mwandb[0m: 	dropout: 0.3
[34m[1mwandb[0m: 	embed_size: 256
[34m[1mwandb[0m: 	enc_layers: 2
[34m[1mwandb[0m: 	hidden_size: 128


Epoch 1: Train Loss = 1.0981, Val Acc = 0.4072, Test Acc = 0.4125
Epoch 2: Train Loss = 0.4345, Val Acc = 0.4583, Test Acc = 0.4506
Epoch 3: Train Loss = 0.3356, Val Acc = 0.4856, Test Acc = 0.4848
Epoch 4: Train Loss = 0.2783, Val Acc = 0.4866, Test Acc = 0.4824
Epoch 5: Train Loss = 0.2367, Val Acc = 0.5054, Test Acc = 0.4920
Epoch 6: Train Loss = 0.2121, Val Acc = 0.5136, Test Acc = 0.5123
Epoch 7: Train Loss = 0.1827, Val Acc = 0.5155, Test Acc = 0.5242
Epoch 8: Train Loss = 0.1695, Val Acc = 0.5063, Test Acc = 0.5155
Epoch 9: Train Loss = 0.1513, Val Acc = 0.5276, Test Acc = 0.5127
Epoch 10: Train Loss = 0.1402, Val Acc = 0.5022, Test Acc = 0.5052
Epoch 11: Train Loss = 0.1327, Val Acc = 0.5169, Test Acc = 0.5194
Epoch 12: Train Loss = 0.1174, Val Acc = 0.5139, Test Acc = 0.5185
Epoch 13: Train Loss = 0.1116, Val Acc = 0.5102, Test Acc = 0.5082
Epoch 14: Train Loss = 0.1068, Val Acc = 0.5153, Test Acc = 0.5169
Epoch 15: Train Loss = 0.1019, Val Acc = 0.5178, Test Acc = 0.5134
Epoc

In [None]:
import wandb
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import random
import os
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from IPython.display import display, HTML
import zipfile
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Char-level vocabulary with encoding specification
class CharVocab:
    def __init__(self, words):
        chars = sorted(set("".join(words)))
        self.char2idx = {'<pad>': 0, '<sos>': 1, '<eos>': 2, '<unk>': 3}
        for c in chars:
            self.char2idx[c] = len(self.char2idx)
        self.idx2char = {i: c for c, i in self.char2idx.items()}
        self.pad_idx = self.char2idx['<pad>']
        self.sos_idx = self.char2idx['<sos>']
        self.eos_idx = self.char2idx['<eos>']

    def encode(self, word):
        return [self.sos_idx] + [self.char2idx.get(c, self.char2idx['<unk>']) for c in word] + [self.eos_idx]

    def decode(self, ids):
        chars = []
        for idx in ids:
            if idx == self.eos_idx:
                break
            if idx not in (self.sos_idx, self.pad_idx):
                chars.append(self.idx2char.get(idx, ''))
        return ''.join(chars)

    def __len__(self):
        return len(self.char2idx)

def read_file(path):
    with open(path, encoding='utf-8') as f:
        lines = f.read().strip().split('\n')
    return [(line.split('\t')[0], line.split('\t')[1]) for line in lines if len(line.split('\t')) >= 2]

train_pairs = read_file('/kaggle/input/malayalam/ml.translit.sampled.train.tsv')
dev_pairs = read_file('/kaggle/input/malayalam/ml.translit.sampled.dev.tsv')
test_pairs = read_file('/kaggle/input/malayalam/ml.translit.sampled.test.tsv')

src_vocab = CharVocab([src for _, src in train_pairs])
tgt_vocab = CharVocab([tgt for tgt, _ in train_pairs])

class TransliterationDataset(Dataset):
    def __init__(self, pairs, src_vocab, tgt_vocab):
        self.data = pairs
        self.src_vocab = src_vocab
        self.tgt_vocab = tgt_vocab

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        tgt, src = self.data[idx]
        return torch.tensor(self.src_vocab.encode(src)), torch.tensor(self.tgt_vocab.encode(tgt))

def collate_fn(batch):
    src_batch, tgt_batch = zip(*batch)
    src_pad = nn.utils.rnn.pad_sequence(src_batch, batch_first=True, padding_value=src_vocab.pad_idx)
    tgt_pad = nn.utils.rnn.pad_sequence(tgt_batch, batch_first=True, padding_value=tgt_vocab.pad_idx)
    return src_pad, tgt_pad

train_loader = DataLoader(TransliterationDataset(train_pairs, src_vocab, tgt_vocab),
                          batch_size=64, shuffle=True, collate_fn=collate_fn)
dev_loader = DataLoader(TransliterationDataset(dev_pairs, src_vocab, tgt_vocab),
                        batch_size=32, shuffle=False, collate_fn=collate_fn)
test_loader = DataLoader(TransliterationDataset(test_pairs, src_vocab, tgt_vocab),
                        batch_size=32, shuffle=False, collate_fn=collate_fn)

class Attention(nn.Module):
    def __init__(self, enc_hidden_dim, dec_hidden_dim):
        super().__init__()
        self.attn = nn.Linear(enc_hidden_dim + dec_hidden_dim, dec_hidden_dim)
        self.v = nn.Parameter(torch.rand(dec_hidden_dim))

    def forward(self, hidden, encoder_outputs, mask=None):
        src_len = encoder_outputs.shape[1]
        hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)
        energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim=2)))
        energy = energy.permute(0, 2, 1)
        v = self.v.repeat(encoder_outputs.size(0), 1).unsqueeze(1)
        attention = torch.bmm(v, energy).squeeze(1)
        if mask is not None:
            attention = attention.masked_fill(mask == 0, -1e10)
        return F.softmax(attention, dim=1)

class Seq2Seq(nn.Module):
    def __init__(self, config, input_vocab_size, output_vocab_size):
        super().__init__()
        self.embedding_dim = config.embed_size
        self.hidden_size = config.hidden_size
        self.num_enc_layers = config.enc_layers
        self.num_dec_layers = config.dec_layers
        self.cell_type = config.cell
        self.device = device
        self.dropout = nn.Dropout(config.dropout)
        self.max_len = 30
        self.attention_weights = []  # Store attention weights for visualization

        self.encoder_embedding = nn.Embedding(input_vocab_size, self.embedding_dim)
        self.decoder_embedding = nn.Embedding(output_vocab_size, self.embedding_dim)

        RNN = {'RNN': nn.RNN, 'GRU': nn.GRU, 'LSTM': nn.LSTM}[self.cell_type]
        self.encoder = RNN(self.embedding_dim, self.hidden_size, num_layers=self.num_enc_layers,
                           batch_first=True, bidirectional=True)
        self.decoder = RNN(self.embedding_dim + self.hidden_size * 2, self.hidden_size * 2,
                           num_layers=self.num_dec_layers, batch_first=True)

        self.attention = Attention(self.hidden_size * 2, self.hidden_size * 2)
        self.fc = nn.Linear(self.hidden_size * 4, output_vocab_size)

        self.sos_idx = tgt_vocab.sos_idx
        self.eos_idx = tgt_vocab.eos_idx
        self.pad_idx = tgt_vocab.pad_idx

    def encode(self, src):
        embedded = self.dropout(self.encoder_embedding(src))
        outputs, h_n = self.encoder(embedded)
        if self.cell_type == 'LSTM':
            h, c = h_n
            h_cat = torch.cat((h[-2], h[-1]), dim=1).unsqueeze(0)
            c_cat = torch.cat((c[-2], c[-1]), dim=1).unsqueeze(0)
            return outputs, (h_cat, c_cat)
        else:
            h_cat = torch.cat((h_n[-2], h_n[-1]), dim=1).unsqueeze(0)
            return outputs, h_cat

    def decode_step(self, input_token, hidden, encoder_outputs):
        embedded = self.dropout(self.decoder_embedding(input_token))
        if self.cell_type == 'LSTM':
            h_t = hidden[0][-1]
        else:
            h_t = hidden[-1]
            
        attn_weights = self.attention(h_t, encoder_outputs)
        self.attention_weights.append(attn_weights)  # Store for visualization
        context = torch.bmm(attn_weights.unsqueeze(1), encoder_outputs)
        rnn_input = torch.cat((embedded, context), dim=2)
        output, hidden = self.decoder(rnn_input, hidden)
        logits = self.fc(torch.cat((output.squeeze(1), context.squeeze(1)), dim=1))
        return logits, hidden, attn_weights

    def forward(self, src, tgt, teacher_forcing_ratio=0.5):
        self.attention_weights = []  # Reset attention weights storage
        batch_size, tgt_len = tgt.shape
        encoder_outputs, hidden = self.encode(src)
        input_token = tgt[:, 0].unsqueeze(1)
        outputs = []

        for t in range(1, tgt_len):
            output, hidden, attn_weights = self.decode_step(input_token, hidden, encoder_outputs)
            outputs.append(output.unsqueeze(1))
            teacher_force = random.random() < teacher_forcing_ratio
            top1 = output.argmax(1).unsqueeze(1)
            input_token = tgt[:, t].unsqueeze(1) if teacher_force else top1
        return torch.cat(outputs, dim=1)

def train(model, loader, optimizer, criterion):
    model.train()
    total_loss = 0
    for src, tgt in loader:
        src, tgt = src.to(device), tgt.to(device)
        optimizer.zero_grad()
        output = model(src, tgt)
        loss = criterion(output.view(-1, output.size(-1)), tgt[:, 1:].reshape(-1))
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)

def beam_decode(model, src, beam_size):
    model.eval()
    with torch.no_grad():
        encoder_outputs, hidden = model.encode(src)
        batch_size = src.size(0)
        final_outputs = []

        for b in range(batch_size):
            h_b = (hidden[0][:, b:b+1, :].contiguous(), hidden[1][:, b:b+1, :].contiguous()) if model.cell_type == 'LSTM' else hidden[:, b:b+1, :].contiguous()
            enc_out_b = encoder_outputs[b:b+1]
            beams = [([model.sos_idx], 0.0, h_b)]
            
            for _ in range(model.max_len):
                new_beams = []
                for seq, score, h in beams:
                    if seq[-1] == model.eos_idx:
                        new_beams.append((seq, score, h))
                        continue
                    input_token = torch.tensor([[seq[-1]]], device=device)
                    out, h_new, _ = model.decode_step(input_token, h, enc_out_b)
                    log_probs = F.log_softmax(out, dim=1)
                    topk_probs, topk_idxs = torch.topk(log_probs, beam_size, dim=1)
                    for i in range(beam_size):
                        next_seq = seq + [topk_idxs[0][i].item()]
                        new_score = score + topk_probs[0][i].item()
                        new_beams.append((next_seq, new_score, h_new))
                beams = sorted(new_beams, key=lambda x: x[1], reverse=True)[:beam_size]
            final_outputs.append(beams[0][0])
        return final_outputs

def plot_attention_heatmaps(model, test_samples, num_samples=9):
    model.eval()
    plt.figure(figsize=(15, 15))
    samples = test_samples[:num_samples]
    
    for i, (src, tgt) in enumerate(samples):
        plt.subplot(3, 3, i+1)
        with torch.no_grad():
            src_tensor = torch.tensor([src_vocab.encode(src)], device=device)
            tgt_tensor = torch.tensor([tgt_vocab.encode(tgt)], device=device)
            model(src_tensor, tgt_tensor)  # This populates attention_weights
            
            # Get attention weights and convert to numpy
            attn_weights = torch.cat(model.attention_weights).squeeze().cpu().numpy()
            
            # Create heatmap with proper labels
            ax = sns.heatmap(attn_weights, cmap="YlGnBu", 
                        xticklabels=list(src),
                        yticklabels=list(tgt))
            plt.title(f"Input: {src}\nOutput: {tgt}")
            plt.xlabel("Source Characters")
            plt.ylabel("Target Characters")
            ax.set_xticklabels(ax.get_xticklabels(), rotation=45)
            ax.set_yticklabels(ax.get_yticklabels(), rotation=0)
    
    plt.tight_layout()
    return plt

def create_interactive_attention_plot(model, src_word, tgt_word):
    model.eval()
    with torch.no_grad():
        src_tensor = torch.tensor([src_vocab.encode(src_word)], device=device)
        tgt_tensor = torch.tensor([tgt_vocab.encode(tgt_word)], device=device)
        model(src_tensor, tgt_tensor)
        
        attn_weights = torch.cat(model.attention_weights).squeeze().cpu().numpy()
        
        fig = go.Figure(data=go.Heatmap(
            z=attn_weights,
            x=list(src_word),
            y=list(tgt_word),
            colorscale='YlGnBu',
            hoverongaps=False
        ))
        
        fig.update_layout(
            title=f'Attention Visualization: {src_word} → {tgt_word}',
            xaxis_title='Source Characters',
            yaxis_title='Target Characters',
            width=800,
            height=600
        )
        
        return fig

def evaluate_beam(model, dataloader, beam_size):
    model.eval()
    total_seq, correct_seq = 0, 0
    total_tokens, correct_tokens = 0, 0
    all_predictions = []

    with torch.no_grad():
        for src, tgt in dataloader:
            src, tgt = src.to(device), tgt.to(device)
            preds = beam_decode(model, src, beam_size)
            
            for i, (pred, true) in enumerate(zip(preds, tgt)):
                pred_trimmed = [tok for tok in pred[1:] if tok != model.pad_idx and tok != model.eos_idx]
                true_trimmed = [tok.item() for tok in true[1:] if tok.item() != model.pad_idx and tok.item() != model.eos_idx]

                if pred_trimmed == true_trimmed:
                    correct_seq += 1
                total_seq += 1

                for p, t in zip(pred_trimmed, true_trimmed):
                    if p == t:
                        correct_tokens += 1
                total_tokens += len(true_trimmed)
                
                src_word = src_vocab.decode([x.item() for x in src[i] if x.item() not in (src_vocab.sos_idx, src_vocab.eos_idx, src_vocab.pad_idx)])
                pred_word = tgt_vocab.decode(pred)
                true_word = tgt_vocab.decode([x.item() for x in true if x.item() not in (tgt_vocab.pad_idx, tgt_vocab.eos_idx)])
                
                all_predictions.append((src_word, pred_word, true_word))

    seq_accuracy = correct_seq / total_seq if total_seq > 0 else 0.0
    token_accuracy = correct_tokens / total_tokens if total_tokens > 0 else 0.0
    return seq_accuracy, token_accuracy, all_predictions

def visualize_predictions(predictions, num_samples=10, log_to_wandb=False):
    df = pd.DataFrame(predictions[:num_samples], columns=['Input', 'Predicted', 'True'])
    
    def highlight_diff(row):
        pred, true = row['Predicted'], row['True']
        diff = []
        for p, t in zip(pred, true):
            if p == t:
                diff.append(p)
            else:
                diff.append(f'<b style="color:red">{p}</b>')
        return ''.join(diff)
    
    df['Difference'] = df.apply(lambda row: highlight_diff(row), axis=1)
    
    def row_style(row):
        color = 'lightgreen' if row['Predicted'] == row['True'] else 'lightpink'
        return [f'background-color: {color}' for _ in row]
    
    styled_df = df.style.apply(row_style, axis=1).set_properties(**{'text-align': 'left'})
    display(HTML(styled_df.to_html(escape=False)))
    
    if log_to_wandb:
        wandb.log({"predictions": wandb.Table(dataframe=df)})
    
    return styled_df

def save_predictions(predictions, filename):
    os.makedirs('predictions_attention', exist_ok=True)
    df = pd.DataFrame(predictions, columns=['Input', 'Predicted', 'True'])
    # Save with UTF-8 encoding to handle special characters
    df.to_csv(f'predictions_attention/{filename}', index=False, encoding='utf-8')
    print(f"Saved to predictions_attention/{filename}")

def save_model(model, filename):
    os.makedirs('saved_models', exist_ok=True)
    torch.save({
        'model_state_dict': model.state_dict(),
        'src_vocab': src_vocab,
        'tgt_vocab': tgt_vocab,
        'config': config
    }, f'saved_models/{filename}')
    print(f"Model saved to saved_models/{filename}")


def sweep_train(config=None):
    with wandb.init(config=config):
        config = wandb.config
        run_name = f"embed{config.embed_size}_hid{config.hidden_size}_enc{config.enc_layers}_dec{config.dec_layers}_{config.cell}_drop{config.dropout}_beam{config.beam_size}_attn"
        wandb.run.name = run_name

        model = Seq2Seq(config, len(src_vocab), len(tgt_vocab)).to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
        criterion = nn.CrossEntropyLoss(ignore_index=tgt_vocab.pad_idx)

        for epoch in range(5):
            train_loss = train(model, train_loader, optimizer, criterion)
            acc, token_acc, val_preds = evaluate_beam(model, dev_loader, beam_size=config.beam_size)
            test_acc, test_token_acc, test_preds = evaluate_beam(model, test_loader, beam_size=config.beam_size)

            # Log attention visualizations
            if epoch == 4:
                # Static heatmaps
                test_samples = []
                for i in range(9):
                    src, tgt = test_loader.dataset[i]
                    src_word = src_vocab.decode([x.item() for x in src if x.item() not in (src_vocab.sos_idx, src_vocab.eos_idx, src_vocab.pad_idx)])
                    tgt_word = tgt_vocab.decode([x.item() for x in tgt if x.item() not in (tgt_vocab.pad_idx, tgt_vocab.eos_idx)])
                    test_samples.append((src_word, tgt_word))
                
                attention_plot = plot_attention_heatmaps(model, test_samples)
                wandb.log({"attention_heatmaps": wandb.Image(attention_plot)})
                plt.close()
                
                # Interactive visualization for one example
                src_word, tgt_word = test_samples[0]
                interactive_fig = create_interactive_attention_plot(model, src_word, tgt_word)
                wandb.log({"interactive_attention": wandb.Plotly(interactive_fig)})

            wandb.log({
                'epoch': epoch + 1,
                'train_loss': train_loss,
                'val_accuracy': acc,
                'val_token_accuracy': token_acc,
                'test_accuracy': test_acc,
                'test_token_accuracy': test_token_acc,
                'used_attention': True
            })

            print(f"Epoch {epoch+1}: Train Loss = {train_loss:.4f}, Val Acc = {acc:.4f}, Test Acc = {test_acc:.4f}")
        
        visualize_predictions(test_preds, num_samples=15, log_to_wandb=True)
        save_predictions(test_preds, f'test_predictions_{run_name}.csv')
        save_model(model, f'model_{run_name}.pt')

# Sweep Config
sweep_config = {
    'method': 'bayes',
    'metric': {'name': 'val_accuracy', 'goal': 'maximize'},
    'parameters': {
        'embed_size': {'values': [256]},
        'hidden_size': {'values': [128]},
        'enc_layers': {'values': [2]},
        'dec_layers': {'values': [1]},
        'dropout': {'values': [0.30]},
        'cell': {'values': ['LSTM']},
        'beam_size': {'values': [5]}
    }
}

sweep_id = wandb.sweep(sweep_config, project="A3_ce21b020")
wandb.agent(sweep_id, function=sweep_train, count=1)

def create_prediction_zip():
    with zipfile.ZipFile('predictions_attention.zip', 'w') as zipf:
        for root, dirs, files in os.walk('predictions_attention'):
            for file in files:
                zipf.write(os.path.join(root, file))
    print("Zip created: predictions_attention.zip")

create_prediction_zip()

Create sweep with ID: 28jgsziy
Sweep URL: https://wandb.ai/apoorvaprashanth-indian-institute-of-technology-madras/A3_ce21b020/sweeps/28jgsziy


[34m[1mwandb[0m: Agent Starting Run: xhz3l7hg with config:
[34m[1mwandb[0m: 	beam_size: 5
[34m[1mwandb[0m: 	cell: LSTM
[34m[1mwandb[0m: 	dec_layers: 1
[34m[1mwandb[0m: 	dropout: 0.3
[34m[1mwandb[0m: 	embed_size: 256
[34m[1mwandb[0m: 	enc_layers: 2
[34m[1mwandb[0m: 	hidden_size: 128


Epoch 1: Train Loss = 1.0905, Val Acc = 0.3845, Test Acc = 0.3747
Epoch 2: Train Loss = 0.4411, Val Acc = 0.4653, Test Acc = 0.4574
Epoch 3: Train Loss = 0.3372, Val Acc = 0.4930, Test Acc = 0.4829
Epoch 4: Train Loss = 0.2761, Val Acc = 0.5123, Test Acc = 0.5094


  fig.canvas.draw()
  fig.canvas.draw()
  fig.canvas.draw()
  fig.canvas.draw()
  fig.canvas.draw()
  fig.canvas.draw()
  fig.canvas.draw()
  fig.canvas.draw()
  fig.canvas.draw()
  fig.canvas.draw()
  fig.canvas.draw()
  fig.canvas.draw()
  fig.canvas.draw()
  fig.canvas.draw()
  fig.canvas.draw()
  fig.canvas.draw()
  fig.canvas.draw()
  fig.canvas.draw()
  fig.canvas.draw()
  fig.canvas.draw()
  fig.canvas.draw()
  fig.canvas.draw()
  fig.canvas.draw()
  fig.canvas.draw()
  fig.canvas.draw()
  fig.canvas.draw()
  fig.canvas.draw()
  fig.canvas.draw()
  fig.canvas.draw()
  fig.canvas.draw()
  fig.canvas.draw()
  fig.canvas.draw()
  fig.canvas.draw()
  fig.canvas.draw()
  fig.canvas.draw()
  fig.canvas.draw()
  fig.canvas.draw()
  fig.canvas.draw()
  fig.canvas.draw()
  fig.canvas.draw()
  fig.canvas.draw()
  fig.canvas.draw()
  fig.canvas.draw()
  fig.canvas.draw()
  fig.canvas.draw()
  fig.canvas.draw()
  fig.canvas.draw()
  fig.canvas.draw()
  fig.canvas.draw()
  fig.canvas.draw()


Epoch 5: Train Loss = 0.2399, Val Acc = 0.5090, Test Acc = 0.5171


Unnamed: 0,Input,Predicted,True,Difference
0,amgathavavum,അംഗതവവും,അംഗത്വവും,അംഗതവവും
1,amgathvavum,അംഗത്വവും,അംഗത്വവും,അംഗത്വവും
2,angathwavum,അംഗത്വവും,അംഗത്വവും,അംഗത്വവും
3,amgabalam,അംഗബലം,അംഗബലം,അംഗബലം
4,angabalam,അംഗബലം,അംഗബലം,അംഗബലം
5,amgeekarikkuka,അംഗീകരിക്കുക,അംഗീകരിക്കുക,അംഗീകരിക്കുക
6,angeekarikkuka,അംഗീകരിക്കുക,അംഗീകരിക്കുക,അംഗീകരിക്കുക
7,ambaasadar,അംബാസർ,അംബാസഡർ,അംബാസർ
8,ambaassador,അംബാസർ,അംബാസഡർ,അംബാസർ
9,ambassador,അംബസ്ഡഡർ,അംബാസഡർ,അംബസ്ഡഡ


Saved to predictions_attention/test_predictions_embed256_hid128_enc2_dec1_LSTM_drop0.3_beam5_attn.csv


Traceback (most recent call last):
  File "/tmp/ipykernel_35/2880310165.py", line 398, in sweep_train
    save_model(model, f'model_{run_name}.pt')
  File "/tmp/ipykernel_35/2880310165.py", line 346, in save_model
    'config': config
              ^^^^^^
NameError: name 'config' is not defined


0,1
epoch,▁▃▅▆█
test_accuracy,▁▅▆██
test_token_accuracy,▁▆▇██
train_loss,█▃▂▁▁
val_accuracy,▁▅▇██
val_token_accuracy,▁▆▇██

0,1
epoch,5
test_accuracy,0.51711
test_token_accuracy,0.83063
train_loss,0.2399
used_attention,True
val_accuracy,0.50895
val_token_accuracy,0.82101


[34m[1mwandb[0m: [32m[41mERROR[0m Run xhz3l7hg errored:
[34m[1mwandb[0m: [32m[41mERROR[0m Traceback (most recent call last):
[34m[1mwandb[0m: [32m[41mERROR[0m   File "/usr/local/lib/python3.11/dist-packages/wandb/agents/pyagent.py", line 306, in _run_job
[34m[1mwandb[0m: [32m[41mERROR[0m     self._function()
[34m[1mwandb[0m: [32m[41mERROR[0m   File "/tmp/ipykernel_35/2880310165.py", line 398, in sweep_train
[34m[1mwandb[0m: [32m[41mERROR[0m     save_model(model, f'model_{run_name}.pt')
[34m[1mwandb[0m: [32m[41mERROR[0m   File "/tmp/ipykernel_35/2880310165.py", line 346, in save_model
[34m[1mwandb[0m: [32m[41mERROR[0m     'config': config
[34m[1mwandb[0m: [32m[41mERROR[0m               ^^^^^^
[34m[1mwandb[0m: [32m[41mERROR[0m NameError: name 'config' is not defined
[34m[1mwandb[0m: [32m[41mERROR[0m 


Zip created: predictions_attention.zip


The error above means that the font being used by Matplotlib (and WandB) does not support some Malayalam characters, so those characters cannot be rendered (displayed correctly) in plots or images. Beacuse of which, some chanracters will be shown as blocks. Howvwer, this doesn't affect the functionality of the code