In [46]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

import interp_tools.saes.jumprelu_sae as jumprelu_sae
import interp_tools.model_utils as model_utils

In [47]:
model_name = "google/gemma-2-2b"
dtype = torch.bfloat16
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=dtype)

Loading checkpoint shards: 100%|██████████| 3/3 [00:05<00:00,  1.69s/it]
Some parameters are on the meta device because they were offloaded to the cpu.


In [48]:
layer = 20

repo_id = "google/gemma-scope-2b-pt-res"
filename = f"layer_{layer}/width_16k/average_l0_71/params.npz"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_name = "google/gemma-2-2b"

sae = jumprelu_sae.load_gemma_scope_jumprelu_sae(repo_id, filename, layer, model_name, device, dtype)


In [53]:
tokenizer = AutoTokenizer.from_pretrained(model_name)

test_input = "The scientist named the population, after their distinctive horn, Ovid’s Unicorn. These four-horned, silver-white unicorns were previously unknown to science"

input = tokenizer(test_input, return_tensors="pt", add_special_tokens=True).to(device)

In [54]:
print(input['input_ids'].shape)

torch.Size([1, 32])


In [55]:
submodule = model_utils.get_submodule(model, layer)

acts_BLD = model_utils.collect_activations(model, submodule, input)
print(acts_BLD.shape)

encoded_acts_BLF = sae.encode(acts_BLD)
print(encoded_acts_BLF.shape)

decoded_acts_BLD = sae.decode(encoded_acts_BLF)
print(decoded_acts_BLD.shape)

torch.Size([1, 32, 2304])
torch.Size([1, 32, 16384])
torch.Size([1, 32, 2304])


In [None]:
l0_BL = (encoded_acts_BLF > 0).sum(dim=-1)
print(l0_BL[0, :10], "As we can see, the L0 norm is very high for the first BOS token, so we'll skip it.")

mean_l0 = l0_BL[:, 1:].float().mean()
print(f"mean l0: {mean_l0.item()}")

total_variance = torch.var(acts_BLD[:, 1:], dim=1).sum()
residual_variance = torch.var(acts_BLD[:, 1:] - decoded_acts_BLD[:, 1:], dim=1).sum()
frac_variance_explained = (1 - residual_variance / total_variance)
print(f"frac_variance_explained: {frac_variance_explained.item()}")

tensor([7019,   29,   97,   68,   68,   96,   69,   59,   65,   80],
       device='cuda:0') As we can see, the L0 norm is very high for the first BOS token, so we'll skip it.
mean l0: 73.61289978027344
frac_variance_explained: 0.7421875
