In [None]:
#| default_exp experiment

In [None]:
#| include: false
%load_ext autoreload
%autoreload 2
from ipynb_path import *

In [None]:
#| export
from cfnet.import_essentials import *
from cfnet.datasets import TabularDataModule
from cfnet.train import train_model, train_model_with_states
from cfnet.evaluate import evaluate_cfs, benchmark_cfs, generate_cf_results, generate_cf_results_local_exp, generate_cf_results_cfnet, CFExplanationResults
from cfnet.training_module import BaseTrainingModule, CounterNetTrainingModule
from cfnet.interfaces import BaseCFExplanationModule, LocalCFExplanationModule
from cfnet.utils import accuracy
from rocourset_net.training_module import RoCourseNetTrainingModule
from copy import deepcopy
from functools import partial
from sklearn.preprocessing import StandardScaler,MinMaxScaler,OneHotEncoder
import wandb
from pydantic import validator

### Evaluation Metrics

In [None]:
#| export
def compute_rob_validity(cf_results: CFExplanationResults, shifted_pred_fn: Callable):
    pred_fn = cf_results.pred_fn
    y_pred = pred_fn(cf_results.X).reshape(-1, 1).round()
    y_prime = 1. - y_pred
    cf_y = shifted_pred_fn(cf_results.cfs).reshape(-1, 1).round()

    rob_validity = accuracy(y_prime, cf_y)
    return rob_validity.item()

### Util Methods 

In [None]:
#| export
def _aggregate_default_data_encoders(default_data_config: Dict[str, Any], data_dir_list: List[str]):
    # data encoding
    data = pd.concat(
        [pd.read_csv(data_dir) for data_dir in data_dir_list]
    )
    print(f"total data length: {len(data)}")
    if len(default_data_config['continous_cols']) != 0:
        print("preprocessing continuous features...")
        normalizer = MinMaxScaler().fit(
            data[default_data_config['continous_cols']]
        )
        default_data_config.update({"normalizer": normalizer})

    if len(default_data_config['discret_cols']) != 0:
        print("preprocessing discret features...")
        encoder = OneHotEncoder(sparse=False).fit(
            data[default_data_config['discret_cols']]
        )
        default_data_config.update({"encoder": encoder})
    return default_data_config

In [None]:
#| export
def calculate_validity_matrix(
    cf_results_list: Iterable[CFExplanationResults]
) -> pd.DataFrame:
    validity_matrix_dict = {}
    for i, cf_results_i in enumerate(cf_results_list):
        data_name = cf_results_i.dataset_name
        rob_validity = {}
        for j, cf_results_j in enumerate(cf_results_list):
            cf_name_j = cf_results_j.cf_name
            shifted_pred_fn = cf_results_j.pred_fn
            val = compute_rob_validity(
                cf_results_i, shifted_pred_fn
            )
            rob_validity[cf_name_j] = val
            # print(f'data_name: {data_name}; cf_name: {cf_name_j}; shifted_pred_fn: {shifted_pred_fn}; val: {val} ')
        validity_matrix_dict[data_name] = rob_validity
    return pd.DataFrame.from_dict(validity_matrix_dict)

### Training Models

In [None]:
#| export
def train_models(
    training_module: BaseTrainingModule,
    default_data_configs: Dict[str, Any],
    data_dir_list: List[str],
    t_configs: Dict[str, Any],
    return_data_module_list: bool = False
):
    model_params_opt_list = []
    data_module_list = []
    for i, data_dir in enumerate(data_dir_list):
        data_config = deepcopy(default_data_configs)
        data_config['data_dir'] = data_dir
        dm = TabularDataModule(data_config)
        params, opt_state = training_module.init_net_opt(
            dm, random.PRNGKey(0)
        )

        params, opt_state = train_model_with_states(
            training_module, params, opt_state, dm, t_configs
        )
        model_params_opt_list.append((params, opt_state))
        data_module_list.append(dm)

    if return_data_module_list:
        return model_params_opt_list, data_module_list
    else:
        return model_params_opt_list

### Log to Wandb

In [None]:
#| export
def wandb_vis_table(table: pd.DataFrame):
    return wandb.Table(dataframe=table)

