# Evaluate

> Evaluating and benchmarking the quality of CF explanations.

In [None]:
# | include: false
# | default_exp evaluate


In [1]:
#| include: false
%load_ext autoreload
%autoreload 2
from ipynb_path import *
from nbdev import show_doc

In [2]:
# | export
from __future__ import annotations
from cfnet.import_essentials import *
from cfnet.train import train_model, TensorboardLogger
from cfnet.datasets import TabularDataModule
from cfnet.utils import accuracy, proximity
from cfnet.methods.base import BaseCFModule, ParametricCFModule
from cfnet.methods.counternet import CounterNet
from copy import deepcopy
from sklearn.neighbors import NearestNeighbors


In [9]:
#| export
#| hide
@dataclass
class Explanation:
    """Generated CF Explanations class."""
    cf_name: str  # cf method's name
    data_module: TabularDataModule  # data module
    cfs: jnp.DeviceArray  # generated cf explanation of `X`
    total_time: float  # total runtime
    pred_fn: Callable[[jnp.DeviceArray], jnp.DeviceArray]  # predict function
    dataset_name: str = str()  # dataset name
    X: jnp.ndarray = None  # input
    y: jnp.ndarray = None  # label

    def __post_init__(self):
        if self.data_module:
            if self.dataset_name == str():
                self.dataset_name = self.data_module.data_name
            test_X, label = self.data_module.test_dataset[:]
            if self.X is None:
                self.X = test_X
            if self.y is None:
                self.y = label

CFExplanationResults = Explanation

In [10]:
show_doc(Explanation)

---

### Explanation

>      Explanation (cf_name:str, data_module:TabularDataModule,
>                   cfs:jnp.DeviceArray, total_time:float,
>                   pred_fn:Callable[[jnp.DeviceArray],jnp.DeviceArray],
>                   dataset_name:str='', X:jnp.ndarray=None, y:jnp.ndarray=None)

Generated CF Explanations

Arguments to `Explanation`:

* `cf_name`: cf method's name
* `dataset_name`: dataset name
* `X`: input
* `y`: label
* `cfs`: generated cf explanation of `X`
* `total_time`: total runtime
* `pred_fn`: predict function
* `data_module`: data module


## Generating CF Explanation Results


In [None]:
#| exporti
def _prepare_module(
    cf_module: BaseCFModule,
    datamodule: TabularDataModule
):
    cf_module.update_cat_info(datamodule)
    return cf_module

def _train_parametric_module(
    cf_module: BaseCFModule,
    datamodule: TabularDataModule,
    t_configs=None
):
    print(f'{type(cf_module)} contains parametric models. '
        'Starts training before generating explanations...')
    cf_module.train(datamodule, t_configs)
    return cf_module

def _check_pred_fn(pred_fn, cf_module):
    if pred_fn is None:
        try:
            pred_fn = cf_module.pred_fn
        except AttributeError:
            raise AttributeError(
                    "`generate_cf_explanations` is incorrectly configured."
                    f"It is supposed to be `pred_fn != None`,"
                    f"or {type(cf_module)} has attribute `pred_fn`."
                    f"However, we got `pred_fn={pred_fn}`, "
                    f"and cf_module=`{type(cf_module)}` contains no `pred_fn`."
            )
    return pred_fn

In [None]:
#| export
def generate_cf_explanations(
    cf_module: BaseCFModule,
    datamodule: TabularDataModule,
    pred_fn: Callable[[jnp.DeviceArray], jnp.DeviceArray] = None,
    *,
    t_configs=None
) -> Explanation:
    """Generate CF explanations."""
    cf_module = _prepare_module(cf_module, datamodule)

    if isinstance(cf_module, ParametricCFModule):
        cf_module = _train_parametric_module(
            cf_module, datamodule, t_configs=t_configs
        )
    X, _ = datamodule.test_dataset[:]

    # generate cfs
    current_time = time.time()
    cfs = cf_module.generate_cfs(X, pred_fn=pred_fn)
    total_time = time.time() - current_time
    # check pred_fn
    pred_fn = _check_pred_fn(pred_fn, cf_module)

    return Explanation(
        cf_name=cf_module.name,
        data_module=datamodule,
        cfs=cfs,
        total_time=total_time,
        pred_fn=pred_fn,
    )

