In [None]:
import numpy as np
import pandas as pd

import sklearn
from sklearn import metrics

import shap
import statsmodels.api as sm

import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.patches as mpatches

from imblearn.over_sampling import SMOTE, ADASYN, RandomOverSampler
from imblearn.under_sampling import RandomUnderSampler, EditedNearestNeighbours

import lightgbm as lgb

import sys
sys.path.append('../')
import importlib
import os
from joblib import dump, load

import model_util
importlib.reload(model_util)
from model_util import get_estimator_probabilities, get_scoring_metrics

import feature_sets
importlib.reload(feature_sets)

import neural_net
importlib.reload(neural_net)
from neural_net import MLP

from sklearnex import patch_sklearn
patch_sklearn()

# Load data

- Use min-max scaled X_train and X_test for all models for consistency, even if not required
- Use miceforest imputed data for all models (same reasoning)

In [None]:
X_train_imputed_scaled = load('../data/imputed/IOPsubcohort_X_train_imputed_scaled.pkl')
y_train = load('../data/imputed/IOPsubcohort_y_train.pkl')

X_test_imputed_scaled = load('../data/imputed/IOPsubcohort_X_test_imputed_scaled.pkl')
y_test = load('../data/imputed/IOPsubcohort_y_test.pkl')

In [None]:
model_feature_dict = {
    'ophthalmic': feature_sets.ophthalmic_features['feature'].values,
    'demographic': feature_sets.demographic_features['feature'].values,
    'systemic': feature_sets.systemic_features['feature'].values,
    'lifestyle': feature_sets.lifestyle_features['feature'].values,

    'OD': feature_sets.OD_features['feature'].values,
    'SL': feature_sets.SL_features['feature'].values,
    'ODSL': feature_sets.ODSL_features['feature'].values,
    #'ODS': feature_sets.ODS_features['feature'].values,
    'DSL': feature_sets.DSL_features['feature'].values, # Primary-care focused       

    'minimal_features_rfecv': feature_sets.minimal_features_rfecv['feature'].values
}

In [None]:
fitted_models_dir = './best_hyperparams_fitted/'

algorithm_display_names = {
    'LR': 'logistic_regression_sgd',
    'SVM': 'svm',
    'KNN': 'knn',
    'RF': 'randomforest',
    'XGBoost': 'xgboost',
    'LightGBM': 'lightgbm',
    'MLP': 'mlp_50trials',
}

# Model accuracy

- Compare algorithms per feature set
- Compare best-performing algorithms across feature sets

In [None]:
def compare_estimator_auroc_bar(estimator_dict, y, best_model, ax):
    #ax.set_xlabel('Model')
    #ax.set_ylabel('Area under the receiver operating characteristic curve (AUROC)')

    x_arr = np.array([])
    y_arr = np.array([])

    errorbar_low_arr = np.array([])
    errorbar_high_arr = np.array([])

    ax.axhline(color='k', linestyle='dotted', y=0.5, xmin=0.05, xmax=0.95, label=f'Baseline: 0.5', alpha=0.7)

    for display_name, vals in estimator_dict.items():
        model_name = vals['model_name']
        auroc = vals['auroc']
        auroc_CI_low = vals['auroc_CI_low']
        auroc_CI_high = vals['auroc_CI_high']

        x_arr = np.append(x_arr, model_name)
        y_arr = np.append(y_arr, auroc)
        errorbar_low_arr = np.append(errorbar_low_arr, auroc - auroc_CI_low)
        errorbar_high_arr = np.append(errorbar_high_arr, auroc_CI_high - auroc)

    sns.barplot(x=x_arr, y=y_arr, hue=x_arr, ax=ax, legend=False)
    
    for display_name, vals in estimator_dict.items():
        model_name = vals['model_name']
        auroc = vals['auroc']
        #ax.text(model_name, 0.1, f'{auroc:0.3f}', va='center', ha='center', fontsize=8)
        if display_name == best_model:
            ax.text(model_name, 0.1, f'{auroc:0.3f}', va='center', ha='center', fontsize=8, fontweight='bold')
        else:
            ax.text(model_name, 0.1, f'{auroc:0.3f}', va='center', ha='center', fontsize=8)

    for label in ax.get_xticklabels():
        if label.get_text() == estimator_dict[best_model]['model_name']:
            label.set_fontweight('bold')
    
    #ax.bar_label(ax.containers[0], padding=40, fmt='%.3f')
    #ax.bar_label(ax.containers[0], padding=0, fmt='%.3f', label_type='center', fontsize=8)

    # for bar in ax.containers[0]:
    #     plt.text(bar.get_x(), 0.95, f'{bar.get_height():0.3f}')
    
    ax.errorbar(
        x=x_arr, y=y_arr,
        yerr = [errorbar_low_arr, errorbar_high_arr],
        fmt='none',
        ecolor='k',
        elinewidth=1,
        capsize=6,
    )

    ax.set_ylim([0, 1])
    ax.set_yticks(np.linspace(0, 1, 11))


    
    #if save_dir:
        #plt.savefig(f'{save_dir}/bar_auroc.png', dpi=600, bbox_inches='tight', pad_inches=0)

    #plt.show()

In [None]:
# V2

def compare_estimator_roc(estimator_dict, y, ax):

    ax.set_xlim([-0.01, 1.01])
    ax.set_ylim([-0.01, 1.01])
    # ax.set_xlabel('1 - Specificity (false positive rate)')
    # ax.set_ylabel('Sensitivity (true positive rate)')

    ax.plot([0, 1], [0, 1], 'k:', label='Baseline: 0.500', alpha=0.7)

    for display_name, vals in estimator_dict.items():
        estimator = vals['estimator']
        label = vals['roc_label']
        X = vals['X']
        
        probabilities = get_estimator_probabilities(estimator, X)
        fpr_arr, tpr_arr, trial_t_arr = metrics.roc_curve(y, probabilities, pos_label=1, drop_intermediate=False)
        scoring_metrics = get_scoring_metrics(estimator, X, y)
        roc_auc = scoring_metrics['roc_auc']
        ax.plot(fpr_arr, tpr_arr, label=label, alpha=1)


    #ax.legend(loc='lower right')

    # if save_dir:
    #     plt.savefig(f'{save_dir}/roc.png', dpi=1000, bbox_inches='tight', pad_inches=0.1)

    # plt.show()

In [None]:
# V2

