1
- from pytorch_lightning . trainer . supporters import Accumulator
1
+ import subprocess
2
2
import numpy as np
3
+ import torch
4
+ import torch .distributed as torch_distrib
5
+ from pytorch_lightning .utilities .model_utils import is_overridden
6
+ from pytorch_lightning .trainer .supporters import Accumulator
7
+ from pytorch_lightning .callbacks import ModelCheckpoint
3
8
from pytorch_lightning .core .step_result import Result
9
+ from pytorch_lightning import _logger as log
4
10
5
11
6
12
class TrainLoop :
@@ -10,12 +16,69 @@ def __init__(self, trainer):
10
16
self .should_check_val = False
11
17
self .early_stopping_accumulator = None
12
18
self .checkpoint_accumulator = None
19
+ self ._teardown_already_run = False
13
20
14
21
@property
15
22
def num_optimizers (self ):
16
23
num_optimizers = len (self .get_optimizers_iterable ())
17
24
return num_optimizers
18
25
26
+ def on_train_start (self ):
27
+ # clear cache before training
28
+ if self .trainer .on_gpu and self .trainer .root_gpu is not None :
29
+ # use context because of:
30
+ # https://discuss.pytorch.org/t/out-of-memory-when-i-use-torch-cuda-empty-cache/57898
31
+ with torch .cuda .device (f'cuda:{ self .trainer .root_gpu } ' ):
32
+ torch .cuda .empty_cache ()
33
+
34
+ # hook
35
+ self .trainer .call_hook ('on_train_start' )
36
+
37
+ def on_train_end (self ):
38
+ if self ._teardown_already_run :
39
+ return
40
+
41
+ self ._teardown_already_run = True
42
+
43
+ # Save latest checkpoint
44
+ log .info ('Saving latest checkpoint..' )
45
+ self .check_checkpoint_callback (should_check_val = False )
46
+
47
+ # hook
48
+ self .trainer .call_hook ('on_train_end' )
49
+
50
+ # kill loggers
51
+ if self .trainer .logger is not None :
52
+ self .trainer .logger .finalize ("success" )
53
+
54
+ # summarize profile results
55
+ if self .trainer .global_rank == 0 :
56
+ self .trainer .profiler .describe ()
57
+
58
+ if self .trainer .global_rank == 0 :
59
+ for proc in self .trainer .interactive_ddp_procs :
60
+ subprocess .Popen .kill (proc )
61
+
62
+ # clean up dist group
63
+ if self .trainer .use_ddp or self .trainer .use_ddp2 :
64
+ torch_distrib .destroy_process_group ()
65
+
66
+ # clear mem
67
+ if self .trainer .on_gpu :
68
+ model = self .trainer .get_model ()
69
+ model .cpu ()
70
+ torch .cuda .empty_cache ()
71
+
72
+ def check_checkpoint_callback (self , should_check_val ):
73
+ model = self .trainer .get_model ()
74
+
75
+ # when no val loop is present or fast-dev-run still need to call checkpoints
76
+ # TODO bake this logic into the checkpoint callback
77
+ should_activate = not is_overridden ('validation_step' , model ) and not should_check_val
78
+ if should_activate :
79
+ checkpoint_callbacks = [c for c in self .trainer .callbacks if isinstance (c , ModelCheckpoint )]
80
+ [c .on_validation_end (self .trainer , model ) for c in checkpoint_callbacks ]
81
+
19
82
def on_train_epoch_start (self ):
20
83
# hook
21
84
self .trainer .call_hook ('on_epoch_start' )
@@ -28,6 +91,7 @@ def on_train_epoch_start(self):
28
91
self .early_stopping_accumulator = Accumulator ()
29
92
self .checkpoint_accumulator = Accumulator ()
30
93
94
+
31
95
def on_train_batch_end (self , epoch_output , epoch_end_outputs , batch , batch_idx , dataloader_idx ):
32
96
# figure out what to track for epoch end
33
97
self .track_epoch_end_reduce_metrics (epoch_output , epoch_end_outputs )
@@ -36,6 +100,13 @@ def on_train_batch_end(self, epoch_output, epoch_end_outputs, batch, batch_idx,
36
100
self .trainer .call_hook ('on_batch_end' )
37
101
self .trainer .call_hook ('on_train_batch_end' , batch , batch_idx , dataloader_idx )
38
102
103
+ def reset_train_val_dataloaders (self , model ):
104
+ if not self .trainer .reload_dataloaders_every_epoch :
105
+ self .trainer .reset_train_dataloader (model )
106
+
107
+ if self .trainer .val_dataloaders is None and not self .trainer .reload_dataloaders_every_epoch :
108
+ self .trainer .reset_val_dataloader (model )
109
+
39
110
def track_epoch_end_reduce_metrics (self , epoch_output , epoch_end_outputs ):
40
111
# track the outputs to reduce at the end of the epoch
41
112
for opt_idx , opt_outputs in enumerate (epoch_end_outputs ):
0 commit comments