## Simulation on the semi-synthetic background-MNIST classification task

In [1]:
from mechanism_learn import pipeline as mlpipe
import numpy as np
import pandas as pd
from scipy.ndimage import maximum_filter
from sklearn.metrics import classification_report
from sklearn.neighbors import KNeighborsClassifier

semisyn_data_dir = r"../test_data/semi_synthetic_data/"

def maxPooling_imgArr(img_flatArr, kernel_size, padding = "nearest", flatten = False):
    n_imgs = img_flatArr.shape[0]
    img_size = int(img_flatArr.shape[1]**0.5)
    img_arr = img_flatArr.reshape(n_imgs, img_size, img_size)
    resized_imgs = []
    for i in range(n_imgs):
        resized_imgs.append(maximum_filter(img_arr[i], size=kernel_size, mode=padding)[::kernel_size, ::kernel_size])
    resized_imgs = np.array(resized_imgs)
    if flatten:
        resized_imgs = resized_imgs.reshape(n_imgs, -1)
    return resized_imgs

### Load datasets

In [2]:
X_train_conf = pd.read_csv(semisyn_data_dir + "X_train_conf.csv")
Y_train_conf = pd.read_csv(semisyn_data_dir + "Y_train_conf.csv")
Z_train_conf = pd.read_csv(semisyn_data_dir + "Z_train_conf.csv")
X_train_conf = np.array(X_train_conf)
X_train_conf = maxPooling_imgArr(X_train_conf, kernel_size=3, flatten=True)
Y_train_conf = np.array(Y_train_conf).reshape(-1,1)
Z_train_conf = np.array(Z_train_conf).reshape(-1,1)

X_test_unconf = pd.read_csv(semisyn_data_dir + "X_test_unconf.csv")
Y_test_unconf = pd.read_csv(semisyn_data_dir + "Y_test_unconf.csv")
X_test_unconf = np.array(X_test_unconf)
X_test_unconf = maxPooling_imgArr(X_test_unconf, kernel_size=3, flatten=True)
Y_test_unconf = np.array(Y_test_unconf).reshape(-1,1)

X_test_conf = pd.read_csv(semisyn_data_dir + "X_test_conf.csv")
Y_test_conf = pd.read_csv(semisyn_data_dir + "Y_test_conf.csv")
X_test_conf = np.array(X_test_conf)
X_test_conf = maxPooling_imgArr(X_test_conf, kernel_size=3, flatten=True)
Y_test_conf = np.array(Y_test_conf).reshape(-1,1)

### Train a deconfounded KNN using mechanism learning

In [None]:
ml_gmm_pipeline = mlpipe.mechanism_learning_process(cause_data = Y_train_conf,
                                                    mechanism_data = Z_train_conf, 
                                                    effect_data = X_train_conf, 
                                                    intv_values = np.unique(Y_train_conf), 
                                                    dist_map = None, 
                                                    est_method = "histogram", 
                                                    n_bins = [0, 0]
                                                    )

ml_gmm_pipeline.cwgmm_fit(comp_k = 1000,
                          max_iter = 500, 
                          tol = 1e-3, 
                          init_method = "kmeans++", 
                          cov_type = "diag", 
                          random_seed=None, 
                          return_model = False,
                          verbose = 2)

ml_gmm_pipeline.cwgmm_resample(n_samples=10000, return_samples = False)


CW-GMMs fitting:   0%|          | 0/2 [00:00<?, ?model/s]

EM iter:   0%|          | 0/500 [00:00<?, ?it/s]

EM iter:   0%|          | 0/500 [00:00<?, ?it/s]

In [10]:
deconf_gmm_clf = ml_gmm_pipeline.deconf_model_fit(ml_model = KNeighborsClassifier(n_neighbors = 20))

### Train a deconfounded KNN classifier using CB-based deconfounding method

In [None]:
ml_cb_pipeline = mlpipe.mechanism_learning_process(cause_data = Y_train_conf,
                                                   mechanism_data = Z_train_conf, 
                                                   effect_data = X_train_conf, 
                                                   intv_values = np.unique(Y_train_conf), 
                                                   dist_map = None, 
                                                   est_method = "histogram", 
                                                   n_bins = [0, 0]
                                                   )
deconf_X_cb, deconf_Y_cb = ml_cb_pipeline.cb_resample(n_samples = 10000,
                                                      cb_mode = "fast",
                                                      return_samples = True)



CB Resampling:   0%|          | 0/2 [00:00<?, ?it/s]

In [11]:
deconf_cb_clf = ml_cb_pipeline.deconf_model_fit(ml_model = KNeighborsClassifier(n_neighbors = 20))

### Train a confounded KNN classifier

In [5]:
conf_clf = KNeighborsClassifier(n_neighbors = 20)
conf_clf = conf_clf.fit(X_train_conf, Y_train_conf.reshape(-1))

### Model performance comparison

In [12]:
print("Test on the non-confounded test set:")

y_pred_gmm_deconf_unconf = deconf_gmm_clf.predict(X_test_unconf)
print("Report of deconfounded model using mechanism learning:")
print(classification_report(Y_test_unconf, y_pred_gmm_deconf_unconf, digits=4))
print("-"*20)
y_pred_cb_deconf_unconf = deconf_cb_clf.predict(X_test_unconf)
print("Report of deconfounded model using CB-based method:")
print(classification_report(Y_test_unconf, y_pred_cb_deconf_unconf, digits=4))
print("-"*20)
y_pred_conf_unconf = conf_clf.predict(X_test_unconf)
print("Report of confonded model:")
print(classification_report(Y_test_unconf, y_pred_conf_unconf, digits=4))


print("*"*30)
print("Test on the confounded test set:")

y_pred_gmm_deconf_conf = deconf_gmm_clf.predict(X_test_conf)
print("Report of deconfounded model using mechanism learning:")
print(classification_report(Y_test_conf, y_pred_gmm_deconf_conf, digits=4))
print("-"*20)
y_pred_cb_deconf_conf = deconf_cb_clf.predict(X_test_conf)
print("Report of deconfounded model using CB-based method:")
print(classification_report(Y_test_conf, y_pred_cb_deconf_conf, digits=4))
print("-"*20)
y_pred_conf_conf = conf_clf.predict(X_test_conf)
print("Report of confonded model:")
print(classification_report(Y_test_conf, y_pred_conf_conf, digits=4))

Test on the non-confounded test set:
Report of deconfounded model using mechanism learning:
              precision    recall  f1-score   support

           1     0.9347    0.9326    0.9336       430
           2     0.9437    0.9455    0.9446       514

    accuracy                         0.9396       944
   macro avg     0.9392    0.9390    0.9391       944
weighted avg     0.9396    0.9396    0.9396       944

--------------------
Report of deconfounded model using CB-based method:
              precision    recall  f1-score   support

           1     0.9366    0.9279    0.9322       430
           2     0.9402    0.9475    0.9438       514

    accuracy                         0.9386       944
   macro avg     0.9384    0.9377    0.9380       944
weighted avg     0.9385    0.9386    0.9385       944

--------------------
Report of confonded model:
              precision    recall  f1-score   support

           1     0.7828    0.6791    0.7273       430
           2     0.7583 