In [1]:
import os
import json
import jsonlines

from itertools import combinations
from collections import defaultdict, Counter
import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, balanced_accuracy_score, classification_report, confusion_matrix
from scipy.spatial import distance

In [2]:
def read_memmap(filepath):
    with open(filepath.replace(".dat", ".conf"), "r") as fin_config:
        memmap_configs = json.load(fin_config)
        return np.memmap(filepath, mode="r", shape=tuple(memmap_configs["shape"]), dtype=memmap_configs["dtype"])

In [3]:
NUM_FORMATS = 8

In [4]:
result_dir = "../../../results"
input_dir = "../../../preprocessed_datasets"

In [5]:
model_name = "Llama-3.1-8B-Instruct"
dataset_name = "CommonsenseQA"
prompting_strategy = "zero-shot"

id_predictions_map = {}

output_dir = f"{result_dir}/{dataset_name}/{model_name}"

predictions_path = os.path.join(output_dir, f"{prompting_strategy}_predictions_validation.jsonl")
try:
    with jsonlines.open(predictions_path) as fin:
        id_predictions_map = {}
        for example in fin.iter():
            id_predictions_map[example["id"]] = example["predictions"]
except:
    pass

labels = []
for preds in id_predictions_map.values():
    pairwise_labels = []
    for i, j in list(combinations(range(8), 2)):
        label = int(preds[str(i)] == preds[str(j)])
        pairwise_labels.append(label)
    labels.append(pairwise_labels)
labels = np.array(labels)


In [6]:
def expand_pairwise_embeddings(hidden_states):
    num_samples = hidden_states.shape[0]

    pairwise_embeddings = []

    for s in range(num_samples):
        sample_pairs = []
        for i, j in list(combinations(range(8), 2)):
            emb_i = hidden_states[s, i]
            emb_j = hidden_states[s, j]
            concat = np.concatenate([emb_i, emb_j], axis=-1)
            sample_pairs.append(concat)
        sample_pairs = np.stack(sample_pairs, axis=0)
        pairwise_embeddings.append(sample_pairs)

    return np.stack(pairwise_embeddings, axis=0)

In [7]:
layer_wise_path = os.path.join(output_dir, f"{prompting_strategy}_layer_wise_hidden_states_validation.dat")
head_wise_path = os.path.join(output_dir, f"{prompting_strategy}_head_wise_hidden_states_validation.dat")

layer_wise_hidden_states = read_memmap(layer_wise_path)
head_wise_hidden_states = read_memmap(head_wise_path)

pairwise_layer_wise_hidden_states = expand_pairwise_embeddings(layer_wise_hidden_states)
pairwise_head_wise_hidden_states = expand_pairwise_embeddings(head_wise_hidden_states)

In [8]:
num_samples, num_formats, num_layers, hidden_size = layer_wise_hidden_states.shape
num_samples, num_formats, num_layers, num_heads, hidden_size2 = head_wise_hidden_states.shape

layer_accuracy_scores = []
for layer_idx in range(num_layers):
    # if layer_idx not in [31]:
    #     continue

    # Step 1: Prepare input
    X = pairwise_layer_wise_hidden_states[:,:,layer_idx,:]
    X = X.reshape(-1, hidden_size*2)
    Y = labels.reshape(-1)

    # Step 2: 80:20 split for validation to select top-K components
    X_fold1, X_fold2, y_fold1, y_fold2 = train_test_split(X, Y, test_size=0.2, random_state=0)

    for i in range(2):
        # Step 3: Standardization
        scaler = StandardScaler()
        if i == 0:
            X_train = scaler.fit_transform(X_fold1)
            X_test = scaler.transform(X_fold2)
            y_train, y_test = y_fold1, y_fold2
        else:
            break
            X_train = scaler.fit_transform(X_fold2)
            X_test = scaler.transform(X_fold1)
            y_train, y_test = y_fold2, y_fold1

        # Step 4: Train a linear probing model
        clf = LogisticRegression(solver="lbfgs", max_iter=1000)
        clf.fit(X_train, y_train)

        # Step 5: Evaluate the model
        y_pred = clf.predict(X_test)
        accuracy = balanced_accuracy_score(y_test, y_pred)
        layer_accuracy_scores.append(accuracy)
        print(classification_report(y_test, y_pred))
        print(confusion_matrix(y_test, y_pred))
    print(layer_accuracy_scores)
    print("*"*20)

              precision    recall  f1-score   support

           0       0.62      0.70      0.66        40
           1       0.99      0.99      0.99      1327

    accuracy                           0.98      1367
   macro avg       0.81      0.84      0.82      1367
weighted avg       0.98      0.98      0.98      1367

[[  28   12]
 [  17 1310]]
[0.843594574227581]
********************
              precision    recall  f1-score   support

           0       0.64      0.70      0.67        40
           1       0.99      0.99      0.99      1327

    accuracy                           0.98      1367
   macro avg       0.81      0.84      0.83      1367
weighted avg       0.98      0.98      0.98      1367

[[  28   12]
 [  16 1311]]
[0.843594574227581, 0.8439713639788997]
********************
              precision    recall  f1-score   support

           0       0.67      0.75      0.71        40
           1       0.99      0.99      0.99      1327

    accuracy              

In [9]:
print(sorted(layer_accuracy_scores, reverse=True)[:10])
print()
print(sorted(layer_accuracy_scores, reverse=False)[:10])

[0.9068481537302185, 0.8947249434815373, 0.8943481537302186, 0.8943481537302186, 0.8943481537302186, 0.8943481537302186, 0.8939713639788998, 0.8939713639788998, 0.8939713639788998, 0.8939713639788998]

[0.8314713639788998, 0.843594574227581, 0.8439713639788997, 0.8447249434815373, 0.8451017332328561, 0.8451017332328561, 0.8568481537302186, 0.8682177844762622, 0.8682177844762622, 0.868594574227581]


In [10]:
def get_topk_layers(layer_scores, k=10):
    topk_indices = np.argsort(layer_scores)[-k:][::-1]
    topk_info = [(idx, layer_scores[idx]) for idx in topk_indices]
    return topk_info

get_topk_layers(layer_accuracy_scores, 32)

[(29, 0.9068481537302185),
 (26, 0.8947249434815373),
 (28, 0.8943481537302186),
 (31, 0.8943481537302186),
 (27, 0.8943481537302186),
 (17, 0.8943481537302186),
 (24, 0.8939713639788998),
 (25, 0.8939713639788998),
 (18, 0.8939713639788998),
 (23, 0.8939713639788998),
 (16, 0.893594574227581),
 (19, 0.893594574227581),
 (4, 0.8826017332328561),
 (21, 0.8822249434815372),
 (22, 0.8818481537302185),
 (20, 0.8818481537302185),
 (13, 0.8818481537302185),
 (30, 0.8814713639788998),
 (11, 0.8810945742275811),
 (15, 0.8810945742275811),
 (9, 0.870101733232856),
 (2, 0.8693481537302186),
 (10, 0.868594574227581),
 (14, 0.8682177844762622),
 (12, 0.8682177844762622),
 (3, 0.8568481537302186),
 (5, 0.8451017332328561),
 (7, 0.8451017332328561),
 (6, 0.8447249434815373),
 (1, 0.8439713639788997),
 (0, 0.843594574227581),
 (8, 0.8314713639788998)]