# Overview & Current Questions
This is for validating my clustering idea:
1. Store SAE activations for all prompts
2. Preprocess activations based on entropy and activation level (optional?)
3. Cluster activations using UMAP and HDBSCAN
4. Analyze clusters

Some questions: 
1. How does entropy, activation preprocessing affect clustering? How does n_prompts, length, affect this step?
2. How useful is UMAP? What are the best parameters?
3. How does HDBSCAN perform? What are the best parameters?
4. How does the clustering change when we use different datasets?
5. What commonalities do clusters have? Does this vary by dataset? Hyperparameter?

### Imports, config, model setup

In [1]:
%load_ext autoreload
%autoreload 2

from sae_lens import SAE, HookedSAETransformer
import torch
import gc
from config import config # cfg auto updates

import random
from datasets import load_dataset
import os

import plotly.express as px
import plotly.graph_objects as go
import pandas as pd

import numpy as np

from sklearn.cluster import DBSCAN
from sklearn.metrics import silhouette_score
import umap.umap_ as umap
import plotly.graph_objects as go
from einops import rearrange
import hdbscan

  from .autonotebook import tqdm as notebook_tqdm


In [2]:

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = HookedSAETransformer.from_pretrained("EleutherAI/pythia-70m-deduped", device=device)
sae, _, _ = SAE.from_pretrained(
    release="pythia-70m-deduped-mlp-sm",
    sae_id="blocks.3.hook_mlp_out",
    device=device
)


Loaded pretrained model EleutherAI/pythia-70m-deduped into HookedTransformer


## Loading Activations from get_data.ipynb

In [None]:
def load_processed_data(filename):
    """Load processed data from disk."""
    if os.path.exists(filename):
        print(f"Loading processed data from {filename}...")
        try:
            data = torch.load(filename)
            print("Data loaded successfully.")
            return data
        except Exception as e:
            print(f"Error loading data: {e}")
            return None
    else:
        print(f"No cached data found at {filename}")
        return None
    
def get_cache_filename(config, n_prompts):
    """Generate a cache filename based on hierarchical config parameters."""
    # Create a unique filename based on key parameters
    
    params = [
        f"prompts_{n_prompts}",
    ]
    print(params)
    
    # Create directory if it doesn't exist
    cache_dir = config.get('cache_dir', 'cache')
    os.makedirs(cache_dir, exist_ok=True)
    
    return os.path.join(cache_dir, f"processed_data_{'_'.join(params)}.pt") 

In [6]:
# Determine whether to use cached data 
n_prompts = config['n_prompts']
use_cached_data = config['use_cached_data']
cache_filename = get_cache_filename(config, n_prompts)
cached_data = load_processed_data(cache_filename)


acts = cached_data['acts']
prompts = cached_data.get('prompts', [])


['prompts_25000']
No cached data found at feature_cache/processed_data_prompts_25000.pt
True
Processing data from scratch with 25000 prompts
Loading prompts from wikipedia...


KeyboardInterrupt: 

# Entropy and sparsity

## Filtering using entropy and sparsity

In [32]:
def get_entropy_sparsity_varentropy(acts, config, n_prompts):
    
    # Calculate entropy and sparsity
    activations = acts.abs()
    probs = activations / (activations.sum(dim=0) + 1e-10)
    entropy = -torch.sum(probs * torch.log(probs + 1e-10), dim=0)

    # Correct varentropy calculation
    # We want variance of entropy across prompts for each feature
    log_probs = torch.log(probs + 1e-10)
    entropy_per_prompt = -(probs * log_probs)  # Shape: [n_prompts, n_features]
    mean_entropy = entropy_per_prompt.mean(dim=0)  # Shape: [n_features]
    varentropy = torch.mean((entropy_per_prompt - mean_entropy.unsqueeze(0))**2, dim=0)  # Shape: [n_features]
    # variance of entropy is a measure of how much the entropy varies across prompts 
    # i posit that this will be useful for clustering
    
    activation_threshold = config.get('activation_threshold', 0.1)
    sparsity = (acts.abs() > activation_threshold).float().mean(dim=0)

    # if config['verbose']:
    #     # After calculating entropy and sparsity
    #     print(f"\nEntropy stats:")
    #     print(f"Min: {entropy.min().item():.3f}")
    #     print(f"Max: {entropy.max().item():.3f}")
    #     print(f"Mean: {entropy.mean().item():.3f}")

    #     print(f"\nSparsity stats:")
    #     print(f"Min: {sparsity.min().item():.3f}")
    #     print(f"Max: {sparsity.max().item():.3f}")
    #     print(f"Mean: {sparsity.mean().item():.3f}")

    #     # Print the thresholds being used
    #     print(f"\nThresholds:")
    #     print(f"Entropy: [{config['entropy_threshold_low']}, {config['entropy_threshold_high']}]")
    #     print(f"Sparsity: [{config['sparsity_min']}, {config['sparsity_max']}]")

    return entropy, sparsity, varentropy

