Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoParalle] balancing the calculation of global_norm in data parallel #49510

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
157 changes: 119 additions & 38 deletions python/paddle/distributed/passes/auto_parallel_grad_clip.py
Expand Up @@ -20,8 +20,14 @@
from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole

from ..auto_parallel.dist_attribute import OperatorDistAttr, TensorDistAttr
from ..auto_parallel.operators.common import SyncMode
from ..auto_parallel.process_group import get_world_process_group
from ..auto_parallel.operators.common import (
SyncMode,
is_data_parallel_reduce_op,
)
from ..auto_parallel.process_group import (
get_all_process_groups,
get_world_process_group,
)
from ..auto_parallel.process_mesh import ProcessMesh
from ..auto_parallel.reshard import Resharder
from ..auto_parallel.utils import (
Expand All @@ -31,6 +37,7 @@
is_optimize_op,
use_standalone_executor,
)
from .auto_parallel_sharding import ShardingPass
from .pass_base import PassBase, register_pass


Expand Down Expand Up @@ -145,46 +152,65 @@ def _is_about_global_norm(


class ClipHelper:
def __init__(self, params_grads, rank_id, block, dist_context):
def __init__(
self, params_grads, rank_id, block, dist_context, pass_context
):
params, _ = zip(*params_grads)
self.params = list(params)
self.params_name = [p.name for p in self.params]
self.rank_id = rank_id
self.block = block
self.dist_context = dist_context
self.pass_context = pass_context
self.sharding_group = None
self.world_ranks = get_world_process_group().ranks
if hasattr(dist_context, '_sharding_group'):
self.sharding_group = dist_context._sharding_group

def _is_calcuate_norm(self, name):
if not self._is_local_param(name):
return False, []
self.world_nranks = len(self.world_ranks)
self.pure_data_parallel = self._is_pure_data_parallel()
self.rank_to_params = self._partition_parameters(params)

param = self.params[self.params_name.index(name)]
dist_attr = self._get_dist_attr(name)
topology = dist_attr.process_mesh.shape
processes = dist_attr.process_mesh.process_ids
dims_mapping = dist_attr.dims_mapping
return _is_about_global_norm(
self.rank_id,
param.shape,
topology,
processes,
dims_mapping,
self.sharding_group,
)
def is_calcuate_norm(self, name):
"""
whether the param_name@GRAD paticipate in the calculation of global_norm
"""
if not self.is_local_param(name):
return False

def _get_dist_attr(self, name):
var = self.block.vars[name]
return self.dist_context.get_tensor_dist_attr_for_program(var)
param = self.params[self.params_name.index(name)]
if not self.pure_data_parallel:
dist_attr = self._get_dist_attr(name)
topology = dist_attr.process_mesh.shape
processes = dist_attr.process_mesh.process_ids
dims_mapping = dist_attr.dims_mapping
return _is_about_global_norm(
self.rank_id,
param.shape,
topology,
processes,
dims_mapping,
self.sharding_group,
)
else:
return param.name in self.rank_to_params[self.rank_id]

def _is_local_param(self, name):
def is_local_param(self, name):
"""
whether the param_name is updated with opt in cur_rank
"""
if name not in self.params_name:
return False
return True

def _is_local_var(self, name):
def _get_dist_attr(self, name):
var = self.block.vars[name]
return self.dist_context.get_tensor_dist_attr_for_program(var)

def is_local_var_with_dist_attr(self, name):
"""
whether the var_name is belong to cur_rank
"""
dist_attr = self._get_dist_attr(name)
assert dist_attr is not None
return self.rank_id in dist_attr.process_mesh.process_ids
Expand Down Expand Up @@ -212,6 +238,50 @@ def _init_dist_attr(self, op):
op_dist_attr.set_output_dist_attr(out_name, out_dist_attr)
self.dist_context.set_op_dist_attr_for_program(op, op_dist_attr)

def _is_pure_data_parallel(self):
for applied_pass in self.pass_context.passes:
if isinstance(applied_pass, ShardingPass):
return False

groups = get_all_process_groups()
for g in groups:
if g.nranks != self.world_nranks:
return False

for op in self.block.ops:
if op.type in [
"c_reduce_sum",
"c_allreduce_sum",
] and not is_data_parallel_reduce_op(op):
return False

return True

def _partition_parameters(self, params):
"""
build rank_id_to_params by the param's numel
to guarantee params in every rank of dp_group as even as possible.
"""
mapping = {}
if not self.pure_data_parallel:
for rank_ in range(self.world_nranks):
mapping[rank_] = [p.name for p in params]
else:
for rank_ in range(self.world_nranks):
mapping[rank_] = []
sizes = [0] * self.world_nranks
for param in params:
rank = sizes.index(min(sizes))
mapping[rank].append(param.name)
numel = reduce(lambda x, y: x * y, param.shape)
assert (
numel > 0
), "param [{}] should larger than 0, but it is [{}]".format(
param.name, numel
)
sizes[rank] += numel
return mapping


@register_pass("auto_parallel_grad_clip")
class ClipGradByGloblNormPass(PassBase):
Expand Down Expand Up @@ -248,14 +318,13 @@ def _apply_single_impl(self, main_program, startup_program, context):
# dist_params_grads = _get_params_grads(block)

self.clip_helper = ClipHelper(
dist_params_grads, rank_id, block, dist_context
dist_params_grads, rank_id, block, dist_context, context
)
self._remove_no_need_ops_vars(block)

def _remove_no_need_ops_vars(self, block):

removed_op_out_type = [
'clip_by_norm',
'squared_l2_norm',
'square',
'reduce_sum',
Expand All @@ -267,31 +336,40 @@ def _remove_no_need_ops_vars(self, block):
if not is_gradient_clip_op(op):
continue

if op.type in removed_op_out_type:
if op.type == 'clip_by_norm':
# remove 'clip_by_norm' op if the param is not updated with opt in current rank
input_name = op.input("X")[0]
if input_name.find("@GRAD") != -1:
# 'clip_by_norm', 'squared_l2_norm', 'square'
param_name = input_name[: input_name.find("@GRAD")]
is_local = self.clip_helper._is_local_param(param_name)
is_calculate = self.clip_helper._is_calcuate_norm(
param_name
)
if not is_local or (
not is_calculate and op.type != 'clip_by_norm'
):
is_local = self.clip_helper.is_local_param(param_name)
if not is_local:
removed_op_idx.add(idx)
removed_tmp_var.update(set(op.output_arg_names))

elif op.type in removed_op_out_type:
input_name = op.input("X")[0]
if input_name.find("@GRAD") != -1:
# remove 'squared_l2_norm' and 'square' ops,
# if the param@GRAD in cur_rank does not participate in the calculation of global_norm
param_name = input_name[: input_name.find("@GRAD")]
is_local = self.clip_helper.is_local_param(param_name)
is_calculate = self.clip_helper.is_calcuate_norm(param_name)
if not is_local or not is_calculate:
removed_op_idx.add(idx)
removed_tmp_var.update(set(op.output_arg_names))
else:
# 'reduce_sum'
# 'reduce_sum' must be behind 'square'
if idx - 1 in removed_op_idx:
removed_op_idx.add(idx)
removed_tmp_var.update(set(op.output_arg_names))

elif op.type == 'elementwise_mul':
# 'elementwise_mul' scale the param@GRAD with global_norm
# remove 'elementwise_mul' op if the param is not updated with opt in current rank
input_name = op.input("X")[0]
if input_name.find("@GRAD") != -1:
param_name = input_name[: input_name.find("@GRAD")]
is_local = self.clip_helper._is_local_param(param_name)
is_local = self.clip_helper.is_local_param(param_name)
if not is_local:
removed_op_idx.add(idx)
if block.ops[idx - 1].type == 'cast':
Expand All @@ -301,11 +379,14 @@ def _remove_no_need_ops_vars(self, block):
)

elif op.type == 'sum':
# 'sum' op is used to calculate global_norm, and need to filter inputs which is not in cur_rank
reserved_vars = []
for input_name in op.input_arg_names:
if (
input_name not in removed_tmp_var
and self.clip_helper._is_local_var(input_name)
and self.clip_helper.is_local_var_with_dist_attr(
input_name
)
):
reserved_vars.append(input_name)
if not reserved_vars:
Expand Down
Expand Up @@ -31,6 +31,7 @@ def apply_pass(use_sharding=False):
strategy.reinit = True
if use_sharding:
sharding = strategy.sharding
sharding.enable = True
sharding.degree = 2
sharding.stage = 2
return strategy
Expand Down