# 1️⃣ Introduction

> ##### Objectives
>
> * Load model using the `nnsight` library,
> * Learn some basics of HuggingFace models (e.g. tokenization, model output)
> * Use it to extract & visualise GPT-J-6B's internal activations
> * Load sae model corresponding to the model

### Reference:  
Tutoiral:  
> [Gemma Scope 2](https://colab.research.google.com/drive/1NhWjg7n0nhfW--CjtsOdw5A5J_-Bzn4r#scrollTo=nOBcV4om7mrT)

> [SAE_Lens](https://decoderesearch.github.io/SAELens/latest/usage/#using-saes-without-transformerlens)  
https://colab.research.google.com/drive/1RMOvARSFvyqig8yHdsT7lfRQmXOpFlE4#scrollTo=yfDUxRx0wSRl

# Set-Up

In [None]:
try:
    import google.colab  # type: ignore
    from google.colab import output

    COLAB = True
    %pip install sae-lens transformer-lens sae-dashboard datasets
except:
    COLAB = False
    from IPython import get_ipython  # type: ignore

    ipython = get_ipython()
    assert ipython is not None
    ipython.run_line_magic("load_ext", "autoreload")
    ipython.run_line_magic("autoreload", "2")




In [None]:
import logging
import os
import sys
import time
from collections import defaultdict
from pathlib import Path
from tqdm.auto import tqdm
import numpy as np
import pandas as pd

#import circuitsvis as cv
import einops
import numpy as np
import torch
import torch as t
import torch.nn as nn
from IPython.display import display, HTML
from jaxtyping import Float

from openai import OpenAI
from rich import print as rprint
from rich.table import Table
from torch import Tensor
from safetensors.torch import load_file

from transformers import AutoModelForCausalLM, BitsAndBytesConfig, AutoTokenizer
from huggingface_hub import hf_hub_download, notebook_login



# Only doing inference, no need to safe grad to save memory
t.set_grad_enabled(False)


if t.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda" if t.cuda.is_available() else "cpu"

print(f"Device: {device}")

Device: cuda


# 2️⃣ Load Model Using NNSIGHT
Here, we'll discuss some important syntax for interacting with `nnsight` models. Since these models are extensions of HuggingFace models, some of this information (e.g. tokenization) applies to plain HuggingFace models as well as `nnsight` models, and some of it (e.g. forward passes) is specific to `nnsight`, i.e. it would work differently if you just had a standard HuggingFace model. Make sure to keep this distinction in mind, otherwise syntax can get confusing!  

1. Tutoiral: [NNSIGHT Main Page](https://nnsight.net/)
2. Page: [NNSIGHT Remote Statues](https://nnsight.net/status/)

### Model config

Each model comes with a `model.config`, which contains lots of useful information about the model (e.g. number of heads and layers, size of hidden layers, etc.). You can access this with `model.config`. Run the code below to see this in action, and to define some useful variables for later.

In [None]:
from nnsight import CONFIG, LanguageModel


model = LanguageModel("google/gemma-3-1b-pt", device_map="auto", torch_dtype=t.bfloat16)
tokenizer = model.tokenizer

print("=== Entire config === \n", model.config)

config.json:   0%|          | 0.00/880 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


tokenizer_config.json:   0%|          | 0.00/1.16M [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.69M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/33.4M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/35.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/662 [00:00<?, ?B/s]

=== Entire config === 
 Gemma3TextConfig {
  "_sliding_window_pattern": 6,
  "architectures": [
    "Gemma3ForCausalLM"
  ],
  "attention_bias": false,
  "attention_dropout": 0.0,
  "attn_logit_softcapping": null,
  "bos_token_id": 2,
  "cache_implementation": "hybrid",
  "dtype": "bfloat16",
  "eos_token_id": 1,
  "final_logit_softcapping": null,
  "head_dim": 256,
  "hidden_activation": "gelu_pytorch_tanh",
  "hidden_size": 1152,
  "initializer_range": 0.02,
  "intermediate_size": 6912,
  "layer_types": [
    "sliding_attention",
    "sliding_attention",
    "sliding_attention",
    "sliding_attention",
    "sliding_attention",
    "full_attention",
    "sliding_attention",
    "sliding_attention",
    "sliding_attention",
    "sliding_attention",
    "sliding_attention",
    "full_attention",
    "sliding_attention",
    "sliding_attention",
    "sliding_attention",
    "sliding_attention",
    "sliding_attention",
    "full_attention",
    "sliding_attention",
    "sliding_attentio

## Remote Execution

1. [NDIF](https://ndif.us/) is the complementary service to nnsight which allows users to perform interventions without model loading or requiring a GPU. Through NDIF, users can have access to a variety of models, remotely, with a wide range of sizes (up to 400+ Billion parameters!).  
  > You can visit https://nnsight.net/status/ to check current model serving, or, use the nnsight API to get the current status of the backend.

2. API Key: To access remote functionality on NDIF, you need to claim an API key at https://login.ndif.us.
3. Note: Make sure your HF_TOKEN is set to your environment. This step is required for both local and remote execution. More info at: https://huggingface.co/docs/huggingface_hub/en/quick-star

4. NDIF Tutorial:  
https://arena-chapter1-transformer-interp.streamlit.app/[1.4.2]_Function_Vectors_&_Model_Steering#remote-execution

In [None]:
# Hide some info logging messages from nnsight
logging.disable(sys.maxsize)


# If you have an API key & want to work remotely, then set REMOTE = True
# and replace CONFIG.set_default_api_key("API-KEY") with your actual key.
# If not, then leave REMOTE = False.

# ===========
REMOTE = True
# ===========

if REMOTE and not COLAB:
  from dotenv import load_dotenv
  load_dotenv()
  CONFIG.set_default_api_key(os.getenv("NDIF_API_KEY"))
if COLAB and REMOTE:
  from google.colab import userdata
  CONFIG.set_default_api_key(userdata.get('NDIF_API_KEY'))
  os.environ['HF_TOKEN'] = userdata.get('HF_TOKEN')


### Check Available NDIF Remote Model

In [None]:
import nnsight

nnsight.ndif_status()

NDIF Service: Down 🔴
Visit our community support at https://discuss.ndif.us/ or try again later.


{}

### Test Model outputs

At a high level, there are 2 ways to run our model: using the `trace` method (a single forward pass) and the `generate` method (generating multiple tokens). We'll focus on `trace` for now, and we'll discuss `generate` when it comes to multi-token generation later.

The default behaviour of forward passes in normal HuggingFace models is to return an object containing logits (and optionally a bunch of other things). The default behaviour of `trace` in `nnsight` is to not return anything, because anything that we choose to return is explicitly returned inside the context manager.

Below is the simplest example of code to run the model (and also access the internal states of the model). Run it and look at the output, then read the explanation below. Remember to obtain and set an API key first if you're using remote execution!

In [None]:
# Gemma 3,NO REMOTE on NDIF
prompt = "The Eiffel Tower is in the city of"

with model.trace(prompt, remote=False):
    # Save the model's hidden states
    # Corrected attribute access: model.model.layers instead of model.transformer.h
    hidden_states = model.model.layers[-1].output[0].save() #(batch_size, seq_len, d_model)

    # Save the model's logit output
    # h.output[0].shape = (batch, seq, d_model)
    logits = model.lm_head.output[0, -1].save()

# Get the model's logit output, and it's next token prediction
print(f"logits.shape = {logits.shape} = (vocab_size,)")
print("Predicted token ID =", predicted_token_id := logits.argmax().item())
print(f"Predicted token = {tokenizer.decode(predicted_token_id)!r}")

# Print the shape of the model's residual stream
print(f"\nresid.shape = {hidden_states.shape} = (batch_size, seq_len, d_model)")

logits.shape = torch.Size([262144]) = (vocab_size,)
Predicted token ID = 9079
Predicted token = ' Paris'

resid.shape = torch.Size([1, 9, 1152]) = (batch_size, seq_len, d_model)


### Load SAE utilities

In [None]:
# Load SAE utilities
from sae_lens import SAE
from sae_lens.loading.pretrained_saes_directory import get_pretrained_saes_directory

# Example: list firsst 3 available SAE releases
sae_dir = get_pretrained_saes_directory()
sae_df = pd.DataFrame.from_records({k: v.__dict__ for k, v in sae_dir.items()}).T

display(sae_df.head(3))


Unnamed: 0,release,repo_id,model,conversion_func,saes_map,expected_var_explained,expected_l0,neuronpedia_id,config_overrides
deepseek-r1-distill-llama-8b-qresearch,deepseek-r1-distill-llama-8b-qresearch,qresearch/DeepSeek-R1-Distill-Llama-8B-SAE-l19,deepseek-ai/DeepSeek-R1-Distill-Llama-8B,deepseek_r1,{'blocks.19.hook_resid_post': 'DeepSeek-R1-Dis...,{'blocks.19.hook_resid_post': 1.0},{'blocks.19.hook_resid_post': 0.0},{'blocks.19.hook_resid_post': 'deepseek-r1-dis...,
gemma-2-2b-res-matryoshka-dc,gemma-2-2b-res-matryoshka-dc,chanind/gemma-2-2b-batch-topk-matryoshka-saes-...,gemma-2-2b,,{'blocks.0.hook_resid_post': 'standard/blocks....,"{'blocks.0.hook_resid_post': 1.0, 'blocks.1.ho...","{'blocks.0.hook_resid_post': 40.0, 'blocks.1.h...","{'blocks.0.hook_resid_post': None, 'blocks.1.h...",
gemma-2-2b-res-snap-matryoshka-dc,gemma-2-2b-res-snap-matryoshka-dc,chanind/gemma-2-2b-batch-topk-matryoshka-saes-...,gemma-2-2b,,{'blocks.0.hook_resid_post': 'snap/blocks.0.ho...,"{'blocks.0.hook_resid_post': 1.0, 'blocks.1.ho...","{'blocks.0.hook_resid_post': 40.0, 'blocks.1.h...","{'blocks.0.hook_resid_post': None, 'blocks.1.h...",


In [None]:
# Get a Series of unique model names as a pandas Series
unique_models_series = pd.Series(sae_df['model'].unique())
print("\nUnique Model Names available on SAE_lens:\n", unique_models_series.to_string())


Unique Model Names available on SAE_lens:
 0     deepseek-ai/DeepSeek-R1-Distill-Llama-8B
1                                   gemma-2-2b
2                                   gemma-2-9b
3                                  gemma-2b-it
4                                     gemma-2b
5                         google/gemma-3-1b-pt
6                        google/gemma-3-12b-it
7                        google/gemma-3-12b-pt
8                         google/gemma-3-1b-it
9                       google/gemma-3-270m-it
10                         google/gemma-3-270m
11                       google/gemma-3-27b-it
12                       google/gemma-3-27b-pt
13                        google/gemma-3-4b-it
14                        google/gemma-3-4b-pt
15                                 gemma-2-27b
16                               gemma-2-9b-it
17            meta-llama/Llama-3.1-8B-Instruct
18           meta-llama/Llama-3.3-70B-Instruct
19                          openai/gpt-oss-20b
20              

In [None]:
# Filter to only show entries containing 'gemma' in the release name
gemma_saes = sae_df[sae_df['release'].str.contains('gemma-scope-2', case=False, na=False)]
display(gemma_saes[['release','saes_map']].head(40))

Unnamed: 0,release,saes_map
gemma-scope-2-12b-it-att,gemma-scope-2-12b-it-att,{'layer_12_width_16k_l0_big': 'attn_out/layer_...
gemma-scope-2-12b-it-att-all,gemma-scope-2-12b-it-att-all,{'layer_0_width_16k_l0_big': 'attn_out_all/lay...
gemma-scope-2-12b-it-mlp,gemma-scope-2-12b-it-mlp,{'layer_12_width_16k_l0_big': 'mlp_out/layer_1...
gemma-scope-2-12b-it-mlp-all,gemma-scope-2-12b-it-mlp-all,{'layer_0_width_16k_l0_big': 'mlp_out_all/laye...
gemma-scope-2-12b-it-res,gemma-scope-2-12b-it-res,{'layer_12_width_16k_l0_big': 'resid_post/laye...
gemma-scope-2-12b-it-res-all,gemma-scope-2-12b-it-res-all,{'layer_0_width_16k_l0_big': 'resid_post_all/l...
gemma-scope-2-12b-it-transcoders,gemma-scope-2-12b-it-transcoders,{'transcoder/layer_12_width_16k_l0_big': 'tran...
gemma-scope-2-12b-it-transcoders-all,gemma-scope-2-12b-it-transcoders-all,{'layer_0_width_16k_l0_big': 'transcoder_all/l...
gemma-scope-2-12b-pt-att,gemma-scope-2-12b-pt-att,{'layer_12_width_16k_l0_big': 'attn_out/layer_...
gemma-scope-2-12b-pt-att-all,gemma-scope-2-12b-pt-att-all,{'layer_0_width_16k_l0_big': 'attn_out_all/lay...


In [None]:
# Filter to only show entries containing 'gemma' in the release name
gemma_saes = sae_df[sae_df['release'].str.contains('gemma-scope-2-1b-it-res', case=False, na=False)]
display(gemma_saes[['release','saes_map']].head(40))

Unnamed: 0,release,saes_map
gemma-scope-2-1b-it-res,gemma-scope-2-1b-it-res,{'layer_13_width_16k_l0_big': 'resid_post/laye...
gemma-scope-2-1b-it-res-all,gemma-scope-2-1b-it-res-all,{'layer_0_width_16k_l0_big': 'resid_post_all/l...


### Load a specific SAE  
(pick the hook point you want, e.g., a residual stream SAE)  
SAE lens sample :
https://decoderesearch.github.io/SAELens/latest/usage/

In [None]:
# Load a specific SAE (pick the hook point you want, e.g., a residual stream SAE)
# Format for loading SAE :
#   https://decoderesearch.github.io/SAELens/latest/#loading-sparse-autoencoders-from-huggingface
#   https://decoderesearch.github.io/SAELens/latest/usage/
# Availble SAEs:
#   https://huggingface.co/models?library=saelens

release = "gemma-scope-2-1b-it-res"
sae_id = "layer_22_width_65k_l0_medium"
LAYER = 22  # options are {7, 13, 17, 22}
WIDTH = "65k"   # options are {16k, 65k, 262k, 1m}
L0 = "medium"  # options are {small, medium, big}

sae, cfg_dict, sparsity = SAE.from_pretrained_with_cfg_and_sparsity(
    release = release,
    sae_id = f"layer_{LAYER}_width_{WIDTH}_l0_{L0}",
    device = "cpu", # Load to CPU first to avoid SafetensorError
)

print("="*3, "Loaded SAE for hook","="*3)
print(f"{release}/{sae_id}")
print(f"Layer: {LAYER}")
print(f"Width: {WIDTH}")
print(f"L0: {L0}")
print("="*25)

# ====== Key SAE Attributes ======
# Model dimensions
print("\n> Model dimensions")
print(f"Input dimension (d_in): {sae.cfg.d_in}")
print(f"SAE dimension (d_sae): {sae.cfg.d_sae}")
print(f"Expansion factor: {sae.cfg.d_sae / sae.cfg.d_in}")

# Metadata about the SAE
print("\n> SAE metadata")
print(f"Hook name: {sae.cfg.metadata.hook_name}")
print(f"Model name: {sae.cfg.metadata.model_name}")
print(f"Context size: {sae.cfg.metadata.context_size}")

# Hugging Face / NNsight Hook Name (if present)
print("\n> Hugging Face / NNsight Hook Name (if present)")
print(f"Hook name: {sae.cfg.metadata.hf_hook_name}")

# Weights
print("\n> Weight")
print(f"Encoder weights shape: {sae.W_enc.shape}")  # (d_in, d_sae)
print(f"Decoder weights shape: {sae.W_dec.shape}")  # (d_sae, d_in)
print(f"Decoder bias shape: {sae.b_dec.shape}")     # (d_in,)

print()
print("== SAE cfg ==\n")
sae.to(device) # Then move the SAE to the actual device (CUDA in this case)

config.json:   0%|          | 0.00/247 [00:00<?, ?B/s]

resid_post/layer_22_width_65k_l0_medium/(…):   0%|          | 0.00/605M [00:00<?, ?B/s]

=== Loaded SAE for hook ===
gemma-scope-2-1b-it-res/layer_22_width_65k_l0_medium
Layer: 22
Width: 65k
L0: medium

> Model dimensions
Input dimension (d_in): 1152
SAE dimension (d_sae): 65536
Expansion factor: 56.888888888888886

> SAE metadata
Hook name: blocks.22.hook_resid_post
Model name: google/gemma-3-1b-it
Context size: 1024

> Hugging Face / NNsight Hook Name (if present)
Hook name: model.layers.22.output

> Weight
Encoder weights shape: torch.Size([1152, 65536])
Decoder weights shape: torch.Size([65536, 1152])
Decoder bias shape: torch.Size([1152])

== SAE cfg ==



JumpReLUSAE(
  (activation_fn): ReLU()
  (hook_sae_input): HookPoint()
  (hook_sae_acts_pre): HookPoint()
  (hook_sae_acts_post): HookPoint()
  (hook_sae_output): HookPoint()
  (hook_sae_recons): HookPoint()
  (hook_sae_error): HookPoint()
)

Using SAEs with NNsight
nnsight provides a clean interface for model interventions. SAEs integrate naturally with nnsight's tracing API.

In [None]:

prompt = "The Eiffel Tower is located in"

# Extract activations and compute SAE features
with model.trace(prompt):
    # .output[0] is the hidden states
    # .output[1] is the key-value caches ((key_states, value_states) tuple)
    # Access hidden states at layer
    hidden_states = model.model.layers[22].output[0]

    # Save the hidden states
    hidden_states_saved = hidden_states.save()

# Get SAE features outside the trace
with torch.no_grad():
    features = sae.encode(hidden_states_saved)

print(f"Feature activations shape: {features.shape}")
print(f"Average L0: {(features[:, 1:, :] > 0).sum(dim=-1).float().mean().item():.1f}")

Feature activations shape: torch.Size([1, 7, 65536])
Average L0: 979.0


In [None]:
values, indexs = features[0].mean(0).topk(10)
for value, index in zip(values,indexs):
  print(f"{value:>6.1f}|{index}")

# see SAE Tutorial

1726.6|1217
1648.1|1904
1593.5|1497
1435.4|1503
1419.1|908
1284.2|1169
1240.4|121
1090.9|94
1027.1|65093
1012.2|613


In [None]:
feature_idx = 18126

str_toks = tokenizer.tokenize(prompt, add_special_tokens=True)
activations = features[0, :, feature_idx].tolist()

def html_activations(str_toks: list[str], activations: list[float]):
  return "".join(
      f'<span style="background-color: rgba(255,0,0,{v}); padding: 4px 0px;">{t}</span>'
      for t, v in zip(str_toks, np.array(activations) / (1e-6 + np.max(activations)), strict=True)
  )

display(HTML(html_activations(str_toks, activations)))

# Other Test

In [None]:
# GPU Memory measurement helper
class GPUMemMeasure:
  def __init__(self):
    self.model_mem_dict = {}
    self.initial_mem = 0

  def begin(self, model_name: str):
    if torch.cuda.is_available():
      # Clear cache for more accurate measurement
      torch.cuda.empty_cache()
      # Initial memory usage
      self.initial_mem = torch.cuda.memory_allocated()
      print(f"Initial GPU memory allocated: {self.initial_mem / (1024**2):.2f} MB")
      self.model_mem_dict[model_name] = self.initial_mem
    else:
      print("CUDA is not available. Cannot measure GPU memory.")

  def after(self, model_name: str):
    if torch.cuda.is_available():
      model_mem_after = torch.cuda.memory_allocated()
      model_allocated_mem = model_mem_after - self.model_mem_dict[model_name]
      self.model_mem_dict[model_name] = model_allocated_mem # Store for later use if needed
      print(f"GPU memory used by {model_name}: {model_allocated_mem / (1024**2):.2f} MB")
    else:
      print("CUDA is not available. Cannot measure GPU memory.")

  def print_mem(self):
    if torch.cuda.is_available():
      total_mem = torch.cuda.memory_allocated()
      print(f"Total GPU memory currently allocated: {total_mem / (1024**2):.2f} MB")
      total_gpu_mem = torch.cuda.get_device_properties(0).total_memory
      print(f"Total GPU memory available on device 0: {total_gpu_mem / (1024**3):.2f} GB")
      reserved_mem = torch.cuda.memory_reserved()
      print(f"Total GPU memory reserved by PyTorch: {reserved_mem / (1024**2):.2f} MB")
      print()
      for model_mame, model_mem in self.model_mem_dict.items():
        print(f"GPU memory used by {model_mame}: {model_mem / (1024**2):.2f} MB")

    else:
      print("CUDA is not available. Cannot measure GPU memory.")

In [None]:
from nnsight import LanguageModel
from sae_lens import SAE

# Measure GPU Memory usage
mem_measure = GPUMemMeasure()


mem_measure.begin("SAE")
# Load SAE
sae = SAE.from_pretrained(
    release="gemma-scope-2b-pt-res-canonical",
    sae_id="layer_12/width_16k/canonical",
    device="cuda"
)
mem_measure.after("SAE")


# Load model with nnsight

mem_measure.begin("model")
model = LanguageModel("google/gemma-2-2b", device_map="auto", torch_dtype=t.bfloat16)
mem_measure.after("model")

Initial GPU memory allocated: 10270.53 MB
GPU memory used by SAE: 0.00 MB
Initial GPU memory allocated: 10270.53 MB
GPU memory used by model: 0.00 MB


In [None]:

prompt = "The Eiffel Tower is in"

# Extract activations and compute SAE features
with model.trace(prompt):
    # Access hidden states at layer 12
    hidden_states = model.model.layers[12].output[0]

    # Save the hidden states
    hidden_states_saved = hidden_states.save()


# Get SAE features outside the trace
with torch.no_grad():
    features = sae.encode(hidden_states_saved)

print(f"Feature activations shape: {features.shape}")
print(f"Average L0: {(features[:, 1:, :] > 0).sum(dim=-1).float().mean().item():.1f}")

NameError: name 'model' is not defined

In [None]:
hidden_states_memory_bytes = hidden_states_saved.element_size() * hidden_states_saved.nelement()
print(f"Hidden states memory usage: {hidden_states_memory_bytes / (1024**2):.2f} MB")

Hidden states memory usage: 0.03 MB
