Skip to content

Commit 4b52142

Browse files
authoredSep 5, 2020
ref: inner train loop (intermediate step) 3/n (Lightning-AI#3362)
* ref: inner train loop (intermediate step) 3/n * ref: inner train loop (intermediate step) 3/n * ref: inner train loop (intermediate step) 3/n * ref: inner train loop (intermediate step) 3/n * ref: inner train loop (intermediate step) 3/n * ref: inner train loop (intermediate step) 3/n
1 parent f55efb7 commit 4b52142

File tree

2 files changed

+75
-63
lines changed

2 files changed

+75
-63
lines changed
 

‎pytorch_lightning/trainer/training_loop.py

+4-62
Original file line numberDiff line numberDiff line change
@@ -934,73 +934,15 @@ def optimizer_closure(self, split_batch, batch_idx, opt_idx, optimizer, hiddens)
934934
"""
935935
wrap the forward step in a closure so second order methods work
936936
"""
937-
# ---------------------------
938-
# FORWARD (TRAINING STEP + TRAIN STEP END)
939-
# ---------------------------
940-
with self.profiler.profile('model_forward'):
941-
args = self.build_train_args(split_batch, batch_idx, opt_idx, hiddens)
942-
training_step_output = self.accelerator_backend.training_step(args)
943-
training_step_output = self.call_hook('training_step_end', training_step_output)
944-
945-
# ----------------------------
946-
# PROCESS THE RESULT
947-
# ----------------------------
948-
# format and reduce outputs accordingly
949-
training_step_output_for_epoch_end = training_step_output
950-
is_result_obj = isinstance(training_step_output, Result)
951-
952-
# track batch size for weighted average
953-
if is_result_obj:
954-
training_step_output.track_batch_size(len(split_batch))
955-
956-
# don't allow EvalResult in the training_step
957-
if isinstance(training_step_output, EvalResult):
958-
raise MisconfigurationException('training_step cannot return EvalResult, '
959-
'use a dict or TrainResult instead')
960-
961-
# handle regular dicts
962-
if not is_result_obj:
963-
training_step_output = self.process_output(training_step_output, train=True)
964-
965-
training_step_output = AttributeDict(
966-
batch_loss=training_step_output[0],
967-
pbar_on_batch_end=training_step_output[1],
968-
log_metrics=training_step_output[2],
969-
callback_metrics=training_step_output[3],
970-
hiddens=training_step_output[4],
971-
)
972-
973-
# if the user decides to finally reduce things in epoch_end, save raw output without graphs
974-
if isinstance(training_step_output_for_epoch_end, torch.Tensor):
975-
training_step_output_for_epoch_end = training_step_output_for_epoch_end.detach()
976-
elif is_result_obj:
977-
training_step_output_for_epoch_end = copy(training_step_output)
978-
training_step_output_for_epoch_end.detach()
979-
else:
980-
training_step_output_for_epoch_end = recursive_detach(training_step_output_for_epoch_end)
981-
982-
# accumulate loss
983-
# (if accumulate_grad_batches = 1 no effect)
984-
closure_loss = training_step_output.minimize if is_result_obj else training_step_output.batch_loss
985-
closure_loss = closure_loss / self.accumulate_grad_batches
986-
987-
# the loss will get scaled for amp. avoid any modifications to it
988-
untouched_loss = closure_loss.detach().clone()
937+
# lightning module hook
938+
result = self.train_loop.training_step(split_batch, batch_idx, opt_idx, hiddens)
989939

990940
# backward pass
991-
with self.profiler.profile('model_backward'):
992-
closure_loss = self.accelerator_backend.backward(closure_loss, optimizer, opt_idx)
941+
self.train_loop.backward(result, optimizer, opt_idx)
993942

994943
# hook
995-
self.train_loop.on_after_backward(training_step_output, batch_idx, untouched_loss)
944+
self.train_loop.on_after_backward(result.training_step_output, batch_idx, result.loss)
996945

997-
# result
998-
result = AttributeDict(
999-
loss=untouched_loss,
1000-
training_step_output=training_step_output,
1001-
training_step_output_for_epoch_end=training_step_output_for_epoch_end,
1002-
hiddens=training_step_output.hiddens,
1003-
)
1004946
return result
1005947

1006948
def build_train_args(self, batch, batch_idx, opt_idx, hiddens):

‎pytorch_lightning/trainer/training_loop_temp.py

+71-1
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,12 @@
55
from pytorch_lightning.utilities.model_utils import is_overridden
66
from pytorch_lightning.trainer.supporters import Accumulator
77
from pytorch_lightning.callbacks import ModelCheckpoint
8-
from pytorch_lightning.core.step_result import Result
98
from pytorch_lightning import _logger as log
9+
from pytorch_lightning.utilities.memory import recursive_detach
10+
from pytorch_lightning.utilities.exceptions import MisconfigurationException
11+
from pytorch_lightning.core.step_result import EvalResult, Result
12+
from pytorch_lightning.utilities.parsing import AttributeDict
13+
from copy import copy
1014

1115

1216
class TrainLoop:
@@ -130,6 +134,11 @@ def get_optimizers_iterable(self):
130134
opt_idx = np.argmax(optimizer_freq_cumsum > current_place_in_loop)
131135
return [(opt_idx, self.trainer.optimizers[opt_idx])]
132136

137+
def backward(self, result, optimizer, opt_idx):
138+
# backward pass
139+
with self.trainer.profiler.profile('model_backward'):
140+
result.closure_loss = self.trainer.accelerator_backend.backward(result.closure_loss, optimizer, opt_idx)
141+
133142
def on_after_backward(self, training_step_output, batch_idx, untouched_loss):
134143
is_result_obj = isinstance(training_step_output, Result)
135144

@@ -143,3 +152,64 @@ def on_after_backward(self, training_step_output, batch_idx, untouched_loss):
143152

144153
# when in dev debugging track the losses
145154
self.trainer.dev_debugger.track_train_loss_history(batch_idx, untouched_loss.detach())
155+
156+
def training_step(self, split_batch, batch_idx, opt_idx, hiddens):
157+
with self.trainer.profiler.profile('model_forward'):
158+
args = self.trainer.build_train_args(split_batch, batch_idx, opt_idx, hiddens)
159+
training_step_output = self.trainer.accelerator_backend.training_step(args)
160+
training_step_output = self.trainer.call_hook('training_step_end', training_step_output)
161+
162+
# ----------------------------
163+
# PROCESS THE RESULT
164+
# ----------------------------
165+
# format and reduce outputs accordingly
166+
training_step_output_for_epoch_end = training_step_output
167+
is_result_obj = isinstance(training_step_output, Result)
168+
169+
# track batch size for weighted average
170+
if is_result_obj:
171+
training_step_output.track_batch_size(len(split_batch))
172+
173+
# don't allow EvalResult in the training_step
174+
if isinstance(training_step_output, EvalResult):
175+
raise MisconfigurationException('training_step cannot return EvalResult, '
176+
'use a dict or TrainResult instead')
177+
178+
# handle regular dicts
179+
if not is_result_obj:
180+
training_step_output = self.trainer.process_output(training_step_output, train=True)
181+
182+
training_step_output = AttributeDict(
183+
batch_loss=training_step_output[0],
184+
pbar_on_batch_end=training_step_output[1],
185+
log_metrics=training_step_output[2],
186+
callback_metrics=training_step_output[3],
187+
hiddens=training_step_output[4],
188+
)
189+
190+
# if the user decides to finally reduce things in epoch_end, save raw output without graphs
191+
if isinstance(training_step_output_for_epoch_end, torch.Tensor):
192+
training_step_output_for_epoch_end = training_step_output_for_epoch_end.detach()
193+
elif is_result_obj:
194+
training_step_output_for_epoch_end = copy(training_step_output)
195+
training_step_output_for_epoch_end.detach()
196+
else:
197+
training_step_output_for_epoch_end = recursive_detach(training_step_output_for_epoch_end)
198+
199+
# accumulate loss
200+
# (if accumulate_grad_batches = 1 no effect)
201+
closure_loss = training_step_output.minimize if is_result_obj else training_step_output.batch_loss
202+
closure_loss = closure_loss / self.trainer.accumulate_grad_batches
203+
204+
# the loss will get scaled for amp. avoid any modifications to it
205+
untouched_loss = closure_loss.detach().clone()
206+
207+
# result
208+
result = AttributeDict(
209+
closure_loss=closure_loss,
210+
loss=untouched_loss,
211+
training_step_output=training_step_output,
212+
training_step_output_for_epoch_end=training_step_output_for_epoch_end,
213+
hiddens=training_step_output.hiddens,
214+
)
215+
return result

0 commit comments

Comments
 (0)
Failed to load comments.