# CRISP: Unlearning WMDP-Bio Demo

This notebook demonstrates the CRISP method for unlearning biosecurity knowledge from the WMDP-bio dataset.


## 1. Setup and Model Loading
We initialize the environment, select the model (Gemma 2 2B), and download the necessary Sparse Autoencoders (SAEs).


In [1]:
%load_ext autoreload
%autoreload 2

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"  # Set your GPU device

import torch
from globals import GEMMA_2_2B, LLAMA_3_1_8B
from crisp import CRISP, CRISPConfig
from unlearn import unlearn_lora, UnlearnConfig
from data import load_wmdp_data, WMDPDataConfig, generate_bio_eval_text
from sae import JumpReLUSAE, TopkSae
from eval import get_mcq_accuracy
from utils import load_cached_features, get_feature_tokens
from crisp import LayerFeatures
from plot import plot_features_scatter

# Set your HuggingFace token (needed for model access)
# os.environ['HF_TOKEN'] = <YOUR_HF_TOKEN>
if 'HF_TOKEN' not in os.environ or os.environ['HF_TOKEN'] is None:
    print("Warning: HF_TOKEN environment variable not set. You may need it for model access.")

# Configuration for WMDP-Bio with Gemma 2 2B
MODEL_CARD = LLAMA_3_1_8B#GEMMA_2_2B
is_gemma = (MODEL_CARD == GEMMA_2_2B)

GEMMA_CONFIG = {
    "sae_layers": list(range(4, 15, 2)),  # [4, 6, 8, 10, 12, 14]
    "save_path": "gemma_sae_cache",
    "sae_class": JumpReLUSAE,
    "model_name_short": "gemma",
    "unlearn": {
        "learning_rate": 1e-5,
        "k_features": 10,
        "alpha": 5,
    },
    "neuronpedia_id": "gemma-2-2b",
    "neuronpedia_source_suffix": "-gemmascope-res-16k",
    "layer_to_plot": 10
}

LLAMA_CONFIG = {
    "sae_layers": list(range(4, 30, 2)),  # [4, 6, 8, ..., 28]
    "save_path": "llama_sae_cache",
    "sae_class": TopkSae,
    "model_name_short": "llama",
    "unlearn": {
        "learning_rate": 2e-5,
        "k_features": 10,
        "alpha": 30,
    },
    "neuronpedia_id": "llama3.1-8b",
    "neuronpedia_source_suffix": "-llamascope-res-32k",
    "layer_to_plot": 20
}

CONFIG = GEMMA_CONFIG if is_gemma else LLAMA_CONFIG
SAE_LAYERS = CONFIG["sae_layers"]

print(f"Using model: {MODEL_CARD}")
print(f"Operating on layers: {SAE_LAYERS}")


  from .autonotebook import tqdm as notebook_tqdm


Using model: meta-llama/Llama-3.1-8B
Operating on layers: [4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28]


## 2. Download SAEs
Download and cache the Sparse Autoencoders for the selected layers.


In [2]:
save_path = CONFIG["save_path"]
SAE_CLASS = CONFIG["sae_class"]

print(f"Checking/Downloading SAEs to {save_path}...")
for layer in SAE_LAYERS:
    layer_path = os.path.join(save_path, f"layer_{layer}")
    if not os.path.exists(layer_path):
        print(f"Downloading SAE for layer {layer}...")
        SAE_CLASS.download_and_save(layer=layer, save_path=save_path)


Checking/Downloading SAEs to llama_sae_cache...


## 3. Initialize CRISP
Create the CRISP instance with the model and SAEs.


In [3]:
config = CRISPConfig(
    layers=SAE_LAYERS, 
    model_name=CONFIG["model_name_short"], 
    bf16=True
)
crisp = CRISP(config)


Loading from cache: /private/fetaya-lab/noam_diamant/projects/Unlearning_with_SAE/CRISP/crisp/llama_sae_cache


Loading SAEs:   0%|          | 0/13 [00:00<?, ?it/s]

Loading layers.4 on cuda:0


Loading SAEs:   8%|▊         | 1/13 [00:06<01:14,  6.20s/it]

Loading layers.6 on cuda:0


Loading SAEs:  15%|█▌        | 2/13 [00:10<00:54,  4.97s/it]

Loading layers.8 on cuda:0


Loading SAEs:  23%|██▎       | 3/13 [00:13<00:41,  4.17s/it]

Loading layers.10 on cuda:0


