In [36]:
from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    Trainer,
    TrainingArguments,
    Pipeline,
)
import pyarrow as pa
import pyarrow.parquet as pq
import torch
from rouge_score import rouge_scorer
import pandas as pd

In [5]:
dataset = pq.read_table('./data/test-00000-of-00001.parquet')

In [6]:
model = AutoModelForSeq2SeqLM.from_pretrained("PrekshaJoon/flan-t5-finetuned-summarization")
tokenizer = AutoTokenizer.from_pretrained("PrekshaJoon/flan-t5-finetuned-summarization")

In [43]:
def generate_summary(article, max_length=256, min_length=100, length_penalty=2.0, num_beams=16):
    article = str(article)
    
    # Tokenize the article
    inputs = tokenizer("summarize: " + article, return_tensors="pt", max_length=1024, truncation=True)
    
    # Generate the summary
    summary_ids = model.generate(
        inputs["input_ids"],
        max_length=max_length,
        min_length=min_length,
        num_beams=num_beams,
        length_penalty=length_penalty,
        early_stopping=True
    )
    
    # Decode the summary
    summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
    return summary


In [46]:
results = []

for i in range(len(dataset["article"])):
    article = str(dataset["article"][i])
    reference_summary = str(dataset["highlights"][i])
    generated_summary = str(generate_summary(article))
    
    # Calculate ROUGE scores
    rouge_scores = scorer.score(reference_summary, generated_summary)
    
    # Store results
    results.append({
        "article": article,
        "reference_summary": reference_summary,
        "generated_summary": generated_summary,
        "rouge1": rouge_scores["rouge1"].fmeasure,
        "rouge2": rouge_scores["rouge2"].fmeasure,
        "rougeL": rouge_scores["rougeL"].fmeasure
    })

    if i == 50:
        break

results_df = pd.DataFrame(results)
results_df.to_csv("rouge_scores_flanT5_finetuned.csv", index=False)

print(results_df.head())

                                             article  \
0  (CNN)The Palestinian Authority officially beca...   
1  (CNN)Never mind cats having nine lives. A stra...   
2  (CNN)If you've been following the news lately,...   
3  (CNN)Five Americans who were monitored for thr...   
4  (CNN)A Duke student has admitted to hanging a ...   

                                   reference_summary  \
0  Membership gives the ICC jurisdiction over all...   
1  Theia, a bully breed mix, was apparently hit b...   
2  Mohammad Javad Zarif has spent more time with ...   
3  17 Americans were exposed to the Ebola virus w...   
4  Student is no longer on Duke University campus...   

                                   generated_summary    rouge1    rouge2  \
0  The Palestinian Authority officially became th...  0.323529  0.194030   
1  A stray pooch in Washington State has used up ...  0.413793  0.175439   
2  If you've been following the news lately, ther...  0.250000  0.106667   
3  Five Americans who 