# T5 Fine-tuning


## Import Libraries

In [None]:
# pip install -qqq --upgrade -r requirements.txt

In [None]:
# General libraries
import os
import shutil
import json

# Dataset libraries
from datasets import load_dataset as hf_load_dataset, Dataset, DatasetDict
import kagglehub
from kagglehub import KaggleDatasetAdapter

# Data manipulation
import pandas as pd
import numpy as np

# Transformers libraries
from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    DataCollatorForSeq2Seq,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    EarlyStoppingCallback
)

# Visualization
import matplotlib.pyplot as plt

# PyTorch for deep learning
import torch

# Evaluation
from rouge_score import rouge_scorer, scoring

# Hyper parameters tuning
import optuna
from optuna.trial import TrialState

## Load dataset

In [None]:
def load_dataset(dataset_id):
    dataset = hf_load_dataset(dataset_id, split="train")
    return dataset


In [None]:
def split_dataset(dataset, test_size=0.1, validation_size=0.1, seed=42):
    datasets_train_test = dataset.train_test_split(test_size=test_size)
    datasets_train_validation = datasets_train_test["train"].train_test_split(test_size=validation_size)
    dataset_split = DatasetDict({
        "train": datasets_train_validation["train"],
        "validation": datasets_train_validation["test"],
        "test": datasets_train_test["test"]
    })

    return dataset_split

## Load Model and Tokenizer

In [None]:
def load_model(model_name):
    model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
    return model

In [None]:
def load_tokenizer(model_name):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    return tokenizer

## Dataset Preprocess

In [None]:
def preprocess_and_tokenize(
    dataset_split,
    tokenizer,
    input_col="judgements",
    target_col="summary",
    input_max_length=8192,
    target_max_length=1024
):
    def preprocess_function(examples):
        inputs = ["summarize: " + doc for doc in examples[input_col]]
        model_inputs = tokenizer(
            inputs,
            max_length=input_max_length,
            truncation=True,
            padding=True
        )
        labels = tokenizer(
            text_target=examples[target_col],
            max_length=target_max_length,
            truncation=True,
            padding=True
        )
        model_inputs["labels"] = labels["input_ids"]
        return model_inputs

    return dataset_split.map(preprocess_function, batched=True)

## Function calculate the rouge metrics

In [None]:
def compute_rouge(refs, preds):
    scorer = rouge_scorer.RougeScorer(["rouge1", "rouge2", "rougeL"], use_stemmer=True)
    aggregator = scoring.BootstrapAggregator()
    for pred, ref in zip(preds, refs):
        scores = scorer.score(ref, pred)
        aggregator.add_scores(scores)
    result = aggregator.aggregate()
    return {
        "rouge1": result["rouge1"].mid.fmeasure * 100,
        "rouge2": result["rouge2"].mid.fmeasure * 100,
        "rougeL": result["rougeL"].mid.fmeasure * 100,
    }

## Function for evaluation

In [None]:
def manual_evaluate(
    model,
    tokenizer,
    dataset,
    batch_size=8,
    max_length=8192,
    num_beams=4,
    repetition_penalty=2.0,
    no_repeat_ngram_size=3,
    length_penalty=1.0,
    early_stopping=True
):
    model.eval()
    all_preds = []

    all_input_ids = dataset["input_ids"]
    all_attention_mask = dataset["attention_mask"]

    for start_idx in range(0, len(dataset), batch_size):
        batch_input_ids = all_input_ids[start_idx : start_idx + batch_size]
        batch_attention_mask = all_attention_mask[start_idx : start_idx + batch_size]

        input_ids = torch.tensor(batch_input_ids, dtype=torch.long, device=model.device)
        attention_mask = torch.tensor(batch_attention_mask, dtype=torch.long, device=model.device)

        with torch.no_grad():
            outputs = model.generate(
                input_ids=input_ids,
                attention_mask=attention_mask,
                max_length=max_length,
                num_beams=num_beams,
                repetition_penalty=repetition_penalty,
                no_repeat_ngram_size=no_repeat_ngram_size,
                length_penalty=length_penalty,
                early_stopping=early_stopping
            )

        decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True)
        all_preds.extend(decoded)

    return all_preds


## Training the model

In [None]:
def train_epoch_only(
    model,
    tokenizer,
    data_collator,
    tokenized_datasets,
    train_batch_size,
    eval_batch_size,
    learning_rate,
    num_train_epochs,
    output_dir
):
    training_args = Seq2SeqTrainingArguments(
        output_dir=output_dir,
        eval_strategy="no",
        save_strategy="no",
        per_device_train_batch_size=train_batch_size,
        per_device_eval_batch_size=eval_batch_size,
        learning_rate=learning_rate,
        num_train_epochs=num_train_epochs,
        weight_decay=0.01,
        logging_dir="./logs",
        logging_steps=50,
        gradient_accumulation_steps=2,
        fp16=torch.cuda.is_available(),
    )

    trainer = Seq2SeqTrainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_datasets["train"],
        eval_dataset=None,
        tokenizer=tokenizer,
        data_collator=data_collator,
    )
    trainer.train()
    return trainer

## Hyperparameter tuning

