# Training Module

> Training modules define the optimization procedure, which is passed to `train_model`

In [None]:
# default_exp training_module

In [None]:
# hide
%load_ext autoreload
%autoreload 2
from ipynb_path import *
import os
os.environ["JAX_PLATFORM_NAME"] = "cpu"


In [None]:
# export
from cfnet.import_essentials import *
from cfnet.nets import PredictiveMLP, CounterNetMLP, CounterNetConv
from cfnet.interfaces import BaseCFExplanationModule
from cfnet.datasets import TabularDataModule
from cfnet.logger import TensorboardLogger
from cfnet.utils import validate_configs, cat_normalize, accuracy, proximity, make_model, init_net_opt, grad_update
from functools import partial
from abc import ABC, abstractmethod
from copy import deepcopy

In [None]:
# export utils
def make_model(
    m_configs: Dict[str, Any],
    model: hk.Module
) -> hk.Transformed:
    # example:
    # net = make_model(PredictiveModel)
    # params = net.init(...)
    def model_fn(x, is_training: bool = True):
        return model(m_configs)(x, is_training)

    return hk.transform(model_fn)


def init_net_opt(
    net: hk.Transformed,
    opt: optax.GradientTransformation,
    X: jnp.DeviceArray,
    key: random.PRNGKey
) -> Tuple[hk.Params, optax.OptState]:
    X = device_put(X)
    params = net.init(key, X, is_training=True)
    opt_state = opt.init(params)
    return params, opt_state

In [None]:
#export utils
def grad_update(
    grads: Dict[str, jnp.ndarray],
    params: hk.Params,
    opt_state: optax.OptState,
    opt: optax.GradientTransformation
) -> Tuple[hk.Params, optax.OptState]:
    updates, opt_state = opt.update(grads, opt_state)
    upt_params = optax.apply_updates(params, updates)
    return upt_params, opt_state

In [None]:
# export
class BaseTrainingModule(ABC):
    hparams: Dict[str, Any]
    logger: Optional[TensorboardLogger]

    def save_hyperparameters(self, configs: Dict[str, Any]) -> Dict[str, Any]:
        self.hparams = deepcopy(configs)
        return self.hparams

    def init_logger(self, logger: TensorboardLogger):
        self.logger = logger

    def log(self, name: str, value: Any):
        self.log_dict({name: value})

    def log_dict(self, dictionary: Dict[str, Any]):
        if self.logger:
            # self.logger.log({k: np.asarray(v) for k, v in dictionary.items()})
            self.logger.log_dict(dictionary)
        else:
            raise ValueError("Logger has not been initliazed.")

    @abstractmethod
    def init_net_opt(self,
        data_module: TabularDataModule,
        key: random.PRNGKey) -> Tuple[hk.Params, optax.OptState]:
        pass

    @abstractmethod
    def training_step(self,
        params: hk.Params,
        opt_state: optax.OptState,
        rng_key: random.PRNGKey,
        batch: Tuple[jnp.array, jnp.array]) -> Tuple[hk.Params, optax.OptState]:
        pass

    @abstractmethod
    def validation_step(self,
        params: hk.Params,
        rng_key: random.PRNGKey,
        batch: Tuple[jnp.array, jnp.array]) -> Dict[str, Any]:
        pass

In [None]:
# export
class PredictiveTrainingModuleConfigs(BaseParser):
    lr: float

