In [22]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

import interp_tools.model_utils as model_utils
import interp_tools.saes.batch_topk_sae as batch_topk_sae

Very important note about my Qwen3 SAEs - when training and using these, I filtered out activations that were 10x the median, as I found that Qwen models have random attention sinks 100's of tokens into the sequence (not just on BOS). This happened on around 0.1% of tokens, with the norms being 100-1000x the median.

I'm not sure if this was the best decision, but I empirically found that it improved my final MSE by 3%. I believe it also caused some dead features, but I don't remember for sure.  Refer to last cell for details on how I did it.

This is quite an annoying detail to deal with, as these now must be dealt with whenever using the SAEs. I considered just training on these high activation tokens, but I feel like they should be dealt with no matter what. For example, if I'm calculating attribution scores, features which activate on these tokens may dominate the attribution scores on certain prompts. Maybe we also want to deal with the high activation tokens when computing max activating examples.

The dictionary_learning implementation of the filtering is here: https://github.com/saprmarks/dictionary_learning/blob/main/dictionary_learning/pytorch_buffer.py#L220

And the PR with some more details is here: https://github.com/saprmarks/dictionary_learning/pull/52


In [23]:
model_name = "Qwen/Qwen3-8B"
dtype = torch.bfloat16
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=dtype)

Loading checkpoint shards: 100%|██████████| 5/5 [00:03<00:00,  1.43it/s]


In [24]:
layer = 18

repo_id = "adamkarvonen/qwen3-8b-saes"

# trainer_2 is L0 80, width 65k
filename = f"saes_Qwen_Qwen3-8B_batch_top_k/resid_post_layer_{layer}/trainer_2/ae.pt"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

sae = batch_topk_sae.load_dictionary_learning_batch_topk_sae(
    repo_id=repo_id, filename=filename, model_name=model_name, device=device, dtype=dtype
)


Original keys in state_dict: dict_keys(['b_dec', 'k', 'threshold', 'decoder.weight', 'encoder.weight', 'encoder.bias'])
Renamed keys in state_dict: dict_keys(['b_dec', 'k', 'threshold', 'W_dec', 'W_enc', 'b_enc'])


In [25]:
tokenizer = AutoTokenizer.from_pretrained(model_name)

test_input = "The scientist named the population, after their distinctive horn, Ovid’s Unicorn. These four-horned, silver-white unicorns were previously unknown to science"

input = tokenizer(test_input, return_tensors="pt", add_special_tokens=True).to(device)

In [26]:
print(input['input_ids'].shape)

torch.Size([1, 31])


In [27]:
submodule = model_utils.get_submodule(model, layer)

acts_BLD = model_utils.collect_activations(model, submodule, input)
print(acts_BLD.shape)

torch.Size([1, 31, 4096])


In [28]:
encoded_acts_BLF = sae.encode(acts_BLD)
print(encoded_acts_BLF.shape)

decoded_acts_BLD = sae.decode(encoded_acts_BLF)
print(decoded_acts_BLD.shape)

l0_BL = (encoded_acts_BLF > 0).sum(dim=-1)
print(l0_BL[0, :10], "As we can see, the L0 norm is very high for the first BOS token")

norms_BL = acts_BLD.norm(dim=-1)
print(norms_BL[0, :10], "This is because the norms are very high for the first BOS token")

median_norm = norms_BL.median()
norm_mask_BL = norms_BL > (median_norm * 10)


encoded_acts_BLF = encoded_acts_BLF * ~norm_mask_BL[:, :, None]


l0_BL = (encoded_acts_BLF > 0).sum(dim=-1)
print(l0_BL[0, :10], "You will have to decide how to deal with this. In this case, I'll just zero out the encoded acts")

mean_l0 = l0_BL[:, 1:].float().mean()
print(f"mean l0: {mean_l0.item()}")

print("When calculating variance explained, we'll just ignore the BOS token activations for now. In real sequences, these high norm tokens should be filtered out entirely from the MSE calculation.")

total_variance = torch.var(acts_BLD[:, 1:], dim=1).sum()
residual_variance = torch.var(acts_BLD[:, 1:] - decoded_acts_BLD[:, 1:], dim=1).sum()
frac_variance_explained = (1 - residual_variance / total_variance)
print(f"frac_variance_explained: {frac_variance_explained.item()}")

torch.Size([1, 31, 65536])
torch.Size([1, 31, 4096])
tensor([27361,    69,    54,    88,    90,    91,    79,    84,    82,   123],
       device='cuda:0') As we can see, the L0 norm is very high for the first BOS token
tensor([9216.0000,  108.5000,   94.0000,  101.0000,  102.5000,   90.5000,
          87.0000,   91.5000,   99.5000,   99.0000], device='cuda:0',
       dtype=torch.bfloat16) This is because the norms are very high for the first BOS token
tensor([  0,  69,  54,  88,  90,  91,  79,  84,  82, 123], device='cuda:0') You will have to decide how to deal with this. In this case, I'll just zero out the encoded acts
mean l0: 100.9000015258789
When calculating variance explained, we'll just ignore the BOS token activations for now. In real sequences, these high norm tokens should be filtered out entirely from the MSE calculation.
frac_variance_explained: 0.66015625
