In [1]:
import tensorflow as tf
import pandas as pd
import numpy as np

from utils.df_loader import (
    load_adult_df,
    load_compas_df,
    load_german_df,
    load_diabetes_df,
    load_breast_cancer_df,
)
from sklearn.model_selection import train_test_split
from utils.preprocessing import preprocess_df
from utils.models import (
    train_three_models,
    evaluation_test,
    save_three_models,
    load_three_models,
)
import utils.cf_proto as util_cf_proto
import utils.dice as util_dice
import utils.gs as util_gs
import utils.watcher as util_watcher
import utils.print as print_f


from utils.save import save_result_as_csv

### Only for alibi, to generate the results from GS and DiCE, commenting this section. ####


seed = 123
# tf.random.set_seed(seed)
# np.random.seed(seed)


In [2]:
RUN_ALIBI = False # (1)True(2)False
TRAIN_MODEL = False # (1)True(2)False
num_instances = 20 # (1)&(2) 20
num_cf_per_instance = 1 # (1)&(2)5

In [3]:
#to ensure the consistency all the experiments was ran on single 3 1070 3090 GPU  M1 GPU. .... OS. 

# github (tf, alibi).
if RUN_ALIBI:
    tf.get_logger().setLevel(40)
    tf.compat.v1.disable_v2_behavior()
    tf.keras.backend.clear_session()
    tf.compat.v1.disable_eager_execution()
    #############################################

    pd.options.mode.chained_assignment = None

    print("TF version: ", tf.__version__)
    print("Eager execution enabled: ", tf.executing_eagerly())  # False    

In [140]:
#dataset_name

In [5]:
print("TF version: ", tf.__version__)
print("Eager execution enabled: ", tf.executing_eagerly())  # False    

TF version:  2.0.0
Eager execution enabled:  True


In [11]:
import tensorflow as tf
print(tf.__version__)

2.0.0


In [141]:
#### Select dataset ####

for dataset_name in [
    # "adult",
    # "german",
     "compas",
    # "diabetes",
    #   "breast_cancer",
]:  # [adult, german, compas]
    print(f"Dataset Name: [{dataset_name}]")
    if dataset_name == "adult":
        dataset_loading_fn = load_adult_df
    elif dataset_name == "german":
        dataset_loading_fn = load_german_df
    elif dataset_name == "compas":
        dataset_loading_fn = load_compas_df
    elif dataset_name == "diabetes":
        dataset_loading_fn = load_diabetes_df
    elif dataset_name == "breast_cancer":
        dataset_loading_fn = load_breast_cancer_df
    else:
        raise Exception("Unsupported dataset")

    df_info = preprocess_df(dataset_loading_fn)

    train_df, test_df = train_test_split(
        df_info.dummy_df, train_size=0.8, random_state=seed, shuffle=True
    )
    X_train = np.array(train_df[df_info.ohe_feature_names])
    y_train = np.array(train_df[df_info.target_name])
    X_test = np.array(test_df[df_info.ohe_feature_names])
    y_test = np.array(test_df[df_info.target_name])

    if TRAIN_MODEL:
        ## Train models.
        models = train_three_models(X_train, y_train)
        ## Save models.
        save_three_models(models, dataset_name)

    ### Load models
    models = load_three_models(X_train.shape[-1], dataset_name)

    ### Print out accuracy on testset.
    evaluation_test(models, X_test, y_test)

    if dataset_name in ["diabetes", "breast_cancer"]:
        # run the cf algorithms supporting categorical data.

        # watcher and gs can only run for the datasets containing numerical data only.
        if RUN_ALIBI:
            print_f.print_block(title="Counterfactual Algorithm", content="Watcher")
            results = util_watcher.generate_watcher_result(
                df_info,
                train_df,
                models,
                num_instances,
                num_cf_per_instance,
                X_train,
                X_test,
                y_test,
                max_iters=1000,
                models_to_run=["dt", "rfc", "nn"],
                output_int=True,
            )
            result_dfs = util_watcher.process_result(results, df_info)
            save_result_as_csv("watcher", dataset_name, result_dfs)

        else:
            print_f.print_block(title="Counterfactual Algorithm", content="GS")
            results = util_gs.generate_gs_result(
                df_info, test_df, models, num_instances, num_cf_per_instance, 2000
            )
            result_dfs = util_gs.process_results(df_info, results)
            save_result_as_csv("GS", dataset_name, result_dfs)

    if RUN_ALIBI:
        print_f.print_block(title="Counterfactual Algorithm", content="Prototype")
        results = util_cf_proto.generate_cf_proto_result(
            df_info,
            train_df,
            models,
            num_instances,
            num_cf_per_instance,
            X_train,
            X_test,
            y_test,
            max_iters=1000,
            models_to_run=["dt", "rfc", "nn"],
            output_int=True,
        )
        result_dfs = util_cf_proto.process_result(results, df_info)
        save_result_as_csv("proto", dataset_name, result_dfs)

    else:
        print_f.print_block(title="Counterfactual Algorithm", content="DiCE")
        results = util_dice.generate_dice_result(
            df_info,
            test_df,
            models,
            num_instances,
            num_cf_per_instance,
            sample_size=50,
            models_to_run=["dt", "rfc", "nn"],
        )
        result_dfs = util_dice.process_results(df_info, results)
        save_result_as_csv("dice", dataset_name, result_dfs)


