In [6]:
#from data_preprocessing import load_and_preprocess_data
from train_model import train_model
from evaluate import evaluate_model
from datasets import load_from_disk
#from rl_training import train_rl_model, load_rl_model
from transformers import T5ForConditionalGeneration, AutoTokenizer
import os

# Step 1: Load the tokenized dataset
tokenized_dataset_path = "./tokenized_cnn_dailymail"
try:
    tokenized_dataset = load_from_disk(tokenized_dataset_path)
    print(f"Tokenized dataset loaded from {tokenized_dataset_path}")
except FileNotFoundError:
    raise FileNotFoundError(f"Tokenized dataset not found at {tokenized_dataset_path}. Please preprocess and save it first.")

# Step 2: Load or train the baseline model
model_path = "./new_model"
try:
    # Load saved model and tokenizer
    model = T5ForConditionalGeneration.from_pretrained(model_path)
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    print(f"Model and tokenizer loaded from {model_path}")
except FileNotFoundError:
    print(f"Model not found at {model_path}. Training a new model...")
    train_model(tokenized_dataset, save_path=model_path)
    model = T5ForConditionalGeneration.from_pretrained(model_path)
    tokenizer = AutoTokenizer.from_pretrained(model_path)



Tokenized dataset loaded from ./tokenized_cnn_dailymail
Model and tokenizer loaded from ./new_model


In [8]:
# Evaluate baseline model
results = evaluate_model(tokenized_dataset["test"], "./new_model")
print("Baseline ROUGE Scores:", results)


Baseline ROUGE Scores: {'rouge1': AggregateScore(low=Score(precision=0.26934525322955366, recall=0.3602463636274879, fmeasure=0.30252468927069914), mid=Score(precision=0.292423662285552, recall=0.3929663421342301, fmeasure=0.32776543372953215), high=Score(precision=0.3184325315090215, recall=0.42668200910643056, fmeasure=0.3536614759198495)), 'rouge2': AggregateScore(low=Score(precision=0.09900605421032654, recall=0.13524754457649948, fmeasure=0.11275603792166371), mid=Score(precision=0.11758541720820494, recall=0.15903963535058496, fmeasure=0.13168541717958604), high=Score(precision=0.13653515139547978, recall=0.18476744871069717, fmeasure=0.15236039352410796)), 'rougeL': AggregateScore(low=Score(precision=0.19423215875411667, recall=0.2635538315183025, fmeasure=0.21964526695982453), mid=Score(precision=0.21495479349538615, recall=0.2905431752390082, fmeasure=0.24146084611729618), high=Score(precision=0.23709780285043489, recall=0.3213960980307847, fmeasure=0.26478044820838914)), 'rou

In [4]:
results_custom = evaluate_model(tokenized_dataset["test"], "./t5_rl_summarization_model")
print("RL-Tuned ROUGE Scores:", results_custom)

RL-Tuned ROUGE Scores: {'rouge1': AggregateScore(low=Score(precision=0.25100072075438085, recall=0.1993977160383028, fmeasure=0.21338175322955663), mid=Score(precision=0.25870779131176025, recall=0.20582763351843764, fmeasure=0.21975229117655812), high=Score(precision=0.26725030615436635, recall=0.21107152086299955, fmeasure=0.225336784409068)), 'rouge2': AggregateScore(low=Score(precision=0.05862306760884156, recall=0.0487416399903649, fmeasure=0.0510952569346142), mid=Score(precision=0.06325483974222897, recall=0.05234511810320749, fmeasure=0.05505311891697022), high=Score(precision=0.06780703311311573, recall=0.05586013673835154, fmeasure=0.05879574176006237)), 'rougeL': AggregateScore(low=Score(precision=0.20315960285073092, recall=0.1635369764552097, fmeasure=0.17397414865516192), mid=Score(precision=0.20963625260587723, recall=0.167852087071973, fmeasure=0.1784693516718131), high=Score(precision=0.21577249684188582, recall=0.17279658829993275, fmeasure=0.18340527590378947)), 'rou