# Seq2Seq Code Generation — Google Colab (Inline Edition)
### All source code lives directly in this notebook — edit any cell and re-run to update that file
**Before running:** `Runtime → Change runtime type → GPU (T4)`

## Step 1 — Install Dependencies

In [None]:
!pip install datasets sacrebleu seaborn pyyaml tqdm -q
print("All dependencies installed!")

## Step 2 — Check GPU
If you see "No GPU", go to `Runtime → Change runtime type → T4 GPU` and re-run.

In [None]:
import torch
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
    print("Training will be fast!")
else:
    print("No GPU — go to Runtime → Change runtime type → T4 GPU")

## Step 3 — Source Files
Each cell below writes one file to disk.
**Edit the code here in Colab, then re-run the cell to save your changes.**

In [None]:
import os
os.makedirs("models", exist_ok=True)
print("models/ directory ready")

### models/__init__.py

In [None]:
%%writefile models/__init__.py
"""
Models package initialization
"""
from .vanilla_rnn import create_vanilla_seq2seq
from .lstm import create_lstm_seq2seq
from .attention_lstm import create_attention_seq2seq

__all__ = [
    'create_vanilla_seq2seq',
    'create_lstm_seq2seq',
    'create_attention_seq2seq'
]


### models/vanilla_rnn.py

In [None]:
%%writefile models/vanilla_rnn.py
"""
Vanilla RNN-based Seq2Seq Model
"""
import torch
import torch.nn as nn
import random


class EncoderRNN(nn.Module):
    """Vanilla RNN Encoder"""
    
    def __init__(self, input_size, embedding_dim, hidden_dim, dropout=0.3):
        super(EncoderRNN, self).__init__()
        self.hidden_dim = hidden_dim
        
        self.embedding = nn.Embedding(input_size, embedding_dim, padding_idx=0)
        self.rnn = nn.RNN(embedding_dim, hidden_dim, batch_first=True)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, input_seq, input_lengths):
        """
        Args:
            input_seq: (batch_size, seq_len)
            input_lengths: (batch_size,)
        
        Returns:
            outputs: (batch_size, seq_len, hidden_dim)
            hidden: (1, batch_size, hidden_dim)
        """
        embedded = self.dropout(self.embedding(input_seq))
        
        # Pack padded sequence
        packed = nn.utils.rnn.pack_padded_sequence(
            embedded, input_lengths.cpu(), batch_first=True, enforce_sorted=False
        )
        
        outputs, hidden = self.rnn(packed)
        
        # Unpack
        outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True)
        
        return outputs, hidden


class DecoderRNN(nn.Module):
    """Vanilla RNN Decoder"""
    
    def __init__(self, output_size, embedding_dim, hidden_dim, dropout=0.3):
        super(DecoderRNN, self).__init__()
        self.hidden_dim = hidden_dim
        self.output_size = output_size
        
        self.embedding = nn.Embedding(output_size, embedding_dim, padding_idx=0)
        self.rnn = nn.RNN(embedding_dim, hidden_dim, batch_first=True)
        self.out = nn.Linear(hidden_dim, output_size)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, input_token, hidden):
        """
        Args:
            input_token: (batch_size, 1)
            hidden: (1, batch_size, hidden_dim)
        
        Returns:
            output: (batch_size, output_size)
            hidden: (1, batch_size, hidden_dim)
        """
        embedded = self.dropout(self.embedding(input_token))
        
        output, hidden = self.rnn(embedded, hidden)
        
        output = self.out(output.squeeze(1))
        
        return output, hidden


class VanillaSeq2Seq(nn.Module):
    """Vanilla RNN Seq2Seq Model"""
    
    def __init__(self, encoder, decoder, device):
        super(VanillaSeq2Seq, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device
    
    def forward(self, src, src_lengths, tgt, teacher_forcing_ratio=0.5):
        """
        Args:
            src: (batch_size, src_len)
            src_lengths: (batch_size,)
            tgt: (batch_size, tgt_len)
            teacher_forcing_ratio: probability of using teacher forcing
        
        Returns:
            outputs: (batch_size, tgt_len, output_size)
        """
        batch_size = src.shape[0]
        tgt_len = tgt.shape[1]
        tgt_vocab_size = self.decoder.output_size
        
        # Tensor to store decoder outputs
        outputs = torch.zeros(batch_size, tgt_len, tgt_vocab_size).to(self.device)
        
        # Encode
        _, hidden = self.encoder(src, src_lengths)
        
        # First input to decoder is SOS token
        input_token = tgt[:, 0].unsqueeze(1)
        
        for t in range(1, tgt_len):
            output, hidden = self.decoder(input_token, hidden)
            outputs[:, t] = output
            
            # Teacher forcing
            teacher_force = random.random() < teacher_forcing_ratio
            top1 = output.argmax(1)
            input_token = tgt[:, t].unsqueeze(1) if teacher_force else top1.unsqueeze(1)
        
        return outputs
    
    def generate(self, src, src_lengths, max_length, sos_token):
        """
        Generate output sequence
        
        Args:
            src: (batch_size, src_len)
            src_lengths: (batch_size,)
            max_length: maximum length of generated sequence
            sos_token: start of sequence token
        
        Returns:
            outputs: (batch_size, max_length)
        """
        batch_size = src.shape[0]
        
        # Encode
        _, hidden = self.encoder(src, src_lengths)
        
        # Start with SOS token
        input_token = torch.tensor([[sos_token]] * batch_size).to(self.device)
        
        outputs = []
        
        for _ in range(max_length):
            output, hidden = self.decoder(input_token, hidden)
            top1 = output.argmax(1)
            outputs.append(top1.unsqueeze(1))
            input_token = top1.unsqueeze(1)
        
        return torch.cat(outputs, dim=1)


def create_vanilla_seq2seq(src_vocab_size, tgt_vocab_size, config, device):
    """Create Vanilla RNN Seq2Seq model"""
    encoder = EncoderRNN(
        src_vocab_size,
        config['model']['embedding_dim'],
        config['model']['hidden_dim'],
        config['model']['dropout']
    )
    
    decoder = DecoderRNN(
        tgt_vocab_size,
        config['model']['embedding_dim'],
        config['model']['hidden_dim'],
        config['model']['dropout']
    )
    
    model = VanillaSeq2Seq(encoder, decoder, device).to(device)
    
    return model


if __name__ == '__main__':
    # Test model
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    config = {
        'model': {
            'embedding_dim': 256,
            'hidden_dim': 256,
            'dropout': 0.3
        }
    }
    
    model = create_vanilla_seq2seq(5000, 5000, config, device)
    
    # Test forward pass
    src = torch.randint(0, 5000, (32, 20)).to(device)
    tgt = torch.randint(0, 5000, (32, 30)).to(device)
    src_lengths = torch.tensor([20] * 32)
    
    outputs = model(src, src_lengths, tgt)
    print("Output shape:", outputs.shape)
    
    # Test generation
    generated = model.generate(src, src_lengths, 30, 1)
    print("Generated shape:", generated.shape)


### models/lstm.py

In [None]:
%%writefile models/lstm.py
"""
LSTM-based Seq2Seq Model
"""
import torch
import torch.nn as nn
import random


class EncoderLSTM(nn.Module):
    """LSTM Encoder"""
    
    def __init__(self, input_size, embedding_dim, hidden_dim, dropout=0.3):
        super(EncoderLSTM, self).__init__()
        self.hidden_dim = hidden_dim
        
        self.embedding = nn.Embedding(input_size, embedding_dim, padding_idx=0)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, input_seq, input_lengths):
        """
        Args:
            input_seq: (batch_size, seq_len)
            input_lengths: (batch_size,)
        
        Returns:
            outputs: (batch_size, seq_len, hidden_dim)
            hidden: tuple of (h_n, c_n) each (1, batch_size, hidden_dim)
        """
        embedded = self.dropout(self.embedding(input_seq))
        
        # Pack padded sequence
        packed = nn.utils.rnn.pack_padded_sequence(
            embedded, input_lengths.cpu(), batch_first=True, enforce_sorted=False
        )
        
        outputs, (hidden, cell) = self.lstm(packed)
        
        # Unpack
        outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True)
        
        return outputs, (hidden, cell)


class DecoderLSTM(nn.Module):
    """LSTM Decoder"""
    
    def __init__(self, output_size, embedding_dim, hidden_dim, dropout=0.3):
        super(DecoderLSTM, self).__init__()
        self.hidden_dim = hidden_dim
        self.output_size = output_size
        
        self.embedding = nn.Embedding(output_size, embedding_dim, padding_idx=0)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)
        self.out = nn.Linear(hidden_dim, output_size)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, input_token, hidden, cell):
        """
        Args:
            input_token: (batch_size, 1)
            hidden: (1, batch_size, hidden_dim)
            cell: (1, batch_size, hidden_dim)
        
        Returns:
            output: (batch_size, output_size)
            hidden: (1, batch_size, hidden_dim)
            cell: (1, batch_size, hidden_dim)
        """
        embedded = self.dropout(self.embedding(input_token))
        
        output, (hidden, cell) = self.lstm(embedded, (hidden, cell))
        
        output = self.out(output.squeeze(1))
        
        return output, hidden, cell


