-
Notifications
You must be signed in to change notification settings - Fork 37
Closed
Labels
Description
Bug
In iris/iris.py, get_device_context() calls torch.tensor(context_data, dtype=torch.int64, device=self.device) every time it is invoked. The context data (rank, world_size, heap_bases) never changes after initialization.
This causes two problems:
- Performance: Unnecessary GPU memory allocation on every kernel call.
- CUDAGraph/HIPGraph incompatibility:
torch.tensor()allocates GPU memory, which is forbidden during graph capture. Any kernel that callsget_device_context()in its launch path cannot be captured in a CUDAGraph.
Impact
Blocks CUDAGraph integration for any iris-backed fused kernel (e.g., matmul_all_reduce used via vLLM's torch.compile pipeline). Minor performance overhead otherwise.
Fix
Cache the tensor on first call:
def get_device_context(self):
if self._cached_device_context is not None:
return self._cached_device_context
# ... existing construction logic ...
self._cached_device_context = context_tensor
return context_tensorAdd self._cached_device_context = None in __init__.
Component
iris/iris.py
Reactions are currently unavailable