In [1]:
import os
import json
import jsonlines

from itertools import combinations
from collections import defaultdict, Counter
import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, balanced_accuracy_score, classification_report, confusion_matrix
from scipy.spatial import distance

In [2]:
def read_memmap(filepath):
    with open(filepath.replace(".dat", ".conf"), "r") as fin_config:
        memmap_configs = json.load(fin_config)
        return np.memmap(filepath, mode="r", shape=tuple(memmap_configs["shape"]), dtype=memmap_configs["dtype"])

In [3]:
NUM_FORMATS = 8

In [4]:
result_dir = "../../../results"
input_dir = "../../../preprocessed_datasets"

In [5]:
model_name = "Llama-3.1-8B-Instruct"
dataset_name = "CommonsenseQA"
prompting_strategy = "zero-shot"

id_predictions_map = {}

output_dir = f"{result_dir}/{dataset_name}/{model_name}"

predictions_path = os.path.join(output_dir, f"{prompting_strategy}_predictions_validation.jsonl")
try:
    with jsonlines.open(predictions_path) as fin:
        id_predictions_map = {}
        for example in fin.iter():
            id_predictions_map[example["id"]] = example["predictions"]
except:
    pass

labels = []
for preds in id_predictions_map.values():
    pairwise_labels = []
    for i, j in list(combinations(range(8), 2)):
        label = int(preds[str(i)] == preds[str(j)])
        pairwise_labels.append(label)
    labels.append(pairwise_labels)
labels = np.array(labels)


In [6]:
def expand_pairwise_embeddings(hidden_states):
    num_samples = hidden_states.shape[0]

    pairwise_embeddings = []

    for s in range(num_samples):
        sample_pairs = []
        for i, j in list(combinations(range(8), 2)):
            emb_i = hidden_states[s, i]
            emb_j = hidden_states[s, j]
            concat = np.concatenate([emb_i, emb_j], axis=-1)
            sample_pairs.append(concat)
        sample_pairs = np.stack(sample_pairs, axis=0)
        pairwise_embeddings.append(sample_pairs)

    return np.stack(pairwise_embeddings, axis=0)

In [7]:
layer_wise_path = os.path.join(output_dir, f"{prompting_strategy}_layer_wise_sae_hidden_states_validation.dat")
# head_wise_path = os.path.join(output_dir, f"{prompting_strategy}_head_wise_hidden_states_validation.dat")

layer_wise_hidden_states = read_memmap(layer_wise_path)
# head_wise_hidden_states = read_memmap(head_wise_path)

pairwise_layer_wise_hidden_states = expand_pairwise_embeddings(layer_wise_hidden_states)
# pairwise_head_wise_hidden_states = expand_pairwise_embeddings(head_wise_hidden_states)

In [8]:
layer_wise_hidden_states.shape, pairwise_layer_wise_hidden_states.shape

((244, 8, 32, 32768), (244, 28, 32, 65536))

## Computing Steering Direction

In [12]:
z = []
for id, predictions in id_predictions_map.items():
    if "majority_voting" in predictions:
        predictions.pop("majority_voting")
    zz = []
    for ii in range(NUM_FORMATS):
        zzz = list(predictions.values()).count(predictions[str(ii)])
        zzz = ((zzz>4)*1.0)
        zz.append(zzz)
    z.append(zz)
majority_minority_array = np.array(z)

In [31]:
# def extract_embeddings_by_class(embeddings, majority_minority_array):
#     N, F, L, H, D = embeddings.shape

#     flat_embeddings = embeddings.reshape(N * F, L, H, D)
#     flat_labels = majority_minority_array.reshape(N * F)

#     positive_embeddings = flat_embeddings[flat_labels > 0.5]
#     negative_embeddings = flat_embeddings[flat_labels < 0.5]

#     return positive_embeddings, negative_embeddings

