To enable your code to work with Lightning, perform the following to organize PyTorch into Lightning.
Keep your regular nn.Module architecture
.. testcode:: import pytorch_lightning as pl import torch import torch.nn as nn import torch.nn.functional as F class LitModel(nn.Module): def __init__(self): super().__init__() self.layer_1 = nn.Linear(28 * 28, 128) self.layer_2 = nn.Linear(128, 10) def forward(self, x): x = x.view(x.size(0), -1) x = self.layer_1(x) x = F.relu(x) x = self.layer_2(x) return x
In the training_step of the LightningModule configure how your training routine behaves with a batch of training data:
.. testcode:: class LitModel(pl.LightningModule): def __init__(self, encoder): super().__init__() self.encoder = encoder def training_step(self, batch, batch_idx): x, y = batch y_hat = self.encoder(x) loss = F.cross_entropy(y_hat, y) return loss
Note
If you need to fully own the training loop for complicated legacy projects, check out :doc:`Own your loop <../model/own_your_loop>`.
Move your optimizers to the :meth:`~pytorch_lightning.core.module.LightningModule.configure_optimizers` hook.
.. testcode:: class LitModel(pl.LightningModule): def configure_optimizers(self): optimizer = torch.optim.Adam(self.encoder.parameters(), lr=1e-3) lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1) return [optimizer], [lr_scheduler]
If you need a validation loop, configure how your validation routine behaves with a batch of validation data:
.. testcode:: class LitModel(pl.LightningModule): def validation_step(self, batch, batch_idx): x, y = batch y_hat = self.encoder(x) val_loss = F.cross_entropy(y_hat, y) self.log("val_loss", val_loss)
Tip
trainer.validate()
loads the best checkpoint automatically by default if checkpointing was enabled during fitting.
If you need a test loop, configure how your testing routine behaves with a batch of test data:
.. testcode:: class LitModel(pl.LightningModule): def test_step(self, batch, batch_idx): x, y = batch y_hat = self.encoder(x) test_loss = F.cross_entropy(y_hat, y) self.log("test_loss", test_loss)
If you need a prediction loop, configure how your prediction routine behaves with a batch of test data:
.. testcode:: class LitModel(LightningModule): def predict_step(self, batch, batch_idx): x, y = batch pred = self.encoder(x) return pred
Your :doc:`LightningModule <../common/lightning_module>` can automatically run on any hardware!
If you have any explicit calls to .cuda()
or .to(device)
, you can remove them since Lightning makes sure that the data coming from :class:`~torch.utils.data.DataLoader`
and all the :class:`~torch.nn.Module` instances initialized inside LightningModule.__init__
are moved to the respective devices automatically.
If you still need to access the current device, you can use self.device
anywhere in your LightningModule
except in the __init__
and setup
methods.
.. testcode:: class LitModel(LightningModule): def training_step(self, batch, batch_idx): z = torch.randn(4, 5, device=self.device) ...
Hint: If you are initializing a :class:`~torch.Tensor` within the LightningModule.__init__
method and want it to be moved to the device automatically you should call
:meth:`~torch.nn.Module.register_buffer` to register it as a parameter.
.. testcode:: class LitModel(LightningModule): def __init__(self): super().__init__() self.register_buffer("running_mean", torch.zeros(num_features))
Regular PyTorch DataLoaders work with Lightning. For more modular and scalable datasets, check out :doc:`LightningDataModule <../data/datamodule>`.
Additionally, you can run only the validation loop using :meth:`~pytorch_lightning.trainer.trainer.Trainer.validate` method.
model = LitModel()
trainer.validate(model)
Note
model.eval()
and torch.no_grad()
are called automatically for validation.
The test loop isn't used within :meth:`~pytorch_lightning.trainer.trainer.Trainer.fit`, therefore, you would need to explicitly call :meth:`~pytorch_lightning.trainer.trainer.Trainer.test`.
model = LitModel()
trainer.test(model)
Note
model.eval()
and torch.no_grad()
are called automatically for testing.
Tip
trainer.test()
loads the best checkpoint automatically by default if checkpointing is enabled.
The predict loop will not be used until you call :meth:`~pytorch_lightning.trainer.trainer.Trainer.predict`.
model = LitModel()
trainer.predict(model)
Note
model.eval()
and torch.no_grad()
are called automatically for testing.
Tip
trainer.predict()
loads the best checkpoint automatically by default if checkpointing is enabled.