In [2]:
!pip install natsort

Collecting natsort
  Using cached natsort-8.4.0-py3-none-any.whl.metadata (21 kB)
Using cached natsort-8.4.0-py3-none-any.whl (38 kB)
Installing collected packages: natsort
Successfully installed natsort-8.4.0


In [3]:
from __future__ import annotations
from typing import TYPE_CHECKING
if TYPE_CHECKING:
    from transformer_lens import HookedTransformer
from natsort import natsorted
import os
from pathlib import Path
import torch
import yaml

from sparsify.models.transformers import SAETransformer
from sparsify.log import logger
from sparsify.utils import filter_names, load_config
from sparsify.data import DatasetConfig
from sparsify.loader import load_tlens_model, load_pretrained_saes
from sparsify.scripts.train_tlens_saes.run_train_tlens_saes import Config
from sparsify.scripts.train_tlens_saes.run_train_tlens_saes import main as run_train
from sparsify.scripts.generate_dashboards import DashboardsConfig, PromptDashboardsConfig, generate_dashboards
from sparsify.utils import replace_pydantic_model

current_dir = Path(os.getcwd())
sae_save_dir = Path(current_dir) / Path("out") / Path("run_for_testing_feature_dashboards")
dashboard_data_save_dir = sae_save_dir / Path("feature_dashboard_data")
sae_position_name = "blocks.1.hook_resid_post"
config_dir = Path('../sparsify/scripts/train_tlens_saes/tinystories_1M.yaml')

  return torch._C._cuda_getDeviceCount() > 0


In [4]:
# Train a SAE on tiny-stories-1M and save it to sae_save_dir 
# Just load a pretrained SAE if we already have one handy
if len(list(sae_save_dir.glob("*.pt"))) > 0 or len(list(sae_save_dir.glob("*/*.pt"))) > 0 or len(list(sae_save_dir.glob("*\\*.pt"))) > 0: 
    print(f"SAEs already exist in sae_save_dir = {sae_save_dir}\nUsing those.")
else:
    # Train a shitty SAE if we don't have one already
    with open(config_dir) as f:
        base_config = Config(**yaml.safe_load(stream=f))
    update_dict = {
                "save_dir": sae_save_dir,
                "save_every_n_samples": 20000,
                "n_samples": 20000,
                "warmup_samples": 5000,
                "cooldown_samples": 5000,
                "loss": {
                    "out_to_in":{"coeff":1.0},
                    "logits_kl": None
                    },
                "saes": {"sae_positions": sae_position_name},
                "wandb_project": None,
            }
    new_config = replace_pydantic_model(base_config, update_dict)
    print(new_config)
    run_train(new_config)

ValidationError: 1 validation error for Config
loss.inp_to_out
  Extra inputs are not permitted [type=extra_forbidden, input_value={'coeff': 1.0}, input_type=dict]
    For further information visit https://errors.pydantic.dev/2.6/v/extra_forbidden

In [6]:
print(base_config)

wandb_project='tinystories-1m_play' wandb_run_name=None wandb_run_name_prefix='' seed=0 tlens_model_name='roneneldan/TinyStories-1M' tlens_model_path=None save_dir=PosixPath('/mnt/c/Users/nadro/Documents/AI_safety/MATS5/Sparsify/sparsify/sparsify/scripts/train_tlens_saes/out') n_samples=900000 save_every_n_samples=None eval_every_n_samples=10000 eval_n_samples=500 batch_size=10 effective_batch_size=10 lr=0.001 adam_beta1=0.0 warmup_samples=50000 cooldown_samples=200000 max_grad_norm=1.0 log_every_n_grad_steps=20 collect_act_frequency_every_n_samples=10000 act_frequency_n_tokens=500000 collect_output_metrics_every_n_samples=0 loss=LossConfigs(sparsity=SparsityLoss(coeff=0.1, p_norm=1.0), in_to_orig=InToOrigLoss(coeff=0.0, hook_positions=['hook_resid_post']), out_to_orig=None, out_to_in=OutToInLoss(coeff=0.0), logits_kl=LogitsKLLoss(coeff=1.0)) train_data=DatasetConfig(dataset_name='apollo-research/roneneldan-TinyStories-tokenizer-gpt2', is_tokenized=True, tokenizer_name='gpt2', streamin

In [None]:
# Load the saved SAEs and the corresponding model
def load_SAETransformer_from_saes_path(
    saes_path: Path,
    config_path: str | Path | None = None,
    tlens_model: HookedTransformer | None = None,
) -> tuple[SAETransformer, Config, list[str]]:
    saes_path = Path(saes_path)
    # Allow passing in a directoty and finding the latest .pt or .pth file in it:
    if saes_path.suffix != ".pt" and saes_path.suffix != ".pth":
        if not saes_path.is_dir():
            saes_path = saes_path.parent
        saes_paths = natsorted(list(saes_path.glob("*.pt")) + list(saes_path.glob("*.pth")))
        if len(saes_paths) == 0:
            saes_paths = natsorted(list(saes_path.glob("*/*.pt")) + list(saes_path.glob("*/*.pth")))
        if len(saes_paths) == 0:
            saes_paths = natsorted(list(saes_path.glob("*\\*.pt")) + list(saes_path.glob("*\\*.pth")))
        assert len(saes_paths) > 0, "Could not find any .pt or .pth files in the saes_path"
        saes_path = saes_paths[-1]
    assert saes_path.exists(), "saes_path does not exist"
    config_path = saes_path.parent / "config.yaml" if config_path is None else Path(config_path)
    assert (
        config_path.exists()
    ), "Could not find the config_path: config.yaml should be in the same folder as the saes_path"
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    config = load_config(config_path, config_model=Config)
    logger.info(config)
    if tlens_model is None:
        tlens_model = load_tlens_model(
            tlens_model_name=config.tlens_model_name, tlens_model_path=config.tlens_model_path
        )
    raw_sae_positions = filter_names(
        list(tlens_model.hook_dict.keys()), config.saes.sae_positions
    )
    model = SAETransformer(
        config=config, tlens_model=tlens_model, raw_sae_positions=raw_sae_positions
    ).to(device=device)

    all_param_names = [name for name, _ in model.saes.named_parameters()]
    trainable_param_names = load_pretrained_saes(
        saes=model.saes,
        pretrained_sae_paths=[saes_path]
        if config.saes.pretrained_sae_paths is None
        else [saes_path] + config.saes.pretrained_sae_paths,
        all_param_names=all_param_names,
        retrain_saes=config.saes.retrain_saes,
    )
    return model, config, trainable_param_names

print("Loading the model and SAEs")
model, config, _ = load_SAETransformer_from_saes_path(sae_save_dir)
print("done")

In [None]:
dataset_config = DatasetConfig(
    dataset_name='apollo-research/sae-skeskinen-TinyStories-hf-tokenizer-gpt2', 
    tokenizer_name='gpt2', 
    split = "train",
    n_ctx=512, 
)

In [None]:
# Generate the dashboards
dashboards_config = DashboardsConfig(
    n_samples = 2000, 
    batch_size = 20,
    minibatch_size_features = 100,
    save_dir = sae_save_dir,
    data = dataset_config,
    feature_indices = list(range(50)),
    prompt_centric = PromptDashboardsConfig(
        prompts = ["Sally met Mike at the show. She brought popcorn for him."],
        n_random_prompt_dashboards = 10
    )
)
generate_dashboards(model, dashboards_config)