In [1]:
import os
import wandb
import numpy as np
import pandas as pd
from copy import deepcopy
from typing import List, Tuple, Dict, Callable
from functools import partial

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset

from sklearn.compose import ColumnTransformer
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score

from alibi.datasets import fetch_adult
from alibi.models.pytorch.autoencoder import HeAE
from alibi.models.pytorch.actor_critic import Actor, Critic
from alibi.models.pytorch.cfrl_models import ADULTEncoder, ADULTDecoder
from alibi.models.pytorch.metrics import AccuracyMetric

from alibi.explainers.cfrl_tabular import CounterfactualRLTabular
from alibi.explainers.cfrl_base import CounterfactualRLBase, Callback
from alibi.explainers.backends.cfrl_tabular import get_he_preprocessor, get_statistics, \
    get_conditional_vector, apply_category_mapping


%load_ext autoreload
%autoreload 2



### Train black-box classifier

In [2]:
# fetch adult dataset
adult = fetch_adult()

# separate columns in numerical and categorical
categorical_names = [adult.feature_names[i] for i in adult.category_map.keys()]
categorical_ids = list(adult.category_map.keys())

numerical_names = [name for i, name in enumerate(adult.feature_names) if i not in adult.category_map.keys()]
numerical_ids = [i for i in range(len(adult.feature_names)) if i not in adult.category_map.keys()]

# split data into train and test
X, Y = adult.data, adult.target
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.2, random_state=13)

In [3]:
# data preprocessor
num_transf = StandardScaler()
cat_transf = OneHotEncoder(
    categories=[range(len(x)) for x in adult.category_map.values()],
    handle_unknown="ignore"
)
preprocessor = ColumnTransformer(
    transformers=[
        ("num", num_transf, numerical_ids),
        ("cat", cat_transf, categorical_ids)
    ],
    sparse_threshold=0
)

In [4]:
preprocessor.fit(X_train)
X_train_ohe = preprocessor.transform(X_train)
X_test_ohe = preprocessor.transform(X_test)

In [5]:
clf = RandomForestClassifier(max_depth=15, min_samples_split=10, n_estimators=50, random_state=0)
clf.fit(X_train_ohe, Y_train)

RandomForestClassifier(max_depth=15, min_samples_split=10, n_estimators=50,
                       random_state=0)

In [6]:
# define prediction function
predictor = lambda x: clf.predict_proba(preprocessor.transform(x))

# compute accuracy
acc = accuracy_score(y_true=Y_test, y_pred=np.argmax(predictor(X_test), axis=1))
print("Accuracy: %.3f" % acc)

Accuracy: 0.862


### Train autoencoder

In [7]:
# define input dimension
input_dim = 57

# define hidden dim
hidden_dim = 128

# define latent dimension
latent_dim = 15

# output dims
output_dims = [len(numerical_ids)]
output_dims += [len(adult.category_map[cat_id]) for cat_id in categorical_ids]

In [8]:
# define autoencoder
he_ae = HeAE(encoder=ADULTEncoder(hidden_dim=hidden_dim, latent_dim=latent_dim), 
             decoder=ADULTDecoder(hidden_dim=hidden_dim, output_dims=output_dims))

Lazy modules are a new feature under heavy development so changes to the API or functionality can happen at any moment.


In [9]:
# add numerical loss
he_loss = [nn.MSELoss()]
he_loss_weights = [1.]

# add categorical losses
for i in range(len(categorical_names)):
    he_loss.append(nn.CrossEntropyLoss())
    he_loss_weights.append(1./len(categorical_names))
    
# add metrics
metrics = {}
for i, cat_name in enumerate(categorical_names):
    metrics.update({f"output_{i+2}": AccuracyMetric()})

In [10]:
# compile model
he_ae.compile(optimizer=torch.optim.Adam(he_ae.parameters(), lr=1e-3), 
              loss=he_loss, 
              loss_weights=he_loss_weights,
              metrics=metrics)

In [11]:
BATCH_SIZE = 128
NUM_WORKERS = 4

# Define attribute types, required for datatype conversion.
feature_types = {"Age": int, "Capital Gain": int, "Capital Loss": int, "Hours per week": int}

# define data preprocessor and inverse preprocessor
ae_preprocessor, ae_inv_preprocessor = get_he_preprocessor(X=X_train,
                                                           feature_names=adult.feature_names,
                                                           category_map=adult.category_map,
                                                           feature_types=feature_types)

# transform to ohe
X_trian_ohe = ae_preprocessor(X_train)
X_test_ohe = ae_preprocessor(X_test)

# define train loader
trainset_input = torch.tensor(X_train_ohe).float()
trainset_outputs = [torch.tensor(X_train_ohe).float()[:, :len(numerical_ids)]]

for cat_id in categorical_ids:
    trainset_outputs.append(torch.tensor(X_train[:, cat_id]).long())       

trainset = TensorDataset(trainset_input, *trainset_outputs)
trainloader = DataLoader(trainset,
                         batch_size=BATCH_SIZE,
                         num_workers=NUM_WORKERS,
                         shuffle=True,
                         drop_last=True)


In [12]:
he_ae_dir = "pytorch/he_autoencoder/"
he_ae_path = os.path.join(he_ae_dir, "he_autoencoder_adult.pt")

if not os.path.exists(he_ae_dir):
    os.makedirs(he_ae_dir)

if not os.path.exists(he_ae_path):
    he_ae.fit(trainloader, epochs=50)
    he_ae.save_weights(he_ae_path)
else:
    # load the model
    he_ae.load_weights(he_ae_path)

  0%|          | 0/203 [00:00<?, ?it/s]

Epoch 0/50


100%|██████████| 203/203 [00:03<00:00, 63.97it/s]
  0%|          | 0/203 [00:00<?, ?it/s]

output_1_loss: 0.2761	output_2_loss: 0.1410	output_3_loss: 0.1558	output_4_loss: 0.0872	output_5_loss: 0.2140	output_6_loss: 0.1232	output_7_loss: 0.0723	output_8_loss: 0.0448	output_9_loss: 0.0911	loss: 1.2055	output_2_accuracy: 0.6984	output_3_accuracy: 0.5750	output_4_accuracy: 0.7407	output_5_accuracy: 0.3784	output_6_accuracy: 0.6338	output_7_accuracy: 0.8545	output_8_accuracy: 0.8372	output_9_accuracy: 0.8583	
Epoch 1/50


100%|██████████| 203/203 [00:02<00:00, 81.67it/s] 
  0%|          | 0/203 [00:00<?, ?it/s]

