# Llama 3 8B SAE Analysis

Notebook to analyze the locally trained SAE saved under `runs/20251213_174540_llama3_8b`.
The flow mirrors `tutorial-analysis.ipynb`, but loads weights from disk instead of the hub.


In [1]:
# Core imports
import json
from pathlib import Path
import math
from itertools import islice

import torch
from datasets import load_dataset, Dataset
from safetensors.torch import load_file
import plotly.express as px

from transformer_lens import HookedTransformer, utils
from transformer_lens.utils import tokenize_and_concatenate

from sae_lens import SAE
from sae_lens.saes.sae import SAEConfig


In [2]:
torch.cuda.device_count()

for i in range(torch.cuda.device_count()):
    print(i, torch.cuda.get_device_name(i))

0 NVIDIA RTX A6000
1 NVIDIA RTX A6000
2 NVIDIA RTX A6000
3 NVIDIA RTX A6000


In [3]:
GPU_INDEX = 3

DEVICE = (
    f"cuda:{GPU_INDEX}" if torch.cuda.is_available() else
    "mps" if torch.backends.mps.is_available() else
    "cpu"
)
DTYPE = torch.bfloat16 if DEVICE.startswith('cuda') else torch.float32
print(f"Using device: {DEVICE}, dtype: {DTYPE}")

Using device: cuda:3, dtype: torch.bfloat16


In [4]:
# Load the trained SAE from disk
RUN_DIR = Path('runs/20251213_174540_llama3_8b')
SAE_DIR = RUN_DIR / 'final_sae'
SAE_CFG = SAE_DIR / 'cfg.json'
SAE_WEIGHTS = SAE_DIR / 'sae_weights.safetensors'
SPARSITY_STATS = RUN_DIR / 'sparsity.safetensors'

with open(SAE_CFG) as f:
    sae_cfg_dict = json.load(f)
sae_cfg_dict['device'] = DEVICE
sae_config = SAEConfig.from_dict(sae_cfg_dict)

sae_cls = SAE.get_sae_class_for_architecture(sae_config.architecture())
sae = sae_cls(sae_config)
state_dict = load_file(str(SAE_WEIGHTS), device=DEVICE)
sae.load_state_dict(state_dict)
sae.to(DEVICE)
sae.eval()

log_sparsity = None
if SPARSITY_STATS.exists():
    log_sparsity = load_file(str(SPARSITY_STATS), device='cpu').get('sparsity')

print(f"Loaded SAE with d_in={sae.cfg.d_in}, d_sae={sae.cfg.d_sae}, hook={sae.cfg.metadata.hook_name}")
if log_sparsity is not None:
    print(f'Mean log sparsity: {log_sparsity.mean().item():.3f}')


This SAE has non-empty model_from_pretrained_kwargs. 
For optimal performance, load the model like so:
model = HookedSAETransformer.from_pretrained_no_processing(..., **cfg.model_from_pretrained_kwargs)


Loaded SAE with d_in=4096, d_sae=16384, hook=blocks.15.hook_resid_pre
Mean log sparsity: -4.510


## Load model and dataset

We load the matching Llama 3 8B model through TransformerLens and tokenize the same dataset that the SAE was trained on.


In [5]:
# Load model (may take time and VRAM; adjust kwargs if needed)
model_dtype = getattr(torch, sae_cfg_dict.get('dtype', 'float32'))
model = HookedTransformer.from_pretrained(
    sae.cfg.metadata.model_name,
    device=DEVICE,
    dtype=model_dtype,
    fold_ln=False,
    center_unembed=False,
    center_writing_weights=False,
)
hook_name = sae.cfg.metadata.hook_name
context_size = sae.cfg.metadata.context_size
prepend_bos = sae.cfg.metadata.prepend_bos
print(f'Model loaded; hook point: {hook_name}')


`torch_dtype` is deprecated! Use `dtype` instead!


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Loaded pretrained model meta-llama/Meta-Llama-3-8B into HookedTransformer
Model loaded; hook point: blocks.15.hook_resid_pre


In [6]:
print(sae.cfg.metadata.dataset_path)

monology/pile-uncopyrighted


In [7]:
# Tokenize the dataset without downloading everything
# Match the training token budget (~10M tokens) by limiting how many chunks we stream
MAX_TOKENS = 10_000_000
max_sequences = max(1, math.ceil(MAX_TOKENS / context_size))

# Stream raw text, then keep only the text field in-memory
raw_stream = load_dataset(
    path=sae.cfg.metadata.dataset_path,
    split='train',
    streaming=True,
)
raw_samples = list(islice(raw_stream, max_sequences))
texts = [row['text'] for row in raw_samples if 'text' in row]

sample_dataset = Dataset.from_dict({'text': texts})

token_dataset = tokenize_and_concatenate(
    sample_dataset,
    model.tokenizer,
    max_length=context_size,
    column_name='text',
    add_bos_token=prepend_bos,
    streaming=False,
)

num_batch = min(32, len(token_dataset))
tokens_slice = token_dataset[:num_batch]['tokens']
if isinstance(tokens_slice, torch.Tensor):
    batch_tokens = tokens_slice.to(DEVICE)
