# The experiment on real-world ICH detection using mechanism learning

In [29]:
import causalBootstrapping as cb
from distEst_lib import MultivarContiDistributionEstimator
import numpy as np
import pandas as pd
import cv2
import os 
from tensorflow.keras import layers, models
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras import backend as K
from tensorflow.keras.utils import to_categorical
import tensorflow as tf
from imblearn.over_sampling import RandomOverSampler
from imblearn.under_sampling import RandomUnderSampler
from sklearn.metrics import classification_report
from sklearn.model_selection import train_test_split


#### Set random seeds for reproducibility

In [30]:
tf.random.set_seed(42)
rng_bootstrap = np.random.RandomState(42)
rng_train_test = np.random.RandomState(6)

In [31]:
def img_read(dir_list, img_size):
    img_list = []
    for dir in dir_list:
        img = cv2.imread(dir, cv2.IMREAD_GRAYSCALE)
        img = cv2.resize(img, img_size)
        img_list.append(img.flatten())
    return np.array(img_list)

In [32]:
dir = r"../../test_data/ICH_data/"
effect_dir = dir + r"ct_clean/"
mediator_dir = dir
cause_dir = dir
imgs_names = os.listdir(effect_dir)

In [33]:
imgs_names = sorted(imgs_names, key=lambda x: int(x.split('.')[0]))

In [34]:
effect_imgs = img_read([effect_dir + img_name for img_name in imgs_names], (128, 128))
cause_table = pd.read_csv(cause_dir + "hemorrhage_diagnosis_ct_clean.csv")
mediator_table = pd.read_csv(mediator_dir + "mediator_embedding.csv")

## Cause Var Transformation

In [35]:
mapping = {0: "No Hemorrhage", 1: "Intraventricular", 2: "Intraparenchymal", 3: "Subarachnoid", 4: "Epidural", 5: "Subdural"}

cause_table["category"] = np.nan
cause_table.loc[cause_table["No_Hemorrhage"] == 1, "category"] = 0
cause_table.loc[cause_table["Intraventricular"] == 1, "category"] = 1
cause_table.loc[cause_table["Intraparenchymal"] == 1, "category"] = 2
cause_table.loc[cause_table["Subarachnoid"] == 1, "category"] = 3
cause_table.loc[cause_table["Epidural"] == 1, "category"] = 4
cause_table.loc[cause_table["Subdural"] == 1, "category"] = 5
cause_table["category"] = cause_table["category"].astype(int)

In [36]:
cause_category = cause_table["category"].values

In [37]:
n_class = len(cause_table["category"].unique())

## Mediator Var Exploration

In [38]:
mediator_table.head()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9
0,7.445242,11.299638,8.707513,0.0,3.309047,7.546306,0.0,10.398108,0.0,12.386361
1,7.445242,11.299638,8.707513,0.0,3.309047,7.546306,0.0,10.398108,0.0,12.386361
2,7.445242,11.299638,8.707513,0.0,3.309047,7.546306,0.0,10.398108,0.0,12.386361
3,7.445242,11.299638,8.707513,0.0,3.309047,7.546306,0.0,10.398108,0.0,12.386361
4,7.445242,11.299638,8.707513,0.0,3.309047,7.546306,0.0,10.398108,0.0,12.386361


In [39]:
mediator_table['3'].value_counts()

3
0.0    2411
Name: count, dtype: int64

In [40]:
# Drop the all-zero column
mediator_table.drop(columns=['3'], inplace=True)

## Causal Bootstrapping

#### Format variables and normalize the images

In [41]:
effect_imgs = effect_imgs/255.0
cause_category = cause_category.reshape(-1,1)
mediaor_values = mediator_table.values

#### Estimate required distributions for causal bootstrapping

In [42]:

cause_data = {"Y": cause_category}
mediator_data = {"Z": mediaor_values}
effect_data = {"X": effect_imgs} 
n_bins_yz = [0]+[6 for i in range(mediaor_values.shape[1])]
n_bins_y = [0]

joint_yz_data = np.concatenate((cause_category, mediaor_values), axis = 1)

