# Evaluate

> Evaluating and benchmarking the quality of CF explanations.

In [None]:
# | include: false
# | default_exp evaluate


In [None]:
#| include: false
%load_ext autoreload
%autoreload 2
from ipynb_path import *
from nbdev import show_doc

In [None]:
# | export
from __future__ import annotations
from cfnet.import_essentials import *
from cfnet.train import train_model, TensorboardLogger
from cfnet.datasets import TabularDataModule
from cfnet.utils import accuracy, proximity
from cfnet.methods.base import BaseCFModule, BaseParametricCFModule, BasePredFnCFModule
from cfnet.methods.counternet import CounterNet
from copy import deepcopy
from sklearn.neighbors import NearestNeighbors


In [None]:
#| export
#| hide
@dataclass
class Explanation:
    """Generated CF Explanations class."""
    cf_name: str  # cf method's name
    data_module: TabularDataModule  # data module
    cfs: jnp.DeviceArray  # generated cf explanation of `X`
    total_time: float  # total runtime
    pred_fn: Callable[[jnp.DeviceArray], jnp.DeviceArray]  # predict function
    dataset_name: str = str()  # dataset name
    X: jnp.ndarray = None  # input
    y: jnp.ndarray = None  # label

    def __post_init__(self):
        if self.data_module:
            if self.dataset_name == str():
                self.dataset_name = self.data_module.data_name
            test_X, label = self.data_module.test_dataset[:]
            if self.X is None:
                self.X = test_X
            if self.y is None:
                self.y = label

CFExplanationResults = Explanation

In [None]:
show_doc(Explanation)

---