In [None]:
# export
class PredictiveTrainingModule(BaseTrainingModule):
    def __init__(
        self,
        net: hk.Transformed,
        m_configs: Dict[str, Any]
    ):
        self.save_hyperparameters(m_configs)
        self.net = net
        self.configs = validate_configs(m_configs, PredictiveTrainingModuleConfigs)
        # self.configs = PredictiveTrainingModuleConfigs(**m_configs)
        self.opt = optax.adam(learning_rate=self.configs.lr)

    @partial(jax.jit, static_argnames=['self', 'is_training'])
    def forward(self, params, rng_key, x, is_training: bool = True):
        return self.net.apply(params, rng_key, x, is_training = is_training)

    def init_net_opt(self, data_module, key):
        params, opt_state = init_net_opt(self.net, self.opt, X=data_module.get_sample_X(), key=key)
        return params, opt_state

    def loss_fn(self, params, rng_key, batch, is_training: bool = True):
        x, y = batch
        y_pred = self.net.apply(params, rng_key, x, is_training=is_training)
        return jnp.mean(vmap(optax.l2_loss)(y_pred, y))

    # def _training_step(self, params, opt_state, rng_key, batch):
    #     grads = jax.grad(self.loss_fn)(params, rng_key, batch)
    #     upt_params, opt_state = grad_update(grads, params, opt_state, self.opt)
    #     return upt_params, opt_state

    @partial(jax.jit, static_argnames=['self'])
    def _training_step(self, params, opt_state, rng_key, batch):
        grads = jax.grad(self.loss_fn)(params, rng_key, batch)
        upt_params, opt_state = grad_update(grads, params, opt_state, self.opt)
        return upt_params, opt_state

    def training_step(self, params, opt_state, rng_key, batch):
        params, opt_state = self._training_step(params, opt_state, rng_key, batch)

        loss = self.loss_fn(params, rng_key, batch)
        self.log_dict({
            'train/train_loss_1': loss.item()
        })
        return params, opt_state

    def validation_step(self, params, rng_key, batch):
        x, y = batch
        y_pred = self.net.apply(params, rng_key, x, is_training=False)
        loss = self.loss_fn(params, rng_key, batch, is_training=False)
        logs = {
            'val/val_loss': loss.item(),
            'val/val_accuracy': accuracy(y, y_pred)
        }
        self.log_dict(logs)

In [None]:
# export
class PredictiveTrainingModuleMLP(PredictiveTrainingModule):
    def __init__(self, m_configs):
        net = make_model(m_configs, PredictiveMLP)
        super().__init__(net, m_configs)

In [None]:
# export 
def partition_trainable_params(params: hk.Params, trainable_name: str):
    trainable_params, non_trainable_params = hk.data_structures.partition(
            lambda m, n, p: trainable_name in m, params)
    return trainable_params, non_trainable_params

In [None]:
# export
class CounterNetTrainingModuleConfigs(BaseParser):
    lr: float
    lambda_1: float
    lambda_2: float
    lambda_3: float
    use_immutable: bool = True

In [None]:
# export
def project_immutable_features(x, cf: jnp.DeviceArray, imutable_idx_list: List[int]):
    cf = cf.at[:, imutable_idx_list].set(x[:, imutable_idx_list])
    return cf

