In [1]:
from typing import Tuple

import torch
from transformer_lens import HookedTransformer, HookedTransformerConfig
from transformer_lens.loading_from_pretrained import OFFICIAL_MODEL_NAMES
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformer_lens import HookedTransformer
from sae_training.config import LanguageModelSAERunnerConfig
from sae_training.lm_runner import language_model_sae_runner

model_name = "ExplosionNuclear/Llama-2.3-3B-Instruct-special"  # замените при необходимости
device = "cuda" if torch.cuda.is_available() else "cpu"

def get_model(model_name: str, kwargs=dict(torch_dtype=torch.float32)) -> HookedTransformer:
    '''
    Loads a model from transformer lens or HuggingFace
    '''
    
    if model_name not in OFFICIAL_MODEL_NAMES:
        return get_custom_hf_model(model_name, kwargs)
    
    return HookedTransformer.from_pretrained(model_name)

def get_custom_hf_model(model_name: str, kwargs: dict = {}) -> HookedTransformer:
    hf_model = AutoModelForCausalLM.from_pretrained(
        model_name, 
        **kwargs
    )
    tokenizer = AutoTokenizer.from_pretrained(
        model_name, 
    )
    
    hf_config = hf_model.config
    
    # Создаем конфигурацию для TransformerLens
    # Ограничиваем размер контекста для экономии памяти
    max_ctx = min(hf_config.max_position_embeddings, 2048)
    
    cfg = HookedTransformerConfig(
        n_layers=hf_config.num_hidden_layers,
        d_model=hf_config.hidden_size,
        d_head=hf_config.hidden_size // hf_config.num_attention_heads,
        n_heads=hf_config.num_attention_heads,
        d_mlp=hf_config.intermediate_size,
        d_vocab=hf_config.vocab_size,
        n_ctx=max_ctx,  # Ограничиваем размер контекста
        act_fn=hf_config.hidden_act,  # Llama использует SiLU
        model_name=model_name,
        normalization_type="RMS",  # Llama использует RMSNorm
        device="cpu"
    )
    
    model = HookedTransformer(cfg)
    
    model.load_state_dict(hf_model.state_dict(), strict=False)
    model.set_tokenizer(tokenizer)
    
    return model 


In [2]:
# 1) Грузим модель, чтобы автоматически определить d_in
d_type = torch.float32
model = get_model(model_name, dict(torch_dtype=d_type))
d_in = int(model.cfg.d_model)


# 2) Конфиг тренировки SAE
cfg = LanguageModelSAERunnerConfig(
    # Модель и датасет (лучше нетокенизированный, чтобы токенизировать LLAMA-токенайзером)
    model = model,
    model_name=model_name,
    hook_point="blocks.{layer}.hook_resid_pre",
    hook_point_layer=[5],  # выберите слой; можно начать с середины
    d_in=d_in,
    dataset_path="ashaba1in/small_openwebtext",
    is_dataset_tokenized=False,
    context_size=128,

    # SAE
    expansion_factor=1,  # d_sae = d_in * expansion_factor
    b_dec_init_method="mean",

    # Тренировка
    lr=3e-4,
    l1_coefficient=1e-3,
    lr_scheduler_name="constantwithwarmup",
    lr_warm_up_steps=1000,
    train_batch_size=10,         # уменьшите, если не хватает памяти
    n_batches_in_buffer=64,
    total_training_tokens=2_000,  # увеличьте для серьёзного запуска
    store_batch_size=16,           # уменьшите/увеличьте под вашу GPU

    wandb_project="mats_sae_training_llama32",
    wandb_log_frequency=10,
    wandb_api_key="a89e0ceef33f3c2cc4b7d9d9d5795fa238b4a60c",
    wandb_entity="rokser9-lucid-layers",
    eval_every_n_steps=100_000,

    logger_backend="clearml",
    
    n_checkpoints=2,
    checkpoint_path="checkpoints",

    push_to_hub=True,
    hub_repo_id="Lucid-Layers-Inc/llama23-sae-resid_pre",
    hub_private=False,

    # Прочее
    device=device,
    seed=42,
    dtype=d_type,
)


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Run name: 3072-L1-0.001-LR-0.0003-Tokens-2.000e+03
n_tokens_per_buffer (millions): 0.131072
Lower bound: n_contexts_per_buffer (millions): 0.001024
Total training steps: 200
Total wandb updates: 20
n_tokens_per_feature_sampling_window (millions): 2.56
n_tokens_per_dead_feature_window (millions): 1.28
We will reset the sparsity calculation 0 times.
Number tokens in sparsity calculation window: 2.00e+04


In [None]:
from dotenv import load_dotenv


load_dotenv()
_ = language_model_sae_runner(cfg)
print("Training finished")