In [33]:
def filter_features(acts, config):
    """
    Filter features based on entropy and sparsity.
    We are looking at the feature behavior across prompts (we could modify this for tokens, or even for tokens within prompts)
    """

    n_prompts, n_features = acts.shape
    entropy, sparsity, varentropy = get_entropy_sparsity_varentropy(acts, config, n_prompts)
    
    # Apply filtering mask
    mask = (entropy > config['entropy_threshold_low']) & \
           (entropy < config['entropy_threshold_high']) & \
           (sparsity > config['sparsity_min']) & \
           (sparsity < config['sparsity_max'])
    
    print(f"Kept {mask.sum().item()} out of {n_features} features")
    return acts[:, mask], mask.nonzero(as_tuple=True)[0], entropy, sparsity, varentropy

In [None]:
filtered_acts, original_indices, entropy_data, sparsity_data, varentropy_data = filter_features(acts, config)
print(f' after filtering acts.shape: {acts.shape}')

if config['visualize_features']:
    # Create a DataFrame for entropy and sparsity
    df = pd.DataFrame({
        'entropy': entropy_data.cpu().numpy(),
        'sparsity': sparsity_data.cpu().numpy(),
        'varentropy': varentropy_data.cpu().numpy(),
    })

    px.scatter(df, x='entropy', y='sparsity').show()
    px.scatter(df, x='entropy', y='varentropy').show()
    px.scatter(df, x='sparsity', y='varentropy').show() 

## Entropy and sparsity visualization

In [10]:
def analyze_feature_quadrants(filtered_acts, n_examples=5):
    """Analyze features based on their entropy/varentropy quadrants.
    
    Args:
        acts: [n_prompts, n_features] activation tensor
        entropy: [n_features] entropy per feature
        varentropy: [n_features] variance of entropy per feature
        n_examples: number of example features to show per quadrant
    """

    filtered_entropy, filtered_sparsity, filtered_varentropy = get_entropy_sparsity_varentropy(filtered_acts, config, n_prompts)

    # Get medians for quadrant splitting
    entropy_median = filtered_entropy.median()
    varentropy_median = filtered_varentropy.median()
    
    # Create quadrant masks
    q1 = (filtered_entropy <= entropy_median) & (filtered_varentropy <= varentropy_median)  # Low E, Low VE
    q2 = (filtered_entropy > entropy_median) & (filtered_varentropy <= varentropy_median)   # High E, Low VE
    q3 = (filtered_entropy <= entropy_median) & (filtered_varentropy > varentropy_median)   # Low E, High VE
    q4 = (filtered_entropy > entropy_median) & (filtered_varentropy > varentropy_median)    # High E, High VE
    
    quadrants = {
        "Flowing (Low E, Low VE)": q1,
        "Careful (High E, Low VE)": q2,
        "Exploring (Low E, High VE)": q3,
        "Resampling (High E, High VE)": q4
    }
    
    # print("Feature Distribution in Quadrants:")
    for name, mask in quadrants.items():
        n_features = mask.sum().item()
        # print(f"\n{name}: {n_features} features ({n_features/len(filtered_entropy):.1%})")
        
        # Get example features from this quadrant
        feature_indices = mask.nonzero(as_tuple=True)[0]
        if len(feature_indices) > 0:
            sample_indices = feature_indices[torch.randperm(len(feature_indices))[:n_examples]]
            
            # print("\nExample features:")
            for idx in sample_indices:
                # Get activation statistics for this feature
                feature_acts = acts[:, idx]
                active_prompts = (feature_acts.abs() > config['activation_threshold']).sum().item()
                max_activation = feature_acts.abs().max().item()
                
                # print(f"\nFeature {idx}:")
                # print(f"  Entropy: {filtered_entropy[idx]:.3f}")
                # print(f"  Varentropy: {filtered_varentropy[idx]:.3f}")
                # print(f"  Active in {active_prompts}/{len(acts)} prompts")
                # print(f"  Max activation: {max_activation:.3f}")
                
    return quadrants, filtered_entropy, filtered_varentropy

