Skip to content

Commit 5a474c4

Browse files
authoredSep 5, 2020
ref: inner train loop (intermediate step) 1/n (Lightning-AI#3359)
1 parent 43b8d62 commit 5a474c4

File tree

4 files changed

+55
-30
lines changed

4 files changed

+55
-30
lines changed
 

‎pytorch_lightning/accelerators/base_backend.py

+34
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import torch
22
from typing import Any
33
from pytorch_lightning.utilities.apply_func import move_data_to_device
4+
from pytorch_lightning.utilities import AMPType, rank_zero_warn
45

56

67
class Accelerator(object):
@@ -31,3 +32,36 @@ def validation_step_end(self, output):
3132

3233
def process_dataloader(self, dataloader):
3334
return dataloader
35+
36+
def backward(self, closure_loss, optimizer, opt_idx):
37+
model_ref = self.trainer.get_model()
38+
39+
# scale loss for 16 bit
40+
if self.trainer.precision == 16:
41+
closure_loss = model_ref.amp_scale_loss(
42+
closure_loss,
43+
optimizer,
44+
opt_idx,
45+
amp_backend=self.trainer.amp_backend
46+
)
47+
48+
# enter amp context
49+
if self.trainer.amp_backend == AMPType.APEX:
50+
self.trainer.dev_debugger.track_event('AMP', str(AMPType.APEX))
51+
context = closure_loss
52+
closure_loss = closure_loss.__enter__()
53+
54+
# do backward pass
55+
model_ref.backward(self, closure_loss, optimizer, opt_idx)
56+
57+
# exit amp context
58+
if self.trainer.precision == 16 and self.trainer.amp_backend == AMPType.APEX:
59+
a, b, c = None, None, None
60+
error = context.__exit__(a, b, c)
61+
if error:
62+
rank_zero_warn(a, b, c)
63+
raise Exception('apex unscale error')
64+
65+
# once backward has been applied, release graph
66+
closure_loss = closure_loss.detach()
67+
return closure_loss

‎pytorch_lightning/accelerators/cpu_backend.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import torch
1515
from pytorch_lightning.utilities.exceptions import MisconfigurationException
1616
from pytorch_lightning.accelerators.base_backend import Accelerator
17-
from pytorch_lightning.utilities import AMPType
17+
from pytorch_lightning.utilities import AMPType, rank_zero_warn
1818

1919

2020
class CPUBackend(Accelerator):

‎pytorch_lightning/accelerators/tpu_backend.py

+11
Original file line numberDiff line numberDiff line change
@@ -220,3 +220,14 @@ def __setup_tpu_training(self, model: LightningModule, trainer):
220220
log.info(f'INIT TPU local core: {trainer.tpu_local_core_rank},'
221221
f' global rank: {trainer.tpu_global_core_rank}'
222222
f' with XLA_USE_BF16={os.environ.get("XLA_USE_BF16")}')
223+
224+
def backward(self, closure_loss, optimizer, opt_idx):
225+
model_ref = self.trainer.get_model()
226+
227+
# do backward pass
228+
model_ref.backward(self, closure_loss, optimizer, opt_idx)
229+
230+
# detach after backward
231+
closure_loss = closure_loss.detach()
232+
233+
return closure_loss

‎pytorch_lightning/trainer/training_loop.py

+9-29
Original file line numberDiff line numberDiff line change
@@ -988,36 +988,16 @@ def optimizer_closure(self, split_batch, batch_idx, opt_idx, optimizer, hiddens)
988988
untouched_loss = closure_loss.detach().clone()
989989

990990
# backward pass
991-
model_ref = self.get_model()
992991
with self.profiler.profile('model_backward'):
993-
# scale loss for 16 bit
994-
if self.precision == 16 and not self.on_tpu:
995-
closure_loss = model_ref.amp_scale_loss(closure_loss, optimizer, opt_idx, amp_backend=self.amp_backend)
996-
997-
# enter amp context
998-
if self.amp_backend == AMPType.APEX:
999-
self.dev_debugger.track_event('AMP', str(AMPType.APEX))
1000-
context = closure_loss
1001-
closure_loss = closure_loss.__enter__()
1002-
1003-
# do backward pass
1004-
model_ref.backward(self, closure_loss, optimizer, opt_idx)
1005-
1006-
# exit amp context
1007-
if self.precision == 16 and self.amp_backend == AMPType.APEX and not self.on_tpu:
1008-
a, b, c = None, None, None
1009-
error = context.__exit__(a, b, c)
1010-
if error:
1011-
rank_zero_warn(a, b, c)
1012-
raise Exception('apex unscale error')
1013-
1014-
# once backward has been applied, release graph
1015-
closure_loss = closure_loss.detach()
1016-
1017-
if is_result_obj:
1018-
training_step_output.detach()
1019-
else:
1020-
training_step_output.batch_loss = training_step_output.batch_loss.detach()
992+
closure_loss = self.accelerator_backend.backward(closure_loss, optimizer, opt_idx)
993+
994+
# --------------------
995+
# ON AFTER BACKWARD TODO
996+
# --------------------
997+
if is_result_obj:
998+
training_step_output.detach()
999+
else:
1000+
training_step_output.batch_loss = training_step_output.batch_loss.detach()
10211001

10221002
if self.use_horovod:
10231003
# Synchronize Horovod to ensure gradient manipulations (e.g., loss scaling) are valid

0 commit comments

Comments
 (0)
Failed to load comments.