# Llama 3 8B SAE reuse / analysis

Load a trained SAE for Llama 3 8B, optionally attach it to the base model, and run quick sparsity + reconstruction diagnostics. Defaults to 4-bit quant, with a toggle to run without quantization.


## How to use
- Pick the latest run automatically (or set `RUN_DIR` to a specific `runs/*_llama3_8b`).
- Choose quantization mode: `"4bit"` (default), `"none"`, or `"from_run"` (HF config only; interventions strip quant flags).
- Flip `RUN_MODEL` to `True` to load the base model for text generation + feature probes (heavy); leave `False` for quick SAE-only checks.
- Run the cells in order; sparsity + reconstruction metrics help catch bad checkpoints.


In [1]:
# Imports and device pick
from pathlib import Path
import json
import re
from typing import Any

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from sae_lens import SAE, HookedSAETransformer
from sae_lens.constants import DTYPE_MAP
from safetensors.torch import load_file

# pick device; default to CPU if CUDA/MPS unavailable
DEVICE = (
    "cuda" if torch.cuda.is_available() else
    "mps" if torch.backends.mps.is_available() else
    "cpu"
)
print(f"Using device: {DEVICE}")


  from .autonotebook import tqdm as notebook_tqdm


Using device: cuda


In [2]:
# Locate the most recent run directory (suffix `_llama3_8b`).
runs_root = Path("runs")
candidate_runs = sorted(
    [p for p in runs_root.glob("*_llama3_8b") if p.is_dir()],
    key=lambda p: p.stat().st_mtime,
    reverse=True,
)

# Enter your desired run directory here, or leave as None to auto-select
RUN_DIR = None  # e.g., Path("runs/20240622-120000_llama3_8b")
if RUN_DIR:
    run_dir = Path(RUN_DIR)
elif candidate_runs:
    run_dir = candidate_runs[0]
else:
    raise FileNotFoundError("No *_llama3_8b run directories found in ./runs/")
print(f"Selected run dir: {run_dir.resolve()}")

sae_dir = run_dir / "final_sae"
if not sae_dir.exists():
    # fallback to runs/<name> directly if final_sae isn't present
    sae_dir = run_dir

cfg_path = sae_dir / "cfg.json"
runner_cfg_path = run_dir / "runner_cfg.json"
assert cfg_path.exists(), f"SAE config not found at {cfg_path}"

runner_cfg: dict[str, Any] = json.loads(runner_cfg_path.read_text()) if runner_cfg_path.exists() else {}

# Convert dtype strings like "torch.bfloat16" to actual dtypes
hf_load_kwargs_raw = runner_cfg.get("model_from_pretrained_kwargs", {})

def resolve_dtype(val: Any) -> Any:
    if isinstance(val, str) and val in DTYPE_MAP:
        return DTYPE_MAP[val]
    return val

hf_load_kwargs = {k: resolve_dtype(v) for k, v in hf_load_kwargs_raw.items()}

hook_name = runner_cfg.get("hook_name", "model.layers.15.mlp.down_proj")
model_name = runner_cfg.get("model_name", "meta-llama/Meta-Llama-3-8B")
context_size = runner_cfg.get("context_size", 256)

# Translate HF module paths to TransformerLens hook names

def tl_hook_name_from_hf(path: str) -> str:
    m = re.fullmatch(r"model\.layers\.(\d+)\.mlp\.down_proj", path)
    if m:
        return f"blocks.{m.group(1)}.hook_mlp_out"
    m = re.fullmatch(r"model\.layers\.(\d+)\.self_attn\.o_proj", path)
    if m:
        return f"blocks.{m.group(1)}.attn.hook_result"
    return path

original_hook_name = hook_name
hook_name = tl_hook_name_from_hf(hook_name)
if hook_name != original_hook_name:
    print(f"Translated hook name: {original_hook_name} -> {hook_name}")

