# Evaluate

> Evaluating and benchmarking the quality of CF explanations.

In [None]:
# | include: false
# | default_exp evaluate


In [None]:
#| include: false
%load_ext autoreload
%autoreload 2
from ipynb_path import *
from nbdev import show_doc
import warnings
warnings.filterwarnings('ignore')

In [None]:
#| export
from __future__ import annotations
from relax.import_essentials import *
from relax.trainer import TrainingConfigs
from relax.data import TabularDataModule
from relax.utils import accuracy, proximity
from relax.methods.base import BaseCFModule, BaseParametricCFModule, BasePredFnCFModule
from relax.methods.counternet import CounterNet
from copy import deepcopy
from sklearn.neighbors import NearestNeighbors
from fastcore.test import test_fail

In [None]:
#| export
#| hide
@dataclass
class Explanation:
    """Generated CF Explanations class."""
    cf_name: str  # cf method's name
    data_module: TabularDataModule  # data module
    cfs: jnp.DeviceArray  # generated cf explanation of `X`
    total_time: float  # total runtime
    pred_fn: Callable[[jnp.DeviceArray], jnp.DeviceArray]  # predict function
    dataset_name: str = str()  # dataset name
    X: jnp.ndarray = None  # input
    y: jnp.ndarray = None  # label

    def __post_init__(self):
        if self.data_module:
            if self.dataset_name == str():
                self.dataset_name = self.data_module.data_name
            test_X, label = self.data_module.test_dataset[:]
            if self.X is None:
                self.X = test_X
            if self.y is None:
                self.y = label

CFExplanationResults = Explanation

In [None]:
show_doc(Explanation)

---

