diff --git a/benchmarks/test_sharded_parity.py b/benchmarks/test_sharded_parity.py new file mode 100644 index 0000000000000..2ecdc6c50e709 --- /dev/null +++ b/benchmarks/test_sharded_parity.py @@ -0,0 +1,317 @@ +import os +import platform +import time + +import pytest +import torch + +from pytorch_lightning import Trainer, seed_everything +from pytorch_lightning.plugins.ddp_plugin import DDPPlugin +from pytorch_lightning.plugins.sharded_plugin import DDPShardedPlugin +from pytorch_lightning.utilities import FAIRSCALE_AVAILABLE, NATIVE_AMP_AVAILABLE +from tests.backends.launcher import DDPLauncher +from tests.base.boring_model import BoringModel, RandomDataset + + +@pytest.mark.skipif(platform.system() == "Windows", + reason="Distributed training is not supported on Windows") +@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") +def test_ddp_sharded_plugin_correctness_one_device(): + plugin_parity_test( + accelerator='ddp_cpu', + max_percent_speed_diff=0.15, # slower speed due to one CPU doing additional sequential memory saving calls + plugin=DDPShardedPlugin(), + model_cls=SeedTrainLoaderModel + ) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine") +@pytest.mark.skipif(platform.system() == "Windows", + reason="Distributed training is not supported on Windows") +@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") +def test_ddp_sharded_plugin_correctness_one_gpu(): + plugin_parity_test( + gpus=1, + accelerator='ddp_spawn', + plugin=DDPShardedPlugin(), + model_cls=SeedTrainLoaderModel + ) + + +@pytest.mark.skipif(not NATIVE_AMP_AVAILABLE, reason="Requires native AMP") +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine") +@pytest.mark.skipif(platform.system() == "Windows", + reason="Distributed training is not supported on Windows") +@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") +def test_ddp_sharded_plugin_correctness_amp_one_gpu(): + plugin_parity_test( + gpus=1, + precision=16, + accelerator='ddp_spawn', + plugin=DDPShardedPlugin(), + model_cls=SeedTrainLoaderModel + ) + + +@pytest.mark.skip(reason="Not a critical test, skip till drone CI performance improves.") +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +@pytest.mark.skipif(platform.system() == "Windows", + reason="Distributed training is not supported on Windows") +@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") +def test_ddp_sharded_plugin_correctness_multi_gpu(): + plugin_parity_test( + gpus=2, + accelerator='ddp_spawn', + plugin=DDPShardedPlugin(), + model_cls=SeedTrainLoaderModel, + max_percent_speed_diff=0.25 + ) + + +@pytest.mark.skipif(not NATIVE_AMP_AVAILABLE, reason="Requires native AMP") +@pytest.mark.skipif(platform.system() == "Windows", + reason="Distributed training is not supported on Windows") +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") +def test_ddp_sharded_plugin_correctness_amp_multi_gpu(): + plugin_parity_test( + gpus=2, + precision=16, + accelerator='ddp_spawn', + plugin=DDPShardedPlugin(), + model_cls=SeedTrainLoaderModel, + max_percent_speed_diff=0.25 + ) + + +@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +@pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', + reason="test should be run outside of pytest") +@DDPLauncher.run("--distributed_backend ddp --gpus 2 --precision 32") +def test_ddp_sharded_plugin_correctness_multi_gpu_ddp(tmpdir, args=None): + plugin_parity_test( + gpus=args.gpus, + precision=args.precision, + accelerator=args.distributed_backend, + plugin=DDPShardedPlugin(), + model_cls=SeedTrainLoaderModel + ) + + +@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +@pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', + reason="test should be run outside of pytest") +@DDPLauncher.run("--distributed_backend ddp --gpus 2 --precision 16") +def test_ddp_sharded_plugin_correctness_amp_multi_gpu_ddp(tmpdir, args=None): + plugin_parity_test( + gpus=args.gpus, + precision=args.precision, + accelerator=args.distributed_backend, + plugin=DDPShardedPlugin(), + model_cls=SeedTrainLoaderModel + ) + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +@pytest.mark.skipif(platform.system() == "Windows", + reason="Distributed training is not supported on Windows") +@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") +def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim(): + """ + Ensures same results using multiple optimizers across multiple GPUs + """ + plugin_parity_test( + plugin=DDPShardedPlugin(), + gpus=2, + accelerator='ddp_spawn', + model_cls=SeedTrainLoaderMultipleOptimizersModel, + max_percent_speed_diff=0.2 # Increase speed diff since only 2 GPUs sharding 2 optimizers + ) + + +@pytest.mark.skip(reason="Currently DDP manual optimization is broken due to no reduce within training step.") +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +@pytest.mark.skipif(platform.system() == "Windows", + reason="Distributed training is not supported on Windows") +@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") +def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim_manual(tmpdir): + """ + Ensures using multiple optimizers across multiple GPUs with manual optimization + """ + plugin_parity_test( + plugin=DDPShardedPlugin(), + gpus=2, + accelerator='ddp_spawn', + model_cls=SeedTrainLoaderManualModel, + ) + + +class SeedTrainLoaderModel(BoringModel): + """ + Overrides training loader to ensure we enforce the same seed for all DDP processes. + """ + + def train_dataloader(self): + seed_everything(42) + return torch.utils.data.DataLoader(RandomDataset(32, 64)) + + +class SeedTrainLoaderManualModel(SeedTrainLoaderModel): + def training_step(self, batch, batch_idx, optimizer_idx): + # manual + (opt_a, opt_b) = self.optimizers() + loss_1 = self.step(batch) + + self.manual_backward(loss_1, opt_a) + self.manual_optimizer_step(opt_a) + + # fake discriminator + loss_2 = self.step(batch[0]) + + # ensure we forward the correct params to the optimizer + # without retain_graph we can't do multiple backward passes + self.manual_backward(loss_2, opt_b, retain_graph=True) + self.manual_backward(loss_2, opt_a, retain_graph=True) + self.manual_optimizer_step(opt_b) + + assert self.layer.weight.grad is None or torch.all(self.layer.weight.grad == 0) + + def training_epoch_end(self, outputs) -> None: + # outputs should be an array with an entry per optimizer + assert len(outputs) == 2 + + def configure_optimizers(self): + optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) + optimizer_2 = torch.optim.SGD(self.layer.parameters(), lr=0.1) + return optimizer, optimizer_2 + + @property + def automatic_optimization(self) -> bool: + return False + + +class SeedTrainLoaderMultipleOptimizersModel(SeedTrainLoaderModel): + def training_step(self, batch, batch_idx, optimizer_idx): + output = self.layer(batch) + loss = self.loss(batch, output) + return {"loss": loss} + + def training_epoch_end(self, outputs) -> None: + # outputs should be an array with an entry per optimizer + assert len(outputs) == 2 + + def configure_optimizers(self): + optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) + optimizer_2 = torch.optim.SGD(self.layer.parameters(), lr=0.1) + return optimizer, optimizer_2 + + +def record_ddp_fit_model_stats(trainer, model, use_cuda): + """ + Helper to calculate wall clock time for fit + max allocated memory. + + Args: + trainer: The trainer object. + model: The model to fit. + use_cuda: Whether to sync CUDA kernels. + + Returns: + Max Memory if using GPUs, and total wall clock time. + """ + max_memory = None + + time_start = time.perf_counter() + if use_cuda: + torch.cuda.reset_peak_memory_stats() + torch.cuda.synchronize() + + trainer.fit(model) + + if use_cuda: + torch.cuda.synchronize() + max_memory = torch.cuda.max_memory_allocated() / 2 ** 20 + + total_time = time.perf_counter() - time_start + + return max_memory, total_time + + +def plugin_parity_test( + model_cls: SeedTrainLoaderModel, + plugin: DDPPlugin, + seed: int = 42, + accelerator: str = 'ddp_spawn', + gpus: int = 0, + precision: int = 32, + max_percent_speed_diff: float = 0.1): + """ + Ensures that the trained model is identical to the standard DDP implementation. + Also checks for speed/memory regressions, we should expect always less memory but performance to fluctuate. + + Args: + model_cls: Model class to use for test. + plugin: Plugin to parity test. + seed: Seed for generators. Note that this does not handle the seed for data-loading on multi-process. + accelerator: Accelerator type for test. + gpus: Number of GPUS to enable. + precision: Whether to use AMP or normal FP32 training. + max_percent_speed_diff: The maximum speed difference compared to normal DDP training. + This is more a safety net for variability in CI which can vary in speed, not for benchmarking. + + """ + + # Train normal DDP + seed_everything(seed) + ddp_model = model_cls() + use_cuda = gpus > 0 + + trainer = Trainer( + fast_dev_run=True, + max_epochs=1, + gpus=gpus, + precision=precision, + accelerator=accelerator, + ) + + max_memory_ddp, ddp_time = record_ddp_fit_model_stats( + trainer=trainer, + model=ddp_model, + use_cuda=use_cuda + ) + + # Reset and train Custom DDP + seed_everything(seed) + custom_plugin_model = model_cls() + + trainer = Trainer( + fast_dev_run=True, + max_epochs=1, + gpus=gpus, + precision=precision, + accelerator=accelerator, + plugins=[plugin], + ) + + max_memory_custom, custom_model_time = record_ddp_fit_model_stats( + trainer=trainer, + model=custom_plugin_model, + use_cuda=use_cuda + ) + + # Assert model parameters are identical after fit + for ddp_param, custom_param in zip(ddp_model.parameters(), custom_plugin_model.parameters()): + assert torch.equal(ddp_param, custom_param), 'Model parameters are different between DDP and Custom plugin' + + # Assert speed parity by ensuring percentage difference between custom/ddp is below threshold + percent_diff = (custom_model_time - ddp_time) / custom_model_time + + assert percent_diff <= max_percent_speed_diff, \ + f'Custom DDP plugin was too slow compared to DDP, Custom Plugin Time: {custom_model_time}, DDP Time: {ddp_time}' + + if use_cuda: + # Assert CUDA memory parity + assert max_memory_custom <= max_memory_ddp, \ + f'Custom plugin used too much memory compared to DDP,' \ + f'Custom Mem: {max_memory_custom}, DDP Mem: {max_memory_ddp}' diff --git a/pytorch_lightning/overrides/fairscale.py b/pytorch_lightning/overrides/fairscale.py new file mode 100644 index 0000000000000..c1d5769df6446 --- /dev/null +++ b/pytorch_lightning/overrides/fairscale.py @@ -0,0 +1,32 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pytorch_lightning.utilities import FAIRSCALE_AVAILABLE + +LightningShardedDataParallel = None +if FAIRSCALE_AVAILABLE: + from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel + + class LightningShardedDataParallel(ShardedDataParallel): + + def forward(self, *inputs, **kwargs): + if self.enable_broadcast_buffers: + self.sync_buffers() + + if self.module.training: + outputs = self.module.training_step(*inputs, **kwargs) + elif self.module.testing: + outputs = self.module.test_step(*inputs, **kwargs) + else: + outputs = self.module.validation_step(*inputs, **kwargs) + return outputs diff --git a/pytorch_lightning/plugins/sharded_native_amp_plugin.py b/pytorch_lightning/plugins/sharded_native_amp_plugin.py new file mode 100644 index 0000000000000..a66b118da4b2b --- /dev/null +++ b/pytorch_lightning/plugins/sharded_native_amp_plugin.py @@ -0,0 +1,33 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import cast + +from pytorch_lightning.utilities import NATIVE_AMP_AVAILABLE, FAIRSCALE_AVAILABLE +from pytorch_lightning.plugins.native_amp import NativeAMPPlugin + +if NATIVE_AMP_AVAILABLE and FAIRSCALE_AVAILABLE: + from fairscale.optim import OSS + from fairscale.optim.grad_scaler import ShardedGradScaler + + +class ShardedNativeAMPPlugin(NativeAMPPlugin): + @property + def scaler(self): + return ShardedGradScaler() + + def clip_gradients(self, grad_clip_val, model, optimizer): + max_norm = grad_clip_val + norm_type = float(2.0) + optimizer = cast(OSS, optimizer) + optimizer.clip_grad_norm(max_norm, norm_type=norm_type) diff --git a/pytorch_lightning/plugins/sharded_plugin.py b/pytorch_lightning/plugins/sharded_plugin.py new file mode 100644 index 0000000000000..84b23c3230059 --- /dev/null +++ b/pytorch_lightning/plugins/sharded_plugin.py @@ -0,0 +1,81 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List, Optional, Union + +from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.plugins.ddp_plugin import DDPPlugin +from pytorch_lightning.utilities import rank_zero_only, FAIRSCALE_AVAILABLE +from pytorch_lightning.utilities.exceptions import MisconfigurationException + +if FAIRSCALE_AVAILABLE: + from fairscale.optim import OSS + from pytorch_lightning.overrides.fairscale import LightningShardedDataParallel + + +class DDPShardedPlugin(DDPPlugin): + + def __init__(self, **kwargs): + self._check_fairscale() + super().__init__(**kwargs) + + def configure_ddp( + self, model: LightningModule, device_ids: List[int] + ): + self._wrap_optimizers(model) + return LightningShardedDataParallel(model, sharded_optimizer=model.trainer.optimizers) + + def optimizer_state(self, optimizer: 'OSS') -> Optional[dict]: + optimizer.consolidate_state_dict() + return self._optim_state_dict(optimizer) + + def on_before_forward(self, model: LightningModule, *args): + return model.transfer_batch_to_device(args, model.trainer.root_gpu) + + def _check_fairscale(self): + if not FAIRSCALE_AVAILABLE: + raise MisconfigurationException( + 'Sharded DDP Plugin requires Fairscale to be installed.' + ) + + @rank_zero_only + def _optim_state_dict(self, optimizer): + return optimizer.state_dict() + + def _wrap_optimizers(self, model): + trainer = model.trainer + if trainer.testing is True: + return + + self._reinit_with_fairscale_oss(trainer) + + def _reinit_with_fairscale_oss(self, trainer): + optimizers = trainer.optimizers + for x, optimizer in enumerate(optimizers): + if not isinstance(optimizer, OSS): + optim_class = type(optimizer) + zero_optimizer = OSS( + params=optimizer.param_groups, + optim=optim_class, + **optimizer.defaults + ) + optimizers[x] = zero_optimizer + del optimizer + + def get_model_from_plugin( + self, + model: Union['LightningShardedDataParallel', LightningModule] + ) -> LightningModule: + if isinstance(model, LightningShardedDataParallel): + return model.module + return model diff --git a/pytorch_lightning/trainer/connectors/precision_connector.py b/pytorch_lightning/trainer/connectors/precision_connector.py index 015627f20dcf4..8866607dc678c 100644 --- a/pytorch_lightning/trainer/connectors/precision_connector.py +++ b/pytorch_lightning/trainer/connectors/precision_connector.py @@ -11,11 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional from pytorch_lightning import _logger as log from pytorch_lightning.plugins.apex import ApexPlugin from pytorch_lightning.plugins.native_amp import NativeAMPPlugin +from pytorch_lightning.plugins.sharded_native_amp_plugin import ShardedNativeAMPPlugin +from pytorch_lightning.plugins.sharded_plugin import DDPShardedPlugin from pytorch_lightning.utilities import APEX_AVAILABLE, NATIVE_AMP_AVAILABLE, AMPType, rank_zero_warn +from pytorch_lightning.utilities.exceptions import MisconfigurationException class PrecisionConnector: @@ -24,7 +28,7 @@ def __init__(self, trainer): self.trainer = trainer self.backend = None - def on_trainer_init(self, precision, amp_level, amp_backend): + def on_trainer_init(self, precision, amp_level, amp_backend, plugins): # AMP init # These are the only lines needed after v0.8.0 # we wrap the user's forward with autocast and give it back at the end of fit @@ -33,18 +37,19 @@ def on_trainer_init(self, precision, amp_level, amp_backend): self.trainer.scaler = None self.trainer.amp_level = amp_level - self.init_amp(amp_backend) + self.init_amp(amp_backend, plugins) - def init_amp(self, amp_type: str): + def init_amp(self, amp_type: str, plugins: Optional[list]): assert self.trainer.precision in (16, 32), 'only 32 or 16 bit precision supported' self.trainer.amp_backend = None - self._setup_amp_backend(amp_type) + self._setup_amp_backend(amp_type, plugins) - def _setup_amp_backend(self, amp_type: str): + def _setup_amp_backend(self, amp_type: str, plugins: Optional[list]): if self.trainer.precision != 16: # no AMP requested, so we can leave now return + using_sharded_plugin = self._check_using_sharded_plugin(plugins) amp_type = amp_type.lower() assert amp_type in ('native', 'apex'), f'Unsupported amp type {amp_type}' if amp_type == 'native': @@ -54,14 +59,22 @@ def _setup_amp_backend(self, amp_type: str): ' We will attempt to use NVIDIA Apex for this session.') amp_type = 'apex' else: - log.info('Using native 16bit precision.') self.trainer.amp_backend = AMPType.NATIVE - self.backend = NativeAMPPlugin(self.trainer) + if using_sharded_plugin: + log.info('Using sharded 16bit precision.') + self.backend = ShardedNativeAMPPlugin(self.trainer) + else: + log.info('Using native 16bit precision.') + self.backend = NativeAMPPlugin(self.trainer) if amp_type == 'apex': if not APEX_AVAILABLE: rank_zero_warn('You have asked for Apex AMP but you have not installed it yet.' ' Install apex first using this guide: https://github.com/NVIDIA/apex#linux') + elif using_sharded_plugin: + raise MisconfigurationException( + 'Sharded Plugin is not supported with Apex AMP, please using native AMP for 16-bit precision.' + ) else: log.info('Using APEX 16bit precision.') self.trainer.amp_backend = AMPType.APEX @@ -79,3 +92,10 @@ def connect(self, model): self.trainer.optimizers = optimizers return model + + def _check_using_sharded_plugin(self, plugins: Optional[list]): + if plugins is not None: + for plugin in plugins: + if isinstance(plugin, DDPShardedPlugin): + return True + return False diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 251cf6cf1e78f..1f8a2341b9889 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -394,7 +394,7 @@ def __init__( ) # set precision - self.precision_connector.on_trainer_init(precision, amp_level, amp_backend) + self.precision_connector.on_trainer_init(precision, amp_level, amp_backend, plugins) # last thing are the plugins which override whatever the trainer used by default self.plugin_connector.on_trainer_init(plugins) diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py index be229ac1dfea3..916e434e5ff06 100644 --- a/pytorch_lightning/utilities/__init__.py +++ b/pytorch_lightning/utilities/__init__.py @@ -13,6 +13,7 @@ # limitations under the License. """General utilities""" import importlib +import platform from enum import Enum import numpy @@ -48,6 +49,7 @@ def _module_available(module_path: str) -> bool: HYDRA_AVAILABLE = _module_available("hydra") TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists() +FAIRSCALE_AVAILABLE = platform.system() != 'Windows' and _module_available('fairscale.nn.data_parallel') FLOAT16_EPSILON = numpy.finfo(numpy.float16).eps FLOAT32_EPSILON = numpy.finfo(numpy.float32).eps diff --git a/requirements/extra.txt b/requirements/extra.txt index b5cbde0c15485..76ce46ae7fe6b 100644 --- a/requirements/extra.txt +++ b/requirements/extra.txt @@ -9,3 +9,4 @@ torchtext>=0.3.1, <0.7 # TODO: temporary fix fix for compatibility onnx>=1.7.0 onnxruntime>=1.3.0 hydra-core>=1.0 +https://github.com/facebookresearch/fairscale/archive/8e85ce8c93569017521d92ceb78dba2c57c955a0.zip # TODO temporary fix till release version diff --git a/tests/plugins/test_sharded_plugin.py b/tests/plugins/test_sharded_plugin.py new file mode 100644 index 0000000000000..5010c39de7a80 --- /dev/null +++ b/tests/plugins/test_sharded_plugin.py @@ -0,0 +1,334 @@ +import os +import platform +from unittest import mock + +import pytest +import torch + +from pytorch_lightning import Trainer +from pytorch_lightning.callbacks import Callback +from pytorch_lightning.plugins.sharded_native_amp_plugin import ShardedNativeAMPPlugin +from pytorch_lightning.plugins.sharded_plugin import DDPShardedPlugin, FAIRSCALE_AVAILABLE +from pytorch_lightning.utilities import NATIVE_AMP_AVAILABLE, APEX_AVAILABLE +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from tests.base.boring_model import BoringModel + + +@mock.patch.dict( + os.environ, + { + "CUDA_VISIBLE_DEVICES": "0,1", + "SLURM_NTASKS": "2", + "SLURM_JOB_NAME": "SOME_NAME", + "SLURM_NODEID": "0", + "LOCAL_RANK": "0", + "SLURM_LOCALID": "0", + }, +) +@mock.patch("torch.cuda.device_count", return_value=2) +@pytest.mark.parametrize( + ["ddp_backend", "gpus", "num_processes"], + [("ddp_cpu", None, None), ("ddp", 2, 0), ("ddp2", 2, 0), ("ddp_spawn", 2, 0)], +) +@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") +def test_ddp_choice_sharded(tmpdir, ddp_backend, gpus, num_processes): + """ + Test to ensure that plugin is correctly chosen + """ + + class CB(Callback): + def on_fit_start(self, trainer, pl_module): + assert isinstance(trainer.accelerator_backend.ddp_plugin, DDPShardedPlugin) + raise SystemExit() + + model = BoringModel() + trainer = Trainer( + fast_dev_run=True, + gpus=gpus, + num_processes=num_processes, + distributed_backend=ddp_backend, + plugins=[DDPShardedPlugin()], + callbacks=[CB()], + ) + + with pytest.raises(SystemExit): + trainer.fit(model) + + +@pytest.mark.skipif(not APEX_AVAILABLE, reason="test requires apex") +@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") +def test_invalid_apex_sharded(tmpdir): + """ + Test to ensure that we raise an error when we try to use apex and sharded + """ + + model = BoringModel() + with pytest.raises(MisconfigurationException, match='Sharded Plugin is not supported with Apex AMP'): + trainer = Trainer( + fast_dev_run=True, + distributed_backend='ddp_spawn', + plugins=[DDPShardedPlugin()], + precision=16, + amp_backend='apex' + ) + + trainer.fit(model) + + +@mock.patch.dict( + os.environ, + { + "CUDA_VISIBLE_DEVICES": "0,1", + "SLURM_NTASKS": "2", + "SLURM_JOB_NAME": "SOME_NAME", + "SLURM_NODEID": "0", + "LOCAL_RANK": "0", + "SLURM_LOCALID": "0", + }, +) +@mock.patch("torch.cuda.device_count", return_value=2) +@pytest.mark.parametrize( + ["ddp_backend", "gpus", "num_processes"], + [("ddp_cpu", None, None), ("ddp", 2, 0), ("ddp2", 2, 0), ("ddp_spawn", 2, 0)], +) +@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") +@pytest.mark.skipif(not NATIVE_AMP_AVAILABLE, reason="Requires native AMP") +def test_ddp_choice_sharded_amp(tmpdir, ddp_backend, gpus, num_processes): + """ + Test to ensure that plugin native amp plugin is correctly chosen when using sharded + """ + + class CB(Callback): + def on_fit_start(self, trainer, pl_module): + assert isinstance(trainer.accelerator_backend.ddp_plugin, DDPShardedPlugin) + assert isinstance(trainer.precision_connector.backend, ShardedNativeAMPPlugin) + raise SystemExit() + + model = BoringModel() + trainer = Trainer( + fast_dev_run=True, + gpus=gpus, + precision=16, + num_processes=num_processes, + distributed_backend=ddp_backend, + plugins=[DDPShardedPlugin()], + callbacks=[CB()], + ) + + with pytest.raises(SystemExit): + trainer.fit(model) + + +@pytest.mark.skipif(platform.system() == "Windows", + reason="Distributed training is not supported on Windows") +@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") +def test_ddp_sharded_plugin_checkpoint_cpu(tmpdir): + """ + Test to ensure that checkpoint is saved correctly + """ + model = BoringModel() + trainer = Trainer( + accelerator='ddp_cpu', + plugins=[DDPShardedPlugin()], + fast_dev_run=True, + ) + + trainer.fit(model) + + checkpoint_path = os.path.join(tmpdir, 'model.pt') + trainer.save_checkpoint(checkpoint_path) + saved_model = BoringModel.load_from_checkpoint(checkpoint_path) + + # Assert model parameters are identical after loading + for ddp_param, shard_param in zip(model.parameters(), saved_model.parameters()): + assert torch.equal(ddp_param, shard_param) + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +@pytest.mark.skipif(platform.system() == "Windows", + reason="Distributed training is not supported on Windows") +@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") +def test_ddp_sharded_plugin_checkpoint_multi_gpu(tmpdir): + """ + Test to ensure that checkpoint is saved correctly when using multiple GPUs + """ + model = BoringModel() + trainer = Trainer( + gpus=2, + accelerator='ddp_spawn', + plugins=[DDPShardedPlugin()], + fast_dev_run=True, + ) + + trainer.fit(model) + + checkpoint_path = os.path.join(tmpdir, 'model.pt') + trainer.save_checkpoint(checkpoint_path) + saved_model = BoringModel.load_from_checkpoint(checkpoint_path) + + # Assert model parameters are identical after loading + for ddp_param, shard_param in zip(model.parameters(), saved_model.parameters()): + assert torch.equal(ddp_param, shard_param) + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +@pytest.mark.skipif(platform.system() == "Windows", + reason="Distributed training is not supported on Windows") +@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") +def test_ddp_sharded_plugin_finetune(tmpdir): + """ + Test to ensure that we can save and restart training (simulate fine-tuning) + """ + model = BoringModel() + trainer = Trainer( + gpus=2, + accelerator='ddp_spawn', + plugins=[DDPShardedPlugin()], + fast_dev_run=True, + ) + trainer.fit(model) + + checkpoint_path = os.path.join(tmpdir, 'model.pt') + trainer.save_checkpoint(checkpoint_path) + saved_model = BoringModel.load_from_checkpoint(checkpoint_path) + + trainer = Trainer( + fast_dev_run=True, + ) + trainer.fit(saved_model) + + +@pytest.mark.skipif(platform.system() == "Windows", + reason="Distributed training is not supported on Windows") +@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") +def test_ddp_sharded_plugin_resume_from_checkpoint(tmpdir): + """ + Test to ensure that resuming from checkpoint works + """ + model = BoringModel() + trainer = Trainer( + accelerator='ddp_cpu', + plugins=[DDPShardedPlugin()], + fast_dev_run=True, + ) + + trainer.fit(model) + + checkpoint_path = os.path.join(tmpdir, 'model.pt') + trainer.save_checkpoint(checkpoint_path) + + model = BoringModel() + + trainer = Trainer( + accelerator='ddp_cpu', + plugins=[DDPShardedPlugin()], + fast_dev_run=True, + resume_from_checkpoint=checkpoint_path + ) + + trainer.fit(model) + + +@pytest.mark.skip(reason="Not a critical test, skip till drone CI performance improves.") +@pytest.mark.skip(reason="Currently unsupported restarting training on different number of devices.") +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +@pytest.mark.skipif(platform.system() == "Windows", + reason="Distributed training is not supported on Windows") +@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") +def test_ddp_sharded_plugin_resume_from_checkpoint_downsize_gpus(tmpdir): + """ + Test to ensure that resuming from checkpoint works when downsizing number of GPUS + """ + model = BoringModel() + trainer = Trainer( + accelerator='ddp_spawn', + plugins=[DDPShardedPlugin()], + fast_dev_run=True, + gpus=2, + ) + + trainer.fit(model) + + checkpoint_path = os.path.join(tmpdir, 'model.pt') + trainer.save_checkpoint(checkpoint_path) + + model = BoringModel() + + trainer = Trainer( + accelerator='ddp_spawn', + plugins=[DDPShardedPlugin()], + fast_dev_run=True, + gpus=1, + resume_from_checkpoint=checkpoint_path + ) + + trainer.fit(model) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine") +@pytest.mark.skipif(platform.system() == "Windows", + reason="Distributed training is not supported on Windows") +@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") +def test_ddp_sharded_plugin_resume_from_checkpoint_gpu_to_cpu(tmpdir): + """ + Test to ensure that resuming from checkpoint works when going from GPUs- > CPU + """ + model = BoringModel() + trainer = Trainer( + accelerator='ddp_spawn', + plugins=[DDPShardedPlugin()], + gpus=1, + fast_dev_run=True + ) + + trainer.fit(model) + + checkpoint_path = os.path.join(tmpdir, 'model.pt') + trainer.save_checkpoint(checkpoint_path) + + model = BoringModel() + + trainer = Trainer( + plugins=[DDPShardedPlugin()], + accelerator='ddp_cpu', + fast_dev_run=True, + resume_from_checkpoint=checkpoint_path + ) + + trainer.fit(model) + + +@pytest.mark.skipif(platform.system() == "Windows", + reason="Distributed training is not supported on Windows") +@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") +def test_ddp_sharded_plugin_test(tmpdir): + """ + Test to ensure we can use test without fit + """ + model = BoringModel() + trainer = Trainer( + accelerator='ddp_cpu', + plugins=[DDPShardedPlugin()], + fast_dev_run=True, + ) + + trainer.test(model) + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +@pytest.mark.skipif(platform.system() == "Windows", + reason="Distributed training is not supported on Windows") +@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") +def test_ddp_sharded_plugin_test_multigpu(tmpdir): + """ + Test to ensure we can use test without fit + """ + model = BoringModel() + trainer = Trainer( + accelerator='ddp_spawn', + gpus=2, + plugins=[DDPShardedPlugin()], + fast_dev_run=True, + ) + + trainer.test(model)