<a href="https://colab.research.google.com/github/SanjanaRitika/TextToCode_seq2seq/blob/main/text_to_code.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install datasets tqdm nltk seaborn

import torch
import torch.nn as nn
import torch.optim as optim
import random
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
from datasets import load_dataset
from nltk.translate.bleu_score import corpus_bleu
import nltk
nltk.download('punkt')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [None]:
# configs
MAX_SRC_LEN = 50
MAX_TGT_LEN = 80
BATCH_SIZE  = 64

EMBED_DIM   = 256
HIDDEN_DIM  = 256
DROPOUT     = 0.3

LEARNING_RATE        = 0.001
TEACHER_FORCING_RATIO= 0.5
EPOCHS               = 15
GRAD_CLIP            = 1
SEED                 = 42

SPECIAL_TOKENS = ["<pad>", "<sos>", "<eos>", "<unk>"]

torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

In [None]:
# load dataset
ds = load_dataset("Nan-Do/code-search-net-python")
full_data = ds["train"].shuffle(seed=SEED)
full_data = full_data.select(range(10000))  # reduce for Colab

train_data = full_data.select(range(8000))
val_data   = full_data.select(range(8000, 9000))
test_data  = full_data.select(range(9000, 10000))

print(len(train_data), len(val_data), len(test_data))

8000 1000 1000


In [None]:
# tokenization and vocab
import re
from collections import Counter

def tokenize(text):
    return re.findall(r"\w+|[^\w\s]", text)

def build_vocab(data, field, max_vocab_size=20000):
    counter = Counter()
    for item in data:
        tokens = tokenize(item[field])
        counter.update(tokens)
    vocab = SPECIAL_TOKENS + [tok for tok, _ in counter.most_common(max_vocab_size)]
    stoi = {tok:i for i, tok in enumerate(vocab)}
    itos = {i:tok for tok,i in stoi.items()}
    return stoi, itos

src_stoi, src_itos = build_vocab(train_data, "docstring")
tgt_stoi, tgt_itos = build_vocab(train_data, "code")

print("Source vocab:", len(src_stoi), "Target vocab:", len(tgt_stoi))

Source vocab: 20004 Target vocab: 20004


In [None]:
# dataset and dataloader
class CodeDataset(torch.utils.data.Dataset):
    def __init__(self, data):
        self.data = data

    def encode(self, tokens, stoi, max_len):
        tokens = tokens[:max_len]
        ids = [stoi.get(tok, stoi["<unk>"]) for tok in tokens]
        ids = [stoi["<sos>"]] + ids + [stoi["<eos>"]]
        ids += [stoi["<pad>"]] * (max_len + 2 - len(ids))
        return torch.tensor(ids)

    def __getitem__(self, idx):
        item = self.data[idx]
        src_tokens = tokenize(item["docstring"])
        tgt_tokens = tokenize(item["code"])
        src_ids = self.encode(src_tokens, src_stoi, MAX_SRC_LEN)
        tgt_ids = self.encode(tgt_tokens, tgt_stoi, MAX_TGT_LEN)
        return src_ids, tgt_ids

    def __len__(self):
        return len(self.data)

train_loader = torch.utils.data.DataLoader(CodeDataset(train_data), batch_size=BATCH_SIZE, shuffle=True)
val_loader   = torch.utils.data.DataLoader(CodeDataset(val_data), batch_size=BATCH_SIZE)
test_loader  = torch.utils.data.DataLoader(CodeDataset(test_data), batch_size=BATCH_SIZE)

