Skip to content

Communication opt#1286

Merged
hiworldwzj merged 10 commits into
mainfrom
communication_opt
May 7, 2026
Merged

Communication opt#1286
hiworldwzj merged 10 commits into
mainfrom
communication_opt

Conversation

@blueswhen

Copy link
Copy Markdown
Collaborator

No description provided.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a high-performance all-reduce stack incorporating FlashInfer and PyTorch symmetric memory (SymmMem), updating the communication dispatch logic to prioritize these backends for compatible tensors. It also adjusts the default decode attention backend priority to favor FlashInfer. Feedback focuses on several critical issues: the use of non-existent group_name attributes on ProcessGroup objects which will cause runtime errors, hardcoded data types and missing device checks in the SymmMem backend that limit its applicability and performance, and the problematic use of global random seeding in the FlashInfer workspace initialization.


try:
self.buffer = torch_symm_mem.empty(self.max_size // dtype.itemsize, device=device, dtype=dtype)
handle = torch_symm_mem.rendezvous(self.buffer, group.group_name)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Standard PyTorch ProcessGroup objects do not have a group_name attribute. Accessing group.group_name directly will likely raise an AttributeError at runtime. You should use torch.distributed.distributed_c10d._get_group_name(group) to retrieve the internal group name string required by the symm_mem APIs. This issue also affects lines 89 and 91.

self._workspace = None
rng_state = random.getstate()
try:
random.seed(int.from_bytes(os.urandom(16), byteorder="big"))

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Seeding the global random module with os.urandom inside a library function is generally discouraged as it can have side effects on other parts of the application. Furthermore, in a distributed setting, if flashinfer_comm.create_allreduce_fusion_workspace relies on this randomness for any coordination or resource naming, having different seeds across ranks (which is guaranteed here by os.urandom) might lead to hangs or inconsistent states. If a unique identifier is required, consider using uuid or a rank-based approach that doesn't mutate global state.

class SymmMemAllreduce:
"""In-place all-reduce via torch symmetric memory (NVLink SHARP / NVLS)."""

def __init__(self, group: ProcessGroup, device, dtype: torch.dtype = torch.bfloat16) -> None:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The dtype is hardcoded to torch.bfloat16 in the constructor and is not updated when initialized in CustomProcessGroup. This will cause the SymmMemAllreduce backend to be skipped for models using float16 because should_use (line 76) strictly checks for dtype equality. It should be initialized with the model's actual compute/KV dtype (e.g., by passing it from CustomProcessGroup which can access it via get_env_start_args()).

)

def should_use(self, inp: torch.Tensor) -> bool:
if self.disabled or inp.dtype != self.dtype or not inp.is_contiguous():

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The should_use method should verify that the input tensor is on the same CUDA device as the pre-allocated self.buffer. If inp.device differs from self.device, the copy_ operation on line 87 will trigger a cross-device transfer, which is significantly slower and defeats the purpose of using a high-performance communication backend.

@blueswhen blueswhen force-pushed the communication_opt branch from d378b74 to cfd82c9 Compare May 6, 2026 09:12
@hiworldwzj hiworldwzj merged commit 28254a9 into main May 7, 2026
1 check passed
@hiworldwzj hiworldwzj deleted the communication_opt branch May 7, 2026 08:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants