In [None]:
import glob
import wandb
import numpy as np
import pandas as pd
from functools import partial
from typing import List, Tuple, Dict, Callable

import tensorflow as tf
import tensorflow.keras as keras

from sklearn.compose import ColumnTransformer
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from sklearn.ensemble import RandomForestClassifier


from alibi.datasets import fetch_adult
from alibi.models.tensorflow.autoencoder import HeAE
from alibi.models.tensorflow.actor_critic import Actor, Critic
from alibi.models.tensorflow.cfrl_models import ADULTEncoder, ADULTDecoder
from alibi.explainers.cfrl_tabular import CounterfactualRLTabular
from alibi.explainers.backends.cfrl_tabular import he_preprocessor, statistics, conditional_vector, category_mapping
from alibi.explainers.cfrl_base import CounterfactualRLBase, ExperienceCallback, TrainingCallback

%load_ext autoreload
%autoreload 2

### Train black-box classifier

In [None]:
# fetch adult dataset
adult = fetch_adult()

# separate columns in numerical and categorical
categorical_names = [adult.feature_names[i] for i in adult.category_map.keys()]
categorical_ids = list(adult.category_map.keys())

numerical_names = [name for i, name in enumerate(adult.feature_names) if i not in adult.category_map.keys()]
numerical_ids = [i for i in range(len(adult.feature_names)) if i not in adult.category_map.keys()]

# split data into train and test
x, y = adult.data, adult.target
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=13)

In [None]:
# data preprocessor
num_transf = StandardScaler()
cat_transf = OneHotEncoder(
    categories=[range(len(x)) for x in adult.category_map.values()],
    handle_unknown="ignore"
)
preprocessor = ColumnTransformer(
    transformers=[
        ("num", num_transf, numerical_ids),
        ("cat", cat_transf, categorical_ids)
    ],
    sparse_threshold=0
)

In [None]:
preprocessor.fit(x_train)
x_train_ohe = preprocessor.transform(x_train)
x_test_ohe = preprocessor.transform(x_test)

In [None]:
clf = RandomForestClassifier(max_depth=15, min_samples_split=10, n_estimators=50)
clf.fit(x_train_ohe, y_train)

In [None]:
# define prediction function
predict_func = lambda x: clf.predict(preprocessor.transform(x))

# compute accuracy
acc = accuracy_score(y_true=y_test, y_pred=predict_func(x_test))
print("Accuracy: %.3f" % acc)

### Train autoencoder

In [None]:
# define input dimension
input_dim = 57

# define hidden dim
hidden_dim = 128

# define latent dimension
latent_dim = 15

# output dims
output_dims = [len(numerical_ids)]
output_dims += [len(adult.category_map[cat_id]) for cat_id in categorical_ids]

In [None]:
# define the heterogeneous auto-encoder
he_ae = HeAE(encoder=ADULTEncoder(hidden_dim=hidden_dim, latent_dim=latent_dim),
             decoder=ADULTDecoder(hidden_dim=hidden_dim, output_dims=output_dims))

In [None]:
# define loss functions
he_loss = [keras.losses.MeanSquaredError()]
he_loss_weights = [1.]

# add categorical losses
for i in range(len(categorical_names)):
    he_loss.append(keras.losses.SparseCategoricalCrossentropy(from_logits=True))
    he_loss_weights.append(1./len(categorical_names))

# define metrics
metrics = {}
for i, cat_name in enumerate(categorical_names):
    name = f"output_{i+2}"
    metrics.update({name: keras.metrics.SparseCategoricalAccuracy()})

In [None]:
# compile model
he_ae.compile(optimizer=keras.optimizers.Adam(learning_rate=1e-3),
              loss=he_loss,
              loss_weights=he_loss_weights,
              metrics=metrics)

In [None]:
# define attribute types
feature_types = {0: int, 8: int, 9: int, 10: int}

# define data preprocessor and inverse preprocessor
ae_preprocessor, ae_inv_preprocessor = he_preprocessor(x=x_train,
                                                       feature_names=adult.feature_names,
                                                       category_map=adult.category_map,
                                                       feature_types=feature_types)

# define trainset
trainset_input = ae_preprocessor(x_train)
trainset_outputs = [x_train_ohe[:, :len(numerical_ids)]]

