Skip to content

Latest commit

 

History

History
197 lines (135 loc) · 6.25 KB

converting.rst

File metadata and controls

197 lines (135 loc) · 6.25 KB

How to Organize PyTorch Into Lightning

To enable your code to work with Lightning, perform the following to organize PyTorch into Lightning.


1. Keep Your Computational Code

Keep your regular nn.Module architecture

.. testcode::

    import pytorch_lightning as pl
    import torch
    import torch.nn as nn
    import torch.nn.functional as F


    class LitModel(nn.Module):
        def __init__(self):
            super().__init__()
            self.layer_1 = nn.Linear(28 * 28, 128)
            self.layer_2 = nn.Linear(128, 10)

        def forward(self, x):
            x = x.view(x.size(0), -1)
            x = self.layer_1(x)
            x = F.relu(x)
            x = self.layer_2(x)
            return x


2. Configure Training Logic

In the training_step of the LightningModule configure how your training routine behaves with a batch of training data:

.. testcode::

    class LitModel(pl.LightningModule):
        def __init__(self, encoder):
            super().__init__()
            self.encoder = encoder

        def training_step(self, batch, batch_idx):
            x, y = batch
            y_hat = self.encoder(x)
            loss = F.cross_entropy(y_hat, y)
            return loss

Note

If you need to fully own the training loop for complicated legacy projects, check out :doc:`Own your loop <../model/own_your_loop>`.


3. Move Optimizer(s) and LR Scheduler(s)

Move your optimizers to the :meth:`~pytorch_lightning.core.module.LightningModule.configure_optimizers` hook.

.. testcode::

    class LitModel(pl.LightningModule):
        def configure_optimizers(self):
            optimizer = torch.optim.Adam(self.encoder.parameters(), lr=1e-3)
            lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
            return [optimizer], [lr_scheduler]


4. Organize Validation Logic (optional)

If you need a validation loop, configure how your validation routine behaves with a batch of validation data:

.. testcode::

    class LitModel(pl.LightningModule):
        def validation_step(self, batch, batch_idx):
            x, y = batch
            y_hat = self.encoder(x)
            val_loss = F.cross_entropy(y_hat, y)
            self.log("val_loss", val_loss)

Tip

trainer.validate() loads the best checkpoint automatically by default if checkpointing was enabled during fitting.


5. Organize Testing Logic (optional)

If you need a test loop, configure how your testing routine behaves with a batch of test data:

.. testcode::

    class LitModel(pl.LightningModule):
        def test_step(self, batch, batch_idx):
            x, y = batch
            y_hat = self.encoder(x)
            test_loss = F.cross_entropy(y_hat, y)
            self.log("test_loss", test_loss)


6. Configure Prediction Logic (optional)

If you need a prediction loop, configure how your prediction routine behaves with a batch of test data:

.. testcode::

    class LitModel(LightningModule):
        def predict_step(self, batch, batch_idx):
            x, y = batch
            pred = self.encoder(x)
            return pred


7. Remove any .cuda() or .to(device) Calls

Your :doc:`LightningModule <../common/lightning_module>` can automatically run on any hardware!

If you have any explicit calls to .cuda() or .to(device), you can remove them since Lightning makes sure that the data coming from :class:`~torch.utils.data.DataLoader` and all the :class:`~torch.nn.Module` instances initialized inside LightningModule.__init__ are moved to the respective devices automatically. If you still need to access the current device, you can use self.device anywhere in your LightningModule except in the __init__ and setup methods.

.. testcode::

    class LitModel(LightningModule):
        def training_step(self, batch, batch_idx):
            z = torch.randn(4, 5, device=self.device)
            ...

Hint: If you are initializing a :class:`~torch.Tensor` within the LightningModule.__init__ method and want it to be moved to the device automatically you should call :meth:`~torch.nn.Module.register_buffer` to register it as a parameter.

.. testcode::

    class LitModel(LightningModule):
        def __init__(self):
            super().__init__()
            self.register_buffer("running_mean", torch.zeros(num_features))


8. Use your own data

Regular PyTorch DataLoaders work with Lightning. For more modular and scalable datasets, check out :doc:`LightningDataModule <../data/datamodule>`.


Good to know

Additionally, you can run only the validation loop using :meth:`~pytorch_lightning.trainer.trainer.Trainer.validate` method.

model = LitModel()
trainer.validate(model)

Note

model.eval() and torch.no_grad() are called automatically for validation.

The test loop isn't used within :meth:`~pytorch_lightning.trainer.trainer.Trainer.fit`, therefore, you would need to explicitly call :meth:`~pytorch_lightning.trainer.trainer.Trainer.test`.

model = LitModel()
trainer.test(model)

Note

model.eval() and torch.no_grad() are called automatically for testing.

Tip

trainer.test() loads the best checkpoint automatically by default if checkpointing is enabled.

The predict loop will not be used until you call :meth:`~pytorch_lightning.trainer.trainer.Trainer.predict`.

model = LitModel()
trainer.predict(model)

Note

model.eval() and torch.no_grad() are called automatically for testing.

Tip

trainer.predict() loads the best checkpoint automatically by default if checkpointing is enabled.