In [None]:
# Install required packages if needed
# !pip install -e .[test]
# !pip install transformers datasets torch scikit-learn tqdm accelerate


In [None]:
import torch
import numpy as np
from datasets import Dataset
from transformers import (
    AutoTokenizer, 
    AutoModelForSequenceClassification, 
    TrainingArguments, 
    Trainer
)

# dPrune imports
from dprune.callbacks import ForgettingCallback
from dprune.scorers.supervised import ForgettingScorer
from dprune.pruners.selection import TopKPruner, BottomKPruner
from dprune.pipeline import PruningPipeline

print("All imports successful!")


In [None]:
# Create a more substantial dataset for sentiment analysis
positive_texts = [
    "This movie is absolutely fantastic!",
    "I loved every minute of this film.",
    "Outstanding performance by all actors.",
    "A masterpiece of modern cinema.",
    "Brilliant storytelling and direction.",
    "This is one of the best movies I've ever seen.",
    "Incredible cinematography and soundtrack.",
    "A delightful and heartwarming story.",
    "Perfect blend of action and emotion.",
    "This film exceeded all my expectations.",
    "Amazing special effects and great plot.",
    "A truly inspiring and uplifting movie."
]

negative_texts = [
    "This movie was a complete waste of time.",
    "Boring and predictable storyline.",
    "Poor acting and terrible direction.",
    "I couldn't wait for this movie to end.",
    "Disappointing and overrated film.",
    "The worst movie I've seen this year.",
    "Confusing plot and bad character development.",
    "Not worth the money or time.",
    "Terrible script and poor execution.",
    "This film was incredibly dull.",
    "Weak storyline and unconvincing performances.",
    "A forgettable and mediocre movie."
]

# Combine into dataset
texts = positive_texts + negative_texts
labels = [1] * len(positive_texts) + [0] * len(negative_texts)  # 1=positive, 0=negative

# Shuffle the data
indices = list(range(len(texts)))
np.random.seed(42)
np.random.shuffle(indices)

shuffled_texts = [texts[i] for i in indices]
shuffled_labels = [labels[i] for i in indices]

raw_dataset = Dataset.from_dict({
    'text': shuffled_texts,
    'label': shuffled_labels
})

print(f"Dataset created with {len(raw_dataset)} examples")
print(f"Positive examples: {sum(raw_dataset['label'])}")
print(f"Negative examples: {len(raw_dataset) - sum(raw_dataset['label'])}")
print("\nFirst few examples:")
for i in range(3):
    print(f"  {i}: '{raw_dataset['text'][i]}' -> {raw_dataset['label'][i]}")


In [None]:
model_name = 'distilbert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)

def tokenize_function(examples):
    return tokenizer(examples['text'], padding='max_length', truncation=True, max_length=128)

tokenized_dataset = raw_dataset.map(tokenize_function, batched=True)

print("Model and tokenizer loaded successfully!")
print(f"Tokenized dataset: {tokenized_dataset}")


In [None]:
# Initialize the forgetting callback
forgetting_callback = ForgettingCallback()

print("ForgettingCallback initialized!")
print("This callback will track learning events during training.")


In [None]:
# Training arguments - we want multiple epochs to observe forgetting
training_args = TrainingArguments(
    output_dir='./forgetting_results',
    num_train_epochs=5,  # More epochs to observe forgetting
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    learning_rate=2e-5,
    logging_steps=10,
    save_strategy="no",  # Don't save checkpoints for this example
)

# Create trainer with our callback
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset,
    callbacks=[forgetting_callback],  # This is the key addition!
)

print("Trainer created with ForgettingCallback.")
print("Starting training...")

# Train the model
trainer.train()

print("Training completed!")


In [None]:
# Check what the callback recorded
print(f"Number of examples tracked: {len(forgetting_callback.learning_events)}")
print("\nLearning events for first 5 examples:")
for i in range(min(5, len(forgetting_callback.learning_events))):
    events = forgetting_callback.learning_events.get(i, [])
    print(f"  Example {i}: {events}")
    if events:
        # Count transitions from correct (1) to incorrect (0)
        transitions = list(zip(events, events[1:]))
        forgetting_events = sum(1 for prev, curr in transitions if prev == 1 and curr == 0)
        print(f"    -> Forgetting events: {forgetting_events}")