class CounterNetTrainingModule(BaseTrainingModule, BaseCFExplanationModule):
    name = "CounterNet"

    def __init__(
        self,
        net: hk.Transformed,
        m_configs: Dict[str, Any]
    ):
        self.save_hyperparameters(m_configs)
        self.configs = validate_configs(m_configs, CounterNetTrainingModuleConfigs)
        self.net = net
        # self.configs = CounterNetTrainingModuleConfigs(**m_configs)
        self.opt_1 = optax.adam(learning_rate=self.configs.lr)
        self.opt_2 = optax.adam(learning_rate=self.configs.lr)

    def init_net_opt(self, data_module, key):
        self.update_cat_info(data_module)
        # manually init multiple opts
        params, opt_1_state = init_net_opt(self.net, self.opt_1, X=data_module.get_sample_X(), key=key)
        trainable_params, _ = partition_trainable_params(
            params, trainable_name='counter_net_model/Explainer'
        )
        opt_2_state = self.opt_2.init(trainable_params)
        return params, (opt_1_state, opt_2_state)

    @partial(jax.jit, static_argnames=['self', 'is_training'])
    def forward(self, params, rng_key, x, is_training: bool = True):
        # first forward to get y_pred and normalized cf
        y_pred, cf = self.net.apply(params, rng_key, x, is_training=is_training)
        # cf = cf_res + x
        cf = cat_normalize(cf, self.cat_arrays, self.cat_idx, hard=not is_training)
        # project immutable features
        if self.configs.use_immutable:
            cf = project_immutable_features(x, cf, self.imutable_idx_list)
        # second forward to calulate cf_y
        cf_y, _ = self.net.apply(params, rng_key, cf, is_training=is_training)
        return y_pred, cf, cf_y

    def predict(self, params, rng_key, x):
        y_pred, _ = self.net.apply(params, rng_key, x, is_training=False)
        return y_pred

    def generate_cfs(self, X: chex.ArrayBatched, params, rng_key) -> chex.ArrayBatched:
        y_pred, cfs = self.net.apply(params, rng_key, X, is_training=False)
        # cfs = cfs + X
        cfs = cat_normalize(cfs, self.cat_arrays, self.cat_idx, hard=True)
        if self.configs.use_immutable:
            cfs = project_immutable_features(X, cfs, self.imutable_idx_list)
        return cfs

    def loss_fn_1(self, y_pred, y):
        return jnp.mean(vmap(optax.l2_loss)(y_pred, y))

    def loss_fn_2(self, cf_y, y_prime):
        return jnp.mean(vmap(optax.l2_loss)(cf_y, y_prime))

    def loss_fn_3(self, x, cf):
        return jnp.mean(vmap(optax.l2_loss)(x, cf))

    # def loss_fns(self, params, rng_key, batch, is_training: bool = True):
    #     x, y = batch
    #     y_pred, cf, cf_y = self.forward(params, rng_key, x, is_training=is_training)
    #     y_prime = 1 - jnp.round(y_pred)
    #     return self.loss_fn_1(y_pred, y), self.loss_fn_2(cf_y, y_prime), self.loss_fn_3(x, cf)

    def pred_loss_fn(self, params, rng_key, batch, is_training: bool = True):
        x, y = batch
        y_pred, cf = self.net.apply(params, rng_key, x, is_training=is_training)
        return self.configs.lambda_1 * self.loss_fn_1(y_pred, y)

    def exp_loss_fn(self, trainable_params, non_trainable_params, rng_key, batch, is_training: bool = True):
        # merge trainable and non_trainable params
        params = hk.data_structures.merge(trainable_params, non_trainable_params)
        x, y = batch
        y_pred, cf, cf_y = self.forward(params, rng_key, x, is_training=is_training)
        y_prime = 1 - jnp.round(y_pred)
        loss_2, loss_3 = self.loss_fn_2(cf_y, y_prime), self.loss_fn_3(x, cf)
        return self.configs.lambda_2 * loss_2 + self.configs.lambda_3 * loss_3

    def _predictor_step(self, params, opt_state, rng_key, batch):
        grads = jax.grad(self.pred_loss_fn)(params, rng_key, batch)
        upt_params, opt_state = grad_update(grads, params, opt_state, self.opt_1)
        return upt_params, opt_state

    def _explainer_step(self, params, opt_state, rng_key, batch):
        trainable_params, non_trainable_params = partition_trainable_params(
            params, trainable_name='counter_net_model/Explainer'
        )
        grads = jax.grad(self.exp_loss_fn)(
            trainable_params, non_trainable_params, rng_key, batch)
        upt_trainable_params, opt_state = grad_update(grads, trainable_params, opt_state, self.opt_2)
        upt_params = hk.data_structures.merge(upt_trainable_params, non_trainable_params)
        return upt_params, opt_state

    @partial(jax.jit, static_argnames=['self'])
    def _training_step(self,
            params: hk.Params,
            opts_state: Tuple[optax.GradientTransformation, optax.GradientTransformation],
            rng_key: random.PRNGKey,
            batch: Tuple[jnp.array, jnp.array]):
        opt_1_state, opt_2_state = opts_state
        params, opt_1_state = self._predictor_step(params, opt_1_state, rng_key, batch)
        upt_params, opt_2_state = self._explainer_step(params, opt_2_state, rng_key, batch)
        return upt_params, (opt_1_state, opt_2_state)

    def _training_step_logs(self, params, rng_key, batch):
        x, y = batch
        y_pred, cf, cf_y = self.forward(params, rng_key, x, is_training=False)
        y_prime = 1 - jnp.round(y_pred)

        loss_1, loss_2, loss_3 = self.loss_fn_1(y_pred, y), self.loss_fn_2(cf_y, y_prime), self.loss_fn_3(x, cf)
        logs = {
            'train/train_loss_1': loss_1.item(),
            'train/train_loss_2': loss_2.item(),
            'train/train_loss_3': loss_3.item(),
        }
        return logs

    def training_step(self,
        params: hk.Params,
        opts_state: Tuple[optax.OptState, optax.OptState],
        rng_key: random.PRNGKey,
        batch: Tuple[jnp.array, jnp.array]
    ) -> Tuple[hk.Params, Tuple[optax.OptState, optax.OptState]]:
        upt_params, (opt_1_state, opt_2_state) = self._training_step(params, opts_state, rng_key, batch)

        logs = self._training_step_logs(upt_params, rng_key, batch)
        self.log_dict(logs)
        return upt_params, (opt_1_state, opt_2_state)

    def validation_step(self, params, rng_key, batch):
        x, y = batch
        y_pred, cf, cf_y = self.forward(params, rng_key, x, is_training=False)
        y_prime = 1 - jnp.round(y_pred)

        loss_1, loss_2, loss_3 = self.loss_fn_1(y_pred, y), self.loss_fn_2(cf_y, y_prime), self.loss_fn_3(x, cf)
        loss_1, loss_2, loss_3 = map(np.asarray, (loss_1, loss_2, loss_3))
        logs = {
            'val/accuracy': accuracy(y, y_pred),
            'val/validity': accuracy(cf_y, y_prime),
            'val/proximity': proximity(x, cf),
            'val/val_loss_1': loss_1,
            'val/val_loss_2': loss_2,
            'val/val_loss_3': loss_3,
            'val/val_loss': loss_1 + loss_2 + loss_3
        }
        self.log_dict(logs)
        return logs

