# Diverse CF


In [None]:
#| default_exp methods.diverse

In [None]:
#| include: false
%load_ext autoreload
%autoreload 2
from ipynb_path import *

In [None]:
#| export
from relax.import_essentials import *
from relax.methods.base import BaseCFModule
from relax.utils import *

In [None]:
#| exporti
def hinge_loss(input: jax.Array, target: jax.Array):
    """
    reference:
    - https://github.com/interpretml/DiCE/blob/a772c8d4fcd88d1cab7f2e02b0bcc045dc0e2eab/dice_ml/explainer_interfaces/dice_pytorch.py#L196-L202
    - https://en.wikipedia.org/wiki/Hinge_loss
    """
    input = jnp.log((jnp.abs(input - 1e-6) / (1 - jnp.abs(input - 1e-6))))
    all_ones = jnp.ones_like(target)
    target = 2 * target - all_ones
    loss = all_ones - target * input
    loss = jax.nn.relu(loss)
    return jnp.linalg.norm(loss)


In [None]:
#| exporti
def l1_mean(X, cfs):
    x_mean = jnp.mean(jnp.abs(X))
    l1_loss = jnp.mean(jnp.abs(X - cfs))
    return l1_loss / x_mean


In [None]:
#| exporti
def dpp_style(cf: jax.Array, n_cfs: int):
    det_entries = jnp.ones((n_cfs, n_cfs))
    for i in range(n_cfs):
        for j in range(n_cfs):
            det_entries.at[i, j].set(dist(cf[i], cf[j], ord=1))

    det_entries = 1.0 / (1.0 + det_entries)
    det_entries += jnp.eye(n_cfs) * 0.0001
    return jnp.linalg.det(det_entries)


In [None]:
#| exporti
def _compute_regularization_loss(cfs, cat_idx, cat_arrays, n_cfs):
    # cat_idx = len(self.model.continous_cols)
    regularization_loss = 0.0
    for i in range(n_cfs):
        for col in cat_arrays:
            cat_idx_end = cat_idx + len(col)
            regularization_loss += jnp.power(
                (jnp.sum(cfs[i][cat_idx:cat_idx_end]) - 1.0), 2
            )
    return regularization_loss


In [None]:
#| exporti
@auto_reshaping('x')
def _diverse_cf(
    x: jax.Array,  # `x` shape: (k,), where `k` is the number of features
    pred_fn: Callable[[jax.Array], jax.Array],  # y = pred_fn(x)
    n_cfs: int,
    n_steps: int,
    lr: float,  # learning rate for each `cf` optimization step
    lambda_: float,  #  loss = validity_loss + lambda_params * cost
    key: jax.random.PRNGKey,
    projection_fn: Callable,
    regularization_fn: Callable
) -> jax.Array:  # return `cf` shape: (k,)
    @jit
    def loss_fn_1(cf_y: Array, y_prime: Array):
        return optax.l2_loss(cf_y.mean(axis=0, keepdims=True), y_prime).mean()

    @jit
    def loss_fn_2(x: Array, cf: Array):
        return jnp.mean(jnp.abs(cf - x))

    @partial(jit, static_argnums=(1,))
    def loss_fn_3(cfs: jax.Array, n_cfs: int):
        return dpp_style(cfs, n_cfs)

    @jit
    def loss_fn_4(x: Array, cfs: Array):
        # return _compute_regularization_loss(cfs, cat_idx, cat_arrays, n_cfs)
        reg_loss = 0.
        for i in range(n_cfs):
            reg_loss += regularization_fn(x, cfs[i])
        return reg_loss

    @partial(jit, static_argnums=(2,))
    def loss_fn(
        cf: jax.Array,  # `cf` shape: (k, n_cfs)
        x: jax.Array,  # `x` shape: (k, 1)
        pred_fn: Callable[[Array], Array],
    ):
        y_pred = pred_fn(x)
        y_prime = 1.0 - y_pred
        cf_y = pred_fn(cf)

        loss_1 = loss_fn_1(cf_y, y_prime)
        loss_2 = loss_fn_2(x, cf)
        loss_3 = loss_fn_3(cf, n_cfs)
        loss_4 = loss_fn_4(x, cfs)
        return loss_1 + 0.01 * loss_2 + loss_3 + 0.1 * loss_4

    @loop_tqdm(n_steps)
    def gen_cf_step(
        i, cf_opt_state: Tuple[Array, optax.OptState]
    ) -> Tuple[Array, optax.OptState]:
        cf, opt_state = cf_opt_state
        cf_grads = jax.grad(loss_fn)(cf, x, pred_fn)
        cf, opt_state = grad_update(cf_grads, cf, opt_state, opt)
        return cf, opt_state

    key, subkey = jax.random.split(key)
    cfs = jax.random.normal(key, shape=(n_cfs, x.shape[-1]))
    opt = optax.adam(lr)
    opt_state = opt.init(cfs)
    cfs, opt_state = lax.fori_loop(0, n_steps, gen_cf_step, (cfs, opt_state))
    # for _ in tqdm(range(n_steps)):
    #     cfs, opt_state = gen_cf_step(x, cfs, opt_state)
    cf = projection_fn(x, cfs[:1, :], hard=True)
    return cf

