In [None]:
import numpy as np
import pandas as pd
from tqdm import tqdm
import os
import re
from sklearn.metrics import auc, precision_recall_curve
from sklearn.utils import resample
import seaborn as sns
import matplotlib.pyplot as plt

In [None]:
data_dir = '/Users/jk1/temp/cereblink/data_saving/exclude_nan_outcome_False'
output_dir = '/Users/jk1/Downloads'

In [None]:

def boostrapped_pr_curve(gt, predictions, interpolations_points=200, n_samples=100):
    bs_gt, bs_predictions  = [], []
    for i in range(n_samples):
        bs_sample_gt, bs_sample_predictions = resample(gt, predictions, replace=True, random_state=i)
        bs_gt.append(bs_sample_gt)
        bs_predictions.append(bs_sample_predictions)
    bs_pr_aucs = []

    resampled_pr_df = pd.DataFrame()

    for idx in tqdm(range(len(bs_predictions))):
        bs_precision, bs_recall, _ = precision_recall_curve(bs_gt[idx], bs_predictions[idx])
        bs_pr_auc = auc(bs_recall, bs_precision)
        bs_pr_aucs.append(bs_pr_auc)

        # sort by recall
        bs_recall, bs_precision = zip(*sorted(zip(bs_recall, bs_precision)))
        bs_resampled_precision = np.interp(np.linspace(0, 1, interpolations_points), bs_recall, bs_precision)
        bs_resampled_pr_df = pd.DataFrame({'precision': bs_resampled_precision, 'recall': np.linspace(0, 1, interpolations_points)})
        bs_resampled_pr_df['bootstrap_idx'] = idx
        resampled_pr_df = pd.concat([resampled_pr_df, bs_resampled_pr_df], axis=0)

    return bs_pr_aucs, resampled_pr_df

In [None]:
drop_overlapping_timebins = True

In [None]:
pupillometry_metrics = ['NPI', 'CV']
inter_eye_metrics = ['mean', 'min', 'max', 'delta']
# combine to get all metrics
single_timepoint_metrics = [f'{metric}_inter_eye_{metric_type}' for metric in pupillometry_metrics for metric_type in
                            inter_eye_metrics]
over_time_metrics = ['max', 'min', 'median']
# combine to get all metrics
timebin_metrics = [f'{metric}_timebin_{metric_type}' for metric in single_timepoint_metrics for metric_type in
                   over_time_metrics]

In [None]:
data_filenames = [f for f in os.listdir(data_dir) if f.endswith('.csv') and 'timebin' in f and 'reassembled_pupillometry' in f]

pupillometry_df = pd.DataFrame()
for data_filename in data_filenames:
    # find timebin size with regex identifying pattern : _xh_
    timebin_size = int(re.search(r'_(\d+)h_', data_filename).group(1))
    data_is_normalized = int(('normalized' in data_filename) or ('normalised' in data_filename))
    outcome = '_'.join(data_filename.split('_')[0:2])

    df = pd.read_csv(os.path.join(data_dir, data_filename))
    df['timebin_size'] = timebin_size
    df['normalized'] = data_is_normalized
    df['outcome'] = outcome
    
    # # drop overlapping timebins
    if drop_overlapping_timebins:
       # find first timebin for every pNr
        df.timebin_end = pd.to_datetime(df.timebin_end)
        df['first_timebin'] = df.groupby('pNr')['timebin_end'].transform('min')
        df['relative_timebin_end'] = (df['timebin_end'] - df['first_timebin']).dt.total_seconds() / 3600
        df['relative_timebin_end_cat'] = df['relative_timebin_end'] / df['timebin_size']
        df.loc[~df.relative_timebin_end_cat.isna(), 'relative_timebin_end_cat'] = df.loc[~df.relative_timebin_end_cat.isna(), 'relative_timebin_end_cat'].astype(int)
        # drop row if all timebin metrics in row are NaN
        df.dropna(subset=timebin_metrics, how='all', inplace=True)
        df.drop_duplicates(subset=['pNr', 'relative_timebin_end_cat', 'label'], inplace=True)
    
    pupillometry_df = pd.concat([pupillometry_df, df], axis=0)
    
