# Setup (just run all)

## Auxilliary functions

In [1]:
import os
from types import SimpleNamespace
from typing import Dict, List, Optional, Sequence

import torch
from safetensors.torch import load_file as safetensors_load_file

from sae_lens.sae import SAE, SAEConfig


def _dtype_to_cfg_str(dtype: torch.dtype | str | None) -> str:
    if dtype is None:
        return "torch.float32"
    if isinstance(dtype, torch.dtype):
        return str(dtype)
    return dtype


def _device_to_cfg_str(device: torch.device | str | None) -> str:
    if device is None:
        return "cpu"
    if isinstance(device, torch.device):
        return str(device)
    return device


def _build_sae_cfg_from_training(
    *,
    model_name: str,
    hook_layer: int,
    d_in: int,
    d_sae: int,
    context_size: int = 128,
    dataset_path: str = "ashaba1in/small_openwebtext",
    hook_name: str = "blocks.{layer}.hook_{target}",  # hook_mlp_in  hook_resid_pre,
    target: str = "resid_pre",  # mid_pre, resid_pre,
    device: torch.device | str | None = "cuda",
    dtype: torch.dtype | str | None = "torch.float32",
    hook_head_index: Optional[int] = None,
) -> SAEConfig:
    target = "mlp_in" if target == "mid_pre" else "resid_pre"
    
    cfg_dict = {
        "architecture": "standard",
        "d_in": int(d_in),
        "d_sae": int(d_sae),
        "activation_fn_str": "relu",
        "activation_fn_kwargs": {},
        "apply_b_dec_to_input": True,
        "finetuning_scaling_factor": False,
        "context_size": int(context_size),
        "model_name": model_name,
        "hook_name": hook_name.format(layer=hook_layer, target=target),
        "hook_layer": int(hook_layer),
        "hook_head_index": hook_head_index,
        "prepend_bos": False,
        "dataset_path": dataset_path,
        "dataset_trust_remote_code": False,
        "normalize_activations": "none",
        "dtype": _dtype_to_cfg_str(dtype),
        "device": _device_to_cfg_str(device),
        "sae_lens_training_version": None,
        "neuronpedia_id": None,
        "model_from_pretrained_kwargs": {},
        "seqpos_slice": (None,),
    }
    return SAEConfig.from_dict(cfg_dict)


In [2]:
from typing import Any, Dict, Tuple

import torch
from transformer_lens import HookedTransformer, HookedTransformerConfig
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import hf_hub_download, HfApi
from safetensors.torch import load_file, save_file
import os

def get_custom_hf_model(model_name: str, kwargs: Dict[str, Any] = {}) -> HookedTransformer:
    hf_model = AutoModelForCausalLM.from_pretrained(
        model_name, 
        **kwargs
    )
    tokenizer = AutoTokenizer.from_pretrained(
        model_name, 
    )
    
    hf_config = hf_model.config
    
    # Создаем конфигурацию для TransformerLens
    # Ограничиваем размер контекста для экономии памяти
    max_ctx = min(hf_config.max_position_embeddings, 2048)
    
    cfg = HookedTransformerConfig(
        n_layers=hf_config.num_hidden_layers,
        d_model=hf_config.hidden_size,
        d_head=hf_config.hidden_size // hf_config.num_attention_heads,
        n_heads=hf_config.num_attention_heads,
        d_mlp=hf_config.intermediate_size,
        d_vocab=hf_config.vocab_size,
        n_ctx=max_ctx,  # Ограничиваем размер контекста
        act_fn=hf_config.hidden_act,  # Llama использует SiLU
        model_name=model_name,
        normalization_type="RMS",  # Llama использует RMSNorm
        device="cpu", 
        use_hook_mlp_in=True,
    )
    
    model = HookedTransformer(cfg)
    
    model.load_state_dict(hf_model.state_dict(), strict=False)
    model.set_tokenizer(tokenizer)
    
    return model

def get_sae_and_logs(
    src_file = "Llama-2.3-3B-Instruct-special_blocks.-4.hook_resid_pre_18432.safetensors",
    log_file = None,
    layer_range = None,
):

    # ==== 1. качаем исходный файл ====
    local_path = hf_hub_download(
        repo_id=SRC_REPO,
        filename=src_file,
        repo_type="model",
    )
    log_path = hf_hub_download(
        repo_id=SRC_REPO,
        filename=src_file.replace(".safetensors", "_log_feature_sparsity.pt") if log_file is None else log_file,
        repo_type="model",
    )

    weights = load_file(local_path)
    logs = torch.load(log_path, weights_only=True)

    if layer_range is None:
        start, _ = extract_layer_range(src_file)
    else:
        start, _ = layer_range

    sae_by_layer = {}
    for k, v in weights.items():
        layer_id, weight_name = k.split(".")
        layer_id = start + int(layer_id)
        if layer_id not in sae_by_layer:
            sae_by_layer[layer_id] = {weight_name: v}
        sae_by_layer[layer_id][weight_name] = v

    return sae_by_layer, logs

