In [1]:
from __future__ import annotations
from typing import TYPE_CHECKING
if TYPE_CHECKING:
    from transformer_lens import HookedTransformer

# from IPython import get_ipython
# ipython = get_ipython()
# ipython.run_line_magic("load_ext", "autoreload")
# ipython.run_line_magic("autoreload", "2")
from IPython.display import display, HTML
from natsort import natsorted
import os
import pickle
from pathlib import Path
import torch
from tqdm.auto import tqdm
import yaml

from sparsify.models.transformers import SAETransformer
from sparsify.log import logger
from sparsify.utils import filter_names, load_config
from sparsify.data import DataConfig
from sparsify.loader import load_tlens_model, load_pretrained_saes
from sparsify.scripts.train_tlens_saes.run_train_tlens_saes import Config
from sparsify.scripts.train_tlens_saes.run_train_tlens_saes import main as run_train
from sparsify.scripts.generate_dashboards import DashboardsConfig, PromptDashboardsConfig, generate_dashboards
from sparsify.utils import replace_pydantic_model
current_dir = Path(os.getcwd())
sae_save_dir = Path(current_dir) / Path("run_for_testing_feature_dashboards")
dashboard_data_save_dir = sae_save_dir / Path("feature_dashboard_data")
sae_position_name = "blocks.1.hook_resid_post"

In [2]:
# Train a SAE on tiny-stories-1M and save it to sae_save_dir 
# Just load a pretrained SAE if we already have one handy
if len(list(sae_save_dir.glob("*.pt"))) > 0 or len(list(sae_save_dir.glob("*/*.pt"))) > 0 or len(list(sae_save_dir.glob("*\\*.pt"))) > 0: 
    print(f"SAEs already exist in sae_save_dir = {sae_save_dir}\nUsing those.")
else:
    # Train a shitty SAE if we don't have one already
    config_path_str = Path('../sparsify/scripts/train_tlens_saes/tinystories_1M.yaml')
    with open(config_path_str) as f:
        base_config = Config(**yaml.safe_load(stream=f))
    update_dict = {
                "train": {
                    "save_dir": sae_save_dir,
                    "save_every_n_samples": 20000,
                    "n_samples": 20000,
                    "loss_configs": {
                    "inp_to_out":{"coeff":1.0},
                    "logits_kl": None}},
                "saes": {"sae_position_names": sae_position_name},
                "wandb_project": None,
            }
    new_config = replace_pydantic_model(base_config, update_dict)
    print(new_config)
    run_train(new_config)

SAEs already exist in sae_save_dir = /mnt/c/Users/nadro/Documents/AI_safety/MATS5/Sparsify/sparsify/notebooks/run_for_testing_feature_dashboards
Using those.


In [3]:
# Load the saved SAEs and the corresponding model
def load_SAETransformer_from_saes_path(
    saes_path: Path,
    config_path: str | Path | None = None,
    tlens_model: HookedTransformer | None = None,
) -> tuple[SAETransformer, Config, list[str]]:
    saes_path = Path(saes_path)
    # Allow passing in a directoty and finding the latest .pt or .pth file in it:
    if saes_path.suffix != ".pt" and saes_path.suffix != ".pth":
        if not saes_path.is_dir():
            saes_path = saes_path.parent
        saes_paths = natsorted(list(saes_path.glob("*.pt")) + list(saes_path.glob("*.pth")))
        if len(saes_paths) == 0:
            saes_paths = natsorted(list(saes_path.glob("*/*.pt")) + list(saes_path.glob("*/*.pth")))
        if len(saes_paths) == 0:
            saes_paths = natsorted(list(saes_path.glob("*\\*.pt")) + list(saes_path.glob("*\\*.pth")))
        assert len(saes_paths) > 0, "Could not find any .pt or .pth files in the saes_path"
        saes_path = saes_paths[-1]
    assert saes_path.exists(), "saes_path does not exist"
    config_path = saes_path.parent / "config.yaml" if config_path is None else Path(config_path)
    assert (
        config_path.exists()
    ), "Could not find the config_path: config.yaml should be in the same folder as the saes_path"
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    config = load_config(config_path, config_model=Config)
    logger.info(config)
    if tlens_model is None:
        tlens_model = load_tlens_model(
            tlens_model_name=config.tlens_model_name, tlens_model_path=config.tlens_model_path
        )
    raw_sae_position_names = filter_names(
        list(tlens_model.hook_dict.keys()), config.saes.sae_position_names
    )
    model = SAETransformer(
        config=config, tlens_model=tlens_model, raw_sae_position_names=raw_sae_position_names
    ).to(device=device)

    all_param_names = [name for name, _ in model.saes.named_parameters()]
    trainable_param_names = load_pretrained_saes(
        saes=model.saes,
        pretrained_sae_paths=[saes_path]
        if config.saes.pretrained_sae_paths is None
        else [saes_path] + config.saes.pretrained_sae_paths,
        all_param_names=all_param_names,
        retrain_saes=config.saes.retrain_saes,
    )
    return model, config, trainable_param_names

