# Imports and Initialization

In [1]:
!pip install sae-lens transformer_lens circuitsvis git+https://github.com/callummcdougall/CircuitsVis.git#subdirectory=python git+https://github.com/callummcdougall/sae_vis.git@callum/v3

Collecting git+https://github.com/callummcdougall/CircuitsVis.git#subdirectory=python
  Cloning https://github.com/callummcdougall/CircuitsVis.git to /tmp/pip-req-build-coo36b4a
  Running command git clone --filter=blob:none --quiet https://github.com/callummcdougall/CircuitsVis.git /tmp/pip-req-build-coo36b4a
  Resolved https://github.com/callummcdougall/CircuitsVis.git to commit 1e6129d08cae7af9242d9ab5d3ed322dd44b4dd3
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting git+https://github.com/callummcdougall/sae_vis.git@callum/v3
  Cloning https://github.com/callummcdougall/sae_vis.git (to revision callum/v3) to /tmp/pip-req-build-_wlscu7j
  Running command git clone --filter=blob:none --quiet https://github.com/callummcdougall/sae_vis.git /tmp/pip-req-build-_wlscu7j
  Running command git checkout -b callum/v3 --track origin/callum/v3
  Switched to a new b

In [2]:
import sae_lens
print(sae_lens.__version__)

6.5.1


In [3]:
import torch
import os

from sae_lens import (
    LanguageModelSAERunnerConfig,
    SAETrainingRunner,
    StandardTrainingSAEConfig,
    LoggingConfig,
)

if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"

print("Using device:", device)
os.environ["TOKENIZERS_PARALLELISM"] = "false"

Using device: cuda


# Training the SAE

In [4]:
import torch as t
total_training_steps = 10000
batch_size = 4096
total_training_tokens = total_training_steps * batch_size

lr_warm_up_steps = 200
lr_decay_steps = total_training_steps // 5  # 20% of training
l1_warm_up_steps = total_training_steps // 20  # 5% of training

cfg = LanguageModelSAERunnerConfig(
    # Data Generating Function (Model + Training Distibuion)
    model_name="tiny-stories-1L-21M",
    hook_name="blocks.0.hook_mlp_out",
    dataset_path="roneneldan/TinyStories",
    is_dataset_tokenized=True,
    streaming=True,

    sae=StandardTrainingSAEConfig(
        d_in=1024,
        d_sae=4096,
        apply_b_dec_to_input=False,
        normalize_activations="expected_average_only_in",
        l1_coefficient=4,
        l1_warm_up_steps=l1_warm_up_steps,
    ),

    # Training Parameters
    lr=4e-4,
    adam_beta1=0.9,
    adam_beta2=0.999,
    lr_scheduler_name="cosineannealing",
    lr_warm_up_steps=lr_warm_up_steps,
    lr_decay_steps=lr_decay_steps,
    train_batch_size_tokens=batch_size,
    context_size=512,
    # Activation Store Parameters
    n_batches_in_buffer=32,
    training_tokens=total_training_tokens,  #
    store_batch_size_prompts=16,
    # Resampling protocol
    feature_sampling_window=300,
    dead_feature_window=300,
    dead_feature_threshold=1e-4,
    # WANDB
    logger=LoggingConfig(
        log_to_wandb=False,
        wandb_project="Sparse_Autoencoder_Training",
        wandb_log_frequency=30,
        eval_every_n_wandb_logs=20,
    ),
    # Misc
    device=device,
    seed=42,
    n_checkpoints=0,
    checkpoint_path="checkpoints",
    dtype="float32",
)
t.set_grad_enabled(True)
runner = SAETrainingRunner(cfg)
sparse_autoencoder_2 = runner.run()


