# Modular coding when creating machine learning models

## Why modularity matters?

A typical machine learning model training script looks something like this:

1. Start the training script from the command line
2. Set run configuration based on command line flags or configuration files
3. Specify data sources for training and validation
4. Construct model based on configuration
5. Feed data to model during training
6. Log training information and checkpoints
7. Save model once a stopping condition is reached

```mermaid
flowchart TD
    CL[Start the training script from the command line] --> C[Set run configuration based on command line flags or configuration files]   
    C --> D[Specify data sources for training and validation];
    D --> M[Construct model based on configuration];
    M --> T[Feed data to model during training];
    T --> L[Log training information and checkpoints];
    L --> T;
    L --> O[Save model once a stopping condition is reached]; 
```
**Figure 1: Typical machine learning training script**

Now you could write and often many do write the whole procedure into a single script that can be read from the top to the bottom.

However, this can cause serious problems with reproducibility:

- Copy pasting code across multiple scripts (command line interfaces (CLI); configuration handling; training, logging and checkpointing codes)
- Code is not shared across multiple experiments, so fixing a bug in one experiment might leave the bug in place in another
- Combining multiple datasets is difficult as there isn't a clear idea of what kind of data your model wants 
- Data loading, model creation and training codes can mix together, which can cause problems with the frameworks
- It is hard to keep track of different experiments

## Modularity is used by every framework out there

To solve the previously mentioned problems, all frameworks provide tools that make it easy to write the code in a modular form. In fact, they expect you to utilize these tools.

Typically each framework has their own specification of:

1. Dataset specification that specifies how to load a single data point from the stored files.
2. Data loader that creates batches out of the dataset
3. Model specification that defines the layers of the model and the forward pass
4. Optimizers that specify how the model weights should be tuned
5. Model trainer that trains the model with a chosen dataset and optimizer

Many frameworks build on top of existing features like PyTorch's [torch.nn.Module](https://docs.pytorch.org/docs/stable/generated/torch.nn.Module.html) so they just extend the functionality of those modules.

Here are some examples of these modules:

