Skip to content

Latest commit

 

History

History
1007 lines (690 loc) · 27.3 KB

introduction_guide.rst

File metadata and controls

1007 lines (690 loc) · 27.3 KB
.. testsetup:: *

    from pytorch_lightning.core.lightning import LightningModule
    from pytorch_lightning.trainer.trainer import Trainer


Step-by-step walk-through

PyTorch Lightning provides a very simple template for organizing your PyTorch code. Once you've organized it into a LightningModule, it automates most of the training for you.

To illustrate, here's the typical PyTorch project structure organized in a LightningModule.

As your project grows in complexity with things like 16-bit precision, distributed training, etc... the part in blue quickly becomes onerous and starts distracting from the core research code.


Goal of this guide

This guide walks through the major parts of the library to help you understand what each part does. But at the end of the day, you write the same PyTorch code... just organize it into the LightningModule template which means you keep ALL the flexibility without having to deal with any of the boilerplate code

To show how Lightning works, we'll start with an MNIST classifier. We'll end showing how to use inheritance to very quickly create an AutoEncoder.

Note

Any DL/ML PyTorch project fits into the Lightning structure. Here we just focus on 3 types of research to illustrate.


Installing Lightning

Lightning is trivial to install.

conda activate my_env
pip install pytorch-lightning

Or without conda environments, anywhere you can use pip.

pip install pytorch-lightning

Or with conda

conda install pytorch-lightning -c conda-forge

Lightning Philosophy

Lightning factors DL/ML code into three types:

  • Research code
  • Engineering code
  • Non-essential code

Research code

In the MNIST generation example, the research code would be the particular system and how it's trained (ie: A GAN or VAE). In Lightning, this code is abstracted out by the LightningModule.

l1 = nn.Linear(...)
l2 = nn.Linear(...)
decoder = Decoder()

x1 = l1(x)
x2 = l2(x2)
out = decoder(features, x)

loss = perceptual_loss(x1, x2, x) + CE(out, x)

Engineering code

The Engineering code is all the code related to training this system. Things such as early stopping, distribution over GPUs, 16-bit precision, etc. This is normally code that is THE SAME across most projects.

In Lightning, this code is abstracted out by the Trainer.

model.cuda(0)
x = x.cuda(0)

distributed = DistributedParallel(model)

with gpu_zero:
    download_data()

dist.barrier()

Non-essential code

This is code that helps the research but isn't relevant to the research code. Some examples might be: 1. Inspect gradients 2. Log to tensorboard.

In Lightning this code is abstracted out by Callbacks.

# log samples
z = Q.rsample()
generated = decoder(z)
self.experiment.log('images', generated)

Elements of a research project

Every research project requires the same core ingredients:

  1. A model
  2. Train/val/test data
  3. Optimizer(s)
  4. Training step computations
  5. Validation step computations
  6. Test step computations

The Model

The LightningModule provides the structure on how to organize these 5 ingredients.

Let's first start with the model. In this case we'll design a 3-layer neural network.

.. testcode::

    import torch
    from torch.nn import functional as F
    from torch import nn
    from pytorch_lightning.core.lightning import LightningModule

    class LitMNIST(LightningModule):

      def __init__(self):
        super().__init__()

        # mnist images are (1, 28, 28) (channels, width, height)
        self.layer_1 = torch.nn.Linear(28 * 28, 128)
        self.layer_2 = torch.nn.Linear(128, 256)
        self.layer_3 = torch.nn.Linear(256, 10)

      def forward(self, x):
        batch_size, channels, width, height = x.size()

        # (b, 1, 28, 28) -> (b, 1*28*28)
        x = x.view(batch_size, -1)

        # layer 1
        x = self.layer_1(x)
        x = torch.relu(x)

        # layer 2
        x = self.layer_2(x)
        x = torch.relu(x)

        # layer 3
        x = self.layer_3(x)

        # probability distribution over labels
        x = torch.log_softmax(x, dim=1)

        return x

Notice this is a LightningModule instead of a torch.nn.Module. A LightningModule is equivalent to a PyTorch Module except it has added functionality. However, you can use it EXACTLY the same as you would a PyTorch Module.

.. testcode::

    net = LitMNIST()
    x = torch.Tensor(1, 1, 28, 28)
    out = net(x)

.. rst-class:: sphx-glr-script-out

 Out:

 .. code-block:: python

    torch.Size([1, 10])

Data

