# FLAN-T5-Base Model for News Summarization

This notebook implements Zero-shot, Few-shot, and Fine-tuning approaches.

In [1]:
# Install required packages
!pip install -q torch transformers datasets rouge-score bert-score numpy tqdm accelerate sentencepiece


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0.1[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [2]:
import torch
from transformers import T5ForConditionalGeneration, T5Tokenizer, Trainer, TrainingArguments
from datasets import load_dataset
from torch.utils.data import Dataset
import numpy as np
from rouge_score import rouge_scorer
from bert_score import score as bert_score
import json
from tqdm import tqdm

In [3]:
class NewsSummarizationDataset(Dataset):
    """Dataset class for news summarization"""
    def __init__(self, texts, summaries, tokenizer, max_input_length=512, max_target_length=128):
        self.texts = texts
        self.summaries = summaries
        self.tokenizer = tokenizer
        self.max_input_length = max_input_length
        self.max_target_length = max_target_length
    
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        text = str(self.texts[idx])
        summary = str(self.summaries[idx])
        
        # Tokenize inputs
        inputs = self.tokenizer(
            text,
            max_length=self.max_input_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        # Tokenize targets
        targets = self.tokenizer(
            summary,
            max_length=self.max_target_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        return {
            'input_ids': inputs['input_ids'].squeeze(),
            'attention_mask': inputs['attention_mask'].squeeze(),
            'labels': targets['input_ids'].squeeze()
        }


class FLANT5BaseSummarizer:
    def __init__(self, model_name="google/flan-t5-base"):
        self.model_name = model_name
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.tokenizer = T5Tokenizer.from_pretrained(model_name)
        self.model = T5ForConditionalGeneration.from_pretrained(model_name).to(self.device)
        print(f"Model loaded on {self.device}")
    
    def zero_shot_summarize(self, text, max_length=128, min_length=30):
        """Zero-shot summarization using prompt-based approach"""
        prompt = f"Summarize the following news article: {text}"
        
        inputs = self.tokenizer(
            prompt,
            max_length=512,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        ).to(self.device)
        
        with torch.no_grad():
            outputs = self.model.generate(
                input_ids=inputs['input_ids'],
                attention_mask=inputs['attention_mask'],
                max_length=max_length,
                min_length=min_length,
                num_beams=4,
                length_penalty=2.0,
                early_stopping=True
            )
        
        summary = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        return summary
    
    def few_shot_summarize(self, text, examples, max_length=128, min_length=30):
        """Few-shot summarization with example demonstrations"""
        prompt = "Summarize the following news articles:\n\n"
        
        for i, (article, summary) in enumerate(examples, 1):
            prompt += f"Example {i}:\n"
            prompt += f"Article: {article[:200]}...\n"
            prompt += f"Summary: {summary}\n\n"
        
        prompt += f"Now summarize this article:\n{text}\nSummary:"
        
        inputs = self.tokenizer(
            prompt,
            max_length=512,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        ).to(self.device)
        
        with torch.no_grad():
            outputs = self.model.generate(
                input_ids=inputs['input_ids'],
                attention_mask=inputs['attention_mask'],
                max_length=max_length,
                min_length=min_length,
                num_beams=4,
                length_penalty=2.0,
                early_stopping=True
            )
        
        summary = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        return summary
    
    def evaluate(self, texts, reference_summaries, method='zero_shot', examples=None):
        """Evaluate model performance using ROUGE and BERTScore"""
        generated_summaries = []
        
        print(f"Generating summaries using {method}...")
        for text in tqdm(texts):
            if method == 'zero_shot':
                summary = self.zero_shot_summarize(text)
            elif method == 'few_shot' and examples:
                summary = self.few_shot_summarize(text, examples)
            else:
                raise ValueError(f"Invalid method: {method}")
            generated_summaries.append(summary)
        
        scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
        rouge_scores = {'rouge1': [], 'rouge2': [], 'rougeL': []}
        
        for gen_sum, ref_sum in zip(generated_summaries, reference_summaries):
            scores = scorer.score(ref_sum, gen_sum)
            rouge_scores['rouge1'].append(scores['rouge1'].fmeasure)
            rouge_scores['rouge2'].append(scores['rouge2'].fmeasure)
            rouge_scores['rougeL'].append(scores['rougeL'].fmeasure)
        
        print("Calculating BERTScore...")
        P, R, F1 = bert_score(generated_summaries, reference_summaries, lang='en', verbose=True)
        
        results = {
            'rouge1': {'f1': np.mean(rouge_scores['rouge1'])},
            'rouge2': {'f1': np.mean(rouge_scores['rouge2'])},
            'rougeL': {'f1': np.mean(rouge_scores['rougeL'])},
            'bertscore': {
                'precision': P.mean().item(),
                'recall': R.mean().item(),
                'f1': F1.mean().item()
            }
        }
        
        return results, generated_summaries


In [4]:
def load_cnn_dailymail(split='test', num_samples=100):
    """Load CNN/DailyMail dataset"""
    dataset = load_dataset("cnn_dailymail", "3.0.0", split=split)
    if num_samples:
        dataset = dataset.select(range(min(num_samples, len(dataset))))
    texts = [item['article'] for item in dataset]
    summaries = [item['highlights'] for item in dataset]
    return texts, summaries

def load_xsum(split='test', num_samples=100):
    """Load XSum dataset"""
    dataset = load_dataset("xsum", split=split)
    if num_samples:
        dataset = dataset.select(range(min(num_samples, len(dataset))))
    texts = [item['document'] for item in dataset]
    summaries = [item['summary'] for item in dataset]
    return texts, summaries

## Initialize Model

In [5]:
# Initialize model
summarizer = FLANT5BaseSummarizer()

tokenizer_config.json: 0.00B [00:00, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

special_tokens_map.json: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


config.json: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/990M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

Model loaded on cpu


## Load Dataset

In [6]:
# Load dataset (using smaller subset for demonstration)
print("Loading CNN/DailyMail dataset...")
train_texts, train_summaries = load_cnn_dailymail('train', num_samples=1000)
test_texts, test_summaries = load_cnn_dailymail('test', num_samples=100)
print(f"Loaded {len(train_texts)} training samples and {len(test_texts)} test samples")

Loading CNN/DailyMail dataset...
Loaded 1000 training samples and 100 test samples


## Zero-shot Evaluation

In [7]:
# Zero-shot evaluation
print("=== Zero-shot Evaluation ===")
zero_shot_results, zero_shot_summaries = summarizer.evaluate(
    test_texts[:10], test_summaries[:10], method='zero_shot'
)
print("\nZero-shot Results:")
print(json.dumps(zero_shot_results, indent=2))

=== Zero-shot Evaluation ===
Generating summaries using zero_shot...


100%|██████████| 10/10 [00:29<00:00,  2.91s/it]


Calculating BERTScore...


Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


calculating scores...
computing bert embedding.


  0%|          | 0/1 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/1 [00:00<?, ?it/s]

done in 2.60 seconds, 3.85 sentences/sec

Zero-shot Results:
{
  "rouge1": {
    "f1": 0.32504652398970296
  },
  "rouge2": {
    "f1": 0.13829923433039357
  },
  "rougeL": {
    "f1": 0.24898921840528607
  },
  "bertscore": {
    "precision": 0.8817199468612671,
    "recall": 0.8624206781387329,
    "f1": 0.8719387054443359
  }
}


## Few-shot Evaluation

In [8]:
# Few-shot evaluation
print("=== Few-shot Evaluation ===")
# Create few-shot examples dynamically from training dataset
few_shot_examples = list(zip(train_texts[:3], train_summaries[:3]))
print(f"Created {len(few_shot_examples)} few-shot examples from training dataset")

few_shot_results, few_shot_summaries = summarizer.evaluate(
    test_texts[:10], test_summaries[:10], method='few_shot', examples=few_shot_examples
)
print("\nFew-shot Results:")
print(json.dumps(few_shot_results, indent=2))

=== Few-shot Evaluation ===
Created 3 few-shot examples from training dataset
Generating summaries using few_shot...


100%|██████████| 10/10 [00:22<00:00,  2.24s/it]


Calculating BERTScore...


Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


calculating scores...
computing bert embedding.


  0%|          | 0/1 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/1 [00:00<?, ?it/s]

done in 2.38 seconds, 4.21 sentences/sec

Few-shot Results:
{
  "rouge1": {
    "f1": 0.27626319629513074
  },
  "rouge2": {
    "f1": 0.10690527984124629
  },
  "rougeL": {
    "f1": 0.20366161325329096
  },
  "bertscore": {
    "precision": 0.872004508972168,
    "recall": 0.848450779914856,
    "f1": 0.8600193858146667
  }
}


## Comparison of Results

In [9]:
# Compare results from Zero-shot and Few-shot
print("\n=== COMPARISON OF RESULTS ===")
print("\nZero-shot Performance:")
print(f"  ROUGE-1 F1: {zero_shot_results['rouge1']['f1']:.4f}")
print(f"  ROUGE-2 F1: {zero_shot_results['rouge2']['f1']:.4f}")
print(f"  ROUGE-L F1: {zero_shot_results['rougeL']['f1']:.4f}")
print(f"  BERTScore F1: {zero_shot_results['bertscore']['f1']:.4f}")

print("\nFew-shot Performance:")
print(f"  ROUGE-1 F1: {few_shot_results['rouge1']['f1']:.4f}")
print(f"  ROUGE-2 F1: {few_shot_results['rouge2']['f1']:.4f}")
print(f"  ROUGE-L F1: {few_shot_results['rougeL']['f1']:.4f}")
print(f"  BERTScore F1: {few_shot_results['bertscore']['f1']:.4f}")

# Save results to file
results_summary = {
    "model": "FLAN-T5-Base",
    "zero_shot": zero_shot_results,
    "few_shot": few_shot_results
}

with open('flan_t5_base_results.json', 'w') as f:
    json.dump(results_summary, f, indent=2)
print("\nResults saved to flan_t5_base_results.json")


=== COMPARISON OF RESULTS ===

Zero-shot Performance:
  ROUGE-1 F1: 0.3250
  ROUGE-2 F1: 0.1383
  ROUGE-L F1: 0.2490
  BERTScore F1: 0.8719

Few-shot Performance:
  ROUGE-1 F1: 0.2763
  ROUGE-2 F1: 0.1069
  ROUGE-L F1: 0.2037
  BERTScore F1: 0.8600

Results saved to flan_t5_base_results.json


## Example Summaries

In [10]:
# Display example summaries for comparison
print("\n=== EXAMPLE SUMMARIES ===")
for i in range(min(3, len(test_texts))):
    print(f"\n--- Example {i+1} ---")
    print(f"\nOriginal Article (first 300 chars):\n{test_texts[i][:300]}...")
    print(f"\nReference Summary:\n{test_summaries[i]}")
    print(f"\nZero-shot Summary:\n{zero_shot_summaries[i]}")
    print(f"\nFew-shot Summary:\n{few_shot_summaries[i]}")
    print("-" * 80)


=== EXAMPLE SUMMARIES ===

--- Example 1 ---

Original Article (first 300 chars):
(CNN)The Palestinian Authority officially became the 123rd member of the International Criminal Court on Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The formal accession was marked with a ceremony at The Hague, in the Netherlands, where the cou...

Reference Summary:
Membership gives the ICC jurisdiction over alleged crimes committed in Palestinian territories since last June .
Israel and the United States opposed the move, which could open the door to war crimes investigations against Israelis .

Zero-shot Summary:
Palestinians should be allowed to join the ICC, a move that could lead to war crimes investigations against Israelis, a court official said.

Few-shot Summary:
The Palestinian Authority has become the 123rd member of the ICC, a step that gives the court jurisdiction over alleged crimes in Palestinian territories.
------------------------