In [29]:
# 设置环境变量
import os
import sys
sys.path.append('D:\ComputerScience\Research\PRADA\sparse_autoencoder')
# os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
# 导入库
import torch
import blobfile as bf
import sparse_autoencoder
from experiments.utils import update_json_file, update_numpy_file
import pandas as pd
from datetime import datetime
import numpy as np
import re
from transformers import GPT2LMHeadModel, AutoTokenizer, GPT2Tokenizer

In [30]:
# 加载模型
def load_model(model_name):
    model = GPT2LMHeadModel.from_pretrained(model_name, output_hidden_states=True)
    auto_tokenizer = AutoTokenizer.from_pretrained(model_name)
    gpt2_tokenizer = GPT2Tokenizer.from_pretrained(model_name)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    return model, auto_tokenizer, gpt2_tokenizer, device

# 处理输入
def process_input(model, tokenizer, prompt):
    tokens_id = tokenizer(prompt, return_tensors='pt').input_ids.to(model.device)
    tokens_str = tokenizer.convert_ids_to_tokens(tokens_id[0])
    with torch.no_grad():
        outputs = model(tokens_id)
    activation_cache = outputs.hidden_states
    return tokens_id, tokens_str, activation_cache

# 提取激活
def get_activation(activation_cache, layer_index=6, location="resid_post_mlp"):
    transformer_lens_loc = {
        "mlp_post_act": f"blocks.{layer_index}.mlp.hook_post",
        "resid_delta_attn": f"blocks.{layer_index}.hook_attn_out",
        "resid_post_attn": f"blocks.{layer_index}.hook_resid_mid",
        "resid_delta_mlp": f"blocks.{layer_index}.hook_mlp_out",
        "resid_post_mlp": f"blocks.{layer_index}.hook_resid_post",
    }[location]
    return activation_cache[transformer_lens_loc]

# 加载自编码器
def load_autoencoder(location, layer_index, device):
    with bf.BlobFile(sparse_autoencoder.paths.v5_128k(location, layer_index), mode="rb") as f:
        state_dict = torch.load(f)
        autoencoder = sparse_autoencoder.Autoencoder.from_state_dict(state_dict)
        autoencoder.to(device)
    return autoencoder

# 编码和解码激活张量
def encode_decode(autoencoder, input_tensor):
    with torch.no_grad():
        latent_activations, info = autoencoder.encode(input_tensor)
        reconstructed_activations = autoencoder.decode(latent_activations, info)
    return latent_activations, reconstructed_activations

# 计算误差并打印结果
def calculate_normalized_mse(input_tensor, reconstructed_activations):
    normalized_mse = (reconstructed_activations - input_tensor).pow(2).sum(dim=1) / (input_tensor).pow(2).sum(dim=1)
    return normalized_mse

def extract_activations(prompt, tokens, latent_activations, top_k=32, activation_threshold=3):
    activations_dict = {}
    prompt_key = prompt  # 根据需要设置不同的 prompt 标识符

    total_activations_count = 0
    
    # 遍历所有 feature
    for feature_index in range(latent_activations.shape[1]):
        # 获取该 feature 的所有激活值
        feature_activations = latent_activations[:, feature_index]
        
        # 仅提取 top k 非零激活值
        non_zero_activations = feature_activations[(feature_activations != 0) & (feature_activations >= activation_threshold)]
        if non_zero_activations.numel() == 0:
            continue
        top_k_values, top_k_indices = torch.topk(non_zero_activations, min(top_k, non_zero_activations.numel()))

        # 构建特征激活字典
        feature_key = f"Feature {feature_index}"
        activations_dict[feature_key] = {prompt_key: {}}
        for value, index in zip(top_k_values, top_k_indices):
            nonzero_indices = (feature_activations == value).nonzero(as_tuple=True)
            if len(nonzero_indices[0]) == 1:  # 确保只有一个元素
                token_index = nonzero_indices[0].item()
                token = tokens[token_index]
                activations_dict[feature_key][prompt][token] = [f"idx:{token_index}",value.item()]
            else:
                print(f"Skipping ambiguous token index: {nonzero_indices}")

        total_activations_count += len(top_k_values)

    # Print the total number of activations extracted
    print(f"Total activations extracted: {total_activations_count}")

    # Optionally, return the total number of activations
    return activations_dict


In [31]:
model_name = "gpt2"
model, auto_tokenizer, gpt2_tokenizer, device = load_model(model_name)
layer_index = 6
location = "resid_post_mlp"
autoencoder = load_autoencoder(location, layer_index, device)

In [20]:
today = datetime.today().strftime('%Y-%m-%d')
output_folder = f'output/{today}'
os.makedirs(output_folder, exist_ok=True)