dist_estimator_yz = MultivarContiDistributionEstimator(data_fit=joint_yz_data, n_bins = n_bins_yz)
pdf_yz, pyz = dist_estimator_yz.fit_histogram()
dist_estimator_y = MultivarContiDistributionEstimator(data_fit=cause_category, n_bins = n_bins_y)
pdf_y, py = dist_estimator_y.fit_histogram()

dist_map = {"Y,Z": lambda Y, Z: pdf_yz([Y,Z]),
            "Y',Z": lambda Y_prime, Z: pdf_yz([Y_prime,Z]),
            "Y": lambda Y: pdf_y(Y),
            "Y'": lambda Y_prime: pdf_y(Y_prime)}

#### Prepare inputs for causal bootstrapping

In [43]:
causal_graph = '"Front-door"; \
                Y; X; Z; \
                Y -> Z; \
                Z -> X; \
                X <-> Y;'
weight_func_lam, weight_func_str = cb.general_cb_analysis(causal_graph = causal_graph, 
                                                          effect_var_name = 'X', 
                                                          cause_var_name = 'Y',
                                                          info_print = True)
N = cause_category.shape[0]
w_func = weight_func_lam(dist_map = dist_map, N = N, kernel = None)
cause_var_name = list(cause_data.keys())[0]
effect_var_name = list(effect_data.keys())[0]
mediator_var_name = list(mediator_data.keys())[0]

data = {cause_var_name+"'": list(cause_data.values())[0]}
data.update(effect_data)
data.update(mediator_data)
n_sample = [1000, 800, 800, 800, 800, 800]

