# Unsupervised Data Pruning with Perplexity Scoring for Summarization

This notebook demonstrates how to use the `PerplexityScorer` to improve dataset quality for text summarization tasks using the CNN/DailyMail dataset. We'll train summarization models on both original and pruned datasets and compare their ROUGE-L scores.

## Overview

The `PerplexityScorer` calculates perplexity scores for text using a KenLM language model. Higher perplexity indicates harder (and potentially more informative) instances, while lower perplexity indicates easier and more prototypical instances.

For summarization tasks, we can use perplexity scoring to:
- Remove extremely low-quality or malformed articles
- Filter out articles that are too difficult or unusual
- Keep a balanced dataset of informative yet learnable examples

## Requirements

You'll need to install the dependencies:

```bash
pip install dprune transformers rouge-score nltk
```

You'll also need a KenLM language model. For this example, we'll download a pre-trained English model.


In [None]:
import os
import numpy as np
import pandas as pd
from datasets import Dataset, load_dataset
from transformers import (
    AutoTokenizer, AutoModelForSeq2SeqLM, 
    Seq2SeqTrainingArguments, Seq2SeqTrainer,
    DataCollatorForSeq2Seq, pipeline
)
from rouge_score import rouge_scorer
import torch
from dprune.scorers import PerplexityScorer
from dprune.pruners import TopKPruner, BottomKPruner
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.auto import tqdm

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

print("Libraries imported successfully!")


## Step 1: Download KenLM Model and Load Dataset

First, we'll download a pre-trained KenLM model and load the CNN/DailyMail dataset.


In [None]:
# Download and setup KenLM model (optional - for demo we'll use mock scores)
# You can download a pre-trained KenLM model from:
# https://github.com/kpu/kenlm or https://huggingface.co/models?search=kenlm

# For this demo, we'll show how to use PerplexityScorer with a real model:
# KENLM_MODEL_PATH = "/path/to/your/kenlm/model.bin"

# Load CNN/DailyMail dataset
print("Loading CNN/DailyMail dataset...")
dataset = load_dataset("abisee/cnn_dailymail", "3.0.0", split="train")

# Take a subset for faster experimentation (remove this for full dataset)
SUBSET_SIZE = 1000  # Adjust based on your computational resources
dataset = dataset.select(range(SUBSET_SIZE))

print(f"Dataset loaded with {len(dataset)} examples")
print(f"Dataset columns: {dataset.column_names}")

# Show a sample
sample = dataset[0]
print(f"\nSample article (first 300 chars): {sample['article'][:300]}...")
print(f"Sample highlights (first 200 chars): {sample['highlights'][:200]}...")


## Step 2: Calculate Perplexity Scores

We'll use the PerplexityScorer to score the articles based on their text quality and complexity.


In [None]:
# Calculate perplexity scores
# Method 1: Using a real KenLM model (uncomment if you have one)
# scorer = PerplexityScorer(
#     model_path=KENLM_MODEL_PATH,
#     text_column='article',
#     batch_size=50
# )
# scored_dataset = scorer.score(dataset)

# Method 2: Generate realistic mock scores for demonstration
def generate_mock_perplexity_scores(articles):
    """Generate mock perplexity scores based on article characteristics."""
    scores = []
    for article in tqdm(articles, desc="Generating mock perplexity scores"):
        # Calculate features that correlate with perplexity
        word_count = len(article.split())
        sentence_count = len([s for s in article.split('.') if s.strip()])
        avg_sentence_length = word_count / max(sentence_count, 1)
        
        # Count complex words (longer than 6 characters)
        complex_words = len([w for w in article.split() if len(w) > 6])
        complexity_ratio = complex_words / max(word_count, 1)
        
        # Base perplexity on text characteristics
        # Longer articles with more complex words tend to have higher perplexity
        base_score = 50 + (word_count * 0.05) + (complexity_ratio * 100)
        
        # Add some noise to make it more realistic
        noise = np.random.normal(0, 10)
        score = max(10, base_score + noise)  # Ensure minimum score of 10
        scores.append(score)
    
    return scores