else:
    batch_tokens = torch.stack(tokens_slice).to(DEVICE)
print('Tokenized dataset ready:', batch_tokens.shape)


Resolving data files:   0%|          | 0/30 [00:00<?, ?it/s]

Map (num_proc=10):   0%|          | 0/391 [00:00<?, ? examples/s]

Tokenized dataset ready: torch.Size([32, 256])


## L0 sparsity check

Compute the number of active SAE features per token to sanity-check sparsity.


In [8]:
sae.eval()
with torch.no_grad():
    tokens = batch_tokens
    target_layer = int(hook_name.split('.')[1]) if hook_name.startswith('blocks.') else None
    _, cache = model.run_with_cache(
        tokens,
        prepend_bos=prepend_bos,
        names_filter=[hook_name],
        stop_at_layer=(target_layer + 1) if target_layer is not None else None,
    )
    acts = cache[hook_name]
    feature_acts = sae.encode(acts)
    l0 = (feature_acts > 0).sum(dim=-1).float()
    if prepend_bos:
        l0 = l0[:, 1:]
    l0_flat = l0.flatten().cpu()
    l0_values = l0_flat.numpy()
    print(f'Mean L0: {l0_flat.mean().item():.2f} +/- {l0_flat.std().item():.2f}')
    del cache, acts, feature_acts
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

fig = px.histogram(
    x=l0_values,
    labels={
        'x': 'Number of SAE features active per token (L0 count)',
        'y': 'Token positions in sampled batch',
    },
    title='Distribution of SAE feature sparsity (L0)',
)
fig.show()


Mean L0: 148.01 +/- 32.93


## Reconstruction vs ablation

Compare model loss when (a) using the SAE reconstruction and (b) zeroing out the hooked activations.


In [9]:
def sae_reconstruction_hook(activation, hook):
    return sae(activation)

def zero_ablation_hook(activation, hook):
    return torch.zeros_like(activation)

orig_loss = model(batch_tokens, return_type='loss', prepend_bos=prepend_bos).item()
recon_loss = model.run_with_hooks(
    batch_tokens,
    return_type='loss',
    fwd_hooks=[(hook_name, sae_reconstruction_hook)],
    prepend_bos=prepend_bos,
).item()
ablate_loss = model.run_with_hooks(
    batch_tokens,
    return_type='loss',
    fwd_hooks=[(hook_name, zero_ablation_hook)],
    prepend_bos=prepend_bos,
).item()

print(f'Original loss:     {orig_loss:.4f}')
print(f'Reconstruction loss:{recon_loss:.4f}')
print(f'Zero ablation loss:{ablate_loss:.4f}')


OutOfMemoryError: CUDA out of memory. Tried to allocate 64.00 MiB. GPU 3 has a total capacity of 47.40 GiB of which 12.62 MiB is free. Including non-PyTorch memory, this process has 47.37 GiB memory in use. Of the allocated memory 47.05 GiB is allocated by PyTorch, and 5.14 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

## Quick prompt check

Run a short prompt through the SAE-reconstructed activations to spot obvious regressions.


In [10]:
prompt = 'When John and Mary went to the shops, John gave the bag to'
answer = ' Mary'

print('Baseline:')
utils.test_prompt(prompt, answer, model, prepend_bos=True)

print('
With SAE reconstruction:')
recon_logits, _ = model.run_with_cache(
    prompt,
    prepend_bos=True,
    fwd_hooks=[(hook_name, sae_reconstruction_hook)],
)
probs = torch.softmax(recon_logits[0, -1], dim=-1)
answer_token = model.to_single_token(answer)
print(f'P(answer) with SAE: {probs[answer_token].item():.4f}')


Baseline:
Tokenized prompt: ['<|begin_of_text|>', 'When', ' John', ' and', ' Mary', ' went', ' to', ' the', ' shops', ',', ' John', ' gave', ' the', ' bag', ' to']
Tokenized answer: [' Mary']


OutOfMemoryError: CUDA out of memory. Tried to allocate 64.00 MiB. GPU 3 has a total capacity of 47.40 GiB of which 4.62 MiB is free. Including non-PyTorch memory, this process has 47.38 GiB memory in use. Of the allocated memory 47.06 GiB is allocated by PyTorch, and 8.46 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

## (Optional) Feature dashboards

If `sae-dashboard` is installed, uncomment the cell below to generate feature-centric HTML dashboards for a few features.


In [None]:
# from sae_dashboard.sae_vis_data import SaeVisConfig
# from sae_dashboard.sae_vis_runner import SaeVisRunner
# from sae_dashboard.data_writing_fns import save_feature_centric_vis
#
# feature_ids = list(range(10))
# vis_runner = SaeVisRunner(
#     model=model,
#     sae=sae,
#     hook_name=hook_name,
#     dataset=token_dataset,
#     cfg=SaeVisConfig(num_examples=256),
# )
# vis_data = vis_runner.run(feature_ids)
# save_feature_centric_vis(vis_data, filename='demo_feature_dashboards.html')
