Skip to content

Commit f747cb6

Browse files
authoredSep 2, 2020
ref: moving train loop to own object 2/n (intermediate steps) (Lightning-AI#3314)
* ref: moving train loop to own object 2/n (intermediate steps) * ref: moving train loop to own object 2/n (intermediate steps)
1 parent 0d90d53 commit f747cb6

File tree

2 files changed

+84
-67
lines changed

2 files changed

+84
-67
lines changed
 

‎pytorch_lightning/trainer/training_loop.py

+12-66
Original file line numberDiff line numberDiff line change
@@ -341,33 +341,16 @@ def run_sanity_check(self, *args):
341341
def train(self):
342342
self.run_sanity_check(self.get_model())
343343

344-
# TODO: shrink
345-
# clear cache before training
346-
if self.on_gpu and self.root_gpu is not None:
347-
# use context because of:
348-
# https://discuss.pytorch.org/t/out-of-memory-when-i-use-torch-cuda-empty-cache/57898
349-
with torch.cuda.device(f'cuda:{self.root_gpu}'):
350-
torch.cuda.empty_cache()
351-
352-
# get model
353-
model = self.get_model()
354-
355344
# enable train mode
345+
model = self.get_model()
356346
model.train()
357-
358-
# enable gradients
359347
torch.set_grad_enabled(True)
360348

361-
# load data
362-
# if reload_dataloaders_every_epoch, this is moved to the epoch loop
363-
if not self.reload_dataloaders_every_epoch:
364-
self.reset_train_dataloader(model)
365-
366-
if self.val_dataloaders is None and not self.reload_dataloaders_every_epoch:
367-
self.reset_val_dataloader(model)
349+
# reload data when needed
350+
self.train_loop.reset_train_val_dataloaders(model)
368351

369352
# hook
370-
self.call_hook('on_train_start')
353+
self.train_loop.on_train_start()
371354

372355
try:
373356
# run all epochs
@@ -399,7 +382,9 @@ def train(self):
399382
self.run_training_epoch()
400383

401384
if self.max_steps and self.max_steps <= self.global_step:
402-
self.run_training_teardown()
385+
386+
# hook
387+
self.train_loop.on_train_end()
403388
return
404389

405390
# update LR schedulers
@@ -411,14 +396,15 @@ def train(self):
411396

412397
if self.should_stop:
413398
if (met_min_epochs and met_min_steps):
414-
self.run_training_teardown()
399+
self.train_loop.on_train_end()
415400
return
416401
else:
417402
log.info('Trainer was signaled to stop but required minimum epochs'
418403
f' ({self.min_epochs}) or minimum steps ({self.min_steps}) has'
419404
' not been met. Training will continue...')
420405

421-
self.run_training_teardown()
406+
# hook
407+
self.train_loop.on_train_end()
422408

423409
except KeyboardInterrupt:
424410
rank_zero_warn('Detected KeyboardInterrupt, attempting graceful shutdown...')
@@ -429,7 +415,8 @@ def train(self):
429415
self._state = TrainerState.INTERRUPTED
430416
self.on_keyboard_interrupt()
431417

432-
self.run_training_teardown()
418+
# hook
419+
self.train_loop.on_train_end()
433420

434421
def run_training_epoch(self):
435422

@@ -1053,47 +1040,6 @@ def optimizer_closure(self, split_batch, batch_idx, opt_idx, optimizer, hiddens)
10531040
)
10541041
return result
10551042

1056-
1057-
# @atexit.register
1058-
def run_training_teardown(self):
1059-
if hasattr(self, '_teardown_already_run') and self._teardown_already_run:
1060-
return
1061-
1062-
self._teardown_already_run = True
1063-
1064-
# Save latest checkpoint
1065-
log.info('Saving latest checkpoint..')
1066-
self.check_checkpoint_callback(should_check_val=False)
1067-
1068-
# Train end events
1069-
with self.profiler.profile('on_train_end'):
1070-
# callbacks
1071-
self.on_train_end()
1072-
# model hooks
1073-
if self.is_function_implemented('on_train_end'):
1074-
self.get_model().on_train_end()
1075-
1076-
if self.logger is not None:
1077-
self.logger.finalize("success")
1078-
1079-
# summarize profile results
1080-
if self.global_rank == 0:
1081-
self.profiler.describe()
1082-
1083-
if self.global_rank == 0:
1084-
for proc in self.interactive_ddp_procs:
1085-
subprocess.Popen.kill(proc)
1086-
1087-
# clean up dist group
1088-
if self.use_ddp or self.use_ddp2:
1089-
torch_distrib.destroy_process_group()
1090-
1091-
# clear mem
1092-
if self.on_gpu:
1093-
model = self.get_model()
1094-
model.cpu()
1095-
torch.cuda.empty_cache()
1096-
10971043
def build_train_args(self, batch, batch_idx, opt_idx, hiddens):
10981044
# enable not needing to add opt_idx to training_step
10991045
args = [batch, batch_idx]

‎pytorch_lightning/trainer/training_loop_temp.py

+72-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
1-
from pytorch_lightning.trainer.supporters import Accumulator
1+
import subprocess
22
import numpy as np
3+
import torch
4+
import torch.distributed as torch_distrib
5+
from pytorch_lightning.utilities.model_utils import is_overridden
6+
from pytorch_lightning.trainer.supporters import Accumulator
7+
from pytorch_lightning.callbacks import ModelCheckpoint
38
from pytorch_lightning.core.step_result import Result
9+
from pytorch_lightning import _logger as log
410

511

612
class TrainLoop:
@@ -10,12 +16,69 @@ def __init__(self, trainer):
1016
self.should_check_val = False
1117
self.early_stopping_accumulator = None
1218
self.checkpoint_accumulator = None
19+
self._teardown_already_run = False
1320

1421
@property
1522
def num_optimizers(self):
1623
num_optimizers = len(self.get_optimizers_iterable())
1724
return num_optimizers
1825

26+
def on_train_start(self):
27+
# clear cache before training
28+
if self.trainer.on_gpu and self.trainer.root_gpu is not None:
29+
# use context because of:
30+
# https://discuss.pytorch.org/t/out-of-memory-when-i-use-torch-cuda-empty-cache/57898
31+
with torch.cuda.device(f'cuda:{self.trainer.root_gpu}'):
32+
torch.cuda.empty_cache()
33+
34+
# hook
35+
self.trainer.call_hook('on_train_start')
36+
37+
def on_train_end(self):
38+
if self._teardown_already_run:
39+
return
40+
41+
self._teardown_already_run = True
42+
43+
# Save latest checkpoint
44+
log.info('Saving latest checkpoint..')
45+
self.check_checkpoint_callback(should_check_val=False)
46+
47+
# hook
48+
self.trainer.call_hook('on_train_end')
49+
50+
# kill loggers
51+
if self.trainer.logger is not None:
52+
self.trainer.logger.finalize("success")
53+
54+
# summarize profile results
55+
if self.trainer.global_rank == 0:
56+
self.trainer.profiler.describe()
57+
58+
if self.trainer.global_rank == 0:
59+
for proc in self.trainer.interactive_ddp_procs:
60+
subprocess.Popen.kill(proc)
61+
62+
# clean up dist group
63+
if self.trainer.use_ddp or self.trainer.use_ddp2:
64+
torch_distrib.destroy_process_group()
65+
66+
# clear mem
67+
if self.trainer.on_gpu:
68+
model = self.trainer.get_model()
69+
model.cpu()
70+
torch.cuda.empty_cache()
71+
72+
def check_checkpoint_callback(self, should_check_val):
73+
model = self.trainer.get_model()
74+
75+
# when no val loop is present or fast-dev-run still need to call checkpoints
76+
# TODO bake this logic into the checkpoint callback
77+
should_activate = not is_overridden('validation_step', model) and not should_check_val
78+
if should_activate:
79+
checkpoint_callbacks = [c for c in self.trainer.callbacks if isinstance(c, ModelCheckpoint)]
80+
[c.on_validation_end(self.trainer, model) for c in checkpoint_callbacks]
81+
1982
def on_train_epoch_start(self):
2083
# hook
2184
self.trainer.call_hook('on_epoch_start')
@@ -28,6 +91,7 @@ def on_train_epoch_start(self):
2891
self.early_stopping_accumulator = Accumulator()
2992
self.checkpoint_accumulator = Accumulator()
3093

94+
3195
def on_train_batch_end(self, epoch_output, epoch_end_outputs, batch, batch_idx, dataloader_idx):
3296
# figure out what to track for epoch end
3397
self.track_epoch_end_reduce_metrics(epoch_output, epoch_end_outputs)
@@ -36,6 +100,13 @@ def on_train_batch_end(self, epoch_output, epoch_end_outputs, batch, batch_idx,
36100
self.trainer.call_hook('on_batch_end')
37101
self.trainer.call_hook('on_train_batch_end', batch, batch_idx, dataloader_idx)
38102

103+
def reset_train_val_dataloaders(self, model):
104+
if not self.trainer.reload_dataloaders_every_epoch:
105+
self.trainer.reset_train_dataloader(model)
106+
107+
if self.trainer.val_dataloaders is None and not self.trainer.reload_dataloaders_every_epoch:
108+
self.trainer.reset_val_dataloader(model)
109+
39110
def track_epoch_end_reduce_metrics(self, epoch_output, epoch_end_outputs):
40111
# track the outputs to reduce at the end of the epoch
41112
for opt_idx, opt_outputs in enumerate(epoch_end_outputs):

0 commit comments

Comments
 (0)
Failed to load comments.