output_1_loss: 0.0363	output_2_loss: 0.0800	output_3_loss: 0.0687	output_4_loss: 0.0306	output_5_loss: 0.0964	output_6_loss: 0.0414	output_7_loss: 0.0358	output_8_loss: 0.0099	output_9_loss: 0.0451	loss: 0.4443	output_2_accuracy: 0.7845	output_3_accuracy: 0.8194	output_4_accuracy: 0.9185	output_5_accuracy: 0.7511	output_6_accuracy: 0.9093	output_7_accuracy: 0.9012	output_8_accuracy: 0.9806	output_9_accuracy: 0.9019	
Epoch 2/50


100%|██████████| 203/203 [00:02<00:00, 96.65it/s] 
  0%|          | 0/203 [00:00<?, ?it/s]

output_1_loss: 0.0277	output_2_loss: 0.0416	output_3_loss: 0.0388	output_4_loss: 0.0165	output_5_loss: 0.0424	output_6_loss: 0.0276	output_7_loss: 0.0210	output_8_loss: 0.0061	output_9_loss: 0.0324	loss: 0.2542	output_2_accuracy: 0.9021	output_3_accuracy: 0.9112	output_4_accuracy: 0.9617	output_5_accuracy: 0.9039	output_6_accuracy: 0.9388	output_7_accuracy: 0.9510	output_8_accuracy: 0.9883	output_9_accuracy: 0.9232	
Epoch 3/50


100%|██████████| 203/203 [00:02<00:00, 96.42it/s] 
  0%|          | 0/203 [00:00<?, ?it/s]

output_1_loss: 0.0214	output_2_loss: 0.0254	output_3_loss: 0.0238	output_4_loss: 0.0104	output_5_loss: 0.0247	output_6_loss: 0.0193	output_7_loss: 0.0149	output_8_loss: 0.0044	output_9_loss: 0.0270	loss: 0.1714	output_2_accuracy: 0.9436	output_3_accuracy: 0.9513	output_4_accuracy: 0.9781	output_5_accuracy: 0.9520	output_6_accuracy: 0.9567	output_7_accuracy: 0.9646	output_8_accuracy: 0.9914	output_9_accuracy: 0.9320	
Epoch 4/50


100%|██████████| 203/203 [00:02<00:00, 80.77it/s]
  0%|          | 0/203 [00:00<?, ?it/s]

output_1_loss: 0.0166	output_2_loss: 0.0187	output_3_loss: 0.0179	output_4_loss: 0.0083	output_5_loss: 0.0180	output_6_loss: 0.0150	output_7_loss: 0.0117	output_8_loss: 0.0032	output_9_loss: 0.0234	loss: 0.1327	output_2_accuracy: 0.9597	output_3_accuracy: 0.9630	output_4_accuracy: 0.9809	output_5_accuracy: 0.9656	output_6_accuracy: 0.9667	output_7_accuracy: 0.9715	output_8_accuracy: 0.9942	output_9_accuracy: 0.9410	
Epoch 5/50


100%|██████████| 203/203 [00:02<00:00, 73.39it/s]
  0%|          | 0/203 [00:00<?, ?it/s]

output_1_loss: 0.0140	output_2_loss: 0.0147	output_3_loss: 0.0141	output_4_loss: 0.0068	output_5_loss: 0.0142	output_6_loss: 0.0121	output_7_loss: 0.0094	output_8_loss: 0.0025	output_9_loss: 0.0205	loss: 0.1083	output_2_accuracy: 0.9698	output_3_accuracy: 0.9713	output_4_accuracy: 0.9852	output_5_accuracy: 0.9730	output_6_accuracy: 0.9733	output_7_accuracy: 0.9771	output_8_accuracy: 0.9955	output_9_accuracy: 0.9493	
Epoch 6/50


100%|██████████| 203/203 [00:02<00:00, 72.11it/s]
  0%|          | 0/203 [00:00<?, ?it/s]

output_1_loss: 0.0121	output_2_loss: 0.0122	output_3_loss: 0.0116	output_4_loss: 0.0057	output_5_loss: 0.0122	output_6_loss: 0.0097	output_7_loss: 0.0078	output_8_loss: 0.0021	output_9_loss: 0.0176	loss: 0.0911	output_2_accuracy: 0.9746	output_3_accuracy: 0.9777	output_4_accuracy: 0.9880	output_5_accuracy: 0.9768	output_6_accuracy: 0.9789	output_7_accuracy: 0.9816	output_8_accuracy: 0.9964	output_9_accuracy: 0.9578	
Epoch 7/50


100%|██████████| 203/203 [00:02<00:00, 77.98it/s] 
  0%|          | 0/203 [00:00<?, ?it/s]

output_1_loss: 0.0108	output_2_loss: 0.0106	output_3_loss: 0.0098	output_4_loss: 0.0049	output_5_loss: 0.0108	output_6_loss: 0.0082	output_7_loss: 0.0068	output_8_loss: 0.0017	output_9_loss: 0.0155	loss: 0.0790	output_2_accuracy: 0.9787	output_3_accuracy: 0.9821	output_4_accuracy: 0.9898	output_5_accuracy: 0.9788	output_6_accuracy: 0.9821	output_7_accuracy: 0.9837	output_8_accuracy: 0.9971	output_9_accuracy: 0.9628	
Epoch 8/50


100%|██████████| 203/203 [00:02<00:00, 85.82it/s] 
  0%|          | 0/203 [00:00<?, ?it/s]

output_1_loss: 0.0095	output_2_loss: 0.0094	output_3_loss: 0.0086	output_4_loss: 0.0043	output_5_loss: 0.0098	output_6_loss: 0.0073	output_7_loss: 0.0061	output_8_loss: 0.0015	output_9_loss: 0.0138	loss: 0.0703	output_2_accuracy: 0.9805	output_3_accuracy: 0.9844	output_4_accuracy: 0.9903	output_5_accuracy: 0.9803	output_6_accuracy: 0.9836	output_7_accuracy: 0.9860	output_8_accuracy: 0.9972	output_9_accuracy: 0.9672	
Epoch 9/50


100%|██████████| 203/203 [00:02<00:00, 88.42it/s] 
  0%|          | 0/203 [00:00<?, ?it/s]

