# Module

> Modules used for defining model architecture and training procedure, which are passed to `train_model`.

In [87]:
#| default_exp module

In [88]:
#| include: false
%load_ext autoreload
%autoreload 2
from ipynb_path import *
from nbdev import show_doc

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [89]:
#| export
from __future__ import annotations
from relax.import_essentials import *
from relax.data import TabularDataModule, load_data
from relax.logger import TensorboardLogger
from relax.utils import validate_configs, sigmoid, accuracy, init_net_opt, grad_update, make_hk_module, show_doc as show_parser_doc, load_json
from relax.trainer import train_model
from fastcore.basics import patch
from functools import partial
from abc import ABC, abstractmethod
from copy import deepcopy
from relax._ckpt_manager import load_checkpoint, save_checkpoint

## Networks

Networks are [haiku.module](https://dm-haiku.readthedocs.io/en/latest/api.html#common-modules), 
which define model architectures.

In [90]:
#| export
class BaseNetwork(ABC):
    """BaseNetwork needs a `is_training` argument"""

    def __call__(self, *, is_training: bool):
        pass


In [91]:
#| export
class DenseBlock(hk.Module):
    """A `DenseBlock` consists of a dense layer, followed by Leaky Relu and a dropout layer."""
    
    def __init__(
        self,
        output_size: int,  # Output dimensionality.
        dropout_rate: float = 0.3,  # Dropout rate.
        name: str | None = None,  # Name of the Module
    ):
        super().__init__(name=name)
        self.output_size = output_size
        self.dropout_rate = dropout_rate

    def __call__(self, x: jnp.ndarray, is_training: bool = True) -> jnp.ndarray:
        dropout_rate = self.dropout_rate if is_training else 0.0
        # he_uniform
        w_init = hk.initializers.VarianceScaling(2.0, "fan_in", "uniform")
        x = hk.Linear(self.output_size, w_init=w_init)(x)
        x = jax.nn.leaky_relu(x)
        x = hk.dropout(hk.next_rng_key(), dropout_rate, x)
        return x


In [92]:
#| export
class MLP(hk.Module):
    """A `MLP` consists of a list of `DenseBlock` layers."""
    
    def __init__(
        self,
        sizes: Iterable[int],  # Sequence of layer sizes.
        dropout_rate: float = 0.3,  # Dropout rate.
        name: str | None = None,  # Name of the Module
    ):
        super().__init__(name=name)
        self.sizes = sizes
        self.dropout_rate = dropout_rate

    def __call__(self, x: jnp.ndarray, is_training: bool = True) -> jnp.ndarray:
        for size in self.sizes:
            x = DenseBlock(size, self.dropout_rate)(x, is_training)
        return x


## Predictive Model

In [93]:
#| exporti
class PredictiveModelConfigs(BaseParser):
    """Configurator of `PredictiveModel`."""

    sizes: List[int]  # Sequence of layer sizes.
    dropout_rate: float = 0.3  # Dropout rate.


In [94]:
#| export
class PredictiveModel(hk.Module):
    """A basic predictive model for binary classification."""
    
    def __init__(
        self,
        sizes: List[int], # Sequence of layer sizes.
        dropout_rate: float = 0.3,  # Dropout rate.
        name: Optional[str] = None,  # Name of the module.
    ):
        """A basic predictive model for binary classification."""
        super().__init__(name=name)
        self.configs = PredictiveModelConfigs(
            sizes=sizes, dropout_rate=dropout_rate
        )

    def __call__(self, x: jnp.ndarray, is_training: bool = True) -> jnp.ndarray:
        x = MLP(sizes=self.configs.sizes, dropout_rate=self.configs.dropout_rate)(
            x, is_training
        )
        x = hk.Linear(1)(x)
        x = jax.nn.sigmoid(x)
        # x = sigmoid(x)
        return x


Use `make_hk_module` to create a `haiku.Transformed` model.

In [95]:
from relax.utils import make_hk_module

In [96]:
net = make_hk_module(PredictiveModel, sizes=[50, 20, 10], dropout_rate=0.3)

We make some random data.

In [97]:
key = hk.PRNGSequence(42)
xs = random.normal(next(key), (1000, 10))

We can then initalize the model

In [98]:
params = net.init(next(key), xs, is_training=True)

We can view model's structure via `jax.tree_map`.

In [99]:
jax.tree_map(lambda x: x.shape, params)

{'predictive_model/linear': {'b': (1,), 'w': (10, 1)},
 'predictive_model/mlp/dense_block/linear': {'b': (50,), 'w': (10, 50)},
 'predictive_model/mlp/dense_block_1/linear': {'b': (20,), 'w': (50, 20)},
 'predictive_model/mlp/dense_block_2/linear': {'b': (10,), 'w': (20, 10)}}

Model output is produced via `apply` function.

In [100]:
y = net.apply(params, next(key), xs, is_training=True)

For more usage of `haiku.module`, please refer to 
[Haiku documentation](https://dm-haiku.readthedocs.io/en/latest/api.html#haiku-fundamentals).

## Training Modules API

In [101]:
#| hide
class BaseTrainingModule(ABC):
    pass

@patch(as_prop=True)
def logger(
    self:BaseTrainingModule
) -> TensorboardLogger | None:
    """A logger property"""
    pass

@patch
def log(self:BaseTrainingModule, 
        name: str, # Name of the log
        value: Any # value
    ) -> None:
    pass

In [102]:
#| export
class BaseTrainingModule(ABC):
    hparams: Dict[str, Any]
    logger: TensorboardLogger | None

    def save_hyperparameters(self, configs: Dict[str, Any]) -> Dict[str, Any]:
        self.hparams = deepcopy(configs)
        return self.hparams

    def init_logger(self, logger: TensorboardLogger):
        self.logger = logger

    def log(self, name: str, value: Any):
        self.log_dict({name: value})

    def log_dict(self, dictionary: Dict[str, Any]):
        if self.logger:
            # self.logger.log({k: np.asarray(v) for k, v in dictionary.items()})
            self.logger.log_dict(dictionary)
        else:
            raise ValueError("Logger has not been initliazed.")

    @abstractmethod
    def init_net_opt(
        self, data_module: TabularDataModule, key: random.PRNGKey
    ) -> Tuple[hk.Params, optax.OptState]:
        pass

    @abstractmethod
    def training_step(
        self,
        params: hk.Params,
        opt_state: optax.OptState,
        rng_key: random.PRNGKey,
        batch: Tuple[jnp.array, jnp.array],
    ) -> Tuple[hk.Params, optax.OptState]:
        pass

    @abstractmethod
    def validation_step(
        self,
        params: hk.Params,
        rng_key: random.PRNGKey,
        batch: Tuple[jnp.array, jnp.array],
    ) -> Dict[str, Any]:
        pass


## Predictive Training Module

In [103]:
#| export
class PredictiveTrainingModuleConfigs(BaseParser):
    """Configurator of `PredictiveTrainingModule`."""
    
    lr: float = Field(description='Learning rate.')
    sizes: List[int] = Field(description='Sequence of layer sizes.')
    dropout_rate: float = Field(0.3, description='Dropout rate') 

In [104]:
#| export
class PredictiveTrainingModule(BaseTrainingModule):
    """A training module for predictive models."""
    
    def __init__(self, m_configs: Dict | PredictiveTrainingModuleConfigs):
        self.save_hyperparameters(m_configs)
        self.configs = validate_configs(m_configs, PredictiveTrainingModuleConfigs)
        self.net = make_hk_module(
            PredictiveModel, 
            sizes=self.configs.sizes, 
            dropout_rate=self.configs.dropout_rate
        )
        self.opt = optax.adam(learning_rate=self.configs.lr)

    @partial(jax.jit, static_argnames=["self", "is_training"])
    def forward(self, params, rng_key, x, is_training: bool = True):
        return self.net.apply(params, rng_key, x, is_training=is_training)
    
    def pred_fn(self, x, params, rng_key):
        return self.forward(params, rng_key, x, is_training=False)

    def init_net_opt(self, data_module, key):
        X, _ = data_module.train_dataset[:100]
        params, opt_state = init_net_opt(
            self.net, self.opt, X=X, key=key
        )
        return params, opt_state

    @partial(jax.jit, static_argnames=["self", "is_training"])
    def loss_fn(self, params, rng_key, batch, is_training: bool = True):
        x, y = batch
        y_pred = self.net.apply(params, rng_key, x, is_training=is_training)
        return jnp.mean(vmap(optax.l2_loss)(y_pred, y))

    # def _training_step(self, params, opt_state, rng_key, batch):
    #     grads = jax.grad(self.loss_fn)(params, rng_key, batch)
    #     upt_params, opt_state = grad_update(grads, params, opt_state, self.opt)
    #     return upt_params, opt_state

    @partial(jax.jit, static_argnames=["self"])
    def _training_step(self, params, opt_state, rng_key, batch):
        loss, grads = jax.value_and_grad(self.loss_fn)(params, rng_key, batch)
        upt_params, opt_state = grad_update(grads, params, opt_state, self.opt)
        return upt_params, opt_state, loss

    def training_step(self, params, opt_state, rng_key, batch):
        params, opt_state, loss = self._training_step(params, opt_state, rng_key, batch)
        self.log_dict({"train/train_loss_1": loss.item()})
        return params, opt_state

    def validation_step(self, params, rng_key, batch):
        x, y = batch
        y_pred = self.net.apply(params, rng_key, x, is_training=False)
        loss = self.loss_fn(params, rng_key, batch, is_training=False)
        logs = {"val/val_loss": loss.item(), "val/val_accuracy": accuracy(y, y_pred)}
        self.log_dict(logs)


In [105]:
# export
def load_pred_model(data_name: str) -> Tuple[hk.Params, PredictiveTrainingModule]:

    # Fetch the sizes and lr from the configs file
    data_dir = Path(os.getcwd()) / "cf_data" / data_name 
    mlp_configs = load_json(data_dir / "configs.json" )['mlp_configs']
    sizes = mlp_configs["sizes"]
    lr = mlp_configs["lr"]

    module = PredictiveTrainingModule({'sizes': sizes, 'lr': lr})
    param = load_checkpoint(data_dir / "model")
    return (param, module)

# Pretrain model

In [106]:
# slow
import shutil

DATASET_NAMES = ["adult","credit","heloc","oulad","student_performance","titanic","german","cancer","spam", "ozone", "qsar", "bioresponse", "churn", "road"]

for data_name in DATASET_NAMES:
    datamodule, data_configs = load_data(data_name = data_name, return_config=True)

    # Fetch the sizes and lr from the configs file
    data_dir = Path(os.getcwd()) / "cf_data" / data_name / "configs.json"
    print(data_dir)
    mlp_configs = load_json(data_dir)['mlp_configs']
    sizes = mlp_configs["sizes"]
    lr = mlp_configs["lr"]
    batch_size = load_json(data_dir)["data_configs"]['batch_size']

    params, opt_state = train_model(
        PredictiveTrainingModule({'sizes': sizes, 'lr': lr}),
        datamodule, t_configs={
            'n_epochs': 10, 'batch_size': batch_size, 'monitor_metrics': 'val/val_loss',
            'max_n_checkpoints': 1, 'logger_name': data_name
        }
    )

    # get the most recent version and the best epoch stored in the version
    version_dir = "log/{data_name}/".format(data_name = data_name) # Obtain the all version
    latest_version = max([os.path.join(version_dir,v) for v in os.listdir(version_dir) if v.startswith("version_")], key=os.path.getmtime)
    epoch = [d for d in os.listdir(latest_version + "/checkpoints/".format(data_name = data_name, version = latest_version)) if d.startswith("epoch")][0] # Obtain the epoch value
    model_dir = latest_version + "/checkpoints/{epoch}/model".format(epoch = epoch)

    # update model to the assets
    shutil.rmtree("assets/{data_name}/model".format(data_name=data_name), ignore_errors=True)
    shutil.copytree(model_dir, "assets/{data_name}/model".format(data_name = data_name))

    # test: save model under cf_data
    shutil.rmtree("cf_data/{data_name}/model".format(data_name=data_name), ignore_errors=True)
    shutil.copytree(model_dir, "cf_data/{data_name}/model".format(data_name = data_name))



/Users/chuck/PycharmProjects/temp/ReLax/nbs/cf_data/adult/configs.json


Epoch 9: 100%|██████████| 96/96 [00:00<00:00, 394.19batch/s, train/train_loss_1=0.0706]


/Users/chuck/PycharmProjects/temp/ReLax/nbs/cf_data/credit/configs.json


Epoch 9: 100%|██████████| 88/88 [00:00<00:00, 395.03batch/s, train/train_loss_1=0.087] 


/Users/chuck/PycharmProjects/temp/ReLax/nbs/cf_data/heloc/configs.json


Epoch 9: 100%|██████████| 31/31 [00:00<00:00, 344.25batch/s, train/train_loss_1=0.117]


/Users/chuck/PycharmProjects/temp/ReLax/nbs/cf_data/oulad/configs.json


Epoch 9: 100%|██████████| 96/96 [00:00<00:00, 378.58batch/s, train/train_loss_1=0.0331]


/Users/chuck/PycharmProjects/temp/ReLax/nbs/cf_data/student_performance/configs.json


Epoch 9: 100%|██████████| 16/16 [00:00<00:00, 757.64batch/s, train/train_loss_1=0.129]


/Users/chuck/PycharmProjects/temp/ReLax/nbs/cf_data/titanic/configs.json


Epoch 9: 100%|██████████| 11/11 [00:00<00:00, 620.70batch/s, train/train_loss_1=0.0687]


/Users/chuck/PycharmProjects/temp/ReLax/nbs/cf_data/german/configs.json


Epoch 9: 100%|██████████| 12/12 [00:00<00:00, 647.53batch/s, train/train_loss_1=0.057]


/Users/chuck/PycharmProjects/temp/ReLax/nbs/cf_data/cancer/configs.json


Epoch 9: 100%|██████████| 14/14 [00:00<00:00, 734.22batch/s, train/train_loss_1=0.0509]


/Users/chuck/PycharmProjects/temp/ReLax/nbs/cf_data/spam/configs.json


Epoch 9: 100%|██████████| 14/14 [00:00<00:00, 412.38batch/s, train/train_loss_1=0.0338]


/Users/chuck/PycharmProjects/temp/ReLax/nbs/cf_data/ozone/configs.json


Epoch 9: 100%|██████████| 8/8 [00:00<00:00, 406.84batch/s, train/train_loss_1=0.0232]


/Users/chuck/PycharmProjects/temp/ReLax/nbs/cf_data/qsar/configs.json


Epoch 9: 100%|██████████| 7/7 [00:00<00:00, 324.92batch/s, train/train_loss_1=0.1] 


/Users/chuck/PycharmProjects/temp/ReLax/nbs/cf_data/bioresponse/configs.json


Epoch 9: 100%|██████████| 11/11 [00:00<00:00, 260.42batch/s, train/train_loss_1=0.0556]


/Users/chuck/PycharmProjects/temp/ReLax/nbs/cf_data/churn/configs.json


Epoch 9: 100%|██████████| 21/21 [00:00<00:00, 423.40batch/s, train/train_loss_1=0.0746]


/Users/chuck/PycharmProjects/temp/ReLax/nbs/cf_data/road/configs.json


Epoch 9: 100%|██████████| 655/655 [00:01<00:00, 471.80batch/s, train/train_loss_1=0.0668]


# Test

In [107]:
from sklearn.metrics import accuracy_score
log = {"name":[], "accuracy":[],"rfc accuracy":[]}
for data_name in DATASET_NAMES:
    datamodule = load_data(data_name = data_name)
    params, module = load_pred_model(data_name)
    x,y_true = datamodule.test_dataset[:]
    y_pred = module.pred_fn(x = x, params = params, rng_key = random.PRNGKey(0))
    assert y_pred.shape == (x.shape[0],1)

    # calculate accuracy
    y_pred = y_pred > 0.5
    y_pred = y_pred.astype(int)
    accuracy = accuracy_score(y_true,y_pred)

    log["name"].append(data_name)
    log["accuracy"].append(accuracy)




In [108]:
from sklearn.ensemble import RandomForestClassifier

log["rfc accuracy"] = []
for data_name in DATASET_NAMES:
    rfc = RandomForestClassifier(random_state=0)
    datamodule = load_data(data_name = data_name)
    X_train, y_train = datamodule.train_dataset[:]
    rfc.fit(X_train, y_train)
    X_test, y_test = datamodule.test_dataset[:]
    y_pred = rfc.predict(X_test)

    # calculate accuracy
    y_pred = y_pred > 0.5
    y_pred = y_pred.astype(int)
    accuracy = accuracy_score(y_test,y_pred)
    log["rfc accuracy"].append(accuracy)

pd.DataFrame.from_dict(log)



  rfc.fit(X_train, y_train)
  rfc.fit(X_train, y_train)
  rfc.fit(X_train, y_train)
  rfc.fit(X_train, y_train)
  rfc.fit(X_train, y_train)
  rfc.fit(X_train, y_train)
  rfc.fit(X_train, y_train)
  rfc.fit(X_train, y_train)
  rfc.fit(X_train, y_train)
  rfc.fit(X_train, y_train)
  rfc.fit(X_train, y_train)
  rfc.fit(X_train, y_train)
  rfc.fit(X_train, y_train)
  rfc.fit(X_train, y_train)


Unnamed: 0,name,accuracy,rfc accuracy
0,adult,0.8241,0.806166
1,credit,0.8132,0.813467
2,heloc,0.702868,0.719312
3,oulad,0.926739,0.940361
4,student_performance,0.90184,0.920245
5,titanic,0.816143,0.802691
6,german,0.756,0.756
7,cancer,0.909091,0.916084
8,spam,0.93397,0.943527
9,ozone,0.933754,0.949527
