From aa6f7110af20f7cdd1a811c1cf89603cc05874d1 Mon Sep 17 00:00:00 2001 From: Leonard Lausen Date: Fri, 7 May 2021 11:46:03 +0200 Subject: [PATCH 1/7] Fix DeepSpeedPlugin with IterableDataset (#7362) * deepspeed add train_micro_batch_size_per_gpu argument * Update naming and doc * Modify to use auto naming convention, add test * Add iterable tests * Fix tests, attempt by mocking * Import correct package * Fix comparison * Set as special test * Remove import * Add Changelog Co-authored-by: SeanNaren (cherry picked from commit 98b94b810c89d5e51a0ad0a2e6a87747aee6fbe9) --- CHANGELOG.md | 20 +++++++++ .../plugins/training_type/deepspeed.py | 27 ++++++++++-- tests/plugins/test_deepspeed_plugin.py | 41 ++++++++++++++++++- 3 files changed, 84 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index cbcdba1c9bc16..f615ed070edac 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,26 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). +## [1.3.x] - 2021-MM-DD + +### Added + + +### Changed + + +### Deprecated + + +### Removed + + +### Fixed + + +- Fixed DeepSpeed with IterableDatasets ([#7362](https://github.com/PyTorchLightning/pytorch-lightning/pull/7362)) + + ## [1.3.0] - 2021-05-06 ### Added diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index 54974739c1746..fe3f51fa99390 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -88,6 +88,7 @@ def __init__( allgather_bucket_size: int = 2e8, reduce_bucket_size: int = 2e8, zero_allow_untested_optimizer: bool = True, + logging_batch_size_per_gpu: Union[str, int] = "auto", config: Optional[Union[Path, str, dict]] = None, logging_level: int = logging.WARN, num_nodes: int = 1, @@ -148,6 +149,13 @@ def __init__( zero_allow_untested_optimizer: Allow untested optimizers to be used with ZeRO. Currently only Adam is a DeepSpeed supported optimizer when using ZeRO (default: True) + logging_batch_size_per_gpu: Config used in DeepSpeed to calculate verbose timing for logging + on a per sample per second basis (only displayed if logging=logging.INFO). + If set to "auto", the plugin tries to infer this from + the train DataLoader's BatchSampler, else defaults to 1. + To obtain accurate logs when using datasets that do not support batch samplers, + set this to the actual per gpu batch size (trainer.batch_size). + config: Pass in a deepspeed formatted config dict, or path to a deepspeed config: https://www.deepspeed.ai/docs/config-json. All defaults will be ignored if a config is passed in. (Default: ``None``) @@ -182,6 +190,7 @@ def __init__( when using ZeRO Stage 3. This allows a single weight file to contain the entire model, rather than individual sharded weight files. Disable to save sharded states individually. (Default: True) + """ if not _DEEPSPEED_AVAILABLE: raise MisconfigurationException( @@ -197,6 +206,7 @@ def __init__( self.config = self._create_default_config( zero_optimization, zero_allow_untested_optimizer, + logging_batch_size_per_gpu, partition_activations=partition_activations, cpu_checkpointing=cpu_checkpointing, contiguous_memory_optimization=contiguous_memory_optimization, @@ -409,14 +419,22 @@ def _format_batch_size_and_grad_accum_config(self): " as this will be set via accumulate_grad_batches=x argument passed via the Lightning Trainer." ) if "train_micro_batch_size_per_gpu" not in self.config: - # train_micro_batch_size_per_gpu is used for throughput logging purposes - # by default we use the batch size of the loader which may be incorrect if a batch sampler is passed - batch_size = self.lightning_module.train_dataloader().batch_sampler.batch_size + batch_size = self._auto_select_batch_size() self.config["train_micro_batch_size_per_gpu"] = batch_size self.config["gradient_accumulation_steps"] = self.lightning_module.trainer.accumulate_grad_batches if "gradient_clipping" not in self.config: self.config["gradient_clipping"] = self.lightning_module.trainer.gradient_clip_val + def _auto_select_batch_size(self): + # train_micro_batch_size_per_gpu is used for throughput logging purposes + # by default we try to use the batch size of the loader + batch_size = 1 + if hasattr(self.lightning_module, 'train_dataloader'): + train_dataloader = self.lightning_module.train_dataloader() + if hasattr(train_dataloader, 'batch_sampler'): + batch_size = train_dataloader.batch_sampler.batch_size + return batch_size + def _format_precision_config(self): amp_type = self.lightning_module.trainer.accelerator_connector.amp_type amp_level = self.lightning_module.trainer.accelerator_connector.amp_level @@ -446,6 +464,7 @@ def _create_default_config( self, zero_optimization: bool, zero_allow_untested_optimizer: bool, + logging_batch_size_per_gpu: Union[str, int], partition_activations: bool, cpu_checkpointing: bool, contiguous_memory_optimization: bool, @@ -466,6 +485,8 @@ def _create_default_config( "zero_optimization": zero_kwargs, **cfg } + if logging_batch_size_per_gpu != 'auto': + cfg = {"train_micro_batch_size_per_gpu": logging_batch_size_per_gpu, **cfg} return cfg def _filepath_to_dir(self, filepath: str) -> str: diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index c768a9aabf8fb..056c28ffa2309 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -7,6 +7,7 @@ import torch.nn.functional as F from torch import nn, Tensor from torch.optim import Optimizer +from torch.utils.data import DataLoader from pytorch_lightning import LightningModule, seed_everything, Trainer from pytorch_lightning.callbacks import Callback, ModelCheckpoint @@ -14,7 +15,7 @@ from pytorch_lightning.plugins import DeepSpeedPlugin, DeepSpeedPrecisionPlugin from pytorch_lightning.plugins.training_type.deepspeed import LightningDeepSpeedModule from pytorch_lightning.utilities.exceptions import MisconfigurationException -from tests.helpers.boring_model import BoringModel +from tests.helpers.boring_model import BoringModel, RandomDataset, RandomIterableDataset from tests.helpers.datamodules import ClassifDataModule from tests.helpers.runif import RunIf @@ -234,6 +235,44 @@ def backward(self, loss: Tensor, optimizer: Optimizer, optimizer_idx: int, *args trainer.fit(model) +@RunIf(min_gpus=1, deepspeed=True, special=True) +@pytest.mark.parametrize(['dataset_cls', 'value'], [(RandomDataset, "auto"), (RandomDataset, 10), + (RandomIterableDataset, "auto"), (RandomIterableDataset, 10)]) +def test_deepspeed_auto_batch_size_config_select(tmpdir, dataset_cls, value): + """Test to ensure that the batch size is correctly set as expected for deepspeed logging purposes.""" + + class TestModel(BoringModel): + + def train_dataloader(self): + return DataLoader(dataset_cls(32, 64)) + + class AssertCallback(Callback): + + def on_train_start(self, trainer, pl_module) -> None: + assert isinstance(trainer.accelerator.training_type_plugin, DeepSpeedPlugin) + config = trainer.accelerator.training_type_plugin.config + + # int value overrides auto mode + expected_value = value if isinstance(value, int) else 1 + if dataset_cls == RandomDataset: + expected_value = pl_module.train_dataloader().batch_size if value == "auto" else value + + assert config['train_micro_batch_size_per_gpu'] == expected_value + raise SystemExit + + ck = AssertCallback() + model = TestModel() + trainer = Trainer( + default_root_dir=tmpdir, + fast_dev_run=True, + callbacks=ck, + gpus=1, + plugins=DeepSpeedPlugin(logging_batch_size_per_gpu=value, zero_optimization=False), + ) + with pytest.raises(SystemExit): + trainer.fit(model) + + @RunIf(min_gpus=1, deepspeed=True, special=True) def test_deepspeed_run_configure_optimizers(tmpdir): """ From 1ae191c507ca6c58c830f7c2d96e344f1e6bea74 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Fri, 7 May 2021 13:13:54 +0200 Subject: [PATCH 2/7] update ngc for 1.3 (#7414) (cherry picked from commit 1a27c12b26ef1fc94ddd2cfed6b00238c73af88a) --- dockers/nvidia/Dockerfile | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/dockers/nvidia/Dockerfile b/dockers/nvidia/Dockerfile index 41e5ddea8fc0b..2027c4f3d5d7d 100644 --- a/dockers/nvidia/Dockerfile +++ b/dockers/nvidia/Dockerfile @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel_21-03.html#rel_21-03 -FROM nvcr.io/nvidia/pytorch:21.03-py3 +# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes +FROM nvcr.io/nvidia/pytorch:21.04-py3 MAINTAINER PyTorchLightning @@ -46,6 +46,8 @@ RUN \ rm -rf pytorch-lightning && \ pip list +RUN pip install lightning-grid -U + ENV PYTHONPATH="/workspace" RUN \ From d59ef15df2bdf85660975592630076f80429fb0d Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Sat, 8 May 2021 14:15:52 +0900 Subject: [PATCH 3/7] Restore `trainer.current_epoch` after tuning (#7434) * Add a test * Save and restore current_epoch * Update CHANGELOG * alphabetical order (cherry picked from commit 710b144b9ba7d5e1cbc4bc3817ba2d30d6e8968d) --- CHANGELOG.md | 3 +++ pytorch_lightning/tuner/lr_finder.py | 2 ++ tests/tuner/test_lr_finder.py | 8 +++++++- tests/tuner/test_scale_batch_size.py | 8 ++++---- 4 files changed, 16 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f615ed070edac..ca6323fd62e1c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,6 +25,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed DeepSpeed with IterableDatasets ([#7362](https://github.com/PyTorchLightning/pytorch-lightning/pull/7362)) +- Fixed `Trainer.current_epoch` not getting restored after tuning ([#7434](https://github.com/PyTorchLightning/pytorch-lightning/pull/7434)) + + ## [1.3.0] - 2021-05-06 ### Added diff --git a/pytorch_lightning/tuner/lr_finder.py b/pytorch_lightning/tuner/lr_finder.py index 01f48c66ad201..601f45f171ae4 100644 --- a/pytorch_lightning/tuner/lr_finder.py +++ b/pytorch_lightning/tuner/lr_finder.py @@ -288,6 +288,7 @@ def __lr_finder_dump_params(trainer, model): 'logger': trainer.logger, 'max_steps': trainer.max_steps, 'checkpoint_callback': trainer.checkpoint_callback, + 'current_epoch': trainer.current_epoch, 'configure_optimizers': model.configure_optimizers, } @@ -297,6 +298,7 @@ def __lr_finder_restore_params(trainer, model): trainer.logger = trainer.__dumped_params['logger'] trainer.callbacks = trainer.__dumped_params['callbacks'] trainer.max_steps = trainer.__dumped_params['max_steps'] + trainer.current_epoch = trainer.__dumped_params['current_epoch'] model.configure_optimizers = trainer.__dumped_params['configure_optimizers'] del trainer.__dumped_params diff --git a/tests/tuner/test_lr_finder.py b/tests/tuner/test_lr_finder.py index 641196eda466f..608cb8c6778bf 100644 --- a/tests/tuner/test_lr_finder.py +++ b/tests/tuner/test_lr_finder.py @@ -77,7 +77,13 @@ def test_trainer_reset_correctly(tmpdir): ) changed_attributes = [ - 'callbacks', 'logger', 'max_steps', 'auto_lr_find', 'accumulate_grad_batches', 'checkpoint_callback' + 'accumulate_grad_batches', + 'auto_lr_find', + 'callbacks', + 'checkpoint_callback', + 'current_epoch', + 'logger', + 'max_steps', ] expected = {ca: getattr(trainer, ca) for ca in changed_attributes} trainer.tuner.lr_find(model, num_training=5) diff --git a/tests/tuner/test_scale_batch_size.py b/tests/tuner/test_scale_batch_size.py index 3c9a38ac0aee2..7d4e05000d5da 100644 --- a/tests/tuner/test_scale_batch_size.py +++ b/tests/tuner/test_scale_batch_size.py @@ -111,13 +111,13 @@ def test_trainer_reset_correctly(tmpdir): ) changed_attributes = [ - 'max_steps', - 'weights_summary', - 'logger', 'callbacks', 'checkpoint_callback', - 'limit_train_batches', 'current_epoch', + 'limit_train_batches', + 'logger', + 'max_steps', + 'weights_summary', ] expected = {ca: getattr(trainer, ca) for ca in changed_attributes} trainer.tuner.scale_batch_size(model, max_trials=5) From 1577bb5141cda00ab64e44035f3da7d497b3f737 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 8 May 2021 20:03:51 +0200 Subject: [PATCH 4/7] fix 1.9 test (#7441) (cherry picked from commit 1af42d7d1e3f02d78063e02b5a603e3dd59837eb) --- tests/trainer/test_dataloaders.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index d988943c06088..6f78f125754b5 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -803,8 +803,12 @@ def _user_worker_init_fn(_): pass +@RunIf(max_torch="1.8.9") def test_missing_worker_init_fn(): - """ Test that naive worker seed initialization leads to undesired random state in subprocesses. """ + """ + Test that naive worker seed initialization leads to undesired random state in subprocesses. + PyTorch 1.9+ does not have this issue. + """ dataset = NumpyRandomDataset() seed_everything(0) From a71838b626d93fef3c44a1985f4ad9754aa7fb8d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 10 May 2021 05:26:15 +0200 Subject: [PATCH 5/7] fix display bug (#7395) (cherry picked from commit 6bc616d78f13c9921f3a08f7c71229b81be8b5ca) --- pytorch_lightning/accelerators/gpu.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/accelerators/gpu.py b/pytorch_lightning/accelerators/gpu.py index d14b7cbeb9db6..03303edfc5ad2 100644 --- a/pytorch_lightning/accelerators/gpu.py +++ b/pytorch_lightning/accelerators/gpu.py @@ -36,7 +36,7 @@ def setup(self, trainer: 'pl.Trainer', model: 'pl.LightningModule') -> None: """ if "cuda" not in str(self.root_device): raise MisconfigurationException(f"Device should be GPU, got {self.root_device} instead") - self.set_nvidia_flags() + self.set_nvidia_flags(trainer.local_rank) torch.cuda.set_device(self.root_device) return super().setup(trainer, model) @@ -55,12 +55,12 @@ def teardown(self) -> None: torch.cuda.empty_cache() @staticmethod - def set_nvidia_flags() -> None: + def set_nvidia_flags(local_rank: int) -> None: # set the correct cuda visible devices (using pci order) os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" all_gpu_ids = ",".join([str(x) for x in range(torch.cuda.device_count())]) devices = os.getenv("CUDA_VISIBLE_DEVICES", all_gpu_ids) - _log.info(f"LOCAL_RANK: {os.getenv('LOCAL_RANK', 0)} - CUDA_VISIBLE_DEVICES: [{devices}]") + _log.info(f"LOCAL_RANK: {local_rank} - CUDA_VISIBLE_DEVICES: [{devices}]") def to_device(self, batch: Any) -> Any: # no need to transfer batch to device in DP mode From 1ef805242fcef4b151c49f009677dbe101b7460a Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Mon, 10 May 2021 17:27:37 +0900 Subject: [PATCH 6/7] Pin `Sphinx<4.0` (#7456) * Dont use sphinx 4.0.0 * Dont use sphinx 4.0.0 * Update comment * Simple There is no other release between 3.5 and 4.0 Co-authored-by: Jirka Borovec (cherry picked from commit 6d82dc832b066a9d6a5fa63a449b8ac372d72b6a) --- requirements/docs.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/docs.txt b/requirements/docs.txt index 7444287ed401a..b5056fc2dacd9 100644 --- a/requirements/docs.txt +++ b/requirements/docs.txt @@ -1,4 +1,4 @@ -sphinx>=3.0, !=3.5 # fails with sphinx.ext.viewcode +sphinx>=3.0, <3.5 # fails with sphinx.ext.viewcode # fails with sphinx_paramlinks with 4.0.0 recommonmark # fails with badges m2r # fails with multi-line text nbsphinx>=0.8 From 4eaae47cf386c52107e9da58e61b7528117d7d2f Mon Sep 17 00:00:00 2001 From: jirka Date: Tue, 11 May 2021 09:19:26 +0200 Subject: [PATCH 7/7] v1.3.1 --- CHANGELOG.md | 18 ++---------------- docs/source/governance.rst | 2 -- pytorch_lightning/__about__.py | 2 +- 3 files changed, 3 insertions(+), 19 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ca6323fd62e1c..2bd9dd13a6a01 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,27 +5,13 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). -## [1.3.x] - 2021-MM-DD - -### Added - - -### Changed - - -### Deprecated - - -### Removed - +## [1.3.1] - 2021-05-11 ### Fixed - - Fixed DeepSpeed with IterableDatasets ([#7362](https://github.com/PyTorchLightning/pytorch-lightning/pull/7362)) - - - Fixed `Trainer.current_epoch` not getting restored after tuning ([#7434](https://github.com/PyTorchLightning/pytorch-lightning/pull/7434)) +- Fixed local rank displayed in console log ([#7395](https://github.com/PyTorchLightning/pytorch-lightning/pull/7395)) ## [1.3.0] - 2021-05-06 diff --git a/docs/source/governance.rst b/docs/source/governance.rst index fac8b68e1df53..5b1f9bd1916c1 100644 --- a/docs/source/governance.rst +++ b/docs/source/governance.rst @@ -38,5 +38,3 @@ Alumni - Jeff Ling (`jeffling `_) - Teddy Koker (`teddykoker `_) - Nate Raw (`nateraw `_) - - diff --git a/pytorch_lightning/__about__.py b/pytorch_lightning/__about__.py index a333cdb3e6ecd..67ca6d9e8d167 100644 --- a/pytorch_lightning/__about__.py +++ b/pytorch_lightning/__about__.py @@ -1,7 +1,7 @@ import time _this_year = time.strftime("%Y") -__version__ = '1.3.0' +__version__ = '1.3.1' __author__ = 'William Falcon et al.' __author_email__ = 'waf2107@columbia.edu' __license__ = 'Apache-2.0'