output_1_loss: 0.0088	output_2_loss: 0.0084	output_3_loss: 0.0076	output_4_loss: 0.0040	output_5_loss: 0.0090	output_6_loss: 0.0065	output_7_loss: 0.0055	output_8_loss: 0.0013	output_9_loss: 0.0126	loss: 0.0636	output_2_accuracy: 0.9832	output_3_accuracy: 0.9860	output_4_accuracy: 0.9910	output_5_accuracy: 0.9825	output_6_accuracy: 0.9858	output_7_accuracy: 0.9870	output_8_accuracy: 0.9978	output_9_accuracy: 0.9706	
Epoch 10/50


100%|██████████| 203/203 [00:02<00:00, 87.49it/s] 
  0%|          | 0/203 [00:00<?, ?it/s]

output_1_loss: 0.0081	output_2_loss: 0.0075	output_3_loss: 0.0069	output_4_loss: 0.0037	output_5_loss: 0.0081	output_6_loss: 0.0060	output_7_loss: 0.0051	output_8_loss: 0.0011	output_9_loss: 0.0114	loss: 0.0579	output_2_accuracy: 0.9851	output_3_accuracy: 0.9873	output_4_accuracy: 0.9915	output_5_accuracy: 0.9842	output_6_accuracy: 0.9868	output_7_accuracy: 0.9885	output_8_accuracy: 0.9983	output_9_accuracy: 0.9737	
Epoch 11/50


100%|██████████| 203/203 [00:02<00:00, 85.97it/s] 
  0%|          | 0/203 [00:00<?, ?it/s]

output_1_loss: 0.0075	output_2_loss: 0.0069	output_3_loss: 0.0063	output_4_loss: 0.0034	output_5_loss: 0.0075	output_6_loss: 0.0056	output_7_loss: 0.0046	output_8_loss: 0.0010	output_9_loss: 0.0106	loss: 0.0533	output_2_accuracy: 0.9862	output_3_accuracy: 0.9881	output_4_accuracy: 0.9925	output_5_accuracy: 0.9853	output_6_accuracy: 0.9879	output_7_accuracy: 0.9894	output_8_accuracy: 0.9981	output_9_accuracy: 0.9759	
Epoch 12/50


100%|██████████| 203/203 [00:02<00:00, 77.70it/s]
  0%|          | 0/203 [00:00<?, ?it/s]

output_1_loss: 0.0070	output_2_loss: 0.0063	output_3_loss: 0.0058	output_4_loss: 0.0031	output_5_loss: 0.0069	output_6_loss: 0.0052	output_7_loss: 0.0042	output_8_loss: 0.0009	output_9_loss: 0.0097	loss: 0.0490	output_2_accuracy: 0.9873	output_3_accuracy: 0.9890	output_4_accuracy: 0.9928	output_5_accuracy: 0.9864	output_6_accuracy: 0.9883	output_7_accuracy: 0.9903	output_8_accuracy: 0.9983	output_9_accuracy: 0.9784	
Epoch 13/50


100%|██████████| 203/203 [00:02<00:00, 71.33it/s]
  0%|          | 0/203 [00:00<?, ?it/s]

output_1_loss: 0.0066	output_2_loss: 0.0057	output_3_loss: 0.0053	output_4_loss: 0.0029	output_5_loss: 0.0064	output_6_loss: 0.0048	output_7_loss: 0.0037	output_8_loss: 0.0008	output_9_loss: 0.0090	loss: 0.0452	output_2_accuracy: 0.9887	output_3_accuracy: 0.9895	output_4_accuracy: 0.9933	output_5_accuracy: 0.9880	output_6_accuracy: 0.9894	output_7_accuracy: 0.9920	output_8_accuracy: 0.9987	output_9_accuracy: 0.9796	
Epoch 14/50


100%|██████████| 203/203 [00:02<00:00, 71.47it/s]
  0%|          | 0/203 [00:00<?, ?it/s]

output_1_loss: 0.0064	output_2_loss: 0.0053	output_3_loss: 0.0048	output_4_loss: 0.0028	output_5_loss: 0.0059	output_6_loss: 0.0045	output_7_loss: 0.0034	output_8_loss: 0.0007	output_9_loss: 0.0084	loss: 0.0422	output_2_accuracy: 0.9891	output_3_accuracy: 0.9909	output_4_accuracy: 0.9938	output_5_accuracy: 0.9885	output_6_accuracy: 0.9900	output_7_accuracy: 0.9932	output_8_accuracy: 0.9987	output_9_accuracy: 0.9809	
Epoch 15/50


100%|██████████| 203/203 [00:02<00:00, 77.68it/s]
  0%|          | 0/203 [00:00<?, ?it/s]

output_1_loss: 0.0060	output_2_loss: 0.0048	output_3_loss: 0.0045	output_4_loss: 0.0027	output_5_loss: 0.0054	output_6_loss: 0.0042	output_7_loss: 0.0030	output_8_loss: 0.0007	output_9_loss: 0.0078	loss: 0.0391	output_2_accuracy: 0.9905	output_3_accuracy: 0.9914	output_4_accuracy: 0.9943	output_5_accuracy: 0.9905	output_6_accuracy: 0.9907	output_7_accuracy: 0.9940	output_8_accuracy: 0.9990	output_9_accuracy: 0.9818	
Epoch 16/50


100%|██████████| 203/203 [00:02<00:00, 85.37it/s] 
  0%|          | 0/203 [00:00<?, ?it/s]

output_1_loss: 0.0056	output_2_loss: 0.0045	output_3_loss: 0.0041	output_4_loss: 0.0024	output_5_loss: 0.0050	output_6_loss: 0.0040	output_7_loss: 0.0028	output_8_loss: 0.0006	output_9_loss: 0.0073	loss: 0.0363	output_2_accuracy: 0.9910	output_3_accuracy: 0.9920	output_4_accuracy: 0.9947	output_5_accuracy: 0.9905	output_6_accuracy: 0.9910	output_7_accuracy: 0.9943	output_8_accuracy: 0.9991	output_9_accuracy: 0.9830	
Epoch 17/50


100%|██████████| 203/203 [00:02<00:00, 85.29it/s] 
  0%|          | 0/203 [00:00<?, ?it/s]

output_1_loss: 0.0055	output_2_loss: 0.0042	output_3_loss: 0.0037	output_4_loss: 0.0023	output_5_loss: 0.0045	output_6_loss: 0.0037	output_7_loss: 0.0026	output_8_loss: 0.0005	output_9_loss: 0.0069	loss: 0.0339	output_2_accuracy: 0.9922	output_3_accuracy: 0.9928	output_4_accuracy: 0.9948	output_5_accuracy: 0.9916	output_6_accuracy: 0.9916	output_7_accuracy: 0.9950	output_8_accuracy: 0.9992	output_9_accuracy: 0.9849	
Epoch 18/50


