# 01 ‑ Evaluate Goodfire SAE on LLaMA 3.1‑8B vs DeepSeek R1‑Distill
This notebook loads Goodfire’s sparse auto‑encoder (layer 19) and:
1. Tests it on the original *LLaMA 3.1‑8B Instruct* model.
2. Runs the same SAE on *DeepSeek‑R1‑Distill‑Llama‑8B* (without fine‑tuning).
3. Visualises reconstruction error and top feature activations.

In [None]:
# Install core libs (run once)
!pip install -q sae-lens transformers accelerate datasets matplotlib

In [6]:
# clean cuda
import torch
torch.cuda.empty_cache()

# uce gc collect smth to clean
import gc
gc.collect()

# import libs

132

In [8]:
del r1_model

In [1]:
import torch, matplotlib.pyplot as plt, numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
from sae_lens import SAE, HookedSAETransformer

device = 'cuda' if torch.cuda.is_available() else 'cpu'

  from .autonotebook import tqdm as notebook_tqdm
2025-04-18 23:19:11.533731: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-04-18 23:19:11.626032: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1745018351.669510    3865 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1745018351.682046    3865 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-04-18 23:19:11.763749: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorF

OSError: You are trying to access a gated repo.
Make sure to have access to it at https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct.
401 Client Error. (Request ID: Root=1-6802ddf2-741555c36ccc7654336d5519;61636158-ab2a-440b-a4c7-dd2fb0304b40)

Cannot access gated repo for url https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct/resolve/main/config.json.
Access to model meta-llama/Llama-3.1-8B-Instruct is restricted. You must have access to it and be authenticated to access it. Please log in.

In [2]:
#auth into huggingface
from huggingface_hub import login
login("")

In [3]:
# --- load models ---
base_model_name = 'meta-llama/Llama-3.1-8B-Instruct'
r1_model_name   = 'deepseek-ai/DeepSeek-R1-Distill-Llama-8B'
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
base_model = AutoModelForCausalLM.from_pretrained(base_model_name, device_map='auto')
base_model.eval()
r1_model = AutoModelForCausalLM.from_pretrained(r1_model_name, device_map='auto')
r1_model.eval()

# --- load SAE ---
sae_repo = 'Goodfire/Llama-3.1-8B-Instruct-SAE-l19'
sae_id   = 'blocks.19.hook_resid_post'
sae, sae_cfg, _ = SAE.from_pretrained(release=sae_repo, sae_id=sae_id, device=device)
print('Loaded SAE with latent dim', sae_cfg['d_sae'])

Fetching 4 files: 100%|██████████| 4/4 [00:17<00:00,  4.40s/it]
Loading checkpoint shards: 100%|██████████| 4/4 [00:02<00:00,  1.60it/s]
Fetching 2 files: 100%|██████████| 2/2 [00:21<00:00, 10.61s/it]
Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00,  1.22s/it]


EntryNotFoundError: 404 Client Error. (Request ID: Root=1-6802de65-55ec3fd51da25fd3657b8791;931b2fe3-517c-4055-8c73-02a7d70ed185)

Entry Not Found for url: https://huggingface.co/Goodfire/Llama-3.1-8B-Instruct-SAE-l19/resolve/main/blocks.19.hook_resid_post/cfg.json.

## Evaluate on the original model (LLaMA 3.1‑8B)

In [None]:
prompt = (
    'Alice and Bob are planning a party. Alice has 3 apples, Bob brings 5 more. '
    'How many apples do they have in total? Explain step by step.'
)
inputs = tokenizer(prompt, return_tensors='pt').to(device)
hooked_base = HookedSAETransformer(base_model)
_, cache_base = hooked_base.run_with_cache(inputs['input_ids'], saes=[sae])

acts_base  = cache_base[sae_id]
recon_base = cache_base[f'SAE_RECON:{sae_id}']
feat_base  = cache_base[f'SAE:{sae_id}']

mse_base = ((recon_base - acts_base) ** 2).mean(dim=-1).cpu().numpy()
print('Average reconstruction MSE (base):', mse_base.mean())

In [None]:
# Visualise reconstruction error per token
plt.figure(figsize=(6,4))
plt.plot(mse_base, marker='o')
plt.title('LLaMA 3.1‑8B reconstruction error per token')
plt.xlabel('Token index')
plt.ylabel('MSE')
plt.show()

In [None]:
# Show top‑5 features for final token (base)
final_idx = feat_base.shape[0] - 1
vals = feat_base[final_idx].detach().cpu().numpy()
top = np.argsort(vals)[::-1][:5]
for i in top:
    print(f'Feature {i}: {vals[i]:.4f}')

## Evaluate on DeepSeek R1‑Distill (no fine‑tune)

In [None]:
hooked_r1 = HookedSAETransformer(r1_model)
_, cache_r1 = hooked_r1.run_with_cache(inputs['input_ids'], saes=[sae])
acts_r1   = cache_r1[sae_id]
recon_r1  = cache_r1[f'SAE_RECON:{sae_id}']
feat_r1   = cache_r1[f'SAE:{sae_id}']
mse_r1 = ((recon_r1 - acts_r1) ** 2).mean(dim=-1).cpu().numpy()
print('Average reconstruction MSE (R1‑distill, pre‑tune):', mse_r1.mean())

In [None]:
# Compare error curves
plt.figure(figsize=(6,4))
plt.plot(mse_base, label='Base', marker='o')
plt.plot(mse_r1,  label='R1‑Distill', marker='s')
plt.legend()
plt.title('Reconstruction error per token: Base vs R1‑Distill')
plt.xlabel('Token index')
plt.ylabel('MSE')
plt.show()

In [None]:
# Top‑5 features for final token on R1‑distill
vals_r1 = feat_r1[final_idx].detach().cpu().numpy()
top_r1 = np.argsort(vals_r1)[::-1][:5]
for i in top_r1:
    print(f'Feature {i}: {vals_r1[i]:.4f}')