Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[hotfix] Extend Optimizer + update doc #5095

Merged
merged 21 commits into from
Dec 11, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
49 changes: 36 additions & 13 deletions docs/source/optimizers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -191,46 +191,69 @@ override the :meth:`optimizer_step` function.

For example, here step optimizer A every 2 batches and optimizer B every 4 batches

.. testcode::
.. note:: When using Trainer(enable_pl_optimizer=True), there is no need to call `.zero_grad()`.

def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, second_order_closure=None, on_tpu=False, using_native_amp=False, using_lbfgs=False):
optimizer.step()
.. testcode::

def optimizer_zero_grad(self, current_epoch, batch_idx, optimizer, opt_idx):
optimizer.zero_grad()

# Alternating schedule for optimizer steps (ie: GANs)
def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, second_order_closure=None, on_tpu=False, using_native_amp=False, using_lbfgs=False):
def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, closure, on_tpu=False, using_native_amp=False, using_lbfgs=False):
# update generator opt every 2 steps
if optimizer_i == 0:
if batch_nb % 2 == 0 :
optimizer.step()
optimizer.zero_grad()
optimizer.step(closure=closure)

# update discriminator opt every 4 steps
if optimizer_i == 1:
if batch_nb % 4 == 0 :
optimizer.step()
optimizer.zero_grad()
optimizer.step(closure=closure)

.. note:: When using ``Trainer(enable_pl_optimizer=True)``, ``.step`` accepts a boolean ``make_optimizer_step`` which can be used as follow.

tchaton marked this conversation as resolved.
Show resolved Hide resolved
.. testcode::

def optimizer_zero_grad(self, current_epoch, batch_idx, optimizer, opt_idx):
optimizer.zero_grad()

# Alternating schedule for optimizer steps (ie: GANs)
def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, closure, on_tpu=False, using_native_amp=False, using_lbfgs=False):
# update generator opt every 2 steps
if optimizer_i == 0:
optimizer.step(closure=closure, make_optimizer_step=(batch_nb % 2) == 0)

# ...
# add as many optimizers as you want
# update discriminator opt every 4 steps
if optimizer_i == 1:
optimizer.step(closure=closure, make_optimizer_step=(batch_nb % 4) == 0)

Here we add a learning-rate warm up

.. testcode::

# learning rate warm-up
def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, second_order_closure=None, on_tpu=False, using_native_amp=False, using_lbfgs=False):
def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, closure, on_tpu=False, using_native_amp=False, using_lbfgs=False):
tchaton marked this conversation as resolved.
Show resolved Hide resolved
# warm up lr
if self.trainer.global_step < 500:
lr_scale = min(1., float(self.trainer.global_step + 1) / 500.)
for pg in optimizer.param_groups:
pg['lr'] = lr_scale * self.hparams.learning_rate

# update params
optimizer.step()
optimizer.zero_grad()
optimizer.step(closure=closure)

The default ``optimizer_step`` is relying on the internal ``LightningOptimizer`` to properly perform a step.

.. testcode::

from pytorch_lightning.core.optimizer import LightningOptimizer

# function hook in LightningModule
def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, closure, on_tpu=False, using_native_amp=False, using_lbfgs=False):
if not isinstance(optimizer, LightningOptimizer):
# wraps into LightingOptimizer only for running step
optimizer = LightningOptimizer.to_lightning_optimizer(optimizer, self.trainer)
optimizer.step(closure=closure)

----------

Expand Down
4 changes: 1 addition & 3 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -1170,7 +1170,6 @@ def toggle_optimizer(self, optimizer: Optimizer, optimizer_idx: int):

