In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import AutoTokenizer
from datasets import load_dataset
from torch.utils.data import DataLoader, Dataset
import random
from tqdm import tqdm
import math

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# ✅ Load Dataset
dataset = load_dataset("google/code_x_glue_ct_code_to_text", "java")
train_data = dataset["train"]
val_data = dataset["validation"]
train_data = train_data.select(random.sample(range(len(train_data)), int(len(train_data) * 0.3)))
val_data = val_data.select(random.sample(range(len(val_data)), int(len(val_data) * 0.3)))


In [3]:

# ✅ Tokenizer setup
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.add_special_tokens({'pad_token': '<pad>', 'sep_token': '<sep>', 'bos_token': '<s>', 'eos_token': '</s>'})


4

In [None]:
# ✅ Tokenization
def tokenize(example):
    code = example["code"]
    docstring = example["docstring"]
    input_text = f"<s> {code} <sep> {docstring} </s>"  # 👈 Still used for full tokenization
    tokenized = tokenizer(input_text, truncation=True, padding="max_length", max_length=256)

    # Create labels that mask code tokens, only docstring is target
    try:
        sep_idx = tokenized["input_ids"].index(tokenizer.convert_tokens_to_ids("<sep>"))
    except ValueError:
        sep_idx = 128  # Fallback if separator is not found
    
    # Only predict tokens AFTER the <sep>
    labels = [-100] * (sep_idx + 1) + tokenized["input_ids"][sep_idx + 1:]
    labels += [-100] * (256 - len(labels))
    tokenized["labels"] = labels[:256]

    return tokenized

train_data = train_data.map(tokenize)
val_data = val_data.map(tokenize)

Map: 100%|████████████████████████████████████████████████████████████████████████████| 49476/49476 [00:22<00:00, 2186.63 examples/s]
Map: 100%|██████████████████████████████████████████████████████████████████████████████| 1554/1554 [00:00<00:00, 2390.76 examples/s]


In [6]:
# ✅ Dataset Class
class CodeDataset(Dataset):
    def __init__(self, hf_dataset):
        self.input_ids = torch.tensor(hf_dataset["input_ids"])
        self.labels = torch.tensor(hf_dataset["labels"])

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return {
            "input_ids": self.input_ids[idx],
            "labels": self.labels[idx]
        }

train_dataset = CodeDataset(train_data)
val_dataset = CodeDataset(val_data)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8)

In [31]:

# ✅ Decoder-Only Model
class DecoderOnlyModel(nn.Module):
    def __init__(self, vocab_size, d_model, n_heads, num_layers, max_seq_len=256):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_embedding = nn.Embedding(max_seq_len, d_model)
        decoder_layer = nn.TransformerDecoderLayer(d_model=d_model, nhead=n_heads)
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers)
        self.fc_out = nn.Linear(d_model, vocab_size)

    def forward(self, x):
        batch_size, seq_length = x.shape
        positions = torch.arange(0, seq_length, device=x.device).unsqueeze(0).expand(batch_size, -1)
        x = self.embedding(x) + self.pos_embedding(positions)

        tgt_mask = torch.triu(torch.ones(seq_length, seq_length, device=x.device), diagonal=1)
        tgt_mask = tgt_mask.masked_fill(tgt_mask == 1, float('-inf'))

        memory = torch.zeros((seq_length, batch_size, x.size(-1)), device=x.device)
        x = self.transformer_decoder(x.permute(1, 0, 2), memory, tgt_mask=tgt_mask)
        return self.fc_out(x.permute(1, 0, 2))

In [None]:
# ✅ Train Function
def train_model(name, d_model, n_layers, n_heads, epochs=5):
    print(f"Training {name} with d_model={d_model}, layers={n_layers}, heads={n_heads}")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = DecoderOnlyModel(len(tokenizer), d_model=d_model, n_heads=n_heads, num_layers=n_layers).to(device)
    model.embedding.weight.data.normal_(mean=0.0, std=0.02)
    model.fc_out.bias.data.zero_()

    optimizer = optim.AdamW(model.parameters(), lr=5e-5)
    criterion = nn.CrossEntropyLoss(ignore_index=-100)

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        for batch in tqdm(train_loader, desc=f"{name} Epoch {epoch+1}"):
            input_ids = batch["input_ids"].to(device)
            labels = batch["labels"].to(device)
            optimizer.zero_grad()
            outputs = model(input_ids)
            loss = criterion(outputs.view(-1, outputs.size(-1)), labels.view(-1))
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Epoch {epoch+1} Loss: {total_loss / len(train_loader):.4f}")

    # Save model + check perplexity
    torch.save(model.state_dict(), f"{name}_final_model.pth")
    ppl = calculate_perplexity(model, val_loader, criterion, device)
    print(f"{name} Final Validation Perplexity: {ppl:.2f}")

train_model("M8", d_model=768, n_layers=6, n_heads=8, epochs=5)


Training M8 with d_model=768, layers=6, heads=8


M8 Epoch 1: 100%|████████████████████████████████████████████████████████████████████████████████| 6185/6185 [09:59<00:00, 10.32it/s]


Epoch 1 Loss: 1.8673


