## The experiment on real-world ICH detection using mechanism learning

In [1]:
from mechanism_learn import pipeline as mlpipe
import numpy as np
import pandas as pd
import cv2
import os 
from tensorflow.keras import layers, models
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras import backend as K
from tensorflow.keras.utils import to_categorical
import tensorflow as tf
from sklearn.metrics import classification_report
from sklearn.model_selection import train_test_split
from sklearn.decomposition import PCA
import warnings
warnings.simplefilter('ignore')
import gc

def img_read(dir_list, img_size):
    img_list = []
    for dir in dir_list:
        img = cv2.imread(dir, cv2.IMREAD_GRAYSCALE)
        img = cv2.resize(img, img_size)
        img_list.append(img.flatten())
    return np.array(img_list)


### Setup the GPU to accelerate the computation

In [2]:
print("TensorFlow version:", tf.__version__)
print("Built with CUDA?:", tf.test.is_built_with_cuda())
print("Built with GPU?:", tf.test.is_built_with_gpu_support())
print("Available GPU device:", tf.config.list_physical_devices('GPU'))
tf.config.experimental.set_memory_growth(tf.config.list_physical_devices('GPU')[0], True)


TensorFlow version: 2.10.1
Built with CUDA?: True
Built with GPU?: True
Available GPU device: [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]


### ResNet-CNN structure

In [3]:
def resNetCNN_model(input_shape, num_class):
    input_img = layers.Input(shape=input_shape)
    
    short_cut = input_img
    x = layers.Conv2D(16, (7, 7), activation='relu', padding='same')(input_img)
    x = layers.MaxPooling2D((2, 2), padding='same')(x)
    
    x = layers.Conv2D(32, (5, 5), activation='relu', padding='same')(x)
    x = layers.MaxPooling2D((2, 2), padding='same')(x)
    
    x = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(x)
    x = layers.AveragePooling2D((2, 2), padding='same')(x)
    
    short_cut = layers.AveragePooling2D((8, 8), padding='same')(short_cut)
    x = layers.Add()([x, short_cut])
    x = layers.Activation('relu')(x)
    
    x = layers.Conv2D(16, (1, 1), activation='relu', padding='same')(x)
    
    x = layers.Flatten()(x)
    x = layers.Dropout(0.3)(x)
    
    x = layers.Dense(64, activation='relu')(x)
    x = layers.Dropout(0.2)(x)
    x = layers.Dense(64, activation='relu')(x)
    x = layers.Dropout(0.2)(x)
    x = layers.Dense(32, activation='relu')(x)
    
    encoded = layers.Dense(num_class, activation='softmax')(x)
    
    return models.Model(input_img, encoded)

### Load datasets

In [4]:
dir = r"../test_data/ICH_data/"
effect_dir = dir + r"ct_clean/"
mediator_dir = dir
cause_dir = dir
imgs_names = os.listdir(effect_dir)
imgs_names = sorted(imgs_names, key=lambda x: int(x.split('.')[0]))

In [5]:
effect_imgs = img_read([effect_dir + img_name for img_name in imgs_names], (128, 128))
cause_table = pd.read_csv(cause_dir + "hemorrhage_diagnosis_ct_clean.csv")
mediator_table = pd.read_csv(mediator_dir + "mediator_embedding.csv")

### Cause variable encoding

In [6]:

cause_table["category"] = np.nan
cause_table.loc[cause_table["No_Hemorrhage"] == 1, "category"] = 0
cause_table.loc[cause_table["Intraparenchymal"] == 1, "category"] = 1
cause_table.loc[cause_table["Epidural"] == 1, "category"] = 2
cause_table.loc[cause_table["Subdural"] == 1, "category"] = 3
cause_table.loc[cause_table["Intraventricular"] == 1, "category"] = 4
cause_table.loc[cause_table["Subarachnoid"] == 1, "category"] = 5


### Mediator variable cleaning

In [7]:
(mediator_table == 0).all()

0    False
1    False
2    False
3     True
4    False
5    False
6    False
7    False
8    False
9    False
dtype: bool

In [8]:
category_unique, category_cnt = np.unique(cause_table["category"], return_counts=True)
print("Categories:", category_unique)
print("Counts:", category_cnt)

Categories: [0. 1. 2. 3. 4. 5.]
Counts: [2093   52  171   56   21   18]


