# Overview & Current Questions
This is for validating my clustering idea:
1. Store SAE activations for all prompts
2. Preprocess activations based on entropy and activation level (optional?)
3. Cluster activations using UMAP and HDBSCAN
4. Analyze clusters

Some questions: 
1. How does entropy, activation preprocessing affect clustering? How does n_prompts, length, affect this step?
2. How useful is UMAP? What are the best parameters?
3. How does HDBSCAN perform? What are the best parameters?
4. How does the clustering change when we use different datasets?
5. What commonalities do clusters have? Does this vary by dataset? Hyperparameter?

### Imports, config, model setup

In [1]:
%load_ext autoreload
%autoreload 2

from sae_lens import SAE, HookedSAETransformer
import torch
import gc
from config import config # cfg auto updates

import random
from datasets import load_dataset
import os

import plotly.express as px
import plotly.graph_objects as go
import pandas as pd

import numpy as np

from sklearn.cluster import DBSCAN
from sklearn.metrics import silhouette_score
import umap.umap_ as umap
import plotly.graph_objects as go
from einops import rearrange
import hdbscan

  from .autonotebook import tqdm as notebook_tqdm


In [2]:

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = HookedSAETransformer.from_pretrained("EleutherAI/pythia-70m-deduped", device=device)
sae, _, _ = SAE.from_pretrained(
    release="pythia-70m-deduped-mlp-sm",
    sae_id="blocks.3.hook_mlp_out",
    device=device
)


Loaded pretrained model EleutherAI/pythia-70m-deduped into HookedTransformer


# Collecting activations

In [3]:
def should_invalidate_cache(cached_data, n_prompts):
    """Check if cache should be invalidated based on n_prompts."""
    if cached_data is None:
        return True
    return cached_data['acts'].shape[0] < n_prompts * 0.85

def load_diverse_prompts(config):
    """Load diverse prompts from multiple sources for maximum SAE latent activation."""
    from datasets import load_dataset
    import random
    
    n_prompts = config.get('n_prompts', 500)
    prompts = []
    
    # Define datasets with correct configurations
    datasets = [
        ('wikipedia', '20220301.en', 'train'),  # Wikipedia articles
        ('c4', 'en', 'train'),                  # Web text
        ('bookcorpus', None, 'train'),          # Books
        ('wikitext', 'wikitext-103-raw-v1', 'train')  # More wiki text
    ]
    
    prompts_per_dataset = n_prompts // len(datasets)
    
    for name, subset, split in datasets:
        try:
            print(f"Loading prompts from {name}...")
            if subset:
                dataset = load_dataset(name, subset, split=split, streaming=True)
            else:
                dataset = load_dataset(name, split=split, streaming=True)
            
            dataset_prompts = []
            text_key = 'text' if 'text' in next(iter(dataset)) else 'content'
            
            for item in dataset:
                text = item[text_key]
                if isinstance(text, str) and len(text.strip()) >= config.get('min_prompt_length', 10):
                    dataset_prompts.append(text.strip())
                if len(dataset_prompts) >= prompts_per_dataset * 1.2:  # Get extra for deduplication
                    break
            
            prompts.extend(dataset_prompts[:prompts_per_dataset])
            print(f"Added {len(dataset_prompts[:prompts_per_dataset])} prompts from {name}")
            
        except Exception as e:
            print(f"Error with {name}: {e}")
            continue
    
    # Deduplicate and shuffle
    prompts = list(set(prompts))
    random.shuffle(prompts)
    prompts = prompts[:n_prompts]
    
    print(f"\nTotal unique prompts loaded: {len(prompts)}")
    
    # Add diagnostics about prompt diversity
    print("\nPrompt statistics:")
    lengths = [len(p.split()) for p in prompts]
    print(f"Average words per prompt: {sum(lengths)/len(lengths):.1f}")
    print(f"Min length: {min(lengths)}, Max length: {max(lengths)}")
    
    return prompts

def get_cache_filename(config, n_prompts):
    """Generate a cache filename based on hierarchical config parameters."""
    # Create a unique filename based on key parameters
    
    params = [
        f"prompts_{n_prompts}",
    ]
    print(params)
    
    # Create directory if it doesn't exist
    cache_dir = config.get('cache_dir', 'cache')
    os.makedirs(cache_dir, exist_ok=True)
    
    return os.path.join(cache_dir, f"processed_data_{'_'.join(params)}.pt") 

