In [2]:
import os
from types import SimpleNamespace
from typing import Dict, List, Optional, Sequence
from typing import Any, Dict, Tuple
import torch
from transformer_lens import HookedTransformer, HookedTransformerConfig
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import hf_hub_download, HfApi
from safetensors.torch import load_file, save_file
import os

from sae_lens import HookedSAETransformer, StandardSAE, SAEConfig
from safetensors.torch import load_file as safetensors_load_file


def _dtype_to_cfg_str(dtype: torch.dtype | str | None) -> str:
    if dtype is None:
        return "torch.float32"
    if isinstance(dtype, torch.dtype):
        return str(dtype)
    return dtype


def _device_to_cfg_str(device: torch.device | str | None) -> str:
    if device is None:
        return "cpu"
    if isinstance(device, torch.device):
        return str(device)
    return device


def _build_sae_cfg_from_training(
    *,
    model_name: str,
    hook_layer: int,
    d_in: int,
    d_sae: int,
    context_size: int = 128,
    dataset_path: str = "ashaba1in/small_openwebtext",
    hook_name: str = "blocks.{layer}.hook_{target}",  # hook_mlp_in  hook_resid_pre,
    target: str = "resid_pre",  # mid_pre, resid_pre,
    device: torch.device | str | None = "cuda",
    dtype: torch.dtype | str | None = "torch.float32",
    hook_head_index: Optional[int] = None,
) -> SAEConfig:
    target = "mlp_in" if target == "mid_pre" else "resid_pre"
    
    cfg_dict = {
        "architecture": "standard",
        "d_in": int(d_in),
        "d_sae": int(d_sae),
        "activation_fn_str": "relu",
        "activation_fn_kwargs": {},
        "apply_b_dec_to_input": True,
        "finetuning_scaling_factor": False,
        "context_size": int(context_size),
        "model_name": model_name,
        "hook_name": hook_name.format(layer=hook_layer, target=target),
        "hook_layer": int(hook_layer),
        "hook_head_index": hook_head_index,
        "prepend_bos": False,
        "dataset_path": dataset_path,
        "dataset_trust_remote_code": False,
        "normalize_activations": "none",
        "dtype": _dtype_to_cfg_str(dtype),
        "device": _device_to_cfg_str(device),
        "sae_lens_training_version": None,
        "neuronpedia_id": None,
        "model_from_pretrained_kwargs": {},
        "seqpos_slice": (None,),
    }
    return SAEConfig.from_dict(cfg_dict)


def get_custom_hf_model(model_name: str, kwargs: Dict[str, Any] = {}) -> HookedTransformer:
    hf_model = AutoModelForCausalLM.from_pretrained(
        model_name, 
        **kwargs
    )
    tokenizer = AutoTokenizer.from_pretrained(
        model_name, 
    )
    
    hf_config = hf_model.config
    
    # Создаем конфигурацию для TransformerLens
    # Ограничиваем размер контекста для экономии памяти
    max_ctx = min(hf_config.max_position_embeddings, 2048)
    
    cfg = HookedTransformerConfig(
        n_layers=hf_config.num_hidden_layers,
        d_model=hf_config.hidden_size,
        d_head=hf_config.hidden_size // hf_config.num_attention_heads,
        n_heads=hf_config.num_attention_heads,
        d_mlp=hf_config.intermediate_size,
        d_vocab=hf_config.vocab_size,
        n_ctx=max_ctx,  # Ограничиваем размер контекста
        act_fn=hf_config.hidden_act,  # Llama использует SiLU
        model_name=model_name,
        normalization_type="RMS",  # Llama использует RMSNorm
        device="cpu", 
        use_hook_mlp_in=True,
    )
    
    model = HookedTransformer(cfg)
    
    model.load_state_dict(hf_model.state_dict(), strict=False)
    model.set_tokenizer(tokenizer)
    
    return model