class LSTMSeq2Seq(nn.Module):
    """LSTM Seq2Seq Model"""
    
    def __init__(self, encoder, decoder, device):
        super(LSTMSeq2Seq, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device
    
    def forward(self, src, src_lengths, tgt, teacher_forcing_ratio=0.5):
        """
        Args:
            src: (batch_size, src_len)
            src_lengths: (batch_size,)
            tgt: (batch_size, tgt_len)
            teacher_forcing_ratio: probability of using teacher forcing
        
        Returns:
            outputs: (batch_size, tgt_len, output_size)
        """
        batch_size = src.shape[0]
        tgt_len = tgt.shape[1]
        tgt_vocab_size = self.decoder.output_size
        
        # Tensor to store decoder outputs
        outputs = torch.zeros(batch_size, tgt_len, tgt_vocab_size).to(self.device)
        
        # Encode
        _, (hidden, cell) = self.encoder(src, src_lengths)
        
        # First input to decoder is SOS token
        input_token = tgt[:, 0].unsqueeze(1)
        
        for t in range(1, tgt_len):
            output, hidden, cell = self.decoder(input_token, hidden, cell)
            outputs[:, t] = output
            
            # Teacher forcing
            teacher_force = random.random() < teacher_forcing_ratio
            top1 = output.argmax(1)
            input_token = tgt[:, t].unsqueeze(1) if teacher_force else top1.unsqueeze(1)
        
        return outputs
    
    def generate(self, src, src_lengths, max_length, sos_token):
        """
        Generate output sequence
        
        Args:
            src: (batch_size, src_len)
            src_lengths: (batch_size,)
            max_length: maximum length of generated sequence
            sos_token: start of sequence token
        
        Returns:
            outputs: (batch_size, max_length)
        """
        batch_size = src.shape[0]
        
        # Encode
        _, (hidden, cell) = self.encoder(src, src_lengths)
        
        # Start with SOS token
        input_token = torch.tensor([[sos_token]] * batch_size).to(self.device)
        
        outputs = []
        
        for _ in range(max_length):
            output, hidden, cell = self.decoder(input_token, hidden, cell)
            top1 = output.argmax(1)
            outputs.append(top1.unsqueeze(1))
            input_token = top1.unsqueeze(1)
        
        return torch.cat(outputs, dim=1)


def create_lstm_seq2seq(src_vocab_size, tgt_vocab_size, config, device):
    """Create LSTM Seq2Seq model"""
    encoder = EncoderLSTM(
        src_vocab_size,
        config['model']['embedding_dim'],
        config['model']['hidden_dim'],
        config['model']['dropout']
    )
    
    decoder = DecoderLSTM(
        tgt_vocab_size,
        config['model']['embedding_dim'],
        config['model']['hidden_dim'],
        config['model']['dropout']
    )
    
    model = LSTMSeq2Seq(encoder, decoder, device).to(device)
    
    return model


if __name__ == '__main__':
    # Test model
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    config = {
        'model': {
            'embedding_dim': 256,
            'hidden_dim': 256,
            'dropout': 0.3
        }
    }
    
    model = create_lstm_seq2seq(5000, 5000, config, device)
    
    # Test forward pass
    src = torch.randint(0, 5000, (32, 20)).to(device)
    tgt = torch.randint(0, 5000, (32, 30)).to(device)
    src_lengths = torch.tensor([20] * 32)
    
    outputs = model(src, src_lengths, tgt)
    print("Output shape:", outputs.shape)
    
    # Test generation
    generated = model.generate(src, src_lengths, 30, 1)
    print("Generated shape:", generated.shape)


### models/attention_lstm.py

In [None]:
%%writefile models/attention_lstm.py
"""
LSTM with Bahdanau Attention Seq2Seq Model
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import random


class EncoderBiLSTM(nn.Module):
    """Bidirectional LSTM Encoder"""
    
    def __init__(self, input_size, embedding_dim, hidden_dim, dropout=0.3):
        super(EncoderBiLSTM, self).__init__()
        self.hidden_dim = hidden_dim
        
        self.embedding = nn.Embedding(input_size, embedding_dim, padding_idx=0)
        self.lstm = nn.LSTM(
            embedding_dim,
            hidden_dim,
            batch_first=True,
            bidirectional=True
        )
        self.dropout = nn.Dropout(dropout)
        
        # Linear layer to project bidirectional hidden state to decoder hidden size
        self.fc_hidden = nn.Linear(hidden_dim * 2, hidden_dim)
        self.fc_cell = nn.Linear(hidden_dim * 2, hidden_dim)
    
    def forward(self, input_seq, input_lengths):
        """
        Args:
            input_seq: (batch_size, seq_len)
            input_lengths: (batch_size,)
        
        Returns:
            outputs: (batch_size, seq_len, hidden_dim * 2)
            hidden: (1, batch_size, hidden_dim)
            cell: (1, batch_size, hidden_dim)
        """
        embedded = self.dropout(self.embedding(input_seq))
        
        # Pack padded sequence
        packed = nn.utils.rnn.pack_padded_sequence(
            embedded, input_lengths.cpu(), batch_first=True, enforce_sorted=False
        )
        
        outputs, (hidden, cell) = self.lstm(packed)
        
        # Unpack
        outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True)
        
        # hidden and cell are (2, batch_size, hidden_dim) for bidirectional
        # Concatenate forward and backward and project to decoder size
        hidden = torch.tanh(self.fc_hidden(torch.cat((hidden[0], hidden[1]), dim=1))).unsqueeze(0)
        cell = torch.tanh(self.fc_cell(torch.cat((cell[0], cell[1]), dim=1))).unsqueeze(0)
        
        return outputs, hidden, cell


class BahdanauAttention(nn.Module):
    """Bahdanau (Additive) Attention Mechanism"""
    
    def __init__(self, hidden_dim, encoder_dim):
        super(BahdanauAttention, self).__init__()
        self.hidden_dim = hidden_dim
        self.encoder_dim = encoder_dim
        
        # Attention layers
        self.attn_hidden = nn.Linear(hidden_dim, hidden_dim)
        self.attn_encoder = nn.Linear(encoder_dim, hidden_dim)
        self.v = nn.Linear(hidden_dim, 1, bias=False)
    
    def forward(self, hidden, encoder_outputs, mask=None):
        """
        Args:
            hidden: (batch_size, hidden_dim) - decoder hidden state
            encoder_outputs: (batch_size, src_len, encoder_dim) - all encoder outputs
            mask: (batch_size, src_len) - mask for padding
        
        Returns:
            context: (batch_size, encoder_dim) - weighted context vector
            attention_weights: (batch_size, src_len) - attention weights
        """
        src_len = encoder_outputs.shape[1]
        
        # Repeat hidden state src_len times
        hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)  # (batch_size, src_len, hidden_dim)
        
        # Calculate attention energy
        energy = torch.tanh(
            self.attn_hidden(hidden) + self.attn_encoder(encoder_outputs)
        )  # (batch_size, src_len, hidden_dim)
        
        # Calculate attention scores
        attention = self.v(energy).squeeze(2)  # (batch_size, src_len)
        
        # Apply mask if provided
        if mask is not None:
            attention = attention.masked_fill(mask == 0, -1e10)
        
        # Softmax to get attention weights
        attention_weights = F.softmax(attention, dim=1)  # (batch_size, src_len)
        
        # Calculate context vector
        context = torch.bmm(
            attention_weights.unsqueeze(1),
            encoder_outputs
        ).squeeze(1)  # (batch_size, encoder_dim)
        
        return context, attention_weights


class AttentionDecoderLSTM(nn.Module):
    """LSTM Decoder with Attention"""
    
    def __init__(self, output_size, embedding_dim, hidden_dim, encoder_dim, dropout=0.3):
        super(AttentionDecoderLSTM, self).__init__()
        self.hidden_dim = hidden_dim
        self.output_size = output_size
        self.encoder_dim = encoder_dim
        
        self.embedding = nn.Embedding(output_size, embedding_dim, padding_idx=0)
        self.attention = BahdanauAttention(hidden_dim, encoder_dim)
        
        # LSTM input is embedding + context vector
        self.lstm = nn.LSTM(embedding_dim + encoder_dim, hidden_dim, batch_first=True)
        
        # Output layer
        self.out = nn.Linear(hidden_dim + encoder_dim + embedding_dim, output_size)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, input_token, hidden, cell, encoder_outputs, mask=None):
        """
        Args:
            input_token: (batch_size, 1)
            hidden: (1, batch_size, hidden_dim)
            cell: (1, batch_size, hidden_dim)
            encoder_outputs: (batch_size, src_len, encoder_dim)
            mask: (batch_size, src_len)
        
        Returns:
            output: (batch_size, output_size)
            hidden: (1, batch_size, hidden_dim)
            cell: (1, batch_size, hidden_dim)
            attention_weights: (batch_size, src_len)
        """
        embedded = self.dropout(self.embedding(input_token))  # (batch_size, 1, embedding_dim)
        
        # Calculate attention
        context, attention_weights = self.attention(
            hidden.squeeze(0),
            encoder_outputs,
            mask
        )  # context: (batch_size, encoder_dim)
        
        # Concatenate embedding and context
        lstm_input = torch.cat((embedded, context.unsqueeze(1)), dim=2)  # (batch_size, 1, embedding_dim + encoder_dim)
        
        # LSTM forward
        output, (hidden, cell) = self.lstm(lstm_input, (hidden, cell))
        
        # Concatenate output, context, and embedding for final prediction
        output = output.squeeze(1)  # (batch_size, hidden_dim)
        embedded = embedded.squeeze(1)  # (batch_size, embedding_dim)
        
        pred_input = torch.cat((output, context, embedded), dim=1)  # (batch_size, hidden_dim + encoder_dim + embedding_dim)
        prediction = self.out(pred_input)  # (batch_size, output_size)
        
        return prediction, hidden, cell, attention_weights


class AttentionSeq2Seq(nn.Module):
    """LSTM with Attention Seq2Seq Model"""
    
    def __init__(self, encoder, decoder, device):
        super(AttentionSeq2Seq, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device
    
    def create_mask(self, src, src_lengths):
        """Create mask for padding"""
        mask = torch.zeros_like(src, dtype=torch.bool)
        for i, length in enumerate(src_lengths):
            mask[i, :length] = 1
        return mask
    
    def forward(self, src, src_lengths, tgt, teacher_forcing_ratio=0.5):
        """
        Args:
            src: (batch_size, src_len)
            src_lengths: (batch_size,)
            tgt: (batch_size, tgt_len)
            teacher_forcing_ratio: probability of using teacher forcing
        
        Returns:
            outputs: (batch_size, tgt_len, output_size)
            attentions: (batch_size, tgt_len, src_len)
        """
        batch_size = src.shape[0]
        tgt_len = tgt.shape[1]
        tgt_vocab_size = self.decoder.output_size
        src_len = src.shape[1]
        
        # Tensors to store decoder outputs and attention weights
        outputs = torch.zeros(batch_size, tgt_len, tgt_vocab_size).to(self.device)
        attentions = torch.zeros(batch_size, tgt_len, src_len).to(self.device)
        
        # Create mask
        mask = self.create_mask(src, src_lengths).to(self.device)
        
        # Encode
        encoder_outputs, hidden, cell = self.encoder(src, src_lengths)
        
        # First input to decoder is SOS token
        input_token = tgt[:, 0].unsqueeze(1)
        
        for t in range(1, tgt_len):
            output, hidden, cell, attention_weights = self.decoder(
                input_token, hidden, cell, encoder_outputs, mask
            )
            outputs[:, t] = output
            attentions[:, t] = attention_weights
            
            # Teacher forcing
            teacher_force = random.random() < teacher_forcing_ratio
            top1 = output.argmax(1)
            input_token = tgt[:, t].unsqueeze(1) if teacher_force else top1.unsqueeze(1)
        
        return outputs, attentions
    
    def generate(self, src, src_lengths, max_length, sos_token):
        """
        Generate output sequence
        
        Args:
            src: (batch_size, src_len)
            src_lengths: (batch_size,)
            max_length: maximum length of generated sequence
            sos_token: start of sequence token
        
        Returns:
            outputs: (batch_size, max_length)
            attentions: (batch_size, max_length, src_len)
        """
        batch_size = src.shape[0]
        src_len = src.shape[1]
        
        # Create mask
        mask = self.create_mask(src, src_lengths).to(self.device)
        
        # Encode
        encoder_outputs, hidden, cell = self.encoder(src, src_lengths)
        
        # Start with SOS token
        input_token = torch.tensor([[sos_token]] * batch_size).to(self.device)
        
        outputs = []
        attentions = torch.zeros(batch_size, max_length, src_len).to(self.device)
        
        for t in range(max_length):
            output, hidden, cell, attention_weights = self.decoder(
                input_token, hidden, cell, encoder_outputs, mask
            )
            attentions[:, t] = attention_weights
            top1 = output.argmax(1)
            outputs.append(top1.unsqueeze(1))
            input_token = top1.unsqueeze(1)
        
        return torch.cat(outputs, dim=1), attentions


def create_attention_seq2seq(src_vocab_size, tgt_vocab_size, config, device):
    """Create LSTM with Attention Seq2Seq model"""
    encoder = EncoderBiLSTM(
        src_vocab_size,
        config['model']['embedding_dim'],
        config['model']['hidden_dim'],
        config['model']['dropout']
    )
    
    decoder = AttentionDecoderLSTM(
        tgt_vocab_size,
        config['model']['embedding_dim'],
        config['model']['hidden_dim'],
        config['model']['hidden_dim'] * 2,  # bidirectional encoder
        config['model']['dropout']
    )
    
    model = AttentionSeq2Seq(encoder, decoder, device).to(device)
    
    return model


if __name__ == '__main__':
    # Test model
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    config = {
        'model': {
            'embedding_dim': 256,
            'hidden_dim': 256,
            'dropout': 0.3
        }
    }
    
    model = create_attention_seq2seq(5000, 5000, config, device)
    
    # Test forward pass
    src = torch.randint(0, 5000, (32, 20)).to(device)
    tgt = torch.randint(0, 5000, (32, 30)).to(device)
    src_lengths = torch.tensor([20] * 32)
    
    outputs, attentions = model(src, src_lengths, tgt)
    print("Output shape:", outputs.shape)
    print("Attention shape:", attentions.shape)
    
    # Test generation
    generated, gen_attentions = model.generate(src, src_lengths, 30, 1)
    print("Generated shape:", generated.shape)
    print("Generated attention shape:", gen_attentions.shape)


### data_loader.py

In [None]:
%%writefile data_loader.py
"""
Data loading and preprocessing for CodeSearchNet dataset
"""
import torch
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
from collections import Counter
import re
from typing import List, Tuple, Dict
import pickle
import os


class Vocabulary:
    """Builds and manages vocabulary for source and target sequences"""
    
    def __init__(self, max_vocab_size=10000):
        self.word2idx = {}
        self.idx2word = {}
        self.word_counts = Counter()
        self.max_vocab_size = max_vocab_size
        
        # Special tokens
        self.PAD_token = 0
        self.SOS_token = 1
        self.EOS_token = 2
        self.UNK_token = 3
        
        self.word2idx['<PAD>'] = self.PAD_token
        self.word2idx['<SOS>'] = self.SOS_token
        self.word2idx['<EOS>'] = self.EOS_token
        self.word2idx['<UNK>'] = self.UNK_token
        
        self.idx2word[self.PAD_token] = '<PAD>'
        self.idx2word[self.SOS_token] = '<SOS>'
        self.idx2word[self.EOS_token] = '<EOS>'
        self.idx2word[self.UNK_token] = '<UNK>'
        
        self.n_words = 4
    
    def add_sentence(self, sentence: str):
        """Add all words in a sentence to vocabulary"""
        for word in self.tokenize(sentence):
            self.word_counts[word] += 1
    
    def build_vocab(self):
        """Build vocabulary from word counts, keeping most common words"""
        # Get most common words
        most_common = self.word_counts.most_common(self.max_vocab_size - 4)
        
        for word, _ in most_common:
            if word not in self.word2idx:
                self.word2idx[word] = self.n_words
                self.idx2word[self.n_words] = word
                self.n_words += 1
    
    @staticmethod
    def tokenize(text: str) -> List[str]:
        """Simple whitespace tokenization with some preprocessing"""
        # Basic preprocessing
        text = text.lower().strip()
        # Split on whitespace and basic punctuation
        tokens = re.findall(r'\w+|[^\w\s]', text)
        return tokens
    
    def encode(self, sentence: str) -> List[int]:
        """Convert sentence to list of indices"""
        tokens = self.tokenize(sentence)
        return [self.word2idx.get(token, self.UNK_token) for token in tokens]
    
    def decode(self, indices: List[int]) -> str:
        """Convert list of indices back to sentence"""
        words = []
        for idx in indices:
            if idx == self.EOS_token:
                break
            if idx not in [self.PAD_token, self.SOS_token]:
                words.append(self.idx2word.get(idx, '<UNK>'))
        return ' '.join(words)
    
    def save(self, path: str):
        """Save vocabulary to file"""
        with open(path, 'wb') as f:
            pickle.dump({
                'word2idx': self.word2idx,
                'idx2word': self.idx2word,
                'word_counts': self.word_counts,
                'max_vocab_size': self.max_vocab_size,
                'n_words': self.n_words
            }, f)
    
    def load(self, path: str):
        """Load vocabulary from file"""
        with open(path, 'rb') as f:
            data = pickle.load(f)
            self.word2idx = data['word2idx']
            self.idx2word = data['idx2word']
            self.word_counts = data['word_counts']
            self.max_vocab_size = data['max_vocab_size']
            self.n_words = data['n_words']


class CodeSearchNetDataset(Dataset):
    """PyTorch Dataset for CodeSearchNet data"""
    
    def __init__(self, data, src_vocab, tgt_vocab, max_src_len, max_tgt_len):
        self.data = data
        self.src_vocab = src_vocab
        self.tgt_vocab = tgt_vocab
        self.max_src_len = max_src_len
        self.max_tgt_len = max_tgt_len
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        
        # Encode source (docstring)
        src_indices = self.src_vocab.encode(item['docstring'])
        # Truncate if too long
        src_indices = src_indices[:self.max_src_len]
        # Add EOS token
        src_indices.append(self.src_vocab.EOS_token)
        
        # Encode target (code)
        tgt_indices = self.tgt_vocab.encode(item['code'])
        # Truncate if too long
        tgt_indices = tgt_indices[:self.max_tgt_len]
        # Add SOS and EOS tokens
        tgt_indices = [self.tgt_vocab.SOS_token] + tgt_indices + [self.tgt_vocab.EOS_token]
        
        return {
            'src': torch.tensor(src_indices, dtype=torch.long),
            'tgt': torch.tensor(tgt_indices, dtype=torch.long),
            'src_text': item['docstring'],
            'tgt_text': item['code']
        }


def collate_fn(batch):
    """Custom collate function to pad sequences in a batch"""
    src_batch = [item['src'] for item in batch]
    tgt_batch = [item['tgt'] for item in batch]
    src_texts = [item['src_text'] for item in batch]
    tgt_texts = [item['tgt_text'] for item in batch]
    
    # Pad sequences
    src_lengths = torch.tensor([len(s) for s in src_batch])
    tgt_lengths = torch.tensor([len(t) for t in tgt_batch])
    
    src_padded = torch.nn.utils.rnn.pad_sequence(src_batch, batch_first=True, padding_value=0)
    tgt_padded = torch.nn.utils.rnn.pad_sequence(tgt_batch, batch_first=True, padding_value=0)
    
    return {
        'src': src_padded,
        'tgt': tgt_padded,
        'src_lengths': src_lengths,
        'tgt_lengths': tgt_lengths,
        'src_texts': src_texts,
        'tgt_texts': tgt_texts
    }


def load_and_prepare_data(config):
    """
    Load CodeSearchNet dataset and prepare vocabularies
    
    Args:
        config: Configuration dictionary
    
    Returns:
        tuple: (train_loader, val_loader, test_loader, src_vocab, tgt_vocab)
    """
    print("Loading CodeSearchNet dataset...")
    
    # Load dataset from Hugging Face
    dataset = load_dataset(
        config['dataset']['name'],
        split='train',
        cache_dir=config['dataset']['cache_dir']
    )
    
    print(f"Total dataset size: {len(dataset)}")
    
    # Take larger subset initially (we'll filter and then select what we need)
    total_needed = config['dataset']['train_size'] + config['dataset']['val_size'] + config['dataset']['test_size']
    # Get 5x more than needed to account for filtering
    initial_sample = min(total_needed * 5, len(dataset))
    dataset = dataset.shuffle(seed=42).select(range(initial_sample))
    
    # Filter out examples that are too long or empty
    def filter_fn(example):
        try:
            # Handle different possible field names
            doc = example.get('func_documentation_string') or example.get('docstring') or example.get('doc')
            code = example.get('func_code_string') or example.get('code') or example.get('function')
            
            if not doc or not code:
                return False
            
            # Check if strings are not empty after stripping
            doc = str(doc).strip()
            code = str(code).strip()
            
            if not doc or not code:
                return False
            
            doc_len = len(Vocabulary.tokenize(doc))
            code_len = len(Vocabulary.tokenize(code))
            
            return (doc_len > 2 and doc_len <= config['dataset']['max_docstring_length'] and
                    code_len > 2 and code_len <= config['dataset']['max_code_length'])
        except Exception as e:
            return False
    
    print("Filtering dataset...")
    dataset = dataset.filter(filter_fn)
    print(f"After filtering: {len(dataset)} examples")
    
    # Check if we have enough data
    if len(dataset) == 0:
        raise ValueError("No examples passed filtering! The dataset might have different field names or all examples were too long.")
    
    # Convert to simpler format
    processed_data = []
    for item in dataset:
        # Handle different possible field names
        doc = item.get('func_documentation_string') or item.get('docstring') or item.get('doc')
        code = item.get('func_code_string') or item.get('code') or item.get('function')
        
        processed_data.append({
            'docstring': str(doc).strip(),
            'code': str(code).strip()
        })
        
        # Stop if we have enough examples
        if len(processed_data) >= total_needed:
            break
    
    # Split into train/val/test
    train_size = config['dataset']['train_size']
    val_size = config['dataset']['val_size']
    
    train_data = processed_data[:train_size]
    val_data = processed_data[train_size:train_size + val_size]
    test_data = processed_data[train_size + val_size:train_size + val_size + config['dataset']['test_size']]
    
    print(f"Train: {len(train_data)}, Val: {len(val_data)}, Test: {len(test_data)}")
    
    # Check if we have enough data
    if len(train_data) == 0:
        raise ValueError(f"No training data! Got {len(processed_data)} examples after filtering, but needed at least {train_size}. Try reducing the dataset sizes in config.yaml or increasing max_docstring_length/max_code_length.")
    
    if len(train_data) < train_size:
        print(f"⚠️  Warning: Only got {len(train_data)} training examples (requested {train_size})")
        print(f"   Continuing with available data...")
    
    # Build vocabularies
    print("Building vocabularies...")
    src_vocab = Vocabulary(max_vocab_size=config['model']['max_vocab_size'])
    tgt_vocab = Vocabulary(max_vocab_size=config['model']['max_vocab_size'])
    
    # Add all sentences to vocabulary
    for item in train_data:
        src_vocab.add_sentence(item['docstring'])
        tgt_vocab.add_sentence(item['code'])
    
    src_vocab.build_vocab()
    tgt_vocab.build_vocab()
    
    print(f"Source vocabulary size: {src_vocab.n_words}")
    print(f"Target vocabulary size: {tgt_vocab.n_words}")
    
    # Create datasets
    train_dataset = CodeSearchNetDataset(
        train_data, src_vocab, tgt_vocab,
        config['dataset']['max_docstring_length'],
        config['dataset']['max_code_length']
    )
    val_dataset = CodeSearchNetDataset(
        val_data, src_vocab, tgt_vocab,
        config['dataset']['max_docstring_length'],
        config['dataset']['max_code_length']
    )
    test_dataset = CodeSearchNetDataset(
        test_data, src_vocab, tgt_vocab,
        config['dataset']['max_docstring_length'],
        config['dataset']['max_code_length']
    )
    
    # Create data loaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=config['training']['batch_size'],
        shuffle=True,
        collate_fn=collate_fn,
        num_workers=0
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=config['training']['batch_size'],
        shuffle=False,
        collate_fn=collate_fn,
        num_workers=0
    )
    test_loader = DataLoader(
        test_dataset,
        batch_size=config['training']['batch_size'],
        shuffle=False,
        collate_fn=collate_fn,
        num_workers=0
    )
    
    # Save vocabularies
    os.makedirs(config['paths']['checkpoints'], exist_ok=True)
    src_vocab.save(os.path.join(config['paths']['checkpoints'], 'src_vocab.pkl'))
    tgt_vocab.save(os.path.join(config['paths']['checkpoints'], 'tgt_vocab.pkl'))
    
    return train_loader, val_loader, test_loader, src_vocab, tgt_vocab


if __name__ == '__main__':
    import yaml
    
    # Test data loading
    with open('config.yaml', 'r') as f:
        config = yaml.safe_load(f)
    
    train_loader, val_loader, test_loader, src_vocab, tgt_vocab = load_and_prepare_data(config)
    
    # Print sample batch
    for batch in train_loader:
        print("Source shape:", batch['src'].shape)
        print("Target shape:", batch['tgt'].shape)
        print("Source text:", batch['src_texts'][0])
        print("Target text:", batch['tgt_texts'][0])
        break


### train.py

In [None]:
%%writefile train.py
"""
Training script for Seq2Seq models
"""
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import yaml
import os
import argparse
import json
import numpy as np

from data_loader import load_and_prepare_data, Vocabulary
from models import create_vanilla_seq2seq, create_lstm_seq2seq, create_attention_seq2seq


class Trainer:
    """Trainer class for Seq2Seq models"""
    
    def __init__(self, model, config, device, model_name, src_vocab, tgt_vocab):
        self.model = model
        self.config = config
        self.device = device
        self.model_name = model_name
        self.src_vocab = src_vocab
        self.tgt_vocab = tgt_vocab
        
        # Loss and optimizer
        self.criterion = nn.CrossEntropyLoss(ignore_index=0)  # ignore padding
        self.optimizer = optim.Adam(
            model.parameters(),
            lr=config['training']['learning_rate']
        )
        
        # Training history
        self.train_losses = []
        self.val_losses = []
        
        # Create directories
        self.checkpoint_dir = os.path.join(config['paths']['checkpoints'], model_name)
        self.log_dir = os.path.join(config['paths']['logs'], model_name)
        os.makedirs(self.checkpoint_dir, exist_ok=True)
        os.makedirs(self.log_dir, exist_ok=True)
    
    def train_epoch(self, train_loader, epoch):
        """Train for one epoch"""
        self.model.train()
        epoch_loss = 0
        
        progress_bar = tqdm(train_loader, desc=f'Epoch {epoch+1}')
        
        for batch_idx, batch in enumerate(progress_bar):
            src = batch['src'].to(self.device)
            tgt = batch['tgt'].to(self.device)
            src_lengths = batch['src_lengths']
            
            self.optimizer.zero_grad()
            
            # Forward pass
            if 'attention' in self.model_name:
                outputs, _ = self.model(
                    src, src_lengths, tgt,
                    teacher_forcing_ratio=self.config['training']['teacher_forcing_ratio']
                )
            else:
                outputs = self.model(
                    src, src_lengths, tgt,
                    teacher_forcing_ratio=self.config['training']['teacher_forcing_ratio']
                )
            
            # Calculate loss
            # Reshape: outputs (batch_size, tgt_len, vocab_size) -> (batch_size * tgt_len, vocab_size)
            # tgt (batch_size, tgt_len) -> (batch_size * tgt_len)
            output_dim = outputs.shape[-1]
            outputs = outputs[:, 1:].contiguous().view(-1, output_dim)
            tgt = tgt[:, 1:].contiguous().view(-1)
            
            loss = self.criterion(outputs, tgt)
            
            # Backward pass
            loss.backward()
            
            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(
                self.model.parameters(),
                self.config['training']['gradient_clip']
            )
            
            self.optimizer.step()
            
            epoch_loss += loss.item()
            
            progress_bar.set_postfix({'loss': loss.item()})
        
        return epoch_loss / len(train_loader)
    
    def evaluate(self, val_loader):
        """Evaluate on validation set"""
        self.model.eval()
        epoch_loss = 0
        
        with torch.no_grad():
            for batch in tqdm(val_loader, desc='Validation'):
                src = batch['src'].to(self.device)
                tgt = batch['tgt'].to(self.device)
                src_lengths = batch['src_lengths']
                
                # Forward pass
                if 'attention' in self.model_name:
                    outputs, _ = self.model(src, src_lengths, tgt, teacher_forcing_ratio=0)
                else:
                    outputs = self.model(src, src_lengths, tgt, teacher_forcing_ratio=0)
                
                # Calculate loss
                output_dim = outputs.shape[-1]
                outputs = outputs[:, 1:].contiguous().view(-1, output_dim)
                tgt = tgt[:, 1:].contiguous().view(-1)
                
                loss = self.criterion(outputs, tgt)
                epoch_loss += loss.item()
        
        return epoch_loss / len(val_loader)
    
    def train(self, train_loader, val_loader, num_epochs):
        """Full training loop"""
        print(f"\nTraining {self.model_name}...")
        print(f"Model parameters: {sum(p.numel() for p in self.model.parameters()):,}")
        
        best_val_loss = float('inf')
        
        for epoch in range(num_epochs):
            train_loss = self.train_epoch(train_loader, epoch)
            val_loss = self.evaluate(val_loader)
            
            self.train_losses.append(train_loss)
            self.val_losses.append(val_loss)
            
            print(f'Epoch {epoch+1}/{num_epochs}:')
            print(f'  Train Loss: {train_loss:.4f}')
            print(f'  Val Loss:   {val_loss:.4f}')
            
            # Save checkpoint
            if (epoch + 1) % self.config['training']['save_every'] == 0:
                self.save_checkpoint(epoch, val_loss)
            
            # Save best model
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                self.save_checkpoint(epoch, val_loss, is_best=True)
                print(f'  New best model saved!')
        
        # Save training history
        self.save_training_history()
        
        print(f'\nTraining completed for {self.model_name}!')
        print(f'Best validation loss: {best_val_loss:.4f}')
    
    def save_checkpoint(self, epoch, val_loss, is_best=False):
        """Save model checkpoint"""
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'val_loss': val_loss,
            'train_losses': self.train_losses,
            'val_losses': self.val_losses,
            'config': self.config,
            'src_vocab': self.src_vocab,
            'tgt_vocab': self.tgt_vocab
        }
        
        if is_best:
            path = os.path.join(self.checkpoint_dir, 'best_model.pt')
        else:
            path = os.path.join(self.checkpoint_dir, f'checkpoint_epoch_{epoch+1}.pt')
        
        torch.save(checkpoint, path)
    
    def save_training_history(self):
        """Save training history"""
        history = {
            'train_losses': self.train_losses,
            'val_losses': self.val_losses
        }
        
        path = os.path.join(self.log_dir, 'training_history.json')
        with open(path, 'w') as f:
            json.dump(history, f, indent=2)


def main():
    parser = argparse.ArgumentParser(description='Train Seq2Seq models')
    parser.add_argument('--config', type=str, default='config.yaml',
                        help='Path to config file')
    parser.add_argument('--model', type=str, default='all',
                        choices=['vanilla', 'lstm', 'attention', 'all'],
                        help='Which model to train')
    args = parser.parse_args()
    
    # Load config
    with open(args.config, 'r') as f:
        config = yaml.safe_load(f)
    
    # Set device
    device = torch.device(config['device'] if torch.cuda.is_available() else 'cpu')
    print(f'Using device: {device}')
    
    # Load data
    print('Loading data...')
    train_loader, val_loader, test_loader, src_vocab, tgt_vocab = load_and_prepare_data(config)
    
    src_vocab_size = src_vocab.n_words
    tgt_vocab_size = tgt_vocab.n_words
    
    print(f'Source vocabulary size: {src_vocab_size}')
    print(f'Target vocabulary size: {tgt_vocab_size}')
    
    # Determine which models to train
    models_to_train = []
    if args.model == 'all':
        models_to_train = ['vanilla', 'lstm', 'attention']
    else:
        models_to_train = [args.model]
    
    # Train models
    for model_name in models_to_train:
        print(f'\n{"="*60}')
        print(f'Training {model_name.upper()} model')
        print(f'{"="*60}')
        
        # Create model
        if model_name == 'vanilla':
            model = create_vanilla_seq2seq(src_vocab_size, tgt_vocab_size, config, device)
        elif model_name == 'lstm':
            model = create_lstm_seq2seq(src_vocab_size, tgt_vocab_size, config, device)
        elif model_name == 'attention':
            model = create_attention_seq2seq(src_vocab_size, tgt_vocab_size, config, device)
        
        # Create trainer and train
        trainer = Trainer(model, config, device, model_name, src_vocab, tgt_vocab)
        trainer.train(train_loader, val_loader, config['training']['num_epochs'])
    
    print('\n' + '='*60)
    print('All training completed!')
    print('='*60)


if __name__ == '__main__':
    main()


### evaluate.py

In [None]:
%%writefile evaluate.py
"""
Evaluation metrics and testing script
"""
import torch
import yaml
import os
import argparse
import json
import numpy as np
from tqdm import tqdm
from sacrebleu.metrics import BLEU
from collections import defaultdict

from data_loader import load_and_prepare_data, Vocabulary
from models import create_vanilla_seq2seq, create_lstm_seq2seq, create_attention_seq2seq


class Evaluator:
    """Evaluation class for Seq2Seq models"""
    
    def __init__(self, model, config, device, model_name, src_vocab, tgt_vocab):
        self.model = model
        self.config = config
        self.device = device
        self.model_name = model_name
        self.src_vocab = src_vocab
        self.tgt_vocab = tgt_vocab
        
        # BLEU metric
        self.bleu = BLEU()
    
    def load_checkpoint(self, checkpoint_path):
        """Load model checkpoint"""
        checkpoint = torch.load(checkpoint_path, map_location=self.device)
        self.model.load_state_dict(checkpoint['model_state_dict'])
        print(f'Loaded checkpoint from epoch {checkpoint["epoch"]+1}')
        print(f'Validation loss: {checkpoint["val_loss"]:.4f}')
    
    def generate_sequences(self, data_loader):
        """Generate sequences for all examples in data loader"""
        self.model.eval()
        
        all_predictions = []
        all_references = []
        all_src_texts = []
        
        with torch.no_grad():
            for batch in tqdm(data_loader, desc='Generating'):
                src = batch['src'].to(self.device)
                src_lengths = batch['src_lengths']
                src_texts = batch['src_texts']
                tgt_texts = batch['tgt_texts']
                
                # Generate
                max_length = self.config['dataset']['max_code_length'] + 2
                if 'attention' in self.model_name:
                    generated, _ = self.model.generate(
                        src, src_lengths, max_length, self.tgt_vocab.SOS_token
                    )
                else:
                    generated = self.model.generate(
                        src, src_lengths, max_length, self.tgt_vocab.SOS_token
                    )
                
                # Decode generated sequences
                for i in range(generated.shape[0]):
                    pred_indices = generated[i].cpu().tolist()
                    pred_text = self.tgt_vocab.decode(pred_indices)
                    
                    all_predictions.append(pred_text)
                    all_references.append(tgt_texts[i])
                    all_src_texts.append(src_texts[i])
        
        return all_predictions, all_references, all_src_texts
    
    def calculate_token_accuracy(self, predictions, references):
        """Calculate token-level accuracy"""
        correct_tokens = 0
        total_tokens = 0
        
        for pred, ref in zip(predictions, references):
            pred_tokens = self.tgt_vocab.tokenize(pred)
            ref_tokens = self.tgt_vocab.tokenize(ref)
            
            # Compare tokens up to the length of prediction
            min_len = min(len(pred_tokens), len(ref_tokens))
            for i in range(min_len):
                if pred_tokens[i] == ref_tokens[i]:
                    correct_tokens += 1
                total_tokens += 1
            
            # Add penalty for length mismatch
            total_tokens += abs(len(pred_tokens) - len(ref_tokens))
        
        return correct_tokens / total_tokens if total_tokens > 0 else 0
    
    def calculate_exact_match(self, predictions, references):
        """Calculate exact match accuracy"""
        exact_matches = 0
        
        for pred, ref in zip(predictions, references):
            # Normalize whitespace for comparison
            pred_norm = ' '.join(pred.split())
            ref_norm = ' '.join(ref.split())
            
            if pred_norm == ref_norm:
                exact_matches += 1
        
        return exact_matches / len(predictions) if predictions else 0
    
    def calculate_bleu(self, predictions, references):
        """Calculate BLEU score"""
        # Format references as list of lists (sacrebleu format)
        refs = [[ref] for ref in references]
        
        try:
            bleu_score = self.bleu.corpus_score(predictions, list(zip(*refs)))
            return bleu_score.score
        except Exception as e:
            print(f"Error calculating BLEU: {e}")
            return 0.0
    
    def analyze_by_length(self, predictions, references, src_texts):
        """Analyze performance by source sequence length"""
        length_buckets = defaultdict(lambda: {'predictions': [], 'references': []})
        
        for pred, ref, src in zip(predictions, references, src_texts):
            src_len = len(self.src_vocab.tokenize(src))
            
            # Categorize by length
            if src_len <= 10:
                bucket = '0-10'
            elif src_len <= 20:
                bucket = '11-20'
            elif src_len <= 30:
                bucket = '21-30'
            else:
                bucket = '31+'
            
            length_buckets[bucket]['predictions'].append(pred)
            length_buckets[bucket]['references'].append(ref)
        
        # Calculate metrics for each bucket
        results = {}
        for bucket, data in sorted(length_buckets.items()):
            if data['predictions']:
                bleu = self.calculate_bleu(data['predictions'], data['references'])
                token_acc = self.calculate_token_accuracy(data['predictions'], data['references'])
                exact_match = self.calculate_exact_match(data['predictions'], data['references'])
                
                results[bucket] = {
                    'count': len(data['predictions']),
                    'bleu': bleu,
                    'token_accuracy': token_acc,
                    'exact_match': exact_match
                }
        
        return results
    
    def evaluate(self, data_loader):
        """Full evaluation"""
        print(f'\nEvaluating {self.model_name}...')
        
        # Generate predictions
        predictions, references, src_texts = self.generate_sequences(data_loader)
        
        # Calculate overall metrics
        bleu_score = self.calculate_bleu(predictions, references)
        token_accuracy = self.calculate_token_accuracy(predictions, references)
        exact_match = self.calculate_exact_match(predictions, references)
        
        print(f'\nOverall Results:')
        print(f'  BLEU Score:        {bleu_score:.2f}')
        print(f'  Token Accuracy:    {token_accuracy*100:.2f}%')
        print(f'  Exact Match:       {exact_match*100:.2f}%')
        
        # Analyze by length
        length_analysis = self.analyze_by_length(predictions, references, src_texts)
        
        print(f'\nResults by Source Length:')
        for bucket, metrics in sorted(length_analysis.items()):
            print(f'  Length {bucket}: (n={metrics["count"]})')
            print(f'    BLEU:          {metrics["bleu"]:.2f}')
            print(f'    Token Acc:     {metrics["token_accuracy"]*100:.2f}%')
            print(f'    Exact Match:   {metrics["exact_match"]*100:.2f}%')
        
        # Save results
        results = {
            'model': self.model_name,
            'overall': {
                'bleu': float(bleu_score),
                'token_accuracy': float(token_accuracy),
                'exact_match': float(exact_match)
            },
            'by_length': {k: {mk: float(mv) for mk, mv in v.items()} 
                         for k, v in length_analysis.items()},
            'examples': self.get_example_predictions(predictions, references, src_texts, n=10)
        }
        
        # Save to file
        results_dir = os.path.join(self.config['paths']['results'], self.model_name)
        os.makedirs(results_dir, exist_ok=True)
        
        with open(os.path.join(results_dir, 'evaluation_results.json'), 'w') as f:
            json.dump(results, f, indent=2)
        
        print(f'\nResults saved to {results_dir}')
        
        return results
    
    def get_example_predictions(self, predictions, references, src_texts, n=10):
        """Get example predictions for analysis"""
        examples = []
        
        indices = np.random.choice(len(predictions), min(n, len(predictions)), replace=False)
        
        for idx in indices:
            examples.append({
                'source': src_texts[idx],
                'reference': references[idx],
                'prediction': predictions[idx]
            })
        
        return examples


def main():
    parser = argparse.ArgumentParser(description='Evaluate Seq2Seq models')
    parser.add_argument('--config', type=str, default='config.yaml',
                        help='Path to config file')
    parser.add_argument('--model', type=str, default='all',
                        choices=['vanilla', 'lstm', 'attention', 'all'],
                        help='Which model to evaluate')
    parser.add_argument('--split', type=str, default='test',
                        choices=['val', 'test'],
                        help='Which split to evaluate on')
    args = parser.parse_args()
    
    # Load config
    with open(args.config, 'r') as f:
        config = yaml.safe_load(f)
    
    # Set device
    device = torch.device(config['device'] if torch.cuda.is_available() else 'cpu')
    print(f'Using device: {device}')
    
    # Vocabularies will be loaded from checkpoints
    src_vocab = None
    tgt_vocab = None
    
    # Determine which models to evaluate
    models_to_eval = []
    if args.model == 'all':
        models_to_eval = ['vanilla', 'lstm', 'attention']
    else:
        models_to_eval = [args.model]
    
    all_results = {}
    
    # Evaluate models
    for model_name in models_to_eval:
        print(f'\n{"="*60}')
        print(f'Evaluating {model_name.upper()} model')
        print(f'{"="*60}')
        
        # Load checkpoint first to get the config it was trained with
        checkpoint_path = os.path.join(
            config['paths']['checkpoints'],
            model_name,
            'best_model.pt'
        )
        
        if not os.path.exists(checkpoint_path):
            print(f'Checkpoint not found: {checkpoint_path}')
            continue
        
        # Load checkpoint to get training config and vocabularies
        print(f'Loading checkpoint: {checkpoint_path}')
        checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
        
        # Extract vocabularies from checkpoint
        if 'src_vocab' in checkpoint and 'tgt_vocab' in checkpoint:
            src_vocab = checkpoint['src_vocab']
            tgt_vocab = checkpoint['tgt_vocab']
            print(f'Loaded vocabularies from checkpoint: src={src_vocab.n_words}, tgt={tgt_vocab.n_words}')
        else:
            print('Error: Vocabularies not found in checkpoint!')
            print('Please retrain the model with the updated train.py')
            continue
        
        # Use the config from checkpoint (has correct model dimensions)
        model_config = checkpoint.get('config', config)
        print(f'Using embedding_dim={model_config["model"]["embedding_dim"]}, hidden_dim={model_config["model"]["hidden_dim"]}')
        
        src_vocab_size = src_vocab.n_words
        tgt_vocab_size = tgt_vocab.n_words
        
        # Create model with checkpoint's config and vocab sizes
        if model_name == 'vanilla':
            model = create_vanilla_seq2seq(src_vocab_size, tgt_vocab_size, model_config, device)
        elif model_name == 'lstm':
            model = create_lstm_seq2seq(src_vocab_size, tgt_vocab_size, model_config, device)
        elif model_name == 'attention':
            model = create_attention_seq2seq(src_vocab_size, tgt_vocab_size, model_config, device)
        
        # Load the state dict
        model.load_state_dict(checkpoint['model_state_dict'])
        print(f'Loaded checkpoint from epoch {checkpoint["epoch"]+1}')
        print(f'Validation loss: {checkpoint["val_loss"]:.4f}')
        
        # Load evaluation data with checkpoint's vocabularies (only for first model)
        if 'data_loader' not in locals():
            print('Loading evaluation data...')
            from datasets import load_dataset
            from data_loader import CodeSearchNetDataset, collate_fn
            from torch.utils.data import DataLoader
            
            dataset = load_dataset(
                model_config['dataset']['name'],
                split='train',
                cache_dir=model_config['dataset']['cache_dir']
            )
            
            total_needed = model_config['dataset']['train_size'] + model_config['dataset']['val_size'] + model_config['dataset']['test_size']
            dataset = dataset.shuffle(seed=42).select(range(min(total_needed, len(dataset))))
            
            splits = dataset.train_test_split(test_size=model_config['dataset']['val_size'] + model_config['dataset']['test_size'], seed=42)
            temp_splits = splits['test'].train_test_split(test_size=model_config['dataset']['test_size'], seed=42)
            
            eval_data = temp_splits['test'] if args.split == 'test' else temp_splits['train']
            max_src_len = model_config['dataset'].get('max_docstring_length', 50)
            max_tgt_len = model_config['dataset'].get('max_code_length', 100)
            eval_dataset = CodeSearchNetDataset(eval_data, src_vocab, tgt_vocab, max_src_len, max_tgt_len)
            data_loader = DataLoader(
                eval_dataset,
                batch_size=model_config['training']['batch_size'],
                shuffle=False,
                collate_fn=collate_fn
            )
            print(f'Evaluation dataset size: {len(eval_dataset)}')
        
        # Create evaluator and evaluate
        evaluator = Evaluator(model, model_config, device, model_name, src_vocab, tgt_vocab)
        results = evaluator.evaluate(data_loader)
        
        all_results[model_name] = {
            'overall': results['overall'],
            'by_length': results['by_length']
        }
    
    # Print comparison
    if len(all_results) > 1:
        # Overall comparison
        print(f'\n{"="*60}')
        print('Model Comparison — Overall')
        print(f'{"="*60}')
        print(f'{"Model":<15} {"BLEU":<10} {"Token Acc":<12} {"Exact Match":<12}')
        print('-' * 60)
        for model_name, data in all_results.items():
            metrics = data['overall']
            print(f'{model_name:<15} {metrics["bleu"]:<10.2f} '
                  f'{metrics["token_accuracy"]*100:<12.2f} '
                  f'{metrics["exact_match"]*100:<12.2f}')

        # Cross-model by-length comparison
        all_buckets = sorted(set(
            bucket
            for data in all_results.values()
            for bucket in data['by_length'].keys()
        ))

        for metric_key, metric_label in [
            ('bleu', 'BLEU'),
            ('token_accuracy', 'Token Accuracy (%)'),
            ('exact_match', 'Exact Match (%)'),
        ]:
            print(f'\n{"="*60}')
            print(f'Model Comparison by Docstring Length — {metric_label}')
            print(f'{"="*60}')
            model_names = list(all_results.keys())
            header = f'{"Length":<10}' + ''.join(f'{m:<15}' for m in model_names)
            print(header)
            print('-' * (10 + 15 * len(model_names)))
            for bucket in all_buckets:
                row = f'{bucket:<10}'
                for model_name in model_names:
                    bucket_data = all_results[model_name]['by_length'].get(bucket)
                    if bucket_data:
                        val = bucket_data[metric_key]
                        if metric_key != 'bleu':
                            val *= 100
                        row += f'{val:<15.2f}'
                    else:
                        row += f'{"N/A":<15}'
                print(row)


if __name__ == '__main__':
    main()


### visualize_attention.py

In [None]:
%%writefile visualize_attention.py
"""
Attention visualization script
"""
import matplotlib
matplotlib.use('Agg')  # Use non-interactive backend
import matplotlib.pyplot as plt
import seaborn as sns

import torch
import yaml
import os
import argparse
import numpy as np
from tqdm import tqdm

from data_loader import load_and_prepare_data, Vocabulary
from models import create_attention_seq2seq


class AttentionVisualizer:
    """Visualize attention weights"""
    
    def __init__(self, model, config, device, src_vocab, tgt_vocab):
        self.model = model
        self.config = config
        self.device = device
        self.src_vocab = src_vocab
        self.tgt_vocab = tgt_vocab
        
        # Visualization directory
        self.viz_dir = os.path.join(config['paths']['visualizations'], 'attention')
        os.makedirs(self.viz_dir, exist_ok=True)
    
    def load_checkpoint(self, checkpoint_path):
        """Load model checkpoint"""
        checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=False)
        self.model.load_state_dict(checkpoint['model_state_dict'])
        print(f'Loaded checkpoint from epoch {checkpoint["epoch"]+1}')
    
    def visualize_attention(self, src_tokens, tgt_tokens, attention_weights, example_id):
        """
        Visualize attention weights as a heatmap
        
        Args:
            src_tokens: list of source tokens
            tgt_tokens: list of target tokens
            attention_weights: (tgt_len, src_len) attention matrix
            example_id: identifier for saving the plot
        """
        try:
            # Limit visualization to reasonable size
            max_tgt_tokens = 50
            max_src_tokens = 30
            
            if len(tgt_tokens) > max_tgt_tokens:
                tgt_tokens = tgt_tokens[:max_tgt_tokens]
                attention_weights = attention_weights[:max_tgt_tokens, :]
                print(f'  Note: Truncated visualization to first {max_tgt_tokens} target tokens')
            
            if len(src_tokens) > max_src_tokens:
                src_tokens = src_tokens[:max_src_tokens]
                attention_weights = attention_weights[:, :max_src_tokens]
                print(f'  Note: Truncated visualization to first {max_src_tokens} source tokens')
            
            # Create figure
            fig, ax = plt.subplots(figsize=(max(10, len(src_tokens) * 0.5), 
                                           max(8, len(tgt_tokens) * 0.4)))
            
            # Create heatmap
            sns.heatmap(
                attention_weights,
                xticklabels=src_tokens,
                yticklabels=tgt_tokens,
                cmap='YlOrRd',
                cbar=True,
                ax=ax,
                vmin=0,
                vmax=1,
                square=False
            )
            
            ax.set_xlabel('Source Sequence (Docstring)', fontsize=12, fontweight='bold')
            ax.set_ylabel('Target Sequence (Generated Code)', fontsize=12, fontweight='bold')
            ax.set_title(f'Attention Weights - Example {example_id}', fontsize=14, fontweight='bold')
            
            # Rotate labels for better readability
            plt.xticks(rotation=45, ha='right', fontsize=8)
            plt.yticks(rotation=0, fontsize=8)
            
            plt.tight_layout()
            
            # Save figure
            save_path = os.path.join(self.viz_dir, f'attention_example_{example_id}.png')
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
            plt.close()
            
            print(f'Saved attention visualization: {save_path}')
        except Exception as e:
            print(f'Error creating visualization: {e}')
            plt.close('all')
    
    def visualize_example(self, src, tgt, src_text, tgt_text, example_id):
        """Visualize attention for a single example"""
        self.model.eval()
        
        with torch.no_grad():
            src = src.unsqueeze(0).to(self.device)
            src_lengths = torch.tensor([src.shape[1]])
            
            # Generate with attention
            max_length = self.config['dataset']['max_code_length'] + 2
            generated, attention_weights = self.model.generate(
                src, src_lengths, max_length, self.tgt_vocab.SOS_token
            )
            
            # Get generated sequence
            generated_indices = generated[0].cpu().tolist()
            generated_text = self.tgt_vocab.decode(generated_indices)
            
            # Get attention weights (remove batch dimension)
            attention_weights = attention_weights[0].cpu().numpy()
            
            # Get tokens
            src_tokens = self.src_vocab.tokenize(src_text)
            tgt_tokens = self.tgt_vocab.tokenize(generated_text)
            
            # Truncate attention to actual lengths
            attention_weights = attention_weights[:len(tgt_tokens), :len(src_tokens)]
            
            # Print example
            print(f'\n{"="*80}')
            print(f'Example {example_id}')
            print(f'{"="*80}')
            print(f'Source (Docstring):\n  {src_text}')
            print(f'\nReference Code:\n  {tgt_text}')
            print(f'\nGenerated Code:\n  {generated_text}')
            print(f'\nAttention matrix shape: {attention_weights.shape}')
            
            # Visualize attention
            self.visualize_attention(src_tokens, tgt_tokens, attention_weights, example_id)
            
            # Save detailed info
            info = {
                'source': src_text,
                'reference': tgt_text,
                'generated': generated_text,
                'src_tokens': src_tokens,
                'tgt_tokens': tgt_tokens,
                'attention_shape': attention_weights.shape
            }
            
            return info, attention_weights
    
    def analyze_attention_patterns(self, attention_weights, src_tokens, tgt_tokens):
        """Analyze attention patterns"""
        print(f'\nAttention Analysis:')
        
        # Find max attention for each target token
        max_attentions = np.max(attention_weights, axis=1)
        max_src_indices = np.argmax(attention_weights, axis=1)
        
        print(f'\nTarget -> Most Attended Source:')
        for i, (tgt_token, src_idx, max_att) in enumerate(zip(tgt_tokens, max_src_indices, max_attentions)):
            if src_idx < len(src_tokens):
                print(f'  {tgt_token:<15} -> {src_tokens[src_idx]:<15} (weight: {max_att:.3f})')
        
        # Calculate attention entropy (measure of focus)
        epsilon = 1e-10
        entropy = -np.sum(attention_weights * np.log(attention_weights + epsilon), axis=1)
        avg_entropy = np.mean(entropy)
        
        print(f'\nAverage Attention Entropy: {avg_entropy:.3f}')
        print(f'  (Lower entropy = more focused attention)')
    
    def visualize_multiple_examples(self, data_loader, num_examples=5):
        """Visualize attention for multiple examples"""
        print(f'\nVisualizing attention for {num_examples} examples...')
        
        examples_found = 0
        
        with torch.no_grad():
            for batch in data_loader:
                if examples_found >= num_examples:
                    break
                
                src = batch['src']
                tgt = batch['tgt']
                src_texts = batch['src_texts']
                tgt_texts = batch['tgt_texts']
                
                for i in range(src.shape[0]):
                    if examples_found >= num_examples:
                        break
                    
                    info, attention_weights = self.visualize_example(
                        src[i],
                        tgt[i],
                        src_texts[i],
                        tgt_texts[i],
                        examples_found + 1
                    )
                    
                    # Analyze attention patterns
                    self.analyze_attention_patterns(
                        attention_weights,
                        info['src_tokens'],
                        info['tgt_tokens']
                    )
                    
                    examples_found += 1
        
        print(f'\n{"="*80}')
        print(f'Visualized {examples_found} examples')
        print(f'Visualizations saved to: {self.viz_dir}')
        print(f'{"="*80}')
    
    def create_summary_visualization(self, data_loader, num_samples=100):
        """Create summary visualization of attention statistics"""
        print(f'\nCreating attention summary visualization...')
        
        all_entropies = []
        all_max_attentions = []
        
        self.model.eval()
        
        samples_processed = 0
        
        with torch.no_grad():
            for batch in tqdm(data_loader, desc='Processing'):
                if samples_processed >= num_samples:
                    break
                
                src = batch['src'].to(self.device)
                src_lengths = batch['src_lengths']
                
                max_length = self.config['dataset']['max_code_length'] + 2
                _, attention_weights = self.model.generate(
                    src, src_lengths, max_length, self.tgt_vocab.SOS_token
                )
                
                # Calculate statistics
                for i in range(attention_weights.shape[0]):
                    if samples_processed >= num_samples:
                        break
                    
                    att = attention_weights[i].cpu().numpy()
                    
                    # Entropy
                    epsilon = 1e-10
                    entropy = -np.sum(att * np.log(att + epsilon), axis=1)
                    all_entropies.extend(entropy.tolist())
                    
                    # Max attention
                    max_att = np.max(att, axis=1)
                    all_max_attentions.extend(max_att.tolist())
                    
                    samples_processed += 1
        
        # Create summary plot
        fig, axes = plt.subplots(1, 2, figsize=(14, 5))
        
        # Entropy distribution
        axes[0].hist(all_entropies, bins=50, edgecolor='black', alpha=0.7)
        axes[0].set_xlabel('Attention Entropy', fontsize=12)
        axes[0].set_ylabel('Frequency', fontsize=12)
        axes[0].set_title('Distribution of Attention Entropy', fontsize=14, fontweight='bold')
        axes[0].axvline(np.mean(all_entropies), color='red', linestyle='--', 
                       label=f'Mean: {np.mean(all_entropies):.3f}')
        axes[0].legend()
        
        # Max attention distribution
        axes[1].hist(all_max_attentions, bins=50, edgecolor='black', alpha=0.7, color='orange')
        axes[1].set_xlabel('Maximum Attention Weight', fontsize=12)
        axes[1].set_ylabel('Frequency', fontsize=12)
        axes[1].set_title('Distribution of Maximum Attention Weights', fontsize=14, fontweight='bold')
        axes[1].axvline(np.mean(all_max_attentions), color='red', linestyle='--',
                       label=f'Mean: {np.mean(all_max_attentions):.3f}')
        axes[1].legend()
        
        plt.tight_layout()
        
        save_path = os.path.join(self.viz_dir, 'attention_statistics.png')
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.close()
        
        print(f'Saved attention statistics: {save_path}')


def main():
    parser = argparse.ArgumentParser(description='Visualize attention weights')
    parser.add_argument('--config', type=str, default='config.yaml',
                        help='Path to config file')
    parser.add_argument('--num_examples', type=int, default=5,
                        help='Number of examples to visualize')
    parser.add_argument('--summary', action='store_true',
                        help='Create summary statistics visualization')
    args = parser.parse_args()
    
    # Load config
    with open(args.config, 'r') as f:
        config = yaml.safe_load(f)
    
    # Set device
    device = torch.device(config['device'] if torch.cuda.is_available() else 'cpu')
    print(f'Using device: {device}')
    
    # Load checkpoint first to check if it exists
    checkpoint_path = os.path.join(
        config['paths']['checkpoints'],
        'attention',
        'best_model.pt'
    )
    
    if not os.path.exists(checkpoint_path):
        print(f'Error: Checkpoint not found: {checkpoint_path}')
        print('Please train the attention model first.')
        print('\nTo train the model, run:')
        print('  python train.py --config config_quick.yaml --model attention')
        return
    
    # Load checkpoint to get vocabularies and config
    print('Loading checkpoint...')
    checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
    
    # Extract vocabularies from checkpoint
    if 'src_vocab' not in checkpoint or 'tgt_vocab' not in checkpoint:
        print('Error: Vocabularies not found in checkpoint!')
        print('Please retrain the model with the updated train.py')
        return
    
    src_vocab = checkpoint['src_vocab']
    tgt_vocab = checkpoint['tgt_vocab']
    print(f'Loaded vocabularies from checkpoint: src={src_vocab.n_words}, tgt={tgt_vocab.n_words}')
    
    # Use config from checkpoint (has correct dimensions)
    model_config = checkpoint.get('config', config)
    print(f'Using embedding_dim={model_config["model"]["embedding_dim"]}, hidden_dim={model_config["model"]["hidden_dim"]}')
    
    # Load test data with checkpoint's vocabularies
    print('Loading test data...')
    from datasets import load_dataset
    from data_loader import CodeSearchNetDataset, collate_fn
    from torch.utils.data import DataLoader
    
    dataset = load_dataset(
        model_config['dataset']['name'],
        split='train',
        cache_dir=model_config['dataset']['cache_dir']
    )
    
    total_needed = model_config['dataset']['train_size'] + model_config['dataset']['val_size'] + model_config['dataset']['test_size']
    dataset = dataset.shuffle(seed=42).select(range(min(total_needed, len(dataset))))
    
    splits = dataset.train_test_split(test_size=model_config['dataset']['val_size'] + model_config['dataset']['test_size'], seed=42)
    temp_splits = splits['test'].train_test_split(test_size=model_config['dataset']['test_size'], seed=42)
    test_data = temp_splits['test']
    
    max_src_len = model_config['dataset'].get('max_docstring_length', 50)
    max_tgt_len = model_config['dataset'].get('max_code_length', 100)
    test_dataset = CodeSearchNetDataset(test_data, src_vocab, tgt_vocab, max_src_len, max_tgt_len)
    test_loader = DataLoader(
        test_dataset,
        batch_size=model_config['training']['batch_size'],
        shuffle=False,
        collate_fn=collate_fn
    )
    print(f'Test dataset size: {len(test_dataset)}')
    
    src_vocab_size = src_vocab.n_words
    tgt_vocab_size = tgt_vocab.n_words
    
    # Create attention model with checkpoint's config and vocab sizes
    print('Creating attention model...')
    model = create_attention_seq2seq(src_vocab_size, tgt_vocab_size, model_config, device)
    
    # Load checkpoint weights
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f'Loaded checkpoint from epoch {checkpoint["epoch"]+1}, val_loss={checkpoint["val_loss"]:.4f}')
    
    # Create visualizer
    visualizer = AttentionVisualizer(model, model_config, device, src_vocab, tgt_vocab)
    
    # Visualize examples
    visualizer.visualize_multiple_examples(test_loader, num_examples=args.num_examples)
    
    # Create summary visualization
    if args.summary:
        num_samples = min(100, len(test_dataset))
        visualizer.create_summary_visualization(test_loader, num_samples=num_samples)


if __name__ == '__main__':
    main()


### generate_report_plots.py

In [None]:
%%writefile generate_report_plots.py
"""
Generate plots and figures for the report
"""
import json
import matplotlib.pyplot as plt
import seaborn as sns
import os
import argparse
import numpy as np

sns.set_style('whitegrid')
sns.set_palette('husl')


def plot_training_curves(model_names, config):
    """Plot training and validation loss curves"""
    fig, axes = plt.subplots(1, len(model_names), figsize=(6*len(model_names), 5))
    
    if len(model_names) == 1:
        axes = [axes]
    
    for idx, model_name in enumerate(model_names):
        # Load training history
        history_path = os.path.join(config['paths']['logs'], model_name, 'training_history.json')
        
        if not os.path.exists(history_path):
            print(f"Warning: History not found for {model_name}")
            continue
        
        with open(history_path, 'r') as f:
            history = json.load(f)
        
        epochs = range(1, len(history['train_losses']) + 1)
        
        axes[idx].plot(epochs, history['train_losses'], label='Train Loss', linewidth=2)
        axes[idx].plot(epochs, history['val_losses'], label='Val Loss', linewidth=2)
        axes[idx].set_xlabel('Epoch', fontsize=12)
        axes[idx].set_ylabel('Loss', fontsize=12)
        axes[idx].set_title(f'{model_name.upper()} - Training Progress', 
                           fontsize=14, fontweight='bold')
        axes[idx].legend()
        axes[idx].grid(True, alpha=0.3)
    
    plt.tight_layout()
    
    # Save
    save_path = os.path.join(config['paths']['results'], 'training_curves.png')
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    print(f"Saved training curves: {save_path}")
    plt.close()


def plot_model_comparison(model_names, config):
    """Plot comparison of models across metrics"""
    metrics_data = {
        'model': [],
        'BLEU': [],
        'Token Accuracy': [],
        'Exact Match': []
    }
    
    for model_name in model_names:
        results_path = os.path.join(config['paths']['results'], model_name, 'evaluation_results.json')
        
        if not os.path.exists(results_path):
            print(f"Warning: Results not found for {model_name}")
            continue
        
        with open(results_path, 'r') as f:
            results = json.load(f)
        
        metrics_data['model'].append(model_name.upper())
        metrics_data['BLEU'].append(results['overall']['bleu'])
        metrics_data['Token Accuracy'].append(results['overall']['token_accuracy'] * 100)
        metrics_data['Exact Match'].append(results['overall']['exact_match'] * 100)
    
    # Create grouped bar chart
    x = np.arange(len(metrics_data['model']))
    width = 0.25
    
    fig, ax = plt.subplots(figsize=(12, 6))
    
    rects1 = ax.bar(x - width, metrics_data['BLEU'], width, label='BLEU Score')
    rects2 = ax.bar(x, metrics_data['Token Accuracy'], width, label='Token Accuracy (%)')
    rects3 = ax.bar(x + width, metrics_data['Exact Match'], width, label='Exact Match (%)')
    
    ax.set_xlabel('Model', fontsize=14, fontweight='bold')
    ax.set_ylabel('Score', fontsize=14, fontweight='bold')
    ax.set_title('Model Comparison Across Metrics', fontsize=16, fontweight='bold')
    ax.set_xticks(x)
    ax.set_xticklabels(metrics_data['model'])
    ax.legend(fontsize=12)
    ax.grid(True, alpha=0.3, axis='y')
    
    # Add value labels on bars
    def autolabel(rects):
        for rect in rects:
            height = rect.get_height()
            ax.annotate(f'{height:.1f}',
                       xy=(rect.get_x() + rect.get_width() / 2, height),
                       xytext=(0, 3),
                       textcoords="offset points",
                       ha='center', va='bottom',
                       fontsize=10)
    
    autolabel(rects1)
    autolabel(rects2)
    autolabel(rects3)
    
    plt.tight_layout()
    
    # Save
    save_path = os.path.join(config['paths']['results'], 'model_comparison.png')
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    print(f"Saved model comparison: {save_path}")
    plt.close()


def plot_performance_by_length(model_names, config):
    """Plot performance by input length"""
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))
    
    metric_names = ['bleu', 'token_accuracy', 'exact_match']
    titles = ['BLEU Score by Length', 'Token Accuracy by Length', 'Exact Match by Length']
    
    for metric_idx, (metric, title) in enumerate(zip(metric_names, titles)):
        for model_name in model_names:
            results_path = os.path.join(config['paths']['results'], model_name, 
                                       'evaluation_results.json')
            
            if not os.path.exists(results_path):
                continue
            
            with open(results_path, 'r') as f:
                results = json.load(f)
            
            # Extract length buckets
            buckets = sorted(results['by_length'].keys())
            values = []
            
            for bucket in buckets:
                val = results['by_length'][bucket][metric]
                if metric != 'bleu':
                    val *= 100
                values.append(val)
            
            axes[metric_idx].plot(buckets, values, marker='o', linewidth=2, 
                                 label=model_name.upper(), markersize=8)
        
        axes[metric_idx].set_xlabel('Source Length (tokens)', fontsize=12)
        axes[metric_idx].set_ylabel('Score', fontsize=12)
        axes[metric_idx].set_title(title, fontsize=14, fontweight='bold')
        axes[metric_idx].legend()
        axes[metric_idx].grid(True, alpha=0.3)
    
    plt.tight_layout()
    
    # Save
    save_path = os.path.join(config['paths']['results'], 'performance_by_length.png')
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    print(f"Saved performance by length: {save_path}")
    plt.close()


def generate_results_table(model_names, config):
    """Generate LaTeX table of results"""
    table = "\\begin{table}[h]\n"
    table += "\\centering\n"
    table += "\\caption{Model Performance Comparison}\n"
    table += "\\begin{tabular}{|l|c|c|c|}\n"
    table += "\\hline\n"
    table += "\\textbf{Model} & \\textbf{BLEU} & \\textbf{Token Acc (\\%)} & \\textbf{Exact Match (\\%)} \\\\\n"
    table += "\\hline\n"
    
    for model_name in model_names:
        results_path = os.path.join(config['paths']['results'], model_name, 
                                   'evaluation_results.json')
        
        if not os.path.exists(results_path):
            continue
        
        with open(results_path, 'r') as f:
            results = json.load(f)
        
        bleu = results['overall']['bleu']
        token_acc = results['overall']['token_accuracy'] * 100
        exact_match = results['overall']['exact_match'] * 100
        
        table += f"{model_name.upper()} & {bleu:.2f} & {token_acc:.2f} & {exact_match:.2f} \\\\\n"
    
    table += "\\hline\n"
    table += "\\end{tabular}\n"
    table += "\\end{table}\n"
    
    # Save
    save_path = os.path.join(config['paths']['results'], 'results_table.tex')
    with open(save_path, 'w') as f:
        f.write(table)
    
    print(f"Saved LaTeX table: {save_path}")
    print("\nLaTeX table:")
    print(table)


def main():
    import yaml
    
    parser = argparse.ArgumentParser(description='Generate plots for report')
    parser.add_argument('--config', type=str, default='config.yaml')
    args = parser.parse_args()
    
    # Load config
    with open(args.config, 'r') as f:
        config = yaml.safe_load(f)
    
    model_names = ['vanilla', 'lstm', 'attention']
    
    print("Generating plots for report...")
    print("="*60)
    
    # Generate all plots
    plot_training_curves(model_names, config)
    plot_model_comparison(model_names, config)
    plot_performance_by_length(model_names, config)
    generate_results_table(model_names, config)
    
    print("="*60)
    print("All plots generated successfully!")
    print(f"Check the '{config['paths']['results']}' directory")


if __name__ == '__main__':
    main()


## Step 4 — Create Config
Edit hyperparameters here, then re-run.

In [None]:
%%writefile config_colab.yaml
dataset:
  name: "Nan-Do/code-search-net-python"
  train_size: 10000
  val_size: 1000
  test_size: 1000
  max_docstring_length: 50
  max_code_length: 80
  cache_dir: "./data_cache"

model:
  embedding_dim: 256
  hidden_dim: 256
  dropout: 0.3
  max_vocab_size: 10000

training:
  batch_size: 64
  num_epochs: 5
  learning_rate: 0.001
  teacher_forcing_ratio: 0.5
  gradient_clip: 5.0
  save_every: 1

paths:
  checkpoints: "./checkpoints"
  results: "./results"
  visualizations: "./visualizations"
  logs: "./logs"

device: "cuda"

## Step 5 — Train All 3 Models
Takes ~30–60 minutes on T4 GPU. Progress bars show loss per epoch.

In [None]:
!python train.py --config config_colab.yaml --model all

## Step 6 — Evaluate All Models
Computes BLEU, Token Accuracy, Exact Match.
Also prints a cross-model comparison table broken down by docstring length.

In [None]:
!python evaluate.py --config config_colab.yaml --model all --split test

## Step 7 — Attention Heatmaps & Report Plots

In [None]:
# Attention heatmaps (5 examples + entropy stats)
!python visualize_attention.py --config config_colab.yaml --num_examples 5 --summary

# All report plots (training curves, model comparison, performance-by-length)
!python generate_report_plots.py --config config_colab.yaml

# Display heatmaps inline
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import os

viz_dir = "./visualizations/attention"
png_files = sorted([f for f in os.listdir(viz_dir) if f.endswith(".png")])
print(f"Found {len(png_files)} attention visualizations")

fig, axes = plt.subplots(1, min(5, len(png_files)), figsize=(25, 6))
if len(png_files) == 1:
    axes = [axes]
for i, fname in enumerate(png_files[:5]):
    img = mpimg.imread(os.path.join(viz_dir, fname))
    axes[i].imshow(img)
    axes[i].set_title(fname.replace(".png", ""), fontsize=9)
    axes[i].axis("off")
plt.tight_layout()
plt.show()

## Step 8 — View Report Plots Inline

In [None]:
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import os

plots = [
    ("./results/training_curves.png",       "Training Curves"),
    ("./results/model_comparison.png",       "Model Comparison"),
    ("./results/performance_by_length.png",  "Performance by Docstring Length"),
]

for fpath, title in plots:
    if os.path.exists(fpath):
        img = mpimg.imread(fpath)
        plt.figure(figsize=(14, 6))
        plt.imshow(img)
        plt.title(title, fontsize=14)
        plt.axis("off")
        plt.tight_layout()
        plt.show()
    else:
        print(f"Not found: {fpath}")

## Step 9 — Download All Results
Downloads three zip files: trained model weights, evaluation results, and visualizations.

In [None]:
import zipfile, os
from google.colab import files

def zip_folder(folder, zipname):
    with zipfile.ZipFile(zipname, "w", zipfile.ZIP_DEFLATED) as zf:
        for root, dirs, fnames in os.walk(folder):
            for fname in fnames:
                zf.write(os.path.join(root, fname))
    print(f"Zipped: {zipname} ({os.path.getsize(zipname)/1e6:.1f} MB)")

zip_folder("./checkpoints",   "checkpoints.zip")
zip_folder("./results",       "results.zip")
zip_folder("./visualizations","visualizations.zip")

files.download("checkpoints.zip")
files.download("results.zip")
files.download("visualizations.zip")
print("Done! Check your Downloads folder.")