In [None]:
import pandas as pd
from baseline_scores import hiat_score, thrive_score, span100_score
from sklearn import metrics
import tensorflow as tf
import seaborn as sns
import matplotlib.pyplot as plt
import os
import numpy as np

In [None]:
stroke_registry_path = '/Users/jk1/OneDrive - unige.ch/stroke_research/geneva_stroke_unit_dataset/data/stroke_registry/post_hoc_modified/stroke_registry_post_hoc_modified.xlsx'

In [None]:
data_df = pd.read_excel(stroke_registry_path)

In [None]:
data_df

In [None]:
# only keep data with '3M mRs' not nan
data_df = data_df[data_df['3M mRS'].notnull()]

In [None]:
data_df['3M mRS 0-1'] = np.where(data_df['3M mRS'].isna(), np.nan, np.where(data_df['3M mRS'] <= 1, 1, 0))
data_df['3M mRS 0-2'] = np.where(data_df['3M mRS'].isna(), np.nan, np.where(data_df['3M mRS'] <= 2, 1, 0))

In [None]:
# mrs forwarding (model simply outputs premorbid mrs)
data_df['mrs01_forwarding good outcome pred'] = data_df['Prestroke disability (Rankin)'] <= 1
data_df['mrs01_forwarding_prob'] = data_df['Prestroke disability (Rankin)'] <= 1
data_df['mrs02_forwarding good outcome pred'] = data_df['Prestroke disability (Rankin)'] <= 2
data_df['mrs02_forwarding_prob'] = data_df['Prestroke disability (Rankin)'] <= 2

In [None]:
data_df['HIAT_prob'] =  data_df.apply(
    lambda subject: hiat_score(
        subject['Age (calc.)'],
        subject['NIH on admission'],
        subject['1st glucose']),
    axis=1)

# defined as mRS < 4 at discharge
data_df['HIAT good outcome pred'] = data_df['HIAT_prob'] > 0.5

In [None]:
data_df['span100_prob'] =  data_df.apply(
    lambda subject: span100_score(
        subject['Age (calc.)'],
        subject['NIH on admission']),
    axis=1)
data_df['span100 good outcome pred'] = data_df['span100_prob'] > 0.5

In [None]:
data_df['THRIVE_prob'] = data_df.apply(
    lambda subject: thrive_score(
        subject['Age (calc.)'],
        subject['NIH on admission'],
        subject['MedHist Hypertension'],
        subject['MedHist Diabetes'],
        subject['MedHist Atrial Fibr.']
    ),
    axis=1)

data_df['THRIVE good outcome pred'] = data_df['THRIVE_prob'] > 0.5

In [None]:
from prediction.mrs_outcome_prediction.baseline_models.baseline_scores import thriveC_score

data_df['THRIVEC_prob'] = data_df.apply(
    lambda subject: thriveC_score(
        subject['Age (calc.)'],
        subject['NIH on admission'],
        subject['MedHist Hypertension'],
        subject['MedHist Diabetes'],
        subject['MedHist Atrial Fibr.']
    ),
    axis=1)

data_df['THRIVEC good outcome pred'] = data_df['THRIVEC_prob'] > 0.5

In [None]:
data_df[['Age (calc.)',
        'NIH on admission',
        'MedHist Hypertension',
        'MedHist Diabetes',
        'MedHist Atrial Fibr.', 'THRIVE_prob', 'THRIVE good outcome pred', 'THRIVEC_prob', 'THRIVEC good outcome pred']]

In [None]:
precision = tf.keras.metrics.Precision()
recall = tf.keras.metrics.Recall()
auc = tf.keras.metrics.AUC()
accuracy = tf.keras.metrics.Accuracy()

def plot_roc_curve(fpr, tpr, name:str):
    sns.lineplot(x=fpr, y=tpr, color='orange', label=name)
    sns.lineplot(x=[0, 1], y=[0, 1], color='darkblue', linestyle='--')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title(f'Receiver Operating Characteristic (ROC) Curve of {name}')
    plt.legend()
    plt.show()

In [None]:
result_columns = ['ground truth', 'method', 'auc', 'accuracy', 'f1', 'precision', 'recall',
                  'fpr', 'tpr', 'roc_thresholds']

In [None]:
def evaluate_method(method_name:str, data_df, ground_truth:str='3M mRS 0-1'):
    temp_df = data_df[~data_df[f'{method_name}_prob'].isnull()].copy()
    method_auc = auc(temp_df[f'{method_name}_prob'], temp_df[ground_truth]).numpy()
    method_acc = accuracy(temp_df[f'{method_name} good outcome pred'], temp_df[ground_truth]).numpy()
    method_f1 = metrics.f1_score(temp_df[f'{method_name} good outcome pred'],temp_df[ground_truth])
    method_precision = precision(temp_df[f'{method_name} good outcome pred'], temp_df[ground_truth]).numpy()
    method_recall = recall(temp_df[f'{method_name} good outcome pred'], temp_df[ground_truth]).numpy()
    method_fpr, method_tpr, method_thresholds = metrics.roc_curve(
                            temp_df[ground_truth],
                            temp_df[f'{method_name}_prob']
                                                                  )

    method_df = pd.DataFrame(
        [[ground_truth, method_name, method_auc, method_acc, method_f1, method_precision, method_recall,
         method_fpr, method_tpr, method_thresholds]],
        columns=result_columns)

    plot_roc_curve(method_fpr, method_tpr, method_name)

    return method_df

In [None]:
thrive_df = evaluate_method('THRIVE', data_df, ground_truth='3M mRS 0-2')
thriveC_df = evaluate_method('THRIVEC', data_df, ground_truth='3M mRS 0-2')
hiat_df = evaluate_method('HIAT', data_df, ground_truth='3M mRS 0-2')
span100_df = evaluate_method('span100', data_df, ground_truth='3M mRS 0-2')
mrs02_forwarding_df = evaluate_method('mrs02_forwarding', data_df, ground_truth='3M mRS 0-2')

mrs02_result_df = pd.concat([thrive_df, thriveC_df, hiat_df, span100_df, mrs02_forwarding_df])
mrs02_result_df

In [None]:
thrive_df = evaluate_method('THRIVE', data_df, ground_truth='3M mRS 0-1')
thriveC_df = evaluate_method('THRIVEC', data_df, ground_truth='3M mRS 0-1')
hiat_df = evaluate_method('HIAT', data_df, ground_truth='3M mRS 0-1')
span100_df = evaluate_method('span100', data_df, ground_truth='3M mRS 0-1')
mrs01_forwarding_df = evaluate_method('mrs01_forwarding', data_df, ground_truth='3M mRS 0-1')

mrs01_result_df = pd.concat([thrive_df, thriveC_df, hiat_df, span100_df, mrs01_forwarding_df])
mrs01_result_df

In [None]:
output_dir = '/Users/jk1/Downloads'
# mrs02_result_df.to_csv(os.path.join(output_dir, 'mrs02_clinical_scores_results.csv'))
# mrs01_result_df.to_csv(os.path.join(output_dir, 'mrs01_clinical_scores_results.csv'))