# Overview & Current Questions
This is for validating my clustering idea:
1. Store SAE activations for all prompts
2. Preprocess activations based on entropy and activation level (optional?)
3. Cluster activations using UMAP and HDBSCAN
4. Analyze clusters

Some questions: 
1. How does entropy, activation preprocessing affect clustering? How does n_prompts, length, affect this step?
2. How useful is UMAP? What are the best parameters?
3. How does HDBSCAN perform? What are the best parameters?
4. How does the clustering change when we use different datasets?
5. What commonalities do clusters have? Does this vary by dataset? Hyperparameter?

### Imports, config, model setup

In [1]:
%load_ext autoreload
%autoreload 2

from sae_lens import SAE, HookedSAETransformer
import torch
import gc
from config import config # cfg auto updates

import random
from datasets import load_dataset
import os

import plotly.express as px
import plotly.graph_objects as go
import pandas as pd

import numpy as np

from sklearn.cluster import DBSCAN
from sklearn.metrics import silhouette_score
import umap.umap_ as umap
import plotly.graph_objects as go
from einops import rearrange
import hdbscan

  from .autonotebook import tqdm as notebook_tqdm


In [2]:

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = HookedSAETransformer.from_pretrained("EleutherAI/pythia-70m-deduped", device=device)
sae, _, _ = SAE.from_pretrained(
    release="pythia-70m-deduped-mlp-sm",
    sae_id="blocks.3.hook_mlp_out",
    device=device
)


Loaded pretrained model EleutherAI/pythia-70m-deduped into HookedTransformer


# Collecting activations

In [3]:
def should_invalidate_cache(cached_data, n_prompts):
    """Check if cache should be invalidated based on n_prompts."""
    if cached_data is None:
        return True
    return cached_data['acts'].shape[0] < n_prompts * 0.85

def load_diverse_prompts(config):
    """Load diverse prompts from multiple sources for maximum SAE latent activation.
    
    Sources include:
    - Wikipedia (academic/factual)
    - C4 (web text)
    - BookCorpus (fiction/narrative)
    - Wikitext (structured wiki text)
    - HuggingFace Datasets:
        - code (programming)
        - scientific papers
        - news articles
        - social media
        - dialogue
    """
    from datasets import load_dataset
    import random
    import re
    
    n_prompts = config.get('n_prompts', 500)
    prompts = []
    
    datasets = [
        # General knowledge / factual
        ('wikipedia', '20220301.en', 'train'),  # Wikipedia articles
        ('wikitext', 'wikitext-103-raw-v1', 'train'),  # More wiki text
        
        # Web/social text
        ('allenai/c4', 'en', 'train'),  # Web text
        ('reddit', None, 'train'),  # Social media discussions
        ('tweet_eval', 'sentiment', 'train'),  # Twitter posts
        
        # Books and stories
        ('bookcorpus', None, 'train'),  # Fiction books
        
        # Technical/specialized
        ('code_search_net', 'python', 'train'),  # Code + documentation
        ('glue', 'cola', 'train'),  # Linguistic examples - need to access 'sentence' field
        
        # Dialogue and conversation
        ('daily_dialog', None, 'train'),  # Conversations
        ('squad', 'plain_text', 'train'),  # Question-answer pairs - need plain_text config
    ]
    
    
    prompts_per_dataset = n_prompts // len(datasets)
    seen_content = set()  # Track unique content
    
    def clean_text(text):
        """Clean and normalize text."""
        if not isinstance(text, str):
            return ""
        # Remove excessive whitespace
        text = re.sub(r'\s+', ' ', text).strip()
        # Remove URLs
        text = re.sub(r'http\S+|www.\S+', '', text)
        # Remove special characters but keep basic punctuation
        text = re.sub(r'[^\w\s.,!?-]', '', text)
        return text
    
    def is_diverse_enough(text, seen_content):
        """Check if text is diverse enough from existing content."""
        # Skip very short texts
        if len(text) < config.get('min_prompt_length', 50):
            return False
            
        # Skip very long texts
        if len(text) > config.get('max_prompt_length', 1000):
            return False
            
        # Check for near-duplicates using character n-grams
        ngram_size = 10
        text_ngrams = set(text[i:i+ngram_size] for i in range(len(text)-ngram_size+1))
        
        # Calculate overlap with existing content
        for existing in seen_content:
            existing_ngrams = set(existing[i:i+ngram_size] for i in range(len(existing)-ngram_size+1))
            overlap = len(text_ngrams & existing_ngrams) / len(text_ngrams)
            if overlap > 0.5:  # More than 50% overlap
                return False
        
        return True
    
    for name, subset, split in datasets:
        try:
            print(f"Loading prompts from {name}...")
            if subset:
                dataset = load_dataset(name, subset, split=split, streaming=True)
            else:
                dataset = load_dataset(name, split=split, streaming=True)
            
            dataset_prompts = []
            # Determine the correct text field based on dataset
            if name == 'code_search_net':
                text_key = 'whole_func_string'
            elif name == 'glue':
                text_key = 'sentence'  # GLUE uses 'sentence' field
            elif name == 'squad':
                text_key = 'context'  # SQuAD uses 'context' field
            elif name == 'tweet_eval':
                text_key = 'text'
            elif name == 'daily_dialog':
                text_key = 'dialog'
            else:
                text_key = 'text' if 'text' in next(iter(dataset)) else 'content'
            
            for item in dataset:
                text = item[text_key]
                if isinstance(text, list):  # Handle list-type fields (e.g., dialogues)
                    text = " ".join(text)
                
                text = clean_text(text)
                if text and is_diverse_enough(text, seen_content):
                    dataset_prompts.append(text)
                    seen_content.add(text)
                
                if len(dataset_prompts) >= prompts_per_dataset:
                    break
            
            prompts.extend(dataset_prompts)
            print(f"Added {len(dataset_prompts)} prompts from {name}")
            
        except Exception as e:
            print(f"Error with {name}: {e}")
            continue
    
    # Final deduplication and shuffling
    prompts = list(set(prompts))
    random.shuffle(prompts)
    prompts = prompts[:n_prompts]
    
    # Print diversity statistics
    print(f"\nTotal unique prompts loaded: {len(prompts)}")
    print("\nPrompt statistics:")
    lengths = [len(p.split()) for p in prompts]
    print(f"Average words per prompt: {sum(lengths)/len(lengths):.1f}")
    print(f"Min length: {min(lengths)}, Max length: {max(lengths)}")
    
    # Sample and print some prompts for inspection
    print("\nSample prompts:")
    for p in random.sample(prompts, min(5, len(prompts))):
        print(f"- {p[:100]}...")
    
    return prompts

