# Machine Unlearning Experiments

This notebook demonstrates unlearning operations on linearized models:
- Unlearning specific identities
- Comparing different unlearning methods
- Evaluating unlearning effectiveness

## Objectives
1. Perform unlearning operations
2. Compare unlearning methods
3. Verify unlearning effectiveness


In [None]:
import sys
import os
sys.path.append(os.path.join(os.path.dirname(os.getcwd()), 'src'))

import torch
import numpy as np
import matplotlib.pyplot as plt
import yaml
from tqdm import tqdm

from utils.model_loader import load_model_from_config
from linearizer.linearizer import Linearizer
from unlearning.unlearning import UnlearningEngine
from data.dataloader import get_ms1mv2_dataloader

# Load configuration
with open('../config.yaml', 'r') as f:
    config = yaml.safe_load(f)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")


## 1. Load Linearized Model

In [None]:
# Load model and create linearizer (assuming already trained)
# In practice, you would load a saved linearizer checkpoint
model = load_model_from_config(config)
model = model.to(device)

linearizer_config = config['linearizer']
linearizer = Linearizer(
    model=model,
    embedding_size=config['model'].get('embedding_size', 512),
    num_blocks=linearizer_config.get('num_blocks', 4),
    hidden_dim=linearizer_config.get('hidden_dim', 1024)
)
linearizer = linearizer.to(device)

print("Linearizer loaded")


## 2. Select Identities to Unlearn

In [None]:
# Select identities to unlearn
# In practice, you would select specific identity IDs
identity_ids_to_unlearn = config['unlearning'].get('target_identities', [0, 1, 2, 3, 4])

# If empty, select some random identities from dataset
if len(identity_ids_to_unlearn) == 0:
    ms1mv2_path = config['data']['ms1mv2']['path']
    from data.dataset import MS1MV2Dataset
    dataset = MS1MV2Dataset(ms1mv2_path, is_training=False)
    unique_identities = list(set([label.item() for _, label in dataset]))
    identity_ids_to_unlearn = unique_identities[:5]  # Select first 5

print(f"Identities to unlearn: {identity_ids_to_unlearn}")


## 3. Perform Unlearning

In [None]:
# Create unlearning engine
unlearning_method = config['unlearning'].get('method', 'orthogonal_projection')
unlearning_engine = UnlearningEngine(linearizer, method=unlearning_method)

# Load data for unlearning
ms1mv2_path = config['data']['ms1mv2']['path']
dataloader = get_ms1mv2_dataloader(ms1mv2_path, batch_size=64, is_training=True)

# Perform unlearning
print(f"Unlearning using method: {unlearning_method}")
print("This may take a while...")

updated_operator = unlearning_engine.unlearn(
    dataloader,
    identity_ids_to_unlearn,
    device=device
)

print("Unlearning completed!")


## 4. Verify Unlearning

In [None]:
# Verify unlearning effectiveness
from unlearning.evaluation import compute_unlearning_metrics

# Get retain identities (all except unlearned)
ms1mv2_path = config['data']['ms1mv2']['path']
from data.dataset import MS1MV2Dataset
dataset = MS1MV2Dataset(ms1mv2_path, is_training=False)
all_identities = list(set([label.item() for _, label in dataset]))
identity_ids_to_retain = [id for id in all_identities if id not in identity_ids_to_unlearn]

# Evaluate
test_dataloader = get_ms1mv2_dataloader(ms1mv2_path, batch_size=64, is_training=False)

metrics = compute_unlearning_metrics(
    original_model=model,
    unlearned_model=linearizer,
    dataloader=test_dataloader,
    identity_ids_to_forget=identity_ids_to_unlearn,
    identity_ids_to_retain=identity_ids_to_retain[:100],  # Limit for speed
    device=device
)

print("\nUnlearning Metrics:")
for key, value in metrics.items():
    print(f"{key}: {value:.4f}")