# Calculate forgetting scores
forgetting_scores = forgetting_callback.calculate_forgetting_scores()
print(f"\nForgetting scores calculated for {len(forgetting_scores)} examples")
print(f"Score distribution: min={min(forgetting_scores)}, max={max(forgetting_scores)}, mean={np.mean(forgetting_scores):.2f}")


In [None]:
# Create the forgetting scorer using our callback
forgetting_scorer = ForgettingScorer(forgetting_callback)

# Score the dataset
scored_dataset = forgetting_scorer.score(raw_dataset)

print("Dataset scored with forgetting scores!")
print(f"Scored dataset columns: {scored_dataset.column_names}")
print("\nFirst few examples with scores:")
for i in range(5):
    print(f"  Score: {scored_dataset['score'][i]}, Text: '{scored_dataset['text'][i][:50]}...', Label: {scored_dataset['label'][i]}")


In [None]:
# Strategy 1: Keep examples that are forgotten the most (hardest examples)
top_pruner = TopKPruner(k=0.5)  # Keep top 50%
pipeline_hard = PruningPipeline(scorer=forgetting_scorer, pruner=top_pruner)
hard_examples = pipeline_hard.run(raw_dataset)

# Strategy 2: Keep examples that are never forgotten (easy/stable examples)
bottom_pruner = BottomKPruner(k=0.5)  # Keep bottom 50%
pipeline_easy = PruningPipeline(scorer=forgetting_scorer, pruner=bottom_pruner)
easy_examples = pipeline_easy.run(raw_dataset)

print("Pruning Results:")
print(f"Original dataset: {len(raw_dataset)} examples")
print(f"Hard examples (most forgotten): {len(hard_examples)} examples")
print(f"Easy examples (least forgotten): {len(easy_examples)} examples")

print("\nHardest examples (most forgotten):")
for i in range(min(3, len(hard_examples))):
    print(f"  Score: {hard_examples['score'][i]}, Text: '{hard_examples['text'][i][:60]}...', Label: {hard_examples['label'][i]}")

print("\nEasiest examples (least forgotten):")
for i in range(min(3, len(easy_examples))):
    print(f"  Score: {easy_examples['score'][i]}, Text: '{easy_examples['text'][i][:60]}...', Label: {easy_examples['label'][i]}")


In [None]:
import matplotlib.pyplot as plt

# Plot distribution of forgetting scores
plt.figure(figsize=(10, 6))
plt.hist(scored_dataset['score'], bins=max(1, max(scored_dataset['score']) + 1), alpha=0.7, edgecolor='black')
plt.xlabel('Forgetting Score')
plt.ylabel('Number of Examples')
plt.title('Distribution of Forgetting Scores')
plt.grid(True, alpha=0.3)
plt.show()

# Analyze by label
positive_scores = [score for score, label in zip(scored_dataset['score'], scored_dataset['label']) if label == 1]
negative_scores = [score for score, label in zip(scored_dataset['score'], scored_dataset['label']) if label == 0]

print(f"\nForgetting Score Analysis:")
print(f"Positive examples - Mean: {np.mean(positive_scores):.2f}, Std: {np.std(positive_scores):.2f}")
print(f"Negative examples - Mean: {np.mean(negative_scores):.2f}, Std: {np.std(negative_scores):.2f}")

# Find the most and least forgotten examples
max_score = max(scored_dataset['score'])
min_score = min(scored_dataset['score'])

most_forgotten_idx = scored_dataset['score'].index(max_score)
least_forgotten_idx = scored_dataset['score'].index(min_score)

print(f"\nMost forgotten example (score: {max_score}):")
print(f"  '{scored_dataset['text'][most_forgotten_idx]}' (Label: {scored_dataset['label'][most_forgotten_idx]})")

print(f"\nLeast forgotten example (score: {min_score}):")
print(f"  '{scored_dataset['text'][least_forgotten_idx]}' (Label: {scored_dataset['label'][least_forgotten_idx]})")
