In [13]:
import torch

import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple
from typing import Callable, Any
import einops
class TopK(nn.Module):
    def __init__(self, k: int, postact_fn: Callable = nn.ReLU()) -> None:
        super().__init__()
        self.k = k
        self.postact_fn = postact_fn

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        topk = torch.topk(x, k=self.k, dim=-1)
        values = self.postact_fn(topk.values)
        # make all other values 0
        result = torch.zeros_like(x)
        result.scatter_(-1, topk.indices, values)
        return result

    def state_dict(self, destination=None, prefix="", keep_vars=False):
        state_dict = super().state_dict(destination, prefix, keep_vars)
        state_dict.update({prefix + "k": self.k, prefix + "postact_fn": self.postact_fn.__class__.__name__})
        return state_dict

    @classmethod
    def from_state_dict(cls, state_dict: dict[str, torch.Tensor], strict: bool = True) -> "TopK":
        k = state_dict["k"]
        postact_fn = ACTIVATIONS_CLASSES[state_dict["postact_fn"]]()
        return cls(k=k, postact_fn=postact_fn)
ACTIVATIONS_CLASSES = {
    "ReLU": nn.ReLU,
    "Identity": nn.Identity,
    "TopK": TopK,
}
class SAE(nn.Module):
  def __init__(self, batch_size: int,input_dim: int, expansion_factor: float = 8, device: str = 'cuda'): # Reorder arguments
        super().__init__()
        self.input_dim = input_dim
        self.latent_dim = 16384
        self.dtype=torch.float32
        self.W_dec = nn.Parameter(
            torch.nn.init.normal_(
                torch.empty(
                    self.latent_dim, input_dim, dtype=self.dtype
                )
            )
        )
        self.dtype=torch.float32
        #nn.init.kaiming_uniform_(self.decoder)
        self.W_enc = nn.Parameter(
            torch.empty( self.input_dim, self.latent_dim, dtype=self.dtype)  
        )
        self.W_dec.data = (
            self.W_dec.data / self.W_dec.data.norm(dim=-1, keepdim=True) * .08
        )
        self.W_enc.data = einops.rearrange(
            self.W_dec.data.clone(),
            "d_hidden d_model ->  d_model d_hidden",
        )
        self.b_enc = nn.Parameter(torch.zeros(self.latent_dim, dtype=self.dtype))
        self.b_dec = nn.Parameter(
            torch.zeros((self.input_dim), dtype=self.dtype)
        )
        self.batch_size=batch_size
        self.device=device
        self.l1_coefficient=3.2e-5
  def encode(self, x: torch.Tensor,k) -> torch.Tensor:
        
            
      
      topk=TopK(k=k)
      return topk(x@self.W_enc+self.b_enc)
  def decode(self,encoded: torch.Tensor)-> torch.Tensor:
        return encoded@self.W_dec+self.b_dec
  @torch.autocast(
        "cuda", dtype=torch.bfloat16, enabled=torch.cuda.is_bf16_supported() #speeds up forward by "2x"
    )
  def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    encoded=self.encode(x,128)
    final=self.decode(encoded)
    reconstruction_error_BD = (final - x).pow(2)
    reconstruction_error_B = einops.reduce(reconstruction_error_BD, 'B D -> B', 'sum')
    l2_loss = reconstruction_error_B.mean()

    
    loss = l2_loss

    nonzeros_per_sample = encoded.count_nonzero(dim=1)       

    
    avg_nonzeros = nonzeros_per_sample.float().mean() 
    return final,encoded,loss,avg_nonzeros

In [2]:
pip install sae-vis

Defaulting to user installation because normal site-packages is not writeable
Collecting sae-vis
  Downloading sae_vis-0.3.6-py3-none-any.whl.metadata (5.1 kB)
Collecting datasets<3.0.0,>=2.0.0 (from sae-vis)
  Downloading datasets-2.21.0-py3-none-any.whl.metadata (21 kB)
Collecting einops<0.8.0,>=0.7.0 (from sae-vis)
  Downloading einops-0.7.0-py3-none-any.whl.metadata (13 kB)
