In [1]:
import os
import wandb
import numpy as np
import pandas as pd
from copy import deepcopy
from functools import partial
from typing import List, Tuple, Dict, Callable

import tensorflow as tf
import tensorflow.keras as keras

from sklearn.compose import ColumnTransformer
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from sklearn.ensemble import RandomForestClassifier


from alibi.datasets import fetch_adult
from alibi.models.tensorflow.autoencoder import HeAE
from alibi.models.tensorflow.actor_critic import Actor, Critic
from alibi.models.tensorflow.cfrl_models import ADULTEncoder, ADULTDecoder
from alibi.explainers.cfrl_tabular import CounterfactualRLTabular
from alibi.explainers.cfrl_base import CounterfactualRLBase, ExperienceCallback, TrainingCallback
from alibi.explainers.backends.cfrl_tabular import get_he_preprocessor, get_statistics, \
    get_conditional_vector, apply_category_mapping

%load_ext autoreload
%autoreload 2

### Train black-box classifier

In [2]:
# fetch adult dataset
adult = fetch_adult()

# separate columns in numerical and categorical
categorical_names = [adult.feature_names[i] for i in adult.category_map.keys()]
categorical_ids = list(adult.category_map.keys())

numerical_names = [name for i, name in enumerate(adult.feature_names) if i not in adult.category_map.keys()]
numerical_ids = [i for i in range(len(adult.feature_names)) if i not in adult.category_map.keys()]

# split data into train and test
X, Y = adult.data, adult.target
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.2, random_state=13)

In [3]:
# data preprocessor
num_transf = StandardScaler()
cat_transf = OneHotEncoder(
    categories=[range(len(x)) for x in adult.category_map.values()],
    handle_unknown="ignore"
)
preprocessor = ColumnTransformer(
    transformers=[
        ("num", num_transf, numerical_ids),
        ("cat", cat_transf, categorical_ids)
    ],
    sparse_threshold=0
)

In [4]:
preprocessor.fit(X_train)
X_train_ohe = preprocessor.transform(X_train)
X_test_ohe = preprocessor.transform(X_test)

In [5]:
clf = RandomForestClassifier(max_depth=15, min_samples_split=10, n_estimators=50)
clf.fit(X_train_ohe, Y_train)

RandomForestClassifier(max_depth=15, min_samples_split=10, n_estimators=50)

In [6]:
# define prediction function
predictor = lambda x: clf.predict(preprocessor.transform(x))

# compute accuracy
acc = accuracy_score(y_true=Y_test, y_pred=predictor(X_test))
print("Accuracy: %.3f" % acc)

Accuracy: 0.862


### Train autoencoder

In [7]:
# define input dimension
input_dim = 57

# define hidden dim
hidden_dim = 128

# define latent dimension
latent_dim = 15

# output dims
output_dims = [len(numerical_ids)]
output_dims += [len(adult.category_map[cat_id]) for cat_id in categorical_ids]

In [8]:
# define the heterogeneous auto-encoder
he_ae = HeAE(encoder=ADULTEncoder(hidden_dim=hidden_dim, latent_dim=latent_dim),
             decoder=ADULTDecoder(hidden_dim=hidden_dim, output_dims=output_dims))

In [9]:
# define loss functions
he_loss = [keras.losses.MeanSquaredError()]
he_loss_weights = [1.]

# add categorical losses
for i in range(len(categorical_names)):
    he_loss.append(keras.losses.SparseCategoricalCrossentropy(from_logits=True))
    he_loss_weights.append(1./len(categorical_names))

# define metrics
metrics = {}
for i, cat_name in enumerate(categorical_names):
    name = f"output_{i+2}"
    metrics.update({name: keras.metrics.SparseCategoricalAccuracy()})

In [10]:
# compile model
he_ae.compile(optimizer=keras.optimizers.Adam(learning_rate=1e-3),
              loss=he_loss,
              loss_weights=he_loss_weights,
              metrics=metrics)

In [11]:
# define attribute types
feature_types = {0: int, 8: int, 9: int, 10: int}

# define data preprocessor and inverse preprocessor
ae_preprocessor, ae_inv_preprocessor = get_he_preprocessor(X=X_train,
                                                           feature_names=adult.feature_names,
                                                           category_map=adult.category_map,
                                                           feature_types=feature_types)

# define trainset
trainset_input = ae_preprocessor(X_train)
trainset_outputs = [X_train_ohe[:, :len(numerical_ids)]]

