In [None]:
import numpy as np
import torch
from alibi.explainers import CounterfactualRLTabular
from alibi.explainers.backends.cfrl_tabular import get_he_preprocessor
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split

from models.data_process import load_adult_income_dataset
from models.run_MLP import load_model
from utils.helper import adult_process_names
from utils.parser import *

args = parse_args()

In [None]:
model = load_model()
# 加载划分数据集
dataset, target, encoder, categorical_names = load_adult_income_dataset()
train_dataset, test_dataset, y_train, y_test = train_test_split(dataset,
                                                                target,
                                                                test_size=0.2,
                                                                random_state=args.random_state,
                                                                stratify=target)

predictor = lambda x: model.predict_anchor(x, encoder)
# Compute accuracy.
acc = accuracy_score(y_true=y_test, y_pred=predictor(test_dataset).argmax(axis=1))
print("Accuracy: %.3f" % acc)


In [None]:
class HeAE(torch.nn.Module):
    def __init__(self, encoder: torch.nn.Module, decoder: torch.nn.Module, **kwargs) -> None:
        super().__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder

    def call(self, x: torch.Tensor, **kwargs):
        z = self.encoder(x)
        x_hat = self.decoder(z)
        return x_hat

In [None]:
# Define attribute types, required for datatype conversion.
feature_types = {"age": int, "hours-per-week": int}

# Define data preprocessor and inverse preprocessor. The invers preprocessor include datatype conversions.
heae_preprocessor, heae_inv_preprocessor = get_he_preprocessor(X=train_dataset,
                                                               feature_names=adult_process_names,
                                                               category_map=categorical_names,
                                                               feature_types=feature_types)

# Define trainset
trainset_input = heae_preprocessor(train_dataset).astype(np.float32)
trainset_outputs = {
    "output_1": trainset_input[:, :len(numerical_ids)]
}

for i, cat_id in enumerate(categorical_ids):
    trainset_outputs.update({
        f"output_{i + 2}": X_train[:, cat_id]
    })

## Counterfactual with Reinforcement Learning

In [None]:
# Define constants
COEFF_SPARSITY = 0.5  # sparisty coefficient
COEFF_CONSISTENCY = 0.5  # consisteny coefficient
TRAIN_STEPS = 10000  # number of training steps -> consider increasing the number of steps
BATCH_SIZE = 100  # batch size

In [None]:
# Define immutable features.
immutable_features = ['marital-status', 'relationship', 'race', 'sex']

# Define ranges. This means that the `Age` feature can not decrease.
ranges = {'age': [0.0, 1.0]}

In [None]:
explainer = CounterfactualRLTabular(predictor=predictor,
                                    # encoder=heae.encoder,
                                    # decoder=heae.decoder,
                                    # latent_dim=LATENT_DIM,
                                    # encoder_preprocessor=heae_preprocessor,
                                    # decoder_inv_preprocessor=heae_inv_preprocessor,
                                    coeff_sparsity=COEFF_SPARSITY,
                                    coeff_consistency=COEFF_CONSISTENCY,
                                    category_map=categorical_names,
                                    feature_names=adult_process_names,
                                    ranges=ranges,
                                    immutable_features=immutable_features,
                                    train_steps=TRAIN_STEPS,
                                    batch_size=BATCH_SIZE,
                                    backend="tensorflow")
explainer = explainer.fit(X=train_dataset)