In [None]:
# %% Load Models 

import torch
from sae_lens import SAE, HookedSAETransformer
from transformer_lens import ActivationCache, HookedTransformer, utils
# from hook_sae import HookedSAETransformer
# from transformer_lens import ActivationCache, HookedTransformer, utils
# from transformer_lens.hook_points import HookPoint


device = "cuda" if torch.cuda.is_available() else "cpu"
sae, cfg_dict, sparsity = SAE.from_pretrained(
    release = "pythia-70m-deduped-mlp-sm", # see other options in sae_lens/pretrained_saes.yaml
    sae_id = "blocks.3.hook_mlp_out", # won't always be a hook point
    device = device
)

pythia: HookedSAETransformer = HookedSAETransformer.from_pretrained("EleutherAI/pythia-70m-deduped", device=device)

  from .autonotebook import tqdm as notebook_tqdm
  return torch._C._cuda_getDeviceCount() > 0


Loaded pretrained model EleutherAI/pythia-70m-deduped into HookedTransformer


In [None]:
# %% Output Comparison with and without SAEs

'''
Here we can compare the performance of the model with and without SAEs.
'''

prompt = "Mitigating the risk of extinction from AI should be a global"
answer = " priority"

# Commented out, but shows output of the model with and without SAEs

# # First see how the model does without SAEs
# utils.test_prompt(prompt, answer, pythia)

# # Test our prompt, to see what the model says
# with pythia.saes(saes=[sae]):
#     utils.test_prompt(prompt, answer, pythia)

# # Same thing, done in a different way
# pythia.add_sae(sae)
# utils.test_prompt(prompt, answer, pythia)
# pythia.reset_saes()  # Remember to always do this!

# Using `run_with_saes` method in place of standard forward pass
logits = pythia(prompt, return_type="logits")
logits_sae = pythia.run_with_saes(prompt, saes=[sae], return_type="logits")
answer_token_id = pythia.to_single_token(answer)

# Getting model's prediction
top_prob, token_id_prediction = logits[0, -1].softmax(-1).max(-1)
top_prob_sae, token_id_prediction_sae = logits_sae[0, -1].softmax(-1).max(-1)

print(f"""Standard model: top prediction = {pythia.to_string(token_id_prediction)!r}, prob = {top_prob.item():.2%}
SAE reconstruction: top prediction = {pythia.to_string(token_id_prediction_sae)!r}, prob = {top_prob_sae.item():.2%}
""")

Standard model: top prediction = ' priority', prob = 11.90%
SAE reconstruction: top prediction = ' priority', prob = 7.77%



In [None]:
# %% Training Setup

# Let's create tracking variables for our diagnostics
training_stats = {
    'loss': [],
    'feature_loss': [],
    'reg_loss': [],
    'target_activation': [],
    'gradient_norm': [],
    'embedding_distance': [],
    'similarity_top': [],
    'learning_rates': []
}

length = 4
d_model = pythia.cfg.d_model

# Reset our parameter for a fresh start
# Use a different initialization strategy - start closer to actual tokens
noise_scale = 0.1  # Smaller noise
P = torch.nn.Parameter(
    # Start from random token embeddings instead of mean embedding
    pythia.W_E[torch.randint(0, pythia.cfg.d_vocab, (length,))].clone() + 
    torch.randn(length, d_model, device=device) * noise_scale,
    requires_grad=True
)

print(f"Initial P shape: {P.shape}")

# Try a more aggressive optimizer
optimizer = torch.optim.AdamW([P], lr=1e-1, weight_decay=0)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.7, patience=20, verbose=True
)

# Create dummy tokens for the forward pass
dummy_tokens = torch.zeros(1, length, dtype=torch.long, device=device)

# Define embedding hook
def embed_hook(value, hook):
    return P.unsqueeze(0)

Initial P shape: torch.Size([4, 512])


In [None]:
# %% Option 1: High Activation Feature from Random Tokens

