# SAE + Llama3 8B minimal test
Load latest SAE, load model directly via HookedSAETransformer.from_pretrained (no HF wrapper), and generate with the SAE attached.


In [4]:
import torch
torch.cuda.device_count()

for i in range(torch.cuda.device_count()):
    print(i, torch.cuda.get_device_name(i))

0 NVIDIA RTX A6000
1 NVIDIA RTX A6000
2 NVIDIA RTX A6000
3 NVIDIA RTX A6000


In [10]:
GPU_INDEX = 3

DEVICE = (
    f"cuda:{GPU_INDEX}" if torch.cuda.is_available() else
    "mps" if torch.backends.mps.is_available() else
    "cpu"
)
DTYPE = torch.bfloat16 if DEVICE.startswith('cuda') else torch.float32
print(f"Using device: {DEVICE}, dtype: {DTYPE}")

Using device: cuda:3, dtype: torch.bfloat16


In [11]:
# Imports
from pathlib import Path
from sae_lens import SAE, HookedSAETransformer

In [12]:
# Load latest SAE
runs_root = Path('runs')
candidate_runs = sorted([p for p in runs_root.glob('*_llama3_8b') if p.is_dir()], key=lambda p: p.stat().st_mtime, reverse=True)
if not candidate_runs:
    raise FileNotFoundError('No *_llama3_8b run directories found.')
run_dir = candidate_runs[0]
sae_dir = run_dir / 'final_sae'
if not sae_dir.exists():
    sae_dir = run_dir

sae = SAE.load_from_disk(sae_dir, device=DEVICE)
sae.eval()

# Translate HF module path to TransformerLens hook name
import re
def tl_hook_name(path: str) -> str:
    m = re.fullmatch(r'model\.layers\.(\d+)\.mlp\.down_proj', path)
    if m:
        return f'blocks.{m.group(1)}.hook_mlp_out'
    m = re.fullmatch(r'model\.layers\.(\d+)\.self_attn\.o_proj', path)
    if m:
        return f'blocks.{m.group(1)}.attn.hook_result'
    return path

sae.cfg.metadata.hook_name = tl_hook_name(sae.cfg.metadata.hook_name)
print(f"Loaded SAE from {sae_dir}, hook={sae.cfg.metadata.hook_name}, d_in={sae.cfg.d_in}, d_sae={sae.cfg.d_sae}")


Loaded SAE from runs/20251213_174540_llama3_8b/final_sae, hook=blocks.15.hook_resid_pre, d_in=4096, d_sae=16384


In [13]:
# Load model directly via TransformerLens
model_name = sae.cfg.metadata.model_name
model = HookedSAETransformer.from_pretrained(
    model_name,
    device=DEVICE,
    dtype=DTYPE,
    move_to_device=True,
)
model.eval()
print(f"Model ready on {DEVICE} with dtype {DTYPE}")


`torch_dtype` is deprecated! Use `dtype` instead!


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]



Loaded pretrained model meta-llama/Meta-Llama-3-8B into HookedTransformer
Model ready on cuda:3 with dtype torch.bfloat16


In [14]:
# Quick activation sanity check
hook_name = sae.cfg.metadata.hook_name
if hook_name not in model.hook_dict:
    raise KeyError(f"Hook {hook_name} not in model.hook_dict; sample keys: {list(model.hook_dict.keys())[:5]} ...")
with torch.no_grad():
    tokens = model.to_tokens('Hello world', prepend_bos=True)
    _, cache = model.run_with_cache(tokens, names_filter=[hook_name])
    acts = cache[hook_name]
    print('Hook activation shape', acts.shape)


Hook activation shape torch.Size([1, 3, 4096])


In [15]:
# Generate with SAE attached
prompt = "Once upon a time, a helpful robot"
with model.saes(saes=[sae]):
    output = model.generate(
        prompt,
        max_new_tokens=60,
        temperature=0.7,
        stop_at_eos=True,
        verbose=False,
    )
print(output)


Once upon a time, a helpful robot named Nelly, a dog named Pip, and an unnamed baby were left behind on the surface of Mars.
The Moon is 4.5 billion years old, and its remaining magnetism and various other geologic features are clues to its origins.