[A
Refilling buffer: 100%|██████████| 16/16 [00:03<00:00,  6.10it/s][A
                                                                 [A
Refilling buffer:   0%|          | 0/16 [00:00<?, ?it/s][A
Refilling buffer:   6%|▋         | 1/16 [00:00<00:12,  1.20it/s][A
Refilling buffer:  12%|█▎        | 2/16 [00:00<00:06,  2.31it/s][A
Refilling buffer:  19%|█▉        | 3/16 [00:01<00:03,  3.28it/s][A
Refilling buffer:  25%|██▌       | 4/16 [00:01<00:02,  4.09it/s][A
Refilling buffer:  31%|███▏      | 5/16 [00:01<00:02,  4.72it/s][A
Refilling buffer:  38%|███▊      | 6/16 [00:01<00:01,  5.16it/s][A
Refilling buffer:  44%|████▍     | 7/16 [00:01<00:01,  5.44it/s][A
Refilling buffer:  50%|█████     | 8/16 [00:01<00:01,  5.64it/s][A
Refilling buffer:  56%|█████▋    | 9/16 [00:02<00:01,  5.74it/s][A
Refilling buffer:  62%|██████▎   | 10/16 [00:02<00:01,  5.75it/s][A
Refilling buffer:  69%|██████▉   | 11/16 [00:02<00:00,  5.85it/s][A
Refilling buffer:  75%|███████▌  | 12/16 [00:02<

# Saving and Recovering sae from drive

In [None]:
# Code for saving trained SAE to google drive, permission is needed
'''
from google.colab import drive
import os
drive.mount('/content/drive')
drive_save_path = "/content/drive/MyDrive/4_model"
os.makedirs(drive_save_path, exist_ok=True)
sparse_autoencoder_2.save_model(drive_save_path)

print(f"Successfully saved trained SAE to: {drive_save_path}")
'''

In [None]:
#Code for extracting Trained SAE from drive
'''
from google.colab import drive
from sae_lens import SAE
drive.mount('/content/drive')
save_folder = "/content/drive/MyDrive/4_model"
sparse_autoencoder = SAE.load_from_disk(save_folder)
print(sparse_autoencoder)
'''

# Evaluation

In [6]:

import webbrowser
import http.server
import socketserver
import threading
from google.colab import output
PORT = 8000

torch.set_grad_enabled(False)
def display_vis_inline(filename: str, height: int = 850):
    """
    Displays the HTML files in Colab. Uses global `PORT` variable defined in prev cell, so that each
    vis has a unique port without having to define a port within the function.
    """
    global PORT

    def serve(directory):
        os.chdir(directory)
        handler = http.server.SimpleHTTPRequestHandler
        with socketserver.TCPServer(("", PORT), handler) as httpd:
            print(f"Serving files from {directory} on port {PORT}")
            httpd.serve_forever()

    thread = threading.Thread(target=serve, args=("/content",))
    thread.start()

    output.serve_kernel_port_as_iframe(
        PORT, path=f"/{filename}", height=height, cache_in_notebook=True
    )

    PORT += 1

In [7]:

from datasets import load_dataset
from pathlib import Path
from IPython.display import HTML, IFrame, display
from sae_lens import HookedSAETransformer
from sae_vis import SaeVisConfig, SaeVisData, SaeVisLayoutConfig
import torch as t

sparse_autoencoder = sparse_autoencoder_2
section_dir = Path("sae_vis_outputs")
section_dir.mkdir(parents=True, exist_ok=True)
model = HookedSAETransformer.from_pretrained("tiny-stories-1L-21M")
dataset = load_dataset(cfg.dataset_path, streaming=True)


# Tokenizing, long way
data_iter = iter(dataset["train"])
token_list = []
vis_context_size = 128  # We need a fixed length for the dashboard
print("Collecting and tokenizing data...")
while len(token_list) < batch_size:
    # 1. Get the raw text
    try:
        text = next(data_iter)["text"]
    except StopIteration:
        break
    batch_tokens = model.to_tokens(text, prepend_bos=True)
    if batch_tokens.shape[1] >= vis_context_size:
        token_list.append(batch_tokens[0, :vis_context_size])
tokens = t.stack(token_list)
print(f"Final tokens shape: {tokens.shape}")

#Solve mismatch problems
sparse_autoencoder.cfg.hook_name = "blocks.0.hook_mlp_out"
sparse_autoencoder.cfg.hook_layer = 0
#Move everything to cuda
device = "cuda" if t.cuda.is_available() else "cpu"
print(f"Using device: {device}")
sparse_autoencoder = sparse_autoencoder.to(device)
tinystories_model = model.to(device)
tokens = tokens.to(device)
print("All components moved to CUDA successfully.")

#From 3.2.1 evaluationlesson
sae_vis_data = SaeVisData.create(
    sparse_autoencoder,
    model=tinystories_model,
    tokens=tokens,
    cfg=SaeVisConfig(features=range(16)),
    verbose=True,
)
sae_vis_data.save_feature_centric_vis(
    filename=str(section_dir / "feature_vis.html"),
    verbose=True,
)

display_vis_inline(section_dir / "feature_vis.html")

Loaded pretrained model tiny-stories-1L-21M into HookedTransformer


'(ProtocolError('Connection aborted.', RemoteDisconnected('Remote end closed connection without response')), '(Request ID: 0cc40801-b78f-49f4-af41-cd318536dd88)')' thrown while requesting GET https://huggingface.co/datasets/roneneldan/TinyStories/resolve/f54c09fd23315a6f9c86f9dc80f725de7d8f9c64/data/train-00000-of-00004-2d5a1467fff1081b.parquet
Retrying in 1s [Retry 1/5].


Collecting and tokenizing data...
Final tokens shape: torch.Size([4096, 128])
Using device: cuda
Moving model to device:  cuda
All components moved to CUDA successfully.


Forward passes to cache data for vis:   0%|          | 0/64 [00:00<?, ?it/s]

Extracting vis data from cached data:   0%|          | 0/16 [00:00<?, ?it/s]

  assert torch.tensor(shape).prod().item() == index_tensor[idx].numel(), \
  full_idx_item = index_tensor[idx].reshape(*shape)
  arr_indexed = arr[full_idx]


Saving feature-centric vis:   0%|          | 0/16 [00:00<?, ?it/s]

<IPython.core.display.Javascript object>

Serving files from /content on port 8000


# **L0 & MSEComparison**

## *What does L0 mean?*
### L0 computes the average number of neurons activated per neuron




In [8]:
import plotly.express as px
import torch

# Ensure we aren't tracking gradients to save memory
with torch.no_grad():
    batch_tokens = tokens[:32]
    hook_name = sparse_autoencoder.cfg.metadata.hook_name
    _, cache = model.run_with_cache(batch_tokens, prepend_bos=True)
    model_acts = cache[hook_name]
    model_l0 = (model_acts[:, 1:] > 0).float().sum(-1).detach()

    print("Average Model L0:", model_l0.mean().item())

    fig = px.histogram(
        model_l0.flatten().cpu().numpy(),
        title=f"Model L0 Distribution (Layer: {hook_name})",
        labels={'value': 'Number of Active Neurons (L0)', 'count': 'Frequency'}
    )
    fig.show()

    # Cleanup
    del cache

Average Model L0: 499.3648986816406


In [9]:

sparse_autoencoder.eval()  # prevents error if we're expecting a dead neuron mask for who grads

with torch.no_grad():
    # activation store can give us tokens.
    batch_tokens = tokens[:32]
    _, cache = model.run_with_cache(batch_tokens, prepend_bos=True)

    original_acts = cache[sparse_autoencoder.cfg.metadata.hook_name]
    original_var = (original_acts - original_acts.mean(0)).pow(2).mean()
    feature_acts = sparse_autoencoder.encode(original_acts)
    sae_out = sparse_autoencoder.decode(feature_acts)
    mse_manual = (original_acts - sae_out).pow(2).mean()
    mse = mse_manual.item()
    explained_variance = 1 - (mse / original_var.item())
    del cache

    # ignore the bos token, get the number of features that activated in each token, averaged accross batch and position
    l0 = (feature_acts[:, 1:] > 0).float().sum(-1).detach()
    print("Average L0:", l0.mean().item())
    print(f"Original Variance: {original_var.item():.4f}")
    print(f"Explained Variance: {explained_variance:.2%}")
    px.histogram(l0.flatten().cpu().numpy()).show()

Average L0: 86.32579040527344
Original Variance: 0.1454
Explained Variance: 57.53%


# Checking for dead features

In [10]:
import plotly.express as px
import pandas as pd
import numpy as np
import torch

# Ensure we aren't tracking gradients
with torch.no_grad():
    batch_tokens = tokens[:32] # Note: With a small batch, rare features might appear "False Dead"

    # 1. Run Model & SAE
    hook_name = sparse_autoencoder.cfg.metadata.hook_name
    _, cache = model.run_with_cache(batch_tokens, prepend_bos=True)
    feature_acts = sparse_autoencoder.encode(cache[hook_name])

    # Flatten: [Batch * Seq_Len, n_features]
    flat_acts = feature_acts.reshape(-1, feature_acts.shape[-1])

    # 2. Count Activations
    # (How many tokens triggered each specific feature?)
    did_fire = (flat_acts > 0).float().sum(dim=0)

    # Convert to numpy for plotting
    fire_counts = did_fire.cpu().numpy()

    # 3. Calculate Dead Stats
    n_features = len(fire_counts)
    n_dead = (fire_counts == 0).sum()
    pct_dead = n_dead / n_features

    print(f"Dead Feature Frequency: {pct_dead:.2%}")
    print(f"Alive Features: {n_features - n_dead}")

    # --- PLOTTING CODE ---

    # We filter out dead features to plot the log-distribution of the alive ones
    # (Log scale is essential because feature frequency usually follows a Power Law)
    alive_counts = fire_counts[fire_counts > 0]

    if len(alive_counts) > 0:
        # Create a DataFrame for Plotly
        # We take log10 so the x-axis is readable (1, 10, 100, 1000 activations)
        df_plot = pd.DataFrame({
            'Log10(Activations)': np.log10(alive_counts)
        })

        fig = px.histogram(
            df_plot,
            x="Log10(Activations)",
            nbins=100,
            title=f"Feature Activation Distribution (Dead: {pct_dead:.2%})",
            labels={'count': 'Number of Features'},
            color_discrete_sequence=['#636EFA'] # Standard Plotly Blue
        )

        # Add a red box/annotation to explicitly show the Dead Count
        fig.add_annotation(
            x=0.95, y=0.95,
            xref="paper", yref="paper",
            text=f"<b>Dead Features:</b><br>{n_dead} ({pct_dead:.2%})",
            showarrow=False,
            bgcolor="crimson",
            bordercolor="black",
            font=dict(color="white")
        )

        # Update x-axis to show real numbers (1, 10, 100) instead of just 0, 1, 2
        fig.update_layout(
            xaxis_title="Log10(Count) - (0=1 firing, 1=10 firings, 2=100 firings)"
        )

        fig.show()
    else:
        print("All features are dead! (Check your model inputs or SAE hook)")

    # Cleanup
    del cache

Dead Feature Frequency: 4.57%
Alive Features: 3909


In [11]:
import plotly.express as px
import torch
import pandas as pd
import numpy as np
import umap

sparse_autoencoder.eval()

with torch.no_grad():

    batch_tokens = tokens[:128]
    _, cache = model.run_with_cache(batch_tokens, prepend_bos=True)

    # Get SAE activations
    hook_name = sparse_autoencoder.cfg.metadata.hook_name
    feature_acts = sparse_autoencoder.encode(cache[hook_name])

    # Count firings per feature
    flat_acts = feature_acts.reshape(-1, feature_acts.shape[-1])
    fire_counts = (flat_acts > 0).float().sum(dim=0).cpu().numpy()
    log_freq = np.log10(fire_counts + 1)
    weights = sparse_autoencoder.W_dec.detach().cpu().numpy()
    norms = np.linalg.norm(weights, axis=1, keepdims=True)
    weights_normalized = weights / (norms + 1e-8)

    print("Running UMAP projection")

    # Run UMAP
    reducer = umap.UMAP(
        n_neighbors=15,
        min_dist=0.1,
        metric='cosine', # Cosine distance is best for semantic vectors
        random_state=42
    )
    embedding = reducer.fit_transform(weights_normalized)

    # plot
    df_map = pd.DataFrame({
        'x': embedding[:, 0],
        'y': embedding[:, 1],
        'activations': fire_counts,
        'log_activations': log_freq,
        'feature_index': range(len(weights))
    })

    fig = px.scatter(
        df_map,
        x='x',
        y='y',
        color='log_activations', # Color by how "alive" the neuron is
        hover_data=['feature_index', 'activations'],
        title='Map of SAE Neurons (UMAP of Decoder Weights)',
        color_continuous_scale='Viridis', # 'Viridis' or 'Plasma' look good for heatmaps
        labels={'log_activations': 'Log10(Freq)'}
    )

    # Make points smaller so dense clusters are visible
    fig.update_traces(marker=dict(size=3, opacity=0.7))
    fig.show()

    del cache

Running UMAP projection



n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism.