Dataset Name: [german]
DT: [0.6800] | RF [0.7850] | NN [0.7700]

| DiCE 
Finding counterfactual for dt
instance 0
CF 0


100%|██████████| 1/1 [00:00<00:00,  6.02it/s]

Diverse Counterfactuals found! total time taken: 00 min 00 sec
cp7





instance 1
CF 0


100%|██████████| 1/1 [00:00<00:00,  7.69it/s]

Diverse Counterfactuals found! total time taken: 00 min 00 sec
cp7
instance 2





CF 0


100%|██████████| 1/1 [00:00<00:00,  6.69it/s]

Diverse Counterfactuals found! total time taken: 00 min 00 sec
cp7





instance 3
CF 0


100%|██████████| 1/1 [00:00<00:00,  7.90it/s]


Diverse Counterfactuals found! total time taken: 00 min 00 sec
cp7
instance 4
CF 0


100%|██████████| 1/1 [00:00<00:00,  7.70it/s]


Diverse Counterfactuals found! total time taken: 00 min 00 sec
cp7
instance 5
CF 0


100%|██████████| 1/1 [00:00<00:00,  8.89it/s]


Diverse Counterfactuals found! total time taken: 00 min 00 sec
cp7
instance 6
CF 0


100%|██████████| 1/1 [00:00<00:00,  8.52it/s]


Diverse Counterfactuals found! total time taken: 00 min 00 sec
cp7
instance 7
CF 0


100%|██████████| 1/1 [00:00<00:00,  8.72it/s]


Diverse Counterfactuals found! total time taken: 00 min 00 sec
cp7
instance 8
CF 0


100%|██████████| 1/1 [00:00<00:00,  6.68it/s]

Diverse Counterfactuals found! total time taken: 00 min 00 sec
cp7
instance 9
CF 0



100%|██████████| 1/1 [00:00<00:00,  4.31it/s]


Diverse Counterfactuals found! total time taken: 00 min 00 sec
cp7
instance 10
CF 0


100%|██████████| 1/1 [00:00<00:00,  8.92it/s]


Diverse Counterfactuals found! total time taken: 00 min 00 sec
cp7
instance 11
CF 0


100%|██████████| 1/1 [00:00<00:00,  8.98it/s]


Diverse Counterfactuals found! total time taken: 00 min 00 sec
cp7
instance 12
CF 0


100%|██████████| 1/1 [00:00<00:00,  9.06it/s]


Diverse Counterfactuals found! total time taken: 00 min 00 sec
cp7
instance 13
CF 0


100%|██████████| 1/1 [00:00<00:00,  8.16it/s]


Diverse Counterfactuals found! total time taken: 00 min 00 sec
cp7
instance 14
CF 0


100%|██████████| 1/1 [00:00<00:00,  8.90it/s]


Diverse Counterfactuals found! total time taken: 00 min 00 sec
cp7
instance 15
CF 0


100%|██████████| 1/1 [00:00<00:00,  8.86it/s]


Diverse Counterfactuals found! total time taken: 00 min 00 sec
cp7
instance 16
CF 0


100%|██████████| 1/1 [00:00<00:00,  8.41it/s]


Diverse Counterfactuals found! total time taken: 00 min 00 sec
cp7
instance 17
CF 0


100%|██████████| 1/1 [00:00<00:00,  8.66it/s]


Diverse Counterfactuals found! total time taken: 00 min 00 sec
cp7
instance 18
CF 0


100%|██████████| 1/1 [00:00<00:00,  7.84it/s]


Diverse Counterfactuals found! total time taken: 00 min 00 sec
cp7
instance 19
CF 0


100%|██████████| 1/1 [00:00<00:00,  8.75it/s]


Diverse Counterfactuals found! total time taken: 00 min 00 sec
cp7
Finding counterfactual for rfc
instance 0
CF 0


