# Train an Encoder-Decoder Model for Java Code Summarisation

---




**INSTALL LIBRARIES**
----------------------
----------------------
----------------------
----------------------

In [None]:

!pip install transformers datasets evaluate rouge_score bert_score --quiet


  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m491.2/491.2 kB[0m [31m24.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.0/84.0 kB[0m [31m7.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.1/61.1 kB[0m [31m5.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m11.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m183.9/183.9 kB[0m [31m18.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m143.5/143.5 kB[0m [31m14.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m4.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m119.3 MB/s[0m eta [36

**MOUNT DRIVE**
----------------------
----------------------
----------------------
----------------------

In [None]:
import os
from google.colab import drive

# ==========================
# Mount Google Drive
# ==========================
drive.mount('/content/drive')

Mounted at /content/drive


**TOKENIZATION**
----------------------
----------------------
----------------------
----------------------

In [None]:
from transformers import AutoTokenizer
from datasets import load_dataset
import torch
from torch.utils.data import Dataset
import re

from datasets import load_dataset

# Load dataset and tokenizer
dataset = load_dataset("code_x_glue_ct_code_to_text", "java")

tokenizer = AutoTokenizer.from_pretrained("Salesforce/codet5-base")
MAX_SOURCE_LENGTH = 256
MAX_TARGET_LENGTH = 80

class CodeSummaryDataset(Dataset):
    def __init__(self, hf_dataset_split):
        self.dataset = hf_dataset_split

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        sample = self.dataset[idx]
        code = sample["code"]
        summary = sample["docstring"]

        # Normalize whitespace
        code = re.sub(r'\s+', ' ', code).strip()
        summary = re.sub(r'\s+', ' ', summary).strip()

        # Tokenize code (encoder input)
        source = tokenizer(code,
                           padding="max_length",
                           truncation=True,
                           max_length=MAX_SOURCE_LENGTH,
                           return_tensors="pt",
                           add_special_tokens=True)

        # Tokenize summary (decoder input/labels)
        target = tokenizer(summary,
                           padding="max_length",
                           truncation=True,
                           max_length=MAX_TARGET_LENGTH,
                           return_tensors="pt",
                           add_special_tokens=True)

        # Shift for decoder input vs labels
        decoder_input_ids = target.input_ids[:, :-1]
        labels = target.input_ids[:, 1:]

        # Pad to max target length - 1 (after shift)
        pad_len = MAX_TARGET_LENGTH - 1 - decoder_input_ids.size(1)
        if pad_len > 0:
            pad = torch.full((1, pad_len), tokenizer.pad_token_id)
            decoder_input_ids = torch.cat([decoder_input_ids, pad], dim=1)
            labels = torch.cat([labels, pad], dim=1)

        return {
            "input_ids": source.input_ids.squeeze(0),
            "decoder_input_ids": decoder_input_ids.squeeze(0),
            "labels": labels.squeeze(0)
        }

**MODEL** **ARCHITECTURE**
----------------------
----------------------
----------------------
----------------------

In [None]:
# 🔧 Custom Encoder-Decoder Model using PyTorch and CodeT5 Embeddings

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel
from datasets import load_dataset
from torch.utils.data import Dataset, DataLoader
import time
import os
from google.colab import drive
import re
import math

# ==========================
# Sinusoidal Positional Encoding
# ==========================
def get_sinusoidal_encoding(max_len, d_model):
    pe = torch.zeros(max_len, d_model)
    position = torch.arange(0, max_len).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))
    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)
    return pe.unsqueeze(0)  # [1, max_len, d_model]

# ==========================
# Model Components
# ==========================
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        assert embed_dim % num_heads == 0
        self.head_dim = embed_dim // num_heads

        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim)

    def forward(self, query, key, value, mask=None):
        B, T, D = query.size()
        H = self.num_heads

        def reshape(x):
            return x.view(B, -1, H, self.head_dim).transpose(1, 2)

        Q = reshape(self.q_proj(query))
        K = reshape(self.k_proj(key))
        V = reshape(self.v_proj(value))

        attn_weights = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)
        if mask is not None:
            mask = mask.unsqueeze(1).unsqueeze(2)  # [B, 1, 1, T]
            attn_weights = attn_weights.masked_fill(mask == 0, float('-inf'))

        attn_probs = F.softmax(attn_weights, dim=-1)
        attn_output = torch.matmul(attn_probs, V)
        attn_output = attn_output.transpose(1, 2).contiguous().view(B, T, D)
        return self.out_proj(attn_output)

class TransformerEncoderLayer(nn.Module):
    def __init__(self, embed_dim, num_heads, ff_dim, dropout):
        super().__init__()
        self.self_attn = MultiHeadAttention(embed_dim, num_heads)
        self.ff = nn.Sequential(
            nn.Linear(embed_dim, ff_dim),
            nn.ReLU(),
            nn.Linear(ff_dim, embed_dim)
        )
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, src, src_mask):
        attn_output = self.self_attn(src, src, src, src_mask)
        src = self.norm1(src + self.dropout(attn_output))
        ff_output = self.ff(src)
        src = self.norm2(src + self.dropout(ff_output))
        return src

class TransformerDecoderLayer(nn.Module):
    def __init__(self, embed_dim, num_heads, ff_dim, dropout):
        super().__init__()
        self.self_attn = MultiHeadAttention(embed_dim, num_heads)
        self.cross_attn = MultiHeadAttention(embed_dim, num_heads)
        self.ff = nn.Sequential(
            nn.Linear(embed_dim, ff_dim),
            nn.ReLU(),
            nn.Linear(ff_dim, embed_dim)
        )
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.norm3 = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, tgt, memory, tgt_mask, memory_mask):
        tgt2 = self.self_attn(tgt, tgt, tgt, tgt_mask)
        tgt = self.norm1(tgt + self.dropout(tgt2))
        tgt2 = self.cross_attn(tgt, memory, memory, memory_mask)
        tgt = self.norm2(tgt + self.dropout(tgt2))
        tgt2 = self.ff(tgt)
        tgt = self.norm3(tgt + self.dropout(tgt2))
        return tgt

class TransformerModel(nn.Module):
    def __init__(self, embed_dim=768, num_heads=8, ff_dim=2048, num_layers=4, dropout=0.1, vocab_size=32100, max_len=512):
        super().__init__()
        self.encoder_layers = nn.ModuleList([
            TransformerEncoderLayer(embed_dim, num_heads, ff_dim, dropout)
            for _ in range(num_layers)
        ])
        self.decoder_layers = nn.ModuleList([
            TransformerDecoderLayer(embed_dim, num_heads, ff_dim, dropout)
            for _ in range(num_layers)
        ])
        self.lm_head = nn.Linear(embed_dim, vocab_size)

        self.positional_encoding = get_sinusoidal_encoding(max_len, embed_dim)

    def forward(self, encoder_embeddings, decoder_embeddings, src_mask=None, tgt_mask=None):
        B, S, _ = encoder_embeddings.size()
        B2, T, _ = decoder_embeddings.size()

        # Add sinusoidal positional embeddings
        encoder_embeddings = encoder_embeddings + self.positional_encoding[:, :S, :].to(encoder_embeddings.device)
        decoder_embeddings = decoder_embeddings + self.positional_encoding[:, :T, :].to(decoder_embeddings.device)

        memory = encoder_embeddings
        for layer in self.encoder_layers:
            memory = layer(memory, src_mask)

        if tgt_mask is None:
            tgt_mask = torch.tril(torch.ones(T, T)).to(decoder_embeddings.device)  # [T, T]
            tgt_mask = tgt_mask.unsqueeze(0).unsqueeze(0)  # [1, 1, T, T]

        output = decoder_embeddings
        for layer in self.decoder_layers:
            output = layer(output, memory, tgt_mask, src_mask)

        logits = self.lm_head(output)
        return logits


**MODEL TRAINING**
----------------------
----------------------
----------------------
----------------------

In [None]:
from torch.utils.data import DataLoader
import torch
import time
import os

# Load dataset and tokenizer
dataset = load_dataset("code_x_glue_ct_code_to_text", "java")
train_dataset = CodeSummaryDataset(dataset["train"])
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)

# Load tokenizer and embeddings
tokenizer = AutoTokenizer.from_pretrained("Salesforce/codet5-base")
codet5 = AutoModel.from_pretrained("Salesforce/codet5-base")
embedding_layer = codet5.get_input_embeddings()

# Training Setup
model = TransformerModel()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
embedding_layer.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)
loss_fn = torch.nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)

EPOCHS = 3
CHECKPOINT_DIR = "/content/checkpoints"
DRIVE_DIR = "/content/drive/MyDrive/codet5_checkpoints-new"
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
os.makedirs(DRIVE_DIR, exist_ok=True)

total_batches = len(train_loader)
save_every = total_batches // 2

print(f" Starting training for {EPOCHS} epochs — {total_batches} batches/epoch\n")

# Training Loop
for epoch in range(EPOCHS):
    model.train()
    total_loss = 0
    print(f" Epoch {epoch+1}/{EPOCHS}")

    for batch_idx, batch in enumerate(train_loader):
        start_time = time.time()

        input_ids = batch["input_ids"].to(device)
        decoder_input_ids = batch["decoder_input_ids"].to(device)
        labels = batch["labels"].to(device)

        with torch.no_grad():
            encoder_input_embeddings = embedding_layer(input_ids)
            decoder_input_embeddings = embedding_layer(decoder_input_ids)

        logits = model(encoder_input_embeddings, decoder_input_embeddings)
        loss = loss_fn(logits.view(-1, logits.size(-1)), labels.view(-1))
        total_loss += loss.item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        step_time = time.time() - start_time
        print(f" Epoch {epoch+1} | Batch {batch_idx+1}/{total_batches} | Loss: {loss.item():.4f} | Time: {step_time:.2f}s")

        # Checkpoint mid-epoch
        if (batch_idx + 1) % save_every == 0:
            name = f"model_epoch{epoch+1}_step{batch_idx+1}.pt"
            torch.save(model.state_dict(), os.path.join(CHECKPOINT_DIR, name))
            torch.save(model.state_dict(), os.path.join(DRIVE_DIR, name))
            print(f" Checkpoint saved to Colab and Drive: {name}")

    # End-of-epoch checkpoint
    final_name = f"model_epoch{epoch+1}.pt"
    torch.save(model.state_dict(), os.path.join(CHECKPOINT_DIR, final_name))
    torch.save(model.state_dict(), os.path.join(DRIVE_DIR, final_name))
    print(f"End-of-epoch checkpoint saved: {final_name}")
    print(f"Average Epoch Loss: {total_loss / total_batches:.4f}\\n")


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
🌀 Epoch 1 | Batch 5309/20616 | Loss: 1.2042 | Time: 0.07s
🌀 Epoch 1 | Batch 5310/20616 | Loss: 1.4902 | Time: 0.07s
🌀 Epoch 1 | Batch 5311/20616 | Loss: 1.9371 | Time: 0.07s
🌀 Epoch 1 | Batch 5312/20616 | Loss: 1.3970 | Time: 0.07s
🌀 Epoch 1 | Batch 5313/20616 | Loss: 1.1576 | Time: 0.07s
🌀 Epoch 1 | Batch 5314/20616 | Loss: 1.2867 | Time: 0.07s
🌀 Epoch 1 | Batch 5315/20616 | Loss: 1.7901 | Time: 0.07s
🌀 Epoch 1 | Batch 5316/20616 | Loss: 1.5753 | Time: 0.08s
🌀 Epoch 1 | Batch 5317/20616 | Loss: 2.1265 | Time: 0.07s
🌀 Epoch 1 | Batch 5318/20616 | Loss: 1.8460 | Time: 0.07s
🌀 Epoch 1 | Batch 5319/20616 | Loss: 1.8576 | Time: 0.07s
🌀 Epoch 1 | Batch 5320/20616 | Loss: 2.0082 | Time: 0.08s
🌀 Epoch 1 | Batch 5321/20616 | Loss: 2.2212 | Time: 0.08s
🌀 Epoch 1 | Batch 5322/20616 | Loss: 2.1579 | Time: 0.08s
🌀 Epoch 1 | Batch 5323/20616 | Loss: 1.5295 | Time: 0.07s
🌀 Epoch 1 | Batch 5324/20616 | Loss: 1.2544 | Time: 0.08s
🌀 Epoch

RuntimeError: File /content/drive/MyDrive/codet5_checkpoints-new/model_epoch1_step10308.pt cannot be opened.

**INFERENCE PIPELINE**
----------------------
----------------------
----------------------
----------------------

In [None]:
class CustomSummaryGenerator:
    def __init__(self, model, tokenizer, embedding_layer, device, max_input_len=256, max_target_len=80):
        self.model = model.to(device)
        self.tokenizer = tokenizer
        self.embedding_layer = embedding_layer.to(device)
        self.device = device
        self.max_input_len = max_input_len
        self.max_target_len = max_target_len

    def top_k_sampling(self, logits, k=50, temperature=1.0):
        logits = logits / temperature
        top_k_values, top_k_indices = torch.topk(logits, k, dim=-1)
        probs = torch.softmax(top_k_values, dim=-1)
        sampled_idx = torch.multinomial(probs, num_samples=1)
        return top_k_indices.gather(-1, sampled_idx)

    def generate_summary(self, code_snippet):
        import re
        code_snippet = re.sub(r'\s+', ' ', code_snippet).strip()
        source = self.tokenizer(code_snippet,
                                padding="max_length",
                                truncation=True,
                                max_length=self.max_input_len,
                                return_tensors="pt").to(self.device)

        with torch.no_grad():
            encoder_embeddings = self.embedding_layer(source["input_ids"])
            memory = encoder_embeddings
            for layer in self.model.encoder_layers:
                memory = layer(memory, source["attention_mask"])

        generated_ids = torch.full((1, 1), self.tokenizer.pad_token_id, dtype=torch.long).to(self.device)

        for _ in range(self.max_target_len):
            with torch.no_grad():
                decoder_embeddings = self.embedding_layer(generated_ids)
                logits = self.model(memory, decoder_embeddings)
                next_token_logits = logits[:, -1, :]
                # To top-k sampling:
                next_token_id = self.top_k_sampling(next_token_logits, k=50, temperature=0.7)
                generated_ids = torch.cat((generated_ids, next_token_id), dim=1)

                if next_token_id.item() == self.tokenizer.pad_token_id:
                    break
                if (generated_ids[0, -10:] == next_token_id).all():
                    print("Stopping early due to repetition")
                    break

        return self.tokenizer.decode(generated_ids[0], skip_special_tokens=True).strip()


In [None]:
from huggingface_hub import hf_hub_download

# Load tokenizer and embeddings
tokenizer = AutoTokenizer.from_pretrained("Salesforce/codet5-base")
codet5 = AutoModel.from_pretrained("Salesforce/codet5-base")
embedding_layer = codet5.get_input_embeddings()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
checkpoint_path = hf_hub_download(
    repo_id="pritammane105/Custom-Java-Summarisation",
    filename="my_model.pt"
)

model = TransformerModel()
model.load_state_dict(torch.load(checkpoint_path, map_location=device))
model.to(device)
embedding_layer.to(device)
model.eval()

generator = CustomSummaryGenerator(model, tokenizer, embedding_layer, device)
summary = generator.generate_summary("public int add(int a, int b) { return a + b; }")
print("Summary: ", summary)


**INFERENCE ON VALIDATION & TEST SETS**
----------------------
----------------------
----------------------
----------------------

In [None]:
from datasets import load_dataset
import pandas as pd
import os

def generate_batch_summaries(generator, split: str, save_path: str):
    dataset = load_dataset("code_x_glue_ct_code_to_text", "java")[split]

    predictions = []
    references = []

    print(f"Generating summaries for the {split} set...")
    for idx, example in enumerate(dataset):
        code = example["code"]
        reference = example["docstring"]

        try:
            summary = generator.generate_summary(code)
        except Exception as e:
            print(f"Error at index {idx}: {e}")
            summary = ""

        predictions.append(summary)
        references.append(reference)

        if (idx + 1) % 50 == 0:
            print(f"Processed {idx + 1}/{len(dataset)} examples")

    df = pd.DataFrame({
        "gold_summary": references,
        "predicted_summary": predictions
    })

    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    df.to_csv(save_path, index=False)
    print(f"Saved {split} summaries to: {save_path}")


In [None]:
# Run for validation set
generate_batch_summaries(generator, "validation", "/content/drive/MyDrive/custom_val_topk_predictions.csv")

# Run for test set
generate_batch_summaries(generator, "test", "/content/drive/MyDrive/custom_test_topk_predictions.csv")


In [None]:
java_code = """
public int max(int a, int b) {
    return a > b ? a : b;
}
"""

summary = generate_summary(java_code)
print("Generated Summary:", summary)

⚠️ Stopping early due to repetition
🧠 Generated Summary: FixFixFixFixFixFixFixFixFixFix


In [None]:
java_code = """
public int function(int a, int b) {
    return a > b ? a : b;
}
"""

summary = generate_summary(java_code)
print("Generated Summary:", summary)

⚠️ Stopping early due to repetition
🧠 Generated Summary: FixFixFixFixFixFixFixFixFixFix


**EVALUATION**
----------------------
----------------------
----------------------
----------------------

In [None]:
import evaluate
import numpy as np
from collections import Counter
from typing import List, Dict, Optional

class SummaryEvaluator:
    def __init__(self):
        self.rouge = evaluate.load("rouge")
        self.bleu = evaluate.load("bleu")
        self.bertscore = evaluate.load("bertscore")

    def avg_token_repetition(self, predictions):
        """
        Computes the average number of repeated tokens per prediction.
        A high repetition score indicates redundancy in the generated text.
        """
        rep_counts = []
        for text in predictions:
            tokens = text.strip().split()
            counts = Counter(tokens)
            repeated_tokens = sum(v for v in counts.values() if v > 1)
            rep_counts.append(repeated_tokens / max(1, len(tokens)))
        return np.mean(rep_counts)

    def evaluate_csvs(self, files: Dict[str, str]):
        """
        Evaluates multiple prediction files and returns a DataFrame of metrics.
        Each file must be a CSV with columns: 'predicted_summary' and 'gold_summary'.
        """
        all_results = []
        for name, path in files.items():
            if not os.path.exists(path):
                print(f"File not found: {path}")
                continue

            df = pd.read_csv(path)
            predictions = df["predicted_summary"].astype(str).tolist()
            references = df["gold_summary"].astype(str).tolist()

            rouge_scores = self.rouge.compute(predictions=predictions, references=references, use_stemmer=True)
            bleu_score = self.bleu.compute(predictions=predictions, references=references)
            bert_score = self.bertscore.compute(predictions=predictions, references=references, lang="en", device="cuda" if torch.cuda.is_available() else "cpu")
            repetition = self.avg_token_repetition(predictions)

            all_results.append({
                "Version": name,
                "ROUGE-1": round(rouge_scores["rouge1"], 4),
                "ROUGE-2": round(rouge_scores["rouge2"], 4),
                "ROUGE-L": round(rouge_scores["rougeL"], 4),
                "BLEU": round(bleu_score["bleu"], 4),
                "BERTScore": round(np.mean(bert_score["f1"]), 4),
                "Avg Token Repetition": round(repetition, 4)
            })

        return pd.DataFrame(all_results)


In [None]:
if __name__ == "__main__":

    # Evaluation
    evaluator = SummaryEvaluator()
    files = {
        "custom_topk_val": "/content/drive/MyDrive/custom_val_topk_predictions.csv",
        "custom_topk_text": "/content/drive/MyDrive/custom_test_topk_predictions.csv"
    }
    results_df = evaluator.evaluate_csvs(files)
    display(results_df)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Unnamed: 0,Version,ROUGE-1,ROUGE-2,ROUGE-L,BLEU,BERTScore,Avg Token Repetition
0,custom_topk_val,0.0473,0.0011,0.0417,0.0013,0.7776,0.6539
1,custom_topk_text,0.0465,0.0011,0.0405,0.0014,0.7785,0.6535
