# Imports

In [None]:
# !pip install -r /kaggle/input/requirements-txt/requirements.txt
# !nvidia-smi

Collecting tree-sitter (from -r /kaggle/input/requirements-txt/requirements.txt (line 10))
  Downloading tree_sitter-0.24.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (9.8 kB)
Collecting tree-sitter-python (from -r /kaggle/input/requirements-txt/requirements.txt (line 11))
  Downloading tree_sitter_python-0.23.6-cp39-abi3-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.9 kB)
Collecting tree-sitter-java (from -r /kaggle/input/requirements-txt/requirements.txt (line 12))
  Downloading tree_sitter_java-0.23.5-cp39-abi3-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.7 kB)
Collecting evaluate (from -r /kaggle/input/requirements-txt/requirements.txt (line 22))
  Downloading evaluate-0.4.3-py3-none-any.whl.metadata (9.2 kB)
Collecting rouge_score (from -r /kaggle/input/requirements-txt/requirements.txt (line 23))
  Downloading rouge_score-0.1.2.tar.gz (17 kB)
  Prepari

In [None]:
import os
import re
import pandas as pd
import numpy as np
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset, Subset

from transformers import AutoTokenizer, T5ForConditionalGeneration, Seq2SeqTrainer, Seq2SeqTrainingArguments, DataCollatorForSeq2Seq
from peft import LoraConfig, get_peft_model, TaskType

from sklearn.utils import resample
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from sklearn.metrics import precision_recall_curve, average_precision_score

# WandB
import wandb

# AST
from tree_sitter import Language, Parser
import tree_sitter_python
import tree_sitter_java

## AST Graphing
import graphviz

# Plotting
import matplotlib.pyplot as plt
import seaborn as sns

import evaluate

# Datasets
from pathlib import Path
from datasets import load_dataset, load_from_disk, DatasetDict, concatenate_datasets



# Data Loading

## Python Data

In [None]:
# Python loading
python_dataset = DatasetDict({
    'train': load_dataset('code_search_net', 'python', split='train[:60000]', trust_remote_code=True),
    'validation': load_dataset('code_search_net', 'python', split='validation[:7000]', trust_remote_code=True),
    'test': load_dataset('code_search_net', 'python', split='test[:3500]', trust_remote_code=True)
})

python_dataset

## Java Data

In [4]:
# Java loading
java_dataset = DatasetDict({
    'train': load_dataset('code_search_net', 'java', split='train[:60000]', trust_remote_code=True),
    'validation': load_dataset('code_search_net', 'java', split='validation[:7000]', trust_remote_code=True),
    'test': load_dataset('code_search_net', 'java', split='test[:3500]', trust_remote_code=True)
})

java_dataset

java.zip:   0%|          | 0.00/1.06G [00:00<?, ?B/s]

Generating train split:   0%|          | 0/454451 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/26909 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/15328 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['repository_name', 'func_path_in_repository', 'func_name', 'whole_func_string', 'language', 'func_code_string', 'func_code_tokens', 'func_documentation_string', 'func_documentation_tokens', 'split_name', 'func_code_url'],
        num_rows: 60000
    })
    validation: Dataset({
        features: ['repository_name', 'func_path_in_repository', 'func_name', 'whole_func_string', 'language', 'func_code_string', 'func_code_tokens', 'func_documentation_string', 'func_documentation_tokens', 'split_name', 'func_code_url'],
        num_rows: 7000
    })
    test: Dataset({
        features: ['repository_name', 'func_path_in_repository', 'func_name', 'whole_func_string', 'language', 'func_code_string', 'func_code_tokens', 'func_documentation_string', 'func_documentation_tokens', 'split_name', 'func_code_url'],
        num_rows: 3500
    })
})

## Debug and Test modes

In [5]:
# set to False for full training config
debug = False
# set tp True for enabling testing blocks
test_run = False

# Concatenation