def load_processed_data(filename):
    """Load processed data from disk."""
    if os.path.exists(filename):
        print(f"Loading processed data from {filename}...")
        try:
            data = torch.load(filename)
            print("Data loaded successfully.")
            return data
        except Exception as e:
            print(f"Error loading data: {e}")
            return None
    else:
        print(f"No cached data found at {filename}")
        return None


In [4]:
def clear_cache():
    """Clear CUDA cache to free memory."""
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    gc.collect()

def create_embed_hook(P):
    def hook(value, hook):
        return P.unsqueeze(0)
    return hook

def get_feature_activations(model, sae, tokens, P=None):
    hooks = [('hook_embed', create_embed_hook(P))] if P is not None else []
    with model.hooks(fwd_hooks=hooks):
        _, cache = model.run_with_cache_with_saes(
            tokens, 
            saes=[sae],
            names_filter=lambda name: name == 'blocks.3.hook_mlp_out.hook_sae_acts_post'  # Only cache what you need
        )
    return cache['blocks.3.hook_mlp_out.hook_sae_acts_post']

def get_model_activations(model, tokens):
    _, cache = model.run_with_cache(
        tokens,
        names_filter=lambda name: name == 'blocks.3.hook_mlp_out'  # Only cache what you need
    )
    return cache['blocks.3.hook_mlp_out']

In [5]:
def collect_activations(model, sae, prompts, config):
    """Collect feature activations from prompts."""
    all_acts = []
    batch_size = config.get('batch_size', 10)
    for i in range(0, len(prompts), batch_size):
        batch_prompts = prompts[i:i + batch_size]
        print(f"Processing batch {i // batch_size + 1}/{len(prompts) // batch_size + 1}...")
        batch_acts = []
        for prompt in batch_prompts:
            try:
                tokens = model.to_tokens(prompt)
                #acts = get_feature_activations(model, sae, tokens)
                acts = get_model_activations(model, tokens)
                batch_acts.append(acts.mean(dim=1).squeeze(0)) # dont care about position specific information
                clear_cache()
            except Exception as e:
                print(f"Skipping prompt '{e}")
        
        if batch_acts:
            all_acts.extend(batch_acts)
            
    acts = torch.stack(all_acts)
    print(f"Collected activations for {acts.shape[0]} prompts, {acts.shape[1]} features")
    return acts

In [6]:
# Determine whether to use cached data 
n_prompts = config['n_prompts']
use_cached_data = config['use_cached_data']
cache_filename = get_cache_filename(config, n_prompts)
cached_data = load_processed_data(cache_filename)


if use_cached_data and cached_data is not None and not should_invalidate_cache(cached_data, n_prompts):
    print(f"Using cached data with {cached_data['acts'].shape[0]} prompts")
    acts = cached_data['acts']
    prompts = cached_data.get('prompts', [])
else:
    print(f"Processing data from scratch with {n_prompts} prompts")
    prompts = load_diverse_prompts(config)
    acts = collect_activations(model, sae, prompts, config)
    
    # Save the new data
    torch.save({
        'acts': acts,
        'prompts': prompts
    }, cache_filename)
    print(f"Data saved successfully to {cache_filename}")

['prompts_10000']
Loading processed data from feature_cache/processed_data_prompts_10000.pt...
Data loaded successfully.
Using cached data with 9791 prompts


# Entropy and sparsity

## Filtering using entropy and sparsity