In [None]:
# export
class CounterNetTrainingModuleMLP(CounterNetTrainingModule):
    def __init__(self, m_configs: Dict[str, Any]):
        net = make_model(m_configs, CounterNetMLP)
        super().__init__(net, m_configs)

In [None]:
# export
class CounterNetTrainingModuleConv(CounterNetTrainingModule):
    def __init__(self, m_configs: Dict[str, Any]):
        net = make_model(None, CounterNetConv)
        super().__init__(net, m_configs)

    def init_net_opt(self, data_module, key):
        # manually init multiple opts
        params, opt_1_state = init_net_opt(self.net, self.opt_1, X=data_module.get_sample_X(), key=key)
        trainable_params, _ = partition_trainable_params(
            params, trainable_name='counter_net_model/Explainer'
        )
        opt_2_state = self.opt_2.init(trainable_params)
        return params, (opt_1_state, opt_2_state)

    @partial(jax.jit, static_argnames=['self', 'is_training'])
    def forward(self, params, rng_key, x, is_training: bool = True):
        # first forward to get y_pred and normalized cf
        y_pred, cf = self.net.apply(params, rng_key, x, is_training=is_training)
        # project immutable features
        if self.configs.use_immutable:
            cf = project_immutable_features(x, cf, self.imutable_idx_list)
        # second forward to calulate cf_y
        cf_y, _ = self.net.apply(params, rng_key, cf, is_training=is_training)
        return y_pred, cf, cf_y

## Test

In [None]:
from cfnet.train import train_model, TensorboardLogger
from cfnet.datasets import TabularDataModule

data_configs = {
    "data_dir": "assets/data/s_adult.csv",
    "data_name": "adult",
    "batch_size": 128,
    'sample_frac': 0.1,
    "continous_cols": [
        "age",
        "hours_per_week"
    ],
    "discret_cols": [
        "workclass",
        "education",
        "marital_status",
        "occupation",
        "race",
        "gender"
    ],
}
# dm = 
m_configs = {
    "enc_sizes": [50,10],
    "dec_sizes": [10],
    "exp_sizes": [50, 50],
    "dropout_rate": 0.3,
    "lr": 0.003,
    "lambda_1": 1.0,
    "lambda_3": 0.1,
    "lambda_2": 0.2,
}

t_configs = {
    'n_epochs': 1,
    'monitor_metrics': 'val/val_loss'
}



In [None]:
net = make_model(m_configs, CounterNetMLP)
opt = optax.adam(0.01)
dm = TabularDataModule(data_configs)
key = hk.PRNGSequence(42)

params, opt_state = init_net_opt(net, opt, dm.get_sample_X(), next(key))

