In [1]:
import os
import json
import jsonlines

from collections import defaultdict
import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from scipy.spatial import distance

In [2]:
def read_memmap(filepath):
    with open(filepath.replace(".dat", ".conf"), "r") as fin_config:
        memmap_configs = json.load(fin_config)
        return np.memmap(filepath, mode="r", shape=tuple(memmap_configs["shape"]), dtype=memmap_configs["dtype"])

In [3]:
NUM_FORMATS = 8

In [4]:
result_dir = "../../../results"
input_dir = "../../../preprocessed_datasets"

In [5]:
model_name = "Llama-3.1-8B-Instruct"
dataset_name = "CommonsenseQA"
prompting_strategy = "zero-shot"

output_dir = f"{result_dir}/{dataset_name}/{model_name}"

predictions_path = os.path.join(output_dir, f"{prompting_strategy}_predictions.jsonl")
raw_predictions_path = os.path.join(output_dir, f"{prompting_strategy}_raw_predictions.jsonl")
try:
    with jsonlines.open(predictions_path) as fin:
        id_predictions_map = {}
        for example in fin.iter():
            id_predictions_map[example["id"]] = example["predictions"]
except:
    pass

id_majority_level_map = defaultdict(list)
with jsonlines.open(raw_predictions_path) as fin:
    for example in fin.iter():
        id_predictions_map[example["id"]].pop('majority_voting')
        for ii in range(NUM_FORMATS):
            zz = list(id_predictions_map[example["id"]].values()).count(id_predictions_map[example["id"]][str(ii)])
            # print(zz)
            # zz = ((zz>7)*1.0) # binarize
            # zz = max(8-zz, zz) # pairing
            id_majority_level_map[example["id"]].append(zz)
        # print()

id_majority_level_map = list(id_majority_level_map.values())
id_majority_level_map = np.array(id_majority_level_map)

id_majority_level_map = id_majority_level_map.reshape(-1)

In [6]:
layer_wise_path = os.path.join(output_dir, f"{prompting_strategy}_layer_wise_hidden_states.dat")
head_wise_path = os.path.join(output_dir, f"{prompting_strategy}_head_wise_hidden_states.dat")

layer_wise_hidden_states = read_memmap(layer_wise_path)

In [7]:
num_samples, num_formats, num_layers, hidden_size = layer_wise_hidden_states.shape
for layer_idx in range(num_layers):
    if layer_idx not in [31]:
        continue

    # Step 1: Prepare input
    X = layer_wise_hidden_states[:,:,layer_idx,:]
    X = X.reshape(-1, hidden_size)
    # Y = np.tile(np.arange(num_formats), num_samples)
    Y = id_majority_level_map

    Xs = []
    for zi in range(1, 8+1):
        Xs.append(X[Y == zi].mean(0))
    Xs = np.array(Xs)
    Xsvar = Xs.var(0)
    target_idx = np.argsort(Xsvar)[::-1][:5]

    print("Mean embedding vectors for 8 formats (first 5 dimensions)")
    print(Xs[:,:5])
    print()

    print("Pairwise embedding distance (Euclidean)")
    for ii in range(8):
        for jj in range(8):
            # print(f"{distance.cosine(Xs[ii], Xs[jj])*100:2.2f}", end=' ')
            print(f"{distance.euclidean(Xs[ii], Xs[jj])*100:8.2f}", end='')
        print()
    print()

    print("Linear probing")
    # Step 2: Train-test split (2-fold)
    X_fold1, X_fold2, y_fold1, y_fold2 = train_test_split(X, Y, test_size=0.5, random_state=0)

    layer_accuracy_scores = []
    for i in range(2):
        # Step 3: Standardization
        scaler = StandardScaler()
        if i == 0:
            X_train = scaler.fit_transform(X_fold1)
            X_test = scaler.transform(X_fold2)
            y_train, y_test = y_fold1, y_fold2
        else:
            break
            X_train = scaler.fit_transform(X_fold2)
            X_test = scaler.transform(X_fold1)
            y_train, y_test = y_fold2, y_fold1

        # Step 4: Train a linear probing model
        clf = LogisticRegression(solver="lbfgs", max_iter=1000)
        clf.fit(X_train, y_train)

        # Step 5: Evaluate the model
        y_pred = clf.predict(X_test)
        accuracy = accuracy_score(y_test, y_pred)
        layer_accuracy_scores.append(accuracy)
        print(classification_report(y_test, y_pred))
        print(confusion_matrix(y_test, y_pred))
    # print(f"Layer {layer_idx}: {np.mean(layer_accuracy_scores):.4f}")
    print("*"*20)

print("*"*30)

Mean embedding vectors for 8 formats (first 5 dimensions)
[[-0.1315  -0.4983   0.3645  -0.772   -0.02591]
 [-0.1427  -0.5137   0.3198  -0.7725  -0.0371 ]
 [-0.1595  -0.507    0.3645  -0.788   -0.0375 ]
 [-0.1433  -0.4856   0.3845  -0.793   -0.04163]
 [-0.1385  -0.4937   0.3862  -0.776   -0.03073]
 [-0.1332  -0.508    0.3699  -0.7837  -0.0452 ]
 [-0.1321  -0.5015   0.387   -0.7803  -0.03253]
 [-0.132   -0.5107   0.378   -0.78    -0.03114]]

Pairwise embedding distance (Euclidean)
    0.00  118.65  116.02  143.85  118.16  114.65   71.04   88.67
  118.65    0.00  115.62  136.82  145.02  119.24  149.80  111.13
  116.02  115.62    0.00  105.96   79.00   93.95  119.34   82.28
  143.85  136.82  105.96    0.00  122.17  115.62  144.34  100.59
  118.16  145.02   79.00  122.17    0.00  103.52  115.23   94.38
  114.65  119.24   93.95  115.62  103.52    0.00   90.97   70.61
   71.04  149.80  119.34  144.34  115.23   90.97    0.00   77.49
   88.67  111.13   82.28  100.59   94.38   70.61   77.49    0