# Vanilla CF

Paper link: https://doi.org/10.2139/ssrn.3063289

In [None]:
# default_exp methods.vanilla

In [None]:
# hide
%load_ext autoreload
%autoreload 2
from ipynb_path import *

In [None]:
# export
from cfnet.import_essentials import *
from cfnet.interfaces import BaseCFExplanationModule, LocalCFExplanationModule
from cfnet.datasets import TabularDataModule
from cfnet.training_module import grad_update, cat_normalize
from cfnet.utils import check_cat_info, validate_configs


In [None]:
# export
def binary_cross_entropy(y_pred: chex.Array, y: chex.Array) -> chex.Array:
    return -(y * jnp.log(y_pred) + (1 - y) * jnp.log(1 - y_pred))

In [None]:
# export 
class VanillaCFConfig(BaseParser):
    n_steps: int = 1000
    lr: float = 0.001


In [None]:
# export
class VanillaCF(LocalCFExplanationModule):
    name = "VanillaCF"

    def __init__(self, 
        configs: Union[Dict[str, Any], VanillaCFConfig], 
        data_module: Optional[TabularDataModule] = None
    ):
        self.configs = validate_configs(configs, VanillaCFConfig)
        if data_module:
            self.update_cat_info(data_module)

    def _loss_fn_1(self,
        cf_y: jnp.ndarray,
        y_prime: jnp.ndarray
    ) -> jnp.ndarray:
        return jnp.mean(binary_cross_entropy(y_pred=cf_y, y=y_prime))

    def _loss_fn_2(self,
        x: jnp.ndarray,
        cf: jnp.ndarray
    ) -> jnp.ndarray:
        return jnp.mean(optax.l2_loss(cf, x))

    def generate_cf(self,
        x: jnp.ndarray, # `x` shape: (k,), where `k` is the number of features 
        pred_fn: Callable[[jnp.DeviceArray], jnp.DeviceArray]
    ) -> jnp.DeviceArray:
        def loss_fn(
            cf: jnp.ndarray, x: jnp.ndarray, pred_fn: Callable[[jnp.DeviceArray], jnp.DeviceArray]
        ) -> jnp.DeviceArray:
            y_pred = pred_fn(x)
            y_prime = 1. - y_pred
            cf_y = pred_fn(cf)
            return self._loss_fn_1(cf_y, y_prime) + 0.5 * self._loss_fn_2(x, cf)

        @jax.jit
        def gen_cf_step(
            x: jnp.DeviceArray, cf: jnp.DeviceArray, opt_state: optax.OptState
        ) -> Tuple[jnp.DeviceArray, optax.OptState]:
            cf_grads = jax.grad(loss_fn)(cf, x, pred_fn)
            cf, opt_state = grad_update(cf_grads, cf, opt_state, opt)
            cf = cat_normalize(
                cf, cat_arrays=self.cat_arrays, cat_idx=self.cat_idx, hard=False)
            return cf, opt_state

        x = x.reshape(1, -1)
        cf = jnp.array(x, copy=True)
        opt = optax.rmsprop(self.configs.lr)
        opt_state = opt.init(cf)
        for _ in tqdm(range(self.configs.n_steps)):
            cf, opt_state = gen_cf_step(x, cf, opt_state)

        cf = cat_normalize(
            cf, cat_arrays=self.cat_arrays, cat_idx=self.cat_idx, hard=True)
        return cf.reshape(-1)

    @check_cat_info
    def generate_cfs(self,
        X: jnp.DeviceArray, # `x` shape: (b, k), where `b` is batch size, `k` is the number of features 
        pred_fn: Callable[[jnp.DeviceArray], jnp.DeviceArray],
        is_parallel: bool = False
    ) -> jnp.DeviceArray:
        def _generate_cf(x: jnp.DeviceArray, pred_fn) -> jnp.ndarray:
            return self.generate_cf(x, pred_fn)
        return jax.vmap(_generate_cf)(X, pred_fn) if not is_parallel else jax.pmap(_generate_cf)(X, pred_fn)

## Test

In [None]:
data_configs = {
    "data_dir": "assets/data/s_adult.csv",
    "data_name": "adult",
    "batch_size": 256,
    'sample_frac': 0.1,
    "continous_cols": [
        "age",
        "hours_per_week"
    ],
    "discret_cols": [
        "workclass",
        "education",
        "marital_status",
        "occupation",
        "race",
        "gender"
    ],
}
m_configs = {
    "sizes": [50, 10, 50],
    "dropout_rate": 0.3,
    'lr': 0.03,
}
t_configs = {
    'n_epochs': 10,
    'monitor_metrics': 'val/val_loss'
}

In [None]:
from cfnet.training_module import PredictiveTrainingModule
from cfnet.train import train_model

training_module = PredictiveTrainingModule(m_configs)

params, opt_state = train_model(
    training_module, 
    TabularDataModule(data_configs), 
    t_configs
)

Epoch 9: 100%|██████████| 10/10 [00:00<00:00, 91.62batch/s, train/train_loss_1=0.0597]


In [None]:
dm = TabularDataModule(data_configs)
pred_fn = lambda x: training_module.forward(params, random.PRNGKey(0), x, is_training=False)

cf_exp = VanillaCF(
    configs=VanillaCFConfig(n_steps=1000, pred_fn=pred_fn)
)
cf_exp.update_cat_info(dm)

X, y = dm.test_dataset[:]


In [None]:
jnp.sum(jnp.round(pred_fn(X)) == y) / len(X)

DeviceArray(0.8152561, dtype=float32)

In [None]:
cf = cf_exp.generate_cf(X[0])

100%|██████████| 1000/1000 [00:00<00:00, 1780.54it/s]


In [None]:
cfs = cf_exp.generate_cfs(X)

100%|██████████| 1000/1000 [00:05<00:00, 188.09it/s]


In [None]:
y_pred = pred_fn(X)
cf_pred = pred_fn(cfs)

In [None]:
y_prime = 1. - jnp.round(y_pred)
validity = jnp.sum(jnp.round(cf_pred) == y_prime) / len(cf_pred)

In [None]:
validity

DeviceArray(0.79056627, dtype=float32)