# def generate_cf_results(
#     cf_module: BaseCFExplanationModule,
#     dm: TabularDataModule,
#     pred_fn: Callable[[jnp.DeviceArray], jnp.DeviceArray] = None,
#     params: hk.Params = None,  # params of `cf_module`
#     rng_key: Optional[random.PRNGKey] = None,
# ) -> CFExplanationResults:
#     # validate arguments
#     if (pred_fn is None) and (params is None) and (rng_key is None):
#         raise ValueError(
#             "A valid `pred_fn: Callable[jnp.DeviceArray], jnp.DeviceArray]` or `params: hk.Params` needs to be passed."
#         )
#     # prepare
#     X, y = dm.test_dataset[:]
#     cf_module.update_cat_info(dm)
#     # generate cfs
#     current_time = time.time()
#     if pred_fn:
#         cfs = cf_module.generate_cfs(X, pred_fn)
#     else:
#         cfs = cf_module.generate_cfs(X, params, rng_key)
#         pred_fn = lambda x: cf_module.predict(deepcopy(params), rng_key, x)
#     total_time = time.time() - current_time

#     return CFExplanationResults(
#         cf_name=cf_module.name,
#         data_module=dm,
#         cfs=cfs,
#         total_time=total_time,
#         pred_fn=pred_fn,
#     )
    # return CFExplanationResults(
    #     X=X, y=y, cfs=cfs, total_time=total_time,
    #     pred_fn=pred_fn,
    #     cf_name=cf_module.name, dataset_name=dm.data_name
    # )



In [None]:
# | export
@deprecated(removed_in='0.1.0', deprecated_in='0.0.9')
def generate_cf_results_local_exp(
    cf_module: BaseCFModule,
    dm: TabularDataModule,
    pred_fn: Callable[[jnp.DeviceArray], jnp.DeviceArray],
) -> CFExplanationResults:
    return generate_cf_explanations(cf_module, dm, pred_fn=pred_fn)


@deprecated(removed_in='0.1.0', deprecated_in='0.0.9')
def generate_cf_results_cfnet(
    cf_module: CounterNet,
    dm: TabularDataModule,
    params: hk.Params = None,  # params of `cf_module`
    rng_key: random.PRNGKey = None,
) -> CFExplanationResults:
    return generate_cf_explanations(cf_module, dm, pred_fn=None)


## Evaluating Metrics

In [None]:
# | export
def compute_predictive_acc(cf_results: CFExplanationResults):
    X, y = cf_results.data_module.test_dataset[:]
    pred_fn = cf_results.pred_fn

    y_pred = pred_fn(X).reshape(-1, 1)
    label = y.reshape(-1, 1)
    return accuracy(jnp.round(y_pred), label).item()


def compute_validity(cf_results: CFExplanationResults):
    X, y = cf_results.data_module.test_dataset[:]
    pred_fn = cf_results.pred_fn

    y_pred = pred_fn(X).reshape(-1, 1).round()
    y_prime = 1 - y_pred
    cf_y = pred_fn(cf_results.cfs).reshape(-1, 1).round()
    return accuracy(y_prime, cf_y).item()


def compute_proximity(cf_results: CFExplanationResults):
    X, y = cf_results.data_module.test_dataset[:]
    return proximity(X, cf_results.cfs).item()


