LIGHTNING IN 15 MINUTES

Required background: None

Goal: In this guide, we’ll walk you through the 7 key steps of a typical Lightning workflow.

PyTorch Lightning is the deep learning framework with “batteries included” for professional AI researchers and machine learning engineers who need maximal flexibility while super-charging performance at scale.

Lightning organizes PyTorch code to remove boilerplate and unlock scalability.

In [1]:
# # For pip users
#
# pip install lightning


# # For conda users
#
# conda install lightning -c conda-forge

2: Define a LightningModule

A LightningModule enables your PyTorch nn.Module to play together in complex ways inside the training_step (there is also an optional validation_step and test_step).

In [2]:
import os
from torch import optim, nn, utils, Tensor
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
import lightning.pytorch as pl

# define any number of nn.Modules (or use your current ones)
encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))


# define the LightningModule
class LitAutoEncoder(pl.LightningModule):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        # it is independent of forward
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = nn.functional.mse_loss(x_hat, x)
        # Logging to TensorBoard (if installed) by default
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        return optimizer


# init the autoencoder
autoencoder = LitAutoEncoder(encoder, decoder)

3: Define a dataset

Lightning supports ANY iterable (DataLoader, numpy, etc…) for the train/val/test/predict splits.

In [3]:
# setup data
dataset = MNIST(os.getcwd(), download=True, transform=ToTensor())
train_loader = utils.data.DataLoader(dataset)

4: Train the model

The Lightning Trainer “mixes” any LightningModule with any dataset and abstracts away all the engineering complexity needed for scale.

In [4]:
# train the model (hint: here are some helpful Trainer arguments for rapid idea iteration)
trainer = pl.Trainer(limit_train_batches=100, max_epochs=1)
trainer.fit(model=autoencoder, train_dataloaders=train_loader)

# in accordance to a previous run the following file has been manually adapted. on line 428
# /Users/stephandekker/miniconda3/envs/miw2/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name    | Type       | Params
---------------------------------------
0 | encoder | Sequential | 50.4 K
1 | decoder | Sequential | 51.2 K
---------------------------------------
101 K     Trainable params
0         Non-trainable params
101 K     Total params
0.407     Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=1` reached.


The Lightning Trainer automates 40+ tricks including:

Epoch and batch iteration
optimizer.step(), loss.backward(), optimizer.zero_grad() calls
Calling of model.eval(), enabling/disabling grads during evaluation
Checkpoint Saving and Loading
Tensorboard (see loggers options)
Multi-GPU support
TPU
16-bit precision AMP support

5: Use the model

Once you’ve trained the model you can export to onnx, torchscript and put it into production or simply load the weights and run predictions.

In [5]:
# load checkpoint
checkpoint = "./lightning_logs/version_0/checkpoints/epoch=0-step=100.ckpt"
autoencoder = LitAutoEncoder.load_from_checkpoint(checkpoint, encoder=encoder, decoder=decoder)

# choose your trained nn.Module
encoder = autoencoder.encoder
encoder.eval()

# embed 4 fake images!
fake_image_batch = Tensor(4, 28 * 28)
embeddings = encoder(fake_image_batch)
print("⚡" * 20, "\nPredictions (4 image embeddings):\n", embeddings, "\n", "⚡" * 20)

⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡ 
Predictions (4 image embeddings):
 tensor([[ 2.9280e+36, -2.6702e+36, -3.9303e+36],
        [-1.6299e+33, -1.8654e+33, -2.1126e+33],
        [ 1.1594e+33, -1.2254e+33,  8.7130e+32],
        [        nan,         nan,         nan]], grad_fn=<AddmmBackward0>) 
 ⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡


-----------------------------------------------------------------------

6: Visualize training

If you have tensorboard installed, you can use it for visualizing experiments.

Run this on your commandline and open your browser to http://localhost:6006/

In [6]:
# tensorboard --logdir .

7: Supercharge training

Enable advanced training features using Trainer arguments. These are state-of-the-art techniques that are automatically integrated into your training loop without changes to your code.

In [23]:
# train on 4 GPUs
# # doesnt work on a Mac.
# # MisconfigurationException: You requested gpu: [0, 1]
# # But your machine only has: [0]
# trainer = pl.Trainer(
#     devices=1,
#     accelerator="gpu",
#  )

# # train 1TB+ parameter models with Deepspeed/fsdp
# trainer = pl.Trainer(
#     # devices=4,  # MPS available but not used. Set `accelerator` and `devices` using `Trainer(accelerator='mps', devices=1)`.
#     devices=1,
#     # accelerator="cpu",
#     accelerator="mps",
#     # strategy="deepspeed_stage_2",  # You set `strategy=deepspeed_stage_2` but strategies from the DDP family are not supported on the MPS accelerator. Either explicitly set `accelerator='cpu'` or change the strategy.
#     # strategy='ddp_notebook',
#     # precision=16  # but AMP with fp16 is not supported on CPU. Using `precision='bf16-mixed'` instead.
#     precision='bf16-mixed'
 )

# 20+ helpful flags for rapid idea iteration
trainer = pl.Trainer(
    max_epochs=10,
    min_epochs=5,
    overfit_batches=1
 )
#
# # access the latest state of the art techniques
# trainer = pl.Trainer(callbacks=[StochasticWeightAveraging(...)])