# Vanilla CF

> Vanilla counterfactual explanation.

In [1]:
# | default_exp methods.vanilla

In [2]:
#| include: false
%load_ext autoreload
%autoreload 2
from ipynb_path import *

In [3]:
# | export
from __future__ import annotations
from cfnet.import_essentials import *
from cfnet.methods.base import BaseCFModule
from cfnet.data import TabularDataModule
from cfnet.utils import validate_configs, binary_cross_entropy, grad_update

In [4]:
# | exporti
def _vanilla_cf(
    x: jnp.DeviceArray,  # `x` shape: (k,), where `k` is the number of features
    pred_fn: Callable[[jnp.DeviceArray], jnp.DeviceArray],  # y = pred_fn(x)
    n_steps: int,
    lr: float,  # learning rate for each `cf` optimization step
    lambda_: float,  #  loss = validity_loss + lambda_params * cost
    apply_fn: Callable
) -> jnp.DeviceArray:  # return `cf` shape: (k,)
    def loss_fn_1(cf_y: jnp.DeviceArray, y_prime: jnp.DeviceArray):
        return jnp.mean(binary_cross_entropy(y_pred=cf_y, y=y_prime))

    def loss_fn_2(x: jnp.DeviceArray, cf: jnp.DeviceArray):
        return jnp.mean(optax.l2_loss(cf, x))

    def loss_fn(
        cf: jnp.DeviceArray,  # `cf` shape: (k, 1)
        x: jnp.DeviceArray,  # `x` shape: (k, 1)
        pred_fn: Callable[[jnp.DeviceArray], jnp.DeviceArray],
    ):
        y_pred = pred_fn(x)
        y_prime = 1.0 - y_pred
        cf_y = pred_fn(cf)
        return loss_fn_1(cf_y, y_prime) + lambda_ * loss_fn_2(x, cf)

    @jax.jit
    def gen_cf_step(
        x: jnp.DeviceArray, cf: jnp.DeviceArray, opt_state: optax.OptState
    ) -> Tuple[jnp.DeviceArray, optax.OptState]:
        cf_grads = jax.grad(loss_fn)(cf, x, pred_fn)
        cf, opt_state = grad_update(cf_grads, cf, opt_state, opt)
        cf = apply_fn(x, cf, hard=False)
        cf = jnp.clip(cf, 0.0, 1.0)
        return cf, opt_state

    x_size = x.shape
    if len(x_size) > 1 and x_size[0] != 1:
        raise ValueError(
            f"""Invalid Input Shape: Require `x.shape` = (1, k) or (k, ),
but got `x.shape` = {x.shape}. This method expects a single input instance."""
        )
    if len(x_size) == 1:
        x = x.reshape(1, -1)
    cf = jnp.array(x, copy=True)
    opt = optax.rmsprop(lr)
    opt_state = opt.init(cf)
    for _ in tqdm(range(n_steps)):
        cf, opt_state = gen_cf_step(x, cf, opt_state)

    cf = apply_fn(x, cf, hard=False)
    return cf.reshape(x_size)


In [16]:
# | export
class VanillaCFConfig(BaseParser):
    n_steps: int = 1000
    lr: float = 0.001
    lambda_: float = 0.1  # loss = validity_loss + lambda_ * cost


In [17]:
# | export
class VanillaCF(BaseCFModule):
    name = "VanillaCF"

    def __init__(
        self,
        configs: dict | VanillaCFConfig = None
    ):
        if configs is None:
            configs = VanillaCFConfig()
        self.configs = validate_configs(configs, VanillaCFConfig)

    def generate_cf(
        self,
        x: jnp.ndarray,  # `x` shape: (k,), where `k` is the number of features
        pred_fn: Callable[[jnp.DeviceArray], jnp.DeviceArray],
    ) -> jnp.DeviceArray:
        return _vanilla_cf(
            x=x,  # `x` shape: (k,), where `k` is the number of features
            pred_fn=pred_fn,  # y = pred_fn(x)
            n_steps=self.configs.n_steps,
            lr=self.configs.lr,  # learning rate for each `cf` optimization step
            lambda_=self.configs.lambda_,  #  loss = validity_loss + lambda_params * cost
            apply_fn=self.data_module.apply_constraints
        )

    def generate_cfs(
        self,
        X: jnp.DeviceArray,  # `x` shape: (b, k), where `b` is batch size, `k` is the number of features
        pred_fn: Callable[[jnp.DeviceArray], jnp.DeviceArray],
        is_parallel: bool = False,
    ) -> jnp.DeviceArray:
        def _generate_cf(x: jnp.DeviceArray) -> jnp.ndarray:
            return self.generate_cf(x, pred_fn)

        return (
            jax.vmap(_generate_cf)(X) if not is_parallel else jax.pmap(_generate_cf)(X)
        )


In [18]:
from cfnet.data import load_data
from cfnet.module import PredictiveTrainingModule, PredictiveTrainingModuleConfigs
from cfnet.evaluate import generate_cf_explanations, benchmark_cfs
from cfnet.train import train_model

Load data:

In [19]:
dm = load_data('adult', data_configs=dict(sample_frac=0.1))

Train predictive model:

In [9]:
#| output: false
m_config = dict(sizes=[50, 10, 50], lr=0.03)
t_config = dict(n_epochs=10, batch_size=256)

training_module = PredictiveTrainingModule(m_config)
params, opt_state = train_model(
    training_module, dm, t_config
)
# predict function
pred_fn = lambda x, params, key: training_module.forward(
    params, key, x, is_training=False
)


  "`monitor_metrics` is not specified in `CheckpointManager`. No checkpoints will be stored."
Epoch 9: 100%|██████████| 10/10 [00:00<00:00, 87.77batch/s, train/train_loss_1=0.0615]


Define `VanillaCF`:

In [20]:
vanillacf = VanillaCF()

Generate explanations:

In [21]:
#| output: false
cf_exp = generate_cf_explanations(
    vanillacf, dm, pred_fn=pred_fn, 
    t_configs=dict(
        n_epochs=5, batch_size=128
    ), 
    pred_fn_args=dict(
        params=params, key=random.PRNGKey(0)
    )
)

100%|██████████| 1000/1000 [00:05<00:00, 185.48it/s]


Evaluate explanations:

In [22]:
benchmark_cfs([cf_exp])

Unnamed: 0,Unnamed: 1,acc,validity,proximity
adult,VanillaCF,0.825329,0.167424,7.3851247