with torch.no_grad():
    # Test many random token embeddings to find responsive features
    n_samples = 100
    sample_tokens = torch.randint(0, pythia.cfg.d_vocab, (n_samples, length), device=device)
    
    # Get activations for random tokens
    all_activations = []
    for i in range(0, n_samples, 10):  # Process in batches
        batch = sample_tokens[i:i+10]
        _, cache = pythia.run_with_cache_with_saes(
            input=batch,
            return_type="logits",
            saes=[sae]
        )
        batch_acts = cache['blocks.3.hook_mlp_out.hook_sae_acts_post']
        all_activations.append(batch_acts)
    
    # Combine all activations
    combined_acts = torch.cat(all_activations, dim=0)  # [n_samples, seq_len, n_features]
    
    # Find features with highest max activation
    max_activations = combined_acts.max(dim=0).values.max(dim=0).values
    
    # Get top 10 features with highest max activation
    top_max_features = torch.argsort(max_activations, descending=True)[:10]
    
    # Print info about top features
    print("\nTop features by maximum activation:")
    for i, feat_idx in enumerate(top_max_features):
        print(f"Feature {feat_idx}: Max activation = {max_activations[feat_idx]:.4f}")
    
    # Choose the feature with highest max activation
    target_feature = top_max_features[0].item()
    print(f"\nSelected feature {target_feature} with max activation {max_activations[target_feature]:.4f}")
    
    # Find tokens that activate this feature strongly
    feature_activations = combined_acts[:, :, target_feature]  # [n_samples, seq_len]
    max_activation_indices = torch.argmax(feature_activations.view(-1))
    sample_idx = max_activation_indices // length
    pos_idx = max_activation_indices % length
    
    # Get the token that most activates this feature
    activating_token_id = sample_tokens[sample_idx, pos_idx]
    activating_token = pythia.to_string(activating_token_id)
    activation_value = feature_activations[sample_idx, pos_idx]
    
    print(f"Token that most activates feature {target_feature}: '{activating_token}' with activation {activation_value:.4f}")
    
    # Initialize P with this token's embedding plus noise
    P.data = pythia.W_E[activating_token_id].repeat(length, 1) + torch.randn(length, d_model, device=device) * noise_scale


Top features by maximum activation:
Feature 20383: Max activation = 14.0576
Feature 7082: Max activation = 3.9123
Feature 16939: Max activation = 3.6901
Feature 18192: Max activation = 3.5958
Feature 21462: Max activation = 3.5672
Feature 27926: Max activation = 3.2479
Feature 16970: Max activation = 2.4913
Feature 1932: Max activation = 2.3153
Feature 17485: Max activation = 2.1687
Feature 11786: Max activation = 1.9848

Selected feature 20383 with max activation 14.0576
Token that most activates feature 20383: ',\' with activation 14.0576


In [None]:
# %% Option 2: Contextual Features

print("\nSearching for contextual features...")

# Generate diverse prompts
test_prompts = [
    "The president of the United States",
    "Once upon a time in a galaxy",
    "The quick brown fox jumps over",
    "To be or not to be that",
    "Four score and seven years ago"
]

# Get the number of features from the SAE
n_features = sae.cfg.d_sae  # Use d_sae instead of n_components

# Track feature variance across positions
feature_position_variance = torch.zeros(n_features, device=device)

# Process each prompt
for prompt in test_prompts:
    print(f"\nAnalyzing prompt: '{prompt}'")
    _, cache = pythia.run_with_cache_with_saes(prompt, saes=[sae])
    acts = cache['blocks.3.hook_mlp_out.hook_sae_acts_post'][0]
    
    # For each feature, measure how much its activation varies by position
    for feat_idx in range(n_features):
        # Add to running variance calculation
        feature_position_variance[feat_idx] += acts[:, feat_idx].var()

# Normalize by number of prompts
feature_position_variance /= len(test_prompts)

# Get top contextual features (highest position variance)
top_contextual_features = torch.argsort(feature_position_variance, descending=True)[:20]

print("\nTop contextual features (highest position variance):")
for i, feat_idx in enumerate(top_contextual_features):
    print(f"Feature {feat_idx}: Position variance = {feature_position_variance[feat_idx]:.4f}")

# Choose a contextual feature for optimization
contextual_feature = top_contextual_features[0].item()
print(f"\nSelected contextual feature {contextual_feature} for optimization")

# Now you can set target_feature = contextual_feature in your optimization