In [9]:
mediator_table.drop(columns=['3'], inplace=True)

### Prepare the final datasets

In [10]:
cause_category = cause_table["category"].values
cause_category = cause_category.reshape(-1,1)
mediaor_values = mediator_table.values
n_class = len(cause_table["category"].unique())
X_d = effect_imgs.shape[1]
image_h = int(np.sqrt(X_d))
image_w = int(np.sqrt(X_d))

### Reduce the dimensionality of the effect variable using PCA

In [11]:
img_pca = PCA(n_components=0.95)
img_pca.fit(effect_imgs)
effect_imgs_lowd_embedding = img_pca.transform(effect_imgs)
reduced_X_d = effect_imgs_lowd_embedding.shape[1]
print("PCA Reduced dimension of effect images:", reduced_X_d)

PCA Reduced dimension of effect images: 753


### Split training and test datasets for confounded data

In [12]:
test_prop = 0.4
X_train_conf, X_testval_conf, Y_train_conf, Y_testval_conf = train_test_split(effect_imgs, cause_category, 
                                                                             test_size=test_prop, random_state=42, stratify=cause_category)

val_prop = 0.4
X_val_conf, X_test_conf, Y_val_conf, Y_test_conf = train_test_split(X_testval_conf, Y_testval_conf,
                                                                    test_size=1-val_prop, random_state=42, stratify=Y_testval_conf)

### Deconfounded ResNet-CNN using mechanism learning

In [None]:
# Initializing the machanism learning pipeline
ml_gmm_pipeline = mlpipe.mechanism_learning_process(cause_data = cause_category,
                                                    mechanism_data = mediaor_values, 
                                                    effect_data = effect_imgs_lowd_embedding, 
                                                    intv_values = np.unique(cause_category), 
                                                    dist_map = None, 
                                                    est_method = "kde",
                                                    bandwidth = "scott"
                                                    )

# Fitting the CWGMM model
## Don't sample the data, just fit and return the CWGMM model for later sampling
## Set different comp_k for different intervention categories because of the class imbalance
ml_gmm_pipeline.cwgmm_fit(comp_k = [400, 10, 55, 11, 4, 3],
                          max_iter = 500, 
                          tol = 1e-5, 
                          init_method = "kmeans++", 
                          cov_type = "diag", 
                          random_seed = None, 
                          return_model = False,
                          verbose = 2)


#### Sample the deconfounded data (i.i.d) to form the deconfounded training, validation and test datasets.

In [14]:
n_train_sample = [5000 for i in range(len(np.unique(cause_category)))]
n_val_sample = np.unique(Y_val_conf, return_counts=True)[1]
n_test_sample = np.unique(Y_test_conf, return_counts=True)[1]

# Sample the deconfounded training data
X_train_deconf_gmm, Y_train_deconf_gmm = ml_gmm_pipeline.cwgmm_resample(n_samples=n_train_sample, return_samples = True)
# Inverse transform the sampled image embedding to the original space
X_train_deconf_gmm = img_pca.inverse_transform(X_train_deconf_gmm)
# Clip the X values to be in the range of [0, 255] as the original images
X_train_deconf_gmm = np.clip(X_train_deconf_gmm, 0, 255.0)
# Reshape the data to the original image shape
X_train_deconf_gmm = X_train_deconf_gmm.reshape(-1,image_h,image_w,1)

X_val_deconf, Y_val_deconf = ml_gmm_pipeline.cwgmm_resample(n_samples=n_val_sample, return_samples = True)
# Inverse transform the sampled image embedding to the original space
X_val_deconf = img_pca.inverse_transform(X_val_deconf)
# Clip the X values to be in the range of [0, 255] as the original images
X_val_deconf = np.clip(X_val_deconf, 0, 255.0)
# Reshape the data to the original image shape
X_val_deconf = X_val_deconf.reshape(-1,image_h,image_w,1)

X_test_deconf, Y_test_deconf = ml_gmm_pipeline.cwgmm_resample(n_samples=n_test_sample, return_samples = True)
# Inverse transform the sampled image embedding to the original space
X_test_deconf = img_pca.inverse_transform(X_test_deconf)
# Clip the X values to be in the range of [0, 255] as the original images
X_test_deconf = np.clip(X_test_deconf, 0, 255.0)
# Reshape the data to the original image shape
X_test_deconf = X_test_deconf.reshape(-1,image_h,image_w,1)