# Create a visualization of the quadrants
def plot_entropy_quadrants(entropy, varentropy, quadrants):
    """Create a scatter plot showing feature distribution across quadrants."""
    import plotly.express as px
    import pandas as pd
    
    # Create DataFrame
    df = pd.DataFrame({
        'entropy': entropy.cpu().numpy(),
        'varentropy': varentropy.cpu().numpy(),
        'quadrant': 'Unknown'
    })
    
    # Assign quadrant labels
    for name, mask in quadrants.items():
        df.loc[mask.cpu().numpy(), 'quadrant'] = name
    
    # Create scatter plot
    fig = px.scatter(df, x='entropy', y='varentropy', color='quadrant',
                    title='Feature Distribution across Entropy/Varentropy Quadrants',
                    labels={'entropy': 'Entropy', 'varentropy': 'Variance of Entropy'})
    
    # Add median lines
    fig.add_hline(y=varentropy.median().item(), line_dash="dash", line_color="gray")
    fig.add_vline(x=entropy.median().item(), line_dash="dash", line_color="gray")
    
    fig.show()

In [None]:
quadrants, filtered_entropy, filtered_varentropy = analyze_feature_quadrants(filtered_acts)
plot_entropy_quadrants(filtered_entropy, filtered_varentropy, quadrants)

# Clustering

## Applying UMAP and HDBSCAN

In [12]:
def apply_umap_preprocessing(acts, preprocessing_config):
    """
    Apply UMAP dimensionality reduction if beneficial.
    
    
    """
    feature_acts = acts.T.cpu().numpy()
    normalized_acts = feature_acts / (np.linalg.norm(feature_acts, axis=1, keepdims=True) + 1e-10)
    
    n_samples, n_dims = normalized_acts.shape
    min_dims_for_reduction = 50  # Only reduce if we have more than 50 dimensions
    
    if n_dims <= min_dims_for_reduction:
        print(f"Skipping UMAP reduction - input dimensionality ({n_dims}) is already manageable")
        return normalized_acts, normalized_acts
        
    # For high-dimensional data, apply UMAP reduction
    umap_config = preprocessing_config.get('umap', {})
    target_dims = min(umap_config.get('n_components', 50), n_samples - 1)
    
    reducer = umap.UMAP(
        n_components=target_dims,
        n_neighbors=min(umap_config.get('n_neighbors', 15), n_samples - 1),
        min_dist=umap_config.get('min_dist', 0.1),
        metric=umap_config.get('metric', 'cosine'),
        random_state=42
    )
    
    try:
        reduced_acts = reducer.fit_transform(normalized_acts)
        print(f"Applied UMAP reduction: {n_dims}d → {target_dims}d")
        return reduced_acts, normalized_acts
    except Exception as e:
        print(f"Error during UMAP reduction: {e}")
        return normalized_acts, normalized_acts


def run_hdbscan(reduced_acts, clustering_config):
    """Run HDBSCAN clustering."""
    hdbscan_config = clustering_config.get('hdbscan', {})
    clusterer = hdbscan.HDBSCAN(
        min_cluster_size=hdbscan_config.get('min_cluster_size', 5),
        min_samples=hdbscan_config.get('min_samples', 1),
        metric=hdbscan_config.get('metric', 'euclidean'),
        cluster_selection_epsilon=hdbscan_config.get('cluster_selection_epsilon', 0.0)
    )
    return clusterer.fit_predict(reduced_acts)

