Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed freeing data iterators in loop `on_run_end` ([#9386](https://github.com/PyTorchLightning/pytorch-lightning/pull/9386))


- Fixed DeepSpeed crash for RNNs ([#9489](https://github.com/PyTorchLightning/pytorch-lightning/pull/9489))


- Fixed `BasePredictionWriter` not returning the batch_indices in a non-distributed setting ([#9432](https://github.com/PyTorchLightning/pytorch-lightning/pull/9432))


Expand Down
10 changes: 9 additions & 1 deletion pytorch_lightning/plugins/training_type/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def __init__(
contiguous_memory_optimization: bool = False,
synchronize_checkpoint_boundary: bool = False,
load_full_weights: bool = False,
partition_module: bool = True,
) -> None:
"""Provides capabilities to run training using the DeepSpeed library, with training optimizations for large
billion parameter models. `For more information: https://pytorch-
Expand Down Expand Up @@ -252,6 +253,12 @@ def __init__(
load_full_weights: True when loading a single checkpoint file containing the model state dict
when using ZeRO Stage 3. This differs from the DeepSpeed checkpoint which contains shards
per worker.

partition_module: When True, partitions the ``LightningModule`` across devices when using ZeRO Stage 3.
This is the default behaviour to ensure that the entire module is appropriately initialized
for DeepSpeed. When False we do not explicitly convert the model, which is fine if NO layers
or ALL layers are defined in ``configure_sharded_model``. This is useful for layers such as
``torch.nn.RNN`` which do internal logic when moving to device.
"""
if not _DEEPSPEED_AVAILABLE:
raise MisconfigurationException(
Expand Down Expand Up @@ -304,6 +311,7 @@ def __init__(

self.remote_device = remote_device
self.load_full_weights = load_full_weights
self.partition_module = partition_module

# default FP16 parameters.
self.loss_scale = loss_scale
Expand Down Expand Up @@ -374,7 +382,7 @@ def init_deepspeed(self):
precision = self.lightning_module.trainer.accelerator.precision
model = LightningDeepSpeedModule(pl_module=self.model, precision=precision)

if self.zero_stage_3:
if self.zero_stage_3 and self.partition_module:
# Ensure the entire model has been moved to the appropriate device
dtype = torch.float16 if self.precision in (16, "mixed") else torch.float32
deepspeed.zero.Init(
Expand Down
68 changes: 61 additions & 7 deletions tests/plugins/test_deepspeed_plugin.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import contextlib
import json
import os
from typing import Any, Dict, Optional
Expand Down Expand Up @@ -409,13 +410,15 @@ def test_deepspeed_stage_3_save_warning(tmpdir):
)
trainer.fit(model)
checkpoint_path = os.path.join(tmpdir, "model.pt")
with pytest.warns(UserWarning) as record:
# both ranks need to call save checkpoint

# both ranks need to call save checkpoint, however only rank 0 needs to check the warning
context_manager = (
pytest.warns(UserWarning, match="each worker will save a shard of the checkpoint within a directory.")
if trainer.is_global_zero
else contextlib.suppress()
)
with context_manager:
trainer.save_checkpoint(checkpoint_path)
if trainer.is_global_zero:
assert len(record) == 1
match = "each worker will save a shard of the checkpoint within a directory."
assert match in str(record[0].message)


@RunIf(min_gpus=1, deepspeed=True, special=True)
Expand Down Expand Up @@ -735,7 +738,7 @@ def on_train_batch_start(


@RunIf(min_gpus=2, deepspeed=True, special=True)
def test_deepspeed_multigpu_test(tmpdir, deepspeed_config):
def test_deepspeed_multigpu_test(tmpdir):
"""Test to ensure we can use DeepSpeed with just test using ZeRO Stage 3."""
model = ModelParallelBoringModel()
trainer = Trainer(
Expand All @@ -744,6 +747,57 @@ def test_deepspeed_multigpu_test(tmpdir, deepspeed_config):
trainer.test(model)


@RunIf(min_gpus=1, deepspeed=True, special=True)
def test_deepspeed_multigpu_partial_partition_parameters(tmpdir):
"""Test to ensure that a module that defines a layer inside the ``__init__`` and ``configure_sharded_model``
correctly converts all parameters to float16 when ``precision=16`` and runs successfully."""

class TestModel(ModelParallelBoringModel):
def __init__(self):
super().__init__()
self.layer_2 = torch.nn.Linear(32, 32)

def configure_sharded_model(self) -> None:
self.layer = torch.nn.Linear(32, 2)

def forward(self, x):
x = self.layer_2(x)
return self.layer(x)

def on_train_epoch_start(self) -> None:
assert all([x.dtype == torch.float16 for x in self.parameters()])

model = TestModel()
trainer = Trainer(
default_root_dir=tmpdir, plugins=[DeepSpeedPlugin(stage=3)], gpus=1, fast_dev_run=True, precision=16
)
trainer.fit(model)


@RunIf(min_gpus=1, deepspeed=True, special=True)
def test_deepspeed_multigpu_test_rnn(tmpdir):
"""Test to ensure that turning off explicit partitioning of the entire module for ZeRO Stage 3 works when
training with certain layers which will crash with explicit partitioning."""

class TestModel(BoringModel):
def __init__(self):
super().__init__()
self.rnn = torch.nn.GRU(32, 32)

def on_train_epoch_start(self) -> None:
assert all([x.dtype == torch.float16 for x in self.parameters()])

model = TestModel()
trainer = Trainer(
default_root_dir=tmpdir,
plugins=[DeepSpeedPlugin(stage=3, partition_module=False)],
gpus=1,
fast_dev_run=True,
precision=16,
)
trainer.fit(model)


@RunIf(deepspeed=True)
@mock.patch("deepspeed.init_distributed", autospec=True)
@pytest.mark.parametrize("platform", ["Linux", "Windows"])
Expand Down