Skip to content

Commit

Permalink
Update additional dtype support (#194)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
ankitgola005 and pre-commit-ci[bot] committed Jun 19, 2024
1 parent d35879f commit 61f9069
Show file tree
Hide file tree
Showing 11 changed files with 351 additions and 133 deletions.
9 changes: 5 additions & 4 deletions .azure/hpu-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,8 @@ jobs:
displayName: 'HPU General tests'
- bash: |
python -m pytest -sv tests/test_pytorch/test_compile.py \
--hpus 1 --junitxml=hpu_compile_test-results.xml
python -m pytest -sv tests/test_pytorch/test_compile.py \
--hpus 1 --junitxml=hpu_compile_test-results.xml
env:
PT_HPU_LAZY_MODE: 0
displayName: 'HPU torch compile tests'
Expand All @@ -142,15 +142,16 @@ jobs:
- bash: |
bash tests/run_standalone_tests.sh --hpus 1 -m standalone_only -f \
tests/test_pytorch/test_precision.py
tests/test_pytorch/test_precision.py \
tests/test_pytorch/test_dynamic_shapes.py
displayName: Standalone-only single card tests
- bash: |
# TODO: Revert to default mode of execution once Ligitning 2.3.x is released
bash tests/run_standalone_tests.sh --hpus 2 -f \
tests/test_pytorch/test_accelerator.py \
tests/test_pytorch/test_compile.py \
tests/test_pytorch/test_profiler.py \
tests/test_pytorch/test_profiler.py
displayName: 'Multi card(2) HPU test'
- bash: pip install ".[examples]"
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Added support for additional dtypes ([#194](https://github.com/Lightning-AI/lightning-Habana/pull/194))
-

### Changed
Expand Down
5 changes: 3 additions & 2 deletions docs/source/intermediate.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ HPUPrecisionPlugin, :class:`~lightning_habana.pytorch.plugins.precision.HPUPreci

In addition to the default settings, you can choose to override these defaults and provide your own BF16 (LOWER_LIST) and FP32 (FP32_LIST)
The `LOWER_LIST` and `FP32_LIST` environment variables must be set before any instances begin.
HPUPrecisionPlugin supports `bf16-mixed` and `16-mixed` for mixed precision training. It is advised to use `bf16-mixed` over `16-mixed` where possible.

The following is an excerpt from an MNIST example implemented on a single HPU.

Expand Down Expand Up @@ -156,7 +157,7 @@ The plugin accepts following args for the fp8 training:
1. Import `transformer_engine` and replace your modules with `transformer_engine` modules in the model.
2. Wrap the forward pass of the training with `fp8_autocast`.

Users may still use `HPUPrecisionPlugin` to train in `bf16-mixed` precision for modules not supported by `transformer_engine`.
Users may still use `HPUPrecisionPlugin` to train in mixed precision for modules not supported by `transformer_engine`.


.. note::
Expand Down Expand Up @@ -254,9 +255,9 @@ Refer to `Supported JSON Config File Options <https://docs.habana.ai/en/latest/P
**Limitations**

1. Measurement mode and Quantization mode cannot be run in single process. Please run in measurement mode first, followed by quantization mode. Measurement data may be re-used for inference in quantiztion mode for the given model.
2. Only single card inference is currently supported. Support for multiple cards will be enabled in a future release.

For more details, refer to `Inference Using FP8 <https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_FP8.html>`__.
For a list of data types supported with HPU, refer to `PyTorch Support Matrix <https://docs.habana.ai/en/v1.15.1/PyTorch/Reference/PyTorch_Support_Matrix.html>`__.

----

Expand Down
19 changes: 13 additions & 6 deletions src/lightning_habana/pytorch/plugins/deepspeed_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,17 +58,24 @@


class HPUDeepSpeedPrecisionPlugin(HPUPrecisionPlugin):
"""Plugin that enables bfloat support on HPUs.
"""Plugin that enables mixed precision support on HPUs.
Args:
precision: to enable ``torch.bfloat16`` (``'bf16-mixed'``).
device: The device for ``torch.autocast``.
precision (_PRECISION_INPUT, optional): Precision input. Defaults to "32-true".
recipe (Optional[Union[Mapping[str, Any], "DelayedScaling"]], optional):
recipe for fp8 training. Defaults to None.
replace_layers (bool, optional): Replace module with transformer engine equivalent. Defaults to False.
Raises:
OSError: Unsupported Synapse version.
ValueError: Invalid precision value(s).
NotImplementedError: fp8 not available.
"""

def __init__(
self,
precision: _PRECISION_INPUT,
precision: _PRECISION_INPUT = "32-true",
device: str = "hpu",
recipe: Optional[Union[Mapping[str, Any], "DelayedScaling"]] = None,
replace_layers: bool = False,
Expand All @@ -78,7 +85,7 @@ def __init__(
"To use the `HPUDeepSpeedPrecisionPlugin`, you must have hpu DeepSpeed installed."
" Install it by running `pip install git+https://github.com/HabanaAI/DeepSpeed.git@1.16.0`."
)
super().__init__(device=device, precision=precision, recipe=recipe, replace_layers=replace_layers)
super().__init__(precision=precision, recipe=recipe, replace_layers=replace_layers)

def backward(
self,
Expand Down Expand Up @@ -153,7 +160,7 @@ def _enable_fp8_inference(
ds_inference_kwargs["dtype"] = torch.bfloat16
assert ds_inference_kwargs["dtype"] in (torch.bfloat16, torch.float)

htcore.hpu_set_env()
htcore.quantization.hpu_set_inference_env()
module = module.to("hpu")

module = deepspeed.init_inference(module, **ds_inference_kwargs)
Expand Down
72 changes: 45 additions & 27 deletions src/lightning_habana/pytorch/plugins/precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
_HABANA_FRAMEWORK_AVAILABLE,
_HABANA_QUANTIZATION_TOOLKIT_AVAILABLE,
is_fp8_available,
is_fp16_available,
modify_fp8_json,
)

Expand All @@ -37,7 +38,15 @@
else:
raise ModuleNotFoundError("You are missing `lightning` or `pytorch-lightning` package, please install it.")

_PRECISION_INPUT = Literal["32", "32-true", "bf16", "bf16-mixed", "fp8"]
_PRECISION_INPUT = Literal["32", "32-true", "bf16", "bf16-mixed", "fp8", "16-mixed"]

_AMP_DICT = {
"32": torch.float32,
"32-true": torch.float32,
"bf16": torch.bfloat16,
"bf16-mixed": torch.bfloat16,
"16-mixed": torch.float16,
}

if _HPU_SYNAPSE_GREATER_EQUAL_1_14_0 and _HABANA_FRAMEWORK_AVAILABLE:
# Required for training in fp8 using habana transformer engine
Expand All @@ -61,14 +70,21 @@ class HPUPrecisionPlugin(Precision):
"""Plugin that enables mixed precision support on HPUs.
Args:
precision: to enable ``torch.bfloat16`` (``'bf16-mixed'``).
device: The device for ``torch.autocast``.
precision (_PRECISION_INPUT, optional): Precision input. Defaults to "32-true".
recipe (Optional[Union[Mapping[str, Any], "DelayedScaling"]], optional):
recipe for fp8 training. Defaults to None.
replace_layers (bool, optional): Replace module with transformer engine equivalent. Defaults to False.
Raises:
OSError: Unsupported Synapse version.
ValueError: Invalid precision value(s).
NotImplementedError: fp8 not available.
"""

def __init__(
self,
precision: _PRECISION_INPUT,
precision: _PRECISION_INPUT = "32-true",
device: str = "hpu",
recipe: Optional[Union[Mapping[str, Any], "DelayedScaling"]] = None,
replace_layers: bool = False,
Expand All @@ -85,13 +101,18 @@ def __init__(
self.replace_layers = False
self.device = device

if any([recipe, replace_layers]) and precision != "fp8":
if any([recipe, replace_layers]) and self.precision != "fp8":
rank_zero_warn(f"Precision is not 'fp8'. Params {recipe=} and {replace_layers=} will not be set.")

self.recipe = None
self.fp8_train_available = False
self.fp8_inference_available = False

if self.precision == "16-mixed":
fp16_available, reason_no_fp16 = is_fp16_available()
if not fp16_available:
raise NotImplementedError(f"fp16 not supported: {reason_no_fp16}.")

if self.precision == "fp8":
fp8_available, reason_no_fp8 = is_fp8_available()
if not fp8_available:
Expand All @@ -102,7 +123,7 @@ def __init__(
self.replace_layers = replace_layers

rank_zero_info(
f"fp8 training available: {self.fp8_train_available}. "
f"fp8 training available: {self.fp8_train_available}."
f"fp8 inference available: {self.fp8_inference_available}."
)

Expand All @@ -119,7 +140,6 @@ def _setup_fp8_quant_config(self, quant: bool = True, fp8_data_path: Optional[st
file_path=fp8_json,
patch={
"dump_stats_path": os.path.join(fp8_data_path, "hqt"),
"dump_stats_xlsx_path": os.path.join(fp8_data_path, "hqt", "fp8stats.xlsx"),
},
)
os.environ["QUANT_CONFIG"] = fp8_json
Expand Down Expand Up @@ -177,18 +197,31 @@ def convert_modules(
quant: bool = True,
fp8_data_path: Optional[str] = None,
) -> torch.nn.Module:
"""Enable support for fp8."""
if inference is True and self.fp8_inference_available:
"""Convert modules for FP8 precision.
Args:
module (torch.nn.Module): Module to convert
inference (bool, optional): Convert module for inference (True) / Training (False). Defaults to False.
quant (bool, optional): Convert module for measurement (False) / Quantization (True) during inference.
Defaults to True.
fp8_data_path (Optional[str], optional): Path to dump fp8 inference measurement data. Defaults to None.
Returns:
torch.nn.Module: FP8 enabled module
"""
assert self.precision == "fp8", "HPUPrecisionPlugin.convert_modules() should only be used with precision=`fp8`."
if inference and self.fp8_inference_available:
self._enable_fp8_inference(module, quant, fp8_data_path)
if self.fp8_train_available is True and self.replace_layers is True and inference is False:
if self.fp8_train_available and self.replace_layers and not inference:
self._enable_fp8_training(module)
return module

def autocast_context_manager(self) -> Union[ContextManager[Any], torch.autocast]:
"""Return Autocast context manager."""
if self.fp8_train_available:
return _nested_precision_cm(fp8_enabled=(self.precision == "fp8"), recipe=self.recipe)
return torch.autocast(device_type="hpu", dtype=torch.bfloat16, enabled=True)
return tengine.fp8_autocast(enabled=True, fp8_recipe=self.recipe)
return torch.autocast(device_type="hpu", dtype=_AMP_DICT[self.precision], enabled=True)

@contextmanager
def forward_context(self) -> Generator[None, None, None]:
Expand All @@ -215,18 +248,3 @@ def _replace_layers(module: torch.nn.Module) -> None:
module.__setattr__(name, replacement)
else:
_replace_layers(child)


@contextmanager
def _nested_precision_cm(
fp8_enabled: bool, recipe: Optional[Union[Mapping[str, Any], "DelayedScaling"]]
) -> Generator[Any, Any, Any]:
"""CM to nest fp8 precision with torch.autocast.
This enables the ops that do not support fp8 to run with torch autocast.
"""
with torch.autocast(device_type="hpu", dtype=torch.bfloat16, enabled=True), tengine.fp8_autocast(
enabled=fp8_enabled, fp8_recipe=recipe
):
yield
4 changes: 2 additions & 2 deletions src/lightning_habana/pytorch/strategies/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -933,15 +933,15 @@ def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None:

def on_test_end(self) -> None:
if self.precision_plugin.precision == "fp8" and self.precision_plugin.fp8_inference_available:
from quantization_toolkit import habana_quantization_toolkit # noqa
import habana_quantization_toolkit

habana_quantization_toolkit.finish_measurements(self.model)
htcore.quantization.hpu_teardown_inference_env()
return super().on_test_end()

def on_predict_end(self) -> None:
if self.precision_plugin.precision == "fp8" and self.precision_plugin.fp8_inference_available:
from quantization_toolkit import habana_quantization_toolkit # noqa
import habana_quantization_toolkit

habana_quantization_toolkit.finish_measurements(self.model)
htcore.quantization.hpu_teardown_inference_env()
Expand Down
4 changes: 2 additions & 2 deletions src/lightning_habana/pytorch/strategies/single.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,15 +126,15 @@ def predict_step(self, *args: Any, **kwargs: Any) -> Any:

def on_test_end(self) -> None:
if self.precision_plugin.precision == "fp8" and self.precision_plugin.fp8_inference_available:
from quantization_toolkit import habana_quantization_toolkit # noqa
import habana_quantization_toolkit

habana_quantization_toolkit.finish_measurements(self.model)
htcore.quantization.hpu_teardown_inference_env()
return super().on_test_end()

def on_predict_end(self) -> None:
if self.precision_plugin.precision == "fp8" and self.precision_plugin.fp8_inference_available:
from quantization_toolkit import habana_quantization_toolkit # noqa
import habana_quantization_toolkit

habana_quantization_toolkit.finish_measurements(self.model)
htcore.quantization.hpu_teardown_inference_env()
Expand Down
10 changes: 10 additions & 0 deletions src/lightning_habana/utils/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,16 @@ def is_fp8_available() -> Tuple[bool, str]:
return tengine.fp8.is_fp8_available()


@lru_cache
def is_fp16_available() -> Tuple[bool, str]:
"""Returns a bool indicating if fp16 is available."""
if not _HABANA_FRAMEWORK_AVAILABLE:
raise OSError("Habana Frameworks required for training on Habana devices.")
if torch_hpu.get_device_name() == "GAUDI":
return False, "FP16 not supported on Gaudi, Gaudi2 or higher required."
return True, ""


def modify_fp8_json(file_path: str, patch: dict) -> None:
"""Edit a specific entry in a JSON file.
Expand Down
15 changes: 13 additions & 2 deletions tests/test_pytorch/test_deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,8 +825,19 @@ def run_training(tmpdir, model, plugin, strategy):
return trainer.callback_metrics["val_loss"], trainer.callback_metrics["train_loss"]

precision_plugin_params_list = [
({"device": "hpu", "precision": "bf16-mixed"}),
({"device": "hpu", "precision": "fp8", "replace_layers": True, "recipe": recipe.DelayedScaling()}),
({"precision": "bf16-mixed"}),
pytest.param(
{"precision": "16-mixed"},
marks=pytest.mark.skipif(
HPUAccelerator.get_device_name() == "GAUDI", reason="fp16 supported on Gaudi2 and above"
),
),
pytest.param(
{"precision": "fp8", "replace_layers": True, "recipe": recipe.DelayedScaling()},
marks=pytest.mark.skipif(
HPUAccelerator.get_device_name() == "GAUDI", reason="fp8 supported on Gaudi2 and above"
),
),
]

loss_list = []
Expand Down
6 changes: 2 additions & 4 deletions tests/test_pytorch/test_dynamic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import csv
import os

import pytest
import torch
from habana_frameworks.torch.hpu.metrics import metric_global
from habana_frameworks.torch.utils.experimental import detect_recompilation_auto_model
Expand Down Expand Up @@ -78,6 +79,7 @@ def test_dynamic_shapes_graph_compiler(tmpdir, hpus, monkeypatch):
assert cached_compiles[0] <= default_compiles[0]


@pytest.mark.standalone_only()
def test_dynamic_shapes_auto_detect_recompilations(tmpdir):
"""Test auto_detect_recompilations tool."""

Expand All @@ -95,10 +97,6 @@ def calculate_auto_detect_total_recompile_counts(csv_file_path):
print(f"Error: CSV file not found: {csv_file_path}")
return None

# Close dist pg if initialized.
if torch.distributed.is_available() and torch.distributed.is_initialized():
torch.distributed.destroy_process_group()

seed_everything(42)
model = DynamicOpsBoringModel()
net = detect_recompilation_auto_model(model, csv_out=os.path.join(tmpdir, "out.csv"))
Expand Down
Loading

0 comments on commit 61f9069

Please sign in to comment.