# Loading Autoencoders


Here we show the syntax to load the different autoencoders that we support. Autoencoders are "attached" to the model using the `edit` method from NNsight.

## GemmaScope SAEs

In [None]:
from nnsight import LanguageModel
from sae_auto_interp.autoencoders import load_gemma_autoencoders


In [1]:

# Load the model
model = LanguageModel("google/gemma-2-9b", device_map="cuda", dispatch=True,torch_dtype="bfloat16")

# Load the autoencoders, the function returns a dictionary of the submodules with the autoencoders and the edited model.
# it takes as arguments the model, the layers to load the autoencoders into,
# the average L0 sparsity per layer, the size of the autoencoders and the type of autoencoders (residuals or MLPs).

submodule_dict,model = load_gemma_autoencoders(
            model,
            layers=[10],
            average_l0s={10: 47},
            size="131k",
            type="res"
        )

# The autoencoder is loaded into the submodules dictionary and the model is edited in place.
autoencoder = submodule_dict[".model.layers.10"].ae


## EleutherAI SAEs

In [2]:
from nnsight import LanguageModel
from sae_auto_interp.autoencoders import load_eai_autoencoders


In [3]:

model = LanguageModel("meta-llama/Meta-Llama-3.1-8B", device_map="cpu",dispatch=True,torch_dtype="bfloat16")
    
# The load function takes as arguments the model, the layers to load the autoencoders into,
# The path of the weights (it can be a huggingface path or a local path),
# the type of module to load the autoencoders into (residuals or MLPs),
# whether to randomize the autoencoders or not,
# whether to use the trained k or a specific k on the top-k activation.

submodule_dict,model = load_eai_autoencoders(
    model,
    [23,29],
    weight_dir="EleutherAI/sae-llama-3.1-8b-64x",
    module="mlp",
    randomize=False,
    k=None
)

# The autoencoder is loaded into the submodules dictionary and the model is edited in place.
autoencoder = submodule_dict[".model.layers.23"].ae

config.json:   0%|          | 0.00/826 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/50.5k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/73.0 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/4 [00:00<?, ?it/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.92G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/1.17G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/185 [00:00<?, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

layers.23.mlp/cfg.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

sae.safetensors:   0%|          | 0.00/8.59G [00:00<?, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

sae.safetensors:   0%|          | 0.00/8.59G [00:00<?, ?B/s]

KeyError: '.model.layers.10'