# CounterNet

> A prediction-aware recourse model

* Paper link: https://arxiv.org/abs/2109.07557

In [None]:
#| default_exp methods.counternet

In [None]:
#| include: false
%load_ext autoreload
%autoreload 2
from ipynb_path import *
from nbdev import show_doc
from relax.utils import show_doc as show_doc_parser

In [None]:
#| export
from __future__ import annotations
from relax.import_essentials import *
from relax.module import MLP, BaseTrainingModule
from relax.methods.base import BaseCFModule, BaseParametricCFModule, BasePredFnCFModule
from relax.trainer import TrainingConfigs, train_model
from relax.data import TabularDataModule
from relax.utils import validate_configs, sigmoid, accuracy, proximity, make_model, init_net_opt, grad_update
from functools import partial

## CounterNet Model


In [None]:
#| exporti
class CounterNetModelConfigs(BaseParser):
    """Configurator of `CounterNetModel`."""

    enc_sizes: List[int] = Field(description='Encoder sizes.')
    dec_sizes: List[int] = Field(description='Predictor sizes.')
    exp_sizes: List[int] = Field(description='CF generator sizes.')
    dropout_rate: float = Field(0.3, description='Dropout rate.')

In [None]:
#| export
#| hide
class CounterNetModel(hk.Module):
    """CounterNet Model"""
    def __init__(
        self,
        m_config: Dict | CounterNetModelConfigs,  # Model configs which contain configs in `CounterNetModelConfigs`.
        name: str = None,  # Name of the module.
    ):
        """CounterNet model architecture."""
        super().__init__(name=name)
        self.configs = validate_configs(m_config, CounterNetModelConfigs)

    def __call__(self, x: jnp.ndarray, is_training: bool = True) -> jnp.ndarray:
        input_shape = x.shape[-1]
        # encoder
        z = MLP(self.configs.enc_sizes, self.configs.dropout_rate, name="Encoder")(
            x, is_training
        )

        # prediction
        pred = MLP(self.configs.dec_sizes, self.configs.dropout_rate, name="Predictor")(
            z, is_training
        )
        y_hat = hk.Linear(1, name="Predictor")(pred)
        y_hat = jax.nn.sigmoid(y_hat)

        # explain
        z_exp = jnp.concatenate((z, pred), axis=-1)
        cf = MLP(self.configs.exp_sizes, self.configs.dropout_rate, name="Explainer")(
            z_exp, is_training
        )
        cf = hk.Linear(input_shape, name="Explainer")(cf)
        return y_hat, cf


In [None]:
show_doc_parser(CounterNetModelConfigs)

---