def optimizer_step(
self,
*args,
epoch: int = None,
batch_idx: int = None,
optimizer: Optimizer = None,
Expand All @@ -1179,7 +1178,6 @@ def optimizer_step(
on_tpu: bool = None,
using_native_amp: bool = None,
using_lbfgs: bool = None,
**kwargs,
) -> None:
r"""
Override this method to adjust the default way the
Expand Down Expand Up @@ -1254,7 +1252,7 @@ def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx,
if not isinstance(optimizer, LightningOptimizer):
# wraps into LightingOptimizer only for running step
optimizer = LightningOptimizer.to_lightning_optimizer(optimizer, self.trainer)
optimizer.step(closure=optimizer_closure, *args, **kwargs)
optimizer.step(closure=optimizer_closure)
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved

def optimizer_zero_grad(
self, epoch: int, batch_idx: int, optimizer: Optimizer, optimizer_idx: int
Expand Down
34 changes: 26 additions & 8 deletions pytorch_lightning/core/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import types
from typing import Any, Callable, Optional
from weakref import proxy
Expand Down Expand Up @@ -58,12 +57,35 @@ def __init__(self,
else:
self.__class__ = type("Lightning" + optimizer.__class__.__name__, (self.__class__, optimizer.__class__), {})

self._trainer = None
self._optimizer = optimizer
self._trainer = None
self._accumulate_grad_batches = accumulate_grad_batches
self._support_closure = 'closure' in inspect.signature(optimizer.step).parameters
self._optimizer_idx = None

@property
def defaults(self):
return self._optimizer.defaults

@defaults.setter
def defaults(self, defaults):
self._optimizer.defaults = defaults

@property
def state(self):
return self._optimizer.state

@state.setter
def state(self, state):
self._optimizer.state = state

@property
def param_groups(self):
return self._optimizer.param_groups

@param_groups.setter
def param_groups(self, param_groups):
self._optimizer.param_groups = param_groups

@property
def accumulate_grad_batches(self):
return self._accumulate_grad_batches
Expand Down Expand Up @@ -111,11 +133,7 @@ def __optimizer_step(self, *args, closure: Optional[Callable] = None, profiler_n

else:
with trainer.profiler.profile(profiler_name):
if self._support_closure:
optimizer.step(closure=closure, *args, **kwargs)
else:
closure()
optimizer.step(*args, **kwargs)
optimizer.step(closure=closure, *args, **kwargs)

accelerator_backend = trainer.accelerator_backend
if accelerator_backend is not None and accelerator_backend.rpc_enabled:
Expand Down
14 changes: 6 additions & 8 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,7 @@ def _process_result(self, training_step_output, split_batch):

return training_step_output_for_epoch_end

def optimizer_step(self, optimizer, opt_idx, batch_idx, train_step_and_backward_closure, *args, **kwargs):
def optimizer_step(self, optimizer, opt_idx, batch_idx, train_step_and_backward_closure):
model_ref = self.trainer.get_model()

is_lbfgs = isinstance(optimizer, torch.optim.LBFGS)
Expand All @@ -491,16 +491,14 @@ def optimizer_step(self, optimizer, opt_idx, batch_idx, train_step_and_backward_

# model hook
model_ref.optimizer_step(
epoch=self.trainer.current_epoch,
batch_idx=batch_idx,
optimizer=optimizer,
optimizer_idx=opt_idx,
optimizer_closure=train_step_and_backward_closure,
self.trainer.current_epoch,
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved
batch_idx,
optimizer,
opt_idx,
train_step_and_backward_closure,
on_tpu=self.trainer.use_tpu and TPU_AVAILABLE,
using_native_amp=using_native_amp,
using_lbfgs=is_lbfgs,
*args,
**kwargs,
)

def on_before_zero_grad(self, optimizer):
Expand Down
44 changes: 44 additions & 0 deletions tests/core/test_lightning_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,47 @@ def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx,
assert sgd_zero_grad.call_count == 4
assert adam_step.call_count == 2
assert adam_zero_grad.call_count == 2


@pytest.mark.parametrize("enable_pl_optimizer", [False, True])
def test_params_groups_and_state_are_accessible(enable_pl_optimizer, tmpdir):

with patch("torch.optim.SGD.step") as sgd_step, \
patch("torch.optim.SGD.zero_grad") as sgd_zero_grad, \
patch("torch.optim.Adam.step") as adam_step, \
patch("torch.optim.Adam.zero_grad") as adam_zero_grad:

class TestModel(BoringModel):

def training_step(self, batch, batch_idx, optimizer_idx):
output = self.layer(batch)
loss = self.loss(batch, output)
return {"loss": loss}

def configure_optimizers(self):
optimizer = SGD(self.layer.parameters(), lr=0.1)
optimizer_2 = Adam(self.layer.parameters(), lr=0.1)
return [optimizer, optimizer_2]

def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, closure,
on_tpu=False, using_native_amp=False, using_lbfgs=False):
# warm up lr
if self.trainer.global_step < 500:
lr_scale = min(1., float(self.trainer.global_step + 1) / 500.)
for pg in optimizer.param_groups:
pg['lr'] = lr_scale * 0.01

optimizer.step(closure=closure)

model = TestModel()
model.training_epoch_end = None

trainer = Trainer(
max_epochs=1,
default_root_dir=tmpdir,
limit_train_batches=8,
accumulate_grad_batches=1,
enable_pl_optimizer=enable_pl_optimizer
)

trainer.fit(model)
68 changes: 18 additions & 50 deletions tests/core/test_lightning_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,12 +193,29 @@ def test_state(tmpdir):
model = torch.nn.Linear(3, 4)
optimizer = torch.optim.Adam(model.parameters())
lightning_optimizer = LightningOptimizer(optimizer)

# test state
assert optimizer.state == lightning_optimizer.state
lightning_optimizer.state = optimizer.state
assert optimizer.state == lightning_optimizer.state

# test param_groups
assert optimizer.param_groups == lightning_optimizer.param_groups
lightning_optimizer.param_groups = optimizer.param_groups
assert optimizer.param_groups == lightning_optimizer.param_groups

# test defaults
assert optimizer.defaults == lightning_optimizer.defaults
lightning_optimizer.defaults = optimizer.defaults
assert optimizer.defaults == lightning_optimizer.defaults

assert isinstance(lightning_optimizer, LightningOptimizer)
assert isinstance(lightning_optimizer, Adam)
assert isinstance(lightning_optimizer, Optimizer)
lightning_dict = {}
special_attrs = ["_accumulate_grad_batches", "_optimizer", "_optimizer_idx", "_support_closure",
"_trainer"]
"_trainer", "__getstate__", "__setstate__", "state_dict", "load_state_dict",
"zero_grad", "__setstate__", "add_param_group"]
for k, v in lightning_optimizer.__dict__.items():
if k not in special_attrs:
lightning_dict[k] = v
Expand All @@ -207,55 +224,6 @@ def test_state(tmpdir):
assert optimizer.state == lightning_optimizer.state


def test_lightning_optimizer_with_wrong_optimizer_interface(tmpdir):
class OptimizerWrapper(object):
def __init__(self, optimizer):
self.optim = optimizer
self.state_dict = self.optim.state_dict
self.load_state_dict = self.optim.load_state_dict
self.zero_grad = self.optim.zero_grad
self.add_param_group = self.optim.add_param_group
self.__setstate__ = self.optim.__setstate__
self.__getstate__ = self.optim.__getstate__
self.__repr__ = self.optim.__repr__

@property
def __class__(self):
return Optimizer

@property
def state(self):
return self.optim.state

@property
def param_groups(self):
return self.optim.param_groups

@param_groups.setter
def param_groups(self, value):
self.optim.param_groups = value

def step(self):
# wrongly defined step. Should contain closure
self.optim.step(closure=None)

class TestLightningOptimizerModel(BoringModel):

def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=0.1)
optimizer = OptimizerWrapper(optimizer)
return [optimizer]

model = TestLightningOptimizerModel()
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
weights_summary=None,
log_every_n_steps=1,
)
trainer.fit(model)


def test_lightning_optimizer_automatic_optimization(tmpdir):
"""
Test lightning optimize works with make_optimizer_step in automatic_optimization
Expand Down
6 changes: 3 additions & 3 deletions tests/trainer/optimization/test_manual_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -855,7 +855,7 @@ def automatic_optimization(self) -> bool:
)

trainer.fit(model)
expected_calls = [call() for s in range(2)]
expected_calls = [call(closure=ANY) for s in range(2)]
step_mock.assert_has_calls(expected_calls)


Expand Down Expand Up @@ -933,9 +933,9 @@ def automatic_optimization(self) -> bool:
)

trainer.fit(model)
expected_calls = [call(optim='sgd') for s in range(4)]
expected_calls = [call(closure=ANY, optim='sgd') for s in range(4)]
mock_sgd_step.assert_has_calls(expected_calls)
expected_calls = [call() for s in range(2)]
expected_calls = [call(closure=ANY) for s in range(2)]
mock_adam_step.assert_has_calls(expected_calls)


Expand Down