# Basketball/Baseball SAE Feature Analysis

This notebook demonstrates SAE feature analysis for basketball and baseball concepts using LLAMA_3_1_8B model.


## 1. Setup and Configuration
We initialize the environment, select the model (Llama 3.1 8B), and set up VLLM for fast text generation.


In [None]:
%load_ext autoreload
%autoreload 2

import os
import time
import torch
from dataclasses import dataclass
from typing import List
from vllm import LLM, SamplingParams
from globals import LLAMA_3_1_8B
from crisp import CRISP, CRISPConfig, LayerFeatures
from sae import TopkSae
from utils import load_cached_features, save_cached_features
from plot import plot_features_scatter
from datasets import load_dataset
from data import prepare_text

# Set GPU device
os.environ["CUDA_VISIBLE_DEVICES"] = "2"  # Adjust as needed

# Check for HF token
if 'HF_TOKEN' not in os.environ or os.environ.get('HF_TOKEN') is None:
    print("Warning: HF_TOKEN environment variable not set. You may need it for model access.")

# Configuration for LLAMA_3_1_8B
MODEL_CARD = LLAMA_3_1_8B
LAYER_TO_ANALYZE = 16
SAE_LAYERS = [LAYER_TO_ANALYZE]  # Only layer 16
SAE_SAVE_PATH = "llama_sae_cache"

print(f"Using model: {MODEL_CARD}")
print(f"Operating on layer: {LAYER_TO_ANALYZE}")


## 2. Download SAE for Layer 16


In [None]:
# Download SAE for layer 16 if not already cached
print(f"Checking/Downloading SAE for layer {LAYER_TO_ANALYZE}...")
layer_path = os.path.join(SAE_SAVE_PATH, f"layer_{LAYER_TO_ANALYZE}")
if not os.path.exists(layer_path):
    print(f"Downloading SAE for layer {LAYER_TO_ANALYZE}...")
    TopkSae.download_and_save(layer=LAYER_TO_ANALYZE, save_path=SAE_SAVE_PATH)
else:
    print(f"SAE for layer {LAYER_TO_ANALYZE} already cached.")


## 3. Initialize VLLM for Fast Generation


In [None]:
# Initialize VLLM for fast generation
print("Initializing VLLM...")
vllm_model = LLM(
    model=MODEL_CARD,
    trust_remote_code=True,
    gpu_memory_utilization=0.9,
    dtype="bfloat16"
)
print("VLLM initialized successfully.")


## 4. Generate Basketball and Baseball Sentences


In [None]:
# Create 50 keywords for basketball
basketball_keywords = [
    "dunk", "three-pointer", "rebound", "free throw", "layup",
    "point guard", "shooting guard", "small forward", "power forward", "center",
    "basketball court", "hoop", "backboard", "foul", "traveling",
    "jump shot", "alley-oop", "fast break", "pick and roll", "zone defense",
    "man-to-man defense", "full court press", "half court", "buzzer beater", "slam dunk",
    "NBA", "college basketball", "March Madness", "basketball player", "coach",
    "timeout", "substitution", "technical foul", "flagrant foul", "ejection",
    "basketball team", "championship", "playoffs", "regular season", "overtime",
    "basketball game", "scoring", "assist", "steal", "block",
    "basketball skills", "ball handling", "shooting form", "defensive stance", "offensive play"
]

# Create 50 keywords for baseball
baseball_keywords = [
    "home run", "pitcher", "strikeout", "baseball", "batting",
    "first base", "second base", "third base", "home plate", "outfield",
    "infield", "catcher", "shortstop", "second baseman", "third baseman",
    "baseball diamond", "mound", "batter's box", "dugout", "bullpen",
    "fastball", "curveball", "slider", "changeup", "knuckleball",
    "MLB", "World Series", "baseball game", "inning", "strike",
    "ball", "walk", "hit", "double", "triple",
    "baseball team", "manager", "coach", "umpire", "referee",
    "baseball player", "pitcher's mound", "base running", "stealing base", "bunt",
    "baseball skills", "pitching", "hitting", "fielding", "catching"
]

print(f"Created {len(basketball_keywords)} basketball keywords")
print(f"Created {len(baseball_keywords)} baseball keywords")