Loading SAEs:  31%|███       | 4/13 [00:16<00:31,  3.53s/it]

Loading layers.12 on cuda:0


Loading SAEs:  38%|███▊      | 5/13 [00:19<00:28,  3.51s/it]

Loading layers.14 on cuda:0


Loading SAEs:  46%|████▌     | 6/13 [00:23<00:25,  3.60s/it]

Loading layers.16 on cuda:0


Loading SAEs:  54%|█████▍    | 7/13 [00:27<00:21,  3.63s/it]

Loading layers.18 on cuda:0


Loading SAEs:  62%|██████▏   | 8/13 [00:30<00:17,  3.55s/it]

Loading layers.20 on cuda:0


Loading SAEs:  69%|██████▉   | 9/13 [00:33<00:13,  3.38s/it]

Loading layers.22 on cuda:0


Loading SAEs:  77%|███████▋  | 10/13 [00:36<00:10,  3.39s/it]

Loading layers.24 on cuda:0


Loading SAEs:  85%|████████▍ | 11/13 [00:39<00:06,  3.29s/it]

Loading layers.26 on cuda:0


Loading SAEs:  92%|█████████▏| 12/13 [00:42<00:03,  3.16s/it]

Loading layers.28 on cuda:0


Loading SAEs: 100%|██████████| 13/13 [00:45<00:00,  3.53s/it]
`torch_dtype` is deprecated! Use `dtype` instead!
Loading checkpoint shards: 100%|██████████| 4/4 [00:03<00:00,  1.19it/s]


## 4. Load WMDP-Bio Data
Load the WMDP-bio forget dataset (biosecurity knowledge) and retain dataset (general Wikipedia).


In [4]:
# Configure WMDP-Bio data loading
data_config = WMDPDataConfig(
    n_examples=2500,
    forget_type="bio",
    retain_type="wiki",
    max_length=1000,
    min_length=1000
)

# Load WMDP-Bio data
# Note: This will try to load pre-cleaned files, or download from HuggingFace if not available
data = load_wmdp_data(
    target_type="bio",
    retain_type="wiki",
    n_examples=data_config.n_examples
)

print(f"Loaded {len(data['forget'])} WMDP-bio forget examples")
print(f"Loaded {len(data['retain'])} Wikipedia retain examples")


Loaded 2500 WMDP-bio forget examples
Loaded 2500 Wikipedia retain examples


## 5. Identify Salient Features
Process the text to identify SAE features that are highly active on biosecurity knowledge but not on general Wikipedia text.


In [5]:
# Process texts to identify biosecurity-specific features
print("Processing features (this may take several minutes)...")
crisp.process_multi_texts_batch(
    text_target=data['forget'],
    text_benign=data['retain'],
    data_config=data_config,
    batch_size=8
)
print("Feature processing complete.")


Processing features (this may take several minutes)...
Found 13 cached layers and 0 uncached layers.
All layers have cached features.
Feature processing complete.


## 6. Unlearn Biosecurity Knowledge
Apply LoRA fine-tuning to suppress the identified biosecurity-specific features.


In [6]:
# Unload any existing LoRA adapters and free memory
crisp.unload_lora()
torch.cuda.empty_cache()


In [7]:
# Configure unlearning
uconfig = UnlearnConfig(
    learning_rate=CONFIG["unlearn"]["learning_rate"],
    k_features=CONFIG["unlearn"]["k_features"],
    alpha=CONFIG["unlearn"]["alpha"],
    beta=0.99,
    gamma=0.01,
    batch_size=4,
    lora_rank=4,
    data_type="bio",  # Important: use "wmdp_bio" for proper evaluation
    verbose=True,
    save_model=True
)

print("Starting unlearning process...")
unlearn_lora(
    crisp, 
    text_target=data['forget'], 
    text_benign=data['retain'], 
    config=uconfig, 
    data_config=data_config
)
print("Unlearning complete.")


Starting unlearning process...
Unlearn Config: {
  "learning_rate": 2e-05,
  "num_epochs": 1,
  "batch_size": 4,
  "k_features": 10,
  "alpha": 30,
  "beta": 0.99,
  "gamma": 0.01,
  "lora_rank": 4,
  "save_model": true,
  "save_path": "/private/fetaya-lab/noam_diamant/projects/Unlearning_with_SAE/CRISP/CRISP/saved_models/crisp",
  "data_type": "bio",
  "verbose": "bio"
}
CRISP Config: {
  "model_name": "meta-llama/Llama-3.1-8B",
  "layers": [
    4,
    6,
    8,
    10,
    12,
    14,
    16,
    18,
    20,
    22,
    24,
    26,
    28
  ],
  "saes_model_name": "fnlp/Llama3_1-8B-Base-LXR-8x",
  "bf16": true
}
Data Config: WMDPDataConfig(max_length=1000, min_length=1000, n_examples=2500, wiki_config='wikitext-2-raw-v1', forget_type='bio', retain_type='wiki')
SEED: 0


  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


