In [None]:
import numpy as np
import pandas as pd
from pathlib import Path
import json
import pickle
import sys
from tqdm import tqdm
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import matplotlib.pyplot as plt
from collections import defaultdict

sys.path.append('../')
pd.set_option("display.max_rows", 500)
pd.set_option('max_colwidth', 500)

figure_dir = Path('./figures')
figure_dir.mkdir(exist_ok=True)

In [None]:
root_dir = Path('/data/healthy-ml/scratch/haoran/clinical_fact_check/results/fact_check/results/eval_pipeline_claude/')
root_dir2 = Path('/data/healthy-ml/scratch/haoran/clinical_fact_check/results/fact_check/results/eval_pipeline_codellama/')
root_dir_llm = Path('/data/healthy-ml/scratch/haoran/clinical_fact_check/results/fact_check/results/eval_baseline_llm/')

ress, argss = [], []
for i in list(root_dir.glob('**/done')) + list(root_dir2.glob('**/done')) + list(root_dir_llm.glob('**/done')):
    args = json.load((i.parent/'args.json').open('r'))
        
    res = pickle.load((i.parent/'res.pkl').open('rb')) 
    exp_name = i.parent.parent.name
    args['exp_name'] = exp_name
    
    if exp_name.startswith('eval_baseline'): # llm baseline
        args['name'] = 'baseline' + ': ' + args['llm'] + ': ' + args['select_rows']
    else:
        args['name'] =  args['llm'] + ': ' + args['prompt']
    
    ress.append(res)
    argss.append(args)

In [None]:
claims = pd.read_csv(Path(argss[0]['claim_df_path'])).reset_index().rename(columns = {'index': 'claim_id'})

In [None]:
claims.columns

In [None]:
for c, i in enumerate(ress):
    ress[c] = (ress[c].drop(columns = ['label', 'claim'], errors = 'ignore')
               .merge(claims, on = 'claim_id', how = 'inner')
               .assign(prompt_type = argss[c]['name']))

In [None]:
df = pd.concat(ress, ignore_index = True)

In [None]:
df.shape

In [None]:
len(claims)

### Accuracy based on # tables required

In [None]:
srs = df.groupby(['prompt_type', 'requires_global_kg', 'num_tables_required']).apply(lambda x: (x['label'] == x['pred_label']).sum()/len(x))
temp = srs.to_frame().reset_index().rename(columns = {0: 'Accuracy'}).pivot_table(columns = ['requires_global_kg', 'num_tables_required'], index = ['prompt_type'], values = ['Accuracy'])
temp = temp.loc[~temp.index.str.endswith('random')]
(temp).style.format('{:.1%}')

In [None]:
srs = df.groupby(['prompt_type']).apply(lambda x: (x['label'] == x['pred_label']).sum()/len(x))
srs.to_frame().style.format('{:.2%}')

In [None]:
# how many samples in each category?
srs = df.groupby(['prompt_type', 'requires_global_kg', 'num_tables_required']).apply(len)
srs.to_frame().reset_index().rename(columns = {0: '# Samples'}).pivot_table(columns = ['requires_global_kg', 'num_tables_required'], index = ['prompt_type'], values = ['# Samples'])

In [None]:
# accuracy on NEI labels
srs = df.query('label == "N"').groupby(['prompt_type', 'requires_global_kg', 'num_tables_required']).apply(lambda x: (x['label'] == x['pred_label']).sum()/len(x))
temp = srs.to_frame().reset_index().rename(columns = {0: 'Accuracy'}).pivot_table(columns = ['requires_global_kg', 'num_tables_required'], index = ['prompt_type'], values = ['Accuracy'])
temp = temp.loc[~temp.index.str.endswith('random')]
temp.to_pickle('result2_nei.pkl')
temp.style.format('{:.1%}')

In [None]:
# accuracy on non-NEI labels
srs = df.query('label != "N"').groupby(['prompt_type', 'requires_global_kg', 'num_tables_required']).apply(lambda x: (x['label'] == x['pred_label']).sum()/len(x))
temp = srs.to_frame().reset_index().rename(columns = {0: 'Accuracy'}).pivot_table(columns = ['requires_global_kg', 'num_tables_required'], index = ['prompt_type'], values = ['Accuracy'])
temp = temp.loc[~temp.index.str.endswith('random')]
temp.to_pickle('result2_tf.pkl')
temp.style.format('{:.1%}')

### Confusion Matrix

In [None]:
for i in df['prompt_type'].unique():
    print(i)
    disp = ConfusionMatrixDisplay(confusion_matrix(df.loc[df.prompt_type == i, 'label'], df.loc[df.prompt_type == i, 'pred_label'], labels = ['T', 'F', 'N']),
    display_labels = ['T', 'F', 'N'])
    disp.plot()
    plt.show()

### Committed predictions

In [None]:
# % non-NEI predictions
srs = df.groupby(['prompt_type', 'requires_global_kg', 'num_tables_required']).apply(lambda x: ( (x['pred_label'] != 'N')).sum()/len(x))
srs.to_frame().reset_index().rename(columns = {0: '# Samples'}).pivot_table(columns = ['requires_global_kg', 'num_tables_required'], index = ['prompt_type'], values = ['# Samples']).style.format('{:.2%}')

In [None]:
# Accuracy when predict non-NEI
srs = df.groupby(['prompt_type', 'requires_global_kg', 'num_tables_required']).apply(lambda x: ((x['label'] == x['pred_label']) & (x['pred_label'] != 'N')).sum()/(x['pred_label'] != 'N').sum())
srs.to_frame().reset_index().rename(columns = {0: 'Accuracy'}).pivot_table(columns = ['requires_global_kg', 'num_tables_required'], index = ['prompt_type'], values = ['Accuracy']).style.format('{:.2%}')