In [None]:
# vanillaRNN
class EncoderRNN(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.rnn = nn.RNN(embed_dim, hidden_dim, batch_first=True)

    def forward(self, src):
        embedded = self.embedding(src)
        _, hidden = self.rnn(embedded)
        return hidden

class DecoderRNN(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.rnn = nn.RNN(embed_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, input_token, hidden):
        embedded = self.embedding(input_token)
        output, hidden = self.rnn(embedded, hidden)
        prediction = self.fc(output)
        return prediction, hidden

class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, src, tgt, teacher_forcing_ratio=TEACHER_FORCING_RATIO):
        batch_size = tgt.shape[0]
        tgt_len = tgt.shape[1]
        vocab_size = self.decoder.fc.out_features
        outputs = torch.zeros(batch_size, tgt_len, vocab_size).to(device)
        hidden = self.encoder(src)
        input_token = tgt[:,0].unsqueeze(1)

        for t in range(1, tgt_len):
            output, hidden = self.decoder(input_token, hidden)
            outputs[:, t] = output.squeeze(1)
            teacher_force = random.random() < teacher_forcing_ratio
            top1 = output.argmax(2)
            input_token = tgt[:, t].unsqueeze(1) if teacher_force else top1

        return outputs

In [None]:
# LSTM
class EncoderLSTM(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.rnn = nn.LSTM(embed_dim, hidden_dim, batch_first=True)

    def forward(self, src):
        embedded = self.embedding(src)
        _, (hidden, _) = self.rnn(embedded)
        return hidden

class DecoderLSTM(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.rnn = nn.LSTM(embed_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, input_token, hidden):
        embedded = self.embedding(input_token)
        output, (hidden, cell) = self.rnn(embedded, (hidden, torch.zeros_like(hidden)))
        prediction = self.fc(output)
        return prediction, hidden

In [None]:
# LSTM + attention
class BahdanauAttention(nn.Module):
    def __init__(self, enc_hidden_dim, dec_hidden_dim):
        super().__init__()
        self.attn = nn.Linear(enc_hidden_dim*2 + dec_hidden_dim, dec_hidden_dim)
        self.v = nn.Linear(dec_hidden_dim, 1, bias=False)

    def forward(self, hidden, encoder_outputs):
        batch_size = encoder_outputs.shape[0]
        src_len = encoder_outputs.shape[1]
        hidden = hidden.permute(1,0,2).repeat(1, src_len, 1)
        energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim=2)))
        attention = self.v(energy).squeeze(2)
        return torch.softmax(attention, dim=1)

class EncoderBiLSTM(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, dropout=DROPOUT):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.rnn = nn.LSTM(embed_dim, hidden_dim, batch_first=True, bidirectional=True)
        self.dropout = nn.Dropout(dropout)
        self.hidden_dim = hidden_dim

    def forward(self, src):
        embedded = self.dropout(self.embedding(src))
        outputs, (hidden, cell) = self.rnn(embedded)
        hidden = torch.tanh(torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim=1)).unsqueeze(0)
        return outputs, hidden

class DecoderLSTMWithAttention(nn.Module):
    def __init__(self, vocab_size, embed_dim, enc_hidden_dim, dec_hidden_dim, dropout=DROPOUT):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.rnn = nn.LSTM(embed_dim + enc_hidden_dim*2, dec_hidden_dim, batch_first=True)
        self.fc = nn.Linear(enc_hidden_dim*2 + dec_hidden_dim + embed_dim, vocab_size)
        self.attention = BahdanauAttention(enc_hidden_dim, dec_hidden_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, input_token, hidden, encoder_outputs):
        embedded = self.dropout(self.embedding(input_token))
        a = self.attention(hidden, encoder_outputs).unsqueeze(1)
        weighted = torch.bmm(a, encoder_outputs)
        rnn_input = torch.cat((embedded, weighted), dim=2)
        output, (hidden, cell) = self.rnn(rnn_input, (hidden, torch.zeros_like(hidden)))
        output_fc = self.fc(torch.cat((output, weighted, embedded), dim=2))
        return output_fc, hidden, a.squeeze(1)

class Seq2SeqWithAttention(nn.Module):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, src, tgt, teacher_forcing_ratio=TEACHER_FORCING_RATIO):
        batch_size = tgt.shape[0]
        tgt_len = tgt.shape[1]
        vocab_size = self.decoder.fc.out_features
        outputs = torch.zeros(batch_size, tgt_len, vocab_size).to(device)
        encoder_outputs, hidden = self.encoder(src)
        input_token = tgt[:,0].unsqueeze(1)

        for t in range(1, tgt_len):
            output, hidden, _ = self.decoder(input_token, hidden, encoder_outputs)
            outputs[:, t] = output.squeeze(1)
            teacher_force = random.random() < teacher_forcing_ratio
            top1 = output.argmax(2)
            input_token = tgt[:, t].unsqueeze(1) if teacher_force else top1
        return outputs