def wandb_vis_heatmap(heatmap: pd.DataFrame):
    return wandb.plots.HeatMap(x_labels=heatmap.columns, y_labels=heatmap.index, matrix_values=heatmap.values, show_text=False)

DATA_TYPE_TO_VIS_FN = {
    'table': wandb_vis_table,
    'heatmap': wandb_vis_heatmap
}

In [None]:
#| exporti
class ExperimentResult(BaseParser):
    name: str
    data_type: str
    data: Any

    @validator('data_type')
    def validate_data_type(cls, v):
        if v not in DATA_TYPE_TO_VIS_FN.keys():
            raise ValueError(f"`data_type` should be one of {DATA_TYPE_TO_VIS_FN.keys()}, but got {v}")
        return v

    def wandb_vis(self):
        return DATA_TYPE_TO_VIS_FN[self.data_type](self.data)

In [None]:
#| export
class ExperimentLogger(ABC):
    @abstractmethod
    def store_results(self, results: List[ExperimentResult]):
        raise NotImplementedError

In [None]:
#| export
class ExperimentLoggerWanbConfigs(BaseParser):
    project_name: str                           # `project`
    user_name: str                              # `entity`
    experiment_name: str                        # `name`
    hparams: Optional[Dict[str, Any]] = None    # hypterparamters


class ExperimentLoggerWanb(ExperimentLogger):
    def __init__(self, configs: ExperimentLoggerWanbConfigs):
        super().__init__()
        self.run = wandb.init(
            project=configs.project_name, entity=configs.user_name, name=configs.experiment_name, config=configs.hparams,
            settings=wandb.Settings(start_method="fork"))

    def store_results(self, results: List[ExperimentResult]):
        with self.run as run:
            run.log({
                r.name: r.wandb_vis() for r in results
            })
        return self.run.dir

### Experiment

In [None]:
#| exporti
def calculate_validity_changes(val_matrix_df: pd.DataFrame):
    assert len(val_matrix_df.columns) == len(val_matrix_df.index), \
        f"val_matrix_df.columns={val_matrix_df.columns}, but val_matrix_df.index={val_matrix_df.index}"
    matrix = val_matrix_df.values
    n_datasets = len(matrix[0])

    validity = [matrix[i][i] for i in range(n_datasets)]
    w1_val_result, w1_dec_result = [], []
    for i in range(n_datasets - 1):
        w1_val_result.append(matrix[i][i+1])
        w1_dec_result.append(matrix[i][i] - matrix[i][i+1])

    wall_val_result, wall_dec_result =  [], []
    for i in range(n_datasets):
        for j in range(n_datasets):
            if j == i: continue
            wall_val_result.append(matrix[i][j])
            wall_dec_result.append(matrix[i][i] - matrix[i][j])
    result = {
        'cf_validity': { 'mean': np.average(validity), 'std': np.std(validity) },
        'cf_validity (w=1)': { 'mean': np.average(w1_val_result), 'std': np.std(w1_val_result) },
        'cf_validity (all)': { 'mean': np.average(wall_val_result), 'std': np.std(wall_val_result) },
        'validity_decrease (w=1)': { 'mean': np.average(w1_dec_result), 'std': np.std(w1_dec_result) },
        'validity_decrease (all)': { 'mean': np.average(wall_dec_result), 'std': np.std(wall_dec_result) }
    }
    return pd.DataFrame.from_dict(result)

In [None]:
#| exporti
def _evaluate_adversarial_model(
    cf_results_list: Iterable[CFExplanationResults],
    experiment_logger_configs: Optional[ExperimentLoggerWanbConfigs] = None
):
    cf_results_df = benchmark_cfs(cf_results_list)
    cf_aggre_df = cf_results_df.describe().loc[['mean', 'std']].reset_index().rename(columns={'index': 'stat'})

    print("calculating the validity matrix...")
    validity_matrix_df = calculate_validity_matrix(cf_results_list=cf_results_list)
    valditity_changes = calculate_validity_changes(validity_matrix_df)

    experiment_results = [
        ExperimentResult(name='CF Results', data_type='table', data=cf_results_df),
        ExperimentResult(name='CF Metrics', data_type='table', data=cf_aggre_df),
        ExperimentResult(name='Heatmap', data_type='heatmap', data=validity_matrix_df),
        ExperimentResult(name='Validity Matrix', data_type='table', data=validity_matrix_df),
        ExperimentResult(name='Validity Changes', data_type='table', data=valditity_changes),
    ]

    if experiment_logger_configs:
        logger = ExperimentLoggerWanb(experiment_logger_configs)
        dir_path = logger.store_results(experiment_results)
        print(f"Results stored at {dir_path}")
    return experiment_results

