In [None]:
import matplotlib.pyplot as plt
import argparse
import os
import torch
import numpy as np 

from linear_probe_utils import *
from models import download_and_load_model

parser = argparse.ArgumentParser(description='Template Label Counter')
parser.add_argument("--dataset",    type=str, action="store", default='trivia_qa') # coqa, trivia_qa
parser.add_argument("--model",      type=str, action="store", default='google/gemma-2-2b-it') # meta-llama/Llama-3.1-8B-Instruct
parser.add_argument("--judge_model",type=str, action="store", default='google/gemma-2-9b-it')
parser.add_argument("--output_dir", type=str, action="store", default='./cache') # required=True, 
parser.add_argument("--data_split", type=str, action="store", default='train') 
parser.add_argument("--f", type=str, action="store", default='') 

args = parser.parse_args()

model, tokenizer = download_and_load_model(args.model, args.output_dir)

In [None]:
bs = 8
lr = 1e-4
epoch = 10
probe_trainer = ProbeTrainer(model.config,bs=bs,lr=lr,epoch=epoch)
model_config = model.config
model= None
tokenizer = None

# Load dataset and classify them into two classes 

In [None]:
from manifold_learning_isomap_viz import load_judged_data  

# workspace
workspace_dir = os.path.join(args.output_dir, 'answers', args.dataset, 'train', args.model)
test_workspace_dir = os.path.join(args.output_dir, 'answers', args.dataset, 'test', args.model)
print(model_config.num_hidden_layers)
max_acc_hidden = []
for layer in range(model_config.num_hidden_layers):
    train_data, train_gt_labels = load_judged_data(workspace_dir, layer)
    unknown_data = train_data[train_gt_labels.reshape(-1) == 0.0, :].copy()
    unknown_labels = train_gt_labels[train_gt_labels.reshape(-1) == 0.0, :].copy()
    train_data = np.concatenate([train_data] + [unknown_data] * 5, axis=0)
    train_gt_labels = np.concatenate([train_gt_labels] + [unknown_labels] * 5, axis=0)
    train_data = torch.from_numpy(train_data).float()
    train_gt_labels = torch.from_numpy(train_gt_labels).float()

    test_data, test_gt_labels   = load_judged_data(test_workspace_dir, layer)
    test_data = torch.from_numpy(test_data).float()
    test_gt_labels = torch.from_numpy(test_gt_labels).float()
    # print(train_data.shape)
    state_dim = model_config.hidden_size
    prober = LinearProbe(state_dim)
    optimizer = torch.optim.SGD(prober.parameters(), lr=probe_trainer.lr, momentum=0.9)

    test_results,prober = probe_trainer.fit_one_probe(prober, optimizer,
                                                      train_data, train_gt_labels,
                                                      test_data, test_gt_labels)
    max_acc = max(test_results,key=lambda x:x['auroc'])
    max_acc_hidden.append(max_acc)
    print(layer, max_acc)
    # if max_acc["acc"]>0.8:
    #     break
print("probe:",max_acc_hidden.index(max(max_acc_hidden,key=lambda x:x['auroc'])), 
      max(max_acc_hidden,key=lambda x:x['auroc']))

meta-llama/Llama-3.1-8B-Instruct:   
probe: 26 {'acc': 0.578, 'precision': 0.616, 'recall': 0.746, 'specificity': 0.338, 'macro_f1': 0.536, 'auroc': 0.5420561325420377}  
gemma2-9b:  
probe: 35 {'acc': 0.539, 'precision': 0.623, 'recall': 0.489, 'specificity': 0.607, 'macro_f1': 0.539, 'auroc': 0.5476651420061696}  
2B:   
probe: 18 {'acc': 0.477, 'precision': 0.396, 'recall': 0.807, 'specificity': 0.284, 'macro_f1': 0.469, 'auroc': 0.5453987884342413}  

In [None]:
plt.plot([d["macro_f1"] for d in max_acc_hidden],label="hidden")
# plt.legend()
plt.xlabel("layer")
plt.ylabel("test macro_f1")
# plt.savefig('hidden_probe_sft0911_1.png')
plt.show()
plt.show()

***