In [6]:
combined_dataset = DatasetDict({
    'train': concatenate_datasets([python_dataset['train'], java_dataset['train']]),
    'validation': concatenate_datasets([python_dataset['validation'], java_dataset['validation']]),
    'test': concatenate_datasets([python_dataset['test'], java_dataset['test']])
})

if debug:
    combined_dataset["train"] = combined_dataset["train"].select(range(200))
    combined_dataset["validation"] = combined_dataset["validation"].select(range(50))
    combined_dataset["test"] = combined_dataset["test"].select(range(50))

combined_dataset['train'] = combined_dataset['train'].shuffle(seed=42)
combined_dataset['validation'] = combined_dataset['validation'].shuffle(seed=42)
combined_dataset['test'] = combined_dataset['test'].shuffle(seed=42)

combined_dataset['train']['func_code_string'][1]

'private static int[][] findAromaticRings(int[][] cycles, int[] contribution, int[] dbs) {\n\n        // loop control variables, the while loop continual checks all cycles\n        // until no changes are found\n        boolean found;\n        boolean[] checked = new boolean[cycles.length];\n\n        // stores the aromatic atoms as a bit set and the aromatic bonds as\n        // a hash set. the aromatic bonds are the result of this method but the\n        // aromatic atoms are needed for checking each ring\n        final boolean[] aromaticAtoms = new boolean[contribution.length];\n\n        final List<int[]> ringsOfSize6 = new ArrayList<int[]>();\n        final List<int[]> ringsOfSize5 = new ArrayList<int[]>();\n\n        do {\n            found = false;\n            for (int i = 0; i < cycles.length; i++) {\n\n                // note paths are closed walks and repeat first/last vertex so\n                // the true length is one less\n                int[] cycle = cycles[i];\n      

# ASTs

In [7]:
# Initialize the languages
PY_LANGUAGE = Language(tree_sitter_python.language())
JAVA_LANGUAGE = Language(tree_sitter_java.language())

# Initialize the parsers by passing the language
python_parser = Parser(PY_LANGUAGE)
java_parser = Parser(JAVA_LANGUAGE)

def parse_code_to_ast(code, language):
    if language.lower() == 'python':
        parser = python_parser
    elif language.lower() == 'java':
        parser = java_parser
    tree = parser.parse(bytes(code, 'utf8'))
    return tree

def sbt_traverse(node):
    """
    Recursively traverse the AST node using an SBT (Structure-Based Traversal) method.
    This function outputs a list of tokens with explicit start and end markers for each node.
    """
    # Add a start marker for the current node
    sequence = [f"<{node.type}>"]
    # Recursively traverse each child and extend the sequence
    for child in node.children:
        sequence.extend(sbt_traverse(child))
    # Add an end marker for the current node
    sequence.append(f"</{node.type}>")
    return sequence

# Visualize AST Graph
def visualize_ast(tree):
    dot = graphviz.Digraph(format="png")
    
    def add_nodes_edges(node, parent_id=None):
        node_id = str(id(node))
        dot.node(node_id, label=node.type)  # Add the node with its type as label

        if parent_id:
            dot.edge(parent_id, node_id)  # Add an edge from the parent to this node
        
        for child in node.children:
            add_nodes_edges(child, node_id)
    
    add_nodes_edges(tree.root_node)
    return dot

### TEST: AST Output

In [8]:
if debug:
    # Retrieve code strings from your datasets (for testing)
    python_code = python_dataset['train']['func_code_string'][2]
    java_code = java_dataset['train']['func_code_string'][2]
    
    # Parse the code to AST trees
    python_tree = python_parser.parse(bytes(python_code, "utf8"))
    java_tree = java_parser.parse(bytes(java_code, "utf8"))
    
    # Generate AST Visualization for Python code sample
    ast_viz = visualize_ast(python_tree)
    ast_viz.render("ast_visualization", format="png", view=True)
    
    # Print the basic AST representation
    print("Python AST:")
    print(str(python_tree.root_node))
    print("\nJava AST:")
    print(str(java_tree.root_node))
    
    # Generate SBT sequences from the ASTs
    python_sbt_sequence = " ".join(sbt_traverse(python_tree.root_node))
    java_sbt_sequence = " ".join(sbt_traverse(java_tree.root_node))
    
    # Print the SBT sequences
    print("\nPython SBT Sequence:")
    print(python_sbt_sequence)
    print("\nJava SBT Sequence:")
    print(java_sbt_sequence)