[source](https://github.com/birkhoffg/relax/tree/master/blob/master/relax/evaluate.py#L22){target="_blank" style="float:right; font-size:smaller"}

### EXPLANATION

::: {.doc-sig}

 CLASS relax.evaluate.<b>Explanation</b> <em>(cf_name, data_module, cfs, total_time, pred_fn, dataset_name='', X=None, y=None)</em>

:::

Generated CF Explanations class.

Arguments to `Explanation`:

* `cf_name`: cf method's name
* `dataset_name`: dataset name
* `X`: input
* `y`: label
* `cfs`: generated cf explanation of `X`
* `total_time`: total runtime
* `pred_fn`: predict function with only one input argument, 
and output a label (i.e., its format is `y=pred_fn(x)`).
* `data_module`: data module


## Parallelism Strategy

In [None]:
#| export
class BaseGenerationStrategy:
    """Base class for mapping strategy."""
    
    def __call__(
        self, 
        fn: Callable, # Function to generate cf for a single input
        X: jnp.ndarray, # Input instances to be explained
        pred_fn: Callable[[Array], Array],
        **kwargs
    ) -> Array: # Generated counterfactual explanations
        raise NotImplementedError

In [None]:
#| export
class IterativeGenerationStrategy(BaseGenerationStrategy):
    """Iterativly generate counterfactuals."""

    def __call__(
        self, 
        fn: Callable, # Function to generate cf for a single input
        X: jnp.ndarray, # Input instances to be explained
        pred_fn: Callable[[Array], Array],
        **kwargs
    ) -> Array: # Generated counterfactual explanations
        
        assert X.ndim == 2
        cfs = jnp.stack([fn(X[i], pred_fn=pred_fn, **kwargs) for i in range(X.shape[0])])
        assert X.shape == cfs.shape
        return cfs

In [None]:
#| export
class VmapGenerationStrategy(BaseGenerationStrategy):
    """Generate counterfactuals via `jax.vmap`."""

    def __call__(
        self, 
        fn: Callable, # Function to generate cf for a single input
        X: jnp.ndarray, # Input instances to be explained
        pred_fn: Callable[[Array], Array],
        **kwargs
    ) -> Array: # Generated counterfactual explanations
        
        assert X.ndim == 2
        partial_fn = partial(fn, pred_fn=pred_fn, **kwargs)
        cfs = jax.vmap(partial_fn)(X)
        return cfs

In [None]:
#| exporti
def _pad_divisible_X(
    X: Array,
    n_devices: int
):
    """Pad `X` to be divisible by `n_devices`."""
    if X.shape[0] % n_devices != 0:
        pad_size = n_devices - X.shape[0] % n_devices
        X = jnp.concatenate([X, jnp.zeros((pad_size, *X.shape[1:]))])
    X_padded = X.reshape(n_devices, -1, *X.shape[1:])
    return X_padded

In [None]:
#| hide
X = jnp.ones((5, 29))
X_padded = _pad_divisible_X(X, 2)
assert X_padded.shape == (2, 3, 29)
assert X.sum() == X_padded.sum()

X = jnp.ones((5, 29))
X_padded = _pad_divisible_X(X, 6)
assert X_padded.shape == (6, 1, 29)

X = jnp.ones((5, 29))
X_padded = _pad_divisible_X(X, 1)
assert X_padded.shape == (1, 5, 29)


In [None]:
#| export
class PmapGenerationStrategy(BaseGenerationStrategy):
    def __init__(
        self, 
        n_devices: int = None, # Number of devices. If None, use all available devices
        strategy: str = 'auto', # Strategy to generate counterfactuals
        **kwargs
    ):
        self.strategy = strategy
        self.n_devices = n_devices or jax.device_count()

    def __call__(
        self, 
        fn: Callable, # Function to generate cf for a single input
        X: jnp.ndarray, # Input instances to be explained
        pred_fn: Callable[[Array], Array],
        **kwargs
    ) -> Array: # Generated counterfactual explanations
        
        assert X.ndim == 2
        X_padded = _pad_divisible_X(X, self.n_devices)
        partial_fn = partial(fn, pred_fn=pred_fn, **kwargs)
        cfs = jax.pmap(jax.vmap(partial_fn))(X_padded)
        cfs = cfs.reshape(-1, *cfs.shape[2:])
        cfs = cfs[:X.shape[0]]
        return cfs

In [None]:
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8'

w = jrand.normal(jrand.PRNGKey(0), (100, 100))
X = jrand.normal(jrand.PRNGKey(0), (1000, 100))

@jit
def pred_fn(x): return jnp.dot(x, w.T)

def f(x, pred_fn=None, **kwargs):
    return pred_fn(x)

iter_gen = IterativeGenerationStrategy()
vmap_gen = VmapGenerationStrategy()
pmap_gen = PmapGenerationStrategy()

In [None]:
cf_iter = iter_gen(f, X, pred_fn=pred_fn).block_until_ready()

In [None]:
cf_vmap = vmap_gen(f, X, pred_fn=pred_fn).block_until_ready()

In [None]:
cf_pmap = pmap_gen(f, X, pred_fn=pred_fn).block_until_ready()

In [None]:
assert jnp.allclose(cf_iter, cf_vmap, atol=1e-4)
assert jnp.allclose(cf_iter, cf_pmap, atol=1e-4)

In [None]:
#| export
class StrategyFactory(object):
    """Factory class for Parallelism Strategy."""

    __strategy_map = {
        'iter': IterativeGenerationStrategy(),
        'vmap': VmapGenerationStrategy(),
        'pmap': PmapGenerationStrategy(),
    }

    def __init__(self) -> None:
        raise ValueError("This class should not be instantiated.")
        
    @staticmethod
    def get_default_strategy() -> BaseGenerationStrategy:
        """Get default strategy."""
        return VmapGenerationStrategy()

    @classmethod
    def get_strategy(cls, strategy: str | BaseGenerationStrategy) -> BaseGenerationStrategy:
        """Get strategy."""
        if isinstance(strategy, BaseGenerationStrategy):
            return strategy
        elif isinstance(strategy, str) and strategy in cls.__strategy_map:
            return cls.__strategy_map[strategy]
        else:
            raise ValueError(f"Invalid strategy: {strategy}")

In [None]:
#| hide
it = StrategyFactory.get_strategy('iter')
vm = StrategyFactory.get_strategy('vmap')
pm = StrategyFactory.get_strategy('pmap')
default = StrategyFactory.get_default_strategy()
cus = StrategyFactory.get_strategy(VmapGenerationStrategy())

assert isinstance(it, IterativeGenerationStrategy)
assert isinstance(vm, VmapGenerationStrategy)
assert isinstance(pm, PmapGenerationStrategy)
assert isinstance(default, VmapGenerationStrategy)
assert isinstance(cus, VmapGenerationStrategy)


## Generating CF Explanation Results


In [None]:
#| exporti
def _validate_configs(
    cf_module: BaseCFModule,
    datamodule: TabularDataModule,
    pred_fn: Callable[[jnp.DeviceArray], jnp.DeviceArray] = None,
    t_configs=None
):
    if (pred_fn is None) and (not isinstance(cf_module, BasePredFnCFModule)):
        warnings.warn(f"`{type(cf_module).__name__}` is not a subclass of `BasePredFnCFModule`."
            "This might cause problems as you set `pred_fn=None`, "
            f"which infers that `{type(cf_module).__name__}` has an attribute `pred_fn`.")


def _prepare_module(
    cf_module: BaseCFModule,
    datamodule: TabularDataModule
):
    cf_module.hook_data_module(datamodule)
    return cf_module

def _train_parametric_module(
    cf_module: BaseParametricCFModule,
    datamodule: TabularDataModule,
    t_configs=None,
    pred_fn=None
):
    if not cf_module._is_module_trained():
        print(f'{type(cf_module).__name__} contains parametric models. '
            'Starts training before generating explanations...')
        cf_module.train(datamodule, t_configs, pred_fn=pred_fn)
    return cf_module

In [None]:
#| exporti
def _check_aux_pred_fn_args(pred_fn_args: dict | None):
    if pred_fn_args is None:
        return dict()
    elif isinstance(pred_fn_args, dict):
        return pred_fn_args
    else:
        raise ValueError(f'`pred_fn_args` should be a `dict`,',
            f'but got `{type(pred_fn_args).__name__}`')

class _AuxPredFn:
    def __init__(self, pred_fn, pred_fn_args: dict | None):
        self.pred_fn = pred_fn
        self.fn_args = deepcopy(_check_aux_pred_fn_args(pred_fn_args))

    def __call__(self, x: jnp.DeviceArray) -> jnp.DeviceArray:
        return self.pred_fn(x, **self.fn_args)


def _check_pred_fn(
    pred_fn: callable | None, 
    cf_module: BaseCFModule
) -> callable:
    if pred_fn is None:
        try:
            pred_fn = cf_module.pred_fn
        except AttributeError:
            raise AttributeError(
                    "`generate_cf_explanations` is incorrectly configured."
                    f"It is supposed to be `pred_fn != None`,"
                    f"or `{type(cf_module).__name__}` has attribute `pred_fn`."
                    f"However, we got `pred_fn={pred_fn}`, "
                    f"and `{type(cf_module).__name__}` has not attribute `pred_fn`."
            )
    elif isinstance(cf_module, BasePredFnCFModule):
        # override pred_fn if `cf_module` has `pred_fn`
        pred_fn = cf_module.pred_fn
    return pred_fn

In [None]:
#| export
def generate_cf_explanations(
    cf_module: BaseCFModule, # CF Explanation Module
    datamodule: TabularDataModule, # Data Module
    pred_fn: callable = None, # Predictive function
    strategy: str | BaseGenerationStrategy = 'vmap', # Parallelism Strategy for generating CFs
    t_configs: TrainingConfigs = None, # training configs for `BaseParametricCFModule`
    pred_fn_args: dict = None # auxiliary arguments for `pred_fn` 
) -> Explanation:
    """Generate CF explanations."""

    _validate_configs(cf_module, datamodule, pred_fn, t_configs)
    cf_module = _prepare_module(cf_module, datamodule)

    if isinstance(cf_module, BaseParametricCFModule):
        cf_module = _train_parametric_module(
            cf_module, datamodule, t_configs=t_configs, pred_fn=pred_fn
        )
    X, _ = datamodule.test_dataset[:]
    
    # create `pred_fn` which only takes `x` as an input
    if pred_fn is not None:
        pred_fn = _AuxPredFn(pred_fn, pred_fn_args=pred_fn_args)

    strategy = StrategyFactory.get_strategy(strategy)
    current_time = time.time()
    cfs = strategy(cf_module.generate_cf, X, pred_fn=pred_fn)
    total_time = time.time() - current_time

    # check pred_fn
    pred_fn = _check_pred_fn(pred_fn, cf_module)

    return Explanation(
        cf_name=cf_module.name,
        data_module=datamodule,
        cfs=cfs,
        total_time=total_time,
        pred_fn=pred_fn,
    )


The `pred_fn` in `generate_cf_explanations` is a model's prediction function. 
The general format is `y = pred_fn(x, **pred_fn_args)`. 
If `pred_fn` is not parameterized by other variables (except input `x`), 
then `pred_fn_args` is set to `None`, which is the default setting.
Otherwise, you should pass these argument as a `dict`.

For example, we have a simple linear function

In [None]:
def linear_pred_fn(x: jnp.DeviceArray, params: jnp.DeviceArray):
    return x @ params

To pass `linear_pred_fn` to `generate_cf_explanations`, 
we can either create an auxiliary function of `linear_pred_fn`,
or pass `params` into `pred_fn_args`.

Assuming we now have the input `x` and `params`:

In [None]:
x = jax.random.normal(random.PRNGKey(0), shape=(5, 10)) # input
params = jnp.ones((10, 1)) # params



1. Create an auxillary function (Not recommended)

```python
aux_linear_pred_fn = lambda x: linear_pred_fn(x, params)
explanations = generate_cf_explanations(
    cf_module, datamodule, aux_linear_pred_fn
)
```

This approach could work, but if `params` is changed, 
`explanations.pred_fn` might not work as expected.

2. Pass `params` into `pred_fn_args`

```python
explanations = generate_cf_explanations(
    cf_module, datamodule, linear_pred_fn, 
    pred_fn_args=dict(params=params)
)
```

This is a recommended approach as we will deepcopy `params` inside `generate_cf_explanations`.


The `pred_fn` in `explanations` only takes `x: jnp.DeviceArray` as an input.
For example, to make predictions, we use

```python
y = explanations.pred_fn(x)
```

## Evaluating Metrics

In [None]:
#| export
class BaseEvalMetrics(ABC):
    """Base evaluation metrics class."""

    def __init__(self, name: str = None):
        if name is None: name = type(self).__name__
        self.name = name

    def __str__(self) -> str:
        has_name = hasattr(self, 'name')
        if not has_name:
            raise ValidationError(
                "EvalMetrics must have a name. Add the following as the first line in your "
                f"__init__ method:\n\nsuper({self.__name__}, self).__init__()")
        return self.name

    def __call__(self, cf_explanations: Explanation) -> Any:
        raise NotImplementedError

In [None]:
#| exporti
def _compute_acc(
    input: jnp.DeviceArray, # input dim: [N, k]
    label: jnp.DeviceArray, # label dim: [N] or [N, 1]
    pred_fn: Callable[[jnp.DeviceArray], jnp.DeviceArray]
) -> float:
    y_pred = pred_fn(input).reshape(-1, 1).round()
    label = label.reshape(-1, 1)
    return accuracy(y_pred, label).item()

In [None]:
#| hide
fake_pred_fn = lambda x: x.clip(0, 1).mean(axis=1)
inputs = jnp.array([
    [0., 0.99], [0.1, 0.1], [0.99, 0.1], [0.99, 0.99] # [0, 0, 1, 1]
])
labels_1 = jnp.array([0, 0, 1, 1])
labels_2 = jnp.array([0, 1, 0, 1])
assert _compute_acc(inputs, labels_1, fake_pred_fn) == 1.0
assert _compute_acc(inputs, labels_2, fake_pred_fn) == 0.5

In [None]:
#| export
class PredictiveAccuracy(BaseEvalMetrics):
    """Compute the accuracy of the predict function."""
    
    def __init__(self, name: str = "accuracy"):
        super().__init__(name=name)

    def __call__(self, cf_explanations: Explanation) -> float:
        X, y = cf_explanations.data_module.test_dataset[:]
        return _compute_acc(X, y, cf_explanations.pred_fn)

In [None]:
#| hide
_acc = PredictiveAccuracy()
assert _acc.name == "accuracy"
assert str(_acc) == "accuracy"
_acc = PredictiveAccuracy('acc')
assert _acc.name == "acc"

In [None]:
#| exporti
def _compute_val(
    input: jnp.DeviceArray, # input dim: [N, k]
    cfs: jnp.DeviceArray, # cfs dim: [N, k]
    pred_fn: Callable[[jnp.DeviceArray], jnp.DeviceArray]
):
    y_pred = pred_fn(input).reshape(-1, 1).round()
    y_prime = jnp.ones_like(y_pred) - y_pred
    cf_y = pred_fn(cfs).reshape(-1, 1).round()
    return accuracy(y_prime, cf_y).item()

In [None]:
#| hide
fake_pred_fn = lambda x: x.clip(0, 1).mean(axis=1)
inputs = jnp.array([
    [0., 0.99], [0.1, 0.1], [0.99, 0.1], [0.99, 0.99] # [0, 0, 1, 1]
])
cfs_1 = jnp.array([
    [0.99, 0.99], [0.1, 0.1], [0.1, 0.1], [0., 0.99] # [1, 0, 0, 0]
])
cfs_2 = jnp.array([
    [0.99, 0.], [0.1, 0.1], [0.1, 0.1], [0., 0.99] # [0, 0, 0, 0]
])

assert _compute_val(inputs, cfs_1, fake_pred_fn) == 0.75
assert _compute_val(inputs, cfs_2, fake_pred_fn) == 0.5

In [None]:
#| export
class Validity(BaseEvalMetrics):
    """Compute fraction of input instances on which CF explanation methods output valid CF examples."""
    
    def __init__(self, name: str = "validity"):
        super().__init__(name=name)
    
    def __call__(self, cf_explanations: Explanation) -> float:
        X, _ = cf_explanations.data_module.test_dataset[:]
        return _compute_val(
            X, cf_explanations.cfs, cf_explanations.pred_fn
        )

In [None]:
#| exporti
def _compute_proximity(
    inputs: jnp.DeviceArray, # input dim: [N, k]
    cfs: jnp.DeviceArray, # cfs dim: [N, k]
):
    prox = jnp.linalg.norm(inputs - cfs, ord=1, axis=1).mean()
    return prox.item()

In [None]:
inputs = jnp.array([
    [0, 1], [1, 0], [1, 1]])
cfs_1 = jnp.array([
    [0, 0], [0, 0], [0, 1]])
cfs_2 = jnp.array([
    [1, 0], [1, -2], [0, 0]])
assert _compute_proximity(inputs, cfs_1) == 1.0
assert _compute_proximity(inputs, cfs_2) == 2.0

In [None]:
#| export
class Proximity(BaseEvalMetrics):
    """Compute L1 norm distance between input datasets and CF examples divided by the number of features."""
    def __init__(self, name: str = "proximity"):
        super().__init__(name=name)
    
    def __call__(self, cf_explanations: Explanation) -> float:
        X, _ = cf_explanations.data_module.test_dataset[:]
        return _compute_proximity(X, cf_explanations.cfs)

In [None]:
#| exporti
def _compute_spar(
    input: jnp.DeviceArray,
    cfs: jnp.DeviceArray,
    cat_idx: int
):
    # calculate sparsity
    cat_sparsity = proximity(input[:, cat_idx:], cfs[:, cat_idx:]) / 2
    cont_sparsity = jnp.linalg.norm(
        jnp.abs(input[:, :cat_idx] - cfs[:, :cat_idx]), ord=0, axis=1
    ).mean()
    return (cont_sparsity + cat_sparsity).item()


In [None]:
#| export
class Sparsity(BaseEvalMetrics):
    """Compute the number of feature changes between input datasets and CF examples."""

    def __init__(self, name: str = "sparsity"):
        super().__init__(name=name)
    
    def __call__(self, cf_explanations: Explanation) -> float:
        X, _ = cf_explanations.data_module.test_dataset[:]
        return _compute_spar(X, cf_explanations.cfs, cf_explanations.cat_idx)

In [None]:
#| exporti
def _compute_manifold_dist(
    input: jnp.DeviceArray,
    cfs: jnp.DeviceArray,
    n_neighbors: int = 1,
    p: int = 2
):
    knn = NearestNeighbors(n_neighbors=n_neighbors, p=p)
    knn.fit(input)
    nearest_dist, nearest_points = knn.kneighbors(cfs, 1, return_distance=True)
    return jnp.mean(nearest_dist).item()

In [None]:
#| export
class ManifoldDist(BaseEvalMetrics):
    """Compute the L1 distance to the n-nearest neighbor for all CF examples."""
    def __init__(self, n_neighbors: int = 1, p: int = 2, name: str = "manifold_dist"):
        super().__init__(name=name)
        self.n_neighbors = n_neighbors
        self.p = p
        
    def __call__(self, cf_explanations: Explanation) -> float:
        X, _ = cf_explanations.data_module.test_dataset[:]
        return _compute_manifold_dist(
            X, cf_explanations.cfs, self.n_neighbors, self.p
        )

In [None]:
#| export
class Runtime(BaseEvalMetrics):
    """Get the running time to generate CF examples."""
    def __init__(self, name: str = "runtime"):
        super().__init__(name=name)
    
    def __call__(self, cf_explanations: Explanation) -> float:
        return cf_explanations.total_time

In [None]:
#| hide
pred_fn_test = lambda x: jnp.clip(x + 0.5, 0., 1)

x_1 = jnp.array([[0.1], [0.34], [0.4], [-0.2], [0.7]])
cf_1 = jnp.array([[-0.2], [0.4], [-0.1], [0.7], [-0.1]])
y_1 = jnp.array([1., 0., 1., 1., 1.])
x_2 = jnp.array([[-0.5], [0.34], [0.4], [-0.2], [0.7]])
cf_2 = jnp.array([[0.2], [0.4], [-0.1], [0.7], [-0.1]])
y_2 = jnp.array([[0.], [1.], [1.], [1.], [1.]])

_acc_1 = _compute_acc(x_1, y_1, pred_fn_test)
_acc_2 = _compute_acc(x_2, y_2, pred_fn_test)
_val_1 = _compute_val(x_1, cf_1, pred_fn_test)
_val_2 = _compute_val(x_2, cf_2, pred_fn_test)

assert jnp.isclose(_acc_1, 0.6)
assert jnp.isclose(_acc_2, 0.8)
assert jnp.isclose(_val_1, 0.8)
assert jnp.isclose(_val_2, 0.8)

In [None]:
#| export
def _create_second_order_cfs(cf_results: CFExplanationResults, threshold: float = 2.0):
    X, y = cf_results.data_module.test_dataset[:]
    cfs = cf_results.cfs
    scaler = cf_results.data_module.normalizer
    cat_idx = cf_results.data_module.cat_idx

    # get normalized threshold = threshold / (max - min)
    data_range = scaler.data_range_
    thredshold_normed = threshold / data_range

    # select continous features
    x_cont = X[:, :cat_idx]
    cf_cont = cfs[:, :cat_idx]
    # calculate the diff between x and c
    cont_diff = jnp.abs(x_cont - cf_cont) <= thredshold_normed
    # new cfs
    cfs_cont_hat = jnp.where(cont_diff, x_cont, cf_cont)

    cfs_hat = jnp.concatenate((cfs_cont_hat, cfs[:, cat_idx:]), axis=-1)
    return cfs_hat


def compute_so_validity(cf_results: CFExplanationResults, threshold: float = 2.0):
    cfs_hat = _create_second_order_cfs(cf_results, threshold)
    cf_results_so = deepcopy(cf_results)
    cf_results_so.cfs = cfs_hat
    compute_validity = Validity()
    return compute_validity(cf_results_so)


def compute_so_proximity(cf_results: CFExplanationResults, threshold: float = 2.0):
    cfs_hat = _create_second_order_cfs(cf_results, threshold)
    cf_results_so = deepcopy(cf_results)
    cf_results_so.cfs = cfs_hat
    compute_proximity = Proximity()
    return compute_proximity(cf_results_so)


def compute_so_sparsity(cf_results: CFExplanationResults, threshold: float = 2.0):
    cfs_hat = _create_second_order_cfs(cf_results, threshold)
    cf_results_so = deepcopy(cf_results)
    cf_results_so.cfs = cfs_hat
    compute_sparsity = Sparsity()
    return compute_sparsity(cf_results_so)


## Benchmarking

In [None]:
#| exporti
def fake_explanations():
    """Generate sudo explanations for testing."""
    from relax.data import load_data

    dm = load_data("adult")
    X, y = dm.test_dataset[:]
    cfs = X
    dn = dm.data_name
    pred_fn = lambda x: jax.random.bernoulli(jax.random.PRNGKey(0), 0.5, (x.shape[0], 1)).astype(float)
    assert y.shape == pred_fn(X).shape
    return Explanation(
        cf_name='sudo', data_module=dm, cfs=cfs, pred_fn=pred_fn, total_time=0.0, dataset_name=dn
    )


In [None]:
#| export
# METRICS = dict(
#     acc=PredictiveAccuracy(),
#     accuracy=PredictiveAccuracy(),
#     validity=Validity(),
#     proximity=Proximity(),
#     runtime=Runtime(),
#     manifold_dist=ManifoldDist(),
#     # validity=compute_so_validity,
#     # so_proximity=compute_so_proximity,
#     # so_sparsity=compute_so_sparsity
# )

METRICS_CALLABLE = [
    PredictiveAccuracy('acc'),
    PredictiveAccuracy('accuracy'),
    Validity(),
    Proximity(),
    Runtime(),
    ManifoldDist(),
]

METRICS = { m.name: m for m in METRICS_CALLABLE }

DEFAULT_METRICS = ["acc", "validity", "proximity"]

In [None]:
#| hide
for m in METRICS.keys(): assert isinstance(m, str)

In [None]:
#| exporti
def _get_metric(metric: str | BaseEvalMetrics, cf_exp: Explanation):
    if isinstance(metric, str):
        if metric not in METRICS.keys():
            raise ValueError(f"'{metric}' is not supported. Must be one of {METRICS.keys()}")
        res = METRICS[metric](cf_exp)
    elif callable(metric):
        # f(cf_exp) not supported for now
        if not isinstance(metric, BaseEvalMetrics):
            raise ValueError(f"metric needs to be a subclass of `BaseEvalMetrics`.")
        res = metric(cf_exp)
    else:
        raise ValueError(f"{type(metric).__name__} is not supported as a metric.")
    
    if isinstance(res, jnp.ndarray) and res.shape == (1,):
        res = res.item()
    return res

In [None]:
#| hide
exp = fake_explanations()
_acc_1 = _get_metric('acc', exp)
test_fail(lambda: _get_metric('acc_1', exp), contains='is not supported')
_acc_2 = _get_metric(PredictiveAccuracy(), exp)
assert jnp.allclose(_acc_1, _acc_2)
# functional callable not supported
test_fail(lambda: _get_metric(Proximity, exp), contains='needs to be a subclass')
test_fail(lambda: _get_metric(lambda: 1., exp), contains='needs to be a subclass') 

for m in METRICS_CALLABLE:
    _res = _get_metric(m, exp)
    assert isinstance(_res, (int, float))
    assert not isinstance(_res, jnp.ndarray)

In [None]:
# | export
def evaluate_cfs(
    cf_exp: Explanation, # CF Explanations
    metrics: Iterable[Union[str, BaseEvalMetrics]] = None, # A list of Metrics. Can be `str` or a subclass of `BaseEvalMetrics`
    return_dict: bool = True, # return a dictionary or not (default: True)
    return_df: bool = False # return a pandas Dataframe or not (default: False)
):
    cf_name = cf_exp.cf_name
    data_name = cf_exp.data_module.data_name
    result_dict = { (data_name, cf_name): dict() }

    if metrics is None:
        metrics = DEFAULT_METRICS

    for metric in metrics:
        metric_name = str(metric)
        result_dict[(data_name, cf_name)][metric_name] = _get_metric(metric, cf_exp)
    result_df = pd.DataFrame.from_dict(result_dict, orient="index")
    
    if return_dict and return_df:
        return (result_dict, result_df)
    elif return_dict or return_df:
        return result_df if return_df else result_dict

In [None]:
#| hide
exp = fake_explanations()
evaluate_cfs(exp)
evaluate_cfs(exp, metrics=["acc", "validity", "proximity", "runtime"])
d, df = evaluate_cfs(exp, metrics=["acc", "validity", "proximity", "runtime"], return_df=True)
assert isinstance(d, dict)
assert isinstance(df, pd.DataFrame)
df = evaluate_cfs(exp, metrics=["acc", "validity", "proximity", "runtime"], return_df=True, return_dict=False)
assert isinstance(df, pd.DataFrame)

evaluate_cfs(exp, metrics=[PredictiveAccuracy(), Validity()])

{('adult', 'sudo'): {'accuracy': 0.4939196705818176, 'validity': 0.0}}

In [None]:
# | export
def benchmark_cfs(
    cf_results_list: Iterable[CFExplanationResults],
    metrics: Optional[Iterable[str]] = None,
):
    dfs = [
        evaluate_cfs(
            cf_exp=cf_results, metrics=metrics, return_dict=False, return_df=True
        )
        for cf_results in cf_results_list
    ]
    return pd.concat(dfs)

## How to evaluate a CF Explanation Module

In [None]:
from relax.module import PredictiveTrainingModule
from relax.trainer import train_model
from relax.utils import load_json

In [None]:
configs = load_json('assets/configs/data_configs/adult.json')
m_configs = configs['mlp_configs']
data_configs = configs['data_configs']
data_configs['sample_frac'] = 0.1

t_configs = {
    'n_epochs': 10,
    'monitor_metrics': 'val/val_loss',
    'seed': 42,
    "batch_size": 256
} 

We first train a model

In [None]:
training_module = PredictiveTrainingModule(m_configs)
dm = TabularDataModule(data_configs)

params, opt_state = train_model(
    training_module, 
    dm, 
    t_configs
)
pred_fn = lambda x, params, prng_key: \
    training_module.forward(params, prng_key, x, is_training=False)

Epoch 9: 100%|██████████| 96/96 [00:01<00:00, 53.81batch/s, train/train_loss_1=0.0791]


Now, we can start to benchmark different methods

In [None]:
from relax.methods import VanillaCF, CounterNet

Generate CF explanations for `VanillaCF`

In [None]:
#| slow
vanillacf = VanillaCF(dict(n_steps=1000, lr=0.001))
vanillacf_exp = generate_cf_explanations(
    vanillacf, dm, pred_fn,
    pred_fn_args=dict(params=params, prng_key=random.PRNGKey(0))
)

100%|██████████| 1000/1000 [00:10<00:00, 92.93it/s]


In [None]:
#| slow
assert vanillacf_exp.cf_name == vanillacf.name
assert vanillacf_exp.dataset_name == dm.data_name
assert vanillacf_exp.X.shape == vanillacf_exp.cfs.shape
assert vanillacf_exp.pred_fn(vanillacf_exp.X).shape == vanillacf_exp.y.shape

Generate CF explanations for `CounterNet`

In [None]:
#| slow
counternet = CounterNet()
counternet_exp = generate_cf_explanations(counternet, dm, pred_fn=None)

CounterNet contains parametric models. Starts training before generating explanations...


Epoch 99: 100%|██████████| 191/191 [00:03<00:00, 58.07batch/s, train/train_loss_1=0.0657, train/train_loss_2=0.000985, train/train_loss_3=0.0963]


Note that `CounterNet` contains a predictive module, so we set `pred_fn=None`

In [None]:
#| slow
assert counternet_exp.cf_name == counternet.name
assert counternet_exp.dataset_name == dm.data_name
assert counternet_exp.X.shape == counternet_exp.cfs.shape
assert counternet_exp.pred_fn(counternet_exp.X).shape == counternet_exp.y.shape
assert counternet_exp.pred_fn == counternet.pred_fn

If `cf_module` is a subclass of `BasePredFnCFModule` (e.g., `CounterNet`),
the `pred_fn` in `Explanation` will be set to `cf_module.pred_fn`,
and the `pred_fn` argument passed `generate_cf_explanations` will be ignored.

In [None]:
#| slow
counternet_exp_1 = generate_cf_explanations(counternet, dm, pred_fn=pred_fn)
assert counternet_exp_1.pred_fn != pred_fn
assert counternet_exp_1.pred_fn == counternet.pred_fn

CounterNet contains parametric models. Starts training before generating explanations...


Epoch 99: 100%|██████████| 191/191 [00:02<00:00, 64.02batch/s, train/train_loss_1=0.0713, train/train_loss_2=0.000427, train/train_loss_3=0.0944]


Now, we can compute metrics for benchmarking different CF explanation methods.

In [None]:
#| slow
evaluate_cfs(vanillacf_exp, return_df=True)[1]

Unnamed: 0,Unnamed: 1,acc,validity,proximity
adult,VanillaCF,0.822012,0.93674,7.62256


In [None]:
#| slow
evaluate_cfs(counternet_exp, return_df=True)[1]

Unnamed: 0,Unnamed: 1,acc,validity,proximity
adult,CounterNet,0.831347,0.958605,5.9374576


In [None]:
#| slow
benchmark_cfs([vanillacf_exp, counternet_exp])

Unnamed: 0,Unnamed: 1,acc,validity,proximity
adult,VanillaCF,0.822012,0.93674,7.62256
adult,CounterNet,0.831347,0.958605,5.9374576
