In [1]:
import os
import json
import jsonlines

from collections import defaultdict
import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from scipy.spatial import distance

In [2]:
def read_memmap(filepath):
    with open(filepath.replace(".dat", ".conf"), "r") as fin_config:
        memmap_configs = json.load(fin_config)
        return np.memmap(filepath, mode="r", shape=tuple(memmap_configs["shape"]), dtype=memmap_configs["dtype"])

In [3]:
NUM_FORMATS = 8

In [4]:
result_dir = "../../../results"
input_dir = "../../../preprocessed_datasets"

In [5]:
model_name = "Llama-3.1-8B-Instruct"
dataset_name = "CommonsenseQA"
prompting_strategy = "zero-shot"

output_dir = f"{result_dir}/{dataset_name}/{model_name}"

predictions_path = os.path.join(output_dir, f"{prompting_strategy}_predictions_validation.jsonl")
try:
    with jsonlines.open(predictions_path) as fin:
        id_predictions_map = {}
        for example in fin.iter():
            id_predictions_map[example["id"]] = example["predictions"]
except:
    pass

id_majority_level_map = defaultdict(list)
with jsonlines.open(predictions_path) as fin:
    for example in fin.iter():
        id_predictions_map[example["id"]].pop('majority_voting')
        for ii in range(NUM_FORMATS):
            zz = list(id_predictions_map[example["id"]].values()).count(id_predictions_map[example["id"]][str(ii)])
            # print(zz)
            # zz = ((zz>7)*1.0) # binarize
            # zz = max(8-zz, zz) # pairing
            id_majority_level_map[example["id"]].append(zz)
        # print()

id_majority_level_map = list(id_majority_level_map.values())
id_majority_level_map = np.array(id_majority_level_map)

id_majority_level_map = id_majority_level_map.reshape(-1)

In [6]:
layer_wise_path = os.path.join(output_dir, f"{prompting_strategy}_layer_wise_hidden_states_validation.dat")
head_wise_path = os.path.join(output_dir, f"{prompting_strategy}_head_wise_hidden_states_validation.dat")

layer_wise_hidden_states = read_memmap(layer_wise_path)

In [7]:
num_samples, num_formats, num_layers, hidden_size = layer_wise_hidden_states.shape
for layer_idx in range(num_layers):
    if layer_idx not in [31]:
        continue

    # Step 1: Prepare input
    X = layer_wise_hidden_states[:,:,layer_idx,:]
    X = X.reshape(-1, hidden_size)
    Y = id_majority_level_map

    print("Linear probing")
    # Step 2: 80:20 split for validation to select top-K components
    X_fold1, X_fold2, y_fold1, y_fold2 = train_test_split(X, Y, test_size=0.2, random_state=0)

    layer_accuracy_scores = []
    for i in range(2):
        # Step 3: Standardization
        scaler = StandardScaler()
        if i == 0:
            X_train = scaler.fit_transform(X_fold1)
            X_test = scaler.transform(X_fold2)
            y_train, y_test = y_fold1, y_fold2
        else:
            break
            X_train = scaler.fit_transform(X_fold2)
            X_test = scaler.transform(X_fold1)
            y_train, y_test = y_fold2, y_fold1

        # Step 4: Train a linear probing model
        clf = LogisticRegression(solver="lbfgs", max_iter=1000)
        clf.fit(X_train, y_train)

        # Step 5: Evaluate the model
        y_pred = clf.predict(X_test)
        accuracy = accuracy_score(y_test, y_pred)
        layer_accuracy_scores.append(accuracy)
        print(classification_report(y_test, y_pred))
        print(confusion_matrix(y_test, y_pred))
    # print(f"Layer {layer_idx}: {np.mean(layer_accuracy_scores):.4f}")
    print("*"*20)

print("*"*30)

Linear probing
              precision    recall  f1-score   support

           2       0.00      0.00      0.00         2
           3       0.00      0.00      0.00         3
           4       1.00      1.00      1.00         4
           5       0.57      0.67      0.62         6
           7       1.00      1.00      1.00         6
           8       1.00      1.00      1.00       370

    accuracy                           0.98       391
   macro avg       0.60      0.61      0.60       391
weighted avg       0.98      0.98      0.98       391

[[  0   2   0   0   0   0]
 [  0   0   0   3   0   0]
 [  0   0   4   0   0   0]
 [  0   2   0   4   0   0]
 [  0   0   0   0   6   0]
 [  0   0   0   0   0 370]]
********************
******************************


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
