.. testsetup:: * from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.trainer.trainer import Trainer
To enable your code to work with Lightning, here's how to organize PyTorch into Lightning
Move the model architecture and forward pass to your :class:`~pytorch_lightning.core.LightningModule`.
class LitModel(pl.LightningModule): def __init__(self): super().__init__() self.layer_1 = torch.nn.Linear(28 * 28, 128) self.layer_2 = torch.nn.Linear(128, 10) def forward(self, x): x = x.view(x.size(0), -1) x = self.layer_1(x) x = F.relu(x) x = self.layer_2(x) return x
Move your optimizers to :func:`pytorch_lightning.core.LightningModule.configure_optimizers` hook. Make sure to use the hook parameters (self in this case).
class LitModel(pl.LightningModule): def configure_optimizers(self): optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) return optimizer
Lightning automates most of the trining for you, the epoch and batch iterations, all you need to keep is the training step logic. This should go into :func:`pytorch_lightning.core.LightningModule.training_step` hook (make sure to use the hook parameters, self in this case):
class LitModel(pl.LightningModule): def training_step(self, batch, batch_idx): x, y = batch y_hat = self(x) loss = F.cross_entropy(y_hat, y) return loss
To add an (optional) validation loop add logic to :func:`pytorch_lightning.core.LightningModule.validation_step` hook (make sure to use the hook parameters, self in this case).
.. testcode:: class LitModel(LightningModule): def validation_step(self, batch, batch_idx): x, y = batch y_hat = self(x) val_loss = F.cross_entropy(y_hat, y) return val_loss
Note
model.eval() and torch.no_grad() are called automatically for validation
To add an (optional) test loop add logic to :func:`pytorch_lightning.core.LightningModule.test_step` hook (make sure to use the hook parameters, self in this case).
class LitModel(pl.LightningModule): def test_step(self, batch, batch_idx): x, y = batch y_hat = self(x) loss = F.cross_entropy(y_hat, y) return loss
Note
model.eval() and torch.no_grad() are called automatically for testing.
The test loop will not be used until you call.
trainer.test()
Note
.test() loads the best checkpoint automatically
Your :class:`~pytorch_lightning.core.LightningModule` can automatically run on any hardware!