Lightning operates on pure dataloaders. Here's the PyTorch code for loading MNIST.

.. testcode::
    :skipif: not TORCHVISION_AVAILABLE

    from torch.utils.data import DataLoader, random_split
    from torchvision.datasets import MNIST
    import os
    from torchvision import datasets, transforms

    # transforms
    # prepare transforms standard to MNIST
    transform=transforms.Compose([transforms.ToTensor(),
                                  transforms.Normalize((0.1307,), (0.3081,))])

    # data
    mnist_train = MNIST(os.getcwd(), train=True, download=True)
    mnist_train = DataLoader(mnist_train, batch_size=64)

.. testoutput::
    :hide:
    :skipif: os.path.isdir(os.path.join(os.getcwd(), 'MNIST')) or not TORCHVISION_AVAILABLE

    Downloading ...
    Extracting ...
    Downloading ...
    Extracting ...
    Downloading ...
    Extracting ...
    Processing...
    Done!

There's nothing special you need to do with PyTorch Lightning! Just pass in the dataloaders to the .fit() function.

model = LitMNIST()
trainer = Trainer()
trainer.fit(model, mnist_train)
DataModules

Defining free-floating dataloaders, splits, download instructions and such can get messy. In this case, it's better to group the full definition of a dataset into a DataModule which includes:

  • Download instructions
  • Processing instructions
  • Split instructions
  • Train dataloader
  • Val dataloader(s)
  • Test dataloader(s)
class MyDataModule(pl.DataModule):

    def __init__(self):
        super().__init__()
        self.train_dims = None
        self.vocab_size = 0

    def prepare_data(self):
        # called only on 1 GPU
        download_dataset()
        tokenize()
        build_vocab()

    def setup(self):
        # called on every GPU
        vocab = load_vocab
        self.vocab_size = len(vocab)

        self.train, self.val, self.test = load_datasets()
        self.train_dims = self.train.next_batch.size()

    def train_dataloader(self):
        transforms = ...
        return DataLoader(self.train, transforms)

    def val_dataloader(self):
        transforms = ...
        return DataLoader(self.val, transforms)

    def test_dataloader(self):
        transforms = ...
        return DataLoader(self.test, transforms)

Using DataModules allows easier sharing of full dataset definitions.

# use an MNIST dataset
mnist_dm = MNISTDatamodule()
model = LitModel(num_classes=mnist.num_classes)
trainer.fit(model, mnist_dm)

# or other datasets with the same model
imagenet_dm = ImagenetDatamodule()
model = LitModel(num_classes=imagenet_dm.num_classes)
trainer.fit(model, imagenet_dm)

Note

prepare_data is called only one 1 GPU in distributed training (automatically)

Note

setup is called on every GPU (automatically)

Models defined by data

When your models need to know about the data, it's best to process the data before passing it to the model.

# init dm AND call the processing manually
dm = ImagenetDataModule()
dm.prepare_data()
dm.setup()

model = LitModel(out_features=dm.num_classes, img_width=dm.img_width, img_height=dm.img_height)
trainer.fit(model)
  1. use prepare_data to download and process the dataset.
  2. use setup to do splits, and build your model internals
.. testcode::

    class LitMNIST(LightningModule):

        def __init__(self):
            self.l1 = None

        def prepare_data(self):
            download_data()
            tokenize()

        def setup(self, step):
            # step is either 'fit' or 'test' 90% of the time not relevant
            data = load_data()
            num_classes = data.classes
            self.l1 = nn.Linear(..., num_classes)

Optimizer

Next we choose what optimizer to use for training our system. In PyTorch we do it as follows:

from torch.optim import Adam
optimizer = Adam(LitMNIST().parameters(), lr=1e-3)

In Lightning we do the same but organize it under the configure_optimizers method.

.. testcode::

    class LitMNIST(LightningModule):

        def configure_optimizers(self):
            return Adam(self.parameters(), lr=1e-3)

Note

The LightningModule itself has the parameters, so pass in self.parameters()

However, if you have multiple optimizers use the matching parameters

.. testcode::

    class LitMNIST(LightningModule):

        def configure_optimizers(self):
            return Adam(self.generator(), lr=1e-3), Adam(self.discriminator(), lr=1e-3)

Training step

The training step is what happens inside the training loop.

for epoch in epochs:
    for batch in data:
        # TRAINING STEP
        # ....
        # TRAINING STEP
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

In the case of MNIST we do the following

