# Define the Splits and Used Patients

All or only TIA patients can be selected.  
Then the number of splits (Folds) is defined and the random seed is selected.  
The selection is either stratified by MRS score or by binarized MRS score (0-2 vs. 3-6).

**Note:** All parameters and version names should be defined before running all cells.

### Import Libraries and Install Packages

In [None]:
!pip install seaborn

In [None]:
%matplotlib inline

import os
import h5py
import random
import pandas as pd
import numpy as np

import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.model_selection import StratifiedKFold

### Load Data

In [None]:
IMG_DIR = "/tf/notebooks/hezo/stroke_zurich/data/" 
# IMG_DIR2 = "/tf/notebooks/kook/data-sets/stroke-lh/"
OUTPUT_DIR = "/tf/notebooks/brdd/xAI_stroke_3d/data/"

In [None]:
path_img = IMG_DIR + 'dicom_2d_192x192x3_clean_interpolated_18_02_2021_preprocessed2.h5'
path_tab = IMG_DIR + 'baseline_data_zurich_prepared.csv'

# should only non TIA (transient ischemic attack) patients be included?
only_non_tia = True

with h5py.File(path_img, "r") as h5:
# with h5py.File(IMG_DIR2 + 'dicom-3d.h5', "r") as h5:
# both images are the same
    X_in = h5["X"][:]
    Y_img = h5["Y_img"][:]
    Y_pat = h5["Y_pat"][:]
    pat = h5["pat"][:]

X_in = np.expand_dims(X_in, axis = 4)

print("image shape in: ", X_in.shape)
print("image min, max, mean, std: ", X_in.min(), X_in.max(), X_in.mean(), X_in.std())


## read tabular data
dat = pd.read_csv(path_tab, sep=",")

print("tabular shape in: ", dat.shape)

In [None]:
# get original data

n = []
for p in pat:
    if p in dat.p_id.values:
        n.append(p)
n = len(n)

# match image and tabular data
X = np.zeros((n, X_in.shape[1], X_in.shape[2], X_in.shape[3], X_in.shape[4]))
X_tab = np.zeros((n, 13))
Y_mrs = np.zeros((n))
Y_eventtia = np.zeros((n))
p_id = np.zeros((n))

i = 0
for j, p in enumerate(pat):
    if p in dat.p_id.values:
        k = np.where(dat.p_id.values == p)[0]
        X_tab[i,:] = dat.loc[k,["age", "sexm", "nihss_baseline", "mrs_before",
                               "stroke_beforey", "tia_beforey", "ich_beforey", 
                               "rf_hypertoniay", "rf_diabetesy", "rf_hypercholesterolemiay", 
                               "rf_smokery", "rf_atrial_fibrillationy", "rf_chdy"]]
        X[i] = X_in[j]
        p_id[i] = pat[j]
        Y_eventtia[i] = Y_pat[j]
        Y_mrs[i] = dat.loc[k, "mrs3"]
        i += 1
p_id = p_id.astype("int")
        
print("X img out shape: ", X.shape)
print("X tab out shape: ", X_tab.shape)
print("Y mrs out shape: ", Y_mrs.shape)

In [None]:
## all mrs <= 2 are favorable all higher unfavorable
Y_new = []
for element in Y_mrs:
    if element in [0,1,2]:
        Y_new.append(0)
    else:
        Y_new.append(1)
Y_new = np.array(Y_new)

p_idx = np.arange(0, len(p_id))+1

In [None]:
fig, (ax1, ax2) = plt.subplots(1,2, figsize=(15, 8))
sns.countplot(x = Y_mrs, ax = ax1)
sns.countplot(x = Y_new, ax = ax2)

print(sum(Y_new == 0), sum(Y_new == 1))

Left: Distribution of the MRS score  
Right: Distribution of the binary outcome (MRS>=2)

In [None]:
# reduce the data to only non-TIA patients if desired
if only_non_tia:
    p_idx = p_idx[Y_eventtia == 1]
    X_tab = X_tab[Y_eventtia == 1]
    Y_mrs = Y_mrs[Y_eventtia == 1]
    p_id = p_id[Y_eventtia == 1]
    Y_new = Y_new[Y_eventtia == 1]
    Y_eventtia = Y_eventtia[Y_eventtia == 1]
    
    
    fig, (ax1, ax2) = plt.subplots(1,2)
    sns.countplot(x = Y_mrs, ax = ax1)
    sns.countplot(x = Y_new, ax = ax2)
    
    print(sum(Y_new == 0), sum(Y_new == 1))

In [None]:
# Safe ids in pd
id_tab = pd.DataFrame(
    {"p_idx": p_idx,
     "p_id": p_id,
     "mrs": Y_mrs,
     "unfavorable": Y_new
    }
)

