In [1]:
import os
import json
import jsonlines

from itertools import combinations
from collections import defaultdict, Counter
import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, balanced_accuracy_score, classification_report, confusion_matrix
from scipy.spatial import distance

In [2]:
def read_memmap(filepath):
    with open(filepath.replace(".dat", ".conf"), "r") as fin_config:
        memmap_configs = json.load(fin_config)
        return np.memmap(filepath, mode="r", shape=tuple(memmap_configs["shape"]), dtype=memmap_configs["dtype"])

In [3]:
NUM_FORMATS = 8

In [4]:
result_dir = "../../../results"
input_dir = "../../../preprocessed_datasets"

In [5]:
model_name = "Llama-3.1-8B-Instruct"
dataset_name = "CommonsenseQA"
prompting_strategy = "zero-shot-cot"

output_dir = f"{result_dir}/{dataset_name}/{model_name}"

predictions_path = os.path.join(output_dir, f"{prompting_strategy}_predictions_validation.jsonl")
try:
    with jsonlines.open(predictions_path) as fin:
        id_predictions_map = {}
        for example in fin.iter():
            id_predictions_map[example["id"]] = example["predictions"]
except:
    pass

labels = []
for preds in id_predictions_map.values():
    pairwise_labels = []
    for i, j in list(combinations(range(8), 2)):
        label = int(preds[str(i)] == preds[str(j)])
        pairwise_labels.append(label)
    labels.append(pairwise_labels)
labels = np.array(labels)


In [6]:
def expand_pairwise_embeddings(hidden_states):
    num_samples = hidden_states.shape[0]

    pairwise_embeddings = []

    for s in range(num_samples):
        sample_pairs = []
        for i, j in list(combinations(range(8), 2)):
            emb_i = hidden_states[s, i]
            emb_j = hidden_states[s, j]
            concat = np.concatenate([emb_i, emb_j], axis=-1)
            sample_pairs.append(concat)
        sample_pairs = np.stack(sample_pairs, axis=0)
        pairwise_embeddings.append(sample_pairs)

    return np.stack(pairwise_embeddings, axis=0)

In [7]:
layer_wise_path = os.path.join(output_dir, f"{prompting_strategy}_layer_wise_hidden_states_validation.dat")
head_wise_path = os.path.join(output_dir, f"{prompting_strategy}_head_wise_hidden_states_validation.dat")

layer_wise_hidden_states = read_memmap(layer_wise_path)
head_wise_hidden_states = read_memmap(head_wise_path)

pairwise_layer_wise_hidden_states = expand_pairwise_embeddings(layer_wise_hidden_states)
pairwise_head_wise_hidden_states = expand_pairwise_embeddings(head_wise_hidden_states)

In [8]:
num_samples, num_formats, num_layers, hidden_size = layer_wise_hidden_states.shape
num_samples, num_formats, num_layers, num_heads, hidden_size2 = head_wise_hidden_states.shape

layer_accuracy_scores = []
for layer_idx in range(num_layers):
    # if layer_idx not in [31]:
    #     continue

    # Step 1: Prepare input
    X = pairwise_layer_wise_hidden_states[:,:,layer_idx,:]
    X = X.reshape(-1, hidden_size*2)
    Y = labels.reshape(-1)

    # Step 2: 80:20 split for validation to select top-K components
    X_fold1, X_fold2, y_fold1, y_fold2 = train_test_split(X, Y, test_size=0.2, random_state=0)

    for i in range(2):
        # Step 3: Standardization
        scaler = StandardScaler()
        if i == 0:
            X_train = scaler.fit_transform(X_fold1)
            X_test = scaler.transform(X_fold2)
            y_train, y_test = y_fold1, y_fold2
        else:
            break
            X_train = scaler.fit_transform(X_fold2)
            X_test = scaler.transform(X_fold1)
            y_train, y_test = y_fold2, y_fold1

        # Step 4: Train a linear probing model
        clf = LogisticRegression(solver="lbfgs", max_iter=1000)
        clf.fit(X_train, y_train)

        # Step 5: Evaluate the model
        y_pred = clf.predict(X_test)
        accuracy = balanced_accuracy_score(y_test, y_pred)
        layer_accuracy_scores.append(accuracy)
        # print(classification_report(y_test, y_pred))
        # print(confusion_matrix(y_test, y_pred))
    print(layer_accuracy_scores)
    print("*"*20)

