In [1]:
!pip install datasets tokenizers nltk sacrebleu transformers rouge-score -q

[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m100.8/100.8 kB[0m [31m3.8 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from datasets import load_dataset
from tokenizers import Tokenizer, models, trainers, pre_tokenizers, decoders
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
import numpy as np
import random
import os
from tqdm import tqdm
import math
import warnings
warnings.filterwarnings('ignore')

In [3]:
class Config:
    # Model Architecture
    EMBED_SIZE = 256
    HIDDEN_SIZE = 256
    NUM_LAYERS = 2
    DROPOUT = 0.3
    BIDIRECTIONAL = False  # optional True if you reduce HIDDEN_SIZE
    
    # Training
    BATCH_SIZE = 32
    EPOCHS = 20
    LR = 0.001
    CLIP = 1.0
    TEACHER_FORCING_RATIO = 0.7
    TEACHER_FORCING_DECAY = 0.95
    
    # Data
    MAX_DOC_LEN = 60
    MAX_CODE_LEN = 120
    VOCAB_SIZE = 10000
    
    # Optimization
    WEIGHT_DECAY = 1e-5
    LR_SCHEDULER_PATIENCE = 2
    LR_SCHEDULER_FACTOR = 0.5
    
    # Generation
    BEAM_SIZE = 5
    LENGTH_PENALTY = 1.0
    
    # System
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    SEED = 42

config = Config()

In [4]:
# Set seeds for reproducibility
torch.manual_seed(config.SEED)
random.seed(config.SEED)
np.random.seed(config.SEED)

print(f"Using device: {config.DEVICE}")
print(f"Model will have {config.NUM_LAYERS} layers with hidden size {config.HIDDEN_SIZE}")

Using device: cuda
Model will have 2 layers with hidden size 256


In [5]:
print("\nüìä Loading dataset...")
dataset = load_dataset("Nan-Do/code-search-net-python")

# Use more data for better performance
full_data = dataset["train"].select(range(8000))

split1 = full_data.train_test_split(test_size=0.15, seed=config.SEED)
train_data = split1["train"]
temp_data = split1["test"]

split2 = temp_data.train_test_split(test_size=0.5, seed=config.SEED)
val_data = split2["train"]
test_data = split2["test"]

print(f"‚úÖ Train: {len(train_data):,} | Val: {len(val_data):,} | Test: {len(test_data):,}")


üìä Loading dataset...


README.md: 0.00B [00:00, ?B/s]

data/train-00000-of-00004-ee77a7de79eb2a(‚Ä¶):   0%|          | 0.00/155M [00:00<?, ?B/s]

data/train-00001-of-00004-648b3bede2edf6(‚Ä¶):   0%|          | 0.00/139M [00:00<?, ?B/s]

data/train-00002-of-00004-1dfd72b171e6b2(‚Ä¶):   0%|          | 0.00/153M [00:00<?, ?B/s]

data/train-00003-of-00004-184ab6d0e3c690(‚Ä¶):   0%|          | 0.00/151M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/455243 [00:00<?, ? examples/s]

‚úÖ Train: 6,800 | Val: 600 | Test: 600


In [6]:
def create_advanced_tokenizer(train_data, vocab_size=10000):
    """Create a robust BPE tokenizer with proper configuration"""
    tokenizer_path = f"tokenizer_v{vocab_size}.json"
    
    if os.path.exists(tokenizer_path):
        print("üìÇ Loading existing tokenizer...")
        tokenizer = Tokenizer.from_file(tokenizer_path)
    else:
        print("üîß Training new tokenizer...")
        # Initialize BPE tokenizer
        tokenizer = Tokenizer(models.BPE(unk_token="<UNK>"))
        
        # Use byte-level pre-tokenizer for better coverage
        tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=True)
        tokenizer.decoder = decoders.ByteLevel()
        
        # Special tokens
        special_tokens = ["<PAD>", "<SOS>", "<EOS>", "<UNK>", "<MASK>"]
        
        trainer = trainers.BpeTrainer(
            vocab_size=vocab_size,
            special_tokens=special_tokens,
            min_frequency=2,
            show_progress=True
        )
        
        # Prepare corpus
        corpus = []
        for i in range(min(10000, len(train_data))):
            corpus.append(train_data[i]["docstring"] if train_data[i]["docstring"] else "")
            corpus.append(train_data[i]["code"] if train_data[i]["code"] else "")
        
        tokenizer.train_from_iterator(corpus, trainer)
        tokenizer.save(tokenizer_path)
        print(f"‚úÖ Tokenizer saved with vocab size: {tokenizer.get_vocab_size()}")
    
    return tokenizer

tokenizer = create_advanced_tokenizer(train_data, config.VOCAB_SIZE)
# Token helpers
PAD_IDX = tokenizer.token_to_id("<PAD>")
SOS_IDX = tokenizer.token_to_id("<SOS>")
EOS_IDX = tokenizer.token_to_id("<EOS>")
UNK_IDX = tokenizer.token_to_id("<UNK>")
VOCAB_SIZE = tokenizer.get_vocab_size()

print(f"üìù Vocab: {VOCAB_SIZE} | PAD: {PAD_IDX} | SOS: {SOS_IDX} | EOS: {EOS_IDX}")
def encode(text):
    """Encode text with proper error handling"""
    if not text:
        return []
    try:
        return tokenizer.encode(text).ids
    except:
        return [UNK_IDX]

def decode(ids):
    """Decode ids with proper error handling"""
    if not ids:
        return ""
    try:
        return tokenizer.decode(ids)
    except:
        return "<DECODE_ERROR>"

üìÇ Loading existing tokenizer...
üìù Vocab: 10000 | PAD: 0 | SOS: 1 | EOS: 2


In [7]:
class CodeDataset(Dataset):
    def __init__(self, data, max_doc_len, max_code_len, augment=False):
        self.data = data
        self.max_doc_len = max_doc_len
        self.max_code_len = max_code_len
        self.augment = augment
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        
        # Extract and clean text
        docstring = str(item["docstring"]).strip() if item["docstring"] else ""
        code = str(item["code"]).strip() if item["code"] else ""
        
        # Data augmentation for training
        if self.augment and random.random() < 0.1:
            # Randomly drop some characters (simulates noise)
            if len(docstring) > 10:
                drop_idx = random.randint(0, len(docstring)-1)
                docstring = docstring[:drop_idx] + docstring[drop_idx+1:]
        
        # Encode
        doc_ids = encode(docstring)[:self.max_doc_len]
        code_ids = encode(code)[:self.max_code_len]
        
        # Ensure minimum length
        if len(doc_ids) == 0:
            doc_ids = [PAD_IDX]
        if len(code_ids) == 0:
            code_ids = [PAD_IDX]
            
        return {
            "doc": doc_ids,
            "code": code_ids,
            "doc_len": len(doc_ids),
            "code_len": len(code_ids)
        }

In [8]:
def advanced_collate_fn(batch):
    """Advanced collate function with sequence lengths"""
    docs = [b["doc"] for b in batch]
    codes = [b["code"] for b in batch]
    doc_lens = [b["doc_len"] for b in batch]
    code_lens = [b["code_len"] for b in batch]
    
    max_doc = max(doc_lens)
    max_code = max(code_lens)
    
    src_list = []
    trg_in_list = []
    trg_out_list = []
    
    for d, c, d_len, c_len in zip(docs, codes, doc_lens, code_lens):
        # Source padding
        src_list.append(d + [PAD_IDX] * (max_doc - d_len))
        
        # Target input (with SOS)
        trg_in_list.append([SOS_IDX] + c + [PAD_IDX] * (max_code - c_len))
        
        # Target output (with EOS)
        trg_out_list.append(c + [EOS_IDX] + [PAD_IDX] * (max_code - c_len))
    
    return {
        "src": torch.tensor(src_list, dtype=torch.long),
        "trg_in": torch.tensor(trg_in_list, dtype=torch.long),
        "trg_out": torch.tensor(trg_out_list, dtype=torch.long),
        "src_len": torch.tensor(doc_lens, dtype=torch.long),
        "trg_len": torch.tensor(code_lens, dtype=torch.long)
    }

In [9]:
# Create datasets with augmentation for training
train_dataset = CodeDataset(train_data, config.MAX_DOC_LEN, config.MAX_CODE_LEN, augment=True)
val_dataset = CodeDataset(val_data, config.MAX_DOC_LEN, config.MAX_CODE_LEN, augment=False)
test_dataset = CodeDataset(test_data, config.MAX_DOC_LEN, config.MAX_CODE_LEN, augment=False)

In [10]:
# DataLoaders
train_loader = DataLoader(
    train_dataset, 
    batch_size=config.BATCH_SIZE, 
    shuffle=True, 
    collate_fn=advanced_collate_fn,
    num_workers=2,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset, 
    batch_size=config.BATCH_SIZE, 
    collate_fn=advanced_collate_fn,
    num_workers=2,
    pin_memory=True
)

test_loader = DataLoader(
    test_dataset, 
    batch_size=config.BATCH_SIZE, 
    collate_fn=advanced_collate_fn,
    num_workers=2,
    pin_memory=True
)

In [11]:
class Attention(nn.Module):
    """Combined attention mechanism with multiple scoring functions"""
    def __init__(self, hidden_size, method='general'):
        super().__init__()
        self.method = method
        self.hidden_size = hidden_size
        
        if method == 'general':
            self.attn = nn.Linear(hidden_size, hidden_size)
        elif method == 'concat':
            self.attn = nn.Linear(hidden_size * 2, hidden_size)
            self.v = nn.Parameter(torch.FloatTensor(1, hidden_size))
        elif method == 'dot':
            pass  # No parameters needed
        else:
            raise ValueError(f'Unknown attention method: {method}')
            
        self.reset_parameters()
    
    def reset_parameters(self):
        if hasattr(self, 'attn'):
            nn.init.xavier_uniform_(self.attn.weight)
        if hasattr(self, 'v'):
            nn.init.xavier_uniform_(self.v)
    
    def forward(self, hidden, encoder_outputs, mask=None):
        """
        hidden: (batch_size, hidden_size * num_directions)
        encoder_outputs: (batch_size, seq_len, hidden_size * num_directions)
        mask: (batch_size, seq_len)
        """
        batch_size, seq_len = encoder_outputs.shape[0], encoder_outputs.shape[1]
        
        # Calculate attention energies
        if self.method == 'dot':
            # Dot product attention
            hidden = hidden.unsqueeze(2)  # (batch, hidden, 1)
            energy = torch.bmm(encoder_outputs, hidden).squeeze(2)  # (batch, seq_len)
            
        elif self.method == 'general':
            # General attention
            energy = self.attn(encoder_outputs)  # (batch, seq_len, hidden)
            hidden = hidden.unsqueeze(2)  # (batch, hidden, 1)
            energy = torch.bmm(energy, hidden).squeeze(2)  # (batch, seq_len)
            
        else:  # concat
            # Concatenation attention
            hidden = hidden.unsqueeze(1).repeat(1, seq_len, 1)  # (batch, seq_len, hidden)
            energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim=2)))
            energy = energy @ self.v.t()  # (batch, seq_len, 1)
            energy = energy.squeeze(2)  # (batch, seq_len)
        
        # Apply mask if provided
        if mask is not None:
            energy = energy.masked_fill(mask == 0, -1e10)
        
        # Normalize
        attention_weights = torch.softmax(energy, dim=1)
        
        # Apply attention
        context = torch.bmm(attention_weights.unsqueeze(1), encoder_outputs).squeeze(1)
        
        return context, attention_weights