[source](https://github.com/birkhoffg/cfnet/tree/master/blob/master/cfnet/nets.py#L80){target="_blank" style="float:right; font-size:smaller"}

### CounterNetModelConfigs

>      CounterNetModelConfigs (enc_sizes:List[int], dec_sizes:List[int],
>                              exp_sizes:List[int], dropout_rate:float=0.3)

Configurator of `CounterNetModel`.

|    | **Type** | **Default** | **Details** |
| -- | -------- | ----------- | ----------- |
| enc_sizes | List[int] |  | Encoder sizes. |
| dec_sizes | List[int] |  | Predictor sizes. |
| exp_sizes | List[int] |  | CF generator sizes. |
| dropout_rate | float | 0.3 | Dropout rate. |

In [None]:
show_doc(CounterNetModel.__init__)

---

[source](https://github.com/birkhoffg/cfnet/tree/master/blob/master/cfnet/nets.py#L91){target="_blank" style="float:right; font-size:smaller"}

### CounterNetModel.__init__

>      CounterNetModel.__init__
>                                (m_config:Union[Dict,__main__.CounterNetModelCo
>                                nfigs], name:str=None)

CounterNet model architecture.

|    | **Type** | **Default** | **Details** |
| -- | -------- | ----------- | ----------- |
| m_config | Dict \| CounterNetModelConfigs |  | Model configs which contain configs in `CounterNetModelConfigs`. |
| name | str | None | Name of the module. |

## CounterNet Training Module

Define the `CounterNetTrainingModule` for training `CounterNetModel`.

In [None]:
#| export 
def partition_trainable_params(params: hk.Params, trainable_name: str):
    trainable_params, non_trainable_params = hk.data_structures.partition(
        lambda m, n, p: trainable_name in m, params
    )
    return trainable_params, non_trainable_params


In [None]:
#| export
def project_immutable_features(x, cf: jnp.DeviceArray, imutable_idx_list: List[int]):
    cf = cf.at[:, imutable_idx_list].set(x[:, imutable_idx_list])
    return cf

In [None]:
#| export
class CounterNetTrainingModuleConfigs(BaseParser):
    lr: float = 0.003
    lambda_1: float = 1.0
    lambda_2: float = 0.2
    lambda_3: float = 0.1


In [None]:
#| export
class CounterNetTrainingModule(BaseTrainingModule):
    _data_module: TabularDataModule

    def __init__(self, m_configs: Dict[str, Any]):
        self.save_hyperparameters(m_configs)
        self.net = make_model(m_configs, CounterNetModel)
        self.configs = validate_configs(m_configs, CounterNetTrainingModuleConfigs)
        # self.configs = CounterNetTrainingModuleConfigs(**m_configs)
        self.opt_1 = optax.adam(learning_rate=self.configs.lr)
        self.opt_2 = optax.adam(learning_rate=self.configs.lr)

    def init_net_opt(self, data_module: TabularDataModule, key):
        # hook data_module
        self._data_module = data_module
        X, _ = data_module.train_dataset[:]

        # manually init multiple opts
        params, opt_1_state = init_net_opt(
            self.net, self.opt_1, X=X[:100], key=key
        )
        trainable_params, _ = partition_trainable_params(
            params, trainable_name="counter_net_model/Explainer"
        )
        opt_2_state = self.opt_2.init(trainable_params)
        return params, (opt_1_state, opt_2_state)

    @partial(jax.jit, static_argnames=["self", "is_training"])
    def forward(self, params, rng_key, x, is_training: bool = True):
        # first forward to get y_pred and normalized cf
        y_pred, cf = self.net.apply(params, rng_key, x, is_training=is_training)
        cf = self._data_module.apply_constraints(x, cf, hard=not is_training)

        # second forward to calulate cf_y
        cf_y, _ = self.net.apply(params, rng_key, cf, is_training=is_training)
        return y_pred, cf, cf_y

    def predict(self, params, rng_key, x):
        y_pred, _ = self.net.apply(params, rng_key, x, is_training=False)
        return y_pred

    def generate_cfs(self, X: chex.ArrayBatched, params, rng_key) -> chex.ArrayBatched:
        y_pred, cfs = self.net.apply(params, rng_key, X, is_training=False)
        # cfs = cfs + X
        cfs = self._data_module.apply_constraints(X, cfs, hard=True)
        return cfs

    def loss_fn_1(self, y_pred, y):
        return jnp.mean(vmap(optax.l2_loss)(y_pred, y))

    def loss_fn_2(self, cf_y, y_prime):
        return jnp.mean(vmap(optax.l2_loss)(cf_y, y_prime))

    def loss_fn_3(self, x, cf):
        return jnp.mean(vmap(optax.l2_loss)(x, cf))

    def pred_loss_fn(self, params, rng_key, batch, is_training: bool = True):
        x, y = batch
        y_pred, cf = self.net.apply(params, rng_key, x, is_training=is_training)
        return self.configs.lambda_1 * self.loss_fn_1(y_pred, y)

    def exp_loss_fn(
        self,
        trainable_params,
        non_trainable_params,
        rng_key,
        batch,
        is_training: bool = True,
    ):
        # merge trainable and non_trainable params
        params = hk.data_structures.merge(trainable_params, non_trainable_params)
        x, y = batch
        y_pred, cf, cf_y = self.forward(params, rng_key, x, is_training=is_training)
        y_prime = 1 - jnp.round(y_pred)
        loss_2, loss_3 = self.loss_fn_2(cf_y, y_prime), self.loss_fn_3(x, cf)
        return self.configs.lambda_2 * loss_2 + self.configs.lambda_3 * loss_3

    def _predictor_step(self, params, opt_state, rng_key, batch):
        grads = jax.grad(self.pred_loss_fn)(params, rng_key, batch)
        upt_params, opt_state = grad_update(grads, params, opt_state, self.opt_1)
        return upt_params, opt_state

    def _explainer_step(self, params, opt_state, rng_key, batch):
        trainable_params, non_trainable_params = partition_trainable_params(
            params, trainable_name="counter_net_model/Explainer"
        )
        grads = jax.grad(self.exp_loss_fn)(
            trainable_params, non_trainable_params, rng_key, batch
        )
        upt_trainable_params, opt_state = grad_update(
            grads, trainable_params, opt_state, self.opt_2
        )
        upt_params = hk.data_structures.merge(
            upt_trainable_params, non_trainable_params
        )
        return upt_params, opt_state

    @partial(jax.jit, static_argnames=["self"])
    def _training_step(
        self,
        params: hk.Params,
        opts_state: Tuple[optax.GradientTransformation, optax.GradientTransformation],
        rng_key: random.PRNGKey,
        batch: Tuple[jnp.array, jnp.array],
    ):
        opt_1_state, opt_2_state = opts_state
        params, opt_1_state = self._predictor_step(params, opt_1_state, rng_key, batch)
        upt_params, opt_2_state = self._explainer_step(
            params, opt_2_state, rng_key, batch
        )
        return upt_params, (opt_1_state, opt_2_state)

    def _training_step_logs(self, params, rng_key, batch):
        x, y = batch
        y_pred, cf, cf_y = self.forward(params, rng_key, x, is_training=False)
        y_prime = 1 - jnp.round(y_pred)

        loss_1, loss_2, loss_3 = (
            self.loss_fn_1(y_pred, y),
            self.loss_fn_2(cf_y, y_prime),
            self.loss_fn_3(x, cf),
        )
        logs = {
            "train/train_loss_1": loss_1.item(),
            "train/train_loss_2": loss_2.item(),
            "train/train_loss_3": loss_3.item(),
        }
        return logs

    def training_step(
        self,
        params: hk.Params,
        opts_state: Tuple[optax.OptState, optax.OptState],
        rng_key: random.PRNGKey,
        batch: Tuple[jnp.array, jnp.array],
    ) -> Tuple[hk.Params, Tuple[optax.OptState, optax.OptState]]:
        upt_params, (opt_1_state, opt_2_state) = self._training_step(
            params, opts_state, rng_key, batch
        )

        logs = self._training_step_logs(upt_params, rng_key, batch)
        self.log_dict(logs)
        return upt_params, (opt_1_state, opt_2_state)

    def validation_step(self, params, rng_key, batch):
        x, y = batch
        y_pred, cf, cf_y = self.forward(params, rng_key, x, is_training=False)
        y_prime = 1 - jnp.round(y_pred)

        loss_1, loss_2, loss_3 = (
            self.loss_fn_1(y_pred, y),
            self.loss_fn_2(cf_y, y_prime),
            self.loss_fn_3(x, cf),
        )
        loss_1, loss_2, loss_3 = map(np.asarray, (loss_1, loss_2, loss_3))
        logs = {
            "val/accuracy": accuracy(y, y_pred),
            "val/validity": accuracy(cf_y, y_prime),
            "val/proximity": proximity(x, cf),
            "val/val_loss_1": loss_1,
            "val/val_loss_2": loss_2,
            "val/val_loss_3": loss_3,
            "val/val_loss": loss_1 + loss_2 + loss_3,
        }
        self.log_dict(logs)
        return logs

## CounterNet Explanation Module


![CounterNet architecture](images/CounterNet-architecture.svg)

`CounterNet` consists of three objectives:

1. **predictive accuracy**: the predictor network should output accurate predictions $\hat{y}_x$; 
2. **counterfactual validity**: CF examples $x'$ produced by the CF generator network should be valid (e.g. $\hat{y}_{x} + \hat{y}_{x'}=1$);
3. **minimizing cost of change**: minimal modifications should be required to change input instance $x$ to CF example $x'$.

