# dPrune: Forgetting Score Example

This notebook demonstrates how to implement pruning using the **Forgetting Score** in `dPrune`. The forgetting score is based on the [An Empirical Study of Example Forgetting during Deep Neural Network Learning](https://arxiv.org/abs/1812.05159) paper and measures how many times an example is "forgotten" during training. Such examples are found to be more *informative* than others.

An example is "forgotten" if it transitions from being classified correctly to incorrectly between epochs. Therefore, it is well-suited for the classification tasks.


## 1. Setup and Installation


In [1]:
# Install required packages if needed
# !pip install -e .[test]
!pip install transformers torch scikit-learn tqdm accelerate

!pip install -U datasets huggingface_hub fsspec

Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5

In [3]:
import torch
import numpy as np
from datasets import Dataset
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer
)

from dprune.callbacks import ForgettingCallback
from dprune.scorers.supervised import ForgettingScorer
from dprune.pruners.selection import TopKPruner, BottomKPruner
from dprune.pipeline import PruningPipeline


## 2. Load the IMDB dataset

For the forgetting score to be meaningful, we need a dataset large enough and training long enough to observe forgetting events. We will be using IMDB dataset from HuggingFace.


In [32]:
from datasets import load_dataset
raw_dataset = load_dataset("stanfordnlp/imdb", split="train")

"""
If you want to use a sample of the dataset for faster training, uncomment the snipper below
# raw_dataset = raw_dataset.shuffle()
# raw_dataset = raw_dataset.filter(lambda entry, index: index < 0.1 * len(raw_dataset), with_indices=True)
"""

print(f"Positive examples: {sum(raw_dataset['label'])}")
print(f"Negative examples: {len(raw_dataset) - sum(raw_dataset['label'])}")
print("\nFirst few examples:")
for i in range(3):
    print(f"  {i}: '{raw_dataset['text'][i]}' -> {raw_dataset['label'][i]}")


Positive examples: 12500
Negative examples: 12500

First few examples:
  0: 'I rented I AM CURIOUS-YELLOW from my video store because of all the controversy that surrounded it when it was first released in 1967. I also heard that at first it was seized by U.S. customs if it ever tried to enter this country, therefore being a fan of films considered "controversial" I really had to see this for myself.<br /><br />The plot is centered around a young Swedish drama student named Lena who wants to learn everything she can about life. In particular she wants to focus her attentions to making some sort of documentary on what the average Swede thought about certain political issues such as the Vietnam War and race issues in the United States. In between asking politicians and ordinary denizens of Stockholm about their opinions on politics, she has sex with her drama teacher, classmates, and married men.<br /><br />What kills me about I AM CURIOUS-YELLOW is that 40 years ago, this was considered

## 3. Setup Model and Tokenizer


In [37]:
model_name = 'distilbert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)

def tokenize_function(examples):
    return tokenizer(examples['text'], padding='max_length', truncation=True, max_length=128)

tokenized_dataset = raw_dataset.map(tokenize_function, batched=True)

print("Model and tokenizer loaded successfully!")
print(f"Tokenized dataset: {tokenized_dataset}")


Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Map:   0%|          | 0/25000 [00:00<?, ? examples/s]

Model and tokenizer loaded successfully!
Tokenized dataset: Dataset({
    features: ['text', 'label', 'input_ids', 'attention_mask'],
    num_rows: 25000
})


## 4. Initialize the Forgetting Callback

This is the key step! We create a `ForgettingCallback` that will monitor the training process.


In [39]:
forgetting_callback = ForgettingCallback()

## 5. Train the Model with the Callback

We'll train for several epochs to give the model a chance to "forget" some examples.


In [40]:
# Training arguments - we want multiple epochs to observe forgetting
training_args = TrainingArguments(
    output_dir='./forgetting_results',
    num_train_epochs=3,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=16,
    learning_rate=2e-5,
    logging_steps=100,
    save_strategy="no",
    report_to="none"
)

# Create trainer with our callback
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset,
    callbacks=[forgetting_callback],  # This is the key addition!
)

forgetting_callback.trainer = trainer

print("Trainer created with ForgettingCallback.")
print("Starting training...")

# Train the model
trainer.train()

print("Training completed!")


Trainer created with ForgettingCallback.
Starting training...


Step,Training Loss
100,0.4922
200,0.3725
300,0.3458
400,0.3504
500,0.3201
600,0.3274
700,0.3137
800,0.2992
900,0.2353
1000,0.2237


