# SAE + Llama3 8B minimal test
Load latest SAE, load model directly via HookedSAETransformer.from_pretrained (no HF wrapper), and generate with the SAE attached.


In [None]:
# Imports
from pathlib import Path
import torch
from sae_lens import SAE, HookedSAETransformer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Load latest SAE
runs_root = Path('runs')
candidate_runs = sorted([p for p in runs_root.glob('*_llama3_8b') if p.is_dir()], key=lambda p: p.stat().st_mtime, reverse=True)
if not candidate_runs:
    raise FileNotFoundError('No *_llama3_8b run directories found.')
run_dir = candidate_runs[0]
sae_dir = run_dir / 'final_sae'
if not sae_dir.exists():
    sae_dir = run_dir

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
DTYPE = torch.bfloat16 if DEVICE == 'cuda' else torch.float32
sae = SAE.load_from_disk(sae_dir, device=DEVICE)
sae.eval()

# Translate HF module path to TransformerLens hook name
import re
def tl_hook_name(path: str) -> str:
    m = re.fullmatch(r'model\.layers\.(\d+)\.mlp\.down_proj', path)
    if m:
        return f'blocks.{m.group(1)}.hook_mlp_out'
    m = re.fullmatch(r'model\.layers\.(\d+)\.self_attn\.o_proj', path)
    if m:
        return f'blocks.{m.group(1)}.attn.hook_result'
    return path

sae.cfg.metadata.hook_name = tl_hook_name(sae.cfg.metadata.hook_name)
print(f"Loaded SAE from {sae_dir}, hook={sae.cfg.metadata.hook_name}, d_in={sae.cfg.d_in}, d_sae={sae.cfg.d_sae}")


Loaded SAE from runs/20251213_002417_llama3_8b/final_sae, hook=model.layers.15.mlp.down_proj, d_in=4096, d_sae=16384


This SAE has non-empty model_from_pretrained_kwargs. 
For optimal performance, load the model like so:
model = HookedSAETransformer.from_pretrained_no_processing(..., **cfg.model_from_pretrained_kwargs)


In [3]:
# Load model directly via TransformerLens
model_name = sae.cfg.metadata.model_name
model = HookedSAETransformer.from_pretrained(
    model_name,
    device=DEVICE,
    dtype=DTYPE,
    move_to_device=True,
)
model.eval()
print(f"Model ready on {DEVICE} with dtype {DTYPE}")


`torch_dtype` is deprecated! Use `dtype` instead!
Loading checkpoint shards: 100%|██████████| 4/4 [00:00<00:00, 85.60it/s]


Loaded pretrained model meta-llama/Meta-Llama-3-8B into HookedTransformer
Model ready on cuda with dtype torch.bfloat16


In [4]:
# Quick activation sanity check
hook_name = sae.cfg.metadata.hook_name
if hook_name not in model.hook_dict:
    raise KeyError(f"Hook {hook_name} not in model.hook_dict; sample keys: {list(model.hook_dict.keys())[:5]} ...")
with torch.no_grad():
    tokens = model.to_tokens('Hello world', prepend_bos=True)
    _, cache = model.run_with_cache(tokens, names_filter=[hook_name])
    acts = cache[hook_name]
    print('Hook activation shape', acts.shape)


KeyError: "Hook model.layers.15.mlp.down_proj not in model.hook_dict; sample keys: ['hook_embed', 'blocks.0.ln1.hook_scale', 'blocks.0.ln1.hook_normalized', 'blocks.0.ln2.hook_scale', 'blocks.0.ln2.hook_normalized'] ..."

In [None]:
# Generate with SAE attached
prompt = "Once upon a time, a helpful robot"
with model.saes(saes=[sae]):
    output = model.generate(
        prompt,
        max_new_tokens=60,
        temperature=0.7,
        stop_at_eos=True,
        verbose=False,
    )
print(output)
