Lightning offers two modes for managing the optimization process:
- automatic optimization
- manual optimization
For the majority of research cases, automatic optimization will do the right thing for you and it is what most users should use.
For advanced/expert users who want to do esoteric optimization schedules or techniques, use manual optimization.
For advanced research topics like reinforcement learning, sparse coding, or GAN research, it may be desirable to manually manage the optimization process.
This is only recommended for experts who need ultimate flexibility. Lightning will handle only precision and accelerators logic. The users are left with optimizer.zero_grad()
, gradient accumulation, model toggling, etc..
To manually optimize, do the following:
- Set
self.automatic_optimization=False
in yourLightningModule
's__init__
. - Use the following functions and call them manually:
self.optimizers()
to access your optimizers (one or multiple)optimizer.zero_grad()
to clear the gradients from the previous training stepself.manual_backward(loss)
instead ofloss.backward()
optimizer.step()
to update your model parameters
Here is a minimal example of manual optimization.
python
from pytorch_lightning import LightningModule
- class MyModel(LightningModule):
- def __init__(self):
super().__init__() # Important: This property activates manual optimization. self.automatic_optimization = False
- def training_step(self, batch, batch_idx):
opt = self.optimizers() opt.zero_grad() loss = self.compute_loss(batch) self.manual_backward(loss) opt.step()
Warning
Before 1.2, optimizer.step()
was calling optimizer.zero_grad()
internally. From 1.2, it is left to the user's expertise.
Tip
Be careful where you call optimizer.zero_grad()
, or your model won't converge. It is good practice to call optimizer.zero_grad()
before self.manual_backward(loss)
.
You can accumulate gradients over batches similarly to ~pytorch_lightning.trainer.trainer.Trainer.accumulate_grad_batches
of automatic optimization. To perform gradient accumulation with one optimizer, you can do as such.
python
# accumulate gradients over n batches def __init__(self): super().__init__() self.automatic_optimization = False
- def training_step(self, batch, batch_idx):
opt = self.optimizers()
loss = self.compute_loss(batch) self.manual_backward(loss)
# accumulate gradients of n batches if (batch_idx + 1) % n == 0: opt.step() opt.zero_grad()
Here is an example training a simple GAN with multiple optimizers.
python
import torch from torch import Tensor from pytorch_lightning import LightningModule
- class SimpleGAN(LightningModule):
- def __init__(self):
super().__init__() self.G = Generator() self.D = Discriminator()
# Important: This property activates manual optimization. self.automatic_optimization = False
- def sample_z(self, n) -> Tensor:
sample = self._Z.sample((n,)) return sample
- def sample_G(self, n) -> Tensor:
z = self.sample_z(n) return self.G(z)
- def training_step(self, batch, batch_idx):
# Implementation follows the PyTorch tutorial: # https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html g_opt, d_opt = self.optimizers()
X, _ = batch batch_size = X.shape[0]
real_label = torch.ones((batch_size, 1), device=self.device) fake_label = torch.zeros((batch_size, 1), device=self.device)
g_X = self.sample_G(batch_size)
d_x = self.D(X) errD_real = self.criterion(d_x, real_label)
d_z = self.D(g_X.detach()) errD_fake = self.criterion(d_z, fake_label)
errD = errD_real + errD_fake
d_opt.zero_grad() self.manual_backward(errD) d_opt.step()
d_z = self.D(g_X) errG = self.criterion(d_z, real_label)
g_opt.zero_grad() self.manual_backward(errG) g_opt.step()
self.log_dict({"g_loss": errG, "d_loss": errD}, prog_bar=True)
- def configure_optimizers(self):
g_opt = torch.optim.Adam(self.G.parameters(), lr=1e-5) d_opt = torch.optim.Adam(self.D.parameters(), lr=1e-5) return g_opt, d_opt
Every optimizer you use can be paired with any Learning Rate Scheduler. Please see the documentation of ~pytorch_lightning.core.lightning.LightningModule.configure_optimizers
for all the available options
You can call lr_scheduler.step()
at arbitrary intervals. Use self.lr_schedulers()
in your ~pytorch_lightning.core.lightning.LightningModule
to access any learning rate schedulers defined in your ~pytorch_lightning.core.lightning.LightningModule.configure_optimizers
.
Warning
* Before 1.3, Lightning automatically called lr_scheduler.step()
in both automatic and manual optimization. From 1.3, lr_scheduler.step()
is now for the user to call at arbitrary intervals. * Note that the lr_scheduler_config
keys, such as "step"
and "interval"
, will be ignored even if they are provided in your ~pytorch_lightning.core.lightning.LightningModule.configure_optimizers
during manual optimization.
Here is an example calling lr_scheduler.step()
every step.
python
# step every batch def __init__(self): super().__init__() self.automatic_optimization = False
- def training_step(self, batch, batch_idx):
# do forward, backward, and optimization ...
# single scheduler sch = self.lr_schedulers() sch.step()
# multiple schedulers sch1, sch2 = self.lr_schedulers() sch1.step() sch2.step()
If you want to call lr_scheduler.step()
every n
steps/epochs, do the following.
python
- def __init__(self):
super().__init__() self.automatic_optimization = False
- def training_step(self, batch, batch_idx):
# do forward, backward, and optimization ...
sch = self.lr_schedulers()
# step every n batches if (batch_idx + 1) % n == 0: sch.step()
# step every n epochs if self.trainer.is_last_batch and (self.trainer.current_epoch + 1) % n == 0: sch.step()
If you want to call schedulers that require a metric value after each epoch, consider doing the following:
- def __init__(self):
super().__init__() self.automatic_optimization = False
- def training_epoch_end(self, outputs):
sch = self.lr_schedulers()
# If the selected scheduler is a ReduceLROnPlateau scheduler. if isinstance(sch, torch.optim.lr_scheduler.ReduceLROnPlateau): sch.step(self.trainer.callback_metrics["loss"])
It is a good practice to provide the optimizer with a closure function that performs a forward
, zero_grad
and backward
of your model. It is optional for most optimizers, but makes your code compatible if you switch to an optimizer which requires a closure, such as torch.optim.LBFGS
.
See the PyTorch docs for more about the closure.
Here is an example using a closure function.
python
- def __init__(self):
super().__init__() self.automatic_optimization = False
- def configure_optimizers(self):
return torch.optim.LBFGS(...)
- def training_step(self, batch, batch_idx):
opt = self.optimizers()
- def closure():
loss = self.compute_loss(batch) opt.zero_grad() self.manual_backward(loss) return loss
opt.step(closure=closure)
Warning
The torch.optim.LBFGS
optimizer is not supported for apex AMP, native AMP, IPUs, or DeepSpeed.
optimizer
is a ~pytorch_lightning.core.optimizer.LightningOptimizer
object wrapping your own optimizer configured in your ~pytorch_lightning.core.lightning.LightningModule.configure_optimizers
. You can access your own optimizer with optimizer.optimizer
. However, if you use your own optimizer to perform a step, Lightning won't be able to support accelerators and precision for you.
python
- def __init__(self):
super().__init__() self.automatic_optimization = False
- def training_step(batch, batch_idx):
optimizer = self.optimizers()
# optimizer is a LightningOptimizer wrapping the optimizer. # To access it, do the following. # However, it won't work on TPU, AMP, etc... optimizer = optimizer.optimizer ...
With Lightning, most users don't have to think about when to call .zero_grad()
, .backward()
and .step()
since Lightning automates that for you.
Under the hood, Lightning does the following:
for epoch in epochs:
for batch in data:
def closure():
loss = model.training_step(batch, batch_idx, ...)
optimizer.zero_grad()
loss.backward()
return loss
optimizer.step(closure)
for lr_scheduler in lr_schedulers:
lr_scheduler.step()
In the case of multiple optimizers, Lightning does the following:
for epoch in epochs:
for batch in data:
for opt in optimizers:
def closure():
loss = model.training_step(batch, batch_idx, optimizer_idx)
opt.zero_grad()
loss.backward()
return loss
opt.step(closure)
for lr_scheduler in lr_schedulers:
lr_scheduler.step()
As can be seen in the code snippet above, Lightning defines a closure with training_step
, zero_grad
and backward
for the optimizer to execute. This mechanism is in place to support optimizers which operate on the output of the closure (e.g. the loss) or need to call the closure several times (e.g. ~torch.optim.LBFGS
).
Warning
Before 1.2.2, Lightning internally calls backward
, step
and zero_grad
in the order. From 1.2.2, the order is changed to zero_grad
, backward
and step
.
To use multiple optimizers (optionally with learning rate schedulers), return two or more optimizers from ~pytorch_lightning.core.LightningModule.configure_optimizers
.
python
# two optimizers, no schedulers def configure_optimizers(self): return Adam(...), SGD(...)
# two optimizers, one scheduler for adam only def configure_optimizers(self): opt1 = Adam(...) opt2 = SGD(...) optimizers = [opt1, opt2] lr_schedulers = {"scheduler": ReduceLROnPlateau(opt1, ...), "monitor": "metric_to_track"} return optimizers, lr_schedulers
# two optimizers, two schedulers def configure_optimizers(self): opt1 = Adam(...) opt2 = SGD(...) return [opt1, opt2], [StepLR(opt1, ...), OneCycleLR(opt2, ...)]
Under the hood, Lightning will call each optimizer sequentially:
for epoch in epochs:
for batch in data:
for opt in optimizers:
loss = train_step(batch, batch_idx, optimizer_idx)
opt.zero_grad()
loss.backward()
opt.step()
for lr_scheduler in lr_schedulers:
lr_scheduler.step()
To do more interesting things with your optimizers such as learning rate warm-up or odd scheduling, override the ~pytorch_lightning.core.lightning.LightningModule.optimizer_step
function.
Warning
If you are overriding this method, make sure that you pass the optimizer_closure
parameter to optimizer.step()
function as shown in the examples because training_step()
, optimizer.zero_grad()
, backward()
are called in the closure function.
For example, here step optimizer A every batch and optimizer B every 2 batches.
python
# Alternating schedule for optimizer steps (e.g. GANs) def optimizer_step( self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, on_tpu=False, using_native_amp=False, using_lbfgs=False, ): # update generator every step if optimizer_idx == 0: optimizer.step(closure=optimizer_closure)
# update discriminator every 2 steps if optimizer_idx == 1: if (batch_idx + 1) % 2 == 0: # the closure (which includes the training_step) will be executed by optimizer.step optimizer.step(closure=optimizer_closure) else: # call the closure by itself to run training_step + backward without an optimizer step optimizer_closure()
# ... # add as many optimizers as you want
Here we add a learning rate warm-up.
python
# learning rate warm-up def optimizer_step( self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, on_tpu=False, using_native_amp=False, using_lbfgs=False, ): # skip the first 500 steps if self.trainer.global_step < 500: lr_scale = min(1.0, float(self.trainer.global_step + 1) / 500.0) for pg in optimizer.param_groups: pg["lr"] = lr_scale * self.hparams.learning_rate
# update params optimizer.step(closure=optimizer_closure)
optimizer
is a ~pytorch_lightning.core.optimizer.LightningOptimizer
object wrapping your own optimizer configured in your ~pytorch_lightning.core.lightning.LightningModule.configure_optimizers
. You can access your own optimizer with optimizer.optimizer
. However, if you use your own optimizer to perform a step, Lightning won't be able to support accelerators and precision for you.
python
# function hook in LightningModule def optimizer_step( self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, on_tpu=False, using_native_amp=False, using_lbfgs=False, ): optimizer.step(closure=optimizer_closure)
# optimizer is a LightningOptimizer wrapping the optimizer. # To access it, do the following. # However, it won't work on TPU, AMP, etc... def optimizer_step( self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, on_tpu=False, using_native_amp=False, using_lbfgs=False, ): optimizer = optimizer.optimizer optimizer.step(closure=optimizer_closure)
To configure custom gradient clipping, consider overriding the ~pytorch_lightning.core.lightning.LightningModule.configure_gradient_clipping
method. Attributes ~pytorch_lightning.trainer.trainer.Trainer.gradient_clip_val
and ~pytorch_lightning.trainer.trainer.Trainer.gradient_clip_algorithm
will be passed in the respective arguments here and Lightning will handle gradient clipping for you. In case you want to set different values for your arguments of your choice and let Lightning handle the gradient clipping, you can use the inbuilt ~pytorch_lightning.core.lightning.LightningModule.clip_gradients
method and pass the arguments along with your optimizer.
Note
Make sure to not override ~pytorch_lightning.core.lightning.LightningModule.clip_gradients
method. If you want to customize gradient clipping, consider using ~pytorch_lightning.core.lightning.LightningModule.configure_gradient_clipping
method.
For example, here we will apply gradient clipping only to the gradients associated with optimizer A.
python
- def configure_gradient_clipping(self, optimizer, optimizer_idx, gradient_clip_val, gradient_clip_algorithm):
- if optimizer_idx == 0:
# Lightning will handle the gradient clipping self.clip_gradients( optimizer, gradient_clip_val=gradient_clip_val, gradient_clip_algorithm=gradient_clip_algorithm )
Here we configure gradient clipping differently for optimizer B.
python
- def configure_gradient_clipping(self, optimizer, optimizer_idx, gradient_clip_val, gradient_clip_algorithm):
- if optimizer_idx == 0:
# Lightning will handle the gradient clipping self.clip_gradients( optimizer, gradient_clip_val=gradient_clip_val, gradient_clip_algorithm=gradient_clip_algorithm )
- elif optimizer_idx == 1:
- self.clip_gradients(
optimizer, gradient_clip_val=gradient_clip_val * 2, gradient_clip_algorithm=gradient_clip_algorithm
)