In [None]:
# training and eval functions
def train_epoch(model, loader, optimizer, criterion):
    model.train()
    epoch_loss = 0
    for src, tgt in loader:
        src, tgt = src.to(device), tgt.to(device)
        optimizer.zero_grad()
        output = model(src, tgt)
        output_dim = output.shape[-1]
        loss = criterion(output[:,1:].reshape(-1, output_dim), tgt[:,1:].reshape(-1))
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
        optimizer.step()
        epoch_loss += loss.item()
    return epoch_loss / len(loader)

def evaluate(model, loader, criterion):
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for src, tgt in loader:
            src, tgt = src.to(device), tgt.to(device)
            output = model(src, tgt, teacher_forcing_ratio=0)
            output_dim = output.shape[-1]
            loss = criterion(output[:,1:].reshape(-1, output_dim), tgt[:,1:].reshape(-1))
            val_loss += loss.item()
    return val_loss / len(loader)

def decode_tokens(output_tensor):
    return [output_tensor.argmax(-1).cpu().item()]

def evaluate_metrics(model, loader, tgt_itos):
    model.eval()
    total_tokens = 0
    correct_tokens = 0
    references, hypotheses = [], []
    exact_matches = 0

    with torch.no_grad():
        for src, tgt in loader:
            src, tgt = src.to(device), tgt.to(device)
            output = model(src, tgt, teacher_forcing_ratio=0)
            pred_tokens = output.argmax(-1).cpu().numpy()
            tgt_tokens = tgt.cpu().numpy()

            # Token-level accuracy
            mask = tgt_tokens[:,1:] != tgt_stoi["<pad>"]
            correct_tokens += ((pred_tokens[:,1:] == tgt_tokens[:,1:]) * mask).sum()
            total_tokens += mask.sum()

            # BLEU
            for t, p in zip(tgt_tokens, pred_tokens):
                references.append([ [tgt_itos[i] for i in t if i not in [tgt_stoi["<pad>"], tgt_stoi["<sos>"], tgt_stoi["<eos>"]]] ])
                hypotheses.append([ tgt_itos[i] for i in p if i not in [tgt_stoi["<pad>"], tgt_stoi["<sos>"], tgt_stoi["<eos>"]]])

            # Exact Match
            for t, p in zip(tgt_tokens, pred_tokens):
                t_clean = [tgt_itos[i] for i in t if i not in [tgt_stoi["<pad>"], tgt_stoi["<sos>"], tgt_stoi["<eos>"]]]
                p_clean = [tgt_itos[i] for i in p if i not in [tgt_stoi["<pad>"], tgt_stoi["<sos>"], tgt_stoi["<eos>"]]]
                if t_clean == p_clean:
                    exact_matches += 1

    token_acc = correct_tokens / total_tokens
    bleu = corpus_bleu(references, hypotheses)
    exact_match = exact_matches / len(loader.dataset)
    return token_acc, bleu, exact_match

In [None]:
# train models
from tqdm.notebook import tqdm

model_configs = {
    "VanillaRNN": (EncoderRNN(len(src_stoi), EMBED_DIM, HIDDEN_DIM),
                   DecoderRNN(len(tgt_stoi), EMBED_DIM, HIDDEN_DIM)),
    "LSTM": (EncoderLSTM(len(src_stoi), EMBED_DIM, HIDDEN_DIM),
             DecoderLSTM(len(tgt_stoi), EMBED_DIM, HIDDEN_DIM)),
    "Attention": (EncoderBiLSTM(len(src_stoi), EMBED_DIM, HIDDEN_DIM),
                  DecoderLSTMWithAttention(len(tgt_stoi), EMBED_DIM, HIDDEN_DIM, HIDDEN_DIM*2, DROPOUT))
}

results = {}