def cluster_features(acts, clustering_config):
    """Cluster features using UMAP preprocessing (if needed) and HDBSCAN."""
    if acts.shape[1] <= 1:
        return np.zeros(acts.shape[1], dtype=int), acts.T.cpu().numpy()
    
    # Preprocess with UMAP if enabled and beneficial
    if clustering_config.get('use_umap_preprocessing', True):
        reduced_acts, normalized_acts = apply_umap_preprocessing(acts, clustering_config)
    else:
        feature_acts = acts.T.cpu().numpy()
        normalized_acts = feature_acts / (np.linalg.norm(feature_acts, axis=1, keepdims=True) + 1e-10)
        reduced_acts = normalized_acts
    
    # Run HDBSCAN clustering
    labels = run_hdbscan(reduced_acts, clustering_config)
    
    n_clusters = len(np.unique(labels)) - (1 if -1 in labels else 0)
    noise_ratio = np.mean(labels == -1) if -1 in labels else 0
    print(f"Clustering complete: found {n_clusters} clusters with {noise_ratio:.2%} noise points")
    
    return labels, normalized_acts

In [None]:
labels, reduced_acts = cluster_features(filtered_acts, config)

## Cluster correaltion heatmap

In [None]:
# Compute and visualize correlation matrix for filtered activations
def plot_activation_correlations(filtered_acts, max_features=1000):
    """
    Create a correlation matrix heatmap for filtered activations.
    
    Args:
        filtered_acts: Tensor of shape [n_prompts, n_features]
        max_features: Maximum number of features to include in visualization
    """
    # Convert to numpy and transpose to get [n_features, n_prompts]
    acts = filtered_acts.T.cpu().numpy()
    
    # If we have too many features, sample randomly
    if acts.shape[0] > max_features:
        indices = np.random.choice(acts.shape[0], max_features, replace=False)
        acts = acts[indices]
        print(f"Sampled {max_features} features randomly for visualization")
    
    # Compute correlation matrix
    corr_matrix = np.corrcoef(acts)
    
    # Create heatmap
    fig = go.Figure(data=go.Heatmap(
        z=corr_matrix,
        colorscale='RdBu',
        zmid=0,  # Center the colorscale at 0
        colorbar=dict(
            title="Correlation",
            titleside="right"
        )
    ))
    
    fig.update_layout(
        title="Feature Activation Correlations",
        width=800,
        height=800,
        xaxis=dict(showticklabels=False),
        yaxis=dict(showticklabels=False)
    )
    
    # Add correlation statistics
    corr_flat = corr_matrix[np.triu_indices_from(corr_matrix, k=1)]
    mean_corr = np.mean(np.abs(corr_flat))
    median_corr = np.median(np.abs(corr_flat))
    high_corr = np.mean(np.abs(corr_flat) > 0.5)
    
    print(f"\nCorrelation Statistics:")
    print(f"Mean absolute correlation: {mean_corr:.3f}")
    print(f"Median absolute correlation: {median_corr:.3f}")
    print(f"Fraction of high correlations (|r| > 0.5): {high_corr:.1%}")
    
    fig.show()

# You can also look at correlations within specific clusters
def plot_cluster_correlations(filtered_acts, labels, cluster_id, max_features=1000):
    """Plot correlation matrix for features within a specific cluster"""
    cluster_mask = labels == cluster_id
    cluster_acts = filtered_acts[:, cluster_mask]
    
    print(f"\nAnalyzing Cluster {cluster_id}")
    print(f"Number of features: {cluster_acts.shape[1]}")
    
    plot_activation_correlations(cluster_acts, max_features)

# Example: Plot correlations for a specific cluster
# Replace 0 with any cluster ID you're interested in
plot_cluster_correlations(filtered_acts, labels, 20)

## Cluster prompt analysis

In [None]:
# Track results for each cluster
cluster_analysis = {}

# Get unique valid cluster labels (including noise points)
all_labels = np.unique(labels)
unique_labels = [label for label in all_labels if label != -1]