for epoch in epochs:
    for batch in data:
        # TRAINING STEP START
        x, y = batch
        logits = model(x)
        loss = F.nll_loss(logits, y)
        # TRAINING STEP END

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

In Lightning, everything that is in the training step gets organized under the training_step function in the LightningModule

.. testcode::

    class LitMNIST(LightningModule):

        def training_step(self, batch, batch_idx):
            x, y = batch
            logits = self(x)
            loss = F.nll_loss(logits, y)
            return loss

Again, this is the same PyTorch code except that it has been organized by the LightningModule. This code is not restricted which means it can be as complicated as a full seq-2-seq, RL loop, GAN, etc...

TrainResult

Whenever you'd like more control over the outputs of the training_step use a TrainResult object which can:

  • log to Tensorboard or the other logger of your choice.
  • log to the progress-bar.
  • log on every step.
  • log aggregate epoch metrics.
def training_step(...):
    return loss

    # equivalent
    return pl.TrainResult(loss)

    # log a metric
    result = pl.TrainResult(loss)
    result.log('train_loss', loss)

    # equivalent
    result.log('train_loss', loss, on_step=True, on_epoch=False, prog_bar=False, logger=True, reduce_fx=torch.mean)

If you are only using a training_loop (training_step) without a validation or test loop (validation_step, test_step), you can still use EarlyStopping or automatic checkpointing

result = pl.TrainResult(loss, checkpoint_on=loss, early_stop_on=loss)
return result

Training

So far we defined 4 key ingredients in pure PyTorch but organized the code with the LightningModule.

  1. Model.
  2. Training data.
  3. Optimizer.
  4. What happens in the training loop.

For clarity, we'll recall that the full LightningModule now looks like this.

class LitMNIST(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer_1 = torch.nn.Linear(28 * 28, 128)
        self.layer_2 = torch.nn.Linear(128, 256)
        self.layer_3 = torch.nn.Linear(256, 10)

    def forward(self, x):
        batch_size, channels, width, height = x.size()
        x = x.view(batch_size, -1)
        x = self.layer_1(x)
        x = torch.relu(x)
        x = self.layer_2(x)
        x = torch.relu(x)
        x = self.layer_3(x)
        x = torch.log_softmax(x, dim=1)
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)

        # using TrainResult to enable logging
        result = pl.TrainResult(loss)
        result.log('train_loss', loss)

        return result

Again, this is the same PyTorch code, except that it's organized by the LightningModule. This organization now lets us train this model

Train on CPU

from pytorch_lightning import Trainer

model = LitMNIST()
trainer = Trainer()
trainer.fit(model, train_loader)

You should see the following weights summary and progress bar

mnist CPU bar

Logging

When we added the TrainResult in the return dictionary it went into the built-in tensorboard logger. But you could have also logged by calling:

def training_step(self, batch, batch_idx):
    # ...
    loss = ...
    self.logger.summary.scalar('loss', loss, step=self.global_step)

    # equivalent
    result = TrainResult()
    result.log('loss', loss)

Which will generate automatic tensorboard logs.

mnist CPU bar

But you can also use any of the number of other loggers we support.

GPU training

But the beauty is all the magic you can do with the trainer flags. For instance, to run this model on a GPU:

model = LitMNIST()
trainer = Trainer(gpus=1)
trainer.fit(model, train_loader)
mnist GPU bar

Multi-GPU training

Or you can also train on multiple GPUs.

model = LitMNIST()
trainer = Trainer(gpus=8)
trainer.fit(model, train_loader)

Or multiple nodes

# (32 GPUs)
model = LitMNIST()
trainer = Trainer(gpus=8, num_nodes=4, distributed_backend='ddp')
trainer.fit(model, train_loader)

Refer to the distributed computing guide for more details.

TPUs

Did you know you can use PyTorch on TPUs? It's very hard to do, but we've worked with the xla team to use their awesome library to get this to work out of the box!

Let's train on Colab (full demo available here)

First, change the runtime to TPU (and reinstall lightning).

mnist GPU bar
mnist GPU bar

Next, install the required xla library (adds support for PyTorch on TPUs)

!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py !python pytorch-xla-env-setup.py --version nightly --apt-packages libomp5 libopenblas-dev

In distributed training (multiple GPUs and multiple TPU cores) each GPU or TPU core will run a copy of this program. This means that without taking any care you will download the dataset N times which will cause all sorts of issues.

