# Advanced customized model base

Some low level methods are provided within `AbstractModel` and `TorchModel`, and offer more flexibility for training and testing customization.

## Advanced customizations of `AbstractModel`

Assume that a model base `TabNetFromAbstractInherited` is built upon `TabNetFromAbstract` introduced in "Customized model base".

```python
class TabNetFromAbstractInherited(TabNetFromAbstract):
```

### Training parameters

`_custom_training_params` returns a dictionary containing items that override settings in the configuration file for the model base. For example:

```python
    def _custom_training_params(self, model_name) -> Dict:
        return {"epoch": 100}
```

### Bayesian optimization criterion

During Bayesian hyperparameter optimization, the objective might be the validation loss, the training loss, or something else. By default, the larger one of the validation loss and the training loss will be returned (The former is usually higher, but randomization may sometimes make the latter higher). For example, the following code returns the validation loss

```python
    def _bayes_eval(self, model, X_train, y_train, X_val, y_val):
        y_val_pred = self._pred_single_model(model, X_val, verbose=False)
        _, val_loss = self._default_metric_sklearn(y_val, y_val_pred)
        return val_loss
```

where `_default_metric_sklearn` returns MSE loss for regression tasks and log loss for classification tasks.

### Validity of a model

`_conditional_validity` is used to check the validity of a model under certain circumstances. For example, some models might be invalid if a certain feature `A_FEATURE` is not provided:

```python
    def _conditional_validity(self, model_name: str) -> bool:
        if model_name == "SOME_MODEL" and "A_FEATURE" not in self.trainer.cont_feature_names:
            return False
```

**Remark**: We do not recommend modifying other methods in `AbstractModel` except for those introduced in this part and in "Customized model base" unless you know what you are doing.

## Advanced customizations of `TorchModel`

The above customizations of `AbstractModel` can also be applied to `TorchModel`. `TorchModel` is restricted by a narrower framework, but provides more APIs for flexibility considerations. Some customizations are provided in `AbstractNN` at a lower and more specific level.

```python
class TabNetFromTorchInherited(TabNetFromTorch):
```

### Customized data processing

In `TorchModel._train_data_preprocess`, a model base processes tabular or multimodal datasets for itself. The method `_prepare_custom_datamodule` is called at the beginning and should return a `DataModule` instance (`self.trainer.datamodule` by default), which is used to generate final datasets (`torch.utils.data.Dataset` instances) and provides other information. For example, the following code builds a `Datamodule` that additionally records unscaled data as an item of derived data (multimodal data) by using `UnscaledDataDeriver`.

```python
    def _prepare_custom_datamodule(self, model_name):
        from tabensemb.data import DataModule

        base = self.trainer.datamodule
        datamodule = DataModule(config=self.trainer.datamodule.args, initialize=False)
        datamodule.set_data_imputer("MeanImputer")
        datamodule.set_data_derivers(
            [("UnscaledDataDeriver", {"derived_name": "Unscaled"})]
        )
        datamodule.set_data_processors(
            [("CategoricalOrdinalEncoder", {}), ("StandardScaler", {})]
        )
        datamodule.set_data(
            base.df,
            cont_feature_names=base.cont_feature_names,
            cat_feature_names=base.cat_feature_names,
            label_name=base.label_name,
            train_indices=base.train_indices,
            val_indices=base.val_indices,
            test_indices=base.test_indices,
            verbose=False,
        )
        tmp_derived_data = base.derived_data.copy()
        tmp_derived_data.update(datamodule.derived_data)
        datamodule.derived_data = tmp_derived_data
        self.datamodule = datamodule
        return datamodule
```

In `TorchModel._data_preprocess`, `_run_custom_data_module` is called first to transform the incoming data into a consistent form. A common implementation is as follows:

```python
    def _run_custom_data_module(self, df, derived_data, model_name):
        df, my_derived_data = self.datamodule.prepare_new_data(df, ignore_absence=True)
        derived_data = derived_data.copy()
        derived_data.update(my_derived_data)
        derived_data = self.datamodule.sort_derived_data(derived_data)
        return df, derived_data, self.datamodule
```

### Output normalization

The functionality is provided in `AbstractNN`. Different normalizations are used for different tasks: `torch.nn.Identity()` for regression so that nothing is done on the output, and `torch.nn.Softmax(dim=-1)` for multi-class classification and `torch.nn.Sigmoid()` for binary classification to calculate probabilities from logits. For example, a model will always return positive predictions using the following code:

```python
class TabNetNNInherited(TabNetNN):
    def output_norm(self, y_pred):
        return torch.abs(y_pred)
```

**Remark**: Normalization is not related to the calculation of the loss function.

### Loss function

The functionality is provided in `AbstractNN`. By default, `torch.nn.BCEWithLogitsLoss()` is used for binary classification; `torch.nn.CrossEntropyLoss()` is used for multi-class classification; `torch.nn.MSELoss()` (`loss=="mse"`) or `torch.nn.L1Loss()` (`loss=="mae"`) is used for regression. For example, a model with the following code uses `torch.nn.SmoothL1Loss`:

```python
class TabNetNNInherited(TabNetNN):
    def loss_fn(self, y_pred, y_true, *data, **kwargs):
        return torch.nn.SmoothL1Loss()(y_pred, y_true)
```

`before_loss_fn` is called before calling `loss_fn` to transform the output (from `forward`) and the target to the desired format. Correspondingly to the default `loss_fn` (`self.default_loss_fn` returned by `AbstractNN.get_loss_fn`), a common implementation of `before_loss_fn` is as follows to meet the need of `torch.nn.BCEWithLogitsLoss()` and `torch.nn.CrossEntropyLoss()`:

```python
class TabNetNNInherited(TabNetNN):
    def before_loss_fn(self, y, yhat):
        if self.task == "binary":
            y = torch.flatten(y)
            yhat = torch.flatten(yhat)
        elif self.task == "multiclass":
            yhat = torch.flatten(yhat).long()
        return y, yhat
```

### `pytorch_lighting` functionalities

`AbstractNN` is based on `pytorch_lightning.LightningModule`, so most methods of `LightningModule` can be directly used for `AbstractNN`. Note that some of those methods are already implemented, like `training_step`, `validation_step`, and `configure_optimizers`. Others like `on_train_start`, `on_train_epoch_end`, etc. will be automatically called by `pytorch_lightning.Trainer`. See the original [instructions](https://lightning.ai/docs/pytorch/stable/) for advanced usage.

### Backward propagation

The functionality is provided in `AbstractNN`. With the loss value returned by `loss_fn` (or registered attributes during calling `loss_fn`) and optimizers returned by `configure_optimizers`, backward propagation and optimization are performed. The default implementation is as follows where only one optimizer and one loss item are used:

```python
class TabNetNNInherited(TabNetNN):
    def cal_backward_step(self, loss):
        self.manual_backward(loss)
        opt = self.optimizers()
        opt.step()
```

`self.manual_backward` should be used instead of `loss.backward` due to the requirement of `LightningModule`.

### Early stopping criterion

The functionality is provided in `AbstractNN`. Early stopping is used to reduce over-fitting risks. `_early_stopping_eval` returns the monitored value of early stopping. By default, the validation loss is returned:

```python
class TabNetNNInherited(TabNetNN):
    def _early_stopping_eval(self, train_loss: float, val_loss: float) -> float:
        return val_loss + 0.0 * train_loss
```

The second term is used to identify NaN in the training loss.