# BART-Base Model for News Summarization

This notebook implements Zero-shot, Few-shot, and Fine-tuning approaches.

In [1]:
# Install required packages
!pip install -q torch transformers datasets rouge-score bert-score numpy tqdm accelerate sentencepiece


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0.1[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [2]:
import torch
from transformers import BartForConditionalGeneration, BartTokenizer, Trainer, TrainingArguments
from datasets import load_dataset
from torch.utils.data import Dataset
import numpy as np
from rouge_score import rouge_scorer
from bert_score import score as bert_score
import json
from tqdm import tqdm

In [3]:
class NewsSummarizationDataset(Dataset):
    """Dataset class for news summarization"""
    def __init__(self, texts, summaries, tokenizer, max_input_length=512, max_target_length=128):
        self.texts = texts
        self.summaries = summaries
        self.tokenizer = tokenizer
        self.max_input_length = max_input_length
        self.max_target_length = max_target_length
    
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        text = str(self.texts[idx])
        summary = str(self.summaries[idx])
        
        # Tokenize inputs
        inputs = self.tokenizer(
            text,
            max_length=self.max_input_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        # Tokenize targets
        targets = self.tokenizer(
            summary,
            max_length=self.max_target_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        return {
            'input_ids': inputs['input_ids'].squeeze(),
            'attention_mask': inputs['attention_mask'].squeeze(),
            'labels': targets['input_ids'].squeeze()
        }

In [4]:
class BARTBaseSummarizer:
    def __init__(self, model_name="facebook/bart-base"):
        self.model_name = model_name
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.tokenizer = BartTokenizer.from_pretrained(model_name)
        self.model = BartForConditionalGeneration.from_pretrained(model_name).to(self.device)
        print(f"Model loaded on {self.device}")
    
    def zero_shot_summarize(self, text, max_length=128, min_length=30):
        """
        Zero-shot summarization - BART is pre-trained for summarization
        """
        inputs = self.tokenizer(
            text,
            max_length=512,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        ).to(self.device)
        
        with torch.no_grad():
            outputs = self.model.generate(
                input_ids=inputs['input_ids'],
                attention_mask=inputs['attention_mask'],
                max_length=max_length,
                min_length=min_length,
                num_beams=4,
                length_penalty=2.0,
                early_stopping=True
            )
        
        summary = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        return summary
    
    def few_shot_summarize(self, text, examples, max_length=128, min_length=30):
        """
        Few-shot summarization with example demonstrations
        Note: BART doesn't use prompts like T5, so we concatenate examples
        Args:
            text: Input news article to summarize
            examples: List of tuples (article, summary) for few-shot learning
        """
        # Build few-shot context by concatenating examples
        # This is a simplified approach - in practice, you might want to use prompt tuning
        context_text = ""
        
        for article, summary in examples:
            context_text += f"{article}\n\nSummary: {summary}\n\n"
        
        # Append the target article
        context_text += f"{text}\n\nSummary:"
        
        inputs = self.tokenizer(
            context_text,
            max_length=512,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        ).to(self.device)
        
        with torch.no_grad():
            outputs = self.model.generate(
                input_ids=inputs['input_ids'],
                attention_mask=inputs['attention_mask'],
                max_length=max_length,
                min_length=min_length,
                num_beams=4,
                length_penalty=2.0,
                early_stopping=True
            )
        
        summary = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        return summary
    
    def fine_tune(self, train_texts, train_summaries, val_texts=None, val_summaries=None,
                  output_dir="./bart_base_finetuned", num_epochs=3, batch_size=4):
        """
        Fine-tune the model on news summarization task
        """
        # Create datasets
        train_dataset = NewsSummarizationDataset(
            train_texts, train_summaries, self.tokenizer
        )
        
        val_dataset = None
        if val_texts and val_summaries:
            val_dataset = NewsSummarizationDataset(
                val_texts, val_summaries, self.tokenizer
            )
        
        # Training arguments
        training_args = TrainingArguments(
            output_dir=output_dir,
            num_train_epochs=num_epochs,
            per_device_train_batch_size=batch_size,
            per_device_eval_batch_size=batch_size,
            warmup_steps=500,
            weight_decay=0.01,
            logging_dir=f'{output_dir}/logs',
            logging_steps=100,
            eval_strategy="epoch" if val_dataset else "no",
            save_strategy="epoch",
            load_best_model_at_end=True if val_dataset else False,
            save_total_limit=2,
        )
        
        # Trainer
        trainer = Trainer(
            model=self.model,
            args=training_args,
            train_dataset=train_dataset,
            eval_dataset=val_dataset,
        )
        
        # Train
        print("Starting fine-tuning...")
        trainer.train()
        
        # Save model
        trainer.save_model()
        self.tokenizer.save_pretrained(output_dir)
        print(f"Model saved to {output_dir}")
        
        # Load fine-tuned model
        self.model = BartForConditionalGeneration.from_pretrained(output_dir).to(self.device)
        print("Fine-tuned model loaded")
    
    def evaluate(self, texts, reference_summaries, method='zero_shot', examples=None):
        """
        Evaluate model performance using ROUGE and BERTScore
        """
        generated_summaries = []
        
        print(f"Generating summaries using {method}...")
        for text in tqdm(texts):
            if method == 'zero_shot':
                summary = self.zero_shot_summarize(text)
            elif method == 'few_shot' and examples:
                summary = self.few_shot_summarize(text, examples)
            else:
                raise ValueError(f"Invalid method: {method}")
            generated_summaries.append(summary)
        
        # Calculate ROUGE scores
        scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
        rouge_scores = {'rouge1': [], 'rouge2': [], 'rougeL': []}
        
        for gen_sum, ref_sum in zip(generated_summaries, reference_summaries):
            scores = scorer.score(ref_sum, gen_sum)
            rouge_scores['rouge1'].append(scores['rouge1'].fmeasure)
            rouge_scores['rouge2'].append(scores['rouge2'].fmeasure)
            rouge_scores['rougeL'].append(scores['rougeL'].fmeasure)
        
        # Calculate BERTScore
        print("Calculating BERTScore...")
        P, R, F1 = bert_score(generated_summaries, reference_summaries, lang='en', verbose=True)
        
        results = {
            'rouge1': {
                'precision': np.mean([scores['rouge1'].precision for scores in 
                                     [scorer.score(ref, gen) for gen, ref in 
                                      zip(generated_summaries, reference_summaries)]]),
                'recall': np.mean([scores['rouge1'].recall for scores in 
                                  [scorer.score(ref, gen) for gen, ref in 
                                   zip(generated_summaries, reference_summaries)]]),
                'f1': np.mean(rouge_scores['rouge1'])
            },
            'rouge2': {
                'precision': np.mean([scores['rouge2'].precision for scores in 
                                     [scorer.score(ref, gen) for gen, ref in 
                                      zip(generated_summaries, reference_summaries)]]),
                'recall': np.mean([scores['rouge2'].recall for scores in 
                                  [scorer.score(ref, gen) for gen, ref in 
                                   zip(generated_summaries, reference_summaries)]]),
                'f1': np.mean(rouge_scores['rouge2'])
            },
            'rougeL': {
                'precision': np.mean([scores['rougeL'].precision for scores in 
                                     [scorer.score(ref, gen) for gen, ref in 
                                      zip(generated_summaries, reference_summaries)]]),
                'recall': np.mean([scores['rougeL'].recall for scores in 
                                  [scorer.score(ref, gen) for gen, ref in 
                                   zip(generated_summaries, reference_summaries)]]),
                'f1': np.mean(rouge_scores['rougeL'])
            },
            'bertscore': {
                'precision': P.mean().item(),
                'recall': R.mean().item(),
                'f1': F1.mean().item()
            }
        }
        
        return results, generated_summaries

In [5]:
def load_cnn_dailymail(split='test', num_samples=100):
    """Load CNN/DailyMail dataset"""
    dataset = load_dataset("cnn_dailymail", "3.0.0", split=split)
    if num_samples:
        dataset = dataset.select(range(min(num_samples, len(dataset))))
    texts = [item['article'] for item in dataset]
    summaries = [item['highlights'] for item in dataset]
    return texts, summaries

def load_xsum(split='test', num_samples=100):
    """Load XSum dataset"""
    dataset = load_dataset("xsum", split=split)
    if num_samples:
        dataset = dataset.select(range(min(num_samples, len(dataset))))
    texts = [item['document'] for item in dataset]
    summaries = [item['summary'] for item in dataset]
    return texts, summaries

## Initialize Model

In [6]:
summarizer = BARTBaseSummarizer()

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

config.json: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/558M [00:00<?, ?B/s]

Model loaded on cpu


## Load Dataset

In [7]:
print("Loading CNN/DailyMail dataset...")
train_texts, train_summaries = load_cnn_dailymail('train', num_samples=1000)
test_texts, test_summaries = load_cnn_dailymail('test', num_samples=100)
print(f"Loaded {len(train_texts)} training samples and {len(test_texts)} test samples")

Loading CNN/DailyMail dataset...
Loaded 1000 training samples and 100 test samples


## Zero-shot Evaluation

In [8]:
print("=== Zero-shot Evaluation ===")
zero_shot_results, zero_shot_summaries = summarizer.evaluate(
    test_texts[:10], test_summaries[:10], method='zero_shot'
)
print("\nZero-shot Results:")
print(json.dumps(zero_shot_results, indent=2))

=== Zero-shot Evaluation ===
Generating summaries using zero_shot...


100%|██████████| 10/10 [00:47<00:00,  4.75s/it]


Calculating BERTScore...


Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


calculating scores...
computing bert embedding.


  0%|          | 0/1 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/1 [00:00<?, ?it/s]

done in 3.48 seconds, 2.87 sentences/sec

Zero-shot Results:
{
  "rouge1": {
    "precision": 0.20874772135152114,
    "recall": 0.620037392041856,
    "f1": 0.30741451173498263
  },
  "rouge2": {
    "precision": 0.08529261882688013,
    "recall": 0.2632432069424972,
    "f1": 0.12658145082632183
  },
  "rougeL": {
    "precision": 0.1404936968656018,
    "recall": 0.42253222676183294,
    "f1": 0.20759259125989987
  },
  "bertscore": {
    "precision": 0.8464106321334839,
    "recall": 0.8873850703239441,
    "f1": 0.8663761019706726
  }
}


## Few-shot Evaluation

In [9]:
print("=== Few-shot Evaluation ===")
few_shot_examples = list(zip(train_texts[:3], train_summaries[:3]))
few_shot_results, few_shot_summaries = summarizer.evaluate(
    test_texts[:10], test_summaries[:10], method='few_shot', examples=few_shot_examples
)
print("\nFew-shot Results:")
print(json.dumps(few_shot_results, indent=2))

=== Few-shot Evaluation ===
Generating summaries using few_shot...


100%|██████████| 10/10 [00:45<00:00,  4.59s/it]


Calculating BERTScore...


Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


calculating scores...
computing bert embedding.


  0%|          | 0/1 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/1 [00:00<?, ?it/s]

done in 2.60 seconds, 3.85 sentences/sec

Few-shot Results:
{
  "rouge1": {
    "precision": 0.08316831683168317,
    "recall": 0.22851163550331957,
    "f1": 0.11990969968115048
  },
  "rouge2": {
    "precision": 0.009,
    "recall": 0.023183313122174996,
    "f1": 0.01277936243076817
  },
  "rougeL": {
    "precision": 0.06534653465346536,
    "recall": 0.1857887296408788,
    "f1": 0.09492511253450639
  },
  "bertscore": {
    "precision": 0.7934611439704895,
    "recall": 0.8140667080879211,
    "f1": 0.8036127090454102
  }
}


## Comparison of Results

In [10]:
# Compare results from Zero-shot and Few-shot
print("\n=== COMPARISON OF RESULTS ===")
print("\nZero-shot Performance:")
print(f"  ROUGE-1 F1: {zero_shot_results['rouge1']['f1']:.4f}")
print(f"  ROUGE-2 F1: {zero_shot_results['rouge2']['f1']:.4f}")
print(f"  ROUGE-L F1: {zero_shot_results['rougeL']['f1']:.4f}")
print(f"  BERTScore F1: {zero_shot_results['bertscore']['f1']:.4f}")

print("\nFew-shot Performance:")
print(f"  ROUGE-1 F1: {few_shot_results['rouge1']['f1']:.4f}")
print(f"  ROUGE-2 F1: {few_shot_results['rouge2']['f1']:.4f}")
print(f"  ROUGE-L F1: {few_shot_results['rougeL']['f1']:.4f}")
print(f"  BERTScore F1: {few_shot_results['bertscore']['f1']:.4f}")

# Save results to file
results_summary = {
    "model": "BART-Base",
    "zero_shot": zero_shot_results,
    "few_shot": few_shot_results
}

with open('bart_base_results.json', 'w') as f:
    json.dump(results_summary, f, indent=2)
print("\nResults saved to bart_base_results.json")


=== COMPARISON OF RESULTS ===

Zero-shot Performance:
  ROUGE-1 F1: 0.3074
  ROUGE-2 F1: 0.1266
  ROUGE-L F1: 0.2076
  BERTScore F1: 0.8664

Few-shot Performance:
  ROUGE-1 F1: 0.1199
  ROUGE-2 F1: 0.0128
  ROUGE-L F1: 0.0949
  BERTScore F1: 0.8036

Results saved to bart_base_results.json


## Example Summaries

In [11]:
# Display example summaries for comparison
print("\n=== EXAMPLE SUMMARIES ===")
for i in range(min(3, len(test_texts))):
    print(f"\n--- Example {i+1} ---")
    print(f"\nOriginal Article (first 300 chars):\n{test_texts[i][:300]}...")
    print(f"\nReference Summary:\n{test_summaries[i]}")
    print(f"\nZero-shot Summary:\n{zero_shot_summaries[i]}")
    print(f"\nFew-shot Summary:\n{few_shot_summaries[i]}")
    print("-" * 80)


=== EXAMPLE SUMMARIES ===

--- Example 1 ---

Original Article (first 300 chars):
(CNN)The Palestinian Authority officially became the 123rd member of the International Criminal Court on Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The formal accession was marked with a ceremony at The Hague, in the Netherlands, where the cou...

Reference Summary:
Membership gives the ICC jurisdiction over alleged crimes committed in Palestinian territories since last June .
Israel and the United States opposed the move, which could open the door to war crimes investigations against Israelis .

Zero-shot Summary:
(CNN)The Palestinian Authority officially became the 123rd member of the International Criminal Court on Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The formal accession was marked with a ceremony at The Hague, in the Netherlands, where the court is based. The Palestinians signed th