def compare_estimator_prc(estimator_dict, y, ax):

    ax.set_xlim([-0.01, 1.01])
    ax.set_ylim([-0.01, 1.01])
    # ax.set_xlabel('1 - Specificity (false positive rate)')
    # ax.set_ylabel('Sensitivity (true positive rate)')

    # Baseline = glaucoma prevalence
    n_pos = (y == 1).sum()
    n_total = len(y)
    baseline_y = n_pos / n_total
    ax.hlines(colors='k', linestyles='dotted', y=baseline_y, xmin=0, xmax=1, label=f'Baseline: {baseline_y:0.3f}', alpha=0.7)
   
    for display_name, vals in estimator_dict.items():
        estimator = vals['estimator']
        label = vals['roc_label']
        X = vals['X']
        
        probabilities = get_estimator_probabilities(estimator, X)
        precision, recall, pr_t = metrics.precision_recall_curve(y, probabilities, pos_label=1, drop_intermediate=True)
        pr_auc = metrics.auc(recall, precision)
        ax.plot(recall, precision, label=label, alpha=1)

 
    #ax.legend(loc='upper right', fontsize=8)


    # if save_dir:
    #     plt.savefig(f'{save_dir}/roc.png', dpi=1000, bbox_inches='tight', pad_inches=0.1)

    # plt.show()

In [None]:
def compare_models(model_dict, save_dir=None, auroc_bar_ax=None, roc_ax=None, prc_ax=None):
    model_df = pd.DataFrame(columns=['Model'])
    model_df = model_df.set_index('Model', drop=True)

    estimator_dict = {}

    if save_dir:
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)

    for model_name, model_properties in model_dict.items():
        print(model_name)

        feature_set = model_feature_dict[model_properties['feature_set']]
        X = X_test_imputed_scaled[feature_set]

        estimator = load(f'{model_properties['save_dir']}.pkl')
        with open(f'{model_properties['save_dir']}_best_params.txt', 'r') as txt:
            params = txt.read()
        
        test_scoring_metrics = get_scoring_metrics(estimator, X, y_test)
        probabilities = get_estimator_probabilities(estimator, X)
        auroc = test_scoring_metrics['roc_auc']
        auprc = test_scoring_metrics['pr_auc']

        t = model_util.get_ideal_threshold(y_test, probabilities)

        btstrp_measures = model_util.get_bootstrapped_measures(y_test, probabilities, t).confidence_interval
        auroc_CI_low, auprc_CI_low, sensitivity_CI_low, specificity_CI_low, ppv_CI_low, npv_CI_low, f1_CI_low = btstrp_measures.low
        auroc_CI_high, auprc_CI_high, sensitivity_CI_high, specificity_CI_high, ppv_CI_high, npv_CI_high, f1_CI_high = btstrp_measures.high

        model_df.loc[model_name, 'Algorithm'] = model_properties['algorithm']
        model_df.loc[model_name, 'Parameters'] = params
        model_df.loc[model_name, 'AUROC'] = auroc
        model_df.loc[model_name, 'AUROC (95% CI)'] = f'{auroc:0.3f} ({auroc_CI_low:0.3f} - {auroc_CI_high:0.3f})'
       
        model_df.loc[model_name, 'Sensitivity (95% CI)'] = (
            f'{test_scoring_metrics['sensitivity_recall']:0.3f} '
            f'({sensitivity_CI_low:0.3f} - {sensitivity_CI_high:0.3f})'
        )
        
        model_df.loc[model_name, 'Specificity (95% CI)'] = (
            f'{test_scoring_metrics['specificity']:0.3f} '
            f'({specificity_CI_low:0.3f} - {specificity_CI_high:0.3f})'
        )

        model_df.loc[model_name, 'PPV (95% CI)'] = (
            f'{test_scoring_metrics['ppv_precision']:0.3f} '
            f'({ppv_CI_low:0.3f} - {ppv_CI_high:0.3f})'
        )

        model_df.loc[model_name, 'NPV (95% CI)'] = (
            f'{test_scoring_metrics['npv']:0.3f} '
            f'({npv_CI_low:0.3f} - {npv_CI_high:0.3f})'
        )

        model_df.loc[model_name, 'F1 score (95% CI)'] = (
            f'{test_scoring_metrics['f1_score']:0.3f} '
            f'({f1_CI_low:0.3f} - {f1_CI_high:0.3f})'
        )

        model_df.loc[model_name, 'AUPRC (95% CI)'] = (
            f'{auprc:0.3f} '
            f'({auprc_CI_low:0.3f} - {auprc_CI_high:0.3f})'
        )

        estimator_dict[model_name] = {
            'estimator': estimator,
            'model_name': model_name,
            'X': X,
            'probabilities': probabilities,
            'roc_label': f'{model_name}: {auroc:0.3f} ({auroc_CI_low:0.3f} - {auroc_CI_high:0.3f})',
            'prc_label': f'{model_name}: {auprc:0.3f} ({auprc_CI_low:0.3f} - {auprc_CI_high:0.3f})',

            'auroc': auroc,
            'auroc_CI_low': auroc_CI_low,
            'auroc_CI_high': auroc_CI_high,

            'auprc': auprc,
            'auprc_CI_low': auprc_CI_low,
            'auprc_CI_high': auprc_CI_high,

            'f1_score': test_scoring_metrics['f1_score'],
            'f1_CI_low': f1_CI_low,
            'f1_CI_high': f1_CI_high,
            
            'sensitivity_recall': test_scoring_metrics['sensitivity_recall'],
            'sensitivity_CI_low': sensitivity_CI_low,
            'sensitivity_CI_high': sensitivity_CI_high,
            
            'specificity': test_scoring_metrics['specificity'],
            'specificity_CI_low': specificity_CI_low,
            'specificity_CI_high': specificity_CI_high,
        }

    best_model = model_df['AUROC'].idxmax()
    best_model_dict = model_dict[best_model]
    best_model_algorithm = best_model_dict['algorithm']
    print(f'Best model: {best_model} ({best_model_algorithm})')

    if auroc_bar_ax != None:
        compare_estimator_auroc_bar(estimator_dict, y_test, best_model=best_model, ax=auroc_bar_ax)
    compare_estimator_roc(estimator_dict, y_test, ax=roc_ax)
    compare_estimator_prc(estimator_dict, y_test, ax=prc_ax)


    display(model_df.style.set_properties(**{
        'white-space': 'pre-wrap',
    }))

    if save_dir:
        model_df.to_csv(f'{save_dir}/comparison_results.tsv', sep='\t')
        # Save best model as txt
        with open(f'{save_dir}/best_model.txt', 'w+') as txt:
            txt.write(best_model)
        with open(f'{save_dir}/best_algorithm.txt', 'w+') as txt:
            txt.write(best_model_algorithm)

   # plt.show()

