In [1]:
import os
import json
import jsonlines

from itertools import combinations
from collections import defaultdict, Counter
import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, balanced_accuracy_score, classification_report, confusion_matrix
from scipy.spatial import distance

In [2]:
def read_memmap(filepath):
    with open(filepath.replace(".dat", ".conf"), "r") as fin_config:
        memmap_configs = json.load(fin_config)
        return np.memmap(filepath, mode="r", shape=tuple(memmap_configs["shape"]), dtype=memmap_configs["dtype"])

In [3]:
NUM_FORMATS = 8

In [4]:
result_dir = "../../../results"
input_dir = "../../../preprocessed_datasets"

In [5]:
model_name = "Llama-3.1-8B-Instruct"
dataset_name = "CommonsenseQA"
prompting_strategy = "zero-shot"

id_predictions_map = {}

output_dir = f"{result_dir}/{dataset_name}/{model_name}"

predictions_path = os.path.join(output_dir, f"{prompting_strategy}_predictions_validation.jsonl")
try:
    with jsonlines.open(predictions_path) as fin:
        id_predictions_map = {}
        for example in fin.iter():
            id_predictions_map[example["id"]] = example["predictions"]
except:
    pass

labels = []
for preds in id_predictions_map.values():
    pairwise_labels = []
    for i, j in list(combinations(range(8), 2)):
        label = int(preds[str(i)] == preds[str(j)])
        pairwise_labels.append(label)
    labels.append(pairwise_labels)
labels = np.array(labels)


In [6]:
def expand_pairwise_embeddings(hidden_states):
    num_samples = hidden_states.shape[0]

    pairwise_embeddings = []

    for s in range(num_samples):
        sample_pairs = []
        for i, j in list(combinations(range(8), 2)):
            emb_i = hidden_states[s, i]
            emb_j = hidden_states[s, j]
            concat = np.concatenate([emb_i, emb_j], axis=-1)
            sample_pairs.append(concat)
        sample_pairs = np.stack(sample_pairs, axis=0)
        pairwise_embeddings.append(sample_pairs)

    return np.stack(pairwise_embeddings, axis=0)

In [7]:
layer_wise_path = os.path.join(output_dir, f"{prompting_strategy}_layer_wise_sae_hidden_states_validation.dat")
# head_wise_path = os.path.join(output_dir, f"{prompting_strategy}_head_wise_hidden_states_validation.dat")

layer_wise_hidden_states = read_memmap(layer_wise_path)
# head_wise_hidden_states = read_memmap(head_wise_path)

# pairwise_layer_wise_hidden_states = expand_pairwise_embeddings(layer_wise_hidden_states)
# pairwise_head_wise_hidden_states = expand_pairwise_embeddings(head_wise_hidden_states)

In [8]:
layer_wise_hidden_states.shape#, pairwise_layer_wise_hidden_states.shape

(244, 8, 32, 32768)

## Computing Steering Direction

In [9]:
z = []
for id, predictions in id_predictions_map.items():
    if "majority_voting" in predictions:
        predictions.pop("majority_voting")
    zz = []
    for ii in range(NUM_FORMATS):
        zzz = list(predictions.values()).count(predictions[str(ii)])
        zzz = ((zzz>4)*1.0)
        zz.append(zzz)
    z.append(zz)
majority_minority_array = np.array(z)

In [10]:
# def extract_embeddings_by_class(embeddings, majority_minority_array):
#     N, F, L, H, D = embeddings.shape

#     flat_embeddings = embeddings.reshape(N * F, L, H, D)
#     flat_labels = majority_minority_array.reshape(N * F)

#     positive_embeddings = flat_embeddings[flat_labels > 0.5]
#     negative_embeddings = flat_embeddings[flat_labels < 0.5]

#     return positive_embeddings, negative_embeddings

def extract_embeddings_by_class(embeddings, majority_minority_array):
    N, F, D = embeddings.shape

    selection_idx = np.logical_and(majority_minority_array.mean(1) < 1, majority_minority_array.mean(1) > 0)
    selected_embeddings = embeddings[selection_idx]
    selected_labels = majority_minority_array[selection_idx]

    # (1) without balancing
    selected_N = int(selection_idx.sum())
    flat_embeddings = selected_embeddings.reshape(selected_N * F, D)
    flat_labels = selected_labels.reshape(selected_N * F)
    positive_embeddings = flat_embeddings[flat_labels > 0.5]
    negative_embeddings = flat_embeddings[flat_labels < 0.5]

    # (2) balancing
    # pos_weight = selected_labels/((selected_labels.sum(1))[...,np.newaxis])
    # neg_weight = (1-selected_labels)/(((1-selected_labels).sum(1))[...,np.newaxis])
    # positive_embeddings = (selected_embeddings*pos_weight[...,np.newaxis,np.newaxis,np.newaxis]).sum(1)
    # negative_embeddings = (selected_embeddings*neg_weight[...,np.newaxis,np.newaxis,np.newaxis]).sum(1)

    return positive_embeddings, negative_embeddings

In [11]:
layer_wise_hidden_states.shape

(244, 8, 32, 32768)

In [12]:
layer_wise_hidden_states.shape

(244, 8, 32, 32768)

In [13]:
majority_minority_array.shape

(244, 8)

In [20]:
sae_idx = 15

pos, neg = extract_embeddings_by_class(layer_wise_hidden_states[:,:,sae_idx,:], majority_minority_array)
steering_direction = pos.mean(0) - neg.mean(0)

steering_direction_path = os.path.join(output_dir, f"{prompting_strategy}_SAE_steering_direction_{sae_idx}.npy")
np.save(steering_direction_path, steering_direction)

