In [None]:
pip install torch wandb pandas tqdm

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import random
import wandb
import editdistance
import numpy as np

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)
LANG = 'te'
data_path = f'/content/drive/MyDrive/dakshina_dataset_v1.0/{LANG}/lexicons/'

def read_data(filepath, max_len=40):
    pairs = []
    # Open the file with UTF-8 encoding to properly read Unicode characters
    with open(filepath, encoding='utf8') as f:
        for line in f:
            # Remove leading/trailing whitespace and split by tab
            parts = line.strip().split('\t')
            # Skip lines that don't contain both source and target text
            if len(parts) < 2:
                continue
            devanagiri, latin = parts[0], parts[1]

            # We are training a Latin → Devanagiri transliteration model,
            # so set Latin as the source and Devanagiri as the target
            source, target = latin, devanagiri

            # Only keep pairs where both source and target are within the allowed max length
            if len(source) <= max_len and len(target) <= max_len:
                pairs.append((source, target))

    # Return the list of filtered (source, target) pairs
    return pairs

def make_vocab(sequences):
    # Initialize the vocabulary with special tokens
    vocab = {'<pad>': 0, '<sos>': 1, '<eos>': 2}
    idx = 3  # Starting index for regular characters

    # Loop through all sequences to build the vocabulary
    for seq in sequences:
        for ch in seq:
            # Add each unique character to the vocabulary
            if ch not in vocab:
                vocab[ch] = idx
                idx += 1

    # Create reverse mapping from index to character
    idx2char = {i: c for c, i in vocab.items()}

    # Return both the character-to-index and index-to-character dictionaries
    return vocab, idx2char

def encode_word(word, vocab):
    # Convert a word into a list of indices using the vocabulary
    # Add <sos> token at the beginning and <eos> token at the end
    return [vocab['<sos>']] + [vocab[ch] for ch in word] + [vocab['<eos>']]

def pad_seq(seq, max_len, pad_idx=0):
    # Pad the sequence with <pad> tokens (default index 0) to reach max_len
    return seq + [pad_idx] * (max_len - len(seq))

class TransliterationDataset(Dataset):
    def __init__(self, pairs, source_vocab, target_vocab):
        # Save padding indices for both source and target vocabularies
        self.source_pad = source_vocab['<pad>']
        self.target_pad = target_vocab['<pad>']
        self.data = []

        # Convert each (source, target) word pair into sequences of token indices
        for source, target in pairs:
            source_t = encode_word(source, source_vocab)
            target_t = encode_word(target, target_vocab)
            self.data.append((source_t, target_t))

        # Determine the maximum lengths of source and target sequences
        self.source_max = max(len(x[0]) for x in self.data)
        self.target_max = max(len(x[1]) for x in self.data)

    def __len__(self):
        # Return total number of samples in the dataset
        return len(self.data)

    def __getitem__(self, idx):
        # Fetch a source-target pair and pad both to their respective max lengths
        source, target = self.data[idx]
        source = pad_seq(source, self.source_max, self.source_pad)
        target = pad_seq(target, self.target_max, self.target_pad)
        return torch.tensor(source), torch.tensor(target)

class Attention(nn.Module):
    def __init__(self, hid_dimensions):
        super().__init__()
        # Linear layer to compute attention scores from hidden and encoder outputs
        self.attn = nn.Linear(hid_dimensions * 2, hid_dimensions)

        # Learnable vector used to reduce the attention scores to a scalar
        self.v = nn.Parameter(torch.rand(hid_dimensions))

        # Initialize vector weights uniformly
        stdv = 1. / (hid_dimensions ** 0.5)
        self.v.data.uniform_(-stdv, stdv)

        self.hid_dimensions = hid_dimensions

    def forward(self, hidden, encoder_outputs):
        # hidden: decoder hidden state
        # encoder_outputs: all encoder outputs for the input sequence

        batch_size = encoder_outputs.size(0)
        src_len = encoder_outputs.size(1)

        # If hidden state has multiple layers, take the last one
        if hidden.dim() == 3:
            hidden = hidden[-1]
        elif hidden.dim() != 2:
            raise ValueError(f"Expected hidden to be 2D or 3D, got shape {hidden.shape}")

        # Repeat hidden state to match the number of encoder outputs
        hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)

        # Concatenate hidden and encoder outputs, then pass through a non-linear layer
        energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim=2)))

        # Compute raw attention scores using the learnable vector `v`
        energy = energy @ self.v

        # Normalize scores into a probability distribution (attention weights)
        attn_weights = torch.softmax(energy, dim=1).unsqueeze(2)

        # Compute weighted sum of encoder outputs (context vector)
        context = torch.sum(attn_weights * encoder_outputs, dim=1)

        # Return both the context vector and the attention weights
        return context, attn_weights.squeeze(2)