## AST Integration, Masking & Preprocessing Functions

In [9]:
def mask_func_name(code_str: str, func_name: str, lang: str) -> str:
    lang = lang.lower()

    if lang == 'python':
        pattern = rf"(def\s+)({re.escape(func_name)})(\s*\()"
        return re.sub(pattern, r"\1<extra_id_0>\3", code_str, count=1)
    
    elif lang == 'java':
        pattern = rf"(?<!\w){re.escape(func_name)}(?=\s*\()"
        return re.sub(pattern, "<extra_id_0>", code_str, count=1)

    else:
        return code_str

def test_real_python_samples(dataset, num_samples=3):
    print("=== REAL PYTHON SAMPLES ===")
    
    for i in range(num_samples):
        full_func_name = dataset['train'][i]['func_name'] 
        method_name = full_func_name.split('.')[-1] 
        code = dataset['train'][i]['func_code_string']
        
        print(f"--- Sample #{i} ---")
        print(f"Original Function Name: {full_func_name}")
        print("\nOriginal Code:\n", code)
        print("\nMasked Code:\n", mask_func_name(code, method_name, lang="python"))
        print("=" * 100 + "\n")

def test_real_java_samples(dataset, num_samples=3):
    print("=== REAL JAVA SAMPLES ===")
    
    for i in range(num_samples):
        full_func_name = dataset['train'][i]['func_name']
        method_name = full_func_name.split('.')[-1]
        code = dataset['train'][i]['func_code_string']
        
        print(f"--- Sample #{i} ---")
        print(f"Original Function Name: {full_func_name}")
        print("\nOriginal Code:\n", code)
        print("\nMasked Code:\n", mask_func_name(code, method_name, lang="java"))
        print("=" * 100 + "\n")


def inspect_samples(dataset, lang: str, num_samples: int = 5):
    print(f"\n=== {lang.upper()} SAMPLE VERIFICATION ===\n")
    for i in range(num_samples):
        sample = dataset['train'][i]
        code = sample['func_code_string']
        full_name = sample['func_name']
        method_name = full_name.split('.')[-1]

        masked_code = mask_func_name(code, method_name, lang)
        dummy_ast = "<AST> dummy AST </AST>"
        combined_input = masked_code + " " + dummy_ast
        tokens = tokenizer.tokenize(combined_input)

        print(f"--- Sample #{i} ---")
        print(f"Original Function Name: {full_name}")
        print("\nOriginal Code:\n", code)
        print("\nMasked Code:\n", masked_code)
        print("\nFinal Combined Input:\n", combined_input)
        print("\nTokenized Input:\n", tokens)
        print("=" * 100)

def preprocess(examples):
    combined_inputs = []
    combined_labels = []
    
    # Iterate over each example
    for code, target, lang in zip(examples['func_code_string'], examples['func_name'], examples['language']):
         # Extract method name (in case it's fully qualified like Class.method)
        method_name = target.split('.')[-1]
        # Mask function name in definition
        masked_code = mask_func_name(code, method_name, lang)

        tree = parse_code_to_ast(code, lang)
        root_node = tree.root_node
        ast_features = sbt_traverse(root_node)
        
        ast_string = "<AST> " + " ".join(ast_features) + " </AST>" # Wrapping the AST features with <AST> and </AST>.
        combined_input = masked_code + " " + ast_string # Combining code with AST features

        combined_inputs.append(combined_input)
        combined_labels.append(method_name) # Extract the method name from the full path
    
    # Tokenize the combined input and targets
    model_inputs = tokenizer(combined_inputs, max_length=1024, truncation=True, padding='max_length')
    tokenized_labels = tokenizer(combined_labels, max_length=50, truncation=True, padding='max_length')
    
    model_inputs['labels'] = tokenized_labels['input_ids']
    return model_inputs