The objective function of `CounterNet`:

$$
\operatorname*{argmin}_{\mathbf{\theta}} \frac{1}{N}\sum\nolimits_{i=1}^{N} 
    \bigg[ 
    \lambda_1 \cdot \! \underbrace{\left(y_i- \hat{y}_{x_i}\right)^2}_{\text{Prediction Loss}\ (\mathcal{L}_1)} + 
    \;\lambda_2 \cdot \;\; \underbrace{\left(\hat{y}_{x_i}- \left(1 - \hat{y}_{x_i'}\right)\right)^2}_{\text{Validity Loss}\ (\mathcal{L}_2)} \,+ 
    \;\lambda_3 \cdot \!\! \underbrace{\left(x_i- x'_i\right)^2}_{\text{Cost of change Loss}\ (\mathcal{L}_3)}
    \bigg]
$$

`CounterNet` applies two-stage gradient updates to `CounterNetModel` 
for each `training_step` (see `CounterNetTrainingModule`).

1. The first gradient update optimizes for predictive accuracy: 
$\theta^{(1)} = \theta^{(0)} - \nabla_{\theta^{(0)}} (\lambda_1 \cdot \mathcal{L}_1)$.
2. The second gradient update optimizes for generating CF explanation:
$\theta^{(2)}_g = \theta^{(1)}_g - \nabla_{\theta^{(1)}_g} (\mathcal \lambda_2 \cdot \mathcal{L}_2 + \lambda_3 \cdot \mathcal{L}_3)$

The design choice of this optimizing procedure is made due to *improved convergence of the model*,
and *improved adversarial robustness of the predictor network*. 
The [CounterNet paper](https://arxiv.org/abs/2109.07557) elaborates the design choices.


In [None]:
#| hide
#| export
class CounterNetConfigs(CounterNetTrainingModuleConfigs, CounterNetModelConfigs):
    """Configurator of `CounterNet`."""

    enc_sizes: List[int] = Field(
        [50,10], description="Sequence of layer sizes for encoder network."
    )
    dec_sizes: List[int] = Field(
        [10], description="Sequence of layer sizes for predictor."
    ) 
    exp_sizes: List[int] = Field(
        [50, 50], description="Sequence of layer sizes for CF generator."
    )
    
    dropout_rate: float = Field(
        0.3, description="Dropout rate."
    )
    lr: float = Field(
        0.003, description="Learning rate for training `CounterNet`."
    ) 
    lambda_1: float = Field(
        1.0, description=" $\lambda_1$ for balancing the prediction loss $\mathcal{L}_1$."
    ) 
    lambda_2: float = Field(
        0.2, description=" $\lambda_2$ for balancing the prediction loss $\mathcal{L}_2$."
    ) 
    lambda_3: float = Field(
        0.1, description=" $\lambda_3$ for balancing the prediction loss $\mathcal{L}_3$."
    )


In [None]:
show_doc_parser(CounterNetConfigs)

---

[source](https://github.com/birkhoffg/cfnet/tree/master/blob/master/cfnet/methods/counternet.py#L260){target="_blank" style="float:right; font-size:smaller"}

### CounterNetConfigs

>      CounterNetConfigs (enc_sizes:List[int]=[50, 10],
>                         dec_sizes:List[int]=[10], exp_sizes:List[int]=[50,
>                         50], dropout_rate:float=0.3, lr:float=0.003,
>                         lambda_1:float=1.0, lambda_2:float=0.2,
>                         lambda_3:float=0.1)

Configurator of `CounterNet`.

|    | **Type** | **Default** | **Details** |
| -- | -------- | ----------- | ----------- |
| enc_sizes | List[int] | [50, 10] | Sequence of layer sizes for encoder network. |
| dec_sizes | List[int] | [10] | Sequence of layer sizes for predictor. |
| exp_sizes | List[int] | [50, 50] | Sequence of layer sizes for CF generator. |
| dropout_rate | float | 0.3 | Dropout rate. |
| lr | float | 0.003 | Learning rate for training `CounterNet`. |
| lambda_1 | float | 1.0 |  $\lambda_1$ for balancing the prediction loss $\mathcal{L}_1$. |
| lambda_2 | float | 0.2 |  $\lambda_2$ for balancing the prediction loss $\mathcal{L}_2$. |
| lambda_3 | float | 0.1 |  $\lambda_3$ for balancing the prediction loss $\mathcal{L}_3$. |

In [None]:
#| export
class CounterNet(BaseCFModule, BaseParametricCFModule, BasePredFnCFModule):
    """API for CounterNet Explanation Module."""
    params: hk.Params = None
    module: CounterNetTrainingModule
    name: str = 'CounterNet'

    def __init__(
        self, 
        m_configs: dict | CounterNetConfigs = None # configurator of hyperparamters; see `CounterNetConfigs`
    ):
        if m_configs is None:
            m_configs = CounterNetConfigs()
        self.module = CounterNetTrainingModule(m_configs)

    def _is_module_trained(self):
        return not (self.params is None)
    
    def train(
        self, 
        datamodule: TabularDataModule, # data module
        t_configs: TrainingConfigs | dict = None # training configs
    ):
        _default_t_configs = dict(
            n_epochs=100, batch_size=128
        )
        if t_configs is None: t_configs = _default_t_configs
        params, _ = train_model(self.module, datamodule, t_configs)
        self.params = params

    def generate_cfs(self, X: jnp.ndarray, pred_fn = None) -> jnp.ndarray:
        return self.module.generate_cfs(X, self.params, rng_key=jax.random.PRNGKey(0))
    
    def pred_fn(self, X: jnp.DeviceArray):
        rng_key = jax.random.PRNGKey(0)
        y_pred = self.module.predict(self.params, rng_key, X)
        return y_pred


#### Basic usage of `CounterNet`

Prepare data:

In [None]:
from relax.data import load_data

In [None]:

dm = load_data("adult", data_configs=dict(sample_frac=0.1))

Define `CounterNet`:

In [None]:
counternet = CounterNet()

In [None]:
assert isinstance(counternet, BaseParametricCFModule)
assert isinstance(counternet, BaseCFModule)
assert isinstance(counternet, BasePredFnCFModule)
assert hasattr(counternet, 'pred_fn')

Train the model:

In [None]:
#| output: false
t_configs = dict(n_epochs=1, batch_size=128)
counternet.train(dm, t_configs=t_configs)

  "`monitor_metrics` is not specified in `CheckpointManager`. No checkpoints will be stored."
Epoch 0: 100%|██████████| 191/191 [00:09<00:00, 20.36batch/s, train/train_loss_1=0.0595, train/train_loss_2=0.0545, train/train_loss_3=0.1]    


Predict labels

In [None]:
X, y = dm.test_dataset[:]
y_pred = counternet.pred_fn(X)
assert y_pred.shape == (len(y), 1)

Generate CF explanations for given `x`.

In [None]:
X, _ = dm.test_dataset[:]
cfs = counternet.generate_cfs(X)
assert X.shape == cfs.shape