for cat_id in categorical_ids:
    trainset_outputs.append(X_train[:, cat_id].reshape(-1, 1))

In [12]:
# fit model and then save, or if checkpoint already exists, just load the model
he_ae_path = "tensorflow/he_autoencoder"

if not os.path.exists(he_ae_path):
    os.makedirs(he_ae_path)
    
if len(os.listdir(he_ae_path)) == 0:
    he_ae.fit(trainset_input, trainset_outputs, epochs=50)
    he_ae.save(he_ae_path, save_format="tf")
else:
    he_ae = keras.models.load_model(he_ae_path, compile=False)

### Counterfactual RL

#### Define dataset specifi attributes and constraints

In [13]:
num_classes = 2

# define immutable features
immutable_features = ['Marital Status', 'Relationship', 'Race', 'Sex']

# define ranges
ranges = {'Age': [-0.0, 1.0]}


# compute statistic for clamping
stats = get_statistics(X=X_train, 
                       preprocessor=ae_preprocessor, 
                       category_map=adult.category_map)

#### Define experience callbacks

In [14]:
class RewardCallback(ExperienceCallback):
    def __call__(self,
                 step: int, 
                 model: CounterfactualRLBase, 
                 sample: Dict[str, np.ndarray]):
        if step % 100 != 0:
            return
        
        # get the counterfactual and target
        X_cf = model.params["ae_inv_preprocessor"](sample["X_cf"])
        Y_t = sample["Y_t"]
        
        # get prediction label
        Y_m_cf = predictor(X_cf)
        
        # compute reward
        reward = np.mean(model.params["reward_func"](Y_m_cf, Y_t))
        wandb.log({"reward": reward})

#### Define training callbacks

In [15]:
class DisplayLossCallback(TrainingCallback):
    def __call__(self,
                 step: int, 
                 update: int, 
                 model: CounterfactualRLBase,
                 sample: Dict[str, np.ndarray],
                 losses: Dict[str, float]):
        # log training losses
        if (step + update) % 100 == 0:
            wandb.log(losses)

#### Define explainer

In [16]:
# define ddpg
explainer = CounterfactualRLTabular(ae=he_ae,
                                    latent_dim=latent_dim,
                                    ae_preprocessor=ae_preprocessor,
                                    ae_inv_preprocessor=ae_inv_preprocessor,
                                    predictor=predictor,
                                    coeff_sparsity=0.5,
                                    coeff_consistency=0.5,
                                    num_classes=2,
                                    category_map=adult.category_map,
                                    feature_names=adult.feature_names,
                                    ranges=ranges,
                                    immutable_features=immutable_features,
                                    experience_callbacks=[RewardCallback()],
                                    train_callbacks=[DisplayLossCallback()],
                                    weight_cat=1.0,
                                    weight_num=1.0,
                                    backend="tensorflow",
                                    train_steps=100,
                                    batch_size=100)

#### Fit explainer

In [17]:
#initialize wandb
wandb_project = "ADULT CounterfactualRL"
wandb.init(project=wandb_project)

# fit the explainers
explainer = explainer.fit(X=X_train)

