Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make optimizers skippable when using amp #7975

Merged
merged 4 commits into from
Jun 16, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
14 changes: 8 additions & 6 deletions pytorch_lightning/plugins/precision/native_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,19 +83,21 @@ def pre_optimizer_step(
f"native PyTorch amp and lbfgs are not compatible (optimizer {optimizer_idx})."
" To request, please file a Github issue in PyTorch and tag @mcarilli"
)
lambda_closure()
carmocca marked this conversation as resolved.
Show resolved Hide resolved

if not pl_module.automatic_optimization:
self.scaler.unscale_(optimizer)
pl_module.trainer.call_hook("on_after_backward")
self.scaler.step(optimizer)
self.scaler.update()
else:
result = lambda_closure()
# lambda_closure returning None indicates that backward has been skipped
if result is not None:
self.scaler.step(optimizer)
self.scaler.update()

return False

def post_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int) -> None:
"""Updates the GradScaler"""
self.scaler.step(optimizer)
self.scaler.update()

@contextmanager
def train_step_context(self) -> Generator[None, None, None]:
"""Enable autocast context"""
Expand Down
41 changes: 41 additions & 0 deletions tests/plugins/test_amp_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,47 @@ def test_amp_gradient_unscale(tmpdir, accum: int):
trainer.fit(model)


@RunIf(min_gpus=1, amp_native=True)
def test_amp_skip_optimizer(tmpdir):
"""
Test that optimizers can be skipped when using amp
"""

class CustomBoringModel(BoringModel):

def __init__(self):
super().__init__()
self.layer1 = torch.nn.Linear(32, 32)
self.layer2 = torch.nn.Linear(32, 2)

def forward(self, x: torch.Tensor):
x = self.layer1(x)
x = self.layer2(x)
return x

def training_step(self, batch, batch_idx, optimizer_idx):
if optimizer_idx == 1:
return None
output = self(batch)
return self.loss(batch, output)

def configure_optimizers(self):
return [
torch.optim.SGD(self.layer1.parameters(), lr=0.1),
torch.optim.SGD(self.layer2.parameters(), lr=0.1),
]

trainer = Trainer(
default_root_dir=tmpdir,
gpus=1,
fast_dev_run=1,
amp_backend='native',
precision=16,
)
model = CustomBoringModel()
trainer.fit(model)


@RunIf(min_gpus=2, amp_apex=True, special=True)
@pytest.mark.parametrize("amp_level", ['O2'])
def test_amp_apex_ddp_fit(amp_level, tmpdir):
Expand Down