In [3]:
# !pip install -U sae-lens transformer-lens sae-dashboard

# SAE gathering

In [4]:
import torch
from sae_lens import HookedSAETransformer


def load_hookedtrans_and_sae(
    hf_model = "ExplosionNuclear/Llama-2.3-3B-Instruct-special", 
    layer = 1, 
    hook_target = "mid_pre"  # mid_pre  resid_pre
) -> Tuple[HookedTransformer, SAE]:
    hook_trans = get_custom_hf_model(hf_model)
    hook_trans.to("cuda")
    
    repo_id = "Lucid-Layers-Inc/Llama-3.2-3B-Instruct-special-SAE"
    repo_id = repo_id.replace("special", "special-merged") if "special-merged" in hf_model else repo_id
    local_path = hf_hub_download(
        repo_id=repo_id,
        filename=f"Llama-2.3-3B-Instruct-special_layer-{layer}.hook_{hook_target}_18432.safetensors",
        repo_type="model",
    )
    
    weights = load_file(local_path)
    hidden_dim, d_in = weights["W_dec"].shape
    
    sae_cfg = _build_sae_cfg_from_training(d_sae=hidden_dim, d_in=d_in, hook_layer=layer, model_name=hf_model, target=hook_target)
    sae = SAE(sae_cfg)
    sae.load_state_dict(weights, strict=True)
    return hook_trans, sae

model, sae = load_hookedtrans_and_sae(layer=10)

# Если нужна поддержка HookedSAETransformer - вроде трюк сработал
model.__class__ = HookedSAETransformer
if not hasattr(model, "acts_to_saes"):
    model.acts_to_saes = {}

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Moving model to device:  cuda


In [None]:
import torch
import gradio as gr

def render_colored(tokens, strengths):
    # нормируем в [0,1] для альфы цвета
    s = torch.clamp(strengths, min=0)
    s = (s / (s.max() + 1e-8)).tolist()
    html = []
    for tok, a in zip(tokens, s):
        # оранжевый heatmap по активации
        html.append(f'<span style="background-color: rgba(255,165,0,{a:.3f}); padding:2px; margin:1px; border-radius:3px">{tok}</span>')
    return "<div style='font-family:monospace; line-height:2.0'>" + " ".join(html) + "</div>"

@torch.inference_mode()
def inspect_feature(prompt: str, feature_id: int):
    # 1) токенизация
    toks = model.to_tokens(prompt)
    str_toks = model.to_str_tokens(prompt)

    # 2) получаем активации узла для SAE и считаем фичи
    # ВАЖНО: SAE не должен быть "приклеен" к модели во время кэширования.
    out, cache = model.run_with_cache(
        toks, return_type="logits",
        names_filter=[sae.cfg.hook_name],
        remove_batch_dim=True
    )
    print(cache.keys())
    x = cache[sae.cfg.hook_name]                      # [pos, d_in]
    feats = sae.encode(x)                             # [pos, d_sae]
    f = feats[:, feature_id].float().cpu()            # [pos]

    # 3) топ токены словаря по направлению декодера (логит-влияние при активации=1)
    # dir_vocab[v] ≈ (W_dec[f] · W_U[:, v])
    W_U = model.W_U                                   # [d_model, d_vocab]
    wdec_f = sae.W_dec[feature_id].to(W_U.dtype).to(W_U.device)  # [d_in]
    dir_vocab = (wdec_f @ W_U).float().cpu()          # [d_vocab]
    vals, idx = torch.topk(dir_vocab, 15)
    toks_top = [model.to_string([i.item()]) for i in idx]

    heatmap_html = render_colored(str_toks, f)
    table = {t: float(v) for t, v in zip(toks_top, vals)}
    stats = {
        "mean_act": float(torch.relu(f).mean()),
        "max_act": float(torch.relu(f).max()),
        "sparsity(frac>0)": float((torch.relu(f) > 0).float().mean()),
        "d_sae": int(sae.cfg.d_sae),
        "hook": sae.cfg.hook_name,
    }
    return heatmap_html, table, stats

demo = gr.Interface(
    fn=inspect_feature,
    inputs=[
        gr.Textbox(lines=4, label="Prompt", value="The capital of France is Paris."),
        gr.Slider(0, sae.cfg.d_sae-1, value=0, step=1, label="Feature ID"),
    ],
    outputs=[
        gr.HTML(label="Token activations (highlighted)"),
        gr.Label(num_top_classes=15, label="Top vocab (decoder direction)"),
        gr.JSON(label="Stats"),
    ],
    title="SAE Feature Explorer",
    description="Подсветка активаций признака и его влияния на логиты.",
)
demo.launch(inline=True, share=True, debug=True)

* Running on local URL:  http://127.0.0.1:7860
* Running on public URL: https://9dc51eb84c1723b73a.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


dict_keys(['blocks.10.hook_mlp_in'])
