Skip to content

Commit

Permalink
Error for unsupported precision types with ModelParallelStrategy (#19902
Browse files Browse the repository at this point in the history
)
  • Loading branch information
awaelchli committed May 23, 2024
1 parent c09356d commit 896c2a6
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 13 deletions.
7 changes: 7 additions & 0 deletions src/lightning/fabric/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
)
from lightning.fabric.strategies.ddp import _DDP_FORK_ALIASES
from lightning.fabric.strategies.fsdp import _FSDP_ALIASES, FSDPStrategy
from lightning.fabric.strategies.model_parallel import ModelParallelStrategy
from lightning.fabric.utilities import rank_zero_info, rank_zero_warn
from lightning.fabric.utilities.device_parser import _determine_root_gpu_device
from lightning.fabric.utilities.imports import _IS_INTERACTIVE
Expand Down Expand Up @@ -460,6 +461,12 @@ def _check_and_init_precision(self) -> Precision:
return DeepSpeedPrecision(self._precision_input) # type: ignore
if isinstance(self.strategy, FSDPStrategy):
return FSDPPrecision(precision=self._precision_input) # type: ignore[arg-type]
mp_precision_supported = ("32-true", "bf16-mixed", "bf16-true", "16-true")
if isinstance(self.strategy, ModelParallelStrategy) and self._precision_input not in mp_precision_supported:
raise ValueError(
f"The `ModelParallelStrategy` does not support `Fabric(..., precision={self._precision_input!r})`."
f" Choose a different precision among: {', '.join(mp_precision_supported)}."
)
if self._precision_input in ("16-true", "bf16-true"):
return HalfPrecision(self._precision_input) # type: ignore
if self._precision_input == "32-true":
Expand Down
10 changes: 10 additions & 0 deletions src/lightning/pytorch/trainer/connectors/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,16 @@ def _validate_precision_choice(self) -> None:
self.accelerator, CUDAAccelerator
):
raise RuntimeError("Bitsandbytes is only supported on CUDA GPUs.")
mp_precision_supported = ("32-true", "bf16-mixed", "bf16-true", "16-true")
if (
isinstance(self._strategy_flag, ModelParallelStrategy)
and self._precision_flag not in mp_precision_supported
):
raise ValueError(
f"The `ModelParallelStrategy` does not support `Fabric(..., precision={self._precision_flag!r})`."
f" Choose a different precision among: {', '.join(mp_precision_supported)}."
)

if _habana_available_and_importable():
from lightning_habana import HPUAccelerator

Expand Down
15 changes: 2 additions & 13 deletions tests/tests_fabric/strategies/test_model_parallel_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,9 +241,7 @@ def _train(fabric, model=None, optimizer=None):
@pytest.mark.parametrize(
"precision",
[
pytest.param(
"16-mixed", marks=pytest.mark.xfail(reason="Precision plugin does not implement ShardedGradScaler yet")
),
pytest.param("32-true"),
pytest.param("bf16-mixed", marks=RunIf(bf16_cuda=True)),
],
)
Expand Down Expand Up @@ -548,26 +546,17 @@ def _parallelize_single_linear_tp_fsdp2(model, device_mesh):
"precision",
[
"32-true",
pytest.param("16-mixed"),
pytest.param("bf16-mixed", marks=RunIf(bf16_cuda=True)),
],
)
@pytest.mark.parametrize(
"clip_type",
[
pytest.param("norm", marks=pytest.mark.skip("Gradient clipping by norm is not correct.")),
pytest.param(
"val",
marks=pytest.mark.xfail(
raises=RecursionError, strict=False, reason="Recursion error when clipping DTensor"
),
),
"val",
],
)
def test_clip_gradients(clip_type, precision):
if clip_type == "norm" and precision == "16-mixed":
pytest.skip(reason="Clipping by norm with 16-mixed is numerically unstable.")

strategy = ModelParallelStrategy(_parallelize_single_linear_tp_fsdp2)
fabric = Fabric(accelerator="auto", devices=2, precision=precision, strategy=strategy)
fabric.launch()
Expand Down
14 changes: 14 additions & 0 deletions tests/tests_fabric/test_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import inspect
import os
import sys
from contextlib import nullcontext
from typing import Any, Dict
from unittest import mock
from unittest.mock import Mock
Expand Down Expand Up @@ -53,6 +54,7 @@
DDPStrategy,
DeepSpeedStrategy,
FSDPStrategy,
ModelParallelStrategy,
SingleDeviceStrategy,
SingleDeviceXLAStrategy,
XLAFSDPStrategy,
Expand Down Expand Up @@ -866,6 +868,18 @@ def test_precision_selection_amp_ddp(strategy, devices, is_custom_plugin, plugin
assert isinstance(connector.precision, plugin_cls)


@RunIf(min_torch="2.3")
@pytest.mark.parametrize(
("precision", "raises"),
[("32-true", False), ("16-true", False), ("bf16-true", False), ("16-mixed", True), ("bf16-mixed", False)],
)
@mock.patch("lightning.fabric.accelerators.mps.MPSAccelerator.is_available", return_value=False)
def test_precision_selection_model_parallel(_, precision, raises):
error_context = pytest.raises(ValueError, match=f"does not support .*{precision}") if raises else nullcontext()
with error_context:
_Connector(precision=precision, strategy=ModelParallelStrategy(lambda x, _: x))


def test_bitsandbytes_precision_cuda_required(monkeypatch):
monkeypatch.setattr(lightning.fabric.plugins.precision.bitsandbytes, "_BITSANDBYTES_AVAILABLE", True)
monkeypatch.setitem(sys.modules, "bitsandbytes", Mock())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import inspect
import os
import sys
from contextlib import nullcontext
from typing import Any, Dict
from unittest import mock
from unittest.mock import Mock
Expand Down Expand Up @@ -48,6 +49,7 @@
DDPStrategy,
DeepSpeedStrategy,
FSDPStrategy,
ModelParallelStrategy,
SingleDeviceStrategy,
SingleDeviceXLAStrategy,
XLAStrategy,
Expand Down Expand Up @@ -1063,3 +1065,14 @@ def test_bitsandbytes_precision_cuda_required(monkeypatch):
monkeypatch.setitem(sys.modules, "bitsandbytes", Mock())
with pytest.raises(RuntimeError, match="Bitsandbytes is only supported on CUDA GPUs"):
_AcceleratorConnector(accelerator="cpu", plugins=BitsandbytesPrecision(mode="int8"))


@RunIf(min_torch="2.3")
@pytest.mark.parametrize(
("precision", "raises"),
[("32-true", False), ("16-true", False), ("bf16-true", False), ("16-mixed", True), ("bf16-mixed", False)],
)
def test_precision_selection_model_parallel(precision, raises, mps_count_0):
error_context = pytest.raises(ValueError, match=f"does not support .*{precision}") if raises else nullcontext()
with error_context:
_AcceleratorConnector(precision=precision, strategy=ModelParallelStrategy())

0 comments on commit 896c2a6

Please sign in to comment.