In [2]:
import os
import torch
import ast
from transformers import T5Tokenizer, T5ForConditionalGeneration, Trainer, TrainingArguments
from datasets import load_dataset
import re

In [3]:
# Load CodeSearchNet dataset
dataset = load_dataset("code_search_net", "python")

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.
Downloading builder script: 100%|██████████| 8.44k/8.44k [00:00<00:00, 8.42MB/s]
Downloading readme: 100%|██████████| 12.9k/12.9k [00:00<00:00, 12.9MB/s]
Downloading data: 100%|██████████| 941M/941M [00:31<00:00, 30.2MB/s] 
Generating train split: 100%|██████████| 412178/412178 [05:08<00:00, 1336.65 examples/s]
Generating test split: 100%|██████████| 22176/22176 [00:16<00:00, 1365.78 examples/s]
Generating validation split: 100%|██████████| 23107/23107 [00:17<00:00, 1303.57 examples/s]


In [4]:
# View the structure of the dataset
print(dataset)

DatasetDict({
    train: Dataset({
        features: ['repository_name', 'func_path_in_repository', 'func_name', 'whole_func_string', 'language', 'func_code_string', 'func_code_tokens', 'func_documentation_string', 'func_documentation_tokens', 'split_name', 'func_code_url'],
        num_rows: 412178
    })
    test: Dataset({
        features: ['repository_name', 'func_path_in_repository', 'func_name', 'whole_func_string', 'language', 'func_code_string', 'func_code_tokens', 'func_documentation_string', 'func_documentation_tokens', 'split_name', 'func_code_url'],
        num_rows: 22176
    })
    validation: Dataset({
        features: ['repository_name', 'func_path_in_repository', 'func_name', 'whole_func_string', 'language', 'func_code_string', 'func_code_tokens', 'func_documentation_string', 'func_documentation_tokens', 'split_name', 'func_code_url'],
        num_rows: 23107
    })
})


In [None]:
# Preprocessing function to extract function headers and docstrings
def preprocess_function(sample):
    # Extract function and docstring to form the header-comment pair
    # Adjusted for the actual dataset keys: 'func_documentation_string' and 'func_code_string'
    docstring = sample.get('func_documentation_string')
    code_string = sample.get('func_code_string')
    
    if docstring and code_string:
        function_header = code_string.split(')')[0] + ')'
        if "def " in function_header:
            # print({
            #     "input": docstring,
            #     "output": function_header
            # })
            return {
                "input": docstring,
                "output": function_header
            }
    return {
        "input": "",
        "output": ""
    }

In [27]:
# Apply preprocessing to the dataset
train_data = dataset['train'].map(preprocess_function, batched=False)
print(len(train_data))
train_data = train_data.filter(lambda x: len(x["input"])==0 or len(x["output"])==0) # Remove invalid samples
print(len(train_data))

Map:   0%|          | 390/412178 [00:00<01:50, 3714.07 examples/s]

