diff --git a/custom_ops/gpu_ops/moe/ep_moe_expert_dispatch.cu b/custom_ops/gpu_ops/moe/ep_moe_expert_dispatch.cu index 8a8fb111697..369a92ee2eb 100644 --- a/custom_ops/gpu_ops/moe/ep_moe_expert_dispatch.cu +++ b/custom_ops/gpu_ops/moe/ep_moe_expert_dispatch.cu @@ -964,7 +964,7 @@ std::vector EPMoeExpertDispatchFP8( } else { token_rows = input_dims[0]; } - const int num_rows = token_rows; + const int hidden_size = input.dims()[input_dims.size() - 1]; const int num_experts_per_rank = num_experts_per_rank_tensor.dims()[0]; @@ -988,9 +988,9 @@ std::vector EPMoeExpertDispatchFP8( auto dst_weights = GetEmptyTensor( {token_nums_feed_to_ffn}, paddle::DataType::FLOAT32, place); auto dst_indices = GetEmptyTensor( - {num_rows, num_experts_per_rank}, paddle::DataType::INT32, place); + {token_rows, num_experts_per_rank}, paddle::DataType::INT32, place); auto permute_indices_per_token = paddle::full( - {num_experts_per_rank, num_rows}, -1, paddle::DataType::INT32, place); + {num_experts_per_rank, token_rows}, -1, paddle::DataType::INT32, place); auto cumsum_idx_gpu = paddle::full({num_experts_per_rank}, 0, paddle::DataType::INT32, place); @@ -1001,7 +1001,7 @@ std::vector EPMoeExpertDispatchFP8( num_experts_per_rank_tensor, num_experts_per_rank_padded_tensor, moe_topk, - num_rows, + token_rows, -1, -1, hidden_size, diff --git a/custom_ops/gpu_ops/per_token_quant_fp8.cu b/custom_ops/gpu_ops/per_token_quant_fp8.cu index 77fa719f781..bd783df817a 100644 --- a/custom_ops/gpu_ops/per_token_quant_fp8.cu +++ b/custom_ops/gpu_ops/per_token_quant_fp8.cu @@ -232,6 +232,11 @@ std::vector PerTokenQuantPadding(paddle::Tensor &input, auto input_dim = input.dims(); const int token_num = input_dim[0]; const int hidden_size = input_dim[1]; + + PADDLE_ENFORCE(block_size == 128, "now only support block_size = 128"); + PADDLE_ENFORCE(hidden_size % 128 == 0, + "hidden_size must be divisible by 128"); + const int hidden_size_scale = hidden_size / block_size; auto quanted_x = GetEmptyTensor( {token_num, hidden_size}, paddle::DataType::FLOAT8_E4M3FN, input.place());