In [None]:
# Generate 100 prompts for basketball (using keywords twice to get 100)
basketball_prompts = [f"Write a sentence about {keyword}" for keyword in basketball_keywords * 2]

# Generate 100 prompts for baseball (using keywords twice to get 100)
baseball_prompts = [f"Write a sentence about {keyword}" for keyword in baseball_keywords * 2]

print(f"Generated {len(basketball_prompts)} basketball prompts")
print(f"Generated {len(baseball_prompts)} baseball prompts")


In [None]:
# Configure sampling parameters for generation
sampling_params = SamplingParams(
    temperature=0.7,
    top_p=0.9,
    max_tokens=50,  # Generate up to 50 tokens per sentence
    stop=None
)

# Generate basketball sentences
print("Generating basketball sentences...")
start_time = time.time()
basketball_outputs = vllm_model.generate(basketball_prompts, sampling_params)
basketball_generation_time = time.time() - start_time

# Extract generated text
basketball_sentences = []
for output in basketball_outputs:
    generated_text = output.outputs[0].text.strip()
    # Combine prompt and generated text to form complete sentence
    prompt = output.prompt
    full_sentence = prompt + " " + generated_text
    basketball_sentences.append(full_sentence)

print(f"Generated {len(basketball_sentences)} basketball sentences")
print(f"Time taken: {basketball_generation_time:.2f} seconds")


In [None]:
# Generate baseball sentences
print("Generating baseball sentences...")
start_time = time.time()
baseball_outputs = vllm_model.generate(baseball_prompts, sampling_params)
baseball_generation_time = time.time() - start_time

# Extract generated text
baseball_sentences = []
for output in baseball_outputs:
    generated_text = output.outputs[0].text.strip()
    # Combine prompt and generated text to form complete sentence
    prompt = output.prompt
    full_sentence = prompt + " " + generated_text
    baseball_sentences.append(full_sentence)

print(f"Generated {len(baseball_sentences)} baseball sentences")
print(f"Time taken: {baseball_generation_time:.2f} seconds")

total_time = basketball_generation_time + baseball_generation_time
print(f"\nTotal generation time: {total_time:.2f} seconds")


## 5. Load Wikipedia Retain Examples


In [None]:
# Load 100 Wikipedia retain examples
print("Loading Wikipedia retain examples...")
wiki_data = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
wiki_texts = prepare_text(wiki_data['text'], max_len=1000)

# Take 100 examples
wiki_retain = list(wiki_texts)[:100]
print(f"Loaded {len(wiki_retain)} Wikipedia retain examples")


## 6. Initialize CRISP for Feature Analysis


In [None]:
# Create a simple data config class for caching
@dataclass
class BasketballBaseballDataConfig:
    max_length: int = 1000
    min_length: int = 100
    n_examples: int = 100
    data_type: str = "basketball_baseball"
    forget_type: str = "basketball"  # For caching compatibility
    retain_type: str = "wiki"  # For caching compatibility
    
    def to_dict(self):
        return {
            "max_length": self.max_length,
            "min_length": self.min_length,
            "n_examples": self.n_examples,
            "data_type": self.data_type,
            "forget_type": self.forget_type,
            "retain_type": self.retain_type
        }

# Initialize CRISP
config = CRISPConfig(
    layers=SAE_LAYERS,
    model_name="llama",  # This will be converted to LLAMA_3_1_8B
    bf16=True
)
crisp = CRISP(config)
print("CRISP initialized successfully.")


## 7. Process Features: Basketball vs Wikipedia


In [None]:
# Create data config for basketball vs wiki
data_config_basketball_wiki = BasketballBaseballDataConfig(
    n_examples=100,
    data_type="basketball_wiki",
    forget_type="basketball",
    retain_type="wiki"
)

# Process basketball vs Wikipedia
print("Processing features: Basketball vs Wikipedia...")
crisp.process_multi_texts_batch(
    text_target=basketball_sentences,
    text_benign=wiki_retain,
    data_config=data_config_basketball_wiki,
    batch_size=8
)
print("Feature processing complete for Basketball vs Wikipedia.")

# Load features for layer 16
basketball_wiki_features = load_cached_features(
    LAYER_TO_ANALYZE,
    data_config_basketball_wiki,
    model_name=MODEL_CARD
)
if basketball_wiki_features is None:
    # If not cached, get from crisp (features are already processed and stored)
    basketball_wiki_features = list(crisp.features_dict[LAYER_TO_ANALYZE].features.values())