# Generate mock scores
mock_scores = generate_mock_perplexity_scores(dataset['article'])
scored_dataset = dataset.add_column('perplexity_score', mock_scores)

print(f"Perplexity scores calculated for {len(scored_dataset)} articles")
print(f"Score statistics:")
print(f"  Mean: {np.mean(mock_scores):.2f}")
print(f"  Median: {np.median(mock_scores):.2f}")
print(f"  Min: {np.min(mock_scores):.2f}")
print(f"  Max: {np.max(mock_scores):.2f}")
print(f"  Std: {np.std(mock_scores):.2f}")


In [None]:
# Visualize perplexity distribution
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.hist(mock_scores, bins=30, alpha=0.7, edgecolor='black')
plt.axvline(np.mean(mock_scores), color='red', linestyle='--', label=f'Mean: {np.mean(mock_scores):.1f}')
plt.axvline(np.median(mock_scores), color='green', linestyle='--', label=f'Median: {np.median(mock_scores):.1f}')
plt.xlabel('Perplexity Score')
plt.ylabel('Frequency')
plt.title('Distribution of Perplexity Scores')
plt.legend()
plt.grid(True, alpha=0.3)

plt.subplot(1, 2, 2)
article_lengths = [len(article.split()) for article in dataset['article']]
plt.scatter(article_lengths, mock_scores, alpha=0.6, s=20)
plt.xlabel('Article Length (words)')
plt.ylabel('Perplexity Score')
plt.title('Perplexity vs Article Length')
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Show examples of high and low perplexity articles
sorted_indices = np.argsort(mock_scores)

print("\\n=== LOWEST PERPLEXITY ARTICLES (Easiest/Most Prototypical) ===")
for i in range(3):
    idx = sorted_indices[i]
    print(f"\\nScore: {mock_scores[idx]:.1f}")
    print(f"Article: {scored_dataset[idx]['article'][:200]}...")
    print(f"Summary: {scored_dataset[idx]['highlights'][:100]}...")

print("\\n=== HIGHEST PERPLEXITY ARTICLES (Hardest/Most Complex) ===")
for i in range(3):
    idx = sorted_indices[-(i+1)]
    print(f"\\nScore: {mock_scores[idx]:.1f}")
    print(f"Article: {scored_dataset[idx]['article'][:200]}...")
    print(f"Summary: {scored_dataset[idx]['highlights'][:100]}...")


## Step 3: Prune the Dataset

We'll create pruned versions of the dataset by removing extreme outliers (both very high and very low perplexity) to focus on a balanced set of informative examples.


In [None]:
# Create different pruned datasets for comparison
# We'll compare:
# 1. Original dataset (no pruning)
# 2. Pruned dataset (remove extreme outliers - top and bottom 20%)

# For proper pruning with perplexity scores, we want to remove:
# - Articles with extremely low perplexity (too simple/repetitive)
# - Articles with extremely high perplexity (too complex/noisy)

# Calculate pruning thresholds
perplexity_scores = np.array(mock_scores)
low_threshold = np.percentile(perplexity_scores, 20)  # Bottom 20%
high_threshold = np.percentile(perplexity_scores, 80)  # Top 20%

print(f"Perplexity thresholds:")
print(f"  Low threshold (20th percentile): {low_threshold:.2f}")
print(f"  High threshold (80th percentile): {high_threshold:.2f}")

# Create masks for filtering
middle_range_mask = (perplexity_scores >= low_threshold) & (perplexity_scores <= high_threshold)

# Create datasets
original_dataset = scored_dataset
pruned_dataset = scored_dataset.filter(lambda example, idx: middle_range_mask[idx], with_indices=True)

