Communication opt#1286
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a high-performance all-reduce stack incorporating FlashInfer and PyTorch symmetric memory (SymmMem), updating the communication dispatch logic to prioritize these backends for compatible tensors. It also adjusts the default decode attention backend priority to favor FlashInfer. Feedback focuses on several critical issues: the use of non-existent group_name attributes on ProcessGroup objects which will cause runtime errors, hardcoded data types and missing device checks in the SymmMem backend that limit its applicability and performance, and the problematic use of global random seeding in the FlashInfer workspace initialization.
|
|
||
| try: | ||
| self.buffer = torch_symm_mem.empty(self.max_size // dtype.itemsize, device=device, dtype=dtype) | ||
| handle = torch_symm_mem.rendezvous(self.buffer, group.group_name) |
There was a problem hiding this comment.
Standard PyTorch ProcessGroup objects do not have a group_name attribute. Accessing group.group_name directly will likely raise an AttributeError at runtime. You should use torch.distributed.distributed_c10d._get_group_name(group) to retrieve the internal group name string required by the symm_mem APIs. This issue also affects lines 89 and 91.
| self._workspace = None | ||
| rng_state = random.getstate() | ||
| try: | ||
| random.seed(int.from_bytes(os.urandom(16), byteorder="big")) |
There was a problem hiding this comment.
Seeding the global random module with os.urandom inside a library function is generally discouraged as it can have side effects on other parts of the application. Furthermore, in a distributed setting, if flashinfer_comm.create_allreduce_fusion_workspace relies on this randomness for any coordination or resource naming, having different seeds across ranks (which is guaranteed here by os.urandom) might lead to hangs or inconsistent states. If a unique identifier is required, consider using uuid or a rank-based approach that doesn't mutate global state.
| class SymmMemAllreduce: | ||
| """In-place all-reduce via torch symmetric memory (NVLink SHARP / NVLS).""" | ||
|
|
||
| def __init__(self, group: ProcessGroup, device, dtype: torch.dtype = torch.bfloat16) -> None: |
There was a problem hiding this comment.
The dtype is hardcoded to torch.bfloat16 in the constructor and is not updated when initialized in CustomProcessGroup. This will cause the SymmMemAllreduce backend to be skipped for models using float16 because should_use (line 76) strictly checks for dtype equality. It should be initialized with the model's actual compute/KV dtype (e.g., by passing it from CustomProcessGroup which can access it via get_env_start_args()).
| ) | ||
|
|
||
| def should_use(self, inp: torch.Tensor) -> bool: | ||
| if self.disabled or inp.dtype != self.dtype or not inp.is_contiguous(): |
There was a problem hiding this comment.
The should_use method should verify that the input tensor is on the same CUDA device as the pre-allocated self.buffer. If inp.device differs from self.device, the copy_ operation on line 87 will trigger a cross-device transfer, which is significantly slower and defeats the purpose of using a high-performance communication backend.
d378b74 to
cfd82c9
Compare
No description provided.