class translit_Decoder(nn.Module):
    def __init__(self, output_dimensions, emb_dimensions, hid_dimensions, num_layers, dropout, cell='lstm'):
        super().__init__()

        # Embedding layer to convert token indices into dense vectors
        self.embedding = nn.Embedding(output_dimensions, emb_dimensions)

        # Attention module to focus on relevant parts of the encoder output
        self.attention = Attention(hid_dimensions)

        # Choose RNN type based on user-specified cell type
        rnn_cls = {'rnn': nn.RNN, 'gru': nn.GRU, 'lstm': nn.LSTM}[cell.lower()]

        # RNN layer to process embedded inputs and context
        self.rnn = rnn_cls(
            emb_dimensions, hid_dimensions, num_layers,
            dropout=dropout if num_layers > 1 else 0,
            batch_first=True
        )

        # Final fully connected layer to map combined context + RNN output to vocabulary logits
        self.fc_out = nn.Linear(hid_dimensions * 2, output_dimensions)

        # Store the type of RNN cell
        self.cell = cell.lower()

        # Apply dropout to the embeddings
        self.dropout = nn.Dropout(dropout)

    def forward(self, input, hidden, cell, encoder_outputs):
        # Add time dimension to input (batch_size → batch_size x 1)
        input = input.unsqueeze(1)

        # Convert input token index to embedding and apply dropout
        embedded = self.dropout(self.embedding(input))

        # Pass through the RNN (handle LSTM and others differently)
        if self.cell == 'lstm':
            output, (hidden, cell) = self.rnn(embedded, (hidden, cell))
        else:
            output, hidden = self.rnn(embedded, hidden)
            cell = None  # Non-LSTM cells don't return a separate cell state

        # Use attention mechanism to compute context vector from encoder outputs
        context, attn_weights = self.attention(hidden, encoder_outputs)

        # Remove time dimension from RNN output
        rnn_output = output.squeeze(1)

        # Combine RNN output and context for final prediction
        combined = torch.cat((rnn_output, context), dim=1)

        # Compute the predicted output token scores
        prediction = self.fc_out(combined)

        # Return prediction, updated hidden/cell states, and attention weights
        return prediction, hidden, cell, attn_weights


class translit_Encoder(nn.Module):
    def __init__(self, input_dimensions, emb_dimensions, hid_dimensions, num_layers, dropout, cell='lstm'):
        super().__init__()

        # Embedding layer to convert input indices into dense vectors
        self.embedding = nn.Embedding(input_dimensions, emb_dimensions)

        # Choose RNN type based on cell argument
        rnn_cls = {'rnn': nn.RNN, 'gru': nn.GRU, 'lstm': nn.LSTM}[cell.lower()]

        # RNN layer to process the embedded input sequence
        self.rnn = rnn_cls(
            emb_dimensions, hid_dimensions, num_layers,
            dropout=dropout if num_layers > 1 else 0,
            batch_first=True
        )

        # Store attention module and cell type
        self.attention = Attention(hid_dimensions)
        self.cell = cell.lower()

        # Dropout layer for regularization
        self.dropout = nn.Dropout(dropout)

    def forward(self, source):
        # Convert input token indices into embeddings and apply dropout
        embedded = self.dropout(self.embedding(source))

        # Pass embedded input through RNN
        if self.cell == 'lstm':
            outputs, (hidden, cell) = self.rnn(embedded)
        else:
            outputs, hidden = self.rnn(embedded)
            cell = None

        # Compute context using attention (optional, can be ignored in basic encoder usage)
        context = self.attention(hidden, outputs)

        # Return the full sequence of encoder outputs, last hidden state, and cell state (if any)
        return outputs, hidden, cell

