# 02 ‑ Fine‑Tune SAE for DeepSeek R1‑Distill
This notebook fine‑tunes Goodfire’s SAE on R1‑distill layer 19 activations using a 
reasoning‑heavy dataset (R1 traces) blended with a small sample of LMSYS chat data.
After training, we evaluate reconstruction fidelity and save the adapted SAE.

In [None]:
!pip install -q sae-lens transformers accelerate datasets matplotlib

In [None]:
import torch, random, numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from sae_lens import SAE, HookedSAETransformer
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Load models & SAE
r1_model_name = 'deepseek-ai/DeepSeek-R1-Distill-Llama-8B'
sae_repo = 'Goodfire/Llama-3.1-8B-Instruct-SAE-l19'
sae_id   = 'blocks.19.hook_resid_post'

tokenizer = AutoTokenizer.from_pretrained(r1_model_name)
r1_model  = AutoModelForCausalLM.from_pretrained(r1_model_name, device_map='auto')
r1_model.eval()
sae, sae_cfg, _ = SAE.from_pretrained(release=sae_repo, sae_id=sae_id, device=device)
print('SAE loaded. Latent dim =', sae_cfg['d_sae'])

## Load reasoning and chat datasets

In [None]:
# Load R1 reasoning traces (example dataset name)
reason_ds = load_dataset('phunguyen01/open-r1-math-220k', split='train[:10%]')

# Load subset of LMSYS chat for generality
chat_ds = load_dataset('lmsys/lmsys-chat-1m', split='train[:2%]')

# Combine with 80/20 weighting (already controlled by slice ratios above)
combined_texts = [ex['text'] for ex in reason_ds]
combined_texts += [ex['prompt'] + ' ' + ex.get('response', '') for ex in chat_ds]
random.shuffle(combined_texts)
print('Total training sequences:', len(combined_texts))

## Data loader that streams activations

In [None]:
from torch.utils.data import Dataset, DataLoader

class ActivationDataset(Dataset):
    def __init__(self, texts, model, tok, sae_hook, max_len=256):
        self.texts, self.model, self.tok, self.sae_hook = texts, model, tok, sae_hook
        self.max_len = max_len
    def __len__(self):
        return len(self.texts)
    def __getitem__(self, idx):
        text = self.texts[idx][:self.max_len]
        toks = self.tok(text, return_tensors='pt', truncation=True, max_length=self.max_len).to(device)
        # run model up to layer 19 & fetch resid post activations
        hooked = HookedSAETransformer(self.model)
        with torch.no_grad():
            _, cache = hooked.run_with_cache(toks['input_ids'])
        acts = cache[self.sae_hook].detach().float()
        return acts, acts  # target is same as input for autoencoder

In [None]:
train_ds = ActivationDataset(combined_texts, r1_model, tokenizer, sae_id)
loader = DataLoader(train_ds, batch_size=8, shuffle=True)
print('Activation dataset ready. Batches:', len(loader))

## Fine‑tune loop

In [None]:
sae.train()
opt = torch.optim.AdamW(sae.parameters(), lr=1e-4)
epochs = 1  # increase as compute allows
for ep in range(epochs):
    total = 0
    for i, (x, y) in enumerate(loader):
        x, y = x.to(device), y.to(device)
        opt.zero_grad()
        recon = sae(x)
        loss = torch.nn.functional.mse_loss(recon, y)
        encoded = sae.encoder(x)
        loss += 1e-5 * torch.mean(torch.abs(encoded))  # sparsity L1
        loss.backward()
        opt.step()
        total += loss.item()
        if i % 50 == 0:
            print(f'Epoch {ep+1} batch {i}: loss {loss.item():.6f}')
    print(f'Epoch {ep+1} avg loss {total/len(loader):.6f}')

## Evaluate post‑fine‑tune

In [None]:
sae.eval()
sample_prompt = 'Solve 12 * (7 + 5) and explain your steps.'
toks = tokenizer(sample_prompt, return_tensors='pt').to(device)
hooked_r1 = HookedSAETransformer(r1_model)
_, cache_eval = hooked_r1.run_with_cache(toks['input_ids'], saes=[sae])
acts = cache_eval[sae_id]
recon = cache_eval[f'SAE_RECON:{sae_id}']
mse = ((recon - acts)**2).mean().item()
print('Post‑tune reconstruction MSE:', mse)

# count active features in last token
feat = cache_eval[f'SAE:{sae_id}'][-1]
active = (feat.abs() > 1e-6).sum().item()
print('Active features last token:', active)

In [None]:
# Save adapted SAE
sae.save_to_disk('R1_distill_finetuned_SAE.pth')
print('SAE saved to disk.')