for model_name, (encoder, decoder) in model_configs.items():
    print(f"\n=== Training {model_name} ===")
    if model_name == "Attention":
        model = Seq2SeqWithAttention(encoder, decoder).to(device)
    else:
        model = Seq2Seq(encoder, decoder).to(device)

    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    criterion = nn.CrossEntropyLoss(ignore_index=tgt_stoi["<pad>"])

    best_val_loss = float("inf")
    for epoch in range(EPOCHS):
        # ---- Training ----
        model.train()
        train_loss = 0
        loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}", leave=False)
        for src, tgt in loop:
            src, tgt = src.to(device), tgt.to(device)
            optimizer.zero_grad()
            output = model(src, tgt)
            output_dim = output.shape[-1]
            loss = criterion(output[:,1:].reshape(-1, output_dim), tgt[:,1:].reshape(-1))
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
            optimizer.step()
            train_loss += loss.item()
            loop.set_postfix(train_loss=train_loss/len(loop))

        # ---- Validation ----
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for src, tgt in val_loader:
                src, tgt = src.to(device), tgt.to(device)
                output = model(src, tgt, teacher_forcing_ratio=0)
                output_dim = output.shape[-1]
                loss = criterion(output[:,1:].reshape(-1, output_dim), tgt[:,1:].reshape(-1))
                val_loss += loss.item()
        val_loss /= len(val_loader)

        print(f"Epoch {epoch+1}/{EPOCHS} | Train Loss: {train_loss/len(train_loader):.4f} | Val Loss: {val_loss:.4f}")

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), f"{model_name}_best.pth")
            print("Saved checkpoint!")

    # ---- Evaluation ----
    model.load_state_dict(torch.load(f"{model_name}_best.pth"))
    token_acc, bleu, exact = evaluate_metrics(model, test_loader, tgt_itos)
    results[model_name] = {"Token Accuracy": token_acc, "BLEU": bleu, "Exact Match": exact}


=== Training VanillaRNN ===


Epoch 1/15:   0%|          | 0/125 [00:00<?, ?it/s]

Epoch 1/15 | Train Loss: 5.6372 | Val Loss: 5.2227
Saved checkpoint!


Epoch 2/15:   0%|          | 0/125 [00:00<?, ?it/s]

Epoch 2/15 | Train Loss: 4.8800 | Val Loss: 6.0615


Epoch 3/15:   0%|          | 0/125 [00:00<?, ?it/s]

Epoch 3/15 | Train Loss: 4.6892 | Val Loss: 5.2066
Saved checkpoint!


Epoch 4/15:   0%|          | 0/125 [00:00<?, ?it/s]

Epoch 4/15 | Train Loss: 4.5706 | Val Loss: 5.5046


Epoch 5/15:   0%|          | 0/125 [00:00<?, ?it/s]

Epoch 5/15 | Train Loss: 4.5034 | Val Loss: 5.5071


Epoch 6/15:   0%|          | 0/125 [00:00<?, ?it/s]

Epoch 6/15 | Train Loss: 4.4456 | Val Loss: 5.4462


Epoch 7/15:   0%|          | 0/125 [00:00<?, ?it/s]

Epoch 7/15 | Train Loss: 4.4127 | Val Loss: 5.5169


Epoch 8/15:   0%|          | 0/125 [00:00<?, ?it/s]

Epoch 8/15 | Train Loss: 4.3531 | Val Loss: 5.5126


Epoch 9/15:   0%|          | 0/125 [00:00<?, ?it/s]

Epoch 9/15 | Train Loss: 4.3231 | Val Loss: 5.4109


Epoch 10/15:   0%|          | 0/125 [00:00<?, ?it/s]

Epoch 10/15 | Train Loss: 4.2780 | Val Loss: 5.5639


Epoch 11/15:   0%|          | 0/125 [00:00<?, ?it/s]

Epoch 11/15 | Train Loss: 4.2527 | Val Loss: 5.3470


Epoch 12/15:   0%|          | 0/125 [00:00<?, ?it/s]

Epoch 12/15 | Train Loss: 4.2285 | Val Loss: 5.4767


Epoch 13/15:   0%|          | 0/125 [00:00<?, ?it/s]

Epoch 13/15 | Train Loss: 4.2057 | Val Loss: 5.6324


Epoch 14/15:   0%|          | 0/125 [00:00<?, ?it/s]

Epoch 14/15 | Train Loss: 4.2093 | Val Loss: 5.6225