M8 Epoch 2: 100%|████████████████████████████████████████████████████████████████████████████████| 6185/6185 [10:54<00:00,  9.45it/s]


Epoch 2 Loss: 0.9518


M8 Epoch 3: 100%|████████████████████████████████████████████████████████████████████████████████| 6185/6185 [12:18<00:00,  8.37it/s]


Epoch 3 Loss: 0.5918


M8 Epoch 4: 100%|████████████████████████████████████████████████████████████████████████████████| 6185/6185 [16:44<00:00,  6.16it/s]


Epoch 4 Loss: 0.3925


M8 Epoch 5: 100%|████████████████████████████████████████████████████████████████████████████████| 6185/6185 [16:49<00:00,  6.13it/s]


Epoch 5 Loss: 0.2844
M8 Final Validation Perplexity: 1.28
Training M10 with d_model=768, layers=8, heads=12


M10 Epoch 1: 100%|███████████████████████████████████████████████████████████████████████████████| 6185/6185 [19:48<00:00,  5.21it/s]


Epoch 1 Loss: 1.9471


M10 Epoch 2: 100%|███████████████████████████████████████████████████████████████████████████████| 6185/6185 [19:48<00:00,  5.20it/s]


Epoch 2 Loss: 1.2511


M10 Epoch 3: 100%|███████████████████████████████████████████████████████████████████████████████| 6185/6185 [19:49<00:00,  5.20it/s]


Epoch 3 Loss: 1.0373


M10 Epoch 4: 100%|███████████████████████████████████████████████████████████████████████████████| 6185/6185 [19:49<00:00,  5.20it/s]


Epoch 4 Loss: 0.8972


M10 Epoch 5:   9%|███████▎                                                                        | 570/6185 [01:51<18:25,  5.08it/s]

In [7]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.add_special_tokens({'pad_token': '<pad>', 'sep_token': '<sep>', 'bos_token': '<s>', 'eos_token': '</s>'})


4

In [8]:
import torch
import torch.nn as nn

class DecoderOnlyModel(nn.Module):
    def __init__(self, vocab_size, d_model, n_heads, num_layers, max_seq_len=256):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_embedding = nn.Embedding(max_seq_len, d_model)
        decoder_layer = nn.TransformerDecoderLayer(d_model=d_model, nhead=n_heads)
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers)
        self.fc_out = nn.Linear(d_model, vocab_size)

    def forward(self, x):
        batch_size, seq_length = x.shape
        positions = torch.arange(0, seq_length, device=x.device).unsqueeze(0).expand(batch_size, -1)
        x = self.embedding(x) + self.pos_embedding(positions)
        tgt_mask = torch.triu(torch.ones(seq_length, seq_length, device=x.device), diagonal=1)
        tgt_mask = tgt_mask.masked_fill(tgt_mask == 1, float('-inf'))
        memory = torch.zeros((seq_length, batch_size, x.size(-1)), device=x.device)
        x = self.transformer_decoder(x.permute(1, 0, 2), memory, tgt_mask=tgt_mask)
        return self.fc_out(x.permute(1, 0, 2))


In [28]:

# Matching architecture
d_model = 768
# n_layers = 8 if model_name == "M10" else 6
n_heads = 12 

# Initialize and load weights
model = DecoderOnlyModel(len(tokenizer), d_model, n_heads, n_layers)
model.load_state_dict(torch.load(f"{model_name}_final_model.pth", map_location="cpu"))
model.eval()


DecoderOnlyModel(
  (embedding): Embedding(50261, 768)
  (pos_embedding): Embedding(256, 768)
  (transformer_decoder): TransformerDecoder(
    (layers): ModuleList(
      (0-5): 6 x TransformerDecoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
        )
        (multihead_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
        )
        (linear1): Linear(in_features=768, out_features=2048, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=2048, out_features=768, bias=True)
        (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (norm3): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(

In [16]:
def generate_docstring(model, tokenizer, code_snippet, max_length=100):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()

    input_text = f"<s> {code_snippet} <sep>"
    input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device)

    for _ in range(max_length):
        outputs = model(input_ids)
        next_token_logits = outputs[:, -1, :]
        next_token_id = torch.argmax(next_token_logits, dim=-1).unsqueeze(0)
        input_ids = torch.cat([input_ids, next_token_id], dim=-1)
        if next_token_id.item() == tokenizer.eos_token_id:
            break

    output_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
    return output_text.split("<sep>")[-1].strip()


In [18]:
# Testing on model M8
java_code = "public int add(int a, int b) { return a + b; }"
docstring = generate_docstring(model, tokenizer, java_code)
print("Generated docstring:", docstring)


Generated docstring: public int add(int a, int b) { return a + b; } s Create theience route Bind Return file elementven Create attribute map - Create bound.. in in path path path left option controller proxy rate rate video users users users users users users users users users users users users users users users users users users users usersatteratterDeviceDeviceDeviceWebitionutiodvenuevenuevenuevenuevenue Campaign budget budgetreshshotshotshotshotshotshotshotshotshotshotshotshotshotshotshotshotshotshotshotshotshotshotshotshotshotshotshotshotshotshotshotshot