pupillometry_df = pupillometry_df.reset_index(drop=True)
pupillometry_df.drop(columns=['Unnamed: 0'], inplace=True)

In [None]:
outcome = 'DCI_ischemia'
timebin_size = 8
data_is_normalized = 1
metric = 'CV_inter_eye_mean_timebin_max'

In [None]:
all_colors_palette = sns.color_palette(['#f61067', '#049b9a', '#012D98', '#a76dfe'], n_colors=4)
all_colors_palette


In [None]:
normalized_metric_df

In [None]:
import matplotlib.patches as mpatches
from matplotlib.legend_handler import HandlerTuple

tick_label_size = 11
label_font_size = 13
n_samples = 10
plot_legend = True
plot_NPI = True

interpolations_points = 200
title = f"{metric}, {timebin_size}h timebin"


normalized_metric_df = pupillometry_df[(pupillometry_df['outcome'] == outcome) & (pupillometry_df['timebin_size'] == timebin_size) & (pupillometry_df['normalized'] == 1)]
normalized_metric_df.dropna(subset=[metric], inplace=True)

non_normalized_metric_df = pupillometry_df[(pupillometry_df['outcome'] == outcome) & (pupillometry_df['timebin_size'] == timebin_size) & (pupillometry_df['normalized'] == 0)]
non_normalized_metric_df.dropna(subset=[metric], inplace=True)

corresponding_npi = 'NPI' + metric[2:]
non_normalized_npi_df = pupillometry_df[(pupillometry_df['outcome'] == outcome) & (pupillometry_df['timebin_size'] == timebin_size) & (pupillometry_df['normalized'] == 1)]
non_normalized_npi_df.dropna(subset=[corresponding_npi], inplace=True)

normalized_metric_bs_pr_aucs, normalized_metric_resampled_pr_df = boostrapped_pr_curve(normalized_metric_df.label, normalized_metric_df[metric], interpolations_points=interpolations_points, n_samples=n_samples)
non_normalized_metric_bs_pr_aucs, non_normalized_metric_resampled_pr_df = boostrapped_pr_curve(non_normalized_metric_df.label, non_normalized_metric_df[metric], interpolations_points=interpolations_points, n_samples=n_samples)
corresponding_npi_bs_pr_aucs, corresponding_npi_resampled_pr_df = boostrapped_pr_curve(non_normalized_npi_df.label, non_normalized_npi_df[corresponding_npi], interpolations_points=interpolations_points, n_samples=n_samples)

fig, ax = plt.subplots(figsize=(10, 10))

sns.lineplot(x='recall', y='precision', data=normalized_metric_resampled_pr_df, color=all_colors_palette[0], label=f'Normalized, AUC: {np.median(normalized_metric_bs_pr_aucs):.3f}', ax=ax, errorbar=('ci', 95))
sns.lineplot(x='recall', y='precision', data=non_normalized_metric_resampled_pr_df, color=all_colors_palette[1], label=f'Non-normalized, AUC: {np.median(non_normalized_metric_bs_pr_aucs):.3f}', ax=ax, errorbar=('ci', 95))

if plot_NPI:
    sns.lineplot(x='recall', y='precision', data=corresponding_npi_resampled_pr_df, color=all_colors_palette[2], label=f'NPI, AUC: {np.median(corresponding_npi_bs_pr_aucs):.3f}', ax=ax, errorbar=('ci', 95))

ax.set_xlabel('Recall (Sensitivity)', fontsize=label_font_size)
ax.set_ylabel('Precision (PPV)', fontsize=label_font_size)

ax.set_yscale('log')


In [None]:
normalized_metric_bs_pr_aucs