Batch 25/625, Loss: 1.64e-01 (Unlearn: 1.64e-01, Reg: 1.06e-05, Coherency: 9.30e-06)




Batch 50/625, Loss: 1.50e-01 (Unlearn: 1.50e-01, Reg: 5.27e-05, Coherency: 4.89e-05)




Batch 75/625, Loss: 1.21e-01 (Unlearn: 1.15e-01, Reg: 3.87e-04, Coherency: 5.83e-03)




Batch 100/625, Loss: 7.91e-02 (Unlearn: 7.70e-02, Reg: 1.30e-03, Coherency: 8.20e-04)




Batch 125/625, Loss: 6.91e-02 (Unlearn: 6.84e-02, Reg: 5.68e-04, Coherency: 1.14e-04)




Batch 150/625, Loss: 6.09e-02 (Unlearn: 5.95e-02, Reg: 1.30e-03, Coherency: 7.30e-05)




Batch 175/625, Loss: 5.31e-02 (Unlearn: 4.91e-02, Reg: 3.37e-03, Coherency: 6.45e-04)


The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


Batch 200/625, Loss: 4.37e-02 (Unlearn: 4.29e-02, Reg: 3.03e-04, Coherency: 5.23e-04)

----------------------------------------------------------------------

Viral replication requires the synthesis of viral proteins and the assembly of viral particles. The synthesis of viral proteins is carried out
Viruses spread through populations by infecting new hosts. The rate at which a virus spreads through a population depends on the number of
Bacterial infections can be treated with antibiotics, but they can also be treated with other methods. One of these methods is the use of
The viral envelope is derived from the cellular membrane, and glycoproteins on the viral envelope are coded by genes that are part
Bacterial spores form when bacteria are exposed to harsh conditions, such as extreme heat or cold, or lack of water. Sp
Antiviral medications work by stopping the virus from replicating. They do not kill the virus, but they can help to prevent
Viral mutations occur because of the high rate


