Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 25 additions & 16 deletions paddlenlp/transformers/deepseek_v2/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@
import paddle.nn.functional as F
from paddle import Tensor, nn
from paddle.distributed import fleet
from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker
from paddle.distributed.communication.reduce import ReduceOp
from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker
from paddle.distributed.fleet.recompute.recompute import recompute
from paddle.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

Expand Down Expand Up @@ -799,7 +799,7 @@ def __init__(self, config: DeepseekV2Config):

for p in self.experts.parameters():
setattr(p, "color", {"color": "moe_expert", "group": moe_grad_group})

setattr(p, "is_moe_param", True)
self.alpha = config.aux_loss_alpha
if config.n_shared_experts is not None:
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
Expand Down Expand Up @@ -851,6 +851,7 @@ def __init__(self, config: DeepseekV2Config):

for p in self.experts.parameters():
setattr(p, "color", {"color": "moe_expert", "group": moe_grad_group})
setattr(p, "is_moe_param", True)

self.alpha = config.aux_loss_alpha
if config.n_shared_experts is not None:
Expand Down Expand Up @@ -895,7 +896,9 @@ def __init__(self, config: DeepseekV2Config, layerwise_recompute: bool = False):
self.num_heads = config.num_attention_heads
self.num_local_heads = self.num_heads
if config.tensor_parallel_degree > 1:
assert self.num_heads % config.tensor_parallel_degree == 0, f"Attention head num ({self.num_heads}) is not divisible by tensor_parallel_degree ({config.tensor_parallel_degree})."
assert (
self.num_heads % config.tensor_parallel_degree == 0
), f"Attention head num ({self.num_heads}) is not divisible by tensor_parallel_degree ({config.tensor_parallel_degree})."
self.num_local_heads = self.num_heads // config.tensor_parallel_degree