# Quantization override: "4bit" (default), "none", or "from_run" (use runner cfg as-is)
QUANTIZATION_MODE = "4bit"
allowed_modes = {"4bit", "none", "from_run"}
if QUANTIZATION_MODE not in allowed_modes:
    raise ValueError(f"QUANTIZATION_MODE must be one of {allowed_modes}")

if QUANTIZATION_MODE != "from_run":
    q_kwargs = {k: v for k, v in hf_load_kwargs.items() if not k.startswith("load_in")}
    for k in ["bnb_4bit_quant_type", "bnb_4bit_compute_dtype", "bnb_4bit_use_double_quant"]:
        q_kwargs.pop(k, None)
    if QUANTIZATION_MODE == "4bit":
        compute_dtype = q_kwargs.get("torch_dtype", torch.bfloat16 if DEVICE != "cpu" else torch.float32)
        q_kwargs.update({
            "load_in_4bit": True,
            "bnb_4bit_quant_type": "nf4",
            "bnb_4bit_compute_dtype": compute_dtype,
            "bnb_4bit_use_double_quant": True,
        })
    # QUANTIZATION_MODE == "none" leaves q_kwargs without 4/8-bit flags
    hf_load_kwargs = q_kwargs
print(f"Quantization mode: {QUANTIZATION_MODE}")
# Note: HookedSAETransformer will receive sanitized kwargs; quantization is driven by the HF model when needed.

# Avoid auto offload unless you really want it; meta tensors can break hook usage
USE_DEVICE_MAP = False
if not USE_DEVICE_MAP:
    hf_load_kwargs.pop("device_map", None)
    hf_load_kwargs.pop("offload_folder", None)

model_load_kwargs = dict(hf_load_kwargs)
print(f"Hook: {hook_name}, Model: {model_name}, Context size: {context_size}")


Selected run dir: /ssd/jdh/interpretability/experiments/exp-sae-lens/runs/20251213_002417_llama3_8b
Translated hook name: model.layers.15.mlp.down_proj -> blocks.15.hook_mlp_out
Quantization mode: 4bit
Hook: blocks.15.hook_mlp_out, Model: meta-llama/Meta-Llama-3-8B, Context size: 256


In [3]:
# Load the SAE (always). Model loads happen sequentially to save VRAM.
RUN_MODEL = True  # flip to True to run generation/feature probes (heavy)

sae = SAE.load_from_disk(sae_dir, device=DEVICE)
sae.eval()
# Ensure SAE hook matches TransformerLens naming
sae.cfg.metadata.hook_name = hook_name
print(f"Loaded SAE arch={sae.cfg.architecture()} d_in={sae.cfg.d_in} d_sae={sae.cfg.d_sae} dtype={sae.dtype} hook={sae.cfg.metadata.hook_name}")

# HookedSAETransformer cannot accept raw bitsandbytes flags via from_pretrained_kwargs.
# Build a quantized HF model when needed and strip quant flags for TransformerLens kwargs.
_intervention_blocklist = {
    "load_in_4bit",
    "load_in_8bit",
    "bnb_4bit_quant_type",
    "bnb_4bit_compute_dtype",
    "bnb_4bit_use_double_quant",
    "from_pretrained_kwargs",
}
intervention_load_kwargs = {k: v for k, v in model_load_kwargs.items() if k not in _intervention_blocklist}


def wants_4bit_quant() -> bool:
    return QUANTIZATION_MODE == "4bit" or (
        QUANTIZATION_MODE == "from_run" and hf_load_kwargs.get("load_in_4bit")
    )


def build_quantized_hf_model():
    if DEVICE == "cpu":
        raise ValueError("4-bit quantization requires CUDA; switch QUANTIZATION_MODE or device.")
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type=model_load_kwargs.get("bnb_4bit_quant_type", "nf4"),
        bnb_4bit_compute_dtype=model_load_kwargs.get(
            "bnb_4bit_compute_dtype",
            model_load_kwargs.get("torch_dtype", torch.bfloat16),
        ),
        bnb_4bit_use_double_quant=model_load_kwargs.get("bnb_4bit_use_double_quant", True),
    )
    hf_kwargs = {
        k: v
        for k, v in model_load_kwargs.items()
        if not k.startswith("load_in") and not k.startswith("bnb_4bit")
    }
    hf_model = AutoModelForCausalLM.from_pretrained(
        model_name,
        quantization_config=bnb_config,
        device_map="auto" if USE_DEVICE_MAP else None,
        **hf_kwargs,
    )
    tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
    return hf_model, tokenizer