[0.7786634460547504]
********************
[0.7786634460547504, 0.8022866344605475]
********************
[0.7786634460547504, 0.8022866344605475, 0.8062866344605475]
********************
[0.7786634460547504, 0.8022866344605475, 0.8062866344605475, 0.8314943639291466]
********************
[0.7786634460547504, 0.8022866344605475, 0.8062866344605475, 0.8314943639291466, 0.7938840579710145]
********************
[0.7786634460547504, 0.8022866344605475, 0.8062866344605475, 0.8314943639291466, 0.7938840579710145, 0.8062866344605475]
********************
[0.7786634460547504, 0.8022866344605475, 0.8062866344605475, 0.8314943639291466, 0.7938840579710145, 0.8062866344605475, 0.7942866344605475]
********************
[0.7786634460547504, 0.8022866344605475, 0.8062866344605475, 0.8314943639291466, 0.7938840579710145, 0.8062866344605475, 0.7942866344605475, 0.8018840579710145]
********************
[0.7786634460547504, 0.8022866344605475, 0.8062866344605475, 0.8314943639291466, 0.7938840579710145, 0.8

STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(


[0.7786634460547504, 0.8022866344605475, 0.8062866344605475, 0.8314943639291466, 0.7938840579710145, 0.8062866344605475, 0.7942866344605475, 0.8018840579710145, 0.8134814814814815, 0.8022866344605475, 0.8010789049919484, 0.8058840579710145, 0.8138840579710145, 0.8258840579710145, 0.8138840579710145, 0.8306892109500805, 0.8258840579710145, 0.8346892109500805, 0.8226892109500805, 0.8274943639291465, 0.8270917874396135, 0.8234943639291465]
********************
[0.7786634460547504, 0.8022866344605475, 0.8062866344605475, 0.8314943639291466, 0.7938840579710145, 0.8062866344605475, 0.7942866344605475, 0.8018840579710145, 0.8134814814814815, 0.8022866344605475, 0.8010789049919484, 0.8058840579710145, 0.8138840579710145, 0.8258840579710145, 0.8138840579710145, 0.8306892109500805, 0.8258840579710145, 0.8346892109500805, 0.8226892109500805, 0.8274943639291465, 0.8270917874396135, 0.8234943639291465, 0.8390917874396135]
********************
[0.7786634460547504, 0.8022866344605475, 0.8062866344605

In [9]:
print(sorted(layer_accuracy_scores, reverse=True)[:10])
print()
print(sorted(layer_accuracy_scores, reverse=False)[:10])

[0.8390917874396135, 0.8346892109500805, 0.8346892109500805, 0.8314943639291466, 0.8310917874396135, 0.8306892109500805, 0.8274943639291465, 0.8270917874396135, 0.8258840579710145, 0.8258840579710145]

[0.7786634460547504, 0.7938840579710145, 0.7942866344605475, 0.8010789049919484, 0.8018840579710145, 0.8022866344605475, 0.8022866344605475, 0.8058840579710145, 0.8062866344605475, 0.8062866344605475]


In [10]:
num_samples, num_formats, num_layers, hidden_size = layer_wise_hidden_states.shape
num_samples, num_formats, num_layers, num_heads, hidden_size2 = head_wise_hidden_states.shape

head_accuracy_scores = defaultdict(list)
for layer_idx in range(num_layers):
    # if layer_idx not in [31]:
    #     continue
    for head_idx in range(num_heads):
        # Step 1: Prepare input
        X = pairwise_head_wise_hidden_states[:,:,layer_idx,head_idx,:]
        X = X.reshape(-1, hidden_size2*2)
        Y = labels.reshape(-1)

        # Step 2: 80:20 split for validation to select top-K components
        X_fold1, X_fold2, y_fold1, y_fold2 = train_test_split(X, Y, test_size=0.2, random_state=0)

        for i in range(2):
            # Step 3: Standardization
            scaler = StandardScaler()
            if i == 0:
                X_train = scaler.fit_transform(X_fold1)
                X_test = scaler.transform(X_fold2)
                y_train, y_test = y_fold1, y_fold2
            else:
                break
                X_train = scaler.fit_transform(X_fold2)
                X_test = scaler.transform(X_fold1)
                y_train, y_test = y_fold2, y_fold1

            # Step 4: Train a linear probing model
            clf = LogisticRegression(solver="lbfgs", max_iter=1000)
            clf.fit(X_train, y_train)

            # Step 5: Evaluate the model
            y_pred = clf.predict(X_test)
            accuracy = balanced_accuracy_score(y_test, y_pred)
            head_accuracy_scores[layer_idx].append(accuracy)
            # print(classification_report(y_test, y_pred))
            # print(confusion_matrix(y_test, y_pred))
    print(head_accuracy_scores)
    print("*"*20)

defaultdict(<class 'list'>, {0: [0.7258067632850241, 0.6806763285024156, 0.6194943639291466, 0.6578840579710145, 0.6386634460547504, 0.6711175523349436, 0.6818582930756844, 0.5467665056360709, 0.6538582930756844, 0.6606763285024155, 0.6622866344605475, 0.6694299516908213, 0.6806505636070853, 0.6438454106280194, 0.6063381642512078, 0.6422608695652174, 0.6043252818035427, 0.6634943639291465, 0.6287020933977455, 0.6734557165861514, 0.6938582930756844, 0.6162737520128825, 0.6155201288244767, 0.6554685990338165, 0.6894557165861513, 0.6514943639291465, 0.6674685990338165, 0.6762479871175524, 0.6374814814814814, 0.5, 0.5487536231884058, 0.508]})
********************
defaultdict(<class 'list'>, {0: [0.7258067632850241, 0.6806763285024156, 0.6194943639291466, 0.6578840579710145, 0.6386634460547504, 0.6711175523349436, 0.6818582930756844, 0.5467665056360709, 0.6538582930756844, 0.6606763285024155, 0.6622866344605475, 0.6694299516908213, 0.6806505636070853, 0.6438454106280194, 0.6063381642512078,

In [11]:
print(sorted(np.array(list(head_accuracy_scores.values())).flatten().tolist(), reverse=True)[:10])
print()
print(sorted(np.array(list(head_accuracy_scores.values())).flatten().tolist(), reverse=False)[:10])

[0.7278454106280193, 0.7262608695652174, 0.7258582930756844, 0.7258067632850241, 0.7254299516908213, 0.7242737520128825, 0.7242479871175523, 0.7226634460547504, 0.7186376811594203, 0.7170531400966184]

[0.499597423510467, 0.5, 0.5, 0.504, 0.508, 0.5087793880837359, 0.5119742351046699, 0.514792270531401, 0.515194847020934, 0.5171819645732689]


In [12]:
def get_topk_heads(scores, k=10):
    num_layers, num_heads = scores.shape
    flat_scores = scores.flatten()
    topk_indices = np.argsort(flat_scores)[-k:][::-1]

    topk_info = []
    for idx in topk_indices:
        layer = idx // num_heads
        head = idx % num_heads
        score = scores[layer, head]
        topk_info.append(((layer, head), score))

    return topk_info

aaa = np.array(list(head_accuracy_scores.values()))
get_topk_heads(aaa, 25)

[((12, 31), 0.7278454106280193),
 ((13, 24), 0.7262608695652174),
 ((2, 0), 0.7258582930756844),
 ((0, 0), 0.7258067632850241),
 ((20, 28), 0.7254299516908213),
 ((11, 14), 0.7242737520128825),
 ((5, 28), 0.7242479871175523),
 ((26, 25), 0.7226634460547504),
 ((4, 18), 0.7186376811594203),
 ((12, 14), 0.7170531400966184),
 ((23, 4), 0.7150660225442834),
 ((2, 4), 0.7146634460547504),
 ((6, 16), 0.7146376811594203),
 ((6, 19), 0.7138325281803543),
 ((6, 21), 0.7126505636070853),
 ((6, 11), 0.7114170692431562),
 ((4, 26), 0.7090531400966184),
 ((10, 24), 0.7090531400966184),
 ((31, 1), 0.7082737520128825),
 ((2, 28), 0.7066634460547504),
 ((31, 21), 0.7058582930756844),
 ((5, 14), 0.7058582930756844),
 ((20, 21), 0.7050789049919485),
 ((17, 30), 0.7050531400966183),
 ((12, 22), 0.7046763285024155)]

In [13]:
def get_topk_layers(layer_scores, k=10):
    topk_indices = np.argsort(layer_scores)[-k:][::-1]
    topk_info = [(idx, layer_scores[idx]) for idx in topk_indices]
    return topk_info

get_topk_layers(layer_accuracy_scores, 10)

[(22, 0.8390917874396135),
 (24, 0.8346892109500805),
 (17, 0.8346892109500805),
 (3, 0.8314943639291466),
 (27, 0.8310917874396135),
 (15, 0.8306892109500805),
 (19, 0.8274943639291465),
 (20, 0.8270917874396135),
 (16, 0.8258840579710145),
 (13, 0.8258840579710145)]