# T5 Fine-tuned model inference

In [None]:
# pip install -qqq --upgrade -r requirements.txt

In [None]:
# General libraries
import os
import random
import re
import time

# Data handling libraries
import pandas as pd
from datasets import load_dataset, DatasetDict

# Transformers libraries
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

# Import torch
import torch

# NLP and evaluation libraries
from rouge_score import rouge_scorer
from nltk.translate.bleu_score import sentence_bleu
from nltk.translate.meteor_score import single_meteor_score
from bert_score import BERTScorer
import spacy
nlp = spacy.load("en_core_web_sm")
import nltk
nltk.download('wordnet')

# Logging for the pipeline
import logging

# Language libraries
from nltk.tokenize import sent_tokenize
nltk.download('punkt')
nltk.download('punkt_tab')

# Load the dataset

In [None]:
def hf_load_dataset(dataset_id, split="train"):
    dataset = load_dataset(dataset_id, split=split)
    return dataset

In [None]:
def split_dataset(dataset, test_size=0.1, seed=42):
    dataset_split = dataset.train_test_split(test_size=test_size, seed=seed)
    return dataset_split["test"]

In [None]:
def preprocess_and_tokenize(
    dataset,
    tokenizer,
    input_col,
    input_max_length
):
    def preprocess_function(examples):
        inputs = ["summarize: " + (doc or "") for doc in examples[input_col]]

        model_inputs = tokenizer(
            inputs,
            max_length=input_max_length,
            truncation=True,
            padding="max_length",
        )

        return model_inputs

    return dataset.map(
        preprocess_function,
        batched=True,
        remove_columns=dataset.column_names
    )

# Load the tokenizer and model

In [None]:
def load_model(model_name):
    model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
    model.gradient_checkpointing_enable()
    return model

In [None]:
def load_tokenizer(model_name):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    return tokenizer

# Inference

In [None]:
def summarize_text(
        prompt, 
        tokenizer, 
        model, 
        device="cpu",
        max_new_tokens=1024, 
        temperature=0.7, 
        top_p=0.9
    ):
    inputs = tokenizer(
        prompt, 
        return_tensors="pt", 
        padding=True, 
        truncation=True, 
        max_length=8192
        ).to(device)

    outputs = model.generate(
        input_ids=inputs["input_ids"],
        attention_mask=inputs["attention_mask"],
        do_sample=True,
        max_new_tokens=max_new_tokens,
        temperature=temperature,
        top_p=top_p,
        no_repeat_ngram_size=3,
        repetition_penalty=1.5,
        min_length=50,
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer.eos_token_id
    )

    return tokenizer.decode(outputs[0], skip_special_tokens=True)

# Evaluate the model

In [None]:
def rouge_score(response, summary):
    scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
    scores = scorer.score(response, summary)
    f1_scores = [score.fmeasure for score in scores.values()]
    return f1_scores[0], f1_scores[1], f1_scores[2]


In [None]:
def bleu_score(response, summary):
    doc_abstract = nlp(summary)
    doc_summary = nlp(response)

    sentences_abstract = [sent.text.split() for sent in doc_abstract.sents]
    sentences_summary = [token.text for token in doc_summary]

    weights = (0.5, 0.5, 0.25, 0.25)

    score = sentence_bleu(sentences_abstract, sentences_summary, weights=weights)
    return score

In [None]:
def meteor_score(response, summary):
    summary_tokens = summary.split()
    response_tokens = response.split()

    score = round(single_meteor_score(response_tokens, summary_tokens), 4)
    return score

In [None]:
def bert_score(response, summary):
    scorer = BERTScorer(model_type='bert-base-uncased')
    P, R, F1 = scorer.score([response], [summary])
    return F1.item()