target_feature = contextual_feature


Searching for contextual features...

Analyzing prompt: 'The president of the United States'

Analyzing prompt: 'Once upon a time in a galaxy'

Analyzing prompt: 'The quick brown fox jumps over'

Analyzing prompt: 'To be or not to be that'

Analyzing prompt: 'Four score and seven years ago'

Top contextual features (highest position variance):
Feature 20383: Position variance = 26.3358
Feature 7082: Position variance = 1.6614
Feature 21462: Position variance = 1.6153
Feature 18192: Position variance = 1.2300
Feature 24661: Position variance = 0.4472
Feature 22510: Position variance = 0.2908
Feature 28580: Position variance = 0.2782
Feature 2082: Position variance = 0.2656
Feature 29131: Position variance = 0.2395
Feature 22514: Position variance = 0.1949
Feature 8178: Position variance = 0.1913
Feature 26284: Position variance = 0.1718
Feature 6884: Position variance = 0.1251
Feature 9359: Position variance = 0.0908
Feature 29366: Position variance = 0.0814
Feature 27926: Position var

In [None]:
# %% Option 3: Co-Activated Features

print("\nSearching for feature combinations...")

# Generate many random sequences
n_samples = 50
sample_tokens = torch.randint(0, pythia.cfg.d_vocab, (n_samples, length), device=device)

# Get activations for all samples
all_acts = []
for i in range(0, n_samples, 10):
    batch = sample_tokens[i:i+10]
    _, cache = pythia.run_with_cache_with_saes(
        input=batch,
        return_type="logits",
        saes=[sae]
    )
    batch_acts = cache['blocks.3.hook_mlp_out.hook_sae_acts_post']
    all_acts.append(batch_acts)

# Combine activations
combined_acts = torch.cat(all_acts, dim=0)  # [n_samples, seq_len, n_features]

# Get the number of features from the SAE
n_features = sae.cfg.d_sae  # Use d_sae instead of n_components

# Reshape to [n_samples*seq_len, n_features]
reshaped_acts = combined_acts.reshape(-1, n_features)

# Calculate correlation matrix between features
feature_corr = torch.corrcoef(reshaped_acts.T)

# Find features that correlate strongly with your target feature
target_correlations = feature_corr[target_feature]
top_correlated = torch.argsort(target_correlations.abs(), descending=True)[1:6]  # Skip self-correlation

print(f"\nFeatures most correlated with feature {target_feature}:")
for feat_idx in top_correlated:
    corr = target_correlations[feat_idx].item()
    print(f"Feature {feat_idx}: Correlation = {corr:.4f}")

# You could then optimize for a weighted combination of these features

target_feature = top_correlated[0].item()


Searching for feature combinations...

Features most correlated with feature 20383:
Feature 21123: Correlation = nan
Feature 21122: Correlation = nan
Feature 21120: Correlation = nan
Feature 21119: Correlation = nan
Feature 21116: Correlation = nan


In [None]:
# %% Training Loop

# Track initial similarity and activation
with torch.no_grad():
    # Get initial token similarity
    learned_embeds = P
    vocab_embeds = pythia.W_E
    similarity = torch.nn.functional.normalize(learned_embeds, dim=-1) @ \
                torch.nn.functional.normalize(vocab_embeds, dim=-1).T
    initial_top_sim = similarity.max(dim=-1).values.mean().item()
    training_stats['similarity_top'].append(initial_top_sim)
    
    # Get initial activation
    with pythia.hooks(fwd_hooks=[('hook_embed', embed_hook)]):
        _, cache = pythia.run_with_cache_with_saes(
            input=dummy_tokens,
            return_type="logits",
            saes=[sae]
        )
    initial_activation = cache['blocks.3.hook_mlp_out.hook_sae_acts_post'][0, :, target_feature].max().item()
    print(f"Initial activation for feature {target_feature}: {initial_activation:.4f}")
    print(f"Initial top token similarity: {initial_top_sim:.4f}")

# Adjust optimization parameters
lambda_reg = 1e-3  # Stronger regularization to keep embeddings realistic
max_steps = 300

