In [17]:
import os
import pathlib

import torch
import blobfile as bf
import transformer_lens
import sparse_autoencoder
import argparse
import pandas as pd

from pathlib import Path


In [28]:
prompt_dir = '/raid/home/s2521923/Pycharm/sparse_autoencoder'
import json 
prompts = json.load(open(os.path.join(prompt_dir, 'personality_mbti.json')))
# print(prompts)
prefix = 'Consider the following story:\n'
# suffix = '\n\nChoose correct answers from options given below. The options can be multiple.\nA. Introversion\nB. Extraversion\nC. Sensing\nD. Intuition\nE. Thinking\nF. Feeling\nG. Judging\nH. Perceiving\nThe current personality is <BLANK>. Replace the <BLANK> with chioces.'
suffix = '\n\nChoose a single answer from options given below to identify how accurately aligned with the above personality:\n'
finals = '\nOptions:\nA. Very Accurate\nB. Moderately Accurate\nC. Neither Accurate Nor Inaccurate\nD. Moderately Inaccurate\nE. Very Inaccurate'
for prompt in prompts:
    prompt['statement'] = f"{prefix}{prompt['statement']}\nPersonality:\n{prompt['description']}{suffix}{finals}"
    
    print(prompt['statement'])

Consider the following story:
Growing up, I was the kid who had a perpetual glow of excitement around him, kindled by the prospect of meeting new friends and sparking conversations. I reveled in the richness of human connections and the stories we shared, each one adding a hue to my vibrant tapestry of experiences. Of all the anecdotes from my ebullient childhood, one strikingly colorful memory continues to influence my life—a lemonade stand project that turned into a community festival.

It was a radiant summer morning, the sun painted the sky a brilliant cerulean blue, almost as if it too celebrated the day's promise. At the tender age of ten, armed with just a plastic pitcher, packets of lemonade mix, and an infectious enthusiasm, I set out to conquer the business world one cup at a time. Our family's old wooden table, draped in a zesty yellow tablecloth, served as my headquarters, located strategically at the corner of our bustling neighborhood street.

The initial motive was simpl

In [29]:
# load the model
model = transformer_lens.HookedTransformer.from_pretrained("gpt2", center_writing_weights=False)
device = next(model.parameters()).device
layer_index = 6
location = "resid_post_mlp"
# with bf.BlobFile(sparse_autoencoder.paths.v5_32k(location, layer_index), mode="rb") as f:
with bf.BlobFile(sparse_autoencoder.paths.v5_128k(location, layer_index), mode="rb") as f:
    state_dict = torch.load(f)
    autoencoder = sparse_autoencoder.Autoencoder.from_state_dict(state_dict)
    autoencoder.to(device)

Loaded pretrained model gpt2 into HookedTransformer


In [30]:
output_folder = '/disk/nfs/gazinasvolume2/s2521923/Data/sparse_ae'
from utils import update_json_file

def process_input(model, prompt):
    tokens_id = model.to_tokens(prompt)  # (1, n_tokens)
    tokens_str = model.to_str_tokens(prompt)
    with torch.no_grad():
        logits, activation_cache = model.run_with_cache(tokens_id, remove_batch_dim=True)
    return tokens_id, tokens_str, activation_cache

def get_activation(activation_cache, layer_index=6, location="resid_post_mlp"):
    transformer_lens_loc = {
        "mlp_post_act": f"blocks.{layer_index}.mlp.hook_post",
        "resid_delta_attn": f"blocks.{layer_index}.hook_attn_out",
        "resid_post_attn": f"blocks.{layer_index}.hook_resid_mid",
        "resid_delta_mlp": f"blocks.{layer_index}.hook_mlp_out",
        "resid_post_mlp": f"blocks.{layer_index}.hook_resid_post",
    }[location]
    return activation_cache[transformer_lens_loc]