In [7]:
def get_entropy_sparsity_varentropy(acts, config, n_prompts):
    
    # Calculate entropy and sparsity
    activations = acts.abs()
    probs = activations / (activations.sum(dim=0) + 1e-10)
    entropy = -torch.sum(probs * torch.log(probs + 1e-10), dim=0)

    # Correct varentropy calculation
    # We want variance of entropy across prompts for each feature
    log_probs = torch.log(probs + 1e-10)
    entropy_per_prompt = -(probs * log_probs)  # Shape: [n_prompts, n_features]
    mean_entropy = entropy_per_prompt.mean(dim=0)  # Shape: [n_features]
    varentropy = torch.mean((entropy_per_prompt - mean_entropy.unsqueeze(0))**2, dim=0)  # Shape: [n_features]
    # variance of entropy is a measure of how much the entropy varies across prompts 
    # i posit that this will be useful for clustering
    
    activation_threshold = config.get('activation_threshold', 0.1)
    sparsity = (acts.abs() > activation_threshold).float().mean(dim=0)

    if config['verbose']:
        # After calculating entropy and sparsity
        print(f"\nEntropy stats:")
        print(f"Min: {entropy.min().item():.3f}")
        print(f"Max: {entropy.max().item():.3f}")
        print(f"Mean: {entropy.mean().item():.3f}")

        print(f"\nSparsity stats:")
        print(f"Min: {sparsity.min().item():.3f}")
        print(f"Max: {sparsity.max().item():.3f}")
        print(f"Mean: {sparsity.mean().item():.3f}")

        # Print the thresholds being used
        print(f"\nThresholds:")
        print(f"Entropy: [{config['entropy_threshold_low']}, {config['entropy_threshold_high']}]")
        print(f"Sparsity: [{config['sparsity_min']}, {config['sparsity_max']}]")

    return entropy, sparsity, varentropy

In [8]:
def filter_features(acts, config):
    """
    Filter features based on entropy and sparsity.
    We are looking at the feature behavior across prompts (we could modify this for tokens, or even for tokens within prompts)
    """

    n_prompts, n_features = acts.shape
    entropy, sparsity, varentropy = get_entropy_sparsity_varentropy(acts, config, n_prompts)
    
    # Apply filtering mask
    mask = (entropy > config['entropy_threshold_low']) & \
           (entropy < config['entropy_threshold_high']) & \
           (sparsity > config['sparsity_min']) & \
           (sparsity < config['sparsity_max'])
    
    print(f"Kept {mask.sum().item()} out of {n_features} features")
    return acts[:, mask], mask.nonzero(as_tuple=True)[0], entropy, sparsity, varentropy

In [9]:
filtered_acts, original_indices, entropy_data, sparsity_data, varentropy_data = filter_features(acts, config)

print(acts.shape)

# Create a DataFrame for entropy and sparsity
df = pd.DataFrame({
    'entropy': entropy_data.cpu().numpy(),
    'sparsity': sparsity_data.cpu().numpy(),
    'varentropy': varentropy_data.cpu().numpy(),
})

# px.scatter(df, x='entropy', y='sparsity').show()
# px.scatter(df, x='entropy', y='varentropy').show()
# px.scatter(df, x='sparsity', y='varentropy').show() 


Entropy stats:
Min: -0.000
Max: 8.984
Mean: 1.937

Sparsity stats:
Min: 0.000
Max: 1.000
Mean: 0.028

Thresholds:
Entropy: [0.0, 15.0]
Sparsity: [0.0, 1.0]
Kept 8012 out of 32768 features
torch.Size([9791, 32768])


## Entropy and sparsity visualization

In [10]:
def analyze_feature_quadrants(filtered_acts, n_examples=5):
    """Analyze features based on their entropy/varentropy quadrants.
    
    Args:
        acts: [n_prompts, n_features] activation tensor
        entropy: [n_features] entropy per feature
        varentropy: [n_features] variance of entropy per feature
        n_examples: number of example features to show per quadrant
    """

    filtered_entropy, filtered_sparsity, filtered_varentropy = get_entropy_sparsity_varentropy(filtered_acts, config, n_prompts)

    # Get medians for quadrant splitting
    entropy_median = filtered_entropy.median()
    varentropy_median = filtered_varentropy.median()
    
    # Create quadrant masks
    q1 = (filtered_entropy <= entropy_median) & (filtered_varentropy <= varentropy_median)  # Low E, Low VE
    q2 = (filtered_entropy > entropy_median) & (filtered_varentropy <= varentropy_median)   # High E, Low VE
    q3 = (filtered_entropy <= entropy_median) & (filtered_varentropy > varentropy_median)   # Low E, High VE
    q4 = (filtered_entropy > entropy_median) & (filtered_varentropy > varentropy_median)    # High E, High VE
    
    quadrants = {
        "Flowing (Low E, Low VE)": q1,
        "Careful (High E, Low VE)": q2,
        "Exploring (Low E, High VE)": q3,
        "Resampling (High E, High VE)": q4
    }
    
    print("Feature Distribution in Quadrants:")
    for name, mask in quadrants.items():
        n_features = mask.sum().item()
        print(f"\n{name}: {n_features} features ({n_features/len(filtered_entropy):.1%})")
        
        # Get example features from this quadrant
        feature_indices = mask.nonzero(as_tuple=True)[0]
        if len(feature_indices) > 0:
            sample_indices = feature_indices[torch.randperm(len(feature_indices))[:n_examples]]
            
            print("\nExample features:")
            for idx in sample_indices:
                # Get activation statistics for this feature
                feature_acts = acts[:, idx]
                active_prompts = (feature_acts.abs() > config['activation_threshold']).sum().item()
                max_activation = feature_acts.abs().max().item()
                
                print(f"\nFeature {idx}:")
                print(f"  Entropy: {filtered_entropy[idx]:.3f}")
                print(f"  Varentropy: {filtered_varentropy[idx]:.3f}")
                print(f"  Active in {active_prompts}/{len(acts)} prompts")
                print(f"  Max activation: {max_activation:.3f}")
                
    return quadrants, filtered_entropy, filtered_varentropy