# Calculate global statistics
total_features = len(labels)
noise_points = np.sum(labels == -1)
noise_ratio = noise_points / total_features if total_features > 0 else 0

print(f"\n=== Cluster Analysis Summary ===")
print(f"Total Features: {total_features}")
print(f"Number of Clusters: {len(unique_labels)}")
print(f"Noise Points: {noise_points} ({noise_ratio:.2%})")

# For each cluster (including noise)
for label in all_labels:
    cluster_type = "Noise Cluster" if label == -1 else f"Cluster {label}"
    cluster_mask = labels == label
    cluster_size = np.sum(cluster_mask)
    cluster_indices = original_indices[cluster_mask]
    
    print(f"\n=== {cluster_type} ===")
    print(f"Size: {cluster_size} features ({(cluster_size/total_features):.2%} of total)")
    print(f"Feature Indices: {cluster_indices.tolist()}")
    
    if label == -1:  # Skip activation analysis for noise cluster
        cluster_analysis[label] = {
            'size': cluster_size,
            'indices': cluster_indices.tolist(),
            'ratio': cluster_size/total_features,
            'is_noise': True
        }
        continue
        
    # Get activations for all prompts on this cluster's features
    print(reduced_acts.shape)
    print(cluster_mask.shape)
    cluster_acts = filtered_acts[:, cluster_mask]  # [n_prompts, n_cluster_features]
    
    # Compute activation statistics
    mean_activation = torch.mean(torch.abs(cluster_acts)).item()
    max_activation = torch.max(torch.abs(cluster_acts)).item()
    std_activation = torch.std(torch.abs(cluster_acts)).item()
    sparsity = (torch.abs(cluster_acts) > 0.1).float().mean().item()
    
    print(f"\nActivation Statistics:")
    print(f"  Mean Activation: {mean_activation:.4f}")
    print(f"  Max Activation: {max_activation:.4f}")
    print(f"  Std Deviation: {std_activation:.4f}")
    print(f"  Sparsity: {sparsity:.4f}")
    
    # Compute average activation of each prompt on this cluster's features
    prompt_activations = torch.mean(torch.abs(cluster_acts), dim=1)  # [n_prompts]
    
    # Find top activating prompts
    top_k = min(5, len(prompts))
    top_prompt_indices = torch.argsort(prompt_activations, descending=True)[:top_k]
    top_prompts = [(prompts[i], prompt_activations[i].item()) for i in top_prompt_indices]
    
    print(f"\nTop {top_k} Activating Prompts:")
    for i, (prompt, act) in enumerate(top_prompts, 1):
        # Show first 100 chars, with special formatting for section headers
        truncated = prompt[:100]
        if len(prompt) > 100:
            truncated += "..."
        
        # Format section headers more clearly
        if "=" in truncated:
            sections = [s.strip() for s in truncated.split("=") if s.strip()]
            if sections:
                truncated = f"[SECTION] {' > '.join(sections)}"
        
        # Split into tokens if the prompt contains spaces
        tokens = truncated.split()
        if len(tokens) > 15:
            token_display = " ".join(tokens[:15]) + " ..."
        else:
            token_display = truncated
            
        print(f"  {i}. \"{token_display}\"")
        print(f"     Activation: {act:.4f}")
        print(f"     Total Length: {len(prompt)} chars, {len(prompt.split())} tokens")
    # Store detailed results
    cluster_analysis[label] = {
        'size': cluster_size,
        'indices': cluster_indices.tolist(),
        'ratio': cluster_size/total_features,
        'is_noise': False,
        'mean_activation': mean_activation,
        'max_activation': max_activation,
        'std_activation': std_activation,
        'sparsity': sparsity,
        'top_prompts': top_prompts,
        'prompt_activations': prompt_activations.cpu().numpy()
    }


# Print cluster similarity analysis
if len(unique_labels) > 1:
    print("\n=== Cluster Similarity Analysis ===")
    for i, label1 in enumerate(unique_labels):
        for label2 in unique_labels[i+1:]:
            acts1 = cluster_analysis[label1]['prompt_activations']
            acts2 = cluster_analysis[label2]['prompt_activations']
            correlation = np.corrcoef(acts1, acts2)[0, 1]
            if abs(correlation) > 0.5:  # Only show significant correlations
                print(f"Clusters {label1} and {label2}: correlation = {correlation:.4f}")