#### Compile and train the mechanism learning-based deconfounded ResNet-CNN model

In [14]:
ResNet_gmm_deconf = resNetCNN_model((image_h, image_w, 1), n_class)
ResNet_gmm_deconf.compile(optimizer='adam',
                          loss='categorical_crossentropy',
                          metrics=['accuracy'])

Y_train_deconf_gmm_oh = to_categorical(Y_train_deconf_gmm.reshape(-1), num_classes=n_class)
Y_val_deconf_oh = to_categorical(Y_val_deconf.reshape(-1), num_classes=n_class)

early_stopping=EarlyStopping(monitor='val_accuracy', min_delta=0,
                            patience=10, verbose=0, mode='max',
                            baseline=None, restore_best_weights=True)

ResNet_gmm_deconf.fit(X_train_deconf_gmm, Y_train_deconf_gmm_oh, 
                      epochs=75, batch_size=4, shuffle=True,
                      validation_data=(X_val_deconf, Y_val_deconf_oh),
                      callbacks=[early_stopping])

Epoch 1/75
Epoch 2/75
Epoch 3/75
Epoch 4/75
Epoch 5/75
Epoch 6/75
Epoch 7/75
Epoch 8/75
Epoch 9/75
Epoch 10/75
Epoch 11/75
Epoch 12/75
Epoch 13/75
Epoch 14/75
Epoch 15/75
Epoch 16/75
Epoch 17/75
Epoch 18/75
Epoch 19/75
Epoch 20/75
Epoch 21/75
Epoch 22/75
Epoch 23/75
Epoch 24/75
Epoch 25/75
Epoch 26/75
Epoch 27/75
Epoch 28/75
Epoch 29/75
Epoch 30/75


<keras.callbacks.History at 0x25390a69d90>

#### Predict on synthetic "non-confounded" and confounded test dataset.

In [15]:
Y_pred_deconfModel_confTest_gmm = ResNet_gmm_deconf.predict(X_test_conf.reshape(-1, image_h, image_w, 1))
Y_pred_deconfModel_deconfTest_gmm = ResNet_gmm_deconf.predict(X_test_deconf)
Y_pred_deconfModel_confTest_gmm = np.argmax(Y_pred_deconfModel_confTest_gmm, axis=1)
Y_pred_deconfModel_deconfTest_gmm = np.argmax(Y_pred_deconfModel_deconfTest_gmm, axis=1)



#### Evaluate the model performance on confounded and synthetic "non-confounded" test dataset

In [16]:
print(classification_report(Y_test_conf.reshape(-1), Y_pred_deconfModel_confTest_gmm, digits=4))

              precision    recall  f1-score   support

         0.0     0.9790    0.9264    0.9520       503
         1.0     0.7273    0.6154    0.6667        13
         2.0     0.5429    0.9268    0.6847        41
         3.0     0.6667    0.7692    0.7143        13
         4.0     1.0000    0.4000    0.5714         5
         5.0     0.8000    1.0000    0.8889         4

    accuracy                         0.9119       579
   macro avg     0.7860    0.7730    0.7463       579
weighted avg     0.9344    0.9119    0.9176       579



In [17]:
print(classification_report(Y_test_deconf.reshape(-1), Y_pred_deconfModel_deconfTest_gmm, digits=4))

              precision    recall  f1-score   support

         0.0     0.9959    0.9742    0.9849       503
         1.0     0.6842    1.0000    0.8125        13
         2.0     1.0000    0.9512    0.9750        41
         3.0     0.7647    1.0000    0.8667        13
         4.0     1.0000    1.0000    1.0000         5
         5.0     0.5714    1.0000    0.7273         4

    accuracy                         0.9741       579
   macro avg     0.8360    0.9876    0.8944       579
weighted avg     0.9811    0.9741    0.9760       579



#### Clear the GPU memeory

In [None]:
tf.keras.backend.clear_session()
gc.collect()

2103

### Deconfounded ResNet-CNN using CB-based decofounding method

In [None]:
# Initializing the machanism learning pipeline using CB-based deconfounding method
ml_cb_pipeline = mlpipe.mechanism_learning_process(cause_data = cause_category,
                                                   mechanism_data = mediaor_values, 
                                                   effect_data = effect_imgs_lowd_embedding, 
                                                   intv_values = np.unique(cause_category), 
                                                   dist_map = None, 
                                                   est_method = "kde",
                                                   bandwidth = "scott"
                                                   )