| Framework | Dataset specification | Data loader | Model specification | Optimizers | Model trainer | 
| --------- | --------------------- | ----------- | ------------------- | ---------- | ------------- |
| [PyTorch](https://docs.pytorch.org/tutorials/beginner/nn_tutorial.html#what-is-torch-nn-really) | [torch.utils.data.Dataset](https://docs.pytorch.org/docs/stable/data.html#torch.utils.data.Dataset) | [torch.utils.data.DataLoader](https://docs.pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader) | [torch.nn.Module](https://docs.pytorch.org/docs/stable/generated/torch.nn.Module.html) | [torch.optim](https://docs.pytorch.org/docs/stable/optim.html) |  - |
| [HuggingFace](https://huggingface.co/docs) | [Datasets](https://huggingface.co/docs/datasets/main/en/create_dataset) | [Datasets](https://huggingface.co/docs/datasets/main/en/create_dataset) with [torch.utils.data.DataLoader](https://docs.pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader) | [transformers.PretrainedConfig](https://huggingface.co/docs/transformers/main/en/custom_models#configuration) & [transformers.PretrainedModel](https://huggingface.co/docs/transformers/main/en/custom_models#model) (see also [huggingface_hub.PyTorchModelHubMixin](https://huggingface.co/docs/huggingface_hub/main/en/package_reference/mixins#huggingface_hub.PyTorchModelHubMixin)) | Provided by [Trainer](https://huggingface.co/docs/transformers/main/en/optimizers) | [transformers.Trainer](https://huggingface.co/docs/transformers/trainer) |
| [PyTorch Lightning](https://lightning.ai/docs/pytorch/stable/) | [lightning.DataModule](https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.core.LightningDataModule.html#lightning.pytorch.core.LightningDataModule) wraps [torch.utils.data.Dataset](https://docs.pytorch.org/docs/stable/data.html#torch.utils.data.Dataset) | [lightning.DataModule](https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.core.LightningDataModule.html#lightning.pytorch.core.LightningDataModule) wraps [torch.utils.data.DataLoader](https://docs.pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader) | [lightning.LightningModule](https://lightning.ai/docs/pytorch/stable/common/lightning_module.html) wraps [torch.nn.Module](https://docs.pytorch.org/docs/stable/generated/torch.nn.Module.html) | - | [lightning.pytorch.Trainer](https://lightning.ai/docs/pytorch/stable/common/trainer.html) |

The idea behind these modules is to create clear delineation between different functionalities.

### Dataset & data loader modularity

```mermaid
flowchart LR  
    D1[Dataset 1] --> DL[Data loader];
    D2[Dataset 2] --> DL;
    D3[Dataset ...] --> DL;
    DL --> B[Batch];
```
**Having a dataset that provides data in a consistent way makes it possible to switch data sources**

```mermaid
flowchart LR  
    D1[Dataset instance 1] --> DL1[Data loader process 1];
    D2[Dataset instance 2] --> DL2[Data loader process 2];
    D3[Dataset instance ...] --> DL3[Data loader process ...];
    DL1 --> DC[Data loader collator];
    DL2 --> DC;
    DL3 --> DC;
    DC --> B[Batch];
```
**When dataset provides data a single sample at a time, data loader can parallelize data loading**

```mermaid
flowchart LR  
    D1[Dataset] --> DL[Data loader];
    DL --> T1[Transform process 1];
    DL --> T2[Transform process 2];
    DL --> T3[Transform process ...];
    T1 --> DC[Data loader collator];
    T2 --> DC;
    T3 --> DC;
    DC --> B[Batch];
```
**When a single dataset is used, data loading can parallelize data transforms (e.g. augmentation, normalization)**

### Model modularity

```mermaid
flowchart LR
    DL[Data loader] --> T[Trainer];
    M1[Model 1] --> T;
    T --> O[Model 1 outputs];
    DL2[Data loader] --> T2[Trainer];
    M2[Model 2] --> T2;
    T2 --> O2[Model 2 outputs];
```
**When model is written as a separate module, it can be modified while re-using the other parts**

### Optimizer modularity

```mermaid
flowchart LR
    DL[Data loader] --> T[Trainer];
    M[Model] --> T;
    Op1[Optimizer 1] --> T;
    T --> O[Model outputs with optimizer 1];
    DL2[Data loader] --> T2[Trainer];
    M2[Model] --> T2;
    Op2[Optimizer 2] --> T2;
    T2 --> O2[Model outputs with optimizer 2];
```
**When optimizer is written as a separate module, it can be modified while re-using the other parts**

### Trainer modularity

```mermaid
flowchart LR
    DL[Data loader] --> T1[Trainer with configuration 1];
    M[Model] --> T1;
    Op1[Optimizer] --> T1;
    T1 --> O[Model outputs with trainer configuration 1];
    DL2[Data loader] --> T2[Trainer with configuration 2];
    M2[Model] --> T2;
    Op2[Optimizer] --> T2;
    T2 --> O2[Model outputs with trainer configuration 2];
```
**When trainer is written as a separate module, it can be modified while re-using the other parts**

## Modular machine learning training script

So what parts are in a machine learning training script:

```mermaid
mindmap
    root((Training script))
        Dataset
            Specifies what the data is
        Data loader
            Specifies how data is loaded and transformed
        Model
            Specifies the model strucuture
        Optimizer
            Specifies how the strategy how model will be trained
        Trainer
            Specifies how the model is trained
        CLI & Configuration
            Specifies command line interface and configuration reading
```

## Modular does not always mean object oriented programming

Modular coding does not necessarily mean that everything needs to be written in classes. Whether you should use classes or functions depends on your preference and on the frameworks preference. Typically the datasets and model specifications are given as classes as they need to inherit features from the framework.