# Caching Autoencoder Activations

Here we show a minimal example of how to cache autoencoder activations after loading the autoencoders into a model. We use the Gemma autoencoders for this example, but the procedure is the same for the other autoencoders.

## Loading the Autoencoders

This code could be substituted for any of the other autoencoders (see the [loading autoencoders](loading_saes.ipynb) example).


In [5]:
from transformers import AutoModel

from delphi.config import RunConfig
from delphi.sparse_coders import load_hooks_sparse_coders


In [8]:
# Load the model
model = AutoModel.from_pretrained("google/gemma-2-2b",
                                   device_map="cuda",
                                     torch_dtype="float16")

# Load the autoencoders, the function returns a dictionary of the submodules
# with the autoencoders and the edited model.
# it takes as arguments the model, the layers to load the autoencoders into,
# the average L0 sparsity per layer, the size of the autoencoders 
# and the type of autoencoders (residuals or MLPs).

run_cfg = RunConfig(
    sparse_model="google/gemma-scope-2b-pt-res",
    hookpoints=["layer_10/width_16k/average_l0_39"],
)

hookpoint_to_sparse_encode = load_hooks_sparse_coders(model, run_cfg)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

params.npz:   0%|          | 0.00/302M [00:00<?, ?B/s]

## Loading the tokens and creating the cache

In [3]:
from delphi.config import CacheConfig
from delphi.latents import LatentCache
from delphi.utils import load_tokenized_data

In [None]:
# There is a default cache config that can also be modified
# when making a "production" script.
cfg = CacheConfig(
    dataset_repo="EleutherAI/rpj-v2-sample",
    dataset_split="train[:1%]",
    batch_size=8,
    ctx_len=256,
    n_tokens=1_000_000,
    n_splits=5,
)

tokens = load_tokenized_data(
        ctx_len=cfg.ctx_len,
        tokenizer=model.tokenizer,
        dataset_repo=cfg.dataset_repo,
        dataset_split=cfg.dataset_split,
)
# Tokens should have the shape (n_batches,ctx_len)

cache = LatentCache(
    model,
    hookpoint_to_sparse_encode,
    batch_size = cfg.batch_size,
)

# Running the cache and saving the results

In [6]:


cache.run(cfg.n_tokens, tokens)

cache.save_splits(
    n_splits=cfg.n_splits,  
    # We split the activation and location indices into different for faster loading
    save_dir="latents"
)

# The config of the cache should be saved with the results~
# such that it can be loaded later.

cache.save_config(
    save_dir="latents",
    cfg=cfg,
    model_name="google/gemma-2-9b"
)

Caching features:   0%|          | 0/488 [00:00<?, ?it/s]You're using a GemmaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
The 'batch_size' argument of HybridCache is deprecated and will be removed in v4.49. Use the more precisely named 'max_batch_size' argument instead.
The 'batch_size' attribute of HybridCache is deprecated and will be removed in v4.49. Use the more precisely named 'self.max_batch_size' attribute instead.
Caching features: 100%|██████████| 488/488 [04:59<00:00,  1.63it/s, Total Tokens=999,424]


Total tokens processed: 999,424