def load_base_model():
    if wants_4bit_quant():
        hf_model, tokenizer = build_quantized_hf_model()
        return HookedSAETransformer.from_pretrained(
            model_name,
            hf_model=hf_model,
            tokenizer=tokenizer,
            device=DEVICE,
            move_to_device=False,  # HF already placed weights according to device_map
            from_pretrained_kwargs=intervention_load_kwargs,
        )

    return HookedSAETransformer.from_pretrained(
        model_name,
        device=DEVICE,
        from_pretrained_kwargs=intervention_load_kwargs,
    )


Loaded SAE arch=standard d_in=4096 d_sae=16384 dtype=torch.float32 hook=blocks.15.hook_mlp_out


This SAE has non-empty model_from_pretrained_kwargs. 
For optimal performance, load the model like so:
model = HookedSAETransformer.from_pretrained_no_processing(..., **cfg.model_from_pretrained_kwargs)


In [4]:
# Quick sparsity snapshot
sparsity_path = run_dir / "sparsity.safetensors"
if not sparsity_path.exists():
    alt = sae_dir / "sparsity.safetensors"
    sparsity_path = alt if alt.exists() else None

if sparsity_path and sparsity_path.exists():
    log_sparsity = load_file(sparsity_path)["sparsity"]
    est_l0 = torch.exp(log_sparsity).mean().item() * log_sparsity.numel()
    print(f"Estimated active features per token (L0): {est_l0:.1f} / {log_sparsity.numel()}")
else:
    print("No sparsity file found; skipping sparsity estimate.")


Estimated active features per token (L0): 3135.3 / 16384


In [5]:
# Helper utilities for sequential generation/loss and feature grabs.
import contextlib

prompt_text = "Once upon a time, a curious robot learned to help humans"

def run_generation_and_loss(with_sae: bool = False):
    model = load_base_model()
    ctx = model.saes(saes=[sae]) if with_sae else contextlib.nullcontext()
    with torch.no_grad(), ctx:
        story = model.generate(
            prompt_text,
            max_new_tokens=80,
            temperature=0.7,
            stop_at_eos=True,
            verbose=False,
        )
        tokens = model.to_tokens(prompt_text, prepend_bos=True)
        logits = model(tokens)
        loss = torch.nn.functional.cross_entropy(
            logits[0, :-1].reshape(-1, logits.size(-1)),
            tokens[0, 1:].reshape(-1),
        ).item()
    del model
    torch.cuda.empty_cache()
    return story, loss

def grab_feature_activations():
    model = load_base_model()
    with torch.no_grad():
        tokens = model.to_tokens(prompt_text, prepend_bos=True)
        _, cache = model.run_with_cache(tokens, names_filter=[hook_name])
        mlp_out = cache[hook_name]
        feature_acts = sae.encode(mlp_out)
    del model
    torch.cuda.empty_cache()
    return mlp_out, feature_acts

def feature_subset_logits(num_features: int = 10):
    model = load_base_model()
    with torch.no_grad():
        tokens = model.to_tokens(prompt_text, prepend_bos=True)
        base_logits = model(tokens)
        _, cache = model.run_with_cache(tokens, names_filter=[hook_name])
        mlp_out = cache[hook_name]
        feature_acts = sae.encode(mlp_out)

        rand_idx = torch.randperm(feature_acts.size(-1), device=feature_acts.device)[:num_features]
        kept_features = torch.zeros_like(feature_acts)
        kept_features[..., rand_idx] = feature_acts[..., rand_idx]
        recon_from_subset = sae.decode(kept_features)

        def swap_mlp_out(acts, hook):
            return recon_from_subset

        hooked_logits = model.run_with_hooks(
            tokens, fwd_hooks=[(hook_name, swap_mlp_out)]
        )
    del model
    torch.cuda.empty_cache()
    return base_logits, hooked_logits, rand_idx





