In [98]:
import os
import json
import jsonlines

from itertools import combinations
from collections import defaultdict, Counter
import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, balanced_accuracy_score, classification_report, confusion_matrix
from scipy.spatial import distance

In [90]:
def read_memmap(filepath):
    with open(filepath.replace(".dat", ".conf"), "r") as fin_config:
        memmap_configs = json.load(fin_config)
        return np.memmap(filepath, mode="r", shape=tuple(memmap_configs["shape"]), dtype=memmap_configs["dtype"])

In [91]:
NUM_FORMATS = 8

In [92]:
result_dir = "../../../results"
input_dir = "../../../preprocessed_datasets"

In [104]:
model_name = "Llama-3.1-8B-Instruct"
dataset_name = "CommonsenseQA"
prompting_strategy = "zero-shot"

output_dir = f"{result_dir}/{dataset_name}/{model_name}"

predictions_path = os.path.join(output_dir, f"{prompting_strategy}_predictions_validation.jsonl")
try:
    with jsonlines.open(predictions_path) as fin:
        id_predictions_map = {}
        for example in fin.iter():
            id_predictions_map[example["id"]] = example["predictions"]
except:
    pass

labels = []
for preds in id_predictions_map.values():
    pairwise_labels = []
    for i, j in list(combinations(range(8), 2)):
        label = int(preds[str(i)] == preds[str(j)])
        pairwise_labels.append(label)
    labels.append(pairwise_labels)
labels = np.array(labels)


In [113]:
def expand_pairwise_embeddings(hidden_states):
    num_samples = hidden_states.shape[0]

    pairwise_embeddings = []

    for s in range(num_samples):
        sample_pairs = []
        for i, j in list(combinations(range(8), 2)):
            emb_i = hidden_states[s, i]
            emb_j = hidden_states[s, j]
            concat = np.concatenate([emb_i, emb_j], axis=-1)
            sample_pairs.append(concat)
        sample_pairs = np.stack(sample_pairs, axis=0)
        pairwise_embeddings.append(sample_pairs)

    return np.stack(pairwise_embeddings, axis=0)

In [114]:
layer_wise_path = os.path.join(output_dir, f"{prompting_strategy}_layer_wise_hidden_states_validation.dat")
head_wise_path = os.path.join(output_dir, f"{prompting_strategy}_head_wise_hidden_states_validation.dat")

layer_wise_hidden_states = read_memmap(layer_wise_path)
head_wise_hidden_states = read_memmap(head_wise_path)

pairwise_layer_wise_hidden_states = expand_pairwise_embeddings(layer_wise_hidden_states)
pairwise_head_wise_hidden_states = expand_pairwise_embeddings(head_wise_hidden_states)

In [122]:
num_samples, num_formats, num_layers, hidden_size = layer_wise_hidden_states.shape
num_samples, num_formats, num_layers, num_heads, hidden_size2 = head_wise_hidden_states.shape

layer_accuracy_scores = []
for layer_idx in range(num_layers):
    # if layer_idx not in [31]:
    #     continue

    # Step 1: Prepare input
    X = pairwise_layer_wise_hidden_states[:,:,layer_idx,:]
    X = X.reshape(-1, hidden_size*2)
    Y = labels.reshape(-1)

    # Step 2: 80:20 split for validation to select top-K components
    X_fold1, X_fold2, y_fold1, y_fold2 = train_test_split(X, Y, test_size=0.2, random_state=0)

    for i in range(2):
        # Step 3: Standardization
        scaler = StandardScaler()
        if i == 0:
            X_train = scaler.fit_transform(X_fold1)
            X_test = scaler.transform(X_fold2)
            y_train, y_test = y_fold1, y_fold2
        else:
            break
            X_train = scaler.fit_transform(X_fold2)
            X_test = scaler.transform(X_fold1)
            y_train, y_test = y_fold2, y_fold1

        # Step 4: Train a linear probing model
        clf = LogisticRegression(solver="lbfgs", max_iter=1000)
        clf.fit(X_train, y_train)

        # Step 5: Evaluate the model
        y_pred = clf.predict(X_test)
        accuracy = balanced_accuracy_score(y_test, y_pred)
        layer_accuracy_scores.append(accuracy)
        # print(classification_report(y_test, y_pred))
        # print(confusion_matrix(y_test, y_pred))
    print(layer_accuracy_scores)
    print("*"*20)

