Skip to content

Commit

Permalink
Add tests for gradients when using HPUPrecisionPlugin (#196)
Browse files Browse the repository at this point in the history
* Add docs, test
  • Loading branch information
ankitgola005 committed Jun 24, 2024
1 parent c73fa7a commit e24995f
Show file tree
Hide file tree
Showing 2 changed files with 159 additions and 4 deletions.
42 changes: 41 additions & 1 deletion docs/source/intermediate.rst
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ The following is an excerpt from an MNIST example implemented on a single HPU.
For more granular control over with mixed precision training, one can use torch.autocast from native PyTorch.

Instances of autocast serve as context managers or decorators that allow regions of your script to run in mixed precision.
These also allow for fine tuning with `enabled` for enabling and disabling mixed precision training for certain parts of the code.


.. code-block:: python
Expand Down Expand Up @@ -106,6 +106,46 @@ These also allow for fine tuning with `enabled` for enabling and disabling mixed
# Train the model ⚡
trainer.fit(model, datamodule=dm)
`torch.autocast` context manager allows fine-tuning of mixed precision training with `enabled` parameter.
It can be used alongside `HPUPrecisionPlugin`, which globally enables mixed precision, while local `torch.autocast` contexts can disable it for particular model parts.
Alternatively, users can forgo `HPUPrecisionPlugin` and use only `torch.autocast` to control precision for every Op.
For nested contexts, the scope of a given context and its `enabled` parameter determine whether mixed precision is enabled or disabled in that region.


.. code::python
# Granular autocast control without HPUPrecisionPlugin
def forward(self, x):
"""Forward."""
with torch.autocast(device_type="hpu", dtype=torch.bfloat16, enabled=True):
torch.hpu.is_autocast_hpu_enabled() # Returns True
with torch.autocast(device_type="hpu", dtype=torch.bfloat16, enabled=False):
torch.hpu.is_autocast_hpu_enabled() # Returns False
# Re-entering autocast enabled region
torch.hpu.is_autocast_hpu_enabled() # Returns True
return super().forward(x)
# Granular autocast control with HPUPrecisionPlugin
def forward(self, x):
"""Forward."""
# HPUPrecisionPlugin wraps a forward_context on train / val / predict / test _steps.
# This makes torch.autocast(enabled=True) as used in previous example redundant.
torch.hpu.is_autocast_hpu_enabled() # Returns True
with torch.autocast(device_type="hpu", dtype=torch.bfloat16, enabled=False):
torch.hpu.is_autocast_hpu_enabled() # Returns False
with torch.autocast(device_type="hpu", dtype=torch.bfloat16, enabled=True):
torch.hpu.is_autocast_hpu_enabled() # Returns True
torch.hpu.is_autocast_hpu_enabled() # Returns False
# Re-entering autocast enabled region
torch.hpu.is_autocast_hpu_enabled() # Returns True
return super().forward(x)
For more details, please refer to
`Native PyTorch Autocast <https://docs.habana.ai/en/latest/PyTorch/PyTorch_Mixed_Precision/Autocast.html>`__.
and `Automatic Mixed Precision Package: torch.autocast <https://pytorch.org/docs/stable/amp.html#autocasting>`__.
Expand Down
121 changes: 118 additions & 3 deletions tests/test_pytorch/test_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,9 @@ def forward(self, x):
return super().forward(x)


def test_autocast_enable_disable(tmpdir):
"""Tests autocast control with enabled arg."""
@pytest.mark.parametrize("precision_plugin", [False, True])
def test_autocast_enable_disable(tmpdir, precision_plugin):
"""Tests autocast granular control with HPUPrecisionPlugin."""

class BMAutocastGranularControl(BaseBM):
"""Tests autocast control with enabled arg."""
Expand Down Expand Up @@ -156,7 +157,8 @@ def forward(self, x):
assert x.dtype == torch.bfloat16
return self.layer(x)

assert run_training(tmpdir, BMAutocastGranularControl(), None) is not None
precision_plugin = HPUPrecisionPlugin(precision="bf16-mixed") if precision_plugin else None
assert run_training(tmpdir, BMAutocastGranularControl(), precision_plugin) is not None


@pytest.mark.xfail(strict=False, reason="Env needs to be set")
Expand Down Expand Up @@ -801,3 +803,116 @@ def forward(self, x):
# Compare metrics between cpu and hpu
assert torch.isclose(metrics[0].get("train_loss"), metrics[1].get("train_loss"), atol=1e-5, rtol=1e-5)
assert torch.isclose(metrics[0].get("val_loss"), metrics[1].get("val_loss"), atol=1e-5, rtol=1e-5)


def test_hpu_precision_plugin_grads_dtype(tmpdir):
"""Tests dtype of gradients on hpu match with those on cpu with HPUPrecisionPlugin."""

class TestModel(BoringModel):
"""Test model."""

def __init__(self):
"""Init."""
super().__init__()
self.linear_hook_handle = self.layer.register_full_backward_hook(self.layer_backward_hook)
self.grad_dict: dict = {}

def back_hook(self, layer_name, grad_input, grad_output):
"""Back hook."""
if layer_name not in self.grad_dict:
self.grad_dict[layer_name] = {}
self.grad_dict[layer_name]["grad_input"] = []
self.grad_dict[layer_name]["grad_output"] = []
self.grad_dict[layer_name]["grad_input"].append(grad_input)
self.grad_dict[layer_name]["grad_output"].append(grad_output)

def layer_backward_hook(self, module, grad_input, grad_output):
"""Layer backward hook."""
assert isinstance(module, torch.nn.Linear)
self.back_hook("Linear", grad_input, grad_output)

def forward(self, x):
"""Forward."""
x.requires_grad_(True)
return super().forward(x)

grad_dict = {}
for accelerator, strategy, precision_plugin in [
("cpu", "auto", MixedPrecision(device="cpu", precision="bf16-mixed")),
(HPUAccelerator(), SingleHPUStrategy(), HPUPrecisionPlugin(precision="bf16-mixed")),
]:
seed_everything(42)
model = TestModel()
dm = BoringDataModule()
trainer = Trainer(
default_root_dir=tmpdir,
accelerator=accelerator,
devices=1,
strategy=strategy,
plugins=precision_plugin,
fast_dev_run=1,
)

trainer.fit(model, dm)
accelerator_str = "hpu" if isinstance(accelerator, HPUAccelerator) else accelerator
grad_dict[accelerator_str] = model.grad_dict

for (kcpu, vcpu), (khpu, vhpu) in zip(grad_dict["cpu"]["Linear"].items(), grad_dict["hpu"]["Linear"].items()):
# Ensure comparing same grad_type grad_input / grad_output for both devices
assert kcpu == khpu
for (grad_cpu,), (grad_hpu,) in zip(vcpu, vhpu):
# Check grad dtype
assert grad_cpu.dtype == grad_hpu.dtype


@pytest.mark.skipif(HPUAccelerator.get_device_name() == "GAUDI", reason="fp8 supported on Gaudi2 and above.")
def test_hpu_precision_plugin_grads_dtype_fp8(tmpdir):
"""Test dtype of gradients when using fp8 training."""

class TestModel(BoringModel):
"""Test model."""

def __init__(self):
"""Init."""
super().__init__()
self.layer = tengine.Linear(32, 2)
self.linear_hook_handle = self.layer.register_full_backward_hook(self.layer_backward_hook)
self.grad_dict: dict = {}

def back_hook(self, layer_name, grad_input, grad_output):
"""Back hook."""
if layer_name not in self.grad_dict:
self.grad_dict[layer_name] = {}
self.grad_dict[layer_name]["grad_input"] = []
self.grad_dict[layer_name]["grad_output"] = []
self.grad_dict[layer_name]["grad_input"].append(grad_input)
self.grad_dict[layer_name]["grad_output"].append(grad_output)

def layer_backward_hook(self, module, grad_input, grad_output):
"""Layer backward hook."""
assert isinstance(module, tengine.Linear)
self.back_hook("Linear", grad_input, grad_output)

def forward(self, x):
"""Forward."""
x.requires_grad_(True)
return super().forward(x)

seed_everything(42)
model = TestModel()
dm = BoringDataModule()
plugin = HPUPrecisionPlugin(precision="fp8")
plugin.convert_modules(model, replace_layers=True)
trainer = Trainer(
default_root_dir=tmpdir,
accelerator=HPUAccelerator(),
devices=1,
strategy=SingleHPUStrategy(),
plugins=plugin,
fast_dev_run=1,
)

trainer.fit(model, dm)
for _, v_grad in model.grad_dict["Linear"].items():
for (grad_tensor,) in v_grad:
assert grad_tensor.dtype == torch.float32

0 comments on commit e24995f

Please sign in to comment.