def load_hookedtrans_and_sae(
    hf_model = "ExplosionNuclear/Llama-2.3-3B-Instruct-special", 
    layer = 1, 
    hook_target = "mid_pre"  # mid_pre  resid_pre
) -> Tuple[HookedTransformer, StandardSAE]:

    hook_trans = get_custom_hf_model(hf_model)
    hook_trans.to("cuda")
    
    repo_id = "Lucid-Layers-Inc/Llama-2.3-3B-Instruct-special-hook_mlp_in-SAE"
    #repo_id = repo_id.replace("special", "special-merged") if "special-merged" in hf_model else repo_id
    local_path = hf_hub_download(
        repo_id=repo_id,
        filename=f"ExplosionNuclear-Llama-2.3-3B-Instruct-special_layer-{layer}.hook_mlp_in_30720.safetensors",
        repo_type="model",
    )
    
    weights = load_file(local_path)
    hidden_dim, d_in = weights["W_dec"].shape
    
    sae_cfg = _build_sae_cfg_from_training(d_sae=hidden_dim, d_in=d_in, hook_layer=layer, model_name=hf_model, target=hook_target)
    sae = StandardSAE(sae_cfg)
    sae.load_state_dict(weights, strict=True)
    return hook_trans, sae

def load_sae_only(
    hf_model = "ExplosionNuclear/Llama-2.3-3B-Instruct-special", 
    layer = 1, 
    hook_target = "mid_pre"  # mid_pre  resid_pre
) -> StandardSAE:
    """Загружает только SAE без модели для проверки конфигурации"""
    
    repo_id = "Lucid-Layers-Inc/Llama-2.3-3B-Instruct-special-hook_resid_pre-SAE"
    repo_id = repo_id.replace("special", "special-merged") if "special-merged" in hf_model else repo_id
    local_path = hf_hub_download(
        repo_id=repo_id,
        filename=f"ExplosionNuclear-Llama-2.3-3B-Instruct-special_layer-{layer}.hook_mlp_in_30720.safetensors",
        repo_type="model",
    )
    
    weights = load_file(local_path)
    hidden_dim, d_in = weights["W_dec"].shape
    
    sae_cfg = _build_sae_cfg_from_training(d_sae=hidden_dim, d_in=d_in, hook_layer=layer, model_name=hf_model, target=hook_target)
    sae = StandardSAE(sae_cfg)
    sae.load_state_dict(weights, strict=True)
    return sae

def get_hook_name(layer: int, hook_target: str) -> str:
    """Получает hook_name для указанного слоя и цели"""
    target = "mlp_in" if hook_target == "mid_pre" else "resid_pre"
    return f"blocks.{layer}.hook_{target}"

In [2]:
# !pip install sae-lens transformer-lens sae-dashboard

Collecting sae-lens
  Downloading sae_lens-6.11.0-py3-none-any.whl.metadata (5.2 kB)
Collecting sae-dashboard
  Downloading sae_dashboard-0.7.2-py3-none-any.whl.metadata (9.5 kB)
Collecting datasets>=3.1.0 (from sae-lens)
  Using cached datasets-4.0.0-py3-none-any.whl.metadata (19 kB)
Collecting simple-parsing<0.2.0,>=0.1.6 (from sae-lens)
  Downloading simple_parsing-0.1.7-py3-none-any.whl.metadata (7.3 kB)
Collecting transformer-lens
  Downloading transformer_lens-2.16.1-py3-none-any.whl.metadata (12 kB)
