Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions paddlenlp/transformers/deepseek_v2/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -2355,10 +2355,11 @@ def compute_loss(preds, labels):
masked_lm_loss > 0, paddle.ones_like(masked_lm_loss), paddle.zeros_like(masked_lm_loss)
)
count = paddle.sum(binary_sequence)
if count == 0:
loss = paddle.sum(masked_lm_loss * binary_sequence)
else:
loss = paddle.sum(masked_lm_loss * binary_sequence) / count
loss = paddle.where(
count == 0,
paddle.sum(masked_lm_loss * binary_sequence),
paddle.sum(masked_lm_loss * binary_sequence) / count,
)
return loss

def add_loss(main_loss, loss):
Expand Down
4 changes: 1 addition & 3 deletions paddlenlp/transformers/fused_a2a.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,11 +108,9 @@ def fused_dispatch_forward_func(
allocate_on_comm_stream=False,
)

tokens_per_expert = paddle.to_tensor(num_recv_tokens_per_expert_list)

states = dict()
states["dispatched_indices"] = recv_token_indices
states["tokens_per_expert"] = tokens_per_expert
states["tokens_per_expert"] = num_recv_tokens_per_expert_list
states["handle"] = handle

return recv_x, recv_token_probs, states, event
Expand Down
7 changes: 3 additions & 4 deletions paddlenlp/transformers/moe_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,6 @@ def __init__(self, config, moe_num_experts, expert_class, expert_kwargs, gate, m

def expert_forward(self, dispatched_input, tokens_per_expert):
outputs = []
tokens_per_expert = tokens_per_expert.tolist()
# print(f"all tokens: {sum(tokens_per_expert)}, detail: {tokens_per_expert}")
chunks = paddle.split(dispatched_input, num_or_sections=tokens_per_expert, axis=0)
for chunk, expert in zip(chunks, self.experts):
Expand All @@ -369,13 +368,13 @@ def forward(self, hidden_states: paddle.Tensor):
(
dispatched_input,
tokens_per_expert,
reversed_mapping_for_combine,
dispatched_routing_map,
token_permuted_indices,
prob_permuted_indices,
dispatched_probs,
) = self.token_dispatcher.token_permutation(hidden_states, probs, routing_map)
expert_output = self.expert_forward(dispatched_input, tokens_per_expert)
output, _ = self.token_dispatcher.token_unpermutation(
expert_output, reversed_mapping_for_combine, dispatched_routing_map, dispatched_probs, None
expert_output, token_permuted_indices, prob_permuted_indices, dispatched_probs, None
)
return output, l_aux, l_zloss

Expand Down
43 changes: 20 additions & 23 deletions paddlenlp/transformers/moe_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,21 @@
import paddle


def topk_to_permuted_indices(x, num_tokens_per_expert_list, topk):
x = paddle.flatten(x)
prob_permuted_indices = paddle.concat(
[
paddle.tensor.search._restrict_nonzero(x == i, total_true_num)
for i, total_true_num in enumerate(num_tokens_per_expert_list)
]
).flatten()
token_permuted_indices = prob_permuted_indices // topk
return token_permuted_indices, prob_permuted_indices


def permute(
tokens,
routing_map,
token_permuted_indices,
drop_and_pad: bool = False,
):
"""Permute the tokens and probs based on the mask.
Expand All @@ -29,33 +41,21 @@ def permute(

Args:
tokens (paddle.Tensor): The input token tensor, [num_tokens, hidden].
routing_map (paddle.Tensor): The sparse token to expert mapping, [num_tokens, num_experts].
drop_and_pad (bool, optional): Whether or not the token dispatcher uses token-drop
and pads the number of tokens to the expert capacity.
"""
assert not drop_and_pad, "token-drop and pads is not supported"
num_tokens, hidden = tokens.shape
num_experts = routing_map.shape[1]

# mask [num_tokens, num_experts] -> [num_experts, num_tokens]
routing_map = routing_map.cast(paddle.bool).T.contiguous()

# Create a dense expert-to-token mapping from the sparse token-to-expert mapping
token_indices = paddle.arange(num_tokens).unsqueeze(0).expand([num_experts, -1])
sorted_indices = token_indices.masked_select(routing_map)

# use the mapping to permute the tokens
permuted_input = tokens.index_select(axis=0, index=sorted_indices)
permuted_input = tokens.index_select(axis=0, index=token_permuted_indices)

return permuted_input, sorted_indices
return permuted_input


def unpermute(
permuted_tokens: paddle.Tensor,
sorted_indices: paddle.Tensor,
token_permuted_indices: paddle.Tensor,
prob_permuted_indices: paddle.Tensor,
restore_shape: paddle.shape,
probs: paddle.Tensor = None,
routing_map: paddle.Tensor = None,
drop_and_pad: bool = False,
):
"""
Expand All @@ -64,11 +64,9 @@ def unpermute(

Args:
permuted_tokens (paddle.Tensor): The permuted token tensor.
sorted_indices (paddle.Tensor): The indices used to sort the tokens.
token_permuted_indices (paddle.Tensor): The indices used to sort the tokens.
restore_shape (paddle.shape): The shape of the unpermuted tensor.
probs (paddle.Tensor, optional): The unpermuted probs tensor,
routing_map (paddle.Tensor, optional): Token to expert mapping, shape
[num_tokens, num_experts].
drop_and_pad (bool, optional): Whether or not the token dispatcher uses token-drop
and pads the number of tokens to the expert capacity.

Expand All @@ -79,16 +77,15 @@ def unpermute(
_, hidden = restore_shape

if probs is not None:
assert routing_map is not None, "Mask must be provided to permute the probs."
permuted_probs = probs.T.contiguous().masked_select(routing_map.T.contiguous())
permuted_probs = probs.flatten().index_select(axis=0, index=prob_permuted_indices)
permuted_tokens = permuted_tokens * permuted_probs.unsqueeze(-1)

# Create an output tensor filled with zeros
output_tokens = paddle.zeros(restore_shape, dtype=permuted_tokens.dtype)
# Scatter add the permuted_input back to the original positions
output_tokens.put_along_axis_(
axis=0,
indices=sorted_indices.unsqueeze(1).expand([-1, hidden]),
indices=token_permuted_indices.unsqueeze(1).expand([-1, hidden]),
values=permuted_tokens,
reduce="add",
include_self=True,
Expand Down
65 changes: 29 additions & 36 deletions paddlenlp/transformers/token_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from paddle.distributed.communication.group import Group

from .fused_a2a import fused_combine, fused_dispatch
from .moe_utils import permute, unpermute
from .moe_utils import permute, topk_to_permuted_indices, unpermute


class _DeepepManager:
Expand Down Expand Up @@ -104,27 +104,29 @@ def combine(self, hidden_states: paddle.Tensor) -> paddle.Tensor:
return hidden_states

def get_permuted_hidden_states_by_experts(
self, hidden_states: paddle.Tensor, dispatched_indices: paddle.Tensor, dispatched_probs: paddle.Tensor
self, hidden_states: paddle.Tensor, dispatched_indices: paddle.Tensor, tokens_per_expert_list: list
) -> paddle.Tensor:
dispatched_routing_map, dispatched_probs = self._indices_to_multihot(dispatched_indices, dispatched_probs)
self.hidden_shape_before_permute = hidden_states.shape
hidden_states, reversed_mapping_for_combine = permute(hidden_states, dispatched_routing_map)
return hidden_states, dispatched_routing_map, dispatched_probs, reversed_mapping_for_combine
token_permuted_indices, prob_permuted_indices = topk_to_permuted_indices(
dispatched_indices, tokens_per_expert_list, self.router_topk
)
hidden_states = permute(hidden_states, token_permuted_indices)
return hidden_states, token_permuted_indices, prob_permuted_indices

def get_restored_hidden_states_by_experts(
self,
hidden_states: paddle.Tensor,
reversed_mapping_for_combine: paddle.Tensor,
dispatched_routing_map: paddle.Tensor,
token_permuted_indices: paddle.Tensor,
prob_permuted_indices: paddle.Tensor,
dispatched_probs: paddle.Tensor,
) -> paddle.Tensor:
input_dtype = hidden_states.dtype
assert dispatched_probs.dtype == paddle.float32, "DeepEP only supports float32 probs"
hidden_states = unpermute(
hidden_states,
reversed_mapping_for_combine,
permuted_tokens=hidden_states,
token_permuted_indices=token_permuted_indices,
prob_permuted_indices=prob_permuted_indices,
restore_shape=self.hidden_shape_before_permute,
routing_map=dispatched_routing_map,
probs=dispatched_probs,
)
return hidden_states.to(input_dtype)
Expand Down Expand Up @@ -167,25 +169,19 @@ def pre_dispatch(self, hidden_states, probs, routing_map):
token_probs, token_indices = paddle.topk(probs, self._comm_manager.router_topk, axis=-1)
return hidden_states, token_indices, token_probs

def post_dispatch(self, hidden_states, dispatched_indices, dispatched_probs):
def post_dispatch(self, hidden_states, dispatched_indices, tokens_per_expert_list):
(
global_input_tokens,
dispatched_routing_map,
dispatched_probs,
reversed_mapping_for_combine,
token_permuted_indices,
prob_permuted_indices,
) = self._comm_manager.get_permuted_hidden_states_by_experts(
hidden_states, dispatched_indices, dispatched_probs
)
return (
global_input_tokens,
reversed_mapping_for_combine,
dispatched_routing_map,
dispatched_probs,
hidden_states, dispatched_indices, tokens_per_expert_list
)
return (global_input_tokens, token_permuted_indices, prob_permuted_indices)

def pre_combine(self, hidden_states, reversed_mapping_for_combine, dispatched_routing_map, dispatched_probs):
def pre_combine(self, hidden_states, token_permuted_indices, prob_permuted_indices, dispatched_probs):
hidden_states = self._comm_manager.get_restored_hidden_states_by_experts(
hidden_states, reversed_mapping_for_combine, dispatched_routing_map, dispatched_probs
hidden_states, token_permuted_indices, prob_permuted_indices, dispatched_probs
)
return hidden_states

Expand All @@ -197,35 +193,32 @@ def token_permutation(
self, hidden_states: paddle.Tensor, probs: paddle.Tensor, routing_map: paddle.Tensor
) -> Tuple[paddle.Tensor, paddle.Tensor]:
hidden_states, token_indices, token_probs = self.pre_dispatch(hidden_states, probs, routing_map)
hidden_states, tokens_per_expert, dispatched_indices, dispatched_probs = self._comm_manager.dispatch(
hidden_states, tokens_per_expert_list, dispatched_indices, dispatched_probs = self._comm_manager.dispatch(
hidden_states, token_indices, token_probs
)
(
global_input_tokens,
reversed_mapping_for_combine,
dispatched_routing_map,
dispatched_probs,
) = self.post_dispatch(hidden_states, dispatched_indices, dispatched_probs)
(global_input_tokens, token_permuted_indices, prob_permuted_indices) = self.post_dispatch(
hidden_states, dispatched_indices, tokens_per_expert_list
)

return (
global_input_tokens,
tokens_per_expert,
reversed_mapping_for_combine,
dispatched_routing_map,
tokens_per_expert_list,
token_permuted_indices,
prob_permuted_indices,
dispatched_probs,
)

def token_unpermutation(
self,
hidden_states: paddle.Tensor,
reversed_mapping_for_combine,
dispatched_routing_map,
token_permuted_indices,
prob_permuted_indices,
dispatched_probs,
bias: Optional[paddle.Tensor] = None,
) -> Tuple[paddle.Tensor, Optional[paddle.Tensor]]:
assert bias is None, "Bias is not supported in MoEFlexTokenDispatcher"
hidden_states = self.pre_combine(
hidden_states, reversed_mapping_for_combine, dispatched_routing_map, dispatched_probs
hidden_states, token_permuted_indices, prob_permuted_indices, dispatched_probs
)
hidden_states = self._comm_manager.combine(hidden_states)

Expand Down
Loading