# Create a function to get the most similar token for each position
def get_similar_tokens(embeddings, top_k=5):
    with torch.no_grad():
        similarity = torch.nn.functional.normalize(embeddings, dim=-1) @ \
                    torch.nn.functional.normalize(vocab_embeds, dim=-1).T
        top_tokens = similarity.topk(top_k, dim=-1)
        
        result = []
        for pos in range(length):
            tokens = [pythia.to_string(idx) for idx in top_tokens.indices[pos]]
            scores = top_tokens.values[pos]
            result.append((tokens, scores))
        return result

# Print initial similar tokens
initial_tokens = get_similar_tokens(P)
print("\nInitial similar tokens:")
for pos, (tokens, scores) in enumerate(initial_tokens):
    print(f"Position {pos}: {tokens[0]} ({scores[0]:.3f})")

for step in range(max_steps):
    optimizer.zero_grad()
    
    with pythia.hooks(fwd_hooks=[('hook_embed', embed_hook)]):
        outputs = pythia.run_with_cache_with_saes(
            input=dummy_tokens,
            return_type="logits",
            saes=[sae]
        )
    
    _, cache = outputs
    sae_acts = cache['blocks.3.hook_mlp_out.hook_sae_acts_post'][0]
    
    # Extract target feature activation - use max activation
    target_activation = sae_acts[:, target_feature].max()
    
    # Simplified loss function - focus directly on maximizing the feature
    loss_feature = -target_activation  # No scaling needed
    
    # Regularization to keep embeddings close to real token embeddings
    # Find closest token embedding for each position
    with torch.no_grad():
        similarity = torch.nn.functional.normalize(P, dim=-1) @ \
                    torch.nn.functional.normalize(vocab_embeds, dim=-1).T
        closest_tokens = similarity.max(dim=1).indices
        closest_embeddings = vocab_embeds[closest_tokens]
    
    # Regularize towards closest token embeddings
    embedding_diff = P - closest_embeddings
    embedding_dist = torch.norm(embedding_diff, p='fro')
    loss_reg = lambda_reg * embedding_dist
    
    # Add diversity penalty to encourage different tokens
    token_diversity_penalty = 0.0
    with torch.no_grad():
        # Calculate similarity between positions
        position_similarity = torch.zeros(length, length, device=device)
        for i in range(length):
            for j in range(i+1, length):
                sim = torch.nn.functional.cosine_similarity(P[i:i+1], P[j:j+1], dim=1)
                position_similarity[i,j] = sim
                position_similarity[j,i] = sim
        
        # Penalize high similarity between positions
        token_diversity_penalty = position_similarity.mean() * 0.1

    # Add to loss
    loss = loss_feature + loss_reg + token_diversity_penalty
    
    # Record stats
    training_stats['loss'].append(loss.item())
    training_stats['feature_loss'].append(loss_feature.item())
    training_stats['reg_loss'].append(loss_reg.item())
    training_stats['target_activation'].append(target_activation.item())
    training_stats['embedding_distance'].append(embedding_dist.item())
    training_stats['learning_rates'].append(optimizer.param_groups[0]['lr'])
    
    if step % 20 == 0:
        print(f"\nStep {step}")
        print(f"Target feature {target_feature} values: {sae_acts[:, target_feature]}")
        print(f"Feature loss: {loss_feature.item():.4f}")
        print(f"Reg loss: {loss_reg.item():.4f}")
        print(f"Total loss: {loss.item():.4f}")
    
    loss.backward()
    
    # Record gradient norm
    grad_norm = P.grad.norm().item()
    training_stats['gradient_norm'].append(grad_norm)
    
    # Gradient clipping with higher threshold
    torch.nn.utils.clip_grad_norm_([P], max_norm=10.0)
    
    optimizer.step()
    
    # Update learning rate less frequently
    if step % 10 == 0:
        scheduler.step(loss)
    
    # Check similarity periodically
    if step % 20 == 0 or step == max_steps - 1:
        with torch.no_grad():
            similarity = torch.nn.functional.normalize(P, dim=-1) @ \
                        torch.nn.functional.normalize(vocab_embeds, dim=-1).T
            top_sim = similarity.max(dim=1).values.mean().item()
            training_stats['similarity_top'].append(top_sim)
            
            # Print current similar tokens
            similar_tokens = get_similar_tokens(P, top_k=1)
            tokens_str = " ".join([t[0][0] for t in similar_tokens])
            print(f"Current tokens: {tokens_str}")
            print(f"Top similarity: {top_sim:.4f}")

