# Diverse CF

In [None]:
# | default_exp methods.dice

In [None]:
#| include: false
%load_ext autoreload
%autoreload 2
from ipynb_path import *

In [None]:
# | export
from __future__ import annotations
from relax.import_essentials import *
from relax.methods.base import CFModule
from relax.base import BaseConfig
from relax.utils import auto_reshaping, grad_update, validate_configs

Using JAX backend.


An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


In [None]:
#| hide
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from relax.ml_model import MLModule
from relax.data_module import load_data
from relax.ml_model import load_ml_module
import torch
import relax

## Util Functions



In [None]:
#| export
@jit
def dpp_style_vmap(cfs: Array):
    def dpp_fn(cf_1, cf_2):
        return 1 / (1 + jnp.linalg.norm(cf_1 - cf_2, ord=1))
    
    det_entries = vmap(vmap(dpp_fn, in_axes=(None, 0)), in_axes=(0, None))(cfs, cfs)
    det_entries += jnp.eye(cfs.shape[0]) * 1e-8
    assert det_entries.shape == (cfs.shape[0], cfs.shape[0])
    return jnp.linalg.det(det_entries)

In [None]:
# From the original dice implementation
# https://github.com/interpretml/DiCE/blob/a772c8d4fcd88d1cab7f2e02b0bcc045dc0e2eab/dice_ml/explainer_interfaces/dice_pytorch.py#L222-L227
def dpp_style_torch(cfs: torch.Tensor):
    compute_dist = lambda x, y: torch.abs(x-y).sum()

    total_CFs = len(cfs)
    det_entries = torch.ones((total_CFs, total_CFs))
    for i in range(total_CFs):
        for j in range(total_CFs):
            det_entries[(i,j)] = 1.0/(1.0 + compute_dist(cfs[i], cfs[j]))
            if i == j:
                det_entries[(i,j)] += 1e-8
    return torch.det(det_entries)

In [None]:
def jax2torch(x: Array):
    return torch.from_numpy(x.__array__())

In [None]:
cfs = jrand.normal(jrand.PRNGKey(0), (100, 100))
cfs_tensor = jax2torch(cfs)
assert np.allclose(
    dpp_style_torch(cfs_tensor).numpy(),
    dpp_style_vmap(cfs)
)

  return torch.from_numpy(x.__array__())


Our jax-based implementation is ~500X faster than DiCE's pytorch implementation.

In [None]:
%%timeit -r 5
torch_res = dpp_style_torch(cfs_tensor)

318 ms ± 4.24 ms per loop (mean ± std. dev. of 5 runs, 1 loop each)


In [None]:
%%timeit -n 50
jax_res = dpp_style_vmap(cfs)

571 µs ± 44.4 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)


In [None]:
#| exporti
@ft.partial(jit, static_argnums=(2, 3, 4, 8, 9, 10, 11))
def _diverse_cf(
    x: jnp.DeviceArray,  # `x` shape: (k,), where `k` is the number of features
    y_target: Array, # `y_target` shape: (1,)
    pred_fn: Callable[[Array], Array],  # y = pred_fn(x)
    n_cfs: int,
    n_steps: int,
    lr: float,  # learning rate for each `cf` optimization step
    lambdas: Tuple[float, float, float, float], # (lambda_1, lambda_2, lambda_3, lambda_4)
    key: jrand.PRNGKey,
    validity_fn: Callable,
    cost_fn: Callable,
    apply_constraints_fn: Callable,
    compute_reg_loss_fn: Callable,
) -> Array:  # return `cf` shape: (k,)
    """Diverse Counterfactuals (Dice) algorithm."""

    def loss_fn(
        cfs: Array, # shape: (n_cfs, k)
        x: Array, # shape: (1, k)
        pred_fn: Callable[[Array], Array], # y = pred_fn(x)
        y_target: Array,
    ):
        def loss_fn_per_sample(cf, x):
            cf = cf.reshape(1, -1)
            cf_y_pred = pred_fn(cf)
            loss_1 = validity_fn(y_target, cf_y_pred)
            loss_2 = cost_fn(x, cf)
            loss_3 = - dpp_style_vmap(cf)
            loss_4 = compute_reg_loss_fn(x, cf)
            return (
                lambda_1 * loss_1 + 
                lambda_2 * loss_2 + 
                lambda_3 * loss_3 + 
                lambda_4 * loss_4
            )

        return jax.vmap(loss_fn_per_sample, in_axes=(0, None))(cfs, x).mean()
    
    @loop_tqdm(n_steps)
    def gen_cf_step(i, states: Tuple[Array, optax.OptState]):
        cf, opt_state = states
        grads = jax.grad(loss_fn)(cf, x, pred_fn, y_target)
        # grads = jax.grad(loss_fn)(cf, x, pred_fn, y_target)
        cf_updates, opt_state = grad_update(grads, cf, opt_state, opt)
        return cf_updates, opt_state
    
    lambda_1, lambda_2, lambda_3, lambda_4 = lambdas
    key, subkey = jrand.split(key)
    cfs = jrand.uniform(key, (n_cfs, x.shape[-1]))
    opt = optax.adam(lr)
    opt_state = opt.init(cfs)
    
    cfs, opt_state = lax.fori_loop(0, n_steps, gen_cf_step, (cfs, opt_state))
    # TODO: support return multiple cfs
    # cfs = apply_constraints_fn(x, cfs[:1, :], hard=True)
    cfs = apply_constraints_fn(x, cfs, hard=True)
    return cfs