Collecting jaxtyping<0.3.0,>=0.2.28 (from sae-vis)
  Downloading jaxtyping-0.2.38-py3-none-any.whl.metadata (6.6 kB)
Downloading sae_vis-0.3.6-py3-none-any.whl (10.6 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.6/10.6 MB[0m [31m21.7 MB/s[0m eta [36m0:00:00[0m [36m0:00:01[0m
[?25hDownloading datasets-2.21.0-py3-none-any.whl (527 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m527.3/527.3 kB[0m [31m75.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading einops-0.7.0-py3-none-any.whl (44 kB)
Downloading jaxtyping-0.2.38-py3-none-any.whl (56 kB)
Installing collected p

In [16]:
chat_dict=torch.load("gemma-2-2b-it-baseline_sae_layer_14_pre_resid-16384.pth")
base_dict=torch.load("gemma-2-2b-baseline_sae_layer_14_pre_resid-16384.pth")

chat_sae=SAE(8,2304).to('cuda')
base_sae=SAE(8,2304).to('cuda')
chat_sae.load_state_dict(chat_dict)
base_sae.load_state_dict(base_dict)

  chat_dict=torch.load("gemma-2-2b-it-baseline_sae_layer_14_pre_resid-16384.pth")
  base_dict=torch.load("gemma-2-2b-baseline_sae_layer_14_pre_resid-16384.pth")


<All keys matched successfully>

In [20]:
from transformer_lens import HookedTransformer

base_model = HookedTransformer.from_pretrained("google/gemma-2-2b").cuda()
chat_model = HookedTransformer.from_pretrained("google/gemma-2-2b-it").cuda()
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b-it")




Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]



Loaded pretrained model google/gemma-2-2b into HookedTransformer
Moving model to device:  cuda


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Loaded pretrained model google/gemma-2-2b-it into HookedTransformer
Moving model to device:  cuda


In [50]:
from datasets import load_dataset

# Login using e.g. `huggingface-cli login` to access this dataset
ds2 = load_dataset("lmsys/lmsys-chat-1m",token='your_token')['train']
len(ds2)
print(ds2[120]['conversation'][1]['content'])

_, cache = chat_model.run_with_cache(
                ds2[120]['conversation'][1]['content'],
                names_filter="blocks.14.hook_resid_pre",
                return_type=None,
            )
acts = cache["blocks.14.hook_resid_pre"]
print(acts.shape)


Using the latest cached version of the dataset since lmsys/lmsys-chat-1m couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration 'default' at /home/user/.cache/huggingface/datasets/lmsys___lmsys-chat-1m/default/0.0.0/200748d9d3cddcc9d782887541057aca0b18c5da (last modified on Tue May 20 00:39:57 2025).


Hello! How can I assist you today?
torch.Size([1, 10, 2304])


In [52]:
new=acts[0,:].squeeze(0)
size,_=new.shape
inputs = tokenizer(ds2[120]['conversation'][1]['content'], return_tensors="pt", truncation=True, padding=True).to('cuda')
for i in range(size):
    
    _,encoded,_,_=chat_sae(new[i,:].unsqueeze(0))
    values,indices=torch.topk(encoded,10)

    
        
    print(i)
    print(indices)
    print(tokenizer.decode(inputs['input_ids'][0][i]))

0
tensor([[11389,  9022, 10867,  9285,  4063, 11498, 10344,  6440,  5050, 15557]],
       device='cuda:0')
<bos>
1
tensor([[11389,  9022, 10867,  9285,  4063,    22, 15719, 11498,  9873,  7449]],
       device='cuda:0')
Hello
2
tensor([[ 8662,  4055,  9505,  7449, 11389,  9873, 11947,  7271,  8044, 15313]],
       device='cuda:0')
!
3
tensor([[ 1138,  9873, 11389,  7449,  9505,  8662, 13715,  9022, 10867, 15313]],
       device='cuda:0')
 How
4
tensor([[ 8807, 11389,  7449,  9505,  9873,  8662,  9022, 15313,  7904,   387]],
       device='cuda:0')
 can
5
tensor([[ 9505,  7449, 11389,  8662,  7809, 13715,  9873,  9285,   387, 10867]],
       device='cuda:0')
 I
6
tensor([[ 7257,  8662, 14637, 11389,  7449,  7904,  9873,  9505,  9022, 10867]],
       device='cuda:0')
 assist
7
tensor([[ 8662,  9505,  8448, 11389,  7449,  7904, 13715,  9285,  9873,  9022]],
       device='cuda:0')
 you
8
tensor([[ 5213,  8662, 11389,  9505, 12006,  7449,  7904, 11856,  9022,  9873]],
       device='cuda:0

In [7]:
pip install sae_lens

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Defaulting to user installation because normal site-packages is not writeable
Collecting sae_lens
  Downloading sae_lens-5.10.3-py3-none-any.whl.metadata (5.3 kB)
Collecting automated-interpretability<1.0.0,>=0.0.5 (from sae_lens)
  Downloading automated_interpretability-0.0.9-py3-none-any.whl.metadata (822 bytes)
Collecting babe<0.0.8,>=0.0.7 (from sae_lens)
  Downloading babe-0.0.7-py3-none-any.whl.metadata (10 kB)
Collecting plotly-express<0.5.0,>=0.4.1 (from sae_lens)
  Downloading plotly_express-0.4.1-py2.py3-none-any.whl.metadata (1.7 kB)
Collecting pytest-profiling<2.0.0,>=1.7.0 (from sae_lens)
  Downloading pytest_profiling-1.8.1-py3-none-any.whl.metadata (15 kB)
Collecting python-dotenv<2.0.0,>=1.0.1 (from sae_lens)
  Downloading python_dotenv-1.1.0-py3-none-any.whl.metadata (24 kB)
Collecting pyzmq==26.0.0 (from sae_lens)
  Downloading pyzmq-26.0.0-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (6.2 kB)
Collecting safetensors<0.5.0,>=0.4.2 (from sae_lens)
  Downloading safete

In [54]:
from sae_vis.data_config_classes import SaeVisConfig
test_feature_idx = [11328]#[147, 507, 963, 994, 1383, 2026, 3738, 3982, 5044, 6310, 6348, 6518, 6592, 6918, 6983, 7079, 7748, 8081, 8489, 8752, 9034, 9134, 9291, 9418, 10335, 10615, 11379, 13708, 14643, 16379]
sae_vis_config = SaeVisConfig(
    features = 7257,
    minibatch_size_tokens=8,
    minibatch_size_features=32,
)

from datasets import load_dataset

# Login using e.g. `huggingface-cli login` to access this dataset
tokens = torch.load('datasets/lmsys_1m_tokens.pt')
from sae_lens import SAE, HookedSAETransformer, SAEConfig
model = HookedSAETransformer.from_pretrained("google/gemma-2-2b-it")
cfg = SAEConfig(
            architecture="standard",
            # forward pass details.
            d_in=2304,
            d_sae=16384,
            activation_fn_str="topk",
            activation_fn_kwargs={"k": 128},
            apply_b_dec_to_input=False,
            finetuning_scaling_factor=False,
            # dataset it was trained on details.
            context_size=1024,
            model_name='google/gemma-2-2b-it',
            hook_name='blocks.14.hook_resid_pre',
            hook_layer=14,
            hook_head_index=None,
            prepend_bos=True,
            dataset_path='lmsys/lmsys-chat-1m',
            dataset_trust_remote_code=False,
            normalize_activations="None",
            # misc
            sae_lens_training_version=None,
            dtype="float32",
            device='cuda',
        )
sae = SAE(cfg)
sae.load_state_dict(chat_dict)
from sae_vis.data_storing_fns import SaeVisData
sae_vis_data = SaeVisData.create(
    sae=sae,
    model=model,
    tokens=tokens[:4096],  # 8192
    cfg=sae_vis_config  # 256
)

  tokens = torch.load('datasets/lmsys_1m_tokens.pt')


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Loaded pretrained model google/gemma-2-2b-it into HookedTransformer


In [55]:
sae_vis_data.save_feature_centric_vis(filename="demo_feature_vis.html", feature=7257)