# Create a visualization of the quadrants
def plot_entropy_quadrants(entropy, varentropy, quadrants):
    """Create a scatter plot showing feature distribution across quadrants."""
    import plotly.express as px
    import pandas as pd
    
    # Create DataFrame
    df = pd.DataFrame({
        'entropy': entropy.cpu().numpy(),
        'varentropy': varentropy.cpu().numpy(),
        'quadrant': 'Unknown'
    })
    
    # Assign quadrant labels
    for name, mask in quadrants.items():
        df.loc[mask.cpu().numpy(), 'quadrant'] = name
    
    # Create scatter plot
    fig = px.scatter(df, x='entropy', y='varentropy', color='quadrant',
                    title='Feature Distribution across Entropy/Varentropy Quadrants',
                    labels={'entropy': 'Entropy', 'varentropy': 'Variance of Entropy'})
    
    # Add median lines
    fig.add_hline(y=varentropy.median().item(), line_dash="dash", line_color="gray")
    fig.add_vline(x=entropy.median().item(), line_dash="dash", line_color="gray")
    
    fig.show()

In [11]:
quadrants, filtered_entropy, filtered_varentropy = analyze_feature_quadrants(filtered_acts)
plot_entropy_quadrants(filtered_entropy, filtered_varentropy, quadrants)


Entropy stats:
Min: 1.628
Max: 8.984
Mean: 7.538

Sparsity stats:
Min: 0.000
Max: 0.903
Mean: 0.115

Thresholds:
Entropy: [0.0, 15.0]
Sparsity: [0.0, 1.0]
Feature Distribution in Quadrants:

Flowing (Low E, Low VE): 89 features (1.1%)

Example features:

Feature 3489:
  Entropy: 7.853
  Varentropy: 0.000
  Active in 0/9791 prompts
  Max activation: 0.000

Feature 3231:
  Entropy: 7.811
  Varentropy: 0.000
  Active in 0/9791 prompts
  Max activation: 0.000

Feature 725:
  Entropy: 7.815
  Varentropy: 0.000
  Active in 0/9791 prompts
  Max activation: 0.000

Feature 269:
  Entropy: 7.878
  Varentropy: 0.000
  Active in 3139/9791 prompts
  Max activation: 0.186

Feature 451:
  Entropy: 7.874
  Varentropy: 0.000
  Active in 0/9791 prompts
  Max activation: 0.000

Careful (High E, Low VE): 3917 features (48.9%)

Example features:

Feature 339:
  Entropy: 8.749
  Varentropy: 0.000
  Active in 0/9791 prompts
  Max activation: 0.000

Feature 7060:
  Entropy: 8.393
  Varentropy: 0.000
  Active

# Correlation analysis