if debug:
    # Run both inspections
    inspect_samples(python_dataset, lang="python", num_samples=5)
    inspect_samples(java_dataset, lang="java", num_samples=5)
    
    # Run the test
    test_real_java_samples(java_dataset, num_samples=5)
    test_real_python_samples(python_dataset, num_samples=5)
    
    print("<extra_id_0>" in tokenizer.get_vocab())
    print("<mask>" in tokenizer.get_vocab())
    print("Token ID for <extra_id_0>:", tokenizer.convert_tokens_to_ids("<extra_id_0>"))
    print("All special tokens:", tokenizer.special_tokens_map)
    print("Additional special tokens:", tokenizer.additional_special_tokens)



# Initialize the tokenizer
tokenizer = AutoTokenizer.from_pretrained("Salesforce/codet5-base")
special_tokens = {"additional_special_tokens": ["<AST>", "</AST>"]}
tokenizer.add_special_tokens(special_tokens)

tokenizer_config.json:   0%|          | 0.00/1.48k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/703k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/294k [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/2.00 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/12.5k [00:00<?, ?B/s]

2

# Tokenize

In [10]:
tokenized_dataset = combined_dataset.map(preprocess, batched=True)



Map:   0%|          | 0/120000 [00:00<?, ? examples/s]

Map:   0%|          | 0/14000 [00:00<?, ? examples/s]

Map:   0%|          | 0/7000 [00:00<?, ? examples/s]

In [None]:
if debug:
    num_samples_to_show = 5
    
    for idx in range(num_samples_to_show):
        print(f"\n===== Sample {idx + 1} =====")
    
        # Print decoded input (with masking, i.e., function body with <extra_id_0>)
        input_ids = tokenized_dataset["train"][idx]["input_ids"]
        decoded_input = tokenizer.decode(input_ids, skip_special_tokens=False)
        print("Masked Input Code:\n", decoded_input)
    
        # Print decoded label (method name target)
        label_ids = tokenized_dataset["train"][idx]["labels"]
        decoded_label = tokenizer.decode(
            [id for id in label_ids if id != tokenizer.pad_token_id],
            skip_special_tokens=True
        )
        print("Target Method Name:", decoded_label)
    
        # Optional: show original method name from combined dataset (if available)
        if "func_name" in combined_dataset["train"].features:
            original_name = combined_dataset["train"][idx]["func_name"]
            print("Original Method Name:", original_name)

In [12]:
if debug:
    # Show sample
    
    print(tokenized_dataset["train"][0])
    print(tokenizer.decode(tokenized_dataset["train"][0]["input_ids"]))
    
    print(tokenized_dataset["train"][0])
    print(tokenizer.decode(tokenized_dataset["train"][0]["input_ids"]))
    
    
    sample_index = 0 
    
    # From original dataset (before masking)
    original_func_name = combined_dataset["train"][sample_index]["func_name"]
    print("Full Function Name:", original_func_name)
    
    # From label inside tokenized dataset
    label_ids = tokenized_dataset["train"][sample_index]["labels"]
    label_text = tokenizer.decode([id for id in label_ids if id != tokenizer.pad_token_id], skip_special_tokens=True)
    print("Target Label Text (after masking & preprocessing):", label_text)
    
    label_ids = tokenized_dataset['train'][0]['labels']
    label_text = tokenizer.decode([id for id in label_ids if id != tokenizer.pad_token_id], skip_special_tokens=True)
    print("Decoded Label (method name):", label_text)

# W&B

#### Make all changes to hyper-params here, pls do not change elsewhere

In [13]:
if debug:
    config = {
        "learning_rate": 5e-5,
        "batch_size": 8,
        "num_train_epochs": 1,
        "eval_steps": 20,
        "save_steps": 20,
        "save_total_limit": 1,
        "logging_steps": 10,
        "fp16": False,  # for smoke-test
        "predict_with_generate": True,
        "load_best_model_at_end": True,
        "evaluation_strategy": "steps",
        "logging_strategy": "steps",
        "save_strategy": "steps",
        "output_dir": "./debug_results",
        "report_to": "wandb",
        "run_name": "mngast120k_smoke_test",
        "model_name": "Salesforce/codet5-base"
    }
else:
    # Define hyperparameters in a dictionary
    config = {
        "learning_rate": 4e-7,
        "batch_size": 8,
        "num_train_epochs": 2,
        "eval_steps": 5000,
        "save_steps": 5000,
        "save_total_limit": 3,
        "logging_steps": 100,
        #"weight_decay": 0.01,
        "fp16": True,
        "predict_with_generate": True,
        "load_best_model_at_end": True,
        "evaluation_strategy": "steps",
        "logging_strategy": "steps",
        "save_strategy": "steps",
        "output_dir": "./training_results",
        "report_to": "wandb",
        "run_name": "mngast120k_training",
        "model_name": "Salesforce/codet5-base"
    }


In [None]:
# Log hyperparameters to W&B
wandb.login(key="ebd5969438c4d7fbf09289ce11c991e89fcc3b5b")
wandb.init(project="Method Name Prediction", name="mng_training")


In [None]:
wandb.config.update(config)

# Model Loading

In [14]:
model = T5ForConditionalGeneration.from_pretrained(config["model_name"])

# Accounting for additional <AST> special tokens
model.resize_token_embeddings(len(tokenizer))

config.json:   0%|          | 0.00/1.57k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/892M [00:00<?, ?B/s]

The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


Embedding(32102, 768)

## LoRA - Fine tuning

In [15]:
lora_config = LoraConfig(
    r=8,
    lora_alpha=32,
    target_modules=[
        # Encoder attention part
        "q", "k", "v", "o",
        # Decoder attention part
        "decoder.q", "decoder.k", "decoder.v", "decoder.o",
        # Feed-forward network layers
        "wi", "wo",
    ],
    bias="none",
    task_type=TaskType.SEQ_2_SEQ_LM,
)

model = get_peft_model(model, lora_config)

## Pls verify file directory

In [16]:
checkpoint_dir = os.path.join(
    "/kaggle/input", "checkpoint-path", "training_results", "training_results", "checkpoint-25000"
)
model.load_adapter(checkpoint_dir, adapter_name="default")

<All keys matched successfully>

## Training Params

In [None]:
cnfg = wandb.config

# All fields called from config dictionary
training_args = Seq2SeqTrainingArguments(
    learning_rate=cnfg.learning_rate,
    per_device_train_batch_size=cnfg.batch_size,
    per_device_eval_batch_size=cnfg.batch_size,
    num_train_epochs=cnfg.num_train_epochs,
    eval_steps=cnfg.eval_steps,
    save_steps=cnfg.save_steps,
    save_total_limit=cnfg.save_total_limit,
    logging_steps=cnfg.logging_steps,
    # weight_decay=cnfg.weight_decay,
    fp16=cnfg.fp16,
    predict_with_generate=cnfg.predict_with_generate,
    load_best_model_at_end=cnfg.load_best_model_at_end,
    eval_strategy=cnfg.evaluation_strategy,
    logging_strategy=cnfg.logging_strategy,
    save_strategy=cnfg.save_strategy,
    output_dir=cnfg.output_dir,
    report_to=cnfg.report_to,
    run_name=cnfg.run_name,
)

# Data Loader

In [None]:
collator = DataCollatorForSeq2Seq(tokenizer, model=model, padding=True)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset['train'],
    eval_dataset=tokenized_dataset['validation'],
    processing_class=tokenizer,
    data_collator=collator,
)