## Define Splits

In [None]:
# Create StratifiedKFold object.
skf = StratifiedKFold(n_splits=10, shuffle=True, random_state=200)
# 10 Fold V0 random_state 100
# 10 Fold V1 random_state 999
# 10 Fold V2 random_stat3 500
# 10 Fold V3 random_state 200
folds = []
  

for train_index, test_index in skf.split(p_id, Y_new): # 10 Fold sigmoid stratified with Outcome Good/Bad (V0, V2, V3)
# for train_index, test_index in skf.split(p_id, Y_mrs): # 10 Fold sigmoid stratified with Outcome MRS (V1)
    folds.append(p_id[test_index])
    print(sum(Y_new[test_index]))

In [None]:
for fold in folds:
    print(len(fold))

In [None]:
for i, fold in enumerate(folds):
    id_tab["fold" + str(i)] = "train" 
    
    # increment for val (+5 so that no fold has only 40 in train & test)
    j = i+5
    if j >= len(folds):
        j = j-10

    id_tab.loc[id_tab["p_id"].isin(fold), "fold"+str(i)] = "test"
    id_tab.loc[id_tab["p_id"].isin(folds[j]), "fold"+str(i)] = "val"

In [None]:
id_tab

In [None]:
for i in range(len(folds)):
    print(id_tab["fold"+str(i)].value_counts())

## Save Data

Version overview:

- andrea_split: splits and training as in paper 
- 10Fold_sigmoid_V0 (old name: 10Fold_sigmoid): 10 stratifed (with outcome mrs > 2 or mrs <= 2) Folds trained with the last layer beeing activated with sigmoid (5 ensembles per split)
- 10Fold_softmax_V0: same Folds as 10Fold_sigmoid but last layer activated with softmax (5 ensembles per split)
- 10Fold_softmax_V1: new 10 Fold stratified (with mrs) and last layer activated with softmax (10 ensembles per split)
- 10Fold_sigmoid_V1: same Folds as 10Fold_softmax_V1 and last layer activated with sigmoid (10 ensembles per split)
- 10Fold_sigmoid_V2: 10 Fold binary stratified (mrs > or <= 2) other seed than V0, and last layer activated with sigmoid (5 ensembles per split)
- 10Fold_sigmoid_V2f: same as 10Fold_sigmoid_V2 but with flatten Layer
- 10Fold_signoid_V3: 10 Fold binary stratified (mrs > or <= 2) without TIA patients, other seed than V0 and V2 and last layer activated wih sigmoid (5 ensembles per split)

In [None]:
# id_tab.to_csv(OUTPUT_DIR + "10Fold_ids_V0.csv",  index=False)
# id_tab.to_csv(OUTPUT_DIR + "10Fold_ids_V1.csv",  index=False)
# id_tab.to_csv(OUTPUT_DIR + "10Fold_ids_V2.csv",  index=False)
id_tab.to_csv(OUTPUT_DIR + "10Fold_ids_V3.csv",  index=False)

In [None]:
X = X.squeeze()
X = np.float32(X)

np.save(OUTPUT_DIR + "prepocessed_dicom_3d.npy", X)

## Analyze Data

Check distribution of each split.

In [None]:
# id_tab = pd.read_csv(OUTPUT_DIR + "10Fold_ids_V0.csv", sep=",")
# id_tab = pd.read_csv(OUTPUT_DIR + "10Fold_ids_V1.csv", sep=",")
# id_tab = pd.read_csv(OUTPUT_DIR + "10Fold_ids_V2.csv", sep=",")
id_tab = pd.read_csv(OUTPUT_DIR + "10Fold_ids_V3.csv", sep=",")
X = np.load(OUTPUT_DIR + "prepocessed_dicom_3d.npy")

In [None]:
id_tab["unfavorable"].value_counts()

In [None]:
print(X.shape)

In [None]:
for i in range(10):
    fig, (ax1, ax2) = plt.subplots(1,2)
    sns.countplot(x = id_tab[id_tab["fold"+str(i)]=="test"].mrs, ax = ax1)
    sns.countplot(x = id_tab[id_tab["fold"+str(i)]=="test"].unfavorable, ax = ax2)

### Check Images

Check if images are the same when accessing them.

In [None]:
patient = 460
index1 = id_tab[id_tab.p_id == patient].p_idx.values[0] -1
# index1 = id_tab[id_tab.p_id == patient].index
print(index1)
index2 = np.argwhere(pat == patient).squeeze()
print(index2)

In [None]:
im1 = X[index1].astype("float64")
im2 = X_in.squeeze()[index2].astype("float64")
np.allclose(im1, im2)