# Overview & Current Questions
This is for validating my clustering idea:
1. Store SAE activations for all prompts
2. Preprocess activations based on entropy and activation level (optional?)
3. Cluster activations using UMAP and HDBSCAN
4. Analyze clusters

Some questions: 
1. How does entropy, activation preprocessing affect clustering? How does n_prompts, length, affect this step?
2. How useful is UMAP? What are the best parameters?
3. How does HDBSCAN perform? What are the best parameters?
4. How does the clustering change when we use different datasets?
5. What commonalities do clusters have? Does this vary by dataset? Hyperparameter?

### Imports, config, model setup

In [42]:
%load_ext autoreload
%autoreload 2

from sae_lens import SAE, HookedSAETransformer
import torch
import gc
from config import config # cfg auto updates

import random
from datasets import load_dataset
import os

import plotly.express as px
import plotly.graph_objects as go
import pandas as pd

import numpy as np

from sklearn.cluster import DBSCAN
from sklearn.metrics import silhouette_score
import umap.umap_ as umap
import plotly.graph_objects as go
from einops import rearrange
import hdbscan

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [43]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = HookedSAETransformer.from_pretrained("EleutherAI/pythia-70m-deduped", device=device)
sae, _, _ = SAE.from_pretrained(
    release="pythia-70m-deduped-mlp-sm",
    sae_id="blocks.3.hook_mlp_out",
    device=device
)

Loaded pretrained model EleutherAI/pythia-70m-deduped into HookedTransformer


## Loading Activations from get_data.ipynb

In [44]:
def load_processed_data(filename):
    """Load processed data from disk."""
    if os.path.exists(filename):
        print(f"Loading processed data from {filename}...")
        try:
            data = torch.load(filename)
            print("Data loaded successfully.")
            return data
        except Exception as e:
            print(f"Error loading data: {e}")
            return None
    else:
        print(f"No cached data found at {filename}")
        return None
    
def get_cache_filename(config, n_prompts):
    """Generate a cache filename based on hierarchical config parameters."""
    # Create a unique filename based on key parameters
    
    params = [
        f"prompts_{n_prompts}",
    ]
    print(params)
    
    # Create directory if it doesn't exist
    cache_dir = config.get('cache_dir', 'cache')
    os.makedirs(cache_dir, exist_ok=True)
    
    return os.path.join(cache_dir, f"processed_data_{'_'.join(params)}.pt") 

# Determine whether to use cached data 
n_prompts = config['n_prompts']
use_cached_data = config['use_cached_data']
cache_filename = get_cache_filename(config, n_prompts)
cached_data = load_processed_data(cache_filename)


acts = cached_data['acts']
prompts = cached_data.get('prompts', [])


['prompts_1000']
Loading processed data from feature_cache/processed_data_prompts_1000.pt...
Data loaded successfully.


# Entropy and sparsity

## Filtering using entropy and sparsity

In [45]:
def get_entropy_sparsity_varentropy(acts, config, n_prompts):
    
    # Calculate entropy and sparsity
    activations = acts.abs()
    probs = activations / (activations.sum(dim=0) + 1e-10)
    entropy = -torch.sum(probs * torch.log(probs + 1e-10), dim=0)

    # Correct varentropy calculation
    # We want variance of entropy across prompts for each feature
    log_probs = torch.log(probs + 1e-10)
    entropy_per_prompt = -(probs * log_probs)  # Shape: [n_prompts, n_features]
    mean_entropy = entropy_per_prompt.mean(dim=0)  # Shape: [n_features]
    varentropy = torch.mean((entropy_per_prompt - mean_entropy.unsqueeze(0))**2, dim=0)  # Shape: [n_features]
    # variance of entropy is a measure of how much the entropy varies across prompts 
    # i posit that this will be useful for clustering
    
    activation_threshold = config.get('activation_threshold', 0.1)
    sparsity = (acts.abs() > activation_threshold).float().mean(dim=0)

    if config['verbose']:
        # After calculating entropy and sparsity
        print(f"\nEntropy stats:")
        print(f"Min: {entropy.min().item():.3f}")
        print(f"Max: {entropy.max().item():.3f}")
        print(f"Mean: {entropy.mean().item():.3f}")

        print(f"\nSparsity stats:")
        print(f"Min: {sparsity.min().item():.3f}")
        print(f"Max: {sparsity.max().item():.3f}")
        print(f"Mean: {sparsity.mean().item():.3f}")

        # Print the thresholds being used
        print(f"\nThresholds:")
        print(f"Entropy: [{config['entropy_threshold_low']}, {config['entropy_threshold_high']}]")
        print(f"Sparsity: [{config['sparsity_min']}, {config['sparsity_max']}]")

    return entropy, sparsity, varentropy