Initial activation for feature 21123: 0.0000
Initial top token similarity: 0.2836

Initial similar tokens:
Position 0: ,\ (0.267)
Position 1: ,\ (0.318)
Position 2: ,\ (0.271)
Position 3: ,\ (0.279)

Step 0
Target feature 21123 values: tensor([0., 0., 0., 0.])
Feature loss: -0.0000
Reg loss: 0.0045
Total loss: 0.0088
Current tokens: ,\ ,\ ,\ ,\
Top similarity: 0.4435

Step 20
Target feature 21123 values: tensor([0., 0., 0., 0.])
Feature loss: -0.0000
Reg loss: 0.0006
Total loss: 0.0655
Current tokens: ,\ ,\ ,\ ,\
Top similarity: 0.9415

Step 40
Target feature 21123 values: tensor([0., 0., 0., 0.])
Feature loss: -0.0000
Reg loss: 0.0002
Total loss: 0.0735
Current tokens: ,\ ,\ ,\ ,\
Top similarity: 0.9845

Step 60
Target feature 21123 values: tensor([0., 0., 0., 0.])
Feature loss: -0.0000
Reg loss: 0.0003
Total loss: 0.0727
Current tokens: ,\ ,\ ,\ ,\
Top similarity: 0.9914

Step 80
Target feature 21123 values: tensor([0., 0., 0., 0.])
Feature loss: -0.0000
Reg loss: 0.0002
Total loss: 

In [None]:
# %% Visualize Training Progress

import plotly.subplots as sp
import plotly.graph_objects as go


# Create a figure with subplots
fig = sp.make_subplots(
    rows=3, cols=2,
    subplot_titles=(
        'Loss Components', 'Target Feature Activation',
        'Gradient Norm', 'Embedding Distance from Mean',
        'Learning Rate', 'Top Token Similarity'
    )
)

# Plot loss components
fig.add_trace(
    go.Scatter(y=training_stats['loss'], mode='lines', name='Total Loss'),
    row=1, col=1
)
fig.add_trace(
    go.Scatter(y=training_stats['feature_loss'], mode='lines', name='Feature Loss'),
    row=1, col=1
)
fig.add_trace(
    go.Scatter(y=training_stats['reg_loss'], mode='lines', name='Regularization Loss'),
    row=1, col=1
)

# Plot target activation
fig.add_trace(
    go.Scatter(y=training_stats['target_activation'], mode='lines', name='Target Activation'),
    row=1, col=2
)

# Plot gradient norm
fig.add_trace(
    go.Scatter(y=training_stats['gradient_norm'], mode='lines', name='Gradient Norm'),
    row=2, col=1
)

# Plot embedding distance
fig.add_trace(
    go.Scatter(y=training_stats['embedding_distance'], mode='lines', name='Embedding Distance'),
    row=2, col=2
)

# Plot learning rate
fig.add_trace(
    go.Scatter(y=training_stats['learning_rates'], mode='lines', name='Learning Rate'),
    row=3, col=1
)

# Plot top similarity
steps = list(range(0, max_steps, 20)) + [max_steps-1]
fig.add_trace(
    go.Scatter(x=steps, y=training_stats['similarity_top'], mode='lines+markers', name='Top Token Similarity'),
    row=3, col=2
)

fig.update_layout(height=900, width=1000, title_text="Feature Optimization Training Progress")
fig.show()

In [None]:
# %% Analyze Results

