### Load library and set seed

In [1]:
import warnings
warnings.filterwarnings('ignore')

# multi-processing for NTK kernel
from torch.multiprocessing import set_start_method, set_sharing_strategy
import torch.multiprocessing as mp
set_start_method("spawn")
set_sharing_strategy("file_system")

import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

import pickle
import yaml as yaml
import click
from datetime import datetime
import random
from datasets import load_dataset, concatenate_datasets, Value
import torch

import sys

sys.path.insert(0, './lmntk')
sys.path.insert(0, './vinfo/lmntk')

os.environ["OMP_NUM_THREADS"] = '16'
os.environ["OPENBLAS_NUM_THREADS"] = '16'
os.environ["MKL_NUM_THREADS"] = '16'
os.environ["CUDA_VISIBLE_DEVICES"] = "7"

from dataset import *
from probe import *
from dvutils.Data_Shapley import *
import logging

from dataclasses import dataclass, field
from transformers import HfArgumentParser

import time

dataset_name="sst2"
seed=2023
num_dp=5000
tmc_iter=200
prompt=True # usually True, whether use prompt fine-tuning
signgd=False # usually False, whether use signGD kernel; not adopted in FreeShap
approximate="inv" # can also be "none" (use no approximation, exact inverse); "diagonal" (use block diagonal for inverse)
per_point=True # if True: get the instance score for each test point; if False: get instance score for test sets
early_stopping="True"
tmc_seed=2023
val_sample_num = 1000
yaml_path="../configs/dshap/sst2/ntk_prompt.yaml"
file_path = "./freeshap_res/"

torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
yaml_args = yaml.load(open(yaml_path), Loader=yaml.Loader)
list_dataset = yaml_args['dataset']
probe_model = yaml_args['probe_com']
dshap_com = yaml_args['dshap_com']
if prompt:
    probe_model.model.init(list_dataset.label_word_list)
if approximate != "none":
    probe_model.approximate(approximate)
if approximate == "inv":
    probe_model.normalize_ntk()
if signgd:
    probe_model.signgd()
np.random.seed(seed)
random.seed(seed)
    
if dataset_name == "sst2":
    dataset = load_dataset("sst2")
elif dataset_name == "mr":
    dataset = load_dataset("rotten_tomatoes")
elif dataset_name == "rte":
    dataset = load_dataset("glue", "rte")
    # 1: not entail; 0: entail
elif dataset_name == "mnli":
    dataset = load_dataset("glue", "mnli")
elif dataset_name == "mrpc":
    dataset = load_dataset("glue", "mrpc")
# Sample 10 data points from the dataset
train_data = dataset['train']
train_data = train_data.map(lambda example, idx: {'idx': idx}, with_indices=True)
train_data = train_data.shuffle(seed).select(range(min(train_data.num_rows, num_dp)))
sampled_idx = train_data['idx']

if dataset_name == "mnli":
    val_num = dataset['validation_matched'].num_rows
elif dataset_name == "subj" or dataset_name == "ag_news":
    val_num = dataset['test'].num_rows
else:
    val_num = dataset['validation'].num_rows
if val_sample_num > val_num:
    sampled_val_idx = np.arange(val_num)
else:
    sampled_val_idx = np.random.choice(np.arange(val_num), val_sample_num, replace=False).tolist()
    
if 'llama' in probe_model.args['model']:
    model_name = 'llama'
elif 'roberta' in probe_model.args['model']:
    model_name = 'roberta'
elif 'bert' in probe_model.args['model']:
    model_name = 'bert'
valid_data = dataset['validation']
reindex_valid_data = []
for index in sampled_val_idx:
    reindex_valid_data.append(valid_data[int(index)])

Constructing <class 'dataset.EasyReader'>
Constructing <class 'dataset.FastListDataset'>
Label 0 to word terrible (6659)
Label 1 to word great (2307)
label_to_word:  {'0': 6659, '1': 2307}
label_list:  [6659, 2307]
Constructing <class 'probe.NTKProbe'>
Constructing NTKProbe
Constructing PromptFinetuneProbe


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Probe has 109514298 parameters
Constructing <class 'dvutils.Data_Shapley.Fast_Data_Shapley'>


### Build NTK kernel

In [2]:
print(f"{file_path}{dataset_name}_{model_name}_ntk_seed{seed}_num{num_dp}_sign{signgd}.pkl")
try:
    with open(f"{file_path}{dataset_name}_{model_name}_ntk_seed{seed}_num{num_dp}_sign{signgd}.pkl", "rb") as f:
        ntk = pickle.load(f)
    print("++++++++++++++++++++++++++++++++++++using cached ntk++++++++++++++++++++++++++++++++++++")
    probe_model.get_cached_ntk(ntk)
    probe_model.get_train_labels(list_dataset.get_idx_dataset(sampled_idx, split="train"))