def filter_features(acts, config):
    """
    Filter features based on entropy and sparsity.
    We are looking at the feature behavior across prompts (we could modify this for tokens, or even for tokens within prompts)
    """

    n_prompts, n_features = acts.shape
    entropy, sparsity, varentropy = get_entropy_sparsity_varentropy(acts, config, n_prompts)
    
    # Apply filtering mask
    mask = (entropy > config['entropy_threshold_low']) & \
           (entropy < config['entropy_threshold_high']) & \
           (sparsity > config['sparsity_min']) & \
           (sparsity < config['sparsity_max'])
    
    print(f"Kept {mask.sum().item()} out of {n_features} features")
    return acts[:, mask], mask.nonzero(as_tuple=True)[0], entropy, sparsity, varentropy




filtered_acts, original_indices, entropy_data, sparsity_data, varentropy_data = filter_features(acts, config)
print(f' after filtering acts.shape: {acts.shape}')

if config['visualize_features']:
    # Create a DataFrame for entropy and sparsity
    df = pd.DataFrame({
        'entropy': entropy_data.cpu().numpy(),
        'sparsity': sparsity_data.cpu().numpy(),
        'varentropy': varentropy_data.cpu().numpy(),
    })

    px.scatter(df, x='entropy', y='sparsity').show()
    px.scatter(df, x='entropy', y='varentropy').show()
    px.scatter(df, x='sparsity', y='varentropy').show() 

OutOfMemoryError: CUDA out of memory. Tried to allocate 126.00 MiB. GPU 0 has a total capacity of 5.78 GiB of which 72.12 MiB is free. Process 214109 has 6.15 MiB memory in use. Including non-PyTorch memory, this process has 1.73 GiB memory in use. Process 231363 has 1.17 GiB memory in use. Of the allocated memory 1.43 GiB is allocated by PyTorch, and 212.62 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

## Entropy and sparsity visualization

