In [None]:
# Cell 1: Imports and setup
import os
from datasets import DatasetDict, Dataset
from transformers import T5Tokenizer, T5ForConditionalGeneration, Trainer, TrainingArguments
import torch
from google.colab import drive


In [None]:
# Cell 2: Mount Google Drive
drive.mount('/content/drive')

In [None]:
# Cell 3: Function to load data from folder
def load_data_from_folder(folder_path):
    buggy_path = os.path.join(folder_path, "buggy.txt")
    fixed_path = os.path.join(folder_path, "fixed.txt")

    with open(buggy_path, "r", encoding="utf-8") as f:
        buggy_lines = f.readlines()

    with open(fixed_path, "r", encoding="utf-8") as f:
        fixed_lines = f.readlines()

    assert len(buggy_lines) == len(fixed_lines), "Mismatch in buggy and fixed lines"

    data = {"input_text": buggy_lines, "target_text": fixed_lines}
    return Dataset.from_dict(data)

In [None]:
# Cell 4: Load datasets
train_folder = "/content/drive/My Drive/CodeFix_DataSet/train"
eval_folder = "/content/drive/My Drive/CodeFix_DataSet/eval"
test_folder = "/content/drive/My Drive/CodeFix_DataSet/test"

train_dataset = load_data_from_folder(train_folder)
eval_dataset = load_data_from_folder(eval_folder)
test_dataset = load_data_from_folder(test_folder)

datasets = DatasetDict({
    "train": train_dataset,
    "validation": eval_dataset,
    "test": test_dataset,
})

In [None]:
# Cell 5: Clean data - strip whitespace from strings
def strip_strings(example):
    example["input_text"] = example["input_text"].strip()
    example["target_text"] = example["target_text"].strip()
    return example

datasets["train"] = datasets["train"].map(strip_strings)
datasets["validation"] = datasets["validation"].map(strip_strings)
datasets["test"] = datasets["test"].map(strip_strings)

In [None]:
# Cell 6: Tokenizer and model init
model_name = "t5-base"
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)

In [None]:
# Cell 7: Preprocessing before tokenizing
max_input_length = 128
max_target_length = 128

def preprocess_function(examples):
    inputs = [f"fix code: {code.strip()}" for code in examples["input_text"]]
    targets = [code.strip() for code in examples["target_text"]]
    model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True, padding="max_length")

    with tokenizer.as_target_tokenizer():
        labels = tokenizer(targets, max_length=max_target_length, truncation=True, padding="max_length")

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [None]:
# Cell 8: Tokenize datasets
tokenized_datasets = datasets.map(preprocess_function, batched=True)


In [None]:
# Cell 9: Checkpoint directory on Drive
checkpoint_dir = "/content/drive/MyDrive/code_fix_model_checkpoint"
os.makedirs(checkpoint_dir, exist_ok=True)

In [None]:
# Cell 10: Setup training args
training_args = TrainingArguments(
    output_dir=checkpoint_dir,
    eval_strategy="steps",
    eval_steps=500,
    save_strategy="steps",
    save_steps=500,
    save_total_limit=3,
    logging_dir="./logs",
    logging_steps=100,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    learning_rate=5e-5,
    weight_decay=0.01,
    num_train_epochs=3,
    max_grad_norm=1.0,
    report_to="none",
)

In [None]:
# Cell 11: Initialize Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    tokenizer=tokenizer,
)

In [None]:
# Cell 12: Find latest checkpoint to resume if any
def find_latest_checkpoint(path):
    if not os.path.exists(path):
        return None
    checkpoints = [os.path.join(path, d) for d in os.listdir(path) if d.startswith("checkpoint")]
    if not checkpoints:
        return None
    latest_checkpoint = max(checkpoints, key=os.path.getmtime)
    return latest_checkpoint

In [None]:
# Cell 13: Train or resume
last_checkpoint = find_latest_checkpoint(checkpoint_dir)
if last_checkpoint:
    print(f"Resuming from checkpoint: {last_checkpoint}")
    trainer.train(resume_from_checkpoint=last_checkpoint)
else:
    trainer.train()

In [None]:
# Cell 14: Save model and tokenizer to Drive
trainer.save_model(checkpoint_dir)
tokenizer.save_pretrained(checkpoint_dir)

In [None]:
# Cell 15: Extra: Quick dataset samples check (you can run anytime)
print(f"Train size: {len(datasets['train'])}, Eval size: {len(datasets['validation'])}, Test size: {len(datasets['test'])}")

import random
print("Random buggy code samples:")
for _ in range(3):
    i = random.randint(0, len(datasets['train']) - 1)
    print(datasets["train"][i]["input_text"])