To solve this problem, make sure your download code is in the prepare_data method in the DataModule. In this method we do all the preparation we need to do once (instead of on every gpu).

prepare_data can be called in two ways, once per node or only on the root node (Trainer(prepare_data_per_node=False)).

class MNISTDataModule(LightningDataModule):
    def prepare_data(self):
        # download only
        MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor())
        MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor())

    def setup(self, stage):
        # transform
        transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
        MNIST(os.getcwd(), train=True, download=False, transform=transform)
        MNIST(os.getcwd(), train=False, download=False, transform=transform)

        # train/val split
        mnist_train, mnist_val = random_split(mnist_train, [55000, 5000])

        # assign to use in dataloaders
        self.train_dataset = mnist_train
        self.val_dataset = mnist_val
        self.test_dataset = mnist_test

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=64)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=64)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=64)

The prepare_data method is also a good place to do any data processing that needs to be done only once (ie: download or tokenize, etc...).

Note

Lightning inserts the correct DistributedSampler for distributed training. No need to add yourself!

Now we can train the LightningModule on a TPU without doing anything else!

dm = MNISTDataModule()
model = LitMNIST()
trainer = Trainer(tpu_cores=8)
trainer.fit(model, dm)

You'll now see the TPU cores booting up.

TPU start

Notice the epoch is MUCH faster!

TPU speed


Validating

For most cases, we stop training the model when the performance on a validation split of the data reaches a minimum.

Just like the training_step, we can define a validation_step to check whatever metrics we care about, generate samples or add more to our logs.

Since the validation_step processes a single batch, use the EvalResult to log metrics for the full epoch.

def validation_step(self, batch, batch_idx):
    result = pl.EvalResult(checkpoint_on=loss)
    result.log('val_loss', loss)

    # equivalent
    result.log('val_loss', loss, prog_bar=False, logger=True, on_step=False, on_epoch=True, reduce_fx=torch.mean)
    return result

Now we can train with a validation loop as well.

from pytorch_lightning import Trainer

model = LitMNIST()
trainer = Trainer(tpu_cores=8)
trainer.fit(model, train_loader, val_loader)

You may have noticed the words Validation sanity check logged. This is because Lightning runs 2 batches of validation before starting to train. This is a kind of unit test to make sure that if you have a bug in the validation loop, you won't need to potentially wait a full epoch to find out.

Note

Lightning disables gradients, puts model in eval mode and does everything needed for validation.

Val loop under the hood

Under the hood, Lightning does the following:

model = Model()
model.train()
torch.set_grad_enabled(True)

for epoch in epochs:
    for batch in data:
        # ...
        # train

    # validate
    model.eval()
    torch.set_grad_enabled(False)

    outputs = []
    for batch in val_data:
        x, y = batch                        # validation_step
        y_hat = model(x)                    # validation_step
        loss = loss(y_hat, x)               # validation_step
        outputs.append({'val_loss': loss})  # validation_step

    full_loss = outputs.mean()              # validation_epoch_end

Optional methods

If you still need even more fine-grain control, define the other optional methods for the loop.

def validation_step(self, batch, batch_idx):
    val_step_output = {'step_output': x}
    return val_step_output

def validation_epoch_end(self, val_step_outputs):
    for val_step_output in val_step_outputs:
        # each object here is what you passed back at each validation_step

Testing

Once our research is done and we're about to publish or deploy a model, we normally want to figure out how it will generalize in the "real world." For this, we use a held-out split of the data for testing.

Just like the validation loop, we define a test loop

class LitMNIST(LightningModule):
    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        result = pl.EvalResult()
        result.log('test_loss', loss)
        return result

However, to make sure the test set isn't used inadvertently, Lightning has a separate API to run tests. Once you train your model simply call .test().

from pytorch_lightning import Trainer

model = LitMNIST()
trainer = Trainer(tpu_cores=8)
trainer.fit(model)

# run test set
result = trainer.test()
print(result)
.. rst-class:: sphx-glr-script-out

 Out:

 .. code-block:: none

        --------------------------------------------------------------
        TEST RESULTS
        {'test_loss': tensor(1.1703, device='cuda:0')}
        --------------------------------------------------------------

You can also run the test from a saved lightning model

model = LitMNIST.load_from_checkpoint(PATH)
trainer = Trainer(tpu_cores=8)
trainer.test(model)

Note

Lightning disables gradients, puts model in eval mode and does everything needed for testing.

Warning

