In [1]:
from typing import Tuple

import torch
from transformer_lens import HookedTransformer, HookedTransformerConfig
from transformer_lens.loading_from_pretrained import OFFICIAL_MODEL_NAMES
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformer_lens import HookedTransformer
from sae_training.config import LanguageModelSAERunnerConfig
from sae_training.lm_runner import language_model_sae_runner

model_name = "ExplosionNuclear/Llama-2.3-3B-Instruct-special"  # замените при необходимости
device = "cuda" if torch.cuda.is_available() else "cpu"

def get_model(model_name: str) -> HookedTransformer:
    '''
    Loads a model from transformer lens or HuggingFace
    '''
    
    if model_name not in OFFICIAL_MODEL_NAMES:
        return get_custom_hf_model(model_name)
    
    return HookedTransformer.from_pretrained(model_name)

def get_custom_hf_model(model_name: str) -> HookedTransformer:
    hf_model = AutoModelForCausalLM.from_pretrained(
        model_name, 
        torch_dtype=torch.float32
    )
    tokenizer = AutoTokenizer.from_pretrained(
        model_name, 
    )
    
    hf_config = hf_model.config
    
    # Создаем конфигурацию для TransformerLens
    # Ограничиваем размер контекста для экономии памяти
    max_ctx = min(hf_config.max_position_embeddings, 2048)
    
    cfg = HookedTransformerConfig(
        n_layers=hf_config.num_hidden_layers,
        d_model=hf_config.hidden_size,
        d_head=hf_config.hidden_size // hf_config.num_attention_heads,
        n_heads=hf_config.num_attention_heads,
        d_mlp=hf_config.intermediate_size,
        d_vocab=hf_config.vocab_size,
        n_ctx=max_ctx,  # Ограничиваем размер контекста
        act_fn=hf_config.hidden_act,  # Llama использует SiLU
        model_name=model_name,
        normalization_type="RMS",  # Llama использует RMSNorm
        device="cpu"
    )
    
    model = HookedTransformer(cfg)
    
    model.load_state_dict(hf_model.state_dict(), strict=False)
    model.set_tokenizer(tokenizer)
    
    return model 


In [3]:
# 1) Грузим модель, чтобы автоматически определить d_in
model = get_model(model_name)
d_in = int(model.cfg.d_model)

wandb_key = "a89e0ceef33f3c2cc4b7d9d9d5795fa238b4a60c"



# 2) Конфиг тренировки SAE
cfg = LanguageModelSAERunnerConfig(
    # Модель и датасет (лучше нетокенизированный, чтобы токенизировать LLAMA-токенайзером)
    model = model,
    model_name=model_name,
    hook_point="blocks.{layer}.hook_resid_pre",
    hook_point_layer=[5],  # выберите слой; можно начать с середины
    d_in=d_in,
    dataset_path="Skylion007/openwebtext",
    is_dataset_tokenized=False,
    context_size=128,

    # SAE
    expansion_factor=1,  # d_sae = d_in * expansion_factor
    b_dec_init_method="mean",

    # Тренировка
    lr=3e-4,
    l1_coefficient=1e-3,
    lr_scheduler_name="constantwithwarmup",
    lr_warm_up_steps=1000,
    train_batch_size=10,         # уменьшите, если не хватает памяти
    n_batches_in_buffer=64,
    total_training_tokens=2_000_0,  # увеличьте для серьёзного запуска
    store_batch_size=16,           # уменьшите/увеличьте под вашу GPU

    wandb_project="mats_sae_training_llama32",
    wandb_log_frequency=10,
    wandb_api_key=wandb_key,
    wandb_entity="rokser9-lucid-layers",
    eval_every_n_steps=100_000,

    logger_backend="clearml",
    
    n_checkpoints=2,
    checkpoint_path="checkpoints",

    push_to_hub=True,
    hub_repo_id="ExplosionNuclear/llama23-sae-resid_pre",
    hub_private=False,
    hub_token="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",

    # Прочее
    device=device,
    seed=42,
    dtype=torch.float32,
)


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Run name: 3072-L1-0.001-LR-0.0003-Tokens-2.000e+04
n_tokens_per_buffer (millions): 0.131072
Lower bound: n_contexts_per_buffer (millions): 0.001024
Total training steps: 2000
Total wandb updates: 200
n_tokens_per_feature_sampling_window (millions): 2.56
n_tokens_per_dead_feature_window (millions): 1.28
We will reset the sparsity calculation 1 times.
Number tokens in sparsity calculation window: 2.00e+04


In [None]:
from dotenv import load_dotenv


load_dotenv()
_ = language_model_sae_runner(cfg)
print("Training finished")

ClearML Task: created new task id=b2043eb17a3a49b58f0c252476cb981a
2025-08-26 15:53:15,568 - clearml.Task - INFO - Storing jupyter notebook directly as code
ClearML results page: https://app.clear.ml/projects/396a5b3095514240954c229f8cf08618/experiments/b2043eb17a3a49b58f0c252476cb981a/output/log


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Moving model to device:  cuda