## Relationships between clusters

In [None]:
# Analyze feature correlations more directly
def analyze_feature_correlations(filtered_acts, original_indices, labels, threshold=0.7, max_features=1000):
    """
    Identify highly correlated feature pairs directly from the correlation matrix.
    
    Args:
        filtered_acts: Tensor of shape [n_prompts, n_features]
        original_indices: Original indices of the filtered features
        threshold: Correlation threshold to consider (absolute value)
        max_features: Maximum number of features to include in analysis
    """
    # Convert to numpy and transpose to get [n_features, n_prompts]
    acts = filtered_acts.T.cpu().numpy()
    
    # If we have too many features, sample randomly
    if acts.shape[0] > max_features:
        np.random.seed(42)  # For reproducibility
        indices = np.random.choice(acts.shape[0], max_features, replace=False)
        acts = acts[indices]
        feature_indices = original_indices[indices]
        print(f"Sampled {max_features} features randomly for analysis")
    else:
        feature_indices = original_indices
    
    # Compute correlation matrix
    corr_matrix = np.corrcoef(acts)
    n_features = corr_matrix.shape[0]
    
    # Find highly correlated pairs
    # We'll only look at the upper triangle to avoid duplicates
    high_corr_pairs = []
    
    for i in range(n_features):
        for j in range(i+1, n_features):  # Upper triangle only
            corr = corr_matrix[i, j]
            if abs(corr) >= threshold:
                high_corr_pairs.append((i, j, corr))
    
    # Sort by absolute correlation (highest first)
    high_corr_pairs.sort(key=lambda x: abs(x[2]), reverse=True)
    
    print(f"\nFound {len(high_corr_pairs)} feature pairs with |correlation| >= {threshold}")
    
    # Print top correlated pairs
    top_n = min(20, len(high_corr_pairs))
    if top_n > 0:
        print(f"\nTop {top_n} correlated feature pairs:")
        for i, j, corr in high_corr_pairs[:top_n]:
            orig_i = feature_indices[i]
            orig_j = feature_indices[j]
            print(f"Features {orig_i} and {orig_j}: correlation = {corr:.4f}")
            
            # If we have cluster labels, show which clusters these features belong to
            if 'labels' in globals():
                cluster_i = labels[orig_i] if orig_i < len(labels) else "Unknown"
                cluster_j = labels[orig_j] if orig_j < len(labels) else "Unknown"
                print(f"  Cluster assignments: {cluster_i} and {cluster_j}")
    
    # Analyze correlation structure
    # Count how many features have at least one strong correlation
    features_with_strong_corr = set()
    for i, j, _ in high_corr_pairs:
        features_with_strong_corr.add(i)
        features_with_strong_corr.add(j)
    
    print(f"\nFeatures with at least one strong correlation: {len(features_with_strong_corr)} out of {n_features} ({len(features_with_strong_corr)/n_features:.1%})")
    
    # Find correlation hubs (features with many strong correlations)
    corr_counts = {}
    for i, j, _ in high_corr_pairs:
        corr_counts[i] = corr_counts.get(i, 0) + 1
        corr_counts[j] = corr_counts.get(j, 0) + 1
    
    # Sort features by number of strong correlations
    sorted_features = sorted(corr_counts.items(), key=lambda x: x[1], reverse=True)
    
    if sorted_features:
        print("\nTop correlation hubs (features with many strong correlations):")
        for i, count in sorted_features[:10]:
            orig_i = feature_indices[i]
            print(f"Feature {orig_i}: {count} strong correlations")
            
            # If we have cluster labels, show which cluster this feature belongs to
            if 'labels' in globals():
                cluster_i = labels[orig_i] if orig_i < len(labels) else "Unknown"
                print(f"  Cluster assignment: {cluster_i}")
    
    # Identify correlation communities (groups of features that are all strongly correlated with each other)
    # This is a simple approach - for larger datasets you might want to use a community detection algorithm
    def find_communities(pairs, n_features):
        # Create an adjacency matrix
        adj_matrix = np.zeros((n_features, n_features), dtype=bool)
        for i, j, _ in pairs:
            adj_matrix[i, j] = True
            adj_matrix[j, i] = True
        
        # Find connected components (simple BFS)
        visited = np.zeros(n_features, dtype=bool)
        communities = []
        
        for start in range(n_features):
            if visited[start]:
                continue
                
            # BFS to find connected component
            community = []
            queue = [start]
            visited[start] = True
            
            while queue:
                node = queue.pop(0)
                community.append(node)
                
                for neighbor in np.where(adj_matrix[node])[0]:
                    if not visited[neighbor]:
                        visited[neighbor] = True
                        queue.append(neighbor)
            
            if len(community) > 1:  # Only consider communities with at least 2 features
                communities.append(community)
        
        return communities
    
    communities = find_communities(high_corr_pairs, n_features)
    communities.sort(key=len, reverse=True)
    
    print(f"\nFound {len(communities)} correlation communities (groups of strongly correlated features)")
    
    for i, community in enumerate(communities[:5]):  # Show top 5 communities
        if len(community) > 1:
            print(f"\nCommunity {i+1}: {len(community)} features")
            orig_indices = [feature_indices[j] for j in community]
            print(f"  Original indices: {orig_indices}")
            

            cluster_counts = {}
            for idx in orig_indices:
                if idx < len(labels):
                    cluster = labels[idx]
                    cluster_counts[cluster] = cluster_counts.get(cluster, 0) + 1
            
            print("  Cluster distribution:")
            for cluster, count in sorted(cluster_counts.items(), key=lambda x: x[1], reverse=True):
                print(f"    Cluster {cluster}: {count} features ({count/len(community):.1%})")
    
    return high_corr_pairs, feature_indices, communities