print(f"\\nDataset sizes:")
print(f"  Original: {len(original_dataset)} examples")
print(f"  Pruned (middle 60%): {len(pruned_dataset)} examples")
print(f"  Reduction: {len(original_dataset) - len(pruned_dataset)} examples ({(1 - len(pruned_dataset)/len(original_dataset))*100:.1f}%)")

# Compare perplexity distributions
plt.figure(figsize=(10, 6))
plt.hist(original_dataset['perplexity_score'], bins=30, alpha=0.6, label='Original Dataset', color='blue')
plt.hist(pruned_dataset['perplexity_score'], bins=30, alpha=0.6, label='Pruned Dataset', color='red')
plt.axvline(low_threshold, color='green', linestyle='--', label=f'Low threshold: {low_threshold:.1f}')
plt.axvline(high_threshold, color='orange', linestyle='--', label=f'High threshold: {high_threshold:.1f}')
plt.xlabel('Perplexity Score')
plt.ylabel('Frequency')
plt.title('Perplexity Distribution: Original vs Pruned Dataset')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

# Show statistics for both datasets
def dataset_stats(dataset, name):
    scores = dataset['perplexity_score']
    lengths = [len(article.split()) for article in dataset['article']]
    
    print(f"\\n{name} Dataset Statistics:")
    print(f"  Size: {len(dataset)} examples")
    print(f"  Perplexity - Mean: {np.mean(scores):.2f}, Std: {np.std(scores):.2f}")
    print(f"  Article length - Mean: {np.mean(lengths):.1f} words, Std: {np.std(lengths):.1f}")

dataset_stats(original_dataset, "Original")
dataset_stats(pruned_dataset, "Pruned")


## Step 4: Train Summarization Models

We'll train lightweight summarization models on both the original and pruned datasets, then compare their performance.


In [None]:
# Model configuration
MODEL_NAME = "facebook/bart-large-cnn"  # Pre-trained summarization model
MAX_INPUT_LENGTH = 1024
MAX_TARGET_LENGTH = 128
BATCH_SIZE = 4
LEARNING_RATE = 5e-5
NUM_EPOCHS = 2  # Small number for demo - increase for better results

# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
print(f"Loaded tokenizer for {MODEL_NAME}")

def preprocess_function(examples):
    """Preprocess the data for summarization."""
    # Prepare inputs and targets
    inputs = [f"summarize: {article}" for article in examples["article"]]
    targets = examples["highlights"]
    
    # Tokenize inputs
    model_inputs = tokenizer(
        inputs, 
        max_length=MAX_INPUT_LENGTH, 
        truncation=True, 
        padding=True
    )
    
    # Tokenize targets
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(
            targets, 
            max_length=MAX_TARGET_LENGTH, 
            truncation=True, 
            padding=True
        )
    
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

print("Preprocessing datasets...")

# Preprocess both datasets
tokenized_original = original_dataset.map(preprocess_function, batched=True)
tokenized_pruned = pruned_dataset.map(preprocess_function, batched=True)

# Split datasets into train/test (80/20 split)
def split_dataset(dataset, test_size=0.2):
    dataset = dataset.train_test_split(test_size=test_size, seed=42)
    return dataset["train"], dataset["test"]

original_train, original_test = split_dataset(tokenized_original)
pruned_train, pruned_test = split_dataset(tokenized_pruned)

print(f"\nDataset splits:")
print(f"Original - Train: {len(original_train)}, Test: {len(original_test)}")
print(f"Pruned - Train: {len(pruned_train)}, Test: {len(pruned_test)}")


