# Caching Autoencoder Activations

Here we show a minimal example of how to cache autoencoder activations after loading the autoencoders into a model. We use the Gemma autoencoders for this example, but the procedure is the same for the other autoencoders.

## Loading the Autoencoders

This code could be substituted for any of the other autoencoders (see the [loading autoencoders](loading_saes.ipynb) example).


In [None]:
from nnsight import LanguageModel
from sae_auto_interp.autoencoders import load_gemma_autoencoders


In [None]:
# Load the model
model = LanguageModel("google/gemma-2-9b", device_map="cuda", dispatch=True,torch_dtype="bfloat16")

# Load the autoencoders, the function returns a dictionary of the submodules with the autoencoders and the edited model.
# it takes as arguments the model, the layers to load the autoencoders into,
# the average L0 sparsity per layer, the size of the autoencoders and the type of autoencoders (residuals or MLPs).

submodule_dict,model = load_gemma_autoencoders(
            model,
            layers=[10],
            average_l0s={10: 47},
            size="131k",
            type="res"
        )

## Loading the tokens and creating the cache

In [1]:
from sae_auto_interp.config import CacheConfig
from sae_auto_interp.features import FeatureCache
from sae_auto_interp.utils import load_tokenized_data


In [None]:
# There is a default cache config that can also be modified when using a "production" script.
cfg = CacheConfig(
    dataset_repo="EleutherAI/rpj-v2-sample",
    dataset_split="train[:1%]",
    batch_size=32,
    ctx_len=256,
    n_tokens=10_000_000,
    n_splits=5,
)



tokens = load_tokenized_data(
        ctx_len=cfg.ctx_len,
        tokenizer=model.tokenizer,
        dataset_repo=cfg.dataset_repo,
        dataset_split=cfg.dataset_split,
)
# Tokens should have the shape (n_batches,ctx_len)



cache = FeatureCache(
    model,
    submodule_dict,
    batch_size = cfg.batch_size,
)

# Running the cache and saving the results

In [None]:


cache.run(cfg.n_tokens, tokens)

cache.save_splits(
    n_splits=cfg.n_splits,  # We split the activation and location indices into different files to make loading faster
    save_dir="latents_gemma"
)

# The config of the cache should be saved with the results such that it can be loaded later.

cache.save_config(
    save_dir="raw_features_gemma",
    cfg=cfg,
    model_name="google/gemma-2-9b"
)