Following [this](https://pytorch-lightning.readthedocs.io/en/stable/starter/introduction.html)

In [4]:
# Define a LightningModule

import os
from torch import optim, nn, utils, Tensor
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
import pytorch_lightning as pl  # NOTE.

# define any number of nn.Modules (or use your current ones)
encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))

# define the LightningModule
class LitAutoEncoder(pl.LightningModule):  # NOTE.
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def training_step(self, batch, batch_idx):  # NOTE.
        # training_step defines the train loop.
        # it is independent of forward
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = nn.functional.mse_loss(x_hat, x)
        # Logging to TensorBoard by default
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):  # NOTE.
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        return optimizer


# init the autoencoder
autoencoder = LitAutoEncoder(encoder, decoder)

In [5]:
# Define a dataset
# Lightning supports ANY iterable (DataLoader, numpy, etc…) for the train/val/test/predict splits.

dataset = MNIST(os.getcwd(), download=True, transform=ToTensor())
train_loader = utils.data.DataLoader(dataset)  # pyright: ignore

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /mnt/data-linux/Dropbox/Programming/wsl_repos/practice_py/lightning_tutorials/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting /mnt/data-linux/Dropbox/Programming/wsl_repos/practice_py/lightning_tutorials/MNIST/raw/train-images-idx3-ubyte.gz to /mnt/data-linux/Dropbox/Programming/wsl_repos/practice_py/lightning_tutorials/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to /mnt/data-linux/Dropbox/Programming/wsl_repos/practice_py/lightning_tutorials/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting /mnt/data-linux/Dropbox/Programming/wsl_repos/practice_py/lightning_tutorials/MNIST/raw/train-labels-idx1-ubyte.gz to /mnt/data-linux/Dropbox/Programming/wsl_repos/practice_py/lightning_tutorials/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to /mnt/data-linux/Dropbox/Programming/wsl_repos/practice_py/lightning_tutorials/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting /mnt/data-linux/Dropbox/Programming/wsl_repos/practice_py/lightning_tutorials/MNIST/raw/t10k-images-idx3-ubyte.gz to /mnt/data-linux/Dropbox/Programming/wsl_repos/practice_py/lightning_tutorials/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to /mnt/data-linux/Dropbox/Programming/wsl_repos/practice_py/lightning_tutorials/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting /mnt/data-linux/Dropbox/Programming/wsl_repos/practice_py/lightning_tutorials/MNIST/raw/t10k-labels-idx1-ubyte.gz to /mnt/data-linux/Dropbox/Programming/wsl_repos/practice_py/lightning_tutorials/MNIST/raw



### Train the model

The Lightning `Trainer` “mixes” any `LightningModule` with any dataset and abstracts away all the engineering complexity needed for scale.