ClearML Task: created new task id=0b505cfada784de6a68cbd73221d7efc
2025-08-27 13:43:41,331 - clearml.Task - INFO - Storing jupyter notebook directly as code
ClearML results page: https://app.clear.ml/projects/396a5b3095514240954c229f8cf08618/experiments/0b505cfada784de6a68cbd73221d7efc/output/log


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Moving model to device:  cuda
Run name: 3072-L1-0.001-LR-0.0003-Tokens-2.000e+03
n_tokens_per_buffer (millions): 0.131072
Lower bound: n_contexts_per_buffer (millions): 0.001024
Total training steps: 200
Total wandb updates: 20
n_tokens_per_feature_sampling_window (millions): 2.56
n_tokens_per_dead_feature_window (millions): 1.28
We will reset the sparsity calculation 0 times.
Number tokens in sparsity calculation window: 2.00e+04
Run name: 3072-L1-0.001-LR-0.0003-Tokens-2.000e+03
n_tokens_per_buffer (millions): 0.131072
Lower bound: n_contexts_per_buffer (millions): 0.001024
Total training steps: 200
Total wandb updates: 20
n_tokens_per_feature_sampling_window (millions): 2.56
n_tokens_per_dead_feature_window (millions): 1.28
We will reset the sparsity calculation 0 times.
Number tokens in sparsity calculation window: 2.00e+04
Reinitializing b_dec with mean of activations
Previous distances: 93.31073760986328
New distances: 64.46891784667969


100| MSE Loss 0.035 | L1 1.919:  50%|███████████████████████████████████████████████                                                | 990/2000 [00:00<00:00, 1537.26it/s]

Saved model to checkpoints/dr7rq6rv/1010_sae_group_ExplosionNuclear/Llama-2.3-3B-Instruct-special_blocks.5.hook_resid_pre_3072.safetensors


Llama-2.3-3B-Instruct-special_blocks.5.hook_resid_pre_3072.safetensors:   0%|          | 0.00/75.5M [00:00<?, …

100| MSE Loss 0.035 | L1 1.919:  50%|███████████████████████████████████████████████                                               | 1000/2000 [00:12<00:00, 1537.26it/s]

Llama-2.3-3B-Instruct-special_blocks.5.hook_resid_pre_3072_log_feature_sparsity.pt:   0%|          | 0.00/14.1…

200| MSE Loss 0.032 | L1 1.787: 100%|██████████████████████████████████████████████████████████████████████████████████████████████▌| 1990/2000 [00:18<00:00, 204.65it/s]

Saved model to checkpoints/dr7rq6rv/final_sae_group_ExplosionNuclear/Llama-2.3-3B-Instruct-special_blocks.5.hook_resid_pre_3072.safetensors


Llama-2.3-3B-Instruct-special_blocks.5.hook_resid_pre_3072.safetensors:   0%|          | 0.00/75.5M [00:00<?, …

Llama-2.3-3B-Instruct-special_blocks.5.hook_resid_pre_3072_log_feature_sparsity.pt:   0%|          | 0.00/14.1…

200| MSE Loss 0.032 | L1 1.787: 100%|████████████████████████████████████████████████████████████████████████████████████████████████| 2000/2000 [00:30<00:00, 65.52it/s]


In [3]:
import dataclasses
from sae_training.utils import LMSparseAutoencoderSessionloader


path = "checkpoints/dr7rq6rv/1010_sae_group_ExplosionNuclear/Llama-2.3-3B-Instruct-special_blocks.5.hook_resid_pre_3072.safetensors"
model, sparse_autoencoders, activation_store = (
    LMSparseAutoencoderSessionloader.load_session_from_pretrained(path, dataclasses.asdict(cfg))
)

Moving model to device:  cuda


In [4]:
sparse_autoencoder = sparse_autoencoders.autoencoders[0]

In [5]:
sparse_autoencoder.eval()
model.eval() # prevents error if we're expecting a dead neuron mask for who grads
with torch.no_grad():
    batch_tokens = activation_store.get_batch_tokens()
    _, cache = model.run_with_cache(batch_tokens, prepend_bos=True)
    sae_out, feature_acts, loss, mse_loss, l1_loss, _ = sparse_autoencoder(
        cache[sparse_autoencoder.cfg.hook_point]
    )
    del cache

    # ignore the bos token, get the number of features that activated in each token, averaged accross batch and position
    l0 = (feature_acts[:, 1:] > 0).float().sum(-1).detach()
    print("average l0", l0.mean().item())
    px.histogram(l0.flatten().cpu().numpy()).show()

OutOfMemoryError: CUDA out of memory. Tried to allocate 24.00 MiB. GPU 0 has a total capacity of 23.53 GiB of which 20.69 MiB is free. Process 73563 has 23.50 GiB memory in use. Of the allocated memory 22.55 GiB is allocated by PyTorch, and 517.84 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)