In [None]:
#| export
def adversarial_experiment(
    pred_training_module: BaseTrainingModule,
    cf_module: BaseCFExplanationModule,
    default_data_config: Dict[str, Any],
    data_dir_list: List[str],
    t_config: Dict[str, Any],
    is_local_cf_module: bool,
    use_prev_model_params: bool = False,
    return_best_model: bool = False, # return last model by default
    experiment_logger_configs: Optional[ExperimentLoggerWanbConfigs] = None
):
    hparams = deepcopy(default_data_config)
    hparams.update(pred_training_module.hparams)
    hparams.update(t_config)
    if experiment_logger_configs:
        experiment_logger_configs.hparams = hparams

    # data encoding
    print("aggregating data...")
    default_data_config = _aggregate_default_data_encoders(default_data_config, data_dir_list)

    # training models
    print("start training...")
    model_params_opt_list, data_module_list = train_models(
        pred_training_module, default_data_config,
        data_dir_list, t_config, return_data_module_list=True
    )

    # evaluate cfs
    print("generating cfs...")
    experiment_results: List[ExperimentResult] = []
    cf_results_list = []

    # generate_cf_results_fn
    def generate_cf_results_fn(cf_module, dm, params, rng_key):
        _params = deepcopy(params)
        pred_fn = lambda x: pred_training_module.forward(_params, rng_key, x)
        return generate_cf_results_local_exp(cf_module, dm, pred_fn=pred_fn) if is_local_cf_module \
            else generate_cf_results_cfnet(cf_module, dm, params=_params, rng_key=rng_key)

    for i, ((params, _), dm) in enumerate(zip(model_params_opt_list, data_module_list)):
        cf_results = generate_cf_results_fn(
            cf_module, dm,
            params=params, rng_key=random.PRNGKey(0)
        )
        cf_results.cf_name = f"model_{i}"
        cf_results.dataset_name = f"data_{i}"
        cf_results_list.append(cf_results)

    experiment_results = _evaluate_adversarial_model(
        cf_results_list, experiment_logger_configs=experiment_logger_configs)
    return cf_results_list, experiment_results

In [None]:
#| export
def adversarial_experiment_cfnet(
    training_module: CounterNetTrainingModule,
    default_data_config: Dict[str, Any],
    data_dir_list: List[str],
    t_config: Dict[str, Any],
    use_prev_model_params: bool = False,
    return_best_model: bool = False, # return last model by default
    experiment_logger_configs: Optional[ExperimentLoggerWanbConfigs] = None
):
    return adversarial_experiment(
        pred_training_module=training_module,
        cf_module=training_module,
        default_data_config=default_data_config,
        data_dir_list=data_dir_list,
        t_config=t_config,
        is_local_cf_module=False,
        use_prev_model_params=use_prev_model_params,
        return_best_model=return_best_model,
        experiment_logger_configs=experiment_logger_configs
    )

def adversarial_experiment_local_exp(
    pred_training_module: CounterNetTrainingModule,
    cf_moudle: LocalCFExplanationModule,
    default_data_config: Dict[str, Any],
    data_dir_list: List[str],
    t_config: Dict[str, Any],
    use_prev_model_params: bool = False,
    return_best_model: bool = False, # return last model by default
    experiment_logger_configs: Optional[ExperimentLoggerWanbConfigs] = None
):
    return adversarial_experiment(
        pred_training_module=pred_training_module,
        cf_module=cf_moudle,
        default_data_config=default_data_config,
        data_dir_list=data_dir_list,
        t_config=t_config,
        is_local_cf_module=True,
        use_prev_model_params=use_prev_model_params,
        return_best_model=return_best_model,
        experiment_logger_configs=experiment_logger_configs
    )

