In [2]:
import itertools
import os
from pathlib import Path
from typing import Any, Callable, Literal, TypeAlias

import torch as t
from datasets import load_dataset
from huggingface_hub import hf_hub_download
from IPython.display import HTML, IFrame, clear_output, display
from jaxtyping import Float, Int

from rich import print as rprint
from rich.table import Table
from sae_lens import (
    SAE,
    ActivationsStore,
    HookedSAETransformer,
    LanguageModelSAERunnerConfig,
    SAEConfig,
    SAETrainingRunner,
    upload_saes_to_huggingface,
)
from sae_lens.toolkit.pretrained_saes_directory import get_pretrained_saes_directory
from sae_vis import SaeVisData, SaeVisLayoutConfig
from sae_dashboard.data_writing_fns import save_feature_centric_vis
from sae_dashboard.sae_vis_data import SaeVisConfig
from sae_dashboard.sae_vis_runner import SaeVisRunner
from tabulate import tabulate
from torch import Tensor, nn
from torch.distributions.categorical import Categorical
from torch.nn import functional as F
from tqdm.auto import tqdm
from transformer_lens import ActivationCache, HookedTransformer, utils
from transformer_lens.hook_points import HookPoint
from transformers import AutoTokenizer

device = t.device("mps" if t.backends.mps.is_available() else "cuda" if t.cuda.is_available() else "cpu")

In [3]:
t.set_grad_enabled(False)
hf_repo_id = "talibk/TinyStories-33M-SAE-lower-dead-threshold"
dataset_path = "roneneldan/TinyStories"
path_for_vis = Path(f"{os.getcwd()}").resolve()

# Load the model
ts_model: HookedSAETransformer = HookedSAETransformer.from_pretrained("roneneldan/TinyStories-33M")

# Load SAEs
ts_saes = {
    layer: SAE.from_pretrained(
        release=hf_repo_id,
        sae_id=f"blocks.{layer}.hook_mlp_out",
        device=str(device)
    )[0]
    for layer in [0, 1, 2, 3]
}

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-125M")

Loaded pretrained model roneneldan/TinyStories-33M into HookedTransformer


This SAE has non-empty model_from_pretrained_kwargs. 
For optimal performance, load the model like so:
model = HookedSAETransformer.from_pretrained_no_processing(..., **cfg.model_from_pretrained_kwargs)


In [4]:
dataset = load_dataset(dataset_path, streaming=True)
batch_size = 256
layer = 1

# Get batch_size number of examples from the dataset
batch = list(itertools.islice(dataset["train"], batch_size))
tokenizer.pad_token = tokenizer.eos_token

# Extract the text from each example and tokenize
tokens = t.tensor(
    tokenizer(
        [example["text"] for example in batch],
        padding=True,
        truncation=True,
        return_tensors="pt"
    )["input_ids"],
    device=str(device)
)
print(tokens.shape)



torch.Size([256, 304])


  tokens = t.tensor(


In [5]:
cfg=SaeVisConfig(
        hook_point=f"blocks.{layer}.hook_mlp_out",
        features=list(range(256)),
        minibatch_size_features=64,
        minibatch_size_tokens=256,
        device="cuda",
        dtype="bfloat16"
    )

data = SaeVisRunner(cfg).run(encoder=ts_saes[layer], model=ts_model, tokens=tokens)


save_feature_centric_vis(sae_vis_data=data, filename=f"feature_dashboard-tinystories33m-layer{layer}.html")



'sae_vis_data = SaeVisData.create(\n    encoder=ts_saes[0],\n    model=ts_model,\n    tokens=tokens,\n    cfg=SaeVisConfig(\n        hook_point="blocks.0.hook_mlp_out",\n        features=list(range(256)),\n        minibatch_size_features=64,\n        minibatch_size_tokens=256,\n        device="cuda",\n        dtype="bfloat16"\n    )\n)'

#### Layer 0 Basic Analysis
4 - materails related to arts and crafts

21 - words related to speaking, "said, cried, shouted, etc"

24 - possesive adjectives, "my, your, his, her, etc"

26 - words related to time, "yesterday, today, tomorrow, etc"

29 - words related to noises

36 - Names of people, characters