The Lightning Trainer automates [40+ tricks](https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html#trainer-flags) including:
* Epoch and batch iteration
* optimizer.step(), loss.backward(), optimizer.zero_grad() calls
* Calling of model.eval(), enabling/disabling grads during evaluation
* Checkpoint Saving and Loading
* Tensorboard (see loggers options)
* Multi-GPU support
* TPU
* 16-bit precision AMP support


In [6]:
# train the model (hint: here are some helpful Trainer arguments for rapid idea iteration)
trainer = pl.Trainer(limit_train_batches=100, max_epochs=1)
trainer.fit(model=autoencoder, train_dataloaders=train_loader)

GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(
Missing logger folder: /mnt/data-linux/Dropbox/Programming/wsl_repos/practice_py/lightning_tutorials/lightning_logs

  | Name    | Type       | Params
---------------------------------------
0 | encoder | Sequential | 50.4 K
1 | decoder | Sequential | 51.2 K
---------------------------------------
101 K     Trainable params
0         Non-trainable params
101 K     Total params
0.407     Total estimated model params size (MB)
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=1` reached.


In [7]:
# Use the model

# Once you’ve trained the model you can export to onnx, torchscript and put it into production or
# simply load the weights and run predictions.

# load checkpoint
checkpoint = "./lightning_logs/version_0/checkpoints/epoch=0-step=100.ckpt"  # NOTE.
autoencoder = LitAutoEncoder.load_from_checkpoint(checkpoint, encoder=encoder, decoder=decoder)

# choose your trained nn.Module
encoder = autoencoder.encoder
encoder.eval()

# embed 4 fake images!
fake_image_batch = Tensor(4, 28 * 28)
embeddings = encoder(fake_image_batch)
print("⚡" * 20, "\nPredictions (4 image embeddings):\n", embeddings, "\n", "⚡" * 20)

⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡ 
Predictions (4 image embeddings):
 tensor([[-5.3405e+29,  8.6866e+29,  2.0069e+29],
        [-4.5485e+34,  8.2463e+36,  2.9913e+36],
        [        nan,         nan,         nan],
        [        nan,         nan,         nan]], grad_fn=<AddmmBackward0>) 
 ⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡


### Visualize training

Lightning comes with a *lot* of batteries included. A helpful one is Tensorboard for visualizing experiments.

Run this on your commandline and open your browser to http://localhost:6006/

```sh
tensorboard --logdir .
```

In [14]:
# Supercharge training

# Enable advanced training features using Trainer arguments.
# These are state-of-the-art techniques that are automatically integrated into your training loop
# without changes to your code.

# train on 4 GPUs
trainer = pl.Trainer(
    devices=1,
    accelerator="gpu",
    # --
    max_epochs=3, # For illustration speed.
    limit_train_batches=100,  # For illustration speed.
)
trainer.fit(model=autoencoder, train_dataloaders=train_loader)

print()
print("=" * 80)
print()

# 20+ helpful flags for rapid idea iteration
trainer = pl.Trainer(
    max_epochs=10,
    min_epochs=5,
    overfit_batches=1,
    # --
    limit_train_batches=100,  # For illustration speed.
    log_every_n_steps=1,  # For illustration speed.
)

trainer.fit(model=autoencoder, train_dataloaders=train_loader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type       | Params
---------------------------------------
0 | encoder | Sequential | 50.4 K
1 | decoder | Sequential | 51.2 K
---------------------------------------
101 K     Trainable params
0         Non-trainable params
101 K     Total params
0.407     Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=3` reached.
GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(overfit_batches=1)` was configured so 1 batch will be used.

  | Name    | Type       | Params
---------------------------------------
0 | encoder | Sequential | 50.4 K
1 | decoder | Sequential | 51.2 K
---------------------------------------
101 K     Trainable params
0         Non-trainable params
101 K     Total params
0.407     Total estimated model params size (MB)






Training: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=10` reached.


More examples:

```python
# train on 4 GPUs
trainer = Trainer(
    devices=4,
    accelerator="gpu",
 )

# train 1TB+ parameter models with Deepspeed/fsdp
trainer = Trainer(
    devices=4,
    accelerator="gpu",
    strategy="deepspeed_stage_2",
    precision=16
 )

# 20+ helpful flags for rapid idea iteration
trainer = Trainer(
    max_epochs=10,
    min_epochs=5,
    overfit_batches=1
 )

# access the latest state of the art techniques
trainer = Trainer(callbacks=[StochasticWeightAveraging(...)])
```

### Maximize flexibility

Lightning’s core guiding principle is to always provide maximal flexibility without ever hiding any of the PyTorch.

Lightning offers 5 added degrees of flexibility depending on your project’s complexity.

#### Customize training loop
Inject custom code anywhere in the Training loop using any of the 20+ methods (Hooks) available in the LightningModule.

<img src="./assets/custom_loop.png" alt="drawing" style="width:700px;"/>

```python
class LitAutoEncoder(pl.LightningModule):
    def backward(self, loss, optimizer, optimizer_idx):
        loss.backward()
```
* Hooks: https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html#lightning-hooks

#### https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html#lightning-hooks
If you have multiple lines of code with similar functionalities, you can use callbacks to easily group them together and toggle all of those lines on or off at the same time.
![img]{ width: 200px; }(./assets/ext_loop.png)
```python
trainer = Trainer(callbacks=[AWSCheckpoints()])
```

#### Use a raw PyTorch loop
For certain types of work at the bleeding-edge of research, Lightning offers experts full control of their training loops in various ways.

ℹ️ See tutorials:
* Manual optimization: https://pytorch-lightning.readthedocs.io/en/stable/model/build_model_advanced.html#manual-optimization
* Lightning lite: https://pytorch-lightning.readthedocs.io/en/stable/model/build_model_expert.html
* Loops: https://pytorch-lightning.readthedocs.io/en/stable/extensions/loops.html