def compute_sparsity(cf_results: CFExplanationResults):
    X, y = cf_results.data_module.test_dataset[:]
    cfs = cf_results.cfs
    cat_idx = cf_results.data_module.cat_idx
    # calculate sparsity
    cat_sparsity = proximity(X[:, cat_idx:], cfs[:, cat_idx:]) / 2
    cont_sparsity = jnp.linalg.norm(
        jnp.abs(X[:, :cat_idx] - cfs[:, :cat_idx]), ord=0, axis=1
    ).mean()
    return cont_sparsity + cat_sparsity


def compute_manifold_dist(cf_results: CFExplanationResults):
    X, y = cf_results.data_module.test_dataset[:]
    cfs = cf_results.cfs
    knn = NearestNeighbors()
    knn.fit(X)
    nearest_dist, nearest_points = knn.kneighbors(cfs, 1, return_distance=True)
    return jnp.mean(nearest_dist).item()


def get_runtime(cf_results: CFExplanationResults):
    return cf_results.total_time


def _create_second_order_cfs(cf_results: CFExplanationResults, threshold: float = 2.0):
    X, y = cf_results.data_module.test_dataset[:]
    cfs = cf_results.cfs
    scaler = cf_results.data_module.normalizer
    cat_idx = cf_results.data_module.cat_idx

    # get normalized threshold = threshold / (max - min)
    data_range = scaler.data_range_
    thredshold_normed = threshold / data_range

    # select continous features
    x_cont = X[:, :cat_idx]
    cf_cont = cfs[:, :cat_idx]
    # calculate the diff between x and c
    cont_diff = jnp.abs(x_cont - cf_cont) <= thredshold_normed
    # new cfs
    cfs_cont_hat = jnp.where(cont_diff, x_cont, cf_cont)

    cfs_hat = jnp.concatenate((cfs_cont_hat, cfs[:, cat_idx:]), axis=-1)
    return cfs_hat


def compute_so_validity(cf_results: CFExplanationResults, threshold: float = 2.0):
    cfs_hat = _create_second_order_cfs(cf_results, threshold)
    cf_results_so = deepcopy(cf_results)
    cf_results_so.cfs = cfs_hat
    return compute_validity(cf_results_so)


def compute_so_proximity(cf_results: CFExplanationResults, threshold: float = 2.0):
    cfs_hat = _create_second_order_cfs(cf_results, threshold)
    cf_results_so = deepcopy(cf_results)
    cf_results_so.cfs = cfs_hat
    return compute_proximity(cf_results_so)


def compute_so_sparsity(cf_results: CFExplanationResults, threshold: float = 2.0):
    cfs_hat = _create_second_order_cfs(cf_results, threshold)
    cf_results_so = deepcopy(cf_results)
    cf_results_so.cfs = cfs_hat
    return compute_sparsity(cf_results_so)


In [None]:
# | export
metrics2fn = {
    "acc": compute_predictive_acc,
    "validity": compute_validity,
    "proximity": compute_proximity,
    "runtime": get_runtime,
    "manifold_dist": compute_manifold_dist,
    "so_validity": compute_so_validity,
    "so_proximity": compute_so_proximity,
    "so_sparsity": compute_so_sparsity,
}


In [None]:
# | export
DEFAULT_METRICS = ["acc", "validity", "proximity"]


def evaluate_cfs(
    cf_results: CFExplanationResults,
    metrics: Optional[Iterable[str]] = None,
    return_dict: bool = True,
    return_df: bool = False,
):
    cf_name = cf_results.cf_name
    data_name = cf_results.data_module.data_name
    result_dict = {(data_name, cf_name): dict()}
    if metrics is None:
        metrics = DEFAULT_METRICS

    for metric in metrics:
        result_dict[(data_name, cf_name)][metric] = metrics2fn[metric](cf_results)
    result_df = pd.DataFrame.from_dict(result_dict, orient="index")
    if return_dict and return_df:
        return (result_dict, result_df)
    elif return_dict or return_df:
        return result_df if return_df else result_dict


