From 61f9069e640ba15eddc8378555e06a48f20bc440 Mon Sep 17 00:00:00 2001 From: Ankit Gola Date: Wed, 19 Jun 2024 11:09:31 +0530 Subject: [PATCH] Update additional dtype support (#194) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .azure/hpu-tests.yml | 9 +- CHANGELOG.md | 1 + docs/source/intermediate.rst | 5 +- .../pytorch/plugins/deepspeed_precision.py | 19 +- .../pytorch/plugins/precision.py | 72 ++-- .../pytorch/strategies/deepspeed.py | 4 +- .../pytorch/strategies/single.py | 4 +- src/lightning_habana/utils/resources.py | 10 + tests/test_pytorch/test_deepspeed.py | 15 +- tests/test_pytorch/test_dynamic_shapes.py | 6 +- tests/test_pytorch/test_precision.py | 339 +++++++++++++----- 11 files changed, 351 insertions(+), 133 deletions(-) diff --git a/.azure/hpu-tests.yml b/.azure/hpu-tests.yml index 367385fb..0c0e618b 100644 --- a/.azure/hpu-tests.yml +++ b/.azure/hpu-tests.yml @@ -120,8 +120,8 @@ jobs: displayName: 'HPU General tests' - bash: | - python -m pytest -sv tests/test_pytorch/test_compile.py \ - --hpus 1 --junitxml=hpu_compile_test-results.xml + python -m pytest -sv tests/test_pytorch/test_compile.py \ + --hpus 1 --junitxml=hpu_compile_test-results.xml env: PT_HPU_LAZY_MODE: 0 displayName: 'HPU torch compile tests' @@ -142,7 +142,8 @@ jobs: - bash: | bash tests/run_standalone_tests.sh --hpus 1 -m standalone_only -f \ - tests/test_pytorch/test_precision.py + tests/test_pytorch/test_precision.py \ + tests/test_pytorch/test_dynamic_shapes.py displayName: Standalone-only single card tests - bash: | @@ -150,7 +151,7 @@ jobs: bash tests/run_standalone_tests.sh --hpus 2 -f \ tests/test_pytorch/test_accelerator.py \ tests/test_pytorch/test_compile.py \ - tests/test_pytorch/test_profiler.py \ + tests/test_pytorch/test_profiler.py displayName: 'Multi card(2) HPU test' - bash: pip install ".[examples]" diff --git a/CHANGELOG.md b/CHANGELOG.md index cfa1d336..22e67139 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Added support for additional dtypes ([#194](https://github.com/Lightning-AI/lightning-Habana/pull/194)) - ### Changed diff --git a/docs/source/intermediate.rst b/docs/source/intermediate.rst index 13eb634c..c5999e25 100644 --- a/docs/source/intermediate.rst +++ b/docs/source/intermediate.rst @@ -33,6 +33,7 @@ HPUPrecisionPlugin, :class:`~lightning_habana.pytorch.plugins.precision.HPUPreci In addition to the default settings, you can choose to override these defaults and provide your own BF16 (LOWER_LIST) and FP32 (FP32_LIST) The `LOWER_LIST` and `FP32_LIST` environment variables must be set before any instances begin. +HPUPrecisionPlugin supports `bf16-mixed` and `16-mixed` for mixed precision training. It is advised to use `bf16-mixed` over `16-mixed` where possible. The following is an excerpt from an MNIST example implemented on a single HPU. @@ -156,7 +157,7 @@ The plugin accepts following args for the fp8 training: 1. Import `transformer_engine` and replace your modules with `transformer_engine` modules in the model. 2. Wrap the forward pass of the training with `fp8_autocast`. - Users may still use `HPUPrecisionPlugin` to train in `bf16-mixed` precision for modules not supported by `transformer_engine`. + Users may still use `HPUPrecisionPlugin` to train in mixed precision for modules not supported by `transformer_engine`. .. note:: @@ -254,9 +255,9 @@ Refer to `Supported JSON Config File Options `__. +For a list of data types supported with HPU, refer to `PyTorch Support Matrix `__. ---- diff --git a/src/lightning_habana/pytorch/plugins/deepspeed_precision.py b/src/lightning_habana/pytorch/plugins/deepspeed_precision.py index e76d142e..49e84089 100644 --- a/src/lightning_habana/pytorch/plugins/deepspeed_precision.py +++ b/src/lightning_habana/pytorch/plugins/deepspeed_precision.py @@ -58,17 +58,24 @@ class HPUDeepSpeedPrecisionPlugin(HPUPrecisionPlugin): - """Plugin that enables bfloat support on HPUs. + """Plugin that enables mixed precision support on HPUs. Args: - precision: to enable ``torch.bfloat16`` (``'bf16-mixed'``). - device: The device for ``torch.autocast``. + precision (_PRECISION_INPUT, optional): Precision input. Defaults to "32-true". + recipe (Optional[Union[Mapping[str, Any], "DelayedScaling"]], optional): + recipe for fp8 training. Defaults to None. + replace_layers (bool, optional): Replace module with transformer engine equivalent. Defaults to False. + + Raises: + OSError: Unsupported Synapse version. + ValueError: Invalid precision value(s). + NotImplementedError: fp8 not available. """ def __init__( self, - precision: _PRECISION_INPUT, + precision: _PRECISION_INPUT = "32-true", device: str = "hpu", recipe: Optional[Union[Mapping[str, Any], "DelayedScaling"]] = None, replace_layers: bool = False, @@ -78,7 +85,7 @@ def __init__( "To use the `HPUDeepSpeedPrecisionPlugin`, you must have hpu DeepSpeed installed." " Install it by running `pip install git+https://github.com/HabanaAI/DeepSpeed.git@1.16.0`." ) - super().__init__(device=device, precision=precision, recipe=recipe, replace_layers=replace_layers) + super().__init__(precision=precision, recipe=recipe, replace_layers=replace_layers) def backward( self, @@ -153,7 +160,7 @@ def _enable_fp8_inference( ds_inference_kwargs["dtype"] = torch.bfloat16 assert ds_inference_kwargs["dtype"] in (torch.bfloat16, torch.float) - htcore.hpu_set_env() + htcore.quantization.hpu_set_inference_env() module = module.to("hpu") module = deepspeed.init_inference(module, **ds_inference_kwargs) diff --git a/src/lightning_habana/pytorch/plugins/precision.py b/src/lightning_habana/pytorch/plugins/precision.py index 6b78fe6f..3cc929e2 100644 --- a/src/lightning_habana/pytorch/plugins/precision.py +++ b/src/lightning_habana/pytorch/plugins/precision.py @@ -25,6 +25,7 @@ _HABANA_FRAMEWORK_AVAILABLE, _HABANA_QUANTIZATION_TOOLKIT_AVAILABLE, is_fp8_available, + is_fp16_available, modify_fp8_json, ) @@ -37,7 +38,15 @@ else: raise ModuleNotFoundError("You are missing `lightning` or `pytorch-lightning` package, please install it.") -_PRECISION_INPUT = Literal["32", "32-true", "bf16", "bf16-mixed", "fp8"] +_PRECISION_INPUT = Literal["32", "32-true", "bf16", "bf16-mixed", "fp8", "16-mixed"] + +_AMP_DICT = { + "32": torch.float32, + "32-true": torch.float32, + "bf16": torch.bfloat16, + "bf16-mixed": torch.bfloat16, + "16-mixed": torch.float16, +} if _HPU_SYNAPSE_GREATER_EQUAL_1_14_0 and _HABANA_FRAMEWORK_AVAILABLE: # Required for training in fp8 using habana transformer engine @@ -61,14 +70,21 @@ class HPUPrecisionPlugin(Precision): """Plugin that enables mixed precision support on HPUs. Args: - precision: to enable ``torch.bfloat16`` (``'bf16-mixed'``). - device: The device for ``torch.autocast``. + precision (_PRECISION_INPUT, optional): Precision input. Defaults to "32-true". + recipe (Optional[Union[Mapping[str, Any], "DelayedScaling"]], optional): + recipe for fp8 training. Defaults to None. + replace_layers (bool, optional): Replace module with transformer engine equivalent. Defaults to False. + + Raises: + OSError: Unsupported Synapse version. + ValueError: Invalid precision value(s). + NotImplementedError: fp8 not available. """ def __init__( self, - precision: _PRECISION_INPUT, + precision: _PRECISION_INPUT = "32-true", device: str = "hpu", recipe: Optional[Union[Mapping[str, Any], "DelayedScaling"]] = None, replace_layers: bool = False, @@ -85,13 +101,18 @@ def __init__( self.replace_layers = False self.device = device - if any([recipe, replace_layers]) and precision != "fp8": + if any([recipe, replace_layers]) and self.precision != "fp8": rank_zero_warn(f"Precision is not 'fp8'. Params {recipe=} and {replace_layers=} will not be set.") self.recipe = None self.fp8_train_available = False self.fp8_inference_available = False + if self.precision == "16-mixed": + fp16_available, reason_no_fp16 = is_fp16_available() + if not fp16_available: + raise NotImplementedError(f"fp16 not supported: {reason_no_fp16}.") + if self.precision == "fp8": fp8_available, reason_no_fp8 = is_fp8_available() if not fp8_available: @@ -102,7 +123,7 @@ def __init__( self.replace_layers = replace_layers rank_zero_info( - f"fp8 training available: {self.fp8_train_available}. " + f"fp8 training available: {self.fp8_train_available}." f"fp8 inference available: {self.fp8_inference_available}." ) @@ -119,7 +140,6 @@ def _setup_fp8_quant_config(self, quant: bool = True, fp8_data_path: Optional[st file_path=fp8_json, patch={ "dump_stats_path": os.path.join(fp8_data_path, "hqt"), - "dump_stats_xlsx_path": os.path.join(fp8_data_path, "hqt", "fp8stats.xlsx"), }, ) os.environ["QUANT_CONFIG"] = fp8_json @@ -177,18 +197,31 @@ def convert_modules( quant: bool = True, fp8_data_path: Optional[str] = None, ) -> torch.nn.Module: - """Enable support for fp8.""" - if inference is True and self.fp8_inference_available: + """Convert modules for FP8 precision. + + Args: + module (torch.nn.Module): Module to convert + inference (bool, optional): Convert module for inference (True) / Training (False). Defaults to False. + quant (bool, optional): Convert module for measurement (False) / Quantization (True) during inference. + Defaults to True. + fp8_data_path (Optional[str], optional): Path to dump fp8 inference measurement data. Defaults to None. + + Returns: + torch.nn.Module: FP8 enabled module + + """ + assert self.precision == "fp8", "HPUPrecisionPlugin.convert_modules() should only be used with precision=`fp8`." + if inference and self.fp8_inference_available: self._enable_fp8_inference(module, quant, fp8_data_path) - if self.fp8_train_available is True and self.replace_layers is True and inference is False: + if self.fp8_train_available and self.replace_layers and not inference: self._enable_fp8_training(module) return module def autocast_context_manager(self) -> Union[ContextManager[Any], torch.autocast]: """Return Autocast context manager.""" if self.fp8_train_available: - return _nested_precision_cm(fp8_enabled=(self.precision == "fp8"), recipe=self.recipe) - return torch.autocast(device_type="hpu", dtype=torch.bfloat16, enabled=True) + return tengine.fp8_autocast(enabled=True, fp8_recipe=self.recipe) + return torch.autocast(device_type="hpu", dtype=_AMP_DICT[self.precision], enabled=True) @contextmanager def forward_context(self) -> Generator[None, None, None]: @@ -215,18 +248,3 @@ def _replace_layers(module: torch.nn.Module) -> None: module.__setattr__(name, replacement) else: _replace_layers(child) - - -@contextmanager -def _nested_precision_cm( - fp8_enabled: bool, recipe: Optional[Union[Mapping[str, Any], "DelayedScaling"]] -) -> Generator[Any, Any, Any]: - """CM to nest fp8 precision with torch.autocast. - - This enables the ops that do not support fp8 to run with torch autocast. - - """ - with torch.autocast(device_type="hpu", dtype=torch.bfloat16, enabled=True), tengine.fp8_autocast( - enabled=fp8_enabled, fp8_recipe=recipe - ): - yield diff --git a/src/lightning_habana/pytorch/strategies/deepspeed.py b/src/lightning_habana/pytorch/strategies/deepspeed.py index dba13bbb..2bc27c49 100644 --- a/src/lightning_habana/pytorch/strategies/deepspeed.py +++ b/src/lightning_habana/pytorch/strategies/deepspeed.py @@ -933,7 +933,7 @@ def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None: def on_test_end(self) -> None: if self.precision_plugin.precision == "fp8" and self.precision_plugin.fp8_inference_available: - from quantization_toolkit import habana_quantization_toolkit # noqa + import habana_quantization_toolkit habana_quantization_toolkit.finish_measurements(self.model) htcore.quantization.hpu_teardown_inference_env() @@ -941,7 +941,7 @@ def on_test_end(self) -> None: def on_predict_end(self) -> None: if self.precision_plugin.precision == "fp8" and self.precision_plugin.fp8_inference_available: - from quantization_toolkit import habana_quantization_toolkit # noqa + import habana_quantization_toolkit habana_quantization_toolkit.finish_measurements(self.model) htcore.quantization.hpu_teardown_inference_env() diff --git a/src/lightning_habana/pytorch/strategies/single.py b/src/lightning_habana/pytorch/strategies/single.py index d12d2b79..d7dbf8c1 100644 --- a/src/lightning_habana/pytorch/strategies/single.py +++ b/src/lightning_habana/pytorch/strategies/single.py @@ -126,7 +126,7 @@ def predict_step(self, *args: Any, **kwargs: Any) -> Any: def on_test_end(self) -> None: if self.precision_plugin.precision == "fp8" and self.precision_plugin.fp8_inference_available: - from quantization_toolkit import habana_quantization_toolkit # noqa + import habana_quantization_toolkit habana_quantization_toolkit.finish_measurements(self.model) htcore.quantization.hpu_teardown_inference_env() @@ -134,7 +134,7 @@ def on_test_end(self) -> None: def on_predict_end(self) -> None: if self.precision_plugin.precision == "fp8" and self.precision_plugin.fp8_inference_available: - from quantization_toolkit import habana_quantization_toolkit # noqa + import habana_quantization_toolkit habana_quantization_toolkit.finish_measurements(self.model) htcore.quantization.hpu_teardown_inference_env() diff --git a/src/lightning_habana/utils/resources.py b/src/lightning_habana/utils/resources.py index 51c0afa0..966580bd 100644 --- a/src/lightning_habana/utils/resources.py +++ b/src/lightning_habana/utils/resources.py @@ -142,6 +142,16 @@ def is_fp8_available() -> Tuple[bool, str]: return tengine.fp8.is_fp8_available() +@lru_cache +def is_fp16_available() -> Tuple[bool, str]: + """Returns a bool indicating if fp16 is available.""" + if not _HABANA_FRAMEWORK_AVAILABLE: + raise OSError("Habana Frameworks required for training on Habana devices.") + if torch_hpu.get_device_name() == "GAUDI": + return False, "FP16 not supported on Gaudi, Gaudi2 or higher required." + return True, "" + + def modify_fp8_json(file_path: str, patch: dict) -> None: """Edit a specific entry in a JSON file. diff --git a/tests/test_pytorch/test_deepspeed.py b/tests/test_pytorch/test_deepspeed.py index a6d2e53f..ba7885c2 100644 --- a/tests/test_pytorch/test_deepspeed.py +++ b/tests/test_pytorch/test_deepspeed.py @@ -825,8 +825,19 @@ def run_training(tmpdir, model, plugin, strategy): return trainer.callback_metrics["val_loss"], trainer.callback_metrics["train_loss"] precision_plugin_params_list = [ - ({"device": "hpu", "precision": "bf16-mixed"}), - ({"device": "hpu", "precision": "fp8", "replace_layers": True, "recipe": recipe.DelayedScaling()}), + ({"precision": "bf16-mixed"}), + pytest.param( + {"precision": "16-mixed"}, + marks=pytest.mark.skipif( + HPUAccelerator.get_device_name() == "GAUDI", reason="fp16 supported on Gaudi2 and above" + ), + ), + pytest.param( + {"precision": "fp8", "replace_layers": True, "recipe": recipe.DelayedScaling()}, + marks=pytest.mark.skipif( + HPUAccelerator.get_device_name() == "GAUDI", reason="fp8 supported on Gaudi2 and above" + ), + ), ] loss_list = [] diff --git a/tests/test_pytorch/test_dynamic_shapes.py b/tests/test_pytorch/test_dynamic_shapes.py index d5118264..f6b91987 100644 --- a/tests/test_pytorch/test_dynamic_shapes.py +++ b/tests/test_pytorch/test_dynamic_shapes.py @@ -15,6 +15,7 @@ import csv import os +import pytest import torch from habana_frameworks.torch.hpu.metrics import metric_global from habana_frameworks.torch.utils.experimental import detect_recompilation_auto_model @@ -78,6 +79,7 @@ def test_dynamic_shapes_graph_compiler(tmpdir, hpus, monkeypatch): assert cached_compiles[0] <= default_compiles[0] +@pytest.mark.standalone_only() def test_dynamic_shapes_auto_detect_recompilations(tmpdir): """Test auto_detect_recompilations tool.""" @@ -95,10 +97,6 @@ def calculate_auto_detect_total_recompile_counts(csv_file_path): print(f"Error: CSV file not found: {csv_file_path}") return None - # Close dist pg if initialized. - if torch.distributed.is_available() and torch.distributed.is_initialized(): - torch.distributed.destroy_process_group() - seed_everything(42) model = DynamicOpsBoringModel() net = detect_recompilation_auto_model(model, csv_out=os.path.join(tmpdir, "out.csv")) diff --git a/tests/test_pytorch/test_precision.py b/tests/test_pytorch/test_precision.py index d21e6800..018b618f 100644 --- a/tests/test_pytorch/test_precision.py +++ b/tests/test_pytorch/test_precision.py @@ -15,6 +15,7 @@ import importlib import json import os +import re from contextlib import nullcontext import habana_frameworks.torch.hpex.experimental.transformer_engine as tengine @@ -32,7 +33,6 @@ from pytorch_lightning.demos.boring_classes import BoringDataModule, BoringModel from pytorch_lightning.plugins import MixedPrecision -import re from lightning_habana.pytorch.accelerator import HPUAccelerator from lightning_habana.pytorch.plugins import HPUPrecisionPlugin @@ -116,12 +116,11 @@ class BMPluginActive(BaseBM): def forward(self, x): """Forward.""" if self.trainer.precision == "fp8": - # Tests fp8 is enabled for supported modules. assert tengine.fp8.is_fp8_enabled() + assert not torch.hpu.is_autocast_hpu_enabled() else: assert not tengine.fp8.is_fp8_enabled() - # Test bf16 enabled. - assert torch.hpu.is_autocast_hpu_enabled() + assert torch.hpu.is_autocast_hpu_enabled() return super().forward(x) @@ -193,7 +192,7 @@ def test_hpu_precision_fp8_synapse_version(monkeypatch): monkeypatch.setattr(lightning_habana.utils.imports, "_HPU_SYNAPSE_GREATER_EQUAL_1_14_0", False) with pytest.raises(OSError, match="fp8 training requires `Synapse AI release >= 1.14.0`."): - HPUPrecisionPlugin(device="hpu", precision="fp8") + HPUPrecisionPlugin(precision="fp8") @pytest.mark.skipif(HPUAccelerator.get_device_name() == "GAUDI", reason="fp8 supported on Gaudi2 and above.") @@ -201,7 +200,7 @@ def test_hpu_precision_fp8_synapse_version(monkeypatch): def test_hpu_precision_replace_layerse(replace_layers): """Tests plugin init with replcae_layers.""" model = BaseBM() - plugin = HPUPrecisionPlugin(device="hpu", precision="fp8", replace_layers=replace_layers) + plugin = HPUPrecisionPlugin(precision="fp8", replace_layers=replace_layers) plugin.convert_modules(model) assert replace_layers == any( "habana_frameworks.torch.hpex.experimental.transformer_engine" in m.__module__ for m in model.modules() @@ -222,7 +221,7 @@ def test_hpu_precision_replace_layerse(replace_layers): def test_hpu_precision_convert_modules(inference, quant, expectation, tmpdir): """Test HPUPrecisionPlugin.convert_modules.""" model = BaseBM() - plugin = HPUPrecisionPlugin(device="hpu", precision="fp8") + plugin = HPUPrecisionPlugin(precision="fp8") with expectation: plugin.convert_modules(module=model, inference=inference, quant=quant, fp8_data_path=tmpdir) @@ -233,7 +232,7 @@ def test_hpu_precision_convert_modules(inference, quant, expectation, tmpdir): def test_hpu_precision_fp8_patch(patch_path, tmpdir): """Tests fp8 jsons are patched correctly.""" model = BaseBM() - plugin = HPUPrecisionPlugin(device="hpu", precision="fp8") + plugin = HPUPrecisionPlugin(precision="fp8") patch_path = patch_path if patch_path is None else tmpdir plugin.convert_modules(module=model, inference=True, quant=False, fp8_data_path=patch_path) @@ -277,7 +276,7 @@ def get_fp8_measurement_files(path): seed_everything(42) model = BaseBM() - plugin = HPUPrecisionPlugin(device="hpu", precision="fp8") + plugin = HPUPrecisionPlugin(precision="fp8") plugin.convert_modules(module=model, inference=True, quant=False) trainer = Trainer( @@ -305,7 +304,7 @@ def test_hpu_precision_fp8_inference_quantization(tmpdir): for precision in ["bf16", "fp8"]: seed_everything(42) model = BaseBM() - plugin = HPUPrecisionPlugin(device="hpu", precision=precision) + plugin = HPUPrecisionPlugin(precision=precision) if precision == "fp8": plugin.convert_modules(module=model, inference=True, quant=True) @@ -331,7 +330,7 @@ def test_hpu_precision_fp8_with_ddp_strategy(tmpdir, hpus): """Negative test for fp8 inference not supported with HPUDDPStrategy.""" model = BoringModel() dm = BoringDataModule() - plugin = HPUPrecisionPlugin(device="hpu", precision="fp8") + plugin = HPUPrecisionPlugin(precision="fp8") plugin.convert_modules(module=model, inference=True, quant=False) trainer = Trainer( @@ -351,28 +350,44 @@ def test_hpu_precision_fp8_output(tmpdir): """Test HPUPrecisionPlugin with module containing both bf16 and fp8 operations.""" class FP8InOutDtype(BaseBM): + def __init__(self): + super().__init__() + self.linear = tengine.Linear(32, 2) + def forward(self, x): - # for a module that supports fp8, - # input is downcasted internally to bf16 - # output is in bf16 x = self.layer(x) - assert x.dtype == torch.bfloat16 + assert x.dtype == torch.float32 return x - plugin = HPUPrecisionPlugin(device="hpu", precision="fp8") + plugin = HPUPrecisionPlugin(precision="fp8", replace_layers=True) model = FP8InOutDtype() - model = plugin.convert_modules(model) + plugin.convert_modules(model, inference=False) run_training(tmpdir, model, plugin) @pytest.mark.skipif(HPUAccelerator.get_device_name() != "GAUDI", reason="Negative test for fp8 on Gaudi") -def test_hpu_precision_fp8_on_gaudi(): +@pytest.mark.parametrize( + ("precision", "expectation"), + [ + ( + "fp8", + pytest.raises( + NotImplementedError, match="fp8 not supported: FP8 not supported on Gaudi, Gaudi2 or higher required." + ), + ), + ( + "16-mixed", + pytest.raises( + NotImplementedError, match="fp16 not supported: FP16 not supported on Gaudi, Gaudi2 or higher required." + ), + ), + ], +) +def test_hpu_precision_not_supported_on_gaudi(precision, expectation): """Test fp8 with unsupported Habana device.""" - with pytest.raises( - NotImplementedError, match="fp8 not supported: FP8 not supported on Gaudi, Gaudi2 or higher required." - ): - HPUPrecisionPlugin(device="hpu", precision="fp8") + with expectation: + HPUPrecisionPlugin(precision=precision) def test_hpu_precision_synapse_version(monkeypatch): @@ -381,52 +396,78 @@ def test_hpu_precision_synapse_version(monkeypatch): monkeypatch.setattr(lightning_habana.pytorch.plugins.precision, "_HPU_SYNAPSE_GREATER_EQUAL_1_11_0", False) with pytest.raises(OSError, match="HPU precision plugin requires `Synapse AI release >= 1.11.0`."): - HPUPrecisionPlugin(device="hpu", precision="bf16-mixed") + HPUPrecisionPlugin(precision="bf16-mixed") @pytest.mark.parametrize( ("plugin", "params"), [ - (MixedPrecision, {"device": "hpu", "precision": "bf16-mixed"}), - (HPUPrecisionPlugin, {"device": "hpu", "precision": "bf16-mixed"}), - (HPUPrecisionPlugin, {"device": "hpu", "precision": "bf16"}), - (HPUPrecisionPlugin, {"device": "hpu", "precision": "32-true"}), - (HPUPrecisionPlugin, {"device": "hpu", "precision": "32"}), + ( + MixedPrecision, + {"device": "hpu", "precision": "bf16-mixed"}, + ), + ( + HPUPrecisionPlugin, + {}, + ), + ( + HPUPrecisionPlugin, + {"precision": "bf16-mixed"}, + ), ( HPUPrecisionPlugin, - {"device": "hpu", "precision": "bf16-mixed", "replace_layers": "True", "recipe": "DelayedScaling"}, + {"precision": "bf16"}, ), pytest.param( HPUPrecisionPlugin, - {"device": "hpu", "precision": "fp8"}, + {"precision": "16-mixed"}, + marks=pytest.mark.skipif( + HPUAccelerator.get_device_name() == "GAUDI", reason="fp16 supported on Gaudi2 and above" + ), + ), + ( + HPUPrecisionPlugin, + {"precision": "32-true"}, + ), + ( + HPUPrecisionPlugin, + {"precision": "32"}, + ), + ( + HPUPrecisionPlugin, + {"precision": "bf16-mixed", "replace_layers": "True", "recipe": "DelayedScaling"}, + ), + pytest.param( + HPUPrecisionPlugin, + {"precision": "fp8"}, marks=pytest.mark.skipif( HPUAccelerator.get_device_name() == "GAUDI", reason="fp8 supported on Gaudi2 and above." ), ), pytest.param( HPUPrecisionPlugin, - {"device": "hpu", "precision": "fp8", "replace_layers": "False"}, + {"precision": "fp8", "replace_layers": "False"}, marks=pytest.mark.skipif( HPUAccelerator.get_device_name() == "GAUDI", reason="fp8 supported on Gaudi2 and above." ), ), pytest.param( HPUPrecisionPlugin, - {"device": "hpu", "precision": "fp8", "replace_layers": "True"}, + {"precision": "fp8", "replace_layers": "True"}, marks=pytest.mark.skipif( HPUAccelerator.get_device_name() == "GAUDI", reason="fp8 supported on Gaudi2 and above." ), ), pytest.param( HPUPrecisionPlugin, - {"device": "hpu", "precision": "fp8", "recipe": "DelayedScaling"}, + {"precision": "fp8", "recipe": "DelayedScaling"}, marks=pytest.mark.skipif( HPUAccelerator.get_device_name() == "GAUDI", reason="fp8 supported on Gaudi2 and above." ), ), pytest.param( HPUPrecisionPlugin, - {"device": "hpu", "precision": "fp8", "replace_layers": "True", "recipe": "DelayedScaling"}, + {"precision": "fp8", "replace_layers": "True", "recipe": "DelayedScaling"}, marks=pytest.mark.skipif( HPUAccelerator.get_device_name() == "GAUDI", reason="fp8 supported on Gaudi2 and above." ), @@ -439,7 +480,7 @@ def test_precision_plugin_init(plugin, params): # Common params assert _plugin.device == "hpu" - assert _plugin.precision == params.get("precision") + assert _plugin.precision == params.get("precision", "32-true") # HPUPrecision specific params if isinstance(_plugin, HPUPrecisionPlugin): @@ -453,60 +494,66 @@ def test_precision_plugin_init(plugin, params): assert _plugin.recipe is None +def test_precision_plugin_invalid_precision_init(): + """Tests precision plugins are instantiated correctly.""" + with pytest.raises( + ValueError, + match=re.escape( + "`Trainer(accelerator='hpu', precision='f16-mixed')` is not supported. " + f"`precision` must be one of: {supported_precision}." + ), + ): + HPUPrecisionPlugin(precision="f16-mixed") + + @pytest.mark.parametrize( - ("precision", "expectation"), + ("precision"), [ - ("32", nullcontext()), - ("32-true", nullcontext()), - ("bf16", nullcontext()), - ("bf16-mixed", nullcontext()), + "32", + "32-true", + "bf16", + "bf16-mixed", pytest.param( - "fp8", - nullcontext(), + "fp16", marks=pytest.mark.skipif( - HPUAccelerator.get_device_name() == "GAUDI", reason="fp8 supported on Gaudi2 and above." + HPUAccelerator.get_device_name() == "GAUDI", reason="fp16 supported on Gaudi2 and above." ), ), - ( - "fp16", - pytest.raises( - ValueError, - match=re.escape( - f"`Trainer(accelerator='hpu', precision='fp16')` is not supported. " - f"`precision` must be one of: {supported_precision}." - ), + pytest.param( + "fp8", + marks=pytest.mark.skipif( + HPUAccelerator.get_device_name() == "GAUDI", reason="fp8 supported on Gaudi2 and above." ), ), ], ) -def test_hpu_precision_supported_precision(precision, expectation): +def test_hpu_precision_supported_precision(precision): """Tests supported precisions with HPU Precision Plugin.""" - with expectation: - HPUPrecisionPlugin(device="hpu", precision=precision) + with nullcontext(): + HPUPrecisionPlugin(precision=precision) @pytest.mark.parametrize( ("plugin", "params"), [ - (MixedPrecision, {"device": "hpu", "precision": "bf16-mixed"}), - (HPUPrecisionPlugin, {"device": "hpu", "precision": "bf16-mixed"}), - pytest.param( + ( + MixedPrecision, + {"device": "hpu", "precision": "bf16-mixed"}, + ), + ( HPUPrecisionPlugin, - {"device": "hpu", "precision": "fp8"}, - marks=pytest.mark.skipif( - HPUAccelerator.get_device_name() == "GAUDI", reason="fp8 supported on Gaudi2 and above." - ), + {"precision": "bf16-mixed"}, ), pytest.param( HPUPrecisionPlugin, - {"device": "hpu", "precision": "fp8", "replace_layers": "False"}, + {"precision": "fp16"}, marks=pytest.mark.skipif( - HPUAccelerator.get_device_name() == "GAUDI", reason="fp8 supported on Gaudi2 and above." + HPUAccelerator.get_device_name() == "GAUDI", reason="fp16 supported on Gaudi2 and above." ), ), pytest.param( HPUPrecisionPlugin, - {"device": "hpu", "precision": "fp8", "replace_layers": "True", "recipe": "DelayedScaling"}, + {"precision": "fp8"}, marks=pytest.mark.skipif( HPUAccelerator.get_device_name() == "GAUDI", reason="fp8 supported on Gaudi2 and above." ), @@ -518,7 +565,7 @@ def test_precision_plugin_fit(tmpdir, plugin, params): class TestCallback(Callback): def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None: - assert trainer.precision == params.get("precision") + assert trainer.precision == params.get("precision", "32-true") raise SystemExit seed_everything(42) @@ -537,11 +584,19 @@ def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> Non (BMAutocastCM, None, None), (BMAutocastDecorator, None, None), (BMPluginActive, MixedPrecision, {"device": "hpu", "precision": "bf16-mixed"}), - (BMPluginActive, HPUPrecisionPlugin, {"device": "hpu", "precision": "bf16-mixed"}), + (BMPluginActive, HPUPrecisionPlugin, {"precision": "bf16-mixed"}), + pytest.param( + BMPluginActive, + HPUPrecisionPlugin, + {"precision": "16-mixed"}, + marks=pytest.mark.skipif( + HPUAccelerator.get_device_name() == "GAUDI", reason="fp8 supported on Gaudi2 and above." + ), + ), pytest.param( BMPluginActive, HPUPrecisionPlugin, - {"device": "hpu", "precision": "fp8"}, + {"precision": "fp8"}, marks=pytest.mark.skipif( HPUAccelerator.get_device_name() == "GAUDI", reason="fp8 supported on Gaudi2 and above." ), @@ -552,6 +607,7 @@ def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> Non "TorchAutocast_Decorator", "MixedPrecision", "HPUPrecisionPlugin_bf16", + "HPUPrecisionPlugin_fp16", "HPUPrecisionPlugin_fp8", ], ) @@ -568,42 +624,41 @@ def test_mixed_precision_autocast_to_precision_active(tmpdir, model, plugin, par def test_mixed_precision_compare_accuracy(tmpdir): """Test and compare accuracy for mixed precision training methods.""" model_plugin_list = [ + (BaseBM, None, None), # float32 baseline (BMAutocastCM, None, None), (BMAutocastDecorator, None, None), (BaseBM, MixedPrecision, {"device": "hpu", "precision": "bf16-mixed"}), - (BaseBM, HPUPrecisionPlugin, {"device": "hpu", "precision": "bf16-mixed"}), + (BaseBM, HPUPrecisionPlugin, {"precision": "bf16-mixed"}), ] is_gaudi = HPUAccelerator().get_device_name() == "GAUDI" if not is_gaudi: model_plugin_list.append( + (BaseBM, HPUPrecisionPlugin, {"precision": "16-mixed"}), ( BaseBM, HPUPrecisionPlugin, - {"device": "hpu", "precision": "fp8", "replace_layers": "True", "recipe": "DelayedScaling"}, - ) + { + "precision": "fp8", + "replace_layers": True, + }, + ), ) loss_list = [] for item in model_plugin_list: seed_everything(42) model, plugin, params = item + model = model() _plugin = plugin(**params) if plugin and params else None - BoringDataModule() if isinstance(_plugin, HPUPrecisionPlugin) and params.get("precision") == "fp8": - model = _plugin.convert_modules(model()) - else: - model = model() - loss_list.append(run_training(tmpdir, model, _plugin)) + model = _plugin.convert_modules(model) + loss_list.append(torch.tensor(run_training(tmpdir, model, _plugin))) - # Assert loss is same for all instances except fp8 - assert all(x == loss_list[0] for x in loss_list[:-1]), list(zip(model_plugin_list, loss_list)) - if not is_gaudi: - # Assert loss is close between baseline and fp8 - assert torch.allclose(torch.tensor(loss_list[0]), torch.tensor(loss_list[-1]), rtol=0.1, atol=0.1) + assert all(torch.allclose(loss_list[0], loss_tensor, rtol=1e-2, atol=1e-2) for loss_tensor in loss_list[1:]) @pytest.mark.skipif(HPUAccelerator.get_device_name() == "GAUDI", reason="fp8 supported on Gaudi2 and above.") -@pytest.mark.parametrize("precision", ["bf16-mixed", "fp8"]) +@pytest.mark.parametrize("precision", ["32-true", "fp8"]) def test_hpu_precision_active_with_te_module(tmpdir, precision): """Tests that fp8 precision is only active when HPUPrecision plugin is init with fp8, even if module from. @@ -621,8 +676,6 @@ def __init__(self): def training_step(self, batch, batch_idx): """Training step.""" - # torch.autocast is enabled for both bf16 and fp8 - assert torch.hpu.is_autocast_hpu_enabled() # fp8 training is only enabled when precision is fp8, # even if module used is from transformer engine. if precision == "fp8": @@ -639,7 +692,7 @@ def configure_optimizers(self): seed_everything(42) model = TestModel() - _plugin = HPUPrecisionPlugin(device="hpu", precision=precision) + _plugin = HPUPrecisionPlugin(precision=precision) # HPUPrecisionPlugin.convert_modules not reqiored as self.layer is already a transformer engine module trainer = Trainer( default_root_dir=tmpdir, @@ -650,3 +703,121 @@ def configure_optimizers(self): plugins=_plugin, ) trainer.fit(model) + + +@pytest.mark.skipif(HPUAccelerator.get_device_name() == "GAUDI", reason="Native int64 supported on Gaudi2 and above.") +@pytest.mark.standalone_only() +@pytest.mark.parametrize( + ("int64_support", "expectation"), + [ + ("False", pytest.raises(RuntimeError, match="Error when trying to cast Long to Int")), + ("True", nullcontext()), + ], +) +def test_hpu_precision_long_type(int64_support, expectation): + """Tests native support for long tensor on G2.""" + os.environ["PT_HPU_LAZY_MODE"] = "0" + os.environ["PT_ENABLE_INT64_SUPPORT"] = int64_support + with expectation: + torch.tensor(torch.iinfo(torch.int64).max, dtype=torch.int64, device=torch.device("hpu")) + + +@pytest.mark.parametrize( + "dtype", + [ + torch.int8, + torch.int16, + torch.int32, + torch.int64, + pytest.param( + torch.float8_e5m2, + marks=pytest.mark.skipif( + HPUAccelerator.get_device_name() == "GAUDI", reason="fp8 supported on Gaudi2 and above" + ), + ), + pytest.param( + torch.float8_e4m3fn, + marks=pytest.mark.skipif( + HPUAccelerator.get_device_name() == "GAUDI", reason="fp8 supported on Gaudi2 and above" + ), + ), + pytest.param( + torch.float16, + marks=pytest.mark.skipif( + HPUAccelerator.get_device_name() == "GAUDI", reason="fp16 supported on Gaudi2 and above" + ), + ), + torch.float32, + torch.bfloat16, + torch.bool, + ], +) +def test_hpu_supported_dtypes_tensor_creation(dtype): + """Tests tensors with supported dtypes can be created on hpu.""" + with nullcontext(): + torch.tensor(42, dtype=dtype, device=torch.device("hpu")) + + +@pytest.mark.parametrize("intype", [torch.int8, torch.int16, torch.int32, torch.int64, torch.bfloat16, torch.float32]) +def test_hpu_dtypes_op_output_dtype(intype): + """Test dtypes type promotion.""" + t1 = torch.tensor([[1, 2], [2, 1]], dtype=intype, device=torch.device("hpu")) + t2 = torch.tensor([[2, 1], [1, 2]], dtype=intype, device=torch.device("hpu")) + + # Operands are promoted as per torch.promote_types + t3 = t1.mm(t2) + t4 = t1.add(t2) + t5 = t1.div(t2) + assert t3.dtype == torch.promote_types(t1.dtype, t2.dtype) + assert t4.dtype == torch.promote_types(t1.dtype, t2.dtype) + # integer div always promoted to float32. + assert ( + t5.dtype == torch.promote_types(t1.dtype, t2.dtype) + if t1.is_floating_point() or t2.is_floating_point() + else torch.float32 + ) + + # torch.autocast only affects torch.float16, torch.bfloat16, torch.float32 + with torch.autocast(device_type="hpu", dtype=torch.bfloat16): + # Computes in lower precision if operands in (bf16, fp32) else operand dtype + t3 = t1.mm(t2) + # Promoted to highest dtype between operands + t4 = t1.add(t2) + # Runs in fp32 + t5 = t1.div(t2) + + assert t3.dtype == intype if intype not in (torch.bfloat16, torch.float32) else torch.bfloat16 + assert t4.dtype == intype + assert t5.dtype == torch.float32 + + +@pytest.mark.parametrize("intype", [torch.int8, torch.int16, torch.int32, torch.int64]) +def test_hpu_dtypes_compare_cpu_accuracy(intype, tmpdir): + """Test dtypes type promotion.""" + + class TestModel(BaseBM): + def forward(self, x): + # Perform some operations in given dtype + x = x.to(intype) + identity = torch.eye(x.shape[1], device=x.device, dtype=intype) + x = torch.addmm(x, x, identity) + + return super().forward(x.to(torch.float32)) + + metrics = [] + for accelerator in [HPUAccelerator(), "cpu"]: + seed_everything(42) + trainer = Trainer( + default_root_dir=tmpdir, + accelerator=accelerator, + devices=1, + strategy=SingleHPUStrategy() if isinstance(accelerator, HPUAccelerator) else "auto", + fast_dev_run=1, + ) + + trainer.fit(TestModel()) + metrics.append(trainer.logged_metrics) + + # Compare metrics between cpu and hpu + assert torch.equal(metrics[0].get("train_loss"), metrics[1].get("train_loss")) + assert torch.equal(metrics[0].get("val_loss"), metrics[1].get("val_loss"))