### Compare algorithms for each feature set

In [None]:
compare_feature_set_dict = {
    'ophthalmic': feature_sets.ophthalmic_features['feature'].values,
    'demographic': feature_sets.demographic_features['feature'].values,
    'systemic': feature_sets.systemic_features['feature'].values,
    'lifestyle': feature_sets.lifestyle_features['feature'].values,

    'OD': feature_sets.OD_features['feature'].values,
    'SL': feature_sets.SL_features['feature'].values,
    'ODSL': feature_sets.ODSL_features['feature'].values,
    #'ODS': feature_sets.ODS_features['feature'].values,
    'DSL': feature_sets.DSL_features['feature'].values, # Primary-care focused       

    #'minimal_features_rfecv': feature_sets.minimal_features_rfecv['feature'].values
}

In [None]:
auroc_bar_fig, auroc_bar_axes = plt.subplots(4, 2, figsize=(35/2.54, 23/2.54), dpi=600, sharex=False, sharey=True)
roc_fig, roc_axes = plt.subplots(2, 4, figsize=(24/2.54, 14/2.54), dpi=600, sharex=False, sharey=False)
prc_fig, prc_axes = plt.subplots(2, 4, figsize=(24/2.54, 14/2.54), dpi=600, sharex=False, sharey=False)

auroc_bar_axes = auroc_bar_axes.flatten()
roc_axes = roc_axes.flatten()
prc_axes = prc_axes.flatten()

title_arr = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K']
#title_arr = ['i', 'ii', 'iii', 'iv', 'v', 'vi', 'vii', 'viii', 'ix', 'x']

for i, (feature_set_name, features) in zip(range(len(compare_feature_set_dict.items())), compare_feature_set_dict.items()):
    print(feature_set_name)
    
    subplot_title = title_arr[i]
    auroc_bar_ax = auroc_bar_axes[i]
    roc_ax = roc_axes[i]
    prc_ax = prc_axes[i]

    feature_set_dict = {}
    for display_name, algorithm in algorithm_display_names.items():
        feature_set_dict[display_name] = {
            'feature_set': feature_set_name,
            'save_dir': f'./{fitted_models_dir}/{feature_set_name}/{algorithm}',
            'algorithm': algorithm,
        }

    compare_models(
        feature_set_dict, f'./evaluation/{feature_set_name}', 
        auroc_bar_ax=auroc_bar_ax, 
        roc_ax=roc_ax, 
        prc_ax=prc_ax,
    )
    
    auroc_bar_ax.text(-0.05, 1.03, f'{subplot_title}. {feature_set_name}', transform=auroc_bar_ax.transAxes, fontsize=8, fontweight='bold', va='bottom', ha='left')
    auroc_bar_ax.tick_params(axis='x', labelsize=8)
    auroc_bar_ax.tick_params(axis='y', labelsize=8)
    auroc_bar_ax.set_yticks(np.linspace(0, 1, 6))

    auroc_bar_ax.spines['top'].set_visible(False)
    auroc_bar_ax.spines['right'].set_visible(False)
    #auroc_bar_ax.spines['left'].set_visible(False)

    roc_ax.text(-0.05, 1.03, f'{subplot_title}. {feature_set_name}', transform=roc_ax.transAxes, fontsize=8, fontweight='bold', va='bottom', ha='left')
    roc_ax.tick_params(axis='x', labelsize=8)
    roc_ax.tick_params(axis='y', labelsize=8)
    roc_ax.set_xticks(np.linspace(0, 1, 6))
    roc_ax.set_yticks(np.linspace(0, 1, 6))
    roc_ax.spines['top'].set_visible(False)
    roc_ax.spines['right'].set_visible(False)

    prc_ax.text(-0.05, 1.03, f'{subplot_title}. {feature_set_name}', transform=prc_ax.transAxes, fontsize=8, fontweight='bold', va='bottom', ha='left')
    prc_ax.tick_params(axis='x', labelsize=8)
    prc_ax.tick_params(axis='y', labelsize=8)
    prc_ax.set_xticks(np.linspace(0, 1, 6))
    prc_ax.set_yticks(np.linspace(0, 1, 6))
    prc_ax.spines['top'].set_visible(False)
    prc_ax.spines['right'].set_visible(False)




# AUROC Bar
#auroc_bar_fig.yticks(np.linspace(0, 1, 6))
#auroc_bar_fig.tick_params(axis='y', labelsize=8)

legend_labels = ['Baseline (0.500)'] + list(algorithm_display_names.keys())

auroc_bar_fig.text(0.5, -0.01, 'Algorithm', ha='center', va='center', fontsize=8)
auroc_bar_fig.text(-0.01, 0.5, 'AUROC', ha='left', va='center', rotation='vertical', fontsize=8)
auroc_bar_fig.legend(loc='center', bbox_to_anchor=(1.01, 0, 0.1, 1), fontsize=8, labels=legend_labels, handles=roc_axes[0].lines)
auroc_bar_fig.tight_layout()
auroc_bar_fig.savefig(f'./evaluation/compare_feature_sets_auroc_bar.png', dpi=600, bbox_inches='tight', pad_inches=0)
auroc_bar_fig.show()


# ROC
legend_labels = ['Baseline'] + list(algorithm_display_names.keys())

roc_fig.text(0.5, -0.01, '1 - Specificity (false positive rate)', ha='center', va='center', fontsize=8)
roc_fig.text(-0.01, 0.5, 'Sensitivity (true positive rate)', ha='left', va='center', rotation='vertical', fontsize=8)
roc_fig.legend(loc='center', bbox_to_anchor=(1.01, 0, 0.1, 1), fontsize=8, labels=legend_labels, handles=roc_axes[0].lines)
roc_fig.tight_layout()
roc_fig.savefig(f'./evaluation/compare_feature_sets_roc.png', dpi=600, bbox_inches='tight', pad_inches=0)
roc_fig.show()