100%|██████████| 203/203 [00:02<00:00, 86.50it/s] 
  0%|          | 0/203 [00:00<?, ?it/s]

output_1_loss: 0.0052	output_2_loss: 0.0038	output_3_loss: 0.0035	output_4_loss: 0.0022	output_5_loss: 0.0042	output_6_loss: 0.0035	output_7_loss: 0.0024	output_8_loss: 0.0005	output_9_loss: 0.0064	loss: 0.0317	output_2_accuracy: 0.9925	output_3_accuracy: 0.9933	output_4_accuracy: 0.9955	output_5_accuracy: 0.9926	output_6_accuracy: 0.9923	output_7_accuracy: 0.9949	output_8_accuracy: 0.9993	output_9_accuracy: 0.9859	
Epoch 19/50


100%|██████████| 203/203 [00:02<00:00, 85.47it/s] 
  0%|          | 0/203 [00:00<?, ?it/s]

output_1_loss: 0.0049	output_2_loss: 0.0036	output_3_loss: 0.0032	output_4_loss: 0.0021	output_5_loss: 0.0037	output_6_loss: 0.0032	output_7_loss: 0.0022	output_8_loss: 0.0005	output_9_loss: 0.0061	loss: 0.0296	output_2_accuracy: 0.9933	output_3_accuracy: 0.9942	output_4_accuracy: 0.9957	output_5_accuracy: 0.9935	output_6_accuracy: 0.9933	output_7_accuracy: 0.9954	output_8_accuracy: 0.9994	output_9_accuracy: 0.9861	
Epoch 20/50


100%|██████████| 203/203 [00:02<00:00, 92.33it/s] 
  0%|          | 0/203 [00:00<?, ?it/s]

output_1_loss: 0.0049	output_2_loss: 0.0033	output_3_loss: 0.0030	output_4_loss: 0.0020	output_5_loss: 0.0035	output_6_loss: 0.0031	output_7_loss: 0.0021	output_8_loss: 0.0004	output_9_loss: 0.0057	loss: 0.0281	output_2_accuracy: 0.9939	output_3_accuracy: 0.9942	output_4_accuracy: 0.9960	output_5_accuracy: 0.9938	output_6_accuracy: 0.9931	output_7_accuracy: 0.9955	output_8_accuracy: 0.9995	output_9_accuracy: 0.9876	
Epoch 21/50


100%|██████████| 203/203 [00:02<00:00, 95.75it/s] 
  0%|          | 0/203 [00:00<?, ?it/s]

output_1_loss: 0.0047	output_2_loss: 0.0031	output_3_loss: 0.0028	output_4_loss: 0.0019	output_5_loss: 0.0031	output_6_loss: 0.0030	output_7_loss: 0.0020	output_8_loss: 0.0004	output_9_loss: 0.0053	loss: 0.0263	output_2_accuracy: 0.9939	output_3_accuracy: 0.9950	output_4_accuracy: 0.9964	output_5_accuracy: 0.9945	output_6_accuracy: 0.9938	output_7_accuracy: 0.9963	output_8_accuracy: 0.9996	output_9_accuracy: 0.9885	
Epoch 22/50


100%|██████████| 203/203 [00:02<00:00, 94.91it/s] 
  0%|          | 0/203 [00:00<?, ?it/s]

output_1_loss: 0.0048	output_2_loss: 0.0028	output_3_loss: 0.0026	output_4_loss: 0.0018	output_5_loss: 0.0029	output_6_loss: 0.0027	output_7_loss: 0.0018	output_8_loss: 0.0004	output_9_loss: 0.0051	loss: 0.0249	output_2_accuracy: 0.9950	output_3_accuracy: 0.9952	output_4_accuracy: 0.9964	output_5_accuracy: 0.9943	output_6_accuracy: 0.9943	output_7_accuracy: 0.9963	output_8_accuracy: 0.9997	output_9_accuracy: 0.9887	
Epoch 23/50


100%|██████████| 203/203 [00:02<00:00, 84.21it/s]
  0%|          | 0/203 [00:00<?, ?it/s]

output_1_loss: 0.0044	output_2_loss: 0.0026	output_3_loss: 0.0024	output_4_loss: 0.0017	output_5_loss: 0.0027	output_6_loss: 0.0025	output_7_loss: 0.0017	output_8_loss: 0.0004	output_9_loss: 0.0048	loss: 0.0232	output_2_accuracy: 0.9954	output_3_accuracy: 0.9957	output_4_accuracy: 0.9967	output_5_accuracy: 0.9951	output_6_accuracy: 0.9946	output_7_accuracy: 0.9969	output_8_accuracy: 0.9995	output_9_accuracy: 0.9897	
Epoch 24/50


100%|██████████| 203/203 [00:02<00:00, 78.87it/s]
  0%|          | 0/203 [00:00<?, ?it/s]

output_1_loss: 0.0042	output_2_loss: 0.0025	output_3_loss: 0.0022	output_4_loss: 0.0016	output_5_loss: 0.0025	output_6_loss: 0.0024	output_7_loss: 0.0017	output_8_loss: 0.0004	output_9_loss: 0.0046	loss: 0.0219	output_2_accuracy: 0.9959	output_3_accuracy: 0.9964	output_4_accuracy: 0.9971	output_5_accuracy: 0.9956	output_6_accuracy: 0.9953	output_7_accuracy: 0.9972	output_8_accuracy: 0.9994	output_9_accuracy: 0.9897	
Epoch 25/50


100%|██████████| 203/203 [00:02<00:00, 94.84it/s] 
  0%|          | 0/203 [00:00<?, ?it/s]

output_1_loss: 0.0042	output_2_loss: 0.0023	output_3_loss: 0.0021	output_4_loss: 0.0015	output_5_loss: 0.0023	output_6_loss: 0.0022	output_7_loss: 0.0015	output_8_loss: 0.0003	output_9_loss: 0.0042	loss: 0.0206	output_2_accuracy: 0.9961	output_3_accuracy: 0.9964	output_4_accuracy: 0.9970	output_5_accuracy: 0.9957	output_6_accuracy: 0.9957	output_7_accuracy: 0.9978	output_8_accuracy: 0.9997	output_9_accuracy: 0.9910	
Epoch 26/50


100%|██████████| 203/203 [00:02<00:00, 93.28it/s] 
  0%|          | 0/203 [00:00<?, ?it/s]