In [None]:
# | export
def benchmark_cfs(
    cf_results_list: Iterable[CFExplanationResults],
    metrics: Optional[Iterable[str]] = None,
):
    dfs = [
        evaluate_cfs(
            cf_results=cf_results, metrics=metrics, return_dict=False, return_df=True
        )
        for cf_results in cf_results_list
    ]
    return pd.concat(dfs)

## Test

### VanillaCF

In [None]:
data_configs = {
    "data_dir": "assets/data/s_adult.csv",
    "data_name": "adult",
    "batch_size": 256,
    "sample_frac": 0.1,
    "continous_cols": ["age", "hours_per_week"],
    "discret_cols": [
        "workclass",
        "education",
        "marital_status",
        "occupation",
        "race",
        "gender",
    ],
}
m_configs = {"sizes": [50, 10, 50], "lr": 0.003, "dropout_rate": 0.3}


In [None]:
from cfnet.training_module import PredictiveTrainingModule
from cfnet.utils import load_json


In [None]:

configs = load_json('assets/configs/data_configs/adult.json')
m_configs = configs['mlp_configs']
data_configs = configs['data_configs']

t_configs = {
    'n_epochs': 2,
    'monitor_metrics': 'val/val_loss',
    'seed': 42,
    "batch_size": 256
} 

training_module = PredictiveTrainingModule(m_configs)
dm = TabularDataModule(data_configs)

params, opt_state = train_model(
    training_module, 
    dm, 
    t_configs
)
pred_fn = lambda x: training_module.forward(params, random.PRNGKey(0), x, is_training=False)

In [None]:
from cfnet.methods import VanillaCF


In [None]:
cf_exp = VanillaCF({"n_steps": 1000})
cf_res = generate_cf_results_local_exp(cf_exp, dm, pred_fn)
assert cf_res.dataset_name == dm.data_name, (
    f"cf_res.dataset_name={cf_res.dataset_name}," f"but dm.data_name={dm.data_name}"
)


100%|██████████| 1000/1000 [00:06<00:00, 153.05it/s]


In [None]:
compute_validity(cf_res)


0.7997788786888123

In [None]:
evaluate_cfs(cf_res, return_df=True)[1]


Unnamed: 0,acc,validity,proximity
VanillaCF,0.822012,0.799779,6.929937


### CounterNet

In [None]:
m_configs = {
    "enc_sizes": [50, 10],
    "dec_sizes": [10],
    "exp_sizes": [50, 50],
    "dropout_rate": 0.3,
    "lr": 0.003,
    "lambda_1": 1.0,
    "lambda_3": 0.1,
    "lambda_2": 0.2,
}
t_configs = {
    "n_epochs": 1,
    "monitor_metrics": "val/val_loss",
    "seed": 42,
    "batch_size": 256,
}


In [None]:
from cfnet.training_module import CounterNetTrainingModule


In [None]:

configs = load_json("assets/configs/data_configs/adult.json")
m_configs = configs["cfnet_configs"]
data_configs = configs["data_configs"]

t_configs = {
    "n_epochs": 2,
    "monitor_metrics": "val/val_loss",
    "seed": 42,
    "batch_size": 256,
}

cfnet = CounterNetTrainingModule(m_configs)
dm = TabularDataModule(data_configs)

params, opt_state = train_model(cfnet, dm, t_configs)


Epoch 99: 100%|██████████| 10/10 [00:00<00:00, 64.40batch/s, train/train_loss_1=0.0485, train/train_loss_2=0.000302, train/train_loss_3=0.105]


In [None]:
cf_res = generate_cf_results_cfnet(cfnet, dm, params, random.PRNGKey(0))
assert cf_res.dataset_name == dm.data_name, (
    f"cf_res.dataset_name={cf_res.dataset_name}," f"but dm.data_name={dm.data_name}"
)


In [None]:
evaluate_cfs(cf_res, return_df=True)[1]


Unnamed: 0,acc,validity,proximity
CounterNet,0.81931,0.997912,6.497492
