Skip to content
Open
11 changes: 9 additions & 2 deletions benchmark/examples/benchmark_all_gather_gemm_pull.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,9 @@ def parse_args():
parser.add_argument("--BLK_N", type=int, default=64, help="Block size N for the kernel")
parser.add_argument("--BLK_K", type=int, default=64, help="Block size K for the kernel")
parser.add_argument("--gsize_m", type=int, default=6, help="Group size in M dimension")
parser.add_argument("--num_sms", type=int, default=304, help="Number of SMs for the kernel")
parser.add_argument(
"--num_sms", type=int, default=None, help="Number of SMs for the kernel (default: auto-detected)"
)

parser.add_argument("--num_ranks", type=int, default=8, help="Number of GPUs to run the example on.")

Expand Down Expand Up @@ -138,7 +140,12 @@ def worker(rank: int, world_size: int, init_url: str, args: argparse.Namespace):
A_local_iris = shmem.empty((M, K_local), dtype=datatype)
A_local_iris.copy_(A_local)

num_sms = torch.cuda.get_device_properties(rank).multi_processor_count
# Use provided num_sms or auto-detect
if run_args["num_sms"] is None:
num_sms = torch.cuda.get_device_properties(rank).multi_processor_count
run_args["num_sms"] = num_sms
else:
num_sms = run_args["num_sms"]