Collecting transformers-stream-generator<0.0.6,>=0.0.5 (from transformer-lens)
  Downloading transformers-stream-generator-0.0.5.tar.gz (13 kB)
  Preparing metadata (setup.py) ... [?25ldone
Collecting docstring-parser<1.0,>=0.15 (from simple-parsing<0.2.0,>=0.1.6->sae-lens)
  Using cached docstring_parser-0.17.0-py3-none-any.whl.metadata (3.5 kB)
Collecting datasets>=3.1.0 (from sae-lens)
  Using cached datasets-3.6.0-py3-none-any.whl.metadata (19 kB)
Collecting decode-clt<0.0.2,>=0

In [4]:
model, sae = load_hookedtrans_and_sae(layer=10, hook_target="resid_pre")


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Moving model to device:  cuda


This SAE has non-empty model_from_pretrained_kwargs. 
For optimal performance, load the model like so:
model = HookedSAETransformer.from_pretrained_no_processing(..., **cfg.model_from_pretrained_kwargs)


In [5]:


prompt = "Max, Mark and Smith were in the empty dark room. Smith left. Mark gave a drink to"

tokens = model.to_tokens(prompt)
str_toks = model.to_str_tokens(prompt)

layer = 10
hook_name = f"blocks.{layer}.hook_mlp_in"
hook_name1 = f"blocks.{layer}.hook_resid_pre"

_, cache = model.run_with_cache(
    tokens,
    names_filter=[hook_name, hook_name1],
    remove_batch_dim=True
)

print(f"Cache keys: {list(cache.keys())}")
resid_pre1 = cache[hook_name1]
resid_pre = cache[hook_name] 

_, sae_cache_pre = sae.run_with_cache(resid_pre1)
sae_out  = sae_cache_pre["hook_sae_output"]

errors = resid_pre1 - sae_out
print(torch.norm(resid_pre1[1,:]))
print(errors.shape)
print("norm", torch.norm(errors, dim = -1))


acts_post = sae_cache_pre["hook_sae_acts_post"] # [n_tokens, d_sae]
top_vals, top_idx = torch.topk(acts_post, k=10, dim=-1)

print(top_idx)


Cache keys: ['blocks.10.hook_resid_pre', 'blocks.10.hook_mlp_in']
tensor(150.1659, device='cuda:0')
torch.Size([21, 3072])
norm tensor([189.4258, 186.5565, 184.8637, 183.1293, 183.7237, 181.9275, 182.0390,
        180.5799, 179.9156, 180.4675, 179.8455, 179.7326, 179.7034, 180.7181,
        180.1364, 179.9122, 179.6683, 178.5136, 178.7495, 177.5304, 177.3266],
       device='cuda:0')
tensor([[14366,  1447, 20511, 23265, 28850, 11604,  3740,  4443, 14864,  4111],
        [25326, 11559, 15454, 12909,  4624, 29818, 16946, 10761, 24041, 16843],
        [26037, 15454,     0,     1,     6,     7,     5,     4,     2,     3],
        [26037,     1,     3,     6,     7,     0,     2,     5,     8,     4],
        [    7,     6,     4,     5,     1,     0,     2,     3,     8,     9],
        [    7,     6,     4,     5,     1,     0,     2,     3,     8,     9],
        [    7,     6,     4,     5,     1,     0,     2,     3,     8,     9],
        [    7,     6,     4,     5,     1,     0,   

In [6]:
_, cache = model.run_with_cache(
    tokens,
    remove_batch_dim=True
)
cache.keys()

dict_keys(['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre', 'blocks.0.ln1.hook_scale', 'blocks.0.ln1.hook_normalized', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_mid', 'blocks.0.hook_mlp_in', 'blocks.0.ln2.hook_scale', 'blocks.0.ln2.hook_normalized', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.ln1.hook_scale', 'blocks.1.ln1.hook_normalized', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'blocks.1.attn.hook_v', 'blocks.1.attn.hook_attn_scores', 'blocks.1.attn.hook_pattern', 'blocks.1.attn.hook_z', 'blocks.1.hook_attn_out', 'blocks.1.hook_resid_mid', 'blocks.1.hook_mlp_in', 'blocks.1.ln2.hook_scale', 'blocks.1.ln2.hook_normalized', 'blocks.1.mlp.hook_pre', 'blocks.1.mlp.hook_post', 'blocks.1.hook_mlp_out', 'blocks.1.hoo