In [33]:
prompt = 'Pretend you\'re a introverted person to answer the question.'
# 使用 AutoTokenizer
tokens_id_auto, tokens_str_auto, activation_cache_auto = process_input(model, auto_tokenizer, prompt)
# 使用 GPT2Tokenizer
tokens_id_gpt2, tokens_str_gpt2, activation_cache_gpt2 = process_input(model, gpt2_tokenizer, prompt)
print("Tokens ID (AutoTokenizer):", tokens_id_auto)
print("Tokens String (AutoTokenizer):", tokens_str_auto)
# 打印hidden_states信息，确保正确提取
if activation_cache_auto is not None:
    print(f"Number of layers in hidden_states: {len(activation_cache_auto)}")
else:
    print("No hidden_states found.")

resid_post_mlp = activation_cache_auto[layer_index+1]
print(f"resid_post_mlp for layer {layer_index}:", resid_post_mlp[0].shape if resid_post_mlp is not None else "None")
print(resid_post_mlp[0])

Tokens ID (AutoTokenizer): tensor([[   47,  1186,   437,   345,   821,   257, 18951, 13658,  1048,   284,
          3280,   262,  1808,    13]])
Tokens String (AutoTokenizer): ['P', 'ret', 'end', 'Ġyou', "'re", 'Ġa', 'Ġintro', 'verted', 'Ġperson', 'Ġto', 'Ġanswer', 'Ġthe', 'Ġquestion', '.']
Tokens ID (GPT2Tokenizer): tensor([[   47,  1186,   437,   345,   821,   257, 18951, 13658,  1048,   284,
          3280,   262,  1808,    13]])
Tokens String (GPT2Tokenizer): ['P', 'ret', 'end', 'Ġyou', "'re", 'Ġa', 'Ġintro', 'verted', 'Ġperson', 'Ġto', 'Ġanswer', 'Ġthe', 'Ġquestion', '.']
Number of layers in hidden_states: 13
resid_post_mlp for layer 6: torch.Size([14, 768])
tensor([[-7.2834e-01, -3.3760e+00,  9.7962e-01,  ..., -1.7737e+00,
          1.3011e+00, -4.3048e-01],
        [ 2.3055e+00,  1.3471e+00,  2.8651e+00,  ...,  1.6660e+00,
         -4.2424e-01, -1.2197e+00],
        [-2.5004e-01,  1.7333e+00, -1.1717e+00,  ...,  9.8642e-01,
          3.5260e+00, -1.0241e+00],
        ...,
      

In [23]:
# 设置路径
prompt_folder_path = 'dataset/prompt_1000_pro_2'
os.makedirs(output_folder, exist_ok=True)
# 遍历所有 .parquet 文件
for file_name in os.listdir(prompt_folder_path):
    count = 1
    if file_name.endswith('.parquet'):
        prompt_file_path = os.path.join(prompt_folder_path, file_name)
        data = pd.read_parquet(prompt_file_path)
        data = data[:50]
        for index, row in data.iterrows():
            print(f"current file: {prompt_file_path}")
            prompt_id = row['prompt_id']
            prompt = row['prompt']
            tokens_id, tokens_str, activation_cache = process_input(model, tokenizer, prompt)
            activation = activation_cache[layer_index+1]
            latent_activations, reconstructed_activations = encode_decode(autoencoder, activation[0])

            print(latent_activations.shape)
            print(activation[0].shape)
            print(reconstructed_activations.shape)
            non_zero_count = (latent_activations != 0).sum().item()
            print("Non-zero activation count:", non_zero_count)
            print(f"This is {count}/50 prompt")
            count+=1
            activations_dict = extract_activations(prompt_id, tokens_str, latent_activations, top_k=5)
            
            activations_file_name = file_name.replace('.parquet', '_1000_activation_pro_2_128k.json')
            activations_file_path = os.path.join(output_folder, activations_file_name)
            
            update_json_file(activations_file_path, activations_dict)

current file: dataset/prompt_1000_pro_2\energy_introversion.parquet
torch.Size([57, 131072])
torch.Size([1, 57, 768])
torch.Size([57, 768])
Non-zero activation count: 1823
This is 1/50 prompt
Total activations extracted: 201
current file: dataset/prompt_1000_pro_2\energy_introversion.parquet
torch.Size([94, 131072])
torch.Size([1, 94, 768])
torch.Size([94, 768])
Non-zero activation count: 3007
This is 2/50 prompt
Total activations extracted: 313
current file: dataset/prompt_1000_pro_2\energy_introversion.parquet
torch.Size([122, 131072])
torch.Size([1, 122, 768])
torch.Size([122, 768])
Non-zero activation count: 3903
This is 3/50 prompt
Total activations extracted: 407
current file: dataset/prompt_1000_pro_2\energy_introversion.parquet
torch.Size([135, 131072])
torch.Size([1, 135, 768])
torch.Size([135, 768])
Non-zero activation count: 4319
This is 4/50 prompt
Total activations extracted: 443
current file: dataset/prompt_1000_pro_2\energy_introversion.parquet
torch.Size([138, 131072])
