Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add bfloat16 support to Lightning Trainer #9049

Merged
merged 15 commits into from
Aug 24, 2021
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `DataLoaderIterDataFetcher` ([#9020](https://github.com/PyTorchLightning/pytorch-lightning/pull/9020))


- Added bfloat16 support for Lightning Trainer ([#9049](https://github.com/PyTorchLightning/pytorch-lightning/pull/9049))


- Added `DataFetcher` within `Fit / Evaluation` Loop ([#9047](https://github.com/PyTorchLightning/pytorch-lightning/pull/9047))


Expand Down
61 changes: 49 additions & 12 deletions pytorch_lightning/plugins/precision/native_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,32 +12,60 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
from typing import Any, Callable, Dict, Generator
from typing import Any, Callable, Dict, Generator, Union

import torch
from torch.optim import LBFGS, Optimizer

import pytorch_lightning as pl
from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin
from pytorch_lightning.utilities import _NATIVE_AMP_AVAILABLE, AMPType
from pytorch_lightning.utilities import _NATIVE_AMP_AVAILABLE, _TORCH_GREATER_EQUAL_1_10, AMPType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.warnings import WarningCache

warning_cache = WarningCache()


class NativeMixedPrecisionPlugin(MixedPrecisionPlugin):
"""Plugin for native mixed precision training with :mod:`torch.cuda.amp`."""
"""
Plugin for native mixed precision training with :mod:`torch.cuda.amp`.

Args:
precision: Whether to use torch.float16 (16) or torch.bfloat16 (bf16).
"""

def __init__(self) -> None:
def __init__(self, precision: Union[int, str] = 16) -> None:
super().__init__()

if not _NATIVE_AMP_AVAILABLE:
raise MisconfigurationException(
"You have asked for native AMP but your PyTorch version does not support it."
" Consider upgrading with `pip install torch>=1.6`."
)

self._fast_dtype = self._select_precision_dtype(precision)
self.backend = AMPType.NATIVE
self.scaler = torch.cuda.amp.GradScaler()
if not self.is_bfloat16:
self.scaler = torch.cuda.amp.GradScaler()

def _select_precision_dtype(self, precision: Union[int, str] = 16) -> torch.dtype:
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved
if precision == "bf16":
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved
if not _TORCH_GREATER_EQUAL_1_10:
raise MisconfigurationException(
"To use bfloat16 with native amp you must install torch greater or equal to 1.10."
)
return torch.bfloat16
return torch.float16

@property
def is_bfloat16(self) -> bool:
return self._fast_dtype == torch.bfloat16
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved

def pre_backward(self, model: "pl.LightningModule", closure_loss: torch.Tensor) -> torch.Tensor:
if self.is_bfloat16:
warning_cache.warn(
"Skipping torch.cuda.amp.GradScaler in NativeMixedPrecisionPlugin as torch.bfloat16 is used."
)
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved
return super().pre_backward(model, closure_loss)
closure_loss = self.scaler.scale(closure_loss)
return super().pre_backward(model, closure_loss)
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved

Expand All @@ -49,6 +77,9 @@ def pre_optimizer_step(
lambda_closure: Callable,
**kwargs: Any,
) -> bool:
if self.is_bfloat16:
# skip scaler logic, as bfloat16 does not require scaler
return super().pre_optimizer_step(model, optimizer, optimizer_idx, lambda_closure, **kwargs)
if isinstance(optimizer, LBFGS):
raise MisconfigurationException(
f"native PyTorch amp and lbfgs are not compatible (optimizer {optimizer_idx})."
Expand All @@ -65,33 +96,39 @@ def pre_optimizer_step(
self.scaler.update()
return False

def autocast_context_manager(self) -> torch.cuda.amp.autocast:
if self.is_bfloat16:
return torch.cuda.amp.autocast(fast_dtype=self._fast_dtype)
return torch.cuda.amp.autocast()

@contextmanager
def train_step_context(self) -> Generator[None, None, None]:
"""Enable autocast context"""
with torch.cuda.amp.autocast():
with self.autocast_context_manager():
yield

@contextmanager
def val_step_context(self) -> Generator[None, None, None]:
"""Enable autocast context"""
with torch.cuda.amp.autocast():
with self.autocast_context_manager():
yield

@contextmanager
def test_step_context(self) -> Generator[None, None, None]:
"""Enable autocast context"""
with torch.cuda.amp.autocast():
with self.autocast_context_manager():
yield

@contextmanager
def predict_step_context(self) -> Generator[None, None, None]:
"""Enable autocast context"""
with torch.cuda.amp.autocast():
with self.autocast_context_manager():
yield

def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
if "native_amp_scaling_state" in checkpoint:
if "native_amp_scaling_state" in checkpoint and not self.is_bfloat16:
self.scaler.load_state_dict(checkpoint["native_amp_scaling_state"])

def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
checkpoint["native_amp_scaling_state"] = self.scaler.state_dict()
if not self.is_bfloat16:
checkpoint["native_amp_scaling_state"] = self.scaler.state_dict()
4 changes: 2 additions & 2 deletions pytorch_lightning/plugins/precision/sharded_native_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
class ShardedNativeMixedPrecisionPlugin(NativeMixedPrecisionPlugin):
"""Mixed Precision for Sharded Training"""

def __init__(self) -> None:
super().__init__()
def __init__(self, precision: Union[int, str] = 16) -> None:
super().__init__(precision)
self.scaler = ShardedGradScaler()

def clip_grad_by_norm(
Expand Down
10 changes: 5 additions & 5 deletions pytorch_lightning/trainer/connectors/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,7 +560,7 @@ def select_precision_plugin(self) -> PrecisionPlugin:
return PrecisionPlugin()
if self.precision == 64:
return DoublePrecisionPlugin()
if self.precision == 16:
if self.precision in (16, "bf16"):
if self.use_tpu:
return TPUHalfPrecisionPlugin()

Expand All @@ -581,12 +581,12 @@ def select_precision_plugin(self) -> PrecisionPlugin:
else:
raise MisconfigurationException(msg)
else:
log.info("Using native 16bit precision.")
log.info(f"Using native {self.precision} bit Automatic Mixed Precision")
if self._is_sharded_training_type:
return ShardedNativeMixedPrecisionPlugin()
return ShardedNativeMixedPrecisionPlugin(self.precision)
if self._is_fully_sharded_training_type:
return FullyShardedNativeMixedPrecisionPlugin()
return NativeMixedPrecisionPlugin()
return FullyShardedNativeMixedPrecisionPlugin(self.precision)
return NativeMixedPrecisionPlugin(self.precision)
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved

if self.amp_type == AMPType.APEX:
if not _APEX_AVAILABLE:
Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def __init__(
log_every_n_steps: int = 50,
accelerator: Optional[Union[str, Accelerator]] = None,
sync_batchnorm: bool = False,
precision: int = 32,
precision: Union[int, str] = 32,
weights_summary: Optional[str] = "top",
weights_save_path: Optional[str] = None,
num_sanity_val_steps: int = 2,
Expand Down Expand Up @@ -260,8 +260,8 @@ def __init__(

plugins: Plugins allow modification of core behavior like ddp and amp, and enable custom lightning plugins.

precision: Double precision (64), full precision (32) or half precision (16). Can be used on CPU, GPU or
TPUs.
precision: Double precision (64), full precision (32), half precision (16) or bfloat16 precision (bf16).
Can be used on CPU, GPU or TPUs.

max_epochs: Stop training once this number of epochs is reached. Disabled by default (None).
If both max_epochs and max_steps are not specified, defaults to ``max_epochs`` = 1000.
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/utilities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
_TORCH_GREATER_EQUAL_1_7,
_TORCH_GREATER_EQUAL_1_8,
_TORCH_GREATER_EQUAL_1_9,
_TORCH_GREATER_EQUAL_1_10,
_TORCH_QUANTIZE_AVAILABLE,
_TORCH_SHARDED_TENSOR_AVAILABLE,
_TORCHTEXT_AVAILABLE,
Expand Down
11 changes: 11 additions & 0 deletions pytorch_lightning/utilities/argparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,10 @@ def add_argparse_args(
if arg == "track_grad_norm":
use_type = float

# hack for precision
if arg == "precision":
use_type = _precision_allowed_type

parser.add_argument(
f"--{arg}", dest=arg, default=arg_default, type=use_type, help=args_help.get(arg), **arg_kwargs
)
Expand Down Expand Up @@ -302,3 +306,10 @@ def _int_or_float_type(x: Union[int, float, str]) -> Union[int, float]:
if "." in str(x):
return float(x)
return int(x)


def _precision_allowed_type(x: Union[int, str]) -> Union[int, str]:
try:
return int(x)
except ValueError:
return x
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved
2 changes: 2 additions & 0 deletions pytorch_lightning/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ def _compare_version(package: str, op, version) -> bool:
_TORCH_GREATER_EQUAL_1_8 = _compare_version("torch", operator.ge, "1.8.0")
_TORCH_GREATER_EQUAL_1_8_1 = _compare_version("torch", operator.ge, "1.8.1")
_TORCH_GREATER_EQUAL_1_9 = _compare_version("torch", operator.ge, "1.9.0")
_TORCH_GREATER_EQUAL_1_10 = _compare_version("torch", operator.ge, "1.10.0dev")
Borda marked this conversation as resolved.
Show resolved Hide resolved


_APEX_AVAILABLE = _module_available("apex.amp")
_BOLTS_AVAILABLE = _module_available("pl_bolts")
Expand Down
77 changes: 27 additions & 50 deletions tests/models/test_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import tests.helpers.utils as tutils
from pytorch_lightning import Trainer
from pytorch_lightning.plugins.environments import SLURMEnvironment
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_10
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers import BoringModel, RandomDataset
from tests.helpers.runif import RunIf
Expand All @@ -31,7 +32,8 @@ class AMPTestModel(BoringModel):
def _step(self, batch, batch_idx):
assert torch.is_autocast_enabled()
output = self(batch)
assert output.dtype == torch.float16
bfloat16 = self.trainer.precision_plugin.is_bfloat16
assert output.dtype == torch.float16 if not bfloat16 else torch.bfloat16
loss = self.loss(batch, output)
return loss

Expand All @@ -50,17 +52,35 @@ def test_step(self, batch, batch_idx):
def predict(self, batch, batch_idx, dataloader_idx=None):
assert torch.is_autocast_enabled()
output = self(batch)
assert output.dtype == torch.float16
bfloat16 = self.trainer.precision_plugin.is_bfloat16
assert output.dtype == torch.float16 if not bfloat16 else torch.bfloat16
return output


@pytest.mark.skip(reason="dp + amp not supported currently") # TODO
@RunIf(min_gpus=1)
def test_amp_single_gpu_dp(tmpdir):
"""Make sure DP/DDP + AMP work."""
@RunIf(min_gpus=2)
@pytest.mark.parametrize(
"accelerator",
[
pytest.param("dp", marks=pytest.mark.skip("dp + amp not supported currently")), # TODO
"ddp_spawn",
],
)
@pytest.mark.parametrize(
"precision",
[
16,
pytest.param(
"bf16",
marks=pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_10, reason="torch.bfloat16 not available"),
),
],
)
@pytest.mark.parametrize("gpus", [1, 2])
def test_amp_gpus(tmpdir, accelerator, precision, gpus):
"""Make sure combinations of AMP and training types work if supported."""
tutils.reset_seed()

trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, gpus=1, accelerator="dp", precision=16)
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, gpus=gpus, accelerator=accelerator, precision=precision)

model = AMPTestModel()
# tutils.run_model_test(trainer_options, model)
Expand All @@ -71,49 +91,6 @@ def test_amp_single_gpu_dp(tmpdir):
assert trainer.state.finished, f"Training failed with {trainer.state}"


@RunIf(min_gpus=1)
def test_amp_single_gpu_ddp_spawn(tmpdir):
"""Make sure DP/DDP + AMP work."""
tutils.reset_seed()
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, gpus=1, accelerator="ddp_spawn", precision=16)

model = AMPTestModel()
# tutils.run_model_test(trainer_options, model)
trainer.fit(model)
trainer.test(model)
trainer.predict(model, DataLoader(RandomDataset(32, 64)))
assert trainer.state.finished, f"Training failed with {trainer.state}"


@pytest.mark.skip(reason="dp + amp not supported currently") # TODO
@RunIf(min_gpus=1)
def test_amp_multi_gpu_dp(tmpdir):
"""Make sure DP/DDP + AMP work."""
tutils.reset_seed()

trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, gpus=2, accelerator="dp", precision=16)

model = AMPTestModel()
# tutils.run_model_test(trainer_options, model)
trainer.fit(model)

assert trainer.state.finished, f"Training failed with {trainer.state}"


@RunIf(min_gpus=2)
def test_amp_multi_gpu_ddp_spawn(tmpdir):
"""Make sure DP/DDP + AMP work."""
tutils.reset_seed()
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, gpus=2, accelerator="ddp_spawn", precision=16)

model = AMPTestModel()
# tutils.run_model_test(trainer_options, model)
trainer.fit(model)
trainer.test(model)
trainer.predict(model, DataLoader(RandomDataset(32, 64)))
assert trainer.state.finished, f"Training failed with {trainer.state}"


@RunIf(min_gpus=2)
@mock.patch.dict(
os.environ,
Expand Down
31 changes: 31 additions & 0 deletions tests/plugins/test_amp_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from pytorch_lightning import Trainer
from pytorch_lightning.plugins import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin
from pytorch_lightning.plugins.precision import MixedPrecisionPlugin
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers import BoringModel
from tests.helpers.runif import RunIf

Expand Down Expand Up @@ -69,6 +70,8 @@ def test_amp_apex_ddp(
plugins=[plugin_cls()] if custom_plugin else None,
)
assert isinstance(trainer.precision_plugin, plugin_cls)
if amp == "native":
assert not trainer.precision_plugin.is_bfloat16


class GradientUnscaleBoringModel(BoringModel):
Expand Down Expand Up @@ -174,3 +177,31 @@ def test_amp_apex_ddp_spawn_fit(amp_level, tmpdir):
assert isinstance(trainer.precision_plugin, ApexMixedPrecisionPlugin)
model = BoringModel()
trainer.fit(model)


@RunIf(min_gpus=1, amp_native=True, min_torch="1.10.0dev")
Borda marked this conversation as resolved.
Show resolved Hide resolved
def test_amp_precision_bfloat_warning(tmpdir):
model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
fast_dev_run=True,
precision="bf16",
gpus=1,
)
with pytest.warns(
UserWarning, match="Skipping torch.cuda.amp.GradScaler in NativeMixedPrecisionPlugin as torch.bfloat16 is used."
):
trainer.fit(model)


@RunIf(min_gpus=1, amp_native=True, max_torch="1.9")
def test_amp_precision_16_bfloat_throws_error(tmpdir):
with pytest.raises(
MisconfigurationException,
match="To use bfloat16 with native amp you must install torch greater or equal to 1.10",
):
Trainer(
default_root_dir=tmpdir,
precision="bf16",
gpus=1,
)
Loading