def extract_embeddings_by_class(embeddings, majority_minority_array):
    N, F, D = embeddings.shape

    selection_idx = np.logical_and(majority_minority_array.mean(1) < 1, majority_minority_array.mean(1) > 0)
    selected_embeddings = embeddings[selection_idx]
    selected_labels = majority_minority_array[selection_idx]

    # (1) without balancing
    selected_N = int(selection_idx.sum())
    flat_embeddings = selected_embeddings.reshape(selected_N * F, D)
    flat_labels = selected_labels.reshape(selected_N * F)
    positive_embeddings = flat_embeddings[flat_labels > 0.5]
    negative_embeddings = flat_embeddings[flat_labels < 0.5]

    # (2) balancing
    # pos_weight = selected_labels/((selected_labels.sum(1))[...,np.newaxis])
    # neg_weight = (1-selected_labels)/(((1-selected_labels).sum(1))[...,np.newaxis])
    # positive_embeddings = (selected_embeddings*pos_weight[...,np.newaxis,np.newaxis,np.newaxis]).sum(1)
    # negative_embeddings = (selected_embeddings*neg_weight[...,np.newaxis,np.newaxis,np.newaxis]).sum(1)

    return positive_embeddings, negative_embeddings

In [32]:
layer_wise_hidden_states.shape

(244, 8, 32, 32768)

In [86]:
pos, neg = extract_embeddings_by_class(layer_wise_hidden_states[:,:,0,:], majority_minority_array)
steering_direction = pos.mean(0) - neg.mean(0)

steering_direction_path = os.path.join(output_dir, f"{prompting_strategy}_SAE_steering_direction.npy")
np.save(steering_direction_path, steering_direction)

In [88]:
np.nonzero(pos.mean(0))

(array([ 1404,  2133,  3052,  3553, 10055, 10494, 12286, 13744, 15820,
        18457, 19999, 21820, 26148, 26675]),)

In [97]:
idx = 13744
pos.mean(0)[idx], neg.mean(0)[idx]

(0.852, 0.844)

In [100]:
for idx in np.nonzero(steering_direction)[0]:
    print (idx, '\t', steering_direction[idx])


1404 	 0.01563
2133 	 0.02313
3052 	 -0.21
3553 	 0.02333
10055 	 -0.005127
10494 	 -0.219
12286 	 -0.04297
13744 	 0.007812
15820 	 -0.0332
16486 	 -0.0694
18457 	 0.04678
19999 	 0.001953
21820 	 0.001953
26148 	 0.03027


In [37]:
result = pos[:, None, :] - neg[None, :, :]

In [77]:
flattened = result.reshape(-1, result.shape[-1])  # shape: [78*26, 32768]

In [80]:
for Z in flattened:
    nonzero_idx = np.nonzero(Z)[0]  # or np.where(Z != 0)[0]
    fired = np.stack((nonzero_idx, Z[nonzero_idx]), axis=1)
    for k, v in fired:
        print(int(k), v)
    break

1404 -0.0625
3052 -2.171875
12286 -0.40625
13744 -2.0625
15820 -0.53125
19999 -0.03125
21820 -0.140625


In [81]:
z = result.mean(0).mean(0)

In [82]:
fired = []
for zidx, zz in enumerate(z):
    if zz > 0 or zz < 0:
        fired.append([zidx, zz])
fired = np.array(fired)

In [83]:
for idx, value in fired:
    print(int(idx), "\t", value)

1404 	 0.015838623046875
2133 	 0.02313232421875
3052 	 -0.2109375
3553 	 0.0233306884765625
10055 	 -0.005107879638671875
10494 	 -0.2188720703125
12286 	 -0.045257568359375
13744 	 0.0078125
15820 	 -0.035064697265625
16486 	 -0.06939697265625
18457 	 0.046783447265625
19999 	 0.0003952980041503906
21820 	 0.0018167495727539062
26148 	 0.030120849609375
26675 	 -3.2842159271240234e-05


In [None]:
fired = []
for zidx, zz in enumerate(z):
    if zz > 0 or zz < 0:
        fired.append([zidx, zz])
fired = np.array(fired)

In [56]:
z.shape

(32768,)

In [60]:
zidx = np.argsort(z)
z[zidx[32760:]]

array([ 0.8125,  3.656 , 15.19  , 32.12  , 47.06  , 47.2   , 61.12  ,
       94.75  ], dtype=float16)