In [31]:
import os
import wandb
import numpy as np
import pandas as pd
from copy import deepcopy
from functools import partial
from typing import List, Tuple, Dict, Callable

import tensorflow as tf
import tensorflow.keras as keras

from sklearn.compose import ColumnTransformer
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from sklearn.ensemble import RandomForestClassifier


from alibi.datasets import fetch_adult
from alibi.models.tflow.autoencoder import HeAE
from alibi.models.tflow.actor_critic import Actor, Critic
from alibi.models.tflow.cfrl_models import ADULTEncoder, ADULTDecoder
from alibi.explainers.cfrl_tabular import CounterfactualRLTabular
from alibi.explainers.backends.cfrl_tabular import he_preprocessor, statistics, conditional_vector, category_mapping
from alibi.explainers.cfrl_base import CounterfactualRLBase, ExperienceCallback, TrainingCallback

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


### Train black-box classifier

In [2]:
# fetch adult dataset
adult = fetch_adult()

# separate columns in numerical and categorical
categorical_names = [adult.feature_names[i] for i in adult.category_map.keys()]
categorical_ids = list(adult.category_map.keys())

numerical_names = [name for i, name in enumerate(adult.feature_names) if i not in adult.category_map.keys()]
numerical_ids = [i for i in range(len(adult.feature_names)) if i not in adult.category_map.keys()]

# split data into train and test
x, y = adult.data, adult.target
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=13)

In [3]:
# data preprocessor
num_transf = StandardScaler()
cat_transf = OneHotEncoder(
    categories=[range(len(x)) for x in adult.category_map.values()],
    handle_unknown="ignore"
)
preprocessor = ColumnTransformer(
    transformers=[
        ("num", num_transf, numerical_ids),
        ("cat", cat_transf, categorical_ids)
    ],
    sparse_threshold=0
)

In [4]:
preprocessor.fit(x_train)
x_train_ohe = preprocessor.transform(x_train)
x_test_ohe = preprocessor.transform(x_test)

In [5]:
clf = RandomForestClassifier(max_depth=15, min_samples_split=10, n_estimators=50)
clf.fit(x_train_ohe, y_train)

RandomForestClassifier(max_depth=15, min_samples_split=10, n_estimators=50)

In [6]:
# define prediction function
predict_func = lambda x: clf.predict(preprocessor.transform(x))

# compute accuracy
acc = accuracy_score(y_true=y_test, y_pred=predict_func(x_test))
print("Accuracy: %.3f" % acc)

Accuracy: 0.860


### Train autoencoder

In [7]:
# define input dimension
input_dim = 57

# define hidden dim
hidden_dim = 128

# define latent dimension
latent_dim = 15

# output dims
output_dims = [len(numerical_ids)]
output_dims += [len(adult.category_map[cat_id]) for cat_id in categorical_ids]

In [8]:
# define the heterogeneous auto-encoder
he_ae = HeAE(encoder=ADULTEncoder(hidden_dim=hidden_dim, latent_dim=latent_dim),
             decoder=ADULTDecoder(hidden_dim=hidden_dim, output_dims=output_dims))

In [9]:
# define loss functions
he_loss = [keras.losses.MeanSquaredError()]
he_loss_weights = [1.]

# add categorical losses
for i in range(len(categorical_names)):
    he_loss.append(keras.losses.SparseCategoricalCrossentropy(from_logits=True))
    he_loss_weights.append(1./len(categorical_names))

# define metrics
metrics = {}
for i, cat_name in enumerate(categorical_names):
    name = f"output_{i+2}"
    metrics.update({name: keras.metrics.SparseCategoricalAccuracy()})

In [10]:
# compile model
he_ae.compile(optimizer=keras.optimizers.Adam(learning_rate=1e-3),
              loss=he_loss,
              loss_weights=he_loss_weights,
              metrics=metrics)

In [11]:
# define attribute types
feature_types = {0: int, 8: int, 9: int, 10: int}

# define data preprocessor and inverse preprocessor
ae_preprocessor, ae_inv_preprocessor = he_preprocessor(x=x_train,
                                                       feature_names=adult.feature_names,
                                                       category_map=adult.category_map,
                                                       feature_types=feature_types)

# define trainset
trainset_input = ae_preprocessor(x_train)
trainset_outputs = [x_train_ohe[:, :len(numerical_ids)]]

for cat_id in categorical_ids:
    trainset_outputs.append(x_train[:, cat_id].reshape(-1, 1))

In [12]:
# fit model and then save, or if checkpoint already exists, just load the model
he_ae_path = "tensorflow/he_autoencoder"

if not os.path.exists(he_ae_path):
    os.makedirs(he_ae_path)
    
if len(os.listdir(he_ae_path)) == 0:
    he_ae.fit(trainset_input, trainset_outputs, epochs=50)
    he_ae.save(he_ae_path, save_format="tf")
else:
    he_ae = keras.models.load_model(he_ae_path, compile=False)

Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50
Epoch 7/50
Epoch 8/50
Epoch 9/50
Epoch 10/50
Epoch 11/50
Epoch 12/50
Epoch 13/50
Epoch 14/50
Epoch 15/50
Epoch 16/50
Epoch 17/50
Epoch 18/50
Epoch 19/50
Epoch 20/50
Epoch 21/50
Epoch 22/50
Epoch 23/50
Epoch 24/50
Epoch 25/50


