In [None]:
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt
import pandas as pd
import os
import seaborn as sns
sns.set_style("darkgrid")

In [None]:
# Path to the directory containing pickle files
dir_path = '/storage/vbutoi/scratch/ESE/records/WMH_aug_runs'

# List all pickle files in the directory
pickle_files = [f for f in os.listdir(dir_path) if f.endswith('.pkl')]

# Initialize an empty DataFrame
all_logs = pd.DataFrame()

# Iterate through each pickle file and append its contents to the DataFrame
for p_file in pickle_files:
    file_path = os.path.join(dir_path, p_file)
    temp_df = pd.read_pickle(file_path)
    all_logs = pd.concat([all_logs, temp_df], ignore_index=True)

In [None]:
# Isolate the cases which have no-label and cases where there is at least some label.
def has_label(value):
    return (value != 0.0)

# Add some new useful columns
all_logs['has_label'] = all_logs['gt_lab_amount'].apply(has_label) 

def reorder_splits(df):
    train_logs = df[df['split'] == 'train']
    val_logs = df[df['split'] == 'val']
    cal_logs = df[df['split'] == 'cal']
    fixed_df = pd.concat([train_logs, val_logs, cal_logs])
    return fixed_df

# Set the rows so that it's train, val, cal
all_logs = reorder_splits(all_logs)

# Now we want to collapse the data by subject, so we can get the mean metric score for each subject.
logs_per_subject = all_logs.groupby(['data_idx', 'cal_metric', 'bin_weighting', 'task', 'split']).agg({
    'pred_lab_amount': 'mean',
    'gt_lab_amount': 'mean',
    'cal_score': 'mean',
    'accuracy': 'mean',
    'dice': 'mean',
    'lab_w_accuracy': 'mean'
}).reset_index()

logs_per_subject = reorder_splits(logs_per_subject)

# Group the metrics by important factors
grouped_logs = all_logs.groupby(['task', 'cal_metric', 'split'])
grouped_logs_per_subject = logs_per_subject.groupby(['task', 'cal_metric', 'split'])

In [None]:
plt.rcParams.update({'font.size': 20})  

# accuracy correlation
acc_corr = grouped_logs.apply(lambda x: x['accuracy'].corr(x['cal_score'])).reset_index(name='correlation')
acc_corr = reorder_splits(acc_corr)
acc_corr['eval_metric'] = 'accuracy'

# dice correlations
dice_corr = grouped_logs.apply(lambda x: x['dice'].corr(x['cal_score'])).reset_index(name='correlation')
dice_corr = reorder_splits(dice_corr)
dice_corr['eval_metric'] = 'dice'

# Combine the two
correlations = pd.concat([acc_corr, dice_corr])

g = sns.catplot(data=correlations, 
                x="eval_metric", 
                y="correlation", 
                hue='cal_metric', 
                col="split", 
                row="task",
                kind="bar", 
                height=8, 
                aspect=1)
# Set the y lim between - 1 and 1
g.set(ylim=(-1, 1))

In [None]:
plt.rcParams.update({'font.size': 25})  

# Accuracy correlations 
acc_subj_correlations = grouped_logs_per_subject.apply(lambda x: x['accuracy'].corr(x['cal_score'])).reset_index(name='correlation')
acc_subj_correlations = reorder_splits(acc_subj_correlations)
acc_subj_correlations['eval_metric'] = 'accuracy'

# Dice correlations
dice_subject_correlations = grouped_logs_per_subject.apply(lambda x: x['dice'].corr(x['cal_score'])).reset_index(name='correlation')
dice_subject_correlations = reorder_splits(dice_subject_correlations)
dice_subject_correlations['eval_metric'] = 'dice'

# Combine the two
subject_correlations = pd.concat([acc_subj_correlations, dice_subject_correlations])

g = sns.catplot(data=subject_correlations, 
                x="eval_metric", 
                y="correlation", 
                hue='cal_metric', 
                col="split", 
                row="task",
                kind="bar", 
                height=8, 
                aspect=1)
# Set the y lim between - 1 and 1
g.set(ylim=(-1, 1))