In [12]:
class Encoder(nn.Module):
    def __init__(self, vocab_size, embed_size, hidden_size, num_layers, dropout, bidirectional):
        super().__init__()
        
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.bidirectional = bidirectional
        self.num_directions = 2 if bidirectional else 1
        
        # Embedding layer
        self.embedding = nn.Embedding(vocab_size, embed_size, padding_idx=PAD_IDX)
        self.embedding_dropout = nn.Dropout(dropout)
        
        # LSTM
        self.rnn = nn.LSTM(
            embed_size,
            hidden_size,
            num_layers,
            dropout=dropout if num_layers > 1 else 0,
            bidirectional=bidirectional,
            batch_first=True
        )
        
        # Layer normalization
        self.layer_norm = nn.LayerNorm(hidden_size * self.num_directions)
        
        # Initialize weights
        self.init_weights()
    
    def init_weights(self):
        for name, param in self.rnn.named_parameters():
            if 'weight_ih' in name:
                nn.init.xavier_uniform_(param)
            elif 'weight_hh' in name:
                nn.init.orthogonal_(param)
            elif 'bias' in name:
                nn.init.zeros_(param)
                # Set forget gate bias to 1
                n = param.size(0)
                param.data[n//4:n//2].fill_(1.0)
    
    def forward(self, src, src_len):
        """
        src: (batch_size, seq_len)
        src_len: (batch_size)
        """
        # Embedding
        embedded = self.embedding(src)
        embedded = self.embedding_dropout(embedded)
        
        # Pack sequences
        packed_embedded = pack_padded_sequence(
            embedded, 
            src_len.cpu(), 
            batch_first=True, 
            enforce_sorted=False
        )
        
        # RNN forward
        packed_outputs, (hidden, cell) = self.rnn(packed_embedded)
        
        # Unpack
        outputs, _ = pad_packed_sequence(packed_outputs, batch_first=True)
        
        # Apply layer norm
        outputs = self.layer_norm(outputs)
        
        # Handle bidirectional hidden states
        if self.bidirectional:
            # Reshape hidden: (num_layers * num_directions, batch, hidden_size)
            # to: (num_layers, batch, hidden_size * num_directions)
            hidden = hidden.view(self.num_layers, self.num_directions, 
                               hidden.size(1), self.hidden_size)
            hidden = hidden.transpose(1, 2).contiguous()
            hidden = hidden.view(self.num_layers, hidden.size(2), 
                               self.hidden_size * self.num_directions)
            
            cell = cell.view(self.num_layers, self.num_directions,
                           cell.size(1), self.hidden_size)
            cell = cell.transpose(1, 2).contiguous()
            cell = cell.view(self.num_layers, cell.size(2), 
                           self.hidden_size * self.num_directions)
        
        return outputs, hidden, cell


In [13]:
class Decoder(nn.Module):
    def __init__(self, vocab_size, embed_size, hidden_size, num_layers, dropout, attention_method='general'):
        super().__init__()
        
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        
        # Embedding layer
        self.embedding = nn.Embedding(vocab_size, embed_size, padding_idx=PAD_IDX)
        self.embedding_dropout = nn.Dropout(dropout)
        
        # Attention
        self.attention = Attention(hidden_size, method=attention_method)
        
        # Input feeding: combine embedding with previous attention context
        self.input_projection = nn.Linear(embed_size + hidden_size, embed_size)
        
        # LSTM
        self.rnn = nn.LSTM(
            embed_size,
            hidden_size,
            num_layers,
            dropout=dropout if num_layers > 1 else 0,
            batch_first=True
        )
        
        # Output projection
        self.output_projection = nn.Sequential(
            nn.Linear(hidden_size * 2, hidden_size),
            nn.Tanh(),
            nn.Dropout(dropout),
            nn.Linear(hidden_size, vocab_size)
        )
        
        # Layer normalization
        self.layer_norm1 = nn.LayerNorm(hidden_size)
        self.layer_norm2 = nn.LayerNorm(hidden_size * 2)
        
        # Initialize weights
        self.init_weights()
    
    def init_weights(self):
        for name, param in self.rnn.named_parameters():
            if 'weight_ih' in name:
                nn.init.xavier_uniform_(param)
            elif 'weight_hh' in name:
                nn.init.orthogonal_(param)
        
        # Initialize output projection
        for layer in self.output_projection:
            if isinstance(layer, nn.Linear):
                nn.init.xavier_uniform_(layer.weight)
                nn.init.zeros_(layer.bias)
    
    def forward(self, x, hidden, cell, encoder_outputs, mask=None, prev_context=None):
        """
        x: (batch_size)
        hidden: (num_layers, batch_size, hidden_size)
        cell: (num_layers, batch_size, hidden_size)
        encoder_outputs: (batch_size, seq_len, hidden_size)
        prev_context: (batch_size, hidden_size)
        """
        batch_size = x.size(0)
        
        # Embedding
        embedded = self.embedding(x.unsqueeze(1))  # (batch, 1, embed)
        embedded = self.embedding_dropout(embedded)
        
        # Input feeding: combine with previous context
        if prev_context is not None:
            embedded = torch.cat((embedded.squeeze(1), prev_context), dim=1)  # (batch, embed + hidden)
            embedded = self.input_projection(embedded)  # (batch, embed)
            embedded = embedded.unsqueeze(1)  # (batch, 1, embed)
        else:
            embedded = embedded
        
        # RNN forward
        output, (hidden, cell) = self.rnn(embedded, (hidden, cell))
        output = output.squeeze(1)  # (batch, hidden)
        output = self.layer_norm1(output)
        
        # Attention
        context, attention_weights = self.attention(output, encoder_outputs, mask)
        
        # Combine output and context
        combined = torch.cat((output, context), dim=1)  # (batch, hidden * 2)
        combined = self.layer_norm2(combined)
        
        # Generate prediction
        prediction = self.output_projection(combined)  # (batch, vocab_size)
        
        return prediction, hidden, cell, context, attention_weights


In [14]:
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device, config):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device
        self.config = config
        
    def forward(self, src, trg, src_len, trg_len, teacher_forcing_ratio=0.5):
        """
        src: (batch_size, src_seq_len)
        trg: (batch_size, trg_seq_len)
        src_len: (batch_size)
        trg_len: (batch_size)
        """
        batch_size = src.size(0)
        trg_seq_len = trg.size(1)
        vocab_size = self.decoder.vocab_size
        
        # Create mask for encoder outputs
        mask = (src != PAD_IDX).float()
        
        # Encoder
        encoder_outputs, hidden, cell = self.encoder(src, src_len)
        
        # Initialize decoder
        decoder_input = trg[:, 0]  # SOS token
        decoder_context = None
        
        # Store outputs
        outputs = torch.zeros(batch_size, trg_seq_len, vocab_size).to(self.device)
        attentions = torch.zeros(batch_size, trg_seq_len, src.size(1)).to(self.device)
        
        for t in range(1, trg_seq_len):
            decoder_output, hidden, cell, decoder_context, attention = self.decoder(
                decoder_input, hidden, cell, encoder_outputs, mask, decoder_context
            )
            
            outputs[:, t] = decoder_output
            attentions[:, t] = attention
            
            # Teacher forcing
            teacher_force = random.random() < teacher_forcing_ratio
            top1 = decoder_output.argmax(1)
            decoder_input = trg[:, t] if teacher_force else top1
        
        return outputs, attentions


