# SAE + Llama3 8B (4-bit) quick test
Minimal example mirroring the tutorial style: load the SAE, load a quantized model, attach the SAE, and generate text.


In [8]:
# Imports
from pathlib import Path
import re

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from sae_lens import SAE, HookedSAETransformer


In [9]:
# Load latest SAE from runs
runs_root = Path("runs")
candidate_runs = sorted([p for p in runs_root.glob("*_llama3_8b") if p.is_dir()], key=lambda p: p.stat().st_mtime, reverse=True)
if not candidate_runs:
    raise FileNotFoundError("No *_llama3_8b run directories found.")
run_dir = candidate_runs[0]
sae_dir = run_dir / "final_sae"
if not sae_dir.exists():
    sae_dir = run_dir

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
sae = SAE.load_from_disk(sae_dir, device=DEVICE)
sae.eval()

# Map HF module paths to TransformerLens hook points (same pattern as tutorial uses TL hook names)
def tl_hook_name(path: str) -> str:
    m = re.fullmatch(r"model\.layers\.(\d+)\.mlp\.down_proj", path)
    if m:
        return f"blocks.{m.group(1)}.hook_mlp_out"
    m = re.fullmatch(r"model\.layers\.(\d+)\.self_attn\.o_proj", path)
    if m:
        return f"blocks.{m.group(1)}.attn.hook_result"
    return path

sae.cfg.metadata.hook_name = tl_hook_name(sae.cfg.metadata.hook_name)
print(f"Loaded SAE {sae_dir.name} -> hook {sae.cfg.metadata.hook_name}, d_in={sae.cfg.d_in}, d_sae={sae.cfg.d_sae}")


Loaded SAE final_sae -> hook blocks.15.hook_mlp_out, d_in=4096, d_sae=16384


In [10]:
# Inspect SAE encode/decode shapes on dummy activation
with torch.no_grad():
    dummy = torch.zeros(1, 1, sae.cfg.d_in, device=sae.device, dtype=sae.dtype)
    encoded = sae.encode(dummy)
    decoded = sae.decode(encoded)
print(f"encode shape: {encoded.shape}, decode shape: {decoded.shape}")


encode shape: torch.Size([1, 1, 16384]), decode shape: torch.Size([1, 1, 4096])


In [11]:
# Load 4-bit quantized model and wrap in HookedSAETransformer
model_name = sae.cfg.metadata.model_name
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16 if DEVICE == "cuda" else torch.float32,
    bnb_4bit_use_double_quant=True,
)

hf_model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    device_map=None,
    torch_dtype=torch.bfloat16 if DEVICE == "cuda" else torch.float32,
)

tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)

model = HookedSAETransformer.from_pretrained(
    model_name,
    hf_model=hf_model,
    tokenizer=tokenizer,
    device=DEVICE,
    move_to_device=False,
)
model.eval()
print("Model ready (4-bit)")


Loading checkpoint shards: 100%|██████████| 4/4 [00:11<00:00,  2.76s/it]


RuntimeError: The size of tensor a (8388608) must match the size of tensor b (4096) at non-singleton dimension 1

In [None]:
# Sanity check: ensure hook exists and activation width matches SAE d_in
assert sae.cfg.metadata.hook_name in model.hook_dict, f'Hook {sae.cfg.metadata.hook_name} not in model.hook_dict'
with torch.no_grad():
    tokens = model.to_tokens('Hello world', prepend_bos=True)
    _, cache = model.run_with_cache(tokens, names_filter=[sae.cfg.metadata.hook_name])
    acts = cache[sae.cfg.metadata.hook_name]
    print('Hook activation shape', acts.shape)
    assert acts.size(-1) == sae.cfg.d_in, f'Activation last dim {acts.size(-1)} != SAE d_in {sae.cfg.d_in}. Check hook name or SAE config.'
    flat = acts.view(-1, acts.size(-1))
    encoded = sae.encode(flat)
    print('SAE encode shape', encoded.shape)


In [None]:
# Generate with SAE attached
prompt = "Once upon a time, a helpful robot"
with model.saes(saes=[sae]):
    output = model.generate(
        prompt,
        max_new_tokens=60,
        temperature=0.7,
        stop_at_eos=True,
        verbose=False,
    )
print(output)