class translit_Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device):
        super().__init__()
        self.encoder = encoder        # Encoder processes the input sequence
        self.decoder = decoder        # Decoder generates the output sequence
        self.device = device          # Device on which computation is performed (CPU/GPU)

    def forward(self, source, target, teacher_forcing_ratio=0.5):
        batch_size = source.size(0)
        target_len = target.size(1)
        output_dimensions = self.decoder.fc_out.out_features

        # Initialize tensor to store decoder predictions for each time step
        outputs = torch.zeros(batch_size, target_len, output_dimensions).to(self.device)

        # Initialize tensor to keep track of attention weights over time
        attn_weights_all = torch.zeros(batch_size, target_len, source.size(1)).to(self.device)

        # Run the encoder on the source sequence to get hidden states
        encoder_outputs, hidden, cell = self.encoder(source)

        # Set initial decoder input to the <sos> token
        input = target[:, 0]

        # Loop over each time step in the target sequence
        for t in range(1, target_len):
            # Get decoder output and updated hidden states
            output, hidden, cell, attn_weights = self.decoder(input, hidden, cell, encoder_outputs)

            # Store the current output prediction
            outputs[:, t] = output

            # Save attention weights for this time step
            attn_weights_all[:, t] = attn_weights

            # Decide whether to use ground truth or model prediction for next input
            teacher_force = random.random() < teacher_forcing_ratio
            top1 = output.argmax(1)
            input = target[:, t] if teacher_force else top1

        # Return the full sequence of predictions and attention weights
        return outputs, attn_weights_all


def strip_after_eos(seq, eos_idx):
    # Convert tensor to list if needed
    if isinstance(seq, torch.Tensor):
        seq = seq.cpu().numpy().tolist()
    # Trim the sequence at the first <eos> token
    if eos_idx in seq:
        return seq[:seq.index(eos_idx)]
    return seq

def calculate_word_accuracy(preds, targets, pad_idx=0, eos_idx=None):
    correct = 0
    for pred, target in zip(preds, targets):
        # Remove padding and stop at <eos> for fair comparison
        pred = strip_after_eos(pred, eos_idx) if eos_idx else pred
        target = strip_after_eos(target, eos_idx) if eos_idx else target
        pred = [p for p in pred if p != pad_idx]
        target = [t for t in target if t != pad_idx]
        # Count if full predicted word matches target
        correct += int(pred == target)
    return correct / max(len(preds), 1)


def calculate_cer(preds, targets, pad_idx=0, eos_idx=None):
    cer = 0
    total = 0
    for pred, target in zip(preds, targets):
        # Clean sequences by removing padding and trimming after <eos>
        pred = strip_after_eos(pred, eos_idx) if eos_idx else pred
        target = strip_after_eos(target, eos_idx) if eos_idx else target
        pred = [p for p in pred if p != pad_idx]
        target = [t for t in target if t != pad_idx]
        # Accumulate edit distance and total characters
        cer += editdistance.eval(pred, target)
        total += max(len(target), 1)
    return cer / total if total > 0 else float('inf')


def calculate_accuracy(preds, targets, pad_idx=0, eos_idx=None):
    correct = 0
    total = 0
    for pred, target in zip(preds, targets):
        # Convert tensors to lists if necessary
        if isinstance(pred, torch.Tensor):
            pred = pred.cpu().tolist()
        if isinstance(target, torch.Tensor):
            target = target.cpu().tolist()
        # Strip <eos> tokens if specified
        if eos_idx is not None:
            pred = strip_after_eos(pred, eos_idx)
            target = strip_after_eos(target, eos_idx)
        # Compare tokens one by one, ignoring padding
        for p_token, t_token in zip(pred, target):
            if t_token == pad_idx:
                continue
            if p_token == t_token:
                correct += 1
            total += 1
    return correct / total if total > 0 else 0.0

