In [None]:
import pandas as pd
from baseline_scores import hiat_score, thrive_score, span100_score
import os
import numpy as np
import matplotlib.pyplot as plt
from prediction.outcome_prediction.baseline_models.evaluation_helper_functions import evaluate_method

In [None]:
stroke_registry_path = '/Users/jk1/OneDrive - unige.ch/stroke_research/geneva_stroke_unit_dataset/data/stroke_registry/post_hoc_modified/stroke_registry_post_hoc_modified.xlsx'
preprocessed_features_path = '/Users/jk1/temp/opsum_prepro_output/gsu_prepro_01012023_233050/preprocessed_features_01012023_233050.csv'
preprocessed_outcomes_path = '/Users/jk1/temp/opsum_prepro_output/gsu_prepro_01012023_233050/preprocessed_outcomes_01012023_233050.csv'

In [None]:
output_dir = '/Users/jk1/Downloads'

In [None]:
outcome = '3M mRS'

In [None]:
from preprocessing.geneva_stroke_unit_preprocessing.utils import create_registry_case_identification_column

data_df = pd.read_excel(stroke_registry_path)
data_df['case_admission_id'] = create_registry_case_identification_column(data_df)

In [None]:
outcomes_df = pd.read_csv(preprocessed_outcomes_path)
features_df = pd.read_csv(preprocessed_features_path)
patient_selection = outcomes_df[(outcomes_df.case_admission_id.isin(features_df.case_admission_id.unique())) & (
            ~outcomes_df[outcome].isnull())].case_admission_id.unique()

In [None]:
data_df = data_df[data_df.case_admission_id.isin(patient_selection)]

In [None]:
data_df = pd.merge(data_df, outcomes_df[['case_admission_id', '3M mRS 0-1', '3M mRS 0-2', '3M Death']], how='left', on='case_admission_id')

In [None]:
# data_df['3M mRS 0-1'] = np.where(data_df['3M mRS'].isna(), np.nan, np.where(data_df['3M mRS'] <= 1, 1, 0))
# data_df['3M mRS 0-2'] = np.where(data_df['3M mRS'].isna(), np.nan, np.where(data_df['3M mRS'] <= 2, 1, 0))

In [None]:
# mrs forwarding (model simply outputs premorbid mrs)
data_df['mrs01_forwarding good outcome pred'] = data_df['Prestroke disability (Rankin)'] <= 1
data_df['mrs01_forwarding_prob'] = data_df['Prestroke disability (Rankin)'] <= 1
data_df['mrs02_forwarding good outcome pred'] = data_df['Prestroke disability (Rankin)'] <= 2
data_df['mrs02_forwarding_prob'] = data_df['Prestroke disability (Rankin)'] <= 2

In [None]:
data_df['HIAT_prob'] =  data_df.apply(
    lambda subject: hiat_score(
        subject['Age (calc.)'],
        subject['NIH on admission'],
        subject['1st glucose']),
    axis=1)

# defined as mRS < 4 at discharge
data_df['HIAT good outcome pred'] = data_df['HIAT_prob'] > 0.5

In [None]:
data_df['span100_prob'] =  data_df.apply(
    lambda subject: span100_score(
        subject['Age (calc.)'],
        subject['NIH on admission']),
    axis=1)
data_df['span100 good outcome pred'] = data_df['span100_prob'] > 0.5

In [None]:
data_df['THRIVE_prob'] = data_df.apply(
    lambda subject: thrive_score(
        subject['Age (calc.)'],
        subject['NIH on admission'],
        subject['MedHist Hypertension'],
        subject['MedHist Diabetes'],
        subject['MedHist Atrial Fibr.']
    ),
    axis=1)

data_df['THRIVE good outcome pred'] = data_df['THRIVE_prob'] > 0.5

In [None]:
from prediction.outcome_prediction.baseline_models.baseline_scores import thriveC_score

data_df['THRIVEC_prob'] = data_df.apply(
    lambda subject: thriveC_score(
        subject['Age (calc.)'],
        subject['NIH on admission'],
        subject['MedHist Hypertension'],
        subject['MedHist Diabetes'],
        subject['MedHist Atrial Fibr.']
    ),
    axis=1)

data_df['THRIVEC good outcome pred'] = data_df['THRIVEC_prob'] > 0.5

In [None]:
extracted_df = data_df[['Age (calc.)',
        'NIH on admission',
        'MedHist Hypertension',
        'MedHist Diabetes',
        'MedHist Atrial Fibr.', 'THRIVE_prob', 'THRIVE good outcome pred', 'THRIVEC_prob', 'THRIVEC good outcome pred']]
extracted_df

In [None]:
extracted_df[extracted_df.THRIVE_prob.isna()]

