# Training the SAE on Phi-2

## Preprossing

In [1]:
import os
import torch
from tqdm import tqdm
import plotly.express as px
import pandas as pd

from sae_lens import SAE, HookedSAETransformer, LanguageModelSAERunnerConfig, SAETrainingRunner
from transformers import AutoModelForCausalLM, AutoTokenizer

torch.set_grad_enabled(False)

if torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Device: {device}")

Device: cuda


### Load Qwen2-7B

In [None]:
model = AutoModelForCausalLM.from_pretrained(
    "microsoft/phi-2",
    torch_dtype="auto",
    trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True)

In [7]:
print(model)
print(model.device)
print(tokenizer.bos_token, tokenizer.eos_token, tokenizer.pad_token)

PhiForCausalLM(
  (model): PhiModel(
    (embed_tokens): Embedding(51200, 2560)
    (embed_dropout): Dropout(p=0.0, inplace=False)
    (layers): ModuleList(
      (0-31): 32 x PhiDecoderLayer(
        (self_attn): PhiSdpaAttention(
          (q_proj): Linear(in_features=2560, out_features=2560, bias=True)
          (k_proj): Linear(in_features=2560, out_features=2560, bias=True)
          (v_proj): Linear(in_features=2560, out_features=2560, bias=True)
          (dense): Linear(in_features=2560, out_features=2560, bias=True)
          (rotary_emb): PhiRotaryEmbedding()
        )
        (mlp): PhiMLP(
          (activation_fn): NewGELUActivation()
          (fc1): Linear(in_features=2560, out_features=10240, bias=True)
          (fc2): Linear(in_features=10240, out_features=2560, bias=True)
        )
        (input_layernorm): LayerNorm((2560,), eps=1e-05, elementwise_affine=True)
        (resid_dropout): Dropout(p=0.1, inplace=False)
      )
    )
    (final_layernorm): LayerNorm((256

In [None]:
model_hooked = HookedSAETransformer.from_pretrained("microsoft/phi-2", tokenizer)

In [8]:
print(model_hooked)

HookedSAETransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (blocks): ModuleList(
    (0-31): 32 x TransformerBlock(
      (ln1): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
        (hook_rot_k): HookPoint()
        (hook_rot_q): HookPoint()
      )
      (mlp): MLP(
        (hook_pre): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_attn_in): HookPoint()
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_mlp_in): HookPoint()
      (hook_attn_out): HookPoint()
      (hook_mlp_out): HookPoi

### Test the tokenizer

In [17]:
prompt = ["I'm interested in mechanistic interpretability.",
          "Attention is all you need."]
# tokens= tokenizer("Attention is all you need.", return_tensors="pt", truncation=True)
# tokens = tokenizer.tokenize("Attention is all you need.", add_special_tokens=True)
tokenizer.pad_token = tokenizer.eos_token
tokens = tokenizer.batch_encode_plus(prompt, return_tensors="pt",padding=True, truncation=True)
print(tokens)

{'input_ids': tensor([[   40,  1101,  4609,   287,  3962,  2569,  6179,  1799,    13],
        [ 8086,  1463,   318,   477,   345,   761,    13, 50256, 50256]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 0, 0]])}


### Load the dataset

In [18]:
import jsonlines
import random
import tqdm

data_path_en = "/home/liyy/datasets/wanjuan_1_0/OpenDataLab___WanJuan1_dot_0/raw/nlp/EN/WebText-en/part-000020-a894b46e.jsonl"
data_path = "/home/liyy/Interpretability/MATS/SAE/dataset_en/data.jsonl"
prompt_list = []
with jsonlines.open(data_path_en) as f:
    for line in f:
        prompt_list.append(line)
print("Read data done.")

Read data done.


In [None]:
# get the input_ids for data   
data_id_list = []
data_content_list = []
with jsonlines.open(data_path, "w") as f:
    for line in tqdm.tqdm(prompt_list):
        data_id_list.append(line["id"])
        data_content_list.append(line["content"])
        if len(data_content_list) % 10000 == 0:
            input_dicts = tokenizer.batch_encode_plus(data_content_list, return_tensors="pt", padding=True, truncation=True, max_length=512)
            for data_id, input_ids, attn_mask in zip(data_id_list, input_dicts["input_ids"], input_dicts["attention_mask"]):
                f.write({"id": data_id, "input_ids": input_ids.tolist(), "attention_mask": attn_mask.tolist()})
            data_id_list = []
            data_content_list = []
    if len(data_content_list) > 0:
        input_dicts = tokenizer.batch_encode_plus(data_content_list, return_tensors="pt", padding=True, truncation=True, max_length=512)
        for data_id, input_ids, attn_mask in zip(data_id_list, input_dicts["input_ids"], input_dicts["attention_mask"]):
            f.write({"id": data_id, "input_ids": input_ids.tolist(), "attention_mask": attn_mask.tolist()})

Total num: 3.5M  
Total token num: 3.5M * 512 = 1.8B

I use 1.7M from it to train SAE.  

total_training_steps = 60k  

In [None]:
from datasets import Dataset, load_from_disk
import pandas as pd

df = pd.read_json(data_path, lines=True)
dataset = Dataset.from_pandas(df)

save_path = "/home/liyy/Interpretability/MATS/SAE/dataset_en"
dataset.save_to_disk(save_path)

## Start to train the SAE

In [2]:
n_gpus = torch.cuda.device_count()
print(f"Number of GPUs: {n_gpus}")

Number of GPUs: 3


In [3]:

total_training_steps = 60000
batch_size = 4096
total_training_tokens = total_training_steps * batch_size  # training tokens here means the num of data samples

lr_warm_up_steps = 0
lr_decay_steps = total_training_steps // 5  # 20% of training
l1_warm_up_steps = total_training_steps // 20  # 5% of training

from datasets import load_from_disk
save_path = "/home/liyy/Interpretability/MATS/SAE/dataset_en"
dataset = load_from_disk(save_path)
# print the first sample
print(dataset[0])

cfg = LanguageModelSAERunnerConfig(
    # Data Generating Function (Model + Training Distibuion)
    model_name="microsoft/phi-2",
    model_class_name = "HookedTransformer",
    hook_name="blocks.4.hook_resid_pre",  # A valid hook point (see more details here: https://neelnanda-io.github.io/TransformerLens/generated/demos/Main_Demo.html#Hook-Points)
    hook_layer=4,  # Only one layer in the model.
    d_in=2560,  # the width of the residual stream.
    streaming=True,  # we could pre-download the token dataset if it was small.
    # SAE Parameters
    mse_loss_normalization=None,  # We won't normalize the mse loss,
    expansion_factor=32,  # the width of the SAE. Larger will result in better stats but slower training.
    b_dec_init_method="zeros",  # The geometric median can be used to initialize the decoder weights.
    apply_b_dec_to_input=False,  # We won't apply the decoder weights to the input.
    normalize_sae_decoder=False,
    scale_sparsity_penalty_by_decoder_norm=True,
    decoder_heuristic_init=True,
    init_encoder_as_decoder_transpose=True,
    normalize_activations="expected_average_only_in",
    # Training Parameters
    lr=5e-5,  # lower the better, we'll go fairly high to speed up the tutorial.
    adam_beta1=0.9,  # adam params (default, but once upon a time we experimented with these.)
    adam_beta2=0.999,
    lr_scheduler_name="constant",  # constant learning rate with warmup. Could be better schedules out there.
    lr_warm_up_steps=lr_warm_up_steps,  # this can help avoid too many dead features initially.
    lr_decay_steps=lr_decay_steps,  # this will help us avoid overfitting.
    l1_coefficient=5,  # will control how sparse the feature activations arel1
    l1_warm_up_steps=l1_warm_up_steps,  # this can help avoid too many dead features initially.
    lp_norm=1.0,  # the L1 penalty (and not a Lp for p < 1)
    train_batch_size_tokens=batch_size,
    context_size=512,  # will control the lenght of the prompts we feed to the model. Larger is better but slower. so for the tutorial we'll use a short one.
    prepend_bos=None,
    # Activation Store Parameters
    n_batches_in_buffer=64,  # controls how many activations we store / shuffle.
    training_tokens=total_training_tokens,  # 100 million tokens is quite a few, but we want to see good stats. Get a coffee, come back.
    store_batch_size_prompts=8,
    # Resampling protocol
    use_ghost_grads=False,  # we don't use ghost grads anymore.
    feature_sampling_window=1000,  # this controls our reporting of feature sparsity stats
    dead_feature_window=1000,  # would effect resampling or ghost grads if we were using it.
    dead_feature_threshold=1e-4,  # would effect resampling or ghost grads if we were using it.
    # WANDB
    log_to_wandb=True,  # always use wandb unless you are just testing code.
    wandb_project="Phi-2-SAE",
    run_name="phi-2-sae-layer4-resid_pre-32x-lr_5e-5_constant",
    wandb_log_frequency=30,
    eval_every_n_wandb_logs=20,
    # Misc
    device=device,
    device_sae=1,  # my modification
    seed=14,
    n_checkpoints=1,
    checkpoint_path="/home/liyy/Interpretability/MATS/SAE/ckpts",
    dtype="float32",
    # other
    # model_from_pretrained_kwargs={"n_devices": n_gpus, "device_model": [0, 2]},
    model_from_pretrained_kwargs={"n_devices": n_gpus},
)

sparse_autoencoder = SAETrainingRunner(cfg,
                                       override_model=None,  # save memory
                                       override_dataset=dataset).run()

Loading dataset from disk:   0%|          | 0/29 [00:00<?, ?it/s]



{'id': 'BkZg--3xK1TgopTHYJTw', 'input_ids': [7120, 10976, 2158, 1588, 393, 1402, 481, 7139, 772, 517, 13097, 290, 40840, 4493, 284, 307, 10588, 416, 262, 12624, 37, 13, 6914, 2174, 284, 1064, 503, 703, 284, 16565, 284, 262, 12624, 37, 393, 3904, 319, 262, 10655, 4936, 284, 1064, 503, 546, 4305, 257, 6979, 284, 262, 12624, 37, 287, 534, 2561, 13, 198, 1532, 345, 423, 587, 7867, 416, 644, 345, 423, 1775, 290, 1100, 319, 674, 3052, 345, 743, 765, 284, 2074, 4305, 257, 10655, 284, 262, 4564, 290, 8108, 7557, 287, 534, 481, 13, 2750, 14771, 287, 262, 7557, 345, 481, 307, 14771, 287, 7325, 1919, 23424, 290, 4365, 4213, 326, 481, 1037, 262, 4564, 10630, 4691, 511, 1957, 5348, 13, 29898, 4855, 416, 262, 12624, 37, 787, 262, 11557, 1254, 636, 286, 257, 2055, 11, 1037, 262, 4928, 284, 8209, 351, 1862, 661, 11, 2834, 661, 503, 286, 10681, 290, 6133, 262, 4564, 284, 307, 257, 262, 2612, 286, 2055, 1204, 13, 8013, 4075, 2055, 12352, 460, 287, 1210, 1085, 284, 8557, 290, 29052, 3349, 287, 1957, 1442



Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loaded pretrained model microsoft/phi-2 into HookedTransformer


ERROR:wandb.jupyter:Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33msiriuslala[0m. Use [1m`wandb login --relogin`[0m to force relogin


  yield torch.tensor(
Estimating norm scaling factor: 100%|██████████| 1000/1000 [05:39<00:00,  2.94it/s]
60000| MSE Loss 45894.898 | L1 889.109: 100%|██████████| 245760000/245760000 [11:43:25<00:00, 4671.25it/s]  

TypeError: Object of type device is not JSON serializable

Error in callback <bound method _WandbInit._pause_backend of <wandb.sdk.wandb_init._WandbInit object at 0x7f1a3b128910>> (for post_run_cell), with arguments args (<ExecutionResult object at 7f1bb862dd90, execution_count=3 error_before_exec=None error_in_exec=Object of type device is not JSON serializable info=<ExecutionInfo object at 7f1a7586e950, raw_cell="
total_training_steps = 60000
batch_size = 4096
to.." store_history=True silent=False shell_futures=True cell_id=vscode-notebook-cell://ssh-remote%2B51.cist.cc/home/liyy/Interpretability/MATS/SAE/train_phi2_sae.ipynb#X22sdnNjb2RlLXJlbW90ZQ%3D%3D> result=None>,),kwargs {}:


TypeError: _WandbInit._pause_backend() takes 1 positional argument but 2 were given

60000| MSE Loss 45894.898 | L1 889.109: 100%|██████████| 245760000/245760000 [11:43:35<00:00, 4671.25it/s]