def encode_decode(autoencoder, input_tensor):
    with torch.no_grad():
        latent_activations, info = autoencoder.encode(input_tensor)
        reconstructed_activations = autoencoder.decode(latent_activations, info)
    return latent_activations, reconstructed_activations

def calculate_normalized_mse(input_tensor, reconstructed_activations):
    normalized_mse = (reconstructed_activations - input_tensor).pow(2).sum(dim=1) / (input_tensor).pow(2).sum(dim=1)
    return normalized_mse

def extract_activations(prompt, tokens, latent_activations, top_k=32, activation_threshold=3):
    activations_dict = {}
    prompt_key = prompt  # 根据需要设置不同的 prompt 标识符

    total_activations_count = 0
    
    # 遍历所有 feature
    for feature_index in range(latent_activations.shape[1]):
        # 获取该 feature 的所有激活值
        feature_activations = latent_activations[:, feature_index]
        
        # 仅提取 top k 非零激活值
        non_zero_activations = feature_activations[(feature_activations != 0) & (feature_activations >= activation_threshold)]
        if non_zero_activations.numel() == 0:
            continue
        top_k_values, top_k_indices = torch.topk(non_zero_activations, min(top_k, non_zero_activations.numel()))

        # 构建特征激活字典
        feature_key = f"Feature {feature_index}"
        activations_dict[feature_key] = {prompt_key: {}}
        for value, index in zip(top_k_values, top_k_indices):
            nonzero_indices = (feature_activations == value).nonzero(as_tuple=True)
            if len(nonzero_indices[0]) == 1:  # 确保只有一个元素
                token_index = nonzero_indices[0].item()
                token = tokens[token_index]
                activations_dict[feature_key][prompt][token] = value.item()
            else:
                print(f"Skipping ambiguous token index: {nonzero_indices}")

        total_activations_count += len(top_k_values)

    # Print the total number of activations extracted
    print(f"Total activations extracted: {total_activations_count}")

    # Optionally, return the total number of activations
    return activations_dict

count = 0
for prompt in prompts:
    tokens_id, tokens_str, activation_cache = process_input(model, prompt['statement'])
    activation = get_activation(activation_cache, layer_index)
    latent_activations, reconstructed_activations = encode_decode(autoencoder, activation)
    print(latent_activations.shape)
    print(activation.shape)
    print(reconstructed_activations.shape)
    non_zero_count = (latent_activations != 0).sum().item()
    print("Non-zero activation count:", non_zero_count)
    # print(f"This is {count}/1000 prompt")
    count+=1
    activations_dict = extract_activations(str(count), tokens_str, latent_activations, top_k=5)
    activations_file_name = 'activations_128k_with_personality.json'
    activations_file_path = os.path.join(output_folder, activations_file_name)
    
    update_json_file(activations_file_path, activations_dict)


torch.Size([774, 131072])
torch.Size([774, 768])
torch.Size([774, 768])
Non-zero activation count: 24768
Total activations extracted: 2218
torch.Size([775, 131072])
torch.Size([775, 768])
torch.Size([775, 768])
Non-zero activation count: 24800
Total activations extracted: 2377
torch.Size([827, 131072])
torch.Size([827, 768])
torch.Size([827, 768])
Non-zero activation count: 26464
Total activations extracted: 2375
torch.Size([729, 131072])
torch.Size([729, 768])
torch.Size([729, 768])
Non-zero activation count: 23328
Total activations extracted: 2164
torch.Size([735, 131072])
torch.Size([735, 768])
torch.Size([735, 768])
Non-zero activation count: 23520
Total activations extracted: 2113
torch.Size([760, 131072])
torch.Size([760, 768])
torch.Size([760, 768])
Non-zero activation count: 24320
Total activations extracted: 2232
torch.Size([771, 131072])
torch.Size([771, 768])
torch.Size([771, 768])
Non-zero activation count: 24672
Total activations extracted: 2236
torch.Size([728, 131072])
t