# Train

In [None]:
trainer.train()

# Evalutation

In [None]:
# Load evaluation metrics
rouge = evaluate.load("rouge")
bleu = evaluate.load("bleu")
accuracy_metric = evaluate.load("accuracy")

# Evaluate on a small subset for speed (adjust as needed)
eval_samples = 100
model.eval()
device = model.device
predictions, references = [], []
exact_matches = 0

test_subset = tokenized_dataset["test"].select(range(eval_samples))

for example in tqdm(test_subset):
    input_ids = torch.tensor(example["input_ids"]).unsqueeze(0).to(device)

    with torch.no_grad():
        generated_ids = model.generate(
            input_ids=input_ids,
            max_length=50,
            num_beams=4,
            early_stopping=True
        )

    pred = tokenizer.decode(generated_ids[0], skip_special_tokens=True).strip()
    ref = tokenizer.decode(example["labels"], skip_special_tokens=True).strip()

    predictions.append(pred)
    references.append(ref)

    if pred == ref:
        exact_matches += 1

# ROUGE
rouge_result = rouge.compute(predictions=predictions, references=references)
print("ROUGE scores:")
for k, v in rouge_result.items():
    print(f"{k}: {v:.4f}")

# BLEU
bleu_result = bleu.compute(predictions=predictions, references=[[r] for r in references])
print(f"\nBLEU score: {bleu_result['bleu']:.4f}")