except:
    print("++++++++++++++++++++++++++++++++++no cached ntk, computing+++++++++++++++++++++++++++++++++++")
    train_set = list_dataset.get_idx_dataset(sampled_idx, split="train")
    val_set = list_dataset.get_idx_dataset(sampled_val_idx, split="val")
    # Given that train_loader and val_loader are provided in run(), prepare datasets
    # Set parameters for ntk computation
    # compute ntk matrix
    ntk = probe_model.compute_ntk(train_set, val_set)
    # save the ntk matrix
    with open(f"{file_path}{dataset_name}_{model_name}_ntk_seed{seed}_num{num_dp}_sign{signgd}.pkl", "wb") as f:
        pickle.dump(ntk, f)
    print("++++++++++++++++++++++++++++++++++saving ntk cache+++++++++++++++++++++++++++++++++++")

./freeshap_res/sst2_bert_ntk_seed2023_num5000_signFalse.pkl
++++++++++++++++++++++++++++++++++no cached ntk, computing+++++++++++++++++++++++++++++++++++


[loading]: 10250it [00:17, 1002.01it/s]

sst2: 67349


[loading]: 67349it [00:18, 3575.48it/s] 
[loading]: 872it [00:15, 56.40it/s]

sst2: 872



100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5872/5872 [05:30<00:00, 17.74it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 500/500 [01:33<00:00,  5.34it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5872/5872 [05:29<00:00, 17.79it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████

current mean value:  tensor(0.4992)
++++++++++++++++++++++++++++++++++saving ntk cache+++++++++++++++++++++++++++++++++++


### Compute Shapley value

In [3]:
shapley_file_path=f"{file_path}{dataset_name}_{model_name}_shapley_result_seed{seed}_num{num_dp}_appro{approximate}_sign{signgd}_earlystop{early_stopping}_tmc{tmc_seed}_iter{tmc_iter}.pkl"
try:
    with open(shapley_file_path,'rb') as f:
        result_dict = pickle.load(f)
    print(f"Loading FreeShap result from {shapley_file_path}")
except:
    print("Computing FreeShap Results")
    dv_result = dshap_com.run(data_idx=sampled_idx, val_data_idx=sampled_val_idx, iteration=tmc_iter,
                                  use_cache_ntk=True, prompt=prompt, seed=tmc_seed, num_dp=num_dp,
                                  checkpoint=False, per_point=per_point, early_stopping=early_stopping)

    mc_com = np.array(dshap_com.mc_cache)
    result_dict = {'dv_result': dv_result,  # entropy, accuracy
                   'sampled_idx': sampled_idx}
    with open(shapley_file_path, "wb") as f:
        pickle.dump(result_dict, f)
    print(f"Saving FreeShap result to {shapley_file_path}")

Computing FreeShap Results
start to compute shapley value


[loading]: 872it [00:16, 53.52it/s]

sst2: 872



[TMC iterations]: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [5:47:11<00:00, 104.16s/it]


Saving FreeShap result to ./freeshap_res/sst2_bert_shapley_result_seed2023_num5000_approinv_signFalse_earlystopTrue_tmc2023_iter200.pkl


### Explain a test prediction (for instance the test point's index is 535)

In [4]:
acc = result_dict['dv_result'][:, 1, :]
acc_sum = np.sum(acc, axis=0)

top_10_high = {}
top_10_low = {}

idx=535
column_vector = acc[:, idx]
print(f"{reindex_valid_data[int(idx)]}")
top_10_high[idx] = np.argsort(column_vector)[-10:][::-1]  # Indices of top 5 highest values
print("================================ Most influential ==========================")
for aindex in top_10_high[idx]:
    print(f"score: {column_vector[int(aindex)]}  |  {train_data[int(aindex)]}")
print("================================ Least influential ==========================")
top_10_low[idx] = np.argsort(column_vector)[:10]  # Indices of top 5 lowest values
for aindex in top_10_low[idx]:
    print(f"score: {column_vector[int(aindex)]}  |  {train_data[int(aindex)]}")
print()
print()

{'idx': 535, 'sentence': 'a quiet , pure , elliptical film ', 'label': 1}
score: 0.02  |  {'idx': 10497, 'sentence': 'a funny little film ', 'label': 1}
score: 0.0175  |  {'idx': 3364, 'sentence': 'pretty good little movie ', 'label': 1}
score: 0.015  |  {'idx': 53851, 'sentence': "it 's refreshing that someone understands the need for the bad boy ", 'label': 1}
score: 0.015  |  {'idx': 7139, 'sentence': 'its ripe recipe , inspiring ingredients ', 'label': 1}
score: 0.015  |  {'idx': 5403, 'sentence': "is that it 's a rock-solid little genre picture ", 'label': 1}
score: 0.015  |  {'idx': 33494, 'sentence': 'trial movie , escape movie and unexpected fable ', 'label': 1}
score: 0.015  |  {'idx': 55842, 'sentence': 'dazzling entertainment ', 'label': 1}
score: 0.015  |  {'idx': 52439, 'sentence': 'the ya-ya sisterhood ', 'label': 1}
score: 0.015  |  {'idx': 28783, 'sentence': 'a well-put-together piece ', 'label': 1}
score: 0.0125  |  {'idx': 55951, 'sentence': 'a serious drama ', 'label

In [5]:
acc = result_dict['dv_result'][:, 1, :]
acc_sum = np.sum(acc, axis=0)

top_10_high = {}
top_10_low = {}

idx=1
column_vector = acc[:, idx]
print(f"{reindex_valid_data[int(idx)]}")
top_10_high[idx] = np.argsort(column_vector)[-10:][::-1]  # Indices of top 5 highest values
print("================================ Most influential ==========================")
for aindex in top_10_high[idx]:
    print(f"score: {column_vector[int(aindex)]}  |  {train_data[int(aindex)]}")
print("================================ Least influential ==========================")
top_10_low[idx] = np.argsort(column_vector)[:10]  # Indices of top 5 lowest values
for aindex in top_10_low[idx]:
    print(f"score: {column_vector[int(aindex)]}  |  {train_data[int(aindex)]}")
print()
print()

{'idx': 1, 'sentence': 'unflinchingly bleak and desperate ', 'label': 0}
score: 0.05999999999999999  |  {'idx': 25883, 'sentence': 'self-defeatingly decorous ', 'label': 0}
score: 0.045  |  {'idx': 23066, 'sentence': 'oddly moving ', 'label': 1}
score: 0.04  |  {'idx': 45633, 'sentence': 'vicious and absurd ', 'label': 0}
score: 0.04  |  {'idx': 36959, 'sentence': 'scare ', 'label': 0}
score: 0.04  |  {'idx': 11913, 'sentence': 'a deep vein of sadness ', 'label': 0}
score: 0.035  |  {'idx': 39254, 'sentence': 'could be a passable date film ', 'label': 1}
score: 0.035  |  {'idx': 51429, 'sentence': 'druggy and self-indulgent ', 'label': 0}
score: 0.035  |  {'idx': 56599, 'sentence': 'predictable and cloying ', 'label': 0}
score: 0.035  |  {'idx': 53672, 'sentence': 'multi-layered ', 'label': 1}
score: 0.035  |  {'idx': 51283, 'sentence': 'absolutely , inescapably gorgeous , ', 'label': 1}
score: -0.10000000000000002  |  {'idx': 29181, 'sentence': 'excessively quirky ', 'label': 1}
score

### Check most helpful/harmful data points

In [6]:
acc = result_dict['dv_result'][:, 1, :]
acc_sum = np.sum(acc, axis=1)

sorted_indices = np.argsort(acc_sum)[::-1]
        
top = 5
cur = 0
# top - sample
equal_symbol="="* 35
print(f"{equal_symbol} Most Helpful {equal_symbol}")
for index in sorted_indices[:top]:
    print(f"score: {acc_sum[int(index)]} | {train_data[int(index)]}")
print(f"{equal_symbol} Most Harmful {equal_symbol}")
for index in sorted_indices[-top:]:
    print(f"score: {acc_sum[int(index)]} | {train_data[int(index)]}")



score: 2.505 | {'idx': 13235, 'sentence': 'a dramatic comedy as pleasantly dishonest and pat as any hollywood fluff . ', 'label': 0}
score: 2.28 | {'idx': 40598, 'sentence': "it 's anchored by splendid performances from an honored screen veteran and a sparkling newcomer who instantly transform themselves into a believable mother/daughter pair . ", 'label': 1}
score: 1.645 | {'idx': 54789, 'sentence': ', the humor dwindles . ', 'label': 0}
score: 1.58 | {'idx': 13053, 'sentence': 'is highly pleasurable . ', 'label': 1}
score: 1.56 | {'idx': 49861, 'sentence': ', alas , it collapses like an overcooked soufflé . ', 'label': 0}
score: -1.6100000000000003 | {'idx': 43875, 'sentence': 'could possibly be more contemptuous of the single female population . ', 'label': 1}
score: -2.245 | {'idx': 51208, 'sentence': ', the more outrageous bits achieve a shock-you-into-laughter intensity of almost dadaist proportions . ', 'label': 1}
score: -2.785 | {'idx': 64148, 'sentence': 'could have easily be