In [None]:
def train_model(train_dataset, output_dir, model_name="Model"):
    """Train a summarization model."""
    print(f"\n=== Training {model_name} ===")
    
    # Load fresh model for each training
    model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
    
    # Training arguments
    training_args = Seq2SeqTrainingArguments(
        output_dir=output_dir,
        evaluation_strategy="no",  # Disable evaluation during training for speed
        learning_rate=LEARNING_RATE,
        per_device_train_batch_size=BATCH_SIZE,
        per_device_eval_batch_size=BATCH_SIZE,
        num_train_epochs=NUM_EPOCHS,
        weight_decay=0.01,
        save_strategy="epoch",
        save_total_limit=1,
        predict_with_generate=True,
        fp16=torch.cuda.is_available(),  # Use mixed precision if CUDA available
        dataloader_num_workers=0,  # Avoid multiprocessing issues
        logging_steps=50,
        report_to=None,  # Disable wandb logging
    )
    
    # Data collator
    data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)
    
    # Initialize trainer
    trainer = Seq2SeqTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        tokenizer=tokenizer,
        data_collator=data_collator,
    )
    
    # Train the model
    print(f"Starting training with {len(train_dataset)} examples...")
    trainer.train()
    
    # Save the model
    trainer.save_model()
    print(f"Model saved to {output_dir}")
    
    return model

# Train both models
print("Starting model training...")
print("Note: This may take some time depending on your hardware.")

# Train on original dataset
original_model = train_model(
    train_dataset=original_train,
    output_dir="./model_original",
    model_name="Original Dataset Model"
)

# Train on pruned dataset  
pruned_model = train_model(
    train_dataset=pruned_train,
    output_dir="./model_pruned", 
    model_name="Pruned Dataset Model"
)

print("\n✅ Both models trained successfully!")


## Step 5: Evaluate Models with ROUGE-L Scores

Now we'll evaluate both models on test sets and compare their ROUGE-L scores to see the impact of perplexity-based pruning.


In [None]:
# Initialize ROUGE scorer
rouge_scorer_instance = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)

def generate_summaries(model, test_dataset, max_samples=None):
    """Generate summaries using the trained model."""
    if max_samples:
        test_dataset = test_dataset.select(range(min(max_samples, len(test_dataset))))
    
    summaries = []
    references = []
    
    # Create a pipeline for easier inference
    summarizer = pipeline(
        "summarization",
        model=model,
        tokenizer=tokenizer,
        max_length=MAX_TARGET_LENGTH,
        min_length=30,
        do_sample=False,
        device=0 if torch.cuda.is_available() else -1
    )
    
    print(f"Generating summaries for {len(test_dataset)} examples...")
    
    for i, example in enumerate(tqdm(test_dataset)):
        # Get original article and reference summary
        article = example['article']
        reference = example['highlights']
        
        try:
            # Generate summary
            result = summarizer(article, max_length=MAX_TARGET_LENGTH, min_length=30, do_sample=False)
            generated_summary = result[0]['summary_text']
            
            summaries.append(generated_summary)
            references.append(reference)
            
        except Exception as e:
            print(f"Error processing example {i}: {e}")
            summaries.append("")  # Empty summary for failed cases
            references.append(reference)
    
    return summaries, references

def calculate_rouge_scores(summaries, references):
    """Calculate ROUGE-L scores."""
    rouge_l_scores = []
    
    for summary, reference in zip(summaries, references):
        if summary.strip():  # Only calculate if summary is not empty
            scores = rouge_scorer_instance.score(reference, summary)
            rouge_l_scores.append(scores['rougeL'].fmeasure)
        else:
            rouge_l_scores.append(0.0)  # Score 0 for empty summaries
    
    return rouge_l_scores

# Evaluate both models
print("\n=== EVALUATING MODELS ===")

# For faster evaluation, limit test samples (remove this for full evaluation)
MAX_TEST_SAMPLES = 50

# Evaluate original model
print("\nEvaluating Original Dataset Model...")
original_summaries, original_references = generate_summaries(
    original_model, original_test, max_samples=MAX_TEST_SAMPLES
)
original_rouge_scores = calculate_rouge_scores(original_summaries, original_references)

# Evaluate pruned model
print("\nEvaluating Pruned Dataset Model...")
pruned_summaries, pruned_references = generate_summaries(
    pruned_model, pruned_test, max_samples=MAX_TEST_SAMPLES
)
pruned_rouge_scores = calculate_rouge_scores(pruned_summaries, pruned_references)

