# CRISP: Unlearning Harry Potter Demo

This notebook demonstrates the CRISP method for unlearning concept (Harry Potter).

## 1. Setup and Model Loading
We initialize the environment, select the model (Gemma 2 2B or Llama 3.1 8B), and download the necessary Sparse Autoencoders (SAEs).

In [None]:
%load_ext autoreload
%autoreload 2

import os
import torch
from globals import GEMMA_2_2B, LLAMA_3_1_8B
from crisp import CRISP, CRISPConfig
from unlearn import unlearn_lora, UnlearnConfig
from data import load_hp_data, HPDataConfig, genenrate_hp_eval_text
from sae import JumpReLUSAE, TopkSae
from eval import get_mcq_accuracy
from utils import load_cached_features, get_feature_tokens
from crisp import LayerFeatures
from plot import plot_features_scatter

# os.environ['HF_TOKEN'] = <YOUR_HF_TOKEN>
if os.environ['HF_TOKEN'] is None:
    raise ValueError("HF_TOKEN environment variable not set. Please set it to your Hugging Face token.")

# Configuration
MODEL_CARD = GEMMA_2_2B # LLAMA_3_1_8B
is_gemma = (MODEL_CARD == GEMMA_2_2B)

GEMMA_CONFIG = {
    "sae_layers": list(range(4, 15, 2)),
    "save_path": "gemma_sae_cache",
    "sae_class": JumpReLUSAE,
    "model_name_short": "gemma",
    "unlearn": {
        "learning_rate": 1e-5,
        "k_features": 10,
        "alpha": 5,
    },
    "neuronpedia_id": "gemma-2-2b",
    "neuronpedia_source_suffix": "-gemmascope-res-16k",
    "layer_to_plot": 10
}

LLAMA_CONFIG = {
    "sae_layers": list(range(4, 30, 2)),
    "save_path": "llama_sae_cache",
    "sae_class": TopkSae,
    "model_name_short": "llama",
    "unlearn": {
        "learning_rate": 2e-5,
        "k_features": 10,
        "alpha": 30,
    },
    "neuronpedia_id": "llama3.1-8b",
    "neuronpedia_source_suffix": "-llamascope-res-32k",
    "layer_to_plot": 20
}

CONFIG = GEMMA_CONFIG if is_gemma else LLAMA_CONFIG

SAE_LAYERS = CONFIG["sae_layers"]

print(f"Using model: {MODEL_CARD}")
print(f"Operating on layers: {SAE_LAYERS}")

In [None]:
save_path = CONFIG["save_path"]
SAE_CLASS = CONFIG["sae_class"]

print(f"Checking/Downloading SAEs to {save_path}...")
for layer in SAE_LAYERS:
    layer_path = os.path.join(save_path, f"layer_{layer}")
    if not os.path.exists(layer_path):
        print(f"Downloading SAE for layer {layer}...")
        SAE_CLASS.download_and_save(layer=layer, save_path=save_path)

In [None]:
config = CRISPConfig(
    layers=SAE_LAYERS, 
    model_name=CONFIG["model_name_short"], 
    bf16=True
)
crisp = CRISP(config)

## 3. Load Data
We load the forget dataset (Harry Potter text) and a retain dataset (general book text) to help the model distinguish between what to remove and what to keep.

In [None]:
# Load Harry Potter text to unlearn and retain data
data_config = HPDataConfig(n_examples=2500)

data = load_hp_data(
    n_examples=data_config.n_examples,
    benign=data_config.retain_type,
    max_len=data_config.max_length
)
print(f"Loaded {len(data['forget'])} HP examples and {len(data['retain'])} retain examples")

## 4. Identify Salient Features
CRISP uses SAEs to identify features that are highly active on the forget data but not on the retain data. These salient features represent the specific knowledge we want to unlearn.

In [None]:
# Process texts to identify Harry Potter-specific features
print("Processing features (this may take a moment)...")
crisp.process_multi_texts_batch(
    text_target=data['forget'],
    text_benign=data['retain'],
    data_config=data_config,
    batch_size=8
)
print("Feature processing complete.")

## 5. Unlearn
We apply the unlearning process. This involves fine-tuning the model (using LoRA) to suppress the identified salient features while maintaining performance on general tasks.

In [None]:
crisp.unload_lora()
torch.cuda.empty_cache()

In [None]:
uconfig = UnlearnConfig(
    learning_rate=CONFIG["unlearn"]["learning_rate"],
    k_features=CONFIG["unlearn"]["k_features"],
    alpha=CONFIG["unlearn"]["alpha"],
    beta=0.99,
    gamma=0.01,
    batch_size=4,
    lora_rank=4,
    data_type="hp",
    verbose=True
)

