Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 53 additions & 40 deletions paddlenlp/transformers/deepseek_v2/modeling_pp.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
DeepseekV2DecoderLayer,
DeepseekV2LMHead,
DeepseekV2Model,
DeepseekV2MoE,
DeepseekV2MTPLayer,
DeepseekV2PretrainedModel,
DeepseekV2PretrainingCriterion,
Expand Down Expand Up @@ -411,7 +412,7 @@ def attn_forward(self, inputs):
token_probs,
)

def dispatch_forward(self, inputs, previous_event=None):
def dispatch_forward(self, inputs, previous_event=None, async_finish=False, allocate_on_comm_stream=False):
(
inputs_embeds_mtp,
hidden_states,
Expand All @@ -429,7 +430,13 @@ def dispatch_forward(self, inputs, previous_event=None):
dispatched_indices,
dispatched_probs,
) = self.fp8_fusion_moe_node.dispatch_node.forward(
hs_fp8, hs_scale, token_indices, token_probs, previous_event=previous_event, async_finish=True
hs_fp8,
hs_scale,
token_indices,
token_probs,
previous_event=previous_event,
async_finish=async_finish,
allocate_on_comm_stream=allocate_on_comm_stream,
)
return (
inputs_embeds_mtp,
Expand Down Expand Up @@ -458,9 +465,9 @@ def mlp_forward(self, inputs):
)
return (inputs_embeds_mtp, hidden_states, residual, l_aux, hidden_states_out)

def combine_forward(self, inputs):
def combine_forward(self, inputs, async_finish=False):
(inputs_embeds_mtp, hidden_states, residual, l_aux, hidden_states_out) = inputs
output_combie = self.fp8_fusion_moe_node.combine_node.forward(hidden_states_out)
output_combie = self.fp8_fusion_moe_node.combine_node.forward(hidden_states_out, async_finish=async_finish)
return (inputs_embeds_mtp, hidden_states, residual, l_aux, output_combie)

def post_process_forward(self, inputs):
Expand Down Expand Up @@ -490,7 +497,7 @@ def post_process_backward(self, output_grad):
output_combie_grad_scale,
)

def combine_backward(self, output_grad):
def combine_backward(self, output_grad, async_finish=False):
(
inputs_embeds_mtp_grad,
hidden_states_grad,
Expand All @@ -500,7 +507,9 @@ def combine_backward(self, output_grad):
output_combie_grad_scale,
) = output_grad
hidden_states_out_grad, hidden_states_out_grad_scale = self.fp8_fusion_moe_node.combine_node.backward(
output_combie_grad_fp8, output_combie_grad_scale
output_combie_grad_fp8,
output_combie_grad_scale,
async_finish=async_finish,
)
return (
inputs_embeds_mtp_grad,
Expand Down Expand Up @@ -533,7 +542,7 @@ def mlp_backward(self, output_grad):
dispatched_probs_grad,
)

def dispatch_backward(self, output_grad):
def dispatch_backward(self, output_grad, async_finish=False):
(
inputs_embeds_mtp_grad,
hidden_states_grad,
Expand All @@ -543,7 +552,7 @@ def dispatch_backward(self, output_grad):
dispatched_probs_grad,
) = output_grad
hs_fp8_grad, token_probs_grad = self.fp8_fusion_moe_node.dispatch_node.backward(
hs_fp8_dispatched_grad, dispatched_probs_grad
hs_fp8_dispatched_grad, dispatched_probs_grad, async_finish=async_finish
)
return (inputs_embeds_mtp_grad, hidden_states_grad, residual_grad, l_aux_grad, hs_fp8_grad, token_probs_grad)

Expand Down Expand Up @@ -605,32 +614,34 @@ def forward_backward(self, inputs, output_grad):
paddle.base.core.nvprof_nvtx_pop()

paddle.base.core.nvprof_nvtx_push("combine_backward")
output_grad = self.backward_node.combine_backward(output_grad)
output_grad = self.backward_node.combine_backward(output_grad, async_finish=True)
paddle.base.core.nvprof_nvtx_pop()
paddle.base.core.nvprof_nvtx_push("attn_forward")
inputs = self.forward_node.attn_forward(inputs)
paddle.base.core.nvprof_nvtx_pop()

calc_stream_wait(self.backward_node.moe_group.id)
# attn_compute_event = deep_ep.get_event_from_calc_stream(self.forward_node.moe_group.id)
paddle.base.core.nvprof_nvtx_push("dispatch_forward")
inputs = self.forward_node.dispatch_forward(inputs, previous_event=None)
paddle.base.core.nvprof_nvtx_pop()
attn_compute_event = deep_ep.get_event_from_calc_stream(self.forward_node.moe_group.id)
paddle.base.core.nvprof_nvtx_push("mlp_backward")
output_grad = self.backward_node.mlp_backward(output_grad)
paddle.base.core.nvprof_nvtx_pop()
paddle.base.core.nvprof_nvtx_push("dispatch_forward")
inputs = self.forward_node.dispatch_forward(
inputs, previous_event=attn_compute_event, async_finish=True, allocate_on_comm_stream=True
)
paddle.base.core.nvprof_nvtx_pop()

calc_stream_wait(self.forward_node.moe_group.id)
paddle.base.core.nvprof_nvtx_push("dispatch_backward")
output_grad = self.backward_node.dispatch_backward(output_grad)
output_grad = self.backward_node.dispatch_backward(output_grad, async_finish=True)
paddle.base.core.nvprof_nvtx_pop()
paddle.base.core.nvprof_nvtx_push("mlp_forward")
inputs = self.forward_node.mlp_forward(inputs)
paddle.base.core.nvprof_nvtx_pop()

