Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/child_modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ Child Modules
Research projects tend to test different approaches to the same dataset.
This is very easy to do in Lightning with inheritance.

For example, imaging we now want to train an Autoencoder to use as a feature extractor for MNIST images.
For example, imagine we now want to train an Autoencoder to use as a feature extractor for MNIST images.
Recall that `LitMNIST` already defines all the dataloading etc... The only things
that change in the `Autoencoder` model are the init, forward, training, validation and test step.

Expand Down
66 changes: 61 additions & 5 deletions docs/source/introduction_guide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,50 @@ Notice the code is exactly the same, except now the training dataloading has bee
under the `train_dataloader` method. This is great because if you run into a project that uses Lightning and want
to figure out how they prepare their training data you can just look in the `train_dataloader` method.

Usually though, we want to separate the things that write to disk in data-processing from
things like transforms which happen in memory.

.. code-block:: python

class LitMNIST(pl.LightningModule):

def prepare_data(self):
# download only
MNIST(os.getcwd(), train=True, download=True)

def train_dataloader(self):
# no download, just transform
transform=transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))])
mnist_train = MNIST(os.getcwd(), train=True, download=False,
transform=transform)
return DataLoader(mnist_train, batch_size=64)

Doing it in the `prepare_data` method ensures that when you have
multiple GPUs you won't overwrite the data. This is a contrived example
but it gets more complicated with things like NLP or Imagenet.

In general fill these methods with the following:

.. code-block:: python

class LitMNIST(pl.LightningModule):

def prepare_data(self):
# stuff here is done once at the very beginning of training
# before any distributed training starts

# download stuff
# save to disk
# etc...

def train_dataloader(self):
# data transforms
# dataset creation
# return a DataLoader



Optimizer
^^^^^^^^^

Expand Down Expand Up @@ -606,11 +650,11 @@ metrics we care about, generate samples or add more to our logs.
loss = loss(y_hat, x) # validation_step
outputs.append({'val_loss': loss}) # validation_step

full_loss = outputs.mean() # validation_end
full_loss = outputs.mean() # validation_epoch_end

Since the `validation_step` processes a single batch,
in Lightning we also have a `validation_end` method which allows you to compute
statistics on the full dataset and not just the batch.
in Lightning we also have a `validation_epoch_end` method which allows you to compute
statistics on the full dataset after an epoch of validation data and not just the batch.

In addition, we define a `val_dataloader` method which tells the trainer what data to use for validation.
Notice we split the train split of MNIST into train, validation. We also have to make sure to do the
Expand Down Expand Up @@ -640,7 +684,7 @@ sample split in the `train_dataloader` method.
return mnist_val

Again, we've just organized the regular PyTorch code into two steps, the `validation_step` method which
operates on a single batch and the `validation_end` method to compute statistics on all batches.
operates on a single batch and the `validation_epoch_end` method to compute statistics on all batches.

If you have these methods defined, Lightning will call them automatically. Now we can train
while checking the validation set.
Expand Down Expand Up @@ -669,7 +713,7 @@ how it will generalize in the "real world." For this, we use a held-out split of
Just like the validation loop, we define exactly the same steps for testing:

- test_step
- test_end
- test_epoch_end
- test_dataloader

.. code-block:: python
Expand Down Expand Up @@ -707,6 +751,17 @@ Once you train your model simply call `.test()`.
# run test set
trainer.test()

.. rst-class:: sphx-glr-script-out

Out:

.. code-block:: none

--------------------------------------------------------------
TEST RESULTS
{'test_loss': tensor(1.1703, device='cuda:0')}
--------------------------------------------------------------

You can also run the test from a saved lightning model

.. code-block:: python
Expand Down Expand Up @@ -881,6 +936,7 @@ you could do your own:
Every single part of training is configurable this way.
For a full list look at `lightningModule <lightning-module.rst>`_.

---------

Callbacks
---------
Expand Down