In [15]:
def train_epoch(model, loader, optimizer, criterion, config, epoch):
    model.train()
    total_loss = 0
    total_tokens = 0
    
    # Decay teacher forcing
    teacher_forcing = config.TEACHER_FORCING_RATIO * (config.TEACHER_FORCING_DECAY ** epoch)
    
    progress_bar = tqdm(loader, desc=f"Epoch {epoch+1} [Train]")
    for batch in progress_bar:
        # Move to device
        src = batch["src"].to(config.DEVICE)
        trg_in = batch["trg_in"].to(config.DEVICE)
        trg_out = batch["trg_out"].to(config.DEVICE)
        src_len = batch["src_len"].to(config.DEVICE)
        trg_len = batch["trg_len"].to(config.DEVICE)
        
        optimizer.zero_grad()
        
        # Forward pass
        outputs, _ = model(src, trg_in, src_len, trg_len, teacher_forcing)
        
        # Reshape for loss calculation
        outputs = outputs[:, 1:].reshape(-1, outputs.shape[-1])
        trg_out = trg_out[:, 1:].reshape(-1)
        
        # Calculate loss
        loss = criterion(outputs, trg_out)
        
        # Backward pass
        loss.backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), config.CLIP)
        
        optimizer.step()
        
        # Update metrics
        non_pad = (trg_out != PAD_IDX).sum().item()
        total_loss += loss.item() * non_pad
        total_tokens += non_pad
        
        # Update progress bar
        progress_bar.set_postfix({
            'loss': f"{loss.item():.4f}",
            'teacher': f"{teacher_forcing:.2f}"
        })
    
    return total_loss / total_tokens if total_tokens > 0 else 0


