# Step 1: Imports

In [1]:
from datasets import load_dataset
from transformers import BartTokenizer, BartForConditionalGeneration, Trainer, TrainingArguments, DataCollatorForSeq2Seq
import evaluate
import torch
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer

  from .autonotebook import tqdm as notebook_tqdm


# Step 2: Load & Subsample CNN/DailyMail

In [2]:
dataset = load_dataset("cnn_dailymail", "3.0.0")

# Subsample to reduce training time

In [None]:
import torch
from datasets import load_dataset
from transformers import (
    BartTokenizer,
    BartForConditionalGeneration,
    DataCollatorForSeq2Seq,
)
import evaluate
from random import seed

In [4]:
seed(42)
train_dataset = dataset["train"].shuffle(seed=42).select(range(50000))
val_dataset = dataset["validation"].shuffle(seed=42).select(range(3000))
test_dataset = dataset["test"].shuffle(seed=42).select(range(3000))

# Step 3: Load Tokenizer & Model

In [5]:
model_name = "facebook/bart-base"
tokenizer = BartTokenizer.from_pretrained(model_name)
model = BartForConditionalGeneration.from_pretrained(
    "facebook/bart-base",
    gradient_checkpointing=True  # Saves memory during training
)


# Step 4: Preprocessing function

In [6]:
def preprocess(example):
    inputs = tokenizer(
        example["article"], max_length=1024, truncation=True, padding="max_length"
    )
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(
            example["highlights"], max_length=128, truncation=True, padding="max_length"
        )
    inputs["labels"] = labels["input_ids"]
    return inputs

# Tokenize datasets

In [7]:
print("Tokenizing datasets...")
tokenized_train = train_dataset.map(preprocess, batched=True, remove_columns=train_dataset.column_names)
tokenized_val = val_dataset.map(preprocess, batched=True, remove_columns=val_dataset.column_names)
tokenized_test = test_dataset.map(preprocess, batched=True, remove_columns=test_dataset.column_names)

Tokenizing datasets...


# Step 5: Data Collator

In [8]:
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)

# Step 6: Define Evaluation Metric (ROUGE)

In [9]:
rouge = evaluate.load("rouge")

# ✅ Correct compute_metrics function
def compute_metrics(eval_pred):
    predictions, labels = eval_pred

    # Replace -100 in labels with pad_token_id for decoding
    labels = [[(label if label != -100 else tokenizer.pad_token_id) for label in l] for l in labels]

    # Decode predicted and reference summaries
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Clean text
    decoded_preds = [pred.strip() for pred in decoded_preds]
    decoded_labels = [label.strip() for label in decoded_labels]

    # Compute ROUGE
    result = rouge.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
    result = {k: round(v * 100, 4) for k, v in result.items()}

    return result

# Step 7: Training Arguments

In [10]:
training_args = Seq2SeqTrainingArguments(
    output_dir="./results",
    eval_strategy="steps",
    eval_steps=500,
    logging_steps=500,
    save_steps=1000,
    per_device_train_batch_size=2,        # Small batch for limited GPU
    per_device_eval_batch_size=2,
    gradient_accumulation_steps=4,        # Effective larger batch
    num_train_epochs=3,
    save_total_limit=2,
    predict_with_generate=True,           # ✅ Needed for summarization
    fp16=True,                            # Use if you have a compatible GPU
    logging_dir="./logs",
    report_to="none",                     # Can be "wandb", "tensorboard", etc.
)

# Step 8: Define Trainer

