In [None]:
from transformers import T5ForConditionalGeneration, T5Tokenizer, Trainer, TrainingArguments
import pandas as pd
from datasets import Dataset

# Load dataset
df = pd.read_csv("../data/sinhala_correction_dataset.csv")
dataset = Dataset.from_pandas(df)

# Tokenizer and Model
model_name = "google/mt5-small"
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)

def preprocess_function(examples):
    inputs = [f"grammar correction: {sentence}" for sentence in examples["incorrect"]]
    targets = examples["correct"]
    model_inputs = tokenizer(inputs, max_length=128, truncation=True, padding="max_length")
    labels = tokenizer(targets, max_length=128, truncation=True, padding="max_length")
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

# Tokenize dataset
tokenized_dataset = dataset.map(preprocess_function, batched=True)

# Training setup
training_args = TrainingArguments(
    output_dir="../models/fine_tuned_mT5",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    num_train_epochs=3,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    logging_dir="./logs",
    save_total_limit=2
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset,
    tokenizer=tokenizer
)

trainer.train()
model.save_pretrained("../models/fine_tuned_mT5")
tokenizer.save_pretrained("../models/tokenizer")
