In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
from transformer_lens import HookedTransformer
from transformers import AutoModelForCausalLM, AutoTokenizer

from lm_saes import SparseAutoEncoder

  from .autonotebook import tqdm as notebook_tqdm


In [8]:
model_name = "meta-llama/Llama-3.1-8B"

hf_model = AutoModelForCausalLM.from_pretrained(model_name)

hf_tokenizer = AutoTokenizer.from_pretrained(
    model_name,
    trust_remote_code=True,
    use_fast=True,
    add_bos_token=True,
)
model = HookedTransformer.from_pretrained_no_processing(
    model_name,
    device="cuda",
    hf_model=hf_model,
    tokenizer=hf_tokenizer,
    dtype=torch.bfloat16,
).eval()

Loading checkpoint shards: 100%|██████████| 4/4 [02:07<00:00, 31.87s/it]


Loaded pretrained model meta-llama/Llama-3.1-8B into HookedTransformer


In [7]:
sae = SparseAutoEncoder.from_pretrained("fnlp/Llama3_1-8B-Base-L15R-8x")

Local path `fnlp/Llama3_1-8B-Base-L15R-8x` not found. Downloading from huggingface model hub.
Downloading Llama Scope SAEs.


Fetching 3 files: 100%|██████████| 3/3 [00:36<00:00, 12.26s/it]


Local path `fnlp/Llama3_1-8B-Base-L15R-8x` not found. Downloading from huggingface model hub.
Downloading Llama Scope SAEs.


Fetching 3 files: 100%|██████████| 3/3 [00:00<00:00, 13443.28it/s]


In [9]:
text = "The quick brown fox jumps over the lazy dog and then sprints through the forest while the wind howls through the tall trees, shaking the branches as birds scatter into the sky."

tokens = model.to_tokens(text)

_, cache = model.run_with_cache(tokens)

cache

ActivationCache with keys ['hook_embed', 'blocks.0.hook_resid_pre', 'blocks.0.ln1.hook_scale', 'blocks.0.ln1.hook_normalized', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_rot_q', 'blocks.0.attn.hook_rot_k', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_mid', 'blocks.0.ln2.hook_scale', 'blocks.0.ln2.hook_normalized', 'blocks.0.mlp.hook_gate', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.ln1.hook_scale', 'blocks.1.ln1.hook_normalized', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'blocks.1.attn.hook_v', 'blocks.1.attn.hook_rot_q', 'blocks.1.attn.hook_rot_k', 'blocks.1.attn.hook_attn_scores', 'blocks.1.attn.hook_pattern', 'blocks.1.attn.hook_z', 'blocks.1.hook_attn_out', 'blocks.1.hook_resid_mid', 'blocks.1.ln2.hook_scale', 'blocks.1.ln2.hook_normalized', '

In [10]:
# L0 Sparsity. The first token is <bos> which extremely out-of-distribution.
(sae.compute_loss(cache["blocks.15.hook_resid_post"])[1][1]["feature_acts"] > 0).sum(-1)

tensor([[17303,    12,    29,    30,    54,    42,    47,    37,    57,    61,
            39,    44,    29,    82,    50,    40,    64,    49,    50,    48,
            37,    63,    44,    40,    59,    49,    34,    40,    48,    62,
            50,    51,    68,    49,    46,    53,    42]], device='cuda:0')

In [11]:
# Reconstruction loss
(
    sae.compute_loss(cache["blocks.15.hook_resid_post"][:, 1:])[1][1]["reconstructed"]
    - cache["blocks.15.hook_resid_post"][:, 1:]
).pow(2).mean()

tensor(0.0080, device='cuda:0', dtype=torch.bfloat16, grad_fn=<MeanBackward0>)

In [19]:
sae.cfg

SAEConfig(device='cuda:0', seed=42, dtype=torch.bfloat16, hook_point_in='blocks.15.hook_resid_post', hook_point_out='blocks.15.hook_resid_post', sae_pretrained_name_or_path='/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/hezhengfu-240208120186/projects/llamascope_ckpts/Llama3_1Base-LXR-8x-topk/Llama3_1Base-L15R-8x', strict_loading=True, use_decoder_bias=True, apply_decoder_bias_to_pre_encoder=False, expansion_factor=8, d_model=4096, d_sae=32768, bias_init_method='all_zero', act_fn='jumprelu', jump_relu_threshold=0.35546875, norm_activation='inference', dataset_average_activation_norm={'in': 10.8125, 'out': 10.8125}, decoder_exactly_fixed_norm=False, sparsity_include_decoder_norm=True, use_glu_encoder=False, init_decoder_norm=0.5, init_encoder_norm=None, init_encoder_with_decoder_transpose=True, lp=1, l1_coefficient=8e-05, l1_coefficient_warmup_steps=39062, top_k=50, k_warmup_steps=39062, use_batch_norm_mse=True, use_ghost_grads=False, tp_size=1, ddp_size=1)