In [16]:
def validate_epoch(model, loader, criterion, config):
    model.eval()
    total_loss = 0
    total_tokens = 0
    
    with torch.no_grad():
        progress_bar = tqdm(loader, desc="[Validate]")
        for batch in progress_bar:
            src = batch["src"].to(config.DEVICE)
            trg_in = batch["trg_in"].to(config.DEVICE)
            trg_out = batch["trg_out"].to(config.DEVICE)
            src_len = batch["src_len"].to(config.DEVICE)
            trg_len = batch["trg_len"].to(config.DEVICE)
            
            # Forward pass (no teacher forcing)
            outputs, _ = model(src, trg_in, src_len, trg_len, teacher_forcing_ratio=0)
            
            outputs = outputs[:, 1:].reshape(-1, outputs.shape[-1])
            trg_out = trg_out[:, 1:].reshape(-1)
            
            loss = criterion(outputs, trg_out)
            
            non_pad = (trg_out != PAD_IDX).sum().item()
            total_loss += loss.item() * non_pad
            total_tokens += non_pad
            
            progress_bar.set_postfix({'loss': f"{loss.item():.4f}"})
    
    return total_loss / total_tokens if total_tokens > 0 else 0


In [53]:
def beam_search_decode(model, src, src_len, beam_size=5, max_len=100, length_penalty=1.0):
    """Beam search decoding for batch_first=True encoder and decoder"""
    with torch.no_grad():
        # src shape: (batch_size=1, seq_len)
        # src_len shape: (1,)
        
        # Encode source
        encoder_outputs, hidden, cell = model.encoder(src, src_len)
        # encoder_outputs: (1, seq_len, hidden_size)
        # hidden: (num_layers, 1, hidden_size)
        # cell: (num_layers, 1, hidden_size)
        
        # Initialize beam with start token
        start_token = torch.tensor([SOS_IDX]).to(config.DEVICE)
        
        # Initial hypotheses: (sequence, log_prob, hidden, cell, context)
        hypotheses = [(start_token, 0.0, hidden, cell, None)]
        
        # Create source mask
        mask = (src != PAD_IDX).float()  # (1, seq_len)
        
        # Beam search
        for step in range(max_len):
            all_candidates = []
            
            for seq, score, h, c, ctx in hypotheses:
                # Check if sequence ended
                if seq[-1].item() == EOS_IDX:
                    all_candidates.append((seq, score, h, c, ctx))
                    continue
                
                # Prepare decoder input - current token
                x = seq[-1].unsqueeze(0)  # (1,)
                
                # Decode one step
                output, h_new, c_new, ctx_new, _ = model.decoder(
                    x,  # (1,)
                    h,  # (num_layers, 1, hidden_size)
                    c,  # (num_layers, 1, hidden_size)
                    encoder_outputs,  # (1, seq_len, hidden_size)
                    mask,  # (1, seq_len)
                    ctx  # (1, hidden_size) or None
                )
                # output: (1, vocab_size)
                
                # Get top-k tokens
                log_probs = torch.log_softmax(output, dim=-1)  # (1, vocab_size)
                topk_log_probs, topk_tokens = log_probs.topk(min(beam_size, log_probs.size(-1)))
                
                for k in range(beam_size):
                    token = topk_tokens[0, k].unsqueeze(0)  # (1,)
                    log_prob = topk_log_probs[0, k].item()
                    new_seq = torch.cat([seq, token])
                    
                    # Calculate score with length penalty
                    if length_penalty > 0:
                        # Normalize by sequence length
                        new_score = (score * step + log_prob) / (step + 1)
                    else:
                        new_score = score + log_prob
                    
                    all_candidates.append((new_seq, new_score, h_new, c_new, ctx_new))
            
            # Select top-k hypotheses
            all_candidates.sort(key=lambda x: x[1], reverse=True)
            hypotheses = all_candidates[:beam_size]
            
            # Early stopping if all have EOS
            if all(h[0][-1].item() == EOS_IDX for h in hypotheses):
                break
        
        # Return best sequence (excluding SOS)
        best_seq = hypotheses[0][0]
        return best_seq[1:] if len(best_seq) > 1 else best_seq