In [None]:
params, opts = train_model(
    CounterNetTrainingModule(net,m_configs),
    TabularDataModule(data_configs),
    t_configs
)

Epoch 0: 100%|██████████| 20/20 [00:06<00:00,  3.13batch/s, train/train_loss_1=0.0314, train/train_loss_2=0.262, train/train_loss_3=0.153]


In [None]:
jax.tree_map(lambda x: x.shape, params)

{'counter_net_mlp/Encoder/dense_block/linear': {'b': (50,), 'w': (29, 50)},
 'counter_net_mlp/Encoder/dense_block_1/linear': {'b': (10,), 'w': (50, 10)},
 'counter_net_mlp/Explainer/dense_block/linear': {'b': (50,), 'w': (20, 50)},
 'counter_net_mlp/Explainer/dense_block_1/linear': {'b': (50,), 'w': (50, 50)},
 'counter_net_mlp/Explainer_1': {'b': (29,), 'w': (50, 29)},
 'counter_net_mlp/Predictor/dense_block/linear': {'b': (10,), 'w': (10, 10)},
 'counter_net_mlp/Predictor_1': {'b': (1,), 'w': (10, 1)}}

In [None]:
jax.tree_map(
    lambda x: x.shape,
    partition_trainable_params(params, trainable_name='counter_net_model/Explainer')
)

({},
 {'counter_net_mlp/Encoder/dense_block/linear': {'b': (50,), 'w': (29, 50)},
  'counter_net_mlp/Encoder/dense_block_1/linear': {'b': (10,), 'w': (50, 10)},
  'counter_net_mlp/Explainer/dense_block/linear': {'b': (50,), 'w': (20, 50)},
  'counter_net_mlp/Explainer/dense_block_1/linear': {'b': (50,),
   'w': (50, 50)},
  'counter_net_mlp/Explainer_1': {'b': (29,), 'w': (50, 29)},
  'counter_net_mlp/Predictor/dense_block/linear': {'b': (10,), 'w': (10, 10)},
  'counter_net_mlp/Predictor_1': {'b': (1,), 'w': (10, 1)}})

## MNIST

In [None]:
from cfnet.nets import CounterNetConv, PredictivConvNet
from cfnet.datasets import MNISTDataModule, MNISTDataConfigs
from cfnet.train import train_model


d_configs = MNISTDataConfigs(batch_size=128)
dm = MNISTDataModule(d_configs)
model = PredictiveTrainingModule(
    net=make_model(None, PredictivConvNet),
    m_configs={"lr": 0.003}
)

train_X.shape: (13007, 28, 28); train_y.shape: (13007, 1) 
test_X.shape: (2163, 28, 28); test_y.shape: (2163, 1) 


In [None]:
t_configs = {
    'n_epochs': 1,
    'monitor_metrics': 'val/val_loss'
}

train_model(
    model, dm, t_configs
)

Epoch 0: 100%|██████████| 102/102 [00:04<00:00, 20.96batch/s, train/train_loss_1=0.00561]


({'predictiv_conv_net/conv2_d': {'b': DeviceArray([-0.00140353,  0.01591166,  0.00818928,  0.04579916,
                 0.01824267, -0.00758895, -0.00354012, -0.03288223,
                -0.00728636,  0.00757818, -0.00079604,  0.03169249,
                -0.00482811, -0.00781124,  0.00478214, -0.00221053,
                -0.0005601 ,  0.00864014, -0.01045936,  0.00089821,
                -0.00253251, -0.0164368 ,  0.02049785,  0.00344305,
                 0.00771122,  0.01764267, -0.02076225,  0.01419374,
                 0.01891027,  0.00951661,  0.01519189,  0.01351604],            dtype=float32),
   'w': DeviceArray([[[[-0.03138227, -0.07555365,  0.04476513, ...,
                   -0.06140496,  0.10528217, -0.05212118],
                  [ 0.04279129, -0.03736086,  0.03191335, ...,
                   -0.10117479, -0.03430128, -0.02101295],
                  [ 0.06817198, -0.03804344,  0.05765162, ...,
                    0.09488443,  0.04754618,  0.11872477],
                  ...,