Epoch 15/15:   0%|          | 0/125 [00:00<?, ?it/s]

Epoch 15/15 | Train Loss: 4.1650 | Val Loss: 5.4330

=== Training LSTM ===


Epoch 1/15:   0%|          | 0/125 [00:00<?, ?it/s]

Epoch 1/15 | Train Loss: 5.9795 | Val Loss: 5.3177
Saved checkpoint!


Epoch 2/15:   0%|          | 0/125 [00:00<?, ?it/s]

Epoch 2/15 | Train Loss: 5.0219 | Val Loss: 5.2127
Saved checkpoint!


Epoch 3/15:   0%|          | 0/125 [00:00<?, ?it/s]

Epoch 3/15 | Train Loss: 4.8323 | Val Loss: 5.2180


Epoch 4/15:   0%|          | 0/125 [00:00<?, ?it/s]

Epoch 4/15 | Train Loss: 4.7194 | Val Loss: 6.2190


Epoch 5/15:   0%|          | 0/125 [00:00<?, ?it/s]

Epoch 5/15 | Train Loss: 4.6371 | Val Loss: 5.3584


Epoch 6/15:   0%|          | 0/125 [00:00<?, ?it/s]

Epoch 6/15 | Train Loss: 4.5640 | Val Loss: 5.4415


Epoch 7/15:   0%|          | 0/125 [00:00<?, ?it/s]

Epoch 7/15 | Train Loss: 4.5072 | Val Loss: 5.4312


Epoch 8/15:   0%|          | 0/125 [00:00<?, ?it/s]

Epoch 8/15 | Train Loss: 4.4690 | Val Loss: 5.5333


Epoch 9/15:   0%|          | 0/125 [00:00<?, ?it/s]

Epoch 9/15 | Train Loss: 4.4278 | Val Loss: 5.5817


Epoch 10/15:   0%|          | 0/125 [00:00<?, ?it/s]

Epoch 10/15 | Train Loss: 4.3977 | Val Loss: 5.4639


Epoch 11/15:   0%|          | 0/125 [00:00<?, ?it/s]

Epoch 11/15 | Train Loss: 4.3463 | Val Loss: 5.4913


Epoch 12/15:   0%|          | 0/125 [00:00<?, ?it/s]

Epoch 12/15 | Train Loss: 4.3220 | Val Loss: 5.5710


Epoch 13/15:   0%|          | 0/125 [00:00<?, ?it/s]

Epoch 13/15 | Train Loss: 4.2937 | Val Loss: 5.5701


Epoch 14/15:   0%|          | 0/125 [00:00<?, ?it/s]

Epoch 14/15 | Train Loss: 4.2606 | Val Loss: 5.4792


Epoch 15/15:   0%|          | 0/125 [00:00<?, ?it/s]

Epoch 15/15 | Train Loss: 4.2463 | Val Loss: 5.5065

=== Training Attention ===


Epoch 1/15:   0%|          | 0/125 [00:00<?, ?it/s]

Epoch 1/15 | Train Loss: 5.4671 | Val Loss: 5.2491
Saved checkpoint!


Epoch 2/15:   0%|          | 0/125 [00:00<?, ?it/s]

Epoch 2/15 | Train Loss: 4.7479 | Val Loss: 5.3218


Epoch 3/15:   0%|          | 0/125 [00:00<?, ?it/s]

Epoch 3/15 | Train Loss: 4.4284 | Val Loss: 5.3144


Epoch 4/15:   0%|          | 0/125 [00:00<?, ?it/s]

Epoch 4/15 | Train Loss: 4.0786 | Val Loss: 5.4330


Epoch 5/15:   0%|          | 0/125 [00:00<?, ?it/s]

Epoch 5/15 | Train Loss: 3.7322 | Val Loss: 5.4370


Epoch 6/15:   0%|          | 0/125 [00:00<?, ?it/s]

In [None]:
print("\n=== Final Comparison ===")
for model_name, metrics in results.items():
    print(f"{model_name}: Token Acc: {metrics['Token Accuracy']*100:.2f}%, BLEU: {metrics['BLEU']*100:.2f}, Exact Match: {metrics['Exact Match']*100:.2f}%")