def run(config=None):
    with wandb.init(config=config):
        cfg = wandb.config

        # Adjust hidden size based on config
        cfg.hidden_dim = 2 * cfg.embed_dim if cfg.hidden_dim_config == 'double' else cfg.embed_dim

        # Create a descriptive name for this run based on key hyperparameters
        sweep_name = f"{cfg.cell_type}_{cfg.embed_dim}e_{cfg.hidden_dim_config}h_{cfg.layers}l_" \
                     f"{int(cfg.dropout*100)}d_{int(cfg.teacher_forcing*10)}tf_" \
                     f"{str(cfg.lr).replace('.', '')}lr_attention"
        wandb.run.name = sweep_name

        max_len = 30

        # Load training, validation, and test data from language-specific files
        train_pairs = read_data(data_path + f"{LANG}.translit.sampled.train.tsv", max_len=max_len)
        val_pairs   = read_data(data_path + f"{LANG}.translit.sampled.dev.tsv",   max_len=max_len)
        test_pairs  = read_data(data_path + f"{LANG}.translit.sampled.test.tsv",  max_len=max_len)

        # Create vocabularies for both source and target languages
        source_vocab, _ = make_vocab([x[0] for x in train_pairs])
        target_vocab, _ = make_vocab([x[1] for x in train_pairs])

        # Sanity check to ensure padding tokens are consistent
        assert source_vocab['<pad>'] == 0 and target_vocab['<pad>'] == 0, "Pad token must be index 0"

        # Prepare datasets
        train_translit = TransliterationDataset(train_pairs, source_vocab, target_vocab)
        val_translit   = TransliterationDataset(val_pairs,   source_vocab, target_vocab)
        test_translit  = TransliterationDataset(test_pairs,  source_vocab, target_vocab)

        # Setup dataloaders
        train_drop_last = DataLoader(train_translit, batch_size=cfg.batch_size, shuffle=True, drop_last=True)
        val_drop_last   = DataLoader(val_translit,   batch_size=cfg.batch_size, drop_last=True)
        test_drop_last  = DataLoader(test_translit,  batch_size=cfg.batch_size, drop_last=True)

        # Initialize encoder, decoder, and the full seq2seq model
        encoder = translit_Encoder(len(source_vocab), cfg.embed_dim, cfg.hidden_dim, cfg.layers, cfg.dropout, cfg.cell_type).to(device)
        decoder = translit_Decoder(len(target_vocab), cfg.embed_dim, cfg.hidden_dim, cfg.layers, cfg.dropout, cfg.cell_type).to(device)
        model = translit_Seq2Seq(encoder, decoder, device).to(device)

        # Setup optimizer, loss function, and learning rate scheduler
        optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr)
        criterion = nn.CrossEntropyLoss(ignore_index=target_vocab['<pad>'])
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=2, factor=0.5)

        best_val_loss = float('inf')
        patience = 15
        wait = 0

        # Training loop
        for epoch in range(cfg.epochs):
            model.train()
            total_loss = 0
            total_acc  = 0
            total_char_acc = 0

            for source, target in train_drop_last:
                source, target = source.to(device), target.to(device)
                optimizer.zero_grad()

                # Forward pass through the model
                output = model(source, target, cfg.teacher_forcing)
                out_dimensions = output.shape[-1]

                # Compute loss (ignore <sos> token)
                loss = criterion(output[:, 1:].reshape(-1, out_dimensions), target[:, 1:].reshape(-1))

                eos_idx = target_vocab['<eos>']
                raw_preds = output.argmax(2)[:, 1:].tolist()
                raw_targets = target[:, 1:].tolist()
                preds = [strip_after_eos(p, eos_idx) for p in raw_preds]
                targets = [strip_after_eos(t, eos_idx) for t in raw_targets]

                acc = calculate_word_accuracy(preds, targets, pad_idx=target_vocab['<pad>'])
                char_acc = calculate_accuracy(preds, targets, pad_idx=target_vocab['<pad>'])

                # Backward pass and optimization
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                optimizer.step()

                total_loss += loss.item()
                total_acc  += acc
                total_char_acc += char_acc

            # Calculate average training metrics for the epoch
            avg_train_loss = total_loss / len(train_drop_last)
            avg_train_acc  = total_acc / len(train_drop_last)
            avg_train_char_acc = total_char_acc / len(train_drop_last)

            model.eval()
            val_loss = 0
            val_acc  = 0
            val_cer  = 0
            total_char_acc = 0

            # Validation loop (no gradient updates)
            with torch.no_grad():
                for source, target in val_drop_last:
                    source, target = source.to(device), target.to(device)
                    output = model(source, target, teacher_forcing_ratio=0)
                    out_dimensions = output.shape[-1]
                    loss = criterion(output[:, 1:].reshape(-1, out_dimensions), target[:, 1:].reshape(-1))

                    eos_idx = target_vocab['<eos>']
                    raw_preds = output.argmax(2)[:, 1:].tolist()
                    raw_targets = target[:, 1:].tolist()
                    preds = [strip_after_eos(p, eos_idx) for p in raw_preds]
                    targets = [strip_after_eos(t, eos_idx) for t in raw_targets]

                    acc = calculate_word_accuracy(preds, targets, pad_idx=target_vocab['<pad>'])
                    char_acc = calculate_accuracy(preds, targets, pad_idx=target_vocab['<pad>'])
                    cer = calculate_cer(output.argmax(2).tolist(), target.tolist(), pad_idx=target_vocab['<pad>'])

                    val_loss += loss.item()
                    val_acc  += acc
                    val_cer  += cer
                    total_char_acc += char_acc

            # Average validation metrics
            avg_val_loss = val_loss / len(val_drop_last)
            avg_val_acc  = val_acc / len(val_drop_last)
            avg_val_cer  = val_cer / len(val_drop_last)
            avg_val_char_acc = total_char_acc / len(val_drop_last)

            # Update learning rate scheduler
            scheduler.step(avg_val_loss)

            # Log metrics to wandb
            wandb.log({
                'train_loss': avg_train_loss,
                'train_accuracy': avg_train_acc,
                'train_char_accuracy': avg_train_char_acc,
                'val_loss': avg_val_loss,
                'val_accuracy': avg_val_acc,
                'val_cer': avg_val_cer,
                'val_char_accuracy': avg_val_char_acc,
                'epoch': epoch + 1
            })

            print(f"Epoch {epoch+1} | Train Loss: {avg_train_loss:.3f} Acc: {avg_train_acc:.3f} | "
                  f"Val Loss: {avg_val_loss:.3f} Acc: {avg_val_acc:.3f} CER: {avg_val_cer:.3f}")

            # Save the best model based on validation loss
            if avg_val_loss < best_val_loss:
                best_val_loss = avg_val_loss
                wait = 0
                torch.save(model.state_dict(), 'best_model.pt')
            else:
                wait += 1
                if wait >= patience:
                    print("Early stopping triggered.")
                    break

        # Save best model as a W&B artifact
        artifact = wandb.Artifact('best_model', type='model')
        artifact.add_file('best_model.pt')
        wandb.log_artifact(artifact)
        model.load_state_dict(torch.load('best_model.pt'))

sweep_config = {
    'method': 'bayes',
    'metric': {
        'name': 'val_char_accuracy',
        'goal': 'maximize'
    },
    'parameters': {
        'embed_dim': {'values': [256]},
        'hidden_dim_config': {'values': ['double']},
        'layers': {'values': [2]},
        'dropout': {'values': [0.3]},
        'lr': {'values': [0.001]},
        'cell_type': {'values': ['lstm']},
        'teacher_forcing': {'values': [0.5]},
        'batch_size': {'value': 64},
        'epochs': {'value': 20}
    }
}

In [None]:
import wandb
wandb.login()
try:
    sweep_id = wandb.sweep(sweep_config, project="dakshina-seq2seq-3")
    wandb.agent(sweep_id, function=run, count=1)
except:
    wandb.finish()