In [None]:
def run_optuna_search(tokenized_datasets, model_name, tokenizer, data_collator, n_trials=3):
    def objective(trial):
        learning_rate = trial.suggest_float("learning_rate", 1e-5, 1e-3, log=True)
        train_batch_size = trial.suggest_categorical("train_batch_size", [4, 8, 16])
        num_train_epochs = trial.suggest_int("num_train_epochs", 3, 6)

        model = load_model(model_name)

        trainer = train_epoch_only(
            model=model,
            tokenizer=tokenizer,
            data_collator=data_collator,
            tokenized_datasets=tokenized_datasets,
            train_batch_size=train_batch_size,
            eval_batch_size=train_batch_size,
            learning_rate=learning_rate,
            num_train_epochs=num_train_epochs,
            output_dir=f"./results/optuna_trial_{trial.number}"
        )

        val_dataset = tokenized_datasets["validation"]
        references = []
        for ex in val_dataset:
            label_ids = ex["labels"]
            label_ids = [tok_id for tok_id in label_ids if tok_id != tokenizer.pad_token_id]
            ref_text = tokenizer.decode(label_ids, skip_special_tokens=True)
            references.append(ref_text)

        preds = manual_evaluate(
            model=model,
            tokenizer=tokenizer,
            dataset=val_dataset,
            batch_size=train_batch_size,
            max_length=128,
            num_beams=4,
            repetition_penalty=2.0,
            no_repeat_ngram_size=3,
            length_penalty=1.0,
            early_stopping=True
        )
        metrics = compute_rouge(references, preds)
        return metrics["rougeL"]

    study = optuna.create_study(direction="maximize")
    study.optimize(objective, n_trials=n_trials)

    best_trial = study.best_trial
    print("Best ROUGE-L:", best_trial.value)
    print("Best hyperparams:", best_trial.params)
    return study

## Plotting the training progress

In [None]:
def plot_training_history(trainer, save_path):
    training_loss = []
    eval_loss = []
    train_steps = []
    eval_steps = []

    for log in trainer.state.log_history:
        if 'loss' in log:  # training loss
            training_loss.append(log['loss'])
            train_steps.append(log['step'])
        if 'eval_loss' in log:  # evaluation loss
            eval_loss.append(log['eval_loss'])
            eval_steps.append(log['step'])

    plt.figure(figsize=(10, 6))
    if training_loss:
        plt.plot(train_steps, training_loss, label="Training Loss")
    if eval_loss:
        plt.plot(eval_steps, eval_loss, label="Evaluation Loss")
    plt.xlabel("Steps")
    plt.ylabel("Loss")
    plt.title("Training and Evaluation Loss Over Time")
    plt.legend()
    plt.grid(True)
    plt.savefig(save_path)
    plt.show()

## Pipeline

Calling the functions in a sequence to fine-tune the model

In [None]:
def pipeline(
    dataset_id,
    model_name,
    input_col="judgement",
    target_col="summary",
    push_to_hub=False,
    hub_repo_name=None
):

    ds_raw = load_dataset(dataset_id)
    ds_split = split_dataset(ds_raw)

    print("Loading model and tokenizer")
    model = load_model(model_name)
    tokenizer = load_tokenizer(model_name)
    data_collator = DataCollatorForSeq2Seq(tokenizer, model)
    
    print("Tokenizing dataset")
    tokenized = preprocess_and_tokenize(ds_split, tokenizer, input_col, target_col)
    
    print("Running hyperparameter tuning...")
    study = run_optuna_search(
        tokenized_datasets=tokenized,
        model_name=model_name,
        tokenizer=tokenizer,
        data_collator=data_collator,
        n_trials=3
    )
    best_params = study.best_trial.params
    print(f"Best params: {best_params}")
    
    # Train with best hyperparameters
    print("Step 6: Training with best hyperparameters...")
    model = load_model(model_name)
    trainer = train_epoch_only(
        model=model,
        tokenizer=tokenizer,
        data_collator=data_collator,
        tokenized_datasets=tokenized,
        train_batch_size=best_params["train_batch_size"],
        eval_batch_size=best_params["train_batch_size"],
        learning_rate=best_params["learning_rate"],
        num_train_epochs=best_params["num_train_epochs"],
        output_dir=f"./results/final_{dataset_id.split('/')[-1]}"
    )
    
    # Evaluate
    print("Step 7: Evaluating...")
    val_ds = tokenized["validation"]
    references = []
    for ex in val_ds:
        lbl_ids = ex["labels"]
        lbl_ids = [i for i in lbl_ids if i != tokenizer.pad_token_id]
        ref_text = tokenizer.decode(lbl_ids, skip_special_tokens=True)
        references.append(ref_text)
    
    preds = manual_evaluate(
        model=model,
        tokenizer=tokenizer,
        dataset=val_ds,
        batch_size=best_params["train_batch_size"],
        max_length=128,
        num_beams=4,
        repetition_penalty=2.0,
        no_repeat_ngram_size=3,
        length_penalty=1.0,
        early_stopping=True
    )
    
    final_metrics = compute_rouge(references, preds)
    print(f"Final Validation ROUGE: {final_metrics}")
    
    # Save model locally
    print("Step 8: Saving model...")
    dataset_name = dataset_id.split('/')[-1]
    local_model_dir = f"t5_{model_name.split('-')[-1]}_fine_tuned_{dataset_name}"
    trainer.save_model(local_model_dir)
    tokenizer.save_pretrained(local_model_dir)
    
    # Push to HuggingFace Hub if requested
    if push_to_hub:
        print("Step 9: Pushing to HuggingFace Hub...")
        if not hub_repo_name:
            hub_repo_name = local_model_dir
        trainer.push_to_hub(hub_repo_name)
        print(f"Model pushed to HF Hub: {hub_repo_name}")
    
    # Plot training history
    print("Step 10: Plotting training history...")
    plot_training_history(trainer, save_path=f'./training_history_{dataset_name}.png')
    
    print("Pipeline completed successfully!")

## Main function

In [None]:
model_name = "google/long-t5-tglobal-base"

dataset_ids = [
    "xkristian/LegalDocumentSummarization"
]

for dataset_id in dataset_ids:
    pipeline(
        dataset_id=dataset_id,
        model_name=model_name,
        input_col="judgement",
        target_col="summary",
        push_to_hub=True, 
        hub_repo_name="LegalDocumentSummarization"
    )