# Calculate statistics
original_mean_rouge = np.mean(original_rouge_scores)
original_std_rouge = np.std(original_rouge_scores)
pruned_mean_rouge = np.mean(pruned_rouge_scores)
pruned_std_rouge = np.std(pruned_rouge_scores)

print(f"\n=== ROUGE-L RESULTS ===")
print(f"Original Dataset Model:")
print(f"  Mean ROUGE-L: {original_mean_rouge:.4f} (±{original_std_rouge:.4f})")
print(f"  Test samples: {len(original_rouge_scores)}")

print(f"\nPruned Dataset Model:")
print(f"  Mean ROUGE-L: {pruned_mean_rouge:.4f} (±{pruned_std_rouge:.4f})")
print(f"  Test samples: {len(pruned_rouge_scores)}")

print(f"\nImprovement:")
improvement = pruned_mean_rouge - original_mean_rouge
improvement_pct = (improvement / original_mean_rouge) * 100
print(f"  Absolute: {improvement:+.4f}")
print(f"  Relative: {improvement_pct:+.2f}%")


In [None]:
# Visualize results
plt.figure(figsize=(12, 5))

# Plot 1: ROUGE-L score distributions
plt.subplot(1, 2, 1)
plt.hist(original_rouge_scores, bins=20, alpha=0.6, label='Original Model', color='blue')
plt.hist(pruned_rouge_scores, bins=20, alpha=0.6, label='Pruned Model', color='red')
plt.axvline(original_mean_rouge, color='blue', linestyle='--', 
           label=f'Original Mean: {original_mean_rouge:.3f}')
plt.axvline(pruned_mean_rouge, color='red', linestyle='--', 
           label=f'Pruned Mean: {pruned_mean_rouge:.3f}')
plt.xlabel('ROUGE-L Score')
plt.ylabel('Frequency')
plt.title('ROUGE-L Score Distribution')
plt.legend()
plt.grid(True, alpha=0.3)

# Plot 2: Box plot comparison
plt.subplot(1, 2, 2)
data_to_plot = [original_rouge_scores, pruned_rouge_scores]
labels = ['Original\nModel', 'Pruned\nModel']
bp = plt.boxplot(data_to_plot, labels=labels, patch_artist=True)
bp['boxes'][0].set_facecolor('lightblue')
bp['boxes'][1].set_facecolor('lightcoral')
plt.ylabel('ROUGE-L Score')
plt.title('ROUGE-L Score Comparison')
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Show some example summaries
print("\n=== EXAMPLE SUMMARIES ===")

for i in range(min(3, len(original_summaries))):
    print(f"\n--- Example {i+1} ---")
    print(f"Article (first 200 chars): {original_test[i]['article'][:200]}...")
    print(f"\nReference Summary: {original_references[i]}")
    print(f"\nOriginal Model Summary: {original_summaries[i]}")
    print(f"Pruned Model Summary: {pruned_summaries[i]}")
    print(f"\nROUGE-L Scores:")
    print(f"  Original Model: {original_rouge_scores[i]:.4f}")
    print(f"  Pruned Model: {pruned_rouge_scores[i]:.4f}")
    print("-" * 80)


## Step 6: Analysis and Conclusions

Let's analyze the results and understand the impact of perplexity-based pruning on summarization model performance.


In [None]:
# Statistical significance test (t-test)
from scipy import stats

t_stat, p_value = stats.ttest_ind(pruned_rouge_scores, original_rouge_scores)

print("=== DETAILED ANALYSIS ===")
print(f"\n📊 Statistical Analysis:")
print(f"  t-statistic: {t_stat:.4f}")
print(f"  p-value: {p_value:.4f}")
print(f"  Significant improvement: {'Yes' if p_value < 0.05 else 'No'} (α = 0.05)")

