# Caching Autoencoder Activations

Here we show a minimal example of how to cache autoencoder activations after loading the autoencoders into a model. We use the Gemma autoencoders for this example, but the procedure is the same for the other autoencoders.

## Loading the Autoencoders

This code could be substituted for any of the other autoencoders (see the [loading autoencoders](loading_saes.ipynb) example).


In [2]:
from nnsight import LanguageModel
from sae_auto_interp.autoencoders import load_gemma_autoencoders


In [3]:
# Load the model
model = LanguageModel("google/gemma-2-9b", device_map="cuda:1", dispatch=True,torch_dtype="float16")

# Load the autoencoders, the function returns a dictionary of the submodules with the autoencoders and the edited model.
# it takes as arguments the model, the layers to load the autoencoders into,
# the average L0 sparsity per layer, the size of the autoencoders and the type of autoencoders (residuals or MLPs).

submodule_dict,model = load_gemma_autoencoders(
            model,
            ae_layers=[10],
            average_l0s={10: 47},
            size="131k",
            type="res"
        )

Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]

## Loading the tokens and creating the cache

In [4]:
from sae_auto_interp.config import CacheConfig
from sae_auto_interp.features import FeatureCache
from sae_auto_interp.utils import load_tokenized_data


In [5]:
# There is a default cache config that can also be modified when using a "production" script.
cfg = CacheConfig(
    dataset_repo="EleutherAI/rpj-v2-sample",
    dataset_split="train[:1%]",
    batch_size=8    ,
    ctx_len=256,
    n_tokens=1_000_000,
    n_splits=5,
)



tokens = load_tokenized_data(
        ctx_len=cfg.ctx_len,
        tokenizer=model.tokenizer,
        dataset_repo=cfg.dataset_repo,
        dataset_split=cfg.dataset_split,
        dataset_row=cfg.dataset_row,
)
# Tokens should have the shape (n_batches,ctx_len)



cache = FeatureCache(
    model,
    submodule_dict,
    batch_size = cfg.batch_size,
)

EleutherAI/rpj-v2-sample  train[:1%]


Resolving data files:   0%|          | 0/150 [00:00<?, ?it/s]

dict_keys(['.model.layers.10'])


# Running the cache and saving the results

In [6]:


cache.run(cfg.n_tokens, tokens)

cache.save_splits(
    n_splits=cfg.n_splits,  # We split the activation and location indices into different files to make loading faster
    save_dir="latents"
)

# The config of the cache should be saved with the results such that it can be loaded later.

cache.save_config(
    save_dir="latents",
    cfg=cfg,
    model_name="google/gemma-2-9b"
)

Caching features:   0%|          | 0/488 [00:00<?, ?it/s]You're using a GemmaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)
Caching features:   1%|▏         | 7/488 [00:06<07:17,  1.10it/s, Total Tokens=14,336]


KeyboardInterrupt: 