In [None]:
import json
import os
import pickle

import pandas as pd
from tqdm import tqdm

from prediction.outcome_prediction.Transformer.architecture import OPSUMTransformer
from prediction.outcome_prediction.Transformer.lightning_wrapper import LitModel
from prediction.outcome_prediction.Transformer.testing.test_transformer_model import test_transformer_model
from prediction.outcome_prediction.data_loading.data_loader import load_data
from prediction.utils.utils import save_json, ensure_dir
from preprocessing.preprocessing_tools.normalisation.reverse_normalisation import reverse_normalisation

In [None]:
features_path = '/Users/jk1/temp/opsum_prepro_output/gsu_prepro_01012023_233050/preprocessed_features_01012023_233050.csv'
labels_path = '/Users/jk1/temp/opsum_prepro_output/gsu_prepro_01012023_233050/preprocessed_outcomes_01012023_233050.csv'
normalisation_parameters_path = '/Users/jk1/temp/opsum_prepro_output/gsu_prepro_01012023_233050/logs_01012023_233050/normalisation_parameters.csv'

In [None]:
outcome = '3M mRS 0-2'
test_size = 0.2
n_splits = 5
seed = 42

In [None]:
normalisation_parameters_df = pd.read_csv(normalisation_parameters_path)

In [None]:
pids, train_data, test_data, train_splits, test_features_lookup_table = load_data(features_path, labels_path, outcome, test_size, n_splits, seed)

In [None]:
test_X_np, test_y_np = test_data

In [None]:
baseline_t0_test_X_np = test_X_np[:, 0, :]
baseline_t0_test_X_df = pd.DataFrame(baseline_t0_test_X_np, columns=test_features_lookup_table['sample_label'])

In [None]:
baseline_t0_test_X_df = baseline_t0_test_X_df.reset_index().rename(columns={'index': 'pidx'}).melt(id_vars='pidx', var_name='sample_label', value_name='value')

In [None]:
non_norm_baseline_t0_test_X_df = reverse_normalisation(baseline_t0_test_X_df, normalisation_parameters_df)

In [None]:
non_norm_baseline_t0_test_X_df[(non_norm_baseline_t0_test_X_df.sample_label == 'max_NIHSS') & (non_norm_baseline_t0_test_X_df.value <= 5)]

In [None]:
non_norm_baseline_t0_test_X_df.sample_label.unique()

Identify patients per treatment category
- no treatment
- only IVT
- IAT (+/- IVT)

In [None]:
all_pidx = set(non_norm_baseline_t0_test_X_df.pidx.unique())
pidx_with_IAT = set(non_norm_baseline_t0_test_X_df[(non_norm_baseline_t0_test_X_df.sample_label == 'categorical_iat_no_iat') & (non_norm_baseline_t0_test_X_df.value == 0)].pidx.unique())
pidx_with_IVT = set(non_norm_baseline_t0_test_X_df[(non_norm_baseline_t0_test_X_df.sample_label == 'categorical_ivt_no_ivt') & (non_norm_baseline_t0_test_X_df.value == 0)].pidx.unique())
pidx_with_only_IVT = pidx_with_IVT - pidx_with_IAT
pidx_with_no_ttt = all_pidx - pidx_with_IAT - pidx_with_IVT

In [None]:
len(all_pidx), len(pidx_with_IAT), len(pidx_with_only_IVT), len(pidx_with_no_ttt)

In [None]:
test_y_np.shape

In [None]:
test_X_np_no_ttt = test_X_np[list(pidx_with_no_ttt), :, :]
test_y_np_no_ttt = test_y_np
test_X_np_no_ttt.shape

In [None]:
pidx_sex_male = set(non_norm_baseline_t0_test_X_df[(non_norm_baseline_t0_test_X_df.sample_label == 'sex_male') & (non_norm_baseline_t0_test_X_df.value == 1)].pidx.unique())
pidx_sex_female = all_pidx - pidx_sex_male
len(pidx_sex_male), len(pidx_sex_female)

In [None]:
pidx_age_under_70 = set(non_norm_baseline_t0_test_X_df[(non_norm_baseline_t0_test_X_df.sample_label == 'age') & (non_norm_baseline_t0_test_X_df.value <= 70)].pidx.unique())
pidx_age_over_70 = all_pidx - pidx_age_under_70
len(pidx_age_under_70), len(pidx_age_over_70)

In [None]:
pidx_mrs3_to_5 = set(list(non_norm_baseline_t0_test_X_df[(non_norm_baseline_t0_test_X_df.sample_label == 'prestroke_disability_(rankin)_3.0') & (non_norm_baseline_t0_test_X_df.value == 1)].pidx.unique()) + list(non_norm_baseline_t0_test_X_df[(non_norm_baseline_t0_test_X_df.sample_label == 'prestroke_disability_(rankin)_4.0') & (non_norm_baseline_t0_test_X_df.value == 1)].pidx.unique()) + list(non_norm_baseline_t0_test_X_df[(non_norm_baseline_t0_test_X_df.sample_label == 'prestroke_disability_(rankin)_5.0') & (non_norm_baseline_t0_test_X_df.value == 1)].pidx.unique()))
pidx_mrs0_to_2 = all_pidx - pidx_mrs3_to_5
len(pidx_mrs3_to_5), len(pidx_mrs0_to_2)