Epoch 26/50
Epoch 27/50
Epoch 28/50
Epoch 29/50
Epoch 30/50
Epoch 31/50
Epoch 32/50
Epoch 33/50
Epoch 34/50
Epoch 35/50
Epoch 36/50
Epoch 37/50
Epoch 38/50
Epoch 39/50
Epoch 40/50
Epoch 41/50
Epoch 42/50
Epoch 43/50
Epoch 44/50
Epoch 45/50
Epoch 46/50
Epoch 47/50
Epoch 48/50
Epoch 49/50


Epoch 50/50
INFO:tensorflow:Assets written to: tensorflow/he_autoencoder/assets


### Counterfactual RL

#### Define dataset specifi attributes and constraints

In [13]:
num_classes = 2

# define immutable features
immutable_features = ['Marital Status', 'Relationship', 'Race', 'Sex']

# define ranges
ranges = {'Age': [-0.0, 1.0]}


# compute statistic for clamping
stats = statistics(x=x_train, 
                   preprocessor=ae_preprocessor, 
                   category_map=adult.category_map)

#### Define experience callbacks

In [14]:
class RewardCallback(ExperienceCallback):
    def __call__(self,
                 step: int, 
                 model: CounterfactualRLBase, 
                 sample: Dict[str, np.ndarray]):
        if step % 100 != 0:
            return
        
        # get the counterfactual and target
        x_cf = model.params["ae_inv_preprocessor"](sample["x_cf"])
        y_t = sample["y_t"]
        
        # get prediction label
        y_m_cf = predict_func(x_cf)
        
        # compute reward
        reward = np.mean(model.params["reward_func"](y_m_cf, y_t))
        wandb.log({"reward": reward})

#### Define training callbacks

In [15]:
class DisplayLossCallback(TrainingCallback):
    def __call__(self,
                 step: int, 
                 update: int, 
                 model: CounterfactualRLBase,
                 sample: Dict[str, np.ndarray],
                 losses: Dict[str, float]):
        # log training losses
        if (step + update) % 100 == 0:
            wandb.log(losses)

#### Define explainer

In [16]:
# define ddpg
explainer = CounterfactualRLTabular(ae=he_ae,
                                    latent_dim=latent_dim,
                                    ae_preprocessor=ae_preprocessor,
                                    ae_inv_preprocessor=ae_inv_preprocessor,
                                    predict_func=predict_func,
                                    coeff_sparsity=0.5,
                                    coeff_consistency=0.5,
                                    num_classes=2,
                                    category_map=adult.category_map,
                                    feature_names=adult.feature_names,
                                    ranges=ranges,
                                    immutable_features=immutable_features,
                                    experience_callbacks=[RewardCallback()],
                                    train_callbacks=[DisplayLossCallback()],
                                    weight_cat=1.0,
                                    weight_num=0.2,
                                    backend="tensorflow",
                                    train_steps=100,
                                    batch_size=128)

#### Fit explainer

In [17]:
#initialize wandb
wandb_project = "ADULT CounterfactualRL"
wandb.init(project=wandb_project)

# fit the explainers
explainer = explainer.fit(x=x_train)