In [41]:
def analyze_feature_quadrants(filtered_acts, n_examples=5):
    """Analyze features based on their entropy/varentropy quadrants.
    
    Args:
        acts: [n_prompts, n_features] activation tensor
        entropy: [n_features] entropy per feature
        varentropy: [n_features] variance of entropy per feature
        n_examples: number of example features to show per quadrant
    """

    filtered_entropy, filtered_sparsity, filtered_varentropy = get_entropy_sparsity_varentropy(filtered_acts, config, n_prompts)

    # Get medians for quadrant splitting
    entropy_median = filtered_entropy.median()
    varentropy_median = filtered_varentropy.median()
    
    # Create quadrant masks
    q1 = (filtered_entropy <= entropy_median) & (filtered_varentropy <= varentropy_median)  # Low E, Low VE
    q2 = (filtered_entropy > entropy_median) & (filtered_varentropy <= varentropy_median)   # High E, Low VE
    q3 = (filtered_entropy <= entropy_median) & (filtered_varentropy > varentropy_median)   # Low E, High VE
    q4 = (filtered_entropy > entropy_median) & (filtered_varentropy > varentropy_median)    # High E, High VE
    
    quadrants = {
        "Flowing (Low E, Low VE)": q1,
        "Careful (High E, Low VE)": q2,
        "Exploring (Low E, High VE)": q3,
        "Resampling (High E, High VE)": q4
    }
    
    # print("Feature Distribution in Quadrants:")
    for name, mask in quadrants.items():
        n_features = mask.sum().item()
        # print(f"\n{name}: {n_features} features ({n_features/len(filtered_entropy):.1%})")
        
        # Get example features from this quadrant
        feature_indices = mask.nonzero(as_tuple=True)[0]
        if len(feature_indices) > 0:
            sample_indices = feature_indices[torch.randperm(len(feature_indices))[:n_examples]]
            
            # print("\nExample features:")
            for idx in sample_indices:
                # Get activation statistics for this feature
                feature_acts = acts[:, idx]
                active_prompts = (feature_acts.abs() > config['activation_threshold']).sum().item()
                max_activation = feature_acts.abs().max().item()
                
                # print(f"\nFeature {idx}:")
                # print(f"  Entropy: {filtered_entropy[idx]:.3f}")
                # print(f"  Varentropy: {filtered_varentropy[idx]:.3f}")
                # print(f"  Active in {active_prompts}/{len(acts)} prompts")
                # print(f"  Max activation: {max_activation:.3f}")
                
    return quadrants, filtered_entropy, filtered_varentropy

# Create a visualization of the quadrants
def plot_entropy_quadrants(entropy, varentropy, quadrants):
    """Create a scatter plot showing feature distribution across quadrants."""
    import plotly.express as px
    import pandas as pd
    
    # Create DataFrame
    df = pd.DataFrame({
        'entropy': entropy.cpu().numpy(),
        'varentropy': varentropy.cpu().numpy(),
        'quadrant': 'Unknown'
    })
    
    # Assign quadrant labels
    for name, mask in quadrants.items():
        df.loc[mask.cpu().numpy(), 'quadrant'] = name
    
    # Create scatter plot
    fig = px.scatter(df, x='entropy', y='varentropy', color='quadrant',
                    title='Feature Distribution across Entropy/Varentropy Quadrants',
                    labels={'entropy': 'Entropy', 'varentropy': 'Variance of Entropy'})
    
    # Add median lines
    fig.add_hline(y=varentropy.median().item(), line_dash="dash", line_color="gray")
    fig.add_vline(x=entropy.median().item(), line_dash="dash", line_color="gray")
    
    fig.show()

if config['visualize_features']:
    quadrants, filtered_entropy, filtered_varentropy = analyze_feature_quadrants(filtered_acts)
    plot_entropy_quadrants(filtered_entropy, filtered_varentropy, quadrants)

# Clustering

## Applying UMAP and HDBSCAN

In [None]:
from sklearn.metrics.pairwise import cosine_distances


def apply_umap_preprocessing(normalized_acts, config):

    
    n_samples, n_dims = normalized_acts.shape
        
    target_dims = config['umap_components']
    
    reducer = umap.UMAP(
        n_components=target_dims,
        n_neighbors=min(config['umap_neighbors'], n_samples - 1),
        min_dist=config['umap_min_dist'],
        metric=config['umap_metric'],
        random_state=42
    )
    
    try:
        reduced_acts = reducer.fit_transform(normalized_acts)
        print(f"Applied UMAP reduction: {n_dims}d → {target_dims}d")
        return reduced_acts, normalized_acts
    except Exception as e:
        print(f"Error during UMAP reduction: {e}")
        return normalized_acts, normalized_acts


