# Basketball/Baseball SAE Feature Analysis

This notebook demonstrates SAE feature analysis for basketball and baseball concepts using LLAMA_3_1_8B model.


## 1. Setup and Configuration
We initialize the environment, select the model (Llama 3.1 8B), and set up VLLM for fast text generation.


In [1]:
%load_ext autoreload
%autoreload 2

import os
import time
import torch
from dataclasses import dataclass
from typing import List
from vllm import LLM, SamplingParams
from globals import LLAMA_3_1_8B
from crisp import CRISP, CRISPConfig, LayerFeatures
from sae import TopkSae
from utils import load_cached_features, save_cached_features
from plot import plot_features_scatter
from datasets import load_dataset
from data import prepare_text

# Set GPU device
os.environ["CUDA_VISIBLE_DEVICES"] = "2"  # Adjust as needed

# Check for HF token
if 'HF_TOKEN' not in os.environ or os.environ.get('HF_TOKEN') is None:
    print("Warning: HF_TOKEN environment variable not set. You may need it for model access.")

# Configuration for LLAMA_3_1_8B
MODEL_CARD = LLAMA_3_1_8B
LAYER_TO_ANALYZE = 16
SAE_LAYERS = [LAYER_TO_ANALYZE]  # Only layer 16
SAE_SAVE_PATH = "llama_sae_cache"

print(f"Using model: {MODEL_CARD}")
print(f"Operating on layer: {LAYER_TO_ANALYZE}")


  from .autonotebook import tqdm as notebook_tqdm
[2026-01-08 15:02:06] INFO config.py:54: PyTorch version 2.9.0 available.


Using model: meta-llama/Llama-3.1-8B
Operating on layer: 16


## 2. Download SAE for Layer 16


In [2]:
# Download SAE for layer 16 if not already cached
print(f"Checking/Downloading SAE for layer {LAYER_TO_ANALYZE}...")
layer_path = os.path.join(SAE_SAVE_PATH, f"layer_{LAYER_TO_ANALYZE}")
if not os.path.exists(layer_path):
    print(f"Downloading SAE for layer {LAYER_TO_ANALYZE}...")
    TopkSae.download_and_save(layer=LAYER_TO_ANALYZE, save_path=SAE_SAVE_PATH)
else:
    print(f"SAE for layer {LAYER_TO_ANALYZE} already cached.")


Checking/Downloading SAE for layer 16...
SAE for layer 16 already cached.


## 3. Initialize VLLM for Fast Generation


In [3]:
# Initialize VLLM for fast generation
print("Initializing VLLM...")
vllm_model = LLM(
    model=MODEL_CARD,
    trust_remote_code=True,
    gpu_memory_utilization=0.9,
    dtype="bfloat16"
)
print("VLLM initialized successfully.")


Initializing VLLM...
INFO 01-08 15:02:06 [utils.py:253] non-default args: {'trust_remote_code': True, 'dtype': 'bfloat16', 'disable_log_stats': True, 'model': 'meta-llama/Llama-3.1-8B'}


The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.


INFO 01-08 15:02:07 [model.py:514] Resolved architecture: LlamaForCausalLM
INFO 01-08 15:02:07 [model.py:1661] Using max model len 131072


2026-01-08 15:02:08,268	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


