Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove memory-retaining epoch-end hooks #16520

Merged
merged 15 commits into from
Feb 6, 2023
Merged
16 changes: 12 additions & 4 deletions docs/source-pytorch/accelerators/accelerator_prepare.rst
Original file line number Diff line number Diff line change
Expand Up @@ -105,19 +105,27 @@ Note if you use any built in metrics or custom metrics that use `TorchMetrics <h

It is possible to perform some computation manually and log the reduced result on rank 0 as follows:

.. testcode::
.. code-block:: python

def __init__(self):
super().__init__()
self.outputs = []


def test_step(self, batch, batch_idx):
x, y = batch
tensors = self(x)
self.outputs.append(tensors)
return tensors


def test_epoch_end(self, outputs):
mean = torch.mean(self.all_gather(outputs))
def on_test_epoch_end(self):
mean = torch.mean(self.all_gather(self.outputs))
self.outputs.clear() # free memory

# When logging only on rank 0, don't forget to add
# ``rank_zero_only=True`` to avoid deadlocks on synchronization.
# `rank_zero_only=True` to avoid deadlocks on synchronization.
# caveat: monitoring this is unimplemented. see https://github.com/Lightning-AI/lightning/issues/15852
if self.trainer.is_global_zero:
self.log("my_reduced_metric", mean, rank_zero_only=True)

Expand Down
141 changes: 47 additions & 94 deletions docs/source-pytorch/common/lightning_module.rst
Original file line number Diff line number Diff line change
Expand Up @@ -182,10 +182,8 @@ Under the hood, Lightning does the following (pseudocode):
model.train()
torch.set_grad_enabled(True)

outs = []
for batch_idx, batch in enumerate(train_dataloader):
loss = training_step(batch, batch_idx)
outs.append(loss.detach())

# clear gradients
optimizer.zero_grad()
Expand Down Expand Up @@ -214,7 +212,7 @@ If you want to calculate epoch-level metrics and log them, use :meth:`~pytorch_l
self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
return loss

The :meth:`~pytorch_lightning.core.module.LightningModule.log` object automatically reduces the
The :meth:`~pytorch_lightning.core.module.LightningModule.log` method automatically reduces the
requested metrics across a complete epoch and devices. Here's the pseudocode of what it does under the hood:

.. code-block:: python
Expand All @@ -223,59 +221,45 @@ requested metrics across a complete epoch and devices. Here's the pseudocode of
for batch_idx, batch in enumerate(train_dataloader):
# forward
loss = training_step(batch, batch_idx)
outs.append(loss)
outs.append(loss.detach())

# clear gradients
optimizer.zero_grad()

# backward
loss.backward()

# update parameters
optimizer.step()

epoch_metric = torch.mean(torch.stack([x for x in outs]))
# note: in reality, we do this incrementally, instead of keeping all outputs in memory
epoch_metric = torch.mean(torch.stack(outs))

Train Epoch-level Operations
============================

If you need to do something with all the outputs of each :meth:`~pytorch_lightning.core.module.LightningModule.training_step`,
override the :meth:`~pytorch_lightning.core.module.LightningModule.training_epoch_end` method.

.. code-block:: python

def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = F.cross_entropy(y_hat, y)
preds = ...
return {"loss": loss, "other_stuff": preds}


def training_epoch_end(self, training_step_outputs):
all_preds = torch.stack(training_step_outputs)
...

The matching pseudocode is:
In the case that you need to make use of all the outputs from each :meth:`~pytorch_lightning.LightningModule.training_step`,
override the :meth:`~pytorch_lightning.LightningModule.on_training_epoch_end` method.

.. code-block:: python

outs = []
for batch_idx, batch in enumerate(train_dataloader):
# forward
loss = training_step(batch, batch_idx)
outs.append(loss)
def __init__(self):
super().__init__()
self.training_step_outputs = []

# clear gradients
optimizer.zero_grad()

# backward
loss.backward()
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = F.cross_entropy(y_hat, y)
preds = ...
self.training_step_outputs.append(preds)
return loss