# Accuracy (exact match)
exact_match_accuracy = exact_matches / eval_samples
print(f"\nExact Match Accuracy: {exact_match_accuracy:.4f}")

# Perplexity
def calculate_perplexity(model, tokenizer, examples):
    model.eval()
    losses = []
    for example in tqdm(examples, desc="Perplexity"):
        input_ids = torch.tensor(example["input_ids"]).unsqueeze(0).to(device)
        labels = torch.tensor(example["labels"]).unsqueeze(0).to(device)

        with torch.no_grad():
            output = model(input_ids=input_ids, labels=labels)
            loss = output.loss
            losses.append(loss.item())

    avg_loss = np.mean(losses)
    return np.exp(avg_loss)

perplexity = calculate_perplexity(model, tokenizer, test_subset)
print(f"\nPerplexity: {perplexity:.2f}")


### Test-output

In [None]:
if test_run:
    test_input = '''def <extra_id_0>(x, y):
        return (x ** 2 + y ** 2) ** 0.5
    '''
    test_input2 = '''public static int <extra_id_0>(int n) {
        if (n == 0) {
            return 1;
        }
        return n * <extra_id_0>(n - 1);
    }
    '''
    test_input3 = '''def <extra_id_0>(data, window_size=3):
        if len(data) < window_size:
            raise ValueError("Data length must be at least equal to the window size.")
        
        moving_averages = []
        for i in range(len(data) - window_size + 1):
            window = data[i : i + window_size]
            window_average = sum(window) / window_size
            moving_averages.append(window_average)
        
        return moving_averages
    '''
    test_input4 = '''public static int <extra_id_0>(int[] numbers) {
        int max = Integer.MIN_VALUE;
        for (int num : numbers) {
            if (num > max) {
                max = num;
            }
        }
        return max;
    }
    '''
    # Tokenize (on GPU)
    inputs = tokenizer(test_input2, return_tensors="pt").to(model.device)
    
    # Generate
    generated_ids = model.generate(**inputs, max_length=16)
    
    # Decode
    output_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
    print("Predicted method name:", output_text)

In [None]:
if debug:
    !zip -r /kaggle/working/training_checkpoints.zip /kaggle/working/debug_results
else:
    !zip -r /kaggle/working/training_checkpoints.zip /kaggle/working/training_results