In [1]:
from __future__ import annotations
"""
The point of this here is basically to try using the `delphi` library from EleutherAI
(https://github.com/4gatepylon/delphi) to do autointerp and then modify it/call it
in such a way as to do iterative improvement for the autointerp.
"""


'\nThe point of this here is basically to try using the `delphi` library from EleutherAI\n(https://github.com/4gatepylon/delphi) to do autointerp and then modify it/call it\nin such a way as to do iterative improvement for the autointerp.\n'

In [4]:
"""
quick_gemma_scope_interpret.py
--------------------------------
A concise Delphi demo for interpreting the first 100 features of a
Gemma-Scope sparse auto-encoder (SAE).

Requirements (≈ the main ones)
pip install "torch>=2.2" transformers sae-lens delphi sparsify datasets orjson pydantic
"""
import asyncio
from pathlib import Path
from typing import List

import orjson
import torch
from datasets import load_dataset
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer

# ──────────────────────────────────────────────────────────────────────────────
# 1.  Choose the model + SAE you want to probe
# ──────────────────────────────────────────────────────────────────────────────
BASE_MODEL = "google/gemma-2b"
SAE_RELEASE = "gemma-scope-2b-pt-res-canonical"          # HuggingFace repo
SAE_ID      = "layer_10/width_16k/canonical"             # particular layer
HOOKPOINT   = "layers.10"                                # module name in Gemma
MAX_LATENTS = 100                                        # how many features

# ──────────────────────────────────────────────────────────────────────────────
# 2.  Load Gemma and plug in the SAE
# ──────────────────────────────────────────────────────────────────────────────
print("⏳  Loading Gemma-2B …")
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)
model     = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL,
    device_map="auto",
    trust_remote_code=True
)

print("⏳  Loading Gemma-Scope SAE …")
from sae_lens import SAE                              # pip install sae-lens

sae, *_ = SAE.from_pretrained(release=SAE_RELEASE, sae_id=SAE_ID)

# Map the model hook-point to the SAE you just loaded
submodule_dict = {HOOKPOINT: sae}

# ──────────────────────────────────────────────────────────────────────────────
# 3.  Cache a *tiny* latent dataset (demo-sized on purpose)
# ──────────────────────────────────────────────────────────────────────────────
print("⏳  Caching latent activations (≈100 k tokens) …")
from sparsify.data import chunk_and_tokenize
from delphi.latents import LatentCache                 # :contentReference[oaicite:0]{index=0}

raw_text = load_dataset("EleutherAI/fineweb-edu-dedup-10b",
                        split="train[:0.05%]")
tokens = chunk_and_tokenize(raw_text,
                            tokenizer,
                            max_seq_len=256,
                            text_key="raw_content")["input_ids"]

cache = LatentCache(model, submodule_dict, batch_size=4)
cache.run(n_tokens=100_000, tokens=tokens)
cache.save_splits(n_splits=1, save_dir="latents")       # creates ./latents/*.safetensors

# ──────────────────────────────────────────────────────────────────────────────
# 4.  Build a LatentDataset targeting the first 100 features
# ──────────────────────────────────────────────────────────────────────────────
from delphi.latents import LatentDataset
from delphi.config  import SamplerConfig, ConstructorConfig

latents = {HOOKPOINT: torch.arange(MAX_LATENTS)}        # [0 … 99]
dataset = LatentDataset(
    raw_dir="latents",
    modules=[HOOKPOINT],
    latents=latents,
    sampler_cfg=SamplerConfig(),
    constructor_cfg=ConstructorConfig(),
    tokenizer=tokenizer,
)

# ──────────────────────────────────────────────────────────────────────────────
# 5.  Configure Delphi’s explainer (local Llama-3-Instruct via vLLM)
# ──────────────────────────────────────────────────────────────────────────────
# TODO(Adrianoh) we probably want to switch to OpenAI/Anthropic's API?
from delphi.clients   import Offline
from delphi.explainers import DefaultExplainer
from delphi.pipeline  import Pipeline, process_wrapper

client    = Offline("meta-llama/Meta-Llama-3.1-8B-Instruct",
                    num_gpus=1,                       # tweak to match your GPU(s)
                    max_memory=0.90,
                    max_model_len=8192)

explainer = DefaultExplainer(client, tokenizer=tokenizer)
explainer_pipe = process_wrapper(explainer)

pipeline = Pipeline(dataset, explainer_pipe)

print("🧠  Generating explanations … (this can take ~minutes-hours)")
asyncio.run(pipeline.run(n_processes=4))                # CPU-bound; bump if you wish

# ──────────────────────────────────────────────────────────────────────────────
# 6.  Shape the results with a Pydantic schema
# ──────────────────────────────────────────────────────────────────────────────
class LatentExplanation(BaseModel):
    latent_id: int
    description: str
    positive_examples: List[str]
    negative_examples: List[str]

out: List[LatentExplanation] = []

for record in dataset:          # after the pipeline each record now has .explanation
    exp = record.explanation
    out.append(
        LatentExplanation(
            latent_id   = int(record.latent),
            description = exp["explanation"],
            positive_examples = exp.get("top_activating_tokens", []),
            negative_examples = exp.get("most_confused_tokens", [])
        )
    )

dest = Path("gemma_scope_layer10_first100_explanations.json")
dest.write_bytes(orjson.dumps([e.dict() for e in out], option=orjson.OPT_INDENT_2))
print(f"✅  Saved structured explanations → {dest.resolve()}")


⏳  Loading Gemma-2B …


Loading checkpoint shards: 100%|██████████| 2/2 [00:05<00:00,  2.82s/it]
Some parameters are on the meta device because they were offloaded to the cpu.


⏳  Loading Gemma-Scope SAE …
⏳  Caching latent activations (≈100 k tokens) …


Downloading readme: 100%|██████████| 721/721 [00:00<00:00, 20.7kB/s]
Downloading data: 100%|██████████| 97/97 [10:25<00:00,  6.45s/files]
Generating train split: 100%|██████████| 9508400/9508400 [03:00<00:00, 52712.46 examples/s]


ValueError: Unrecognized instruction format: train[:0.05%]