for cat_id in categorical_ids:
    trainset_outputs.append(x_train[:, cat_id].reshape(-1, 1))

In [None]:
# fit model and then save, or if checkpoint already exists, just load the model
he_ae_path = "tensorflow/he_autoencoder/autencoder_adult.tf"

if len(glob.glob(he_ae_path + "*")) == 0:
    he_ae.fit(trainset_input, trainset_outputs, epochs=500)
    he_ae.save_weights(he_ae_path)
else:
    he_ae.load_weights(he_ae_path).expect_partial()

### Counterfactual RL

#### Define dataset specifi attributes and constraints

In [None]:
num_classes = 2

# define immutable features
immutable_features = ['Marital Status', 'Relationship', 'Race', 'Sex']

# define ranges
ranges = {'Age': [-0.0, 1.0]}


# compute statistic for clamping
stats = statistics(x=x_train, 
                   preprocessor=ae_preprocessor, 
                   category_map=adult.category_map)

#### Define experience callbacks

In [None]:
class RewardCallback(ExperienceCallback):
    def __call__(self,
                 step: int, 
                 model: CounterfactualRLBase, 
                 sample: Dict[str, np.ndarray]):
        if step % 100 != 0:
            return
        
        # get the counterfactual and target
        x_cf = model.params["ae_inv_preprocessor"](sample["x_cf"])
        y_t = sample["y_t"]
        
        # get prediction label
        y_m_cf = predict_func(x_cf)
        
        # compute reward
        reward = np.mean(model.params["reward_func"](y_m_cf, y_t))
        wandb.log({"reward": reward})

#### Define training callbacks

In [None]:
class DisplayLossCallback(TrainingCallback):
    def __call__(self,
                 step: int, 
                 update: int, 
                 model: CounterfactualRLBase,
                 sample: Dict[str, np.ndarray],
                 losses: Dict[str, float]):
        # log training losses
        if (step + update) % 100 == 0:
            wandb.log(losses)

#### Define explainer

In [None]:
# define ddpg
explainer = CounterfactualRLTabular(ae=he_ae,
                                    latent_dim=latent_dim,
                                    ae_preprocessor=ae_preprocessor,
                                    ae_inv_preprocessor=ae_inv_preprocessor,
                                    predict_func=predict_func,
                                    coeff_sparsity=0.5,
                                    coeff_consistency=0.5,
                                    num_classes=2,
                                    category_map=adult.category_map,
                                    feature_names=adult.feature_names,
                                    ranges=ranges,
                                    immutable_features=immutable_features,
                                    experience_callbacks=[RewardCallback()],
                                    train_callbacks=[DisplayLossCallback()],
                                    weight_cat=1.0,
                                    weight_num=0.2,
                                    backend="tensorflow",
                                    train_steps=10000)

#### Fit explainer

In [None]:
#initialize wandb
wandb_project = "ADULT CounterfactualRL"
wandb.init(project=wandb_project)

# fit the explainers
explainer = explainer.fit(x=x_train)

# close wandb
wandb.finish()

#### Test explainer

In [None]:
# select some positive examples
x_positive = x_train[predict_func(x_train) == 1]


x = x_positive[:2]
y_t = np.array([0])
c = [{"Age": [0, 20], "Workclass": ["State-gov", "?", "Local-gov"]}]

In [None]:
# generate counterfactual instances
x_cf = explainer.explain(x, y_t, c)

In [None]:
print("Input labels:", predict_func(x))
print("Couterfactual labels:",  predict_func(x_cf))

In [None]:
pd.DataFrame(category_mapping(x, adult.category_map), columns=adult.feature_names)

In [None]:
pd.DataFrame(category_mapping(x_cf, adult.category_map),
             columns=adult.feature_names)

#### Diversity

In [None]:
# generate counterfactual instances
x = x_positive[1].reshape(1, -1)
x_cf = explainer.explain(x, y_t, c, diversity=True, num_samples=5)

In [None]:
predict_func(x)

In [None]:
pd.DataFrame(category_mapping(x, adult.category_map), columns=adult.feature_names)

In [None]:
pd.DataFrame(category_mapping(x_cf, adult.category_map), columns=adult.feature_names)