In [54]:
def generate_code(model, docstring, use_beam_search=True):
    """Generate code with optional beam search"""
    model.eval()
    
    # Encode docstring
    tokens = encode(docstring)[:config.MAX_DOC_LEN]
    if len(tokens) == 0:
        tokens = [PAD_IDX]
    
    # Convert to tensor - for batch_first=True encoder
    src = torch.tensor(tokens, dtype=torch.long).to(config.DEVICE)
    
    # Correct shapes for encoder (batch_first=True)
    src = src.unsqueeze(0)  # (batch_size=1, seq_len)
    src_len = torch.tensor([src.size(1)], dtype=torch.long).to(config.DEVICE)
    
    if use_beam_search:
        # Beam search decoding
        best_seq = beam_search_decode(
            model,
            src,  # (1, seq_len)
            src_len,  # (1,)
            beam_size=config.BEAM_SIZE,
            max_len=config.MAX_CODE_LEN,
            length_penalty=config.LENGTH_PENALTY
        )
        
        # Convert to list of tokens
        if isinstance(best_seq, torch.Tensor):
            outputs = best_seq.tolist()
        else:
            outputs = best_seq
    
    else:
        # Greedy decoding
        with torch.no_grad():
            encoder_outputs, hidden, cell = model.encoder(src, src_len)
            
            x = torch.tensor([SOS_IDX]).to(config.DEVICE)
            outputs = []
            context = None
            
            # Create mask
            mask = (src != PAD_IDX).float()  # (1, seq_len)
            
            for _ in range(config.MAX_CODE_LEN):
                output, hidden, cell, context, _ = model.decoder(
                    x,  # (1,)
                    hidden,  # (num_layers, 1, hidden_size)
                    cell,   # (num_layers, 1, hidden_size)
                    encoder_outputs,  # (1, seq_len, hidden_size)
                    mask,   # (1, seq_len)
                    context  # (1, hidden_size) or None
                )
                
                top1 = output.argmax(1).item()
                
                if top1 == EOS_IDX:
                    break
                    
                outputs.append(top1)
                x = torch.tensor([top1]).to(config.DEVICE)
    
    return decode(outputs)

