# Basketball/Baseball SAE Feature Analysis

This notebook demonstrates SAE feature analysis for basketball and baseball concepts using LLAMA_3_1_8B model.


## 1. Setup and Configuration
We initialize the environment, select the model (Llama 3.1 8B), and set up VLLM for fast text generation.


In [None]:
%load_ext autoreload
%autoreload 2

import os
import time
import torch
from dataclasses import dataclass
from typing import List
from vllm import LLM, SamplingParams
from globals import LLAMA_3_1_8B
from crisp import CRISP, CRISPConfig, LayerFeatures
from sae import TopkSae
from utils import load_cached_features, save_cached_features, get_feature_tokens
from plot import plot_features_scatter
from datasets import load_dataset
from data import prepare_text
from IPython.display import HTML, display

# Set GPU device
os.environ["CUDA_VISIBLE_DEVICES"] = "2"  # Adjust as needed

# Check for HF token
if 'HF_TOKEN' not in os.environ or os.environ.get('HF_TOKEN') is None:
    print("Warning: HF_TOKEN environment variable not set. You may need it for model access.")

# Configuration for LLAMA_3_1_8B
MODEL_CARD = LLAMA_3_1_8B
LAYER_TO_ANALYZE = 16
SAE_LAYERS = [LAYER_TO_ANALYZE]  # Only layer 16
SAE_SAVE_PATH = "llama_sae_cache"

print(f"Using model: {MODEL_CARD}")
print(f"Operating on layer: {LAYER_TO_ANALYZE}")


  from .autonotebook import tqdm as notebook_tqdm
[2026-01-08 15:48:20] INFO config.py:54: PyTorch version 2.9.0 available.


Using model: meta-llama/Llama-3.1-8B
Operating on layer: 16


## 2. Download SAE for Layer 16


In [2]:
# Download SAE for layer 16 if not already cached
print(f"Checking/Downloading SAE for layer {LAYER_TO_ANALYZE}...")
layer_path = os.path.join(SAE_SAVE_PATH, f"layer_{LAYER_TO_ANALYZE}")
if not os.path.exists(layer_path):
    print(f"Downloading SAE for layer {LAYER_TO_ANALYZE}...")
    TopkSae.download_and_save(layer=LAYER_TO_ANALYZE, save_path=SAE_SAVE_PATH)
else:
    print(f"SAE for layer {LAYER_TO_ANALYZE} already cached.")


Checking/Downloading SAE for layer 16...
SAE for layer 16 already cached.


## 3. Initialize VLLM for Fast Generation


In [None]:
# Initialize VLLM for fast generation
print("Initializing VLLM...")
vllm_model = LLM(
    model=MODEL_CARD,
    trust_remote_code=True,
    gpu_memory_utilization=0.9,
    dtype="bfloat16"
)
print("VLLM initialized successfully.")


Initializing VLLM...
INFO 01-08 15:48:21 [utils.py:253] non-default args: {'trust_remote_code': True, 'dtype': 'bfloat16', 'disable_log_stats': True, 'model': 'meta-llama/Llama-3.1-8B'}


The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.


INFO 01-08 15:48:22 [model.py:514] Resolved architecture: LlamaForCausalLM
INFO 01-08 15:48:22 [model.py:1661] Using max model len 131072


2026-01-08 15:48:23,432	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


