In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys

sys.path.append("../")

import torch
import transformers
import baukit
from tqdm.auto import tqdm
import json
import os
from src import functional
import numpy as np
import logging
from src import models
from src.utils import env_utils

from src.utils import logging_utils
logger = logging.getLogger(__name__)
logging.basicConfig(
    level=logging.INFO,
    format=logging_utils.DEFAULT_FORMAT,
    datefmt=logging_utils.DEFAULT_DATEFMT,
    stream=sys.stdout,
)

torch.__version__, transformers.__version__, torch.version.cuda

  from .autonotebook import tqdm as notebook_tqdm


('2.4.1', '4.44.2', '12.1')

In [3]:
from datasets import load_dataset

wiki = load_dataset(
    os.path.join(
        env_utils.DEFAULT_DATA_DIR, "hf_datasets",
        "wikimedia/wikipedia", "20231101.en"
    )
)

2024-10-14 18:33:07 numexpr.utils INFO     Note: NumExpr detected 24 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
2024-10-14 18:33:07 numexpr.utils INFO     NumExpr defaulting to 8 threads.
2024-10-14 18:33:08 datasets INFO     PyTorch version 2.4.1 available.


In [4]:
# wiki.save_to_disk(
#     os.path.join(
#         env_utils.DEFAULT_DATA_DIR,
#         "wikimedia/wikipedia", "20231101.en"
#     )
# )

In [5]:
print(len(wiki["train"]))
wiki["train"][:5]

6407814