# PRC
prc_fig.text(0.5, -0.01, 'Recall (sensitivity)', ha='center', va='center', fontsize=8)
prc_fig.text(-0.01, 0.5, 'Precision (positive predictive value)', ha='left', va='center', rotation='vertical', fontsize=8)
prc_fig.legend(loc='center', bbox_to_anchor=(1.01, 0, 0.1, 1), fontsize=8, labels=legend_labels, handles=roc_axes[0].lines)
prc_fig.tight_layout()
prc_fig.savefig(f'./evaluation/compare_feature_sets_prc.png', dpi=600, bbox_inches='tight', pad_inches=0)
prc_fig.show()

### Compare best algorithm across feature sets

In [None]:
# V2

comparison_dict = {}
 
for feature_set_name, features in compare_feature_set_dict.items():    
    with open(f'./evaluation/{feature_set_name}/best_algorithm.txt', 'r') as txt:
        best_algorithm = txt.read()

    print(f'{feature_set_name} --> {best_algorithm}')
    
    comparison_dict[feature_set_name] = {
        'feature_set': feature_set_name,
        'save_dir': f'./{fitted_models_dir}/{feature_set_name}/{best_algorithm}',
        'algorithm': best_algorithm,
    }

fig, axes = plt.subplots(1, 2, figsize=(14/2.54, 8/2.54), dpi=600, sharex=False, sharey=False)
#auroc_bar_ax = axes[0, 0]
roc_ax = axes[0]
prc_ax = axes[1]


compare_models(
    comparison_dict, 
    f'./evaluation/compare_feature_sets',
    #auroc_bar_ax=auroc_bar_ax, 
    roc_ax=roc_ax, 
    prc_ax=prc_ax,
)

    
auroc_bar_ax.text(-0.1, 1.1, f'A.', transform=auroc_bar_ax.transAxes, fontsize=10, fontweight='bold', va='top', ha='right')
roc_ax.text(-0.1, 1.1, f'A.', transform=roc_ax.transAxes, fontsize=10, fontweight='bold', va='top', ha='right')
prc_ax.text(-0.1, 1.1, f'B.', transform=prc_ax.transAxes, fontsize=10, fontweight='bold', va='top', ha='right')

# bar_ax.set_xlabel('Algorithm', fontsize=8)
# bar_ax.set_ylabel('Area under the receiver operating characteristic curve (AUROC)', fontsize=8)
# bar_ax.tick_params(axis='x', labelsize=8)
# bar_ax.tick_params(axis='y', labelsize=8)

auroc_bar_ax.set_xlabel('Feature set', fontsize=8)
auroc_bar_ax.set_ylabel('AUROC', fontsize=8)
auroc_bar_ax.tick_params(axis='x', labelsize=8)
auroc_bar_ax.tick_params(axis='y', labelsize=8)
auroc_bar_ax.spines['top'].set_visible(False)
auroc_bar_ax.spines['right'].set_visible(False)

roc_ax.set_xlabel('1 - Specificity (false positive rate)', fontsize=8)
roc_ax.set_ylabel('Sensitivity (true positive rate)', fontsize=8)
roc_ax.tick_params(axis='x', labelsize=8)
roc_ax.tick_params(axis='y', labelsize=8)
roc_ax.set_xticks(np.linspace(0, 1, 6))
roc_ax.set_yticks(np.linspace(0, 1, 6))
roc_ax.spines['top'].set_visible(False)
roc_ax.spines['right'].set_visible(False)

prc_ax.set_xlabel('Recall (sensitivity)', fontsize=8)
prc_ax.set_ylabel('Precision (positive predictive value)', fontsize=8)
prc_ax.tick_params(axis='x', labelsize=8)
prc_ax.tick_params(axis='y', labelsize=8)
prc_ax.set_xticks(np.linspace(0, 1, 6))
prc_ax.set_yticks(np.linspace(0, 1, 6))
prc_ax.spines['top'].set_visible(False)
prc_ax.spines['right'].set_visible(False)

labels=['Baseline'] + list(compare_feature_set_dict.keys()) 
#fig.legend(loc='center', bbox_to_anchor=(0.5, 0, 0.5, 0.5), fontsize=8, labels=labels, handles=roc_ax.lines)
fig.legend(loc='center', bbox_to_anchor=(1.05, 0, 0.1, 1), fontsize=8, labels=labels, handles=roc_ax.lines)

fig.tight_layout()

fig.savefig(f'./evaluation/compare_feature_sets_roc_prc.png', dpi=600, bbox_inches='tight', pad_inches=0)
fig.show()

## Compare minimal model from RFECV

In [None]:
minimal_model_comp_feature_dict = {
    'OD': feature_sets.OD_features['feature'].values,
    'SL': feature_sets.SL_features['feature'].values,
    'ODSL': feature_sets.ODSL_features['feature'].values,
    'DSL': feature_sets.DSL_features['feature'].values,
    'minimal_features_rfecv': feature_sets.minimal_features_rfecv['feature'].values 
}

In [None]:
comparison_dict = {}
 
for feature_set_name, features in minimal_model_comp_feature_dict.items():    
    with open(f'./evaluation/{feature_set_name}/best_algorithm.txt', 'r') as txt:
        best_algorithm = txt.read()

    print(f'{feature_set_name} --> {best_algorithm}')
    
    comparison_dict[feature_set_name] = {
        'feature_set': feature_set_name,
        'save_dir': f'./{fitted_models_dir}/{feature_set_name}/{best_algorithm}',
        'algorithm': best_algorithm,
    }

fig, axes = plt.subplots(1, 2, figsize=(14/2.54, 8/2.54), dpi=600, sharex=False, sharey=False)
#auroc_bar_ax = axes[0, 0]
roc_ax = axes[0]
prc_ax = axes[1]


compare_models(
    comparison_dict, 
    f'./evaluation/compare_minimal_model',
    #auroc_bar_ax=auroc_bar_ax, 
    roc_ax=roc_ax, 
    prc_ax=prc_ax,
)

    
auroc_bar_ax.text(-0.1, 1.1, f'A.', transform=auroc_bar_ax.transAxes, fontsize=10, fontweight='bold', va='top', ha='right')
roc_ax.text(-0.1, 1.1, f'A.', transform=roc_ax.transAxes, fontsize=10, fontweight='bold', va='top', ha='right')
prc_ax.text(-0.1, 1.1, f'B.', transform=prc_ax.transAxes, fontsize=10, fontweight='bold', va='top', ha='right')

