In [3]:
import torch
import sparse_autoencoder

import blobfile as bf
import transformer_lens

  from .autonotebook import tqdm as notebook_tqdm


In [18]:
# Extract neuron activations with transformer_lens
model_ht = transformer_lens.HookedTransformer.from_pretrained("gpt2", center_writing_weights=False)

sd_ht = model_ht.state_dict()
for k, v in sd_ht.items():
    print(k, v.shape)

print(model_ht)
print(next(model_ht.parameters()))

Loaded pretrained model gpt2 into HookedTransformer
embed.W_E torch.Size([50257, 768])
pos_embed.W_pos torch.Size([1024, 768])
blocks.0.attn.W_Q torch.Size([12, 768, 64])
blocks.0.attn.W_K torch.Size([12, 768, 64])
blocks.0.attn.W_V torch.Size([12, 768, 64])
blocks.0.attn.W_O torch.Size([12, 64, 768])
blocks.0.attn.b_Q torch.Size([12, 64])
blocks.0.attn.b_K torch.Size([12, 64])
blocks.0.attn.b_V torch.Size([12, 64])
blocks.0.attn.b_O torch.Size([768])
blocks.0.attn.mask torch.Size([1024, 1024])
blocks.0.attn.IGNORE torch.Size([])
blocks.0.mlp.W_in torch.Size([768, 3072])
blocks.0.mlp.b_in torch.Size([3072])
blocks.0.mlp.W_out torch.Size([3072, 768])
blocks.0.mlp.b_out torch.Size([768])
blocks.1.attn.W_Q torch.Size([12, 768, 64])
blocks.1.attn.W_K torch.Size([12, 768, 64])
blocks.1.attn.W_V torch.Size([12, 768, 64])
blocks.1.attn.W_O torch.Size([12, 64, 768])
blocks.1.attn.b_Q torch.Size([12, 64])
blocks.1.attn.b_K torch.Size([12, 64])
blocks.1.attn.b_V torch.Size([12, 64])
blocks.1.att

In [22]:
model_ht.to_tokens("hello world")

tensor([[50256, 31373,   995]], device='mps:0')

In [28]:
#attempt to autodetect device
device = "cpu"
if torch.cuda.is_available():
    device = "cuda"
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
    device = "mps"  
print("Using device: ", device)

prompt = "This is an example of a prompt that"
tokens = model_ht.to_tokens(prompt)  # (1, n_tokens)

with torch.no_grad():
    logits, activation_cache = model_ht.run_with_cache(tokens, remove_batch_dim=True)

print(logits.size())
print(activation_cache)

Using device:  mps
torch.Size([1, 9, 50257])
ActivationCache with keys ['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre', 'blocks.0.ln1.hook_scale', 'blocks.0.ln1.hook_normalized', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_mid', 'blocks.0.ln2.hook_scale', 'blocks.0.ln2.hook_normalized', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.ln1.hook_scale', 'blocks.1.ln1.hook_normalized', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'blocks.1.attn.hook_v', 'blocks.1.attn.hook_attn_scores', 'blocks.1.attn.hook_pattern', 'blocks.1.attn.hook_z', 'blocks.1.hook_attn_out', 'blocks.1.hook_resid_mid', 'blocks.1.ln2.hook_scale', 'blocks.1.ln2.hook_normalized', 'blocks.1.mlp.hook_pre', 'blocks.1.mlp.hook_post', 'blocks.1.hook_mlp_out', 

In [31]:
layer_index = 6
location = "resid_post_mlp"

transformer_lens_loc = {
    "mlp_post_act": f"blocks.{layer_index}.mlp.hook_post",
    "resid_delta_attn": f"blocks.{layer_index}.hook_attn_out",
    "resid_post_attn": f"blocks.{layer_index}.hook_resid_mid",
    "resid_delta_mlp": f"blocks.{layer_index}.hook_mlp_out",
    "resid_post_mlp": f"blocks.{layer_index}.hook_resid_post",
}[location]

print("transformer_lens_loc :", transformer_lens_loc)




transformer_lens_loc : blocks.6.hook_resid_post


In [50]:
with bf.BlobFile(sparse_autoencoder.paths.v5_32k(location, layer_index), mode="rb") as f:
    state_dict = torch.load(f)
    autoencoder = sparse_autoencoder.Autoencoder.from_state_dict(state_dict)
    autoencoder.to(device)


sd_oa = autoencoder.state_dict()

for k, v in sd_oa.items():
    if type(v) == torch.Tensor:
        print(k, v.shape)
    else:
        print(k, v)

print(autoencoder)

pre_bias torch.Size([768])
latent_bias torch.Size([32768])
stats_last_nonzero torch.Size([32768])
latents_activation_frequency torch.Size([32768])
latents_mean_square torch.Size([32768])
encoder.weight torch.Size([32768, 768])
activation.k 32
activation.postact_fn ReLU
decoder.weight torch.Size([768, 32768])
activation TopK
activation_state_dict OrderedDict([('k', 32), ('postact_fn', 'ReLU')])
Autoencoder(
  (encoder): Linear(in_features=768, out_features=32768, bias=False)
  (activation): TopK(
    (postact_fn): ReLU()
  )
  (decoder): Linear(in_features=32768, out_features=768, bias=False)
)




In [49]:
input_tensor = activation_cache[transformer_lens_loc]

print("input tensor ", input_tensor.shape)

with torch.no_grad():
    latent_activations, info = autoencoder.encode(input_tensor)
    print("latent_activations: ", latent_activations.shape)
    print("info: ",info)
    reconstructed_activations = autoencoder.decode(latent_activations, info)
    print("reconstructed_activations: ", reconstructed_activations.shape)

normalized_mse = (reconstructed_activations - input_tensor).pow(2).sum(dim=1) / (input_tensor).pow(2).sum(dim=1)
print(location, normalized_mse)

input tensor  torch.Size([9, 768])
latent_activations:  torch.Size([9, 32768])
info:  {'mu': tensor([[4.7377],
        [0.0663],
        [0.0386],
        [0.0584],
        [0.0425],
        [0.0338],
        [0.0644],
        [0.0438],
        [0.0534]], device='mps:0'), 'std': tensor([[111.3588],
        [  3.3521],
        [  3.2112],
        [  3.0729],
        [  3.3707],
        [  3.0138],
        [  2.8864],
        [  3.6442],
        [  3.2611]], device='mps:0')}
reconstructed_activations:  torch.Size([9, 768])
resid_post_mlp tensor([6.4424e-05, 3.9219e-02, 3.1569e-02, 4.6317e-02, 7.1058e-02, 4.7744e-02,
        6.2675e-02, 6.7039e-02, 7.5507e-02], device='mps:0')