# close wandb
wandb.finish()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mrfs[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.11.1 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


100%|██████████| 100/100 [00:05<00:00, 19.84it/s]


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
reward,0.46875
_runtime,8.0
_timestamp,1627634283.0
_step,0.0


0,1
reward,▁
_runtime,▁
_timestamp,▁
_step,▁


### Save explainer

In [18]:
explainer.save("cfrl_tabular")

INFO:tensorflow:Assets written to: cfrl_tabular/ae.tf/assets
INFO:tensorflow:Assets written to: cfrl_tabular/actor.tf/assets
INFO:tensorflow:Assets written to: cfrl_tabular/critic.tf/assets


### Load explainer

In [19]:
explainer = CounterfactualRLTabular.load("cfrl_tabular", predictor=predict_func)

#### Test explainer

In [39]:
# select some positive examples
x_positive = x_train[predict_func(x_train) == 1]


x = x_positive[:100]
y_t = np.array([0])
c = [{"Age": [0, 20], "Workclass": ["State-gov", "?", "Local-gov"]}]

In [40]:
# generate counterfactual instances
explanation = explainer.explain(x, y_t, c)

In [41]:
# concat labels to the original instances
orig = np.concatenate(
    [explanation.data['orig']['X'], explanation.data['orig']['class'].reshape(-1, 1)],
    axis=1
)

# concat labels to the counterfactual instances
cf = np.concatenate(
    [explanation.data['cf']['X'], explanation.data['cf']['class'].reshape(-1, 1)],
    axis=1
)

# define new feature names and category map by including the label
feature_names = adult.feature_names + ["Label"]
category_map = deepcopy(adult.category_map)
category_map.update({feature_names.index("Label"): adult.target_names})

# replace label encodings with strings
orig_pd = pd.DataFrame(
    category_mapping(orig, category_map),
    columns=feature_names
)

cf_pd = pd.DataFrame(
    category_mapping(cf, category_map),
    columns=feature_names
)

In [42]:
orig_pd

Unnamed: 0,Age,Workclass,Education,Marital Status,Occupation,Relationship,Race,Sex,Capital Gain,Capital Loss,Hours per week,Country,Label
0,41,Private,High School grad,Never-Married,Blue-Collar,Unmarried,White,Male,0,3004,60,?,>50K
1,44,Private,High School grad,Married,White-Collar,Husband,White,Male,0,0,40,United-States,>50K
2,36,Private,Bachelors,Married,Professional,Husband,White,Male,0,0,45,United-States,>50K
3,47,State-gov,Masters,Married,White-Collar,Husband,White,Male,0,0,47,United-States,>50K
4,42,Private,Masters,Married,Sales,Husband,Asian-Pac-Islander,Male,0,0,40,?,>50K
...,...,...,...,...,...,...,...,...,...,...,...,...,...
95,60,Private,High School grad,Married,Sales,Husband,White,Male,7298,0,65,United-States,>50K
96,44,Private,Bachelors,Married,Professional,Husband,White,Male,0,0,35,United-States,>50K
97,73,Self-emp-inc,Bachelors,Married,Sales,Husband,White,Male,0,0,50,United-States,>50K
98,53,Private,High School grad,Married,Professional,Husband,White,Male,15024,0,40,United-States,>50K


In [38]:
cf_pd

Unnamed: 0,Age,Workclass,Education,Marital Status,Occupation,Relationship,Race,Sex,Capital Gain,Capital Loss,Hours per week,Country,Label
0,44,?,High School grad,Married,?,Husband,White,Male,0,0,36,United-States,<=50K
1,44,?,High School grad,Married,Admin,Husband,White,Male,0,0,34,United-States,<=50K
2,44,?,High School grad,Married,Admin,Husband,White,Male,0,0,40,United-States,<=50K
3,44,?,High School grad,Married,Blue-Collar,Husband,White,Male,0,0,32,United-States,<=50K
4,44,?,High School grad,Married,Blue-Collar,Husband,White,Male,0,0,35,United-States,<=50K
...,...,...,...,...,...,...,...,...,...,...,...,...,...
95,49,Local-gov,High School grad,Married,Blue-Collar,Husband,White,Male,0,0,38,United-States,<=50K
96,49,Local-gov,High School grad,Married,Military,Husband,White,Male,0,0,29,United-States,<=50K
97,49,Private,High School grad,Married,Admin,Husband,White,Male,0,0,33,United-States,<=50K
98,49,Private,High School grad,Married,Blue-Collar,Husband,White,Male,0,0,34,United-States,<=50K


#### Diversity

In [46]:
# generate counterfactual instances
x = x_positive[1].reshape(1, -1)
explanation = explainer.explain(x, y_t, c, diversity=True, num_samples=100, batch_size=12)

In [47]:
# concat label column
orig = np.concatenate(
    [explanation.data['orig']['X'], explanation.data['orig']['class'].reshape(-1, 1)],
    axis=1
)

cf = np.concatenate(
    [explanation.data['cf']['X'], explanation.data['cf']['class'].reshape(-1, 1)],
    axis=1
)

# transfrom label encodings to string
orig_pd = pd.DataFrame(
    category_mapping(orig, category_map),
    columns=feature_names,
)

cf_pd = pd.DataFrame(
    category_mapping(cf, category_map),
    columns=feature_names,
)

In [48]:
orig_pd

Unnamed: 0,Age,Workclass,Education,Marital Status,Occupation,Relationship,Race,Sex,Capital Gain,Capital Loss,Hours per week,Country,Label
0,44,Private,High School grad,Married,White-Collar,Husband,White,Male,0,0,40,United-States,>50K


In [49]:
cf_pd

Unnamed: 0,Age,Workclass,Education,Marital Status,Occupation,Relationship,Race,Sex,Capital Gain,Capital Loss,Hours per week,Country,Label
0,44,?,High School grad,Married,Admin,Husband,White,Male,0,0,40,United-States,<=50K
1,44,?,High School grad,Married,Blue-Collar,Husband,White,Male,0,0,34,United-States,<=50K
2,44,?,High School grad,Married,Other,Husband,White,Male,0,0,34,United-States,<=50K
3,44,?,High School grad,Married,Other,Husband,White,Male,0,0,39,United-States,<=50K
4,44,Local-gov,Dropout,Married,?,Husband,White,Male,0,0,32,United-States,<=50K
...,...,...,...,...,...,...,...,...,...,...,...,...,...
95,50,?,High School grad,Married,Military,Husband,White,Male,0,0,30,United-States,<=50K
96,50,?,High School grad,Married,Military,Husband,White,Male,0,0,37,United-States,<=50K
97,50,Local-gov,High School grad,Married,Blue-Collar,Husband,White,Male,0,0,28,United-States,<=50K
98,50,Local-gov,High School grad,Married,Blue-Collar,Husband,White,Male,0,0,34,United-States,<=50K