In [57]:
def calculate_bleu(model, test_data, n_samples=200, use_beam_search=True):
    """Calculate BLEU score with proper tokenization"""
    from nltk.translate.bleu_score import SmoothingFunction
    import numpy as np
    from tqdm import tqdm
    
    smooth = SmoothingFunction().method4
    scores = []
    
    for i in tqdm(range(min(n_samples, len(test_data))), desc="Calculating BLEU"):
        ref = test_data[i]["code"]
        doc = test_data[i]["docstring"]
        
        if not ref or not doc:
            continue
            
        pred = generate_code(model, doc, use_beam_search)
        
        # Proper tokenization for BLEU
        ref_tokens = tokenizer.encode(ref).tokens
        pred_tokens = tokenizer.encode(pred).tokens
        
        if len(ref_tokens) > 0 and len(pred_tokens) > 0:
            try:
                score = sentence_bleu([ref_tokens], pred_tokens, 
                                    smoothing_function=smooth)
                scores.append(score)
            except:
                continue
    
    mean_score = np.mean(scores) if scores else 0.0
    return mean_score

In [20]:
def calculate_exact_match(model, test_data, n_samples=200):
    """Calculate exact match accuracy"""
    matches = 0
    total = 0
    
    for i in tqdm(range(min(n_samples, len(test_data))), desc="Calculating EM"):
        ref = test_data[i]["code"].strip()
        doc = test_data[i]["docstring"]
        
        if not ref or not doc:
            continue
            
        pred = generate_code(model, doc, use_beam_search=True)
        
        # Normalize for comparison
        ref_norm = ' '.join(ref.split())
        pred_norm = ' '.join(pred.split())
        
        if ref_norm == pred_norm:
            matches += 1
        total += 1
    
    return matches / total if total > 0 else 0


In [21]:
print("\nüöÄ Initializing model...")

encoder = Encoder(
    vocab_size=VOCAB_SIZE,
    embed_size=config.EMBED_SIZE,
    hidden_size=config.HIDDEN_SIZE,
    num_layers=config.NUM_LAYERS,
    dropout=config.DROPOUT,
    bidirectional=config.BIDIRECTIONAL
).to(config.DEVICE)

decoder = Decoder(
    vocab_size=VOCAB_SIZE,
    embed_size=config.EMBED_SIZE,
    hidden_size=config.HIDDEN_SIZE * (2 if config.BIDIRECTIONAL else 1),
    num_layers=config.NUM_LAYERS,
    dropout=config.DROPOUT,
    attention_method='general'
).to(config.DEVICE)

model = Seq2Seq(encoder, decoder, config.DEVICE, config).to(config.DEVICE)

# parameter count
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"üìä Total parameters: {total_params:,}")
print(f"üìä Trainable parameters: {trainable_params:,}")

# optimizer
optimizer = optim.AdamW(
    model.parameters(),
    lr=config.LR,
    weight_decay=config.WEIGHT_DECAY
)

# scheduler (FIXED)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='min',
    factor=config.LR_SCHEDULER_FACTOR,
    patience=config.LR_SCHEDULER_PATIENCE
)

# loss
criterion = nn.CrossEntropyLoss(
    ignore_index=PAD_IDX,
    label_smoothing=0.1
)


üöÄ Initializing model...
üìä Total parameters: 10,125,840
üìä Trainable parameters: 10,125,840


In [22]:
# Training loop
print("\nüéØ Starting training...")
best_val_loss = float('inf')
patience_counter = 0
early_stopping_patience = 5

