Skip to content

get_device_context() allocates a new tensor on every call #466

@aamarnat

Description

@aamarnat

Bug

In iris/iris.py, get_device_context() calls torch.tensor(context_data, dtype=torch.int64, device=self.device) every time it is invoked. The context data (rank, world_size, heap_bases) never changes after initialization.

This causes two problems:

  1. Performance: Unnecessary GPU memory allocation on every kernel call.
  2. CUDAGraph/HIPGraph incompatibility: torch.tensor() allocates GPU memory, which is forbidden during graph capture. Any kernel that calls get_device_context() in its launch path cannot be captured in a CUDAGraph.

Impact

Blocks CUDAGraph integration for any iris-backed fused kernel (e.g., matmul_all_reduce used via vLLM's torch.compile pipeline). Minor performance overhead otherwise.

Fix

Cache the tensor on first call:

def get_device_context(self):
    if self._cached_device_context is not None:
        return self._cached_device_context
    # ... existing construction logic ...
    self._cached_device_context = context_tensor
    return context_tensor

Add self._cached_device_context = None in __init__.

Component

iris/iris.py

Metadata

Metadata

Labels

bugSomething isn't workingirisIris project issue

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions