forked from Lightning-AI/pytorch-lightning
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtraining_tricks.py
343 lines (293 loc) · 14.1 KB
/
training_tricks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import sys
from abc import ABC, abstractmethod
import gc
import os
from typing import Optional
import torch
from torch import Tensor
from torch.utils.data import DataLoader
from pytorch_lightning import _logger as log
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.callbacks import GradientAccumulationScheduler
from pytorch_lightning.loggers.base import DummyLogger
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.memory import is_oom_error, garbage_collection_cuda
EPSILON = 1e-6
EPSILON_FP16 = 1e-5
class TrainerTrainingTricksMixin(ABC):
# this is just a summary on variables used in this abstract class,
# the proper values/initialisation should be done in child class
gradient_clip_val: ...
precision: int
default_root_dir: str
progress_bar_callback: ...
on_gpu: bool
@abstractmethod
def get_model(self) -> LightningModule:
"""Warning: this is just empty shell for code implemented in other class."""
@abstractmethod
def save_checkpoint(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""
@abstractmethod
def restore(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""
@abstractmethod
def fit(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""
def clip_gradients(self):
# this code is a modification of torch.nn.utils.clip_grad_norm_
# with TPU support based on https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md
if self.gradient_clip_val <= 0:
return
model = self.get_model()
parameters = model.parameters()
max_norm = float(self.gradient_clip_val)
norm_type = float(2.0)
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = list(filter(lambda p: p.grad is not None, parameters))
if norm_type == math.inf:
total_norm = max(p.grad.data.abs().max() for p in parameters)
else:
device = parameters[0].device
out = torch.empty(len(parameters), device=device)
for i, p in enumerate(parameters):
torch.norm(p.grad.data.to(device), norm_type, out=out[i])
total_norm = torch.norm(out, norm_type)
eps = EPSILON_FP16 if self.precision == 16 else EPSILON
clip_coef = torch.tensor(max_norm, device=device) / (total_norm + eps)
clip_coef = torch.min(clip_coef, torch.ones_like(clip_coef))
for p in parameters:
p.grad.data.mul_(clip_coef.to(p.grad.data.device))
def print_nan_gradients(self) -> None:
model = self.get_model()
for param in model.parameters():
if (param.grad is not None) and torch.isnan(param.grad.float()).any():
log.info(param, param.grad)
def detect_nan_tensors(self, loss: Tensor) -> None:
model = self.get_model()
# check if loss is nan
if not torch.isfinite(loss).all():
raise ValueError(
'The loss returned in `training_step` is nan or inf.'
)
# check if a network weight is nan
for name, param in model.named_parameters():
if not torch.isfinite(param).all():
self.print_nan_gradients()
raise ValueError(
f'Detected nan and/or inf values in `{name}`.'
' Check your forward pass for numerically unstable operations.'
)
def configure_accumulated_gradients(self, accumulate_grad_batches):
if isinstance(accumulate_grad_batches, dict):
self.accumulation_scheduler = GradientAccumulationScheduler(accumulate_grad_batches)
elif isinstance(accumulate_grad_batches, int):
schedule = {0: accumulate_grad_batches}
self.accumulation_scheduler = GradientAccumulationScheduler(schedule)
else:
raise TypeError("Gradient accumulation supports only int and dict types")
def scale_batch_size(self,
model: LightningModule,
mode: str = 'power',
steps_per_trial: int = 3,
init_val: int = 2,
max_trials: int = 25,
batch_arg_name: str = 'batch_size'):
r"""
Will iteratively try to find the largest batch size for a given model
that does not give an out of memory (OOM) error.
Args:
model: Model to fit.
mode: string setting the search mode. Either `power` or `binsearch`.
If mode is `power` we keep multiplying the batch size by 2, until
we get an OOM error. If mode is 'binsearch', we will initially
also keep multiplying by 2 and after encountering an OOM error
do a binary search between the last successful batch size and the
batch size that failed.
steps_per_trial: number of steps to run with a given batch size.
Idealy 1 should be enough to test if a OOM error occurs,
however in practise a few are needed
init_val: initial batch size to start the search with
max_trials: max number of increase in batch size done before
algorithm is terminated
"""
if not hasattr(model, batch_arg_name):
if not hasattr(model.hparams, batch_arg_name):
raise MisconfigurationException(
'Neither of `model.batch_size` and `model.hparams.batch_size` found.'
)
if hasattr(model.train_dataloader, 'patch_loader_code'):
raise MisconfigurationException('The batch scaling feature cannot be used with dataloaders'
' passed directly to `.fit()`. Please disable the feature or'
' incorporate the dataloader into the model.')
# Arguments we adjust during the batch size finder, save for restoring
self.__scale_batch_dump_params()
# Set to values that are required by the algorithm
self.__scale_batch_reset_params(model, steps_per_trial)
# Save initial model, that is loaded after batch size is found
save_path = os.path.join(self.default_root_dir, 'temp_model.ckpt')
self.save_checkpoint(str(save_path))
if self.progress_bar_callback:
self.progress_bar_callback.disable()
# Initially we just double in size until an OOM is encountered
new_size = _adjust_batch_size(self, value=init_val) # initially set to init_val
if mode == 'power':
new_size = _run_power_scaling(self, model, new_size, batch_arg_name, max_trials)
elif mode == 'binsearch':
new_size = _run_binsearch_scaling(self, model, new_size, batch_arg_name, max_trials)
else:
raise ValueError('mode in method `scale_batch_size` can only be `power` or `binsearch')
garbage_collection_cuda()
log.info(f'Finished batch size finder, will continue with full run using batch size {new_size}')
# Restore initial state of model
self.restore(str(save_path), on_gpu=self.on_gpu)
os.remove(save_path)
# Finish by resetting variables so trainer is ready to fit model
self.__scale_batch_restore_params()
if self.progress_bar_callback:
self.progress_bar_callback.enable()
return new_size
def __scale_batch_dump_params(self):
# Prevent going into infinite loop
self.__dumped_params = {
'max_steps': self.max_steps,
'weights_summary': self.weights_summary,
'logger': self.logger,
'callbacks': self.callbacks,
'checkpoint_callback': self.checkpoint_callback,
'early_stop_callback': self.early_stop_callback,
'auto_scale_batch_size': self.auto_scale_batch_size,
'limit_train_batches': self.limit_train_batches,
'model': self.model,
}
def __scale_batch_reset_params(self, model, steps_per_trial):
self.auto_scale_batch_size = None # prevent recursion
self.max_steps = steps_per_trial # take few steps
self.weights_summary = None # not needed before full run
self.logger = DummyLogger()
self.callbacks = [] # not needed before full run
self.checkpoint_callback = False # required for saving
self.early_stop_callback = None
self.limit_train_batches = 1.0
self.optimizers, self.schedulers = [], [] # required for saving
self.model = model # required for saving
def __scale_batch_restore_params(self):
self.max_steps = self.__dumped_params['max_steps']
self.weights_summary = self.__dumped_params['weights_summary']
self.logger = self.__dumped_params['logger']
self.callbacks = self.__dumped_params['callbacks']
self.checkpoint_callback = self.__dumped_params['checkpoint_callback']
self.auto_scale_batch_size = self.__dumped_params['auto_scale_batch_size']
self.early_stop_callback = self.__dumped_params['early_stop_callback']
self.limit_train_batches = self.__dumped_params['limit_train_batches']
self.model = self.__dumped_params['model']
del self.__dumped_params
def _adjust_batch_size(trainer,
batch_arg_name: str = 'batch_size',
factor: float = 1.0,
value: Optional[int] = None,
desc: str = None):
""" Function for adjusting the batch size. It is expected that the user
has provided a model that has a hparam field called `batch_size` i.e.
`model.hparams.batch_size` should exist.
Args:
trainer: instance of pytorch_lightning.Trainer
batch_arg_name: field where batch_size is stored in `model.hparams`
factor: value which the old batch size is multiplied by to get the
new batch size
value: if a value is given, will override the batch size with this value.
Note that the value of `factor` will not have an effect in this case
desc: either `succeeded` or `failed`. Used purely for logging
"""
model = trainer.get_model()
if hasattr(model, batch_arg_name):
batch_size = getattr(model, batch_arg_name)
else:
batch_size = getattr(model.hparams, batch_arg_name)
if value:
if hasattr(model, batch_arg_name):
setattr(model, batch_arg_name, value)
else:
setattr(model.hparams, batch_arg_name, value)
new_size = value
if desc:
log.info(f'Batch size {batch_size} {desc}, trying batch size {new_size}')
else:
new_size = int(batch_size * factor)
if desc:
log.info(f'Batch size {batch_size} {desc}, trying batch size {new_size}')
setattr(model.hparams, batch_arg_name, new_size)
return new_size
def _run_power_scaling(trainer, model, new_size, batch_arg_name, max_trials):
""" Batch scaling mode where the size is doubled at each iteration until an
OOM error is encountered. """
for _ in range(max_trials):
garbage_collection_cuda()
trainer.global_step = 0 # reset after each try
try:
# Try fit
trainer.fit(model)
# Double in size
new_size = _adjust_batch_size(trainer, batch_arg_name, factor=2.0, desc='succeeded')
except RuntimeError as exception:
# Only these errors should trigger an adjustment
if is_oom_error(exception):
# If we fail in power mode, half the size and return
garbage_collection_cuda()
new_size = _adjust_batch_size(trainer, batch_arg_name, factor=0.5, desc='failed')
break
else:
raise # some other error not memory related
return new_size
def _run_binsearch_scaling(trainer, model, new_size, batch_arg_name, max_trials):
""" Batch scaling mode where the size is initially is doubled at each iteration
until an OOM error is encountered. Hereafter, the batch size is further
refined using a binary search """
high = None
count = 0
while True:
garbage_collection_cuda()
trainer.global_step = 0 # reset after each try
try:
# Try fit
trainer.fit(model)
count += 1
if count > max_trials:
break
# Double in size
low = new_size
if high:
if high - low <= 1:
break
midval = (high + low) // 2
new_size = _adjust_batch_size(trainer, batch_arg_name, value=midval, desc='succeeded')
else:
new_size = _adjust_batch_size(trainer, batch_arg_name, factor=2.0, desc='succeeded')
except RuntimeError as exception:
# Only these errors should trigger an adjustment
if is_oom_error(exception):
# If we fail in power mode, half the size and return
garbage_collection_cuda()
high = new_size
midval = (high + low) // 2
new_size = _adjust_batch_size(trainer, value=midval, desc='failed')
if high - low <= 1:
break
else:
raise # some other error not memory related
return new_size