In [11]:
trainer = Seq2SeqTrainer(
    model=model,                    # e.g. BartForConditionalGeneration or T5ForConditionalGeneration
    data_collator=data_collator,
    args=training_args,             # instance of Seq2SeqTrainingArguments
    train_dataset=tokenized_train, # your pre-tokenized training dataset
    eval_dataset=tokenized_val,    # your pre-tokenized validation dataset
    tokenizer=tokenizer,           # the tokenizer you used for preprocessing
)

  trainer = Seq2SeqTrainer(


# Step 9: Train the Model

In [12]:
trainer.train()

Step,Training Loss,Validation Loss
500,1.6077,1.104044
1000,1.1559,1.104603
1500,1.124,1.072914
2000,1.1007,1.048071
2500,1.0982,1.050083
3000,1.097,1.041155
3500,1.081,1.056923
4000,1.0717,1.028411
4500,1.0792,1.029974
5000,1.057,1.025284




TrainOutput(global_step=18750, training_loss=0.9688578588867187, metrics={'train_runtime': 19882.3783, 'train_samples_per_second': 7.544, 'train_steps_per_second': 0.943, 'total_flos': 9.1460468736e+16, 'train_loss': 0.9688578588867187, 'epoch': 3.0})

# Step 10: Save the Model

In [13]:
model.save_pretrained(".")
tokenizer.save_pretrained(".")


('.\\tokenizer_config.json',
 '.\\special_tokens_map.json',
 '.\\vocab.json',
 '.\\merges.txt',
 '.\\added_tokens.json')

 # Inference on Custom Text

In [14]:
article = "The European Central Bank has decided to leave interest rates unchanged as the economic outlook remains uncertain..."

inputs = tokenizer(article, return_tensors="pt", truncation=True, max_length=1024).to(model.device)
summary_ids = model.generate(inputs["input_ids"], max_length=128, num_beams=4, early_stopping=True)
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)

print("\nGenerated Summary:\n", summary)




Generated Summary:
 The European Central Bank has decided to leave interest rates unchanged as the economic outlook remains uncertain .


# Step 1: Tokenize the test_dataset

In [15]:
tokenized_test = test_dataset.map(preprocess, batched=True, remove_columns=dataset["test"].column_names)

# Step 2: Run Evaluation on Test Set

In [16]:
metrics = trainer.evaluate(eval_dataset=tokenized_test)
print("📊 Test Set ROUGE Metrics:", metrics)

📊 Test Set ROUGE Metrics: {'eval_loss': 0.9473769664764404, 'eval_runtime': 71.0281, 'eval_samples_per_second': 42.237, 'eval_steps_per_second': 21.118, 'epoch': 3.0}


In [17]:
# ✅ Reattach compute_metrics to Trainer (optional if already done)
trainer.compute_metrics = compute_metrics

# ✅ Proper evaluation — this will generate predictions and compute ROUGE
metrics = trainer.evaluate(eval_dataset=tokenized_test)
print("📊 Test Set ROUGE Metrics:", metrics)

📊 Test Set ROUGE Metrics: {'eval_loss': 0.9473769664764404, 'eval_rouge1': 25.2112, 'eval_rouge2': 12.28, 'eval_rougeL': 20.6701, 'eval_rougeLsum': 23.6094, 'eval_runtime': 813.7209, 'eval_samples_per_second': 3.687, 'eval_steps_per_second': 1.843, 'epoch': 3.0}


# Generate Summaries for Test Samples

In [18]:
for i in range(5):
    article = test_dataset[i]["article"]
    reference = test_dataset[i]["highlights"]

    inputs = tokenizer(article, return_tensors="pt", truncation=True, max_length=1024).to(model.device)
    summary_ids = model.generate(inputs["input_ids"], max_length=128, num_beams=4, early_stopping=True)
    summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)

    print(f"\n📰 Article #{i+1}")
    print("📌 Reference Summary:", reference)
    print("🧠 Generated Summary:", summary)



📰 Article #1
📌 Reference Summary: CNN's Dr. Sanjay Gupta says we should legalize medical marijuana now .
He says he knows how easy it is do nothing "because I did nothing for too long"
🧠 Generated Summary: For the first time a majority, 53%, favor marijuana legalization .
Support for legalization has risen 11 points in the past few years alone .
"Weed" is the first federally approved clinical study on the use of marijuana for PTSD .

📰 Article #2
📌 Reference Summary: Child has amassed thousands of Twitter followers with 'gang life' photos .
In one video he points gun at camera as adults look on unfazed .
His tweets have prompted backlash with calls for intervention .
🧠 Generated Summary: Baby-faced boy from Memphis, Tennessee, poses with guns, cash, and bags of marijuana .
Tweets include phrases such as 'I need a bad b****', 'f*** da police', and 'gang sh** n****'
As he is a minor, DailyMail.com will not identify the boy .

📰 Article #3
📌 Reference Summary: The presidential hopeful he