INFO 01-08 15:48:23 [scheduler.py:230] Chunked prefill is enabled with max_num_batched_tokens=8192.
[0;36m(EngineCore_DP0 pid=3630299)[0;0m INFO 01-08 15:48:24 [core.py:93] Initializing a V1 LLM engine (v0.13.0) with config: model='meta-llama/Llama-3.1-8B', speculative_config=None, tokenizer='meta-llama/Llama-3.1-8B', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=True, dtype=torch.bfloat16, max_seq_len=131072, download_dir=None, load_format=auto, tensor_parallel_size=1, pipeline_parallel_size=1, data_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, device_config=cuda, structured_outputs_config=StructuredOutputsConfig(backend='auto', disable_fallback=False, disable_any_whitespace=False, disable_additional_properties=False, reasoning_parser='', reasoning_parser_plugin='', enable_in_reasoning=False), observability_config=ObservabilityConfig(show_hidden_metrics_for_ve

Loading safetensors checkpoint shards:   0% Completed | 0/4 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  25% Completed | 1/4 [00:00<00:01,  1.52it/s]
Loading safetensors checkpoint shards:  50% Completed | 2/4 [00:01<00:01,  1.28it/s]
Loading safetensors checkpoint shards:  75% Completed | 3/4 [00:01<00:00,  1.67it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:02<00:00,  1.60it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:02<00:00,  1.55it/s]
[0;36m(EngineCore_DP0 pid=3630299)[0;0m 


[0;36m(EngineCore_DP0 pid=3630299)[0;0m INFO 01-08 15:48:32 [default_loader.py:308] Loading weights took 2.80 seconds
[0;36m(EngineCore_DP0 pid=3630299)[0;0m INFO 01-08 15:48:33 [gpu_model_runner.py:3659] Model loading took 14.9889 GiB memory and 4.502505 seconds
[0;36m(EngineCore_DP0 pid=3630299)[0;0m ERROR 01-08 15:48:42 [core.py:866] EngineCore failed to start.
[0;36m(EngineCore_DP0 pid=3630299)[0;0m ERROR 01-08 15:48:42 [core.py:866] Traceback (most recent call last):
[0;36m(EngineCore_DP0 pid=3630299)[0;0m ERROR 01-08 15:48:42 [core.py:866]   File "/dsi/fetaya-lab/noam_diamant/conda/envs/crisp_env/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 1433, in _compile
[0;36m(EngineCore_DP0 pid=3630299)[0;0m ERROR 01-08 15:48:42 [core.py:866]     guarded_code, tracer_output = compile_inner(code, one_graph, hooks)
[0;36m(EngineCore_DP0 pid=3630299)[0;0m ERROR 01-08 15:48:42 [core.py:866]                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

[0;36m(EngineCore_DP0 pid=3630299)[0;0m Process EngineCore_DP0:
[0;36m(EngineCore_DP0 pid=3630299)[0;0m Traceback (most recent call last):
[0;36m(EngineCore_DP0 pid=3630299)[0;0m   File "/dsi/fetaya-lab/noam_diamant/conda/envs/crisp_env/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 1433, in _compile
[0;36m(EngineCore_DP0 pid=3630299)[0;0m     guarded_code, tracer_output = compile_inner(code, one_graph, hooks)
[0;36m(EngineCore_DP0 pid=3630299)[0;0m                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[0;36m(EngineCore_DP0 pid=3630299)[0;0m   File "/dsi/fetaya-lab/noam_diamant/conda/envs/crisp_env/lib/python3.11/site-packages/torch/_utils_internal.py", line 92, in wrapper_function
[0;36m(EngineCore_DP0 pid=3630299)[0;0m     return function(*args, **kwargs)
[0;36m(EngineCore_DP0 pid=3630299)[0;0m            ^^^^^^^^^^^^^^^^^^^^^^^^^
[0;36m(EngineCore_DP0 pid=3630299)[0;0m   File "/dsi/fetaya-lab/noam_diamant/conda/envs/crisp_env/l

## 4. Generate Basketball and Baseball Sentences


In [None]:
# Create 50 keywords for basketball
basketball_keywords = [
    "dunk", "three-pointer", "rebound", "free throw", "layup",
    "point guard", "shooting guard", "small forward", "power forward", "center",
    "basketball court", "hoop", "backboard", "foul", "traveling",
    "jump shot", "alley-oop", "fast break", "pick and roll", "zone defense",
    "man-to-man defense", "full court press", "half court", "buzzer beater", "slam dunk",
    "NBA", "college basketball", "March Madness", "basketball player", "coach",
    "timeout", "substitution", "technical foul", "flagrant foul", "ejection",
    "basketball team", "championship", "playoffs", "regular season", "overtime",
    "basketball game", "scoring", "assist", "steal", "block",
    "basketball skills", "ball handling", "shooting form", "defensive stance", "offensive play"
]

# Create 50 keywords for baseball
baseball_keywords = [
    "home run", "pitcher", "strikeout", "baseball", "batting",
    "first base", "second base", "third base", "home plate", "outfield",
    "infield", "catcher", "shortstop", "second baseman", "third baseman",
    "baseball diamond", "mound", "batter's box", "dugout", "bullpen",
    "fastball", "curveball", "slider", "changeup", "knuckleball",
    "MLB", "World Series", "baseball game", "inning", "strike",
    "ball", "walk", "hit", "double", "triple",
    "baseball team", "manager", "coach", "umpire", "referee",
    "baseball player", "pitcher's mound", "base running", "stealing base", "bunt",
    "baseball skills", "pitching", "hitting", "fielding", "catching"
]

print(f"Created {len(basketball_keywords)} basketball keywords")
print(f"Created {len(baseball_keywords)} baseball keywords")


In [None]:
# Generate 250 prompts for basketball (using keywords 10 times to get 250)
basketball_prompts = [f"Write a sentence about {keyword}" for keyword in basketball_keywords * 5]

# Generate 250 prompts for baseball (using keywords 10 times to get 250)
baseball_prompts = [f"Write a sentence about {keyword}" for keyword in baseball_keywords * 5]

print(f"Generated {len(basketball_prompts)} basketball prompts")
print(f"Generated {len(baseball_prompts)} baseball prompts")


In [None]:
# Configure sampling parameters for generation
sampling_params = SamplingParams(
    temperature=0.7,
    top_p=0.9,
    max_tokens=50,  # Generate up to 50 tokens per sentence
    stop=None
)

# Generate basketball sentences
print("Generating basketball sentences...")
start_time = time.time()
basketball_outputs = vllm_model.generate(basketball_prompts, sampling_params)
basketball_generation_time = time.time() - start_time

# Extract generated text
basketball_sentences = []
for output in basketball_outputs:
    generated_text = output.outputs[0].text.strip()
    # Combine prompt and generated text to form complete sentence
    prompt = output.prompt
    full_sentence = prompt + " " + generated_text
    basketball_sentences.append(full_sentence)

print(f"Generated {len(basketball_sentences)} basketball sentences")
print(f"Time taken: {basketball_generation_time:.2f} seconds")


In [None]:
# Generate baseball sentences
print("Generating baseball sentences...")
start_time = time.time()
baseball_outputs = vllm_model.generate(baseball_prompts, sampling_params)
baseball_generation_time = time.time() - start_time

# Extract generated text
baseball_sentences = []
for output in baseball_outputs:
    generated_text = output.outputs[0].text.strip()
    # Combine prompt and generated text to form complete sentence
    prompt = output.prompt
    full_sentence = prompt + " " + generated_text
    baseball_sentences.append(full_sentence)

print(f"Generated {len(baseball_sentences)} baseball sentences")
print(f"Time taken: {baseball_generation_time:.2f} seconds")

total_time = basketball_generation_time + baseball_generation_time
print(f"\nTotal generation time: {total_time:.2f} seconds")


## 5. Load Wikipedia Retain Examples


In [None]:
# Load 250 Wikipedia retain examples
print("Loading Wikipedia retain examples...")
wiki_data = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
wiki_texts = prepare_text(wiki_data['text'], max_len=1000)

# Take 250 examples
wiki_retain = list(wiki_texts)[:250]
print(f"Loaded {len(wiki_retain)} Wikipedia retain examples")


## 6. Initialize CRISP for Feature Analysis


In [None]:
# Create a simple data config class for caching
@dataclass
class BasketballBaseballDataConfig:
    max_length: int = 1000
    min_length: int = 100
    n_examples: int = 250
    data_type: str = "basketball_baseball"
    forget_type: str = "basketball"  # For caching compatibility
    retain_type: str = "wiki"  # For caching compatibility
    
    def to_dict(self):
        return {
            "max_length": self.max_length,
            "min_length": self.min_length,
            "n_examples": self.n_examples,
            "data_type": self.data_type,
            "forget_type": self.forget_type,
            "retain_type": self.retain_type
        }

# Initialize CRISP
config = CRISPConfig(
    layers=SAE_LAYERS,
    model_name="llama",  # This will be converted to LLAMA_3_1_8B
    bf16=True
)
crisp = CRISP(config)
print("CRISP initialized successfully.")


## 7. Process Features: Basketball vs Wikipedia


In [None]:
# Create data config for basketball vs wiki
data_config_basketball_wiki = BasketballBaseballDataConfig(
    n_examples=250,
    data_type="basketball_wiki",
    forget_type="basketball",
    retain_type="wiki"
)

# Process basketball vs Wikipedia
print("Processing features: Basketball vs Wikipedia...")
crisp.process_multi_texts_batch(
    text_target=basketball_sentences,
    text_benign=wiki_retain,
    data_config=data_config_basketball_wiki,
    batch_size=32
)
print("Feature processing complete for Basketball vs Wikipedia.")

# Load features for layer 16
basketball_wiki_features = load_cached_features(
    LAYER_TO_ANALYZE,
    data_config_basketball_wiki,
    model_name=MODEL_CARD
)
if basketball_wiki_features is None:
    # If not cached, get from crisp (features are already processed and stored)
    basketball_wiki_features = list(crisp.features_dict[LAYER_TO_ANALYZE].features.values())
else:
    # If cached, it's already a list of Feature objects
    basketball_wiki_features = list(basketball_wiki_features)

basketball_wiki_layer_features = LayerFeatures(basketball_wiki_features)
print(f"Loaded {len(basketball_wiki_layer_features)} features for Basketball vs Wikipedia")


## 8. Process Features: Baseball vs Wikipedia


In [None]:
# Create data config for baseball vs wiki
data_config_baseball_wiki = BasketballBaseballDataConfig(
    n_examples=250,
    data_type="baseball_wiki",
    forget_type="baseball",
    retain_type="wiki"
)

# Clear previous features
crisp.features_dict.clear()
torch.cuda.empty_cache()

# Process baseball vs Wikipedia
print("Processing features: Baseball vs Wikipedia...")
crisp.process_multi_texts_batch(
    text_target=baseball_sentences,
    text_benign=wiki_retain,
    data_config=data_config_baseball_wiki,
    batch_size=8
)
print("Feature processing complete for Baseball vs Wikipedia.")

# Load features for layer 16
baseball_wiki_features = load_cached_features(
    LAYER_TO_ANALYZE,
    data_config_baseball_wiki,
    model_name=MODEL_CARD
)
if baseball_wiki_features is None:
    # If not cached, get from crisp (features are already processed and stored)
    baseball_wiki_features = list(crisp.features_dict[LAYER_TO_ANALYZE].features.values())
else:
    # If cached, it's already a list of Feature objects
    baseball_wiki_features = list(baseball_wiki_features)

baseball_wiki_layer_features = LayerFeatures(baseball_wiki_features)
print(f"Loaded {len(baseball_wiki_layer_features)} features for Baseball vs Wikipedia")


In [None]:
# Clear previous features
import gc
crisp.features_dict.clear()
torch.cuda.empty_cache()
gc.collect()
torch.cuda.empty_cache()

## 9. Process Features: Basketball vs Baseball


In [None]:
# Create data config for basketball vs baseball
data_config_basketball_baseball = BasketballBaseballDataConfig(
    n_examples=250,
    data_type="basketball_baseball",
    forget_type="basketball",
    retain_type="baseball"
)

# Clear previous features
crisp.features_dict.clear()
torch.cuda.empty_cache()

# Process basketball vs baseball
print("Processing features: Basketball vs Baseball...")
crisp.process_multi_texts_batch(
    text_target=basketball_sentences,
    text_benign=baseball_sentences,
    data_config=data_config_basketball_baseball,
    batch_size=64
)
print("Feature processing complete for Basketball vs Baseball.")

# Load features for layer 16
basketball_baseball_features = load_cached_features(
    LAYER_TO_ANALYZE,
    data_config_basketball_baseball,
    model_name=MODEL_CARD
)
if basketball_baseball_features is None:
    # If not cached, get from crisp (features are already processed and stored)
    basketball_baseball_features = list(crisp.features_dict[LAYER_TO_ANALYZE].features.values())
else:
    # If cached, it's already a list of Feature objects
    basketball_baseball_features = list(basketball_baseball_features)

basketball_baseball_layer_features = LayerFeatures(basketball_baseball_features)
print(f"Loaded {len(basketball_baseball_layer_features)} features for Basketball vs Baseball")


## 10. Visualization: Plot Feature Activations


In [None]:
# Plot 1: Basketball vs General (Wikipedia)
print("Plot 1: Active features in basketball vs. active features in general text")
plot_features_scatter(
    layer_features=basketball_wiki_layer_features,
    k_features=5,
    top_percentile=0.05,
    title="Basketball vs General Text (Layer 16)"
)


In [None]:
# Plot 2: Baseball vs General (Wikipedia)
print("Plot 2: Active features in baseball vs. active features in general text")
plot_features_scatter(
    layer_features=baseball_wiki_layer_features,
    k_features=5,
    top_percentile=0.05,
    title="Baseball vs General Text (Layer 16)"
)


In [None]:
# Plot 3: Basketball vs Baseball
print("Plot 3: Active features in basketball vs. active features in baseball")
plot_features_scatter(
    layer_features=basketball_baseball_layer_features,
    k_features=5,
    top_percentile=0.05,
    title="Basketball vs Baseball (Layer 16)"
)


## 11. Inspect Top Features with Neuronpedia

We inspect the top concept-related features to see which tokens activate them.


In [None]:
# Configuration for Neuronpedia
NEURONPEDIA_ID = "llama3.1-8b"
NEURONPEDIA_SOURCE_SUFFIX = "-llamascope-res-32k"

def display_feature_info(feature, layer, model_id, source_suffix, title_prefix=""):
    """Display feature information with top activating tokens from Neuronpedia."""
    feature_index = feature.index
    
    print(f"\n{title_prefix}Feature {feature_index} (Layer {layer})")
    print(f"Target Count: {feature.target_count}, Benign Count: {feature.benign_count}")
    print(f"Target Acts Relative: {feature.target_acts_relative:.4f}")
    
    feature_data = get_feature_tokens(model_id, layer, feature_index, top_k=5)
    
    if feature_data:
        source = f"{layer}{source_suffix}"
        neuronpedia_url = f"https://www.neuronpedia.org/{model_id}/{source}/{feature_index}"
        
        # Create HTML content
        tokens_html = ' '.join([
            f'<span style="background-color: #e1ecf4; color: #2c5282; padding: 2px 8px; border-radius: 4px; margin-right: 5px; display: inline-block; border: 1px solid #b3d4fc;">{token}</span>' 
            for token in feature_data.get('pos_str', [])
        ])
        
        description = feature_data['explanations'][0]['description'] if feature_data.get('explanations') else "No description available"
        
        html_content = f"""
        <div style="border: 1px solid #e0e0e0; padding: 20px; border-radius: 8px; background-color: #ffffff; font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Helvetica, Arial, sans-serif; box-shadow: 0 2px 4px rgba(0,0,0,0.05); max-width: 600px; margin-bottom: 20px;">
            <div style="display: flex; justify-content: space-between; align-items: center; margin-bottom: 15px; border-bottom: 1px solid #eee; padding-bottom: 10px;">
                <h3 style="margin: 0; color: #333; font-size: 1.2em;">Feature {feature_index}</h3>
                <span style="background-color: #f0f0f0; color: #666; padding: 2px 8px; border-radius: 12px; font-size: 0.8em;">Layer {layer}</span>
            </div>
            
            <div style="margin-bottom: 20px;">
                <div style="text-transform: uppercase; font-size: 0.75em; color: #888; margin-bottom: 5px; letter-spacing: 0.5px;">Auto Description</div>
                <div style="font-size: 1.1em; color: #1a1a1a; line-height: 1.4;">{description}</div>
            </div>
            
            <div style="margin-bottom: 20px;">
                <div style="text-transform: uppercase; font-size: 0.75em; color: #888; margin-bottom: 8px; letter-spacing: 0.5px;">Top Tokens (Logit Lens)</div>
                <div style="display: flex; flex-wrap: wrap; gap: 5px;">
                    {tokens_html}
                </div>
            </div>
            
            <div style="text-align: right;">
                <a href="{neuronpedia_url}" target="_blank" style="color: #0969da; text-decoration: none; font-size: 0.9em; font-weight: 500;">View on Neuronpedia &rarr;</a>
            </div>
        </div>
        """
        display(HTML(html_content))
    else:
        print(f"Could not fetch feature data from Neuronpedia for feature {feature_index}")


### Top 5 Basketball Features (Basketball vs Wikipedia)


In [None]:
# Get top 5 basketball features (highest target_acts_relative)
top_basketball_features = basketball_wiki_layer_features.topk_filtered(5)

print("="*80)
print("TOP 5 BASKETBALL FEATURES (Basketball vs Wikipedia)")
print("="*80)

for i, feature in enumerate(top_basketball_features, 1):
    display_feature_info(
        feature, 
        LAYER_TO_ANALYZE, 
        NEURONPEDIA_ID, 
        NEURONPEDIA_SOURCE_SUFFIX,
        title_prefix=f"[{i}] "
    )


### Top 5 Baseball Features (Baseball vs Wikipedia)


In [None]:
# Get top 5 baseball features (highest target_acts_relative)
top_baseball_features = baseball_wiki_layer_features.topk_filtered(5)

print("="*80)
print("TOP 5 BASEBALL FEATURES (Baseball vs Wikipedia)")
print("="*80)

for i, feature in enumerate(top_baseball_features, 1):
    display_feature_info(
        feature, 
        LAYER_TO_ANALYZE, 
        NEURONPEDIA_ID, 
        NEURONPEDIA_SOURCE_SUFFIX,
        title_prefix=f"[{i}] "
    )


### Top 5 Basketball Features (Basketball vs Baseball)


In [None]:
# Get top 5 basketball features when compared to baseball (highest target_acts_relative)
top_basketball_vs_baseball_features = basketball_baseball_layer_features.topk_filtered(5)

print("="*80)
print("TOP 5 BASKETBALL FEATURES (Basketball vs Baseball)")
print("="*80)

for i, feature in enumerate(top_basketball_vs_baseball_features, 1):
    display_feature_info(
        feature, 
        LAYER_TO_ANALYZE, 
        NEURONPEDIA_ID, 
        NEURONPEDIA_SOURCE_SUFFIX,
        title_prefix=f"[{i}] "
    )


### Top 5 Baseball Features (Basketball vs Baseball)


In [None]:
# Get top 5 baseball features when compared to basketball (highest benign_acts_relative)
# For basketball vs baseball, "benign" is baseball, so we want features with high benign_acts_relative
top_baseball_vs_basketball_features = basketball_baseball_layer_features.bottomk_filtered(5)

print("="*80)
print("TOP 5 BASEBALL FEATURES (Basketball vs Baseball)")
print("="*80)
print("Note: These are features highly active on baseball (benign) compared to basketball (target)")

for i, feature in enumerate(top_baseball_vs_basketball_features, 1):
    # For this comparison, we want to show features that are more active on baseball
    # So we'll display the benign_acts_relative instead
    print(f"\n[{i}] Feature {feature.index} (Layer {LAYER_TO_ANALYZE})")
    print(f"Basketball Count: {feature.target_count}, Baseball Count: {feature.benign_count}")
    print(f"Baseball Acts Relative: {feature.benign_acts_relative:.4f}")
    
    feature_data = get_feature_tokens(NEURONPEDIA_ID, LAYER_TO_ANALYZE, feature.index, top_k=5)
    
    if feature_data:
        source = f"{LAYER_TO_ANALYZE}{NEURONPEDIA_SOURCE_SUFFIX}"
        neuronpedia_url = f"https://www.neuronpedia.org/{NEURONPEDIA_ID}/{source}/{feature.index}"
        
        tokens_html = ' '.join([
            f'<span style="background-color: #e1ecf4; color: #2c5282; padding: 2px 8px; border-radius: 4px; margin-right: 5px; display: inline-block; border: 1px solid #b3d4fc;">{token}</span>' 
            for token in feature_data.get('pos_str', [])
        ])
        
        description = feature_data['explanations'][0]['description'] if feature_data.get('explanations') else "No description available"
        
        html_content = f"""
        <div style="border: 1px solid #e0e0e0; padding: 20px; border-radius: 8px; background-color: #ffffff; font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Helvetica, Arial, sans-serif; box-shadow: 0 2px 4px rgba(0,0,0,0.05); max-width: 600px; margin-bottom: 20px;">
            <div style="display: flex; justify-content: space-between; align-items: center; margin-bottom: 15px; border-bottom: 1px solid #eee; padding-bottom: 10px;">
                <h3 style="margin: 0; color: #333; font-size: 1.2em;">Feature {feature.index}</h3>
                <span style="background-color: #f0f0f0; color: #666; padding: 2px 8px; border-radius: 12px; font-size: 0.8em;">Layer {LAYER_TO_ANALYZE}</span>
            </div>
            
            <div style="margin-bottom: 20px;">
                <div style="text-transform: uppercase; font-size: 0.75em; color: #888; margin-bottom: 5px; letter-spacing: 0.5px;">Auto Description</div>
                <div style="font-size: 1.1em; color: #1a1a1a; line-height: 1.4;">{description}</div>
            </div>
            
            <div style="margin-bottom: 20px;">
                <div style="text-transform: uppercase; font-size: 0.75em; color: #888; margin-bottom: 8px; letter-spacing: 0.5px;">Top Tokens (Logit Lens)</div>
                <div style="display: flex; flex-wrap: wrap; gap: 5px;">
                    {tokens_html}
                </div>
            </div>
            
            <div style="text-align: right;">
                <a href="{neuronpedia_url}" target="_blank" style="color: #0969da; text-decoration: none; font-size: 0.9em; font-weight: 500;">View on Neuronpedia &rarr;</a>
            </div>
        </div>
        """
        display(HTML(html_content))
    else:
        print(f"Could not fetch feature data from Neuronpedia for feature {feature.index}")