# bar_ax.set_xlabel('Algorithm', fontsize=8)
# bar_ax.set_ylabel('Area under the receiver operating characteristic curve (AUROC)', fontsize=8)
# bar_ax.tick_params(axis='x', labelsize=8)
# bar_ax.tick_params(axis='y', labelsize=8)

auroc_bar_ax.set_xlabel('Feature set', fontsize=8)
auroc_bar_ax.set_ylabel('AUROC', fontsize=8)
auroc_bar_ax.tick_params(axis='x', labelsize=8)
auroc_bar_ax.tick_params(axis='y', labelsize=8)
auroc_bar_ax.spines['top'].set_visible(False)
auroc_bar_ax.spines['right'].set_visible(False)

roc_ax.set_xlabel('1 - Specificity (false positive rate)', fontsize=8)
roc_ax.set_ylabel('Sensitivity (true positive rate)', fontsize=8)
roc_ax.tick_params(axis='x', labelsize=8)
roc_ax.tick_params(axis='y', labelsize=8)
roc_ax.set_xticks(np.linspace(0, 1, 6))
roc_ax.set_yticks(np.linspace(0, 1, 6))
roc_ax.spines['top'].set_visible(False)
roc_ax.spines['right'].set_visible(False)

prc_ax.set_xlabel('Recall (sensitivity)', fontsize=8)
prc_ax.set_ylabel('Precision (positive predictive value)', fontsize=8)
prc_ax.tick_params(axis='x', labelsize=8)
prc_ax.tick_params(axis='y', labelsize=8)
prc_ax.set_xticks(np.linspace(0, 1, 6))
prc_ax.set_yticks(np.linspace(0, 1, 6))
prc_ax.spines['top'].set_visible(False)
prc_ax.spines['right'].set_visible(False)

labels=['Baseline'] + ['OD', 'SL', 'ODSL', 'DSL', 'minimal']
#fig.legend(loc='center', bbox_to_anchor=(0.5, 0, 0.5, 0.5), fontsize=8, labels=labels, handles=roc_ax.lines)
fig.legend(loc='center', bbox_to_anchor=(1.05, 0, 0.1, 1), fontsize=8, labels=labels, handles=roc_ax.lines)

fig.tight_layout()

fig.savefig(f'./evaluation/compare_minimal_model_roc_prc.png', dpi=600, bbox_inches='tight', pad_inches=0)
fig.show()

### Analyse N features

In [None]:
rfe_obj = load('./rfecv_fitted.pkl')

In [None]:
cv_results = rfe_obj.cv_results_
split_test_scores = pd.DataFrame(data=cv_results)
split_test_scores['min_split_score'] = split_test_scores[['split0_test_score', 'split1_test_score', 'split2_test_score', 'split3_test_score', 'split4_test_score']].min(axis=1)
split_test_scores['max_split_score'] = split_test_scores[['split0_test_score', 'split1_test_score', 'split2_test_score', 'split3_test_score', 'split4_test_score']].max(axis=1)

In [None]:
np.flip(cv_results['n_features'])

In [None]:
#fig, a = plt.subplots(1, 1, figsize=(16, 3), dpi=600)
fig, a = plt.subplots(1, 1, figsize=(15.9, 9), dpi=600)

sns.lineplot(x=rfe_obj.cv_results_['n_features'], y = rfe_obj.cv_results_['mean_test_score'], ax=a)

a.fill_between(
    cv_results['n_features'],
    split_test_scores['min_split_score'],
    split_test_scores['max_split_score'],
    alpha = 0.1,
    
)

a.set_xticks(np.arange(1, 95, 2))
#a.set_yticks(np.linspace(0.76, 0.86, 11))
a.set_xlim(0.5, 95.5)

a.axvline(x=15, color='k', linestyle='dotted', alpha=0.7)
a.spines['top'].set_visible(False)
a.spines['right'].set_visible(False)

a.set_ylabel('Mean cross-validation AUROC')
a.set_xlabel('Number of features')

#a.invert_xaxis()

fig.savefig('./evaluation/rfecv_minimal_model_n_features.png', dpi=600/2.54, bbox_inches='tight', pad_inches=0)

fig.show()

# Probability analysis

In [None]:
final_model_rfecv = load(f'{fitted_models_dir}/minimal_features_rfecv/lightgbm.pkl')

In [None]:
final_model_test_predictions = final_model_rfecv.predict_proba(X_test_imputed_scaled[feature_sets.minimal_features_rfecv['feature'].values])[:, 1]

In [None]:
predictions_df = pd.DataFrame(final_model_test_predictions, columns=['Prediction'])
predictions_df['Quantile'] = pd.qcut(predictions_df['Prediction'].values, q=5, labels=False) + 1
predictions_df['Glaucoma'] = y_test

In [None]:
quantile_vals = np.unique(predictions_df['Quantile'])

quantiles_df = pd.DataFrame(columns=['Quantile', 'N glaucoma', 'N glaucoma (%)', 'N control', 'N control (%)'])
quantiles_df['Quantile'] = quantile_vals
quantiles_df.set_index('Quantile', drop=True, inplace=True)

In [None]:
# Table with proportions

is_glaucoma_mask = predictions_df['Glaucoma'] == 1
is_control_mask = predictions_df['Glaucoma'] == 0

total_n_glaucoma = predictions_df[is_glaucoma_mask].shape[0]
total_n_control = predictions_df[is_control_mask].shape[0]

for i in quantile_vals:
    is_quantile_mask = predictions_df['Quantile'] == i
    total_n_quantile = predictions_df[is_quantile_mask].shape[0]

    n_glaucoma = predictions_df[is_glaucoma_mask & is_quantile_mask].shape[0]
    n_control = predictions_df[is_control_mask & is_quantile_mask].shape[0]

    quantiles_df.loc[i, 'N glaucoma'] = n_glaucoma
    quantiles_df.loc[i, 'N glaucoma prop'] = n_glaucoma / total_n_quantile
    quantiles_df.loc[i, 'N glaucoma (%)'] = f'{n_glaucoma} ({(n_glaucoma / total_n_quantile) * 100:0.2f}%)'
    
    quantiles_df.loc[i, 'N control'] = n_control
    quantiles_df.loc[i, 'N control prop'] = n_control / total_n_quantile
    quantiles_df.loc[i, 'N control (%)'] = f'{n_control} ({(n_control / total_n_quantile) * 100:0.2f}%)'