with torch.no_grad():
    # Get activations with optimized embeddings
    with pythia.hooks(fwd_hooks=[('hook_embed', embed_hook)]):
        _, final_cache = pythia.run_with_cache_with_saes(
            input=dummy_tokens,
            return_type="logits",
            saes=[sae]
        )
    
    # Extract activations
    final_acts = final_cache['blocks.3.hook_mlp_out.hook_sae_acts_post'][0]
    
    # 1. Show the optimized token sequence
    learned_embeds = P  # [length, d_model]
    vocab_embeds = pythia.W_E  # [vocab_size, d_model]
    
    # Compute cosine similarity
    similarity = torch.nn.functional.normalize(learned_embeds, dim=-1) @ \
                torch.nn.functional.normalize(vocab_embeds, dim=-1).T
    
    # Get top tokens for each position
    top_k = 5
    top_tokens = similarity.topk(top_k, dim=-1)
    
    print(f"\n\n=== RESULTS FOR FEATURE {target_feature} ===")
    print("\nOptimized sequence as tokens:")
    for pos in range(length):
        tokens = [pythia.to_string(idx) for idx in top_tokens.indices[pos]]
        scores = top_tokens.values[pos]
        print(f"Position {pos}:")
        for token, score in zip(tokens, scores):
            print(f"  {token:20} (similarity: {score:.3f})")
    
    # 2. Show feature activation pattern
    print(f"\nFeature {target_feature} activation pattern:")
    for pos in range(length):
        print(f"Position {pos}: {final_acts[pos, target_feature]:.4f}")
    
    # 3. Plot activation pattern
    fig = go.Figure()
    fig.add_trace(go.Bar(
        x=list(range(length)),
        y=final_acts[:, target_feature].cpu().numpy(),
        marker_color='red'
    ))
    
    fig.update_layout(
        title=f'Feature {target_feature} Activation Pattern',
        xaxis_title='Sequence Position',
        yaxis_title='Activation Value',
    )
    
    fig.show()
    
    # 4. Check for other strongly activated features
    # Get mean activation for each feature
    mean_acts = final_acts.mean(dim=0)
    # Get top activated features
    top_activated = torch.argsort(mean_acts, descending=True)[:10]
    
    print("\nOther strongly activated features:")
    for i, feat_idx in enumerate(top_activated):
        if feat_idx == target_feature:
            print(f"Feature {feat_idx}: {mean_acts[feat_idx]:.4f} (target feature)")
        else:
            print(f"Feature {feat_idx}: {mean_acts[feat_idx]:.4f}")
    
    # 5. Visualize activation pattern across all positions for top features
    top_5_features = top_activated[:5]
    
    fig = go.Figure()
    for feat_idx in top_5_features:
        feat_name = f"Feature {feat_idx}"
        if feat_idx == target_feature:
            feat_name += " (target)"
        
        fig.add_trace(go.Bar(
            x=[f"Pos {i}" for i in range(length)],
            y=final_acts[:, feat_idx].cpu().numpy(),
            name=feat_name
        ))
    
    fig.update_layout(
        title='Activation Patterns for Top Features',
        xaxis_title='Sequence Position',
        yaxis_title='Activation Value',
        barmode='group'
    )
    
    fig.show()



=== RESULTS FOR FEATURE 21123 ===

Optimized sequence as tokens:
Position 0:
  ,\                   (similarity: 0.997)
  },\                  (similarity: 0.598)
   ,\                  (similarity: 0.576)
  ),\                  (similarity: 0.484)
  ',\                  (similarity: 0.479)
Position 1:
  ,\                   (similarity: 0.997)
  },\                  (similarity: 0.604)
   ,\                  (similarity: 0.575)
  ),\                  (similarity: 0.485)
  ',\                  (similarity: 0.479)
Position 2:
  ,\                   (similarity: 0.997)
  },\                  (similarity: 0.600)
   ,\                  (similarity: 0.578)
  ),\                  (similarity: 0.487)
  ',\                  (similarity: 0.483)
Position 3:
  ,\                   (similarity: 0.997)
  },\                  (similarity: 0.602)
   ,\                  (similarity: 0.577)
  ),\                  (similarity: 0.485)
  ',\                  (similarity: 0.479)

Feature 21123 activation


Other strongly activated features:
Feature 20383: 14.0950
Feature 18192: 3.4520
Feature 21462: 3.1163
Feature 7082: 2.3640
Feature 27926: 0.5587
Feature 5167: 0.4007
Feature 26654: 0.3800
Feature 9123: 0.3457
Feature 16970: 0.3168
Feature 31054: 0.2910