calc_stream_wait(self.backward_node.moe_group.id)
paddle.base.core.nvprof_nvtx_push("combine_forward")
inputs = self.forward_node.combine_forward(inputs)
inputs = self.forward_node.combine_forward(inputs, async_finish=True)
paddle.base.core.nvprof_nvtx_pop()
paddle.base.core.nvprof_nvtx_push("attn_backward")
output_grad = self.backward_node.attn_backward(output_grad)
Expand Down Expand Up @@ -981,32 +992,34 @@ def post_process_compute_for_fusion(self, inputs):
return return_args(hidden_states)

def build_schedule_node(self):
if DSV3_USE_FP8_GEMM:
attn_and_gate_node = ScheduleNode(self.attn_compute_for_fusion, name="attn_and_gate_node")
fp8_fusion_moe_node = FusionMoeNode(
self.mlp.token_dispatcher, self.mlp.experts, name="fp8_fusion_moe_node"
)
post_process_node = ScheduleNode(self.post_process_compute_for_fusion, name="post_process_node")
return FusionFp8DecoderLayerNode(
attn_and_gate_node=attn_and_gate_node,
fp8_fusion_moe_node=fp8_fusion_moe_node,
post_process_node=post_process_node,
mlp_layer=self.mlp,
name="FusionFp8DecoderLayerNode",
)
self.mlp.update_flex_token()
if self.mlp.using_flex_token and isinstance(self.mlp, DeepseekV2MoE):
if DSV3_USE_FP8_GEMM:
attn_and_gate_node = ScheduleNode(self.attn_compute_for_fusion, name="attn_and_gate_node")
fp8_fusion_moe_node = FusionMoeNode(self.mlp, name="fp8_fusion_moe_node")
post_process_node = ScheduleNode(self.post_process_compute_for_fusion, name="post_process_node")
return FusionFp8DecoderLayerNode(
attn_and_gate_node=attn_and_gate_node,
fp8_fusion_moe_node=fp8_fusion_moe_node,
post_process_node=post_process_node,
mlp_layer=self.mlp,
name="FusionFp8DecoderLayerNode",
)
else:
attn_node = ScheduleNode(self.attn_compute, name="attn_node")
mlp_node = ScheduleNode(self.mlp_compute, name="mlp_node")
post_process_node = ScheduleNode(self.post_process_compute, name="post_process_node")
return DecoderLayerNode(
attn_node=attn_node,
dispatch_node=None,
mlp_node=mlp_node,
combine_node=None,
post_process_node=post_process_node,
mlp_layer=self.mlp,
name="DecoderLayerNode",
)
else:
attn_node = ScheduleNode(self.attn_compute, name="attn_node")
mlp_node = ScheduleNode(self.mlp_compute, name="mlp_node")
post_process_node = ScheduleNode(self.post_process_compute, name="post_process_node")
return DecoderLayerNode(
attn_node=attn_node,
dispatch_node=None,
mlp_node=mlp_node,
combine_node=None,
post_process_node=post_process_node,
mlp_layer=self.mlp,
name="DeepseekV2DecoderLayerPipe",
)
return ScheduleNode(self.forward, name="DeepseekV2DecoderLayerPipe")


class DeepseekV2MTPLayerPipe(DeepseekV2MTPLayer):
Expand Down
23 changes: 20 additions & 3 deletions paddlenlp/transformers/fused_a2a.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,10 +276,27 @@ class DispatchNode:
def __init__(self, name="dispatch"):
self.name = name

def forward(self, x, token_indices, token_probs, num_experts, group, previous_event=None, async_finish=False):
def forward(
self,
x,
token_indices,
token_probs,
num_experts,
group,
previous_event=None,
async_finish=False,
allocate_on_comm_stream=False,
):
"""Forward pass of fused dispatch."""
recv_x, recv_token_probs, states, event = fused_dispatch_forward_func(
x, token_indices, token_probs, num_experts, group, previous_event=previous_event, async_finish=async_finish
x,
token_indices,
token_probs,
num_experts,
group,
previous_event=previous_event,
async_finish=async_finish,
allocate_on_comm_stream=allocate_on_comm_stream,
)

self.group = group
Expand Down Expand Up @@ -321,5 +338,5 @@ def forward(self, x, group, handle, previous_event=None, async_finish=False):
def backward(self, grad_output, previous_event=None, async_finish=False):
"""Backward pass of fused combine."""
return fused_combine_backward_func(
grad_output, self.group, self.handle, previous_event=previous_event, async_finish=False
grad_output, self.group, self.handle, previous_event=previous_event, async_finish=async_finish
)
12 changes: 11 additions & 1 deletion paddlenlp/transformers/moe_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,16 @@ def __init__(self, token_dispatcher, name="fp8_dispatch_node"):
self.name = name

@paddle.no_grad()
def forward(self, hs_fp8, hs_scale, token_indices, token_probs, previous_event=None, async_finish=False):
def forward(
self,
hs_fp8,
hs_scale,
token_indices,
token_probs,
previous_event=None,
async_finish=False,
allocate_on_comm_stream=False,
):
# dispatch
(hs_fp8_dispatched, hs_scale_dispatched), dispatched_probs, states = self.dispatch_act_node.forward(
(hs_fp8, hs_scale),
Expand All @@ -494,6 +503,7 @@ def forward(self, hs_fp8, hs_scale, token_indices, token_probs, previous_event=N
self.token_dispatcher._comm_manager.group,
previous_event=previous_event,
async_finish=async_finish,
allocate_on_comm_stream=allocate_on_comm_stream,
)
self.token_dispatcher._comm_manager.handle = states["handle"]
self.token_dispatcher._comm_manager.tokens_per_expert = states["tokens_per_expert"]
Expand Down
Loading