## Simulation on the semi-synthetic background-MNIST classification task

In [1]:
from mechanism_learn import pipeline as mlpipe
import numpy as np
import pandas as pd
from scipy.ndimage import maximum_filter
from sklearn.metrics import classification_report
from sklearn.neighbors import KNeighborsClassifier


def maxPooling_imgArr(img_flatArr, kernel_size, padding = "nearest", flatten = False):
    n_imgs = img_flatArr.shape[0]
    img_size = int(img_flatArr.shape[1]**0.5)
    img_arr = img_flatArr.reshape(n_imgs, img_size, img_size)
    resized_imgs = []
    for i in range(n_imgs):
        resized_imgs.append(maximum_filter(img_arr[i], size=kernel_size, mode=padding)[::kernel_size, ::kernel_size])
    resized_imgs = np.array(resized_imgs)
    if flatten:
        resized_imgs = resized_imgs.reshape(n_imgs, -1)
    return resized_imgs

### Load datasets

In [None]:
semisyn_data_dir = r"../../test_data/semi_synthetic_data/"

X_train_conf = pd.read_csv(semisyn_data_dir + "X_train_conf.csv").to_numpy()
Y_train_conf = pd.read_csv(semisyn_data_dir + "Y_train_conf.csv").to_numpy().ravel()
Z_train_conf = pd.read_csv(semisyn_data_dir + "Z_train_conf.csv").to_numpy()
X_train_conf = maxPooling_imgArr(X_train_conf, kernel_size=3, flatten=True)

X_train_unconf = pd.read_csv(semisyn_data_dir + "X_train_unconf.csv").to_numpy()
Y_train_unconf = pd.read_csv(semisyn_data_dir + "Y_train_unconf.csv").to_numpy().ravel()
X_train_unconf = maxPooling_imgArr(X_train_unconf, kernel_size=3, flatten=True)

X_test_unconf = pd.read_csv(semisyn_data_dir + "X_test_unconf.csv").to_numpy()
Y_test_unconf = pd.read_csv(semisyn_data_dir + "Y_test_unconf.csv").to_numpy().ravel()
X_test_unconf = maxPooling_imgArr(X_test_unconf, kernel_size=3, flatten=True)

X_test_conf = pd.read_csv(semisyn_data_dir + "X_test_conf.csv").to_numpy()
Y_test_conf = pd.read_csv(semisyn_data_dir + "Y_test_conf.csv").to_numpy().ravel()
X_test_conf = maxPooling_imgArr(X_test_conf, kernel_size=3, flatten=True)

### Parameters

In [3]:
# Parameters for resampling
n_samples = [(Y_train_conf == 1).sum()*10, (Y_train_conf == 2).sum()*10]
# Parameters for CWGMM
comp_k = 300
max_iter = 500
cov_reg = 1e-3
min_variance_value = 2e-3
tol = 1e-2
cov_type = "diag"
# Parameters for weights estimation
est_method = "histogram"
n_bins = [0, 0]

### Train a deconfounded KNN using mechanism learning

In [4]:
ml_gmm_pipeline = mlpipe.mechanism_learning_process(cause_data = Y_train_conf,
                                                    mechanism_data = Z_train_conf, 
                                                    effect_data = X_train_conf, 
                                                    intv_values = np.unique(Y_train_conf), 
                                                    est_method = est_method, 
                                                    n_bins = n_bins
                                                    )
ml_gmm_pipeline.cwgmm_fit(comp_k = comp_k,
                          max_iter = max_iter,
                          tol = tol,
                          cov_type = cov_type,
                          cov_reg = cov_reg,
                          min_variance_value = min_variance_value,
                          return_model = False)
ml_gmm_pipeline.cwgmm_resample(n_samples = n_samples,
                               return_samples = True)

CW-GMMs fitting:   0%|          | 0/2 [00:00<?, ?model/s]

EM iter:   0%|          | 0/500 [00:00<?, ?it/s]

EM iter:   0%|          | 0/500 [00:00<?, ?it/s]

(array([[55.00956498, 53.96200299, 54.69251438, ..., 54.02702136,
         52.67984517, 54.22774055],
        [16.4056427 , 10.8624333 , 21.13862286, ...,  8.09794241,
         10.73080455, 13.05492523],
        [27.75727001, 27.12058413, 24.56626347, ..., 24.12716242,
         25.47274132, 25.20865076],
        ...,
        [14.21128874, 15.04580393, 22.45451987, ..., 18.70757964,
         20.50699405, 18.78499743],
        [66.43827139, 65.15617884, 60.77913206, ..., 68.36368225,
         67.79030627, 68.05524449],
        [13.0358062 , 15.97460717, 18.26924429, ..., 14.45692635,
         15.89877306, 12.98984422]]),
 array([[2],
        [2],
        [2],
        ...,
        [2],
        [2],
        [2]], dtype=int64))

