# 1️⃣ Introduction

> ##### Objectives
>
> * Load model using the `nnsight` library,
> * Learn some basics of HuggingFace models (e.g. tokenization, model output)
> * Use it to extract & visualise GPT-J-6B's internal activations
> * Load sae model corresponding to the model

### Reference:  
Tutoiral:  
> [Gemma Scope 2](https://colab.research.google.com/drive/1NhWjg7n0nhfW--CjtsOdw5A5J_-Bzn4r#scrollTo=nOBcV4om7mrT)

> [SAE_Lens](https://decoderesearch.github.io/SAELens/latest/usage/#using-saes-without-transformerlens)  
https://colab.research.google.com/drive/1RMOvARSFvyqig8yHdsT7lfRQmXOpFlE4#scrollTo=yfDUxRx0wSRl

> [ARENA](https://arena-chapter1-transformer-interp.streamlit.app/)

# Set-Up

In [22]:
try:
    import google.colab  # type: ignore
    from google.colab import output

    IS_COLAB = True
    %pip install sae-lens transformer-lens sae-dashboard datasets
except:
    IS_COLAB = False
    from IPython import get_ipython  # type: ignore

    ipython = get_ipython()
    assert ipython is not None
    ipython.run_line_magic("load_ext", "autoreload")
    ipython.run_line_magic("autoreload", "2")


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [23]:
import logging
import os
import sys
import time
from collections import defaultdict
from pathlib import Path
from tqdm.auto import tqdm
import numpy as np
import pandas as pd

#import circuitsvis as cv
import einops
import numpy as np
import torch
import torch as t
import torch.nn as nn
from IPython.display import display, HTML
from jaxtyping import Float

from openai import OpenAI
from rich import print as rprint
from rich.table import Table
from torch import Tensor
from safetensors.torch import load_file

from transformers import AutoModelForCausalLM, BitsAndBytesConfig, AutoTokenizer


# Only doing inference, no need to safe grad to save memory
t.set_grad_enabled(False)

# Hide some info logging messages from nnsight
logging.disable(sys.maxsize)


# Check device
if t.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda" if t.cuda.is_available() else "cpu"

print(f"Torch Device: {device}")




Torch Device: cuda


In [24]:

# =========================================
# NNSight and HuggingFace API Configuration
# =========================================

from nnsight import CONFIG

# If you have an API key & want to work remotely, 
# then set REMOTE = True, if not, then leave REMOTE = False.

# ===========
REMOTE = True
# ===========


if REMOTE and not IS_COLAB:

  from dotenv import load_dotenv
  if(load_dotenv()):
    # == Set API from .env file ==
    print("> Running on local device. API loaded from .env file")

    # > NNsight API Key
    CONFIG.set_default_api_key(os.getenv("NDIF_API_KEY")) 
    # > HF Token from .env file  
    os.environ['HF_TOKEN'] = os.getenv("HF_TOKEN")    

  else:
    raise ValueError("REMOTE is set to True but no .env file found.\n Please create a .env file or set API in the code.")
    
if IS_COLAB and REMOTE:
  print("> Running on Colab")
  from google.colab import userdata
  print("API loaded from Colab userdata")
  # > NNsight API Key 
  CONFIG.set_default_api_key(userdata.get('NDIF_API_KEY'))
  # > HF Token from 
  os.environ['HF_TOKEN'] = userdata.get('HF_TOKEN')

> Running on local device. API loaded from .env file


# 2️⃣ Load Model Using NNSIGHT
 Since these models are extensions of HuggingFace models, some of this information (e.g. tokenization) applies to plain HuggingFace models as well as `nnsight` models, and some of it (e.g. forward passes) is specific to `nnsight`, i.e. it would work differently if you just had a standard HuggingFace model. Make sure to keep this distinction in mind, otherwise syntax can get confusing!  

1. Tutoiral: [NNSIGHT Main Page](https://nnsight.net/)
2. Page: [NNSIGHT Remote Statues](https://nnsight.net/status/)

### Model config

Each model comes with a `model.config`, which contains lots of useful information about the model (e.g. number of heads and layers, size of hidden layers, etc.). You can access this with `model.config`. Run the code below to see this in action, and to define some useful variables for later.

In [32]:
from nnsight import LanguageModel

# Load model
model = LanguageModel("google/gemma-3-1b-pt", device_map="auto", torch_dtype=t.bfloat16)
tokenizer = model.tokenizer

# Print model data
from tabulate import tabulate
print(
    tabulate(
        [(k, str(v)) for k, v in model.config.to_dict().items()],
        headers=["Config Key", "Value"],
        tablefmt="github",
    )
)

| Config Key                       | Value                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  |
|----------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

In [None]:

prompt = "The Eiffel Tower is in the city of"

# Gemma 3,NO REMOTE on NDIF
with model.trace(prompt, remote=False):
    # Save the model's hidden states
    # Corrected attribute access: model.model.layers instead of model.transformer.h
    hidden_states = model.model.layers[-1].output[0].save() #(batch_size, seq_len, d_model)

    # Save the model's logit output
    # h.output[0].shape = (batch, seq, d_model)
    logits = model.lm_head.output[0, -1].save()

# Get the model's logit output, and it's next token prediction
print(f"logits.shape = {logits.shape} = (vocab_size,)")
print("Predicted token ID =", predicted_token_id := logits.argmax().item())
print(f"Predicted token = {tokenizer.decode(predicted_token_id)!r}")

# Print the shape of the model's residual stream
print(f"\nresid.shape = {hidden_states.shape} = (batch_size, seq_len, d_model)")

logits.shape = torch.Size([262144]) = (vocab_size,)
Predicted token ID = 9079
Predicted token = ' Paris'

resid.shape = torch.Size([1, 9, 1152]) = (batch_size, seq_len, d_model)


# 3️⃣ Load SAE Lens

### Load SAE utilities

In [27]:
from sae_lens.loading.pretrained_saes_directory import get_pretrained_saes_directory
from tabulate import tabulate

metadata_rows = [
    [data.model, data.release, data.repo_id, len(data.saes_map)]
    for data in get_pretrained_saes_directory().values()
]

# Print all SAE releases, sorted by base model
print(
    tabulate(
        sorted(metadata_rows, key=lambda x: x[0]),
        headers=["model", "release", "repo_id", "n_saes"],
        tablefmt="simple_outline",
    )
)

┌──────────────────────────────────────────┬─────────────────────────────────────────────────────┬───────────────────────────────────────────────────────────┬──────────┐
│ model                                    │ release                                             │ repo_id                                                   │   n_saes │
├──────────────────────────────────────────┼─────────────────────────────────────────────────────┼───────────────────────────────────────────────────────────┼──────────┤
│ Qwen/Qwen2.5-7B-Instruct                 │ qwen2.5-7b-instruct-andyrdt                         │ andyrdt/saes-qwen2.5-7b-instruct                          │        7 │
│ deepseek-ai/DeepSeek-R1-Distill-Llama-8B │ llama_scope_r1_distill                              │ fnlp/Llama-Scope-R1-Distill                               │       96 │
│ deepseek-ai/DeepSeek-R1-Distill-Llama-8B │ deepseek-r1-distill-llama-8b-qresearch              │ qresearch/DeepSeek-R1-Distill-Llama-8B-SAE-l19     

### Different SAEs in model
Any given SAE release may have multiple different mdoels. These might have been trained on different hookpoints or layers in the model, or with different hyperparameters, etc. You can see the data associated with each release as follows:

In [28]:
def format_value(value):
    return (
        "{{{0!r}: {1!r}, ...}}".format(*next(iter(value.items())))
        if isinstance(value, dict)
        else repr(value)
    )


release = get_pretrained_saes_directory()["gemma-scope-2-1b-it-res"]

print(
    tabulate(
        [[k, format_value(v)] for k, v in release.__dict__.items()],
        headers=["Field", "Value"],
        tablefmt="simple_outline",
    )
)

┌────────────────────────┬────────────────────────────────────────────────────────────────────────────┐
│ Field                  │ Value                                                                      │
├────────────────────────┼────────────────────────────────────────────────────────────────────────────┤
│ release                │ 'gemma-scope-2-1b-it-res'                                                  │
│ repo_id                │ 'google/gemma-scope-2-1b-it'                                               │
│ model                  │ 'google/gemma-3-1b-it'                                                     │
│ conversion_func        │ 'gemma_3'                                                                  │
│ saes_map               │ {'layer_13_width_16k_l0_big': 'resid_post/layer_13_width_16k_l0_big', ...} │
│ expected_var_explained │ {'layer_13_width_16k_l0_big': 1.0, ...}                                    │
│ expected_l0            │ {'layer_13_width_16k_l0_big': 150, ..

Let's get some more info about each of the SAEs associated with each release. We can print out the SAE id, the path (i.e. in the HuggingFace repo, which points to the SAE model weights) and the Neuronpedia ID (which is how we'll get feature dashboards - more on this soon).

In [29]:
data = [[id, path, release.neuronpedia_id[id]] for id, path in release.saes_map.items()]

print(
    tabulate(
        data,
        headers=["SAE id", "SAE path (HuggingFace)", "Neuronpedia ID"],
        tablefmt="simple_outline",
    )
)

┌──────────────────────────────────────┬─────────────────────────────────────────────────┬────────────────────────────────────────┐
│ SAE id                               │ SAE path (HuggingFace)                          │ Neuronpedia ID                         │
├──────────────────────────────────────┼─────────────────────────────────────────────────┼────────────────────────────────────────┤
│ layer_13_width_16k_l0_big            │ resid_post/layer_13_width_16k_l0_big            │                                        │
│ layer_13_width_16k_l0_medium         │ resid_post/layer_13_width_16k_l0_medium         │ gemma-3-1b-it/13-gemmascope-2-res-16k  │
│ layer_13_width_16k_l0_small          │ resid_post/layer_13_width_16k_l0_small          │                                        │
│ layer_13_width_1m_l0_big             │ resid_post/layer_13_width_1m_l0_big             │                                        │
│ layer_13_width_1m_l0_medium          │ resid_post/layer_13_width_1m_l0_med