# Evaluate

> Evaluating and benchmarking the quality of CF explanations. 

In [None]:
# hide
# default_exp evaluate

In [None]:
# hide
%load_ext autoreload
%autoreload 2
from ipynb_path import *

In [None]:
# export
from cfnet.import_essentials import *
from cfnet.train import train_model, TensorboardLogger
from cfnet.datasets import TabularDataModule
from cfnet.interfaces import BaseCFExplanationModule, LocalCFExplanationModule
from copy import deepcopy

In [None]:
# export
@dataclass
class CFExplanationResults:
    cf_name: str        # cf method's name
    dataset_name: str   # dataset name
    X: jnp.DeviceArray  # input
    y: jnp.DeviceArray  # label
    cfs: jnp.DeviceArray # generated cf explanation of `X`
    total_time: float   # total runtime
    pred_fn: Callable[[jnp.DeviceArray], jnp.DeviceArray] # predict function

## Generating CF Explanation Results


In [None]:
# export
def generate_cf_results(
    cf_module: BaseCFExplanationModule,
    dm: TabularDataModule,
    pred_fn: Optional[Callable[[jnp.DeviceArray], jnp.DeviceArray]] = None,
    params: Optional[hk.Params] = None, # params of `cf_module`
    rng_key: Optional[random.PRNGKey] = None 
) -> CFExplanationResults:
    # validate arguments
    if (pred_fn is None) and (params is None) and (rng_key is None):
        raise ValueError("A valid `pred_fn: Callable[jnp.DeviceArray], jnp.DeviceArray]` or `params: hk.Params` needs to be passed.")

    X, y = dm.test_dataset[:]
    current_time = time.time()
    if pred_fn:
        cfs = cf_module.generate_cfs(X, pred_fn)
    else:
        cfs = cf_module.generate_cfs(X, params, rng_key)
        pred_fn = lambda x: cf_module.predict(deepcopy(params), rng_key, x)
    total_time = time.time() - current_time

    return CFExplanationResults(
        X=X, y=y, cfs=cfs, total_time=total_time,
        pred_fn=pred_fn,
        cf_name=cf_module.name, dataset_name=dm.data_name
    )

In [None]:
# export
def generate_cf_results_local_exp(
    cf_module: LocalCFExplanationModule, 
    dm: TabularDataModule, 
    pred_fn: Callable[[jnp.DeviceArray], jnp.DeviceArray]
) -> CFExplanationResults:
    return generate_cf_results(cf_module, dm, pred_fn=pred_fn)

def generate_cf_results_cfnet(
    cf_module: LocalCFExplanationModule, 
    dm: TabularDataModule, 
    params: Optional[hk.Params] = None, # params of `cf_module`
    rng_key: Optional[random.PRNGKey] = None 
) -> CFExplanationResults:
    return generate_cf_results(cf_module, dm, params=params, rng_key=rng_key)

## Evaluating Metrics

In [None]:
# export
def _compute_acc(x: jnp.ndarray, y: jnp.ndarray):
    return jnp.sum(x == y) / len(x)

In [None]:
# export
def compute_predictive_acc(cf_results: CFExplanationResults):
    pred_fn = cf_results.pred_fn
    y_pred = pred_fn(cf_results.X).reshape(-1, 1)
    label = cf_results.y.reshape(-1, 1)
    return _compute_acc(jnp.round(y_pred), label).item()

def compute_validity(cf_results: CFExplanationResults):
    pred_fn = cf_results.pred_fn
    y_pred = pred_fn(cf_results.X).reshape(-1, 1).round()
    y_prime = 1 - y_pred
    cf_y = pred_fn(cf_results.cfs).reshape(-1, 1).round()
    return _compute_acc(y_prime, cf_y).item()

def compute_proximity(cf_results: CFExplanationResults):
    return jnp.abs(cf_results.X - cf_results.cfs).sum(axis=1).mean().item()

def get_runtime(cf_results: CFExplanationResults):
    return cf_results.total_time