In [None]:
# Compute and visualize correlation matrix for filtered activations
def plot_activation_correlations(filtered_acts, max_features=2000):
    """
    Create a correlation matrix heatmap for filtered activations.
    
    Args:
        filtered_acts: Tensor of shape [n_prompts, n_features]
        max_features: Maximum number of features to include in visualization
    """
    # Convert to numpy and transpose to get [n_features, n_prompts]
    acts = filtered_acts.T.cpu().numpy()
    
    # If we have too many features, sample randomly
    if acts.shape[0] > max_features:
        indices = np.random.choice(acts.shape[0], max_features, replace=False)
        acts = acts[indices]
        print(f"Sampled {max_features} features randomly for visualization")
    
    # Compute correlation matrix
    corr_matrix = np.corrcoef(acts)
    
    # Create heatmap
    fig = go.Figure(data=go.Heatmap(
        z=corr_matrix,
        colorscale='RdBu',
        zmid=0,  # Center the colorscale at 0
        colorbar=dict(
            title="Correlation",
            titleside="right"
        )
    ))
    
    fig.update_layout(
        title="Feature Activation Correlations",
        width=800,
        height=800,
        xaxis=dict(showticklabels=False),
        yaxis=dict(showticklabels=False)
    )
    
    # Add correlation statistics
    corr_flat = corr_matrix[np.triu_indices_from(corr_matrix, k=1)]
    mean_corr = np.mean(np.abs(corr_flat))
    median_corr = np.median(np.abs(corr_flat))
    high_corr = np.mean(np.abs(corr_flat) > 0.5)
    
    print(f"\nCorrelation Statistics:")
    print(f"Mean absolute correlation: {mean_corr:.3f}")
    print(f"Median absolute correlation: {median_corr:.3f}")
    print(f"Fraction of high correlations (|r| > 0.5): {high_corr:.1%}")
    
    fig.show()

# Plot correlation matrix
plot_activation_correlations(filtered_acts)


Correlation Statistics:
Mean absolute correlation: 0.042
Median absolute correlation: 0.024
Fraction of high correlations (|r| > 0.5): 0.0%


I notice that there are some features which create a 'grid' in the data. I wonder if these are the most important features.

# Clustering

## Applying UMAP and HDBSCAN

In [12]:
def apply_umap_preprocessing(acts, preprocessing_config):
    """
    Apply UMAP dimensionality reduction if beneficial.
    
    
    """
    feature_acts = acts.T.cpu().numpy()
    normalized_acts = feature_acts / (np.linalg.norm(feature_acts, axis=1, keepdims=True) + 1e-10)
    
    n_samples, n_dims = normalized_acts.shape
    min_dims_for_reduction = 50  # Only reduce if we have more than 50 dimensions
    
    if n_dims <= min_dims_for_reduction:
        print(f"Skipping UMAP reduction - input dimensionality ({n_dims}) is already manageable")
        return normalized_acts, normalized_acts
        
    # For high-dimensional data, apply UMAP reduction
    umap_config = preprocessing_config.get('umap', {})
    target_dims = min(umap_config.get('n_components', 50), n_samples - 1)
    
    reducer = umap.UMAP(
        n_components=target_dims,
        n_neighbors=min(umap_config.get('n_neighbors', 15), n_samples - 1),
        min_dist=umap_config.get('min_dist', 0.1),
        metric=umap_config.get('metric', 'cosine'),
        random_state=42
    )
    
    try:
        reduced_acts = reducer.fit_transform(normalized_acts)
        print(f"Applied UMAP reduction: {n_dims}d → {target_dims}d")
        return reduced_acts, normalized_acts
    except Exception as e:
        print(f"Error during UMAP reduction: {e}")
        return normalized_acts, normalized_acts


def run_hdbscan(reduced_acts, clustering_config):
    """Run HDBSCAN clustering."""
    hdbscan_config = clustering_config.get('hdbscan', {})
    clusterer = hdbscan.HDBSCAN(
        min_cluster_size=hdbscan_config.get('min_cluster_size', 5),
        min_samples=hdbscan_config.get('min_samples', 1),
        metric=hdbscan_config.get('metric', 'euclidean'),
        cluster_selection_epsilon=hdbscan_config.get('cluster_selection_epsilon', 0.0)
    )
    return clusterer.fit_predict(reduced_acts)

