diff --git a/paddle/fluid/framework/distributed_strategy.proto b/paddle/fluid/framework/distributed_strategy.proto index 2042a313c41e6..58460fcf9064b 100755 --- a/paddle/fluid/framework/distributed_strategy.proto +++ b/paddle/fluid/framework/distributed_strategy.proto @@ -82,6 +82,7 @@ message PpConfig { optional bool sharding_comm_overlap = 4 [ default = false ]; optional bool profiling = 5 [ default = false ]; optional bool release_gradients = 6 [ default = false ]; + optional bool overlap_p2p_comm = 7 [default = false]; } message DygraphShardingConfig { diff --git a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py index 384d89b4d9c12..e5233c87a199b 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py @@ -216,6 +216,12 @@ def __init__(self, layers, hcg, strategy): "sharding_configs" ].split_param + self._overlap_p2p_comm = self._strategy.hybrid_configs[ + "pp_configs" + ].overlap_p2p_comm + + self._batch_p2p_comm = not self._overlap_p2p_comm + logger.info( f"dp_comm_overlap {self._dp_comm_overlap}; \ sharding_comm_overlap {self._sharding_comm_overlap}; \ @@ -1229,12 +1235,21 @@ def _process_bwd_buffer(step_id, tensor): if not static_scheduler: self.input_tensors[0].append( self._p2p_helper.recv_forward( - self.is_pipeline_first_stage(), sync_recv=False + self.is_pipeline_first_stage(), + sync_recv=False, + batch_p2p_comm=self._batch_p2p_comm, ) ) + fwd_wait_handles = None + bwd_wait_handles = None + # run startup steps for micro_step in range(startup_steps): + if fwd_wait_handles is not None: + for req in fwd_wait_handles: + req.wait() + if static_scheduler: virtual_pp_rank = self._get_virtual_pp_rank( micro_step, forward=True @@ -1270,39 +1285,77 @@ def _process_bwd_buffer(step_id, tensor): if self.is_pipeline_last_stage(ignore_virtual=True): output_tensor = _process_fwd_buffer(micro_step, output_tensor) - # prepare for the first steady step - if ( - micro_step == (startup_steps - 1) - and (not forward_only) - and steady_steps - ): - input_tensor_grad = None - recv_next = True - if self.is_pipeline_last_stage(ignore_virtual=True): - recv_next = False + if not self._overlap_p2p_comm: + # prepare for the first steady step + if ( + micro_step == (startup_steps - 1) + and (not forward_only) + and steady_steps + ): + input_tensor_grad = None + recv_next = True + if self.is_pipeline_last_stage(ignore_virtual=True): + recv_next = False - # the last startup step needs on four direction comm to set up for steady 1f1b + # the last startup step needs on four direction comm to set up for steady 1f1b + ( + input_tensor, + output_tensor_grad, + ) = self._p2p_helper.send_forward_backward_recv_forward_backward( + output_tensor, + input_tensor_grad, + recv_prev=recv_prev, + recv_next=recv_next, + batch_p2p_comm=self._batch_p2p_comm, + ) + # output_tensor_grad is not none if recv_next + # append output_tensor_grad no matter none or not + self.output_tensor_grads[self.num_model_chunks - 1].append( + output_tensor_grad + ) + else: + input_tensor = self._p2p_helper.send_forward_recv_forward( + output_tensor, + recv_prev=recv_prev, + batch_p2p_comm=self._batch_p2p_comm, + ) + # append input_tensor no matter none or not + self.input_tensors[next_virtual_pp_rank].append(input_tensor) + else: ( input_tensor, - output_tensor_grad, - ) = self._p2p_helper.send_forward_backward_recv_forward_backward( + fwd_wait_handles, + ) = self._p2p_helper.send_forward_recv_forward( output_tensor, - input_tensor_grad, recv_prev=recv_prev, - recv_next=recv_next, + batch_p2p_comm=self._batch_p2p_comm, + overlap_p2p_comm=True, ) - # output_tensor_grad is not none if recv_next - # append output_tensor_grad no matter none or not - self.output_tensor_grads[self.num_model_chunks - 1].append( - output_tensor_grad - ) - else: - input_tensor = self._p2p_helper.send_forward_recv_forward( - output_tensor, recv_prev=recv_prev - ) - # append input_tensor no matter none or not - self.input_tensors[next_virtual_pp_rank].append(input_tensor) + if ( + micro_step == (startup_steps - 1) + and (not forward_only) + and steady_steps + ): + input_tensor_grad = None + recv_next = True + if self.is_pipeline_last_stage(ignore_virtual=True): + recv_next = False + ( + output_tensor_grad, + bwd_wait_handles, + ) = self._p2p_helper.send_backward_recv_backward( + input_tensor_grad, + recv_next=recv_next, + batch_p2p_comm=self._batch_p2p_comm, + overlap_p2p_comm=True, + ) + self.output_tensor_grads[self.num_model_chunks - 1].append( + output_tensor_grad + ) + + # append input_tensor no matter none or not + self.input_tensors[next_virtual_pp_rank].append(input_tensor) self._release_output(output_tensor) # run 1f1b steady steps @@ -1339,85 +1392,186 @@ def _process_bwd_buffer(step_id, tensor): continue # forward forward_micro_step_id = micro_step + startup_steps - self._record_stamp("F", forward_micro_step_id, '"B"', forward=True) - output_tensor = self._forward_step_helper( - micro_dataset, forward_micro_step_id - ) - self._record_stamp("F", forward_micro_step_id, '"E"', forward=True) - # backward - backward_micro_step_id = micro_step - self._record_stamp( - "B", backward_micro_step_id, '"B"', forward=False - ) - input_tensor_grad = self._backward_step_helper( - backward_micro_step_id - ) - self._record_stamp( - "B", backward_micro_step_id, '"E"', forward=False - ) + if self._overlap_p2p_comm: + if fwd_wait_handles is not None: + for req in fwd_wait_handles: + req.wait() - # four directions comm - # send output tensor to downstream - # send input tensor grad to upstream - # recv input tensor from upstream - # recv output tensor grad from downstream + self._release_output(output_tensor) + output_tensor = self._forward_step_helper( + micro_dataset, forward_micro_step_id + ) - # last stage doesn't send rst to downstream - forward_virtual_pp_rank = self._get_virtual_pp_rank( - forward_micro_step_id, forward=True - ) - self.set_virtual_pipeline_rank(forward_virtual_pp_rank) - if self.is_pipeline_last_stage(ignore_virtual=True): - output_tensor = _process_fwd_buffer( - forward_micro_step_id, output_tensor + forward_virtual_pp_rank = self._get_virtual_pp_rank( + forward_micro_step_id, forward=True ) + self.set_virtual_pipeline_rank(forward_virtual_pp_rank) + if self.is_pipeline_last_stage(ignore_virtual=True): + output_tensor = _process_fwd_buffer( + forward_micro_step_id, output_tensor + ) - # first stage doesn't send grad to upstream - backward_virtual_pp_rank = self._get_virtual_pp_rank( - backward_micro_step_id, forward=False - ) - self.set_virtual_pipeline_rank(backward_virtual_pp_rank) - if self.is_pipeline_first_stage(ignore_virtual=True): - input_tensor_grad = _process_bwd_buffer( - backward_micro_step_id, input_tensor_grad + # determine whether to recv input tensor from upstream + recv_prev = True + if self.is_pipeline_first_stage(ignore_virtual=True): + next_forward_virtual_pp_rank = self._get_virtual_pp_rank( + forward_micro_step_id + 1, forward=True + ) + if next_forward_virtual_pp_rank == 0: + # next chunk is the first chunk, not need to pre recv an input tensor + recv_prev = False + else: + next_forward_virtual_pp_rank = self._get_virtual_pp_rank( + forward_micro_step_id + 1, forward=True + ) + + # last iteration doesn't need recv from upstream + if micro_step == (steady_steps - 1): + recv_prev = False + + # Send activation tensor to the next stage and receive activation tensor from the + # previous stage + ( + input_tensor, + fwd_wait_handles, + ) = self._p2p_helper.send_forward_recv_forward( + output_tensor, + recv_prev=recv_prev, + batch_p2p_comm=self._batch_p2p_comm, + overlap_p2p_comm=True, ) - # determine whether to recv input tensor from upstream - recv_prev = True - next_forward_virtual_pp_rank = self._get_virtual_pp_rank( - forward_micro_step_id + 1, forward=True - ) - if self.is_pipeline_first_stage(ignore_virtual=True) and ( - next_forward_virtual_pp_rank == 0 - ): - # first pp stage and first virtual stage - recv_prev = False + if bwd_wait_handles is not None: + for req in bwd_wait_handles: + req.wait() - # last iteration doesn't need recv from upstream - if micro_step == (steady_steps - 1): - recv_prev = False + # backward pass + backward_micro_step_id = micro_step + input_tensor_grad = self._backward_step_helper( + backward_micro_step_id + ) - # determine whether to recv grad from downstream - recv_next = True - next_backward_virtual_pp_rank = self._get_virtual_pp_rank( - backward_micro_step_id + 1, forward=False - ) - if self.is_pipeline_last_stage(ignore_virtual=True) and ( - next_backward_virtual_pp_rank == (self.num_model_chunks - 1) - ): - # last pp stage and last virtual stage - recv_next = False - - ( - input_tensor, - output_tensor_grad, - ) = self._p2p_helper.send_forward_backward_recv_forward_backward( - output_tensor, - input_tensor_grad, - recv_prev=recv_prev, - recv_next=recv_next, - ) + # first stage doesn't send grad to upstream + backward_virtual_pp_rank = self._get_virtual_pp_rank( + backward_micro_step_id, forward=False + ) + self.set_virtual_pipeline_rank(backward_virtual_pp_rank) + if self.is_pipeline_first_stage(ignore_virtual=True): + input_tensor_grad = _process_bwd_buffer( + backward_micro_step_id, input_tensor_grad + ) + + recv_next = True + if self.is_pipeline_last_stage(ignore_virtual=True): + next_backward_virtual_pp_rank = self._get_virtual_pp_rank( + backward_micro_step_id + 1, + forward=False, + ) + if next_backward_virtual_pp_rank == ( + self.num_model_chunks - 1 + ): + # next chunk is the last chunk, not need to pre recv an output tensor grad + recv_next = False + else: + next_backward_virtual_pp_rank = self._get_virtual_pp_rank( + backward_micro_step_id + 1, forward=False + ) + + ( + output_tensor_grad, + bwd_wait_handles, + ) = self._p2p_helper.send_backward_recv_backward( + input_tensor_grad, + recv_next=recv_next, + batch_p2p_comm=self._batch_p2p_comm, + overlap_p2p_comm=True, + ) + else: + self._record_stamp( + "F", forward_micro_step_id, '"B"', forward=True + ) + output_tensor = self._forward_step_helper( + micro_dataset, forward_micro_step_id + ) + self._record_stamp( + "F", forward_micro_step_id, '"E"', forward=True + ) + + # backward + backward_micro_step_id = micro_step + self._record_stamp( + "B", backward_micro_step_id, '"B"', forward=False + ) + input_tensor_grad = self._backward_step_helper( + backward_micro_step_id + ) + self._record_stamp( + "B", backward_micro_step_id, '"E"', forward=False + ) + + # four directions comm + # send output tensor to downstream + # send input tensor grad to upstream + # recv input tensor from upstream + # recv output tensor grad from downstream + + # last stage doesn't send rst to downstream + forward_virtual_pp_rank = self._get_virtual_pp_rank( + forward_micro_step_id, forward=True + ) + self.set_virtual_pipeline_rank(forward_virtual_pp_rank) + if self.is_pipeline_last_stage(ignore_virtual=True): + output_tensor = _process_fwd_buffer( + forward_micro_step_id, output_tensor + ) + + # first stage doesn't send grad to upstream + backward_virtual_pp_rank = self._get_virtual_pp_rank( + backward_micro_step_id, forward=False + ) + self.set_virtual_pipeline_rank(backward_virtual_pp_rank) + if self.is_pipeline_first_stage(ignore_virtual=True): + input_tensor_grad = _process_bwd_buffer( + backward_micro_step_id, input_tensor_grad + ) + + # determine whether to recv input tensor from upstream + recv_prev = True + next_forward_virtual_pp_rank = self._get_virtual_pp_rank( + forward_micro_step_id + 1, forward=True + ) + if self.is_pipeline_first_stage(ignore_virtual=True) and ( + next_forward_virtual_pp_rank == 0 + ): + # first pp stage and first virtual stage + recv_prev = False + + # last iteration doesn't need recv from upstream + if micro_step == (steady_steps - 1): + recv_prev = False + + # determine whether to recv grad from downstream + recv_next = True + next_backward_virtual_pp_rank = self._get_virtual_pp_rank( + backward_micro_step_id + 1, forward=False + ) + if self.is_pipeline_last_stage(ignore_virtual=True) and ( + next_backward_virtual_pp_rank == (self.num_model_chunks - 1) + ): + # last pp stage and last virtual stage + recv_next = False + + ( + input_tensor, + output_tensor_grad, + ) = self._p2p_helper.send_forward_backward_recv_forward_backward( + output_tensor, + input_tensor_grad, + recv_prev=recv_prev, + recv_next=recv_next, + batch_p2p_comm=self._batch_p2p_comm, + ) # append input_tensor no matter none or not self.input_tensors[next_forward_virtual_pp_rank].append( input_tensor @@ -1434,10 +1588,15 @@ def _process_bwd_buffer(step_id, tensor): # remaining backward steps if not forward_only: + if self._overlap_p2p_comm and bwd_wait_handles is not None: + for wait_handles in bwd_wait_handles: + wait_handles.wait() + # no steady steps, which only occurs when accumulate_step == num_stage if not steady_steps: output_tensor_grad = p2p.recv_backward( - self.is_pipeline_last_stage() + self.is_pipeline_last_stage(), + batch_p2p_comm=self._batch_p2p_comm, ) self.output_tensor_grads[self.num_model_chunks - 1].append( output_tensor_grad @@ -1482,7 +1641,9 @@ def _process_bwd_buffer(step_id, tensor): # append output_tensor_grad no matter none or not self.output_tensor_grads[next_backward_virtual_pp_rank].append( self._p2p_helper.send_backward_recv_backward( - input_tensor_grad, recv_next=recv_next + input_tensor_grad, + recv_next=recv_next, + batch_p2p_comm=self._batch_p2p_comm, ) ) diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py index 4566f89290fc0..6d470d541f66b 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py +++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py @@ -292,91 +292,33 @@ def batch_send_recv_on_calc_stream(p2p_op_list): op(tensor, comm_group, peer, nranks, rank_id) -def _process_p2p_tuple_or_tensor( +def _batch_p2p_tuple_or_tensor( tensors, p2p_func, pp_rank, pp_group, mp_degree=1, mp_rank=0 ): - ops = [] - if isinstance(tensors, tuple): - for tensor in tensors: - op = P2PonCalcStream( - p2p_func, tensor, pp_rank, pp_group, mp_degree, mp_rank - ) - ops.append(op) - else: - op = P2PonCalcStream( - p2p_func, tensors, pp_rank, pp_group, mp_degree, mp_rank - ) - ops.append(op) + if not isinstance(tensors, tuple): + tensors = (tensors,) + ops = [ + P2PonCalcStream(p2p_func, tensor, pp_rank, pp_group, mp_degree, mp_rank) + for tensor in tensors + ] return ops -def _p2p_helper( - tensor_send_next, - tensor_send_prev, - recv_prev, - recv_next, - sync_recv=True, - send_recv_meta=None, +def _batched_p2p_ops( + tensor_send_prev, tensor_recv_prev, tensor_send_next, tensor_recv_next, hcg ): - global _hcg - - tensor_recv_prev = None - tensor_recv_next = None - - # send / recv message - assert send_recv_meta is not None, "send_recv_meta should not be None" - recv_shape_msg = send_recv_meta.recv_shape_message - recv_dtype_msg = send_recv_meta.recv_dtype_message - recv_stop_gradient = send_recv_meta.recv_stop_gradient - - send_shape_msg = send_recv_meta.send_shape_message - send_dtype_msg = send_recv_meta.send_dtype_message - - # model parallel message - mp_group = _hcg.get_model_parallel_group() - mp_degree = _hcg.get_model_parallel_world_size() - mp_rank = _hcg.get_model_parallel_rank() - - if recv_prev: - if isinstance(recv_shape_msg, tuple): - tensor_recv_prev = [] - for idx, shape in enumerate(recv_shape_msg): - tmp = paddle.empty( - shape=shape, dtype=number_2_dtype(recv_dtype_msg[idx]) - ) - tmp.stop_gradient = recv_stop_gradient[idx] - tensor_recv_prev.append(tmp) - tensor_recv_prev = tuple(tensor_recv_prev) - else: - tensor_recv_prev = paddle.empty( - shape=recv_shape_msg, dtype=number_2_dtype(recv_dtype_msg) - ) - tensor_recv_prev.stop_gradient = recv_stop_gradient - - if recv_next: - if isinstance(send_shape_msg, tuple): - tensor_recv_next = [] - for idx, shape in enumerate(send_shape_msg): - tensor_recv_next.append( - paddle.empty( - shape=shape, dtype=number_2_dtype(send_dtype_msg[idx]) - ) - ) - tensor_recv_next = tuple(tensor_recv_next) - else: - tensor_recv_next = paddle.empty( - shape=send_shape_msg, dtype=number_2_dtype(send_dtype_msg) - ) - ops = [] - pipe_group = _hcg.get_pipe_parallel_group() + pipe_group = hcg.get_pipe_parallel_group() + mp_degree = hcg.get_model_parallel_world_size() + mp_rank = hcg.get_model_parallel_rank() + mp_group = hcg.get_model_parallel_group() # start to p2p communicate if not _sync_send: if tensor_send_prev is not None: - src_rank = _hcg._get_p2p_prev_rank() + src_rank = hcg._get_p2p_prev_rank() ops.extend( - _process_p2p_tuple_or_tensor( + _batch_p2p_tuple_or_tensor( tensor_send_prev, _send_on_calc_stream, src_rank, @@ -386,9 +328,9 @@ def _p2p_helper( ) ) if tensor_recv_prev is not None: - dst_rank = _hcg._get_p2p_prev_rank() + dst_rank = hcg._get_p2p_prev_rank() ops.extend( - _process_p2p_tuple_or_tensor( + _batch_p2p_tuple_or_tensor( tensor_recv_prev, _recv_on_calc_stream, dst_rank, @@ -398,9 +340,9 @@ def _p2p_helper( ) ) if tensor_send_next is not None: - src_rank = _hcg._get_p2p_next_rank() + src_rank = hcg._get_p2p_next_rank() ops.extend( - _process_p2p_tuple_or_tensor( + _batch_p2p_tuple_or_tensor( tensor_send_next, _send_on_calc_stream, src_rank, @@ -410,9 +352,9 @@ def _p2p_helper( ) ) if tensor_recv_next is not None: - dst_rank = _hcg._get_p2p_next_rank() + dst_rank = hcg._get_p2p_next_rank() ops.extend( - _process_p2p_tuple_or_tensor( + _batch_p2p_tuple_or_tensor( tensor_recv_next, _recv_on_calc_stream, dst_rank, @@ -427,9 +369,9 @@ def _p2p_helper( # When using this order, the environment variable # 'PADDLE_P2P_SYNC_SEND' should be set True if tensor_recv_prev is not None: - dst_rank = _hcg._get_p2p_prev_rank() + dst_rank = hcg._get_p2p_prev_rank() ops.extend( - _process_p2p_tuple_or_tensor( + _batch_p2p_tuple_or_tensor( tensor_recv_prev, _recv_on_calc_stream, dst_rank, @@ -439,9 +381,9 @@ def _p2p_helper( ) ) if tensor_send_next is not None: - src_rank = _hcg._get_p2p_next_rank() + src_rank = hcg._get_p2p_next_rank() ops.extend( - _process_p2p_tuple_or_tensor( + _batch_p2p_tuple_or_tensor( tensor_send_next, _send_on_calc_stream, src_rank, @@ -451,9 +393,9 @@ def _p2p_helper( ) ) if tensor_recv_next is not None: - dst_rank = _hcg._get_p2p_next_rank() + dst_rank = hcg._get_p2p_next_rank() ops.extend( - _process_p2p_tuple_or_tensor( + _batch_p2p_tuple_or_tensor( tensor_recv_next, _recv_on_calc_stream, dst_rank, @@ -463,9 +405,9 @@ def _p2p_helper( ) ) if tensor_send_prev is not None: - src_rank = _hcg._get_p2p_prev_rank() + src_rank = hcg._get_p2p_prev_rank() ops.extend( - _process_p2p_tuple_or_tensor( + _batch_p2p_tuple_or_tensor( tensor_send_prev, _send_on_calc_stream, src_rank, @@ -477,7 +419,6 @@ def _p2p_helper( if len(ops) > 0: batch_send_recv_on_calc_stream(ops) - if distutils.util.strtobool( os.getenv('FLAGS_p2p_device_synchronize', '0') ): @@ -506,7 +447,176 @@ def _p2p_helper( use_calc_stream=True, ) - return tensor_recv_prev, tensor_recv_next + +def _p2p_ops_tuple_or_tensor(tensors, p2p_func, pp_rank, pp_group): + if not isinstance(tensors, tuple): + tensors = (tensors,) + reqs = [] + for tensor in tensors: + reqs.append(p2p_func(tensor, pp_rank, pp_group)) + return reqs + + +def _p2p_ops( + tensor_send_prev, tensor_recv_prev, tensor_send_next, tensor_recv_next, hcg +): + reqs = [] + group = hcg.get_pipe_parallel_group() + if hcg.get_stage_id() % 2 == 0: + if tensor_send_next is not None: + reqs.extend( + _p2p_ops_tuple_or_tensor( + tensor_send_next, + paddle.distributed.isend, + hcg._get_p2p_next_rank(), + group, + ) + ) + if tensor_recv_prev is not None: + reqs.extend( + _p2p_ops_tuple_or_tensor( + tensor_recv_prev, + paddle.distributed.irecv, + hcg._get_p2p_prev_rank(), + group, + ) + ) + + if tensor_send_prev is not None: + reqs.extend( + _p2p_ops_tuple_or_tensor( + tensor_send_prev, + paddle.distributed.isend, + _hcg._get_p2p_prev_rank(), + group, + ) + ) + + if tensor_recv_next is not None: + reqs.extend( + _p2p_ops_tuple_or_tensor( + tensor_recv_next, + paddle.distributed.irecv, + hcg._get_p2p_next_rank(), + group, + ) + ) + else: + if tensor_recv_prev is not None: + reqs.extend( + _p2p_ops_tuple_or_tensor( + tensor_recv_prev, + paddle.distributed.irecv, + hcg._get_p2p_prev_rank(), + group, + ) + ) + if tensor_send_next is not None: + reqs.extend( + _p2p_ops_tuple_or_tensor( + tensor_send_next, + paddle.distributed.isend, + hcg._get_p2p_next_rank(), + group, + ) + ) + if tensor_recv_next is not None: + reqs.extend( + _p2p_ops_tuple_or_tensor( + tensor_recv_next, + paddle.distributed.irecv, + hcg._get_p2p_next_rank(), + group, + ) + ) + if tensor_send_prev is not None: + reqs.extend( + _p2p_ops_tuple_or_tensor( + tensor_send_prev, + paddle.distributed.isend, + hcg._get_p2p_prev_rank(), + group, + ) + ) + return reqs + + +def _p2p_helper( + tensor_send_next, + tensor_send_prev, + recv_prev, + recv_next, + sync_recv=True, + send_recv_meta=None, + batch_p2p_comm=True, + wait_on_reqs=True, +): + global _hcg + + tensor_recv_prev = None + tensor_recv_next = None + + # send / recv message + assert send_recv_meta is not None, "send_recv_meta should not be None" + recv_shape_msg = send_recv_meta.recv_shape_message + recv_dtype_msg = send_recv_meta.recv_dtype_message + recv_stop_gradient = send_recv_meta.recv_stop_gradient + + send_shape_msg = send_recv_meta.send_shape_message + send_dtype_msg = send_recv_meta.send_dtype_message + + # model parallel message + mp_group = _hcg.get_model_parallel_group() + mp_degree = _hcg.get_model_parallel_world_size() + mp_rank = _hcg.get_model_parallel_rank() + + if recv_prev: + if isinstance(recv_shape_msg, tuple): + tensor_recv_prev = [] + for idx, shape in enumerate(recv_shape_msg): + tmp = paddle.empty( + shape=shape, dtype=number_2_dtype(recv_dtype_msg[idx]) + ) + tmp.stop_gradient = recv_stop_gradient[idx] + tensor_recv_prev.append(tmp) + tensor_recv_prev = tuple(tensor_recv_prev) + else: + tensor_recv_prev = paddle.empty( + shape=recv_shape_msg, dtype=number_2_dtype(recv_dtype_msg) + ) + tensor_recv_prev.stop_gradient = recv_stop_gradient + + if recv_next: + if isinstance(send_shape_msg, tuple): + tensor_recv_next = [] + for idx, shape in enumerate(send_shape_msg): + tensor_recv_next.append( + paddle.empty( + shape=shape, dtype=number_2_dtype(send_dtype_msg[idx]) + ) + ) + tensor_recv_next = tuple(tensor_recv_next) + else: + tensor_recv_next = paddle.empty( + shape=send_shape_msg, dtype=number_2_dtype(send_dtype_msg) + ) + + p2p_func = _batched_p2p_ops if batch_p2p_comm else _p2p_ops + reqs = p2p_func( + tensor_send_prev, + tensor_recv_prev, + tensor_send_next, + tensor_recv_next, + _hcg, + ) + + # NOTE(shenliang03): batch_p2p_comm no need wait because of using calculate stream + if wait_on_reqs and not batch_p2p_comm and len(reqs) > 0: + for req in reqs: + req.wait() + reqs = None + + return tensor_recv_prev, tensor_recv_next, reqs class P2pHelper: @@ -527,7 +637,7 @@ def _recv_meta(self): self._send_recv_meta.recv_meta(_hcg.get_pipe_parallel_group()) self._send_recv_meta.has_recv_meta = self._use_cache - def recv_forward(self, pp_first_stage, sync_recv=True): + def recv_forward(self, pp_first_stage, sync_recv=True, batch_p2p_comm=True): global _timers if _timers is not None: _timers("recv_forward").start() @@ -536,38 +646,40 @@ def recv_forward(self, pp_first_stage, sync_recv=True): else: self._recv_meta() - input_tensor, _ = _p2p_helper( + input_tensor, _, _ = _p2p_helper( tensor_send_next=None, tensor_send_prev=None, recv_prev=True, recv_next=False, sync_recv=sync_recv, send_recv_meta=self._send_recv_meta, + batch_p2p_comm=batch_p2p_comm, ) if _timers is not None: _timers("recv_forward").stop() return input_tensor - def recv_backward(self, pp_last_stage, sync_recv=True): + def recv_backward(self, pp_last_stage, sync_recv=True, batch_p2p_comm=True): global _timers if _timers is not None: _timers("recv_backward").start() if pp_last_stage: output_tensor_grad = None else: - _, output_tensor_grad = _p2p_helper( + _, output_tensor_grad, _ = _p2p_helper( tensor_send_next=None, tensor_send_prev=None, recv_prev=False, recv_next=True, sync_recv=sync_recv, send_recv_meta=self._send_recv_meta, + batch_p2p_comm=batch_p2p_comm, ) if _timers is not None: _timers("recv_backward").stop() return output_tensor_grad - def send_forward(self, output_tensor, pp_last_stage): + def send_forward(self, output_tensor, pp_last_stage, batch_p2p_comm=True): global _timers if _timers is not None: _timers("send_forward").start() @@ -580,11 +692,14 @@ def send_forward(self, output_tensor, pp_last_stage): recv_prev=False, recv_next=False, send_recv_meta=self._send_recv_meta, + batch_p2p_comm=batch_p2p_comm, ) if _timers is not None: _timers("send_forward").stop() - def send_backward(self, input_tensor_grad, pp_first_stage): + def send_backward( + self, input_tensor_grad, pp_first_stage, batch_p2p_comm=True + ): global _timers if _timers is not None: _timers("send_backward").start() @@ -595,48 +710,60 @@ def send_backward(self, input_tensor_grad, pp_first_stage): recv_prev=False, recv_next=False, send_recv_meta=self._send_recv_meta, + batch_p2p_comm=batch_p2p_comm, ) if _timers is not None: _timers("send_backward").stop() - def send_forward_recv_backward(self, output_tensor, pp_last_stage): + def send_forward_recv_backward( + self, output_tensor, pp_last_stage, batch_p2p_comm=True + ): global _timers if _timers is not None: _timers("send_forward_recv_backward").start() if pp_last_stage: output_tensor_grad = None else: - _, output_tensor_grad = _p2p_helper( + _, output_tensor_grad, _ = _p2p_helper( tensor_send_next=output_tensor, tensor_send_prev=None, recv_prev=False, recv_next=True, send_recv_meta=self._send_recv_meta, + batch_p2p_comm=batch_p2p_comm, ) if _timers is not None: _timers("send_forward_recv_backward").stop() return output_tensor_grad - def send_backward_recv_forward(self, input_tensor_grad, pp_first_stage): + def send_backward_recv_forward( + self, input_tensor_grad, pp_first_stage, batch_p2p_comm=True + ): global _timers if _timers is not None: _timers("send_backward_recv_forward").start() if pp_first_stage: input_tensor = None else: - input_tensor, _ = _p2p_helper( + input_tensor, _, _ = _p2p_helper( tensor_send_next=None, tensor_send_prev=input_tensor_grad, recv_prev=True, recv_next=False, send_recv_meta=self._send_recv_meta, + batch_p2p_comm=batch_p2p_comm, ) if _timers is not None: _timers("send_backward_recv_forward").stop() return input_tensor def send_forward_backward_recv_forward_backward( - self, output_tensor, input_tensor_grad, recv_prev, recv_next + self, + output_tensor, + input_tensor_grad, + recv_prev, + recv_next, + batch_p2p_comm=True, ): # always have to send dtype info to downstream global _timers @@ -648,19 +775,26 @@ def send_forward_backward_recv_forward_backward( if recv_prev: self._recv_meta() - input_tensor, output_tensor_grad = _p2p_helper( + input_tensor, output_tensor_grad, _ = _p2p_helper( tensor_send_next=output_tensor, tensor_send_prev=input_tensor_grad, recv_prev=recv_prev, recv_next=recv_next, sync_recv=False, send_recv_meta=self._send_recv_meta, + batch_p2p_comm=batch_p2p_comm, ) if _timers is not None: _timers("send_forward_backward_recv_forward_backward").stop() return input_tensor, output_tensor_grad - def send_forward_recv_forward(self, output_tensor, recv_prev): + def send_forward_recv_forward( + self, + output_tensor, + recv_prev, + batch_p2p_comm=True, + overlap_p2p_comm=False, + ): # always have to send dtype info to downstream global _timers if _timers is not None: @@ -672,32 +806,48 @@ def send_forward_recv_forward(self, output_tensor, recv_prev): if recv_prev: self._recv_meta() - input_tensor, _ = _p2p_helper( + input_tensor, _, wait_handles = _p2p_helper( tensor_send_next=output_tensor, tensor_send_prev=None, recv_prev=recv_prev, recv_next=False, sync_recv=False, send_recv_meta=self._send_recv_meta, + batch_p2p_comm=batch_p2p_comm, + wait_on_reqs=(not overlap_p2p_comm), ) if _timers is not None: _timers("send_forward_recv_forward").stop() + + if overlap_p2p_comm: + return input_tensor, wait_handles return input_tensor - def send_backward_recv_backward(self, input_tensor_grad, recv_next): + def send_backward_recv_backward( + self, + input_tensor_grad, + recv_next, + batch_p2p_comm=True, + overlap_p2p_comm=False, + ): global _timers if _timers is not None: _timers("send_backward_recv_backward").start() - _, output_tensor_grad = _p2p_helper( + _, output_tensor_grad, wait_handles = _p2p_helper( tensor_send_next=None, tensor_send_prev=input_tensor_grad, recv_prev=False, recv_next=recv_next, sync_recv=False, send_recv_meta=self._send_recv_meta, + batch_p2p_comm=batch_p2p_comm, + wait_on_reqs=(not overlap_p2p_comm), ) if _timers is not None: _timers("send_backward_recv_backward").stop() + + if overlap_p2p_comm: + return output_tensor_grad, wait_handles return output_tensor_grad def __repr__(self):