# Run the analysis with a threshold of 0.7
high_corr_pairs, feature_indices, communities = analyze_feature_correlations(filtered_acts, original_indices, labels, threshold=0.7)

In [None]:
# Visualize one of the correlation communities
def visualize_correlation_community(filtered_acts, community_indices, feature_indices, prompts, n_top_prompts=5):
    """Visualize a community of correlated features"""
    # Get the original indices for the community
    orig_indices = [feature_indices[i] for i in community_indices]
    
    # Get activations for these features - use community_indices directly since these are indices into filtered_acts
    community_acts = np.array([filtered_acts[:, i].cpu().numpy() for i in community_indices]).T
    
    # Create a heatmap of feature activations across prompts
    # First, find top activating prompts for any feature in the community
    max_activations = np.max(np.abs(community_acts), axis=1)
    top_prompt_indices = np.argsort(max_activations)[::-1][:n_top_prompts]
    
    # Create a heatmap of these top prompts vs features
    top_acts = community_acts[top_prompt_indices]
    
    # Create heatmap
    fig = go.Figure(data=go.Heatmap(
        z=top_acts,
        x=[f"Feature {idx}" for idx in orig_indices],
        y=[f"Prompt {i+1}" for i in range(len(top_prompt_indices))],
        colorscale='RdBu',
        zmid=0
    ))
    
    fig.update_layout(
        title=f"Activation Patterns for Correlation Community (Size: {len(community_indices)})",
        width=max(800, 100 + 50 * len(community_indices)),
        height=500
    )
    
    fig.show()
    
    # Print the top prompts
    print("\nTop activating prompts for this community:")
    for i, prompt_idx in enumerate(top_prompt_indices):
        prompt = prompts[prompt_idx]
        max_act = max_activations[prompt_idx]
        
        # Truncate prompt for display
        if len(prompt) > 100:
            display_prompt = prompt[:100] + "..."
        else:
            display_prompt = prompt
            
        print(f"{i+1}. \"{display_prompt}\"")
        print(f"   Max activation: {max_act:.4f}")

# Visualize the largest correlation community
if len(high_corr_pairs) > 0 and 'communities' in locals() and communities:
    largest_community = communities[1]
    visualize_correlation_community(filtered_acts, largest_community, feature_indices, prompts)

## Cluster scoring