Skip to content

Commit

Permalink
docs: explain how Lightning uses closures for automatic optimization (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli authored Jul 26, 2021
1 parent 75e18a5 commit eaa16c7
Showing 1 changed file with 24 additions and 9 deletions.
33 changes: 24 additions & 9 deletions docs/source/common/optimizers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -299,10 +299,14 @@ Under the hood, Lightning does the following:
for epoch in epochs:
for batch in data:
loss = model.training_step(batch, batch_idx, ...)
optimizer.zero_grad()
loss.backward()
optimizer.step()
def closure():
loss = model.training_step(batch, batch_idx, ...)
optimizer.zero_grad()
loss.backward()
return loss
optimizer.step(closure)
for lr_scheduler in lr_schedulers:
lr_scheduler.step()
Expand All @@ -314,14 +318,22 @@ In the case of multiple optimizers, Lightning does the following:
for epoch in epochs:
for batch in data:
for opt in optimizers:
loss = model.training_step(batch, batch_idx, optimizer_idx)
opt.zero_grad()
loss.backward()
opt.step()
def closure():
loss = model.training_step(batch, batch_idx, optimizer_idx)
opt.zero_grad()
loss.backward()
return loss
opt.step(closure)
for lr_scheduler in lr_schedulers:
lr_scheduler.step()
As can be seen in the code snippet above, Lightning defines a closure with ``training_step``, ``zero_grad``
and ``backward`` for the optimizer to execute. This mechanism is in place to support optimizers which operate on the
output of the closure (e.g. the loss) or need to call the closure several times (e.g. :class:`~torch.optim.LBFGS`).

.. warning::
Before 1.2.2, Lightning internally calls ``backward``, ``step`` and ``zero_grad`` in the order.
From 1.2.2, the order is changed to ``zero_grad``, ``backward`` and ``step``.
Expand Down Expand Up @@ -396,8 +408,11 @@ For example, here step optimizer A every batch and optimizer B every 2 batches.
# update discriminator every 2 steps
if optimizer_idx == 1:
if (batch_idx + 1) % 2 == 0:
# the closure (which includes the `training_step`) won't run if the line below isn't executed
# the closure (which includes the `training_step`) will be executed by `optimizer.step`
optimizer.step(closure=optimizer_closure)
else:
# optional: call the closure by itself to run `training_step` + `backward` without an optimizer step
optimizer_closure()

# ...
# add as many optimizers as you want
Expand Down

0 comments on commit eaa16c7

Please sign in to comment.