In [None]:
for quantile in quantile_vals:
    if quantile == 1:
        # Reference
        quantiles_df.loc[quantile, 'OR'] = 1
        continue
    masked_df = predictions_df[(predictions_df['Quantile'] == quantile) | (predictions_df['Quantile'].eq(1))]

    lr = sm.Logit(
        endog = masked_df['Glaucoma'],
        exog = sm.tools.tools.add_constant(
            (masked_df['Quantile'] == quantile).astype(int)
        ),
        missing='drop',
    )

    res = lr.fit(disp=False)
    odds_ratio = np.exp(res.params)['Quantile']
    conf_int = np.exp(res.conf_int())
    conf_low = conf_int[0]['Quantile']
    conf_high = conf_int[1]['Quantile']

    quantiles_df.loc[quantile, 'OR'] = odds_ratio
    quantiles_df.loc[quantile, 'CI_low'] = conf_low
    quantiles_df.loc[quantile, 'CI_high'] = conf_high
    quantiles_df.loc[quantile, 'p'] = res.pvalues['Quantile']

In [None]:
quantile_vals_str = ["%.f" % quantile for quantile in quantile_vals]
fig, axes = plt.subplots(1, 3, figsize=(15.9, 7), dpi=600)

# ax1 probs
sns.kdeplot(ax=axes[0], data=predictions_df[predictions_df['Glaucoma'] == 0], x='Prediction', common_norm=False, multiple='layer', fill=True, bw_adjust=0.6, legend=False, label='Control', color='tab:blue')
sns.kdeplot(ax=axes[0], data=predictions_df[predictions_df['Glaucoma'] == 1], x='Prediction', common_norm=False, multiple='layer', fill=True, bw_adjust=0.6, legend=False, label='Glaucoma', color='tab:orange')


axes[0].set_xlabel('Predicted probability', fontsize=14)
axes[0].set_ylabel('Density', fontsize=14)
axes[0].set_xticks(np.linspace(0, 1, 6))
axes[0].set_yticks(np.linspace(0, 40, 6))
axes[0].tick_params(axis='x', labelsize=14)
axes[0].tick_params(axis='y', labelsize=14)

# ax2 quintiles
a = sns.barplot(x=quantile_vals, y=np.ones(len(quantile_vals)), color='tab:blue', ax=axes[1], width=0.85)
b = sns.barplot(x=quantile_vals, y=list(quantiles_df['N glaucoma prop']), color='tab:orange', ax=axes[1], width=0.85)

bottom_bar = mpatches.Patch(color='tab:orange', label='Glaucoma')
top_bar = mpatches.Patch(color='tab:blue', label='Control')

axes[1].set_xlabel('Glaucoma risk quintile', fontsize=14)
axes[1].set_ylabel('Proportion of individuals', fontsize=14)
axes[1].tick_params(axis='x', labelsize=14)
axes[1].tick_params(axis='y', labelsize=14)

# Control
for bar, n, prop in zip(axes[1].containers[0], quantiles_df['N control'], quantiles_df['N control prop']):
    txt = f'{n}\n({prop * 100:0.2f}%)'
    axes[1].text(bar.get_x() + (bar.get_width() / 2), 0.95, txt, fontsize=10, va='center', ha='center', color='white')
    bar.set_alpha(0.8)

# Glaucoma
for bar, n, prop in zip(axes[1].containers[1], quantiles_df['N glaucoma'], quantiles_df['N glaucoma prop']):
    txt = f'{n}\n({prop * 100:0.2f}%)'
    axes[1].text(bar.get_x() + (bar.get_width() / 2), 0.05, txt, fontsize=10, va='center', ha='center', color='white')
    bar.set_alpha(0.8)

axes[1].set_yticks(np.linspace(0, 1., 6))
axes[1].set_ylim(0, 1.02)
axes[1].tick_params(axis='x', labelsize=14)
axes[1].tick_params(axis='y', labelsize=14)
  
# ax3 LR

axes[2].errorbar(
    x=quantile_vals_str, y=quantiles_df['OR'], 
    yerr = np.array([quantiles_df['OR'] - quantiles_df['CI_low'], quantiles_df['CI_high'] - quantiles_df['OR']]),
    fmt='none',
    ecolor='k',
    elinewidth=1,
    capsize=6,
)


axes[2].plot(
    quantile_vals_str, quantiles_df['OR'],
    'ks',
    #color='tab:orange',
)


axes[2].set_xlabel('Glaucoma risk quintile', fontsize=14)
axes[2].set_ylabel('Odds ratio', fontsize=14)
axes[2].set_yticks(np.linspace(0, 120, 7))

axes[0].text(-0.03, 1.03, 'A.', transform=axes[0].transAxes, fontsize=16, fontweight='bold', va='top', ha='right')
axes[1].text(-0.03, 1.03, 'B.', transform=axes[1].transAxes, fontsize=16, fontweight='bold', va='top', ha='right')
axes[2].text(-0.03, 1.03, 'C.', transform=axes[2].transAxes, fontsize=16, fontweight='bold', va='top', ha='right')


fig.legend(loc='center left', bbox_to_anchor=(0.03, 0, 1, -0.05), fontsize=14, ncols=2)

axes[0].spines['top'].set_visible(False)
axes[0].spines['right'].set_visible(False)

axes[1].spines['top'].set_visible(False)
axes[1].spines['right'].set_visible(False)

axes[2].spines['top'].set_visible(False)
axes[2].spines['right'].set_visible(False)

fig.tight_layout()

plt.savefig(f'evaluation/rfe_cv_robability_analysis.png', dpi=600, bbox_inches='tight', pad_inches=0) #0.1

plt.show()

In [None]:
quantiles_df.to_csv('./evaluation/quantile_analysis.tsv', sep='\t')

In [None]:
quantiles_df

# SHAP for RFECV model

In [None]:
minimal_features_rfecv = feature_sets.minimal_features_rfecv['feature'].values

In [None]:
explainer = shap.Explainer(final_model_rfecv)
shap_values = explainer(X_train_imputed_scaled[feature_sets.minimal_features_rfecv['feature'].values])

In [None]:
ax = shap.plots.beeswarm(shap_values, max_display=21, order=shap_values.abs.mean(0), color=sns.color_palette('crest', as_cmap=True), plot_size=(15, 12), show=False)
plt.savefig(fname='./interpretation/rfecv_shap_beeswarm.png', dpi=600/2.54, bbox_inches='tight', pad_inches=0)
plt.show()