In [6]:
# Compare generations and prompt loss with/without SAE reconstruction (sequential loads).
if RUN_MODEL:
    base_story, base_loss = run_generation_and_loss(with_sae=False)
    sae_story, sae_loss = run_generation_and_loss(with_sae=True)

    print("--- Base model ---")
    print(base_story)
    print("--- With SAE reconstruction ---")
    print(sae_story)
    print(f"Next-token loss without SAE: {base_loss:.3f}")
    print(f"Next-token loss with SAE:    {sae_loss:.3f}")
else:
    print("Set RUN_MODEL=True to run generation and loss comparisons.")



`torch_dtype` is deprecated! Use `dtype` instead!
Loading checkpoint shards: 100%|██████████| 4/4 [00:11<00:00,  2.76s/it]


RuntimeError: The size of tensor a (8388608) must match the size of tensor b (4096) at non-singleton dimension 1

In [None]:
# Top activated SAE features on the last token of the prompt (sequential load).
if RUN_MODEL:
    _, feature_acts = grab_feature_activations()
    last_token_acts = feature_acts[0, -1]
    values, indices = torch.topk(last_token_acts, k=10)
    print("Top activated features on the last token:")
    for val, idx in zip(values.tolist(), indices.tolist()):
        print(f"  Feature {idx:5d} -> activation {val:.2f}")
else:
    print("Set RUN_MODEL=True to inspect activated features.")



In [None]:
# Effect of a random subset of SAE features on logits for the prompt (sequential load).
if RUN_MODEL:
    base_logits, hooked_logits, rand_idx = feature_subset_logits(num_features=10)

    def show_top_ids(logits, label, k=10):
        probs = logits.softmax(-1)
        vals, idxs = torch.topk(probs, k)
        print(label)
        for v, i in zip(vals.tolist(), idxs.tolist()):
            print(f"  token_id={i:5d} prob={v:.3f}")
        print()

    base_last = base_logits[0, -1]
    hooked_last = hooked_logits[0, -1]
    print(f"Random feature IDs: {sorted(rand_idx.tolist())}")
    show_top_ids(base_last, "Base model last-token distribution:")
    show_top_ids(hooked_last, "Only these features reconstructed:")
else:
    print("Set RUN_MODEL=True to run feature-subset logits.")



In [None]:
# Reconstruction quality on the prompt (sequential). Falls back to synthetic if RUN_MODEL=False.
if RUN_MODEL:
    mlp_out, feature_acts = grab_feature_activations()
else:
    mlp_out = torch.randn(4, 8, sae.cfg.d_in, device=sae.device, dtype=sae.dtype)
    feature_acts = sae.encode(mlp_out)

with torch.no_grad():
    recon = sae.decode(feature_acts)

    mse = torch.nn.functional.mse_loss(recon, mlp_out).item()
    l2_orig = mlp_out.pow(2).mean().sqrt().item()
    l2_err = (recon - mlp_out).pow(2).mean().sqrt().item()
    explained = 1 - (l2_err**2 / (l2_orig**2 + 1e-9))
    density = (feature_acts.abs() > 1e-6).float().mean().item()

print(
    f"Recon MSE: {mse:.6f}
"
    f"Orig L2 (mean): {l2_orig:.4f}
"
    f"Error L2 (mean): {l2_err:.4f}
"
    f"Explained variance (approx): {explained:.4f}
"
    f"Mean activation density: {density:.4f}"
)




### Reading the diagnostics
- High `Recon MSE` / `Error L2` or low `Explained variance` may indicate a bad checkpoint or mismatched hook.
- `Mean activation density` tracks sparsity; big jumps suggest L1/normalization changes.
- Compare base vs SAE next-token loss to see whether reconstruction is neutral or harmful for the prompt.
- Top-activated features and feature-subset logits help sanity-check which directions the SAE is using.