print("Starting unlearning process...")
unlearn_lora(crisp, text_target=data['forget'], text_benign=data['retain'], config=uconfig, data_config=data_config)
print("Unlearning complete.")

In [None]:
print("-"*50)
print("Original Model")
print("-"*50)

with crisp.model.disable_adapter():
    print("Evaluating original Harry Potter accuracy...")
    original_hp_acc = get_mcq_accuracy(crisp, type="hp")

    print("Generating Harry Potter evaluation text of origina model...")
    genenrate_hp_eval_text(crisp)

    print(f"Evaluating original MMLU accuracy...")
    original_mmlu_acc = get_mcq_accuracy(crisp, type="mmlu")


print("-"*50)
print("After Unlearning")
print("-"*50)

print("Evaluating Harry Potter accuracy after unlearning...")
hp_acc_after = get_mcq_accuracy(crisp, type="hp")
print(f"HP Accuracy after unlearning: {hp_acc_after:.2%} vs original {original_hp_acc:.2%}")

print("Generating Harry Potter evaluation text after unlearning...")
genenrate_hp_eval_text(crisp)

print("Evaluating MMLU accuracy after unlearning...")
after_acc = get_mcq_accuracy(crisp, type="mmlu")

## 6. Analysis and Visualization
We visualize the features to see which ones were identified as salient. We also inspect the top features using Neuronpedia to understand what concepts they represent.

In [None]:
# Visualize features for one of the layers
layer_to_plot = CONFIG["layer_to_plot"]

cached_features = load_cached_features(layer_to_plot, data_config, model_name=MODEL_CARD)
layer_features = LayerFeatures(cached_features)

print(f"Loaded {len(layer_features.features)} features for layer {layer_to_plot}")

plot_features_scatter(
    layer_features=layer_features,
    k_features=5,
    top_percentile=0.05
)

In [None]:
# Inspect the most salient feature using Neuronpedia
from IPython.display import HTML, display

model_id = CONFIG["neuronpedia_id"]

# Find feature with highest target_acts_relative
best_feature = max(layer_features.topk_filtered(5), key=lambda f: f.target_acts_relative)
feature_index = best_feature.index

print(f"Inspecting Feature {feature_index} from Layer {layer_to_plot}...")

feature_data = get_feature_tokens(model_id, layer_to_plot, feature_index, top_k=5)

if feature_data:
    source = f"{layer_to_plot}{CONFIG['neuronpedia_source_suffix']}"
        
    neuronpedia_url = f"https://www.neuronpedia.org/{model_id}/{source}/{feature_index}"

    # Create HTML content
    tokens_html = ' '.join([f'<span style="background-color: #e1ecf4; color: #2c5282; padding: 2px 8px; border-radius: 4px; margin-right: 5px; display: inline-block; border: 1px solid #b3d4fc;">{token}</span>' for token in feature_data['pos_str']])
    
    description = feature_data['explanations'][0]['description'] if feature_data.get('explanations') else "No description available"

    html_content = f"""
    <div style="border: 1px solid #e0e0e0; padding: 20px; border-radius: 8px; background-color: #ffffff; font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Helvetica, Arial, sans-serif; box-shadow: 0 2px 4px rgba(0,0,0,0.05); max-width: 600px;">
        <div style="display: flex; justify-content: space-between; align-items: center; margin-bottom: 15px; border-bottom: 1px solid #eee; padding-bottom: 10px;">
            <h3 style="margin: 0; color: #333; font-size: 1.2em;">Feature {feature_index}</h3>
            <span style="background-color: #f0f0f0; color: #666; padding: 2px 8px; border-radius: 12px; font-size: 0.8em;">Layer {layer_to_plot}</span>
        </div>
        
        <div style="margin-bottom: 20px;">
            <div style="text-transform: uppercase; font-size: 0.75em; color: #888; margin-bottom: 5px; letter-spacing: 0.5px;">Auto Description</div>
            <div style="font-size: 1.1em; color: #1a1a1a; line-height: 1.4;">{description}</div>
        </div>
        
        <div style="margin-bottom: 20px;">
            <div style="text-transform: uppercase; font-size: 0.75em; color: #888; margin-bottom: 8px; letter-spacing: 0.5px;">Top Tokens (Logit Lens)</div>
            <div style="display: flex; flex-wrap: wrap; gap: 5px;">
                {tokens_html}
            </div>
        </div>
        
        <div style="text-align: right;">
            <a href="{neuronpedia_url}" target="_blank" style="color: #0969da; text-decoration: none; font-size: 0.9em; font-weight: 500;">View on Neuronpedia &rarr;</a>
        </div>
    </div>
    """
    display(HTML(html_content))