In [30]:
import torch

from transformers import AutoModelForCausalLM
from transformer_lens import HookedTransformer

from core.sae import SparseAutoEncoder
from core.config import SAEConfig, ActivationStoreConfig, LanguageModelConfig
from core.activation.token_source import TokenSource
from core.activation.activation_source import TokenActivationSource
from core.activation.activation_store import ActivationStore

In [44]:
device = 'cuda:6'

In [45]:
hf_model = AutoModelForCausalLM.from_pretrained('gpt2')
model = HookedTransformer.from_pretrained('gpt2', device=device, hf_model=hf_model)

Loaded pretrained model gpt2 into HookedTransformer


In [132]:
hook_point = 'blocks.1.hook_mlp_out'
sentences = [
    "I went to the store [",
    "She told me a secret [I promised not to tell",
    "The cat [which was black",
    "We planned a trip to the beach [but it rained",
    "He bought a new car [a red convertible",
    "They decided to cancel the event [due to low attendance",
    "She brought her dog [a golden retriever",
    "He received a gift [a watch",
    "The book [which was on sale",
    "They attended a party [hosted by their neighbors"
]



In [133]:
dict_file_names = ['ft-L0A-bs-4096-lr-8e-5-32x', 'ft-L0M-bs-4096-lr-8e-5-32x',
                   'ft-L1A-bs-4096-lr-8e-5-32x', 'ft-L1M-bs-4096-lr-8e-5-32x',
                   'ft-L2A-bs-4096-lr-8e-5-32x', 'ft-L2M-bs-4096-lr-8e-5-32x',
                   'ft-L3A-bs-4096-lr-8e-5-32x', 'ft-L3M-bs-4096-lr-8e-5-32x',
                   'ft-L4A-bs-4096-lr-8e-5-32x', 'ft-L4M-bs-4096-lr-8e-5-32x',
                   'ft-L5A-bs-4096-lr-8e-5-32x', 'ft-L5M-bs-4096-lr-8e-5-32x',
                   'ft-L6A-bs-4096-lr-8e-5-32x', 'ft-L6M-bs-4096-lr-8e-5-32x',
                   'ft-L7A-bs-4096-lr-8e-5-32x', 'ft-L7M-bs-4096-lr-8e-5-32x',
                   'ft-L8A-bs-4096-lr-8e-5-32x', 'ft-L8M-bs-4096-lr-8e-5-32x',
                   'ft-L9A-bs-4096-lr-8e-5-32x', 'ft-L9M-bs-4096-lr-8e-5-32x',
                   'ft-L10A-bs-4096-lr-8e-5-32x', 'ft-L10M-bs-4096-lr-8e-5-32x',
                   'ft-L11A-bs-4096-lr-8e-5-32x', 'ft-L11M-bs-4096-lr-8e-5-32x',]

layer_to_dict = {'blocks.0.hook_attn_out': 'ft-L0A-bs-4096-lr-8e-5-32x', 'blocks.0.hook_mlp_out': 'ft-L0M-bs-4096-lr-8e-5-32x',
                 'blocks.1.hook_attn_out': 'ft-L1A-bs-4096-lr-8e-5-32x', 'blocks.1.hook_mlp_out': 'ft-L1M-bs-4096-lr-8e-5-32x',
                 'blocks.2.hook_attn_out': 'ft-L2A-bs-4096-lr-8e-5-32x', 'blocks.2.hook_mlp_out': 'ft-L2M-bs-4096-lr-8e-5-32x',
                 'blocks.3.hook_attn_out': 'ft-L3A-bs-4096-lr-8e-5-32x', 'blocks.3.hook_mlp_out': 'ft-L3M-bs-4096-lr-8e-5-32x',
                 'blocks.4.hook_attn_out': 'ft-L4A-bs-4096-lr-8e-5-32x', 'blocks.4.hook_mlp_out': 'ft-L4M-bs-4096-lr-8e-5-32x',
                 'blocks.5.hook_attn_out': 'ft-L5A-bs-4096-lr-8e-5-32x', 'blocks.5.hook_mlp_out': 'ft-L5M-bs-4096-lr-8e-5-32x',
                 'blocks.6.hook_attn_out': 'ft-L6A-bs-4096-lr-8e-5-32x', 'blocks.6.hook_mlp_out': 'ft-L6M-bs-4096-lr-8e-5-32x',
                 'blocks.7.hook_attn_out': 'ft-L7A-bs-4096-lr-8e-5-32x', 'blocks.7.hook_mlp_out': 'ft-L7M-bs-4096-lr-8e-5-32x',
                 'blocks.8.hook_attn_out': 'ft-L8A-bs-4096-lr-8e-5-32x', 'blocks.8.hook_mlp_out': 'ft-L8M-bs-4096-lr-8e-5-32x',
                 'blocks.9.hook_attn_out': 'ft-L9A-bs-4096-lr-8e-5-32x', 'blocks.9.hook_mlp_out': 'ft-L9M-bs-4096-lr-8e-5-32x',
                 'blocks.10.hook_attn_out': 'ft-L10A-bs-4096-lr-8e-5-32x', 'blocks.10.hook_mlp_out': 'ft-L10M-bs-4096-lr-8e-5-32x',
                 'blocks.11.hook_attn_out': 'ft-L11A-bs-4096-lr-8e-5-32x', 'blocks.11.hook_mlp_out': 'ft-L11M-bs-4096-lr-8e-5-32x',}

In [134]:
SAEconfig = SAEConfig(**SAEConfig.get_hyperparameters(layer_to_dict[hook_point], '/remote-home/share/research/mechinterp/gpt2-dictionary/ftresults', 'final.pt'), device=device, seed=42,)
sae = SparseAutoEncoder(cfg=SAEconfig)
sae.load_state_dict(torch.load(sae.cfg.from_pretrained_path, map_location=sae.cfg.device)["sae"])

<All keys matched successfully>

In [135]:
tokens = model.to_tokens(sentences, prepend_bos=False)
_, cache = model.run_with_cache(tokens, names_filter=[hook_point])
activation = cache[hook_point]

_, (_, aux) = sae(activation)
feature_acts = aux["feature_acts"]
print(feature_acts.shape)

torch.Size([10, 12, 24576])


In [136]:
# 单个符号
special_token = model.to_tokens(' [', prepend_bos=False)[0].item()
match = torch.eq(tokens, special_token)
indices = torch.nonzero(match)  # 只取每行第一个匹配的索引
print(indices.shape[0])
features = torch.tensor([feature_acts[indices[i][0]][indices[i][1]].tolist() for i in range(0, indices.shape[0])])
print(features.shape)
min_features, _ = torch.min(features, dim=0)
top_values, top_indices = torch.topk(min_features, k = 10)
print(top_indices)
print(top_values)

10
torch.Size([10, 24576])
tensor([21675,  4230, 15926, 21732,     3,     7,     8,     1,     4,     9])
tensor([10.4506,  2.8440,  1.0607,  0.3649,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000])


In [137]:
# 整体语义
temp, _ = torch.max(feature_acts, dim=1)
final, _ = torch.min(temp, dim=0)
top_values, top_indices = torch.topk(final, k = 10)
print(top_indices)
print(top_values)

tensor([22223,  3220, 21675,  4230, 15074,   464, 15926,  1231, 17077, 20381],
       device='cuda:6')
tensor([14.5024, 13.3414, 10.4506,  2.8440,  2.2064,  1.5701,  1.0607,  0.9122,
         0.6146,  0.4771], device='cuda:6', grad_fn=<TopkBackward0>)