### Test

In [None]:
cf_results_list = [
    CFExplanationResults(
        cf_name='m1',
        dataset_name='d1',
        X=jnp.array([
            [1, 0, 1],
            [1, 1, 0],
            [0, 1, 0],
            [0, 1, 1],
        ]), # y_pred = [1, 1, 0, 0]
        y=jnp.ones((4,1)),
        cfs=jnp.array([
            [0, 0, 1],
            [0, 1, 0],
            [1, 1, 0],
            [1, 1, 1]
        ]),
        pred_fn=lambda x: x[:, 0], 
        total_time=0.1
    ),
    CFExplanationResults(
        cf_name='m2',
        dataset_name='d2',
        X=jnp.array([
            [1, 0, 1],
            [1, 1, 0],
            [0, 1, 1],
            [1, 0, 0]
        ]), # y_pred = [1, 0, 1, 0]
        y=jnp.ones((4,1)),
        cfs=jnp.array([
            [1, 0, 0],
            [1, 1, 1],
            [0, 1, 0],
            [1, 0, 1]
        ]),
        pred_fn=lambda x: x[:, -1],
        total_time=0.1
    ),
]



In [None]:
assert compute_rob_validity(cf_results_list[0], cf_results_list[0].pred_fn) == 1.0
assert compute_rob_validity(cf_results_list[0], cf_results_list[1].pred_fn) == 0.5
assert compute_rob_validity(cf_results_list[1], cf_results_list[0].pred_fn) == 0.75
assert compute_rob_validity(cf_results_list[1], cf_results_list[1].pred_fn) == 1.0

In [None]:
calculate_validity_matrix(cf_results_list)

Unnamed: 0,d1,d2
m1,1.0,0.75
m2,0.5,1.0


In [None]:
data_config = {
    "data_name": "loan",
    "continous_cols": [
        "NoEmp", "NewExist", "CreateJob", "RetainedJob", "DisbursementGross", "GrAppv", "SBA_Appv"
    ],
    "discret_cols": [
        "State", "Term", "UrbanRural", "LowDoc", "Sector_Points"
    ],
    "batch_size": 128,
    'sample_frac': 0.1,
}
m_config = {
    "enc_sizes": [200,10],
    "dec_sizes": [10],
    "exp_sizes": [50],
    "dropout_rate": 0.3,    
    'lr': 0.003,
    "lambda_1": 1.0,
    "lambda_3": 0.1,
    "lambda_2": 0.2,
}
t_configs = {
    # 'n_epochs': 100,
    'n_epochs': 10,
    'monitor_metrics': 'val/val_loss'
}

In [None]:
experiment_logger_configs = ExperimentLoggerWanbConfigs(
    project_name='debug',
    user_name='birkhoffg',
    experiment_name='rocoursenet',
    
)
cf_results_list, experiment_results = adversarial_experiment(
    training_module=RoCourseNetTrainingModule(m_config),
    default_data_config=data_config,
    data_dir_list=[ 
        f"assets/data/loan/year={year}.csv" for year in range(2008, 2010) 
    ],
    # data_dir_list=[ 
    #     f"assets/data/loan/year={year}.csv" for year in range(1994, 2010) 
    # ],
    t_config=t_configs,
    # experiment_logger_configs=experiment_logger_configs
)

aggregating data...
total data length: 38341
preprocessing continuous features...
preprocessing discret features...
start training...


Epoch 9: 100%|██████████| 17/17 [00:00<00:00, 52.24batch/s, train/train_loss_1=0.38231245, train/train_loss_2=4.258957e-09, train/train_loss_3=0.051329758]   
Epoch 9: 100%|██████████| 7/7 [00:00<00:00, 61.73batch/s, train/train_loss_1=0.068208955, train/train_loss_2=0.120453365, train/train_loss_3=0.04799317]


generating cfs...
calculating the validity matrix...


In [None]:
calculate_validity_matrix(cf_results_list)

Unnamed: 0,data_0,data_1
model_0,0.994878,0.995641
model_1,0.695449,0.994915