[source](https://github.com/birkhoffg/cfnet/tree/master/blob/master/cfnet/evaluate.py#L22){target="_blank" style="float:right; font-size:smaller"}

### Explanation

>      Explanation (cf_name:str, data_module:TabularDataModule,
>                   cfs:jnp.DeviceArray, total_time:float,
>                   pred_fn:Callable[[jnp.DeviceArray],jnp.DeviceArray],
>                   dataset_name:str='', X:jnp.ndarray=None, y:jnp.ndarray=None)

Generated CF Explanations class.

Arguments to `Explanation`:

* `cf_name`: cf method's name
* `dataset_name`: dataset name
* `X`: input
* `y`: label
* `cfs`: generated cf explanation of `X`
* `total_time`: total runtime
* `pred_fn`: predict function
* `data_module`: data module


## Generating CF Explanation Results


In [None]:
#| exporti
def _validate_configs(
    cf_module: BaseCFModule,
    datamodule: TabularDataModule,
    pred_fn: Callable[[jnp.DeviceArray], jnp.DeviceArray] = None,
    t_configs=None
):
    if (pred_fn is None) and (not isinstance(cf_module, BasePredFnCFModule)):
        warnings.warn(f"`{type(cf_module).__name__}` is not a subclass of `BasePredFnCFModule`."
            "This might cause problems as you set `pred_fn=None`, "
            f"which infers that `{type(cf_module).__name__}` has an attribute `pred_fn`.")


def _prepare_module(
    cf_module: BaseCFModule,
    datamodule: TabularDataModule
):
    cf_module.update_cat_info(datamodule)
    return cf_module

def _train_parametric_module(
    cf_module: BaseCFModule,
    datamodule: TabularDataModule,
    t_configs=None
):
    print(f'{type(cf_module).__name__} contains parametric models. '
        'Starts training before generating explanations...')
    cf_module.train(datamodule, t_configs)
    return cf_module

def _check_pred_fn(pred_fn, cf_module):
    if pred_fn is None:
        try:
            pred_fn = cf_module.pred_fn
        except AttributeError:
            raise AttributeError(
                    "`generate_cf_explanations` is incorrectly configured."
                    f"It is supposed to be `pred_fn != None`,"
                    f"or `{type(cf_module).__name__}` has attribute `pred_fn`."
                    f"However, we got `pred_fn={pred_fn}`, "
                    f"and `{type(cf_module).__name__}` has not attribute `pred_fn`."
            )
    return pred_fn

In [None]:
#| export
def generate_cf_explanations(
    cf_module: BaseCFModule,
    datamodule: TabularDataModule,
    pred_fn: Callable[[jnp.DeviceArray], jnp.DeviceArray] = None,
    *,
    t_configs=None
) -> Explanation:
    """Generate CF explanations."""

    _validate_configs(cf_module, datamodule, pred_fn, t_configs)
    cf_module = _prepare_module(cf_module, datamodule)

    if isinstance(cf_module, BaseParametricCFModule):
        cf_module = _train_parametric_module(
            cf_module, datamodule, t_configs=t_configs
        )
    X, _ = datamodule.test_dataset[:]

    # generate cfs
    current_time = time.time()
    cfs = cf_module.generate_cfs(X, pred_fn=pred_fn)
    total_time = time.time() - current_time
    # check pred_fn
    pred_fn = _check_pred_fn(pred_fn, cf_module)

    return Explanation(
        cf_name=cf_module.name,
        data_module=datamodule,
        cfs=cfs,
        total_time=total_time,
        pred_fn=pred_fn,
    )

# def generate_cf_results(
#     cf_module: BaseCFExplanationModule,
#     dm: TabularDataModule,
#     pred_fn: Callable[[jnp.DeviceArray], jnp.DeviceArray] = None,
#     params: hk.Params = None,  # params of `cf_module`
#     rng_key: Optional[random.PRNGKey] = None,
# ) -> CFExplanationResults:
#     # validate arguments
#     if (pred_fn is None) and (params is None) and (rng_key is None):
#         raise ValueError(
#             "A valid `pred_fn: Callable[jnp.DeviceArray], jnp.DeviceArray]` or `params: hk.Params` needs to be passed."
#         )
#     # prepare
#     X, y = dm.test_dataset[:]
#     cf_module.update_cat_info(dm)
#     # generate cfs
#     current_time = time.time()
#     if pred_fn:
#         cfs = cf_module.generate_cfs(X, pred_fn)
#     else:
#         cfs = cf_module.generate_cfs(X, params, rng_key)
#         pred_fn = lambda x: cf_module.predict(deepcopy(params), rng_key, x)
#     total_time = time.time() - current_time

#     return CFExplanationResults(
#         cf_name=cf_module.name,
#         data_module=dm,
#         cfs=cfs,
#         total_time=total_time,
#         pred_fn=pred_fn,
#     )
    # return CFExplanationResults(
    #     X=X, y=y, cfs=cfs, total_time=total_time,
    #     pred_fn=pred_fn,
    #     cf_name=cf_module.name, dataset_name=dm.data_name
    # )



In [None]:
# | export
@deprecated(removed_in='0.1.0', deprecated_in='0.0.9')
def generate_cf_results_local_exp(
    cf_module: BaseCFModule,
    dm: TabularDataModule,
    pred_fn: Callable[[jnp.DeviceArray], jnp.DeviceArray],
) -> CFExplanationResults:
    return generate_cf_explanations(cf_module, dm, pred_fn=pred_fn)


@deprecated(removed_in='0.1.0', deprecated_in='0.0.9')
def generate_cf_results_cfnet(
    cf_module: CounterNet,
    dm: TabularDataModule,
    params: hk.Params = None,  # params of `cf_module`
    rng_key: random.PRNGKey = None,
) -> CFExplanationResults:
    return generate_cf_explanations(cf_module, dm, pred_fn=None)


## Evaluating Metrics

In [None]:
#| export
class BaseEvalMetrics(ABC):
    def __call__(self, cf_explanations: Explanation) -> Any:
        raise NotImplementedError

In [None]:
#| exporti
def _compute_acc(
    input: jnp.DeviceArray, # input dim: [N, k]
    label: jnp.DeviceArray, # label dim: [N] or [N, 1]
    pred_fn: Callable[[jnp.DeviceArray], jnp.DeviceArray]
) -> float:
    y_pred = pred_fn(input).reshape(-1, 1).round()
    label = label.reshape(-1, 1)
    return accuracy(y_pred, label).item()

In [None]:
#| export
class PredictiveAccuracy(BaseEvalMetrics):
    def __call__(self, cf_explanations: Explanation) -> float:
        X, y = cf_explanations.data_module.test_dataset[:]
        return _compute_acc(X, y, cf_explanations.pred_fn)

In [None]:
#| exporti
def _compute_val(
    input: jnp.DeviceArray, # input dim: [N, k]
    cfs: jnp.DeviceArray, # cfs dim: [N, k]
    pred_fn: Callable[[jnp.DeviceArray], jnp.DeviceArray]
):
    y_pred = pred_fn(input).reshape(-1, 1).round()
    y_prime = jnp.ones_like(y_pred) - y_pred
    cf_y = pred_fn(cfs).reshape(-1, 1).round()
    return accuracy(y_prime, cf_y).item()

In [None]:
#| export
class Validity(BaseEvalMetrics):
    def __call__(self, cf_explanations: Explanation) -> float:
        X, _ = cf_explanations.data_module.test_dataset[:]
        return _compute_val(
            X, cf_explanations.cfs, cf_explanations.pred_fn
        )

In [None]:
#| export
class Proximity(BaseEvalMetrics):
    def __call__(self, cf_explanations: Explanation) -> float:
        X, _ = cf_explanations.data_module.test_dataset[:]
        return proximity(X, cf_explanations.cfs)

In [None]:
#| exporti
def _compute_spar(
    input: jnp.DeviceArray,
    cfs: jnp.DeviceArray,
    cat_idx: int
):
    # calculate sparsity
    cat_sparsity = proximity(input[:, cat_idx:], cfs[:, cat_idx:]) / 2
    cont_sparsity = jnp.linalg.norm(
        jnp.abs(input[:, :cat_idx] - cfs[:, :cat_idx]), ord=0, axis=1
    ).mean()
    return cont_sparsity + cat_sparsity


In [None]:
#| export
class Sparsity(BaseEvalMetrics):
    def __call__(self, cf_explanations: Explanation) -> float:
        X, _ = cf_explanations.data_module.test_dataset[:]
        return _compute_spar(X, cf_explanations.cfs, cf_explanations.cat_idx)

In [None]:
#| exporti
def _compute_manifold_dist(
    input: jnp.DeviceArray,
    cfs: jnp.DeviceArray,
    n_neighbors: int = 1,
    p: int = 2
):
    knn = NearestNeighbors(n_neighbors=n_neighbors, p=p)
    knn.fit(input)
    nearest_dist, nearest_points = knn.kneighbors(cfs, 1, return_distance=True)
    return jnp.mean(nearest_dist).item()

In [None]:
#| export
class ManifoldDist(BaseEvalMetrics):
    def __init__(self, n_neighbors: int = 1, p: int = 2):
        self.n_neighbors = n_neighbors
        self.p = p
    
    def __call__(self, cf_explanations: Explanation) -> float:
        X, _ = cf_explanations.data_module.test_dataset[:]
        return _compute_manifold_dist(
            X, cf_explanations.cfs, self.n_neighbors, self.p
        )

In [None]:
#| export
class Runtime(BaseEvalMetrics):
    def __call__(self, cf_explanations: Explanation) -> float:
        return cf_explanations.total_time

In [None]:
#| hide
pred_fn_test = lambda x: jnp.clip(x + 0.5, 0., 1)

x_1 = jnp.array([[0.1], [0.34], [0.4], [-0.2], [0.7]])
cf_1 = jnp.array([[-0.2], [0.4], [-0.1], [0.7], [-0.1]])
y_1 = jnp.array([1., 0., 1., 1., 1.])
x_2 = jnp.array([[-0.5], [0.34], [0.4], [-0.2], [0.7]])
cf_2 = jnp.array([[0.2], [0.4], [-0.1], [0.7], [-0.1]])
y_2 = jnp.array([[0.], [1.], [1.], [1.], [1.]])

_acc_1 = _compute_acc(x_1, y_1, pred_fn_test)
_acc_2 = _compute_acc(x_2, y_2, pred_fn_test)
_val_1 = _compute_val(x_1, cf_1, pred_fn_test)
_val_2 = _compute_val(x_2, cf_2, pred_fn_test)

assert jnp.isclose(_acc_1, 0.6)
assert jnp.isclose(_acc_2, 0.8)
assert jnp.isclose(_val_1, 0.8)
assert jnp.isclose(_val_2, 0.8)



In [None]:
#| export
def _create_second_order_cfs(cf_results: CFExplanationResults, threshold: float = 2.0):
    X, y = cf_results.data_module.test_dataset[:]
    cfs = cf_results.cfs
    scaler = cf_results.data_module.normalizer
    cat_idx = cf_results.data_module.cat_idx

    # get normalized threshold = threshold / (max - min)
    data_range = scaler.data_range_
    thredshold_normed = threshold / data_range

    # select continous features
    x_cont = X[:, :cat_idx]
    cf_cont = cfs[:, :cat_idx]
    # calculate the diff between x and c
    cont_diff = jnp.abs(x_cont - cf_cont) <= thredshold_normed
    # new cfs
    cfs_cont_hat = jnp.where(cont_diff, x_cont, cf_cont)

    cfs_hat = jnp.concatenate((cfs_cont_hat, cfs[:, cat_idx:]), axis=-1)
    return cfs_hat


def compute_so_validity(cf_results: CFExplanationResults, threshold: float = 2.0):
    cfs_hat = _create_second_order_cfs(cf_results, threshold)
    cf_results_so = deepcopy(cf_results)
    cf_results_so.cfs = cfs_hat
    compute_validity = Validity()
    return compute_validity(cf_results_so)


def compute_so_proximity(cf_results: CFExplanationResults, threshold: float = 2.0):
    cfs_hat = _create_second_order_cfs(cf_results, threshold)
    cf_results_so = deepcopy(cf_results)
    cf_results_so.cfs = cfs_hat
    compute_proximity = Proximity()
    return compute_proximity(cf_results_so)


def compute_so_sparsity(cf_results: CFExplanationResults, threshold: float = 2.0):
    cfs_hat = _create_second_order_cfs(cf_results, threshold)
    cf_results_so = deepcopy(cf_results)
    cf_results_so.cfs = cfs_hat
    compute_sparsity = Sparsity()
    return compute_sparsity(cf_results_so)


## Benchmarking

In [None]:
#| export
METRICS = dict(
    acc=PredictiveAccuracy(),
    accuracy=PredictiveAccuracy(),
    validity=Validity(),
    proximity=Proximity(),
    runtime=Runtime(),
    manifold_dist=ManifoldDist(),
    # validity=compute_so_validity,
    # so_proximity=compute_so_proximity,
    # so_sparsity=compute_so_sparsity
)

DEFAULT_METRICS = ["acc", "validity", "proximity"]

In [None]:
#| exporti
def _get_metric(metric: str | callable, cf_exp: Explanation):
    if isinstance(metric, str):
        try:
            res = METRICS[metric](cf_exp)
        except KeyError:
            raise ValueError(f"'{metric}' is not supported. Must be one of {METRICS.keys()}")
    elif callable(metric):
        res = metric(cf_exp)
    else:
        raise ValueError(f"{type(metric).__name__} is not supported as a metric.")
    return res

In [None]:
# | export
def evaluate_cfs(
    cf_exp: Explanation, # CF Explanations
    metrics: Iterable[str | callable] = None, # A list of Metrics. Can be `str` or a subclass of `BaseEvalMetrics`
    return_dict: bool = True, # return a dictionary or not (default: True)
    return_df: bool = False # return a pandas Dataframe or not (default: False)
):
    cf_name = cf_exp.cf_name
    data_name = cf_exp.data_module.data_name
    result_dict = { (data_name, cf_name): dict() }

    if metrics is None:
        metrics = DEFAULT_METRICS

    for metric in metrics:
        result_dict[(data_name, cf_name)][metric] = _get_metric(metric, cf_exp)
    result_df = pd.DataFrame.from_dict(result_dict, orient="index")
    
    if return_dict and return_df:
        return (result_dict, result_df)
    elif return_dict or return_df:
        return result_df if return_df else result_dict


In [None]:
# | export
def benchmark_cfs(
    cf_results_list: Iterable[CFExplanationResults],
    metrics: Optional[Iterable[str]] = None,
):
    dfs = [
        evaluate_cfs(
            cf_exp=cf_results, metrics=metrics, return_dict=False, return_df=True
        )
        for cf_results in cf_results_list
    ]
    return pd.concat(dfs)

## How to evaluate a CF Explanation Module

In [None]:
from cfnet.module import PredictiveTrainingModule
from cfnet.utils import load_json

In [None]:
configs = load_json('assets/configs/data_configs/adult.json')
m_configs = configs['mlp_configs']
data_configs = configs['data_configs']
data_configs['sample_frac'] = 0.1

t_configs = {
    'n_epochs': 10,
    'monitor_metrics': 'val/val_loss',
    'seed': 42,
    "batch_size": 256
} 

We first train a model

In [None]:
training_module = PredictiveTrainingModule(m_configs)
dm = TabularDataModule(data_configs)

params, opt_state = train_model(
    training_module, 
    dm, 
    t_configs
)
pred_fn = lambda x: training_module.forward(params, random.PRNGKey(0), x, is_training=False)

Epoch 9: 100%|██████████| 10/10 [00:00<00:00, 25.01batch/s, train/train_loss_1=0.0618]


Now, we can start to benchmark different methods

In [None]:
from cfnet.methods import VanillaCF, CounterNet

Generate CF explanations for `VanillaCF`

In [None]:
#| slow
vanillacf = VanillaCF(dict(n_steps=1000, lr=0.001))
vanillacf_res = generate_cf_explanations(vanillacf, dm, pred_fn)

100%|██████████| 1000/1000 [00:15<00:00, 65.06it/s]


Generate CF explanations for `CounterNet`

In [None]:
#| slow
#| output: false
counternet = CounterNet()
counternet_res = generate_cf_explanations(counternet, dm, pred_fn=None)

CounterNet contains parametric models. Starts training before generating explanations...


  "`monitor_metrics` is not specified in `CheckpointManager`. No checkpoints will be stored."
Epoch 99: 100%|██████████| 20/20 [00:00<00:00, 63.62batch/s, train/train_loss_1=0.0319, train/train_loss_2=0.000102, train/train_loss_3=0.103]  


Note that `CounterNet` contains a predictive module, so we set `pred_fn=None`

In [None]:
#| hide
#| slow
assert vanillacf_res.dataset_name == dm.data_name, (
    f"vanillacf_res.dataset_name={vanillacf_res.dataset_name}," f"but dm.data_name={dm.data_name}"
)
assert counternet_res.dataset_name == dm.data_name, (
    f"counternet_res.dataset_name={counternet_res.dataset_name}," f"but dm.data_name={dm.data_name}"
)

Now, we can compute metrics for benchmarking different CF explanation methods.

In [None]:
#| slow
evaluate_cfs(vanillacf_res, return_df=True)[1]

Unnamed: 0,Unnamed: 1,acc,validity,proximity
adult,VanillaCF,0.807395,0.178725,8.370252


In [None]:
#| slow
evaluate_cfs(counternet_res, return_df=True)[1]

Unnamed: 0,Unnamed: 1,acc,validity,proximity
adult,CounterNet,0.821029,0.999631,5.764046


In [None]:
#| slow
benchmark_cfs([vanillacf_res, counternet_res])