In [1]:
import os
import wandb
import numpy as np
import pandas as pd
from copy import deepcopy
from typing import List, Tuple, Dict, Callable
from functools import partial

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset

from sklearn.compose import ColumnTransformer
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score

from alibi.datasets import fetch_adult
from alibi.models.pytorch.autoencoder import HeAE
from alibi.models.pytorch.actor_critic import Actor, Critic
from alibi.models.pytorch.cfrl_models import ADULTEncoder, ADULTDecoder
from alibi.models.pytorch.metrics import AccuracyMetric

from alibi.explainers.cfrl_tabular import CounterfactualRLTabular
from alibi.explainers.cfrl_base import CounterfactualRLBase, ExperienceCallback, TrainingCallback
from alibi.explainers.backends.cfrl_tabular import get_he_preprocessor, get_statistics, \
    get_conditional_vector, apply_category_mapping


%load_ext autoreload
%autoreload 2



### Train black-box classifier

In [2]:
# fetch adult dataset
adult = fetch_adult()

# separate columns in numerical and categorical
categorical_names = [adult.feature_names[i] for i in adult.category_map.keys()]
categorical_ids = list(adult.category_map.keys())

numerical_names = [name for i, name in enumerate(adult.feature_names) if i not in adult.category_map.keys()]
numerical_ids = [i for i in range(len(adult.feature_names)) if i not in adult.category_map.keys()]

# split data into train and test
X, Y = adult.data, adult.target
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.2, random_state=13)

In [3]:
# data preprocessor
num_transf = StandardScaler()
cat_transf = OneHotEncoder(
    categories=[range(len(x)) for x in adult.category_map.values()],
    handle_unknown="ignore"
)
preprocessor = ColumnTransformer(
    transformers=[
        ("num", num_transf, numerical_ids),
        ("cat", cat_transf, categorical_ids)
    ],
    sparse_threshold=0
)

In [4]:
preprocessor.fit(X_train)
X_train_ohe = preprocessor.transform(X_train)
X_test_ohe = preprocessor.transform(X_test)

In [5]:
clf = RandomForestClassifier(max_depth=15, min_samples_split=10, n_estimators=50, random_state=0)
clf.fit(X_train_ohe, Y_train)

RandomForestClassifier(max_depth=15, min_samples_split=10, n_estimators=50,
                       random_state=0)

In [6]:
# define prediction function
predictor = lambda x: clf.predict_proba(preprocessor.transform(x))

# compute accuracy
acc = accuracy_score(y_true=Y_test, y_pred=np.argmax(predictor(X_test), axis=1))
print("Accuracy: %.3f" % acc)

Accuracy: 0.862


### Train autoencoder

In [7]:
# define input dimension
input_dim = 57

# define hidden dim
hidden_dim = 128

# define latent dimension
latent_dim = 15

# output dims
output_dims = [len(numerical_ids)]
output_dims += [len(adult.category_map[cat_id]) for cat_id in categorical_ids]

In [8]:
# define autoencoder
he_ae = HeAE(encoder=ADULTEncoder(hidden_dim=hidden_dim, latent_dim=latent_dim), 
             decoder=ADULTDecoder(hidden_dim=hidden_dim, output_dims=output_dims))

Lazy modules are a new feature under heavy development so changes to the API or functionality can happen at any moment.


In [9]:
# add numerical loss
he_loss = [nn.MSELoss()]
he_loss_weights = [1.]

# add categorical losses
for i in range(len(categorical_names)):
    he_loss.append(nn.CrossEntropyLoss())
    he_loss_weights.append(1./len(categorical_names))
    
# add metrics
metrics = {}
for i, cat_name in enumerate(categorical_names):
    metrics.update({f"output_{i+2}": AccuracyMetric()})

In [10]:
# compile model
he_ae.compile(optimizer=torch.optim.Adam(he_ae.parameters(), lr=1e-3), 
              loss=he_loss, 
              loss_weights=he_loss_weights,
              metrics=metrics)

In [11]:
BATCH_SIZE = 128
NUM_WORKERS = 4

# Define attribute types, required for datatype conversion.
feature_types = {"Age": int, "Capital Gain": int, "Capital Loss": int, "Hours per week": int}

# define data preprocessor and inverse preprocessor
ae_preprocessor, ae_inv_preprocessor = get_he_preprocessor(X=X_train,
                                                           feature_names=adult.feature_names,
                                                           category_map=adult.category_map,
                                                           feature_types=feature_types)

# transform to ohe
X_trian_ohe = ae_preprocessor(X_train)
X_test_ohe = ae_preprocessor(X_test)

# define train loader
trainset_input = torch.tensor(X_train_ohe).float()
trainset_outputs = [torch.tensor(X_train_ohe).float()[:, :len(numerical_ids)]]

for cat_id in categorical_ids:
    trainset_outputs.append(torch.tensor(X_train[:, cat_id]).long())       

trainset = TensorDataset(trainset_input, *trainset_outputs)
trainloader = DataLoader(trainset,
                         batch_size=BATCH_SIZE,
                         num_workers=NUM_WORKERS,
                         shuffle=True,
                         drop_last=True)


In [12]:
he_ae_dir = "pytorch/he_autoencoder/"
he_ae_path = os.path.join(he_ae_dir, "he_autoencoder_adult.pt")

if not os.path.exists(he_ae_dir):
    os.makedirs(he_ae_dir)

if not os.path.exists(he_ae_path):
    he_ae.fit(trainloader, epochs=50)
    he_ae.save_weights(he_ae_path)