def get_cache_filename(config, n_prompts):
    """Generate a cache filename based on hierarchical config parameters."""
    # Create a unique filename based on key parameters
    
    params = [
        f"prompts_{n_prompts}",
    ]
    print(params)
    
    # Create directory if it doesn't exist
    cache_dir = config.get('cache_dir', 'cache')
    os.makedirs(cache_dir, exist_ok=True)
    
    return os.path.join(cache_dir, f"processed_data_{'_'.join(params)}.pt") 

In [4]:
def clear_cache():
    """Clear CUDA cache to free memory."""
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    gc.collect()

def create_embed_hook(P):
    def hook(value, hook):
        return P.unsqueeze(0)
    return hook

def get_feature_activations(model, sae, tokens, P=None):
    hooks = [('hook_embed', create_embed_hook(P))] if P is not None else []
    with model.hooks(fwd_hooks=hooks):
        _, cache = model.run_with_cache_with_saes(
            tokens, 
            saes=[sae],
            names_filter=lambda name: name == 'blocks.3.hook_mlp_out.hook_sae_acts_post'  # Only cache what you need
        )
    return cache['blocks.3.hook_mlp_out.hook_sae_acts_post']

def get_model_activations(model, tokens):
    _, cache = model.run_with_cache(
        tokens,
        names_filter=lambda name: name == 'blocks.3.hook_mlp_out'  # Only cache what you need
    )
    return cache['blocks.3.hook_mlp_out']

In [5]:
def collect_activations(model, sae, prompts, config):
    """Collect feature activations from prompts."""
    all_acts = []
    batch_size = config.get('batch_size', 10)
    for i in range(0, len(prompts), batch_size):
        batch_prompts = prompts[i:i + batch_size]
        print(f"Processing batch {i // batch_size + 1}/{len(prompts) // batch_size + 1}...")
        batch_acts = []
        for prompt in batch_prompts:
            try:
                tokens = model.to_tokens(prompt)
                acts = get_feature_activations(model, sae, tokens)
                #acts = get_model_activations(model, tokens)
                batch_acts.append(acts.mean(dim=1).squeeze(0)) # dont care about position specific information
                clear_cache()
            except Exception as e:
                print(f"Skipping prompt '{e}")
        
        if batch_acts:
            all_acts.extend(batch_acts)
            
    acts = torch.stack(all_acts)
    print(f"Collected activations for {acts.shape[0]} prompts, {acts.shape[1]} features")
    return acts

In [6]:
# Determine whether to use cached data 
n_prompts = config['n_prompts']
use_cached_data = config['use_cached_data']
cache_filename = get_cache_filename(config, n_prompts)
cached_data = load_processed_data(cache_filename)

print(should_invalidate_cache(cached_data, n_prompts))


if use_cached_data and cached_data is not None: #and not should_invalidate_cache(cached_data, n_prompts):
    print(f"Using cached data with {cached_data['acts'].shape[0]} prompts")
    acts = cached_data['acts']
    prompts = cached_data.get('prompts', [])
else:
    print(f"Processing data from scratch with {n_prompts} prompts")
    prompts = load_diverse_prompts(config)
    acts = collect_activations(model, sae, prompts, config)
    
    # Save the new data
    torch.save({
        'acts': acts,
        'prompts': prompts
    }, cache_filename)
    print(f"Data saved successfully to {cache_filename}")

['prompts_25000']
No cached data found at feature_cache/processed_data_prompts_25000.pt
True
Processing data from scratch with 25000 prompts
Loading prompts from wikipedia...


KeyboardInterrupt: 