# Erasus: Introduction to Machine Unlearning

This notebook gives a short overview of the **Erasus** framework and runs a minimal unlearning example.

## What is Machine Unlearning?

Machine unlearning removes the influence of specific data (the *forget set*) from a trained model while preserving performance on the rest of the data (the *retain set*).

## Erasus Pipeline

1. **Coreset selection** — Choose a small subset of the forget set that drives forgetting.
2. **Unlearning** — Apply a strategy (e.g. gradient ascent, SCRUB, Fisher) to the model.
3. **Evaluation** — Measure forgetting quality (e.g. MIA) and utility (e.g. accuracy).

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset

from erasus.unlearners import ErasusUnlearner
from erasus.metrics.metric_suite import MetricSuite
import erasus.strategies  # register strategies
import erasus.selectors   # register selectors

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

In [None]:
# Minimal model
model = nn.Sequential(
    nn.Linear(32, 64),
    nn.ReLU(),
    nn.Linear(64, 10),
).to(device)

# Synthetic forget & retain loaders
forget_loader = DataLoader(
    TensorDataset(torch.randn(50, 32), torch.randint(0, 10, (50,))),
    batch_size=16, shuffle=True
)
retain_loader = DataLoader(
    TensorDataset(torch.randn(200, 32), torch.randint(0, 10, (200,))),
    batch_size=16, shuffle=True
)

unlearner = ErasusUnlearner(
    model=model,
    strategy="gradient_ascent",
    selector="random",
    device=device,
    strategy_kwargs={"lr": 1e-3},
)
result = unlearner.fit(forget_data=forget_loader, retain_data=retain_loader, epochs=3)
print("Forget loss history:", result.forget_loss_history)

In [None]:
suite = MetricSuite(["accuracy"])
metrics = suite.run(unlearner.model, forget_loader, retain_loader)
for k, v in metrics.items():
    if k != "_meta" and isinstance(v, (int, float)):
        print(f"  {k}: {v:.4f}")