output_1_loss: 0.0041	output_2_loss: 0.0021	output_3_loss: 0.0020	output_4_loss: 0.0015	output_5_loss: 0.0021	output_6_loss: 0.0020	output_7_loss: 0.0015	output_8_loss: 0.0003	output_9_loss: 0.0040	loss: 0.0196	output_2_accuracy: 0.9963	output_3_accuracy: 0.9967	output_4_accuracy: 0.9976	output_5_accuracy: 0.9968	output_6_accuracy: 0.9962	output_7_accuracy: 0.9977	output_8_accuracy: 0.9997	output_9_accuracy: 0.9915	
Epoch 27/50


100%|██████████| 203/203 [00:02<00:00, 95.63it/s] 
  0%|          | 0/203 [00:00<?, ?it/s]

output_1_loss: 0.0038	output_2_loss: 0.0020	output_3_loss: 0.0019	output_4_loss: 0.0013	output_5_loss: 0.0020	output_6_loss: 0.0019	output_7_loss: 0.0013	output_8_loss: 0.0003	output_9_loss: 0.0036	loss: 0.0182	output_2_accuracy: 0.9966	output_3_accuracy: 0.9968	output_4_accuracy: 0.9977	output_5_accuracy: 0.9965	output_6_accuracy: 0.9966	output_7_accuracy: 0.9977	output_8_accuracy: 0.9997	output_9_accuracy: 0.9920	
Epoch 28/50


100%|██████████| 203/203 [00:02<00:00, 95.08it/s] 
  0%|          | 0/203 [00:00<?, ?it/s]

output_1_loss: 0.0038	output_2_loss: 0.0019	output_3_loss: 0.0017	output_4_loss: 0.0012	output_5_loss: 0.0019	output_6_loss: 0.0017	output_7_loss: 0.0013	output_8_loss: 0.0003	output_9_loss: 0.0034	loss: 0.0172	output_2_accuracy: 0.9967	output_3_accuracy: 0.9972	output_4_accuracy: 0.9980	output_5_accuracy: 0.9966	output_6_accuracy: 0.9970	output_7_accuracy: 0.9980	output_8_accuracy: 0.9998	output_9_accuracy: 0.9928	
Epoch 29/50


100%|██████████| 203/203 [00:02<00:00, 91.94it/s] 
  0%|          | 0/203 [00:00<?, ?it/s]

output_1_loss: 0.0037	output_2_loss: 0.0017	output_3_loss: 0.0016	output_4_loss: 0.0012	output_5_loss: 0.0017	output_6_loss: 0.0016	output_7_loss: 0.0012	output_8_loss: 0.0002	output_9_loss: 0.0032	loss: 0.0163	output_2_accuracy: 0.9970	output_3_accuracy: 0.9975	output_4_accuracy: 0.9978	output_5_accuracy: 0.9970	output_6_accuracy: 0.9971	output_7_accuracy: 0.9982	output_8_accuracy: 0.9999	output_9_accuracy: 0.9931	
Epoch 30/50


100%|██████████| 203/203 [00:02<00:00, 73.12it/s]
  0%|          | 0/203 [00:00<?, ?it/s]

output_1_loss: 0.0035	output_2_loss: 0.0016	output_3_loss: 0.0015	output_4_loss: 0.0012	output_5_loss: 0.0016	output_6_loss: 0.0015	output_7_loss: 0.0011	output_8_loss: 0.0002	output_9_loss: 0.0030	loss: 0.0154	output_2_accuracy: 0.9976	output_3_accuracy: 0.9972	output_4_accuracy: 0.9977	output_5_accuracy: 0.9970	output_6_accuracy: 0.9973	output_7_accuracy: 0.9984	output_8_accuracy: 0.9998	output_9_accuracy: 0.9941	
Epoch 31/50


100%|██████████| 203/203 [00:02<00:00, 72.77it/s]
  0%|          | 0/203 [00:00<?, ?it/s]

output_1_loss: 0.0036	output_2_loss: 0.0015	output_3_loss: 0.0014	output_4_loss: 0.0011	output_5_loss: 0.0015	output_6_loss: 0.0014	output_7_loss: 0.0011	output_8_loss: 0.0002	output_9_loss: 0.0028	loss: 0.0146	output_2_accuracy: 0.9975	output_3_accuracy: 0.9976	output_4_accuracy: 0.9982	output_5_accuracy: 0.9975	output_6_accuracy: 0.9977	output_7_accuracy: 0.9984	output_8_accuracy: 0.9999	output_9_accuracy: 0.9948	
Epoch 32/50


100%|██████████| 203/203 [00:02<00:00, 72.25it/s] 
  0%|          | 0/203 [00:00<?, ?it/s]

output_1_loss: 0.0035	output_2_loss: 0.0015	output_3_loss: 0.0014	output_4_loss: 0.0011	output_5_loss: 0.0015	output_6_loss: 0.0014	output_7_loss: 0.0010	output_8_loss: 0.0002	output_9_loss: 0.0026	loss: 0.0141	output_2_accuracy: 0.9976	output_3_accuracy: 0.9974	output_4_accuracy: 0.9982	output_5_accuracy: 0.9977	output_6_accuracy: 0.9975	output_7_accuracy: 0.9984	output_8_accuracy: 0.9998	output_9_accuracy: 0.9952	
Epoch 33/50


100%|██████████| 203/203 [00:03<00:00, 57.62it/s]
  0%|          | 0/203 [00:00<?, ?it/s]

output_1_loss: 0.0034	output_2_loss: 0.0014	output_3_loss: 0.0013	output_4_loss: 0.0010	output_5_loss: 0.0014	output_6_loss: 0.0013	output_7_loss: 0.0010	output_8_loss: 0.0002	output_9_loss: 0.0025	loss: 0.0134	output_2_accuracy: 0.9976	output_3_accuracy: 0.9973	output_4_accuracy: 0.9983	output_5_accuracy: 0.9981	output_6_accuracy: 0.9980	output_7_accuracy: 0.9986	output_8_accuracy: 0.9999	output_9_accuracy: 0.9954	
Epoch 34/50


100%|██████████| 203/203 [00:02<00:00, 81.06it/s] 
  0%|          | 0/203 [00:00<?, ?it/s]

output_1_loss: 0.0032	output_2_loss: 0.0013	output_3_loss: 0.0013	output_4_loss: 0.0009	output_5_loss: 0.0013	output_6_loss: 0.0012	output_7_loss: 0.0010	output_8_loss: 0.0002	output_9_loss: 0.0023	loss: 0.0125	output_2_accuracy: 0.9982	output_3_accuracy: 0.9979	output_4_accuracy: 0.9984	output_5_accuracy: 0.9980	output_6_accuracy: 0.9980	output_7_accuracy: 0.9988	output_8_accuracy: 0.9997	output_9_accuracy: 0.9963	
Epoch 35/50


