diff --git a/paddlenlp/transformers/deepseek_v2/modeling.py b/paddlenlp/transformers/deepseek_v2/modeling.py index a08899111b1b..6e9f17f39205 100644 --- a/paddlenlp/transformers/deepseek_v2/modeling.py +++ b/paddlenlp/transformers/deepseek_v2/modeling.py @@ -33,8 +33,8 @@ import paddle.nn.functional as F from paddle import Tensor, nn from paddle.distributed import fleet -from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker from paddle.distributed.communication.reduce import ReduceOp +from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker from paddle.distributed.fleet.recompute.recompute import recompute from paddle.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss @@ -799,7 +799,7 @@ def __init__(self, config: DeepseekV2Config): for p in self.experts.parameters(): setattr(p, "color", {"color": "moe_expert", "group": moe_grad_group}) - + setattr(p, "is_moe_param", True) self.alpha = config.aux_loss_alpha if config.n_shared_experts is not None: intermediate_size = config.moe_intermediate_size * config.n_shared_experts @@ -851,6 +851,7 @@ def __init__(self, config: DeepseekV2Config): for p in self.experts.parameters(): setattr(p, "color", {"color": "moe_expert", "group": moe_grad_group}) + setattr(p, "is_moe_param", True) self.alpha = config.aux_loss_alpha if config.n_shared_experts is not None: @@ -895,7 +896,9 @@ def __init__(self, config: DeepseekV2Config, layerwise_recompute: bool = False): self.num_heads = config.num_attention_heads self.num_local_heads = self.num_heads if config.tensor_parallel_degree > 1: - assert self.num_heads % config.tensor_parallel_degree == 0, f"Attention head num ({self.num_heads}) is not divisible by tensor_parallel_degree ({config.tensor_parallel_degree})." + assert ( + self.num_heads % config.tensor_parallel_degree == 0 + ), f"Attention head num ({self.num_heads}) is not divisible by tensor_parallel_degree ({config.tensor_parallel_degree})." self.num_local_heads = self.num_heads // config.tensor_parallel_degree self.max_position_embeddings = config.max_position_embeddings @@ -1067,7 +1070,12 @@ def forward( if self.sequence_parallel: target_query_shape = [bsz, self.seq_length, self.num_local_heads, self.q_head_dim] - target_key_value_shape = [bsz, self.seq_length, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim] + target_key_value_shape = [ + bsz, + self.seq_length, + self.num_local_heads, + self.qk_nope_head_dim + self.v_head_dim, + ] else: target_query_shape = [0, 0, self.num_heads, self.q_head_dim] target_key_value_shape = [0, 0, self.num_heads, self.qk_nope_head_dim + self.v_head_dim] @@ -1153,7 +1161,6 @@ def forward( if attn_output.shape != ori_shape: attn_output = attn_output.reshape(ori_shape) - if not output_attentions: attn_weights = None @@ -1511,7 +1518,7 @@ def forward( hidden_states = self.hnorm(hidden_states) nextn_hidden_state = self.enorm(nextn_hidden_state) - hidden_states = self.eh_proj(paddle.concat([hidden_states, nextn_hidden_state], axis=-1)) + hidden_states = self.eh_proj(paddle.concat([nextn_hidden_state, hidden_states], axis=-1)) layer_outputs = super(DeepseekV2MTPLayer, self).forward( hidden_states, @@ -1711,10 +1718,13 @@ def get_tensor_parallel_split_mappings(num_layers): return final_actions - mappings = get_tensor_parallel_split_mappings(config.num_hidden_layers) + mappings = get_tensor_parallel_split_mappings(config.num_hidden_layers + 2) return mappings + def get_tensor_parallel_mappings(self, is_split=True): + return type(self)._get_tensor_parallel_mappings(self.config, is_split) + def _init_weights(self, layer): return if self.config.tensor_parallel_degree > 1: @@ -1988,7 +1998,7 @@ def forward( if self.config.sequence_parallel: # [bs, seq_len, num_head * head_dim] -> [bs * seq_len, num_head * head_dim] bs, seq_len, hidden_size = inputs_embeds.shape - inputs_embeds = paddle.transpose(inputs_embeds, [1, 0, 2]) # [B, S, H] --> [S, B, H] + inputs_embeds = paddle.transpose(inputs_embeds, [1, 0, 2]) # [B, S, H] --> [S, B, H] # inputs_embeds = paddle.reshape(inputs_embeds, [bs * seq_len, hidden_size]) # [seq_len * bs / n, num_head * head_dim] (n is mp parallelism) inputs_embeds = ScatterOp.apply(inputs_embeds) @@ -2071,7 +2081,7 @@ def forward( if self.config.sequence_parallel: hidden_states = GatherOp.apply(hidden_states) - hidden_states = paddle.transpose(hidden_states, [1, 0, 2]) # [S, B, H] --> [B, S, H] + hidden_states = paddle.transpose(hidden_states, [1, 0, 2]) # [S, B, H] --> [B, S, H] # hidden_states = hidden_states.reshape([-1, seq_length, hidden_states.shape[-1]]) inputs_embeds_cur_depth = paddle.concat( @@ -2173,7 +2183,7 @@ def add_loss(main_loss, loss): seq_length = masked_lm_labels.shape[1] if self.config.sequence_parallel: - masked_lm_labels = masked_lm_labels.transpose([1, 0]) # [B, S] --> [S, B] + masked_lm_labels = masked_lm_labels.transpose([1, 0]) # [B, S] --> [S, B] masked_lm_labels = ScatterOp.apply(masked_lm_labels) loss = compute_loss(prediction_scores, masked_lm_labels) @@ -2188,16 +2198,15 @@ def add_loss(main_loss, loss): masked_lm_labels_cur_depth = masked_lm_labels_ori[:, (depth + 1) : (depth + 1 + seq_length)] if self.config.sequence_parallel: - masked_lm_labels_cur_depth = masked_lm_labels_cur_depth.transpose([1, 0]) # [B, S] --> [S, B] + masked_lm_labels_cur_depth = masked_lm_labels_cur_depth.transpose([1, 0]) # [B, S] --> [S, B] masked_lm_labels_cur_depth = ScatterOp.apply(masked_lm_labels_cur_depth) res_cur_depth = compute_loss(prediction_scores_cur_depth, masked_lm_labels_cur_depth) - + if self.config.sequence_parallel: res_cur_depth = res_cur_depth * self.seq_para_scale dist.all_reduce(res_cur_depth, op=ReduceOp.SUM, group=self.mp_group) - mtp_loss_res.append(res_cur_depth) loss = add_loss(loss, self.config.num_nextn_predict_lambda * sum([x for x in mtp_loss_res]) / len(mtp_loss_res)) # fmt: skip @@ -2245,9 +2254,9 @@ def __init__(self, config: DeepseekV2Config): def forward(self, hidden_states, tensor_parallel_output=None): # if self.config.sequence_parallel: - # hidden_states = GatherOp.apply(hidden_states) - # hidden_states = paddle.transpose(hidden_states, [1, 0, 2]) - # hidden_states = paddle.reshape_(hidden_states, [-1, self.seq_length, self.config.hidden_size]) + # hidden_states = GatherOp.apply(hidden_states) + # hidden_states = paddle.transpose(hidden_states, [1, 0, 2]) + # hidden_states = paddle.reshape_(hidden_states, [-1, self.seq_length, self.config.hidden_size]) if tensor_parallel_output is None: tensor_parallel_output = self.config.tensor_parallel_output diff --git a/paddlenlp/transformers/moe_gate.py b/paddlenlp/transformers/moe_gate.py index 0ccd96cc1618..4a526feb6acb 100644 --- a/paddlenlp/transformers/moe_gate.py +++ b/paddlenlp/transformers/moe_gate.py @@ -578,11 +578,19 @@ def topkgating_nodrop(self, gates: paddle.Tensor): # get topk mask mask = paddle.zeros_like(gates).put_along_axis(top_idx, paddle.to_tensor(1.0), axis=1) + # hongyu fix start + gates_masked = gates * mask + gates_s = paddle.sum(gates_masked, axis=-1, keepdim=True) + denom_s = paddle.clip(gates_s, min=paddle.finfo(gates_masked.dtype).eps) + + if self.norm_topk_prob: + gates_masked = gates_masked / denom_s + gates_masked *= self.routed_scaling_factor + # hongyu fix end if hasattr(self.config, "seq_aux") and self.config.seq_aux: l_aux = self._cal_seq_aux_loss(gates_ori, self.top_k, top_idx) else: l_aux = self._cal_aux_loss(gates, mask) - exp_counts = paddle.sum(mask.cast(paddle.int64), axis=0) - topk_masked_gates = paddle.zeros_like(gates).put_along_axis(top_idx, top_gate, axis=1) - return topk_masked_gates, mask, exp_counts, l_aux, l_zloss + # topk_masked_gates = paddle.zeros_like(gates).put_along_axis(top_idx, top_gate, axis=1) + return gates_masked, mask, exp_counts, l_aux, l_zloss diff --git a/paddlenlp/transformers/moe_layer.py b/paddlenlp/transformers/moe_layer.py index f65f702d2a3f..723bf525d6df 100644 --- a/paddlenlp/transformers/moe_layer.py +++ b/paddlenlp/transformers/moe_layer.py @@ -175,7 +175,6 @@ def __init__( is_fleet_init = True except AttributeError: is_fleet_init = False - if is_fleet_init and dist.get_world_size() > 1: if moe_group == "data": self.moe_group = dist.fleet.get_hybrid_communicate_group().get_data_parallel_group() @@ -198,7 +197,6 @@ def __init__( self.expert_parallel_degree = 1 self.moe_num_experts_per_device = self.moe_num_experts self.is_dummy_moe = True - self.all_to_all_dropout = all_to_all_dropout self.enable_recompute = False @@ -348,10 +346,21 @@ def __init__(self, config, moe_num_experts, expert_class, expert_kwargs, gate, m self.moe_router_topk = gate.top_k self.moe_num_experts = moe_num_experts self.num_local_experts = moe_num_experts // self.ep_size + self.moe_rank = dist.get_rank(self.moe_group) + self.moe_rank = 0 if self.moe_rank < 0 else self.moe_rank self.token_dispatcher = MoEFlexTokenDispatcher( self.num_local_experts, self.moe_router_topk, self.moe_num_experts, moe_group ) - self.experts = nn.LayerList([expert_class(**expert_kwargs) for _ in range(self.num_local_experts)]) + self.expert_parallel_degree = 1 if self.ep_size < 0 else self.ep_size + self.moe_num_experts_per_device = self._parse_moe_expert_parallel( + self.moe_num_experts, self.expert_parallel_degree + ) + self.experts = nn.LayerList([]) + for i in range(self.moe_num_experts): + if i // self.moe_num_experts_per_device == self.moe_rank: + self.experts.append(expert_class(**expert_kwargs)) + else: + self.experts.append(None) self.router = gate def expert_forward(self, dispatched_input, tokens_per_expert): @@ -359,10 +368,12 @@ def expert_forward(self, dispatched_input, tokens_per_expert): tokens_per_expert = tokens_per_expert.tolist() # print(f"all tokens: {sum(tokens_per_expert)}, detail: {tokens_per_expert}") chunks = paddle.split(dispatched_input, num_or_sections=tokens_per_expert, axis=0) - for chunk, expert in zip(chunks, self.experts): + for i, chunk in enumerate(chunks): chunk = chunk.contiguous() # assert chunk.shape[0] != 0, "Cannot dispatch empty input" # print("expert token:", chunk.shape, flush=True) + # assert chunk.shape[0] != 0, "Cannot dispatch empty input" + expert = self.experts[i + self.moe_rank * self.moe_num_experts_per_device] outputs += [expert(chunk)] return paddle.concat(outputs, axis=0) @@ -377,3 +388,13 @@ def forward(self, hidden_states: paddle.Tensor): expert_output = self.expert_forward(dispatched_input, tokens_per_expert) output, _ = self.token_dispatcher.token_unpermutation(expert_output, None) return output, l_aux, l_zloss + + def _parse_moe_expert_parallel(self, moe_num_experts, expert_parallel_degree): + assert ( + moe_num_experts >= expert_parallel_degree + ), f"expert moe_num_experts={moe_num_experts} >= moe_world_size={expert_parallel_degree}" + assert ( + moe_num_experts % expert_parallel_degree == 0 + ), f"expert moe_num_experts={moe_num_experts} % moe_world_size={expert_parallel_degree} == 0" + moe_num_experts_per_device = moe_num_experts // expert_parallel_degree + return moe_num_experts_per_device