In [21]:
import http.client

In [22]:
i = np.argsort(pos.mean(0))[::-1]
pos.mean(0)[i[:10]]

array([10.766,  4.145,  3.658,  3.367,  2.941,  2.371,  2.336,  2.19 ,
        2.14 ,  1.929], dtype=float16)

In [23]:
nonzero_indices = np.nonzero(steering_direction)[0]
len(nonzero_indices)

19

In [24]:
nonzero_indices = np.nonzero(steering_direction)[0]
indices = nonzero_indices[np.argsort(np.abs(steering_direction[nonzero_indices]))[::-1]]

for idx in indices[:50]:
    try:
        conn = http.client.HTTPSConnection("www.neuronpedia.org")

        conn.request("GET", f"/api/feature/llama3.1-8b/{sae_idx}-llamascope-res-32k/{idx}")

        res = conn.getresponse()
        data = res.read()
        output = json.loads(data.decode("utf-8"))
        print(idx, '\t', steering_direction[idx], '\t', output['explanations'][0]['description'])
    except:
        print(idx, '\t', steering_direction[idx])

19507 	 0.253 	  programming constructs and functions related to data handling and manipulation
7026 	 0.2305 	  elements related to blogging and online article engagement
6335 	 -0.1934 	 mathematical and scientific concepts related to functions and categories
20752 	 0.168 	 exclamations and punctuations indicating emphasis or a strong emotional response
27173 	 -0.142 	  references to publications and articles
11405 	 -0.1221 	 JavaScript coding patterns and structures related to functions and data management
1954 	 -0.105 	  punctuation marks and question marks
17036 	 -0.0957 	  references to correct answers and evaluation processes in assessments
29766 	 0.09375 	  structures related to code functions and their parameters
7028 	 0.08606 	  code snippets and elements related to programming syntax
20326 	 0.0713 	  significant numerical values or statistics
3179 	 -0.0547 	 references to discussions or mentions of climate change and its impacts
23990 	 0.02745 	  terms related to d

In [None]:
target = 7026
neg[:,target], pos[:,target]

(array([2.297, 0.   , 0.   , 0.   , 0.   , 2.297, 2.547, 0.   , 2.922,
        2.766, 2.531, 2.656, 2.5  , 0.   , 2.234, 2.188, 2.531, 2.688,
        2.906, 2.953, 0.   , 3.344, 3.281, 2.438, 2.156, 2.438],
       dtype=float16),
 array([2.125, 0.   , 2.125, 0.   , 2.172, 0.   , 0.   , 2.328, 2.438,
        2.203, 2.312, 2.234, 2.422, 2.375, 0.   , 0.   , 0.   , 0.   ,
        0.   , 2.328, 2.422, 2.39 , 2.14 , 2.406, 2.766, 2.844, 2.547,
        2.766, 2.766, 2.844, 2.719, 2.75 , 2.484, 2.75 , 2.547, 2.781,
        2.36 , 2.172, 2.422, 2.156, 2.406, 0.   , 2.39 , 2.234, 2.453,
        0.   , 2.344, 0.   , 2.531, 2.797, 2.625, 2.281, 2.578, 2.219,
        2.562, 2.89 , 3.016, 2.89 , 2.969, 2.703, 2.39 , 2.625, 2.203,
        2.453, 2.281, 2.531, 2.438, 3.094, 3.078, 2.906, 3.203, 2.89 ,
        3.234, 2.344, 2.453, 2.297, 2.375, 2.156], dtype=float16))

In [None]:
result = pos[:, None, :] - neg[None, :, :]

In [None]:
flattened = result.reshape(-1, result.shape[-1])  # shape: [78*26, 32768]

In [None]:
for Z in flattened:
    nonzero_idx = np.nonzero(Z)[0]  # or np.where(Z != 0)[0]
    fired = np.stack((nonzero_idx, Z[nonzero_idx]), axis=1)
    for k, v in fired:
        print(int(k), v)
    break

3179 0.3125
3998 -0.0625
6335 -2.9375
7026 -0.171875
9478 -0.03125
11627 0.125
14312 -0.1875
15559 0.15625
17036 0.203125
20326 -0.046875


In [None]:
z = result.mean(0).mean(0)

In [None]:
fired = []
for zidx, zz in enumerate(z):
    if zz > 0 or zz < 0:
        fired.append([zidx, zz])
fired = np.array(fired)

In [None]:
for idx, value in fired:
    print(int(idx), "\t", value)

1954 	 -0.1051025390625
3179 	 -0.05047607421875
3998 	 -0.025665283203125
6335 	 -0.1934814453125
7026 	 0.2298583984375
7028 	 0.08599853515625
9478 	 -0.01062774658203125
11405 	 -0.12127685546875
11627 	 0.01099395751953125
14312 	 0.019195556640625
15559 	 0.0238189697265625
17036 	 -0.093994140625
17594 	 0.00978851318359375
19507 	 0.2529296875
20326 	 0.0714111328125
20752 	 0.16796875
23990 	 0.0274505615234375
27173 	 -0.14208984375
29766 	 0.09381103515625


In [None]:
fired = []
for zidx, zz in enumerate(z):
    if zz > 0 or zz < 0:
        fired.append([zidx, zz])
fired = np.array(fired)

In [None]:
z.shape

(32768,)

In [None]:
zidx = np.argsort(z)
z[zidx[32760:]]

array([ 6.75 ,  6.758,  6.887,  7.83 ,  9.414, 10.1  , 10.63 , 12.65 ],
      dtype=float16)