In [None]:
shap_df = pd.DataFrame(list(zip(feature_sets.minimal_features_rfecv['feature'].values, shap_values.abs.mean(0).values)), columns=['Predictor', 'Mean absolute SHAP value'])
shap_df.sort_values(by='Mean absolute SHAP value', ascending=False, inplace=True)
shap_df.reset_index(drop=True, inplace=True)
shap_df.to_csv('./interpretation/rfecv_shap_values.tsv', sep='\t')
shap_df.head(20)

In [None]:
min_max_scaler = load('../data/imputed/min_max_scaler.pkl')

In [None]:
# Revert feature scaling
# Different N of features to what MinMaxScaler saw, so need workaround to feed it all columns

a = pd.DataFrame(shap_values.data, columns=minimal_features_rfecv)
features_not_used = X_train_imputed_scaled.columns[~X_train_imputed_scaled.columns.isin(minimal_features_rfecv)]
a[features_not_used] = np.nan
b =  min_max_scaler.inverse_transform(a[X_train_imputed_scaled.columns])
shap_values.data = b[:, ~np.isnan(b).any(axis=0)]

In [None]:
ordered_features_by_shap = shap_df['Predictor'].values

In [None]:
predictor_names_fixed = {
    'IOPg pre-treatment': 'IOPg pre-treatment (mmHg)',
    'IOPg pre-treatment inter-eye difference': 'IOPg pre-treatment inter-eye difference (mmHg)',
    'Corneal hysteresis': 'Corneal hysteresis (mmHg)',
    'Corneal hysteresis inter-eye difference': 'Corneal hysteresis inter-eye difference (mmHg)',
    'Age at initial assesement': 'Age at initial assesement (years)',
    'Spherical equivalent': 'Spherical equivalent (D)',
    'Ethnicity': 'Ethnicity (white)',
    'Diastolic blood pressure': 'Diastolic blood pressure (mmHg)',
    'HbA1c': 'HbA1c (mmol/mol)',
    'Total cholesterol': 'Total cholesterol (mmol/L)',
    'eGFR serum creatinine': 'eGFR (mL/min/1.73m2)',
    'Urinary sodium:creatinine ratio': 'Urinary sodium:creatinine ratio (mmol:mmol)',
}

In [None]:
#shap_fig, shap_axes = plt.subplots(5, 3, figsize=(26, 14), dpi=600, sharex=False, sharey=False)
shap_fig, shap_axes = plt.subplots(5, 3, figsize=(24, 14), dpi=600/2.54, sharex=False, sharey=False)

del_axes = []

for i, feature, ax in zip(range(len(ordered_features_by_shap)), ordered_features_by_shap, shap_axes.flatten()):
    shap_feature = shap_values[:, feature]
    shap.plots.scatter(
        shap_feature,
        #color=shap_values[:, 'Age at initial assesement'],
        color=shap_values[:, 'IOPg pre-treatment'],
        ax=ax,
        show=False,
        cmap=sns.color_palette('crest', as_cmap=True), 
        dot_size=10, 
        alpha=1,#0.5,
        x_jitter=0.1, #0.5,
        xmin=shap_feature.percentile(0.5), 
        xmax=shap_feature.percentile(99.5),
        hist=True,
    )

    ax.set_ylabel('SHAP value')

    if feature in predictor_names_fixed:
        ax.set_xlabel(predictor_names_fixed[feature])

    # for categorical stuff
    unique_vals = np.unique(shap_feature.data)
    if len(unique_vals) <= 5:
        ax.set_xticks(unique_vals)

    # delete colorbar (hue ax added, then colorbar ax added)
    del_axes.append(shap_fig.axes[15+i*2])

    #shap_fig.delaxes(shap_fig.axes(21+(i*2))
    #shap_fig.delaxes(shap_fig.axes[21+(i*1)])
    #shap_fig.delaxes(shap_fig.axes[20])

for ax in del_axes:
    shap_fig.delaxes(ax)

shap_fig.tight_layout()

# shap_fig.subplots_adjust(right=0.95)
# cbar_ax = shap_fig.add_axes([0.97, 0.375, 0.01, 0.25])
shap_fig.subplots_adjust(right=0.97)
cbar_ax = shap_fig.add_axes([0.98, 0.375, 0.01, 0.25])
shap_fig.colorbar(
    shap_fig.axes[0].collections[0],
    cax = cbar_ax,
    #label = 'Age at initial assessment',
    label = 'IOPg pre-treatment',
    location = 'right',
    fraction = 1,
    aspect = 20,
    # shrink = 1/5,
    orientation = 'vertical',
)

# shap_fig.colorbar(
#     shap_fig.axes[0].collections[0],
#     ax = shap_fig.axes,
#     label = 'IOPg pre-treatment',
#     location = 'right',
#     shrink = 1/5,
#     orientation = 'vertical',
#     pad = 0.0001,
# )


plt.savefig(fname='./interpretation/rfecv_shap_scatter.png', dpi=600/2.54, bbox_inches='tight', pad_inches=0)

shap_fig.show()

# SHAP for ODSL LGBM model

In [None]:
odsl_lgbm_estimator = load('./best_hyperparams_fitted/ODSL/lightgbm.pkl')

In [None]:
explainer_odsl = shap.Explainer(odsl_lgbm_estimator)
shap_values_odsl = explainer_odsl(X_train_imputed_scaled)

In [None]:
ax = shap.plots.beeswarm(shap_values_odsl, max_display=16, order=shap_values_odsl.abs.mean(0), color=sns.color_palette('crest', as_cmap=True), plot_size=(15, 12), show=False)
plt.savefig(fname='./interpretation/odsl_lgbm_shap_beeswarm.png', dpi=600/2.54, bbox_inches='tight', pad_inches=0)
plt.show()

# Class imbalance models

- Smote, undersampling, oversampling
- Fit best hyperparams, save, then compare

In [None]:
# SMOTE

algorithm = 'imbalanced_lightgbm_SMOTE'
study = load(f'./optuna_results/{algorithm}/minimal_features_rfecv/optuna_study_minimal_features_rfecv.pkl')

best_params_estimator = study.best_trial.user_attrs['all_params']
best_params_all = study.best_trial.params
best_params_all_str = f'\n'.join([': '.join([key, str(val)]) for key, val in best_params_all.items()])

