In [1]:
import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformer_lens import HookedTransformer
from sae_lens import SAE
import numpy as np
from typing import List, Dict

ModuleNotFoundError: No module named 'sae_lens'

In [None]:
class SAEAnalyzer:
    def __init__(self, device="cuda" if torch.cuda.is_available() else "cpu"):
        self.device = device
        
        # Initialize base model and SAE
        print("Loading models...")
        self.model = HookedTransformer.from_pretrained("gpt2-small", device=device)
        
        # Load the SAE
        self.sae, self.cfg_dict, self.sparsity = SAE.from_pretrained(
            release="gpt2-small-res-jb",
            sae_id="blocks.8.hook_resid_pre",
            device=device,
        )
        
    def get_first_n_sentences(self, text: str, n: int = 5) -> List[str]:
        """Extract first n sentences from text."""
        sentences = []
        current = []
        
        for char in text:
            current.append(char)
            if char in '.!?':
                sentence = ''.join(current).strip()
                if sentence and not sentence.endswith(('Mr.', 'Mrs.', 'Dr.', 'Ms.')):
                    sentences.append(sentence)
                    if len(sentences) >= n:
                        break
                    current = []
                    
        return sentences[:n]
    
    def get_first_n_paragraphs(self, text: str, n: int = 5) -> List[str]:
        """Extract first n paragraphs from text."""
        paragraphs = []
        current = []
        
        lines = text.split('\n')
        for line in lines:
            if line.strip():
                current.append(line.strip())
            elif current:
                paragraphs.append(' '.join(current))
                current = []
                if len(paragraphs) >= n:
                    break
                    
        if current and len(paragraphs) < n:
            paragraphs.append(' '.join(current))
            
        return paragraphs[:n]
    
    def get_active_neurons(self, text: str) -> Dict:
        """Get active neurons for a piece of text."""
        tokens = self.model.to_tokens(text, prepend_bos=True)
        
        # Run the model and get activations at the SAE's hook point
        _, cache = self.model.run_with_cache(
            tokens,
            names_filter=lambda name: name == self.sae.cfg.hook_name
        )
        
        # Get activations at the hook point
        activations = cache[self.sae.cfg.hook_name]
        
        # Encode activations with SAE
        encoded = self.sae.encode(activations)
        
        # Find most active neurons
        # Shape: [batch, pos, features]
        neuron_activities = encoded.abs().mean(dim=1)  # Average over positions
        top_neurons = torch.topk(neuron_activities, k=10, dim=-1)
        
        return {
            'indices': top_neurons.indices.cpu().numpy(),
            'values': top_neurons.values.cpu().numpy(),
            'encoded': encoded.cpu().numpy()
        }
    
    def analyze_text_file(self, file_path: str) -> Dict:
        """Analyze a text file at different levels."""
        with open(file_path, 'r', encoding='utf-8') as file:
            full_text = file.read()
            
        # Get text at different levels
        sentences = self.get_first_n_sentences(full_text, 5)
        paragraphs = self.get_first_n_paragraphs(full_text, 5)
        
        results = {
            'sentences': {
                'texts': sentences,
                'neurons': [self.get_active_neurons(s) for s in sentences]
            },
            'paragraphs': {
                'texts': paragraphs,
                'neurons': [self.get_active_neurons(p) for p in paragraphs]
            },
            'full_text': {
                'text': full_text,
                'neurons': self.get_active_neurons(full_text)
            }
        }
        
        return results

In [None]:
def analyze_all_files(data_dir: str) -> Dict:
    """Analyze all text files in directory."""
    analyzer = SAEAnalyzer()
    all_results = {}
    
    for filename in os.listdir(data_dir):
        if filename.endswith('.txt'):
            print(f"\nProcessing {filename}...")
            file_path = os.path.join(data_dir, filename)
            category = filename.replace('.txt', '')
            all_results[category] = analyzer.analyze_text_file(file_path)
            
            # Print summary for this category
            print(f"\nResults for {category}:")
            
            # Show top neurons for each level
            for level in ['sentences', 'paragraphs', 'full_text']:
                print(f"\n{level.upper()}:")
                if level in ['sentences', 'paragraphs']:
                    for i, result in enumerate(all_results[category][level]['neurons']):
                        print(f"\n{level[:-1].capitalize()} {i+1}:")
                        print(f"Top 10 active neurons: {result['indices']}")
                        print(f"Activation values: {result['values'].round(3)}")
                else:
                    result = all_results[category][level]['neurons']
                    print("\nTop 10 active neurons:", result['indices'])
                    print("Activation values:", result['values'].round(3))
                    
    return all_results

In [None]:
if __name__ == "__main__":
    data_dir = "./data"
    results = analyze_all_files(data_dir)