In [None]:
%load_ext autoreload
%autoreload 2

import os
import numpy as np
import pandas as pd
import torch

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

os.environ['OPENAI_KEY_SAE'] = ''

from hypothesaes.embedding import get_openai_embeddings
from hypothesaes.sae import SparseAutoencoder, load_model
from hypothesaes.interpret_neurons import NeuronInterpreter, SamplingConfig, LLMConfig, InterpretConfig, ScoringConfig
from hypothesaes.annotate import annotate_texts_with_concepts
from hypothesaes.evaluation import score_hypotheses
from hypothesaes.select_neurons import select_neurons

Using device: mps


In [2]:
train_df = pd.read_csv('/Users/ivanculo/Desktop/Projects/SAEs/HypotheSAEs/demo-data/reddit-depression-train.csv')
val_df = pd.read_csv('/Users/ivanculo/Desktop/Projects/SAEs/HypotheSAEs/demo-data/reddit-depression-test.csv')


text_col = 'text'
target_col = 'label'

train_texts = train_df[text_col].tolist()
val_texts = val_df[text_col].tolist()

# Compute embeddings for each split
EMBEDDER = "text-embedding-3-small"
CACHE_NAME = f"yelp_demo_{EMBEDDER}"
text2embedding = get_openai_embeddings(
    texts=train_texts + val_texts,
    model=EMBEDDER,
    cache_name=CACHE_NAME,
)

train_embeddings = np.array([text2embedding[text] for text in train_texts])
val_embeddings = np.array([text2embedding[text] for text in val_texts])

Loading embedding chunks:   0%|          | 0/1 [00:00<?, ?it/s]

Loaded 7581 embeddings in 0.2s


In [3]:
# Train and get activations from two SAEs with different parameters
X_train = torch.tensor(train_embeddings, dtype=torch.float32).to(device)
X_val = torch.tensor(val_embeddings, dtype=torch.float32).to(device)

# Define parameters for two SAEs
sae_params = [
    {"M": 256, "K": 8},
    {"M": 64, "K": 4}
]

models = []
activations_list = []
neuron_source_info = []

for params in sae_params:
    M, K = params["M"], params["K"]
    save_dir = f'./checkpoints/{CACHE_NAME}'
    save_path = f'{save_dir}/SAE_M={M}_K={K}.pt'
    
    # Initialize and train (or load) the SAE model
    if os.path.exists(save_path):
        print(f"Loading existing model: M={M}, K={K}")
        model = load_model(save_path).to(device)
    else:
        print(f"Training new model: M={M}, K={K}")
        model = SparseAutoencoder(
            input_dim=X_train.shape[1],
            m_total_neurons=M,
            k_active_neurons=K,
            # Optional parameters:
            # aux_k=None,  # Number of neurons for dead neuron revival (None=default)
            # multi_k=None,  # Number of neurons for secondary reconstruction
            # dead_neuron_threshold_steps=256,  # Number of non-firing steps after which a neuron is considered dead
        ).to(device)
        
        model.fit(
            X_train=X_train,
            X_val=X_val,
            n_epochs=100,
            save_dir=save_dir,
            # Optional parameters:
            # batch_size=512,
            # learning_rate=5e-4,
            # aux_coef=1/32,  # Coefficient for auxiliary loss
            # multi_coef=0.0,  # Coefficient for multi-k loss
            # patience=3,     # Early stopping patience
            # clip_grad=1.0,  # Gradient clipping value
        )
    
    models.append(model)
    
    # Get activations from this model
    model_activations = model.get_activations(X_train)
    activations_list.append(model_activations)
    
    # Track source information for each neuron
    neuron_source_info.extend([(M, K) for i in range(M)])

# Concatenate activations from both models
train_activations = np.concatenate(activations_list, axis=1)
print(f"Neuron activations shape (from {len(sae_params)} models): {train_activations.shape}")