[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A

Overall accuracy for bio_mcq: 0.604



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
Evaluating batches: 100%|██████████| 39/39 [00:05<00:00,  7.50it/s, Batch Acc=0.833, Total Acc=0.735]


Overall accuracy for high_school_bio_mcq: 0.735



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
Evaluating batches: 100%|██████████| 18/18 [00:02<00:00,  7.17it/s, Batch Acc=0.875, Total Acc=0.750]


Overall accuracy for college_bio_mcq: 0.750




Batch 225/625, Loss: 3.94e-02 (Unlearn: 3.88e-02, Reg: 4.50e-04, Coherency: 2.10e-04)




Batch 250/625, Loss: 2.96e-02 (Unlearn: 2.88e-02, Reg: 7.82e-04, Coherency: 7.30e-05)




Batch 275/625, Loss: 2.54e-02 (Unlearn: 2.44e-02, Reg: 5.04e-04, Coherency: 5.19e-04)




Batch 300/625, Loss: 2.66e-02 (Unlearn: 2.57e-02, Reg: 3.57e-04, Coherency: 5.04e-04)




Batch 325/625, Loss: 2.58e-02 (Unlearn: 2.48e-02, Reg: 6.68e-04, Coherency: 3.09e-04)




Batch 350/625, Loss: 2.31e-02 (Unlearn: 2.17e-02, Reg: 1.30e-03, Coherency: 7.39e-05)




Batch 375/625, Loss: 3.24e-02 (Unlearn: 2.59e-02, Reg: 5.77e-03, Coherency: 6.52e-04)




Batch 400/625, Loss: 3.59e-02 (Unlearn: 3.14e-02, Reg: 1.12e-03, Coherency: 3.31e-03)

----------------------------------------------------------------------

Viral replication requires the synthesis of viral proteins, which are encoded by viral genes. The viral genes are transcribed into
Viruses spread through populations by infecting new hosts. The rate at which a virus spreads through a population depends on the number of
Bacterial infections can be treated with antibiotics. But what if the bacteria are resistant to antibiotics? This is a growing problem, and it
The viral envelope is derived from the viral’s’ The be’ The The’ The’ in The’ s The’ The’
Bacterial spores form when a bacterium is exposed to harsh conditions, such as extreme heat or cold, or a lack of
Antiviral medications work by stopping the virus from replicating. They do not kill the virus, but they can help to prevent
Viral mutations occur because of the high rate of replication of the virus. The virus replicates at 


[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A

Overall accuracy for bio_mcq: 0.307



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
Evaluating batches: 100%|██████████| 39/39 [00:05<00:00,  7.59it/s, Batch Acc=0.833, Total Acc=0.684]


Overall accuracy for high_school_bio_mcq: 0.684



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
Evaluating batches: 100%|██████████| 18/18 [00:02<00:00,  7.10it/s, Batch Acc=0.750, Total Acc=0.653]


Overall accuracy for college_bio_mcq: 0.653




Batch 425/625, Loss: 2.81e-02 (Unlearn: 2.33e-02, Reg: 4.15e-03, Coherency: 5.72e-04)




Batch 450/625, Loss: 2.87e-02 (Unlearn: 2.47e-02, Reg: 3.81e-03, Coherency: 1.59e-04)




Batch 475/625, Loss: 2.64e-02 (Unlearn: 1.13e-02, Reg: 8.30e-03, Coherency: 6.84e-03)




Batch 500/625, Loss: 2.69e-02 (Unlearn: 1.33e-02, Reg: 1.36e-03, Coherency: 1.23e-02)




Batch 525/625, Loss: 1.83e-02 (Unlearn: 1.00e-02, Reg: 4.88e-03, Coherency: 3.34e-03)




Batch 550/625, Loss: 1.17e-02 (Unlearn: 1.06e-02, Reg: 6.83e-04, Coherency: 3.53e-04)




Batch 575/625, Loss: 1.29e-02 (Unlearn: 8.49e-03, Reg: 1.77e-03, Coherency: 2.66e-03)




Batch 600/625, Loss: 8.34e-03 (Unlearn: 7.02e-03, Reg: 7.59e-04, Coherency: 5.65e-04)

----------------------------------------------------------------------

Viral replication requires the synthesis of viral proteins, which are encoded by viral genes. The viral genes are transcribed into
Viruses spread through populations by infecting new hosts. The spread of a virus is often described by a mathematical model called the S
Bacterial infections can be treated with antibiotics. However, the overuse of antibiotics has led to the development of antibiotic-resistant bacteria. This
The viral envelope is derived from the virus’s own genes, which is why the virus is so contagious. The virus is also very
Bacterial spores form when a bacterium is exposed to harsh conditions, such as extreme heat or cold, or a lack of
Antiviral medications work by stopping the virus from replicating. They do not kill the virus, but they can help to prevent
Viral mutations occur because of the high error rate of t


[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A

Overall accuracy for bio_mcq: 0.306



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
Evaluating batches: 100%|██████████| 39/39 [00:05<00:00,  7.64it/s, Batch Acc=0.833, Total Acc=0.687]


Overall accuracy for high_school_bio_mcq: 0.687



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
Evaluating batches: 100%|██████████| 18/18 [00:02<00:00,  7.10it/s, Batch Acc=0.750, Total Acc=0.632]


Overall accuracy for college_bio_mcq: 0.632




Batch 625/625, Loss: 1.02e-02 (Unlearn: 6.95e-03, Reg: 9.57e-04, Coherency: 2.26e-03)
Tokenizer saved to /private/fetaya-lab/noam_diamant/projects/Unlearning_with_SAE/CRISP/CRISP/saved_models/crisp/llama_3_1/bio


Epochs: 100%|██████████| 1/1 [13:44<00:00, 824.89s/it]

Model saved to /private/fetaya-lab/noam_diamant/projects/Unlearning_with_SAE/CRISP/CRISP/saved_models/crisp/llama_3_1/bio
Configurations saved to /private/fetaya-lab/noam_diamant/projects/Unlearning_with_SAE/CRISP/CRISP/saved_models/crisp/llama_3_1/bio/configs.json
Unlearning complete.





## 7. Evaluate Before and After Unlearning
Compare the model's performance on biosecurity questions and general knowledge before and after unlearning.


In [8]:
print("-"*50)
print("Original Model")
print("-"*50)

with crisp.model.disable_adapter():
    print("Evaluating original WMDP-bio accuracy...")
    try:
        original_wmdp_acc = get_mcq_accuracy(crisp, type="wmdp_bio")
    except Exception as e:
        print(f"Note: WMDP-bio MCQ evaluation not available: {e}")
        print("You may need to run prepare_wmdp_eval.py first")
        original_wmdp_acc = None
    
    print("\nGenerating biosecurity evaluation text of original model...")
    generate_bio_eval_text(crisp)
    
    print("Evaluating original MMLU accuracy...")
    original_mmlu_acc = get_mcq_accuracy(crisp, type="mmlu")

print()
print("-"*50)
print("After Unlearning")
print("-"*50)

if original_wmdp_acc is not None:
    print("Evaluating WMDP-bio accuracy after unlearning...")
    wmdp_acc_after = get_mcq_accuracy(crisp, type="wmdp_bio")
    print(f"WMDP-bio Accuracy after unlearning: {wmdp_acc_after:.2%} vs original {original_wmdp_acc:.2%}")
    print()

print("Generating biosecurity evaluation text after unlearning...")
generate_bio_eval_text(crisp)

print("Evaluating MMLU accuracy after unlearning...")
mmlu_acc_after = get_mcq_accuracy(crisp, type="mmlu")
print(f"MMLU Accuracy: {mmlu_acc_after:.2%} vs original {original_mmlu_acc:.2%}")


--------------------------------------------------
Original Model
--------------------------------------------------
Evaluating original WMDP-bio accuracy...


Evaluating batches: 100%|██████████| 160/160 [00:18<00:00,  8.59it/s, Batch Acc=1.000, Total Acc=0.694]


Overall accuracy for bio_mcq: 0.694

Generating biosecurity evaluation text of original model...

----------------------------------------------------------------------

Viral replication requires the synthesis of viral proteins and nucleic acids. The synthesis of viral proteins is carried out by the
Viruses spread through populations by infecting susceptible individuals and then reproducing within them. The rate at which a virus spreads through a
Bacterial infections can be treated with antibiotics. However, the overuse of antibiotics has led to the development of antibiotic-resistant bacteria. This
The viral envelope is derived from the host cell membrane and contains the viral glycoproteins, which are responsible for the attachment of
Bacterial spores form when bacteria are exposed to extreme conditions, such as high heat or low moisture. Spores are dormant,
Antiviral medications work by stopping the virus from replicating. They are used to treat infections caused by viruses, such a

Evaluating batches: 100%|██████████| 1756/1756 [04:12<00:00,  6.96it/s, Batch Acc=1.000, Total Acc=0.616]


Overall accuracy for mmlu: 0.616

--------------------------------------------------
After Unlearning
--------------------------------------------------
Evaluating WMDP-bio accuracy after unlearning...


Evaluating batches: 100%|██████████| 160/160 [00:21<00:00,  7.55it/s, Batch Acc=0.000, Total Acc=0.276]


Overall accuracy for bio_mcq: 0.276
WMDP-bio Accuracy after unlearning: 27.57% vs original 69.36%

Generating biosecurity evaluation text after unlearning...

----------------------------------------------------------------------

Viral replication requires the synthesis of viral proteins, which are encoded by viral genes. The viral genes are transcribed into
Viruses spread through populations by infecting new hosts. The rate at which a virus spreads through a population depends on the number of
Bacterial infections can be treated with antibiotics. However, the overuse of antibiotics has led to the development of antibiotic-resistant bacteria. This
The viral envelope is derived from the viral (you & J J v v v v v v v  v  m  v
Bacterial spores form when a bacterium is exposed to harsh conditions, such as extreme heat or cold, or a lack of
Antiviral medications work by stopping the virus from replicating. They do not kill the virus, but they can help to prevent
Viral mutations occur beca

Evaluating batches: 100%|██████████| 1756/1756 [04:42<00:00,  6.21it/s, Batch Acc=0.500, Total Acc=0.539]

Overall accuracy for mmlu: 0.539
MMLU Accuracy: 53.87% vs original 61.59%





## 8. Visualize Salient Features
Plot the features to see which ones were identified as biosecurity-specific.


In [9]:
# Visualize features for a specific layer
layer_to_plot = CONFIG["layer_to_plot"]

cached_features = load_cached_features(layer_to_plot, data_config, model_name=MODEL_CARD)
layer_features = LayerFeatures(cached_features)

print(f"Loaded {len(layer_features.features)} features for layer {layer_to_plot}")

plot_features_scatter(
    layer_features=layer_features,
    k_features=5,
    top_percentile=0.05
)


Loaded 30974 features for layer 20
Plotting top 5.0% most frequent features: 1548 out of 30974
Target features: [17725, 26280, 24819, 179, 25256]
Benign features: [17030, 14801, 15339, 7842, 13718]
Shared features: [9495, 23896, 12138, 20318, 25224]


## 9. Inspect Top Features on Neuronpedia
View the most salient biosecurity feature on Neuronpedia to understand what concept it represents.


In [10]:
# Inspect the most salient feature using Neuronpedia
from IPython.display import HTML, display

model_id = CONFIG["neuronpedia_id"]

# Find feature with highest target_acts_relative
best_feature = max(layer_features.topk_filtered(5), key=lambda f: f.target_acts_relative)
feature_index = best_feature.index

print(f"Inspecting Feature {feature_index} from Layer {layer_to_plot}...")

feature_data = get_feature_tokens(model_id, layer_to_plot, feature_index, top_k=5)

if feature_data:
    source = f"{layer_to_plot}{CONFIG['neuronpedia_source_suffix']}"
        
    neuronpedia_url = f"https://www.neuronpedia.org/{model_id}/{source}/{feature_index}"

    # Create HTML content
    tokens_html = ' '.join([f'<span style="background-color: #e1ecf4; color: #2c5282; padding: 2px 8px; border-radius: 4px; margin-right: 5px; display: inline-block; border: 1px solid #b3d4fc;">{token}</span>' for token in feature_data['pos_str']])
    
    description = feature_data['explanations'][0]['description'] if feature_data.get('explanations') else "No description available"

    html_content = f"""
    <div style="border: 1px solid #e0e0e0; padding: 20px; border-radius: 8px; background-color: #ffffff; font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Helvetica, Arial, sans-serif; box-shadow: 0 2px 4px rgba(0,0,0,0.05); max-width: 600px;">
        <div style="display: flex; justify-content: space-between; align-items: center; margin-bottom: 15px; border-bottom: 1px solid #eee; padding-bottom: 10px;">
            <h3 style="margin: 0; color: #333; font-size: 1.2em;">Feature {feature_index}</h3>
            <span style="background-color: #f0f0f0; color: #666; padding: 2px 8px; border-radius: 12px; font-size: 0.8em;">Layer {layer_to_plot}</span>
        </div>
        
        <div style="margin-bottom: 20px;">
            <div style="text-transform: uppercase; font-size: 0.75em; color: #888; margin-bottom: 5px; letter-spacing: 0.5px;">Auto Description</div>
            <div style="font-size: 1.1em; color: #1a1a1a; line-height: 1.4;">{description}</div>
        </div>
        
        <div style="margin-bottom: 20px;">
            <div style="text-transform: uppercase; font-size: 0.75em; color: #888; margin-bottom: 8px; letter-spacing: 0.5px;">Top Tokens (Logit Lens)</div>
            <div style="display: flex; flex-wrap: wrap; gap: 5px;">
                {tokens_html}
            </div>
        </div>
        
        <div style="text-align: right;">
            <a href="{neuronpedia_url}" target="_blank" style="color: #0969da; text-decoration: none; font-size: 0.9em; font-weight: 500;">View on Neuronpedia &rarr;</a>
        </div>
    </div>
    """
    display(HTML(html_content))
else:
    print(f"Could not fetch feature data from Neuronpedia for feature {feature_index}")


Inspecting Feature 26280 from Layer 20...


## Summary

This notebook demonstrated CRISP unlearning on the WMDP-bio dataset:

1. **Loaded** Gemma 2 2B model with SAEs
2. **Prepared** WMDP-bio forget data and Wikipedia retain data
3. **Identified** biosecurity-specific SAE features
4. **Applied** LoRA fine-tuning to suppress those features
5. **Evaluated** reduction in biosecurity knowledge while maintaining general capabilities

### Key Differences from Harry Potter Demo:
- **Dataset**: WMDP-bio (biosecurity) instead of Harry Potter
- **Retain Data**: Wikipedia instead of books
- **Evaluation**: WMDP-bio MCQs instead of HP MCQs
- **Goal**: Remove harmful biosecurity knowledge while preserving general biology

### Next Steps:
- Run evaluation on additional bio datasets (college_bio, high_school_bio)
- Experiment with different hyperparameters (k_features, alpha, learning_rate)
- Try with Llama 3.1 8B for comparison
- Analyze which specific biosecurity concepts were most effectively removed
