Skip to content

Commit

Permalink
ref: inner train loop (intermediate step) 16/n (#3375)
Browse files Browse the repository at this point in the history
* ref: inner train loop (intermediate step) 16/n

* ref: inner train loop (intermediate step) 16/n

* ref: inner train loop (intermediate step) 16/n

* ref: inner train loop (intermediate step) 16/n

* ref: inner train loop (intermediate step) 16/n

* ref: inner train loop (intermediate step) 16/n
  • Loading branch information
williamFalcon committed Sep 7, 2020
1 parent bce5c81 commit 69e3f90
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 159 deletions.
150 changes: 0 additions & 150 deletions pytorch_lightning/trainer/training_loop.py
Expand Up @@ -12,156 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""
The lightning training loop handles everything except the actual computations of your model.
To decide what will happen in your training loop, define the `training_step` function.
Below are all the things lightning automates for you in the training loop.
Accumulated gradients
---------------------
Accumulated gradients runs K small batches of size N before doing a backwards pass.
The effect is a large effective batch size of size KxN.
.. code-block:: python
# DEFAULT (ie: no accumulated grads)
trainer = Trainer(accumulate_grad_batches=1)
Force training for min or max epochs
------------------------------------
It can be useful to force training for a minimum number of epochs or limit to a max number
.. code-block:: python
# DEFAULT
trainer = Trainer(min_epochs=1, max_epochs=1000)
Force disable early stop
------------------------
To disable early stopping pass None to the early_stop_callback
.. code-block:: python
# DEFAULT
trainer = Trainer(early_stop_callback=None)
Gradient Clipping
-----------------
Gradient clipping may be enabled to avoid exploding gradients.
Specifically, this will `clip the gradient norm computed over all model parameters
`together <https://pytorch.org/docs/stable/nn.html#torch.nn.utils.clip_grad_norm_>`_.
.. code-block:: python
# DEFAULT (ie: don't clip)
trainer = Trainer(gradient_clip_val=0)
# clip gradients with norm above 0.5
trainer = Trainer(gradient_clip_val=0.5)
Inspect gradient norms
----------------------
Looking at grad norms can help you figure out where training might be going wrong.
.. code-block:: python
# DEFAULT (-1 doesn't track norms)
trainer = Trainer(track_grad_norm=-1)
# track the LP norm (P=2 here)
trainer = Trainer(track_grad_norm=2)
Set how much of the training set to check
-----------------------------------------
If you don't want to check 100% of the training set (for debugging or if it's huge), set this flag.
limit_train_batches will be overwritten by overfit_batches if `overfit_batches > 0`
.. code-block:: python
# DEFAULT
trainer = Trainer(limit_train_batches=1.0)
# check 10% only
trainer = Trainer(limit_train_batches=0.1)
# check 10 batches only
trainer = Trainer(limit_train_batches=10)
Packed sequences as inputs
--------------------------
When using PackedSequence, do 2 things:
1. return either a padded tensor in dataset or a list of variable length tensors
in the dataloader collate_fn (example above shows the list implementation).
2. Pack the sequence in forward or training and validation steps depending on use case.
.. code-block:: python
# For use in dataloader
def collate_fn(batch):
x = [item[0] for item in batch]
y = [item[1] for item in batch]
return x, y
# In module
def training_step(self, batch, batch_idx):
x = rnn.pack_sequence(batch[0], enforce_sorted=False)
y = rnn.pack_sequence(batch[1], enforce_sorted=False)
Truncated Backpropagation Through Time
--------------------------------------
There are times when multiple backwards passes are needed for each batch.
For example, it may save memory to use Truncated Backpropagation Through Time when training RNNs.
When this flag is enabled each batch is split into sequences of size truncated_bptt_steps
and passed to training_step(...) separately. A default splitting function is provided,
however, you can override it for more flexibility. See `tbptt_split_batch`.
.. code-block:: python
# DEFAULT (single backwards pass per batch)
trainer = Trainer(truncated_bptt_steps=None)
# (split batch into sequences of size 2)
trainer = Trainer(truncated_bptt_steps=2)
NaN detection and intervention
------------------------------
When the `terminate_on_nan` flag is enabled, after every forward pass during training, Lightning will
check that
1. the loss you return in `training_step` is finite (not NaN and not +/-inf)
2. the model parameters have finite values.
Lightning will terminate the training loop with an error message if NaN or infinite
values are detected. If this happens, you should investigate numerically unstable operations
in your model.
.. code-block:: python
# DEFAULT (won't perform the NaN check)
trainer = Trainer(terminate_on_nan=False)
# (NaN check each batch and terminate on NaN or infinite values)
trainer = Trainer(terminate_on_nan=True)
"""
from abc import ABC, abstractmethod
from typing import Callable
from typing import Union, List

from torch.utils.data import DataLoader
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.core.step_result import Result
Expand All @@ -171,12 +24,9 @@ def training_step(self, batch, batch_idx):
from pytorch_lightning.utilities.model_utils import is_overridden
from pytorch_lightning.trainer.training_loop_temp import TrainLoop
from pytorch_lightning.trainer.data_connector import DataConnector
from pytorch_lightning.utilities.debugging import InternalDebugger


class TrainerTrainLoopMixin(ABC):
# this is just a summary on variables used in this abstract class,
# the proper values/initialisation should be done in child class
on_gpu: bool
use_horovod: bool
check_val_every_n_epoch: ...
Expand Down
3 changes: 2 additions & 1 deletion tests/trainer/test_trainer_steps_dict_return.py
Expand Up @@ -47,7 +47,8 @@ def test_training_step_dict(tmpdir):
assert pbar_metrics['pbar_acc2'] == 19.0

# make sure the optimizer closure returns the correct things
opt_closure_result = trainer.training_step_and_backward(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
opt_closure_result = trainer.train_loop.training_step_and_backward(
batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
assert opt_closure_result['loss'] == (42.0 * 3) + (15.0 * 3)


Expand Down
12 changes: 8 additions & 4 deletions tests/trainer/test_trainer_steps_result_return.py
Expand Up @@ -84,7 +84,8 @@ def test_training_step_result_log_step_only(tmpdir):
assert f'step_log_acc2_b{batch_idx}' in train_step_out

# make sure the optimizer closure returns the correct things
opt_closure_result = trainer.training_step_and_backward(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
opt_closure_result = trainer.train_loop.training_step_and_backward(
batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
assert opt_closure_result['loss'] == (42.0 * 3) + (15.0 * 3)


Expand Down Expand Up @@ -158,7 +159,8 @@ def test_training_step_result_log_epoch_only(tmpdir):
assert f'epoch_log_acc2_e{trainer.current_epoch}' in train_step_out

# make sure the optimizer closure returns the correct things
opt_closure_result = trainer.training_step_and_backward(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
opt_closure_result = trainer.train_loop.training_step_and_backward(
batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
assert opt_closure_result['loss'] == (42.0 * 3) + (15.0 * 3)


Expand Down Expand Up @@ -293,7 +295,8 @@ def test_training_step_result_log_step_and_epoch(tmpdir):
assert 'epoch_step_epoch_log_acc2' in train_step_out

# make sure the optimizer closure returns the correct things
opt_closure_result = trainer.training_step_and_backward(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
opt_closure_result = trainer.train_loop.training_step_and_backward(
batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
assert opt_closure_result['loss'] == (42.0 * 3) + (15.0 * 3)


Expand Down Expand Up @@ -372,7 +375,8 @@ def test_training_step_epoch_end_result(tmpdir):
assert 'epoch_step_epoch_log_acc2' in train_step_out

# make sure the optimizer closure returns the correct things
opt_closure_result = trainer.training_step_and_backward(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
opt_closure_result = trainer.train_loop.training_step_and_backward(
batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
assert opt_closure_result['loss'] == (42.0 * 3) + (15.0 * 3)


Expand Down
12 changes: 8 additions & 4 deletions tests/trainer/test_trainer_steps_scalar_return.py
Expand Up @@ -43,7 +43,8 @@ def test_training_step_scalar(tmpdir):
assert train_step_out.item() == 171

# make sure the optimizer closure returns the correct things
opt_closure_result = trainer.training_step_and_backward(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
opt_closure_result = trainer.train_loop.training_step_and_backward(
batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
assert opt_closure_result['loss'].item() == 171


Expand Down Expand Up @@ -80,7 +81,8 @@ def training_step_scalar_with_step_end(tmpdir):
assert train_step_out.item() == 171

# make sure the optimizer closure returns the correct things
opt_closure_result = trainer.training_step_and_backward(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
opt_closure_result = trainer.train_loop.training_step_and_backward(
batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
assert opt_closure_result['loss'].item() == 171


Expand Down Expand Up @@ -127,7 +129,8 @@ def test_full_training_loop_scalar(tmpdir):
assert train_step_out.item() == 171

# make sure the optimizer closure returns the correct things
opt_closure_result = trainer.training_step_and_backward(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
opt_closure_result = trainer.train_loop.training_step_and_backward(
batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
assert opt_closure_result['loss'].item() == 171


Expand Down Expand Up @@ -170,5 +173,6 @@ def test_train_step_epoch_end_scalar(tmpdir):
assert train_step_out.item() == 171

# make sure the optimizer closure returns the correct things
opt_closure_result = trainer.training_step_and_backward(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
opt_closure_result = trainer.train_loop.training_step_and_backward(
batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
assert opt_closure_result['loss'].item() == 171

0 comments on commit 69e3f90

Please sign in to comment.