self.max_position_embeddings = config.max_position_embeddings
Expand Down Expand Up @@ -1067,7 +1070,12 @@ def forward(

if self.sequence_parallel:
target_query_shape = [bsz, self.seq_length, self.num_local_heads, self.q_head_dim]
target_key_value_shape = [bsz, self.seq_length, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim]
target_key_value_shape = [
bsz,
self.seq_length,
self.num_local_heads,
self.qk_nope_head_dim + self.v_head_dim,
]
else:
target_query_shape = [0, 0, self.num_heads, self.q_head_dim]
target_key_value_shape = [0, 0, self.num_heads, self.qk_nope_head_dim + self.v_head_dim]
Expand Down Expand Up @@ -1153,7 +1161,6 @@ def forward(
if attn_output.shape != ori_shape:
attn_output = attn_output.reshape(ori_shape)


if not output_attentions:
attn_weights = None

Expand Down Expand Up @@ -1511,7 +1518,7 @@ def forward(
hidden_states = self.hnorm(hidden_states)
nextn_hidden_state = self.enorm(nextn_hidden_state)

hidden_states = self.eh_proj(paddle.concat([hidden_states, nextn_hidden_state], axis=-1))
hidden_states = self.eh_proj(paddle.concat([nextn_hidden_state, hidden_states], axis=-1))

layer_outputs = super(DeepseekV2MTPLayer, self).forward(
hidden_states,
Expand Down Expand Up @@ -1711,10 +1718,13 @@ def get_tensor_parallel_split_mappings(num_layers):

return final_actions

mappings = get_tensor_parallel_split_mappings(config.num_hidden_layers)
mappings = get_tensor_parallel_split_mappings(config.num_hidden_layers + 2)

return mappings

def get_tensor_parallel_mappings(self, is_split=True):
return type(self)._get_tensor_parallel_mappings(self.config, is_split)

def _init_weights(self, layer):
return
if self.config.tensor_parallel_degree > 1:
Expand Down Expand Up @@ -1988,7 +1998,7 @@ def forward(
if self.config.sequence_parallel:
# [bs, seq_len, num_head * head_dim] -> [bs * seq_len, num_head * head_dim]
bs, seq_len, hidden_size = inputs_embeds.shape
inputs_embeds = paddle.transpose(inputs_embeds, [1, 0, 2]) # [B, S, H] --> [S, B, H]
inputs_embeds = paddle.transpose(inputs_embeds, [1, 0, 2]) # [B, S, H] --> [S, B, H]
# inputs_embeds = paddle.reshape(inputs_embeds, [bs * seq_len, hidden_size])
# [seq_len * bs / n, num_head * head_dim] (n is mp parallelism)
inputs_embeds = ScatterOp.apply(inputs_embeds)
Expand Down Expand Up @@ -2071,7 +2081,7 @@ def forward(

if self.config.sequence_parallel:
hidden_states = GatherOp.apply(hidden_states)
hidden_states = paddle.transpose(hidden_states, [1, 0, 2]) # [S, B, H] --> [B, S, H]
hidden_states = paddle.transpose(hidden_states, [1, 0, 2]) # [S, B, H] --> [B, S, H]
# hidden_states = hidden_states.reshape([-1, seq_length, hidden_states.shape[-1]])

inputs_embeds_cur_depth = paddle.concat(
Expand Down Expand Up @@ -2173,7 +2183,7 @@ def add_loss(main_loss, loss):
seq_length = masked_lm_labels.shape[1]

if self.config.sequence_parallel:
masked_lm_labels = masked_lm_labels.transpose([1, 0]) # [B, S] --> [S, B]
masked_lm_labels = masked_lm_labels.transpose([1, 0]) # [B, S] --> [S, B]
masked_lm_labels = ScatterOp.apply(masked_lm_labels)

loss = compute_loss(prediction_scores, masked_lm_labels)
Expand All @@ -2188,16 +2198,15 @@ def add_loss(main_loss, loss):
masked_lm_labels_cur_depth = masked_lm_labels_ori[:, (depth + 1) : (depth + 1 + seq_length)]

if self.config.sequence_parallel:
masked_lm_labels_cur_depth = masked_lm_labels_cur_depth.transpose([1, 0]) # [B, S] --> [S, B]
masked_lm_labels_cur_depth = masked_lm_labels_cur_depth.transpose([1, 0]) # [B, S] --> [S, B]
masked_lm_labels_cur_depth = ScatterOp.apply(masked_lm_labels_cur_depth)

res_cur_depth = compute_loss(prediction_scores_cur_depth, masked_lm_labels_cur_depth)

if self.config.sequence_parallel:
res_cur_depth = res_cur_depth * self.seq_para_scale
dist.all_reduce(res_cur_depth, op=ReduceOp.SUM, group=self.mp_group)


mtp_loss_res.append(res_cur_depth)
loss = add_loss(loss, self.config.num_nextn_predict_lambda * sum([x for x in mtp_loss_res]) / len(mtp_loss_res)) # fmt: skip

Expand Down Expand Up @@ -2245,9 +2254,9 @@ def __init__(self, config: DeepseekV2Config):
def forward(self, hidden_states, tensor_parallel_output=None):

# if self.config.sequence_parallel:
# hidden_states = GatherOp.apply(hidden_states)
# hidden_states = paddle.transpose(hidden_states, [1, 0, 2])
# hidden_states = paddle.reshape_(hidden_states, [-1, self.seq_length, self.config.hidden_size])
# hidden_states = GatherOp.apply(hidden_states)
# hidden_states = paddle.transpose(hidden_states, [1, 0, 2])
# hidden_states = paddle.reshape_(hidden_states, [-1, self.seq_length, self.config.hidden_size])

if tensor_parallel_output is None:
tensor_parallel_output = self.config.tensor_parallel_output
Expand Down
14 changes: 11 additions & 3 deletions paddlenlp/transformers/moe_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,11 +578,19 @@ def topkgating_nodrop(self, gates: paddle.Tensor):
# get topk mask
mask = paddle.zeros_like(gates).put_along_axis(top_idx, paddle.to_tensor(1.0), axis=1)

# hongyu fix start
gates_masked = gates * mask
gates_s = paddle.sum(gates_masked, axis=-1, keepdim=True)
denom_s = paddle.clip(gates_s, min=paddle.finfo(gates_masked.dtype).eps)

if self.norm_topk_prob:
gates_masked = gates_masked / denom_s
gates_masked *= self.routed_scaling_factor
# hongyu fix end
if hasattr(self.config, "seq_aux") and self.config.seq_aux:
l_aux = self._cal_seq_aux_loss(gates_ori, self.top_k, top_idx)
else:
l_aux = self._cal_aux_loss(gates, mask)

exp_counts = paddle.sum(mask.cast(paddle.int64), axis=0)
topk_masked_gates = paddle.zeros_like(gates).put_along_axis(top_idx, top_gate, axis=1)
return topk_masked_gates, mask, exp_counts, l_aux, l_zloss
# topk_masked_gates = paddle.zeros_like(gates).put_along_axis(top_idx, top_gate, axis=1)
return gates_masked, mask, exp_counts, l_aux, l_zloss
29 changes: 25 additions & 4 deletions paddlenlp/transformers/moe_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,6 @@ def __init__(
is_fleet_init = True
except AttributeError:
is_fleet_init = False

if is_fleet_init and dist.get_world_size() > 1:
if moe_group == "data":
self.moe_group = dist.fleet.get_hybrid_communicate_group().get_data_parallel_group()
Expand All @@ -198,7 +197,6 @@ def __init__(
self.expert_parallel_degree = 1
self.moe_num_experts_per_device = self.moe_num_experts
self.is_dummy_moe = True

self.all_to_all_dropout = all_to_all_dropout
self.enable_recompute = False

Expand Down Expand Up @@ -348,21 +346,34 @@ def __init__(self, config, moe_num_experts, expert_class, expert_kwargs, gate, m
self.moe_router_topk = gate.top_k
self.moe_num_experts = moe_num_experts
self.num_local_experts = moe_num_experts // self.ep_size
self.moe_rank = dist.get_rank(self.moe_group)
self.moe_rank = 0 if self.moe_rank < 0 else self.moe_rank
self.token_dispatcher = MoEFlexTokenDispatcher(
self.num_local_experts, self.moe_router_topk, self.moe_num_experts, moe_group
)
self.experts = nn.LayerList([expert_class(**expert_kwargs) for _ in range(self.num_local_experts)])
self.expert_parallel_degree = 1 if self.ep_size < 0 else self.ep_size
self.moe_num_experts_per_device = self._parse_moe_expert_parallel(
self.moe_num_experts, self.expert_parallel_degree
)
self.experts = nn.LayerList([])
for i in range(self.moe_num_experts):
if i // self.moe_num_experts_per_device == self.moe_rank:
self.experts.append(expert_class(**expert_kwargs))
else:
self.experts.append(None)
self.router = gate

def expert_forward(self, dispatched_input, tokens_per_expert):
outputs = []
tokens_per_expert = tokens_per_expert.tolist()
# print(f"all tokens: {sum(tokens_per_expert)}, detail: {tokens_per_expert}")
chunks = paddle.split(dispatched_input, num_or_sections=tokens_per_expert, axis=0)
for chunk, expert in zip(chunks, self.experts):
for i, chunk in enumerate(chunks):
chunk = chunk.contiguous()
# assert chunk.shape[0] != 0, "Cannot dispatch empty input"
# print("expert token:", chunk.shape, flush=True)
# assert chunk.shape[0] != 0, "Cannot dispatch empty input"
expert = self.experts[i + self.moe_rank * self.moe_num_experts_per_device]
outputs += [expert(chunk)]

return paddle.concat(outputs, axis=0)
Expand All @@ -377,3 +388,13 @@ def forward(self, hidden_states: paddle.Tensor):
expert_output = self.expert_forward(dispatched_input, tokens_per_expert)
output, _ = self.token_dispatcher.token_unpermutation(expert_output, None)
return output, l_aux, l_zloss

def _parse_moe_expert_parallel(self, moe_num_experts, expert_parallel_degree):
assert (
moe_num_experts >= expert_parallel_degree
), f"expert moe_num_experts={moe_num_experts} >= moe_world_size={expert_parallel_degree}"
assert (
moe_num_experts % expert_parallel_degree == 0
), f"expert moe_num_experts={moe_num_experts} % moe_world_size={expert_parallel_degree} == 0"
moe_num_experts_per_device = moe_num_experts // expert_parallel_degree
return moe_num_experts_per_device
Loading