From 43ff610257d6caae9dac91ac7ce591a1a24fffd7 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Mon, 13 Sep 2021 16:01:45 +0100 Subject: [PATCH 1/5] Remove the module wrap, forcing all the module to be partitioned outside of deepspeed.initialize --- pytorch_lightning/plugins/training_type/deepspeed.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index ca10b47bd9fd2..5903468be97c2 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -370,13 +370,6 @@ def init_deepspeed(self): precision = self.lightning_module.trainer.accelerator.precision model = LightningDeepSpeedModule(pl_module=self.model, precision=precision) - if self.zero_stage_3: - # Ensure the entire model has been moved to the appropriate device - dtype = torch.float16 if self.precision in (16, "mixed") else torch.float32 - deepspeed.zero.Init( - module=model, remote_device=self.remote_device, pin_memory=True, config=self.config, dtype=dtype - ) - if self.lightning_module.trainer and self.lightning_module.trainer.training: self._initialize_deepspeed_train(model) else: From 58897b25b4a5b70dd1460b4a1a5098101915a7b0 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Mon, 13 Sep 2021 16:51:30 +0100 Subject: [PATCH 2/5] Fix test that seems to still be broken --- tests/plugins/test_deepspeed_plugin.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index de4bb3ea987f9..e9d3e88232a57 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -1,5 +1,6 @@ import json import os +from contextlib import ExitStack from typing import Any, Dict, Optional from unittest import mock @@ -409,13 +410,15 @@ def test_deepspeed_stage_3_save_warning(tmpdir): ) trainer.fit(model) checkpoint_path = os.path.join(tmpdir, "model.pt") - with pytest.warns(UserWarning) as record: - # both ranks need to call save checkpoint + + # both ranks need to call save checkpoint, however only rank 0 needs to check the warning + context_manager = ( + pytest.warns(UserWarning, match="each worker will save a shard of the checkpoint within a directory.") + if trainer.is_global_zero + else ExitStack() + ) + with context_manager: trainer.save_checkpoint(checkpoint_path) - if trainer.is_global_zero: - assert len(record) == 1 - match = "each worker will save a shard of the checkpoint within a directory." - assert match in str(record[0].message) @RunIf(min_gpus=1, deepspeed=True, special=True) From 3e1e67b269ceecb537e4c1a60d295541b5e5e40c Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 14 Sep 2021 11:20:26 +0100 Subject: [PATCH 3/5] Add tests, solution for RNN layers --- .../plugins/training_type/deepspeed.py | 15 ++++++ tests/plugins/test_deepspeed_plugin.py | 47 ++++++++++++++++++- 2 files changed, 61 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index 5903468be97c2..7b04632a0f619 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -124,6 +124,7 @@ def __init__( contiguous_memory_optimization: bool = False, synchronize_checkpoint_boundary: bool = False, load_full_weights: bool = False, + partition_module: bool = False, ) -> None: """Provides capabilities to run training using the DeepSpeed library, with training optimizations for large billion parameter models. `For more information: https://pytorch- @@ -253,6 +254,12 @@ def __init__( load_full_weights: True when loading a single checkpoint file containing the model state dict when using ZeRO Stage 3. This differs from the DeepSpeed checkpoint which contains shards per worker. + + partition_module: When True, partitions the ``LightningModule`` across devices when using ZeRO Stage 3. + This is the default behaviour to ensure that the entire module is appropriately initialized + for DeepSpeed. When False we do not explicitly convert the model, which is fine if NO layers + or ALL layers are defined in ``configure_sharded_model``. This is useful for layers such as + ``torch.nn.RNN`` which do internal logic when moving to device. """ if not _DEEPSPEED_AVAILABLE: raise MisconfigurationException( @@ -305,6 +312,7 @@ def __init__( self.remote_device = remote_device self.load_full_weights = load_full_weights + self.partition_module = partition_module # default FP16 parameters. self.loss_scale = loss_scale @@ -370,6 +378,13 @@ def init_deepspeed(self): precision = self.lightning_module.trainer.accelerator.precision model = LightningDeepSpeedModule(pl_module=self.model, precision=precision) + if self.zero_stage_3 and self.partition_module: + # Ensure the entire model has been moved to the appropriate device + dtype = torch.float16 if self.precision in (16, "mixed") else torch.float32 + deepspeed.zero.Init( + module=model, remote_device=self.remote_device, pin_memory=True, config=self.config, dtype=dtype + ) + if self.lightning_module.trainer and self.lightning_module.trainer.training: self._initialize_deepspeed_train(model) else: diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index e9d3e88232a57..75da449afa559 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -733,7 +733,7 @@ def on_train_batch_start( @RunIf(min_gpus=2, deepspeed=True, special=True) -def test_deepspeed_multigpu_test(tmpdir, deepspeed_config): +def test_deepspeed_multigpu_test(tmpdir): """Test to ensure we can use DeepSpeed with just test using ZeRO Stage 3.""" model = ModelParallelBoringModel() trainer = Trainer( @@ -742,6 +742,51 @@ def test_deepspeed_multigpu_test(tmpdir, deepspeed_config): trainer.test(model) +@RunIf(min_gpus=2, deepspeed=True, special=True) +def test_deepspeed_multigpu_partial_partition_parameters(tmpdir): + """Test to ensure that a module that defines a layer inside the ``__init__`` and ``configure_sharded_model`` + correctly converts all parameters to float16 when ``precision=16`` and runs successfully.""" + + class TestModel(ModelParallelBoringModel): + def __init__(self): + super().__init__() + self.layer_2 = torch.nn.Linear(32, 32) + + def forward(self, x): + x = self.layer_2(x) + return self.layer(x) + + model = TestModel() + trainer = Trainer( + default_root_dir=tmpdir, plugins=[DeepSpeedPlugin(stage=3)], gpus=2, fast_dev_run=True, precision=16 + ) + trainer.fit(model) + + +@RunIf(min_gpus=2, deepspeed=True, special=True) +def test_deepspeed_multigpu_test_rnn(tmpdir): + """Test to ensure that turning off explicit partitioning of the entire module for ZeRO Stage 3 works when + training with certain layers which will crash with explicit partitioning.""" + + class TestModel(BoringModel): + def __init__(self): + super().__init__() + self.rnn = torch.nn.GRU(32, 32) + + def on_train_epoch_start(self) -> None: + assert all([x.dtype == torch.float16 for x in self.parameters()]) + + model = TestModel() + trainer = Trainer( + default_root_dir=tmpdir, + plugins=[DeepSpeedPlugin(stage=3, partition_module=False)], + gpus=2, + fast_dev_run=True, + precision=16, + ) + trainer.fit(model) + + @RunIf(deepspeed=True) @mock.patch("deepspeed.init_distributed", autospec=True) @pytest.mark.parametrize("platform", ["Linux", "Windows"]) From 25f6080fcfeb2c27c88dd4b69f6ad2f8348b0ef1 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 14 Sep 2021 11:22:05 +0100 Subject: [PATCH 4/5] Add CHANGELOG.md --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4b8340849c9a4..a773d0448b08d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -355,6 +355,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed freeing data iterators in loop `on_run_end` ([#9386](https://github.com/PyTorchLightning/pytorch-lightning/pull/9386)) +- Fix DeepSpeed crash for RNNs ([#9489](https://github.com/PyTorchLightning/pytorch-lightning/pull/9489)) + + ## [1.4.5] - 2021-08-31 - Fixed reduction using `self.log(sync_dict=True, reduce_fx={mean,max})` ([#9142](https://github.com/PyTorchLightning/pytorch-lightning/pull/9142)) From 6bd2386f0050a36b72d68db1389ad6aee9f40458 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 14 Sep 2021 15:33:56 +0100 Subject: [PATCH 5/5] Address reviews --- .../plugins/training_type/deepspeed.py | 2 +- tests/plugins/test_deepspeed_plugin.py | 18 ++++++++++++------ 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index 7b04632a0f619..422f8c2148f0e 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -124,7 +124,7 @@ def __init__( contiguous_memory_optimization: bool = False, synchronize_checkpoint_boundary: bool = False, load_full_weights: bool = False, - partition_module: bool = False, + partition_module: bool = True, ) -> None: """Provides capabilities to run training using the DeepSpeed library, with training optimizations for large billion parameter models. `For more information: https://pytorch- diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index 75da449afa559..6b81870013366 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -1,6 +1,6 @@ +import contextlib import json import os -from contextlib import ExitStack from typing import Any, Dict, Optional from unittest import mock @@ -415,7 +415,7 @@ def test_deepspeed_stage_3_save_warning(tmpdir): context_manager = ( pytest.warns(UserWarning, match="each worker will save a shard of the checkpoint within a directory.") if trainer.is_global_zero - else ExitStack() + else contextlib.suppress() ) with context_manager: trainer.save_checkpoint(checkpoint_path) @@ -742,7 +742,7 @@ def test_deepspeed_multigpu_test(tmpdir): trainer.test(model) -@RunIf(min_gpus=2, deepspeed=True, special=True) +@RunIf(min_gpus=1, deepspeed=True, special=True) def test_deepspeed_multigpu_partial_partition_parameters(tmpdir): """Test to ensure that a module that defines a layer inside the ``__init__`` and ``configure_sharded_model`` correctly converts all parameters to float16 when ``precision=16`` and runs successfully.""" @@ -752,18 +752,24 @@ def __init__(self): super().__init__() self.layer_2 = torch.nn.Linear(32, 32) + def configure_sharded_model(self) -> None: + self.layer = torch.nn.Linear(32, 2) + def forward(self, x): x = self.layer_2(x) return self.layer(x) + def on_train_epoch_start(self) -> None: + assert all([x.dtype == torch.float16 for x in self.parameters()]) + model = TestModel() trainer = Trainer( - default_root_dir=tmpdir, plugins=[DeepSpeedPlugin(stage=3)], gpus=2, fast_dev_run=True, precision=16 + default_root_dir=tmpdir, plugins=[DeepSpeedPlugin(stage=3)], gpus=1, fast_dev_run=True, precision=16 ) trainer.fit(model) -@RunIf(min_gpus=2, deepspeed=True, special=True) +@RunIf(min_gpus=1, deepspeed=True, special=True) def test_deepspeed_multigpu_test_rnn(tmpdir): """Test to ensure that turning off explicit partitioning of the entire module for ZeRO Stage 3 works when training with certain layers which will crash with explicit partitioning.""" @@ -780,7 +786,7 @@ def on_train_epoch_start(self) -> None: trainer = Trainer( default_root_dir=tmpdir, plugins=[DeepSpeedPlugin(stage=3, partition_module=False)], - gpus=2, + gpus=1, fast_dev_run=True, precision=16, )