In [None]:
def compute_metrics(response, summary):
  rouge1, rouge2, rougel = rouge_score(response, summary)
  bleu = bleu_score(response, summary)
  meteor = meteor_score(response, summary)
  bert = bert_score(response, summary)

  return {
      "rouge1": rouge1,
      "rouge2": rouge2,
      "rougel": rougel,
      "bleu": bleu,
      "meteor": meteor,
      "bert": bert
  }

In [None]:
def save_metrics_to_csv(metrics_list, output_folder, dataset_name, model_name):
    df = pd.DataFrame(metrics_list, columns=["ROUGE-1", "ROUGE-2", "ROUGE-L", "BLEU", "METEOR", "BERT", "Time"])
    results_file = os.path.join(output_folder, f"results_{dataset_name}_{model_name}.csv")
    df.to_csv(results_file, index=False)
    return results_file

In [None]:
def pipeline(
    dataset_id, 
    model_name, 
    output_folder, 
    input_col,
    target_col,
    num_samples,
    device
):
    os.makedirs(output_folder, exist_ok=True)
    
    dataset = hf_load_dataset(dataset_id)
    dataset_split = split_dataset(dataset)
    
    model = load_model(model_name).to(device)
    tokenizer = load_tokenizer(model_name)
    
    test_samples_list = list(dataset_split)
    random.seed(42)
    test_samples = random.sample(test_samples_list, min(num_samples, len(test_samples_list)))
    
    metrics_list = []
    
    for idx, sample in enumerate(test_samples):
        if idx % 10 == 0:
            print(f"Processing sample {idx}/{len(test_samples)}")
        
        text = sample[input_col]
        reference_summary = sample[target_col]
        text_with_prompt = "summarize: " + (text or "")
        
        start_time = time.time()
        response = summarize_text(text_with_prompt, tokenizer, model, device=device)
        end_time = time.time()
        elapsed_time = end_time - start_time
                
        
        metrics = compute_metrics(response, reference_summary)
        metrics_list.append([
            metrics["rouge1"],
            metrics["rouge2"],
            metrics["rougel"],
            metrics["bleu"],
            metrics["meteor"],
            metrics["bert"],
            elapsed_time
        ])
    
    model_checkpoint = model_name.split('/')[-1]
    dataset_name = dataset_id.split('/')[-1]
    results_file = save_metrics_to_csv(metrics_list, output_folder, dataset_name, model_checkpoint)
    print(f"Results saved to {results_file}")

In [None]:
dataset_id = "xkristian/LegalDocumentSummarization"
model_name = "xkristian/long5-LegalDocumentSummarization"
output_folder = "./inference_results"
num_samples = 50

pipeline(
    dataset_id=dataset_id,
    model_name=model_name,
    output_folder=output_folder,
    input_col="judgement",
    target_col="summary",
    num_samples=num_samples,
    device="cuda" if torch.cuda.is_available() else "cpu"
)

In [None]:
dataset = hf_load_dataset(dataset_id)
dataset_split = split_dataset(dataset)

model = load_model(model_name)
tokenizer = load_tokenizer(model_name)

sample = dataset_split[0]

judgement = sample["judgement"]
reference_summary = sample["summary"]

text_with_prompt = "summarize: " + (judgement or "")
model_summary = summarize_text(text_with_prompt, tokenizer, model, device="cuda" if torch.cuda.is_available() else "cpu")

metrics = compute_metrics(model_summary, reference_summary)

# Print results
print("=" * 80)
print("JUDGEMENT (Input):")
print("=" * 80)
print(judgement[:500] + "..." if len(judgement) > 500 else judgement)
print("\n" + "=" * 80)
print("REFERENCE SUMMARY:")
print("=" * 80)
print(reference_summary)
print("\n" + "=" * 80)
print("MODEL SUMMARY:")
print("=" * 80)
print(model_summary)
print("\n" + "=" * 80)
print("METRICS:")
print("=" * 80)
for metric, value in metrics.items():
    print(f"{metric.upper():12} : {value:.4f}")
print("=" * 80)