# 01 ‑ Evaluate Goodfire SAE on LLaMA 3.1‑8B vs DeepSeek R1‑Distill
This notebook loads Goodfire’s sparse auto‑encoder (layer 19) and:
1. Tests it on the original *LLaMA 3.1‑8B Instruct* model.
2. Runs the same SAE on *DeepSeek‑R1‑Distill‑Llama‑8B* (without fine‑tuning).
3. Visualises reconstruction error and top feature activations.

In [None]:
# Install core libs (run once)
!pip install -q sae-lens transformers accelerate datasets matplotlib

In [None]:
import torch, matplotlib.pyplot as plt, numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
from sae_lens import SAE, HookedSAETransformer

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# --- load models ---
base_model_name = 'meta-llama/Llama-3.1-8B-Instruct'
r1_model_name   = 'deepseek-ai/DeepSeek-R1-Distill-Llama-8B'
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
base_model = AutoModelForCausalLM.from_pretrained(base_model_name, device_map='auto')
base_model.eval()
r1_model = AutoModelForCausalLM.from_pretrained(r1_model_name, device_map='auto')
r1_model.eval()

# --- load SAE ---
sae_repo = 'Goodfire/Llama-3.1-8B-Instruct-SAE-l19'
sae_id   = 'blocks.19.hook_resid_post'
sae, sae_cfg, _ = SAE.from_pretrained(release=sae_repo, sae_id=sae_id, device=device)
print('Loaded SAE with latent dim', sae_cfg['d_sae'])

: 

## Evaluate on the original model (LLaMA 3.1‑8B)

In [None]:
prompt = (
    'Alice and Bob are planning a party. Alice has 3 apples, Bob brings 5 more. '
    'How many apples do they have in total? Explain step by step.'
)
inputs = tokenizer(prompt, return_tensors='pt').to(device)
hooked_base = HookedSAETransformer(base_model)
_, cache_base = hooked_base.run_with_cache(inputs['input_ids'], saes=[sae])

acts_base  = cache_base[sae_id]
recon_base = cache_base[f'SAE_RECON:{sae_id}']
feat_base  = cache_base[f'SAE:{sae_id}']

mse_base = ((recon_base - acts_base) ** 2).mean(dim=-1).cpu().numpy()
print('Average reconstruction MSE (base):', mse_base.mean())

In [None]:
# Visualise reconstruction error per token
plt.figure(figsize=(6,4))
plt.plot(mse_base, marker='o')
plt.title('LLaMA 3.1‑8B reconstruction error per token')
plt.xlabel('Token index')
plt.ylabel('MSE')
plt.show()

In [None]:
# Show top‑5 features for final token (base)
final_idx = feat_base.shape[0] - 1
vals = feat_base[final_idx].detach().cpu().numpy()
top = np.argsort(vals)[::-1][:5]
for i in top:
    print(f'Feature {i}: {vals[i]:.4f}')

## Evaluate on DeepSeek R1‑Distill (no fine‑tune)

In [None]:
hooked_r1 = HookedSAETransformer(r1_model)
_, cache_r1 = hooked_r1.run_with_cache(inputs['input_ids'], saes=[sae])
acts_r1   = cache_r1[sae_id]
recon_r1  = cache_r1[f'SAE_RECON:{sae_id}']
feat_r1   = cache_r1[f'SAE:{sae_id}']
mse_r1 = ((recon_r1 - acts_r1) ** 2).mean(dim=-1).cpu().numpy()
print('Average reconstruction MSE (R1‑distill, pre‑tune):', mse_r1.mean())

In [None]:
# Compare error curves
plt.figure(figsize=(6,4))
plt.plot(mse_base, label='Base', marker='o')
plt.plot(mse_r1,  label='R1‑Distill', marker='s')
plt.legend()
plt.title('Reconstruction error per token: Base vs R1‑Distill')
plt.xlabel('Token index')
plt.ylabel('MSE')
plt.show()

In [None]:
# Top‑5 features for final token on R1‑distill
vals_r1 = feat_r1[final_idx].detach().cpu().numpy()
top_r1 = np.argsort(vals_r1)[::-1][:5]
for i in top_r1:
    print(f'Feature {i}: {vals_r1[i]:.4f}')