## Training Finetune T5-small for texts classification

In [None]:
from transformers import T5ForConditionalGeneration, T5Tokenizer, Trainer, TrainingArguments, DataCollatorForSeq2Seq, TrainerCallback
import evaluate
import torch
import numpy as np
import pandas as pd

# Model
model_name = "google-t5/t5-small"
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)
model.gradient_checkpointing_enable()

# Hyperparams
max_input_len = 512
max_target_len = 128

# Preprocess
def preprocess(example):
    inputs = tokenizer(
        "clean: " + example["input"],
        truncation=True,
        padding="max_length",
        max_length=max_input_len,
    )
    targets = tokenizer(
        example["target"],
        truncation=True,
        padding="max_length",
        max_length=max_target_len,
    )
    inputs["labels"] = targets["input_ids"]
    return inputs

tokenized_ds = dataset.map(preprocess, batched=False)
split = tokenized_ds.train_test_split(test_size=0.1)
train_ds = split["train"]
val_ds = split["test"]

# Metric
rouge = evaluate.load("rouge")

def compute_metrics(eval_pred):
    preds, labels = eval_pred
    if isinstance(preds, tuple):
        preds = preds[0]
    if preds.ndim == 3:
        preds = np.argmax(preds, axis=-1)

    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    preds = np.where(preds != -100, preds, tokenizer.pad_token_id)

    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    decoded_preds = [pred.strip() for pred in decoded_preds]
    decoded_labels = [label.strip() for label in decoded_labels]

    return rouge.compute(predictions=decoded_preds, references=decoded_labels)

# Clear GPU cache between train/eval
class CudaClearCallback(TrainerCallback):
    def on_evaluate(self, args, state, control, **kwargs):
        torch.cuda.empty_cache()

# Training args
training_args = TrainingArguments(
    output_dir="./t5_clean_model",
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=8,
    num_train_epochs=10,
    warmup_ratio=0.2,
    learning_rate=3e-5,
    eval_strategy="epoch",
    save_strategy="epoch",
    save_total_limit=2,
    logging_steps=20,
    fp16=True,
    report_to="none",
)

# Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    tokenizer=tokenizer,
    data_collator=DataCollatorForSeq2Seq(tokenizer, model=model),
    compute_metrics=compute_metrics,
    callbacks=[CudaClearCallback()],
)

# Train
trainer.train()

# Save
trainer.save_model("./t5_clean_final")
tokenizer.save_pretrained("./t5_clean_final")

## Inference after training check texts preprocessing

In [None]:
# Load fine-tuned model and tokenizer
model_dir = "./t5_clean_final"
tokenizer = T5Tokenizer.from_pretrained(model_dir)
model = T5ForConditionalGeneration.from_pretrained(model_dir).to("cuda")
model.eval()

# Hyperparameters
max_input_len = 256
max_target_len = 128

# Load new CSV data (must contain a column "input")
df = pd.read_csv("new_data.csv")

# Function to generate predictions for a batch of texts
def generate_clean_text(batch_texts):
    # Tokenize input with prefix "clean: "
    inputs = tokenizer(
        ["clean: " + t for t in batch_texts],
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=max_input_len,
    ).to("cuda")

    # Generate predictions (greedy decoding)
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_length=max_target_len,
            num_beams=1,      # use greedy decoding (faster, less memory)
            do_sample=False,
        )

    # Decode predictions into text
    decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    return [d.strip() for d in decoded]

# Run inference in batches
batch_size = 16
results = []
for i in range(0, len(df), batch_size):
    batch_texts = df["input"].iloc[i:i+batch_size].tolist()
    preds = generate_clean_text(batch_texts)
    results.extend(preds)

# Save results to new CSV
df["prediction"] = results
df.to_csv("predictions.csv", index=False)