Skip to content

Commit

Permalink
fix sharding_stage1 amp O2 decorate bug (#48960)
Browse files Browse the repository at this point in the history
  • Loading branch information
sneaxiy committed Dec 10, 2022
1 parent fd37357 commit c40122d
Showing 1 changed file with 36 additions and 16 deletions.
52 changes: 36 additions & 16 deletions python/paddle/fluid/dygraph/amp/auto_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,14 +252,26 @@ def check_models(models):
)


def _is_valid_optimizer(optimizer):
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import (
DygraphShardingOptimizer,
)

return isinstance(
optimizer,
(
paddle.optimizer.Optimizer,
paddle.fluid.optimizer.Optimizer,
DygraphShardingOptimizer,
),
)


def check_optimizers(optimizers):
for optimizer in optimizers:
if not isinstance(
optimizer,
(paddle.optimizer.Optimizer, paddle.fluid.optimizer.Optimizer),
):
if not _is_valid_optimizer(optimizer):
raise RuntimeError(
"Current train mode is pure fp16, optimizers should be paddle.optimizer.Optimizer or paddle.fluid.optimizer.Optimizer, but receive {}.".format(
"Current train mode is pure fp16, optimizers should be paddle.optimizer.Optimizer or paddle.fluid.optimizer.Optimizer or DygraphShardingOptimizer, but receive {}.".format(
type(optimizer)
)
)
Expand Down Expand Up @@ -477,6 +489,20 @@ def __call__(self, state_dict):
state_dict[key] = param_applied


def _set_multi_precision(optimizer, multi_precision):
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import (
DygraphShardingOptimizer,
)

optimizer = (
optimizer._inner_optimizer
if isinstance(optimizer, DygraphShardingOptimizer)
else optimizer
)
if hasattr(optimizer, "_multi_precision"):
optimizer._multi_precision = multi_precision


@dygraph_only
def amp_decorate(
models,
Expand Down Expand Up @@ -582,10 +608,7 @@ def amp_decorate(
if optimizers is not None:
# check optimizers
optimizers_is_list = False
if isinstance(
optimizers,
(paddle.optimizer.Optimizer, paddle.fluid.optimizer.Optimizer),
):
if _is_valid_optimizer(optimizers):
optimizers_is_list = False
optimizers = [optimizers]
check_optimizers(optimizers)
Expand All @@ -596,13 +619,10 @@ def amp_decorate(
raise TypeError(
"optimizers must be either a single optimizer or a list of optimizers."
)
# supprot master_weight
for idx_opt in range(len(optimizers)):
if hasattr(optimizers[idx_opt], '_multi_precision'):
if master_weight is False:
optimizers[idx_opt]._multi_precision = False
else:
optimizers[idx_opt]._multi_precision = True
# support master_weight
use_multi_precision = not (master_weight is False)
for opt in optimizers:
_set_multi_precision(opt, use_multi_precision)

if save_dtype is not None:
if not (save_dtype in ['float16', 'bfloat16', 'float32', 'float64']):
Expand Down

0 comments on commit c40122d

Please sign in to comment.