INFO 01-08 15:02:08 [scheduler.py:230] Chunked prefill is enabled with max_num_batched_tokens=8192.
[0;36m(EngineCore_DP0 pid=3599596)[0;0m INFO 01-08 15:02:09 [core.py:93] Initializing a V1 LLM engine (v0.13.0) with config: model='meta-llama/Llama-3.1-8B', speculative_config=None, tokenizer='meta-llama/Llama-3.1-8B', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=True, dtype=torch.bfloat16, max_seq_len=131072, download_dir=None, load_format=auto, tensor_parallel_size=1, pipeline_parallel_size=1, data_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, device_config=cuda, structured_outputs_config=StructuredOutputsConfig(backend='auto', disable_fallback=False, disable_any_whitespace=False, disable_additional_properties=False, reasoning_parser='', reasoning_parser_plugin='', enable_in_reasoning=False), observability_config=ObservabilityConfig(show_hidden_metrics_for_ve

Loading safetensors checkpoint shards:   0% Completed | 0/4 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  25% Completed | 1/4 [00:00<00:01,  1.52it/s]
Loading safetensors checkpoint shards:  50% Completed | 2/4 [00:01<00:01,  1.26it/s]
Loading safetensors checkpoint shards:  75% Completed | 3/4 [00:01<00:00,  1.65it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:02<00:00,  1.58it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:02<00:00,  1.54it/s]
[0;36m(EngineCore_DP0 pid=3599596)[0;0m 


[0;36m(EngineCore_DP0 pid=3599596)[0;0m INFO 01-08 15:02:17 [default_loader.py:308] Loading weights took 2.83 seconds
[0;36m(EngineCore_DP0 pid=3599596)[0;0m INFO 01-08 15:02:17 [gpu_model_runner.py:3659] Model loading took 14.9889 GiB memory and 4.482622 seconds
[0;36m(EngineCore_DP0 pid=3599596)[0;0m INFO 01-08 15:02:25 [backends.py:643] Using cache directory: /home/dsi/diamann2/.cache/vllm/torch_compile_cache/b7b3b3ce0c/rank_0_0/backbone for vLLM's torch.compile
[0;36m(EngineCore_DP0 pid=3599596)[0;0m INFO 01-08 15:02:25 [backends.py:703] Dynamo bytecode transform time: 7.56 s
[0;36m(EngineCore_DP0 pid=3599596)[0;0m INFO 01-08 15:02:30 [backends.py:226] Directly load the compiled graph(s) for compile range (1, 8192) from the cache, took 1.964 s
[0;36m(EngineCore_DP0 pid=3599596)[0;0m INFO 01-08 15:02:30 [monitor.py:34] torch.compile takes 9.52 s in total
[0;36m(EngineCore_DP0 pid=3599596)[0;0m INFO 01-08 15:02:31 [gpu_worker.py:375] Available KV cache memory: 23.69 GiB

Capturing CUDA graphs (mixed prefill-decode, PIECEWISE): 100%|██████████| 51/51 [00:04<00:00, 11.57it/s]
Capturing CUDA graphs (decode, FULL): 100%|██████████| 35/35 [00:02<00:00, 13.38it/s]


[0;36m(EngineCore_DP0 pid=3599596)[0;0m INFO 01-08 15:02:40 [gpu_model_runner.py:4587] Graph capturing finished in 8 secs, took 0.57 GiB
[0;36m(EngineCore_DP0 pid=3599596)[0;0m INFO 01-08 15:02:40 [core.py:259] init engine (profile, create kv cache, warmup model) took 22.25 seconds
INFO 01-08 15:02:40 [llm.py:360] Supported tasks: ['generate']
VLLM initialized successfully.


## 4. Generate Basketball and Baseball Sentences


In [4]:
# Create 50 keywords for basketball
basketball_keywords = [
    "dunk", "three-pointer", "rebound", "free throw", "layup",
    "point guard", "shooting guard", "small forward", "power forward", "center",
    "basketball court", "hoop", "backboard", "foul", "traveling",
    "jump shot", "alley-oop", "fast break", "pick and roll", "zone defense",
    "man-to-man defense", "full court press", "half court", "buzzer beater", "slam dunk",
    "NBA", "college basketball", "March Madness", "basketball player", "coach",
    "timeout", "substitution", "technical foul", "flagrant foul", "ejection",
    "basketball team", "championship", "playoffs", "regular season", "overtime",
    "basketball game", "scoring", "assist", "steal", "block",
    "basketball skills", "ball handling", "shooting form", "defensive stance", "offensive play"
]

# Create 50 keywords for baseball
baseball_keywords = [
    "home run", "pitcher", "strikeout", "baseball", "batting",
    "first base", "second base", "third base", "home plate", "outfield",
    "infield", "catcher", "shortstop", "second baseman", "third baseman",
    "baseball diamond", "mound", "batter's box", "dugout", "bullpen",
    "fastball", "curveball", "slider", "changeup", "knuckleball",
    "MLB", "World Series", "baseball game", "inning", "strike",
    "ball", "walk", "hit", "double", "triple",
    "baseball team", "manager", "coach", "umpire", "referee",
    "baseball player", "pitcher's mound", "base running", "stealing base", "bunt",
    "baseball skills", "pitching", "hitting", "fielding", "catching"
]

print(f"Created {len(basketball_keywords)} basketball keywords")
print(f"Created {len(baseball_keywords)} baseball keywords")


Created 50 basketball keywords
Created 50 baseball keywords


In [5]:
# Generate 250 prompts for basketball (using keywords 10 times to get 250)
basketball_prompts = [f"Write a sentence about {keyword}" for keyword in basketball_keywords * 5]

# Generate 250 prompts for baseball (using keywords 10 times to get 250)
baseball_prompts = [f"Write a sentence about {keyword}" for keyword in baseball_keywords * 5]

print(f"Generated {len(basketball_prompts)} basketball prompts")
print(f"Generated {len(baseball_prompts)} baseball prompts")


Generated 250 basketball prompts
Generated 250 baseball prompts


In [6]:
# Configure sampling parameters for generation
sampling_params = SamplingParams(
    temperature=0.7,
    top_p=0.9,
    max_tokens=50,  # Generate up to 50 tokens per sentence
    stop=None
)

# Generate basketball sentences
print("Generating basketball sentences...")
start_time = time.time()
basketball_outputs = vllm_model.generate(basketball_prompts, sampling_params)
basketball_generation_time = time.time() - start_time

# Extract generated text
basketball_sentences = []
for output in basketball_outputs:
    generated_text = output.outputs[0].text.strip()
    # Combine prompt and generated text to form complete sentence
    prompt = output.prompt
    full_sentence = prompt + " " + generated_text
    basketball_sentences.append(full_sentence)

print(f"Generated {len(basketball_sentences)} basketball sentences")
print(f"Time taken: {basketball_generation_time:.2f} seconds")


Generating basketball sentences...


Adding requests: 100%|██████████| 250/250 [00:00<00:00, 4640.95it/s]
Processed prompts: 100%|██████████| 250/250 [00:02<00:00, 97.02it/s, est. speed input: 661.92 toks/s, output: 4824.43 toks/s]

Generated 250 basketball sentences
Time taken: 2.64 seconds





In [7]:
# Generate baseball sentences
print("Generating baseball sentences...")
start_time = time.time()
baseball_outputs = vllm_model.generate(baseball_prompts, sampling_params)
baseball_generation_time = time.time() - start_time

# Extract generated text
baseball_sentences = []
for output in baseball_outputs:
    generated_text = output.outputs[0].text.strip()
    # Combine prompt and generated text to form complete sentence
    prompt = output.prompt
    full_sentence = prompt + " " + generated_text
    baseball_sentences.append(full_sentence)

print(f"Generated {len(baseball_sentences)} baseball sentences")
print(f"Time taken: {baseball_generation_time:.2f} seconds")

total_time = basketball_generation_time + baseball_generation_time
print(f"\nTotal generation time: {total_time:.2f} seconds")


Generating baseball sentences...


Adding requests: 100%|██████████| 250/250 [00:00<00:00, 5946.74it/s]
Processed prompts: 100%|██████████| 250/250 [00:02<00:00, 97.82it/s, est. speed input: 643.90 toks/s, output: 4877.18 toks/s]

Generated 250 baseball sentences
Time taken: 2.60 seconds

Total generation time: 5.24 seconds





## 5. Load Wikipedia Retain Examples


In [8]:
# Load 250 Wikipedia retain examples
print("Loading Wikipedia retain examples...")
wiki_data = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
wiki_texts = prepare_text(wiki_data['text'], max_len=1000)

# Take 250 examples
wiki_retain = list(wiki_texts)[:250]
print(f"Loaded {len(wiki_retain)} Wikipedia retain examples")


Loading Wikipedia retain examples...
Loaded 250 Wikipedia retain examples


## 6. Initialize CRISP for Feature Analysis


In [9]:
# Create a simple data config class for caching
@dataclass
class BasketballBaseballDataConfig:
    max_length: int = 1000
    min_length: int = 100
    n_examples: int = 250
    data_type: str = "basketball_baseball"
    forget_type: str = "basketball"  # For caching compatibility
    retain_type: str = "wiki"  # For caching compatibility
    
    def to_dict(self):
        return {
            "max_length": self.max_length,
            "min_length": self.min_length,
            "n_examples": self.n_examples,
            "data_type": self.data_type,
            "forget_type": self.forget_type,
            "retain_type": self.retain_type
        }

# Initialize CRISP
config = CRISPConfig(
    layers=SAE_LAYERS,
    model_name="llama",  # This will be converted to LLAMA_3_1_8B
    bf16=True
)
crisp = CRISP(config)
print("CRISP initialized successfully.")


Loading from cache: /private/fetaya-lab/noam_diamant/projects/Unlearning_with_SAE/CRISP/crisp/llama_sae_cache


Loading SAEs:   0%|          | 0/1 [00:00<?, ?it/s]

Loading layers.16 on cuda:0


Loading SAEs: 100%|██████████| 1/1 [00:00<00:00,  1.67it/s]
`torch_dtype` is deprecated! Use `dtype` instead!
[2026-01-08 15:02:54] INFO modeling.py:987: We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).
Loading checkpoint shards: 100%|██████████| 4/4 [00:00<00:00, 19.63it/s]


CRISP initialized successfully.


ERROR 01-08 15:03:08 [core_client.py:606] Engine core proc EngineCore_DP0 died unexpectedly, shutting down client.


## 7. Process Features: Basketball vs Wikipedia


In [10]:
# Create data config for basketball vs wiki
data_config_basketball_wiki = BasketballBaseballDataConfig(
    n_examples=250,
    data_type="basketball_wiki",
    forget_type="basketball",
    retain_type="wiki"
)

# Process basketball vs Wikipedia
print("Processing features: Basketball vs Wikipedia...")
crisp.process_multi_texts_batch(
    text_target=basketball_sentences,
    text_benign=wiki_retain,
    data_config=data_config_basketball_wiki,
    batch_size=32
)
print("Feature processing complete for Basketball vs Wikipedia.")

# Load features for layer 16
basketball_wiki_features = load_cached_features(
    LAYER_TO_ANALYZE,
    data_config_basketball_wiki,
    model_name=MODEL_CARD
)
if basketball_wiki_features is None:
    # If not cached, get from crisp (features are already processed and stored)
    basketball_wiki_features = list(crisp.features_dict[LAYER_TO_ANALYZE].features.values())
else:
    # If cached, it's already a list of Feature objects
    basketball_wiki_features = list(basketball_wiki_features)

basketball_wiki_layer_features = LayerFeatures(basketball_wiki_features)
print(f"Loaded {len(basketball_wiki_layer_features)} features for Basketball vs Wikipedia")


Processing features: Basketball vs Wikipedia...
Found 0 cached layers and 1 uncached layers.
Need to process 1 uncached layers: [16]


Processing text batches: 100%|██████████| 8/8 [00:45<00:00,  5.70s/it]
Processing layers:   0%|          | 0/1 [00:00<?, ?it/s]

Processing features for layer 16.


Processing layers: 100%|██████████| 1/1 [00:00<00:00,  3.57it/s]

Saving features for layer 16 to cache.

Feature processing complete for Basketball vs Wikipedia.
Loaded 23328 features for Basketball vs Wikipedia





## 8. Process Features: Baseball vs Wikipedia


In [12]:
# Create data config for baseball vs wiki
data_config_baseball_wiki = BasketballBaseballDataConfig(
    n_examples=250,
    data_type="baseball_wiki",
    forget_type="baseball",
    retain_type="wiki"
)

# Clear previous features
crisp.features_dict.clear()
torch.cuda.empty_cache()

# Process baseball vs Wikipedia
print("Processing features: Baseball vs Wikipedia...")
crisp.process_multi_texts_batch(
    text_target=baseball_sentences,
    text_benign=wiki_retain,
    data_config=data_config_baseball_wiki,
    batch_size=8
)
print("Feature processing complete for Baseball vs Wikipedia.")

# Load features for layer 16
baseball_wiki_features = load_cached_features(
    LAYER_TO_ANALYZE,
    data_config_baseball_wiki,
    model_name=MODEL_CARD
)
if baseball_wiki_features is None:
    # If not cached, get from crisp (features are already processed and stored)
    baseball_wiki_features = list(crisp.features_dict[LAYER_TO_ANALYZE].features.values())
else:
    # If cached, it's already a list of Feature objects
    baseball_wiki_features = list(baseball_wiki_features)

baseball_wiki_layer_features = LayerFeatures(baseball_wiki_features)
print(f"Loaded {len(baseball_wiki_layer_features)} features for Baseball vs Wikipedia")


Processing features: Baseball vs Wikipedia...
Found 0 cached layers and 1 uncached layers.
Need to process 1 uncached layers: [16]


Processing text batches: 100%|██████████| 32/32 [01:40<00:00,  3.13s/it]
Processing layers:   0%|          | 0/1 [00:00<?, ?it/s]

Processing features for layer 16.


Processing layers: 100%|██████████| 1/1 [00:00<00:00,  3.58it/s]

Saving features for layer 16 to cache.

Feature processing complete for Baseball vs Wikipedia.
Loaded 23125 features for Baseball vs Wikipedia





In [15]:
# Clear previous features
import gc
crisp.features_dict.clear()
torch.cuda.empty_cache()
gc.collect()
torch.cuda.empty_cache()

## 9. Process Features: Basketball vs Baseball


In [16]:
# Create data config for basketball vs baseball
data_config_basketball_baseball = BasketballBaseballDataConfig(
    n_examples=250,
    data_type="basketball_baseball",
    forget_type="basketball",
    retain_type="baseball"
)

# Clear previous features
crisp.features_dict.clear()
torch.cuda.empty_cache()

# Process basketball vs baseball
print("Processing features: Basketball vs Baseball...")
crisp.process_multi_texts_batch(
    text_target=basketball_sentences,
    text_benign=baseball_sentences,
    data_config=data_config_basketball_baseball,
    batch_size=64
)
print("Feature processing complete for Basketball vs Baseball.")

# Load features for layer 16
basketball_baseball_features = load_cached_features(
    LAYER_TO_ANALYZE,
    data_config_basketball_baseball,
    model_name=MODEL_CARD
)
if basketball_baseball_features is None:
    # If not cached, get from crisp (features are already processed and stored)
    basketball_baseball_features = list(crisp.features_dict[LAYER_TO_ANALYZE].features.values())
else:
    # If cached, it's already a list of Feature objects
    basketball_baseball_features = list(basketball_baseball_features)

basketball_baseball_layer_features = LayerFeatures(basketball_baseball_features)
print(f"Loaded {len(basketball_baseball_layer_features)} features for Basketball vs Baseball")


Processing features: Basketball vs Baseball...
Found 0 cached layers and 1 uncached layers.
Need to process 1 uncached layers: [16]


Processing text batches: 100%|██████████| 4/4 [00:14<00:00,  3.63s/it]
Processing layers:   0%|          | 0/1 [00:00<?, ?it/s]

Processing features for layer 16.
Saving features for layer 16 to cache.



Processing layers: 100%|██████████| 1/1 [00:00<00:00,  4.35it/s]

Feature processing complete for Basketball vs Baseball.
Loaded 19122 features for Basketball vs Baseball





## 10. Visualization: Plot Feature Activations


In [17]:
# Plot 1: Basketball vs General (Wikipedia)
print("Plot 1: Active features in basketball vs. active features in general text")
plot_features_scatter(
    layer_features=basketball_wiki_layer_features,
    k_features=5,
    top_percentile=0.05,
    title="Basketball vs General Text (Layer 16)"
)


Plot 1: Active features in basketball vs. active features in general text
Plotting top 5.0% most frequent features: 1166 out of 23328
Target features: [17949, 22260, 6763, 12100, 11138]
Benign features: [6986, 27812, 20166, 22593, 8804]
Shared features: [1457, 5391, 10758, 12514, 12171]


In [18]:
# Plot 2: Baseball vs General (Wikipedia)
print("Plot 2: Active features in baseball vs. active features in general text")
plot_features_scatter(
    layer_features=baseball_wiki_layer_features,
    k_features=5,
    top_percentile=0.05,
    title="Baseball vs General Text (Layer 16)"
)


Plot 2: Active features in baseball vs. active features in general text
Plotting top 5.0% most frequent features: 1156 out of 23125
Target features: [17949, 22260, 6763, 20945, 31348]
Benign features: [6986, 27812, 20166, 8804, 22593]
Shared features: [1457, 5391, 10758, 12514, 12171]


In [20]:
# Plot 3: Basketball vs Baseball
print("Plot 3: Active features in basketball vs. active features in baseball")
plot_features_scatter(
    layer_features=basketball_baseball_layer_features,
    k_features=10,
    top_percentile=0.05,
    title="Basketball vs Baseball (Layer 16)"
)


Plot 3: Active features in basketball vs. active features in baseball
Plotting top 5.0% most frequent features: 956 out of 19122
Target features: [1294, 32303, 24259, 6065, 1267, 25751, 32496, 9850, 13814, 9187]
Benign features: [13448, 21101, 29566, 24811, 27527, 2002, 21930, 2572, 17019, 11335]
Shared features: [6986, 10758, 27812, 1457, 5391, 20166, 22593, 12171, 22639, 13819]