# update parameters
optimizer.step()

training_epoch_end(outs)
def on_train_epoch_end(self):
all_preds = torch.stack(self.training_step_outputs)
# do something with all preds
...
self.training_step_outputs.clear() # free memory

Training with DataParallel
==========================
Expand Down Expand Up @@ -309,15 +293,10 @@ method which will have outputs from all the devices and you can accumulate to ge
return (losses[0] + losses[1]) / 2


def training_epoch_end(self, training_step_outputs):
for out in training_step_outputs:
...

Here is the Lightning training pseudo-code for DP:

.. code-block:: python

outs = []
for batch_idx, train_batch in enumerate(train_dataloader):
batches = split_batch(train_batch)
dp_outs = []
Expand All @@ -327,12 +306,7 @@ Here is the Lightning training pseudo-code for DP:
dp_outs.append(dp_out)

# 2
out = training_step_end(dp_outs)
outs.append(out)

# do something with the outputs for all batches
# 3
training_epoch_end(outs)
training_step_end(dp_outs)

------------------

Expand Down Expand Up @@ -399,22 +373,32 @@ and calling :meth:`~pytorch_lightning.trainer.trainer.Trainer.validate`.
Validation Epoch-level Metrics
==============================

If you need to do something with all the outputs of each :meth:`~pytorch_lightning.core.module.LightningModule.validation_step`,
override the :meth:`~pytorch_lightning.core.module.LightningModule.validation_epoch_end` method. Note that this method is called before :meth:`~pytorch_lightning.core.module.LightningModule.training_epoch_end`.
In the case that you need to make use of all the outputs from each :meth:`~pytorch_lightning.LightningModule.validation_step`,
override the :meth:`~pytorch_lightning.LightningModule.on_validation_epoch_end` method.
Note that this method is called before :meth:`~pytorch_lightning.LightningModule.on_train_epoch_end`.

.. code-block:: python

def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = F.cross_entropy(y_hat, y)
pred = ...
return pred
def __init__(self):
super().__init__()
self.validation_step_outputs = []


def validation_epoch_end(self, validation_step_outputs):
all_preds = torch.stack(validation_step_outputs)
...
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = F.cross_entropy(y_hat, y)
pred = ...
self.validation_step_outputs.append(pred)
return pred


def on_validation_epoch_end(self):
all_preds = torch.stack(self.validation_step_outputs)
carmocca marked this conversation as resolved.
Show resolved Hide resolved
# do something with all preds
...
self.validation_step_outputs.clear() # free memory


Validating with DataParallel
============================
Expand Down Expand Up @@ -448,15 +432,10 @@ method which will have outputs from all the devices and you can accumulate to ge
return (losses[0] + losses[1]) / 2


def validation_epoch_end(self, validation_step_outputs):
for out in validation_step_outputs:
...

Here is the Lightning validation pseudo-code for DP:

.. code-block:: python

outs = []
for batch in dataloader:
batches = split_batch(batch)
dp_outs = []
Expand All @@ -466,12 +445,7 @@ Here is the Lightning validation pseudo-code for DP:
dp_outs.append(dp_out)

# 2
out = validation_step_end(dp_outs)
outs.append(out)

# do something with the outputs for all batches
# 3
validation_epoch_end(outs)
validation_step_end(dp_outs)

----------------

Expand Down Expand Up @@ -924,12 +898,6 @@ test_step_end
.. automethod:: pytorch_lightning.core.module.LightningModule.test_step_end
:noindex:

test_epoch_end
~~~~~~~~~~~~~~

.. automethod:: pytorch_lightning.core.module.LightningModule.test_epoch_end
:noindex:

to_onnx
~~~~~~~

Expand All @@ -954,11 +922,6 @@ training_step_end
.. automethod:: pytorch_lightning.core.module.LightningModule.training_step_end
:noindex:

training_epoch_end
~~~~~~~~~~~~~~~~~~
.. automethod:: pytorch_lightning.core.module.LightningModule.training_epoch_end
:noindex:

unfreeze
~~~~~~~~

Expand All @@ -983,12 +946,6 @@ validation_step_end
.. automethod:: pytorch_lightning.core.module.LightningModule.validation_step_end
:noindex:

validation_epoch_end
~~~~~~~~~~~~~~~~~~~~

.. automethod:: pytorch_lightning.core.module.LightningModule.validation_epoch_end
:noindex:

-----------

Properties
Expand Down Expand Up @@ -1246,7 +1203,8 @@ for more information.
transfer_batch_to_device()
on_after_batch_transfer()

training_step()
out = training_step()
training_step_end(out)

on_before_zero_grad()
optimizer_zero_grad()
Expand All @@ -1263,8 +1221,6 @@ for more information.

if should_check_val:
val_loop()
# end training epoch
training_epoch_end()

on_train_epoch_end()

Expand All @@ -1276,7 +1232,6 @@ for more information.
on_validation_start()
on_validation_epoch_start()

val_outs = []
for batch_idx, batch in enumerate(val_dataloader()):
on_validation_batch_start(batch, batch_idx)

Expand All @@ -1285,11 +1240,9 @@ for more information.
batch = on_after_batch_transfer(batch)

out = validation_step(batch, batch_idx)
out = validation_step_end(out)

on_validation_batch_end(batch, batch_idx)
val_outs.append(out)

validation_epoch_end(val_outs)

on_validation_epoch_end()
on_validation_end()
Expand Down
4 changes: 2 additions & 2 deletions docs/source-pytorch/extensions/logging.rst
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ The :meth:`~pytorch_lightning.core.module.LightningModule.log` method has a few
* - Hook
- on_step
- on_epoch
* - on_train_start, on_train_epoch_start, on_train_epoch_end, training_epoch_end
* - on_train_start, on_train_epoch_start, on_train_epoch_end
- False
- True
* - on_before_backward, on_after_backward, on_before_optimizer_step, on_before_zero_grad
Expand All @@ -161,7 +161,7 @@ The :meth:`~pytorch_lightning.core.module.LightningModule.log` method has a few
* - on_train_batch_start, on_train_batch_end, training_step, training_step_end
- True
- False
* - on_validation_start, on_validation_epoch_start, on_validation_epoch_end, validation_epoch_end
* - on_validation_start, on_validation_epoch_start, on_validation_epoch_end
- False
- True
* - on_validation_batch_start, on_validation_batch_end, validation_step, validation_step_end
Expand Down
2 changes: 1 addition & 1 deletion docs/source-pytorch/model/manual_optimization.rst
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ If you want to call schedulers that require a metric value after each epoch, con
self.automatic_optimization = False


def training_epoch_end(self, outputs):
def on_train_epoch_end(self):
sch = self.lr_schedulers()

# If the selected scheduler is a ReduceLROnPlateau scheduler.
Expand Down
6 changes: 3 additions & 3 deletions docs/source-pytorch/starter/style_guide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -149,19 +149,19 @@ In practice, the code looks like this:

def training_step_end(...):

def training_epoch_end(...):
def on_train_epoch_end(...):

def validation_step(...):

def validation_step_end(...):

def validation_epoch_end(...):
def on_validation_epoch_end(...):

def test_step(...):

def test_step_end(...):

def test_epoch_end(...):
def on_test_epoch_end(...):

def configure_optimizers(...):

Expand Down
2 changes: 1 addition & 1 deletion docs/source-pytorch/visualize/logging_advanced.rst
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ In LightningModule
* - on_after_backward, on_before_backward, on_before_optimizer_step, optimizer_step, configure_gradient_clipping, on_before_zero_grad, training_step, training_step_end
- True
- False
* - training_epoch_end, test_epoch_end, test_step, test_step_end, validation_epoch_end, validation_step, validation_step_end
* - test_step, test_step_end, validation_step, validation_step_end
- False
- True

Expand Down