Skip to content

Commit

Permalink
[FSDP] Re-support model dtype change after FSDP init
Browse files Browse the repository at this point in the history
ghstack-source-id: 3815e5ae8eac082490112724bbc3e847161b4397
Pull Request resolved: pytorch#91192
  • Loading branch information
awgu committed Dec 20, 2022
1 parent 37cb3b1 commit bd4d4b8
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 11 deletions.
112 changes: 105 additions & 7 deletions test/distributed/fsdp/test_fsdp_pure_fp16.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,14 @@

import sys

import torch
import torch.distributed.fsdp._traversal_utils as traversal_utils
from torch import distributed as dist
from torch.distributed.fsdp import CPUOffload
from torch.distributed.fsdp import (
CPUOffload,
FullyShardedDataParallel as FSDP,
MixedPrecision,
)
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import (
CUDAInitMode,
Expand All @@ -13,7 +19,6 @@
)
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
run_tests,
TEST_WITH_DEV_DBG_ASAN,
)
Expand All @@ -37,13 +42,20 @@ def world_size(self):
return min(4, super().world_size)

@skip_if_lt_x_gpu(2)
@parametrize(
"cpu_offload",
[CPUOffload(offload_params=True), CPUOffload(offload_params=False)],
)
def test_pure_fp16(self, cpu_offload: CPUOffload):
def test_pure_fp16_training(self):
"""Tests pure FP16 training, including when the parameter's dtype is
changed after FSDP initialization and before training."""
self.run_subtests(
{
"cpu_offload": [
CPUOffload(offload_params=True),
CPUOffload(offload_params=False),
]
},
self._test_pure_fp16_training,
)

def _test_pure_fp16_training(self, cpu_offload: CPUOffload):
self._test_fsdp_parity(
NestedWrappedModule,
FSDPInitMode.RECURSIVE,
Expand All @@ -54,6 +66,92 @@ def test_pure_fp16(self, cpu_offload: CPUOffload):
use_pure_fp16=True,
)

@skip_if_lt_x_gpu(2)
def test_fp16_dtypes(self):
"""
Tests that both user-facing parameter/gradient dtypes and internal
saved dtype attributes are as expected when using an FP16 model
possibly with explicit mixed precision enabled.
"""
self.run_subtests(
{
"to_half_before_fsdp_init": [False, True],
"use_orig_params": [False, True],
"mixed_precision": [
MixedPrecision(),
MixedPrecision(
param_dtype=torch.float16,
reduce_dtype=torch.float32,
),
MixedPrecision(
param_dtype=torch.float32,
),
],
},
self._test_fp16_dtypes,
)

def _test_fp16_dtypes(
self,
to_half_before_fsdp_init: bool,
use_orig_params: bool,
mixed_precision: MixedPrecision,
):
model = NestedWrappedModule.init(
self.process_group,
FSDPInitMode.NO_FSDP,
CUDAInitMode.CUDA_NEVER,
{},
)
fsdp_kwargs = {
"use_orig_params": use_orig_params,
"device_id": torch.cuda.current_device(),
"mixed_precision": mixed_precision,
}
if to_half_before_fsdp_init:
model = model.half()
fsdp_model = FSDP(model, **fsdp_kwargs)
if not to_half_before_fsdp_init:
fsdp_model = fsdp_model.half()
for param in fsdp_model.parameters():
self.assertEqual(param.dtype, torch.float16)
inp = tuple(
t.half() if torch.is_tensor(t) else t
for t in fsdp_model.module.get_input(torch.device("cuda"))
)
out = fsdp_model(*inp)
out.sum().backward()

# Check handle dtype attributes
for handle in traversal_utils._get_fsdp_handles(fsdp_model):
self.assertEqual(handle.flat_param.dtype, torch.float16)
self.assertEqual(handle.flat_param.grad.dtype, torch.float16)
self.assertEqual(handle._orig_param_dtype, torch.float16)
# Specifying `mixed_precision` takes precedence over the model
# dtype for both `param_dtype` and `reduce_dtype`
if mixed_precision.param_dtype is not None:
self.assertEqual(
handle._fwd_bwd_param_dtype, mixed_precision.param_dtype
)
else:
self.assertEqual(handle._fwd_bwd_param_dtype, torch.float16)
if mixed_precision.reduce_dtype is not None:
self.assertEqual(handle._reduce_dtype, mixed_precision.reduce_dtype)
elif (
mixed_precision.reduce_dtype is None
and mixed_precision.param_dtype is not None
):
# Special case: infer reduce dtype from parameter dtype
self.assertEqual(handle._reduce_dtype, mixed_precision.param_dtype)
else:
self.assertEqual(handle._reduce_dtype, torch.float16)

# Check parameter/gradient dtypes
for param in fsdp_model.parameters():
self.assertEqual(param.dtype, torch.float16)
if param.grad is not None:
self.assertEqual(param.grad.dtype, torch.float16)


instantiate_parametrized_tests(TestPureFP16)

Expand Down
31 changes: 27 additions & 4 deletions torch/distributed/fsdp/flat_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,9 +522,14 @@ def _init_param_reduce_dtypes(
is ``None``, in which case we assume the gradient reduction dtype
matches the forward/backward parameter dtype.
"""
low_prec_param_dtype_specified = mp_param_dtype is not None
low_prec_reduce_dtype_specified = mp_reduce_dtype is not None
if low_prec_param_dtype_specified and not low_prec_reduce_dtype_specified:
# Save whether these dtypes were specified so that we permit the
# parameter dtype to change up until the lazy initialization
self._low_prec_param_dtype_specified = mp_param_dtype is not None
self._low_prec_reduce_dtype_specified = mp_reduce_dtype is not None
if (
self._low_prec_param_dtype_specified
and not self._low_prec_reduce_dtype_specified
):
# Special case: infer gradient reduction mixed precision
self._fwd_bwd_param_dtype = mp_param_dtype
self._reduce_dtype = self._fwd_bwd_param_dtype
Expand Down Expand Up @@ -770,6 +775,24 @@ def init_flat_param_attributes(self) -> None:
reshard methods in this class for the allocation and free pattern.
"""
flat_param = self.flat_param
if flat_param.dtype != self._orig_param_dtype:
# Entering this branch means that the user changed the parameter
# dtype after FSDP initialization, in which case we may need to
# refresh some saved dtype attributes (dtypes specified as a part
# of mixed precision take precedence).
if not self._low_prec_param_dtype_specified:
self._fwd_bwd_param_dtype = flat_param.dtype
# For `reduce_dtype`, require `param_dtype` was not specified since
# then we infer the `reduce_dtype` from the specified `param_dtype`
if (
not self._low_prec_reduce_dtype_specified
and not self._low_prec_param_dtype_specified
):
self._reduce_dtype = flat_param.dtype
self._orig_param_dtype = flat_param.dtype
# Delete since they are no longer needed
delattr(self, "_low_prec_param_dtype_specified")
delattr(self, "_low_prec_reduce_dtype_specified")
cpu_device = torch.device("cpu")
if self._offload_params:
p_assert(
Expand Down Expand Up @@ -1552,7 +1575,7 @@ def _use_sharded_views(self) -> None:
# Allow the original data to be freed via garbage collection
param.data = torch.empty(
0,
dtype=param.dtype,
dtype=self.flat_param.dtype, # in case `flat_param` changed dtype
device=self.flat_param.device,
requires_grad=False,
)
Expand Down

0 comments on commit bd4d4b8

Please sign in to comment.