In [None]:
thrive_df, roc_auc_figure, _, _ = evaluate_method('THRIVE', data_df[~data_df['THRIVE_prob'].isna()], ground_truth='3M mRS 0-2')
roc_auc_figure
plt.show()
thriveC_df, roc_auc_figure, THRIVE_C_bootstrapping_data, THRIVE_C_testing_data = evaluate_method('THRIVEC', data_df[~data_df['THRIVEC_prob'].isna()], ground_truth='3M mRS 0-2')
roc_auc_figure
plt.show()
hiat_df, roc_auc_figure, _, _ = evaluate_method('HIAT', data_df[~data_df['HIAT_prob'].isna()], ground_truth='3M mRS 0-2')
roc_auc_figure
plt.show()
span100_df, roc_auc_figure, _, _ = evaluate_method('span100', data_df[~data_df['span100_prob'].isna()], ground_truth='3M mRS 0-2')
roc_auc_figure
plt.show()
mrs02_forwarding_df, roc_auc_figure, _, _ = evaluate_method('mrs02_forwarding', data_df[~data_df['mrs02_forwarding_prob'].isna()], ground_truth='3M mRS 0-2')

mrs02_result_df = pd.concat([thrive_df, thriveC_df, hiat_df, span100_df, mrs02_forwarding_df])
mrs02_result_df

In [None]:
from prediction.utils.utils import ensure_dir
# save bootstrapped ground truth and predictions
import pickle
THRIVE_C_output_dir = os.path.join(output_dir, 'THRIVE_C_predictions')
ensure_dir(THRIVE_C_output_dir)
thriveC_df.to_csv(os.path.join(THRIVE_C_output_dir, 'thriveC_mrs02_results.csv'))
pickle.dump(THRIVE_C_bootstrapping_data, open(os.path.join(THRIVE_C_output_dir, 'bootstrapped_gt_and_pred.pkl'), 'wb'))
pickle.dump(THRIVE_C_testing_data, open(os.path.join(THRIVE_C_output_dir, 'test_gt_and_pred.pkl'), 'wb'))

In [None]:
thrive_df, roc_auc_figure, _, _ = evaluate_method('THRIVE', data_df[~data_df['THRIVE_prob'].isna()], ground_truth='3M mRS 0-1')
plt.show()
thriveC_df, roc_auc_figure, _, _ = evaluate_method('THRIVEC', data_df[~data_df['THRIVEC_prob'].isna()], ground_truth='3M mRS 0-1')
plt.show()
hiat_df, roc_auc_figure, _, _ = evaluate_method('HIAT', data_df[~data_df['HIAT_prob'].isna()], ground_truth='3M mRS 0-1')
plt.show()
span100_df, roc_auc_figure, _, _ = evaluate_method('span100', data_df[~data_df['span100_prob'].isna()], ground_truth='3M mRS 0-1')
plt.show()
mrs01_forwarding_df, roc_auc_figure, _, _ = evaluate_method('mrs01_forwarding', data_df[~data_df['mrs01_forwarding_prob'].isna()], ground_truth='3M mRS 0-1')
plt.show()

mrs01_result_df = pd.concat([thrive_df, thriveC_df, hiat_df, span100_df, mrs01_forwarding_df])
mrs01_result_df

In [None]:
# mrs02_result_df.to_csv(os.path.join(output_dir, 'mrs02_clinical_scores_results.csv'))
# mrs01_result_df.to_csv(os.path.join(output_dir, 'mrs01_clinical_scores_results.csv'))

Evaluating for death at 3 months

In [None]:
data_df['inv_THRIVEC_prob'] = 1 - data_df['THRIVEC_prob']
data_df['inv_THRIVEC good outcome pred'] = data_df['inv_THRIVEC_prob'] > 0.5

In [None]:
death_thriveC_df, roc_auc_figure, death_THRIVE_C_bootstrapping_data, death_THRIVE_C_testing_data = evaluate_method('inv_THRIVEC', data_df[~data_df['THRIVEC_prob'].isna()], ground_truth='3M Death')
roc_auc_figure
plt.show()

In [None]:
death_THRIVE_C_output_dir = os.path.join(output_dir, 'THRIVE_C_3m_death_predictions')
ensure_dir(death_THRIVE_C_output_dir)
death_thriveC_df.to_csv(os.path.join(death_THRIVE_C_output_dir, '3m_death_results.csv'))
# pickle.dump(death_THRIVE_C_bootstrapping_data, open(os.path.join(death_THRIVE_C_output_dir, '3m_death_bootstrapped_gt_and_pred.pkl'), 'wb'))
# pickle.dump(death_THRIVE_C_testing_data, open(os.path.join(death_THRIVE_C_output_dir, '3m_death_test_gt_and_pred.pkl'), 'wb'))