5
5
from pytorch_lightning .utilities .model_utils import is_overridden
6
6
from pytorch_lightning .trainer .supporters import Accumulator
7
7
from pytorch_lightning .callbacks import ModelCheckpoint
8
- from pytorch_lightning .core .step_result import Result
9
8
from pytorch_lightning import _logger as log
9
+ from pytorch_lightning .utilities .memory import recursive_detach
10
+ from pytorch_lightning .utilities .exceptions import MisconfigurationException
11
+ from pytorch_lightning .core .step_result import EvalResult , Result
12
+ from pytorch_lightning .utilities .parsing import AttributeDict
13
+ from copy import copy
10
14
11
15
12
16
class TrainLoop :
@@ -130,6 +134,11 @@ def get_optimizers_iterable(self):
130
134
opt_idx = np .argmax (optimizer_freq_cumsum > current_place_in_loop )
131
135
return [(opt_idx , self .trainer .optimizers [opt_idx ])]
132
136
137
+ def backward (self , result , optimizer , opt_idx ):
138
+ # backward pass
139
+ with self .trainer .profiler .profile ('model_backward' ):
140
+ result .closure_loss = self .trainer .accelerator_backend .backward (result .closure_loss , optimizer , opt_idx )
141
+
133
142
def on_after_backward (self , training_step_output , batch_idx , untouched_loss ):
134
143
is_result_obj = isinstance (training_step_output , Result )
135
144
@@ -143,3 +152,64 @@ def on_after_backward(self, training_step_output, batch_idx, untouched_loss):
143
152
144
153
# when in dev debugging track the losses
145
154
self .trainer .dev_debugger .track_train_loss_history (batch_idx , untouched_loss .detach ())
155
+
156
+ def training_step (self , split_batch , batch_idx , opt_idx , hiddens ):
157
+ with self .trainer .profiler .profile ('model_forward' ):
158
+ args = self .trainer .build_train_args (split_batch , batch_idx , opt_idx , hiddens )
159
+ training_step_output = self .trainer .accelerator_backend .training_step (args )
160
+ training_step_output = self .trainer .call_hook ('training_step_end' , training_step_output )
161
+
162
+ # ----------------------------
163
+ # PROCESS THE RESULT
164
+ # ----------------------------
165
+ # format and reduce outputs accordingly
166
+ training_step_output_for_epoch_end = training_step_output
167
+ is_result_obj = isinstance (training_step_output , Result )
168
+
169
+ # track batch size for weighted average
170
+ if is_result_obj :
171
+ training_step_output .track_batch_size (len (split_batch ))
172
+
173
+ # don't allow EvalResult in the training_step
174
+ if isinstance (training_step_output , EvalResult ):
175
+ raise MisconfigurationException ('training_step cannot return EvalResult, '
176
+ 'use a dict or TrainResult instead' )
177
+
178
+ # handle regular dicts
179
+ if not is_result_obj :
180
+ training_step_output = self .trainer .process_output (training_step_output , train = True )
181
+
182
+ training_step_output = AttributeDict (
183
+ batch_loss = training_step_output [0 ],
184
+ pbar_on_batch_end = training_step_output [1 ],
185
+ log_metrics = training_step_output [2 ],
186
+ callback_metrics = training_step_output [3 ],
187
+ hiddens = training_step_output [4 ],
188
+ )
189
+
190
+ # if the user decides to finally reduce things in epoch_end, save raw output without graphs
191
+ if isinstance (training_step_output_for_epoch_end , torch .Tensor ):
192
+ training_step_output_for_epoch_end = training_step_output_for_epoch_end .detach ()
193
+ elif is_result_obj :
194
+ training_step_output_for_epoch_end = copy (training_step_output )
195
+ training_step_output_for_epoch_end .detach ()
196
+ else :
197
+ training_step_output_for_epoch_end = recursive_detach (training_step_output_for_epoch_end )
198
+
199
+ # accumulate loss
200
+ # (if accumulate_grad_batches = 1 no effect)
201
+ closure_loss = training_step_output .minimize if is_result_obj else training_step_output .batch_loss
202
+ closure_loss = closure_loss / self .trainer .accumulate_grad_batches
203
+
204
+ # the loss will get scaled for amp. avoid any modifications to it
205
+ untouched_loss = closure_loss .detach ().clone ()
206
+
207
+ # result
208
+ result = AttributeDict (
209
+ closure_loss = closure_loss ,
210
+ loss = untouched_loss ,
211
+ training_step_output = training_step_output ,
212
+ training_step_output_for_epoch_end = training_step_output_for_epoch_end ,
213
+ hiddens = training_step_output .hiddens ,
214
+ )
215
+ return result
0 commit comments