100%|██████████| 1/1 [00:00<00:00,  6.35it/s]

Diverse Counterfactuals found! total time taken: 00 min 00 sec
cp7





instance 1
CF 0


100%|██████████| 1/1 [00:00<00:00,  7.21it/s]

Diverse Counterfactuals found! total time taken: 00 min 00 sec
cp7
instance 2





CF 0


100%|██████████| 1/1 [00:00<00:00,  6.24it/s]

Diverse Counterfactuals found! total time taken: 00 min 00 sec
cp7





instance 3
CF 0


100%|██████████| 1/1 [00:00<00:00,  7.53it/s]

Diverse Counterfactuals found! total time taken: 00 min 00 sec
cp7





instance 4
CF 0


100%|██████████| 1/1 [00:00<00:00,  5.25it/s]


Diverse Counterfactuals found! total time taken: 00 min 00 sec
cp7
instance 5
CF 0


100%|██████████| 1/1 [00:00<00:00,  6.90it/s]

Diverse Counterfactuals found! total time taken: 00 min 00 sec
cp7





instance 6
CF 0


100%|██████████| 1/1 [00:00<00:00,  7.50it/s]


Diverse Counterfactuals found! total time taken: 00 min 00 sec
cp7
instance 7
CF 0


100%|██████████| 1/1 [00:00<00:00,  7.80it/s]


Diverse Counterfactuals found! total time taken: 00 min 00 sec
cp7
instance 8
CF 0


100%|██████████| 1/1 [00:00<00:00,  6.97it/s]

Diverse Counterfactuals found! total time taken: 00 min 00 sec
cp7





instance 9
CF 0


100%|██████████| 1/1 [00:00<00:00,  7.22it/s]

Diverse Counterfactuals found! total time taken: 00 min 00 sec
cp7





instance 10
CF 0


100%|██████████| 1/1 [00:00<00:00,  5.23it/s]

Diverse Counterfactuals found! total time taken: 00 min 00 sec
cp7





instance 11
CF 0


100%|██████████| 1/1 [00:00<00:00,  6.50it/s]

Diverse Counterfactuals found! total time taken: 00 min 00 sec
cp7





instance 12
CF 0


100%|██████████| 1/1 [00:00<00:00,  2.65it/s]


Diverse Counterfactuals found! total time taken: 00 min 00 sec
cp7
instance 13
CF 0


100%|██████████| 1/1 [00:00<00:00,  5.48it/s]


Diverse Counterfactuals found! total time taken: 00 min 00 sec
cp7
instance 14
CF 0


100%|██████████| 1/1 [00:00<00:00,  3.86it/s]


Diverse Counterfactuals found! total time taken: 00 min 00 sec
cp7
instance 15
CF 0


100%|██████████| 1/1 [00:00<00:00,  3.98it/s]


Diverse Counterfactuals found! total time taken: 00 min 00 sec
cp7
instance 16
CF 0


100%|██████████| 1/1 [00:00<00:00,  7.64it/s]


Diverse Counterfactuals found! total time taken: 00 min 00 sec
cp7
instance 17
CF 0


100%|██████████| 1/1 [00:00<00:00,  7.68it/s]


Diverse Counterfactuals found! total time taken: 00 min 00 sec
cp7
instance 18
CF 0


100%|██████████| 1/1 [00:00<00:00,  7.29it/s]


Diverse Counterfactuals found! total time taken: 00 min 00 sec
cp7
instance 19
CF 0


100%|██████████| 1/1 [00:00<00:00,  7.30it/s]

Diverse Counterfactuals found! total time taken: 00 min 00 sec
cp7





Finding counterfactual for nn
instance 0
CF 0


  0%|          | 0/1 [00:00<?, ?it/s]



To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.



100%|██████████| 1/1 [00:00<00:00,  6.02it/s]

Diverse Counterfactuals found! total time taken: 00 min 00 sec
cp7


100%|██████████| 1/1 [00:00<00:00,  6.00it/s]


instance 1
CF 0


100%|██████████| 1/1 [00:00<00:00,  7.29it/s]

Diverse Counterfactuals found! total time taken: 00 min 00 sec
cp7





instance 2
CF 0


100%|██████████| 1/1 [00:00<00:00,  7.28it/s]

Diverse Counterfactuals found! total time taken: 00 min 00 sec
cp7





instance 3
CF 0


100%|██████████| 1/1 [00:00<00:00,  7.14it/s]

Diverse Counterfactuals found! total time taken: 00 min 00 sec
cp7





instance 4
CF 0


  0%|          | 0/1 [00:00<?, ?it/s]


KeyboardInterrupt: 