[0.8432177844762623]
********************
[0.8432177844762623, 0.843594574227581]
********************
[0.8432177844762623, 0.843594574227581, 0.8568481537302186]
********************
[0.8432177844762623, 0.843594574227581, 0.8568481537302186, 0.8568481537302186]
********************
[0.8432177844762623, 0.843594574227581, 0.8568481537302186, 0.8568481537302186, 0.8826017332328561]
********************
[0.8432177844762623, 0.843594574227581, 0.8568481537302186, 0.8568481537302186, 0.8826017332328561, 0.8564713639788998]
********************
[0.8432177844762623, 0.843594574227581, 0.8568481537302186, 0.8568481537302186, 0.8826017332328561, 0.8564713639788998, 0.8443481537302185]
********************
[0.8432177844762623, 0.843594574227581, 0.8568481537302186, 0.8568481537302186, 0.8826017332328561, 0.8564713639788998, 0.8443481537302185, 0.8451017332328561]
********************
[0.8432177844762623, 0.843594574227581, 0.8568481537302186, 0.8568481537302186, 0.8826017332328561, 0.856471363

In [127]:
print(sorted(layer_accuracy_scores, reverse=True)[:10])
print()
print(sorted(layer_accuracy_scores, reverse=False)[:10])

[0.9068481537302185, 0.8947249434815373, 0.8947249434815373, 0.8947249434815373, 0.8943481537302186, 0.8943481537302186, 0.8943481537302186, 0.8943481537302186, 0.8943481537302186, 0.8943481537302186]

[0.8432177844762623, 0.843594574227581, 0.8443481537302185, 0.8451017332328561, 0.8564713639788998, 0.8564713639788998, 0.8568481537302186, 0.8568481537302186, 0.8568481537302186, 0.857601733232856]


In [125]:
num_samples, num_formats, num_layers, hidden_size = layer_wise_hidden_states.shape
num_samples, num_formats, num_layers, num_heads, hidden_size2 = head_wise_hidden_states.shape

head_accuracy_scores = defaultdict(list)
for layer_idx in range(num_layers):
    # if layer_idx not in [31]:
    #     continue
    for head_idx in range(num_heads):
        # Step 1: Prepare input
        X = pairwise_head_wise_hidden_states[:,:,layer_idx,head_idx,:]
        X = X.reshape(-1, hidden_size2*2)
        Y = labels.reshape(-1)

        # Step 2: 80:20 split for validation to select top-K components
        X_fold1, X_fold2, y_fold1, y_fold2 = train_test_split(X, Y, test_size=0.2, random_state=0)

        for i in range(2):
            # Step 3: Standardization
            scaler = StandardScaler()
            if i == 0:
                X_train = scaler.fit_transform(X_fold1)
                X_test = scaler.transform(X_fold2)
                y_train, y_test = y_fold1, y_fold2
            else:
                break
                X_train = scaler.fit_transform(X_fold2)
                X_test = scaler.transform(X_fold1)
                y_train, y_test = y_fold2, y_fold1

            # Step 4: Train a linear probing model
            clf = LogisticRegression(solver="lbfgs", max_iter=1000)
            clf.fit(X_train, y_train)

            # Step 5: Evaluate the model
            y_pred = clf.predict(X_test)
            accuracy = balanced_accuracy_score(y_test, y_pred)
            head_accuracy_scores[layer_idx].append(accuracy)
            # print(classification_report(y_test, y_pred))
            # print(confusion_matrix(y_test, y_pred))
    print(head_accuracy_scores)
    print("*"*20)

defaultdict(<class 'list'>, {0: [0.8133195177091184, 0.742087415222306, 0.7553409947249434, 0.7557177844762623, 0.7795874152223059, 0.7557177844762623, 0.7553409947249434, 0.6094856819894499, 0.7553409947249434, 0.754587415222306, 0.7667106254709872, 0.7674642049736247, 0.7557177844762623, 0.754587415222306, 0.7307177844762622, 0.7553409947249434, 0.6719856819894499, 0.7193481537302185, 0.7689713639788998, 0.7553409947249434, 0.7530802562170309, 0.7553409947249434, 0.7682177844762623, 0.7792106254709872, 0.7674642049736247, 0.767087415222306, 0.7674642049736247, 0.754587415222306, 0.7424642049736248, 0.5, 0.6106160512434062, 0.5125]})
********************
defaultdict(<class 'list'>, {0: [0.8133195177091184, 0.742087415222306, 0.7553409947249434, 0.7557177844762623, 0.7795874152223059, 0.7557177844762623, 0.7553409947249434, 0.6094856819894499, 0.7553409947249434, 0.754587415222306, 0.7667106254709872, 0.7674642049736247, 0.7557177844762623, 0.754587415222306, 0.7307177844762622, 0.7553

In [137]:
print(sorted(np.array(list(head_accuracy_scores.values())).flatten().tolist(), reverse=True)[:10])
print()
print(sorted(np.array(list(head_accuracy_scores.values())).flatten().tolist(), reverse=False)[:10])

[0.857601733232856, 0.8572249434815373, 0.843594574227581, 0.8432177844762623, 0.8420874152223059, 0.8420874152223059, 0.8314713639788998, 0.8314713639788998, 0.8314713639788998, 0.831094574227581]

[0.49886963074604374, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5106160512434061, 0.5117464204973624]


In [None]:
def get_topk_heads(scores, k=10):
    num_layers, num_heads = scores.shape
    flat_scores = scores.flatten()
    topk_indices = np.argsort(flat_scores)[-k:][::-1]

    topk_info = []
    for idx in topk_indices:
        layer = idx // num_heads
        head = idx % num_heads
        score = scores[layer, head]
        topk_info.append(((layer, head), score))

    return topk_info

aaa = np.array(list(head_accuracy_scores.values()))
get_topk_heads(aaa, 25)

[((24, 26), 0.857601733232856),
 ((18, 29), 0.8572249434815373),
 ((3, 21), 0.843594574227581),
 ((11, 27), 0.8432177844762623),
 ((11, 20), 0.8420874152223059),
 ((16, 11), 0.8420874152223059),
 ((11, 16), 0.8314713639788998),
 ((6, 18), 0.8314713639788998),
 ((11, 26), 0.8314713639788998),
 ((25, 30), 0.831094574227581),
 ((23, 16), 0.831094574227581),
 ((4, 15), 0.831094574227581),
 ((20, 27), 0.831094574227581),
 ((30, 13), 0.831094574227581),
 ((1, 29), 0.831094574227581),
 ((26, 5), 0.831094574227581),
 ((10, 15), 0.8307177844762623),
 ((11, 31), 0.8307177844762623),
 ((13, 30), 0.8303409947249435),
 ((24, 10), 0.8299642049736247),
 ((23, 19), 0.8299642049736247),
 ((10, 8), 0.829587415222306),
 ((22, 12), 0.829587415222306),
 ((11, 9), 0.8204785229841749),
 ((18, 30), 0.8201017332328561)]

In [None]:
def get_topk_layers(layer_scores, k=10):
    topk_indices = np.argsort(layer_scores)[-k:][::-1]
    topk_info = [(idx, layer_scores[idx]) for idx in topk_indices]
    return topk_info

get_topk_layers(layer_accuracy_scores, 10)

[(24, 0.9068481537302185),
 (28, 0.8947249434815373),
 (29, 0.8947249434815373),
 (20, 0.8947249434815373),
 (31, 0.8943481537302186),
 (27, 0.8943481537302186),
 (26, 0.8943481537302186),
 (30, 0.8943481537302186),
 (21, 0.8943481537302186),
 (23, 0.8943481537302186)]