## Config

In [None]:
#| export
class DiverseCFConfig(BaseConfig):
    n_cfs: int = 5
    n_steps: int = 1000
    lr: float = 0.001
    lambda_1: float = 1.0
    lambda_2: float = 1.0
    lambda_3: float = 1.0
    lambda_4: float = 0.1
    validity_fn: str = 'KLDivergence'
    cost_fn: str = 'MeanSquaredError'
    seed: int = 42


In [None]:
#| export
class DiverseCF(CFModule):

    def __init__(self, config: dict | DiverseCF = None, *, name: str = None):
        if config is None:
             config = DiverseCFConfig()
        config = validate_configs(config, DiverseCFConfig)
        name = "DiverseCF" if name is None else name
        super().__init__(config, name=name)
    
    def save(self, path: str):
        self.config.save(Path(path) / 'config.json')
    
    @classmethod
    def load_from_path(cls, path: str):
        config = DiverseCFConfig.load_from_json(Path(path) / 'config.json')
        return cls(config=config)

    @auto_reshaping('x', reshape_output=False)
    def generate_cf(
        self,
        x: Array,  # `x` shape: (k,), where `k` is the number of features
        pred_fn: Callable[[Array], Array],
        y_target: Array = None,
        rng_key: jnp.ndarray = None,
        **kwargs,
    ) -> jnp.DeviceArray:
        # TODO: Currently assumes binary classification.
        if y_target is None:
            y_target = 1 - pred_fn(x)
        else:
            y_target = jnp.array(y_target, copy=True).reshape(1, -1)
        if rng_key is None:
            raise ValueError("`rng_key` must be provided.")
        
        assert y_target.shape == (1, 2)
        # print(y_target)
        return _diverse_cf(
            x=x,  # `x` shape: (k,), where `k` is the number of features
            y_target=y_target,  # `y_target` shape: (1,)
            pred_fn=pred_fn,  # y = pred_fn(x)
            n_cfs=self.config.n_cfs,
            n_steps=self.config.n_steps,
            lr=self.config.lr,  # learning rate for each `cf` optimization step
            lambdas=(
                self.config.lambda_1, self.config.lambda_2, 
                self.config.lambda_3, self.config.lambda_4
            ),
            key=rng_key,
            validity_fn=keras.losses.get({'class_name': self.config.validity_fn, 'config': {'reduction': None}}),
            cost_fn=keras.losses.get({'class_name': self.config.cost_fn, 'config': {'reduction': None}}),
            apply_constraints_fn=self.apply_constraints,
            compute_reg_loss_fn=self.compute_reg_loss,
        )


In [None]:
dm = load_data('dummy')
model = load_ml_module('dummy')
xs_train, ys_train = dm['train']
xs_test, ys_test = dm['test']
x_shape = xs_test.shape



In [None]:
dcf = DiverseCF({'lambda_2': 4.0})
dcf.set_apply_constraints_fn(dm.apply_constraints)
dcf.set_compute_reg_loss_fn(dm.compute_reg_loss)
cf = dcf.generate_cf(xs_test[0], model.pred_fn, rng_key=jrand.PRNGKey(0))
assert cf.shape == (5, x_shape[1])

partial_gen = partial(dcf.generate_cf, pred_fn=model.pred_fn)
cfs = jax.vmap(partial_gen)(xs_test, rng_key=jrand.split(jrand.PRNGKey(0), xs_test.shape[0]))

assert cfs.shape == (x_shape[0], 5, x_shape[1])

print("Validity: ", keras.metrics.binary_accuracy(
    (1 - model.pred_fn(xs_test)).round(),
    model.pred_fn(cfs[:, 0, :])
).mean())

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

Validity:  1.0


In [None]:
dcf.save('tmp/dice/')
dcf_1 = DiverseCF.load_from_path('tmp/dice/')
dcf_1.set_apply_constraints_fn(dm.apply_constraints)
partial_gen_1 = ft.partial(dcf_1.generate_cf, pred_fn=model.pred_fn)
cfs_1 = jax.vmap(partial_gen_1)(xs_test, rng_key=jrand.split(jrand.PRNGKey(0), xs_test.shape[0]))

assert jnp.allclose(cfs, cfs_1)

  0%|          | 0/1000 [00:00<?, ?it/s]

In [None]:
exp = relax.generate_cf_explanations(
    dcf, dm, model.pred_fn
)
relax.benchmark_cfs([exp])

  0%|          | 0/1000 [00:00<?, ?it/s]

Unnamed: 0,Unnamed: 1,acc,validity,proximity
dummy,DiverseCF,0.983,1.0,1.264459