def run_hdbscan(reduced_acts, config):
    """Run HDBSCAN clustering."""
    # clusterer = hdbscan.HDBSCAN(
    #     min_cluster_size=config['hdbscan_min_cluster_size'],
    #     min_samples=config['hdbscan_min_samples'],
    #     metric=config['hdbscan_metric'],
    #     cluster_selection_epsilon=config['hdbscan_cluster_selection_epsilon']
    # ) 

    clusterer = hdbscan.HDBSCAN(algorithm='best', alpha=1.0, approx_min_span_tree=True,
    gen_min_span_tree=False, leaf_size=40, memory=Memory(None),
    metric='precomputed', min_cluster_size=5, min_samples=None, p=None)

    return clusterer.fit_predict(distance_matrix)

def cluster_features(acts, config):
    """Cluster features using UMAP preprocessing (if needed) and HDBSCAN."""
    
    acts = acts.T.cpu().numpy()
    normalized_acts = acts / (np.linalg.norm(acts, axis=1, keepdims=True) + 1e-10)
    
    # Preprocess with UMAP if enabled
    if config['use_umap']:
        reduced_acts, normalized_acts = apply_umap_preprocessing(normalized_acts, config)
    else:
        reduced_acts = normalized_acts
    
    # Run HDBSCAN clustering
    labels = run_hdbscan(reduced_acts, config)
    
    n_clusters = len(np.unique(labels)) - (1 if -1 in labels else 0)
    noise_ratio = np.mean(labels == -1) if -1 in labels else 0
    print(f"Clustering complete: found {n_clusters} clusters with {noise_ratio:.2%} noise points")
    
    return labels, normalized_acts


labels, reduced_acts = cluster_features(filtered_acts, config)

print(f' after clustering reduced_acts.shape: {reduced_acts.shape}')
print(f' after clustering labels.shape: {labels.shape}')

Clustering complete: found 3 clusters with 75.15% noise points
 after clustering reduced_acts.shape: (7272, 1000)
 after clustering labels.shape: (7272,)


## Cluster prompt analysis

In [189]:
# Track results for each cluster
cluster_analysis = {}

# Get unique valid cluster labels (including noise points)
all_labels = np.unique(labels)
unique_labels = [label for label in all_labels if label != -1]

# Calculate global statistics
total_features = len(labels)
noise_points = np.sum(labels == -1)
noise_ratio = noise_points / total_features if total_features > 0 else 0

print(f"\n=== Cluster Analysis Summary ===")
print(f"Total Features: {total_features}")
print(f"Number of Clusters: {len(unique_labels)}")
print(f"Noise Points: {noise_points} ({noise_ratio:.2%})")