Loading existing model: M=256, K=8
Loaded model from ./checkpoints/yelp_demo_text-embedding-3-small/SAE_M=256_K=8.pt onto device mps


Computing activations (batchsize=16384):   0%|          | 0/1 [00:00<?, ?it/s]

Loading existing model: M=64, K=4
Loaded model from ./checkpoints/yelp_demo_text-embedding-3-small/SAE_M=64_K=4.pt onto device mps


Computing activations (batchsize=16384):   0%|          | 0/1 [00:00<?, ?it/s]

Neuron activations shape (from 2 models): (6128, 320)


In [4]:
# Select neurons using "lasso", "separation_score", or "correlation"
selection_method = "correlation"
top_neuron_count = 35

selected_neurons, scores = select_neurons(
    activations=train_activations,
    target=train_df[target_col].values,
    n_select=top_neuron_count,
    method=selection_method,
    # Optional parameters depend on selection method; see select_neurons.py
)

In [5]:
TASK_SPECIFIC_INSTRUCTIONS = """
All texts are expressions of cognitive patterns and mental states. Your task is to identify and describe features that capture HOW people think, process information, and manage mental states, rather than just WHAT they think about.

Focus on cognitive processes such as:
- Information processing patterns (analytical vs. intuitive, systematic vs. scattered)
- Mental effort and cognitive load indicators (complexity, elaboration, fatigue signs)
- Reasoning styles (logical, emotional, biased, flexible)
- Metacognitive awareness (self-reflection, monitoring one's own thinking)
- Cognitive control and regulation (emotional management, attention control)
- Mental flexibility vs. rigidity (openness to change, fixed thinking patterns)
- Cognitive biases and heuristics (shortcuts, systematic errors in thinking)
- Pattern recognition and connection-making abilities

Features should be formulated as descriptive statements that explain the cognitive mechanism being expressed through language patterns, word choice, sentence structure, reasoning flow, and conceptual organization.

Examples:
- "demonstrates analytical thinking through systematic breakdown of complex ideas into component parts"
- "exhibits cognitive load through fragmented sentence structures and incomplete thoughts"
- "shows metacognitive awareness by explicitly reflecting on and questioning one's own reasoning process"
- "manifests cognitive flexibility by actively considering and integrating opposing viewpoints"
- "reveals confirmation bias through selective attention to supporting evidence while dismissing contradictory information"
"""

interpreter = NeuronInterpreter(
    interpreter_model="gpt-4o",
    annotator_model="gpt4.1-nano",
    n_workers_interpretation=10,
    n_workers_annotation=50,
    cache_name=CACHE_NAME,
)

interpret_config = InterpretConfig(
    sampling=SamplingConfig(
        n_examples=20,
        max_words_per_example=128,
    ),
    llm=LLMConfig(
        temperature=0.7,
        max_interpretation_tokens=75,
    ),
    n_candidates=1,
    task_specific_instructions=TASK_SPECIFIC_INSTRUCTIONS,
)

interpretations = interpreter.interpret_neurons(
    texts=train_texts,
    activations=train_activations,
    neuron_indices=selected_neurons,
    config=interpret_config,
)

Generating 1 interpretation(s) per neuron:   0%|          | 0/35 [00:00<?, ?it/s]

In [6]:
scoring_config = ScoringConfig(
    n_examples=200,
    max_words_per_example=128,
)

all_metrics = interpreter.score_interpretations(
    texts=train_texts,
    activations=train_activations,
    interpretations=interpretations,
    config=scoring_config,
)

Found 0 cached items; annotating 7000 uncached items