{'id': ['12', '39', '290', '303', '305'],
 'url': ['https://en.wikipedia.org/wiki/Anarchism',
  'https://en.wikipedia.org/wiki/Albedo',
  'https://en.wikipedia.org/wiki/A',
  'https://en.wikipedia.org/wiki/Alabama',
  'https://en.wikipedia.org/wiki/Achilles'],
 'title': ['Anarchism', 'Albedo', 'A', 'Alabama', 'Achilles'],
 'text': ['Anarchism is a political philosophy and movement that is skeptical of all justifications for authority and seeks to abolish the institutions it claims maintain unnecessary coercion and hierarchy, typically including nation-states, and capitalism. Anarchism advocates for the replacement of the state with stateless societies and voluntary free associations. As a historically left-wing movement, this reading of anarchism is placed on the farthest left of the political spectrum, usually described as the libertarian wing of the socialist movement (libertarian socialism).\n\nHumans have lived in societies without formal hierarchies long before the establishment o

In [6]:
# print(len(wiki_loaded["train"]))
# wiki_loaded["train"][:5]

In [7]:
# wiki_loaded = load_dataset(
#     os.path.join(
#         env_utils.DEFAULT_DATA_DIR,
#         "wikimedia/wikipedia", "20231101.en"
#     )
# )

In [8]:
# tiny = load_dataset("roneneldan/TinyStories")

# print(len(tiny["train"]))
# tiny["train"][5055]

In [9]:
from src.models import ModelandTokenizer

# model_name = "openai-community/gpt2-xl"
model_name = "openai-community/gpt2"
# model_name = "EleutherAI/pythia-410m"
# model_name = "google/gemma-2-2b"
# model_name = "meta-llama/Llama-3.2-1B"

mt = ModelandTokenizer(
    model_key=model_name,
    torch_dtype=torch.float32,
)

mt.n_embd, mt.n_layer

If not found in cache, model will be downloaded from HuggingFace to cache directory




2024-10-14 18:33:18 accelerate.utils.modeling INFO     We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).
2024-10-14 18:33:19 src.models INFO     loaded model <openai-community/gpt2> | size: 486.700 MB | dtype: torch.float32 | device: cuda:0


(768, 12)

In [10]:
from transformer_lens import HookedTransformer

tl_model = HookedTransformer.from_pretrained(
    model_name="gpt2",
    hf_model=mt._model,
    tokenizer=mt.tokenizer,
    device=mt.device,
    default_prepend_bos=False,
    fold_ln=False,
    center_unembed=False,
    center_writing_weights=False,
    refactor_factored_attn_matrices=False
) 



Loaded pretrained model gpt2 into HookedTransformer


In [11]:
from src.functional import prepare_input
prompt = "A quick brown fox jumps over the lazy"
tokens = prepare_input(prompt, tokenizer=mt)

_logits, cache = tl_model.run_with_cache(prompt)
cache

ActivationCache with keys ['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre', 'blocks.0.ln1.hook_scale', 'blocks.0.ln1.hook_normalized', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_mid', 'blocks.0.ln2.hook_scale', 'blocks.0.ln2.hook_normalized', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.ln1.hook_scale', 'blocks.1.ln1.hook_normalized', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'blocks.1.attn.hook_v', 'blocks.1.attn.hook_attn_scores', 'blocks.1.attn.hook_pattern', 'blocks.1.attn.hook_z', 'blocks.1.hook_attn_out', 'blocks.1.hook_resid_mid', 'blocks.1.ln2.hook_scale', 'blocks.1.ln2.hook_normalized', 'blocks.1.mlp.hook_pre', 'blocks.1.mlp.hook_post', 'blocks.1.hook_mlp_out', 'blocks.1.hook_resid_post', 'blocks.2.hook_re

In [12]:
type(cache)

transformer_lens.ActivationCache.ActivationCache

In [13]:
print(cache["blocks.6.hook_resid_post"].shape)
cache["blocks.6.hook_resid_post"]

torch.Size([1, 8, 768])


tensor([[[-0.9570, -2.0778,  1.0478,  ..., -1.1265, -0.6831,  0.4655],
         [ 3.2263, -1.5937, -5.3277,  ..., -1.0556,  1.0042, -2.1204],
         [ 3.9958,  3.3184, -1.5888,  ...,  0.6263,  0.6513, -2.4669],
         ...,
         [ 0.1355,  0.1361, -6.1814,  ...,  0.4304, -0.5404,  0.3766],
         [-0.4256,  1.2168, -4.6042,  ...,  1.8664,  2.0063, -0.9233],
         [ 3.9466, -1.1716, -4.6144,  ...,  0.5877,  1.3280, -1.0215]]],
       device='cuda:0')

In [14]:
tokens.input_ids.shape

torch.Size([1, 8])

In [15]:
from src.functional import get_module_nnsight

with mt.trace(tokens) as tr:
    module = get_module_nnsight(mt, mt.layer_name_format.format(6))
    resid_out = module.output[0].save()

print(resid_out.shape)
resid_out

You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


torch.Size([1, 8, 768])


tensor([[[-0.9570, -2.0778,  1.0478,  ..., -1.1265, -0.6831,  0.4655],
         [ 3.2263, -1.5937, -5.3277,  ..., -1.0556,  1.0042, -2.1204],
         [ 3.9958,  3.3184, -1.5888,  ...,  0.6263,  0.6513, -2.4669],
         ...,
         [ 0.1355,  0.1361, -6.1814,  ...,  0.4304, -0.5404,  0.3766],
         [-0.4256,  1.2168, -4.6042,  ...,  1.8664,  2.0063, -0.9233],
         [ 3.9466, -1.1716, -4.6144,  ...,  0.5877,  1.3280, -1.0215]]],
       device='cuda:0')

In [16]:
torch.allclose(resid_out, cache["blocks.6.hook_resid_post"], atol=1e-5)

True

In [17]:
mt.n_embd

768

In [20]:
from sae_lens import CacheActivationsRunnerConfig, CacheActivationsRunner

cfg = CacheActivationsRunnerConfig(
    model_name="gpt2",
    # hook
    hook_name="blocks.6.hook_resid_post",
    hook_layer=6,
    d_in=mt.n_embd,

    # dataset
    dataset_path=os.path.join(
        env_utils.DEFAULT_DATA_DIR, "hf_datasets",
        "wikimedia/wikipedia", "20231101.en"
    ),
    new_cached_activations_path=os.path.join(
        env_utils.DEFAULT_DATA_DIR, "cached_activations",
        "wikimedia/wikipedia", "20231101.en",
    ),
    is_dataset_tokenized=False,

    # activation store params
    training_tokens=1024,
    n_batches_in_buffer=1,
    store_batch_size_prompts=1,
    context_size=1024,

    # activation cachching stuff


    # MISCELLANEOUS
    device="cuda",
    seed=42,
    prepend_bos=False,

    # # model params
    # model_kwargs=dict(
    #     # default_prepend_bos=False,
    #     fold_ln=False,
    #     center_unembed=False,
    #     center_writing_weights=False,
    #     refactor_factored_attn_matrices=False
    # )
)

CacheActivationsRunner(cfg).run()

Loaded pretrained model gpt2 into HookedTransformer
Started caching 1024 activations
--------------------------------------------------------------------------------
n_buffers=1 | tokens_per_buffer=1024
--------------------------------------------------------------------------------


Caching activations:   0%|          | 0/1 [00:00<?, ?it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (8668 > 1024). Running this sequence through the model will result in indexing errors
Caching activations: 100%|██████████| 1/1 [00:00<00:00, 20.76it/s]


In [21]:
from safetensors import safe_open

with safe_open(
    os.path.join(
        env_utils.DEFAULT_DATA_DIR, "cached_activations",
        "wikimedia/wikipedia", "20231101.en",
        "0.safetensors"
    ),
    framework="pt",
    device="cuda"
) as f:
    activations = {}
    for k in f.keys():
        activations[k] = f.get_tensor(k)

In [22]:
activations["activations"].shape

torch.Size([1024, 1, 768])

In [41]:
activations["activations"]

tensor([[[ 2.9304,  0.7253,  1.4060,  ..., -1.3532, -1.3724, -0.7944]],

        [[-0.4555,  0.8672,  0.5951,  ...,  2.0086,  2.6229, -0.8465]],

        [[ 1.0792, -2.2333,  1.9231,  ...,  0.3579, -1.7863, -0.1422]],

        ...,

        [[-3.4977, -4.3733, -1.9184,  ..., -0.8113, -0.6607, -1.8814]],

        [[ 1.1604,  1.4785,  5.4576,  ...,  1.2939, -1.6880, -0.8128]],

        [[ 0.6926, -2.5608,  0.3468,  ..., -0.1245, -0.2460,  1.4254]]],
       device='cuda:0')

In [23]:
from sae_lens import LanguageModelSAERunnerConfig, SAETrainingRunner
from src.utils import env_utils

total_training_steps = 300  # probably we should do more
batch_size = 128
total_training_tokens = total_training_steps * batch_size

lr_warm_up_steps = 0
lr_decay_steps = total_training_steps // 5  # 20% of training
l1_warm_up_steps = total_training_steps // 20  # 5% of training

cfg = LanguageModelSAERunnerConfig(
    # Data Generating Function (Model + Training Distibuion)
    model_name="gpt2",  # our model (more options here: https://neelnanda-io.github.io/TransformerLens/generated/model_properties_table.html)
    hook_name="blocks.0.hook_mlp_out",  # A valid hook point (see more details here: https://neelnanda-io.github.io/TransformerLens/generated/demos/Main_Demo.html#Hook-Points)
    hook_layer=6,  # Only one layer in the model.
    d_in=mt.n_embd,  # the width of the mlp output.
    dataset_path="apollo-research/roneneldan-TinyStories-tokenizer-gpt2",  # this is a tokenized language dataset on Huggingface for the Tiny Stories corpus.
    is_dataset_tokenized=True,
    streaming=True,  # we could pre-download the token dataset if it was small.
    # SAE Parameters
    mse_loss_normalization=None,  # We won't normalize the mse loss,
    expansion_factor=8,  # the width of the SAE. Larger will result in better stats but slower training.
    b_dec_init_method="zeros",  # The geometric median can be used to initialize the decoder weights.
    apply_b_dec_to_input=False,  # We won't apply the decoder weights to the input.
    normalize_sae_decoder=False,
    scale_sparsity_penalty_by_decoder_norm=True,
    decoder_heuristic_init=True,
    init_encoder_as_decoder_transpose=True,
    normalize_activations="expected_average_only_in",
    # Training Parameters
    lr=5e-5,  # lower the better, we'll go fairly high to speed up the tutorial.
    adam_beta1=0.9,  # adam params (default, but once upon a time we experimented with these.)
    adam_beta2=0.999,
    lr_scheduler_name="constant",  # constant learning rate with warmup. Could be better schedules out there.
    lr_warm_up_steps=lr_warm_up_steps,  # this can help avoid too many dead features initially.
    lr_decay_steps=lr_decay_steps,  # this will help us avoid overfitting.
    l1_coefficient=5,  # will control how sparse the feature activations are
    l1_warm_up_steps=l1_warm_up_steps,  # this can help avoid too many dead features initially.
    lp_norm=1.0,  # the L1 penalty (and not a Lp for p < 1)
    train_batch_size_tokens=batch_size,
    context_size=512,  # will control the lenght of the prompts we feed to the model. Larger is better but slower. so for the tutorial we'll use a short one.
    # Activation Store Parameters
    n_batches_in_buffer=64,  # controls how many activations we store / shuffle.
    training_tokens=total_training_tokens,  # 100 million tokens is quite a few, but we want to see good stats. Get a coffee, come back.
    store_batch_size_prompts=16,
    # Resampling protocol
    use_ghost_grads=False,  # we don't use ghost grads anymore.
    feature_sampling_window=1000,  # this controls our reporting of feature sparsity stats
    dead_feature_window=1000,  # would effect resampling or ghost grads if we were using it.
    dead_feature_threshold=1e-4,  # would effect resampling or ghost grads if we were using it.
    # WANDB
    # log_to_wandb=True,  # always use wandb unless you are just testing code.
    # wandb_project="sae_lens_tutorial",
    # wandb_log_frequency=30,
    # eval_every_n_wandb_logs=20,
    # Misc
    device="cuda",  # we'll use a GPU for this.
    seed=42,
    n_checkpoints=5,
    checkpoint_path=os.path.join(env_utils.DEFAULT_RESULTS_DIR, "SAE_trainig", "checkpoints"),
    dtype="float32",
)
# look at the next cell to see some instruction for what to do while this is running.
sparse_autoencoder = SAETrainingRunner(cfg).run()

Run name: 6144-L1-5-LR-5e-05-Tokens-3.840e+04
n_tokens_per_buffer (millions): 0.524288
Lower bound: n_contexts_per_buffer (millions): 0.001024
Total training steps: 300
Total wandb updates: 30
n_tokens_per_feature_sampling_window (millions): 65.536
n_tokens_per_dead_feature_window (millions): 65.536
We will reset the sparsity calculation 0 times.
Number tokens in sparsity calculation window: 1.28e+05
2024-10-14 11:45:48 urllib3.connectionpool DEBUG    https://huggingface.co:443 "HEAD /gpt2/resolve/main/config.json HTTP/11" 200 0
2024-10-14 11:45:48 urllib3.connectionpool DEBUG    https://huggingface.co:443 "HEAD /gpt2/resolve/main/config.json HTTP/11" 200 0
2024-10-14 11:45:48 urllib3.connectionpool DEBUG    https://huggingface.co:443 "HEAD /gpt2/resolve/main/generation_config.json HTTP/11" 200 0
2024-10-14 11:45:48 urllib3.connectionpool DEBUG    https://huggingface.co:443 "HEAD /gpt2/resolve/main/tokenizer_config.json HTTP/11" 200 0
Loaded pretrained model gpt2 into HookedTransformer

2024-10-14 11:45:53 urllib3.connectionpool DEBUG    Starting new HTTPS connection (1): o151352.ingest.sentry.io:443
2024-10-14 11:45:53 urllib3.connectionpool DEBUG    https://o151352.ingest.sentry.io:443 "POST /api/4504800232407040/envelope/ HTTP/11" 200 0


2024-10-14 11:45:53 git.cmd DEBUG    Popen(['git', 'cat-file', '--batch-check'], cwd=/home/local_arnab/Codes/Projects/sae, stdin=<valid stream>, shell=False, universal_newlines=False)


  self.scaler = torch.cuda.amp.GradScaler(enabled=self.cfg.autocast)

[A

2024-10-14 11:46:01 fsspec DEBUG    <File-like object HfFileSystem, datasets/apollo-research/roneneldan-TinyStories-tokenizer-gpt2@bc8db71bbc792977b43d430bddeeb9906e193f8d/data/train-00000-of-00004.parquet> read: 197020944 - 197086480
2024-10-14 11:46:01 urllib3.connectionpool DEBUG    https://huggingface.co:443 "GET /datasets/apollo-research/roneneldan-TinyStories-tokenizer-gpt2/resolve/bc8db71bbc792977b43d430bddeeb9906e193f8d/data/train-00000-of-00004.parquet HTTP/11" 302 1126
2024-10-14 11:46:01 urllib3.connectionpool DEBUG    https://cdn-lfs-us-1.hf.co:443 "GET /repos/d9/82/d9825e13c625bba2429d3bbaccad722573c35efb72a6dafb7c1360f56867fc87/8015fda36edcaad836761cffb6c29493d058cffe76a1e2f9544a4f8f31bfa999?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27train-00000-of-00004.parquet%3B+filename%3D%22train-00000-of-00004.parquet%22%3B&Expires=1729179961&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcyOTE3OTk2MX19LCJSZXNvdXJjZSI6Imh0


[A
[A
Estimating norm scaling factor: 100%|██████████| 1000/1000 [00:08<00:00, 116.36it/s]


2024-10-14 11:46:10 urllib3.connectionpool DEBUG    Starting new HTTPS connection (1): api.wandb.ai:443
2024-10-14 11:46:10 urllib3.connectionpool DEBUG    https://api.wandb.ai:443 "POST /graphql HTTP/11" 200 42
2024-10-14 11:46:10 urllib3.connectionpool DEBUG    Starting new HTTPS connection (1): api.wandb.ai:443
2024-10-14 11:46:10 urllib3.connectionpool DEBUG    https://api.wandb.ai:443 "POST /graphql HTTP/11" 200 42




2024-10-14 11:46:11 urllib3.connectionpool DEBUG    Starting new HTTPS connection (1): api.wandb.ai:443
2024-10-14 11:46:11 urllib3.connectionpool DEBUG    https://api.wandb.ai:443 "POST /graphql HTTP/11" 200 42
2024-10-14 11:46:11 urllib3.connectionpool DEBUG    Starting new HTTPS connection (1): api.wandb.ai:443
2024-10-14 11:46:11 urllib3.connectionpool DEBUG    https://api.wandb.ai:443 "POST /graphql HTTP/11" 200 42
2024-10-14 11:46:11 urllib3.connectionpool DEBUG    Starting new HTTPS connection (1): api.wandb.ai:443
2024-10-14 11:46:11 urllib3.connectionpool DEBUG    https://api.wandb.ai:443 "POST /graphql HTTP/11" 200 42
2024-10-14 11:46:11 urllib3.connectionpool DEBUG    Starting new HTTPS connection (1): api.wandb.ai:443
2024-10-14 11:46:11 urllib3.connectionpool DEBUG    https://api.wandb.ai:443 "POST /graphql HTTP/11" 200 42




2024-10-14 11:46:11 urllib3.connectionpool DEBUG    Starting new HTTPS connection (1): api.wandb.ai:443
2024-10-14 11:46:12 urllib3.connectionpool DEBUG    https://api.wandb.ai:443 "POST /graphql HTTP/11" 200 42
2024-10-14 11:46:12 urllib3.connectionpool DEBUG    Starting new HTTPS connection (1): api.wandb.ai:443
2024-10-14 11:46:12 urllib3.connectionpool DEBUG    https://api.wandb.ai:443 "POST /graphql HTTP/11" 200 42




2024-10-14 11:46:12 urllib3.connectionpool DEBUG    Starting new HTTPS connection (1): api.wandb.ai:443
2024-10-14 11:46:12 urllib3.connectionpool DEBUG    https://api.wandb.ai:443 "POST /graphql HTTP/11" 200 42
2024-10-14 11:46:12 urllib3.connectionpool DEBUG    Starting new HTTPS connection (1): api.wandb.ai:443
2024-10-14 11:46:12 urllib3.connectionpool DEBUG    https://api.wandb.ai:443 "POST /graphql HTTP/11" 200 42


300| MSE Loss 364.449 | L1 95.872: 100%|██████████| 38400/38400 [00:11<00:00, 3467.04it/s]


0,1
details/current_l1_coefficient,▁█████████████████████████████
details/current_learning_rate,████████████████████████▇▆▅▃▂▁
details/n_training_tokens,▁▁▁▂▂▂▂▃▃▃▃▄▄▄▄▅▅▅▅▆▆▆▆▇▇▇▇███
losses/auxiliary_reconstruction_loss,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/l1_loss,▂▁▁▁▁▁▁▁▁▂▂▃▃▃▄▄▄▅▅▆▆▆▇▇▇█████
losses/mse_loss,█▅▆▇▆▆▇▆▆▅▅▅▅▆▄▄▄▃▃▄▃▂▂▂▂▁▁▁▁▁
losses/overall_loss,█▅▅▇▅▅▇▅▅▄▅▅▄▆▄▄▃▃▃▅▃▂▂▂▂▁▁▁▁▁
metrics/explained_variance,▁▂▂▂▂▂▂▂▃▂▂▃▂▃▄▄▄▅▅▆▆▆▆▇▇█████
metrics/explained_variance_std,█▆▇▆▆▆▅▇▄▆▆▄▇▄▄▃▃▃▃▃▂▂▂▁▁▁▁▁▁▁
metrics/l0,█▃▁▁▁▁▁▁▂▂▃▃▃▄▄▄▅▅▅▅▆▆▆▆▆▅▅▅▅▅

0,1
details/current_l1_coefficient,5.0
details/current_learning_rate,0.0
details/n_training_tokens,38400.0
losses/auxiliary_reconstruction_loss,0.0
losses/l1_loss,19.17431
losses/mse_loss,364.44943
losses/overall_loss,460.32098
metrics/explained_variance,0.2516
metrics/explained_variance_std,0.20366
metrics/l0,810.07812


2024-10-14 11:46:38 urllib3.connectionpool DEBUG    https://o151352.ingest.sentry.io:443 "POST /api/4504800232407040/envelope/ HTTP/11" 200 0