print("Loading the model and SAEs")
model, config, _ = load_SAETransformer_from_saes_path(sae_save_dir)
print("done")

Loading the model and SAEs


2024-03-07 12:51:49 - INFO - seed=0 tlens_model_name='roneneldan/TinyStories-1M' tlens_model_path=None train=TrainConfig(save_dir=PosixPath('/mnt/c/Users/nadro/Documents/AI_safety/MATS5/Sparsify/sparsify/notebooks/run_for_testing_feature_dashboards'), save_every_n_samples=20000, n_samples=20000, batch_size=10, effective_batch_size=10, lr=0.001, warmup_samples=20000, cooldown_samples=0, max_grad_norm=1.0, log_every_n_grad_steps=20, collect_discrete_metrics_every_n_samples=10000, discrete_metrics_n_tokens=500000, collect_output_metrics_every_n_samples=0, loss_configs=LossConfigs(sparsity=SparsityLossConfig(coeff=5e-06, p_norm=0.6), inp_to_orig=None, out_to_orig=None, inp_to_out=InpToOutLossConfig(coeff=1.0), logits_kl=None)) data=DataConfig(dataset_name='apollo-research/sae-skeskinen-TinyStories-hf-tokenizer-gpt2', is_tokenized=True, tokenizer_name='gpt2', streaming=True, split='train', n_ctx=512, column_name='input_ids') saes=SparsifiersConfig(type_of_sparsifier='sae', dict_size_to_inpu

Loaded pretrained model roneneldan/TinyStories-1M into HookedTransformer
Moving model to device:  cuda
done


In [9]:
config.tlens_model_name

'roneneldan/TinyStories-1M'

In [4]:
torch.cuda.empty_cache()
! nvidia-smi

Thu Mar  7 12:51:52 2024       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 511.04.01    Driver Version: 511.09       CUDA Version: 11.6     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA GeForce ...  On   | 00000000:01:00.0 Off |                  N/A |
| N/A   67C    P2    29W /  N/A |    552MiB /  6144MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [5]:
data_config = DataConfig(
    dataset_name='apollo-research/sae-skeskinen-TinyStories-hf-tokenizer-gpt2', 
    tokenizer_name='gpt2', 
    split = "train",
    n_ctx=512, 
)

In [16]:
# Generate the dashboards
dashboards_config = DashboardsConfig(
    n_samples = 200, 
    batch_size = 2,
    minibatch_size_features = 100,
    data = data_config,
    feature_indices = list(range(50)),
    prompt_centric = PromptDashboardsConfig(
        prompts = ["Sally met Mike at the show. She brought popcorn for him."],
        n_random_prompt_dashboards = 0
    )
)
generate_dashboards(model, dashboards_config)

ValidationError: 1 validation error for DashboardsConfig
store_features_as_sparse
  Extra inputs are not permitted [type=extra_forbidden, input_value=True, input_type=bool]
    For further information visit https://errors.pydantic.dev/2.6/v/extra_forbidden