# Module

> Modules used for defining model architecture and training procedure, which are passed to `train_model`.

In [None]:
#| default_exp module

In [None]:
#| include: false
%load_ext autoreload
%autoreload 2
from ipynb_path import *
from nbdev import show_doc

In [None]:
#| export
from __future__ import annotations
from relax.import_essentials import *
from relax.data import TabularDataModule
from relax.logger import TensorboardLogger
from relax.utils import validate_configs, sigmoid, accuracy, init_net_opt, grad_update, make_hk_module, show_doc as show_parser_doc
from fastcore.basics import patch
from functools import partial
from abc import ABC, abstractmethod
from copy import deepcopy

## Networks

Networks are [haiku.module](https://dm-haiku.readthedocs.io/en/latest/api.html#common-modules), 
which define model architectures.

In [None]:
#| export
class BaseNetwork(ABC):
    """BaseNetwork needs a `is_training` argument"""

    def __call__(self, *, is_training: bool):
        pass


In [None]:
#| export
#| hide
class DenseBlock(hk.Module):
    def __init__(
        self,
        output_size: int,  # Output dimensionality.
        dropout_rate: float = 0.3,  # Dropout rate.
        name: str | None = None,  # Name of the Module
    ):
        """A `DenseBlock` consists of a dense layer, followed by Leaky Relu and a dropout layer."""
        super().__init__(name=name)
        self.output_size = output_size
        self.dropout_rate = dropout_rate

    def __call__(self, x: jnp.ndarray, is_training: bool = True) -> jnp.ndarray:
        dropout_rate = self.dropout_rate if is_training else 0.0
        # he_uniform
        w_init = hk.initializers.VarianceScaling(2.0, "fan_in", "uniform")
        x = hk.Linear(self.output_size, w_init=w_init)(x)
        x = jax.nn.leaky_relu(x)
        x = hk.dropout(hk.next_rng_key(), dropout_rate, x)
        return x


In [None]:
show_doc(DenseBlock.__init__)

---

[source](https://github.com/birkhoffg/cfnet/tree/master/blob/master/relax/module.py#L28){target="_blank" style="float:right; font-size:smaller"}

### DenseBlock.__init__

>      DenseBlock.__init__ (output_size:int, dropout_rate:float=0.3,
>                           name:Union[str,NoneType]=None)

A `DenseBlock` consists of a dense layer, followed by Leaky Relu and a dropout layer.

|    | **Type** | **Default** | **Details** |
| -- | -------- | ----------- | ----------- |
| output_size | int |  | Output dimensionality. |
| dropout_rate | float | 0.3 | Dropout rate. |
| name | str \| None | None | Name of the Module |

In [None]:
#| export
#| hide
class MLP(hk.Module):
    def __init__(
        self,
        sizes: Iterable[int],  # Sequence of layer sizes.
        dropout_rate: float = 0.3,  # Dropout rate.
        name: str | None = None,  # Name of the Module
    ):
        """A `MLP` consists of a list of `DenseBlock` layers."""
        super().__init__(name=name)
        self.sizes = sizes
        self.dropout_rate = dropout_rate

    def __call__(self, x: jnp.ndarray, is_training: bool = True) -> jnp.ndarray:
        for size in self.sizes:
            x = DenseBlock(size, self.dropout_rate)(x, is_training)
        return x


In [None]:
show_doc(MLP.__init__)

---

[source](https://github.com/birkhoffg/cfnet/tree/master/blob/master/relax/module.py#L51){target="_blank" style="float:right; font-size:smaller"}

### MLP.__init__

>      MLP.__init__ (sizes:Iterable[int], dropout_rate:float=0.3,
>                    name:Union[str,NoneType]=None)

A `MLP` consists of a list of `DenseBlock` layers.

|    | **Type** | **Default** | **Details** |
| -- | -------- | ----------- | ----------- |
| sizes | Iterable[int] |  | Sequence of layer sizes. |
| dropout_rate | float | 0.3 | Dropout rate. |
| name | str \| None | None | Name of the Module |

## Predictive Model

In [None]:
#| exporti
class PredictiveModelConfigs(BaseParser):
    """Configurator of `PredictiveModel`."""

    sizes: List[int]  # Sequence of layer sizes.
    dropout_rate: float = 0.3  # Dropout rate.


In [None]:
#| export
#| hide
class PredictiveModel(hk.Module):
    def __init__(
        self,
        sizes: List[int], # Sequence of layer sizes.
        dropout_rate: float = 0.3,  # Dropout rate.
        name: Optional[str] = None,  # Name of the module.
    ):
        """A basic predictive model for binary classification."""
        super().__init__(name=name)
        self.configs = PredictiveModelConfigs(
            sizes=sizes, dropout_rate=dropout_rate
        )

    def __call__(self, x: jnp.ndarray, is_training: bool = True) -> jnp.ndarray:
        x = MLP(sizes=self.configs.sizes, dropout_rate=self.configs.dropout_rate)(
            x, is_training
        )
        x = hk.Linear(1)(x)
        x = jax.nn.sigmoid(x)
        # x = sigmoid(x)
        return x


In [None]:
show_doc(PredictiveModel.__init__)

---

[source](https://github.com/birkhoffg/cfnet/tree/master/blob/master/relax/module.py#L78){target="_blank" style="float:right; font-size:smaller"}

### PredictiveModel.__init__

>      PredictiveModel.__init__ (sizes:List[int], dropout_rate:float=0.3,
>                                name:Union[str,NoneType]=None)

A basic predictive model for binary classification.

|    | **Type** | **Default** | **Details** |
| -- | -------- | ----------- | ----------- |
| sizes | List[int] |  | Sequence of layer sizes. |
| dropout_rate | float | 0.3 | Dropout rate. |
| name | Optional[str] | None | Name of the module. |

Use `make_hk_module` to create a `haiku.Transformed` model.

In [None]:
from relax.utils import make_hk_module

In [None]:
net = make_hk_module(PredictiveModel, sizes=[50, 20, 10], dropout_rate=0.3)

We make some random data.

In [None]:
key = hk.PRNGSequence(42)
xs = random.normal(next(key), (1000, 10))

We can then initalize the model

In [None]:
params = net.init(next(key), xs, is_training=True)

We can view model's structure via `jax.tree_map`.

In [None]:
jax.tree_map(lambda x: x.shape, params)

{'predictive_model/linear': {'b': (1,), 'w': (10, 1)},
 'predictive_model/mlp/dense_block/linear': {'b': (50,), 'w': (10, 50)},
 'predictive_model/mlp/dense_block_1/linear': {'b': (20,), 'w': (50, 20)},
 'predictive_model/mlp/dense_block_2/linear': {'b': (10,), 'w': (20, 10)}}

Model output is produced via `apply` function.

In [None]:
y = net.apply(params, next(key), xs, is_training=True)

For more usage of `haiku.module`, please refer to 
[Haiku documentation](https://dm-haiku.readthedocs.io/en/latest/api.html#haiku-fundamentals).

## Training Modules API

In [None]:
#| hide
class BaseTrainingModule(ABC):
    pass

@patch(as_prop=True)
def logger(
    self:BaseTrainingModule
) -> TensorboardLogger | None:
    """A logger property"""
    pass

@patch
def log(self:BaseTrainingModule, 
        name: str, # Name of the log
        value: Any # value
    ) -> None:
    pass

In [None]:
#| export
class BaseTrainingModule(ABC):
    hparams: Dict[str, Any]
    logger: TensorboardLogger | None

    def save_hyperparameters(self, configs: Dict[str, Any]) -> Dict[str, Any]:
        self.hparams = deepcopy(configs)
        return self.hparams

    def init_logger(self, logger: TensorboardLogger):
        self.logger = logger

    def log(self, name: str, value: Any):
        self.log_dict({name: value})

    def log_dict(self, dictionary: Dict[str, Any]):
        if self.logger:
            # self.logger.log({k: np.asarray(v) for k, v in dictionary.items()})
            self.logger.log_dict(dictionary)
        else:
            raise ValueError("Logger has not been initliazed.")

    @abstractmethod
    def init_net_opt(
        self, data_module: TabularDataModule, key: random.PRNGKey
    ) -> Tuple[hk.Params, optax.OptState]:
        pass

    @abstractmethod
    def training_step(
        self,
        params: hk.Params,
        opt_state: optax.OptState,
        rng_key: random.PRNGKey,
        batch: Tuple[jnp.array, jnp.array],
    ) -> Tuple[hk.Params, optax.OptState]:
        pass

    @abstractmethod
    def validation_step(
        self,
        params: hk.Params,
        rng_key: random.PRNGKey,
        batch: Tuple[jnp.array, jnp.array],
    ) -> Dict[str, Any]:
        pass


## Predictive Training Module

In [None]:
#| export
#| hide
class PredictiveTrainingModuleConfigs(BaseParser):
    lr: float = Field(description='Learning rate.')
    sizes: List[int] = Field(description='Sequence of layer sizes.')
    dropout_rate: float = Field(0.3, description='Dropout rate') 

In [None]:
#| export
class PredictiveTrainingModule(BaseTrainingModule):
    def __init__(self, m_configs: Dict | PredictiveTrainingModuleConfigs):
        self.save_hyperparameters(m_configs)
        self.configs = validate_configs(m_configs, PredictiveTrainingModuleConfigs)
        self.net = make_hk_module(
            PredictiveModel, 
            sizes=self.configs.sizes, 
            dropout_rate=self.configs.dropout_rate
        )
        self.opt = optax.adam(learning_rate=self.configs.lr)

    @partial(jax.jit, static_argnames=["self", "is_training"])
    def forward(self, params, rng_key, x, is_training: bool = True):
        return self.net.apply(params, rng_key, x, is_training=is_training)

    def init_net_opt(self, data_module, key):
        X, _ = data_module.train_dataset[:100]
        params, opt_state = init_net_opt(
            self.net, self.opt, X=X, key=key
        )
        return params, opt_state

    def loss_fn(self, params, rng_key, batch, is_training: bool = True):
        x, y = batch
        y_pred = self.net.apply(params, rng_key, x, is_training=is_training)
        return jnp.mean(vmap(optax.l2_loss)(y_pred, y))

    # def _training_step(self, params, opt_state, rng_key, batch):
    #     grads = jax.grad(self.loss_fn)(params, rng_key, batch)
    #     upt_params, opt_state = grad_update(grads, params, opt_state, self.opt)
    #     return upt_params, opt_state

    @partial(jax.jit, static_argnames=["self"])
    def _training_step(self, params, opt_state, rng_key, batch):
        grads = jax.grad(self.loss_fn)(params, rng_key, batch)
        upt_params, opt_state = grad_update(grads, params, opt_state, self.opt)
        return upt_params, opt_state

    def training_step(self, params, opt_state, rng_key, batch):
        params, opt_state = self._training_step(params, opt_state, rng_key, batch)

        loss = self.loss_fn(params, rng_key, batch)
        self.log_dict({"train/train_loss_1": loss.item()})
        return params, opt_state

    def validation_step(self, params, rng_key, batch):
        x, y = batch
        y_pred = self.net.apply(params, rng_key, x, is_training=False)
        loss = self.loss_fn(params, rng_key, batch, is_training=False)
        logs = {"val/val_loss": loss.item(), "val/val_accuracy": accuracy(y, y_pred)}
        self.log_dict(logs)