main_stream = torch.cuda.Stream()
kernel_timing = {
Expand Down
11 changes: 9 additions & 2 deletions benchmark/examples/benchmark_all_gather_gemm_push.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ def parse_args():
parser.add_argument("--BLK_N", type=int, default=64, help="Block size N for GEMM computation")
parser.add_argument("--BLK_K", type=int, default=64, help="Block size K for tiling")
parser.add_argument("--gsize_m", type=int, default=6, help="Group size in M dimension")
parser.add_argument("--num_sms", type=int, default=304, help="Number of SMs for the kernel")
parser.add_argument(
"--num_sms", type=int, default=None, help="Number of SMs for the kernel (default: auto-detected)"
)

parser.add_argument("--num_ranks", type=int, default=8, help="Number of GPUs to run the example on.")

Expand Down Expand Up @@ -142,7 +144,12 @@ def worker(rank: int, world_size: int, init_url: str, args: argparse.Namespace):
num_k_tiles = (K_local + run_args["BLK_K"] - 1) // run_args["BLK_K"]
signal_flags_iris = shmem.zeros((world_size, world_size, num_m_tiles, num_k_tiles), dtype=torch.int32)

num_sms = torch.cuda.get_device_properties(rank).multi_processor_count
# Use provided num_sms or auto-detect
if run_args["num_sms"] is None:
num_sms = torch.cuda.get_device_properties(rank).multi_processor_count
run_args["num_sms"] = num_sms
else:
num_sms = run_args["num_sms"]

main_stream = torch.cuda.Stream()
kernel_timing = {
Expand Down
14 changes: 12 additions & 2 deletions examples/07_gemm_all_scatter/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,12 @@ def parse_args():
parser.add_argument("--BLK_K", type=int, default=64, help="Block size K")
parser.add_argument("--gsize_m", type=int, default=6, help="L2-cache locality swizzle parameter")
parser.add_argument("--heap_size", type=int, default=1 << 33, help="Iris heap size")
parser.add_argument("--gemm_sms", type=int, default=304, help="Number of SMs for persistent GEMM algorithm")
parser.add_argument(
"--gemm_sms",
type=int,
default=None,
help="Number of SMs for persistent GEMM algorithm (default: auto-detected)",
)
parser.add_argument("-r", "--num_ranks", type=int, default=2, help="Number of ranks/processes")

return vars(parser.parse_args())
Expand All @@ -67,7 +72,12 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict):
shmem = iris.iris(args["heap_size"])
rank = shmem.get_rank()
world_size = shmem.get_num_ranks()
cu_count = shmem.get_cu_count()

# Set default SM values if not provided
if args["gemm_sms"] is None:
# For all_scatter: use total CU count
cu_count = torch.cuda.get_device_properties(rank).multi_processor_count
args["gemm_sms"] = cu_count

# GEMM
datatype = torch.float32
Expand Down
16 changes: 11 additions & 5 deletions examples/08_gemm_atomics_all_reduce/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import os
import argparse
import json
import math

from examples.common.utils import (
JSONWriter,
Expand Down Expand Up @@ -68,10 +69,8 @@ def parse_args():
parser.add_argument("--kpack", type=int, default=2, help="K packing size")
parser.add_argument("--heap_size", type=int, default=1 << 33, help="Iris heap size")

# For All Scatter, use: 288
# For One Shot, use: 256
parser.add_argument("--gemm_sms", type=int, default=288, help="Number of SMs for GEMM")
parser.add_argument("--total_sms", type=int, default=304, help="Total number of SMs")
parser.add_argument("--gemm_sms", type=int, default=None, help="Number of SMs for GEMM (default: auto-detected)")
parser.add_argument("--total_sms", type=int, default=None, help="Total number of SMs (default: auto-detected)")
parser.add_argument("-r", "--num_ranks", type=int, default=2, help="Number of ranks/processes")

return vars(parser.parse_args())
Expand All @@ -86,7 +85,14 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict):
shmem = iris.iris(args["heap_size"])
rank = shmem.get_rank()
world_size = shmem.get_num_ranks()
cu_count = shmem.get_cu_count()

# Set default SM values if not provided
cu_count = torch.cuda.get_device_properties(rank).multi_processor_count
if args["total_sms"] is None:
args["total_sms"] = cu_count
if args["gemm_sms"] is None:
# For all_reduce: use next smaller power of 2, rest for communication
args["gemm_sms"] = 2 ** int(math.log2(cu_count)) if cu_count > 0 else 1

# GEMM
datatype = torch.float32
Expand Down
14 changes: 11 additions & 3 deletions examples/09_gemm_one_shot_all_reduce/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import os
import argparse
import json
import math

from examples.common.utils import (
JSONWriter,
Expand Down Expand Up @@ -68,8 +69,8 @@ def parse_args():
parser.add_argument("--kpack", type=int, default=2, help="K packing size")
parser.add_argument("--heap_size", type=int, default=1 << 33, help="Iris heap size")

parser.add_argument("--gemm_sms", type=int, default=288, help="Number of SMs for GEMM")
parser.add_argument("--total_sms", type=int, default=304, help="Total number of SMs")
parser.add_argument("--gemm_sms", type=int, default=None, help="Number of SMs for GEMM (default: auto-detected)")
parser.add_argument("--total_sms", type=int, default=None, help="Total number of SMs (default: auto-detected)")
parser.add_argument("-r", "--num_ranks", type=int, default=2, help="Number of ranks/processes")
return vars(parser.parse_args())

Expand All @@ -82,7 +83,14 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict):
shmem = iris.iris(args["heap_size"])
rank = shmem.get_rank()
world_size = shmem.get_num_ranks()
cu_count = shmem.get_cu_count()

# Set default SM values if not provided
cu_count = torch.cuda.get_device_properties(rank).multi_processor_count
if args["total_sms"] is None:
args["total_sms"] = cu_count
if args["gemm_sms"] is None:
# For all_reduce: use next smaller power of 2, rest for communication
args["gemm_sms"] = 2 ** int(math.log2(cu_count)) if cu_count > 0 else 1

# GEMM
datatype = torch.float32
Expand Down
22 changes: 19 additions & 3 deletions examples/10_gemm_all_scatter_wg_specialization/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import os
import argparse
import json
import math

from examples.common.utils import JSONWriter, Timestamps, is_triton_interpret_set
from examples.common.validation import validate_gemm
Expand Down Expand Up @@ -54,9 +55,17 @@ def parse_args():
parser.add_argument("--gsize_m", type=int, default=6, help="L2-cache locality swizzle parameter")
parser.add_argument("--heap_size", type=int, default=1 << 33, help="Iris heap size")
parser.add_argument(
"--gemm_sms", type=int, default=256, help="Number of SMs for workgroup-specialized GEMM algorithm"
"--gemm_sms",
type=int,
default=None,
help="Number of SMs for workgroup-specialized GEMM algorithm (default: auto-detected)",
)
parser.add_argument(
"--num_sms",
type=int,
default=None,
help="Number of total SMs for gemm + scatter kernel (default: auto-detected)",
)
parser.add_argument("--num_sms", type=int, default=304, help="Number of total SMs for gemm + scatter kernel")
parser.add_argument("-r", "--num_ranks", type=int, default=2, help="Number of ranks/processes")

return vars(parser.parse_args())
Expand All @@ -70,7 +79,14 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict):
shmem = iris.iris(args["heap_size"])
rank = shmem.get_rank()
world_size = shmem.get_num_ranks()
cu_count = shmem.get_cu_count()

# Set default SM values if not provided
cu_count = torch.cuda.get_device_properties(rank).multi_processor_count
if args["num_sms"] is None:
args["num_sms"] = cu_count
if args["gemm_sms"] is None:
# For wg_specialized: use next smaller power of 2
args["gemm_sms"] = 2 ** int(math.log2(cu_count)) if cu_count > 0 else 1

# GEMM
datatype = torch.float32
Expand Down
22 changes: 19 additions & 3 deletions examples/11_gemm_all_scatter_producer_consumer/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import os
import argparse
import json
import math

from examples.common.utils import JSONWriter, Timestamps, is_triton_interpret_set
from examples.common.validation import validate_gemm
Expand Down Expand Up @@ -55,9 +56,14 @@ def parse_args():
parser.add_argument("--gsize_m", type=int, default=6, help="L2-cache locality swizzle parameter")
parser.add_argument("--heap_size", type=int, default=1 << 33, help="Iris heap size")
parser.add_argument(
"--gemm_sms", type=int, default=256, help="Number of SMs for workgroup-specialized GEMM algorithm"
"--gemm_sms",
type=int,
default=None,
help="Number of SMs for workgroup-specialized GEMM algorithm (default: auto-detected)",
)
parser.add_argument(
"--comm_sms", type=int, default=None, help="Number of SMs for All-Scatter kernel (default: auto-detected)"
)
parser.add_argument("--comm_sms", type=int, default=48, help="Number of SMs for All-Scatter kernel")
parser.add_argument("-r", "--num_ranks", type=int, default=2, help="Number of ranks/processes")