100%|██████████| 203/203 [00:02<00:00, 70.60it/s]
  0%|          | 0/203 [00:00<?, ?it/s]

output_1_loss: 0.0031	output_2_loss: 0.0012	output_3_loss: 0.0012	output_4_loss: 0.0009	output_5_loss: 0.0012	output_6_loss: 0.0011	output_7_loss: 0.0009	output_8_loss: 0.0002	output_9_loss: 0.0022	loss: 0.0119	output_2_accuracy: 0.9979	output_3_accuracy: 0.9980	output_4_accuracy: 0.9983	output_5_accuracy: 0.9983	output_6_accuracy: 0.9982	output_7_accuracy: 0.9988	output_8_accuracy: 0.9998	output_9_accuracy: 0.9958	
Epoch 36/50


100%|██████████| 203/203 [00:02<00:00, 81.29it/s] 
  0%|          | 0/203 [00:00<?, ?it/s]

output_1_loss: 0.0030	output_2_loss: 0.0011	output_3_loss: 0.0011	output_4_loss: 0.0009	output_5_loss: 0.0012	output_6_loss: 0.0011	output_7_loss: 0.0009	output_8_loss: 0.0002	output_9_loss: 0.0020	loss: 0.0114	output_2_accuracy: 0.9980	output_3_accuracy: 0.9978	output_4_accuracy: 0.9988	output_5_accuracy: 0.9984	output_6_accuracy: 0.9985	output_7_accuracy: 0.9988	output_8_accuracy: 0.9998	output_9_accuracy: 0.9965	
Epoch 37/50


100%|██████████| 203/203 [00:02<00:00, 82.43it/s] 
  0%|          | 0/203 [00:00<?, ?it/s]

output_1_loss: 0.0030	output_2_loss: 0.0011	output_3_loss: 0.0011	output_4_loss: 0.0008	output_5_loss: 0.0011	output_6_loss: 0.0010	output_7_loss: 0.0008	output_8_loss: 0.0002	output_9_loss: 0.0020	loss: 0.0110	output_2_accuracy: 0.9987	output_3_accuracy: 0.9981	output_4_accuracy: 0.9987	output_5_accuracy: 0.9985	output_6_accuracy: 0.9986	output_7_accuracy: 0.9988	output_8_accuracy: 0.9999	output_9_accuracy: 0.9966	
Epoch 38/50


100%|██████████| 203/203 [00:02<00:00, 74.52it/s] 
  0%|          | 0/203 [00:00<?, ?it/s]

output_1_loss: 0.0029	output_2_loss: 0.0010	output_3_loss: 0.0011	output_4_loss: 0.0008	output_5_loss: 0.0011	output_6_loss: 0.0009	output_7_loss: 0.0008	output_8_loss: 0.0001	output_9_loss: 0.0018	loss: 0.0106	output_2_accuracy: 0.9986	output_3_accuracy: 0.9983	output_4_accuracy: 0.9989	output_5_accuracy: 0.9987	output_6_accuracy: 0.9984	output_7_accuracy: 0.9990	output_8_accuracy: 0.9999	output_9_accuracy: 0.9970	
Epoch 39/50


100%|██████████| 203/203 [00:02<00:00, 80.35it/s] 
  0%|          | 0/203 [00:00<?, ?it/s]

output_1_loss: 0.0029	output_2_loss: 0.0011	output_3_loss: 0.0010	output_4_loss: 0.0008	output_5_loss: 0.0010	output_6_loss: 0.0009	output_7_loss: 0.0008	output_8_loss: 0.0002	output_9_loss: 0.0018	loss: 0.0104	output_2_accuracy: 0.9983	output_3_accuracy: 0.9980	output_4_accuracy: 0.9984	output_5_accuracy: 0.9984	output_6_accuracy: 0.9987	output_7_accuracy: 0.9992	output_8_accuracy: 0.9997	output_9_accuracy: 0.9968	
Epoch 40/50


100%|██████████| 203/203 [00:03<00:00, 65.62it/s]
  0%|          | 0/203 [00:00<?, ?it/s]

output_1_loss: 0.0028	output_2_loss: 0.0009	output_3_loss: 0.0009	output_4_loss: 0.0007	output_5_loss: 0.0009	output_6_loss: 0.0009	output_7_loss: 0.0008	output_8_loss: 0.0001	output_9_loss: 0.0017	loss: 0.0097	output_2_accuracy: 0.9985	output_3_accuracy: 0.9984	output_4_accuracy: 0.9990	output_5_accuracy: 0.9987	output_6_accuracy: 0.9985	output_7_accuracy: 0.9989	output_8_accuracy: 0.9999	output_9_accuracy: 0.9970	
Epoch 41/50


100%|██████████| 203/203 [00:02<00:00, 71.05it/s]
  0%|          | 0/203 [00:00<?, ?it/s]

output_1_loss: 0.0028	output_2_loss: 0.0009	output_3_loss: 0.0009	output_4_loss: 0.0007	output_5_loss: 0.0009	output_6_loss: 0.0008	output_7_loss: 0.0008	output_8_loss: 0.0001	output_9_loss: 0.0016	loss: 0.0094	output_2_accuracy: 0.9988	output_3_accuracy: 0.9987	output_4_accuracy: 0.9988	output_5_accuracy: 0.9985	output_6_accuracy: 0.9987	output_7_accuracy: 0.9990	output_8_accuracy: 0.9999	output_9_accuracy: 0.9972	
Epoch 42/50


100%|██████████| 203/203 [00:02<00:00, 70.73it/s]
  0%|          | 0/203 [00:00<?, ?it/s]

output_1_loss: 0.0025	output_2_loss: 0.0009	output_3_loss: 0.0008	output_4_loss: 0.0007	output_5_loss: 0.0009	output_6_loss: 0.0008	output_7_loss: 0.0006	output_8_loss: 0.0001	output_9_loss: 0.0015	loss: 0.0087	output_2_accuracy: 0.9988	output_3_accuracy: 0.9985	output_4_accuracy: 0.9988	output_5_accuracy: 0.9989	output_6_accuracy: 0.9992	output_7_accuracy: 0.9993	output_8_accuracy: 1.0000	output_9_accuracy: 0.9973	
Epoch 43/50


