Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add mixed precision training example #54

Merged
merged 5 commits into from Jul 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 3 additions & 1 deletion .azure/hpu-tests.yml
Expand Up @@ -136,7 +136,9 @@ jobs:

- bash: |
export PYTHONPATH="${PYTHONPATH}:$(pwd)"
python mnist_sample.py
python mnist_trainer.py
LOWER_LIST=ops_fp32_mnist.txt FP32_LIST=ops_bf16_mnist.txt \
python mnist_trainer.py -r autocast
workingDirectory: examples/pytorch/
displayName: 'Testing HPU examples'

Expand Down
2 changes: 1 addition & 1 deletion CHANGELOG.md
Expand Up @@ -12,12 +12,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added

- Added tests for mixed precision training ([#36](https://github.com/Lightning-AI/lightning-Habana/pull/36))
- Example to include mixed precision training ([#54](https://github.com/Lightning-AI/lightning-Habana/pull/54))
-

### Changed

- Enabled skipped tests based on registered strategy, accelerator ([#46](https://github.com/Lightning-AI/lightning-Habana/pull/46))
-

### Fixed

Expand Down
60 changes: 36 additions & 24 deletions examples/pytorch/mnist_sample.py
Expand Up @@ -12,24 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse

import torch
from lightning_utilities import module_available
from torch.nn import functional as F # noqa: N812

if module_available("lightning"):
from lightning.pytorch import LightningModule, Trainer
from lightning.pytorch.demos.mnist_datamodule import MNISTDataModule
from lightning.pytorch import LightningModule
elif module_available("pytorch_lightning"):
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.demos.mnist_datamodule import MNISTDataModule

from lightning_habana.pytorch.accelerator import HPUAccelerator
from lightning_habana.pytorch.strategies import HPUParallelStrategy, SingleHPUStrategy
from pytorch_lightning import LightningModule


class LitClassifier(LightningModule):
"""Base model."""

def __init__(self):
super().__init__()
self.l1 = torch.nn.Linear(28 * 28, 10)
Expand Down Expand Up @@ -61,20 +56,37 @@ def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.02)


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="MNIST on HPU", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--hpus", default=1, type=int, help="Number of hpus to be used for training")
parser.add_argument("-b", "--batch-size", default=32, type=int)
args = parser.parse_args()
dm = MNISTDataModule(batch_size=args.batch_size)
model = LitClassifier()
class LitAutocastClassifier(LitClassifier):
"""Base Model with torch.autocast CM."""

hpus = args.hpus
_strategy = SingleHPUStrategy()
if hpus > 1:
parallel_hpus = [torch.device("hpu")] * hpus
_strategy = HPUParallelStrategy(parallel_devices=parallel_hpus)
trainer = Trainer(fast_dev_run=True, accelerator=HPUAccelerator(), devices=hpus, strategy=_strategy)
def __init__(self, op_override=False):
super().__init__()
self.op_override = op_override

def forward(self, x):
if self.op_override:
self.check_override(x)
return super().forward(x)

def check_override(self, x):
"""Checks for op override."""
identity = torch.eye(x.shape[1], device=x.device, dtype=x.dtype)
y = torch.mm(x, identity)
z = torch.tan(x)
assert y.dtype == torch.float32
assert z.dtype == torch.bfloat16

trainer.fit(model, datamodule=dm)
trainer.test(model, datamodule=dm)
def training_step(self, batch, batch_idx):
"""Training step."""
with torch.autocast(device_type="hpu", dtype=torch.bfloat16):
return super().training_step(batch, batch_idx)

def validation_step(self, batch, batch_idx):
"""Validation step."""
with torch.autocast(device_type="hpu", dtype=torch.bfloat16):
return super().validation_step(batch, batch_idx)

def test_step(self, batch, batch_idx):
"""Test step."""
with torch.autocast(device_type="hpu", dtype=torch.bfloat16):
return super().test_step(batch, batch_idx)
112 changes: 112 additions & 0 deletions examples/pytorch/mnist_trainer.py
@@ -0,0 +1,112 @@
# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import os
import warnings

from lightning_utilities import module_available
from mnist_sample import LitAutocastClassifier, LitClassifier

if module_available("lightning"):
from lightning.pytorch import Trainer, seed_everything
from lightning.pytorch.demos.mnist_datamodule import MNISTDataModule
from lightning.pytorch.plugins.precision import MixedPrecisionPlugin
elif module_available("pytorch_lightning"):
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.demos.mnist_datamodule import MNISTDataModule
from pytorch_lightning.plugins.precision import MixedPrecisionPlugin

from lightning_habana import HPUAccelerator, SingleHPUStrategy

RUN_TYPE = ["basic", "autocast"]


def run_trainer(model, plugin):
"""Run trainer.fit and trainer.test with given parameters."""
_data_module = MNISTDataModule(batch_size=32)
trainer = Trainer(
accelerator=HPUAccelerator(),
devices=1,
strategy=SingleHPUStrategy(),
plugins=plugin,
fast_dev_run=True,
)
trainer.fit(model, _data_module)
trainer.test(model, _data_module)


def check_and_init_plugins(plugins, run_type, verbose):
"""Initialise plugins with appropriate checks."""
_plugins = []
for plugin in plugins:
if verbose:
print(f"Initializing {plugin}")
if plugin == "MixedPrecisionPlugin":
warnings.warn("Operator overriding is not supported with MixedPrecisionPlugin on Habana devices.")
if run_type != "autocast":
_plugins.append(MixedPrecisionPlugin(device="hpu", precision="bf16-mixed"))
else:
warnings.warn("Skipping MixedPrecisionPlugin. Redundant with autocast run.")
else:
print(f"Unsupported or invalid plugin: {plugin}")
return _plugins


def run_model(run_type, plugins, verbose):
"""Picks appropriate model and plugins."""
# Initialise plugins
_plugins = check_and_init_plugins(plugins, run_type, verbose)
if run_type == "basic":
_model = LitClassifier()
elif run_type == "autocast":
if "LOWER_LIST" in os.environ or "FP32_LIST" in os.environ:
_model = LitAutocastClassifier(op_override=True)
else:
_model = LitAutocastClassifier()
warnings.warn(
"To override operators with autocast, set LOWER_LIST and FP32_LIST file paths as env variables."
"Example: LOWER_LIST=<path_to_bf16_ops> python example.py"
"https://docs.habana.ai/en/latest/PyTorch/PyTorch_Mixed_Precision/Autocast.html#override-options"
)

if verbose:
print(f"With run type: {run_type}, running model: {_model} with plugin: {_plugins}")
return run_trainer(_model, _plugins)


def parse_args():
"""Cmdline arguments parser."""
parser = argparse.ArgumentParser(description="Example to showcase mixed precision training with HPU.")

parser.add_argument("-v", "--verbose", action="store_true", help="Enable verbosity")
parser.add_argument(
"-r", "--run_types", nargs="+", choices=RUN_TYPE, default=RUN_TYPE, help="Select run type for example"
)
parser.add_argument(
"-p", "--plugins", nargs="+", default=[], choices=["MixedPrecisionPlugin"], help="Plugins for use in training"
)
return parser.parse_args()


if __name__ == "__main__":
# Get options
options = parse_args()
if options.verbose:
print(f"Running MNIST mixed precision training with options: {options}")

# Run model and print accuracy
for run_type in options.run_types:
seed_everything(42)
run_model(run_type, options.plugins, options.verbose)
jerome-habana marked this conversation as resolved.
Show resolved Hide resolved
1 change: 1 addition & 0 deletions examples/pytorch/ops_bf16_mnist.txt
@@ -1,2 +1,3 @@
linear
relu
mm
1 change: 1 addition & 0 deletions examples/pytorch/ops_fp32_mnist.txt
@@ -1 +1,2 @@
cross_entropy
tan