In [3]:
import os
import pandas as pd

In [4]:
RUN_PATH = '/mnt/ssd-1/sai/semantic-memorization/experiments/'
base_path = os.path.join(RUN_PATH, '2024-04-16_11-07-57', 'deduped', '12b')

In [5]:
model_evals_path = os.path.join(base_path, 'model_taxonomy', 'predictions.parquet')

In [6]:
data = pd.read_parquet(model_evals_path)

In [7]:
import scipy.stats as stats
import numpy as np

In [8]:
model_probs = np.array(data['model_prediction_probs'].to_list())
baseline_probs = np.array(data['baseline_prediction_probs'].to_list())
model_predictions = (model_probs[:, 1] > 0.5).astype(np.int32)
baseline_predictions = (baseline_probs[:, 1] > 0.5).astype(np.int32)
labels = data['labels'].to_numpy()

In [23]:
from sklearn.metrics import precision_score, recall_score, f1_score, average_precision_score

In [10]:
f1_score(labels, (model_probs[:, 1] > 0.5).astype(np.int32))

0.6682080924855492

In [11]:
f1_score(labels, model_predictions)

0.6682080924855492

In [12]:
f1_score(labels, baseline_predictions)

0.6271013606036769

In [13]:
f1_score(labels, (baseline_probs[:, 1] > 0.5).astype(np.int32))

0.6271013606036769

In [18]:
import sys
sys.path.append('../') 

In [19]:
from model_utils import expected_calibration_error

In [26]:
from sklearn.model_selection import KFold

metrics = {
    'precision': {'model': [], 'baseline': []},
    'recall': {'model': [], 'baseline': []},
    'ece': {'model': [], 'baseline': []},
    'pr_auc': {'model': [], 'baseline': []},
}

fold = KFold(n_splits=100, shuffle=True)
for _, indicies in fold.split(labels):
    split_model_probs = model_probs[indicies, :]
    split_baseline_probs = baseline_probs[indicies, :]
    split_model_predictions = model_predictions[indicies]
    split_baseline_predictions = baseline_predictions[indicies]
    split_labels = labels[indicies]
    metrics['precision']['model'].append(precision_score(split_labels, split_model_predictions))
    metrics['recall']['model'].append(recall_score(split_labels, split_model_predictions))
    metrics['ece']['model'].append(expected_calibration_error(split_model_probs, split_labels))
    metrics['pr_auc']['model'].append(average_precision_score(split_labels, split_model_probs[:, 0]))

    metrics['precision']['baseline'].append(precision_score(split_labels, split_baseline_predictions))
    metrics['recall']['baseline'].append(recall_score(split_labels, split_baseline_predictions))
    metrics['ece']['baseline'].append(expected_calibration_error(split_baseline_probs, split_labels))
    metrics['pr_auc']['baseline'].append(average_precision_score(split_labels, split_baseline_probs[:, 0]))
    
   

In [27]:
for metric in metrics:
    print("#"*10, end="\t")
    print(metric, end="\t")
    print("#"*10)
    print(stats.ttest_rel(metrics[metric]['model'], metrics[metric]['baseline']))
    print(stats.wilcoxon(metrics[metric]['model'], metrics[metric]['baseline'], zero_method='zsplit'))
    

##########	precision	##########
TtestResult(statistic=31.044878362310417, pvalue=7.91995142541859e-53, df=99)
WilcoxonResult(statistic=0.0, pvalue=3.896559845095909e-18)
##########	recall	##########
TtestResult(statistic=-17.42727806974501, pvalue=6.30548055206751e-32, df=99)
WilcoxonResult(statistic=3.0, pvalue=4.2634423565499585e-18)
##########	ece	##########
TtestResult(statistic=-34.4245862429096, pvalue=6.748238588355811e-57, df=99)
WilcoxonResult(statistic=0.0, pvalue=3.896559845095909e-18)
##########	pr_auc	##########
TtestResult(statistic=-4.22060853567759, pvalue=5.409042556981045e-05, df=99)
WilcoxonResult(statistic=1616.0, pvalue=0.0017754092402886207)


In [131]:
np.mean(metrics['ece']['baseline'])

0.00870116132314611

In [121]:
stats.wilcoxo(metrics['recall']['model'], metrics['recall']['baseline'], zero_method='zsplit')

WilcoxonResult(statistic=1600.0, pvalue=0.0011633823709150059)

In [123]:
stats.ttest_rel(metrics['recall']['model'], metrics['recall']['baseline'])

TtestResult(statistic=-3.708965251518794, pvalue=0.00034353642222615456, df=99)

In [92]:
for i in range(100):
    # print(metrics['precision']['model'][i], metrics['precision']['baseline'][i], metrics['precision']['model'][i] - metrics['precision']['baseline'][i])
    print(metrics['recall']['model'][i],  metrics['recall']['baseline'][i], metrics['recall']['model'][i] - metrics['recall']['baseline'][i])

0.9789473684210527 0.9894736842105263 -0.010526315789473606
0.9830508474576272 0.9830508474576272 0.0
1.0 1.0 0.0
0.9893617021276596 0.9787234042553191 0.010638297872340496
0.9739130434782609 0.9739130434782609 0.0
0.9807692307692307 0.9903846153846154 -0.009615384615384692
1.0 1.0 0.0
1.0 1.0 0.0
0.9893617021276596 1.0 -0.010638297872340385
0.9907407407407407 0.9907407407407407 0.0
0.9905660377358491 1.0 -0.009433962264150941
0.98989898989899 0.98989898989899 0.0
0.9908256880733946 0.981651376146789 0.00917431192660556
1.0 1.0 0.0
0.9690721649484536 0.9587628865979382 0.010309278350515427
1.0 1.0 0.0
0.979381443298969 0.9896907216494846 -0.010309278350515538
0.9900990099009901 0.9801980198019802 0.00990099009900991
1.0 1.0 0.0
0.970873786407767 0.9805825242718447 -0.009708737864077666
0.9626168224299065 0.9719626168224299 -0.009345794392523366
0.9719626168224299 0.9813084112149533 -0.009345794392523366
0.9894736842105263 0.9894736842105263 0.0
0.9906542056074766 0.9906542056074766 0.0

WilcoxonResult(statistic=218.0, pvalue=0.0004251080668211116)

In [35]:
np.mean(metrics['precision']['model']), np.std(metrics['precision']['model'])

(0.5342705427255143, 0.03752549571306509)

In [36]:
np.mean(metrics['precision']['baseline']), np.std(metrics['precision']['baseline'])

(0.5338319597353697, 0.03617128010054141)

In [37]:
np.mean(metrics['recall']['model']), np.std(metrics['recall']['model'])

(0.9875399595350456, 0.010646342251008932)

In [38]:
np.mean(metrics['recall']['baseline']), np.std(metrics['recall']['baseline'])

(0.990465075591582, 0.008838522837603375)

In [20]:
stats.wilcoxon(model_probs, baseline_probs)

WilcoxonResult(statistic=31280028987.0, pvalue=0.0)

In [22]:
stats.wilcoxon(model_predictions, baseline_predictions)

WilcoxonResult(statistic=395317.0, pvalue=0.048323911287848964)

In [16]:
res.statistic

31280028987.0

In [17]:
res.pvalue

0.0