Scoring neuron interpretation fidelity (35 neurons; 1 candidate interps per neuron; 200 examples to score each…

In [7]:
# Create DataFrame with best and worst interpretations
interpretations_data = []
for neuron_idx in selected_neurons:
    neuron_metrics = all_metrics[neuron_idx]
    best_interp, best_metrics = max(neuron_metrics.items(), key=lambda x: x[1]['f1'])
    worst_interp, worst_metrics = min(neuron_metrics.items(), key=lambda x: x[1]['f1'])
    
    interpretations_data.append({
        'neuron_idx': neuron_idx,
        'source_sae': neuron_source_info[neuron_idx],
        f'{selection_method}': scores[selected_neurons.index(neuron_idx)],
        'best_interpretation': best_interp,
        'best_f1': best_metrics['f1'],
        'worst_interpretation': worst_interp,
        'worst_f1': worst_metrics['f1']
    })

best_interp_df = pd.DataFrame(interpretations_data).sort_values(by=f'{selection_method}', ascending=False)

display(
    best_interp_df.style.format({
        'separation_score': '{:.2f}',
        'best_f1': '{:.2f}', 
        'worst_f1': '{:.2f}'
    })
)
# Save DataFrame to CSV file in root folder
best_interp_df.to_csv('./best_interpretations.csv', index=False)
print(f"DataFrame saved to: ./best_interpretations.csv")


Unnamed: 0,neuron_idx,source_sae,correlation,best_interpretation,best_f1,worst_interpretation,worst_f1
0,303,"(64, 4)",0.370044,"demonstrates cognitive overload through detailed recounting of emotionally intense personal experiences, often accompanied by fragmented thoughts and a lack of clear resolution",0.78,"demonstrates cognitive overload through detailed recounting of emotionally intense personal experiences, often accompanied by fragmented thoughts and a lack of clear resolution",0.78
1,315,"(64, 4)",0.36619,"demonstrates dismissive or reductive reasoning by trivializing or minimizing the concept of depression through casual language, slang, and humor",0.52,"demonstrates dismissive or reductive reasoning by trivializing or minimizing the concept of depression through casual language, slang, and humor",0.52
2,131,"(256, 8)",0.346054,"demonstrates metacognitive awareness by expressing inner conflict and self-reflection about one's emotional state and actions, often questioning personal beliefs or societal norms",0.76,"demonstrates metacognitive awareness by expressing inner conflict and self-reflection about one's emotional state and actions, often questioning personal beliefs or societal norms",0.76
3,98,"(256, 8)",0.335258,"uses colloquial and casual language to express emotional states, often with abbreviated or fragmented phrasing that conveys a dismissive or resigned attitude toward the concept of depression",0.4,"uses colloquial and casual language to express emotional states, often with abbreviated or fragmented phrasing that conveys a dismissive or resigned attitude toward the concept of depression",0.4
4,262,"(64, 4)",0.308485,expresses heightened metacognitive awareness by frequently describing and analyzing one's own physical sensations and emotional states in detail,0.78,expresses heightened metacognitive awareness by frequently describing and analyzing one's own physical sensations and emotional states in detail,0.78
5,295,"(64, 4)",0.293443,"demonstrates cognitive overload and emotional exhaustion through repetitive expressions of hopelessness, self-doubt, and perceived lack of personal value in life",0.81,"demonstrates cognitive overload and emotional exhaustion through repetitive expressions of hopelessness, self-doubt, and perceived lack of personal value in life",0.81
6,272,"(64, 4)",0.276806,"expresses a sense of cognitive exhaustion and mental overload through repetitive phrasing, fragmented expressions of hopelessness, and a focus on the inability to endure or escape current circumstances",0.77,"expresses a sense of cognitive exhaustion and mental overload through repetitive phrasing, fragmented expressions of hopelessness, and a focus on the inability to endure or escape current circumstances",0.77
8,62,"(256, 8)",0.258017,"expresses metacognitive awareness by explicitly seeking strategies, tools, or advice to manage and understand their own anxiety or emotional state",0.79,"expresses metacognitive awareness by explicitly seeking strategies, tools, or advice to manage and understand their own anxiety or emotional state",0.79
10,273,"(64, 4)",0.233584,"demonstrates a pervasive focus on suicidal ideation and planning, characterized by detailed descriptions of methods, potential consequences, and emotional rationalizations for self-harm or death",0.83,"demonstrates a pervasive focus on suicidal ideation and planning, characterized by detailed descriptions of methods, potential consequences, and emotional rationalizations for self-harm or death",0.83
11,293,"(64, 4)",0.226236,demonstrates metacognitive awareness by explicitly reflecting on the effectiveness of therapeutic approaches and questioning personal progress in managing mental health challenges,0.18,demonstrates metacognitive awareness by explicitly reflecting on the effectiveness of therapeutic approaches and questioning personal progress in managing mental health challenges,0.18


DataFrame saved to: ./best_interpretations.csv


In [9]:
import csv
with open('interpretations_data.csv', 'w', newline='', encoding='utf-8') as csvfile:
    if interpretations_data:
        writer = csv.DictWriter(csvfile, fieldnames=interpretations_data[0].keys())
        writer.writeheader()
        writer.writerows(interpretations_data)
    else:
        print("interpretations_data is empty, nothing to write.")
print("interpretations_data saved to interpretations_data.csv in the current folder.")

interpretations_data saved to interpretations_data.csv in the current folder.


In [32]:
# Sample 500 random examples from holdout set
np.random.seed(42)
holdout_df = pd.read_json(f'./demo-data/demo-holdout-2K.json', lines=True)
holdout_texts = holdout_df[text_col].tolist()
holdout_labels = holdout_df[target_col].values

# Annotate texts with best interpretations
holdout_annotations = annotate_texts_with_concepts(
    texts=holdout_texts,
    concepts=best_interp_df['best_interpretation'].tolist(),
    max_words_per_example=128,
    cache_name=CACHE_NAME,
    n_workers=50,
)

# Evaluate on holdout set
metrics, hypothesis_df = score_hypotheses(
    hypothesis_annotations=holdout_annotations,
    y_true=holdout_labels,
    classification=False,
)

pd.set_option('display.max_colwidth', None)
display(hypothesis_df.round(3))
pd.reset_option('display.max_colwidth')

print("\nHoldout Set Metrics:")
print(f"R² Score: {metrics['r2']:.3f}")
print(f"Significant hypotheses: {metrics['Significant'][0]}/{metrics['Significant'][1]} " 
      f"(p < {metrics['Significant'][2]:.3e})")

Found 0 cached items; annotating 40000 uncached items


Annotating:   0%|          | 0/40000 [00:00<?, ?it/s]

Unnamed: 0,hypothesis,separation_score,separation_pval,regression_coef,regression_pval,feature_prevalence
2,"expresses personal love or favorite status for the restaurant, using phrases like 'I love this place', 'my favorite restaurant', or 'go-to spot'",1.468,0.0,0.378,0.0,0.389
0,emphasizes the friendliness and warmth of the staff using enthusiastic language,1.271,0.0,0.243,0.0,0.301
3,explicitly mentions a desire or intention to return to the restaurant,1.013,0.0,0.123,0.005,0.203
1,mentions specific positive interactions with named staff members,0.886,0.0,0.085,0.153,0.096
6,mentions long wait times or delays in receiving service or food,-1.122,0.0,0.267,0.0,0.138
10,"mentions long wait times for food or service, often with specific durations given (e.g., 20 minutes, 30 minutes, over an hour)",-1.435,0.0,-0.31,0.0,0.06
17,expresses disappointment in the flavor or seasoning of the food,-1.69,0.0,0.177,0.031,0.213
12,"mentions errors or mistakes in the food order, such as missing items, incorrect ingredients, or receiving the wrong dish",-1.805,0.0,0.01,0.868,0.144
16,expresses disappointment with food being bland or lacking flavor,-1.863,0.0,-0.266,0.002,0.166
7,mentions repeated issues with the restaurant's service or food quality across multiple visits or orders,-2.086,0.0,0.128,0.054,0.136



Holdout Set Metrics:
R² Score: 0.738
Significant hypotheses: 11/20 (p < 5.000e-03)
