In [1]:
!pip install -e circuit-tracer

Defaulting to user installation because normal site-packages is not writeable
Obtaining file:///home/chriskino/Attribution-Graph-Qwen-/circuit-tracer
  Installing build dependencies ... [?25ldone
[?25h  Checking if build backend supports build_editable ... [?25ldone
[?25h  Getting requirements to build editable ... [?25ldone
[?25h  Installing backend dependencies ... [?25ldone
[?25h  Preparing editable metadata (pyproject.toml) ... [?25ldone
Building wheels for collected packages: circuit-tracer
  Building editable for circuit-tracer (pyproject.toml) ... [?25ldone
[?25h  Created wheel for circuit-tracer: filename=circuit_tracer-0.1.0-py3-none-any.whl size=6281 sha256=8770852c4ef29bc5e2f327300216c4e60f2ec46c0c5e8c671cc34866d47a9af3
  Stored in directory: /tmp/pip-ephem-wheel-cache-l427jlhy/wheels/50/3e/98/6b6182be2f50cece9f65caed8fbff0135cbbdf3c8ba3867876
Successfully built circuit-tracer
Installing collected packages: circuit-tracer
  Attempting uninstall: circuit-tracer
   

In [2]:
# Put this in a cell before importing torch/transformers
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

In [3]:
#@title Colab Setup Environment

try:
    import google.colab
    !mkdir -p repository && cd repository && \
     git clone https://github.com/safety-research/circuit-tracer && \
     curl -LsSf https://astral.sh/uv/install.sh | sh && \
     uv pip install -e circuit-tracer/

    import sys
    from huggingface_hub import notebook_login
    sys.path.append('repository/circuit-tracer')
    sys.path.append('repository/circuit-tracer/demos')
    notebook_login(new_session=False)
    IN_COLAB = True
except ImportError:
    IN_COLAB = False

In [4]:


from pathlib import Path
import torch

from circuit_tracer import attribute
from circuit_tracer.utils import create_graph_files



In [None]:
from huggingface_hub import login
login(token="")
from huggingface_hub import get_token, hf_api, hf_hub_download, snapshot_download

In [6]:
model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-14B"
transcoder_name = "mwhanna/qwen3-14b-transcoders-lowl0"

In [7]:
import transformer_lens

In [8]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification, BitsAndBytesConfig, AutoModelForCausalLM, pipeline

In [9]:
from urllib.parse import parse_qs, urlparse

In [10]:
from typing import NamedTuple
class HfUri(NamedTuple):
    """Structured representation of a HuggingFace URI."""

    repo_id: str
    file_path: str | None
    revision: str | None

    @classmethod
    def from_str(cls, hf_ref: str):
        if hf_ref.startswith("hf://"):
            return parse_hf_uri(hf_ref)

        parts = hf_ref.split("@", 1)
        repo_id = parts[0]
        revision = parts[1] if len(parts) > 1 else None
        return cls(repo_id, None, revision)


In [11]:
def parse_hf_uri(uri: str) -> HfUri:
    """Parse an HF URI into repo id, file path and revision.

    Args:
        uri: String like ``hf://org/repo/file?revision=main``.

    Returns:
        ``HfUri`` with repository id, file path and optional revision.
    """
    parsed = urlparse(uri)
    if parsed.scheme != "hf":
        raise ValueError(f"Not a huggingface URI: {uri}")
    path = parsed.path.lstrip("/")
    repo_parts = path.split("/", 1)
    if len(repo_parts) != 2:
        raise ValueError(f"Invalid huggingface URI: {uri}")
    repo_id = f"{parsed.netloc}/{repo_parts[0]}"
    file_path = repo_parts[1]
    revision = parse_qs(parsed.query).get("revision", [None])[0] or None
    return HfUri(repo_id, file_path, revision)

In [12]:
def resolve_transcoder_paths(config: dict) -> dict:
    if "transcoders" in config:
        hf_paths = [path for path in config["transcoders"] if path.startswith("hf://")]
        local_map = download_hf_uris(hf_paths)
        transcoder_paths = {
            i: local_map.get(path, path) for i, path in enumerate(config["transcoders"])
        }
    else:
        local_path = snapshot_download(
            config["repo_id"],
            revision=config.get("revision", "main"),
            allow_patterns=["layer_*.safetensors"],
        )
        layer_files = glob.glob(os.path.join(local_path, "layer_*.safetensors"))
        transcoder_paths = {
            i: os.path.join(local_path, f"layer_{i}.safetensors") for i in range(len(layer_files))
        }
    return transcoder_paths

In [13]:
import os
from collections.abc import Iterator

import numpy as np
import torch
import torch.nn.functional as F
from huggingface_hub import hf_hub_download
from safetensors import safe_open
from torch import nn

from circuit_tracer.transcoder.activation_functions import JumpReLU
from circuit_tracer.utils import get_default_device


class SingleLayerTranscoder(nn.Module):
    """
    A per-layer transcoder (PLT) that replaces MLP computation with interpretable features.

    Per-layer transcoders decompose the output of a single MLP layer into sparsely active
    features that often correspond to interpretable concepts. Unlike cross-layer transcoders,
    each PLT operates independently on its assigned layer, which can result in longer paths
    through attribution graphs when features amplify across multiple layers.

    Attributes:
        d_model: Dimension of the transformer's residual stream
        d_transcoder: Number of learned features (typically >> d_model for superposition)
        layer_idx: Which transformer layer this transcoder replaces
        W_enc: Encoder weights mapping residual stream to feature space
        W_dec: Decoder weights mapping features back to residual stream
        b_enc: Encoder bias terms
        b_dec: Decoder bias terms (reconstruction baseline)
        W_skip: Optional skip connection weights (https://arxiv.org/abs/2501.18823)
        activation_function: Sparsity-inducing nonlinearity (e.g., ReLU, JumpReLU)
    """

    def __init__(
        self,
        d_model: int,
        d_transcoder: int,
        activation_function,
        layer_idx: int,
        skip_connection: bool = False,
        transcoder_path: str | None = None,
        lazy_encoder: bool = False,
        lazy_decoder: bool = False,
        device: torch.device | None = None,
        dtype: torch.dtype = torch.bfloat16,
    ):
        super().__init__()

        if device is None:
            device = get_default_device()

        self.d_model = d_model
        self.d_transcoder = d_transcoder
        self.layer_idx = layer_idx
        self.transcoder_path = transcoder_path
        self.lazy_encoder = lazy_encoder
        self.lazy_decoder = lazy_decoder

        if lazy_encoder or lazy_decoder:
            assert self.transcoder_path is not None, "Transcoder path must be set for lazy loading"

        if not lazy_encoder:
            self.W_enc = nn.Parameter(
                torch.zeros(d_transcoder, d_model, device=device, dtype=dtype)
            )

        if not lazy_decoder:
            self.W_dec = nn.Parameter(
                torch.zeros(d_transcoder, d_model, device=device, dtype=dtype)
            )

        self.b_enc = nn.Parameter(torch.zeros(d_transcoder, device=device, dtype=dtype))
        self.b_dec = nn.Parameter(torch.zeros(d_model, device=device, dtype=dtype))

        if skip_connection:
            self.W_skip = nn.Parameter(torch.zeros(d_model, d_model, device=device, dtype=dtype))
        else:
            self.W_skip = None

        self.activation_function = activation_function

    @property
    def device(self):
        """Get the device of the module's parameters."""
        return next(self.parameters()).device

    @property
    def dtype(self):
        """Get the dtype of the module's parameters."""
        return self.b_enc.dtype

    def __getattr__(self, name):
        """Dynamically load weights when accessed if lazy loading is enabled."""

        if name == "W_enc" and self.lazy_encoder and self.transcoder_path is not None:
            with safe_open(self.transcoder_path, framework="pt", device=self.device.type) as f:
                return f.get_tensor("W_enc").to(self.dtype)
        elif name == "W_dec" and self.lazy_decoder and self.transcoder_path is not None:
            with safe_open(self.transcoder_path, framework="pt", device=self.device.type) as f:
                return f.get_tensor("W_dec").to(self.dtype)

        return super().__getattr__(name)

    def _get_decoder_vectors(self, feat_ids=None):
        to_read = feat_ids if feat_ids is not None else np.s_[:]
        if not self.lazy_decoder:
            return self.W_dec[to_read].to(self.dtype)

        if isinstance(to_read, torch.Tensor):
            to_read = to_read.cpu()
        with safe_open(self.transcoder_path, framework="pt", device=self.device.type) as f:
            return f.get_slice("W_dec")[to_read].to(self.dtype)

    def encode(self, input_acts, apply_activation_function: bool = True):
        W_enc = self.W_enc
        pre_acts = F.linear(input_acts.to(W_enc.dtype), W_enc, self.b_enc)
        if not apply_activation_function:
            return pre_acts
        return self.activation_function(pre_acts)

    def decode(self, acts):
        W_dec = self.W_dec
        return acts @ W_dec + self.b_dec

    def compute_skip(self, input_acts):
        if self.W_skip is not None:
            return input_acts @ self.W_skip.T
        else:
            raise ValueError("Transcoder has no skip connection")

    def forward(self, input_acts):
        transcoder_acts = self.encode(input_acts)
        decoded = self.decode(transcoder_acts)
        decoded = decoded.detach()
        decoded.requires_grad = True

        if self.W_skip is not None:
            skip = self.compute_skip(input_acts)
            decoded = decoded + skip

        return decoded

    def encode_sparse(self, input_acts, zero_first_pos: bool = True):
        """Encode and return sparse activations with active encoder vectors.

        Args:
            input_acts: Input activations
            zero_first_pos: Whether to zero out position 0

        Returns:
            sparse_acts: Sparse tensor of activations
            active_encoders: Encoder vectors for active features only
        """
        W_enc = self.W_enc
        pre_acts = F.linear(input_acts.to(W_enc.dtype), W_enc, self.b_enc)
        acts = self.activation_function(pre_acts)

        if zero_first_pos:
            acts[0] = 0

        sparse_acts = acts.to_sparse()
        _, feat_idx = sparse_acts.indices()
        active_encoders = W_enc[feat_idx]

        return sparse_acts, active_encoders

    def decode_sparse(self, sparse_acts):
        """Decode sparse activations and return reconstruction with scaled decoder vectors.

        Returns:
            reconstruction: Decoded output
            scaled_decoders: Decoder vectors scaled by activation values
        """
        pos_idx, feat_idx = sparse_acts.indices()
        values = sparse_acts.values()

        # Get decoder vectors for active features only
        W_dec = self._get_decoder_vectors(feat_idx.cpu())
        scaled_decoders = W_dec * values[:, None]

        # Reconstruct using index_add
        n_pos = sparse_acts.shape[0]
        reconstruction = torch.zeros(
            n_pos, self.d_model, device=sparse_acts.device, dtype=sparse_acts.dtype
        )
        reconstruction = reconstruction.index_add_(0, pos_idx, scaled_decoders)
        reconstruction = reconstruction + self.b_dec

        return reconstruction, scaled_decoders


class TranscoderSet(nn.Module):
    """
    A collection of per-layer transcoders that enable construction of a replacement model.

    TranscoderSet manages the collection of SingleLayerTranscoders needed for this substitution,
    where each transcoder replaces the MLP computation at its corresponding layer.

    Attributes:
        transcoders: ModuleList of SingleLayerTranscoder instances, one per layer
        n_layers: Total number of layers covered
        d_transcoder: Common feature dimension across all transcoders
        feature_input_hook: Hook point where features read from (e.g., "hook_resid_mid")
        feature_output_hook: Hook point where features write to (e.g., "hook_mlp_out")
        scan: Optional identifier to identify corresponding feature visualization
        skip_connection: Whether transcoders include learned skip connections
    """

    def __init__(
        self,
        transcoders: dict[int, SingleLayerTranscoder],
        feature_input_hook: str,
        feature_output_hook: str,
        scan: str | list[str] | None = None,
    ):
        super().__init__()
        # Validate that we have continuous layers from 0 to max
        assert set(transcoders.keys()) == set(range(max(transcoders.keys()) + 1)), (
            f"Each layer should have a transcoder, but got transcoders for layers "
            f"{set(transcoders.keys())}"
        )

        self.transcoders = nn.ModuleList([transcoders[i] for i in range(len(transcoders))])
        self.n_layers = len(self.transcoders)
        self.d_transcoder = self.transcoders[0].d_transcoder

        # Verify all transcoders have the same d_transcoder
        for transcoder in self.transcoders:
            assert transcoder.d_transcoder == self.d_transcoder, (
                f"All transcoders must have the same d_transcoder, but got "
                f"{transcoder.d_transcoder} != {self.d_transcoder}"
            )

        # Store hook configuration
        self.feature_input_hook = feature_input_hook
        self.feature_output_hook = feature_output_hook
        self.scan = scan
        self.skip_connection = self.transcoders[0].W_skip is not None

    def __len__(self):
        return self.n_layers

    def __getitem__(self, idx: int) -> SingleLayerTranscoder:
        return self.transcoders[idx]  # type: ignore

    def __iter__(self) -> Iterator[SingleLayerTranscoder]:
        return iter(self.transcoders)  # type: ignore

    def apply_activation_function(self, layer_id, features):
        return self.transcoders[layer_id].activation_function(features)  # type: ignore

    def encode(self, input_acts):
        return torch.stack(
            [transcoder.encode(input_acts[i]) for i, transcoder in enumerate(self.transcoders)],  # type: ignore
            dim=0,
        )

    def _get_decoder_vectors(self, layer_id, features):
        return self.transcoders[layer_id]._get_decoder_vectors(features)  # type: ignore

    def select_decoder_vectors(self, features):
        if not features.is_sparse:
            features = features.to_sparse()

        all_layer_idx, all_pos_idx, all_feat_idx = features.indices()
        all_activations = features.values()
        all_scaled_decoder_vectors = []
        for unique_layer in all_layer_idx.unique():
            layer_mask = all_layer_idx == unique_layer
            feat_idx = all_feat_idx[layer_mask]
            activations = all_activations[layer_mask]

            decoder_vectors = self._get_decoder_vectors(unique_layer.item(), feat_idx)

            # Multiply each activation by its corresponding decoder vector
            scaled_decoder_vectors = activations.unsqueeze(-1) * decoder_vectors
            all_scaled_decoder_vectors.append(scaled_decoder_vectors)

        all_scaled_decoder_vectors = torch.cat(all_scaled_decoder_vectors)
        encoder_mapping = torch.arange(features._nnz(), device=features.device)

        return all_pos_idx, all_layer_idx, all_feat_idx, all_scaled_decoder_vectors, encoder_mapping

    def decode(self, acts):
        return torch.stack(
            [transcoder.decode(acts[i]) for i, transcoder in enumerate(self.transcoders)],  # type: ignore
            dim=0,
        )

    def compute_attribution_components(
        self,
        mlp_inputs: torch.Tensor,
    ) -> dict[str, torch.Tensor]:
        """Extract active features and their encoder/decoder vectors for attribution.

        Args:
            mlp_inputs: (n_layers, n_pos, d_model) tensor of MLP inputs

        Returns:
            Dict containing all components needed for AttributionContext:
                - activation_matrix: Sparse (n_layers, n_pos, d_transcoder) activations
                - reconstruction: (n_layers, n_pos, d_model) reconstructed outputs
                - encoder_vecs: Concatenated encoder vectors for active features
                - decoder_vecs: Concatenated decoder vectors (scaled by activations)
                - encoder_to_decoder_map: Mapping from encoder to decoder indices
        """
        device = mlp_inputs.device

        reconstruction = torch.zeros_like(mlp_inputs)
        encoder_vectors = []
        decoder_vectors = []
        sparse_acts_list = []

        for layer, transcoder in enumerate(self.transcoders):
            sparse_acts, active_encoders = transcoder.encode_sparse(  # type: ignore
                mlp_inputs[layer], zero_first_pos=True
            )
            reconstruction[layer], active_decoders = transcoder.decode_sparse(sparse_acts)  # type: ignore
            encoder_vectors.append(active_encoders)
            decoder_vectors.append(active_decoders)
            sparse_acts_list.append(sparse_acts)

        activation_matrix = torch.stack(sparse_acts_list).coalesce()
        encoder_to_decoder_map = torch.arange(activation_matrix._nnz(), device=device)

        return {
            "activation_matrix": activation_matrix,
            "reconstruction": reconstruction,
            "encoder_vecs": torch.cat(encoder_vectors, dim=0),
            "decoder_vecs": torch.cat(decoder_vectors, dim=0),
            "encoder_to_decoder_map": encoder_to_decoder_map,
            "decoder_locations": activation_matrix.indices()[:2],
        }

    def encode_layer(self, x, layer_id, apply_activation_function=True):
        return self.transcoders[layer_id].encode(
            x, apply_activation_function=apply_activation_function
        )  # type: ignore


def load_gemma_scope_transcoder(
    path: str,
    layer: int,
    device: torch.device | None = None,
    dtype: torch.dtype = torch.float32,
    revision: str | None = None,
    **kwargs,
) -> SingleLayerTranscoder:
    if device is None:
        device = get_default_device()
    if os.path.isfile(path):
        path_to_params = path
    else:
        path_to_params = hf_hub_download(
            repo_id="google/gemma-scope-2b-pt-transcoders",
            filename=path,
            revision=revision,
            force_download=False,
        )

    # load the parameters, have to rename the threshold key,
    # as ours is nested inside the activation_function module
    param_dict = np.load(path_to_params)
    param_dict = {k: torch.tensor(v, device=device, dtype=dtype) for k, v in param_dict.items()}
    param_dict["activation_function.threshold"] = param_dict["threshold"]
    param_dict["W_enc"] = param_dict["W_enc"].T.contiguous()
    del param_dict["threshold"]

    # create the transcoders
    # d_model = param_dict["W_enc"].shape[0]
    # d_transcoder = param_dict["W_enc"].shape[1]
    d_transcoder, d_model = param_dict["W_enc"].shape

    # dummy JumpReLU; will get loaded via load_state_dict
    activation_function = JumpReLU(torch.tensor(0.0), 0.1)
    with torch.device("meta"):
        transcoder = SingleLayerTranscoder(d_model, d_transcoder, activation_function, layer)
    transcoder.load_state_dict(param_dict, assign=True)
    return transcoder


def load_relu_transcoder(
    path: str,
    layer: int,
    device: torch.device | None = None,
    dtype: torch.dtype = torch.float32,
    lazy_encoder: bool = True,
    lazy_decoder: bool = True,
):
    if device is None:
        device = get_default_device()

    param_dict = {}
    with safe_open(path, framework="pt", device=device.type) as f:
        for k in f.keys():
            if lazy_encoder and k == "W_enc":
                continue
            if lazy_decoder and k == "W_dec":
                continue
            param_dict[k] = f.get_tensor(k)

    d_sae = param_dict["b_enc"].shape[0]
    d_model = param_dict["b_dec"].shape[0]

    assert param_dict.get("log_thresholds") is None
    activation_function = F.relu
    with torch.device("meta"):
        transcoder = SingleLayerTranscoder(
            d_model,
            d_sae,
            activation_function,
            layer,
            skip_connection=param_dict.get("W_skip") is not None,
            transcoder_path=path,
            lazy_encoder=lazy_encoder,
            lazy_decoder=lazy_decoder,
        )
    transcoder.load_state_dict(param_dict, assign=True)
    return transcoder.to(dtype)


def load_transcoder_set(
    transcoder_paths: dict,
    scan: str,
    feature_input_hook: str,
    feature_output_hook: str,
    device: torch.device | None = None,
    dtype: torch.dtype = torch.float32,
    gemma_scope: bool = False,
    lazy_encoder: bool = True,
    lazy_decoder: bool = True,
) -> TranscoderSet:
    if device is None:
        device = get_default_device()
    """Loads either a preset set of transcoders, or a set specified by a file.

    Args:
        transcoder_paths: Dictionary mapping layer indices to transcoder paths
        scan: Scan identifier
        feature_input_hook: Hook point where features read from
        feature_output_hook: Hook point where features write to
        device (torch.device | None, optional): Device to load to
        dtype (torch.dtype | None, optional): Data type to use
        gemma_scope: Whether to use gemma scope loader
        lazy_encoder: Whether to use lazy loading for encoder weights
        lazy_decoder: Whether to use lazy loading for decoder weights

    Returns:
        TranscoderSet: The loaded transcoder set with all configuration
    """
    n_gpus = torch.cuda.device_count()
    devices = [torch.device(f"cuda:{i}") for i in range(n_gpus)] if n_gpus > 0 else [torch.device("cpu")]

    transcoders = {}
    load_fn = load_gemma_scope_transcoder if gemma_scope else load_relu_transcoder
    for layer in range(len(transcoder_paths)):
        layer_device = devices[layer % len(devices)]
        transcoders[layer] = load_fn(
            transcoder_paths[layer],
            layer,
            device=layer_device,
            dtype=dtype,
            lazy_encoder=lazy_encoder,
            lazy_decoder=lazy_decoder,
        )
    # we don't know how many layers the model has, but we need all layers from 0 to max covered
    assert set(transcoders.keys()) == set(range(max(transcoders.keys()) + 1)), (
        f"Each layer should have a transcoder, but got transcoders for layers "
        f"{set(transcoders.keys())}"
    )

    return TranscoderSet(
        transcoders,
        feature_input_hook=feature_input_hook,
        feature_output_hook=feature_output_hook,
        scan=scan,
    )

In [14]:
def load_transcoders(
    config: dict,
    device: torch.device | None = None,
    dtype: torch.dtype = torch.float32,
    lazy_encoder: bool = False,
    lazy_decoder: bool = True,
):
    """Load a transcoder from a HuggingFace URI."""

    model_kind = config["model_kind"]
    if model_kind == "transcoder_set":
        # from circuit_tracer.transcoder.single_layer_transcoder import load_transcoder_set

        transcoder_paths = resolve_transcoder_paths(config)
        is_gemma_scope = "gemma-scope" in config.get("repo_id", "")

        return load_transcoder_set(
            transcoder_paths,
            scan=config["scan"],
            feature_input_hook=config["feature_input_hook"],
            feature_output_hook=config["feature_output_hook"],
            gemma_scope=is_gemma_scope,
            dtype=dtype,
            device=device,
            lazy_encoder=lazy_encoder,
            lazy_decoder=lazy_decoder,
        )
    elif model_kind == "cross_layer_transcoder":
        from circuit_tracer.transcoder.cross_layer_transcoder import load_clt

        local_path = snapshot_download(
            config["repo_id"],
            revision=config.get("revision", "main"),
            allow_patterns=["*.safetensors"],
        )

        return load_clt(
            local_path,
            scan=config["scan"],
            feature_input_hook=config["feature_input_hook"],
            feature_output_hook=config["feature_output_hook"],
            lazy_decoder=lazy_decoder,
            lazy_encoder=lazy_encoder,
            dtype=dtype,
            device=device,
        )
    else:
        raise ValueError(f"Unknown model kind: {model_kind}")

In [15]:
import yaml
def load_transcoder_from_hub(
    hf_ref: str,
    device: torch.device | None = None,
    dtype: torch.dtype = torch.float32,
    lazy_encoder: bool = False,
    lazy_decoder: bool = True,
):
    """Load a transcoder from a HuggingFace URI."""

    # resolve legacy references
    if hf_ref == "gemma":
        hf_ref = "mntss/gemma-scope-transcoders"
    elif hf_ref == "llama":
        hf_ref = "mntss/transcoder-Llama-3.2-1B"

    hf_uri = HfUri.from_str(hf_ref)
    try:
        config_path = hf_hub_download(
            repo_id=hf_uri.repo_id,
            revision=hf_uri.revision,
            filename="config.yaml",
        )
    except Exception as e:
        raise FileNotFoundError(f"Could not download config.yaml from {hf_uri.repo_id}") from e

    with open(config_path) as f:
        config = yaml.safe_load(f)

    config["repo_id"] = hf_uri.repo_id
    config["revision"] = hf_uri.revision
    config["scan"] = f"{hf_uri.repo_id}@{hf_uri.revision}" if hf_uri.revision else hf_uri.repo_id

    return load_transcoders(config, device, dtype, lazy_encoder, lazy_decoder), config

In [16]:
import glob

In [17]:
import os
from collections.abc import Iterator

import numpy as np
import torch
import torch.nn.functional as F
from huggingface_hub import hf_hub_download
from safetensors import safe_open
from torch import nn

from circuit_tracer.transcoder.activation_functions import JumpReLU
from circuit_tracer.utils import get_default_device

In [18]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
transcoders, config = load_transcoder_from_hub(
        "mwhanna/qwen3-14b-transcoders-lowl0",
        dtype= torch.bfloat16,
        lazy_encoder=True,
        lazy_decoder=True,
    )

Fetching 40 files:   0%|          | 0/40 [00:00<?, ?it/s]

In [19]:
transcoders

TranscoderSet(
  (transcoders): ModuleList(
    (0-39): 40 x SingleLayerTranscoder()
  )
)

In [20]:
import warnings
from contextlib import contextmanager
from collections.abc import Callable
from typing import Any

import torch
import torch.nn.functional as F
from torch import nn
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.tokenization_utils_base import PreTrainedTokenizerBase

from circuit_tracer.attribution.context import AttributionContext
# from circuit_tracer.transcoder import TranscoderSet
from circuit_tracer.transcoder.cross_layer_transcoder import CrossLayerTranscoder
from circuit_tracer.utils import get_default_device
from circuit_tracer.utils.hf_utils import load_transcoder_from_hub


class ReplacementModel(nn.Module):
    """
    HF-based replacement model that uses PyTorch forward hooks (no TransformerLens).

    It attaches:
      - pre_forward hooks on each layer's MLP module to capture MLP inputs
      - forward hooks on each layer's MLP module to capture MLP outputs
    """

    transcoders: TranscoderSet | CrossLayerTranscoder
    feature_input_hook: str
    feature_output_hook: str
    skip_transcoder: bool
    scan: str | list[str] | None
    tokenizer: PreTrainedTokenizerBase

    def __init__(
        self,
        model_name: str,
        transcoders: TranscoderSet | CrossLayerTranscoder,
        device: torch.device | None = None,
        dtype: torch.dtype = torch.float32,
        device_map: str | dict | None = None,
        offload_folder: str | None = None,
        use_low_cpu_mem_usage: bool = True,
        **kwargs: Any,
    ):
        super().__init__()
        if device is None:
            device = get_default_device()

        load_kwargs: dict[str, Any] = {
            "torch_dtype": dtype,
            "device_map": device_map,
            "offload_folder": offload_folder,
            "low_cpu_mem_usage": use_low_cpu_mem_usage,
        }
        # Drop Nones
        load_kwargs = {k: v for k, v in load_kwargs.items() if v is not None}

        self.model = AutoModelForCausalLM.from_pretrained(model_name, **load_kwargs)
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        if self.tokenizer.pad_token is None and self.tokenizer.eos_token is not None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        self.device_hint = device
        self.device_map = device_map
        self.dtype = dtype
        self.config = self.model.config

        # Attach transcoders to a primary device
        self.primary_device = self._get_primary_device()
        self.transcoders = transcoders.to(self.primary_device, dtype)

        # Read hook configuration from transcoders and normalize to HF module + hook types
        self.feature_input_hook = transcoders.feature_input_hook
        self.feature_output_hook = transcoders.feature_output_hook
        self.input_module_name, self.input_hook_type = self._normalize_hook(self.feature_input_hook)
        self.output_module_name, self.output_hook_type = self._normalize_hook(self.feature_output_hook)

        self.skip_transcoder = getattr(transcoders, "skip_connection", False)
        self.scan = getattr(transcoders, "scan", None)

        # Freeze base model; keep embeddings trainable if needed
        for p in self.model.parameters():
            p.requires_grad = False
        if hasattr(self.model, "get_input_embeddings"):
            emb = self.model.get_input_embeddings()
            for p in emb.parameters():
                p.requires_grad = True

        # Hook registry
        self._hooks: dict[str, Any] = {}

    # ---- Public constructors (keep API parity) ----

    @classmethod
    def from_pretrained_and_transcoders(
        cls,
        model_name: str,
        transcoders: TranscoderSet | CrossLayerTranscoder,
        device: torch.device | None = None,
        dtype: torch.dtype = torch.float32,
        device_map: str | dict | None = None,
        offload_folder: str | None = None,
        use_low_cpu_mem_usage: bool = True,
        **kwargs: Any,
    ) -> "ReplacementModel":
        return cls(
            model_name=model_name,
            transcoders=transcoders,
            device=device,
            dtype=dtype,
            device_map=device_map,
            offload_folder=offload_folder,
            use_low_cpu_mem_usage=use_low_cpu_mem_usage,
            **kwargs,
        )

    @classmethod
    def from_pretrained(
        cls,
        model_name: str,
        transcoder_set: str,
        device: torch.device | None = None,
        dtype: torch.dtype = torch.float32,
        device_map: str | dict | None = None,
        offload_folder: str | None = None,
        use_low_cpu_mem_usage: bool = True,
        **kwargs: Any,
    ) -> "ReplacementModel":
        if device is None:
            device = get_default_device()
        transcoders, _ = load_transcoder_from_hub(transcoder_set, device=device, dtype=dtype)
        return cls.from_pretrained_and_transcoders(
            model_name=model_name,
            transcoders=transcoders,
            device=device,
            dtype=dtype,
            device_map=device_map,
            offload_folder=offload_folder,
            use_low_cpu_mem_usage=use_low_cpu_mem_usage,
            **kwargs,
        )

    # ---- Minimal forward ----

    def forward(self, inputs: torch.Tensor | dict[str, torch.Tensor], **kwargs: Any):
        return self.model(inputs, **kwargs) if isinstance(inputs, dict) else self.model(inputs.unsqueeze(0), **kwargs)

    # ---- Utilities ----

    def _get_primary_device(self) -> torch.device:
        # Prefer HF placement if sharded; else fallback to user/device_hint/default
        if hasattr(self.model, "hf_device_map") and isinstance(self.model.hf_device_map, dict):
            first = next(iter(self.model.hf_device_map.values()))
            return torch.device(first)
        if self.device_hint is not None:
            return self.device_hint
        return get_default_device()

    def _normalize_hook(self, name: str) -> tuple[str, str]:
        """
        Map TransformerLens-style names to HF module name + hook type.
        - '...mlp.hook_in'  -> ('mlp', 'pre_forward')
        - '...mlp.hook_out' -> ('mlp', 'forward')
        """
        lower = name.lower()
        if "mlp" in lower and ("hook_in" in lower or "hook_mlp_in" in lower):
            return "mlp", "pre_forward"
        if "mlp" in lower and ("hook_out" in lower or "hook_mlp_out" in lower):
            return "mlp", "forward"
        # Default: treat as module path tail and use forward hook
        return name.split(".")[-1], "forward"

    def _get_layer_module(self, layer_idx: int, module_name: str):
        base = getattr(self.model, "model", self.model)
        # Try common container names
        if hasattr(base, "layers"):
            layer = base.layers[layer_idx]
        elif hasattr(base, "h"):
            layer = base.h[layer_idx]
        elif hasattr(base, "transformer") and hasattr(base.transformer, "h"):
            layer = base.transformer.h[layer_idx]
        elif hasattr(base, "blocks"):
            layer = base.blocks[layer_idx]
        else:
            raise ValueError(f"Unknown base model structure: {type(base)}")

        if not hasattr(layer, module_name):
            raise AttributeError(f"Layer {layer_idx} has no submodule '{module_name}'.")
        return getattr(layer, module_name)

    def _register_hook(self, layer_idx: int, module_name: str, hook_fn: Callable, hook_type: str = "forward"):
        module = self._get_layer_module(layer_idx, module_name)
        if hook_type == "pre_forward":
            handle = module.register_forward_pre_hook(hook_fn)
        elif hook_type == "forward":
            handle = module.register_forward_hook(hook_fn)
        else:
            raise ValueError(f"Unknown hook type: {hook_type}")
        key = f"{layer_idx}.{module_name}.{hook_type}"
        if key in self._hooks:
            self._hooks[key].remove()
        self._hooks[key] = handle
        return handle

    @contextmanager
    def zero_softcap(self):
        # Keep API parity; HF models do not apply output softcap here.
        yield

    def ensure_tokenized(self, prompt: str | torch.Tensor | list[int]) -> torch.Tensor:
        if isinstance(prompt, str):
            tokens = self.tokenizer(prompt, return_tensors="pt").input_ids.squeeze(0)
        elif isinstance(prompt, torch.Tensor):
            tokens = prompt.squeeze()
        elif isinstance(prompt, list):
            tokens = torch.tensor(prompt, dtype=torch.long).squeeze()
        else:
            raise TypeError(f"Unsupported prompt type: {type(prompt)}")
        if tokens.ndim > 1:
            raise ValueError(f"Tensor must be 1-D, got shape {tokens.shape}")

        # If first token is already special, keep
        if hasattr(self.tokenizer, "all_special_ids") and len(self.tokenizer.all_special_ids) > 0:
            if int(tokens[0].item()) in self.tokenizer.all_special_ids:
                return tokens.to(self._get_primary_device())

        # Prepend a special token to avoid pos-0 artifacts
        candidate_bos_token_ids = [
            getattr(self.tokenizer, "bos_token_id", None),
            getattr(self.tokenizer, "pad_token_id", None),
            getattr(self.tokenizer, "eos_token_id", None),
        ]
        if hasattr(self.tokenizer, "all_special_ids"):
            candidate_bos_token_ids += list(self.tokenizer.all_special_ids)
        dummy_bos_token_id = next((t for t in candidate_bos_token_ids if t is not None), None)
        if dummy_bos_token_id is not None:
            tokens = torch.cat([torch.tensor([dummy_bos_token_id], device=tokens.device), tokens])
        else:
            warnings.warn("No suitable special token found; the first token will be ignored.")
        return tokens.to(self._get_primary_device())

    # ---- Optional: expose transcoder activations for debugging ----

    def get_activations(
        self,
        inputs: str | torch.Tensor,
        sparse: bool = False,
        apply_activation_function: bool = True,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        target_device = self._get_primary_device()
        if isinstance(inputs, str):
            toks = self.tokenizer(inputs, return_tensors="pt")
            toks = {k: v.to(target_device) for k, v in toks.items()}
        elif isinstance(inputs, torch.Tensor):
            toks = inputs.to(target_device)
        else:
            toks = inputs

        activation_cache: dict[int, torch.Tensor] = {}

        def make_cache_pre(layer_idx: int):
            def pre_hook(module, inps):
                acts = inps[0]
                dev = self.transcoders[layer_idx].device if isinstance(self.transcoders, TranscoderSet) else self.transcoders.device
                feats = self.transcoders.encode_layer(
                    acts if acts.device == dev else acts.to(dev),
                    layer_idx,
                    apply_activation_function=apply_activation_function,
                ).detach()
                activation_cache[layer_idx] = (feats.to_sparse() if sparse else feats).cpu()
                return None
            return pre_hook

        def make_cache_fwd(layer_idx: int):
            def fwd_hook(module, inps, out):
                acts = inps[0] if len(inps) else out
                dev = self.transcoders[layer_idx].device if isinstance(self.transcoders, TranscoderSet) else self.transcoders.device
                feats = self.transcoders.encode_layer(
                    acts if acts.device == dev else acts.to(dev),
                    layer_idx,
                    apply_activation_function=apply_activation_function,
                ).detach()
                activation_cache[layer_idx] = (feats.to_sparse() if sparse else feats).cpu()
                return out
            return fwd_hook

        handles = []
        for layer in range(self.config.num_hidden_layers):
            if self.input_hook_type == "pre_forward":
                handles.append(self._register_hook(layer, self.input_module_name, make_cache_pre(layer), "pre_forward"))
            else:
                handles.append(self._register_hook(layer, self.input_module_name, make_cache_fwd(layer), "forward"))

        try:
            with torch.no_grad():
                out = self.model(toks if isinstance(toks, dict) else toks.unsqueeze(0))
                logits = out.logits
            acts = torch.stack([activation_cache[i].to(logits.device) for i in range(len(activation_cache))])
            if sparse:
                acts = acts.coalesce()
            return logits, acts
        finally:
            for h in handles:
                h.remove()

    # ---- Attribution context builder ----

    @torch.no_grad()
    def setup_attribution(self, inputs: str | torch.Tensor) -> AttributionContext:
        tokens_1d = self.ensure_tokenized(inputs if isinstance(inputs, str) else inputs)

        mlp_in_cache: dict[int, torch.Tensor] = {}
        mlp_out_cache: dict[int, torch.Tensor] = {}

        def mk_in_pre(layer: int):
            def pre_hook(module, inps):
                mlp_in_cache[layer] = inps[0].detach().cpu()
                return None
            return pre_hook

        def mk_in_fwd(layer: int):
            def fwd_hook(module, inps, out):
                mlp_in_cache[layer] = (inps[0] if len(inps) else out).detach().cpu()
                return out
            return fwd_hook

        def mk_out_fwd(layer: int):
            def fwd_hook(module, inps, out):
                mlp_out_cache[layer] = out.detach().cpu()
                return out
            return fwd_hook

        hooks = []
        for layer in range(self.config.num_hidden_layers):
            if self.input_hook_type == "pre_forward":
                hooks.append(self._register_hook(layer, self.input_module_name, mk_in_pre(layer), "pre_forward"))
            else:
                hooks.append(self._register_hook(layer, self.input_module_name, mk_in_fwd(layer), "forward"))
            hooks.append(self._register_hook(layer, self.output_module_name, mk_out_fwd(layer), "forward"))

        try:
            with torch.inference_mode():
                out = self.model(tokens_1d.unsqueeze(0))
                logits = out.logits

            tgt = self.transcoders[0].device if isinstance(self.transcoders, TranscoderSet) else self.transcoders.device
            mlp_in = torch.stack([mlp_in_cache[i].to(tgt) for i in range(len(mlp_in_cache))])
            mlp_out = torch.stack([mlp_out_cache[i].to(tgt) for i in range(len(mlp_out_cache))])

            # Squeeze batch dim (hooks capture B x S x D; we ran with B=1)
            if mlp_in.ndim == 4 and mlp_in.shape[1] == 1:
                mlp_in = mlp_in[:, 0]
            if mlp_out.ndim == 4 and mlp_out.shape[1] == 1:
                mlp_out = mlp_out[:, 0]

            attribution = self.transcoders.compute_attribution_components(mlp_in)

            attribution = self.transcoders.compute_attribution_components(mlp_in)
            error_vectors = mlp_out - attribution["reconstruction"]
            error_vectors[:, 0] = 0  # ignore artificial BOS

            # Token embeddings (robust to different model classes)
            embed = self.model.get_input_embeddings()
            token_vectors = embed(tokens_1d.to(embed.weight.device)).detach().to(tgt)

            return AttributionContext(
                activation_matrix=attribution["activation_matrix"],
                logits=logits,
                error_vectors=error_vectors,
                token_vectors=token_vectors,
                decoder_vecs=attribution["decoder_vecs"],
                encoder_vecs=attribution["encoder_vecs"],
                encoder_to_decoder_map=attribution["encoder_to_decoder_map"],
                decoder_locations=attribution["decoder_locations"],
            )
        finally:
            for h in hooks:
                h.remove()


# Convenience factory used by notebooks
def create_replacement_model(
    model_name: str,
    transcoders: TranscoderSet | CrossLayerTranscoder,
    device: torch.device | None = None,
    dtype: torch.dtype = torch.float32,
    device_map: str | dict | None = "auto",
    offload_folder: str | None = None,
    use_low_cpu_mem_usage: bool = True,
    **kwargs: Any,
) -> ReplacementModel:
    return ReplacementModel.from_pretrained_and_transcoders(
        model_name=model_name,
        transcoders=transcoders,
        device=device,
        dtype=dtype,
        device_map=device_map,
        offload_folder=offload_folder,
        use_low_cpu_mem_usage=use_low_cpu_mem_usage,
        **kwargs,
    )

In [21]:
model = create_replacement_model(
    model_name=model_name,
    transcoders=transcoders,
    dtype=torch.bfloat16,
    device_map="auto",
)



Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [None]:
plan_generation_sentences = ["Wait, maybe I should try to figure out the exact path.", "Wait, let me try to figure out how many straight segments there are.", "Let me check again."]

In [None]:
prompt = plan_generation_sentences[0] # What you want to get the graph for
max_n_logits = 10   # How many logits to attribute from, max. We attribute to min(max_n_logits, n_logits_to_reach_desired_log_prob); see below for the latter
desired_logit_prob = 0.95  # Attribution will attribute from the minimum number of logits needed to reach this probability mass (or max_n_logits, whichever is lower)
max_feature_nodes = 8192  # Only attribute from this number of feature nodes, max. Lower is faster, but you will lose more of the graph. None means no limit.
batch_size=8  # Batch size when attributing
offload='disk' if IN_COLAB else 'cpu' # Offload various parts of the model during attribution to save memory. Can be 'disk', 'cpu', or None (keep on GPU)
verbose = True  # Whether to display a tqdm progress bar and timing report

In [23]:
graph = attribute(
    prompt=prompt,
    model=model,
    max_n_logits=max_n_logits,
    desired_logit_prob=desired_logit_prob,
    batch_size=batch_size,
    max_feature_nodes=max_feature_nodes,
    offload=offload,
    verbose=verbose
)



Phase 0: Precomputing activations and vectors


OutOfMemoryError: CUDA out of memory. Tried to allocate 6.21 GiB. GPU 0 has a total capacity of 23.58 GiB of which 5.63 GiB is free. Process 3746032 has 280.00 MiB memory in use. Including non-PyTorch memory, this process has 17.66 GiB memory in use. Of the allocated memory 17.32 GiB is allocated by PyTorch, and 44.11 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
graph_dir = 'graphs'
graph_name = 'plan_generation_1.pt'
graph_dir = Path(graph_dir)
graph_dir.mkdir(exist_ok=True)
graph_path = graph_dir / graph_name

graph.to_pt(graph_path)

In [None]:
slug = "plan_generation_1_trimmed"  # this is the name that you assign to the graph
graph_file_dir = './graph_files'  # where to write the graph files. no need to make this one; create_graph_files does that for you
node_threshold=0.8  # keep only the minimum # of nodes whose cumulative influence is >= 0.8
edge_threshold=0.98  # keep only the minimum # of edges whose cumulative influence is >= 0.98

create_graph_files(
    graph_or_path=graph_path,  # the graph to create files for
    slug=slug,
    output_path=graph_file_dir,
    node_threshold=node_threshold,
    edge_threshold=edge_threshold
)