Skip to content

Latest commit

 

History

History
103 lines (71 loc) · 3.35 KB

converting.rst

File metadata and controls

103 lines (71 loc) · 3.35 KB
.. testsetup:: *

    from pytorch_lightning.core.lightning import LightningModule
    from pytorch_lightning.core.datamodule import LightningDataModule
    from pytorch_lightning.trainer.trainer import Trainer

How to organize PyTorch into Lightning

To enable your code to work with Lightning, here's how to organize PyTorch into Lightning

1. Move your computational code

Move the model architecture and forward pass to your :class:`~pytorch_lightning.core.LightningModule`.

class LitModel(pl.LightningModule):

    def __init__(self):
        super().__init__()
        self.layer_1 = torch.nn.Linear(28 * 28, 128)
        self.layer_2 = torch.nn.Linear(128, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = self.layer_1(x)
        x = F.relu(x)
        x = self.layer_2(x)
        return x

2. Move the optimizer(s) and schedulers

Move your optimizers to :func:`pytorch_lightning.core.LightningModule.configure_optimizers` hook. Make sure to use the hook parameters (self in this case).

class LitModel(pl.LightningModule):

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

3. Find the train loop "meat"

Lightning automates most of the trining for you, the epoch and batch iterations, all you need to keep is the training step logic. This should go into :func:`pytorch_lightning.core.LightningModule.training_step` hook (make sure to use the hook parameters, self in this case):

class LitModel(pl.LightningModule):

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        return loss

4. Find the val loop "meat"

To add an (optional) validation loop add logic to :func:`pytorch_lightning.core.LightningModule.validation_step` hook (make sure to use the hook parameters, self in this case).

.. testcode::

    class LitModel(LightningModule):

        def validation_step(self, batch, batch_idx):
            x, y = batch
            y_hat = self(x)
            val_loss = F.cross_entropy(y_hat, y)
            return val_loss

Note

model.eval() and torch.no_grad() are called automatically for validation

5. Find the test loop "meat"

To add an (optional) test loop add logic to :func:`pytorch_lightning.core.LightningModule.test_step` hook (make sure to use the hook parameters, self in this case).

class LitModel(pl.LightningModule):

    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        return loss

Note

model.eval() and torch.no_grad() are called automatically for testing.

The test loop will not be used until you call.

trainer.test()

Note

.test() loads the best checkpoint automatically

6. Remove any .cuda() or to.device() calls

Your :class:`~pytorch_lightning.core.LightningModule` can automatically run on any hardware!