In [5]:
deconf_gmm_clf = ml_gmm_pipeline.deconf_model_fit(ml_model = KNeighborsClassifier(n_neighbors = 3))

### Train a deconfounded KNN classifier using CB-based deconfounding method

In [6]:
ml_cb_pipeline = mlpipe.mechanism_learning_process(cause_data = Y_train_conf,
                                                    mechanism_data = Z_train_conf, 
                                                    effect_data = X_train_conf, 
                                                    intv_values = np.unique(Y_train_conf), 
                                                    est_method = est_method, 
                                                    n_bins = n_bins)
ml_cb_pipeline.cb_resample(n_samples = n_samples,
                            return_samples = False,
                            verbose = 1)


CB Resampling:   0%|          | 0/2 [00:00<?, ?it/s]

In [7]:
deconf_cb_clf = ml_cb_pipeline.deconf_model_fit(ml_model = KNeighborsClassifier(n_neighbors = 3))

### Train confounded and unconfounded KNN classifiers

In [None]:
conf_clf = KNeighborsClassifier(n_neighbors = 3)
conf_clf = conf_clf.fit(X_train_conf, Y_train_conf)

unconf_clf = KNeighborsClassifier(n_neighbors = 3)
unconf_clf.fit(X_train_unconf, Y_train_unconf)

### Model performance comparison

In [10]:
print("Test on the non-confounded test set:")

y_pred_gmm_deconf_unconf = deconf_gmm_clf.predict(X_test_unconf)
print("Report of deconfounded model using mechanism learning:")
print(classification_report(Y_test_unconf, y_pred_gmm_deconf_unconf, digits=4))
print("-"*20)
y_pred_cb_deconf_unconf = deconf_cb_clf.predict(X_test_unconf)
print("Report of deconfounded model using CB-based method:")
print(classification_report(Y_test_unconf, y_pred_cb_deconf_unconf, digits=4))
print("-"*20)
y_pred_conf_unconf = conf_clf.predict(X_test_unconf)
print("Report of confonded model:")
print(classification_report(Y_test_unconf, y_pred_conf_unconf, digits=4))
print("-"*20)
y_pred_unconf_unconf = unconf_clf.predict(X_test_unconf)
print("Report of unconfounded model:")
print(classification_report(Y_test_unconf, y_pred_unconf_unconf, digits=4))

print("*"*30)
print("Test on the confounded test set:")

y_pred_gmm_deconf_conf = deconf_gmm_clf.predict(X_test_conf)
print("Report of deconfounded model using mechanism learning:")
print(classification_report(Y_test_conf, y_pred_gmm_deconf_conf, digits=4))
print("-"*20)
y_pred_cb_deconf_conf = deconf_cb_clf.predict(X_test_conf)
print("Report of deconfounded model using CB-based method:")
print(classification_report(Y_test_conf, y_pred_cb_deconf_conf, digits=4))
print("-"*20)
y_pred_conf_conf = conf_clf.predict(X_test_conf)
print("Report of confonded model:")
print(classification_report(Y_test_conf, y_pred_conf_conf, digits=4))
print("-"*20)
y_pred_unconf_conf = unconf_clf.predict(X_test_conf)
print("Report of unconfounded model:")
print(classification_report(Y_test_conf, y_pred_unconf_conf, digits=4))

Test on the non-confounded test set:
Report of deconfounded model using mechanism learning:
              precision    recall  f1-score   support

           1     0.9485    0.9132    0.9305      1129
           2     0.9162    0.9503    0.9329      1127

    accuracy                         0.9317      2256
   macro avg     0.9323    0.9318    0.9317      2256
weighted avg     0.9323    0.9317    0.9317      2256

--------------------
Report of deconfounded model using CB-based method:
              precision    recall  f1-score   support

           1     0.9445    0.9043    0.9240      1129
           2     0.9081    0.9468    0.9270      1127

    accuracy                         0.9255      2256
   macro avg     0.9263    0.9256    0.9255      2256
weighted avg     0.9263    0.9255    0.9255      2256

--------------------
Report of confonded model:
              precision    recall  f1-score   support

           1     0.7556    0.6572    0.7030      1129
           2     0.6962 