feat(ep): support independent combine dtype for EP dispatch/combine benchmarks#239
Merged
feat(ep): support independent combine dtype for EP dispatch/combine benchmarks#239
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
In real EP+MoE workloads, dispatch and combine often use different data types (e.g., FP4 dispatch with BF16 combine after MoE dequantization). The existing benchmarks assumed dispatch and combine share the same dtype, making it impossible to measure cross-type performance accurately. This PR adds independent combine dtype support to both intranode and internode benchmarks, along with a single-phase tuning mode.
Technical Details
Intranode benchmark (
tests/python/ops/bench_dispatch_combine.py)--combine-dtypeCLI argument to specify a separate dtype for combinecombine_hidden_diminternally based on FP4 packing (FP4 dispatch: hidden_dim/2, BF16 combine: full hidden_dim)max(dispatch_hidden_dim, combine_hidden_dim)to satisfy C++ buffer assertionsunpack_fp4x2()(FP4→any) or.to()(non-FP4)lat) printingLaunchConfignamedtuple for readabilityInternode benchmark (
examples/ops/dispatch_combine/test_dispatch_combine_internode.py)--combine-dtypeblock_num,rdma_block_num,warp_per_block) forwarded throughrun_dispatch/run_combinesweep_benchparameter order bugShared (
tests/python/ops/test_dispatch_combine.py)unpack_fp4x2(tensor, dtype=bf16): LUT-based FP4 E2M1 unpacking to any float dtypecheck_combine_result: acceptscombine_data_type, skips only when combine is FP4 (not dispatch), enables verification for FP4 dispatch + non-FP4 combineKernel fix (
src/ops/dispatch_combine/low_latency_async.cpp)Test Plan