100%|██████████| 203/203 [00:03<00:00, 67.42it/s]
  0%|          | 0/203 [00:00<?, ?it/s]

output_1_loss: 0.0028	output_2_loss: 0.0008	output_3_loss: 0.0008	output_4_loss: 0.0006	output_5_loss: 0.0009	output_6_loss: 0.0008	output_7_loss: 0.0007	output_8_loss: 0.0002	output_9_loss: 0.0014	loss: 0.0092	output_2_accuracy: 0.9986	output_3_accuracy: 0.9983	output_4_accuracy: 0.9992	output_5_accuracy: 0.9986	output_6_accuracy: 0.9985	output_7_accuracy: 0.9992	output_8_accuracy: 0.9997	output_9_accuracy: 0.9975	
Epoch 44/50


100%|██████████| 203/203 [00:03<00:00, 60.86it/s]
  0%|          | 0/203 [00:00<?, ?it/s]

output_1_loss: 0.0026	output_2_loss: 0.0008	output_3_loss: 0.0008	output_4_loss: 0.0006	output_5_loss: 0.0008	output_6_loss: 0.0007	output_7_loss: 0.0006	output_8_loss: 0.0001	output_9_loss: 0.0013	loss: 0.0084	output_2_accuracy: 0.9989	output_3_accuracy: 0.9985	output_4_accuracy: 0.9989	output_5_accuracy: 0.9988	output_6_accuracy: 0.9988	output_7_accuracy: 0.9991	output_8_accuracy: 0.9998	output_9_accuracy: 0.9975	
Epoch 45/50


100%|██████████| 203/203 [00:02<00:00, 77.57it/s]
  0%|          | 0/203 [00:00<?, ?it/s]

output_1_loss: 0.0024	output_2_loss: 0.0008	output_3_loss: 0.0008	output_4_loss: 0.0007	output_5_loss: 0.0008	output_6_loss: 0.0007	output_7_loss: 0.0006	output_8_loss: 0.0001	output_9_loss: 0.0014	loss: 0.0082	output_2_accuracy: 0.9989	output_3_accuracy: 0.9985	output_4_accuracy: 0.9990	output_5_accuracy: 0.9990	output_6_accuracy: 0.9990	output_7_accuracy: 0.9990	output_8_accuracy: 0.9998	output_9_accuracy: 0.9979	
Epoch 46/50


100%|██████████| 203/203 [00:02<00:00, 74.09it/s]
  0%|          | 0/203 [00:00<?, ?it/s]

output_1_loss: 0.0025	output_2_loss: 0.0007	output_3_loss: 0.0007	output_4_loss: 0.0005	output_5_loss: 0.0007	output_6_loss: 0.0007	output_7_loss: 0.0006	output_8_loss: 0.0001	output_9_loss: 0.0012	loss: 0.0077	output_2_accuracy: 0.9990	output_3_accuracy: 0.9987	output_4_accuracy: 0.9991	output_5_accuracy: 0.9991	output_6_accuracy: 0.9989	output_7_accuracy: 0.9993	output_8_accuracy: 1.0000	output_9_accuracy: 0.9979	
Epoch 47/50


100%|██████████| 203/203 [00:02<00:00, 79.26it/s]
  0%|          | 0/203 [00:00<?, ?it/s]

output_1_loss: 0.0023	output_2_loss: 0.0007	output_3_loss: 0.0007	output_4_loss: 0.0006	output_5_loss: 0.0007	output_6_loss: 0.0006	output_7_loss: 0.0006	output_8_loss: 0.0001	output_9_loss: 0.0012	loss: 0.0073	output_2_accuracy: 0.9990	output_3_accuracy: 0.9990	output_4_accuracy: 0.9991	output_5_accuracy: 0.9992	output_6_accuracy: 0.9992	output_7_accuracy: 0.9994	output_8_accuracy: 1.0000	output_9_accuracy: 0.9980	
Epoch 48/50


100%|██████████| 203/203 [00:02<00:00, 75.22it/s]
  0%|          | 0/203 [00:00<?, ?it/s]

output_1_loss: 0.0024	output_2_loss: 0.0007	output_3_loss: 0.0006	output_4_loss: 0.0005	output_5_loss: 0.0007	output_6_loss: 0.0006	output_7_loss: 0.0006	output_8_loss: 0.0001	output_9_loss: 0.0011	loss: 0.0073	output_2_accuracy: 0.9992	output_3_accuracy: 0.9988	output_4_accuracy: 0.9993	output_5_accuracy: 0.9991	output_6_accuracy: 0.9990	output_7_accuracy: 0.9993	output_8_accuracy: 0.9999	output_9_accuracy: 0.9980	
Epoch 49/50


100%|██████████| 203/203 [00:03<00:00, 67.23it/s] 

output_1_loss: 0.0025	output_2_loss: 0.0006	output_3_loss: 0.0006	output_4_loss: 0.0005	output_5_loss: 0.0007	output_6_loss: 0.0006	output_7_loss: 0.0005	output_8_loss: 0.0001	output_9_loss: 0.0011	loss: 0.0072	output_2_accuracy: 0.9990	output_3_accuracy: 0.9989	output_4_accuracy: 0.9992	output_5_accuracy: 0.9992	output_6_accuracy: 0.9991	output_7_accuracy: 0.9994	output_8_accuracy: 1.0000	output_9_accuracy: 0.9982	





### Counterfactual RL

#### Define dataset specifi attributes and constraints

In [13]:
# define immutable features
immutable_features = ['Marital Status', 'Relationship', 'Race', 'Sex']

# define ranges
ranges = {'Age': [-0.0, 1.0]}


# compute statistic for clamping
stats = get_statistics(X=X_train, 
                       preprocessor=ae_preprocessor, 
                       category_map=adult.category_map)

#### Define training callbacks

In [14]:
class RewardCallback(Callback):
    def __call__(self,
                 step: int, 
                 update: int, 
                 model: CounterfactualRLBase,
                 sample: Dict[str, np.ndarray],
                 losses: Dict[str, float]):
        
        if (step + update) % 100 != 0:
            return
        
        # get the counterfactual and target
        Y_t = sample["Y_t"]
        X_cf = sample["X_cf"]
        
        # get prediction label
        Y_m_cf = predictor(X_cf)
        
        # compute reward
        reward = np.mean(model.params["reward_func"](Y_m_cf, Y_t))
        wandb.log({"reward": reward})