# close wandb
wandb.finish()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mrfs[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.11.1 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


100%|██████████| 100/100 [00:04<00:00, 23.58it/s]


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
reward,0.42
_runtime,7.0
_timestamp,1627834443.0
_step,0.0


0,1
reward,▁
_runtime,▁
_timestamp,▁
_step,▁


### Save explainer

In [18]:
explainer.save("cfrl_tabular")



INFO:tensorflow:Assets written to: cfrl_tabular/ae.tf/assets


INFO:tensorflow:Assets written to: cfrl_tabular/ae.tf/assets


INFO:tensorflow:Assets written to: cfrl_tabular/actor.tf/assets


INFO:tensorflow:Assets written to: cfrl_tabular/actor.tf/assets


INFO:tensorflow:Assets written to: cfrl_tabular/critic.tf/assets


INFO:tensorflow:Assets written to: cfrl_tabular/critic.tf/assets


### Load explainer

In [19]:
explainer = CounterfactualRLTabular.load("cfrl_tabular", predictor=predictor)

#### Test explainer

In [20]:
# select some positive examples
X_positive = X_train[predictor(X_train) == 1]


X = X_positive[:100]
Y_t = np.array([0])
C = [{"Age": [0, 20], "Workclass": ["State-gov", "?", "Local-gov"]}]

In [21]:
# generate counterfactual instances
explanation = explainer.explain(X, Y_t, C)

In [22]:
# concat labels to the original instances
orig = np.concatenate(
    [explanation.data['orig']['X'], explanation.data['orig']['class'].reshape(-1, 1)],
    axis=1
)

# concat labels to the counterfactual instances
cf = np.concatenate(
    [explanation.data['cf']['X'], explanation.data['cf']['class'].reshape(-1, 1)],
    axis=1
)

# define new feature names and category map by including the label
feature_names = adult.feature_names + ["Label"]
category_map = deepcopy(adult.category_map)
category_map.update({feature_names.index("Label"): adult.target_names})

# replace label encodings with strings
orig_pd = pd.DataFrame(
    apply_category_mapping(orig, category_map),
    columns=feature_names
)

cf_pd = pd.DataFrame(
    apply_category_mapping(cf, category_map),
    columns=feature_names
)

In [23]:
orig_pd.head(n=5)

Unnamed: 0,Age,Workclass,Education,Marital Status,Occupation,Relationship,Race,Sex,Capital Gain,Capital Loss,Hours per week,Country,Label
0,46,Private,High School grad,Married,Sales,Husband,Black,Male,0,0,60,United-States,>50K
1,41,Private,High School grad,Never-Married,Blue-Collar,Unmarried,White,Male,0,3004,60,?,>50K
2,44,Private,High School grad,Married,White-Collar,Husband,White,Male,0,0,40,United-States,>50K
3,36,Private,Bachelors,Married,Professional,Husband,White,Male,0,0,45,United-States,>50K
4,47,State-gov,Masters,Married,White-Collar,Husband,White,Male,0,0,47,United-States,>50K


In [24]:
cf_pd.head(n=5)

Unnamed: 0,Age,Workclass,Education,Marital Status,Occupation,Relationship,Race,Sex,Capital Gain,Capital Loss,Hours per week,Country,Label
0,46,?,Dropout,Married,Blue-Collar,Husband,Black,Male,813,0,49,United-States,<=50K
1,47,?,Dropout,Never-Married,Blue-Collar,Unmarried,White,Male,347,0,53,United-States,<=50K
2,47,?,Dropout,Married,Blue-Collar,Husband,White,Male,686,0,40,United-States,<=50K
3,43,?,Dropout,Married,Blue-Collar,Husband,White,Male,826,0,42,United-States,<=50K
4,49,?,Dropout,Married,Blue-Collar,Husband,White,Male,604,0,45,United-States,<=50K


#### Diversity

In [25]:
# generate counterfactual instances
X = X_positive[1].reshape(1, -1)
explanation = explainer.explain(X, Y_t, C, diversity=True, num_samples=100, batch_size=100)

In [26]:
# concat label column
orig = np.concatenate(
    [explanation.data['orig']['X'], explanation.data['orig']['class'].reshape(-1, 1)],
    axis=1
)

cf = np.concatenate(
    [explanation.data['cf']['X'], explanation.data['cf']['class'].reshape(-1, 1)],
    axis=1
)

# transfrom label encodings to string
orig_pd = pd.DataFrame(
    apply_category_mapping(orig, category_map),
    columns=feature_names,
)

cf_pd = pd.DataFrame(
    apply_category_mapping(cf, category_map),
    columns=feature_names,
)

In [27]:
orig_pd.head(n=5)

Unnamed: 0,Age,Workclass,Education,Marital Status,Occupation,Relationship,Race,Sex,Capital Gain,Capital Loss,Hours per week,Country,Label
0,41,Private,High School grad,Never-Married,Blue-Collar,Unmarried,White,Male,0,3004,60,?,>50K


In [28]:
cf_pd.head(n=5)

Unnamed: 0,Age,Workclass,Education,Marital Status,Occupation,Relationship,Race,Sex,Capital Gain,Capital Loss,Hours per week,Country,Label
0,41,?,Dropout,Never-Married,Blue-Collar,Unmarried,White,Male,0,0,55,United-States,<=50K
1,41,?,Dropout,Never-Married,Blue-Collar,Unmarried,White,Male,0,189,58,United-States,<=50K
2,41,?,Dropout,Never-Married,Blue-Collar,Unmarried,White,Male,0,274,50,Other,<=50K
3,41,Local-gov,Bachelors,Never-Married,Blue-Collar,Unmarried,White,Male,0,705,52,United-States,<=50K
4,41,Local-gov,Dropout,Never-Married,Blue-Collar,Unmarried,White,Male,0,1996,59,?,<=50K