for epoch in range(config.EPOCHS):
    # Train
    train_loss = train_epoch(model, train_loader, optimizer, criterion, config, epoch)
    
    # Validate
    val_loss = validate_epoch(model, val_loader, criterion, config)
    
    # Update learning rate
    scheduler.step(val_loss)
    
    print(f"\nüìà Epoch {epoch+1}/{config.EPOCHS}")
    print(f"   Train Loss: {train_loss:.4f}")
    print(f"   Val Loss: {val_loss:.4f}")
    print(f"   LR: {optimizer.param_groups[0]['lr']:.6f}")
    
    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': val_loss,
            'config': config
        }, "best_model_attention.pt")
        print(f"‚úÖ Saved best model with val loss: {val_loss:.4f}")
        patience_counter = 0
    else:
        patience_counter += 1
        if patience_counter >= early_stopping_patience:
            print("üõë Early stopping triggered")
            break


üéØ Starting training...


Epoch 1 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 213/213 [01:46<00:00,  2.00it/s, loss=7.0653, teacher=0.70]
[Validate]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 19/19 [00:04<00:00,  4.43it/s, loss=7.1701]



üìà Epoch 1/20
   Train Loss: 7.0944
   Val Loss: 7.2440
   LR: 0.001000
‚úÖ Saved best model with val loss: 7.2440


Epoch 2 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 213/213 [01:44<00:00,  2.04it/s, loss=6.5391, teacher=0.66]
[Validate]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 19/19 [00:04<00:00,  4.32it/s, loss=7.3326]



üìà Epoch 2/20
   Train Loss: 6.6521
   Val Loss: 7.3503
   LR: 0.001000


Epoch 3 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 213/213 [01:43<00:00,  2.05it/s, loss=6.4853, teacher=0.63]
[Validate]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 19/19 [00:04<00:00,  4.43it/s, loss=7.0091]



üìà Epoch 3/20
   Train Loss: 6.3720
   Val Loss: 6.9691
   LR: 0.001000
‚úÖ Saved best model with val loss: 6.9691


Epoch 4 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 213/213 [01:43<00:00,  2.06it/s, loss=5.8394, teacher=0.60]
[Validate]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 19/19 [00:04<00:00,  4.49it/s, loss=7.2262]



üìà Epoch 4/20
   Train Loss: 6.1826
   Val Loss: 7.0533
   LR: 0.001000


Epoch 5 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 213/213 [01:42<00:00,  2.08it/s, loss=6.1380, teacher=0.57]
[Validate]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 19/19 [00:04<00:00,  4.46it/s, loss=6.9704]



üìà Epoch 5/20
   Train Loss: 6.0580
   Val Loss: 6.9469
   LR: 0.001000
‚úÖ Saved best model with val loss: 6.9469


Epoch 6 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 213/213 [01:42<00:00,  2.07it/s, loss=6.0492, teacher=0.54]
[Validate]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 19/19 [00:04<00:00,  4.36it/s, loss=6.9369]



üìà Epoch 6/20
   Train Loss: 5.9746
   Val Loss: 7.0144
   LR: 0.001000


Epoch 7 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 213/213 [01:42<00:00,  2.08it/s, loss=5.8287, teacher=0.51]
[Validate]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 19/19 [00:04<00:00,  4.44it/s, loss=6.9805]



üìà Epoch 7/20
   Train Loss: 5.8998
   Val Loss: 6.8058
   LR: 0.001000
‚úÖ Saved best model with val loss: 6.8058


Epoch 8 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 213/213 [01:42<00:00,  2.08it/s, loss=5.9107, teacher=0.49]
[Validate]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 19/19 [00:04<00:00,  4.49it/s, loss=6.7529]



üìà Epoch 8/20
   Train Loss: 5.8348
   Val Loss: 6.6938
   LR: 0.001000
‚úÖ Saved best model with val loss: 6.6938


Epoch 9 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 213/213 [01:42<00:00,  2.08it/s, loss=5.8613, teacher=0.46]
[Validate]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 19/19 [00:04<00:00,  4.51it/s, loss=6.8484]



üìà Epoch 9/20
   Train Loss: 5.7857
   Val Loss: 6.6710
   LR: 0.001000
‚úÖ Saved best model with val loss: 6.6710


Epoch 10 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 213/213 [01:41<00:00,  2.10it/s, loss=5.8987, teacher=0.44]
[Validate]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 19/19 [00:04<00:00,  4.50it/s, loss=6.7190]



üìà Epoch 10/20
   Train Loss: 5.7403
   Val Loss: 6.6176
   LR: 0.001000
‚úÖ Saved best model with val loss: 6.6176


Epoch 11 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 213/213 [01:42<00:00,  2.09it/s, loss=5.5091, teacher=0.42]
[Validate]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 19/19 [00:04<00:00,  4.49it/s, loss=6.6606]



üìà Epoch 11/20
   Train Loss: 5.7004
   Val Loss: 6.6184
   LR: 0.001000


Epoch 12 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 213/213 [01:41<00:00,  2.09it/s, loss=5.6428, teacher=0.40]
[Validate]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 19/19 [00:04<00:00,  4.48it/s, loss=6.6276]



