In [None]:
import os
import numpy as np
import re
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import roc_utils as ru

In [None]:
data_dir = '/Users/jk1/Downloads/data_saving/exclude_nan_outcome_False'
output_dir = '/Users/jk1/Downloads'

In [None]:
data_filenames = [f for f in os.listdir(data_dir) if f.endswith('.csv') and 'timebin' in f and 'reassembled_pupillometry' in f]

pupillometry_df = pd.DataFrame()
for data_filename in data_filenames:
    # find timebin size with regex identifying pattern : _xh_
    timebin_size = int(re.search(r'_(\d+)h_', data_filename).group(1))
    data_is_normalized = int(('normalized' in data_filename) or ('normalised' in data_filename))
    outcome = '_'.join(data_filename.split('_')[0:2])

    df = pd.read_csv(os.path.join(data_dir, data_filename))
    df['timebin_size'] = timebin_size
    df['normalized'] = data_is_normalized
    df['outcome'] = outcome
    pupillometry_df = pd.concat([pupillometry_df, df], axis=0)
    
pupillometry_df = pupillometry_df.reset_index(drop=True)
pupillometry_df.drop(columns=['Unnamed: 0'], inplace=True)

In [None]:
pupillometry_df.head()

In [None]:
outcome = 'DCI_ischemia'
timebin_size = 8
data_is_normalized = 1
metric = 'CV_inter_eye_min_timebin_max'

In [None]:
all_colors_palette = sns.color_palette(['#f61067', '#049b9a', '#012D98', '#a76dfe'], n_colors=4)
all_colors_palette

In [None]:
from sklearn.metrics import roc_curve, auc
import matplotlib.patches as mpatches
from matplotlib.legend_handler import HandlerTuple

tick_label_size = 11
label_font_size = 13
n_samples = 10
plot_legend = True
plot_NPI = True

title = f"{metric}, {timebin_size}h timebin"

normalized_metric_df = pupillometry_df[(pupillometry_df['outcome'] == outcome) & (pupillometry_df['timebin_size'] == timebin_size) & (pupillometry_df['normalized'] == 1)]
normalized_metric_df.dropna(subset=[metric], inplace=True)

non_normalized_metric_df = pupillometry_df[(pupillometry_df['outcome'] == outcome) & (pupillometry_df['timebin_size'] == timebin_size) & (pupillometry_df['normalized'] == 0)]
non_normalized_metric_df.dropna(subset=[metric], inplace=True)

corresponding_npi = 'NPI' + metric[2:]
non_normalized_npi_df = pupillometry_df[(pupillometry_df['outcome'] == outcome) & (pupillometry_df['timebin_size'] == timebin_size) & (pupillometry_df['normalized'] == 0)]
non_normalized_npi_df.dropna(subset=[corresponding_npi], inplace=True)

norm_metric_fpr, norm_metric_tpr, norm_metric_thresholds = roc_curve(
    normalized_metric_df['label'],
    -1 * normalized_metric_df[metric],
    pos_label=1,
)
norm_metric_roc_auc = auc(norm_metric_fpr, norm_metric_tpr)

non_norm_metric_fpr, non_norm_metric_tpr, non_norm_metric_thresholds = roc_curve(
    non_normalized_metric_df['label'],
    -1 * non_normalized_metric_df[metric],
    pos_label=1,
)
non_norm_metric_roc_auc = auc(non_norm_metric_fpr, non_norm_metric_tpr)

non_norm_npi_fpr, non_norm_npi_tpr, non_norm_npi_thresholds = roc_curve(
    non_normalized_npi_df['label'],
    non_normalized_npi_df[corresponding_npi],
    pos_label=1,
)
non_norm_npi_roc_auc = auc(non_norm_npi_fpr, non_norm_npi_tpr)

fig, ax = plt.subplots(figsize=(10, 10))

# Plot normalized metric 
ru.plot_roc_bootstrap(X=-1 * normalized_metric_df[metric], y=normalized_metric_df['label'], ax=ax, 
                      pos_label=1,
                      n_bootstrap=n_samples,
                      random_state=42, show_ti=False)

# set color 
ax.get_lines()[0].set_color(all_colors_palette[0])
ax.get_children()[0].set_facecolor(all_colors_palette[0])
ax.get_children()[0].set_edgecolor(all_colors_palette[0])
ax.get_children()[0].set_alpha(0.1)

# Plot non-normalized metric
ru.plot_roc_bootstrap(X=-1 * non_normalized_metric_df[metric], y=non_normalized_metric_df['label'], ax=ax, 
                      pos_label=1,
                      n_bootstrap=n_samples,
                      random_state=42, show_ti=False)

# set color
ax.get_lines()[2].set_color(all_colors_palette[3])
ax.get_children()[3].set_facecolor(all_colors_palette[3])
ax.get_children()[3].set_edgecolor(all_colors_palette[3])
ax.get_children()[3].set_alpha(0.1)

if plot_NPI:
    # Plot non-normalized NPI
    ru.plot_roc_bootstrap(X=non_normalized_npi_df[corresponding_npi], y=non_normalized_npi_df['label'], ax=ax, 
                          pos_label=1,
                          n_bootstrap=n_samples,
                          random_state=42, show_ti=False)

    # set color
    ax.get_lines()[4].set_color(all_colors_palette[2])
    ax.get_children()[6].set_facecolor(all_colors_palette[2])
    ax.get_children()[6].set_edgecolor(all_colors_palette[2])
    ax.get_children()[6].set_alpha(0.1)

# Plot chance
ax.plot([0, 1], [0, 1], color='grey', lw=1, linestyle='--', alpha=0.5)

if plot_legend:
    legend_markers, _ = ax.get_legend_handles_labels()
    norm_label = f'Normalized (AUC = {norm_metric_roc_auc:.2f})'
    non_norm_label = f'Non-normalized (AUC = {non_norm_metric_roc_auc:.2f})'
    legend_labels = [norm_label, non_norm_label]
    
    if plot_NPI:
        npi_label = f'NPI (AUC = {non_norm_npi_roc_auc:.2f})'
        legend_labels.append(npi_label)
    
    sd1_patch = mpatches.Patch(color=all_colors_palette[0], alpha=0.3)
    sd2_patch = mpatches.Patch(color=all_colors_palette[3], alpha=0.3)
    sd_marker = (sd1_patch, sd2_patch)
    if plot_NPI:
        sd3_patch = mpatches.Patch(color=all_colors_palette[2], alpha=0.3)
        sd_marker = (sd1_patch, sd2_patch, sd3_patch)
    sd_labels = '95% CI'
    legend_markers.append(sd_marker)
    legend_labels.append(sd_labels)
    ax.legend(legend_markers, legend_labels, fontsize=label_font_size,
              handler_map={tuple: HandlerTuple(ndivide=None)})
    
else:
    # remove legend
    ax.get_legend().remove()


ax.set_xlim([0.0, 1.0])
ax.set_ylim([0.0, 1.05])
ax.set_xlabel('1 - Specificity (False Positive Rate)', fontsize=label_font_size)
ax.set_ylabel('Sensitivity (True Positive Rate)', fontsize=label_font_size)
ax.tick_params('x', labelsize=tick_label_size)
ax.tick_params('y', labelsize=tick_label_size)


plt.title(title)
# remove suptitle
plt.suptitle('')

In [None]:
# Save figure
# fig.savefig(os.path.join(output_dir, f'{outcome}_{timebin_size}h_{metric}_roc.png'), dpi=300, bbox_inches='tight')