else:
    # If cached, it's already a list of Feature objects
    basketball_wiki_features = list(basketball_wiki_features)

basketball_wiki_layer_features = LayerFeatures(basketball_wiki_features)
print(f"Loaded {len(basketball_wiki_layer_features)} features for Basketball vs Wikipedia")


## 8. Process Features: Baseball vs Wikipedia


In [None]:
# Create data config for baseball vs wiki
data_config_baseball_wiki = BasketballBaseballDataConfig(
    n_examples=100,
    data_type="baseball_wiki",
    forget_type="baseball",
    retain_type="wiki"
)

# Clear previous features
crisp.features_dict.clear()
torch.cuda.empty_cache()

# Process baseball vs Wikipedia
print("Processing features: Baseball vs Wikipedia...")
crisp.process_multi_texts_batch(
    text_target=baseball_sentences,
    text_benign=wiki_retain,
    data_config=data_config_baseball_wiki,
    batch_size=8
)
print("Feature processing complete for Baseball vs Wikipedia.")

# Load features for layer 16
baseball_wiki_features = load_cached_features(
    LAYER_TO_ANALYZE,
    data_config_baseball_wiki,
    model_name=MODEL_CARD
)
if baseball_wiki_features is None:
    # If not cached, get from crisp (features are already processed and stored)
    baseball_wiki_features = list(crisp.features_dict[LAYER_TO_ANALYZE].features.values())
else:
    # If cached, it's already a list of Feature objects
    baseball_wiki_features = list(baseball_wiki_features)

baseball_wiki_layer_features = LayerFeatures(baseball_wiki_features)
print(f"Loaded {len(baseball_wiki_layer_features)} features for Baseball vs Wikipedia")


## 9. Process Features: Basketball vs Baseball


In [None]:
# Create data config for basketball vs baseball
data_config_basketball_baseball = BasketballBaseballDataConfig(
    n_examples=100,
    data_type="basketball_baseball",
    forget_type="basketball",
    retain_type="baseball"
)

# Clear previous features
crisp.features_dict.clear()
torch.cuda.empty_cache()

# Process basketball vs baseball
print("Processing features: Basketball vs Baseball...")
crisp.process_multi_texts_batch(
    text_target=basketball_sentences,
    text_benign=baseball_sentences,
    data_config=data_config_basketball_baseball,
    batch_size=8
)
print("Feature processing complete for Basketball vs Baseball.")

# Load features for layer 16
basketball_baseball_features = load_cached_features(
    LAYER_TO_ANALYZE,
    data_config_basketball_baseball,
    model_name=MODEL_CARD
)
if basketball_baseball_features is None:
    # If not cached, get from crisp (features are already processed and stored)
    basketball_baseball_features = list(crisp.features_dict[LAYER_TO_ANALYZE].features.values())
else:
    # If cached, it's already a list of Feature objects
    basketball_baseball_features = list(basketball_baseball_features)

basketball_baseball_layer_features = LayerFeatures(basketball_baseball_features)
print(f"Loaded {len(basketball_baseball_layer_features)} features for Basketball vs Baseball")


## 10. Visualization: Plot Feature Activations


In [None]:
# Plot 1: Basketball vs General (Wikipedia)
print("Plot 1: Active features in basketball vs. active features in general text")
plot_features_scatter(
    layer_features=basketball_wiki_layer_features,
    k_features=5,
    top_percentile=0.05,
    title="Basketball vs General Text (Layer 16)"
)


In [None]:
# Plot 2: Baseball vs General (Wikipedia)
print("Plot 2: Active features in baseball vs. active features in general text")
plot_features_scatter(
    layer_features=baseball_wiki_layer_features,
    k_features=5,
    top_percentile=0.05,
    title="Baseball vs General Text (Layer 16)"
)


In [None]:
# Plot 3: Basketball vs Baseball
print("Plot 3: Active features in basketball vs. active features in baseball")
plot_features_scatter(
    layer_features=basketball_baseball_layer_features,
    k_features=5,
    top_percentile=0.05,
    title="Basketball vs Baseball (Layer 16)"
)
