In [None]:
#| default_exp methods.vanilla

In [None]:
#| include: false
%load_ext autoreload
%autoreload 2
from ipynb_path import *

In [None]:
#| export
from cfnet.import_essentials import *
from cfnet.interfaces import BaseCFExplanationModule, LocalCFExplanationModule
from cfnet.datasets import TabularDataModule
from cfnet.utils import check_cat_info, validate_configs, binary_cross_entropy, cat_normalize, grad_update

In [None]:
#| exporti
def _vanilla_cf(
    x: jnp.DeviceArray, # `x` shape: (k,), where `k` is the number of features
    pred_fn: Callable[[jnp.DeviceArray], jnp.DeviceArray], # y = pred_fn(x)
    n_steps: int,
    lr: float, # learning rate for each `cf` optimization step
    lambda_: float, #  loss = validity_loss + lambda_params * cost
    cat_arrays: List[List[str]],
    cat_idx: int
) -> jnp.DeviceArray: # return `cf` shape: (k,)
    def loss_fn_1(cf_y: jnp.DeviceArray, y_prime: jnp.DeviceArray):
        return jnp.mean(binary_cross_entropy(y_pred=cf_y, y=y_prime))

    def loss_fn_2(x: jnp.DeviceArray, cf: jnp.DeviceArray):
        return jnp.mean(optax.l2_loss(cf, x))

    def loss_fn(
        cf: jnp.DeviceArray, # `cf` shape: (k, 1)
        x: jnp.DeviceArray,  # `x` shape: (k, 1)
        pred_fn: Callable[[jnp.DeviceArray], jnp.DeviceArray]
    ):
        y_pred = pred_fn(x)
        y_prime = 1. - y_pred
        cf_y = pred_fn(cf)
        return loss_fn_1(cf_y, y_prime) + lambda_ * loss_fn_2(x, cf)

    @jax.jit
    def gen_cf_step(
        x: jnp.DeviceArray, cf: jnp.DeviceArray, opt_state: optax.OptState
    ) -> Tuple[jnp.DeviceArray, optax.OptState]:
        cf_grads = jax.grad(loss_fn)(cf, x, pred_fn)
        cf, opt_state = grad_update(cf_grads, cf, opt_state, opt)
        cf = cat_normalize(
            cf, cat_arrays=cat_arrays, cat_idx=cat_idx, hard=False)
        cf = jnp.clip(cf, 0., 1.)
        return cf, opt_state

    x_size = x.shape
    if len(x_size) > 1 and x_size[0] != 1:
        raise ValueError(f"""Invalid Input Shape: Require `x.shape` = (1, k) or (k, ),
but got `x.shape` = {x.shape}. This method expects a single input instance.""")
    if len(x_size) == 1:
        x = x.reshape(1, -1)
    cf = jnp.array(x, copy=True)
    opt = optax.rmsprop(lr)
    opt_state = opt.init(cf)
    for _ in tqdm(range(n_steps)):
        cf, opt_state = gen_cf_step(x, cf, opt_state)

    cf = cat_normalize(
        cf, cat_arrays=cat_arrays, cat_idx=cat_idx, hard=True)
    return cf.reshape(x_size)

In [None]:
#| export 
class VanillaCFConfig(BaseParser):
    n_steps: int = 1000
    lr: float = 0.01
    lambda_: float = 0.01 # loss = validity_loss + lambda_params * cost

In [None]:
#| export
class VanillaCF(LocalCFExplanationModule):
    name = "VanillaCF"

    def __init__(self,
        configs: Union[Dict[str, Any], VanillaCFConfig],
        data_module: Optional[TabularDataModule] = None
    ):
        self.configs = validate_configs(configs, VanillaCFConfig)
        if data_module:
            self.update_cat_info(data_module)

    def generate_cf(self,
        x: jnp.ndarray, # `x` shape: (k,), where `k` is the number of features
        pred_fn: Callable[[jnp.DeviceArray], jnp.DeviceArray]
    ) -> jnp.DeviceArray:
        return _vanilla_cf(
            x= x, # `x` shape: (k,), where `k` is the number of features
            pred_fn=pred_fn, # y = pred_fn(x)
            n_steps=self.configs.n_steps,
            lr=self.configs.lr, # learning rate for each `cf` optimization step
            lambda_=self.configs.lambda_, #  loss = validity_loss + lambda_params * cost
            cat_arrays=self.cat_arrays,
            cat_idx=self.cat_idx
        )

    @check_cat_info
    def generate_cfs(
        self,
        X: jnp.DeviceArray, # `x` shape: (b, k), where `b` is batch size, `k` is the number of features
        pred_fn: Callable[[jnp.DeviceArray], jnp.DeviceArray],
        is_parallel: bool = False
    ) -> jnp.DeviceArray:
        def _generate_cf(x: jnp.DeviceArray) -> jnp.ndarray:
            return self.generate_cf(x, pred_fn)
        return jax.vmap(_generate_cf)(X) if not is_parallel else jax.pmap(_generate_cf)(X)

## Test

In [None]:
data_configs = {
    "data_dir": "assets/data/s_adult.csv",
    "data_name": "adult",
    "batch_size": 256,
    'sample_frac': 0.1,
    "continous_cols": [
        "age",
        "hours_per_week"
    ],
    "discret_cols": [
        "workclass",
        "education",
        "marital_status",
        "occupation",
        "race",
        "gender"
    ],
}
m_configs = {
    "sizes": [50, 10, 50],
    "dropout_rate": 0.3,
    'lr': 0.03,
}
t_configs = {
    'n_epochs': 1,
    'monitor_metrics': 'val/val_loss',
    "batch_size": 256
}

In [None]:
from cfnet.training_module import PredictiveTrainingModule
from cfnet.train import train_model

training_module = PredictiveTrainingModule(m_configs)

params, opt_state = train_model(
    training_module, 
    TabularDataModule(data_configs), 
    t_configs
)

Epoch 9: 100%|██████████| 10/10 [00:00<00:00, 91.62batch/s, train/train_loss_1=0.0597]


In [None]:
dm = TabularDataModule(data_configs)
pred_fn = lambda x: training_module.forward(params, random.PRNGKey(0), x, is_training=False)

cf_exp = VanillaCF(
    configs=VanillaCFConfig(n_steps=1000)
)
cf_exp.update_cat_info(dm)

X, y = dm.test_dataset[:]
cf = cf_exp.generate_cf(X[0], pred_fn)
cfs = cf_exp.generate_cfs(X[:100, :], pred_fn)