predicted_labels:  [0 0 0 ... 1 1 1]
true_labels:  [0 0 0 ... 1 1 1]
predicted_labels:  [0 0 0 ... 1 1 1]
true_labels:  [0 0 0 ... 1 1 1]
predicted_labels:  [0 0 0 ... 1 1 1]
true_labels:  [0 0 0 ... 1 1 1]
Training completed!


## 6. Examine the Forgetting Events

Let's look at what the callback recorded during training.


In [41]:
print(f"Number of examples tracked: {len(forgetting_callback.learning_events)}")

# Calculate forgetting scores
forgetting_scores = forgetting_callback.calculate_forgetting_scores()
print(f"\nForgetting scores calculated for {len(forgetting_scores)} examples")
print(f"Score distribution: min={min(forgetting_scores)}, max={max(forgetting_scores)}, mean={np.mean(forgetting_scores):.2f}")

total_examples_forgotten = sum([score for score in forgetting_scores if score > 0])
total_examples_forgotten

Number of examples tracked: 25000

Forgetting scores calculated for 25000 examples
Score distribution: min=0, max=1, mean=0.01


199

## 7. Use the Forgetting Scorer in a Pipeline

Now we can use the populated callback with our `ForgettingScorer`.


In [48]:
import torch
from datasets import Dataset
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from transformers import PreTrainedModel, PreTrainedTokenizer

from dprune.base import Scorer

class ForgettingScorer(Scorer):
    """
    A Scorer that uses a ForgettingCallback to assign a "forgetting score"
    to each example. The score is the number of times an example was
    "forgotten" during training (i.e., transitioned from being classified
    correctly to incorrectly).
    """

    def __init__(self, forgetting_callback: ForgettingCallback):
        """
        Initializes the ForgettingScorer.

        Args:
            forgetting_callback (ForgettingCallback): A ForgettingCallback instance
                that has been used during a Trainer's training run.
        """
        self.callback = forgetting_callback

    def score(self, dataset: Dataset, **kwargs) -> Dataset:
        """
        Calculates and adds the forgetting scores to the dataset.
        The dataset passed here should be the same one used for training.
        """
        scores = self.callback.calculate_forgetting_scores()

        if len(scores) != len(dataset):
            raise ValueError(
                f"The number of scores from the callback ({len(scores)}) does not match "
                f"the dataset size ({len(dataset)}). Ensure the same dataset was used "
                "for training and scoring."
            )

        return dataset.add_column("score", scores)


In [49]:
# Create the forgetting scorer using our callback
forgetting_scorer = ForgettingScorer(forgetting_callback)

# Score the dataset
scored_dataset = forgetting_scorer.score(raw_dataset)

print("Dataset scored with forgetting scores!")
print(f"Scored dataset columns: {scored_dataset.column_names}")
print("\nFirst few examples with scores:")
for i in range(5):
    print(f"  Score: {scored_dataset['score'][i]}, Text: '{scored_dataset['text'][i][:50]}...', Label: {scored_dataset['label'][i]}")

Dataset scored with forgetting scores!
Scored dataset columns: ['text', 'label', 'score']

First few examples with scores:
  Score: 0, Text: 'I rented I AM CURIOUS-YELLOW from my video store b...', Label: 0
  Score: 0, Text: '"I Am Curious: Yellow" is a risible and pretentiou...', Label: 0
  Score: 0, Text: 'If only to avoid making this type of film in the f...', Label: 0
  Score: 0, Text: 'This film was probably inspired by Godard's Mascul...', Label: 0
  Score: 0, Text: 'Oh, brother...after hearing about this ridiculous ...', Label: 0


## 8. Pruning with Forgetting Scores



In [57]:
# Strategy: Keep examples that are forgotten the most (hardest examples)
top_pruner = TopKPruner(k=0.1)  # Keep top 10%
pipeline_hard = PruningPipeline(scorer=forgetting_scorer, pruner=top_pruner)
hard_examples = pipeline_hard.run(raw_dataset)

print("\nHardest examples (most forgotten):")
for i in range(min(3, len(hard_examples))):
    print(f" Text: '{hard_examples['text'][i][:60]}...', Label: {hard_examples['label'][i]}")

print("Length of pruned dataset", len(hard_examples))


Hardest examples (most forgotten):
 Text: 'En route to a small town that lays way off the beaten track ...', Label: 0
 Text: 'Ned Kelly (Ledger), the infamous Australian outlaw and legen...', Label: 0
 Text: 'The perfect murder is foiled when a wife(played by Mary Elle...', Label: 0
Length of pruned dataset 2500
