In [8]:
# 设置环境变量
import os
# os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
# 导入库
import torch
import blobfile as bf
import transformer_lens
import sparse_autoencoder
from experiments.utils import update_json_file, update_numpy_file
import pandas as pd
from datetime import datetime

In [12]:
import numpy as np
import re
# 加载模型
def load_model(model_name, center_writing_weights=False):
    model = transformer_lens.HookedTransformer.from_pretrained(model_name, center_writing_weights=center_writing_weights)
    device = next(model.parameters()).device
    return model, device

# 处理输入
def process_input(model, prompt):
    tokens_id = model.to_tokens(prompt)  # (1, n_tokens)
    tokens_str = model.to_str_tokens(prompt)
    with torch.no_grad():
        logits, activation_cache = model.run_with_cache(tokens_id, remove_batch_dim=True)
    return tokens_id, tokens_str, activation_cache

# 提取激活
def get_activation(activation_cache, layer_index=6, location="resid_post_mlp"):
    transformer_lens_loc = {
        "mlp_post_act": f"blocks.{layer_index}.mlp.hook_post",
        "resid_delta_attn": f"blocks.{layer_index}.hook_attn_out",
        "resid_post_attn": f"blocks.{layer_index}.hook_resid_mid",
        "resid_delta_mlp": f"blocks.{layer_index}.hook_mlp_out",
        "resid_post_mlp": f"blocks.{layer_index}.hook_resid_post",
    }[location]
    return activation_cache[transformer_lens_loc]

# 加载自编码器
def load_autoencoder(location, layer_index, device):
    with bf.BlobFile(sparse_autoencoder.paths.v5_128k(location, layer_index), mode="rb") as f:
        state_dict = torch.load(f)
        autoencoder = sparse_autoencoder.Autoencoder.from_state_dict(state_dict)
        autoencoder.to(device)
    return autoencoder

# 编码和解码激活张量
def encode_decode(autoencoder, input_tensor):
    with torch.no_grad():
        latent_activations, info = autoencoder.encode(input_tensor)
        reconstructed_activations = autoencoder.decode(latent_activations, info)
    return latent_activations, reconstructed_activations

# 计算误差并打印结果
def calculate_normalized_mse(input_tensor, reconstructed_activations):
    normalized_mse = (reconstructed_activations - input_tensor).pow(2).sum(dim=1) / (input_tensor).pow(2).sum(dim=1)
    return normalized_mse

def extract_activations(prompt, tokens, latent_activations, top_k=32, activation_threshold=3):
    activations_dict = {}
    prompt_key = prompt  # 根据需要设置不同的 prompt 标识符

    total_activations_count = 0
    
    # 遍历所有 feature
    for feature_index in range(latent_activations.shape[1]):
        # 获取该 feature 的所有激活值
        feature_activations = latent_activations[:, feature_index]
        
        # 仅提取 top k 非零激活值
        non_zero_activations = feature_activations[(feature_activations != 0) & (feature_activations >= activation_threshold)]
        if non_zero_activations.numel() == 0:
            continue
        top_k_values, top_k_indices = torch.topk(non_zero_activations, min(top_k, non_zero_activations.numel()))

        # 构建特征激活字典
        feature_key = f"Feature {feature_index}"
        activations_dict[feature_key] = {prompt_key: {}}
        for value, index in zip(top_k_values, top_k_indices):
            nonzero_indices = (feature_activations == value).nonzero(as_tuple=True)
            if len(nonzero_indices[0]) == 1:  # 确保只有一个元素
                token_index = nonzero_indices[0].item()
                token = tokens[token_index]
                activations_dict[feature_key][prompt][token] = [f"idx:{token_index}",value.item()]
            else:
                print(f"Skipping ambiguous token index: {nonzero_indices}")

        total_activations_count += len(top_k_values)

    # Print the total number of activations extracted
    print(f"Total activations extracted: {total_activations_count}")

    # Optionally, return the total number of activations
    return activations_dict


In [3]:
model, device = load_model("gpt2")
layer_index = 6
location = "resid_post_mlp"
autoencoder = load_autoencoder(location, layer_index, device)

Loaded pretrained model gpt2 into HookedTransformer


In [10]:
today = datetime.today().strftime('%Y-%m-%d')
output_folder = f'output/{today}'
os.makedirs(output_folder, exist_ok=True)

In [13]:
# 设置路径
prompt_folder_path = 'dataset/prompt_1000_pro_2'
os.makedirs(output_folder, exist_ok=True)
# 遍历所有 .parquet 文件
for file_name in os.listdir(prompt_folder_path):
    count = 1
    if file_name.endswith('.parquet'):
        prompt_file_path = os.path.join(prompt_folder_path, file_name)
        data = pd.read_parquet(prompt_file_path)
        data = data[:50]
        for index, row in data.iterrows():
            print(f"current file: {prompt_file_path}")
            prompt_id = row['prompt_id']
            prompt = row['prompt']
            tokens_id, tokens_str, activation_cache = process_input(model, prompt)
            activation = get_activation(activation_cache, layer_index)
            latent_activations, reconstructed_activations = encode_decode(autoencoder, activation)

            print(latent_activations.shape)
            print(activation.shape)
            print(reconstructed_activations.shape)
            non_zero_count = (latent_activations != 0).sum().item()
            print("Non-zero activation count:", non_zero_count)
            print(f"This is {count}/50 prompt")
            count+=1
            activations_dict = extract_activations(prompt_id, tokens_str, latent_activations, top_k=5)
            
            activations_file_name = file_name.replace('.parquet', '_1000_activation_pro_2_128k.json')
            activations_file_path = os.path.join(output_folder, activations_file_name)
            
            update_json_file(activations_file_path, activations_dict)

current file: dataset/prompt_1000_pro_2\decision_feeling.parquet
torch.Size([174, 131072])
torch.Size([174, 768])
torch.Size([174, 768])
Non-zero activation count: 5568
This is 1/50 prompt
Total activations extracted: 577
current file: dataset/prompt_1000_pro_2\decision_feeling.parquet
torch.Size([184, 131072])
torch.Size([184, 768])
torch.Size([184, 768])
Non-zero activation count: 5888
This is 2/50 prompt
Total activations extracted: 586
current file: dataset/prompt_1000_pro_2\decision_feeling.parquet
torch.Size([225, 131072])
torch.Size([225, 768])
torch.Size([225, 768])
Non-zero activation count: 7200
This is 3/50 prompt
Total activations extracted: 693
current file: dataset/prompt_1000_pro_2\decision_feeling.parquet
torch.Size([95, 131072])
torch.Size([95, 768])
torch.Size([95, 768])
Non-zero activation count: 3040
This is 4/50 prompt
Total activations extracted: 317
current file: dataset/prompt_1000_pro_2\decision_feeling.parquet
torch.Size([217, 131072])
torch.Size([217, 768])
t

In [8]:
prompt = 'Pretend you’re a perceiving person to answer the question.'
tokens_id, tokens_str, activation_cache = process_input(model, prompt)
tokens_str

['<|endoftext|>',
 'P',
 'ret',
 'end',
 ' you',
 '�',
 '�',
 're',
 ' a',
 ' perce',
 'iving',
 ' person',
 ' to',
 ' answer',
 ' the',
 ' question',
 '.']