# For each cluster (including noise)
for label in all_labels:
    cluster_type = "Noise Cluster" if label == -1 else f"Cluster {label}"
    cluster_mask = labels == label
    cluster_size = np.sum(cluster_mask)
    cluster_indices = original_indices[cluster_mask]
    
    print(f"\n=== {cluster_type} ===")
    print(f"Size: {cluster_size} features ({(cluster_size/total_features):.2%} of total)")
    print(f"Feature Indices: {cluster_indices.tolist()}")
    
    if label == -1:  # Skip activation analysis for noise cluster
        cluster_analysis[label] = {
            'size': cluster_size,
            'indices': cluster_indices.tolist(),
            'ratio': cluster_size/total_features,
            'is_noise': True
        }
        continue
        
    # Get activations for all prompts on this cluster's features
    print(reduced_acts.shape)
    print(cluster_mask.shape)
    cluster_acts = filtered_acts[:, cluster_mask]  # [n_prompts, n_cluster_features]
    
    # Compute activation statistics
    mean_activation = torch.mean(torch.abs(cluster_acts)).item()
    max_activation = torch.max(torch.abs(cluster_acts)).item()
    std_activation = torch.std(torch.abs(cluster_acts)).item()
    sparsity = (torch.abs(cluster_acts) > 0.1).float().mean().item()
    
    print(f"\nActivation Statistics:")
    print(f"  Mean Activation: {mean_activation:.4f}")
    print(f"  Max Activation: {max_activation:.4f}")
    print(f"  Std Deviation: {std_activation:.4f}")
    print(f"  Sparsity: {sparsity:.4f}")
    
    # Compute average activation of each prompt on this cluster's features
    prompt_activations = torch.mean(torch.abs(cluster_acts), dim=1)  # [n_prompts]
    
    # Find top activating prompts
    top_k = min(5, len(prompts))
    top_prompt_indices = torch.argsort(prompt_activations, descending=True)[:top_k]
    top_prompts = [(prompts[i], prompt_activations[i].item()) for i in top_prompt_indices]
    
    print(f"\nTop {top_k} Activating Prompts:")
    for i, (prompt, act) in enumerate(top_prompts, 1):
        # Show first 100 chars, with special formatting for section headers
        truncated = prompt[:100]
        if len(prompt) > 100:
            truncated += "..."
        
        # Format section headers more clearly
        if "=" in truncated:
            sections = [s.strip() for s in truncated.split("=") if s.strip()]
            if sections:
                truncated = f"[SECTION] {' > '.join(sections)}"
        
        # Split into tokens if the prompt contains spaces
        tokens = truncated.split()
        if len(tokens) > 15:
            token_display = " ".join(tokens[:15]) + " ..."
        else:
            token_display = truncated
            
        print(f"  {i}. \"{token_display}\"")
        print(f"     Activation: {act:.4f}")
        print(f"     Total Length: {len(prompt)} chars, {len(prompt.split())} tokens")
    # Store detailed results
    cluster_analysis[label] = {
        'size': cluster_size,
        'indices': cluster_indices.tolist(),
        'ratio': cluster_size/total_features,
        'is_noise': False,
        'mean_activation': mean_activation,
        'max_activation': max_activation,
        'std_activation': std_activation,
        'sparsity': sparsity,
        'top_prompts': top_prompts,
        'prompt_activations': prompt_activations.cpu().numpy()
    }


# Print cluster similarity analysis
if len(unique_labels) > 1:
    print("\n=== Cluster Similarity Analysis ===")
    for i, label1 in enumerate(unique_labels):
        for label2 in unique_labels[i+1:]:
            acts1 = cluster_analysis[label1]['prompt_activations']
            acts2 = cluster_analysis[label2]['prompt_activations']
            correlation = np.corrcoef(acts1, acts2)[0, 1]
            if abs(correlation) > 0.5:  # Only show significant correlations
                print(f"Clusters {label1} and {label2}: correlation = {correlation:.4f}")



=== Cluster Analysis Summary ===
Total Features: 7272
Number of Clusters: 3
Noise Points: 5465 (75.15%)

=== Noise Cluster ===
Size: 5465 features (75.15% of total)
Feature Indices: [4, 5, 9, 10, 14, 21, 25, 26, 33, 37, 40, 42, 44, 45, 48, 65, 69, 83, 89, 94, 99, 100, 110, 115, 125, 128, 140, 144, 154, 160, 174, 176, 186, 197, 200, 209, 213, 216, 218, 219, 222, 227, 235, 240, 253, 257, 265, 267, 290, 292, 298, 310, 314, 317, 321, 325, 332, 340, 341, 344, 345, 354, 356, 363, 372, 381, 383, 387, 392, 395, 402, 410, 418, 426, 427, 436, 439, 442, 444, 447, 455, 473, 477, 483, 484, 502, 507, 513, 521, 527, 529, 532, 537, 538, 541, 543, 547, 558, 564, 569, 574, 584, 603, 607, 614, 617, 621, 628, 635, 649, 669, 672, 687, 695, 701, 706, 708, 709, 738, 745, 751, 766, 767, 771, 779, 780, 781, 782, 784, 786, 790, 791, 797, 798, 803, 804, 805, 817, 821, 835, 836, 847, 858, 859, 868, 883, 887, 898, 900, 901, 922, 923, 925, 930, 933, 935, 946, 948, 952, 953, 963, 964, 987, 991, 1021, 1028, 1032, 10