üìà Epoch 12/20
   Train Loss: 5.6646
   Val Loss: 6.4917
   LR: 0.001000
‚úÖ Saved best model with val loss: 6.4917


Epoch 13 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 213/213 [01:42<00:00,  2.09it/s, loss=6.0645, teacher=0.38]
[Validate]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 19/19 [00:04<00:00,  4.49it/s, loss=6.6706]



üìà Epoch 13/20
   Train Loss: 5.6378
   Val Loss: 6.5983
   LR: 0.001000


Epoch 14 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 213/213 [01:41<00:00,  2.09it/s, loss=5.8301, teacher=0.36]
[Validate]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 19/19 [00:04<00:00,  4.46it/s, loss=6.6591]



üìà Epoch 14/20
   Train Loss: 5.6247
   Val Loss: 6.5397
   LR: 0.001000


Epoch 15 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 213/213 [01:41<00:00,  2.10it/s, loss=5.8361, teacher=0.34]
[Validate]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 19/19 [00:04<00:00,  4.49it/s, loss=6.5196]



üìà Epoch 15/20
   Train Loss: 5.6048
   Val Loss: 6.4933
   LR: 0.000500


Epoch 16 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 213/213 [01:41<00:00,  2.09it/s, loss=5.4131, teacher=0.32]
[Validate]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 19/19 [00:04<00:00,  4.48it/s, loss=6.5524]



üìà Epoch 16/20
   Train Loss: 5.5387
   Val Loss: 6.4384
   LR: 0.000500
‚úÖ Saved best model with val loss: 6.4384


Epoch 17 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 213/213 [01:41<00:00,  2.10it/s, loss=5.8234, teacher=0.31]
[Validate]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 19/19 [00:04<00:00,  4.49it/s, loss=6.5286]



üìà Epoch 17/20
   Train Loss: 5.5414
   Val Loss: 6.4728
   LR: 0.000500


Epoch 18 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 213/213 [01:41<00:00,  2.09it/s, loss=5.5788, teacher=0.29]
[Validate]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 19/19 [00:04<00:00,  4.46it/s, loss=6.5509]



üìà Epoch 18/20
   Train Loss: 5.5351
   Val Loss: 6.3695
   LR: 0.000500
‚úÖ Saved best model with val loss: 6.3695


Epoch 19 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 213/213 [01:43<00:00,  2.06it/s, loss=5.2958, teacher=0.28]
[Validate]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 19/19 [00:04<00:00,  4.48it/s, loss=6.4334]



üìà Epoch 19/20
   Train Loss: 5.5170
   Val Loss: 6.3651
   LR: 0.000500
‚úÖ Saved best model with val loss: 6.3651


Epoch 20 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 213/213 [01:42<00:00,  2.09it/s, loss=5.6870, teacher=0.26]
[Validate]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 19/19 [00:04<00:00,  4.41it/s, loss=6.5123]



üìà Epoch 20/20
   Train Loss: 5.5107
   Val Loss: 6.3517
   LR: 0.000500
‚úÖ Saved best model with val loss: 6.3517


In [58]:
print("\nFinal Evaluation")
print("=" * 50)

# Load best model
checkpoint = torch.load("best_model_attention.pt", weights_only=False)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
print(f"Model loaded from best_model_attention.pt")
print(f"Epoch: {checkpoint.get('epoch', 'N/A')}")
print(f"Loss: {checkpoint.get('loss', 'N/A'):.4f}" if 'loss' in checkpoint else "")

# Calculate metrics
print("\nCalculating BLEU score with greedy decoding...")
bleu_score_greedy = calculate_bleu(model, test_data, n_samples=200, use_beam_search=False)
print(f"\n{'='*50}")
print(f"MEAN BLEU SCORE (Greedy): {bleu_score_greedy:.4f}")
print(f"{'='*50}")

# Optionally calculate with beam search as well
print("\nCalculating BLEU score with beam search...")
bleu_score_beam = calculate_bleu(model, test_data, n_samples=100, use_beam_search=True)
print(f"\n{'='*50}")
print(f"MEAN BLEU SCORE (Beam Search): {bleu_score_beam:.4f}")
print(f"{'='*50}")

# Summary
print("\n" + "="*50)
print("FINAL RESULTS SUMMARY")
print("="*50)
print(f"Greedy Decoding BLEU: {bleu_score_greedy:.4f}")
print(f"Beam Search BLEU:     {bleu_score_beam:.4f}")
print(f"Best Method: {'Beam Search' if bleu_score_beam > bleu_score_greedy else 'Greedy'}")
print("="*50)


Final Evaluation
Model loaded from best_model_attention.pt
Epoch: 19


Calculating BLEU score with greedy decoding...


Calculating BLEU: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 200/200 [00:20<00:00,  9.66it/s]



MEAN BLEU SCORE (Greedy): 0.0107

Calculating BLEU score with beam search...


Calculating BLEU: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 100/100 [01:08<00:00,  1.47it/s]


MEAN BLEU SCORE (Beam Search): 0.0131

FINAL RESULTS SUMMARY
Greedy Decoding BLEU: 0.0107
Beam Search BLEU:     0.0131
Best Method: Beam Search



