Following [this](https://pytorch-lightning.readthedocs.io/en/stable/common/hyperparameters.html)

## Configure hyperparameters from the CLI

Lightning has utilities to interact seamlessly with the command line `ArgumentParser` and plays well with the
hyperparameter optimization framework of your choice.

### ArgumentParser

Lightning is designed to augment a lot of the functionality of the built-in Python `ArgumentParser`.

```python
from argparse import ArgumentParser

parser = ArgumentParser()  # Using the built in ArgumentParser.
parser.add_argument("--layer_1_dim", type=int, default=128)
args = parser.parse_args()
```

Allows:
```sh
python trainer.py --layer_1_dim 64
```

### Argparser Best Practices

ℹ️ **Useful idea**

It is best practice to layer your arguments in three sections.
* Trainer args (`accelerator`, `devices`, `num_nodes`, etc…)
* Model specific arguments (`layer_dim`, `num_layers`, `learning_rate`, etc…)
* Program arguments (`data_path`, `cluster_email`, etc…)

We can do this as follows.

1.
First, in your `LightningModule`, define the arguments specific to that module.
Remember that data splits or data paths may also be specific to a module
(i.e.: if your project has a model that trains on Imagenet and another on CIFAR-10).
```python
class LitModel(LightningModule):
    @staticmethod
    def add_model_specific_args(parent_parser):  # NOTE This. parent_parser is passed in.
        parser = parent_parser.add_argument_group("LitModel")
        parser.add_argument("--encoder_layers", type=int, default=12)
        parser.add_argument("--data_path", type=str, default="/some/path")
        return parent_parser
```

2.
Now in your main trainer file, add the Trainer args, the program args, and add the model args.
```python
# ----------------
# trainer_main.py
# ----------------
from argparse import ArgumentParser

parser = ArgumentParser()

# add PROGRAM level args
parser.add_argument("--conda_env", type=str, default="some_name")
parser.add_argument("--notification_email", type=str, default="will@email.com")

# add model specific args
parser = LitModel.add_model_specific_args(parser)  # NOTE.

# add all the available trainer options to argparse
# ie: now --accelerator --devices --num_nodes ... --fast_dev_run all work in the cli
parser = Trainer.add_argparse_args(parser)  # NOTE.

args = parser.parse_args()
```

Now you can call run your program like so:
```python
# init the trainer like this
trainer = Trainer.from_argparse_args(args, early_stopping_callback=...)  # NOTE: Like so.

# NOT like this
trainer = Trainer(accelerator=hparams.accelerator, devices=hparams.devices, ...)  # NOTE: NOT like so.

# init the model with Namespace directly
model = LitModel(args)

# or init the model with all the key-value pairs
# NOTE this pattern.
dict_args = vars(args)
model = LitModel(**dict_args)
```

### LightningModule hyperparameters

Often times we train many versions of a model. You might share that model or come back to it a few months later at which point it is very useful to know how that model was trained (i.e.: what learning rate, neural network, etc…).

Lightning has a standardized way of saving the information for you in checkpoints and YAML files. The goal here is to improve readability and reproducibility.

#### `save_hyperparameters()`

Use `save_hyperparameters()` within your `LightningModule`’s `__init__` method.

It will enable Lightning to store all the provided arguments under the `self.hparams` attribute.
These hyperparameters will also be stored within the model checkpoint, which simplifies model re-instantiation after training.

⚠️ Note the below carefully - relevant to things like delayed initialisation of parameters.
```python
class LitMNIST(LightningModule):
    def __init__(self, layer_1_dim=128, learning_rate=1e-2):
        super().__init__()
        # call this to save (layer_1_dim=128, learning_rate=1e-4) to the checkpoint
        self.save_hyperparameters()

        # equivalent
        self.save_hyperparameters("layer_1_dim", "learning_rate")

        # Now possible to access layer_1_dim from hparams
        self.hparams.layer_1_dim
```

In addition, loggers that support it will automatically log the contents of `self.hparams`.


#### Excluding hyperparameters

By default, every parameter of the `__init__` method will be considered a hyperparameter to the `LightningModule`.

However, sometimes some parameters need to be excluded from saving, for example when they are not serializable.

Those parameters should be provided back when reloading the `LightningModule`.

In this case, exclude them explicitly:

```python
class LitMNIST(LightningModule):
    def __init__(self, loss_fx, generator_network, layer_1_dim=128):
        super().__init__()
        self.layer_1_dim = layer_1_dim
        self.loss_fx = loss_fx

        # call this to save only (layer_1_dim=128) to the checkpoint
        self.save_hyperparameters("layer_1_dim")  # NOTE.

        # equivalent
        self.save_hyperparameters(ignore=["loss_fx", "generator_network"])  # NOTE.
```

#### `load_from_checkpoint()`

`LightningModules` that have hyperparameters automatically saved with `save_hyperparameters()` can conveniently be
loaded and instantiated directly from a checkpoint with `load_from_checkpoint()`.

If parameters were excluded, they need to be provided at the time of loading.

```python
# to load specify *the other args*
# the excluded parameters were `loss_fx` and `generator_network`
model = LitMNIST.load_from_checkpoint(PATH, loss_fx=torch.nn.SomeOtherLoss, generator_network=MyGenerator())
```

### Trainer args

To recap, add ALL possible trainer flags to the argparser and init the `Trainer` this way:

```python
parser = ArgumentParser()
parser = Trainer.add_argparse_args(parser)
hparams = parser.parse_args()

trainer = Trainer.from_argparse_args(hparams)

# or if you need to pass in callbacks
trainer = Trainer.from_argparse_args(hparams, enable_checkpointing=..., callbacks=[...])
```

### Multiple Lightning Modules

We often have multiple Lightning Modules where each one has different arguments.

Instead of polluting the `main.py` file, the `LightningModule` lets you define arguments for each one.

```python
class LitMNIST(LightningModule):
    def __init__(self, layer_1_dim, **kwargs):
        super().__init__()
        self.layer_1 = nn.Linear(28 * 28, layer_1_dim)

    # NOTE:
    @staticmethod
    def add_model_specific_args(parent_parser):
        parser = parent_parser.add_argument_group("LitMNIST")
        parser.add_argument("--layer_1_dim", type=int, default=128)
        return parent_parser

class GoodGAN(LightningModule):
    def __init__(self, encoder_layers, **kwargs):
        super().__init__()
        self.encoder = Encoder(layers=encoder_layers)

    # NOTE:
    @staticmethod
    def add_model_specific_args(parent_parser):
        parser = parent_parser.add_argument_group("GoodGAN")
        parser.add_argument("--encoder_layers", type=int, default=12)
        return parent_parser
```

Now we can allow each model to inject the arguments it needs in the `main.py`:
```python
def main(args):
    dict_args = vars(args)

    # pick model
    if args.model_name == "gan":
        model = GoodGAN(**dict_args)
    elif args.model_name == "mnist":
        model = LitMNIST(**dict_args)

    trainer = Trainer.from_argparse_args(args)
    trainer.fit(model)


if __name__ == "__main__":
    parser = ArgumentParser()
    parser = Trainer.add_argparse_args(parser)

    # figure out which model to use
    parser.add_argument("--model_name", type=str, default="gan", help="gan or mnist")

    # THIS LINE IS KEY TO PULL THE MODEL NAME
    temp_args, _ = parser.parse_known_args()

    # let the model add what it wants
    if temp_args.model_name == "gan":
        parser = GoodGAN.add_model_specific_args(parser)
    elif temp_args.model_name == "mnist":
        parser = LitMNIST.add_model_specific_args(parser)

    args = parser.parse_args()

    # train
    main(args)
```

⚠️ BTW, it looks like in this example it's literally just *this model or that model*, they're not loading both...

```sh
$ python main.py --model_name gan --encoder_layers 24
$ python main.py --model_name mnist --layer_1_dim 128
```