In [29]:
import pickle as pkl
import pandas as pd
import numpy as np
import os
from tqdm import tqdm
from datasets import load_dataset
from einops import rearrange


In [30]:
# activation_head_wise = pkl.load(open('/data/jxf/activations/llama_7B_tqa_mc2_all_head_wise.pkl', 'rb'))
activation_head_wise_end = pkl.load(open('/data/jxf/activations/llama_7B_tqa_mc2_all_100_head_wise.pkl', 'rb'))
activation_categories = pkl.load(open('/data/jxf/activations/llama_7B_tqa_mc2_all_categories.pkl', 'rb'))
activation_labels = np.load('/data/jxf/activations/llama_7B_tqa_mc2_all_labels.npy')
activation_tokens = pkl.load(open('/data/jxf/activations/llama_7B_tqa_mc2_all_tokens.pkl', 'rb'))

In [31]:
def get_separated_activations(labels, head_wise_activations, categories): 

    # separate activations by question
    dataset=load_dataset('truthful_qa', 'multiple_choice')['validation']
    actual_labels = []
    for i in range(len(dataset)):
        actual_labels.append(dataset[i]['mc2_targets']['labels'])
    

    idxs_to_split_at = np.cumsum([len(x) for x in actual_labels])        

    labels = list(labels)
    categories = list(categories)
    separated_labels = []
    separated_categories = []
    for i in range(len(idxs_to_split_at)):
        if i == 0:
            separated_labels.append(labels[:idxs_to_split_at[i]])
            separated_categories.append(categories[:idxs_to_split_at[i]])
        else:
            separated_labels.append(labels[idxs_to_split_at[i-1]:idxs_to_split_at[i]])
            separated_categories.append(categories[idxs_to_split_at[i-1]:idxs_to_split_at[i]])

    separated_head_wise_activations = np.split(head_wise_activations, idxs_to_split_at)

    return separated_head_wise_activations, separated_labels, separated_categories, idxs_to_split_at


In [32]:
def find_q_pos(tokens):
    positions = []
    for token_list in tokens:
        for i in range(0, len(token_list)-1):
            if ('?' in token_list[i] or token_list[i] == '.') and token_list[i+1] == '▁A':
                positions.append(i)  # 将找到的位置添加到列表中
                break  # 假设每个列表中只有一个满足条件的位置
        if i == len(token_list)-2:
            print('Error: cannot find question')
    return positions

In [33]:
# q_pos = find_q_pos(activation_tokens)
# q_activation_head_wise = [activations[:, pos, :] for activations, pos in zip(activation_head_wise, q_pos)]
# pkl.dump(q_activation_head_wise, open('/data/jxf/activations/llama_7B_tqa_mc2_question_head_wise.pkl', 'wb'))

In [34]:
activation_head_wise_end[0].shape

(32, 4096)

In [35]:
separated_head_wise_activations, separated_labels, separated_categories, idxs_to_split_at = get_separated_activations(activation_labels, activation_head_wise_end, activation_categories)

Found cached dataset truthful_qa (/data/wtl/hf_cache/truthful_qa/multiple_choice/1.1.0/63502f6bc6ee493830ce0843991b028d0ab568d221896b2ee3b8a5dfdaa9d7f4)
100%|██████████| 1/1 [00:00<00:00, 726.54it/s]


In [36]:
head_wise_activation_directions = [a[np.array(l) == 1].mean(axis=0) - a[np.array(l) == 0].mean(axis=0) for a, l in zip(separated_head_wise_activations, separated_labels)]

In [37]:
num_heads = 32
head_wise_activation_directions = rearrange(head_wise_activation_directions, 'b s (h d) -> b s h d', h=num_heads)

In [38]:
head_wise_activation_directions.shape

(817, 32, 32, 128)

In [42]:
dataset = load_dataset("truthful_qa", "multiple_choice")['validation']
df = pd.read_csv('./TruthfulQA/data/v0/TruthfulQA.csv')

Found cached dataset truthful_qa (/data/wtl/hf_cache/truthful_qa/multiple_choice/1.1.0/63502f6bc6ee493830ce0843991b028d0ab568d221896b2ee3b8a5dfdaa9d7f4)
100%|██████████| 1/1 [00:00<00:00, 568.80it/s]


In [43]:
df['Direction'] = 0
for i in tqdm(range(len(dataset))): 
    q = dataset[i]['question']
    direction = head_wise_activation_directions[i]
    df.loc[df['Question'] == q, 'Direction'] = [direction.tolist()]


100%|██████████| 817/817 [00:27<00:00, 29.92it/s]


In [None]:
df.to_csv('./TruthfulQA/data/v0/TruthfulQA_head_wise_end_direction.csv')

In [None]:
pkl.dump(head_wise_activation_directions, open('/data/jxf/activations/llama_7B_tqa_mc2_all_100_head_wise_directions.pkl', 'wb'))