# Gemma‑2 + SAE (sae‑lens) — Document Feature Extraction
This notebook walks you step‑by‑step through loading **Gemma‑2 9B**, attaching a **SAE** from `sae_lens`, gathering residual activations at a target layer, and producing per‑document feature vectors (sum/mean/max across tokens).  
It also reconstructs documents from chunked JSONL rows and aligns them with metadata for downstream tasks.


In [None]:
%pip install -r requirements.txt


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.0[0m[39;49m -> [0m[32;49m25.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


## Step 0 — Prerequisites
- **Hardware:** A CUDA‑capable GPU is strongly recommended (Gemma‑2 9B is large).  
- **Python:** 3.9+
- **Accounts:** Access to the gated Hugging Face models if required (e.g., Gemma).  
- **Auth:** Set your Hugging Face token in the environment variable `HF_HUB_TOKEN`.

> If you don't have access to `google/gemma-2-9b` or a sufficiently large GPU, consider swapping in a smaller model and a matching SAE release.


In [2]:
# (Optional) Install dependencies in your environment
# If you're running locally and need to install packages, uncomment these lines.
# Note: Some installs require system CUDA toolkits preconfigured.
# !pip install -U 'transformers>=4.44' 'torch' 'huggingface_hub' 'pandas' 'numpy' 'sae-lens'

## Step 1 — Imports, device, and logging
We configure a single‑device setup (CUDA if available), set up logging to both file and console, and disable gradients for inference‑only execution.


In [1]:
import os
import json
import numpy as np
import pandas as pd
import logging
import torch
import torch.nn as nn

# force huggingface download path
# DO NOT CHANGE THIS CELL
# RUN THIS CELL ONLY IF RUNNING ON PACE-ICE

# override the huggingface cache path and nltk cache path
dirs = {
    "HF_HOME":"~/scratch/hf_cache",
    "TRITON_CACHE_DIR":"~/scratch/triton_cache",
    "TORCHINDUCTOR_CACHE_DIR":"~/scratch/inductor_cache",
    'NLTK_DATA':"~/scratch/nltk_data"
}

for name in dirs:
    d = dirs[name]
    path = os.path.expanduser(d)
    print(name)
    print(path)
    os.makedirs(path, exist_ok=True)
    # making sure the cache dirs are rwx for owner
    os.chmod(path, 0o700)
    os.environ[name] = path





# os.environ["HF_HUB_TOKEN"] = ""

# --- Read token from secret.txt ---
try:
    with open("secret.txt", "r") as f:
        token = f.read().strip()
        os.environ["HF_HUB_TOKEN"] = token
except FileNotFoundError:
    print("❌ Error: secret.txt not found. Please create it and add your Hugging Face token.")
    exit() # Exit the script if the token file is missing

from sae_lens import SAE
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import login
login(token=os.getenv("HF_HUB_TOKEN"))

# ——— 1. Single device definition ———
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

logging.basicConfig(
    filename="doc_processing.log",
    filemode="w",
    level=logging.INFO,
    format="%(asctime)s %(levelname)s %(name)s: %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S"
)
logger = logging.getLogger(__name__)
console = logging.StreamHandler()
console.setLevel(logging.INFO)
console.setFormatter(logging.Formatter("%(levelname)s: %(message)s"))
logging.getLogger().addHandler(console)

# === Login & model setup ===
torch.set_grad_enabled(False)
hf_token = os.getenv("HF_HUB_TOKEN")
if hf_token is None:
    logger.warning("HF_HUB_TOKEN is not set; private/gated models may fail to load.")

HF_HOME
/home/hice1/mzhang754/scratch/hf_cache
TRITON_CACHE_DIR
/home/hice1/mzhang754/scratch/triton_cache
TORCHINDUCTOR_CACHE_DIR
/home/hice1/mzhang754/scratch/inductor_cache
NLTK_DATA
/home/hice1/mzhang754/scratch/nltk_data
Using device: cuda:0


## Step 2 — Load model & tokenizer, then move to device
This uses `google/gemma-2-9b`. If you do not have access or sufficient VRAM, substitute with a smaller model (and a compatible SAE release).

In [2]:
# ——— 2. Load model + tokenizer, then send to device ———
# Note: 'use_auth_token' may be deprecated in newer transformers versions; if so, replace with 'token=hf_token'.
model = AutoModelForCausalLM.from_pretrained(
    # uncomment below if use Gemma-2 2B
    # "google/gemma-2-2b",
    # uncomment below if use Gemma-2 9B
    "google/gemma-2-9b",
    use_auth_token=hf_token
).to(device)

tokenizer = AutoTokenizer.from_pretrained(
    # uncomment below if use Gemma-2 2B
    # "google/gemma-2-2b",
    # uncomment below if use Gemma-2 9B
    "google/gemma-2-9b",
    use_auth_token=hf_token
)

# Load the SAE that matches Gemma‑2 model
sae, _, _ = SAE.from_pretrained(
    # uncomment below if use Gemma-2 2B
    # release="gemma-scope-2b-pt-res-canonical",
    # uncomment below if use Gemma-2 9B
    release="gemma-scope-9b-pt-res-canonical",

    # uncomment below if use Gemma-2 2B
    # sae_id="layer_12/width_16k/canonical",
    # uncomment below if use Gemma-2 9B
    sae_id="layer_20/width_131k/canonical",

    device="cuda" if torch.cuda.is_available() else "cpu",
)

model.eval()
sae.eval()



config.json:   0%|          | 0.00/856 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/39.1k [00:00<?, ?B/s]

Fetching 8 files:   0%|          | 0/8 [00:00<?, ?it/s]

model-00008-of-00008.safetensors:   0%|          | 0.00/2.38G [00:00<?, ?B/s]

model-00006-of-00008.safetensors:   0%|          | 0.00/4.96G [00:00<?, ?B/s]

model-00001-of-00008.safetensors:   0%|          | 0.00/4.84G [00:00<?, ?B/s]

model-00007-of-00008.safetensors:   0%|          | 0.00/4.96G [00:00<?, ?B/s]

model-00002-of-00008.safetensors:   0%|          | 0.00/4.96G [00:00<?, ?B/s]

model-00003-of-00008.safetensors:   0%|          | 0.00/4.96G [00:00<?, ?B/s]

model-00005-of-00008.safetensors:   0%|          | 0.00/4.96G [00:00<?, ?B/s]

model-00004-of-00008.safetensors:   0%|          | 0.00/4.93G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/173 [00:00<?, ?B/s]



tokenizer_config.json:   0%|          | 0.00/46.4k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]

layer_20/width_131k/average_l0_114/param(…):   0%|          | 0.00/3.76G [00:00<?, ?B/s]

This SAE has non-empty model_from_pretrained_kwargs. 
For optimal performance, load the model like so:
model = HookedSAETransformer.from_pretrained_no_processing(..., **cfg.model_from_pretrained_kwargs)
  sae, _, _ = SAE.from_pretrained(


JumpReLUSAE(
  (activation_fn): ReLU()
  (hook_sae_input): HookPoint()
  (hook_sae_acts_pre): HookPoint()
  (hook_sae_acts_post): HookPoint()
  (hook_sae_output): HookPoint()
  (hook_sae_recons): HookPoint()
  (hook_sae_error): HookPoint()
)

### (Optional) Step 2b — Load SAE weights manually
If you prefer to download SAE params yourself and define a simple JumpReLU SAE, use the template below instead of the `SAE.from_pretrained(...)` snippet. This is provided as a commented reference.


In [5]:
# # ——— 3. Download SAE params and move to device ———
# path_to_params = hf_hub_download(
#     repo_id="google/gemma-scope-9b-pt-res",
#     filename="layer_20/width_16k/average_l0_68/params.npz",
#     use_auth_token=hf_token
# )
# params = np.load(path_to_params)
# pt_params = {k: torch.from_numpy(v).to(device) for k, v in params.items()}

# # Define the SAE module
# class JumpReLUSAE(nn.Module):
#     def __init__(self, d_model, d_sae):
#         super().__init__()
#         self.W_enc = nn.Parameter(torch.zeros(d_model, d_sae))
#         self.W_dec = nn.Parameter(torch.zeros(d_sae, d_model))
#         self.threshold = nn.Parameter(torch.zeros(d_sae))
#         self.b_enc = nn.Parameter(torch.zeros(d_sae))
#         self.b_dec = nn.Parameter(torch.zeros(d_model))

#     def encode(self, input_acts):
#         pre_acts = input_acts @ self.W_enc + self.b_enc
#         mask = pre_acts > self.threshold
#         return mask * torch.nn.functional.relu(pre_acts)

#     def decode(self, acts):
#         return acts @ self.W_dec + self.b_dec

#     def forward(self, acts):
#         return self.decode(self.encode(acts))

# sae = JumpReLUSAE(
#     d_model=params["W_enc"].shape[0],
#     d_sae=params["W_enc"].shape[1]
# ).to(device)
# sae.load_state_dict(pt_params)

## Step 3 — Gather residual activations at a target layer
We register a forward hook on the chosen layer (here, `layer_20`) and capture the residual stream activations for the current inputs.


In [4]:
def gather_residual_activations(model, target_layer, inputs):
    target_act = None
    def gather_target_act_hook(mod, inputs, outputs):
        nonlocal target_act
        target_act = outputs[0]
        return outputs
    handle = model.model.layers[target_layer].register_forward_hook(gather_target_act_hook)
    _ = model.forward(inputs)
    handle.remove()
    return target_act

## Step 4 — Define data locations
Set the input/output paths and list the JSONL + metadata files to process.


In [None]:
# File lists
jsonl_files = [
    "transcript_componenttext_2012_1.jsonl",
    "transcript_componenttext_2012_2.jsonl",
    # "transcript_componenttext_2013_1.jsonl",
    # "transcript_componenttext_2013_2.jsonl",
    # "transcript_componenttext_2014_1.jsonl",
    # "transcript_componenttext_2014_2.jsonl",
]
meta_files = [
    "transcript_metadata_2012_1.csv",
    "transcript_metadata_2012_2.csv",
    # "transcript_metadata_2013_1.csv",
    # "transcript_metadata_2013_2.csv",
    # "transcript_metadata_2014_1.csv",
    # "transcript_metadata_2014_2.csv",
]
input_dir = "./data/train_test_data"
jsonl_files = [os.path.join(input_dir, fn) for fn in jsonl_files]
meta_files = [os.path.join(input_dir, fn) for fn in meta_files]

# uncomment below if use Gemma-2 2B
# output_dir = "./data/doc_features/2b_canonical_16k"
# uncomment below if use Gemma-2 9B
output_dir = "./data/doc_features/9b_canonical_131k

os.makedirs(output_dir, exist_ok=True)

## Step 5 — Reconstruct documents from chunked JSONL
Each JSONL row contains a mapping from chunk keys to text. We group by `transcriptid` and order by the embedded chunk index to rebuild the full document.


In [6]:
def process_jsonl(path):
    temp = {}
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            obj = json.loads(line)
            for key, text in obj.items():
                parts = key.split("_")
                tid, order = parts[2], int(parts[3])
                temp.setdefault(tid, []).append((order, text))
    return {
        tid: "\n".join(str(txt) for _, txt in sorted(lst, key=lambda x: x[0]))
        for tid, lst in temp.items()
    }

## Step 6 — Run feature extraction
For each transcript:
1. Tokenize (truncate to 20k tokens).
2. Capture residual activations at `layer_20`.
3. Encode with the SAE.
4. Aggregate features across tokens (`sum`, `mean`, `max`).
5. Periodically flush to disk in `.npz` and write aligned metadata in `.csv`.


In [None]:
# Main processing
for jpath, mpath in zip(jsonl_files, meta_files):
    prefix = os.path.splitext(os.path.basename(jpath))[0]
    logger.info(f"Starting processing for {prefix}")

    # load & align metadata once per file
    meta = pd.read_csv(mpath, dtype={"transcriptid": str})
    meta_unique = (
        meta[["transcriptid", "SUESCORE"]]
        .drop_duplicates(subset="transcriptid", keep="first")
        .set_index("transcriptid")
    )

    try:
        docs = process_jsonl(jpath)
        tids = list(docs.keys())

        # initialize buffers & counter
        feats_mean, feats_max = [], []
        feats_sum = []
        feats_ntokens = []
        success_tids = []
        count = 0

        for tid in tids[:]:
            logger.info(f"processing tid: {tid}")
            try:
                text = docs[tid]
                # tokenize & truncate to 20k tokens
                inputs = tokenizer(
                    text,
                    return_tensors="pt",
                    add_special_tokens=True,
                    truncation=True,
                    max_length=20000,
                ).input_ids.to(device)

                # record token count
                ntok = inputs.size(1)
                feats_ntokens.append(ntok)

                # get activations & SAE encodings
                res_act = gather_residual_activations(model, 20, inputs)
                acts = sae.encode(res_act.float()).cpu().numpy().squeeze(0)

                # compute feature vectors
                sum_vec  = acts.sum(axis=0)
                mean_vec = acts.mean(axis=0)
                max_vec  = acts.max(axis=0)

                feats_sum.append(sum_vec)
                feats_mean.append(mean_vec)
                feats_max.append(max_vec)
                success_tids.append(tid)

            except Exception as e:
                logger.exception(f"Doc {tid} failed: {e}")

            finally:
                # free per-doc GPU memory
                for var in ("inputs", "res_act", "acts"):
                    if var in locals():
                        del locals()[var]
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()

                # increment & maybe flush
                count += 1
                if count % 100 == 0:
                    part = count // 100
                    logger.info(f"Flushing part {part} ({count} docs)")
                    # stack features
                    # uncomment below if use Gemma-2 2B
                    npz_path = os.path.join(output_dir, f"{prefix}_part{part}_2b_canon_features.npz")
                    # uncomment below if use Gemma-2 9B
                    # npz_path = os.path.join(output_dir, f"{prefix}_part{part}_9b_canon_features.npz")
                    np.savez(
                        npz_path,
                        X_sum        = np.vstack(feats_sum),
                        X_mean       = np.vstack(feats_mean),
                        X_max        = np.vstack(feats_max),
                        token_counts = np.array(feats_ntokens, int),
                        transcriptids = np.array(success_tids, dtype=str)
                    )
                    # write meta CSV
                    meta_batch = (
                        meta_unique
                        .reindex(success_tids)
                        .reset_index()
                    )
                    meta_batch["SUESCORE"] = meta_batch["SUESCORE"].astype(float)
                    meta_batch["label"] = meta_batch["SUESCORE"].apply(
                        lambda s: 1 if s >= 0.5 else (0 if s <= -0.5 else np.nan)
                    )
                    meta_batch.to_csv(
                        # uncomment below if use Gemma-2 2B
                        # os.path.join(output_dir, f"{prefix}_part{part}_2b_canon_features_meta.csv"),
                        # uncomment below if use Gemma-2 9B
                        os.path.join(output_dir, f"{prefix}_part{part}_9b_canon_features_meta.csv"),
                        index=False
                    )
                    # clear buffers
                    feats_sum.clear()
                    feats_mean.clear()
                    feats_max.clear()
                    feats_ntokens.clear()
                    success_tids.clear()
                    if torch.cuda.is_available():
                        torch.cuda.empty_cache()

        # final flush for any remainder
        if count % 100 != 0:
            part = (count // 100) + 1
            logger.info(f"Flushing final part {part} ({count % 100} docs)")
            np.savez(
                # uncomment below if use Gemma-2 2B
                # os.path.join(output_dir, f"{prefix}_part{part}_2b_canon_features.npz"),
                # uncomment below if use Gemma-2 9B
                os.path.join(output_dir, f"{prefix}_part{part}_9b_canon_features.npz"),
                X_sum        = np.vstack(feats_sum),
                X_mean       = np.vstack(feats_mean),
                X_max        = np.vstack(feats_max),
                token_counts = np.array(feats_ntokens, int),
                transcriptids = np.array(success_tids, dtype=str)
            )
            meta_batch = (
                meta_unique
                .reindex(success_tids)
                .reset_index()
            )
            meta_batch["SUESCORE"] = meta_batch["SUESCORE"].astype(float)
            meta_batch["label"] = meta_batch["SUESCORE"].apply(
                lambda s: 1 if s >= 0.5 else (0 if s <= -0.5 else np.nan)
            )
            meta_batch.to_csv(
                # uncomment below if use Gemma-2 2B
                # os.path.join(output_dir, f"{prefix}_part{part}_2b_canon_features_meta.csv"),
                # uncomment below if use Gemma-2 9B
                os.path.join(output_dir, f"{prefix}_part{part}_9b_canon_features_meta.csv"),
                index=False
            )

        logger.info(f"Finished processing & flushing all parts for {prefix}")

    except Exception as e:
        logger.exception(f"Failed to process {prefix}: {e}")

INFO: Starting processing for transcript_componenttext_2014_1


INFO: processing tid: 564725
INFO: processing tid: 564749
INFO: processing tid: 564875
INFO: processing tid: 564953
INFO: processing tid: 565016
INFO: processing tid: 565036
INFO: processing tid: 565105
INFO: processing tid: 565122
INFO: processing tid: 565139
INFO: processing tid: 565161
INFO: processing tid: 565168
INFO: processing tid: 565178
INFO: processing tid: 565214
INFO: processing tid: 565220
INFO: processing tid: 565279
INFO: processing tid: 565299
INFO: processing tid: 565337
INFO: processing tid: 565351
INFO: processing tid: 565355
INFO: processing tid: 565384
INFO: processing tid: 565393
INFO: processing tid: 565404
INFO: processing tid: 565411
INFO: processing tid: 565429
INFO: processing tid: 565441
INFO: processing tid: 565461
INFO: processing tid: 565510
INFO: processing tid: 565606
INFO: processing tid: 565624
INFO: processing tid: 565632
INFO: processing tid: 565634
INFO: processing tid: 565646
INFO: processing tid: 565667
INFO: processing tid: 565684
INFO: processi

In [None]:
output_dir

'./data/doc_features/2b_canonical_16k'

## Step 7 — Inspect outputs (optional)
Quickly list what was saved and preview shapes from an `.npz` file.


In [1]:
from glob import glob
saved = sorted(glob(os.path.join(output_dir, "*.npz")))
print("Saved NPZ files:", saved[:5], ("... ({} total)".format(len(saved)) if len(saved) > 5 else ""))

if saved:
    sample = np.load(saved[0])
    for k in sample.files:
        print(k, sample[k].shape)

NameError: name 'os' is not defined