else:
    # load the model
    he_ae.load_weights(he_ae_path)

In [13]:
from alibi.explainers.backends.pytorch.cfrl_base import to_numpy
X_train_hat_ohe = np.concatenate(to_numpy(he_ae(torch.tensor(X_train_ohe).float().cuda())), axis=1)
X_train_hat = ae_inv_preprocessor(X_train_hat_ohe)
X_train_hat[0].astype(np.int)

array([ 45,   4,   4,   0,   6,   0,   2,   1, 166,  14,  59,   9])

In [15]:
np.mean((X_train_ohe[:, :4] - X_train_hat_ohe[:, :4])**2)

0.002064818837136055

### Counterfactual RL

#### Define dataset specifi attributes and constraints

In [None]:
# define immutable features
immutable_features = ['Marital Status', 'Relationship', 'Race', 'Sex']

# define ranges
ranges = {'Age': [-0.0, 1.0]}


# compute statistic for clamping
stats = get_statistics(X=X_train, 
                       preprocessor=ae_preprocessor, 
                       category_map=adult.category_map)

#### Define training callbacks

In [None]:
class RewardCallback(ExperienceCallback):
    def __call__(self,
                 step: int, 
                 update: int, 
                 model: CounterfactualRLBase,
                 sample: Dict[str, np.ndarray],
                 losses: Dict[str, float]):
        
        if (step + update) % 100 != 0:
            return
        
        # get the counterfactual and target
        Y_t = sample["Y_t"]
        X_cf = sample["X_cf"]
        
        # get prediction label
        Y_m_cf = predictor(X_cf)
        
        # compute reward
        reward = np.mean(model.params["reward_func"](Y_m_cf, Y_t))
        wandb.log({"reward": reward})

In [None]:
class DisplayLossCallback(TrainingCallback):
    def __call__(self,
                 step: int, 
                 update: int, 
                 model: CounterfactualRLBase,
                 sample: Dict[str, np.ndarray],
                 losses: Dict[str, float]):
        # log training losses
        if (step + update) % 100 == 0:
            wandb.log(losses)

#### Define explainer

In [None]:
# define ddpg
explainer = CounterfactualRLTabular(encoder=he_ae.encoder,
                                    decoder=he_ae.decoder,
                                    latent_dim=latent_dim,
                                    encoder_preprocessor=ae_preprocessor,
                                    decoder_inv_preprocessor=ae_inv_preprocessor,
                                    predictor=predictor,
                                    coeff_sparsity=0.5,
                                    coeff_consistency=0.5,
                                    category_map=adult.category_map,
                                    feature_names=adult.feature_names,
                                    ranges=ranges,
                                    immutable_features=immutable_features,
                                    train_callbacks=[DisplayLossCallback()], #, RewardCallback()],
                                    weight_cat=1.0,
                                    weight_num=1.0,
                                    backend="pytorch",
                                    train_steps=100000,
                                    batch_size=100,
                                    num_workers=4,
                                    seed=9)

#### Fit explainer

In [None]:
#initialize wandb
wandb_project = "ADULT CounterfactualRL"
wandb.init(project=wandb_project)

# fit the explainers
explainer = explainer.fit(X=X_train)

# close wandb
wandb.finish()

### Save explainer

In [None]:
explainer.save("cfrl_tabular")

### Load explainer

In [None]:
explainer = CounterfactualRLTabular.load("cfrl_tabular", predictor=predictor)

#### Test explainer

In [None]:
# select some positive examples
X_positive = X_test[np.argmax(predictor(X_test), axis=1) == 1]


X = X_positive[:100]
Y_t = np.array([0])
C = [{"Age": [0, 20], "Workclass": ["State-gov", "?", "Local-gov"]}]

In [None]:
# generate counterfactual instances
explanation = explainer.explain(X, Y_t, C)

In [None]:
# concat labels to the original instances
orig = np.concatenate(
    [explanation.data['orig']['X'], explanation.data['orig']['class']],
    axis=1
)

# concat labels to the counterfactual instances
cf = np.concatenate(
    [explanation.data['cf']['X'], explanation.data['cf']['class']],
    axis=1
)

# define new feature names and category map by including the label
feature_names = adult.feature_names + ["Label"]
category_map = deepcopy(adult.category_map)
category_map.update({feature_names.index("Label"): adult.target_names})

# replace label encodings with strings
orig_pd = pd.DataFrame(
    apply_category_mapping(orig, category_map),
    columns=feature_names
)

cf_pd = pd.DataFrame(
    apply_category_mapping(cf, category_map),
    columns=feature_names
)

In [None]:
orig_pd.head(n=5)

In [None]:
cf_pd.head(n=5)

#### Diversity

In [None]:
# generate counterfactual instances
X = X_positive[2].reshape(1, -1)
explanation = explainer.explain(X, Y_t, C, diversity=True, num_samples=10, batch_size=100)

In [None]:
# concat label column
orig = np.concatenate(
    [explanation.data['orig']['X'], explanation.data['orig']['class']],
    axis=1
)

cf = np.concatenate(
    [explanation.data['cf']['X'], explanation.data['cf']['class']],
    axis=1
)

# transfrom label encodings to string
orig_pd = pd.DataFrame(
    apply_category_mapping(orig, category_map),
    columns=feature_names,
)

cf_pd = pd.DataFrame(
    apply_category_mapping(cf, category_map),
    columns=feature_names,
)

In [None]:
orig_pd.head(n=5)

In [None]:
cf_pd.head(n=5)