k_neighbors = best_params_all['k_neighbors']

resample_model = SMOTE(k_neighbors=k_neighbors, sampling_strategy=0.333, random_state=2024)
X_train_resample, y_train_resample = resample_model.fit_resample(X_train_imputed_scaled[minimal_features_rfecv], y_train)

estimator = lgb.LGBMClassifier(**best_params_estimator)
estimator.fit(X_train_resample[minimal_features_rfecv], y_train_resample)

dump(estimator, f'./best_hyperparams_fitted/minimal_features_rfecv/{algorithm}.pkl')
with open(f'./best_hyperparams_fitted/minimal_features_rfecv/{algorithm}_best_params.txt', 'w+') as txt:
    txt.write(best_params_all_str)

In [None]:
# Random oversampling

algorithm = 'imbalanced_lightgbm_RandomOverSampler'
study = load(f'./optuna_results/{algorithm}/minimal_features_rfecv/optuna_study_minimal_features_rfecv.pkl')

best_params_estimator = study.best_trial.user_attrs['all_params']
best_params_all = study.best_trial.params
best_params_all_str = f'\n'.join([': '.join([key, str(val)]) for key, val in best_params_all.items()])

resample_model = RandomOverSampler(sampling_strategy=0.33, random_state=2024)
X_train_resample, y_train_resample = resample_model.fit_resample(X_train_imputed_scaled[minimal_features_rfecv], y_train)

estimator = lgb.LGBMClassifier(**best_params_estimator)
estimator.fit(X_train_resample[minimal_features_rfecv], y_train_resample)

dump(estimator, f'./best_hyperparams_fitted/minimal_features_rfecv/{algorithm}.pkl')
with open(f'./best_hyperparams_fitted/minimal_features_rfecv/{algorithm}_best_params.txt', 'w+') as txt:
    txt.write(best_params_all_str)

In [None]:
# Random undersampling

algorithm = 'imbalanced_lightgbm_RandomUnderSampler'
study = load(f'./optuna_results/{algorithm}/minimal_features_rfecv/optuna_study_minimal_features_rfecv.pkl')

best_params_estimator = study.best_trial.user_attrs['all_params']
best_params_all = study.best_trial.params
best_params_all_str = f'\n'.join([': '.join([key, str(val)]) for key, val in best_params_all.items()])

resample_model = RandomUnderSampler(sampling_strategy=0.33, random_state=2024)
X_train_resample, y_train_resample = resample_model.fit_resample(X_train_imputed_scaled[minimal_features_rfecv], y_train)

estimator = lgb.LGBMClassifier(**best_params_estimator)
estimator.fit(X_train_resample[minimal_features_rfecv], y_train_resample)

dump(estimator, f'./best_hyperparams_fitted/minimal_features_rfecv/{algorithm}.pkl')
with open(f'./best_hyperparams_fitted/minimal_features_rfecv/{algorithm}_best_params.txt', 'w+') as txt:
    txt.write(best_params_all_str)

In [None]:
# Compare

compare_imbalance_dict = {
    'Minimal model': {
        'feature_set': 'minimal_features_rfecv',
        'save_dir': './best_hyperparams_fitted/minimal_features_rfecv/lightgbm',
        'algorithm': 'lightgbm',
    },
    'SMOTE': {
        'feature_set': 'minimal_features_rfecv',
        'save_dir': './best_hyperparams_fitted/minimal_features_rfecv/imbalanced_lightgbm_SMOTE',
        'algorithm': 'lightgbm',
    },
    'Random over-sampling': {
        'feature_set': 'minimal_features_rfecv',
        'save_dir': './best_hyperparams_fitted/minimal_features_rfecv/imbalanced_lightgbm_RandomOverSampler',
        'algorithm': 'lightgbm',
    },
    'Random under-sampling': {
        'feature_set': 'minimal_features_rfecv',
        'save_dir': './best_hyperparams_fitted/minimal_features_rfecv/imbalanced_lightgbm_RandomUnderSampler',
        'algorithm': 'lightgbm',
    },
}

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14/2.54, 8/2.54), dpi=600, sharex=False, sharey=False)
#auroc_bar_ax = axes[0, 0]
roc_ax = axes[0]
prc_ax = axes[1]

compare_models(
    compare_imbalance_dict, 
    f'./evaluation/compare_class_imbalance',
    #auroc_bar_ax=auroc_bar_ax, 
    roc_ax=roc_ax, 
    prc_ax=prc_ax,
)

roc_ax.text(-0.1, 1.1, f'A.', transform=roc_ax.transAxes, fontsize=10, fontweight='bold', va='top', ha='right')
prc_ax.text(-0.1, 1.1, f'B.', transform=prc_ax.transAxes, fontsize=10, fontweight='bold', va='top', ha='right')

roc_ax.set_xlabel('1 - Specificity (false positive rate)', fontsize=8)
roc_ax.set_ylabel('Sensitivity (true positive rate)', fontsize=8)
roc_ax.tick_params(axis='x', labelsize=8)
roc_ax.tick_params(axis='y', labelsize=8)
roc_ax.set_xticks(np.linspace(0, 1, 6))
roc_ax.set_yticks(np.linspace(0, 1, 6))
roc_ax.spines['top'].set_visible(False)
roc_ax.spines['right'].set_visible(False)

prc_ax.set_xlabel('Recall (sensitivity)', fontsize=8)
prc_ax.set_ylabel('Precision (positive predictive value)', fontsize=8)
prc_ax.tick_params(axis='x', labelsize=8)
prc_ax.tick_params(axis='y', labelsize=8)
prc_ax.set_xticks(np.linspace(0, 1, 6))
prc_ax.set_yticks(np.linspace(0, 1, 6))
prc_ax.spines['top'].set_visible(False)
prc_ax.spines['right'].set_visible(False)

labels=['Baseline'] + list(compare_imbalance_dict.keys())
#fig.legend(loc='center', bbox_to_anchor=(0.5, 0, 0.5, 0.5), fontsize=8, labels=labels, handles=roc_ax.lines)
fig.legend(loc='center', bbox_to_anchor=(1.05, 0, 0.1, 1), fontsize=8, labels=labels, handles=roc_ax.lines)

fig.tight_layout()

fig.savefig(f'./evaluation/compare_class_imbalance_roc_prc.png', dpi=600, bbox_inches='tight', pad_inches=0)
fig.show()