In [1]:
import tensorflow as tf
import pandas as pd
import numpy as np

from utils.df_loader import load_adult_df, load_compas_df, load_german_df
from utils.preprocessing import preprocess_df
from sklearn.model_selection import train_test_split
from utils.dice import generate_dice_result, process_results
from utils.models import train_three_models, evaluation_test, save_three_models, load_three_models
from utils.save import save_result_as_csv

pd.options.mode.chained_assignment = None 

print('TF version: ', tf.__version__)
print('Eager execution enabled: ', tf.executing_eagerly()) # False


seed = 123
tf.random.set_seed(seed)
np.random.seed(seed)


TF version:  2.4.0-rc0
Eager execution enabled:  True


In [2]:
dataset_name = 'adult'

if dataset_name == 'adult':
    dataset_loading_fn = load_adult_df
elif dataset_name == 'german':
    dataset_loading_fn = load_german_df
elif dataset_name == 'compas':
    dataset_loading_fn = load_compas_df
else:
    raise Exception("Unsupported dataset")

In [3]:
df_info = preprocess_df(dataset_loading_fn)

In [4]:
train_df, test_df = train_test_split(df_info.dummy_df, train_size=.8, random_state=seed, shuffle=True)

In [5]:
X_train = np.array(train_df[df_info.ohe_feature_names])
y_train = np.array(train_df[df_info.target_name])
X_test = np.array(test_df[df_info.ohe_feature_names])
y_test = np.array(test_df[df_info.target_name])

In [6]:
models = train_three_models(X_train, y_train)

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20


In [7]:
### Print out accuracy on testset.
evaluation_test(models, X_test, y_test)

DT: [0.8216] | RF [0.8465] | NN [0.8508]


In [8]:
save_three_models(models, dataset_name)

In [9]:
models = load_three_models(X_train.shape[-1], dataset_name)

# DiCE

In [10]:
num_instances = 5
num_cf_per_instance = 1

results = generate_dice_result(df_info, test_df, models, num_instances, num_cf_per_instance, 50)
result_dfs = process_results(df_info, results)
save_result_as_csv("dice", dataset_name, result_dfs)

100%|██████████| 1/1 [00:00<00:00,  6.52it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

Finding counterfactual for dt
instance 0
CF 0
instance 1
CF 0


100%|██████████| 1/1 [00:00<00:00,  6.85it/s]
100%|██████████| 1/1 [00:00<00:00,  6.84it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

instance 2
CF 0
instance 3
CF 0


100%|██████████| 1/1 [00:00<00:00,  6.76it/s]
100%|██████████| 1/1 [00:00<00:00,  6.89it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

instance 4
CF 0
Finding counterfactual for rfc
instance 0
CF 0


100%|██████████| 1/1 [00:00<00:00,  5.78it/s]
100%|██████████| 1/1 [00:00<00:00,  5.97it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

instance 1
CF 0
instance 2
CF 0


100%|██████████| 1/1 [00:00<00:00,  6.08it/s]
100%|██████████| 1/1 [00:00<00:00,  6.09it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

instance 3
CF 0
instance 4
CF 0


100%|██████████| 1/1 [00:00<00:00,  6.12it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

Finding counterfactual for nn
instance 0
CF 0


100%|██████████| 1/1 [00:00<00:00,  2.03it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

No Counterfactuals found for the given configuration, perhaps try with different parameters... ; total time taken: 00 min 00 sec
instance 1
CF 0


100%|██████████| 1/1 [00:00<00:00,  2.14it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

No Counterfactuals found for the given configuration, perhaps try with different parameters... ; total time taken: 00 min 00 sec
instance 2
CF 0


100%|██████████| 1/1 [00:00<00:00,  2.16it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

No Counterfactuals found for the given configuration, perhaps try with different parameters... ; total time taken: 00 min 00 sec
instance 3
CF 0


100%|██████████| 1/1 [00:00<00:00,  2.15it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

No Counterfactuals found for the given configuration, perhaps try with different parameters... ; total time taken: 00 min 00 sec
instance 4
CF 0


100%|██████████| 1/1 [00:00<00:00,  2.14it/s]

No Counterfactuals found for the given configuration, perhaps try with different parameters... ; total time taken: 00 min 00 sec