# Resample the data using the front-door CB
X_train_deconf_cb, Y_train_deconf_cb = ml_cb_pipeline.cb_resample(n_samples = 5000,
                                                                  cb_mode = "fast",
                                                                  return_samples = True)

In [None]:
# Inverse transform the sampled image embedding to the original space
X_train_deconf_cb = img_pca.inverse_transform(X_train_deconf_cb)
# Clip the X values to be in the range of [0, 255] as the original images
X_train_deconf_cb = np.clip(X_train_deconf_cb, 0, 255.0)
# Reshape the data to the original image shape
X_train_deconf_cb = X_train_deconf_cb.reshape(-1,image_h,image_w,1)

#### Compile and train the CB-based deconfounded ResNet-CNN model

In [21]:
ResNet_cb_deconf = resNetCNN_model((image_h, image_w, 1), n_class)
ResNet_cb_deconf.compile(optimizer='adam',
                          loss='categorical_crossentropy',
                          metrics=['accuracy'])

Y_train_deconf_cb_oh = to_categorical(Y_train_deconf_cb.reshape(-1), num_classes=n_class)
Y_val_deconf_oh = to_categorical(Y_val_deconf.reshape(-1), num_classes=n_class)

early_stopping=EarlyStopping(monitor='val_accuracy', min_delta=0,
                            patience=10, verbose=0, mode='max',
                            baseline=None, restore_best_weights=True)

ResNet_cb_deconf.fit(X_train_deconf_cb, Y_train_deconf_cb_oh, 
                      epochs=75, batch_size=4, shuffle=True,
                      validation_data=(X_val_deconf, Y_val_deconf_oh),
                      callbacks=[early_stopping])

Epoch 1/75
Epoch 2/75
Epoch 3/75
Epoch 4/75
Epoch 5/75
Epoch 6/75
Epoch 7/75
Epoch 8/75
Epoch 9/75
Epoch 10/75
Epoch 11/75
Epoch 12/75
Epoch 13/75


<keras.callbacks.History at 0x25397bcf490>

#### Predict on synthetic "non-confounded" and confounded test dataset.

In [22]:
Y_pred_deconfModel_confTest_cb = ResNet_cb_deconf.predict(X_test_conf.reshape(-1, image_h, image_w, 1))
Y_pred_deconfModel_deconfTest_cb = ResNet_cb_deconf.predict(X_test_deconf.reshape(-1, image_h, image_w, 1))
Y_pred_deconfModel_confTest_cb = np.argmax(Y_pred_deconfModel_confTest_cb, axis=1)
Y_pred_deconfModel_deconfTest_cb = np.argmax(Y_pred_deconfModel_deconfTest_cb, axis=1)



#### Evaluate the model performance on confounded and synthetic "non-confounded" test dataset

In [23]:
print(classification_report(Y_test_conf.reshape(-1), Y_pred_deconfModel_confTest_cb, digits=4))

              precision    recall  f1-score   support

         0.0     0.9886    0.8588    0.9191       503
         1.0     0.4762    0.7692    0.5882        13
         2.0     0.4872    0.9268    0.6387        41
         3.0     0.4194    1.0000    0.5909        13
         4.0     0.5556    1.0000    0.7143         5
         5.0     1.0000    0.7500    0.8571         4

    accuracy                         0.8653       579
   macro avg     0.6545    0.8842    0.7181       579
weighted avg     0.9251    0.8653    0.8823       579



In [24]:
print(classification_report(Y_test_deconf.reshape(-1), Y_pred_deconfModel_deconfTest_cb, digits=4))

              precision    recall  f1-score   support

         0.0     0.9332    0.9722    0.9523       503
         1.0     1.0000    0.1538    0.2667        13
         2.0     0.7500    0.7317    0.7407        41
         3.0     0.6250    0.3846    0.4762        13
         4.0     0.5000    0.4000    0.4444         5
         5.0     0.0000    0.0000    0.0000         4

    accuracy                         0.9119       579
   macro avg     0.6347    0.4404    0.4801       579
weighted avg     0.9046    0.9119    0.9003       579



#### Clear the GPU memeory

In [None]:
tf.keras.backend.clear_session()
gc.collect()

2203

### Confounded model
#### Resample the confounded training dataset using ramdom oversampling

