Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions iris/ccl/all_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@
Gathers tensors from all ranks and concatenates them along the last dimension.
"""

import logging

import triton
import triton.language as tl
import iris
from iris.host.logging.logging import _log_rank
from iris.host.tracing.kernel_artifacts import iris_launch
from .config import Config
from .utils import extract_group_info
Expand Down Expand Up @@ -484,6 +487,18 @@ def all_gather(
rank_in_group, rank_global, world_size, rank_start, rank_stride = extract_group_info(group, ctx)

M, N = input_tensor.shape[:2]
_log_rank(
logging.DEBUG,
"all_gather: shape=(%d,%d) dtype=%s rank=%d/%d async_op=%s",
M,
N,
input_tensor.dtype,
rank_global,
world_size,
async_op,
rank=rank_global,
num_ranks=world_size,
)
expected_output_shape = (world_size * M, N)

if output_tensor.shape[:2] != expected_output_shape:
Expand Down
25 changes: 25 additions & 0 deletions iris/ccl/all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@
Supports multiple variants: atomic, spinlock, ring, two-shot, and one-shot.
"""

import logging
from dataclasses import dataclass
from typing import Optional, Tuple

import triton
import triton.language as tl
import torch
import iris
from iris.host.logging.logging import _log_rank
from iris.host.tracing.kernel_artifacts import iris_launch
from .config import Config
from .utils import chiplet_transform_chunked, ReduceOp, extract_group_info
Expand Down Expand Up @@ -75,6 +77,16 @@ def all_reduce_preamble(

M, N = input_tensor.shape[:2]
dtype = input_tensor.dtype
_log_rank(
logging.DEBUG,
"all_reduce_preamble: variant=%s shape=(%d,%d) dtype=%s",
variant,
M,
N,
dtype,
rank=shmem.cur_rank,
num_ranks=shmem.num_ranks,
)

if workspace is None:
workspace = AllReduceWorkspace()
Expand Down Expand Up @@ -767,6 +779,19 @@ def all_reduce(
stride_out_m, stride_out_n = output_tensor.stride(0), output_tensor.stride(1)

variant = config.all_reduce_variant.lower()
_log_rank(
logging.DEBUG,
"all_reduce: variant=%s shape=(%d,%d) dtype=%s rank=%d/%d async_op=%s",
variant,
M,
N,
input_tensor.dtype,
rank_global,
world_size,
async_op,
rank=rank_global,
num_ranks=world_size,
)
if variant not in [
VARIANT_ATOMIC,
VARIANT_SPINLOCK,
Expand Down
16 changes: 16 additions & 0 deletions iris/ccl/all_to_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@
Supports both Triton and Gluon implementations based on config.
"""

import logging

import triton
import triton.language as tl
import iris
from iris.host.logging.logging import _log_rank
from iris.host.tracing.kernel_artifacts import iris_launch
from .config import Config
from .utils import chiplet_transform_chunked, extract_group_info
Expand Down Expand Up @@ -368,6 +371,19 @@ def all_to_all(

M, total_N = input_tensor.shape[:2]
N = total_N // world_size
_log_rank(
logging.DEBUG,
"all_to_all: shape=(%d,%d) N_per_rank=%d dtype=%s rank=%d/%d async_op=%s",
M,
total_N,
N,
input_tensor.dtype,
rank_global,
world_size,
async_op,
rank=rank_global,
num_ranks=world_size,
)

stride_in_m, stride_in_n = input_tensor.stride(0), input_tensor.stride(1)
stride_out_m, stride_out_n = output_tensor.stride(0), output_tensor.stride(1)
Expand Down
15 changes: 15 additions & 0 deletions iris/ccl/reduce_scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@
Uses the two-shot approach: reduce assigned tiles and store only to own rank.
"""

import logging

import triton
import triton.language as tl
import iris
from iris.host.logging.logging import _log_rank
from iris.host.tracing.kernel_artifacts import iris_launch
from .config import Config
from .utils import chiplet_transform_chunked, ReduceOp, extract_group_info
Expand Down Expand Up @@ -214,6 +217,18 @@ def reduce_scatter(
# rank_global: global rank in iris context - passed as iris_rank to kernel for RMA operations
rank_in_group, rank_global, world_size, rank_start, rank_stride = extract_group_info(group, shmem)
M, N = input_tensor.shape[:2]
_log_rank(
logging.DEBUG,
"reduce_scatter: shape=(%d,%d) dtype=%s rank=%d/%d async_op=%s",
M,
N,
input_tensor.dtype,
rank_global,
world_size,
async_op,
rank=rank_global,
num_ranks=world_size,
)

# Validate output shape matches input shape
if output_tensor.shape[:2] != (M, N):
Expand Down
12 changes: 12 additions & 0 deletions iris/host/distributed/fd_passing.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,18 @@ def setup_fd_infrastructure(cur_rank: int, num_ranks: int):
if num_ranks <= 1:
return None

import logging
from iris.host.logging.logging import _log_rank

_log_rank(
logging.DEBUG,
"setup_fd_infrastructure: rank=%d num_ranks=%d",
cur_rank,
num_ranks,
rank=cur_rank,
num_ranks=num_ranks,
)

import torch.distributed as dist
from iris.host.distributed.helpers import distributed_barrier

Expand Down
27 changes: 27 additions & 0 deletions iris/host/distributed/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@
# Copyright (c) 2025-2026 Advanced Micro Devices, Inc. All rights reserved.


import logging

import torch
import torch.distributed as dist
import numpy as np
import triton
import triton.language as tl
from iris.host.logging.logging import _log_rank
from iris.host.tracing.kernel_artifacts import iris_launch


Expand Down Expand Up @@ -58,6 +61,15 @@ def distributed_allgather(data):
world_size = dist.get_world_size()
device = _infer_device()
backend = str(dist.get_backend()).lower()
_log_rank(
logging.DEBUG,
"distributed_allgather: shape=%s backend=%s world_size=%d",
data.shape,
backend,
world_size,
rank=dist.get_rank(),
num_ranks=world_size,
)

# Fast path: tensor all_gather if dtype is NCCL-supported or backend != nccl
data_tensor = torch.from_numpy(data)
Expand Down Expand Up @@ -180,6 +192,14 @@ def distributed_broadcast_tensor(value_to_broadcast=None, root=0):
rank = dist.get_rank()
device = _infer_device()
backend = str(dist.get_backend()).lower()
_log_rank(
logging.DEBUG,
"distributed_broadcast_tensor: src=%d rank=%d",
root,
rank,
rank=rank,
num_ranks=dist.get_world_size(),
)

if rank == root:
if value_to_broadcast is None:
Expand Down Expand Up @@ -291,6 +311,13 @@ def distributed_barrier(group=None):
"""
if not dist.is_initialized():
raise RuntimeError("PyTorch distributed is not initialized")
_log_rank(
logging.DEBUG,
"distributed_barrier: group=%s",
group,
rank=dist.get_rank(),
num_ranks=dist.get_world_size(),
)
dist.barrier(group=group)


Expand Down
7 changes: 7 additions & 0 deletions iris/host/iris.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,12 @@ def __init__(self, heap_size=1 << 30, allocator_type="torch"):
self.gpu_id = gpu_id
self.heap_size = heap_size

if logger.isEnabledFor(logging.INFO):
self._log_with_rank(
logging.INFO,
f"init: heap_size={heap_size / (1 << 30):.1f}GB rank={cur_rank}/{num_ranks} allocator={allocator_type}",
)

# Initialize symmetric heap with specified allocator
self.heap = SymmetricHeap(heap_size, gpu_id, cur_rank, num_ranks, allocator_type)
self.device = f"cuda:{gpu_id}"
Expand Down Expand Up @@ -997,6 +1003,7 @@ def barrier(self, stream=None, group=None):
>>> ctx.barrier() # Synchronize all ranks
>>> ctx.barrier(group=my_group) # Synchronize only ranks in my_group
"""
self._log_with_rank(logging.DEBUG, "barrier: start")
# Wait for all GPUs to finish work
if stream is None:
torch.cuda.synchronize()
Expand Down
50 changes: 41 additions & 9 deletions iris/host/logging/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
"""

import logging
import os
import sys

# Logging constants (compatible with Python logging levels)
DEBUG = logging.DEBUG
Expand All @@ -15,20 +17,23 @@


class IrisFormatter(logging.Formatter):
"""Custom formatter that automatically includes rank information when available."""
"""Custom formatter that includes timestamp, level, rank, and module information."""

def __init__(self):
super().__init__()

def format(self, record):
# Check if rank information is available in the record
if hasattr(record, "iris_rank") and hasattr(record, "iris_num_ranks"):
prefix = f"[Iris] [{record.iris_rank}/{record.iris_num_ranks}]"
else:
prefix = "[Iris]"

# Format the message with the appropriate prefix
return f"{prefix} {record.getMessage()}"
rank = getattr(record, "iris_rank", "?")
num_ranks = getattr(record, "iris_num_ranks", "?")
ts = self.formatTime(record, "%H:%M:%S")
level = record.levelname
# Only show [module] for internal iris logs (set by _log_rank),
# not for user-facing ctx.info()/ctx.debug() etc.
iris_internal = getattr(record, "iris_internal", False)
if iris_internal:
module = record.module
return f"{ts} {level:<5s} [Iris] [{rank}/{num_ranks}] [{module}] {record.getMessage()}"
return f"{ts} {level:<5s} [Iris] [{rank}/{num_ranks}] {record.getMessage()}"
Comment thread
mawad-amd marked this conversation as resolved.


# Logger instance that can be accessed as iris.logger
Expand All @@ -37,6 +42,11 @@ def format(self, record):
# Set up iris logger
logger.setLevel(logging.INFO) # Default level

# Override from environment
_env_level = os.environ.get("IRIS_LOG_LEVEL", "").upper()
if _env_level in ("DEBUG", "INFO", "WARNING", "ERROR"):
logger.setLevel(getattr(logging, _env_level))

# Add a console handler if none exists
if not logger.handlers:
_console_handler = logging.StreamHandler()
Expand All @@ -45,6 +55,28 @@ def format(self, record):
logger.addHandler(_console_handler)


def _log_rank(level, msg, *args, rank=None, num_ranks=None):
"""Log with optional rank injection. Captures caller's module automatically."""
if logger.isEnabledFor(level):
# Capture caller's file/line so the formatter can show [module]
frame = sys._getframe(1)
record = logging.LogRecord(
name=logger.name,
level=level,
pathname=frame.f_code.co_filename,
lineno=frame.f_lineno,
msg=msg,
args=args,
exc_info=None,
)
record.iris_internal = True
if rank is not None:
record.iris_rank = rank
if num_ranks is not None:
record.iris_num_ranks = num_ranks
Comment thread
mawad-amd marked this conversation as resolved.
logger.handle(record)


def set_logger_level(level):
"""
Set the logging level for the iris logger.
Expand Down
29 changes: 29 additions & 0 deletions iris/host/memory/allocators/torch_allocator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@
sub-allocations within it using bump allocation.
"""

import logging
import math
import numpy as np
import torch
from typing import Optional, Dict
import struct

from .base import BaseAllocator
from iris.host.logging.logging import _log_rank
from iris.host.platform.hip import export_dmabuf_handle, import_dmabuf_handle, destroy_external_memory
from iris.host.distributed.fd_passing import send_fd, recv_fd, managed_fd
from iris.host.platform.utils import is_simulation_env
Expand All @@ -41,6 +43,14 @@ def __init__(self, heap_size: int, device_id: int, cur_rank: int, num_ranks: int
super().__init__(heap_size, device_id, cur_rank, num_ranks)

self.device = f"cuda:{device_id}"
_log_rank(
logging.INFO,
"TorchAllocator: init heap_size=%.1fGB device=%d",
heap_size / (1 << 30),
device_id,
rank=cur_rank,
num_ranks=num_ranks,
)
if is_simulation_env():
import json

Expand Down Expand Up @@ -93,7 +103,26 @@ def allocate(self, num_elements: int, dtype: torch.dtype, alignment: int = 1024)
size_in_bytes = num_elements * element_size
aligned_size = math.ceil(size_in_bytes / alignment) * alignment

_log_rank(
logging.DEBUG,
"TorchAllocator.allocate: num_elements=%d dtype=%s size_bytes=%d offset=%d",
num_elements,
dtype,
size_in_bytes,
self.heap_offset,
rank=self.cur_rank,
num_ranks=self.num_ranks,
)

if self.heap_offset + aligned_size > self.heap_size:
_log_rank(
logging.ERROR,
"TorchAllocator: OOM requested=%d available=%d",
aligned_size,
self.heap_size - self.heap_offset,
rank=self.cur_rank,
num_ranks=self.num_ranks,
)
raise MemoryError("Heap out of memory")

start = self.heap_offset
Expand Down
Loading
Loading