Interventional prob.:p_{Y}(X)=\sum_{Z,Y'}[p(X|Z,Y')p(Z|Y)p(Y')]
Causal bootstrapping weights function: [P(Y,Z)P(Y')]/N*[P(Y)P(Y',Z)]
Required distributions:
1: P(Y,Z)
2: P(Y')
3: P(Y)
4: P(Y',Z)


#### Split the training and test datasets by identifying the causal bootstrap indices

In [44]:
trainVal_X_deconf = np.empty((0, effect_imgs.shape[1]))
trainVal_Y_deconf = np.empty((0, 1))
test_X_deconf = np.empty((0, effect_imgs.shape[1]))
test_Y_deconf = np.empty((0, 1))

trainVal_X_conf = np.empty((0, effect_imgs.shape[1]))
trainVal_Y_conf = np.empty((0, 1))
test_X_conf = np.empty((0, effect_imgs.shape[1]))
test_Y_conf = np.empty((0, 1))

train_size = 0.85
weights = np.empty((n_class, N))
idx_ib_by_class = []
idx_ib_all_unique = []
for class_i in range(n_class):
    weights_itv = cb.weight_compute(w_func, data, intv_var = {"Y":[class_i for i in range(N)]})
    # Causal Bootstrapping weights
    weights[class_i] = weights_itv
    weights_norm = weights_itv/np.sum(weights_itv)
    idx_ib = rng_bootstrap.choice(range(N), size = n_sample[class_i], replace = True, p = weights_norm)
    idx_ib_by_class.append(list(idx_ib))
    idx_ib_unique = np.unique(idx_ib)
    idx_ib_all_unique.append(idx_ib_unique)
idx_ib_all_unique = np.unique(np.concatenate(idx_ib_all_unique))
idx_ib_train = rng_train_test.choice(idx_ib_all_unique, size = int(train_size*idx_ib_all_unique.shape[0]), replace = False)
idx_ib_test = np.setdiff1d(idx_ib_all_unique, idx_ib_train)
idx_all_test = np.setdiff1d(range(N), idx_ib_train)

In [45]:
for class_i in range(n_class):
    idx_ib_train_bootstrap = [idx for idx in idx_ib_by_class[class_i] if idx in idx_ib_train]
    trainVal_X_deconf = np.concatenate((trainVal_X_deconf, effect_imgs[idx_ib_train_bootstrap, :]), axis = 0)
    trainVal_Y_deconf = np.concatenate((trainVal_Y_deconf, np.array([class_i for i in range(len(idx_ib_train_bootstrap))]).reshape(-1,1)), axis = 0)
    
    idx_ib_test_bootstrap = [idx for idx in idx_ib_by_class[class_i] if idx in idx_ib_test]
    test_X_deconf = np.concatenate((test_X_deconf, effect_imgs[idx_ib_test_bootstrap, :]), axis = 0)
    test_Y_deconf = np.concatenate((test_Y_deconf, np.array([class_i for i in range(len(idx_ib_test_bootstrap))]).reshape(-1,1)), axis = 0)
    
trainVal_X_conf = effect_imgs[idx_ib_train, :]
trainVal_Y_conf = cause_category[idx_ib_train].reshape(-1,1)
test_X_conf = effect_imgs[idx_all_test, :]
test_Y_conf = cause_category[idx_all_test].reshape(-1,1)

#### Resample the confounded datasets

In [46]:
# Resample the training set
n_rus_train = [np.min([trainVal_Y_deconf[trainVal_Y_deconf == i].shape[0],trainVal_Y_conf[trainVal_Y_conf == i].shape[0]]) for i in range(n_class)]
rus_train = RandomUnderSampler(sampling_strategy={i: n_rus_train[i] for i in range(n_class)}, random_state = 42)
trainVal_X_conf, trainVal_Y_conf = rus_train.fit_resample(trainVal_X_conf, trainVal_Y_conf)

n_ros_train = [trainVal_Y_deconf[trainVal_Y_deconf == i].shape[0] for i in range(n_class)]
ros_train = RandomOverSampler(sampling_strategy = {i: n_ros_train[i] for i in range(n_class)}, random_state = 42)
trainVal_X_conf, trainVal_Y_conf = ros_train.fit_resample(trainVal_X_conf, trainVal_Y_conf)

# Resample the test set
n_rus_test = [np.min([test_Y_deconf[test_Y_deconf == i].shape[0],test_Y_conf[test_Y_conf == i].shape[0]]) for i in range(6)]
rus_test = RandomUnderSampler(sampling_strategy={i: n_rus_test[i] for i in range(n_class)}, random_state = 42)
test_X_conf, test_Y_conf = rus_test.fit_resample(test_X_conf, test_Y_conf)

n_ros_test = [test_Y_deconf[test_Y_deconf == i].shape[0] for i in range(n_class)]
ros_test = RandomOverSampler(sampling_strategy = {i: n_ros_test[i] for i in range(n_class)}, random_state = 42)
test_X_conf, test_Y_conf = ros_test.fit_resample(test_X_conf, test_Y_conf)

#### Split the validation set for both datasets

In [47]:
# Deconfounded data split
val_size = 0.1
train_X_deconf, val_X_deconf, train_Y_deconf, val_Y_deconf = train_test_split(trainVal_X_deconf, trainVal_Y_deconf, 
                                                                              test_size=val_size, stratify=trainVal_Y_deconf, random_state=17)
# Confounded data split
train_X_conf, val_X_conf, train_Y_conf, val_Y_conf = train_test_split(trainVal_X_conf, trainVal_Y_conf,
                                                                      test_size=val_size, stratify=trainVal_Y_conf, random_state=17)

### De-confounded model

In [48]:
def resNetCNN_model(input_shape, num_class):
    input_img = layers.Input(shape=input_shape)
    
    short_cut = input_img
    x = layers.Conv2D(16, (7, 7), activation='relu', padding='same')(input_img)
    x = layers.MaxPooling2D((2, 2), padding='same')(x)
    
    x = layers.Conv2D(32, (5, 5), activation='relu', padding='same')(x)
    x = layers.MaxPooling2D((2, 2), padding='same')(x)
    
    x = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(x)
    x = layers.AveragePooling2D((2, 2), padding='same')(x)
    
    short_cut = layers.AveragePooling2D((8, 8), padding='same')(short_cut)
    x = layers.Add()([x, short_cut])
    x = layers.Activation('relu')(x)
    
    x = layers.Conv2D(16, (1, 1), activation='relu', padding='same')(x)
    
    x = layers.Flatten()(x)
    x = layers.Dropout(0.3)(x)
    
    x = layers.Dense(64, activation='relu')(x)
    x = layers.Dropout(0.2)(x)
    x = layers.Dense(64, activation='relu')(x)
    x = layers.Dropout(0.2)(x)
    x = layers.Dense(32, activation='relu')(x)
    
    encoded = layers.Dense(num_class, activation='softmax')(x)
    
    return models.Model(input_img, encoded)

model_deconf = resNetCNN_model((128, 128, 1), n_class)
model_deconf.compile(optimizer='adam',
                loss='categorical_crossentropy',
                metrics=['accuracy'])
model_deconf.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 128, 128, 1  0           []                               
                                )]                                                                
                                                                                                  
 conv2d (Conv2D)                (None, 128, 128, 16  800         ['input_1[0][0]']                
                                )                                                                 
                                                                                                  
 max_pooling2d (MaxPooling2D)   (None, 64, 64, 16)   0           ['conv2d[0][0]']                 
                                                                                              

#### Train the deconfounded model

In [49]:
train_Y_deconf_oh = to_categorical(train_Y_deconf.reshape(-1), num_classes=6)
val_Y_deconf_oh = to_categorical(val_Y_deconf.reshape(-1), num_classes=6)
early_stopping=EarlyStopping(monitor='val_accuracy', min_delta=0,
                            patience=8, verbose=0, mode='max',
                            baseline=None, restore_best_weights=True)

model_deconf.fit(train_X_deconf.reshape(-1, 128, 128, 1), train_Y_deconf_oh, 
                 epochs=75, batch_size=4, shuffle=True,
                 validation_data=(val_X_deconf.reshape(-1, 128, 128, 1), val_Y_deconf_oh),
                 callbacks=[early_stopping])

Epoch 1/75
Epoch 2/75
Epoch 3/75
Epoch 4/75
Epoch 5/75
Epoch 6/75
Epoch 7/75
Epoch 8/75
Epoch 9/75
Epoch 10/75
Epoch 11/75
Epoch 12/75
Epoch 13/75
Epoch 14/75
Epoch 15/75
Epoch 16/75
Epoch 17/75
Epoch 18/75
Epoch 19/75
Epoch 20/75
Epoch 21/75


<keras.callbacks.History at 0x203c1ecad60>

In [50]:
Y_pred_conf_train = model_deconf.predict(trainVal_X_conf.reshape(-1, 128, 128, 1))
Y_pred_conf_test = model_deconf.predict(test_X_conf.reshape(-1, 128, 128, 1))

Y_pred_deconf_train = model_deconf.predict(trainVal_X_deconf.reshape(-1, 128, 128, 1))
Y_pred_deconf_test = model_deconf.predict(test_X_deconf.reshape(-1, 128, 128, 1))

Y_pred_conf_train = np.argmax(Y_pred_conf_train, axis=1)
Y_pred_conf_test = np.argmax(Y_pred_conf_test, axis=1)
Y_pred_deconf_train = np.argmax(Y_pred_deconf_train, axis=1)
Y_pred_deconf_test = np.argmax(Y_pred_deconf_test, axis=1)



#### Test on confounded test set

In [51]:
print(classification_report(test_Y_conf.reshape(-1), Y_pred_conf_test))

              precision    recall  f1-score   support

           0       0.90      0.96      0.93       152
           1       0.97      1.00      0.99       269
           2       1.00      0.72      0.83        95
           3       0.79      1.00      0.88        74
           4       0.95      0.84      0.89       100
           5       0.98      1.00      0.99        90

    accuracy                           0.94       780
   macro avg       0.93      0.92      0.92       780
weighted avg       0.94      0.94      0.94       780



#### Test on deconfounded test set

In [52]:
print(classification_report(test_Y_deconf.reshape(-1), Y_pred_deconf_test))

              precision    recall  f1-score   support

         0.0       0.90      0.93      0.92       152
         1.0       0.92      1.00      0.96       269
         2.0       1.00      0.52      0.68        95
         3.0       0.72      0.46      0.56        74
         4.0       0.84      0.85      0.85       100
         5.0       0.67      1.00      0.80        90

    accuracy                           0.86       780
   macro avg       0.84      0.79      0.79       780
weighted avg       0.87      0.86      0.85       780



### Confounded model

In [53]:
train_Y_oh_conf = to_categorical(train_Y_conf, num_classes=6)
val_Y_oh_conf = to_categorical(val_Y_conf.reshape(-1), num_classes=6)

model_conf = resNetCNN_model((128, 128, 1), n_class)
model_conf.compile(optimizer='adam',
                loss='categorical_crossentropy',
                metrics=['accuracy'])
model_conf.summary()

Model: "model_1"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_2 (InputLayer)           [(None, 128, 128, 1  0           []                               
                                )]                                                                
                                                                                                  
 conv2d_4 (Conv2D)              (None, 128, 128, 16  800         ['input_2[0][0]']                
                                )                                                                 
                                                                                                  
 max_pooling2d_2 (MaxPooling2D)  (None, 64, 64, 16)  0           ['conv2d_4[0][0]']               
                                                                                            

#### Train the confounded model

In [54]:
early_stopping=EarlyStopping(monitor='val_accuracy', min_delta=0,
                            patience=8, verbose=0, mode='max',
                            baseline=None, restore_best_weights=True)

model_conf.fit(train_X_conf.reshape(-1, 128, 128, 1), train_Y_oh_conf, 
                epochs=75, batch_size=4, shuffle=True,
                validation_data=(val_X_conf.reshape(-1, 128, 128, 1), val_Y_oh_conf),
                callbacks=[early_stopping])

Epoch 1/75
Epoch 2/75
Epoch 3/75
Epoch 4/75
Epoch 5/75
Epoch 6/75
Epoch 7/75
Epoch 8/75
Epoch 9/75
Epoch 10/75
Epoch 11/75
Epoch 12/75
Epoch 13/75
Epoch 14/75
Epoch 15/75
Epoch 16/75


<keras.callbacks.History at 0x205c4ad2730>

In [55]:
Y_pred_conf_train = model_conf.predict(trainVal_X_conf.reshape(-1, 128, 128, 1))
Y_pred_conf_test = model_conf.predict(test_X_conf.reshape(-1, 128, 128, 1))

Y_pred_deconf_train = model_conf.predict(trainVal_X_deconf.reshape(-1, 128, 128, 1))
Y_pred_deconf_test = model_conf.predict(test_X_deconf.reshape(-1, 128, 128, 1))

Y_pred_conf_train = np.argmax(Y_pred_conf_train, axis=1)  
Y_pred_conf_test = np.argmax(Y_pred_conf_test, axis=1)
Y_pred_deconf_train = np.argmax(Y_pred_deconf_train, axis=1)
Y_pred_deconf_test = np.argmax(Y_pred_deconf_test, axis=1)



#### Test on confounded test set

In [56]:
print(classification_report(test_Y_conf.reshape(-1), Y_pred_conf_test))

              precision    recall  f1-score   support

           0       0.96      0.93      0.95       152
           1       1.00      1.00      1.00       269
           2       0.98      0.66      0.79        95
           3       0.70      1.00      0.82        74
           4       0.94      0.94      0.94       100
           5       0.97      1.00      0.98        90

    accuracy                           0.94       780
   macro avg       0.92      0.92      0.91       780
weighted avg       0.95      0.94      0.94       780



#### Test on deconfounded test set

In [57]:
print(classification_report(test_Y_deconf.reshape(-1), Y_pred_deconf_test))

              precision    recall  f1-score   support

         0.0       0.96      0.88      0.91       152
         1.0       0.93      0.39      0.55       269
         2.0       0.24      0.54      0.33        95
         3.0       0.58      0.46      0.51        74
         4.0       0.79      0.94      0.86       100
         5.0       0.67      1.00      0.80        90

    accuracy                           0.65       780
   macro avg       0.69      0.70      0.66       780
weighted avg       0.77      0.65      0.66       780