In [26]:
X_train_conf_resampled = np.empty((0, image_h*image_w))
Y_train_conf_resampled = np.empty((0, 1))

for class_i in range(n_class):
    idx_class_i = np.where(Y_train_conf == class_i)[0]
    random_sample_idx = np.random.choice(idx_class_i, n_train_sample[class_i], replace=True)
    X_train_conf_resampled = np.vstack((X_train_conf_resampled, X_train_conf[random_sample_idx]))
    Y_train_conf_resampled = np.vstack((Y_train_conf_resampled, Y_train_conf[random_sample_idx]))
X_train_conf = X_train_conf_resampled.reshape(-1, image_h, image_w, 1)
Y_train_conf = Y_train_conf_resampled.reshape(-1, 1)
    

#### Compile and train the ResNet-CNN model

In [27]:
Y_train_conf_oh = to_categorical(Y_train_conf, num_classes=n_class)
Y_val_conf_oh = to_categorical(Y_val_conf, num_classes=n_class)

ResNet_conf = resNetCNN_model((image_h, image_w, 1), n_class)
ResNet_conf.compile(optimizer='adam',
                loss='categorical_crossentropy',
                metrics=['accuracy'])

early_stopping=EarlyStopping(monitor='val_accuracy', min_delta=0,
                            patience=10, verbose=0, mode='max',
                            baseline=None, restore_best_weights=True)

ResNet_conf.fit(X_train_conf.reshape(-1, image_h, image_w, 1), Y_train_conf_oh, 
               epochs=75, batch_size=4, shuffle=True,
               validation_data=(X_val_conf.reshape(-1, image_h, image_w, 1), Y_val_conf_oh),
               callbacks=[early_stopping])

Epoch 1/75
Epoch 2/75
Epoch 3/75
Epoch 4/75
Epoch 5/75
Epoch 6/75
Epoch 7/75
Epoch 8/75
Epoch 9/75
Epoch 10/75
Epoch 11/75
Epoch 12/75
Epoch 13/75
Epoch 14/75
Epoch 15/75
Epoch 16/75
Epoch 17/75
Epoch 18/75
Epoch 19/75
Epoch 20/75
Epoch 21/75
Epoch 22/75
Epoch 23/75


<keras.callbacks.History at 0x253909a3c40>

#### Predict on synthetic "non-confounded" and confounded test dataset.

In [28]:
Y_pred_confModel_confTest = ResNet_conf.predict(X_test_conf.reshape(-1, image_h, image_w, 1))
Y_pred_confModel_deconfTest = ResNet_conf.predict(X_test_deconf.reshape(-1, image_h, image_w, 1))
Y_pred_confModel_confTest = np.argmax(Y_pred_confModel_confTest, axis=1)
Y_pred_confModel_deconfTest = np.argmax(Y_pred_confModel_deconfTest, axis=1)



#### Evaluate the model performance on confounded and synthetic "non-confounded" test dataset

In [29]:
print(classification_report(Y_test_conf.reshape(-1), Y_pred_confModel_confTest, digits=4))

              precision    recall  f1-score   support

         0.0     0.9592    0.9821    0.9705       503
         1.0     0.4545    0.3846    0.4167        13
         2.0     0.9667    0.7073    0.8169        41
         3.0     0.8000    0.9231    0.8571        13
         4.0     0.5714    0.8000    0.6667         5
         5.0     0.0000    0.0000    0.0000         4

    accuracy                         0.9396       579
   macro avg     0.6253    0.6329    0.6213       579
weighted avg     0.9349    0.9396    0.9353       579



In [34]:
print(classification_report(Y_test_deconf.reshape(-1), Y_pred_confModel_deconfTest, digits=4))

              precision    recall  f1-score   support

         0.0     0.8822    0.9980    0.9366       503
         1.0     0.7500    0.2308    0.3529        13
         2.0     1.0000    0.0976    0.1778        41
         3.0     1.0000    0.1538    0.2667        13
         4.0     0.0000    0.0000    0.0000         5
         5.0     0.0000    0.0000    0.0000         4

    accuracy                         0.8826       579
   macro avg     0.6054    0.2467    0.2890       579
weighted avg     0.8765    0.8826    0.8401       579



#### Clear the GPU memeory

In [None]:
tf.keras.backend.clear_session()
gc.collect()

1441