<a href="https://colab.research.google.com/github/aliviyabose261105-a11y/aliviyabose261105-a11y/blob/main/Alcronix.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
!pip install -q evaluate rouge-score datasets transformers accelerate sentencepiece gradio sumy

  Preparing metadata (setup.py) ... [?25l[?25hdone
  Preparing metadata (setup.py) ... [?25l[?25hdone
  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m6.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m97.3/97.3 kB[0m [31m8.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.3/6.3 MB[0m [31m64.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for rouge-score (setup.py) ... [?25l[?25hdone
  Building wheel for breadability (setup.py) ... [?25l[?25hdone
  Building wheel for docopt (setup.py) ... [?25l[?25hdone


In [5]:
!pip install -q nltk
import nltk
nltk.download("punkt")
nltk.download("punkt_tab")


[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.
[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt_tab.zip.


True

In [6]:
# %% [markdown]
# # Text Summarization System (Extractive + Abstractive)
# - Dataset: CNN/DailyMail (https://huggingface.co/datasets/abisee/cnn_dailymail)
# - Abstractive model: BART (facebook/bart-large-cnn)
# - Extractive method: TextRank (sumy)
# - Metrics: ROUGE (and optional BERTScore)
# - App: Gradio UI to compare methods
# - Notes:
#   * For a reliable demo, you can use the pretrained model without fine-tuning.
#   * If you fine-tune, use more data than a few hundred examples to avoid overfitting.

# %% [markdown]
# ## 0) Setup

# Install deps (uncomment if needed when running in a fresh environment)
# !pip -q install "transformers>=4.44.0" "datasets>=2.20.0" "evaluate>=0.4.2" "rouge-score>=0.1.2" \
#                  "accelerate>=0.33.0" "sentencepiece" "gradio>=4.31.5" "sumy>=0.11.0"

# (Optional, for BERTScore—comment out if you want faster setup)
# !pip -q install bert-score

# %% [markdown]
# ## 1) Imports, Config, Utilities

import os
import random
import numpy as np
from dataclasses import dataclass
from typing import Dict, List

import evaluate
from datasets import load_dataset, DatasetDict
from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    DataCollatorForSeq2Seq,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    EarlyStoppingCallback,
)

import torch
from tqdm.auto import tqdm

# Repro
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# ----- Switch: fine-tune or not -----
DO_FINETUNE = False  # Recommended False for quick, strong results using pretrained BART

# Model + lengths
MODEL_NAME = "facebook/bart-large-cnn"  # already fine-tuned on CNN/DM
MAX_SOURCE_LEN = 1024
MAX_TARGET_LEN = 144  # ~3-4 sentences
MIN_TARGET_LEN = 32

# Demo sizes (increase for better results if you DO_FINETUNE)
TRAIN_SAMPLES = 400
VAL_SAMPLES   = 500
TEST_SAMPLES  = 500

# %% [markdown]
# ## 2) Load CNN/DailyMail

dataset = load_dataset("cnn_dailymail", "3.0.0")  # fields: 'article', 'highlights'

dataset_small = DatasetDict({
    "train": dataset["train"].select(range(min(TRAIN_SAMPLES, len(dataset["train"])))),
    "validation": dataset["validation"].select(range(min(VAL_SAMPLES, len(dataset["validation"])))),
    "test": dataset["test"].select(range(min(TEST_SAMPLES, len(dataset["test"])))),
})

print(dataset_small)

# %% [markdown]
# ## 3) Cleaning & Tokenization

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)

def clean_text(s: str) -> str:
    if not isinstance(s, str):
        return ""
    s = " ".join(s.split())  # collapse whitespace/newlines
    return s

def preprocess_fn(batch):
    articles = [clean_text(x) for x in batch["article"]]
    summaries = [clean_text(x) for x in batch["highlights"]]

    model_inputs = tokenizer(
        articles,
        max_length=MAX_SOURCE_LEN,
        padding="max_length",
        truncation=True,
    )
    # Modern, non-deprecated target tokenization
    labels = tokenizer(
        text_target=summaries,
        max_length=MAX_TARGET_LEN,
        padding="max_length",
        truncation=True,
    )
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

tokenized = dataset_small.map(
    preprocess_fn,
    batched=True,
    remove_columns=dataset_small["train"].column_names,
    desc="Tokenizing",
)

data_collator = DataCollatorForSeq2Seq(
    tokenizer=tokenizer,
    model=MODEL_NAME,
    label_pad_token_id=-100  # Trainer will ignore padding in loss
)

# %% [markdown]
# ## 4) Abstractive Model (BART) + (Optional) Training

model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
model.to(DEVICE)

# ROUGE metric
rouge = evaluate.load("rouge")

def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    # Replace -100 in the labels as we can't decode them
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    decoded_preds = [clean_text(x) for x in decoded_preds]
    decoded_labels = [clean_text(x) for x in decoded_labels]

    result = rouge.compute(
        predictions=decoded_preds,
        references=decoded_labels,
        use_stemmer=True
    )
    return {k: round(v * 100, 2) for k, v in result.items()}

if DO_FINETUNE:
    training_args = Seq2SeqTrainingArguments(
        output_dir="./results",
        evaluation_strategy="epoch",      # ✅ fixed (was eval_strategy)
        save_strategy="epoch",
        logging_strategy="steps",
        logging_steps=50,
        learning_rate=1e-5,               # lower LR → better stability on small data
        per_device_train_batch_size=2,
        per_device_eval_batch_size=2,
        gradient_accumulation_steps=8,    # effective batch size ↑ without OOM
        num_train_epochs=3,
        warmup_ratio=0.1,
        weight_decay=0.01,
        label_smoothing_factor=0.1,
        predict_with_generate=True,
        generation_max_length=MAX_TARGET_LEN,
        fp16=torch.cuda.is_available(),
        gradient_checkpointing=torch.cuda.is_available(),
        save_total_limit=2,
        report_to="none",
        load_best_model_at_end=True,
        metric_for_best_model="eval_rougeLsum",  # logs are prefixed with eval_
        greater_is_better=True,
    )

    trainer = Seq2SeqTrainer(
        model=model,
        args=training_args,
        train_dataset=tokenized["train"],
        eval_dataset=tokenized["validation"],
        tokenizer=tokenizer,
        data_collator=data_collator,
        compute_metrics=compute_metrics,
        callbacks=[EarlyStoppingCallback(early_stopping_patience=1)],
    )

    print("Starting fine-tuning (demo settings)...")
    train_result = trainer.train()
    print("Training done.")
    trainer.save_model("./results/best_model")
else:
    print("Skipping fine-tuning. Using pretrained BART-Large-CNN as-is.")

# %% [markdown]
# ## 5) Inference Helpers (Abstractive)

def generate_abstractive(
    text: str,
    max_len: int = MAX_TARGET_LEN,
    min_len: int = MIN_TARGET_LEN,
    num_beams: int = 6,
    length_penalty: float = 2.0,
    no_repeat_ngram_size: int = 3,
    repetition_penalty: float = 1.1,
) -> str:
    text = clean_text(text)
    if not text:
        return ""
    inputs = tokenizer([text], max_length=MAX_SOURCE_LEN, truncation=True, return_tensors="pt").to(DEVICE)
    with torch.no_grad():
        ids = model.generate(
            **inputs,
            num_beams=num_beams,
            max_length=max_len,
            min_length=min_len,
            length_penalty=length_penalty,
            no_repeat_ngram_size=no_repeat_ngram_size,
            repetition_penalty=repetition_penalty,
            early_stopping=True,
        )
    return tokenizer.decode(ids[0], skip_special_tokens=True)

# %% [markdown]
# ## 6) Extractive Summarizer (TextRank via Sumy)

from sumy.parsers.plaintext import PlaintextParser
from sumy.nlp.tokenizers import Tokenizer as SumyTokenizer
from sumy.summarizers.text_rank import TextRankSummarizer

def generate_extractive_text_rank(text: str, max_sentences: int = 4) -> str:
    text = clean_text(text)
    if not text:
        return ""
    parser = PlaintextParser.from_string(text, SumyTokenizer("english"))
    summarizer = TextRankSummarizer()
    sentences = summarizer(parser.document, max_sentences)
    return " ".join(str(s) for s in sentences)

# %% [markdown]
# ## 7) Side-by-Side Demo on a Validation Sample

idx = 10  # pick any index < len(dataset_small["validation"])
sample_article = dataset_small["validation"][idx]["article"]
sample_ref     = dataset_small["validation"][idx]["highlights"]

abs_summary = generate_abstractive(sample_article)
ext_summary = generate_extractive_text_rank(sample_article, max_sentences=4)

print("=== ARTICLE (truncated) ===")
print(sample_article[:800], "...\n")
print("=== REFERENCE SUMMARY ===")
print(sample_ref, "\n")
print("=== ABSTRACTIVE (BART) ===")
print(abs_summary, "\n")
print("=== EXTRACTIVE (TextRank) ===")
print(ext_summary, "\n")

# %% [markdown]
# ## 8) Evaluation on a Small Validation Subset

N = min(100, len(dataset_small["validation"]))  # adjust
val_articles = [dataset_small["validation"][i]["article"] for i in range(N)]
val_refs     = [dataset_small["validation"][i]["highlights"] for i in range(N)]

def batched_rouge(preds, refs, name=""):
    r = evaluate.load("rouge").compute(predictions=preds, references=refs, use_stemmer=True)
    r = {f"{name}{k}": round(v * 100, 2) for k, v in r.items()}
    return r

abs_preds, ext_preds = [], []
for art in tqdm(val_articles, desc="Generating (Abstractive)"):
    abs_preds.append(generate_abstractive(art))
for art in tqdm(val_articles, desc="Generating (Extractive)"):
    ext_preds.append(generate_extractive_text_rank(art, max_sentences=4))

abs_scores = batched_rouge(abs_preds, val_refs, name="abs_")
ext_scores = batched_rouge(ext_preds, val_refs, name="ext_")

print("Abstractive ROUGE:", abs_scores)
print("Extractive ROUGE:", ext_scores)

# (Optional) BERTScore
# import evaluate as ev
# bert = ev.load("bertscore")
# abs_bert = bert.compute(predictions=abs_preds, references=val_refs, lang="en")
# ext_bert = bert.compute(predictions=ext_preds, references=val_refs, lang="en")
# print("Abstractive BERTScore F1 (mean):", float(np.mean(abs_bert["f1"])))
# print("Extractive  BERTScore F1 (mean):", float(np.mean(ext_bert["f1"])))

# %% [markdown]
# ## 9) Gradio App (Compare Methods)

import gradio as gr

def summarize(text, method, max_len, min_len, beams, ext_sentences, length_penalty, repetition_penalty):
    text = clean_text(text)
    if not text:
        return ""
    if method == "Abstractive (BART)":
        return generate_abstractive(
            text,
            max_len=max_len,
            min_len=min_len,
            num_beams=beams,
            length_penalty=length_penalty,
            repetition_penalty=repetition_penalty,
        )
    else:
        return generate_extractive_text_rank(text, max_sentences=ext_sentences)

with gr.Blocks(title="Summarization: Extractive vs Abstractive") as demo:
    gr.Markdown("# 📰 Text Summarization\nCompare **Extractive (TextRank)** vs **Abstractive (BART)** on any text.")
    with gr.Row():
        with gr.Column():
            inp = gr.Textbox(label="Input Text", lines=14, placeholder="Paste document/article/transcript here...")
            method = gr.Radio(choices=["Abstractive (BART)", "Extractive (TextRank)"], value="Abstractive (BART)", label="Method")
            with gr.Accordion("Abstractive Settings (BART)", open=False):
                max_len = gr.Slider(32, 256, value=144, step=8, label="Max Length")
                min_len = gr.Slider(0, 128, value=32, step=4, label="Min Length")
                beams   = gr.Slider(1, 8, value=6, step=1, label="Beams")
                length_penalty = gr.Slider(0.5, 3.0, value=2.0, step=0.1, label="Length Penalty")
                repetition_penalty = gr.Slider(1.0, 2.0, value=1.1, step=0.05, label="Repetition Penalty")
            with gr.Accordion("Extractive Settings (TextRank)", open=False):
                ext_sentences = gr.Slider(1, 10, value=4, step=1, label="Number of sentences")
            btn = gr.Button("Summarize", variant="primary")
        with gr.Column():
            out = gr.Textbox(label="Summary", lines=14)
    btn.click(
        fn=summarize,
        inputs=[inp, method, max_len, min_len, beams, ext_sentences, length_penalty, repetition_penalty],
        outputs=[out]
    )

demo.launch(share=False)

# %% [markdown]
# ## 10) Save/Load Utilities

def save_pipeline(save_dir="./results/final"):
    os.makedirs(save_dir, exist_ok=True)
    tokenizer.save_pretrained(save_dir)
    model.save_pretrained(save_dir)
    print(f"Saved tokenizer+model to: {save_dir}")

def load_pipeline(load_dir="./results/final"):
    tok = AutoTokenizer.from_pretrained(load_dir, use_fast=True)
    mdl = AutoModelForSeq2SeqLM.from_pretrained(load_dir).to(DEVICE)
    return tok, mdl

save_pipeline("./results/final")

# %% [markdown]
# ## 11) Quick Example on Your Own Text

example_text = dataset_small["test"][0]["article"]
print("Abstractive:", generate_abstractive(example_text)[:400], "...")
print("Extractive:", generate_extractive_text_rank(example_text, max_sentences=4)[:400], "...")


DatasetDict({
    train: Dataset({
        features: ['article', 'highlights', 'id'],
        num_rows: 400
    })
    validation: Dataset({
        features: ['article', 'highlights', 'id'],
        num_rows: 500
    })
    test: Dataset({
        features: ['article', 'highlights', 'id'],
        num_rows: 500
    })
})


Tokenizing:   0%|          | 0/500 [00:00<?, ? examples/s]

Skipping fine-tuning. Using pretrained BART-Large-CNN as-is.
=== ARTICLE (truncated) ===
March 10, 2015 . We're truly international in scope on Tuesday. We're visiting Italy, Russia, the United Arab Emirates, and the Himalayan Mountains. Find out who's attempting to circumnavigate the globe in a plane powered partially by the sun, and explore the mysterious appearance of craters in northern Asia. You'll also get a view of Mount Everest that was previously reserved for climbers. On this page you will find today's show Transcript and a place for you to request to be on the CNN Student News Roll Call. TRANSCRIPT . Click here to access the transcript of today's CNN Student News program. Please note that there may be a delay between the time when the video is available and when the transcript is published. CNN Student News is created by a team of journalists who consider the Common ...

=== REFERENCE SUMMARY ===
This page includes the show Transcript .
Use the Transcript to help students wi

Generating (Abstractive):   0%|          | 0/100 [00:00<?, ?it/s]

Generating (Extractive):   0%|          | 0/100 [00:00<?, ?it/s]

Abstractive ROUGE: {'abs_rouge1': np.float64(37.26), 'abs_rouge2': np.float64(16.94), 'abs_rougeL': np.float64(27.36), 'abs_rougeLsum': np.float64(32.06)}
Extractive ROUGE: {'ext_rouge1': np.float64(21.55), 'ext_rouge2': np.float64(7.83), 'ext_rougeL': np.float64(14.31), 'ext_rougeLsum': np.float64(17.44)}
Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
Note: opening Chrome Inspector may crash demo inside Colab notebooks.
* To create a public link, set `share=True` in `launch()`.


<IPython.core.display.Javascript object>



Saved tokenizer+model to: ./results/final
Abstractive: The Palestinian Authority becomes the 123rd member of the International Criminal Court. The move gives the court jurisdiction over alleged crimes in Palestinian territories. Israel and the United States opposed the Palestinians' efforts to join the body. ...
Extractive: (CNN)The Palestinian Authority officially became the 123rd member of the International Criminal Court on Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. "As Palestine formally becomes a State Party to the Rome Statute today, the world is also a step closer to ending a long era of impunity and injustice," he said, according to an ICC news release. ...