In [None]:
#| export 
class DiverseCFConfig(BaseParser):
    n_cfs: int = 5
    n_steps: int = 1000
    lr: float = 0.01
    lambda_: float = 0.01  # loss = validity_loss + lambda_params * cost
    seed: int = 42

    @property
    def keys(self):
        return hk.PRNGSequence(self.seed)


In [None]:
#| export
class DiverseCF(BaseCFModule):
    name = "DiverseCF"

    def __init__(
        self,
        configs: Union[Dict[str, Any], DiverseCFConfig] = None,
    ):
        if configs is None:
            configs = DiverseCFConfig()
        self.configs = validate_configs(configs, DiverseCFConfig)

    def generate_cf(
        self,
        x: jnp.ndarray,  # `x` shape: (k,), where `k` is the number of features
        pred_fn: Callable[[jax.Array], jax.Array],
    ) -> jax.Array:
        return _diverse_cf(
            x=x,  # `x` shape: (k,), where `k` is the number of features
            pred_fn=pred_fn,  # y = pred_fn(x)
            n_cfs=self.configs.n_cfs,
            n_steps=self.configs.n_steps,
            lr=self.configs.lr,  # learning rate for each `cf` optimization step
            lambda_=self.configs.lambda_,  #  loss = validity_loss + lambda_params * cost
            key=next(self.configs.keys),
            projection_fn=self.data_module.apply_constraints,
            regularization_fn=self.data_module.apply_regularization
        )

    def generate_cfs(
        self,
        X: jax.Array,  # `x` shape: (b, k), where `b` is batch size, `k` is the number of features
        pred_fn: Callable[[jax.Array], jax.Array],
        is_parallel: bool = False,
    ) -> jax.Array:
        def _generate_cf(x: jax.Array) -> jnp.ndarray:
            return self.generate_cf(x, pred_fn)

        return (
            jax.vmap(_generate_cf)(X) if not is_parallel else jax.pmap(_generate_cf)(X)
        )

In [None]:
from relax.data import load_data
from relax.module import PredictiveTrainingModule, PredictiveTrainingModuleConfigs, load_pred_model
from relax.evaluate import generate_cf_explanations, benchmark_cfs
from relax.trainer import train_model

Load data:

In [None]:
dm = load_data('adult', data_configs=dict(sample_frac=0.1))



Train predictive model:

In [None]:
# load model
params, training_module = load_pred_model('adult')

# predict function
pred_fn = training_module.pred_fn


Define `DiverseCF`:

In [None]:
diversecf = DiverseCF()

Generate explanations:

In [None]:
#| output: false
cf_exp = generate_cf_explanations(
    diversecf, dm, pred_fn=pred_fn, 
    pred_fn_args=dict(
        params=params, rng_key=random.PRNGKey(0)
    )
)

  0%|          | 0/1000 [00:00<?, ?it/s]

Evaluate explanations:

In [None]:
benchmark_cfs([cf_exp])

Unnamed: 0,Unnamed: 1,acc,validity,proximity
adult,DiverseCF,0.8241,0.393932,1.913267