.test() is not stable yet on TPUs. We're working on getting around the multiprocessing challenges.


Predicting

Again, a LightningModule is exactly the same as a PyTorch module. This means you can load it and use it for prediction.

model = LitMNIST.load_from_checkpoint(PATH)
x = torch.Tensor(1, 1, 28, 28)
out = model(x)

On the surface, it looks like forward and training_step are similar. Generally, we want to make sure that what we want the model to do is what happens in the forward. whereas the training_step likely calls forward from within it.

.. testcode::

    class MNISTClassifier(LightningModule):

        def forward(self, x):
            batch_size, channels, width, height = x.size()
            x = x.view(batch_size, -1)
            x = self.layer_1(x)
            x = torch.relu(x)
            x = self.layer_2(x)
            x = torch.relu(x)
            x = self.layer_3(x)
            x = torch.log_softmax(x, dim=1)
            return x

        def training_step(self, batch, batch_idx):
            x, y = batch
            logits = self(x)
            loss = F.nll_loss(logits, y)
            return loss

model = MNISTClassifier()
x = mnist_image()
logits = model(x)

In this case, we've set this LightningModel to predict logits. But we could also have it predict feature maps:

.. testcode::

    class MNISTRepresentator(LightningModule):

        def forward(self, x):
            batch_size, channels, width, height = x.size()
            x = x.view(batch_size, -1)
            x = self.layer_1(x)
            x1 = torch.relu(x)
            x = self.layer_2(x1)
            x2 = torch.relu(x)
            x3 = self.layer_3(x2)
            return [x, x1, x2, x3]

        def training_step(self, batch, batch_idx):
            x, y = batch
            out, l1_feats, l2_feats, l3_feats = self(x)
            logits = torch.log_softmax(out, dim=1)
            ce_loss = F.nll_loss(logits, y)
            loss = perceptual_loss(l1_feats, l2_feats, l3_feats) + ce_loss
            return loss

model = MNISTRepresentator.load_from_checkpoint(PATH)
x = mnist_image()
feature_maps = model(x)

Or maybe we have a model that we use to do generation

.. testcode::

    class LitMNISTDreamer(LightningModule):

        def forward(self, z):
            imgs = self.decoder(z)
            return imgs

        def training_step(self, batch, batch_idx):
            x, y = batch
            representation = self.encoder(x)
            imgs = self(representation)

            loss = perceptual_loss(imgs, x)
            return loss

model = LitMNISTDreamer.load_from_checkpoint(PATH)
z = sample_noise()
generated_imgs = model(z)

How you split up what goes in forward vs training_step depends on how you want to use this model for prediction.


Extensibility

Although lightning makes everything super simple, it doesn't sacrifice any flexibility or control. Lightning offers multiple ways of managing the training state.

Training overrides

Any part of the training, validation and testing loop can be modified. For instance, if you wanted to do your own backward pass, you would override the default implementation

.. testcode::

    def backward(self, use_amp, loss, optimizer):
        loss.backward()

With your own

.. testcode::

    class LitMNIST(LightningModule):

        def backward(self, use_amp, loss, optimizer, optimizer_idx):
            # do a custom way of backward
            loss.backward(retain_graph=True)

Or if you wanted to initialize ddp in a different way than the default one

.. testcode::

    def configure_ddp(self, model, device_ids):
        # Lightning DDP simply routes to test_step, val_step, etc...
        model = LightningDistributedDataParallel(
            model,
            device_ids=device_ids,
            find_unused_parameters=True
        )
        return model

you could do your own:

.. testcode::

    class LitMNIST(LightningModule):

        def configure_ddp(self, model, device_ids):

            model = Horovod(model)
            # model = Ray(model)
            return model

Every single part of training is configurable this way. For a full list look at LightningModule.


Callbacks

Another way to add arbitrary functionality is to add a custom callback for hooks that you might care about

.. testcode::

    from pytorch_lightning.callbacks import Callback

    class MyPrintingCallback(Callback):

        def on_init_start(self, trainer):
            print('Starting to init trainer!')

        def on_init_end(self, trainer):
            print('Trainer is init now')

        def on_train_end(self, trainer, pl_module):
            print('do something when training ends')

And pass the callbacks into the trainer

.. testcode::

    trainer = Trainer(callbacks=[MyPrintingCallback()])

.. testoutput::
    :hide:

    Starting to init trainer!
    Trainer is init now

Note

See full list of 12+ hooks in the :ref:`callbacks`.