Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update HPUPrecisionPlugin fp8 training with Transformer engine #195

Merged
merged 3 commits into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Updated to common `hpu_backend` interface for compile support. ([#183](https://github.com/Lightning-AI/lightning-Habana/pull/183))
- Updated to Intel Gaudi software Release 1.16.0 ([#191](https://github.com/Lightning-AI/lightning-Habana/pull/191))
- Updated HQT APIs to be in accordance with Intel Gaudi software Release 1.16.0 ([#192](https://github.com/Lightning-AI/lightning-Habana/pull/192))
- Updated HPUPrecisionPlugin for fp8 based on Intel Gaudi software Release 1.16.0. ([#195](https://github.com/Lightning-AI/lightning-Habana/pull/195))

### Fixed

Expand Down
7 changes: 3 additions & 4 deletions docs/source/intermediate.rst
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,9 @@ Lightning supports fp8 training using HPUPrecisionPlugin, :class:`~lightning_hab

fp8 training is only available on Gaudi2 and above. Output from fp8 supported modules is in `torch.bfloat16`.

The plugin accepts following args for the fp8 training:
For fp8 training, call plugin.convert_modules(). The function accepts following args for the fp8 training:

1. `replace_layers` : Set `True` to let the plugin replace `torch.nn.Modules` with `transformer_engine` equivalent modules. You can directly import and use modules from `transformer_engine` as well.

2. `recipe` : fp8 recipe used in training.

.. code-block:: python
Expand All @@ -135,10 +134,10 @@ The plugin accepts following args for the fp8 training:
model = BoringModel()

# init the precision plugin for fp8 training.
plugin = HPUPrecisionPlugin(precision="fp8", replace_layers=True, recipe=recipe.DelayedScaling())
plugin = HPUPrecisionPlugin(precision="fp8")

# Replace torch.nn.Modules with transformer engine equivalent modules
plugin.convert_modules(model)
plugin.convert_modules(model, replace_layers=True, recipe=recipe.DelayedScaling())

# Initialize a trainer with HPUPrecisionPlugin
trainer = Trainer(
Expand Down
2 changes: 1 addition & 1 deletion examples/pytorch/mnist_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def run_training(run_type, options, model, data_module, plugin):
model, data_module = get_model(_run_type)
plugin = get_plugins(_run_type)
if _run_type == "fp8_training":
plugin.convert_modules(model)
plugin.convert_modules(model, replace_layers=True, inference=False)

if _run_type == "fp8_inference_measure":
plugin.convert_modules(model, inference=True, quant=False)
Expand Down
19 changes: 8 additions & 11 deletions src/lightning_habana/pytorch/plugins/deepspeed_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,30 +62,25 @@ class HPUDeepSpeedPrecisionPlugin(HPUPrecisionPlugin):

Args:
precision (_PRECISION_INPUT, optional): Precision input. Defaults to "32-true".
recipe (Optional[Union[Mapping[str, Any], "DelayedScaling"]], optional):
recipe for fp8 training. Defaults to None.
replace_layers (bool, optional): Replace module with transformer engine equivalent. Defaults to False.

Raises:
OSError: Unsupported Synapse version.
ValueError: Invalid precision value(s).
NotImplementedError: fp8 not available.
ValueError: Invalid precision value.
NotImplementedError: fp8 / fp16 not available.

"""

def __init__(
self,
precision: _PRECISION_INPUT = "32-true",
device: str = "hpu",
recipe: Optional[Union[Mapping[str, Any], "DelayedScaling"]] = None,
replace_layers: bool = False,
) -> None:
if not _HPU_DEEPSPEED_AVAILABLE:
raise MisconfigurationException(
"To use the `HPUDeepSpeedPrecisionPlugin`, you must have hpu DeepSpeed installed."
" Install it by running `pip install git+https://github.com/HabanaAI/DeepSpeed.git@1.16.0`."
)
super().__init__(precision=precision, recipe=recipe, replace_layers=replace_layers)
super().__init__(device=device, precision=precision)

def backward(
self,
Expand Down Expand Up @@ -170,13 +165,15 @@ def convert_modules(
self,
module: torch.nn.Module,
inference: bool = False,
replace_layers: bool = False,
recipe: Optional[Union[Mapping[str, Any], "DelayedScaling"]] = None,
quant: bool = True,
fp8_data_path: Optional[str] = None,
ds_inference_kwargs: Optional[dict] = None,
) -> torch.nn.Module:
"""Enable support for fp8."""
if inference is True and self.fp8_inference_available:
if inference and self.fp8_inference_available:
self._enable_fp8_inference(module, quant, fp8_data_path, ds_inference_kwargs)
if self.fp8_train_available is True and self.replace_layers is True and inference is False:
self._enable_fp8_training(module)
if not inference and self.fp8_train_available:
self._enable_fp8_training(module, replace_layers, recipe)
return module
85 changes: 47 additions & 38 deletions src/lightning_habana/pytorch/plugins/precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@
)

if module_available("lightning"):
from lightning.fabric.utilities.rank_zero import rank_zero_info, rank_zero_warn
from lightning.fabric.utilities.rank_zero import rank_zero_info
from lightning.pytorch.plugins.precision import Precision
elif module_available("pytorch_lightning"):
from pytorch_lightning.plugins.precision import Precision
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.rank_zero import rank_zero_info
else:
raise ModuleNotFoundError("You are missing `lightning` or `pytorch-lightning` package, please install it.")

Expand Down Expand Up @@ -71,23 +71,18 @@ class HPUPrecisionPlugin(Precision):

Args:
precision (_PRECISION_INPUT, optional): Precision input. Defaults to "32-true".
recipe (Optional[Union[Mapping[str, Any], "DelayedScaling"]], optional):
recipe for fp8 training. Defaults to None.
replace_layers (bool, optional): Replace module with transformer engine equivalent. Defaults to False.

Raises:
OSError: Unsupported Synapse version.
ValueError: Invalid precision value(s).
NotImplementedError: fp8 not available.
NotImplementedError: fp8 / fp16 not available.

"""

def __init__(
self,
precision: _PRECISION_INPUT = "32-true",
device: str = "hpu",
recipe: Optional[Union[Mapping[str, Any], "DelayedScaling"]] = None,
replace_layers: bool = False,
) -> None:
if not _HPU_SYNAPSE_GREATER_EQUAL_1_11_0:
raise OSError("HPU precision plugin requires `Synapse AI release >= 1.11.0`.")
Expand All @@ -97,14 +92,11 @@ def __init__(
f"`Trainer(accelerator='hpu', precision={precision!r})` is not supported."
f" `precision` must be one of: {supported_precision}."
)
self.precision = precision
self.replace_layers = False
self.device = device
self.precision = precision

if any([recipe, replace_layers]) and self.precision != "fp8":
rank_zero_warn(f"Precision is not 'fp8'. Params {recipe=} and {replace_layers=} will not be set.")

self.recipe = None
self.recipe: Union[Mapping[str, Any], "DelayedScaling"] = None
self.replace_layers = False
self.fp8_train_available = False
self.fp8_inference_available = False

Expand All @@ -117,10 +109,8 @@ def __init__(
fp8_available, reason_no_fp8 = is_fp8_available()
if not fp8_available:
raise NotImplementedError(f"fp8 not supported: {reason_no_fp8}.")
self.recipe = recipe
self.fp8_train_available = fp8_available
self.fp8_inference_available = fp8_available and _HABANA_QUANTIZATION_TOOLKIT_AVAILABLE
self.replace_layers = replace_layers

rank_zero_info(
f"fp8 training available: {self.fp8_train_available}."
Expand Down Expand Up @@ -177,44 +167,63 @@ def _setup_fp8_inference_modules(
print("quantization_toolkit not found. Please install it using `pip install habana_quantization_toolkit`.")
raise e

def _enable_fp8_training(self, module: torch.nn.Module) -> None:
def _enable_fp8_training(
self,
module: torch.nn.Module,
replace_layers: bool = False,
recipe: Optional[Union[Mapping[str, Any], "DelayedScaling"]] = None,
) -> None:
"""Convert module for fp8 training."""
# In case model already contains a transformer engine modules,
# assume user responsibility for conversion of required layers.
if any(
"habana_frameworks.torch.hpex.experimental.transformer_engine" in m.__module__ for m in module.modules()
):
rank_zero_info(
f"Module {module} already contains transformer engine equivalent modules. Skipping conversion"
)
else:
_replace_layers(module)
self.recipe = recipe
if replace_layers:
# In case model already contains a transformer engine modules,
# assume user responsibility for conversion of required layers.
if any(
"habana_frameworks.torch.hpex.experimental.transformer_engine" in m.__module__ for m in module.modules()
):
rank_zero_info(
f"Module {module} already contains transformer engine equivalent modules. Skipping conversion"
)
else:
_replace_layers(module)

def convert_modules(
self,
module: torch.nn.Module,
inference: bool = False,
replace_layers: bool = False,
recipe: Optional[Union[Mapping[str, Any], "DelayedScaling"]] = None,
quant: bool = True,
fp8_data_path: Optional[str] = None,
) -> torch.nn.Module:
"""Convert modules for FP8 precision.
"""Convert modules for fp8 precision.

Args:
module (torch.nn.Module): Module to convert
inference (bool, optional): Convert module for inference (True) / Training (False). Defaults to False.
quant (bool, optional): Convert module for measurement (False) / Quantization (True) during inference.
Defaults to True.
fp8_data_path (Optional[str], optional): Path to dump fp8 inference measurement data. Defaults to None.
module (torch.nn.Module): module to convert
inference (bool, optional): prepare modules for inference (True) / training (False). Defaults to False.
replace_layers (bool, optional): Replace layers with transformer engine equivalent layers for fp8 training.
Defaults to False.
recipe (Optional[Union[Mapping[str, Any], "DelayedScaling"]], optional): Recipe for fp8 training.
Defaults to None.
quant (bool, optional): Run fp8 inference in measurement (False) or Quant (True) mode. Defaults to True.
fp8_data_path (Optional[str], optional): path to dump fp8 inference data in measurement mode.
Defaults to None.

Returns:
torch.nn.Module: FP8 enabled module
torch.nn.Module: fp8 enabled module

"""
assert self.precision == "fp8", "HPUPrecisionPlugin.convert_modules() should only be used with precision=`fp8`."
if inference and self.fp8_inference_available:
self._enable_fp8_inference(module, quant, fp8_data_path)
if self.fp8_train_available and self.replace_layers and not inference:
self._enable_fp8_training(module)
if inference:
if self.fp8_inference_available:
self._enable_fp8_inference(module, quant, fp8_data_path)
else:
raise ModuleNotFoundError(
"habana_quantization_toolkit not found. "
"Install it using `pip install habana_quantization_toolkit`"
)
if not inference and self.fp8_train_available:
self._enable_fp8_training(module, replace_layers, recipe)
return module

def autocast_context_manager(self) -> Union[ContextManager[Any], torch.autocast]:
Expand Down
36 changes: 14 additions & 22 deletions tests/test_pytorch/test_deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import habana_frameworks.torch.hpex.experimental.transformer_engine as tengine
import pytest
import torch
from habana_frameworks.torch.hpex.experimental.transformer_engine import recipe
from lightning_utilities import module_available
from torch import Tensor
from torch.utils.data import DataLoader, Dataset
Expand Down Expand Up @@ -781,15 +780,15 @@ def test_lightning_deepspeed_inference_config(get_device_count, dtype):


@pytest.mark.parametrize("stage", [1, 2, 3])
@pytest.mark.skipif(HPUAccelerator.get_device_name() == "GAUDI", reason="fp8 supported on Gaudi2 and above.")
def test_hpu_deepspeed_fp8_training_accuracy(tmpdir, get_device_count, stage):
@pytest.mark.skipif(HPUAccelerator.get_device_name() == "GAUDI", reason="fp8 / fp16 supported on Gaudi2 and above.")
def test_hpu_deepspeed_training_accuracy(tmpdir, get_device_count, stage):
"""Test compare training accuracy between bf16 and fp8 precision for deepspeed."""

class TestModel(BoringModel):
"""Test model."""

def __init__(self):
"""init."""
"""Init."""
super().__init__()
self.layer = tengine.Linear(32, 2)

Expand Down Expand Up @@ -824,28 +823,21 @@ def run_training(tmpdir, model, plugin, strategy):
trainer.fit(model)
return trainer.callback_metrics["val_loss"], trainer.callback_metrics["train_loss"]

precision_plugin_params_list = [
({"precision": "bf16-mixed"}),
pytest.param(
{"precision": "16-mixed"},
marks=pytest.mark.skipif(
HPUAccelerator.get_device_name() == "GAUDI", reason="fp16 supported on Gaudi2 and above"
),
),
pytest.param(
{"precision": "fp8", "replace_layers": True, "recipe": recipe.DelayedScaling()},
marks=pytest.mark.skipif(
HPUAccelerator.get_device_name() == "GAUDI", reason="fp8 supported on Gaudi2 and above"
),
),
precision_list = [
"32-true",
"bf16-mixed",
"16-mixed",
"fp8",
]

loss_list = []

for params in precision_plugin_params_list:
for precision in precision_list:
seed_everything(42)
model = TestModel()
_plugin = HPUDeepSpeedPrecisionPlugin(**params)
_plugin = HPUDeepSpeedPrecisionPlugin(precision=precision)
if precision == "fp8":
_plugin.convert_modules(model)
_strategy = HPUDeepSpeedStrategy(stage=stage)
loss_list.append(run_training(tmpdir, model, _plugin, _strategy))

Expand All @@ -861,7 +853,7 @@ class TestModel(BoringModel):
"""Test model."""

def __init__(self):
"""init."""
"""Init."""
super().__init__()
self.layer = torch.nn.Linear(32, 2)

Expand Down Expand Up @@ -906,7 +898,7 @@ class TestModel(BoringModel):
"""Test model."""

def __init__(self):
"""init."""
"""Init."""
super().__init__()
self.layer = torch.nn.Linear(32, 2)

Expand Down
Loading
Loading