def cluster_features(acts, clustering_config):
    """Cluster features using UMAP preprocessing (if needed) and HDBSCAN."""
    if acts.shape[1] <= 1:
        return np.zeros(acts.shape[1], dtype=int), acts.T.cpu().numpy()
    
    # Preprocess with UMAP if enabled and beneficial
    if clustering_config.get('use_umap_preprocessing', True):
        reduced_acts, normalized_acts = apply_umap_preprocessing(acts, clustering_config)
    else:
        feature_acts = acts.T.cpu().numpy()
        normalized_acts = feature_acts / (np.linalg.norm(feature_acts, axis=1, keepdims=True) + 1e-10)
        reduced_acts = normalized_acts
    
    # Run HDBSCAN clustering
    labels = run_hdbscan(reduced_acts, clustering_config)
    
    n_clusters = len(np.unique(labels)) - (1 if -1 in labels else 0)
    noise_ratio = np.mean(labels == -1) if -1 in labels else 0
    print(f"Clustering complete: found {n_clusters} clusters with {noise_ratio:.2%} noise points")
    
    return labels, normalized_acts

In [13]:
labels, reduced_acts = cluster_features(filtered_acts, config)


'force_all_finite' was renamed to 'ensure_all_finite' in 1.6 and will be removed in 1.8.


n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism.


The TBB threading layer requires TBB version 2021 update 6 or later i.e., TBB_INTERFACE_VERSION >= 12060. Found TBB_INTERFACE_VERSION = 12050. The TBB threading layer is disabled.



Applied UMAP reduction: 9791d → 50d



'force_all_finite' was renamed to 'ensure_all_finite' in 1.6 and will be removed in 1.8.


'force_all_finite' was renamed to 'ensure_all_finite' in 1.6 and will be removed in 1.8.



Clustering complete: found 479 clusters with 31.42% noise points


## Cluster visualization

In [None]:
# You can also look at correlations within specific clusters
def plot_cluster_correlations(filtered_acts, labels, cluster_id, max_features=1000):
    """Plot correlation matrix for features within a specific cluster"""
    cluster_mask = labels == cluster_id
    cluster_acts = filtered_acts[:, cluster_mask]
    
    print(f"\nAnalyzing Cluster {cluster_id}")
    print(f"Number of features: {cluster_acts.shape[1]}")
    
    plot_activation_correlations(cluster_acts, max_features)

# Example: Plot correlations for a specific cluster
# Replace 0 with any cluster ID you're interested in
plot_cluster_correlations(filtered_acts, labels, 0)

In [14]:
# Track results for each cluster
cluster_analysis = {}

# Get unique valid cluster labels (including noise points)
all_labels = np.unique(labels)
unique_labels = [label for label in all_labels if label != -1]

# Calculate global statistics
total_features = len(labels)
noise_points = np.sum(labels == -1)
noise_ratio = noise_points / total_features if total_features > 0 else 0

print(f"\n=== Cluster Analysis Summary ===")
print(f"Total Features: {total_features}")
print(f"Number of Clusters: {len(unique_labels)}")
print(f"Noise Points: {noise_points} ({noise_ratio:.2%})")

