In [1]:
from ipynb_path import *

In [12]:
from cfnet.import_essentials import *
from cfnet.datasets import TabularDataModule, MinMaxScaler, OneHotEncoder, find_imutable_idx_list, NumpyDataset
from cfnet.train import train_model
from cfnet.training_module import CounterNetTrainingModule, PredictiveTrainingModule
from cfnet.evaluate import generate_cf_results, evaluate_cfs
from cfnet.utils import load_json
from copy import deepcopy

In [5]:
class TabularDataModulePosthoc(TabularDataModule):
    def __init__(self, data_configs: Dict, pred_fn: Callable[[jnp.DeviceArray], jnp.DeviceArray]):
        self.pred_fn = pred_fn
        super().__init__(data_configs)

    def prepare_data(self):
        def split_x_and_y(data: pd.DataFrame):
            X = data[data.columns[:-1]]
            y = data[[data.columns[-1]]]
            return X, y

        X, y = split_x_and_y(self.data)

        # preprocessing
        if self.normalizer:
            X_cont = self.normalizer.transform(X[self.continous_cols])
        else:
            self.normalizer = MinMaxScaler()
            X_cont = self.normalizer.fit_transform(
                X[self.continous_cols]) if self.continous_cols else np.array([[] for _ in range(len(X))])

        if self.encoder:
            X_cat = self.encoder.transform(X[self.discret_cols])
        else:
            self.encoder = OneHotEncoder(sparse=False)
            X_cat = self.encoder.fit_transform(
                X[self.discret_cols]) if self.discret_cols else np.array([[] for _ in range(len(X))])
        X = np.concatenate((X_cont, X_cat), axis=1)
        # get categorical arrays
        self.cat_arrays = self.encoder.categories_ if self.discret_cols else []
        self.imutable_idx_list = find_imutable_idx_list(
            imutable_col_names=self.imutable_cols,
            discrete_col_names=self.discret_cols,
            continuous_col_names=self.continous_cols,
            cat_arrays=self.cat_arrays
        )
        y = self.pred_fn(X)

        # prepare train & test
        train_test_tuple = train_test_split(X, y, shuffle=False)
        train_X, test_X, train_y, test_y = map(lambda x: x.astype(jnp.float32), train_test_tuple)
        if self.sample_frac:
            train_size = int(len(train_X) * self.sample_frac)
            train_X, train_y = train_X[:train_size], train_y[:train_size]
        self.train_dataset = NumpyDataset(train_X, train_y)
        self.val_dataset = NumpyDataset(test_X, test_y)
        self.test_dataset = self.val_dataset


In [7]:
adult_configs = load_json('assets/configs/data_configs/adult.json')
dm = TabularDataModule(adult_configs['data_configs'])
mlp = PredictiveTrainingModule(adult_configs['mlp_configs'])
cfnet = CounterNetTrainingModule(adult_configs['cfnet_configs'])

In [15]:
mlp_t_configs = {
    'n_epochs': 10,
    'monitor_metrics': 'val/val_loss',
    'logger_name': 'pred'
}
cfnet_t_configs = {
    'n_epochs': 100,
    'monitor_metrics': 'val/val_loss',
    'logger_name': 'pred'
}


In [11]:
params, _ = train_model(
    mlp, dm, mlp_t_configs
)

  leaves, treedef = jax.tree_flatten(tree)
  return jax.tree_unflatten(treedef, leaves)
  for x in jax.tree_leaves(state):
Epoch 9: 100%|██████████| 96/96 [00:01<00:00, 71.86batch/s, train/train_loss_1=0.0663]


In [13]:
_params = deepcopy(params)
pred_fn = lambda x: mlp.forward(_params, random.PRNGKey(0), x, is_training=False)

In [14]:
dm_posthoc = TabularDataModulePosthoc(
    adult_configs['data_configs'], pred_fn
)

  leaves, treedef = jax.tree_flatten(tree)
  return jax.tree_unflatten(treedef, leaves)


In [16]:
cfnet_params, _ = train_model(
    cfnet, dm_posthoc, cfnet_t_configs
)

  leaves, treedef = jax.tree_flatten(tree)
  return jax.tree_unflatten(treedef, leaves)
                                                                                                                                         

KeyboardInterrupt: 