Skip to content

Commit 7b917de

Browse files
fix setting batch_size attribute in batch_size finder (finishing PR Lightning-AI#2523) (Lightning-AI#3043)
* lightning attr fix * revert refactor * create test * separate test * changelog update * tests * revert * Update pytorch_lightning/trainer/training_tricks.py Co-authored-by: William Falcon <waf2107@columbia.edu>
1 parent ee4eae8 commit 7b917de

File tree

3 files changed

+60
-27
lines changed

3 files changed

+60
-27
lines changed

CHANGELOG.md

+3
Original file line numberDiff line numberDiff line change
@@ -148,8 +148,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
148148

149149
- Fixed gathering of results with tensors of varying shape ([#3020](https://github.com/PyTorchLightning/pytorch-lightning/pull/3020))
150150

151+
- Fixed batch size auto-scaling feature to set the new value on the correct model attribute ([#3043](https://github.com/PyTorchLightning/pytorch-lightning/pull/3043))
152+
151153
- Fixed automatic batch scaling not working with half precision ([#3045](https://github.com/PyTorchLightning/pytorch-lightning/pull/3045))
152154

155+
153156
## [0.8.5] - 2020-07-09
154157

155158
### Added

pytorch_lightning/trainer/training_tricks.py

+14-15
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,10 @@
2424
from pytorch_lightning.callbacks import GradientAccumulationScheduler
2525
from pytorch_lightning.core.lightning import LightningModule
2626
from pytorch_lightning.loggers.base import DummyLogger
27-
from pytorch_lightning.utilities import AMPType
27+
from pytorch_lightning.utilities import AMPType, rank_zero_warn
2828
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2929
from pytorch_lightning.utilities.memory import is_oom_error, garbage_collection_cuda
30+
from pytorch_lightning.utilities.parsing import lightning_hasattr, lightning_getattr, lightning_setattr
3031

3132
try:
3233
from apex import amp
@@ -158,11 +159,15 @@ def scale_batch_size(self,
158159
algorithm is terminated
159160
160161
"""
161-
if not hasattr(model, batch_arg_name):
162-
if not hasattr(model.hparams, batch_arg_name):
163-
raise MisconfigurationException(
164-
'Neither of `model.batch_size` and `model.hparams.batch_size` found.'
165-
)
162+
if not lightning_hasattr(model, batch_arg_name):
163+
raise MisconfigurationException(
164+
f'Field {batch_arg_name} not found in both `model` and `model.hparams`')
165+
if hasattr(model, batch_arg_name) and hasattr(model, "hparams") and batch_arg_name in model.hparams:
166+
rank_zero_warn(
167+
f'Field `model.{batch_arg_name}` and `model.hparams.{batch_arg_name}` are mutually exclusive!'
168+
f' `model.{batch_arg_name}` will be used as the initial batch size for scaling.'
169+
f' If this is not the intended behavior, please remove either one.'
170+
)
166171

167172
if hasattr(model.train_dataloader, 'patch_loader_code'):
168173
raise MisconfigurationException('The batch scaling feature cannot be used with dataloaders'
@@ -268,23 +273,17 @@ def _adjust_batch_size(trainer,
268273
269274
"""
270275
model = trainer.get_model()
271-
if hasattr(model, batch_arg_name):
272-
batch_size = getattr(model, batch_arg_name)
273-
else:
274-
batch_size = getattr(model.hparams, batch_arg_name)
276+
batch_size = lightning_getattr(model, batch_arg_name)
275277
if value:
276-
if hasattr(model, batch_arg_name):
277-
setattr(model, batch_arg_name, value)
278-
else:
279-
setattr(model.hparams, batch_arg_name, value)
278+
lightning_setattr(model, batch_arg_name, value)
280279
new_size = value
281280
if desc:
282281
log.info(f'Batch size {batch_size} {desc}, trying batch size {new_size}')
283282
else:
284283
new_size = int(batch_size * factor)
285284
if desc:
286285
log.info(f'Batch size {batch_size} {desc}, trying batch size {new_size}')
287-
setattr(model.hparams, batch_arg_name, new_size)
286+
lightning_setattr(model, batch_arg_name, new_size)
288287
return new_size
289288

290289

tests/trainer/test_trainer_tricks.py

+43-12
Original file line numberDiff line numberDiff line change
@@ -196,28 +196,59 @@ def test_trainer_reset_correctly(tmpdir):
196196
f'Attribute {key} was not reset correctly after learning rate finder'
197197

198198

199-
@pytest.mark.parametrize('scale_arg', ['power', 'binsearch'])
200-
def test_trainer_arg(tmpdir, scale_arg):
201-
""" Check that trainer arg works with bool input. """
199+
@pytest.mark.parametrize('scale_arg', ['power', 'binsearch', True])
200+
def test_auto_scale_batch_size_trainer_arg(tmpdir, scale_arg):
201+
""" Test possible values for 'batch size auto scaling' Trainer argument. """
202202
tutils.reset_seed()
203-
204203
hparams = EvalModelTemplate.get_default_hparams()
205204
model = EvalModelTemplate(**hparams)
206-
207205
before_batch_size = hparams.get('batch_size')
208-
# logger file to get meta
209-
trainer = Trainer(
210-
default_root_dir=tmpdir,
211-
max_epochs=1,
212-
auto_scale_batch_size=scale_arg,
213-
)
214-
206+
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, auto_scale_batch_size=scale_arg)
215207
trainer.fit(model)
216208
after_batch_size = model.batch_size
217209
assert before_batch_size != after_batch_size, \
218210
'Batch size was not altered after running auto scaling of batch size'
219211

220212

213+
@pytest.mark.parametrize('use_hparams', [True, False])
214+
def test_auto_scale_batch_size_set_model_attribute(tmpdir, use_hparams):
215+
""" Test that new batch size gets written to the correct hyperparameter attribute. """
216+
tutils.reset_seed()
217+
218+
hparams = EvalModelTemplate.get_default_hparams()
219+
before_batch_size = hparams.get('batch_size')
220+
221+
class HparamsEvalModelTemplate(EvalModelTemplate):
222+
223+
def dataloader(self, *args, **kwargs):
224+
# artificially set batch_size so we can get a dataloader
225+
# remove it immediately after, because we want only self.hparams.batch_size
226+
setattr(self, "batch_size", before_batch_size)
227+
dataloader = super().dataloader(*args, **kwargs)
228+
del self.batch_size
229+
return dataloader
230+
231+
model_class = HparamsEvalModelTemplate if use_hparams else EvalModelTemplate
232+
model = model_class(**hparams)
233+
234+
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, auto_scale_batch_size=True)
235+
trainer.fit(model)
236+
after_batch_size = model.hparams.batch_size if use_hparams else model.batch_size
237+
assert before_batch_size != after_batch_size
238+
239+
240+
def test_auto_scale_batch_size_duplicate_attribute_warning(tmpdir):
241+
""" Test for a warning when model.batch_size and model.hparams.batch_size both present. """
242+
hparams = EvalModelTemplate.get_default_hparams()
243+
model = EvalModelTemplate(**hparams)
244+
model.hparams = hparams
245+
# now we have model.batch_size and model.hparams.batch_size
246+
trainer = Trainer(default_root_dir=tmpdir, max_steps=1, auto_scale_batch_size=True)
247+
expected_message = "Field `model.batch_size` and `model.hparams.batch_size` are mutually exclusive!"
248+
with pytest.warns(UserWarning, match=expected_message):
249+
trainer.fit(model)
250+
251+
221252
@pytest.mark.parametrize('scale_method', ['power', 'binsearch'])
222253
def test_call_to_trainer_method(tmpdir, scale_method):
223254
""" Test that calling the trainer method itself works. """

0 commit comments

Comments
 (0)