# For each cluster (including noise)
for label in all_labels:
    cluster_type = "Noise Cluster" if label == -1 else f"Cluster {label}"
    cluster_mask = labels == label
    cluster_size = np.sum(cluster_mask)
    cluster_indices = original_indices[cluster_mask]
    
    print(f"\n=== {cluster_type} ===")
    print(f"Size: {cluster_size} features ({(cluster_size/total_features):.2%} of total)")
    print(f"Feature Indices: {cluster_indices.tolist()}")
    
    if label == -1:  # Skip activation analysis for noise cluster
        cluster_analysis[label] = {
            'size': cluster_size,
            'indices': cluster_indices.tolist(),
            'ratio': cluster_size/total_features,
            'is_noise': True
        }
        continue
        
    # Get activations for all prompts on this cluster's features
    print(reduced_acts.shape)
    print(cluster_mask.shape)
    cluster_acts = filtered_acts[:, cluster_mask]  # [n_prompts, n_cluster_features]
    
    # Compute activation statistics
    mean_activation = torch.mean(torch.abs(cluster_acts)).item()
    max_activation = torch.max(torch.abs(cluster_acts)).item()
    std_activation = torch.std(torch.abs(cluster_acts)).item()
    sparsity = (torch.abs(cluster_acts) > 0.1).float().mean().item()
    
    print(f"\nActivation Statistics:")
    print(f"  Mean Activation: {mean_activation:.4f}")
    print(f"  Max Activation: {max_activation:.4f}")
    print(f"  Std Deviation: {std_activation:.4f}")
    print(f"  Sparsity: {sparsity:.4f}")
    
    # Compute average activation of each prompt on this cluster's features
    prompt_activations = torch.mean(torch.abs(cluster_acts), dim=1)  # [n_prompts]
    
    # Find top activating prompts
    top_k = min(5, len(prompts))
    top_prompt_indices = torch.argsort(prompt_activations, descending=True)[:top_k]
    top_prompts = [(prompts[i], prompt_activations[i].item()) for i in top_prompt_indices]
    
    print(f"\nTop {top_k} Activating Prompts:")
    for i, (prompt, act) in enumerate(top_prompts, 1):
        # Show first 100 chars, with special formatting for section headers
        truncated = prompt[:100]
        if len(prompt) > 100:
            truncated += "..."
        
        # Format section headers more clearly
        if "=" in truncated:
            sections = [s.strip() for s in truncated.split("=") if s.strip()]
            if sections:
                truncated = f"[SECTION] {' > '.join(sections)}"
        
        # Split into tokens if the prompt contains spaces
        tokens = truncated.split()
        if len(tokens) > 15:
            token_display = " ".join(tokens[:15]) + " ..."
        else:
            token_display = truncated
            
        print(f"  {i}. \"{token_display}\"")
        print(f"     Activation: {act:.4f}")
        print(f"     Total Length: {len(prompt)} chars, {len(prompt.split())} tokens")
    # Store detailed results
    cluster_analysis[label] = {
        'size': cluster_size,
        'indices': cluster_indices.tolist(),
        'ratio': cluster_size/total_features,
        'is_noise': False,
        'mean_activation': mean_activation,
        'max_activation': max_activation,
        'std_activation': std_activation,
        'sparsity': sparsity,
        'top_prompts': top_prompts,
        'prompt_activations': prompt_activations.cpu().numpy()
    }


# Print cluster similarity analysis
if len(unique_labels) > 1:
    print("\n=== Cluster Similarity Analysis ===")
    for i, label1 in enumerate(unique_labels):
        for label2 in unique_labels[i+1:]:
            acts1 = cluster_analysis[label1]['prompt_activations']
            acts2 = cluster_analysis[label2]['prompt_activations']
            correlation = np.corrcoef(acts1, acts2)[0, 1]
            if abs(correlation) > 0.5:  # Only show significant correlations
                print(f"Clusters {label1} and {label2}: correlation = {correlation:.4f}")



=== Cluster Analysis Summary ===
Total Features: 8012
Number of Clusters: 479
Noise Points: 2517 (31.42%)

=== Noise Cluster ===
Size: 2517 features (31.42% of total)
Feature Indices: [4, 5, 14, 35, 37, 42, 44, 48, 94, 125, 128, 154, 197, 213, 216, 220, 222, 227, 254, 257, 265, 296, 310, 332, 340, 341, 354, 374, 375, 378, 392, 395, 436, 439, 447, 448, 453, 455, 477, 483, 527, 538, 543, 544, 561, 569, 578, 584, 607, 617, 621, 628, 649, 655, 676, 709, 722, 751, 784, 804, 805, 830, 845, 852, 869, 887, 900, 933, 937, 947, 952, 987, 995, 1047, 1053, 1066, 1076, 1082, 1120, 1151, 1165, 1189, 1209, 1211, 1255, 1285, 1292, 1296, 1321, 1344, 1365, 1367, 1391, 1397, 1420, 1427, 1435, 1440, 1442, 1452, 1471, 1477, 1507, 1520, 1524, 1533, 1534, 1581, 1609, 1610, 1618, 1624, 1640, 1643, 1669, 1671, 1698, 1699, 1719, 1732, 1778, 1780, 1781, 1787, 1809, 1843, 1846, 1878, 1881, 1902, 1903, 1915, 1918, 1919, 1932, 1941, 1944, 1976, 1979, 1998, 2037, 2103, 2106, 2113, 2144, 2164, 2187, 2193, 2203, 2206

## Cluster scoring