In [None]:
# export
metrics2fn = {
    "acc": compute_predictive_acc,
    "validity": compute_validity,
    "proximity": compute_proximity,
    "runtime": get_runtime
}

In [None]:
# export
DEFAULT_METRICS = ['acc', 'validity', 'proximity']

def evaluate_cfs(cf_results: CFExplanationResults,
                 metrics: Optional[List[str]] = None,
                 return_dict: bool = True,
                 return_df: bool = False):
    cf_name = cf_results.cf_name
    result_dict = {
        cf_name: dict()
    }
    if metrics is None:
        metrics = DEFAULT_METRICS

    for metric in metrics:
        result_dict[cf_name][metric] = metrics2fn[metric](cf_results)
    result_df = pd.DataFrame.from_dict(result_dict, orient='index')
    if return_dict and return_df:
        return (result_dict, result_df)
    elif return_dict or return_df:
        return result_df if return_df else result_dict

In [None]:
# export
def benchmark_cfs(cf_results_list: Iterable[CFExplanationResults],
                  metrics: Optional[List[str]] = None):
    dfs = [
        evaluate_cfs(cf_results=cf_results, metrics=metrics, return_dict=False, return_df=True)
            for cf_results in cf_results_list
    ]
    return pd.concat(dfs)

## Test

### VanillaCF

In [None]:
data_configs = {
    "data_dir": "assets/data/s_adult.csv",
    "data_name": "adult",
    "batch_size": 256,
    'sample_frac': 0.1,
    "continous_cols": [
        "age",
        "hours_per_week"
    ],
    "discret_cols": [
        "workclass",
        "education",
        "marital_status",
        "occupation",
        "race",
        "gender"
    ],
}
m_configs = {
    "sizes": [50, 10, 50],
    'lr': 0.003,
    "dropout_rate": 0.3
}
t_configs = {
    'n_epochs': 20,
    'monitor_metrics': 'val/val_loss'
}

In [None]:
from cfnet.training_module import PredictiveTrainingModule

training_module = PredictiveTrainingModule(m_configs)
dm = TabularDataModule(data_configs)

params, opt_state = train_model(
    training_module, 
    dm, 
    t_configs
)
pred_fn = lambda x: training_module.forward(params, random.PRNGKey(0), x, is_training=False)


Epoch 19: 100%|██████████| 10/10 [00:00<00:00, 62.79batch/s, train/train_loss_1=0.0653]


In [None]:
cf_exp = VanillaCF({ "n_steps": 1000, 'pred_fn': pred_fn })
cf_res = generate_cf_results_local_exp(cf_exp, dm)

100%|██████████| 1000/1000 [00:06<00:00, 153.05it/s]


In [None]:
compute_validity(cf_res)

0.7997788786888123

In [None]:
evaluate_cfs(
    cf_res, return_df=True
)[1]

Unnamed: 0,acc,validity,proximity
VanillaCF,0.822012,0.799779,6.929937


### CounterNet

In [None]:
m_configs = {
    "enc_sizes": [50,10],
    "dec_sizes": [10],
    "exp_sizes": [50, 50],
    "dropout_rate": 0.3,    
    'lr': 0.003,
    "lambda_1": 1.0,
    "lambda_3": 0.1,
    "lambda_2": 0.2,
}
t_configs = {
    'n_epochs': 100,
    'monitor_metrics': 'val/val_loss'
}


In [None]:
from cfnet.training_module import CounterNetTrainingModule

cfnet = CounterNetTrainingModule(m_configs)
dm = TabularDataModule(data_configs)

params, opt_state = train_model(
    cfnet, dm, 
    t_configs
)



Epoch 99: 100%|██████████| 10/10 [00:00<00:00, 64.40batch/s, train/train_loss_1=0.0485, train/train_loss_2=0.000302, train/train_loss_3=0.105]


In [None]:
cf_res = generate_cf_results(cfnet, dm, params, random.PRNGKey(0))

In [None]:
evaluate_cfs(
    cf_res, return_df=True
)[1]

Unnamed: 0,acc,validity,proximity
CounterNet,0.81931,0.997912,6.497492
