In [None]:
# 设置环境变量
import os
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
# 导入库
import torch
import blobfile as bf
import transformer_lens
import sparse_autoencoder
from experiments.MBTI.utils import extract_activations, update_json_file, count_activations

In [None]:
# 加载模型
def load_model(model_name, center_writing_weights=False):
    model = transformer_lens.HookedTransformer.from_pretrained(model_name, center_writing_weights=center_writing_weights)
    device = next(model.parameters()).device
    return model, device

# 处理输入
def process_input(model, prompt):
    tokens_id = model.to_tokens(prompt)  # (1, n_tokens)
    tokens_str = model.to_str_tokens(prompt)
    with torch.no_grad():
        logits, activation_cache = model.run_with_cache(tokens_id, remove_batch_dim=True)
    return tokens_id, tokens_str, activation_cache

# 提取激活
def get_activation(activation_cache, layer_index=6, location="resid_post_mlp"):
    transformer_lens_loc = {
        "mlp_post_act": f"blocks.{layer_index}.mlp.hook_post",
        "resid_delta_attn": f"blocks.{layer_index}.hook_attn_out",
        "resid_post_attn": f"blocks.{layer_index}.hook_resid_mid",
        "resid_delta_mlp": f"blocks.{layer_index}.hook_mlp_out",
        "resid_post_mlp": f"blocks.{layer_index}.hook_resid_post",
    }[location]
    return activation_cache[transformer_lens_loc]

# 加载自编码器
def load_autoencoder(location, layer_index, device):
    with bf.BlobFile(sparse_autoencoder.paths.v5_32k(location, layer_index), mode="rb") as f:
        state_dict = torch.load(f)
        autoencoder = sparse_autoencoder.Autoencoder.from_state_dict(state_dict)
        autoencoder.to(device)
    return autoencoder

# 编码和解码激活张量
def encode_decode(autoencoder, input_tensor):
    with torch.no_grad():
        latent_activations, info = autoencoder.encode(input_tensor)
        reconstructed_activations = autoencoder.decode(latent_activations, info)
    return latent_activations, reconstructed_activations

# 计算误差并打印结果
def calculate_normalized_mse(input_tensor, reconstructed_activations):
    normalized_mse = (reconstructed_activations - input_tensor).pow(2).sum(dim=1) / (input_tensor).pow(2).sum(dim=1)
    return normalized_mse

In [7]:
model, device = load_model("gpt2")
layer_index = 6
location = "resid_post_mlp"
autoencoder = load_autoencoder(location, layer_index, device)

Loaded pretrained model gpt2 into HookedTransformer


In [8]:
prompt = "This is an example of a prompt that"
tokens_id, tokens_str, activation_cache = process_input(model, prompt)
input_tensor = get_activation(activation_cache, layer_index, location)
latent_activations, reconstructed_activations = encode_decode(autoencoder, input_tensor)

In [9]:
print(latent_activations.shape)
print(input_tensor.shape)
print(reconstructed_activations.shape)
non_zero_count = (latent_activations != 0).sum().item()
print("Non-zero activation count:", non_zero_count)

torch.Size([9, 32768])
torch.Size([9, 768])
torch.Size([9, 768])
Non-zero activation count: 288


In [10]:
activations_dict = extract_activations(prompt, tokens_str, latent_activations)

In [11]:
update_json_file("activation.json", activations_dict)

In [13]:
filename = "activation.json"
activation_count = count_activations(filename)
print(f"Total number of activations: {activation_count}")

Total number of activations: 288