# Training efficiency analysis
print(f"\n⚡ Training Efficiency:")
print(f"  Original dataset size: {len(original_train)} examples")
print(f"  Pruned dataset size: {len(pruned_train)} examples")
print(f"  Training data reduction: {(1 - len(pruned_train)/len(original_train))*100:.1f}%")
print(f"  Performance change: {improvement_pct:+.2f}%")

efficiency_ratio = improvement_pct / ((len(original_train) - len(pruned_train))/len(original_train)*100)
print(f"  Efficiency ratio: {efficiency_ratio:.2f} (performance gain per % data reduction)")

print(f"\n🎯 Key Findings:")
if improvement > 0:
    print(f"  ✅ Perplexity-based pruning IMPROVED model performance")
    print(f"  ✅ Achieved {improvement_pct:+.2f}% relative improvement in ROUGE-L")
    print(f"  ✅ Used {(1 - len(pruned_train)/len(original_train))*100:.1f}% less training data")
else:
    print(f"  ⚠️  Perplexity-based pruning showed {improvement_pct:.2f}% change in ROUGE-L")
    print(f"  ℹ️  Results may vary with different datasets and model configurations")

print(f"\n💡 Insights:")
print(f"  • Removing extreme perplexity outliers can improve model training")
print(f"  • Lower perplexity articles may be too simple/repetitive")
print(f"  • Higher perplexity articles may be too noisy/complex")
print(f"  • The 'sweet spot' (middle perplexity range) contains optimal training examples")

print(f"\n🔧 Recommendations for Production:")
print(f"  1. Use real KenLM models for more accurate perplexity scoring")
print(f"  2. Experiment with different pruning thresholds (e.g., remove top/bottom 10-30%)")
print(f"  3. Consider domain-specific KenLM models for better scoring")
print(f"  4. Validate results on larger datasets and longer training runs")
print(f"  5. Combine perplexity pruning with other data quality metrics")

# Save results summary
results_summary = {
    'original_rouge_mean': original_mean_rouge,
    'original_rouge_std': original_std_rouge,
    'pruned_rouge_mean': pruned_mean_rouge,
    'pruned_rouge_std': pruned_std_rouge,
    'improvement_absolute': improvement,
    'improvement_relative_pct': improvement_pct,
    'original_dataset_size': len(original_train),
    'pruned_dataset_size': len(pruned_train),
    'data_reduction_pct': (1 - len(pruned_train)/len(original_train))*100,
    't_statistic': t_stat,
    'p_value': p_value
}

print(f"\n💾 Results saved to 'results_summary' dictionary for further analysis.")


## Summary

This notebook demonstrated how to use **PerplexityScorer** from dPrune to improve text summarization models through intelligent data pruning. 

### What We Accomplished:

1. **Loaded CNN/DailyMail dataset** - A real-world summarization dataset
2. **Calculated perplexity scores** - Used mock scores based on text characteristics (in production, use real KenLM models)
3. **Applied intelligent pruning** - Removed extreme perplexity outliers (top and bottom 20%)
4. **Trained two models** - One on original data, one on pruned data
5. **Compared performance** - Used ROUGE-L scores to measure summarization quality
6. **Analyzed results** - Statistical analysis of the performance differences

### Key Takeaways:

- **Perplexity-based pruning can improve model performance** while reducing training data
- **Quality over quantity**: A smaller, well-curated dataset often outperforms a larger, noisy one
- **Sweet spot principle**: Medium perplexity examples provide the best balance of informativeness and learnability
- **Efficiency gains**: Better performance with less data means faster training and lower costs

### Next Steps:

For production use, consider:
- Using real KenLM models for accurate perplexity calculation
- Experimenting with different pruning thresholds
- Validating on larger datasets and longer training runs
- Combining perplexity scores with other data quality metrics
- Domain-specific language models for specialized applications

The PerplexityScorer is a powerful tool for improving dataset quality and model performance in NLP tasks!