In [15]:
class DisplayLossCallback(Callback):
    def __call__(self,
                 step: int, 
                 update: int, 
                 model: CounterfactualRLBase,
                 sample: Dict[str, np.ndarray],
                 losses: Dict[str, float]):
        # log training losses
        if (step + update) % 100 == 0:
            wandb.log(losses)

#### Define explainer

In [16]:
# define ddpg
explainer = CounterfactualRLTabular(encoder=he_ae.encoder,
                                    decoder=he_ae.decoder,
                                    latent_dim=latent_dim,
                                    encoder_preprocessor=ae_preprocessor,
                                    decoder_inv_preprocessor=ae_inv_preprocessor,
                                    predictor=predictor,
                                    coeff_sparsity=0.5,
                                    coeff_consistency=0.5,
                                    category_map=adult.category_map,
                                    feature_names=adult.feature_names,
                                    ranges=ranges,
                                    immutable_features=immutable_features,
                                    train_callbacks=[DisplayLossCallback(), RewardCallback()],
                                    backend="pytorch",
                                    train_steps=1000,
                                    batch_size=100,
                                    num_workers=4,
                                    seed=9)

Lazy modules are a new feature under heavy development so changes to the API or functionality can happen at any moment.
The following keys are incorrect: train_callbacks


#### Fit explainer

In [17]:
#initialize wandb
wandb_project = "ADULT CounterfactualRL"
wandb.init(project=wandb_project)

# fit the explainers
explainer = explainer.fit(X=X_train)

# close wandb
wandb.finish()

[34m[1mwandb[0m: Currently logged in as: [33mrfs[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.11.2 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


100%|██████████| 1000/1000 [00:27<00:00, 35.95it/s]


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

### Save explainer

In [18]:
explainer.save("cfrl_tabular")

### Load explainer

In [19]:
explainer = CounterfactualRLTabular.load("cfrl_tabular", predictor=predictor)

#### Test explainer

In [20]:
# select some positive examples
X_positive = X_test[np.argmax(predictor(X_test), axis=1) == 1]


X = X_positive[:100]
Y_t = np.array([0])
C = [{"Age": [0, 20], "Workclass": ["State-gov", "?", "Local-gov"]}]

In [21]:
# generate counterfactual instances
explanation = explainer.explain(X, Y_t, C)

In [22]:
# concat labels to the original instances
orig = np.concatenate(
    [explanation.data['orig']['X'], explanation.data['orig']['class']],
    axis=1
)

# concat labels to the counterfactual instances
cf = np.concatenate(
    [explanation.data['cf']['X'], explanation.data['cf']['class']],
    axis=1
)

# define new feature names and category map by including the label
feature_names = adult.feature_names + ["Label"]
category_map = deepcopy(adult.category_map)
category_map.update({feature_names.index("Label"): adult.target_names})

# replace label encodings with strings
orig_pd = pd.DataFrame(
    apply_category_mapping(orig, category_map),
    columns=feature_names
)

cf_pd = pd.DataFrame(
    apply_category_mapping(cf, category_map),
    columns=feature_names
)

In [23]:
orig_pd.head(n=5)

Unnamed: 0,Age,Workclass,Education,Marital Status,Occupation,Relationship,Race,Sex,Capital Gain,Capital Loss,Hours per week,Country,Label
0,60,Private,High School grad,Married,Blue-Collar,Husband,White,Male,7298,0,40,United-States,>50K
1,35,Private,High School grad,Married,White-Collar,Husband,White,Male,7688,0,50,United-States,>50K
2,39,State-gov,Masters,Married,Professional,Wife,White,Female,5178,0,38,United-States,>50K
3,44,Self-emp-inc,High School grad,Married,Sales,Husband,White,Male,0,0,50,United-States,>50K
4,39,Private,Bachelors,Separated,White-Collar,Not-in-family,White,Female,13550,0,50,United-States,>50K


In [24]:
cf_pd.head(n=5)

Unnamed: 0,Age,Workclass,Education,Marital Status,Occupation,Relationship,Race,Sex,Capital Gain,Capital Loss,Hours per week,Country,Label
0,60,Private,High School grad,Married,Blue-Collar,Husband,White,Male,0,0,37,United-States,<=50K
1,35,Private,High School grad,Married,Blue-Collar,Husband,White,Male,0,0,47,United-States,<=50K
2,39,Local-gov,High School grad,Married,Blue-Collar,Wife,White,Female,0,11,36,United-States,<=50K
3,44,Local-gov,High School grad,Married,Blue-Collar,Husband,White,Male,0,0,48,United-States,<=50K
4,39,Private,High School grad,Separated,Blue-Collar,Not-in-family,White,Female,87,8,48,United-States,<=50K


#### Diversity

In [25]:
# generate counterfactual instances
X = X_positive[2].reshape(1, -1)
explanation = explainer.explain(X, Y_t, C, diversity=True, num_samples=10, batch_size=100)

In [26]:
# concat label column
orig = np.concatenate(
    [explanation.data['orig']['X'], explanation.data['orig']['class']],
    axis=1
)

cf = np.concatenate(
    [explanation.data['cf']['X'], explanation.data['cf']['class']],
    axis=1
)

# transfrom label encodings to string
orig_pd = pd.DataFrame(
    apply_category_mapping(orig, category_map),
    columns=feature_names,
)

cf_pd = pd.DataFrame(
    apply_category_mapping(cf, category_map),
    columns=feature_names,
)

In [27]:
orig_pd.head(n=5)

Unnamed: 0,Age,Workclass,Education,Marital Status,Occupation,Relationship,Race,Sex,Capital Gain,Capital Loss,Hours per week,Country,Label
0,39,State-gov,Masters,Married,Professional,Wife,White,Female,5178,0,38,United-States,>50K


In [28]:
cf_pd.head(n=5)

Unnamed: 0,Age,Workclass,Education,Marital Status,Occupation,Relationship,Race,Sex,Capital Gain,Capital Loss,Hours per week,Country,Label
0,39,Local-gov,Associates,Married,Blue-Collar,Wife,White,Female,0,0,35,United-States,<=50K
1,39,Local-gov,Associates,Married,Blue-Collar,Wife,White,Female,0,0,36,United-States,<=50K
2,39,Local-gov,Associates,Married,Blue-Collar,Wife,White,Female,0,0,37,United-States,<=50K
3,39,Local-gov,Associates,Married,Blue-Collar,Wife,White,Female,0,11,36,United-States,<=50K
4,39,Local-gov,Dropout,Married,Blue-Collar,Wife,White,Female,0,0,37,United-States,<=50K