return vars(parser.parse_args())
Expand All @@ -71,7 +77,17 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict):
shmem = iris.iris(args["heap_size"])
rank = shmem.get_rank()
world_size = shmem.get_num_ranks()
cu_count = shmem.get_cu_count()

# Set default SM values if not provided
cu_count = torch.cuda.get_device_properties(rank).multi_processor_count
next_pow2 = 2 ** int(math.log2(cu_count)) if cu_count > 0 else 1

if args["gemm_sms"] is None:
# For wg_specialized: use next smaller power of 2
args["gemm_sms"] = next_pow2
if args["comm_sms"] is None:
# comm_sms is the leftover: total - next_power_of_2
args["comm_sms"] = cu_count - next_pow2

# GEMM
datatype = torch.float32
Expand Down
22 changes: 19 additions & 3 deletions examples/12_gemm_all_scatter_bulk_synchronous/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import os
import argparse
import json
import math

from examples.common.utils import JSONWriter, Timestamps, is_triton_interpret_set
from examples.common.validation import validate_gemm
Expand Down Expand Up @@ -55,9 +56,14 @@ def parse_args():
parser.add_argument("--gsize_m", type=int, default=6, help="L2-cache locality swizzle parameter")
parser.add_argument("--heap_size", type=int, default=1 << 33, help="Iris heap size")
parser.add_argument(
"--gemm_sms", type=int, default=256, help="Number of SMs for workgroup-specialized GEMM algorithm"
"--gemm_sms",
type=int,
default=None,
help="Number of SMs for workgroup-specialized GEMM algorithm (default: auto-detected)",
)
parser.add_argument(
"--comm_sms", type=int, default=None, help="Number of SMs for All-Scatter kernel (default: auto-detected)"
)
parser.add_argument("--comm_sms", type=int, default=256, help="Number of SMs for All-Scatter kernel")
parser.add_argument("-r", "--num_ranks", type=int, default=2, help="Number of ranks/processes")

return vars(parser.parse_args())
Expand All @@ -71,7 +77,17 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict):
shmem = iris.iris(args["heap_size"])
rank = shmem.get_rank()
world_size = shmem.get_num_ranks()
cu_count = shmem.get_cu_count()

# Set default SM values if not provided
cu_count = torch.cuda.get_device_properties(rank).multi_processor_count
next_pow2 = 2 ** int(math.log2(cu_count)) if cu_count > 0 else 1

if args["gemm_sms"] is None:
# For wg_specialized: use next smaller power of 2
args["gemm_sms"] = next_pow2
if args["comm_sms"] is None:
# For bulk synchronous, use same as gemm_sms
args["comm_sms"] = next_pow2

# GEMM
datatype = torch.float32
Expand Down
31 changes: 22 additions & 9 deletions examples/benchmark/bench_all_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from datetime import datetime
import argparse
import json
import torch


def launch_sbatch(
Expand Down Expand Up @@ -110,16 +111,28 @@ def main(hashes, config, sbatch_script_content, input_json, tiling_json, dry_run
if mkn not in mkn_gemm_tiles:
mkn_gemm_tiles[mkn] = {key: entry[key] for key in optional_keys if key in entry}

if config["partition"] is not None:
if "mi300" in config["partition"]:
print("Running on MI300")
# Determine gemm_sms based on available GPU or partition name
try:
if torch.cuda.is_available():
gemm_sms = torch.cuda.get_device_properties(0).multi_processor_count
print(f"Auto-detected CU count: {gemm_sms}")
else:
gemm_sms = None
except Exception:
# Fall back to partition-based detection
gemm_sms = None

if gemm_sms is None:
if config["partition"] is not None:
if "mi300" in config["partition"]:
print("Running on MI300 (partition-based)")
gemm_sms = 304
elif "mi250" in config["partition"]:
print("Running on MI250 (partition-based)")
gemm_sms = 104
else:
print("Assuming MI300 (default)")
gemm_sms = 304
elif "mi250" in config["partition"]:
print("Running on MI250")
gemm_sms = 104
else:
print("Assuming MI300")
gemm_sms = 304

enable_algorithms = False
enable_mkn = True
Expand Down
11 changes: 10 additions & 1 deletion scripts/link_bandwidth.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,15 @@
# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved.

import json
import torch

try:
if torch.cuda.is_available():
cu_count = torch.cuda.get_device_properties(0).multi_processor_count
else:
cu_count = 304 # Default for MI300
except Exception:
cu_count = 304 # Default for MI300

# Sample input (replace with file read if needed)
config = {
Expand All @@ -26,7 +35,7 @@
"kpack": 2,
"heap_size": 8589934592,
"gemm_sms": 48,
"total_sms": 304,
"total_sms": cu_count,
"communication_block_size": 256,
"communication_sms_multiplier": 1,
"M": 8192,
Expand Down
Loading