Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error for unsupported precision types with ModelParallelStrategy #19902

Merged
merged 6 commits into from
May 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/lightning/fabric/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
)
from lightning.fabric.strategies.ddp import _DDP_FORK_ALIASES
from lightning.fabric.strategies.fsdp import _FSDP_ALIASES, FSDPStrategy
from lightning.fabric.strategies.model_parallel import ModelParallelStrategy
from lightning.fabric.utilities import rank_zero_info, rank_zero_warn
from lightning.fabric.utilities.device_parser import _determine_root_gpu_device
from lightning.fabric.utilities.imports import _IS_INTERACTIVE
Expand Down Expand Up @@ -460,6 +461,12 @@ def _check_and_init_precision(self) -> Precision:
return DeepSpeedPrecision(self._precision_input) # type: ignore
if isinstance(self.strategy, FSDPStrategy):
return FSDPPrecision(precision=self._precision_input) # type: ignore[arg-type]
mp_precision_supported = ("32-true", "bf16-mixed", "bf16-true", "16-true")
if isinstance(self.strategy, ModelParallelStrategy) and self._precision_input not in mp_precision_supported:
raise ValueError(
f"The `ModelParallelStrategy` does not support `Fabric(..., precision={self._precision_input!r})`."
f" Choose a different precision among: {', '.join(mp_precision_supported)}."
)
if self._precision_input in ("16-true", "bf16-true"):
return HalfPrecision(self._precision_input) # type: ignore
if self._precision_input == "32-true":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,16 @@ def _validate_precision_choice(self) -> None:
self.accelerator, CUDAAccelerator
):
raise RuntimeError("Bitsandbytes is only supported on CUDA GPUs.")
mp_precision_supported = ("32-true", "bf16-mixed", "bf16-true", "16-true")
if (
isinstance(self._strategy_flag, ModelParallelStrategy)
and self._precision_flag not in mp_precision_supported
):
raise ValueError(
f"The `ModelParallelStrategy` does not support `Fabric(..., precision={self._precision_flag!r})`."
f" Choose a different precision among: {', '.join(mp_precision_supported)}."
)

if _habana_available_and_importable():
from lightning_habana import HPUAccelerator

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -241,9 +241,7 @@ def _train(fabric, model=None, optimizer=None):
@pytest.mark.parametrize(
"precision",
[
pytest.param(
"16-mixed", marks=pytest.mark.xfail(reason="Precision plugin does not implement ShardedGradScaler yet")
),
pytest.param("32-true"),
pytest.param("bf16-mixed", marks=RunIf(bf16_cuda=True)),
],
)
Expand Down Expand Up @@ -548,26 +546,17 @@ def _parallelize_single_linear_tp_fsdp2(model, device_mesh):
"precision",
[
"32-true",
pytest.param("16-mixed"),
pytest.param("bf16-mixed", marks=RunIf(bf16_cuda=True)),
],
)
@pytest.mark.parametrize(
"clip_type",
[
pytest.param("norm", marks=pytest.mark.skip("Gradient clipping by norm is not correct.")),
pytest.param(
"val",
marks=pytest.mark.xfail(
raises=RecursionError, strict=False, reason="Recursion error when clipping DTensor"
),
),
"val",
],
)
def test_clip_gradients(clip_type, precision):
if clip_type == "norm" and precision == "16-mixed":
pytest.skip(reason="Clipping by norm with 16-mixed is numerically unstable.")

strategy = ModelParallelStrategy(_parallelize_single_linear_tp_fsdp2)
fabric = Fabric(accelerator="auto", devices=2, precision=precision, strategy=strategy)
fabric.launch()
Expand Down
14 changes: 14 additions & 0 deletions tests/tests_fabric/test_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import inspect
import os
import sys
from contextlib import nullcontext
from typing import Any, Dict
from unittest import mock
from unittest.mock import Mock
Expand Down Expand Up @@ -53,6 +54,7 @@
DDPStrategy,
DeepSpeedStrategy,
FSDPStrategy,
ModelParallelStrategy,
SingleDeviceStrategy,
SingleDeviceXLAStrategy,
XLAFSDPStrategy,
Expand Down Expand Up @@ -866,6 +868,18 @@ def test_precision_selection_amp_ddp(strategy, devices, is_custom_plugin, plugin
assert isinstance(connector.precision, plugin_cls)


@RunIf(min_torch="2.3")
@pytest.mark.parametrize(
("precision", "raises"),
[("32-true", False), ("16-true", False), ("bf16-true", False), ("16-mixed", True), ("bf16-mixed", False)],
)
@mock.patch("lightning.fabric.accelerators.mps.MPSAccelerator.is_available", return_value=False)
def test_precision_selection_model_parallel(_, precision, raises):
error_context = pytest.raises(ValueError, match=f"does not support .*{precision}") if raises else nullcontext()
with error_context:
_Connector(precision=precision, strategy=ModelParallelStrategy(lambda x, _: x))


def test_bitsandbytes_precision_cuda_required(monkeypatch):
monkeypatch.setattr(lightning.fabric.plugins.precision.bitsandbytes, "_BITSANDBYTES_AVAILABLE", True)
monkeypatch.setitem(sys.modules, "bitsandbytes", Mock())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import inspect
import os
import sys
from contextlib import nullcontext
from typing import Any, Dict
from unittest import mock
from unittest.mock import Mock
Expand Down Expand Up @@ -48,6 +49,7 @@
DDPStrategy,
DeepSpeedStrategy,
FSDPStrategy,
ModelParallelStrategy,
SingleDeviceStrategy,
SingleDeviceXLAStrategy,
XLAStrategy,
Expand Down Expand Up @@ -1063,3 +1065,14 @@ def test_bitsandbytes_precision_cuda_required(monkeypatch):
monkeypatch.setitem(sys.modules, "bitsandbytes", Mock())
with pytest.raises(RuntimeError, match="Bitsandbytes is only supported on CUDA GPUs"):
_AcceleratorConnector(accelerator="cpu", plugins=BitsandbytesPrecision(mode="int8"))


@RunIf(min_torch="2.3")
@pytest.mark.parametrize(
("precision", "raises"),
[("32-true", False), ("16-true", False), ("bf16-true", False), ("16-mixed", True), ("bf16-mixed", False)],
)
def test_precision_selection_model_parallel(precision, raises, mps_count_0):
error_context = pytest.raises(ValueError, match=f"does not support .*{precision}") if raises else nullcontext()
with error_context:
_AcceleratorConnector(precision=precision, strategy=ModelParallelStrategy())
Loading