# sae

model: `sae.ep1.pt`

* model: recurrentgemma-9b (d_model=4,096, layer=30, activation=rg_lru)
* Dataset: minipile (110K data instances = 1/10 of the entire dataset)
* num sae feat: 131K (32x d_model)
* sae type: topk, k=32
* training tokens: 259M (1 ep.)

In [1]:
import sys
print(sys.executable)

/home/tg352/.conda/envs/python311/bin/python


In [2]:
import os
import pickle
import torch
import torch.nn.functional as F
import jaxtyping as jt
import kagglehub
from pathlib import Path
from openai_sparse_autoencoder.train import FastAutoencoder
from typing import List
import sentencepiece as spm
import gzip
import random

Fl = lambda size: jt.Float[torch.Tensor, size]

In [3]:
DATA_PATH = "/share/rush/tg352/sae/minipile/9b/artefacts/"
MODEL_PATH = "/home/tg352/sparse_autoencoder/artefacts"

In [4]:
weights_dir = Path(
    kagglehub.model_download(f"google/recurrentgemma/pyTorch/9b")
)
vocab_path = weights_dir / "tokenizer.model"
vocab = spm.SentencePieceProcessor()
vocab.Load(str(vocab_path))
len(vocab)

256000

In [5]:
sae_model_pt = "sae.ep1.pt"
sae_model_pt

'sae.ep1.pt'

In [6]:
d_model = 4_096
sae = FastAutoencoder(
    n_dirs_local=d_model * 32,
    d_model=d_model,
    k=32,
    dead_steps_threshold=None,
    auxk=None,
)
sae.eval()
sae.load_state_dict(
    torch.load(
        os.path.join(MODEL_PATH, sae_model_pt),
        map_location=torch.device("cpu"),
        weights_only=True,
    )
)
sae

  return self.fget.__get__(instance, owner)()


FastAutoencoder(
  (encoder): Linear(in_features=4096, out_features=131072, bias=False)
  (decoder): Linear(in_features=131072, out_features=4096, bias=False)
)

In [7]:
print("frac-dead: ", end="")
print(torch.sum(sae.stats_last_nonzero > (10_000_000 // 16384)) / (32 * 4_096))
print(f"total: {32 * 4_096}, dead = {torch.sum(sae.stats_last_nonzero > (10_000_000 // 16384))}")

frac-dead: tensor(0.0149)
total: 131072, dead = 1958


In [8]:
act_filenames = os.listdir(DATA_PATH)
act_filename = random.choice(act_filenames)
act_filename

'pid.0-batch.37491.pkl.gz'

In [9]:
def get_acts(filename: str) -> [list[str], Fl("l d")]:
    act_dict = pickle.load(gzip.open(os.path.join(DATA_PATH, filename), "rb"))
    tokens = [vocab.IdToPiece(input_id) for input_id in act_dict["input_ids"][0].tolist()]
    return tokens, act_dict[f"blocks.30"]["rg_lru_states"][0].float()


toks, acts = get_acts(act_filename)
len(toks), acts.shape, acts.dtype

(1346, torch.Size([1346, 4096]), torch.float32)

In [10]:
acts.max(), acts.min()

(tensor(18.8750), tensor(-20.1250))

In [11]:
def get_feats(acts: Fl("l d"), k: int = 1) -> [Fl("l k"), Fl("l k")]:
    return sae.encode(acts, k=k)
    

feat_idxs, feat_vals = get_feats(acts)
feat_idxs.shape

torch.Size([1346, 1])

In [12]:
feat_vals.shape, feat_idxs

(torch.Size([1346, 1]),
 tensor([[97060],
         [97060],
         [97060],
         ...,
         [93799],
         [77678],
         [77678]]))

In [32]:
def highlight_tok(tok: str, act_val: float) -> str:
    if act_val < 0:
        # Red spectrum (160-196 in 256-color mode).
        intensity = int(160 + (1 + act_val) * 36)
    else:
        # Blue spectrum (34-46 in 256-color mode).
        intensity = int(34 + act_val * 12)
    return f"\033[38;5;{intensity}m{tok}\033[0m"

print(highlight_tok("tok", -0.5))

[38;5;178mtok[0m


In [34]:
def highlight_tok(tok: str, act_val: float) -> str:
    if act_val >= 0:
        return f"\033[32m{tok}\033[0m"
    else:
        return f"\033[31m{tok}\033[0m"

In [35]:
def view_feat(feat_idx: int, max_num_files: int = 10, min_tok_len: int | None = None):
    num_files_w_feat = 0
    for filename in act_filenames:
        if num_files_w_feat > max_num_files:
            break
        
        toks, acts = get_acts(filename)
        if min_tok_len is not None and len(toks) < min_tok_len:
            continue
        topk_idxs, topk_vals = get_feats(acts, k=1)  # (seq_len, k), (seq_len, k)
        # print(topk_idxs)
        if feat_idx not in topk_idxs:
            continue
        
        mask = (topk_idxs == feat_idx).squeeze(1)
        text_to_print = []
        for tok, val, is_feat in zip(toks, topk_vals.squeeze(1), mask):
            if is_feat:
                text_to_print.append(f"{highlight_tok(tok, val)}")
            else:
                text_to_print.append(tok)
        print(f"{filename}\n{'-' * len(filename)}")
        print("".join(text_to_print))
        print("\n\n")
        
        num_files_w_feat += 1

In [None]:
view_feat(103, max_num_files=2)