{'input': 'Trains a k-nearest neighbors classifier for face recognition.\n\n    :param train_dir: directory that contains a sub-directory for each known person, with its name.\n\n     (View in source code to see train_dir example tree structure)\n\n     Structure:\n        <train_dir>/\n        ├── <person1>/\n        │   ├── <somename1>.jpeg\n        │   ├── <somename2>.jpeg\n        │   ├── ...\n        ├── <person2>/\n        │   ├── <somename1>.jpeg\n        │   └── <somename2>.jpeg\n        └── ...\n\n    :param model_save_path: (optional) path to save model on disk\n    :param n_neighbors: (optional) number of neighbors to weigh in classification. Chosen automatically if not specified\n    :param knn_algo: (optional) underlying data structure to support knn.default is ball_tree\n    :param verbose: verbosity of training\n    :return: returns knn classifier that was trained on the given data.', 'output': "def train(train_dir, model_save_path=None, n_neighbors=None, knn_algo='ball_

Map:   0%|          | 680/412178 [00:00<02:12, 3109.30 examples/s]


KeyboardInterrupt: 

In [None]:
# Print one example from the preprocessed dataset
print("Example from preprocessed dataset:", train_data[0])

Example from preprocessed dataset: 


In [None]:
# Load the T5 tokenizer and model
tokenizer = T5Tokenizer.from_pretrained("Salesforce/codet5-base")
model = T5ForConditionalGeneration.from_pretrained("Salesforce/codet5-base")

In [None]:
# Tokenization function for the dataset
def tokenize_function(example):
    inputs = tokenizer(example['input'], truncation=True, padding='max_length', max_length=128)
    labels = tokenizer(example['output'], truncation=True, padding='max_length', max_length=128)
    return {
        "input_ids": inputs['input_ids'],
        "attention_mask": inputs['attention_mask'],
        "labels": labels['input_ids']
    }

In [None]:
# Apply tokenization to the train data
tokenized_train_data = train_data.map(tokenize_function, batched=True)

In [None]:
# Define training arguments
training_args = TrainingArguments(
    output_dir="./finetuned_codet5",
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    num_train_epochs=3,
    logging_dir="./logs",
    logging_steps=100,
    save_steps=1000,
    evaluation_strategy="steps",
    eval_steps=500
)

# Create a Trainer object for fine-tuning
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train_data,
    tokenizer=tokenizer
)

In [None]:
# Start fine-tuning
trainer.train()

trainer.save_model("./finetuned_codet5")
tokenizer.save_pretrained("./finetuned_codet5")
print("Fine-tuning complete! Model saved at ./finetuned_codet5")

In [None]:
# Define a function to generate function headers with syntax restrictions
def generate_function_header(description):
    # Tokenize the input description
    input_ids = tokenizer.encode(description, return_tensors="pt", truncation=True, max_length=128)

    # Generate output with a maximum length and stopping criteria
    outputs = model.generate(
        input_ids,
        max_length=50,  # Limit to a typical function header length
        num_beams=5,    # Beam search for better results
        early_stopping=True,
        eos_token_id=tokenizer.encode(")")[0]  # Stop generation after the closing parenthesis (end of function header)
    )

    # Decode the output tokens
    generated_header = tokenizer.decode(outputs[0], skip_special_tokens=True)

    # Ensure that the generated output ends with ')'
    if not generated_header.strip().endswith(")"):
        generated_header += ")"

    # Add ':' and a newline after the function header
    generated_header += ":\n"

    # Validate the generated function header
    try:
        ast.parse(generated_header)
        valid = True
    except SyntaxError:
        valid = False

    return generated_header, valid

In [None]:

# Example usage
description = "Calculate the factorial of a number."
header, is_valid = generate_function_header(description)
print("Generated Function Header:", header)
print("Is the header valid?:", is_valid)


In [None]:

print("Fine-tuning complete! Model saved at ./finetuned_codet5")

# Define a function to generate function headers with syntax restrictions
def generate_function_header(description):
    # Tokenize the input description
    input_ids = tokenizer.encode(description, return_tensors="pt", truncation=True, max_length=128)

    # Generate output with a maximum length and stopping criteria
    outputs = model.generate(
        input_ids,
        max_length=50,  # Limit to a typical function header length
        num_beams=5,    # Beam search for better results
        early_stopping=True,
        eos_token_id=tokenizer.encode(")")[0]  # Stop generation after the closing parenthesis (end of function header)
    )

    # Decode the output tokens
    generated_header = tokenizer.decode(outputs[0], skip_special_tokens=True)

    # Ensure that the generated output ends with ')'
    if not generated_header.strip().endswith(")"):
        generated_header += ")"

    # Add ':' and a newline after the function header
    generated_header += ":\n"

    # Validate the generated function header
    try:
        ast.parse(generated_header)
        valid = True
    except SyntaxError:
        valid = False

    return generated_header, valid

# Example usage
description = "Calculate the factorial of a number."
header, is_valid = generate_function_header(description)
print("Generated Function Header:", header)
print("Is the header valid?:", is_valid)
