In [None]:
# Install necessary libraries
# !pip install transformers datasets accelerate sentencepiece torch evaluate rouge_score

import torch
from datasets import Dataset
from transformers import T5ForConditionalGeneration, T5Tokenizer, TrainingArguments, Trainer
import evaluate
import numpy as np

# --- 1. Model Selection & Data Preparation ---
# Choose a model. 't5-small' is an SLM, 't5-base' is a larger option.
model_name = "t5-small"
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)

# Create a small, illustrative dataset for fine-tuning.
# In a real application, you'd load a much larger, diverse dataset.
raw_data = {
    "incorrect": [
        "I has went to the store.",
        "They is planning a party.",
        "She go to the gym every day.",
        "The boy are playing in the park.",
        "He dont like to eat apples.",
        "I didn't liked the movie.",
        "The cat sleep on the chair.",
    ],
    "correct": [
        "I have gone to the store.",
        "They are planning a party.",
        "She goes to the gym every day.",
        "The boys are playing in the park.",
        "He doesn't like to eat apples.",
        "I didn't like the movie.",
        "The cat sleeps on the chair.",
    ]
}

dataset = Dataset.from_dict(raw_data)
# Split the dataset into train and validation sets
dataset = dataset.train_test_split(test_size=0.2)
train_dataset = dataset['train']
val_dataset = dataset['test']

prefix = "grammar correction: "

def preprocess_function(examples):
    inputs = [prefix + text for text in examples["incorrect"]]
    model_inputs = tokenizer(inputs, max_length=128, truncation=True, padding="max_length")
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(examples["correct"], max_length=128, truncation=True, padding="max_length")
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

tokenized_train_dataset = train_dataset.map(preprocess_function, batched=True)
tokenized_val_dataset = val_dataset.map(preprocess_function, batched=True)

# --- 2. Training and Evaluation ---
# Load evaluation metrics
rouge = evaluate.load("rouge")
bleu = evaluate.load("bleu")

def compute_metrics(eval_preds):
    preds, labels = eval_preds
    # Decode predictions and labels
    if isinstance(preds, tuple):
        preds = preds[0]
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # ROUGE metric
    rouge_results = rouge.compute(predictions=decoded_preds, references=decoded_labels)

    # BLEU metric
    bleu_results = bleu.compute(predictions=decoded_preds, references=decoded_labels)

    # Combine results
    result = {**rouge_results, **bleu_results}
    return {k: round(v, 4) for k, v in result.items()}

training_args = TrainingArguments(
    output_dir="./t5_grammar_correction",
    num_train_epochs=10,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir="./logs",
    logging_steps=10,
    save_strategy="epoch",
    evaluation_strategy="epoch",
    fp16=True if torch.cuda.is_available() else False,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train_dataset,
    eval_dataset=tokenized_val_dataset,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

trainer.train()

# Save the final model
model.save_pretrained("./fine_tuned_t5_model")
tokenizer.save_pretrained("./fine_tuned_t5_model")
