diff --git a/.github/scripts/cleanup_ports.sh b/.github/scripts/cleanup_ports.sh new file mode 100755 index 00000000..a89433c1 --- /dev/null +++ b/.github/scripts/cleanup_ports.sh @@ -0,0 +1,78 @@ +#!/bin/bash +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +set -e + +# Script to clean up any lingering test processes and ports +# This is useful when tests segfault and leave processes/ports open + +echo "========================================" +echo "Port Cleanup Script - Starting" +echo "========================================" + +# Show initial state of listening ports +echo "" +echo "Initial state - Listening TCP ports:" +echo "------------------------------------" +ss -tulpn 2>/dev/null | grep LISTEN | grep -E "python|pt_main_thread" || echo "No Python/PyTorch processes listening on ports" +echo "" + +echo "Cleaning up lingering test processes and ports..." + +# Clean up Python test processes that might be stuck +# Look for processes related to run_tests_distributed.py, pytest, and torch distributed tests +echo "Checking for lingering Python test processes..." +PYTHON_TEST_PIDS=$(pgrep -f "run_tests_distributed.py|pytest.*test_|torch.distributed" 2>/dev/null || true) + +if [ -n "$PYTHON_TEST_PIDS" ]; then + echo "Found Python test processes: $PYTHON_TEST_PIDS" + echo "Killing Python test processes..." + echo "$PYTHON_TEST_PIDS" | xargs kill -9 2>/dev/null || true + echo "Cleaned up Python test processes" +fi + +# Clean up pt_main_thread processes (PyTorch multiprocessing spawned processes) +echo "Checking for lingering PyTorch processes (multiprocessing.spawn)..." +PT_PIDS=$(pgrep -f "multiprocessing.spawn" 2>/dev/null || true) + +if [ -n "$PT_PIDS" ]; then + echo "Found PyTorch processes: $PT_PIDS" + echo "Killing PyTorch processes..." + echo "$PT_PIDS" | xargs kill -9 2>/dev/null || true + echo "Cleaned up PyTorch processes" +fi + +# Clean up any processes listening on TCP ports in the common test range +# PyTorch distributed typically uses ports in the 29500+ range, but can use any available port +echo "Checking for processes using TCP ports..." +LISTENING_PIDS=$(lsof -ti tcp -sTCP:LISTEN 2>/dev/null | sort -u || true) + +if [ -n "$LISTENING_PIDS" ]; then + # Filter to only Python/PyTorch processes to avoid killing system services + for PID in $LISTENING_PIDS; do + PROCESS_NAME=$(ps -p $PID -o comm= 2>/dev/null || true) + # Check for python or pt_main_thread processes + if [[ "$PROCESS_NAME" == *"python"* ]] || [[ "$PROCESS_NAME" == *"pt_main_thread"* ]]; then + PORT=$(lsof -Pan -p $PID -i tcp -sTCP:LISTEN 2>/dev/null | awk 'NR>1 {print $9}' | cut -d':' -f2 | head -1) + if [ -n "$PORT" ]; then + echo "Found process $PROCESS_NAME (PID $PID) listening on port $PORT" + kill -9 $PID 2>/dev/null || true + echo "Cleaned up process $PID on port $PORT" + fi + fi + done +fi + +echo "" +echo "========================================" +echo "Port Cleanup Script - Completed" +echo "========================================" + +# Show final state of listening ports +echo "" +echo "Final state - Listening TCP ports:" +echo "------------------------------------" +ss -tulpn 2>/dev/null | grep LISTEN | grep -E "python|pt_main_thread" || echo "No Python/PyTorch processes listening on ports" +echo "" +echo "Port cleanup complete." diff --git a/.github/scripts/run_perf_benchmark.sh b/.github/scripts/run_perf_benchmark.sh new file mode 100755 index 00000000..7be18d84 --- /dev/null +++ b/.github/scripts/run_perf_benchmark.sh @@ -0,0 +1,63 @@ +#!/bin/bash +set -e + +# Arguments +EXAMPLE_PATH=$1 +TFLOPS_THRESHOLD=$2 +shift 2 +BENCHMARK_ARGS="$@" + +# Create overlay image in workspace (will be auto-cleaned by GitHub Actions) +OVERLAY="iris_overlay_perf_${EXAMPLE_PATH//\//_}.img" + +echo "::group::Creating overlay image" +apptainer overlay create --size 1024 --create-dir /var/cache/iris "${OVERLAY}" +echo "::endgroup::" + +echo "::group::Running performance benchmark" +apptainer exec --overlay "${OVERLAY}" --no-home --cleanenv --env HIP_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" \ + --bind "${PWD}:/iris_workspace" --cwd /iris_workspace \ + ~/apptainer/iris-dev.sif bash -c " + set -e + pip install -e . + python examples/${EXAMPLE_PATH}/benchmark.py \ + --benchmark \ + --validate \ + -r 8 \ + ${BENCHMARK_ARGS} \ + --output_file perf_result.json + " +echo "::endgroup::" + +# Parse JSON and check performance +echo "::group::Validating performance" + +# Check if benchmark succeeded +SUCCESS=$(jq -r '.success' perf_result.json) +if [ "$SUCCESS" != "true" ]; then + echo "::error::Benchmark failed (success: $SUCCESS)" + jq '.' perf_result.json + exit 1 +fi + +TFLOPS=$(jq -r '.tflops' perf_result.json) + +if [ -z "$TFLOPS" ] || [ "$TFLOPS" = "null" ]; then + echo "::error::Failed to extract tflops from benchmark output" + jq '.' perf_result.json + exit 1 +fi + +echo "::notice::Achieved TFLOPs: $TFLOPS" + +# Convert to integer for comparison +TFLOPS_INT=${TFLOPS%.*} +if (( TFLOPS_INT < TFLOPS_THRESHOLD )); then + echo "::error::Performance regression detected! TFLOPs ($TFLOPS) is below threshold ($TFLOPS_THRESHOLD)" + jq '.' perf_result.json + exit 1 +fi + +echo "✅ Performance test passed! TFLOPs: $TFLOPS (threshold: >$TFLOPS_THRESHOLD)" +echo "::endgroup::" + diff --git a/.github/scripts/run_tests.sh b/.github/scripts/run_tests.sh new file mode 100755 index 00000000..fd7b9388 --- /dev/null +++ b/.github/scripts/run_tests.sh @@ -0,0 +1,28 @@ +#!/bin/bash +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +set -e # Exit on any error + +# Get num_ranks from command line argument +NUM_RANKS=$1 + +if [ -z "$NUM_RANKS" ]; then + echo "Error: NUM_RANKS not provided" + echo "Usage: $0 " + exit 1 +fi + +# Run examples tests one at a time using distributed wrapper +echo 'Running examples tests one at a time...' +for test_file in tests/examples/test_*.py; do + echo "Testing: $test_file with $NUM_RANKS ranks" + python tests/run_tests_distributed.py --num_ranks $NUM_RANKS "$test_file" -v --tb=short --durations=10 +done + +# Run unit tests one at a time using distributed wrapper +echo 'Running unit tests one at a time...' +for test_file in tests/unittests/test_*.py; do + echo "Testing: $test_file with $NUM_RANKS ranks" + python tests/run_tests_distributed.py --num_ranks $NUM_RANKS "$test_file" -v --tb=short --durations=10 +done diff --git a/.github/workflows/iris-external-validation-test.yml b/.github/workflows/iris-external-validation-test.yml new file mode 100644 index 00000000..bbf547a5 --- /dev/null +++ b/.github/workflows/iris-external-validation-test.yml @@ -0,0 +1,83 @@ +name: Iris External Validation Test + +on: + push: + branches: [ main ] + pull_request: + branches: [ main ] + workflow_dispatch: + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} + cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} + +jobs: + build-apptainer-image: + runs-on: [self-hosted, mi3008x] + timeout-minutes: 90 + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Setup Apptainer + run: | + apt-get update && apt-get install -y software-properties-common + add-apt-repository -y ppa:apptainer/ppa + apt-get update && apt-get install -y apptainer + + - name: Build Iris Apptainer container + run: | + # Create persistent Apptainer directory + mkdir -p ~/apptainer + + # Build Apptainer image from definition file (only if it doesn't exist) + if [ ! -f ~/apptainer/iris-dev.sif ]; then + echo "Building new Apptainer image..." + apptainer build ~/apptainer/iris-dev.sif apptainer/iris.def + else + echo "Using existing Apptainer image" + fi + + external-validation-test: + name: External Validation Test + needs: build-apptainer-image + runs-on: [self-hosted, mi3008x] + timeout-minutes: 30 + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Cleanup lingering ports before tests + run: | + bash .github/scripts/cleanup_ports.sh + + - name: Run External Validation Test with Apptainer + run: | + set -e + + # Create unique overlay image for isolation + OVERLAY="/tmp/iris_overlay_$(whoami)_external_$(date +%s%N).img" + + echo "::group::Creating overlay image" + apptainer overlay create --size 1024 --create-dir /var/cache/iris "${OVERLAY}" + echo "::endgroup::" + + echo "::group::Running external validation test" + apptainer exec --overlay "${OVERLAY}" --no-home --cleanenv \ + --bind "${PWD}:/iris_workspace" --cwd /iris_workspace \ + ~/apptainer/iris-dev.sif bash -c " + set -e + pip install git+https://github.com/${{ github.repository }}.git@${{ github.sha }} + wget -O test_iris_distributed.py https://gist.githubusercontent.com/mawad-amd/6375dc078e39e256828f379e03310ec7/raw/a527c3192bee4615292769e340b1c73676f6945a/test_iris_distributed.py + python test_iris_distributed.py + " + echo "::endgroup::" + + # Cleanup overlay image + echo "::group::Cleaning up overlay image" + rm -f "${OVERLAY}" + echo "::endgroup::" + + echo "✅ External validation test passed!" \ No newline at end of file diff --git a/.github/workflows/iris-performance-regression-test.yml b/.github/workflows/iris-performance-regression-test.yml new file mode 100644 index 00000000..4e5c1e8a --- /dev/null +++ b/.github/workflows/iris-performance-regression-test.yml @@ -0,0 +1,86 @@ +name: Iris Performance Regression Test + +on: + push: + branches: [ main ] + pull_request: + branches: [ main ] + workflow_dispatch: + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} + cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} + +jobs: + build-apptainer-image: + runs-on: [self-hosted, mi3008x] + timeout-minutes: 20 + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Setup Apptainer + run: | + apt-get update && apt-get install -y software-properties-common + add-apt-repository -y ppa:apptainer/ppa + apt-get update && apt-get install -y apptainer + + - name: Build Iris Apptainer container + run: | + # Create persistent Apptainer directory + mkdir -p ~/apptainer + + # Build Apptainer image from definition file (only if it doesn't exist) + if [ ! -f ~/apptainer/iris-dev.sif ]; then + echo "Building new Apptainer image..." + apptainer build ~/apptainer/iris-dev.sif apptainer/iris.def + else + echo "Using existing Apptainer image" + fi + + performance-test: + name: ${{ matrix.example_name }} + needs: build-apptainer-image + runs-on: [self-hosted, mi3008x] + timeout-minutes: 30 + strategy: + fail-fast: false + matrix: + # Performance baselines measured on AMD Instinct MI325X (8 GPUs) + include: + - example_name: "GEMM All-Scatter WG Specialization" + example_path: "10_gemm_all_scatter_wg_specialization" + tflops_threshold: 1600 # Actual: ~2182 TFLOPs + benchmark_args: "-m 16384 -n 16384 -k 16384 --BLK_M 128 --BLK_N 128 --BLK_K 64 --gsize_m 6 --gemm_sms 256" + + - example_name: "GEMM All-Scatter" + example_path: "07_gemm_all_scatter" + tflops_threshold: 1000 # Actual: ~1407 TFLOPs + benchmark_args: "-m 16384 -n 16384 -k 16384 --BLK_M 256 --BLK_N 64 --BLK_K 64 --gsize_m 6 --gemm_sms 256" + + - example_name: "GEMM All-Scatter Producer-Consumer" + example_path: "11_gemm_all_scatter_producer_consumer" + tflops_threshold: 1600 # Actual: ~2190 TFLOPs + benchmark_args: "-m 16384 -n 16384 -k 16384 --BLK_M 128 --BLK_N 128 --BLK_K 64 --gsize_m 6 --gemm_sms 256 --comm_sms 48" + + - example_name: "GEMM All-Scatter Bulk Synchronous" + example_path: "12_gemm_all_scatter_bulk_synchronous" + tflops_threshold: 900 # Actual: ~1262 TFLOPs + benchmark_args: "-m 16384 -n 16384 -k 16384 --BLK_M 128 --BLK_N 128 --BLK_K 64 --gsize_m 6 --gemm_sms 256" + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Cleanup lingering ports before tests + run: | + bash .github/scripts/cleanup_ports.sh + + - name: Run ${{ matrix.example_name }} Benchmark (8 ranks) + run: | + bash .github/scripts/run_perf_benchmark.sh \ + "${{ matrix.example_path }}" \ + "${{ matrix.tflops_threshold }}" \ + ${{ matrix.benchmark_args }} + diff --git a/.github/workflows/iris-pip-install-test.yml b/.github/workflows/iris-pip-install-test.yml new file mode 100644 index 00000000..d88cfd3f --- /dev/null +++ b/.github/workflows/iris-pip-install-test.yml @@ -0,0 +1,173 @@ +name: Iris Pip Install Test + +on: + push: + branches: [ main ] + pull_request: + branches: [ main ] + workflow_dispatch: + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} + cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} + +jobs: + build-apptainer-image: + runs-on: [self-hosted, mi3008x] + timeout-minutes: 90 + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Setup Apptainer + run: | + apt-get update && apt-get install -y software-properties-common + add-apt-repository -y ppa:apptainer/ppa + apt-get update && apt-get install -y apptainer + + - name: Build Iris Apptainer container + run: | + # Create persistent Apptainer directory + mkdir -p ~/apptainer + + # Build Apptainer image from definition file (only if it doesn't exist) + if [ ! -f ~/apptainer/iris-dev.sif ]; then + echo "Building new Apptainer image..." + apptainer build ~/apptainer/iris-dev.sif apptainer/iris.def + else + echo "Using existing Apptainer image" + fi + test-1-2-4-ranks: + name: Pip Install Test 1/2/4 Ranks (Parallel) + needs: build-apptainer-image + runs-on: [self-hosted, mi3008x] + timeout-minutes: 30 + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Cleanup lingering ports before tests + run: | + bash .github/scripts/cleanup_ports.sh + + - name: Run pip install tests for 1, 2, 4 ranks in parallel + run: | + set -e + + # Run tests in parallel with different GPU assignments + # Note: Each test gets 2+ GPUs even if it only uses some of them. + # This allows tests like test_empty_device_handling to verify that + # allocating on a different device correctly raises an error. + + # Create unique overlay images for isolation + OVERLAY_1="/tmp/iris_overlay_$(whoami)_1rank_$(date +%s%N).img" + OVERLAY_2="/tmp/iris_overlay_$(whoami)_2rank_$(date +%s%N).img" + OVERLAY_4="/tmp/iris_overlay_$(whoami)_4rank_$(date +%s%N).img" + + echo "::group::Creating overlay images" + apptainer overlay create --size 1024 --create-dir /var/cache/iris "${OVERLAY_1}" + apptainer overlay create --size 1024 --create-dir /var/cache/iris "${OVERLAY_2}" + apptainer overlay create --size 1024 --create-dir /var/cache/iris "${OVERLAY_4}" + echo "::endgroup::" + + echo "::group::Starting parallel tests" + echo "Starting 1-rank test on GPUs 0,1..." + apptainer exec --overlay "${OVERLAY_1}" --no-home --cleanenv --env HIP_VISIBLE_DEVICES="0,1" \ + --bind "${PWD}:/iris_workspace" --cwd /iris_workspace \ + ~/apptainer/iris-dev.sif bash -c " + set -e + pip install git+https://github.com/${{ github.repository }}.git@${{ github.sha }} + bash .github/scripts/run_tests.sh 1 + " & + PID1=$! + + echo "Starting 2-rank test on GPUs 2,3..." + apptainer exec --overlay "${OVERLAY_2}" --no-home --cleanenv --env HIP_VISIBLE_DEVICES="2,3" \ + --bind "${PWD}:/iris_workspace" --cwd /iris_workspace \ + ~/apptainer/iris-dev.sif bash -c " + set -e + pip install git+https://github.com/${{ github.repository }}.git@${{ github.sha }} + bash .github/scripts/run_tests.sh 2 + " & + PID2=$! + + echo "Starting 4-rank test on GPUs 4,5,6,7..." + apptainer exec --overlay "${OVERLAY_4}" --no-home --cleanenv --env HIP_VISIBLE_DEVICES="4,5,6,7" \ + --bind "${PWD}:/iris_workspace" --cwd /iris_workspace \ + ~/apptainer/iris-dev.sif bash -c " + set -e + pip install git+https://github.com/${{ github.repository }}.git@${{ github.sha }} + bash .github/scripts/run_tests.sh 4 + " & + PID4=$! + echo "::endgroup::" + + # Wait for all parallel tests and track failures + echo "::group::Waiting for parallel tests to complete" + FAIL=0 + FAILED_TESTS="" + + wait $PID1 || { echo "::error::1-rank test FAILED"; FAILED_TESTS="$FAILED_TESTS 1-rank"; FAIL=1; } + wait $PID2 || { echo "::error::2-rank test FAILED"; FAILED_TESTS="$FAILED_TESTS 2-rank"; FAIL=1; } + wait $PID4 || { echo "::error::4-rank test FAILED"; FAILED_TESTS="$FAILED_TESTS 4-rank"; FAIL=1; } + echo "::endgroup::" + + # Cleanup overlay images + echo "::group::Cleaning up overlay images" + rm -f "${OVERLAY_1}" "${OVERLAY_2}" "${OVERLAY_4}" + echo "::endgroup::" + + if [ $FAIL -eq 1 ]; then + echo "::error::Parallel tests failed:$FAILED_TESTS" + exit 1 + fi + + echo "✅ All parallel tests (1, 2, 4 ranks) passed!" + + test-8-ranks: + name: Pip Install Test 8 Ranks + needs: build-apptainer-image + runs-on: [self-hosted, mi3008x] + timeout-minutes: 30 + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Cleanup lingering ports before tests + run: | + bash .github/scripts/cleanup_ports.sh + + - name: Run 8-rank pip install test + run: | + set -e + + # Create unique overlay image for isolation + OVERLAY_8="/tmp/iris_overlay_$(whoami)_8rank_$(date +%s%N).img" + + echo "::group::Creating overlay image" + apptainer overlay create --size 1024 --create-dir /var/cache/iris "${OVERLAY_8}" + echo "::endgroup::" + + echo "::group::Running 8-rank test on all GPUs" + apptainer exec --overlay "${OVERLAY_8}" --no-home --cleanenv --env HIP_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" \ + --bind "${PWD}:/iris_workspace" --cwd /iris_workspace \ + ~/apptainer/iris-dev.sif bash -c " + set -e + pip install git+https://github.com/${{ github.repository }}.git@${{ github.sha }} + bash .github/scripts/run_tests.sh 8 + " + echo "::endgroup::" + + # Cleanup overlay image + echo "::group::Cleaning up overlay image" + rm -f "${OVERLAY_8}" + echo "::endgroup::" + + echo "✅ 8-rank test passed!" diff --git a/.github/workflows/iris-tests-apptainer.yml b/.github/workflows/iris-tests-apptainer.yml index 5e2d9a85..cb58f164 100644 --- a/.github/workflows/iris-tests-apptainer.yml +++ b/.github/workflows/iris-tests-apptainer.yml @@ -38,39 +38,132 @@ jobs: else echo "Using existing Apptainer image" fi - run-tests: - name: ${{ matrix.ranks }}-rank Iris Test + test-1-2-4-ranks: + name: Test 1/2/4 Ranks (Parallel) needs: build-apptainer-image runs-on: [self-hosted, mi3008x] timeout-minutes: 20 - strategy: - matrix: - ranks: [1, 2, 4, 8] - max-parallel: 1 steps: - name: Checkout repository uses: actions/checkout@v4 - - name: Run Iris Tests with ${{ matrix.ranks }} ranks + - name: Cleanup lingering ports before tests run: | - apptainer exec ~/apptainer/iris-dev.sif bash -c " - set -e # Exit on any error - - # Install iris first - pip install -e . - - # Run examples tests one at a time using distributed wrapper - echo 'Running examples tests one at a time...' - for test_file in tests/examples/test_*.py; do - echo \"Testing: \$test_file with ${{ matrix.ranks }} ranks\" - python tests/run_tests_distributed.py --num_ranks ${{ matrix.ranks }} \"\$test_file\" -v --tb=short - done - - # Run unit tests one at a time using distributed wrapper - echo 'Running unit tests one at a time...' - for test_file in tests/unittests/test_*.py; do - echo \"Testing: \$test_file with ${{ matrix.ranks }} ranks\" - python tests/run_tests_distributed.py --num_ranks ${{ matrix.ranks }} \"\$test_file\" -v --tb=short - done - " \ No newline at end of file + bash .github/scripts/cleanup_ports.sh + + - name: Run 1, 2, 4 rank tests in parallel + run: | + set -e + + # Run tests in parallel with different GPU assignments + # Note: Each test gets 2+ GPUs even if it only uses some of them. + # This allows tests like test_empty_device_handling to verify that + # allocating on a different device correctly raises an error. + + # Create unique overlay images for isolation + OVERLAY_1="/tmp/iris_overlay_$(whoami)_1rank_$(date +%s%N).img" + OVERLAY_2="/tmp/iris_overlay_$(whoami)_2rank_$(date +%s%N).img" + OVERLAY_4="/tmp/iris_overlay_$(whoami)_4rank_$(date +%s%N).img" + + echo "::group::Creating overlay images" + apptainer overlay create --size 1024 --create-dir /var/cache/iris "${OVERLAY_1}" + apptainer overlay create --size 1024 --create-dir /var/cache/iris "${OVERLAY_2}" + apptainer overlay create --size 1024 --create-dir /var/cache/iris "${OVERLAY_4}" + echo "::endgroup::" + + echo "::group::Starting parallel tests" + echo "Starting 1-rank test on GPUs 0,1..." + apptainer exec --overlay "${OVERLAY_1}" --no-home --cleanenv --env HIP_VISIBLE_DEVICES="0,1" \ + --bind "${PWD}:/iris_workspace" --cwd /iris_workspace \ + ~/apptainer/iris-dev.sif bash -c " + set -e + pip install -e . + bash .github/scripts/run_tests.sh 1 + " & + PID1=$! + + echo "Starting 2-rank test on GPUs 2,3..." + apptainer exec --overlay "${OVERLAY_2}" --no-home --cleanenv --env HIP_VISIBLE_DEVICES="2,3" \ + --bind "${PWD}:/iris_workspace" --cwd /iris_workspace \ + ~/apptainer/iris-dev.sif bash -c " + set -e + pip install -e . + bash .github/scripts/run_tests.sh 2 + " & + PID2=$! + + echo "Starting 4-rank test on GPUs 4,5,6,7..." + apptainer exec --overlay "${OVERLAY_4}" --no-home --cleanenv --env HIP_VISIBLE_DEVICES="4,5,6,7" \ + --bind "${PWD}:/iris_workspace" --cwd /iris_workspace \ + ~/apptainer/iris-dev.sif bash -c " + set -e + pip install -e . + bash .github/scripts/run_tests.sh 4 + " & + PID4=$! + echo "::endgroup::" + + # Wait for all parallel tests and track failures + echo "::group::Waiting for parallel tests to complete" + FAIL=0 + FAILED_TESTS="" + + wait $PID1 || { echo "::error::1-rank test FAILED"; FAILED_TESTS="$FAILED_TESTS 1-rank"; FAIL=1; } + wait $PID2 || { echo "::error::2-rank test FAILED"; FAILED_TESTS="$FAILED_TESTS 2-rank"; FAIL=1; } + wait $PID4 || { echo "::error::4-rank test FAILED"; FAILED_TESTS="$FAILED_TESTS 4-rank"; FAIL=1; } + echo "::endgroup::" + + # Cleanup overlay images + echo "::group::Cleaning up overlay images" + rm -f "${OVERLAY_1}" "${OVERLAY_2}" "${OVERLAY_4}" + echo "::endgroup::" + + if [ $FAIL -eq 1 ]; then + echo "::error::Parallel tests failed:$FAILED_TESTS" + exit 1 + fi + + echo "✅ All parallel tests (1, 2, 4 ranks) passed!" + + test-8-ranks: + name: Test 8 Ranks + needs: build-apptainer-image + runs-on: [self-hosted, mi3008x] + timeout-minutes: 15 + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Cleanup lingering ports before tests + run: | + bash .github/scripts/cleanup_ports.sh + + - name: Run 8-rank test + run: | + set -e + + # Create unique overlay image for isolation + OVERLAY_8="/tmp/iris_overlay_$(whoami)_8rank_$(date +%s%N).img" + + echo "::group::Creating overlay image" + apptainer overlay create --size 1024 --create-dir /var/cache/iris "${OVERLAY_8}" + echo "::endgroup::" + + echo "::group::Running 8-rank test on all GPUs" + apptainer exec --overlay "${OVERLAY_8}" --no-home --cleanenv --env HIP_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" \ + --bind "${PWD}:/iris_workspace" --cwd /iris_workspace \ + ~/apptainer/iris-dev.sif bash -c " + set -e + pip install -e . + bash .github/scripts/run_tests.sh 8 + " + echo "::endgroup::" + + # Cleanup overlay image + echo "::group::Cleaning up overlay image" + rm -f "${OVERLAY_8}" + echo "::endgroup::" + + echo "✅ 8-rank test passed!" diff --git a/MANIFEST.in b/MANIFEST.in index 083b7f42..2c255da1 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,12 +1,7 @@ -# Include C++ source files -include csrc/finegrained_alloc/*.hip -include csrc/finegrained_alloc/build.sh - # Include documentation include README.md include LICENSE include iris/README.md # Include build configuration -include pyproject.toml -include setup.py \ No newline at end of file +include pyproject.toml \ No newline at end of file diff --git a/README.md b/README.md index dbb7d40a..27eed85e 100644 --- a/README.md +++ b/README.md @@ -11,10 +11,7 @@ Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) [![Ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json)](https://github.com/ROCm/iris/blob/main/.github/workflows/lint.yml) [![Iris Tests](https://github.com/ROCm/iris/actions/workflows/iris-tests-apptainer.yml/badge.svg)](https://github.com/ROCm/iris/actions/workflows/iris-tests-apptainer.yml) -> [!IMPORTANT] -> This project is intended for research purposes only and is provided by AMD Research and Advanced Development team. This is not a product. Use it at your own risk and discretion. - -Iris is a Triton-based framework for Remote Memory Access (RMA) operations. Iris provides SHMEM-like APIs within Triton for Multi-GPU programming. Iris' goal is to make Multi-GPU programming a first-class citizen in Triton while retaining Triton's programmability and performance. +Iris is a Triton-based framework for Remote Memory Access (RMA) operations developed by AMD's Research and Advanced Development team. Iris provides SHMEM-like APIs within Triton for Multi-GPU programming. Iris' goal is to make Multi-GPU programming a first-class citizen in Triton while retaining Triton's programmability and performance. ## Key Features @@ -106,7 +103,7 @@ if __name__ == "__main__": ### Quick Installation > [!NOTE] -> **Requirements**: Python 3.10+, PyTorch 2.0+ (ROCm version), ROCm 6.3.1+ HIP runtime, and Triton +> **Requirements**: Python 3.10+, PyTorch 2.0+ (ROCm version), ROCm 6.3.1+ HIP runtime, Triton, and setuptools>=61 For a quick installation directly from the repository: diff --git a/benchmark/examples/benchmark_all_gather_gemm_pull.py b/benchmark/examples/benchmark_all_gather_gemm_pull.py new file mode 100644 index 00000000..6f0d06f7 --- /dev/null +++ b/benchmark/examples/benchmark_all_gather_gemm_pull.py @@ -0,0 +1,254 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import triton +import random +import sys +import os +import argparse +import json + +from examples.common.utils import JSONWriter +from examples.common.validation import validate_gemm +import importlib.util +from pathlib import Path +import iris + +current_dir = Path(__file__).parent +file_path = (current_dir / "../../examples/14_all_gather_gemm/all_gather_gemm_pull.py").resolve() +module_name = "all_gather_gemm_pull" + +spec = importlib.util.spec_from_file_location(module_name, file_path) +ag_gemm_kernels_module = importlib.util.module_from_spec(spec) +spec.loader.exec_module(ag_gemm_kernels_module) +persistent_ag_gemm = ag_gemm_kernels_module.persistent_ag_gemm + + +torch.manual_seed(123) +random.seed(123) + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Run a sweep of All-Gather GEMM benchmarks from a config file.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("-v", "--validate", action="store_true", help="Enable validation mode.") + parser.add_argument("-b", "--benchmark", action="store_true", help="Enable benchmarking mode.") + parser.add_argument( + "--config_file", + type=str, + default="dataset/ag_gemm.json", + help="Path to the JSON file with benchmark configurations.", + ) + parser.add_argument("--output_file", type=str, default="ag_gemm_pull.json", help="Base name for output files") + parser.add_argument( + "--output_dir", type=str, default="results/all_gather_gemm_pull", help="Name of the output directory" + ) + + parser.add_argument("-m", type=int, default=1024, help="Number of rows in matrix A (M)") + parser.add_argument("-n", type=int, default=3584, help="Total number of columns in matrix B (N)") + parser.add_argument("-k", type=int, default=8192, help="Common dimension between matrices A and B (K)") + + parser.add_argument( + "--datatype", type=str, default="fp16", choices=["fp16", "bf16", "fp32"], help="Datatype of computation" + ) + parser.add_argument("--heap_size", type=int, default=1 << 33, help="Iris heap size in bytes") + + parser.add_argument("--BLK_M", type=int, default=256, help="Block size M for the kernel") + parser.add_argument("--BLK_N", type=int, default=64, help="Block size N for the kernel") + parser.add_argument("--BLK_K", type=int, default=64, help="Block size K for the kernel") + parser.add_argument("--gsize_m", type=int, default=6, help="Group size in M dimension") + parser.add_argument( + "--num_sms", type=int, default=None, help="Number of SMs for the kernel (default: auto-detected)" + ) + + parser.add_argument("--num_ranks", type=int, default=8, help="Number of GPUs to run the example on.") + + return parser.parse_args() + + +def worker(rank: int, world_size: int, init_url: str, args: argparse.Namespace): + """ + This function will be executed by each spawned process. + """ + backend = "nccl" if torch.cuda.is_available() else "gloo" + dist.init_process_group( + backend=backend, init_method=init_url, world_size=world_size, rank=rank, device_id=torch.device(f"cuda:{rank}") + ) + + shmem = iris.iris(args.heap_size) + torch.cuda.set_device(rank) + world_size = shmem.get_num_ranks() + torch.cuda.set_device(rank) + + output_dir = args.output_dir + + if rank == 0: + os.makedirs(output_dir, exist_ok=True) + shmem.barrier() + + with open(args.config_file, "r") as f: + configs_to_run = json.load(f) + + shmem.info(f"Loaded {len(configs_to_run)} configurations from {args.config_file}") + + for config in configs_to_run: + run_args = vars(args).copy() + run_args.update(config) + + dtype_map = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp32": torch.float32} + datatype = dtype_map.get(run_args["datatype"]) + + M, N, K = run_args["m"], run_args["n"], run_args["k"] + shmem.info(f"\n--- Running Benchmark for M={M}, N={N}, K={K} ---") + + base_name, extension = os.path.splitext(args.output_file) + unique_filename = f"{base_name}_m_{M}{extension}" + full_output_path = os.path.join(output_dir, unique_filename) + + json_writer = JSONWriter(full_output_path) + json_writer.add_field("world_size", world_size) + for key, value in run_args.items(): + json_writer.add_field(key, value) + + K_local = K // world_size + + if rank == 0: + A_global = torch.randn((M, K), dtype=datatype, device="cuda") + else: + A_global = torch.empty((M, K), dtype=datatype, device="cuda") + + A_global_broadcasted = ( + torch.from_numpy(shmem.broadcast(A_global.cpu().numpy(), source_rank=0)).to(datatype).to("cuda") + ) + shmem.barrier() + + A_local = A_global_broadcasted[:, rank * K_local : (rank + 1) * K_local].contiguous() + + if rank == 0: + B = torch.randn((K, N), device="cuda", dtype=datatype) + else: + B = torch.empty((K, N), device="cuda", dtype=datatype) + + B = torch.from_numpy(shmem.broadcast(B.cpu().numpy(), source_rank=0)).to(datatype).to("cuda") + shmem.barrier() + + C = torch.empty((M, N), device="cuda", dtype=datatype) + A_local_iris = shmem.empty((M, K_local), dtype=datatype) + A_local_iris.copy_(A_local) + + # Use provided num_sms or auto-detect + if run_args["num_sms"] is None: + num_sms = torch.cuda.get_device_properties(rank).multi_processor_count + run_args["num_sms"] = num_sms + else: + num_sms = run_args["num_sms"] + + main_stream = torch.cuda.Stream() + kernel_timing = { + "fused_ag_gemm": { + "start_event": torch.cuda.Event(enable_timing=True), + "end_event": torch.cuda.Event(enable_timing=True), + "ms": 0, + "experiments": 0, + } + } + + def run_experiment(): + nonlocal kernel_timing + with torch.cuda.stream(main_stream): + kernel_timing["fused_ag_gemm"]["start_event"].record() + persistent_ag_gemm[(num_sms,)]( + A_local_iris, + B, + C, + M, + N, + K, + A_local_iris.stride(0), + A_local_iris.stride(1), + B.stride(0), + B.stride(1), + C.stride(0), + C.stride(1), + run_args["BLK_M"], + run_args["BLK_N"], + run_args["BLK_K"], + run_args["gsize_m"], + run_args["num_sms"], + 1, # NUM_XCDs + (K % run_args["BLK_K"] == 0), + shmem.get_heap_bases(), + rank, + world_size, + ) + kernel_timing["fused_ag_gemm"]["end_event"].record() + kernel_timing["fused_ag_gemm"]["experiments"] += 1 + + shmem.barrier() + + ms = kernel_timing["fused_ag_gemm"]["start_event"].elapsed_time(kernel_timing["fused_ag_gemm"]["end_event"]) + kernel_timing["fused_ag_gemm"]["ms"] += ms + + run_experiment() + shmem.barrier() + kernel_timing["fused_ag_gemm"]["ms"] = 0 + kernel_timing["fused_ag_gemm"]["experiments"] = 0 + + if args.benchmark: + triton_ms = iris.do_bench(run_experiment, barrier_fn=shmem.barrier) + tflops = 2 * M * N * K * 1e-12 / (triton_ms * 1e-3) + + shmem.info(f"Result (iris.do_bench): {triton_ms:.3f} ms, {tflops:.3f} TFLOPS") + json_writer.add_field("total_ms", triton_ms) + json_writer.add_field("tflops", tflops) + + key = "fused_ag_gemm" + avg_kernel_ms = kernel_timing[key]["ms"] / kernel_timing[key]["experiments"] + json_writer.add_field(key + "_ms", avg_kernel_ms) + shmem.info(f"Result (CUDA Events): {avg_kernel_ms:.3f} ms for the kernel") + + if args.validate: + if not args.benchmark: + run_experiment() + shmem.barrier() + + success = validate_gemm(A_global_broadcasted, B, C, shmem, atol=1.0) + + passed_str = "passed" if success else "failed" + shmem.info(f"Final C validation {passed_str}.") + json_writer.add_field("validation_passed", success) + + if rank == 0: + json_writer.flush() + shmem.info(f"Saved results to {full_output_path}") + + shmem.info("\nBenchmark sweep complete.") + + shmem.barrier() + dist.destroy_process_group() + + +def main(): + args = parse_args() + if not args.validate and not args.benchmark: + print("Error: You must specify a mode to run.") + print("Please use -v for validation or -b for benchmarking.") + sys.exit(1) + num_ranks = args.num_ranks + init_url = "tcp://127.0.0.1:29501" + mp.spawn( + fn=worker, + args=(num_ranks, init_url, args), + nprocs=num_ranks, + join=True, + ) + + +if __name__ == "__main__": + main() diff --git a/benchmark/examples/benchmark_all_gather_gemm_push.py b/benchmark/examples/benchmark_all_gather_gemm_push.py new file mode 100644 index 00000000..d3e6ac32 --- /dev/null +++ b/benchmark/examples/benchmark_all_gather_gemm_push.py @@ -0,0 +1,305 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import triton +import random +import sys +import os +import argparse +import json + +from examples.common.utils import JSONWriter +from examples.common.validation import validate_gemm +import importlib.util +from pathlib import Path +import iris + +current_dir = Path(__file__).parent +file_path = (current_dir / "../../examples/14_all_gather_gemm/all_gather_gemm_push.py").resolve() +module_name = "all_gather_gemm_push" + +spec = importlib.util.spec_from_file_location(module_name, file_path) +ag_gemm_kernels_module = importlib.util.module_from_spec(spec) +spec.loader.exec_module(ag_gemm_kernels_module) +gemm_push_kernel = ag_gemm_kernels_module.gemm_push_kernel +push_shards_kernel = ag_gemm_kernels_module.push_shards_kernel + +torch.manual_seed(123) +random.seed(123) + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Run a sweep of Iris Push All-Gather GEMM benchmarks from a config file.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("-v", "--validate", action="store_true", help="Enable validation mode.") + parser.add_argument("-b", "--benchmark", action="store_true", help="Enable benchmarking mode.") + parser.add_argument( + "--config_file", + type=str, + default="dataset/ag_gemm.json", + help="Path to the JSON file with benchmark configurations.", + ) + parser.add_argument("--output_file", type=str, default="ag_gemm_push.json", help="Base name for output files") + parser.add_argument( + "--output_dir", type=str, default="results/all_gather_gemm_push", help="Name of the output directory" + ) + + parser.add_argument("-m", type=int, default=1024, help="Number of rows in matrix A (M)") + parser.add_argument("-n", type=int, default=3584, help="Total number of columns in matrix B (N)") + parser.add_argument("-k", type=int, default=8192, help="Common dimension between matrices A and B (K)") + + parser.add_argument( + "--datatype", type=str, default="fp16", choices=["fp16", "bf16", "fp32"], help="Datatype of computation" + ) + parser.add_argument("--heap_size", type=int, default=1 << 33, help="Iris heap size in bytes") + parser.add_argument("--BLK_M", type=int, default=256, help="Block size M for tiling") + parser.add_argument("--BLK_N", type=int, default=64, help="Block size N for GEMM computation") + parser.add_argument("--BLK_K", type=int, default=64, help="Block size K for tiling") + parser.add_argument("--gsize_m", type=int, default=6, help="Group size in M dimension") + parser.add_argument( + "--num_sms", type=int, default=None, help="Number of SMs for the kernel (default: auto-detected)" + ) + + parser.add_argument("--num_ranks", type=int, default=8, help="Number of GPUs to run the example on.") + + return parser.parse_args() + + +def worker(rank: int, world_size: int, init_url: str, args: argparse.Namespace): + """ + This function will be executed by each spawned process. + """ + backend = "nccl" if torch.cuda.is_available() else "gloo" + dist.init_process_group( + backend=backend, init_method=init_url, world_size=world_size, rank=rank, device_id=torch.device(f"cuda:{rank}") + ) + + shmem = iris.iris(args.heap_size) + torch.cuda.set_device(rank) + world_size = shmem.get_num_ranks() + torch.cuda.set_device(rank) + + output_dir = args.output_dir + if rank == 0: + os.makedirs(output_dir, exist_ok=True) + shmem.barrier() + + with open(args.config_file, "r") as f: + configs_to_run = json.load(f) + + print(f"Loaded {len(configs_to_run)} configurations from {args.config_file}") + + for config in configs_to_run: + run_args = vars(args).copy() + run_args.update(config) + + dtype_map = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp32": torch.float32} + datatype = dtype_map.get(run_args["datatype"]) + + M, N, K = run_args["m"], run_args["n"], run_args["k"] + shmem.info(f"\n--- Running Benchmark for M={M}, N={N}, K={K} ---") + + base_name, extension = os.path.splitext(args.output_file) + unique_filename = f"{base_name}_m_{M}{extension}" + full_output_path = os.path.join(output_dir, unique_filename) + + json_writer = JSONWriter(full_output_path) + json_writer.add_field("world_size", world_size) + for key, value in run_args.items(): + json_writer.add_field(key, value) + + K_local = K // world_size + + if rank == 0: + A_global = torch.randn((M, K), dtype=datatype, device="cuda") + else: + A_global = torch.empty((M, K), dtype=datatype, device="cuda") + + A_global_broadcasted = ( + torch.from_numpy(shmem.broadcast(A_global.cpu().numpy(), source_rank=0)).to(datatype).to("cuda") + ) + shmem.barrier() + + A_local = A_global_broadcasted[:, rank * K_local : (rank + 1) * K_local].contiguous() + + if rank == 0: + B = torch.randn((K, N), device="cuda", dtype=datatype) + else: + B = torch.empty((K, N), device="cuda", dtype=datatype) + + B = torch.from_numpy(shmem.broadcast(B.cpu().numpy(), source_rank=0)).to(datatype).to("cuda") + shmem.barrier() + + C = torch.empty((M, N), device="cuda", dtype=datatype) + + A_local_iris = shmem.empty((M, K_local), dtype=datatype) + A_local_iris.copy_(A_local) + A_inbox_iris = shmem.empty((world_size, M, K_local), dtype=datatype) + + num_m_tiles = (M + run_args["BLK_M"] - 1) // run_args["BLK_M"] + num_k_tiles = (K_local + run_args["BLK_K"] - 1) // run_args["BLK_K"] + signal_flags_iris = shmem.zeros((world_size, world_size, num_m_tiles, num_k_tiles), dtype=torch.int32) + + # Use provided num_sms or auto-detect + if run_args["num_sms"] is None: + num_sms = torch.cuda.get_device_properties(rank).multi_processor_count + run_args["num_sms"] = num_sms + else: + num_sms = run_args["num_sms"] + + main_stream = torch.cuda.Stream() + kernel_timing = { + "push_kernel": { + "start_event": torch.cuda.Event(enable_timing=True), + "end_event": torch.cuda.Event(enable_timing=True), + "ms": 0, + "experiments": 0, + }, + "compute_kernel": { + "start_event": torch.cuda.Event(enable_timing=True), + "end_event": torch.cuda.Event(enable_timing=True), + "ms": 0, + "experiments": 0, + }, + } + + def run_experiment(): + nonlocal kernel_timing + signal_flags_iris.zero_() + shmem.barrier() + + with torch.cuda.stream(main_stream): + push_grid = (num_m_tiles, num_k_tiles) + + kernel_timing["push_kernel"]["start_event"].record() + push_shards_kernel[push_grid]( + A_local_iris, + A_inbox_iris, + signal_flags_iris, + M, + K_local, + A_local_iris.stride(0), + A_local_iris.stride(1), + A_inbox_iris.stride(0), + A_inbox_iris.stride(1), + A_inbox_iris.stride(2), + signal_flags_iris.stride(0), + signal_flags_iris.stride(1), + signal_flags_iris.stride(2), + signal_flags_iris.stride(3), + run_args["BLK_M"], + run_args["BLK_K"], + rank, + world_size, + shmem.get_heap_bases(), + ) + kernel_timing["push_kernel"]["end_event"].record() + + kernel_timing["compute_kernel"]["start_event"].record() + gemm_push_kernel[(num_sms,)]( + A_inbox_iris, + B, + C, + M, + N, + K, + signal_flags_iris, + A_inbox_iris.stride(0), + A_inbox_iris.stride(1), + A_inbox_iris.stride(2), + B.stride(0), + B.stride(1), + C.stride(0), + C.stride(1), + signal_flags_iris.stride(0), + signal_flags_iris.stride(1), + signal_flags_iris.stride(2), + signal_flags_iris.stride(3), + run_args["BLK_M"], + run_args["BLK_N"], + run_args["BLK_K"], + run_args["gsize_m"], + num_sms, + 1, # NUM_XCDs + (K_local % run_args["BLK_K"] == 0), + rank, + world_size, + ) + kernel_timing["compute_kernel"]["end_event"].record() + + torch.cuda.synchronize() + kernel_timing["push_kernel"]["ms"] += kernel_timing["push_kernel"]["start_event"].elapsed_time( + kernel_timing["push_kernel"]["end_event"] + ) + kernel_timing["push_kernel"]["experiments"] += 1 + kernel_timing["compute_kernel"]["ms"] += kernel_timing["compute_kernel"]["start_event"].elapsed_time( + kernel_timing["compute_kernel"]["end_event"] + ) + kernel_timing["compute_kernel"]["experiments"] += 1 + + run_experiment() + shmem.barrier() + + for key in kernel_timing: + kernel_timing[key]["ms"] = 0 + kernel_timing[key]["experiments"] = 0 + + if args.benchmark: + triton_ms = iris.do_bench(run_experiment, barrier_fn=shmem.barrier) + tflops = 2 * M * N * K * 1e-12 / (triton_ms * 1e-3) + + shmem.info(f"Result (iris.do_bench): {triton_ms:.3f} ms, {tflops:.3f} TFLOPS") + json_writer.add_field("total_ms", triton_ms) + json_writer.add_field("tflops", tflops) + + for key in kernel_timing: + if kernel_timing[key]["experiments"] > 0: + avg_kernel_ms = kernel_timing[key]["ms"] / kernel_timing[key]["experiments"] + json_writer.add_field(key + "_ms", avg_kernel_ms) + shmem.info(f"Result (CUDA Events) - {key}: {avg_kernel_ms:.3f} ms") + + if args.validate: + if not args.benchmark: + run_experiment() + shmem.barrier() + + success = validate_gemm(A_global_broadcasted, B, C, shmem, atol=1.0) + + passed_str = "passed" if success else "failed" + shmem.info(f"Final C validation {passed_str}.") + json_writer.add_field("validation_passed", success) + + if rank == 0: + json_writer.flush() + shmem.info(f"Saved results to {full_output_path}") + + shmem.info("\nBenchmark sweep complete.") + + shmem.barrier() + dist.destroy_process_group() + + +def main(): + args = parse_args() + if not args.validate and not args.benchmark: + print("Error: You must specify a mode to run.") + print("Please use -v for validation or -b for benchmarking.") + sys.exit(1) + num_ranks = args.num_ranks + init_url = "tcp://127.0.0.1:29501" + mp.spawn( + fn=worker, + args=(num_ranks, init_url, args), + nprocs=num_ranks, + join=True, + ) + + +if __name__ == "__main__": + main() diff --git a/benchmark/examples/benchmark_flash_decode.py b/benchmark/examples/benchmark_flash_decode.py new file mode 100644 index 00000000..b18259d1 --- /dev/null +++ b/benchmark/examples/benchmark_flash_decode.py @@ -0,0 +1,214 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +import sys +import json +import itertools +from pathlib import Path +import argparse +import torch +import iris +import os +import torch.multiprocessing as mp +import torch.distributed as dist + +project_root = Path(__file__).resolve() +while not (project_root / "tests").is_dir() or not (project_root / "examples").is_dir(): + if project_root == project_root.parent: + raise FileNotFoundError( + "Could not find project root. Make sure your 'tests' and 'examples' " + "directories are siblings in the project structure." + ) + project_root = project_root.parent + +module_dir = project_root / "examples" / "13_flash_decode" +if module_dir.is_dir(): + sys.path.insert(0, str(module_dir)) +else: + raise FileNotFoundError(f"Target directory not found: {module_dir}") + +from flash_decode_fused_layer import flash_decode_fused_layer # noqa: E402 + + +def parse_args(): + """ + Arguments for the benchmark + The default parameters are in dataset/flash_decode_config_iris.json + A different config file can be set with the --config flag + """ + parser = argparse.ArgumentParser( + description="Run Flash Decode benchmark with parameters from a config file.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + parser.add_argument( + "-c", + "--config", + type=str, + default="dataset/flash_decode_config_iris.json", + help="Path to the JSON configuration file", + ) + + config_args, _ = parser.parse_known_args() + + config_defaults = {} + if os.path.exists(config_args.config): + try: + with open(config_args.config, "r") as f: + config_from_file = json.load(f) + if config_from_file: + print(f"Configuration successfully loaded from '{config_args.config}'") + config_defaults = {**config_from_file, **config_from_file.get("sweep_parameters", {})} + if "sweep_parameters" in config_defaults: + del config_defaults["sweep_parameters"] + except json.JSONDecodeError: + print(f"Error: Config file '{config_args.config}' is not valid JSON.") + else: + print(f"Warning: Config file '{config_args.config}' not found.") + + parser.set_defaults(**config_defaults) + + parser.add_argument("--output_dir", type=str, help="Directory to save results") + parser.add_argument("--data_type", type=str, choices=["float16", "bfloat16", "float32"], help="PyTorch data type") + parser.add_argument("--warmup_iterations", type=int, help="Number of warmup iterations") + parser.add_argument("--repeat_iterations", type=int, help="Number of benchmark iterations") + parser.add_argument("--kv_len", type=int, nargs="+", help="Override KV_LEN_SWEEP") + parser.add_argument("--num_heads", type=int, nargs="+", help="Override NUM_HEADS_SWEEP") + parser.add_argument("--head_dim", type=int, nargs="+", help="Override HEAD_DIM_SWEEP") + parser.add_argument("--num_seqs", type=int, nargs="+", help="Override NUM_SEQS_SWEEP") + parser.add_argument("--num_ranks", type=int, default=8, help="Number of GPUs to run on") + + final_args = parser.parse_args() + return final_args + + +def prepare_perf_data(cfg, num_query_heads, num_kv_heads): + """Prepares local data for the performance test on the current rank.""" + num_blocks_per_rank = (cfg["kv_len"] + cfg["block_size"] - 1) // cfg["block_size"] + + query = torch.randn(cfg["num_seqs"], num_query_heads, cfg["head_dim"], dtype=cfg["dtype"]).cuda() + key_cache_this_rank = torch.randn( + num_blocks_per_rank, cfg["block_size"], num_kv_heads, cfg["head_dim"], dtype=cfg["dtype"] + ).cuda() + value_cache_this_rank = torch.randn( + num_blocks_per_rank, cfg["block_size"], num_kv_heads, cfg["head_dim"], dtype=cfg["dtype"] + ).cuda() + block_tables_this_rank = torch.arange(num_blocks_per_rank, dtype=torch.int32).repeat(cfg["num_seqs"], 1).cuda() + + return { + "query": query, + "key_cache_this_rank": key_cache_this_rank, + "value_cache_this_rank": value_cache_this_rank, + "block_tables_this_rank": block_tables_this_rank, + } + + +def run_benchmark(rank, world_size, init_url, args): + backend = "nccl" if torch.cuda.is_available() else "gloo" + dist.init_process_group( + backend=backend, init_method=init_url, world_size=world_size, rank=rank, device_id=torch.device(f"cuda:{rank}") + ) + # Set the correct GPU for this specific process + torch.cuda.set_device(rank) + + torch.manual_seed(42 + rank) + # Iris setup + shmem = iris.iris() + rank = shmem.get_rank() + world_size = shmem.get_num_ranks() + + output_dir = args.output_dir + datatype = getattr(torch, args.data_type) + + if rank == 0: + if not os.path.exists(output_dir): + os.makedirs(output_dir) + print(f"Created output directory: '{output_dir}'") + + config_sweep = [] + param_product = itertools.product(args.kv_len, args.num_heads, args.head_dim, args.num_seqs) + for kv_len, num_heads, head_dim, num_seqs in param_product: + config_sweep.append({"kv_len": kv_len, "num_heads": num_heads, "head_dim": head_dim, "num_seqs": num_seqs}) + + # Loop through configs + for i, config in enumerate(config_sweep): + if rank == 0: + print(f"\n--- Running Config {i + 1}/{len(config_sweep)}: {config} ---") + + cfg = {"block_size": 1, "soft_cap": 0.0, "dtype": datatype, **config} + num_query_heads = cfg["num_heads"] + num_kv_heads = num_query_heads // 8 if num_query_heads >= 8 else 1 + scale = cfg["head_dim"] ** -0.5 + + common_params = { + "num_q_heads": num_query_heads, + "num_kv_heads": num_kv_heads, + "q_head_dim": cfg["head_dim"], + "v_head_dim": cfg["head_dim"], + "page_size": cfg["block_size"], + "scale": scale, + "soft_cap": cfg["soft_cap"], + "max_allowed_batch": cfg["num_seqs"], + } + + fd_layer = flash_decode_fused_layer(shmem, rank, rank, world_size, world_size, **common_params) + + tensor_data = prepare_perf_data(cfg, num_query_heads, num_kv_heads) + kv_lens_per_rank = [config["kv_len"]] * config["num_seqs"] + kv_lens_tensor = torch.tensor(kv_lens_per_rank, dtype=torch.int32).cuda() + global_kv_lens_tensor = kv_lens_tensor.unsqueeze(0).repeat(world_size, 1) + + def run_experiment(): + return fd_layer( + tensor_data["query"], + tensor_data["key_cache_this_rank"], + tensor_data["value_cache_this_rank"], + global_kv_lens_tensor, + tensor_data["block_tables_this_rank"], + ) + + time_ms = iris.do_bench( + fn=run_experiment, + barrier_fn=shmem.barrier, + preamble_fn=getattr(fd_layer, "clear_flags", None), + n_warmup=args.warmup_iterations, + n_repeat=args.repeat_iterations, + return_mode="mean", + ) + + shmem.barrier() + + if rank == 0: + global_kv_len = cfg["kv_len"] * world_size + print(f"Result -> Global KV Length: {global_kv_len}, Avg. Time: {time_ms:.3f} ms") + + result_entry = config.copy() + result_entry["global_kv_len"] = global_kv_len + result_entry["avg_time_ms"] = time_ms + + filename = f"h{config['num_heads']}_d{config['head_dim']}_s{config['num_seqs']}_kv{config['kv_len']}.json" + output_path = os.path.join(output_dir, filename) + + with open(output_path, "w") as f: + json.dump(result_entry, f, indent=4) + print(f"Saved result to '{output_path}'") + + if rank == 0: + print("\nBenchmark sweep complete.") + + shmem.barrier() + dist.destroy_process_group() + + +if __name__ == "__main__": + args = parse_args() + num_ranks = args.num_ranks + init_url = "tcp://127.0.0.1:29500" + + mp.spawn( + fn=run_benchmark, + args=(num_ranks, init_url, args), + nprocs=num_ranks, + join=True, + ) diff --git a/dataset/ag_gemm.json b/dataset/ag_gemm.json new file mode 100644 index 00000000..db38d40f --- /dev/null +++ b/dataset/ag_gemm.json @@ -0,0 +1,13 @@ +[ + { "m": 1, "k": 8192, "n": 3584 }, + { "m": 2, "k": 8192, "n": 3584 }, + { "m": 4, "k": 8192, "n": 3584 }, + { "m": 8, "k": 8192, "n": 3584 }, + { "m": 16, "k": 8192, "n": 3584 }, + { "m": 32, "k": 8192, "n": 3584 }, + { "m": 64, "k": 8192, "n": 3584 }, + { "m": 128, "k": 8192, "n": 3584 }, + { "m": 256, "k": 8192, "n": 3584 }, + { "m": 512, "k": 8192, "n": 3584 }, + { "m": 1024, "k": 8192, "n": 3584 } +] \ No newline at end of file diff --git a/dataset/flash_decode_config_iris.json b/dataset/flash_decode_config_iris.json new file mode 100644 index 00000000..60895775 --- /dev/null +++ b/dataset/flash_decode_config_iris.json @@ -0,0 +1,12 @@ +{ + "output_dir": "results/flash_decode_iris", + "data_type": "float16", + "warmup_iterations": 100, + "repeat_iterations": 1000, + "sweep_parameters": { + "kv_len": [8192, 16384, 32768, 65536, 131072, 262144, 524288], + "num_heads": [96], + "head_dim": [128], + "num_seqs": [1, 4, 8, 16] + } +} \ No newline at end of file diff --git a/dataset/flash_decode_config_rccl.json b/dataset/flash_decode_config_rccl.json new file mode 100644 index 00000000..958b39c4 --- /dev/null +++ b/dataset/flash_decode_config_rccl.json @@ -0,0 +1,12 @@ +{ + "output_dir": "results/flash_decode_rccl", + "data_type": "float16", + "warmup_iterations": 100, + "repeat_iterations": 1000, + "sweep_parameters": { + "kv_len": [8192, 16384, 32768, 65536, 131072, 262144, 524288], + "num_heads": [96], + "head_dim": [128], + "num_seqs": [1, 4, 8, 16] + } +} \ No newline at end of file diff --git a/docs/reference/api-device-functions.md b/docs/reference/api-device-functions.md index 04608ad6..d815b19c 100644 --- a/docs/reference/api-device-functions.md +++ b/docs/reference/api-device-functions.md @@ -24,6 +24,11 @@ Device-side functions provided by Iris for remote memory operations and atomics. .. autofunction:: iris.iris.store ``` +### copy +```{eval-rst} +.. autofunction:: iris.iris.copy +``` + ### get ```{eval-rst} .. autofunction:: iris.iris.get diff --git a/docs/reference/api-iris-class.md b/docs/reference/api-iris-class.md index a14fb680..4b2cb34a 100644 --- a/docs/reference/api-iris-class.md +++ b/docs/reference/api-iris-class.md @@ -40,7 +40,7 @@ Use Iris-aware logging that automatically annotates each message with the curren ## Broadcast Helper -Broadcast a Python scalar or small object from a source rank to all ranks. This is a convenience wrapper over the internal Torch Distributed helper. +Broadcast data from a source rank to all ranks. This method automatically detects whether the value is a tensor/array or a scalar and uses the appropriate broadcast mechanism. ```{eval-rst} .. automethod:: iris.iris.Iris.broadcast diff --git a/docs/reference/examples.md b/docs/reference/examples.md index 06020ac3..1d54c490 100644 --- a/docs/reference/examples.md +++ b/docs/reference/examples.md @@ -22,6 +22,8 @@ We've curated a growing collection of practical examples that showcase the power - **[10_gemm_all_scatter_wg_specialization](https://github.com/ROCm/iris/tree/main/examples/10_gemm_all_scatter_wg_specialization)**: Matrix multiplication with all-scatter using workgroup specialization - **[11_gemm_all_scatter_producer_consumer](https://github.com/ROCm/iris/tree/main/examples/11_gemm_all_scatter_producer_consumer)**: Matrix multiplication with all-scatter using producer-consumer concurrent kernels - **[12_gemm_all_scatter_bulk_synchronous](https://github.com/ROCm/iris/tree/main/examples/12_gemm_all_scatter_bulk_synchronous)**: Matrix multiplication with all-scatter using the bulk synchronous parallel approach +- **[13_flash_decode](https://github.com/ROCm/iris/tree/main/examples/13_flash_decode)**: Fused Flash Decode Attention for accelerating LLM inference +- **[14_all_gather_gemm](https://github.com/ROCm/iris/tree/main/examples/14_all_gather_gemm)**: Fused All-Gather + GEMM with Pull and Push models ### Utilities - **[benchmark](https://github.com/ROCm/iris/tree/main/examples/benchmark)**: Benchmarking utilities and performance testing tools diff --git a/examples/00_load/load_bench.py b/examples/00_load/load_bench.py index fc8e4148..119c9653 100755 --- a/examples/00_load/load_bench.py +++ b/examples/00_load/load_bench.py @@ -235,7 +235,13 @@ def print_bandwidth_matrix(matrix, label="Unidirectional LOAD bandwidth GiB/s [R def _worker(local_rank: int, world_size: int, init_url: str, args: dict): """Worker function for PyTorch distributed execution.""" backend = "nccl" if torch.cuda.is_available() else "gloo" - dist.init_process_group(backend=backend, init_method=init_url, world_size=world_size, rank=local_rank) + dist.init_process_group( + backend=backend, + init_method=init_url, + world_size=world_size, + rank=local_rank, + device_id=torch.device(f"cuda:{local_rank}"), + ) # Main benchmark logic shmem = iris.iris(args["heap_size"]) diff --git a/examples/01_store/store_bench.py b/examples/01_store/store_bench.py index 80e7a7e0..e1edbcc7 100755 --- a/examples/01_store/store_bench.py +++ b/examples/01_store/store_bench.py @@ -208,7 +208,13 @@ def print_bandwidth_matrix(matrix, label="Unidirectional STORE bandwidth GiB/s [ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): """Worker function for PyTorch distributed execution.""" backend = "nccl" if torch.cuda.is_available() else "gloo" - dist.init_process_group(backend=backend, init_method=init_url, world_size=world_size, rank=local_rank) + dist.init_process_group( + backend=backend, + init_method=init_url, + world_size=world_size, + rank=local_rank, + device_id=torch.device(f"cuda:{local_rank}"), + ) # Main benchmark logic shmem = iris.iris(args["heap_size"]) diff --git a/examples/02_all_load/all_load_bench.py b/examples/02_all_load/all_load_bench.py index 6fb65c79..4a6f822f 100755 --- a/examples/02_all_load/all_load_bench.py +++ b/examples/02_all_load/all_load_bench.py @@ -36,7 +36,7 @@ def store_kernel( # Simple data to store (similar to what we accumulate) data = tl.zeros([BLOCK_SIZE], dtype=tl.float32) - tl.store(target_buffer + offsets, data, mask=mask) + tl.store(target_buffer + offsets, data, mask=mask, cache_modifier=".wt") @triton.jit @@ -61,19 +61,19 @@ def all_read_kernel( # Initialize accumulator in registers if world_size == 1: data = iris.load(source_buffer + offsets, cur_rank, 0, heap_bases_ptr, mask=mask) - tl.store(target_buffer + offsets, data, mask=mask) + tl.store(target_buffer + offsets, data, mask=mask, cache_modifier=".wt") elif world_size == 2: data_0 = iris.load(source_buffer + offsets, cur_rank, 0, heap_bases_ptr, mask=mask) data_1 = iris.load(source_buffer + offsets, cur_rank, 1, heap_bases_ptr, mask=mask) sum = data_0 + data_1 - tl.store(target_buffer + offsets, sum, mask=mask) + tl.store(target_buffer + offsets, sum, mask=mask, cache_modifier=".wt") elif world_size == 4: data_0 = iris.load(source_buffer + offsets, cur_rank, 0, heap_bases_ptr, mask=mask) data_1 = iris.load(source_buffer + offsets, cur_rank, 1, heap_bases_ptr, mask=mask) data_2 = iris.load(source_buffer + offsets, cur_rank, 2, heap_bases_ptr, mask=mask) data_3 = iris.load(source_buffer + offsets, cur_rank, 3, heap_bases_ptr, mask=mask) sum = data_0 + data_1 + data_2 + data_3 - tl.store(target_buffer + offsets, sum, mask=mask) + tl.store(target_buffer + offsets, sum, mask=mask, cache_modifier=".wt") else: data_0 = iris.load(source_buffer + offsets, cur_rank, 0, heap_bases_ptr, mask=mask) data_1 = iris.load(source_buffer + offsets, cur_rank, 1, heap_bases_ptr, mask=mask) @@ -84,7 +84,7 @@ def all_read_kernel( data_6 = iris.load(source_buffer + offsets, cur_rank, 6, heap_bases_ptr, mask=mask) data_7 = iris.load(source_buffer + offsets, cur_rank, 7, heap_bases_ptr, mask=mask) sum = data_0 + data_1 + data_2 + data_3 + data_4 + data_5 + data_6 + data_7 - tl.store(target_buffer + offsets, sum, mask=mask) + tl.store(target_buffer + offsets, sum, mask=mask, cache_modifier=".wt") def torch_dtype_from_str(datatype: str) -> torch.dtype: @@ -316,7 +316,13 @@ def print_bandwidth_matrix( def _worker(local_rank: int, world_size: int, init_url: str, args: dict): """Worker function for PyTorch distributed execution.""" backend = "nccl" if torch.cuda.is_available() else "gloo" - dist.init_process_group(backend=backend, init_method=init_url, world_size=world_size, rank=local_rank) + dist.init_process_group( + backend=backend, + init_method=init_url, + world_size=world_size, + rank=local_rank, + device_id=torch.device(f"cuda:{local_rank}"), + ) # Main benchmark logic heap_size = args["heap_size"] diff --git a/examples/03_all_store/all_store_bench.py b/examples/03_all_store/all_store_bench.py index eac5dd5d..b211639e 100755 --- a/examples/03_all_store/all_store_bench.py +++ b/examples/03_all_store/all_store_bench.py @@ -245,7 +245,13 @@ def print_bandwidth_matrix( def _worker(local_rank: int, world_size: int, init_url: str, args: dict): """Worker function for PyTorch distributed execution.""" backend = "nccl" if torch.cuda.is_available() else "gloo" - dist.init_process_group(backend=backend, init_method=init_url, world_size=world_size, rank=local_rank) + dist.init_process_group( + backend=backend, + init_method=init_url, + world_size=world_size, + rank=local_rank, + device_id=torch.device(f"cuda:{local_rank}"), + ) # Main benchmark logic heap_size = args["heap_size"] diff --git a/examples/04_atomic_add/atomic_add_bench.py b/examples/04_atomic_add/atomic_add_bench.py index 9b6dfb4f..3bf3f328 100755 --- a/examples/04_atomic_add/atomic_add_bench.py +++ b/examples/04_atomic_add/atomic_add_bench.py @@ -12,8 +12,10 @@ import torch.multiprocessing as mp import triton import triton.language as tl +import sys import iris +from examples.common.utils import torch_dtype_from_str torch.manual_seed(123) random.seed(123) @@ -22,7 +24,6 @@ @triton.jit def atomic_add_kernel( source_buffer, # tl.tensor: pointer to source data - result_buffer, # tl.tensor: pointer to result data buffer_size, # int32: total number of elements source_rank: tl.constexpr, destination_rank: tl.constexpr, @@ -43,20 +44,6 @@ def atomic_add_kernel( ) -def torch_dtype_from_str(datatype: str) -> torch.dtype: - dtype_map = { - "fp16": torch.float16, - "fp32": torch.float32, - "int8": torch.int8, - "bf16": torch.bfloat16, - } - try: - return dtype_map[datatype] - except KeyError: - print(f"Unknown datatype: {datatype}") - exit(1) - - def parse_args(): parser = argparse.ArgumentParser( description="Parse Message Passing configuration.", @@ -67,14 +54,13 @@ def parse_args(): "--datatype", type=str, default="fp16", - choices=["fp16", "fp32", "int8", "bf16"], + choices=["fp16", "fp32", "bf16", "int32", "int64"], help="Datatype of computation", ) parser.add_argument("-z", "--buffer_size", type=int, default=1 << 32, help="Buffer Size") parser.add_argument("-b", "--block_size", type=int, default=512, help="Block Size") parser.add_argument("-v", "--verbose", action="store_true", help="Enable verbose output") parser.add_argument("-d", "--validate", action="store_true", help="Enable validation output") - parser.add_argument("-p", "--heap_size", type=int, default=1 << 33, help="Iris heap size") parser.add_argument("-o", "--output_file", type=str, default="", help="Output file") @@ -85,7 +71,7 @@ def parse_args(): return vars(parser.parse_args()) -def run_experiment(shmem, args, source_rank, destination_rank, source_buffer, result_buffer): +def run_experiment(shmem, args, source_rank, destination_rank, source_buffer): dtype = torch_dtype_from_str(args["datatype"]) cur_rank = shmem.get_rank() world_size = shmem.get_num_ranks() @@ -108,7 +94,6 @@ def run_atomic_add(): if cur_rank == source_rank: atomic_add_kernel[grid]( source_buffer, - result_buffer, n_elements, source_rank, destination_rank, @@ -116,11 +101,18 @@ def run_atomic_add(): shmem.get_heap_bases(), ) + def preamble(): + source_buffer.fill_(0) + # Warmup run_atomic_add() shmem.barrier() atomic_add_ms = iris.do_bench( - run_atomic_add, shmem.barrier, n_repeat=args["num_experiments"], n_warmup=args["num_warmup"] + run_atomic_add, + barrier_fn=shmem.barrier, + preamble_fn=preamble, + n_repeat=args["num_experiments"], + n_warmup=args["num_warmup"], ) # Subtract overhead @@ -143,28 +135,34 @@ def run_atomic_add(): if args["verbose"]: shmem.info("Validating output...") - expected = torch.arange(n_elements, dtype=dtype, device="cuda") - diff_mask = ~torch.isclose(result_buffer, expected, atol=1) - breaking_indices = torch.nonzero(diff_mask, as_tuple=False) + expected = torch.ones(n_elements, dtype=dtype, device="cuda") + + diff_mask = ~torch.isclose(source_buffer, expected) - if not torch.allclose(result_buffer, expected, atol=1): - max_diff = (result_buffer - expected).abs().max().item() + if torch.any(diff_mask): + max_diff = (source_buffer - expected).abs().max().item() shmem.info(f"Max absolute difference: {max_diff}") - for idx in breaking_indices: - idx = tuple(idx.tolist()) - computed_val = result_buffer[idx] - expected_val = expected[idx] - shmem.error(f"Mismatch at index {idx}: C={computed_val}, expected={expected_val}") - success = False - break + + first_mismatch_idx = torch.argmax(diff_mask.float()).item() + computed_val = source_buffer[first_mismatch_idx] + expected_val = expected[first_mismatch_idx] + shmem.error(f"First mismatch at index {first_mismatch_idx}: C={computed_val}, expected={expected_val}") + success = False if success and args["verbose"]: shmem.info("Validation successful.") if not success and args["verbose"]: shmem.error("Validation failed.") + success = shmem.broadcast(success, source_rank) + shmem.barrier() - return bandwidth_gbps + + if not success: + dist.destroy_process_group() + sys.exit(1) + + return bandwidth_gbps, source_buffer.clone() def print_bandwidth_matrix( @@ -208,7 +206,13 @@ def print_bandwidth_matrix( def _worker(local_rank: int, world_size: int, init_url: str, args: dict): """Worker function for PyTorch distributed execution.""" backend = "nccl" if torch.cuda.is_available() else "gloo" - dist.init_process_group(backend=backend, init_method=init_url, world_size=world_size, rank=local_rank) + dist.init_process_group( + backend=backend, + init_method=init_url, + world_size=world_size, + rank=local_rank, + device_id=torch.device(f"cuda:{local_rank}"), + ) # Main benchmark logic shmem = iris.iris(args["heap_size"]) @@ -218,11 +222,10 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): dtype = torch_dtype_from_str(args["datatype"]) element_size_bytes = torch.tensor([], dtype=dtype).element_size() source_buffer = shmem.arange(args["buffer_size"] // element_size_bytes, device="cuda", dtype=dtype) - result_buffer = shmem.zeros_like(source_buffer) for source_rank in range(num_ranks): for destination_rank in range(num_ranks): - bandwidth_gbps = run_experiment(shmem, args, source_rank, destination_rank, source_buffer, result_buffer) + bandwidth_gbps, _ = run_experiment(shmem, args, source_rank, destination_rank, source_buffer) bandwidth_matrix[source_rank, destination_rank] = bandwidth_gbps shmem.barrier() diff --git a/examples/05_atomic_xchg/atomic_xchg_bench.py b/examples/05_atomic_xchg/atomic_xchg_bench.py index 89d7792f..07fdc4a6 100755 --- a/examples/05_atomic_xchg/atomic_xchg_bench.py +++ b/examples/05_atomic_xchg/atomic_xchg_bench.py @@ -212,7 +212,13 @@ def print_bandwidth_matrix( def _worker(local_rank: int, world_size: int, init_url: str, args: dict): """Worker function for PyTorch distributed execution.""" backend = "nccl" if torch.cuda.is_available() else "gloo" - dist.init_process_group(backend=backend, init_method=init_url, world_size=world_size, rank=local_rank) + dist.init_process_group( + backend=backend, + init_method=init_url, + world_size=world_size, + rank=local_rank, + device_id=torch.device(f"cuda:{local_rank}"), + ) # Main benchmark logic shmem = iris.iris(args["heap_size"]) diff --git a/examples/06_message_passing/message_passing_load_store.py b/examples/06_message_passing/message_passing_load_store.py index 37db8bcb..3964bfa8 100755 --- a/examples/06_message_passing/message_passing_load_store.py +++ b/examples/06_message_passing/message_passing_load_store.py @@ -48,7 +48,7 @@ def producer_kernel( ) # Set flag to signal completion - tl.store(flag + pid, 1) + iris.atomic_cas(flag + pid, 0, 1, producer_rank, consumer_rank, heap_bases_ptr, sem="release", scope="sys") @triton.jit @@ -67,9 +67,11 @@ def consumer_kernel( mask = offsets < buffer_size # Spin-wait until writer sets flag[pid] = 1 - done = tl.load(flag + pid) + done = 0 while done == 0: - done = tl.load(flag + pid) + done = iris.atomic_cas( + flag + pid, 1, 0, consumer_rank, consumer_rank, heap_bases_ptr, sem="acquire", scope="sys" + ) # Read from the target buffer (written by producer) values = iris.load(buffer + offsets, consumer_rank, consumer_rank, heap_bases_ptr, mask=mask) @@ -135,7 +137,13 @@ def parse_args(): def _worker(local_rank: int, world_size: int, init_url: str, args: dict): """Worker function for PyTorch distributed execution.""" backend = "nccl" if torch.cuda.is_available() else "gloo" - dist.init_process_group(backend=backend, init_method=init_url, world_size=world_size, rank=local_rank) + dist.init_process_group( + backend=backend, + init_method=init_url, + world_size=world_size, + rank=local_rank, + device_id=torch.device(f"cuda:{local_rank}"), + ) # Main benchmark logic shmem = iris.iris(args["heap_size"]) @@ -145,7 +153,11 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): # Allocate source and destination buffers on the symmetric heap source_buffer = shmem.zeros(args["buffer_size"], device="cuda", dtype=dtype) - destination_buffer = shmem.randn(args["buffer_size"], device="cuda", dtype=dtype) + if dtype.is_floating_point: + destination_buffer = shmem.randn(args["buffer_size"], device="cuda", dtype=dtype) + else: + ii = torch.iinfo(dtype) + destination_buffer = shmem.randint(ii.min, ii.max, (args["buffer_size"],), device="cuda", dtype=dtype) if world_size != 2: raise ValueError("This example requires exactly two processes.") diff --git a/examples/06_message_passing/message_passing_put.py b/examples/06_message_passing/message_passing_put.py index b4d54064..54abe255 100755 --- a/examples/06_message_passing/message_passing_put.py +++ b/examples/06_message_passing/message_passing_put.py @@ -39,7 +39,7 @@ def producer_kernel( iris.put(source_buffer + offsets, target_buffer + offsets, producer_rank, consumer_rank, heap_bases_ptr, mask=mask) # Set flag to signal completion - tl.store(flag + pid, 1) + iris.atomic_cas(flag + pid, 0, 1, producer_rank, consumer_rank, heap_bases_ptr, sem="release", scope="sys") @triton.jit @@ -58,9 +58,11 @@ def consumer_kernel( mask = offsets < buffer_size # Spin-wait until writer sets flag[pid] = 1 - done = tl.load(flag + pid) + done = 0 while done == 0: - done = tl.load(flag + pid) + done = iris.atomic_cas( + flag + pid, 1, 0, consumer_rank, consumer_rank, heap_bases_ptr, sem="acquire", scope="sys" + ) # Read from the target buffer (written by producer) values = tl.load(buffer + offsets, mask=mask) @@ -123,7 +125,13 @@ def parse_args(): def _worker(local_rank: int, world_size: int, init_url: str, args: dict): """Worker function for PyTorch distributed execution.""" backend = "nccl" if torch.cuda.is_available() else "gloo" - dist.init_process_group(backend=backend, init_method=init_url, world_size=world_size, rank=local_rank) + dist.init_process_group( + backend=backend, + init_method=init_url, + world_size=world_size, + rank=local_rank, + device_id=torch.device(f"cuda:{local_rank}"), + ) # Main benchmark logic shmem = iris.iris(args["heap_size"]) @@ -133,7 +141,11 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): # Allocate source and destination buffers on the symmetric heap source_buffer = shmem.zeros(args["buffer_size"], device="cuda", dtype=dtype) - destination_buffer = shmem.randn(args["buffer_size"], device="cuda", dtype=dtype) + if dtype.is_floating_point: + destination_buffer = shmem.randn(args["buffer_size"], device="cuda", dtype=dtype) + else: + ii = torch.iinfo(dtype) + destination_buffer = shmem.randint(ii.min, ii.max, (args["buffer_size"],), device="cuda", dtype=dtype) if world_size != 2: raise ValueError("This example requires exactly two processes.") diff --git a/examples/07_gemm_all_scatter/benchmark.py b/examples/07_gemm_all_scatter/benchmark.py index 6d872e61..96518c49 100755 --- a/examples/07_gemm_all_scatter/benchmark.py +++ b/examples/07_gemm_all_scatter/benchmark.py @@ -52,7 +52,12 @@ def parse_args(): parser.add_argument("--BLK_K", type=int, default=64, help="Block size K") parser.add_argument("--gsize_m", type=int, default=6, help="L2-cache locality swizzle parameter") parser.add_argument("--heap_size", type=int, default=1 << 33, help="Iris heap size") - parser.add_argument("--gemm_sms", type=int, default=304, help="Number of SMs for persistent GEMM algorithm") + parser.add_argument( + "--gemm_sms", + type=int, + default=None, + help="Number of SMs for persistent GEMM algorithm (default: auto-detected)", + ) parser.add_argument("-r", "--num_ranks", type=int, default=2, help="Number of ranks/processes") return vars(parser.parse_args()) @@ -61,13 +66,24 @@ def parse_args(): def _worker(local_rank: int, world_size: int, init_url: str, args: dict): """Worker function for PyTorch distributed execution.""" backend = "nccl" if torch.cuda.is_available() else "gloo" - dist.init_process_group(backend=backend, init_method=init_url, world_size=world_size, rank=local_rank) + dist.init_process_group( + backend=backend, + init_method=init_url, + world_size=world_size, + rank=local_rank, + device_id=torch.device(f"cuda:{local_rank}"), + ) # Main benchmark logic shmem = iris.iris(args["heap_size"]) rank = shmem.get_rank() world_size = shmem.get_num_ranks() - cu_count = shmem.get_cu_count() + + # Set default SM values if not provided + if args["gemm_sms"] is None: + # For all_scatter: use total CU count + cu_count = torch.cuda.get_device_properties(rank).multi_processor_count + args["gemm_sms"] = cu_count # GEMM datatype = torch.float32 diff --git a/examples/07_gemm_all_scatter/matmul_wrapper.py b/examples/07_gemm_all_scatter/matmul_wrapper.py index b9047df1..597713c9 100644 --- a/examples/07_gemm_all_scatter/matmul_wrapper.py +++ b/examples/07_gemm_all_scatter/matmul_wrapper.py @@ -7,6 +7,7 @@ # from streamk_kernel import streamk_gemm from gemm_all_scatter import persistent_gemm_all_scatter from examples.common.utils import is_triton_interpret_set +import iris gemm_kernel = persistent_gemm_all_scatter @@ -16,6 +17,8 @@ class matmul(torch.autograd.Function): _registers = None _spills = None + _num_xcds = iris.hip.get_num_xcc() + @staticmethod def set_debug(debug: bool): matmul._debug = debug @@ -59,9 +62,7 @@ def _call( M, K = a.shape _, N = b.shape - num_xcds = 1 - if arch == "gfx942" or arch == "gfx950": - num_xcds = 8 + num_xcds = matmul._num_xcds # TODO: Use arch-specific values. num_stages = 2 diff --git a/examples/08_gemm_atomics_all_reduce/benchmark.py b/examples/08_gemm_atomics_all_reduce/benchmark.py index 31de4fa3..d1590bc5 100755 --- a/examples/08_gemm_atomics_all_reduce/benchmark.py +++ b/examples/08_gemm_atomics_all_reduce/benchmark.py @@ -11,6 +11,7 @@ import os import argparse import json +import math from examples.common.utils import ( JSONWriter, @@ -68,10 +69,8 @@ def parse_args(): parser.add_argument("--kpack", type=int, default=2, help="K packing size") parser.add_argument("--heap_size", type=int, default=1 << 33, help="Iris heap size") - # For All Scatter, use: 288 - # For One Shot, use: 256 - parser.add_argument("--gemm_sms", type=int, default=288, help="Number of SMs for GEMM") - parser.add_argument("--total_sms", type=int, default=304, help="Total number of SMs") + parser.add_argument("--gemm_sms", type=int, default=None, help="Number of SMs for GEMM (default: auto-detected)") + parser.add_argument("--total_sms", type=int, default=None, help="Total number of SMs (default: auto-detected)") parser.add_argument("-r", "--num_ranks", type=int, default=2, help="Number of ranks/processes") return vars(parser.parse_args()) @@ -80,13 +79,26 @@ def parse_args(): def _worker(local_rank: int, world_size: int, init_url: str, args: dict): """Worker function for PyTorch distributed execution.""" backend = "nccl" if torch.cuda.is_available() else "gloo" - dist.init_process_group(backend=backend, init_method=init_url, world_size=world_size, rank=local_rank) + dist.init_process_group( + backend=backend, + init_method=init_url, + world_size=world_size, + rank=local_rank, + device_id=torch.device(f"cuda:{local_rank}"), + ) # Main benchmark logic shmem = iris.iris(args["heap_size"]) rank = shmem.get_rank() world_size = shmem.get_num_ranks() - cu_count = shmem.get_cu_count() + + # Set default SM values if not provided + cu_count = torch.cuda.get_device_properties(rank).multi_processor_count + if args["total_sms"] is None: + args["total_sms"] = cu_count + if args["gemm_sms"] is None: + # For all_reduce: use next smaller power of 2, rest for communication + args["gemm_sms"] = 2 ** int(math.log2(cu_count)) if cu_count > 0 else 1 # GEMM datatype = torch.float32 diff --git a/examples/08_gemm_atomics_all_reduce/matmul_wrapper.py b/examples/08_gemm_atomics_all_reduce/matmul_wrapper.py index ab014faf..ba55286e 100644 --- a/examples/08_gemm_atomics_all_reduce/matmul_wrapper.py +++ b/examples/08_gemm_atomics_all_reduce/matmul_wrapper.py @@ -12,6 +12,7 @@ from gemm_atomics_all_reduce import persistent_gemm_all_reduce from examples.common.utils import is_triton_interpret_set +import iris gemm_kernel = persistent_gemm_all_reduce @@ -19,6 +20,8 @@ class matmul(torch.autograd.Function): _debug = True + _num_xcds = iris.hip.get_num_xcc() + @staticmethod def set_debug(debug: bool): matmul._debug = debug @@ -49,7 +52,7 @@ def _call( mfmaInstrSize: int, kpack: int, heap_bases_ptr: torch.Tensor = None, - cu_count: int = 304, + cu_count: int = None, COLLECT_TIMESTAMPS: bool = False, mm_begin_timestamp: torch.Tensor = None, mm_end_timestamp: torch.Tensor = None, @@ -60,9 +63,7 @@ def _call( M, K = a.shape _, N = b.shape - num_xcds = 1 - if cu_count == 304: - num_xcds = 8 + num_xcds = matmul._num_xcds total_blocks_M = triton.cdiv(M, BLK_M) total_blocks_N = triton.cdiv(N, BLK_N) @@ -183,7 +184,7 @@ def forward( mfmaInstrSize=16, kpack=1, heap_bases_ptr: torch.Tensor = None, - cu_count: int = 304, + cu_count: int = None, COLLECT_TIMESTAMPS: bool = False, mm_begin_timestamp: torch.Tensor = None, mm_end_timestamp: torch.Tensor = None, diff --git a/examples/09_gemm_one_shot_all_reduce/benchmark.py b/examples/09_gemm_one_shot_all_reduce/benchmark.py index 212bc857..4badd746 100755 --- a/examples/09_gemm_one_shot_all_reduce/benchmark.py +++ b/examples/09_gemm_one_shot_all_reduce/benchmark.py @@ -11,6 +11,7 @@ import os import argparse import json +import math from examples.common.utils import ( JSONWriter, @@ -68,8 +69,8 @@ def parse_args(): parser.add_argument("--kpack", type=int, default=2, help="K packing size") parser.add_argument("--heap_size", type=int, default=1 << 33, help="Iris heap size") - parser.add_argument("--gemm_sms", type=int, default=288, help="Number of SMs for GEMM") - parser.add_argument("--total_sms", type=int, default=304, help="Total number of SMs") + parser.add_argument("--gemm_sms", type=int, default=None, help="Number of SMs for GEMM (default: auto-detected)") + parser.add_argument("--total_sms", type=int, default=None, help="Total number of SMs (default: auto-detected)") parser.add_argument("-r", "--num_ranks", type=int, default=2, help="Number of ranks/processes") return vars(parser.parse_args()) @@ -77,12 +78,25 @@ def parse_args(): def _worker(local_rank: int, world_size: int, init_url: str, args: dict): """Worker function for PyTorch distributed execution.""" backend = "nccl" if torch.cuda.is_available() else "gloo" - dist.init_process_group(backend=backend, init_method=init_url, world_size=world_size, rank=local_rank) + dist.init_process_group( + backend=backend, + init_method=init_url, + world_size=world_size, + rank=local_rank, + device_id=torch.device(f"cuda:{local_rank}"), + ) shmem = iris.iris(args["heap_size"]) rank = shmem.get_rank() world_size = shmem.get_num_ranks() - cu_count = shmem.get_cu_count() + + # Set default SM values if not provided + cu_count = torch.cuda.get_device_properties(rank).multi_processor_count + if args["total_sms"] is None: + args["total_sms"] = cu_count + if args["gemm_sms"] is None: + # For all_reduce: use next smaller power of 2, rest for communication + args["gemm_sms"] = 2 ** int(math.log2(cu_count)) if cu_count > 0 else 1 # GEMM datatype = torch.float32 diff --git a/examples/09_gemm_one_shot_all_reduce/matmul_wrapper.py b/examples/09_gemm_one_shot_all_reduce/matmul_wrapper.py index 83c9326f..49e53c0d 100644 --- a/examples/09_gemm_one_shot_all_reduce/matmul_wrapper.py +++ b/examples/09_gemm_one_shot_all_reduce/matmul_wrapper.py @@ -12,6 +12,7 @@ from gemm_one_shot_all_reduce import persistent_gemm_all_reduce from examples.common.utils import is_triton_interpret_set +import iris gemm_kernel = persistent_gemm_all_reduce @@ -19,6 +20,8 @@ class matmul(torch.autograd.Function): _debug = True + _num_xcds = iris.hip.get_num_xcc() + @staticmethod def set_debug(debug: bool): matmul._debug = debug @@ -49,7 +52,7 @@ def _call( mfmaInstrSize: int, kpack: int, heap_bases_ptr: torch.Tensor = None, - cu_count: int = 304, + cu_count: int = None, COLLECT_TIMESTAMPS: bool = False, mm_begin_timestamp: torch.Tensor = None, mm_end_timestamp: torch.Tensor = None, @@ -60,9 +63,7 @@ def _call( M, K = a.shape _, N = b.shape - num_xcds = 1 - if cu_count == 304: - num_xcds = 8 + num_xcds = matmul._num_xcds total_blocks_M = triton.cdiv(M, BLK_M) total_blocks_N = triton.cdiv(N, BLK_N) @@ -183,7 +184,7 @@ def forward( mfmaInstrSize=16, kpack=1, heap_bases_ptr: torch.Tensor = None, - cu_count: int = 304, + cu_count: int = None, COLLECT_TIMESTAMPS: bool = False, mm_begin_timestamp: torch.Tensor = None, mm_end_timestamp: torch.Tensor = None, diff --git a/examples/10_gemm_all_scatter_wg_specialization/benchmark.py b/examples/10_gemm_all_scatter_wg_specialization/benchmark.py index bb49bacb..59d14565 100755 --- a/examples/10_gemm_all_scatter_wg_specialization/benchmark.py +++ b/examples/10_gemm_all_scatter_wg_specialization/benchmark.py @@ -11,6 +11,7 @@ import os import argparse import json +import math from examples.common.utils import JSONWriter, Timestamps, is_triton_interpret_set from examples.common.validation import validate_gemm @@ -54,9 +55,17 @@ def parse_args(): parser.add_argument("--gsize_m", type=int, default=6, help="L2-cache locality swizzle parameter") parser.add_argument("--heap_size", type=int, default=1 << 33, help="Iris heap size") parser.add_argument( - "--gemm_sms", type=int, default=256, help="Number of SMs for workgroup-specialized GEMM algorithm" + "--gemm_sms", + type=int, + default=None, + help="Number of SMs for workgroup-specialized GEMM algorithm (default: auto-detected)", + ) + parser.add_argument( + "--num_sms", + type=int, + default=None, + help="Number of total SMs for gemm + scatter kernel (default: auto-detected)", ) - parser.add_argument("--num_sms", type=int, default=304, help="Number of total SMs for gemm + scatter kernel") parser.add_argument("-r", "--num_ranks", type=int, default=2, help="Number of ranks/processes") return vars(parser.parse_args()) @@ -65,12 +74,25 @@ def parse_args(): def _worker(local_rank: int, world_size: int, init_url: str, args: dict): """Worker function for PyTorch distributed execution.""" backend = "nccl" if torch.cuda.is_available() else "gloo" - dist.init_process_group(backend=backend, init_method=init_url, world_size=world_size, rank=local_rank) + dist.init_process_group( + backend=backend, + init_method=init_url, + world_size=world_size, + rank=local_rank, + device_id=torch.device(f"cuda:{local_rank}"), + ) shmem = iris.iris(args["heap_size"]) rank = shmem.get_rank() world_size = shmem.get_num_ranks() - cu_count = shmem.get_cu_count() + + # Set default SM values if not provided + cu_count = torch.cuda.get_device_properties(rank).multi_processor_count + if args["num_sms"] is None: + args["num_sms"] = cu_count + if args["gemm_sms"] is None: + # For wg_specialized: use next smaller power of 2 + args["gemm_sms"] = 2 ** int(math.log2(cu_count)) if cu_count > 0 else 1 # GEMM datatype = torch.float32 diff --git a/examples/10_gemm_all_scatter_wg_specialization/matmul_wrapper.py b/examples/10_gemm_all_scatter_wg_specialization/matmul_wrapper.py index 0587a8f0..ce186561 100644 --- a/examples/10_gemm_all_scatter_wg_specialization/matmul_wrapper.py +++ b/examples/10_gemm_all_scatter_wg_specialization/matmul_wrapper.py @@ -5,8 +5,11 @@ import triton # from streamk_kernel import streamk_gemm -from gemm_all_scatter_wg_specialization import persistent_gemm_all_scatter_wg_specialization +from gemm_all_scatter_wg_specialization import ( + persistent_gemm_all_scatter_wg_specialization, +) from examples.common.utils import is_triton_interpret_set +import iris gemm_kernel = persistent_gemm_all_scatter_wg_specialization @@ -16,6 +19,8 @@ class matmul(torch.autograd.Function): _registers = None _spills = None + _num_xcds = iris.hip.get_num_xcc() + @staticmethod def set_debug(debug: bool): matmul._debug = debug @@ -61,9 +66,7 @@ def _call( M, K = a.shape _, N = b.shape - num_xcds = 1 - if arch == "gfx942" or arch == "gfx950": - num_xcds = 8 + num_xcds = matmul._num_xcds # TODO: Use arch-specific values. num_stages = 2 diff --git a/examples/11_gemm_all_scatter_producer_consumer/benchmark.py b/examples/11_gemm_all_scatter_producer_consumer/benchmark.py index 41c164a8..4849b053 100755 --- a/examples/11_gemm_all_scatter_producer_consumer/benchmark.py +++ b/examples/11_gemm_all_scatter_producer_consumer/benchmark.py @@ -11,6 +11,7 @@ import os import argparse import json +import math from examples.common.utils import JSONWriter, Timestamps, is_triton_interpret_set from examples.common.validation import validate_gemm @@ -55,9 +56,14 @@ def parse_args(): parser.add_argument("--gsize_m", type=int, default=6, help="L2-cache locality swizzle parameter") parser.add_argument("--heap_size", type=int, default=1 << 33, help="Iris heap size") parser.add_argument( - "--gemm_sms", type=int, default=256, help="Number of SMs for workgroup-specialized GEMM algorithm" + "--gemm_sms", + type=int, + default=None, + help="Number of SMs for workgroup-specialized GEMM algorithm (default: auto-detected)", + ) + parser.add_argument( + "--comm_sms", type=int, default=None, help="Number of SMs for All-Scatter kernel (default: auto-detected)" ) - parser.add_argument("--comm_sms", type=int, default=48, help="Number of SMs for All-Scatter kernel") parser.add_argument("-r", "--num_ranks", type=int, default=2, help="Number of ranks/processes") return vars(parser.parse_args()) @@ -66,12 +72,28 @@ def parse_args(): def _worker(local_rank: int, world_size: int, init_url: str, args: dict): """Worker function for PyTorch distributed execution.""" backend = "nccl" if torch.cuda.is_available() else "gloo" - dist.init_process_group(backend=backend, init_method=init_url, world_size=world_size, rank=local_rank) + dist.init_process_group( + backend=backend, + init_method=init_url, + world_size=world_size, + rank=local_rank, + device_id=torch.device(f"cuda:{local_rank}"), + ) shmem = iris.iris(args["heap_size"]) rank = shmem.get_rank() world_size = shmem.get_num_ranks() - cu_count = shmem.get_cu_count() + + # Set default SM values if not provided + cu_count = torch.cuda.get_device_properties(rank).multi_processor_count + next_pow2 = 2 ** int(math.log2(cu_count)) if cu_count > 0 else 1 + + if args["gemm_sms"] is None: + # For wg_specialized: use next smaller power of 2 + args["gemm_sms"] = next_pow2 + if args["comm_sms"] is None: + # comm_sms is the leftover: total - next_power_of_2 + args["comm_sms"] = cu_count - next_pow2 # GEMM datatype = torch.float32 @@ -118,10 +140,7 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): bias = None - num_xcds = 1 - arch = "gfx942" - if arch == "gfx942" or arch == "gfx950": - num_xcds = 8 + num_xcds = iris.hip.get_num_xcc() gemm_stream = torch.cuda.Stream() comm_stream = torch.cuda.Stream() diff --git a/examples/11_gemm_all_scatter_producer_consumer/matmul_wrapper.py b/examples/11_gemm_all_scatter_producer_consumer/matmul_wrapper.py index 4c83cbbf..9b041d90 100644 --- a/examples/11_gemm_all_scatter_producer_consumer/matmul_wrapper.py +++ b/examples/11_gemm_all_scatter_producer_consumer/matmul_wrapper.py @@ -7,6 +7,7 @@ # from streamk_kernel import streamk_gemm from gemm_all_scatter_producer_consumer import persistent_gemm from examples.common.utils import is_triton_interpret_set +import iris gemm_kernel = persistent_gemm @@ -16,6 +17,8 @@ class matmul(torch.autograd.Function): _registers = None _spills = None + _num_xcds = iris.hip.get_num_xcc() + @staticmethod def set_debug(debug: bool): matmul._debug = debug @@ -60,9 +63,7 @@ def _call( M, K = a.shape _, N = b.shape - num_xcds = 1 - if arch == "gfx942" or arch == "gfx950": - num_xcds = 8 + num_xcds = matmul._num_xcds # TODO: Use arch-specific values. num_stages = 2 diff --git a/examples/12_gemm_all_scatter_bulk_synchronous/benchmark.py b/examples/12_gemm_all_scatter_bulk_synchronous/benchmark.py index 5cdc3819..9a242dee 100755 --- a/examples/12_gemm_all_scatter_bulk_synchronous/benchmark.py +++ b/examples/12_gemm_all_scatter_bulk_synchronous/benchmark.py @@ -11,6 +11,7 @@ import os import argparse import json +import math from examples.common.utils import JSONWriter, Timestamps, is_triton_interpret_set from examples.common.validation import validate_gemm @@ -55,9 +56,14 @@ def parse_args(): parser.add_argument("--gsize_m", type=int, default=6, help="L2-cache locality swizzle parameter") parser.add_argument("--heap_size", type=int, default=1 << 33, help="Iris heap size") parser.add_argument( - "--gemm_sms", type=int, default=256, help="Number of SMs for workgroup-specialized GEMM algorithm" + "--gemm_sms", + type=int, + default=None, + help="Number of SMs for workgroup-specialized GEMM algorithm (default: auto-detected)", + ) + parser.add_argument( + "--comm_sms", type=int, default=None, help="Number of SMs for All-Scatter kernel (default: auto-detected)" ) - parser.add_argument("--comm_sms", type=int, default=256, help="Number of SMs for All-Scatter kernel") parser.add_argument("-r", "--num_ranks", type=int, default=2, help="Number of ranks/processes") return vars(parser.parse_args()) @@ -66,12 +72,28 @@ def parse_args(): def _worker(local_rank: int, world_size: int, init_url: str, args: dict): """Worker function for PyTorch distributed execution.""" backend = "nccl" if torch.cuda.is_available() else "gloo" - dist.init_process_group(backend=backend, init_method=init_url, world_size=world_size, rank=local_rank) + dist.init_process_group( + backend=backend, + init_method=init_url, + world_size=world_size, + rank=local_rank, + device_id=torch.device(f"cuda:{local_rank}"), + ) shmem = iris.iris(args["heap_size"]) rank = shmem.get_rank() world_size = shmem.get_num_ranks() - cu_count = shmem.get_cu_count() + + # Set default SM values if not provided + cu_count = torch.cuda.get_device_properties(rank).multi_processor_count + next_pow2 = 2 ** int(math.log2(cu_count)) if cu_count > 0 else 1 + + if args["gemm_sms"] is None: + # For wg_specialized: use next smaller power of 2 + args["gemm_sms"] = next_pow2 + if args["comm_sms"] is None: + # For bulk synchronous, use same as gemm_sms + args["comm_sms"] = next_pow2 # GEMM datatype = torch.float32 @@ -116,10 +138,7 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): bias = None - num_xcds = 1 - arch = "gfx942" - if arch == "gfx942" or arch == "gfx950": - num_xcds = 8 + num_xcds = iris.hip.get_num_xcc() # This is one after another. main_stream = torch.cuda.Stream() diff --git a/examples/12_gemm_all_scatter_bulk_synchronous/matmul_wrapper.py b/examples/12_gemm_all_scatter_bulk_synchronous/matmul_wrapper.py index bf48c4e0..e621112e 100644 --- a/examples/12_gemm_all_scatter_bulk_synchronous/matmul_wrapper.py +++ b/examples/12_gemm_all_scatter_bulk_synchronous/matmul_wrapper.py @@ -7,6 +7,7 @@ # from streamk_kernel import streamk_gemm from gemm_all_scatter_bulk_synchronous import persistent_gemm from examples.common.utils import is_triton_interpret_set +import iris gemm_kernel = persistent_gemm @@ -15,6 +16,7 @@ class matmul(torch.autograd.Function): _debug = False _registers = None _spills = None + _num_xcds = iris.hip.get_num_xcc() @staticmethod def set_debug(debug: bool): @@ -58,9 +60,7 @@ def _call( M, K = a.shape _, N = b.shape - num_xcds = 1 - if arch == "gfx942" or arch == "gfx950": - num_xcds = 8 + num_xcds = matmul._num_xcds # TODO: Use arch-specific values. num_stages = 2 diff --git a/examples/13_flash_decode/README.md b/examples/13_flash_decode/README.md new file mode 100644 index 00000000..e71c69cb --- /dev/null +++ b/examples/13_flash_decode/README.md @@ -0,0 +1,60 @@ + + +# Fused Flash Decode Attention + +This is an example for a distributed Flash Decode kernel designed to accelerate LLM Inference. Part of the code is adapted from [Triton-distributed](https://github.com/ByteDance-Seed/Triton-distributed). + +This is a novel implementation that fuses communication and computation, diminshing the collective kernel launch latencies and the associated waits. + +The core layer implementation is in `examples/13_flash_decode/flash_decode_fused_layer.py` while the Triton fused kernels are defined in `examples/13_flash_decode/decode_kernels.py`. + +We perform comparisons against the RCCL baseline. + +--- + +## Usage + +### Simple Example + +To simply do a test run of the code, run: +```terminal +python examples/13_flash_decode/example_run.py +``` +This example will run by default on 8 GPUs. Use the `--num_ranks` flag to select the number of GPUs. + +### Validation + +These scripts use `pytest` to verify the numerical correctness of each implementation against a standard PyTorch reference. + +**Iris** + +```terminal +python tests/run_tests_distributed.py tests/examples/test_flash_decode.py --num_ranks 8 +``` + +**RCCL** + +```terminal +python tests/run_tests_distributed.py examples/benchmark/reference/flash_decode_rccl/validate_flash_decode_rccl.py --num_ranks 8 +``` + +### Benchmarking + +These scripts run a sweep of configurations and save performance results as `.json` files into the `results/` directory. + +**Iris** + +```terminal +python benchmark/examples/benchmark_flash_decode.py --num_ranks 8 +``` + +**RCCL** + +```terminal +torchrun --nproc_per_node=8 examples/benchmark/reference/flash_decode_rccl/benchmark_flash_decode_rccl.py +``` + + diff --git a/examples/13_flash_decode/decode_kernels.py b/examples/13_flash_decode/decode_kernels.py new file mode 100644 index 00000000..2b0e6c1b --- /dev/null +++ b/examples/13_flash_decode/decode_kernels.py @@ -0,0 +1,416 @@ +################################################################################ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. +# +# +# Part of the code adapted from +# https://github.com/ByteDance-Seed/Triton-distributed/blob/main/python/triton_dist/kernels/nvidia/flash_decode.py +# +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Permission is hereby granted, free of charge, to any person obtaining +# a copy of this software and associated documentation files +# (the "Software"), to deal in the Software without restriction, +# including without limitation the rights to use, copy, modify, merge, +# publish, distribute, sublicense, and/or sell copies of the Software, +# and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be +# included in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +# +################################################################################ + +import torch +import triton +import math +import os +import triton.language as tl +from triton.language.extra import libdevice +import iris + + +# This kernel will do the attention computation on each local GPU +# Each GPU will get a shard of the KV and will produce the outputs for their part +def gqa_local_kernels_fused( + q, + k_cache, + v_cache, + gathered_buffer, + signal_flags, + shmem, + q_lens, + kv_lens, + block_table, + scale, + soft_cap=0.0, + output_split=None, + kv_split=-1, +): + batch, q_heads, q_head_dim = q.shape + _, page_size, kv_heads, k_head_dim = k_cache.shape + v_head_dim = v_cache.shape[-1] + rank = shmem.get_rank() + num_ranks = shmem.get_num_ranks() + + BLOCK_N = 64 + BLOCK_HEAD_DIM = 2 ** int(math.log2(q_head_dim)) + BLOCK_DPE = q_head_dim - BLOCK_HEAD_DIM + BLOCK_DV = triton.next_power_of_2(v_head_dim) + kv_group_num = q_heads // kv_heads + BLOCK_H = 16 + NUM_KV_SPLITS = 32 if kv_split == -1 else kv_split + + # Step 1: Split-K calculation (will produce partial attention outputs) + grid_split_kv = (batch, triton.cdiv(q_heads, min(BLOCK_H, kv_group_num)), NUM_KV_SPLITS) + if output_split is None: + output_split = torch.empty([batch, q_heads, NUM_KV_SPLITS, v_head_dim + 1], dtype=q.dtype, device=q.device) + + # Kernel-Split-K + gqa_local_decode_split_k[grid_split_kv]( + q, + k_cache, + v_cache, + output_split, + scale, + block_table, + kv_lens, + batch, + q.stride(0), + q.stride(1), + k_cache.stride(-3), + k_cache.stride(-2), + v_cache.stride(-3), + v_cache.stride(-2), + output_split.stride(0), + output_split.stride(1), + output_split.stride(2), + block_table.stride(0), + kv_group_num, + q_heads, + BLOCK_HEAD_DIM, + BLOCK_DPE, + BLOCK_DV, + BLOCK_N, + BLOCK_H, + NUM_KV_SPLITS, + page_size, + soft_cap, + k_head_dim, + v_head_dim, + num_warps=4, + num_stages=2, + ) + + # Step 2: Fused Intra-Rank Combine and Inter-Rank Push with tile-level signaling + # The communication happens inside the kernel through Iris Stores + grid_combine_push = (batch, q_heads) + gqa_local_reduce_fused[grid_combine_push]( + output_split, + kv_lens, + gathered_buffer, + signal_flags, + signal_flags.stride(0), + signal_flags.stride(1), + signal_flags.stride(2), + signal_flags.stride(3), + shmem.get_heap_bases(), + output_split.stride(0), + output_split.stride(1), + output_split.stride(2), + gathered_buffer.stride(0), + gathered_buffer.stride(1), + gathered_buffer.stride(2), + rank, + num_ranks, + q_heads, + NUM_KV_SPLITS, + BLOCK_DV, + v_head_dim, + ) + + +@triton.jit +def gqa_local_decode_split_k( + q_ptr, + k_cache_ptr, + v_cache_ptr, + output_ptr, + sm_scale, + block_table_ptr, + kv_length_ptr, + # shape + batch, + # strides + stride_q_bs, + stride_q_h, + stride_k_cache_bs, + stride_k_cache_h, + stride_v_cache_bs, + stride_v_cache_h, + stride_o_bs, + stride_o_h, + stride_o_split, + stride_table_bs, + # constants + kv_group_num: tl.constexpr, + q_head_num: tl.constexpr, + BLOCK_HEAD_DIM: tl.constexpr, + BLOCK_DPE: tl.constexpr, + BLOCK_DV: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_H: tl.constexpr, + NUM_KV_SPLITS: tl.constexpr, + PAGE_SIZE: tl.constexpr, + soft_cap: tl.constexpr, + K_DIM: tl.constexpr, + V_DIM: tl.constexpr, +): + bid = tl.program_id(0) + hid = tl.program_id(1) + kv_hid = hid // tl.cdiv(kv_group_num, BLOCK_H) + split_kv_id = tl.program_id(2) + + if kv_group_num > BLOCK_H: + VALID_BLOCK_H: tl.constexpr = BLOCK_H + else: + VALID_BLOCK_H: tl.constexpr = kv_group_num + + cur_head = hid * VALID_BLOCK_H + tl.arange(0, BLOCK_H) + mask_h = (cur_head < (hid + 1) * VALID_BLOCK_H) & (cur_head < q_head_num) + + offs_d = tl.arange(0, BLOCK_HEAD_DIM) + offs_dv = tl.arange(0, BLOCK_DV) + mask_d = offs_d < K_DIM + mask_dv = offs_dv < V_DIM + cur_kv_seq_len = tl.load(kv_length_ptr + bid) + + offs_q = bid * stride_q_bs + cur_head[:, None] * stride_q_h + offs_d[None, :] + q = tl.load(q_ptr + offs_q, mask=(mask_h[:, None]) & (mask_d[None, :]), other=0.0) + + if BLOCK_DPE > 0: + offs_dpe = BLOCK_HEAD_DIM + tl.arange(0, BLOCK_DPE) + mask_dpe = offs_dpe < K_DIM + offs_qpe = bid * stride_q_bs + cur_head[:, None] * stride_q_h + offs_dpe[:, None] + qpe = tl.load(q_ptr + offs_qpe, mask=mask_h[:, None] & mask_dpe[None, :], other=0.0) + + kv_len_per_split = tl.cdiv(cur_kv_seq_len, NUM_KV_SPLITS) + split_kv_start = kv_len_per_split * split_kv_id + split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_kv_seq_len) + + e_max = tl.zeros([BLOCK_H], dtype=tl.float32) - float("inf") + e_sum = tl.zeros([BLOCK_H], dtype=tl.float32) + acc = tl.zeros([BLOCK_H, BLOCK_DV], dtype=tl.float32) + + for start_n in range(split_kv_start, split_kv_end, BLOCK_N): + offs_n = start_n + tl.arange(0, BLOCK_N) + kv_page_number = tl.load( + block_table_ptr + bid * stride_table_bs + offs_n // PAGE_SIZE, mask=offs_n < split_kv_end, other=0 + ) + kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE + offs_cache_k = kv_loc[None, :] * stride_k_cache_bs + kv_hid * stride_k_cache_h + offs_d[:, None] + k = tl.load(k_cache_ptr + offs_cache_k, mask=(offs_n[None, :] < split_kv_end) & mask_d[:, None], other=0.0) + qk = tl.dot(q, k.to(q.dtype)) + + if BLOCK_DPE > 0: + offs_cache_kpe = kv_loc[None, :] * stride_k_cache_bs + kv_hid * stride_k_cache_h + offs_dpe[:, None] + kpe = tl.load( + k_cache_ptr + offs_cache_kpe, mask=(offs_n[None, :] < split_kv_end) & mask_dpe[:, None], other=0.0 + ) + qk += tl.dot(qpe, kpe.to(qpe.dtype)) + + qk *= sm_scale + + if soft_cap > 0: + qk = soft_cap * libdevice.tanh(qk / soft_cap) + + qk = tl.where(mask_h[:, None] & (offs_n[None, :] < split_kv_end), qk, float("-inf")) + + offs_cache_v = kv_loc[:, None] * stride_v_cache_bs + kv_hid * stride_v_cache_h + offs_dv[None, :] + v = tl.load(v_cache_ptr + offs_cache_v, mask=(offs_n[:, None] < split_kv_end) & mask_dv[None, :], other=0.0) + + n_e_max = tl.maximum(tl.max(qk, 1), e_max) + re_scale = libdevice.fast_expf(e_max - n_e_max) + p = libdevice.fast_expf(qk - n_e_max[:, None]) + acc *= re_scale[:, None] + acc += tl.dot(p.to(v.dtype), v) + + e_sum = e_sum * re_scale + tl.sum(p, 1) + e_max = n_e_max + + offs_out = bid * stride_o_bs + cur_head[:, None] * stride_o_h + split_kv_id * stride_o_split + offs_dv[None, :] + tl.store(output_ptr + offs_out, acc / e_sum[:, None], mask=mask_h[:, None] & mask_dv[None, :]) + + offs_log = bid * stride_o_bs + cur_head * stride_o_h + split_kv_id * stride_o_split + V_DIM + tl.store(output_ptr + offs_log, e_max + tl.log(e_sum), mask=mask_h) + + +@triton.jit +def gqa_local_reduce_fused( + # Input + Mid_O, + B_Seqlen, + gathered_output_ptr, + signal_flags_ptr, + stride_signal_dest, + stride_signal_src, + stride_signal_bs, + stride_signal_h, + heap_bases_ptr, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + stride_gathered_rank, + stride_gathered_bs, + stride_gathered_h, + my_rank: tl.constexpr, + world_size: tl.constexpr, + q_head_num: tl.constexpr, + NUM_KV_SPLITS: tl.constexpr, + BLOCK_DV: tl.constexpr, + Lv: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + + # Standard softmax combination logic + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + offs_d = tl.arange(0, BLOCK_DV) + mask_d = offs_d < Lv + + e_sum = 0.0 + e_max = -float("inf") + acc = tl.zeros([BLOCK_DV], dtype=tl.float32) + + offs_v_base = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + + for split_kv_id in range(0, NUM_KV_SPLITS): + split_kv_start = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS) * split_kv_id + if split_kv_start < cur_batch_seq_len: + offs_v = offs_v_base + split_kv_id * stride_mid_os + offs_d + offs_logic = offs_v_base + split_kv_id * stride_mid_os + Lv + + tv = tl.load(Mid_O + offs_v, mask=mask_d, other=0.0) + tlogic = tl.load(Mid_O + offs_logic) + + n_e_max = tl.maximum(tlogic, e_max) + old_scale = libdevice.fast_expf(e_max - n_e_max) + exp_logic = libdevice.fast_expf(tlogic - n_e_max) + + acc = acc * old_scale + exp_logic * tv + e_sum = e_sum * old_scale + exp_logic + e_max = n_e_max + + final_v = acc / e_sum + final_logic = e_max + tl.log(e_sum) + final_v = tl.where(e_sum == 0.0, 0.0, final_v) + + # Write tile result to all other GPUs and signal completion for this specific tile + base_write_ptr = ( + gathered_output_ptr + + my_rank * stride_gathered_rank + + cur_batch * stride_gathered_bs + + cur_head * stride_gathered_h + ) + + for dest_rank_id in range(0, world_size): + # Write output vector and log-sum-exp value + iris.store(base_write_ptr + offs_d, final_v, my_rank, dest_rank_id, heap_bases_ptr, mask=mask_d) + iris.store(base_write_ptr + Lv, final_logic, my_rank, dest_rank_id, heap_bases_ptr) + + # Signal the destination rank that this specific tile is ready + flag_ptr = ( + signal_flags_ptr + + dest_rank_id * stride_signal_dest + + my_rank * stride_signal_src + + cur_batch * stride_signal_bs + + cur_head * stride_signal_h + ) + iris.atomic_xchg(flag_ptr, 1, my_rank, dest_rank_id, heap_bases_ptr, sem="release", scope="sys") + + +@triton.jit +def gqa_global_reduce_fused( + All_Ranks_Mid_O, + o, + B_Seqlens, + signal_flags_ptr, + stride_signal_dest, + stride_signal_src, + stride_signal_bs, + stride_signal_h, + batch, + q_heads, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + stride_obs, + stride_oh, + my_rank: tl.constexpr, + NUM_KV_SPLITS: tl.constexpr, # This is used as num_ranks + BLOCK_DV: tl.constexpr, + Lv: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + + offs_d = tl.arange(0, BLOCK_DV) + mask_d = offs_d < Lv + cur_batch_seq_len_ptr = B_Seqlens + cur_batch + + e_sum = 0.0 + e_max = -float("inf") + acc = tl.zeros([BLOCK_DV], dtype=tl.float32) + + # Iterate through all source ranks to gather partial results + for source_rank_id in range(0, NUM_KV_SPLITS): + # Wait for the specific tile from the source rank to be ready + flag_ptr = ( + signal_flags_ptr + + my_rank * stride_signal_dest + + source_rank_id * stride_signal_src + + cur_batch * stride_signal_bs + + cur_head * stride_signal_h + ) + + while tl.atomic_cas(flag_ptr, 0, 0, sem="acquire", scope="sys") == 0: + pass + + effective_kv_len = tl.load(cur_batch_seq_len_ptr + source_rank_id * batch) + + if effective_kv_len > 0: + # Load the data for the tile from the source rank + base_ptr = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + source_rank_id * stride_mid_os + offs_v = base_ptr + offs_d + offs_logic = base_ptr + Lv + + tv = tl.load(All_Ranks_Mid_O + offs_v, mask=mask_d, other=0.0) + tlogic = tl.load(All_Ranks_Mid_O + offs_logic) + + # Combine the partial result using softmax reduction + n_e_max = tl.maximum(tlogic, e_max) + old_scale = libdevice.fast_expf(e_max - n_e_max) + acc *= old_scale + + exp_logic = libdevice.fast_expf(tlogic - n_e_max) + acc += exp_logic * tv + + e_sum = e_sum * old_scale + exp_logic + e_max = n_e_max + + final_out = acc / e_sum + final_out = tl.where(e_sum == 0, 0.0, final_out) + + tl.store( + o + cur_batch * stride_obs + cur_head * stride_oh + offs_d, + final_out, + mask=mask_d, + ) diff --git a/examples/13_flash_decode/example_run.py b/examples/13_flash_decode/example_run.py new file mode 100644 index 00000000..34ab1fc3 --- /dev/null +++ b/examples/13_flash_decode/example_run.py @@ -0,0 +1,163 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +""" +A simple, minimal example demonstrating how to use the flash_decode_fused_layer. + +This script initializes the necessary distributed components with Iris, +creates sample input tensors, instantiates the layer, and calls its +forward pass once. It then prints the shape and a slice of the output +tensor to show that the operation completed successfully. + +The layer is defined in the flash_decode_fused_layer.py file. +All the triton kernels are defined in decode_kernels.py +""" + +import torch +import torch.multiprocessing as mp +import torch.distributed as dist +import iris +import argparse +from flash_decode_fused_layer import flash_decode_fused_layer + + +def parse_args(): + """Parses command-line arguments for the example.""" + parser = argparse.ArgumentParser(description="A minimal example for flash_decode_fused_layer.") + parser.add_argument("--kv_len_per_rank", type=int, default=32768, help="KV sequence length per rank.") + parser.add_argument("--num_heads", type=int, default=96, help="Number of attention heads.") + parser.add_argument("--head_dim", type=int, default=128, help="Dimension of each attention head.") + parser.add_argument("--num_seqs", type=int, default=4, help="Number of sequences in the batch.") + parser.add_argument("--num_ranks", type=int, default=8, help="Number of GPUs to run the example on.") + parser.add_argument( + "--dtype", type=str, default="float16", choices=["float16", "bfloat16"], help="PyTorch data type to use." + ) + return parser.parse_args() + + +def setup_example_data(rank, world_size, args, dtype): + """Creates a set of random tensors to serve as inputs for the layer.""" + + num_query_heads = args.num_heads + # Assume an 8:1 Grouped-Query Attention ratio for this example + num_kv_heads = max(1, args.num_heads // 8) + block_size = 1 # PagedAttention works with blocks of tokens + + # Number of blocks needed on this rank to store the KV cache for all sequences + num_blocks_per_rank = (args.kv_len_per_rank + block_size - 1) // block_size + + print(f"[Rank {rank}] Creating example tensors...") + + # 1. Query tensor: The new tokens for which we are calculating attention. + query = torch.randn(args.num_seqs, num_query_heads, args.head_dim, dtype=dtype).cuda() + + # 2. Key/Value Caches: Tensors representing the keys and values + # The KV is split across ranks + key_cache_this_rank = torch.randn(num_blocks_per_rank, block_size, num_kv_heads, args.head_dim, dtype=dtype).cuda() + value_cache_this_rank = torch.randn( + num_blocks_per_rank, block_size, num_kv_heads, args.head_dim, dtype=dtype + ).cuda() + + # 3. Block Tables: A mapping that tells the kernel where to find the blocks for each sequence in the KV cache. + # Here, we create a simple identity mapping for demonstration. + block_tables_this_rank = torch.arange(num_blocks_per_rank, dtype=torch.int32).repeat(args.num_seqs, 1).cuda() + + # 4. Global KV Lengths Tensor: The layer needs to know the sequence length on all ranks. + # Create a list of lengths for each sequence in the batch on this rank. + kv_lens_per_rank = [args.kv_len_per_rank] * args.num_seqs + # Create a 1D tensor from this list. Shape: (NUM_SEQS,) + kv_lens_tensor_this_rank = torch.tensor(kv_lens_per_rank, dtype=torch.int32).cuda() + # Reshape to (1, NUM_SEQS) and repeat for all ranks to get shape (world_size, NUM_SEQS) + global_kv_lens_tensor = kv_lens_tensor_this_rank.unsqueeze(0).repeat(world_size, 1) + + return { + "query": query, + "key_cache_this_rank": key_cache_this_rank, + "value_cache_this_rank": value_cache_this_rank, + "block_tables_this_rank": block_tables_this_rank, + "global_kv_lens_tensor": global_kv_lens_tensor, + } + + +def example_run(rank: int, world_size: int, init_url: str, args: dict): + backend = "nccl" if torch.cuda.is_available() else "gloo" + dist.init_process_group( + backend=backend, init_method=init_url, world_size=world_size, rank=rank, device_id=torch.device(f"cuda:{rank}") + ) + + # 1. Initialize Iris for distributed communication + shmem = iris.iris() + + torch.manual_seed(42) + torch.set_default_device("cuda") + dtype = getattr(torch, args.dtype) + + if rank == 0: + print("--- flash_decode_fused_layer Minimal Example ---") + print(f"Running with {world_size} rank(s).") + + # 2. Set up the example input tensors + tensor_data = setup_example_data(rank, world_size, args, dtype) + shmem.barrier() + + # 3. Define the layer's parameters + num_kv_heads = max(1, args.num_heads // 8) + scale = args.head_dim**-0.5 + common_params = { + "num_q_heads": args.num_heads, + "num_kv_heads": num_kv_heads, + "q_head_dim": args.head_dim, + "v_head_dim": args.head_dim, + "page_size": 1, + "scale": scale, + "soft_cap": 0.0, + "max_allowed_batch": args.num_seqs, + } + + # 4. Instantiate the layer + if rank == 0: + print("\nInstantiating flash_decode_fused_layer...") + fd_layer = flash_decode_fused_layer(shmem, rank, rank, world_size, world_size, **common_params) + + # 5. Call the forward pass of the layer + if rank == 0: + print("Calling the forward pass...") + output = fd_layer( + tensor_data["query"], + tensor_data["key_cache_this_rank"], + tensor_data["value_cache_this_rank"], + tensor_data["global_kv_lens_tensor"], + tensor_data["block_tables_this_rank"], + ) + + # Ensure the computation is finished before printing + torch.cuda.synchronize() + shmem.barrier() + + # 6. Print a summary of the output tensor on the main rank + if rank == 0: + print("\n--- Example Run Finished ---") + print(f"Output tensor shape: {output.shape}") + print("Output tensor values (first 5 elements of the first sequence):") + print(output[0, 0, :5]) + print("--------------------------") + + shmem.barrier() + dist.destroy_process_group() + + +def main(): + args = parse_args() + num_ranks = args.num_ranks + init_url = "tcp://127.0.0.1:29500" + mp.spawn( + fn=example_run, + args=(num_ranks, init_url, args), + nprocs=num_ranks, + join=True, + ) + + +if __name__ == "__main__": + main() diff --git a/examples/13_flash_decode/flash_decode_fused_layer.py b/examples/13_flash_decode/flash_decode_fused_layer.py new file mode 100644 index 00000000..ad0015ca --- /dev/null +++ b/examples/13_flash_decode/flash_decode_fused_layer.py @@ -0,0 +1,148 @@ +################################################################################ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. +# +# +# Part of the code adapted from +# https://github.com/ByteDance-Seed/Triton-distributed/blob/main/python/triton_dist/layers/nvidia/sp_flash_decode_layer.py################################################################################ +# +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Permission is hereby granted, free of charge, to any person obtaining +# a copy of this software and associated documentation files +# (the "Software"), to deal in the Software without restriction, +# including without limitation the rights to use, copy, modify, merge, +# publish, distribute, sublicense, and/or sell copies of the Software, +# and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be +# included in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +# +################################################################################ + +import torch +import triton + +from decode_kernels import gqa_local_kernels_fused, gqa_global_reduce_fused + + +class flash_decode_fused_layer(torch.nn.Module): + def __init__( + self, + shmem, + rank, + node, + num_ranks, + num_nodes, + num_q_heads, + num_kv_heads, + q_head_dim, + v_head_dim, + page_size=1, + scale=1, + soft_cap=0, + max_allowed_batch=1, + thrink_buffer_threshold=500, + stages=20, + ): + super().__init__() + self.shmem = shmem + self.rank = rank + self.num_ranks = num_ranks + self.node = node + self.num_nodes = num_nodes + + self.num_q_heads = num_q_heads + self.num_kv_heads = num_kv_heads + self.q_head_dim = q_head_dim + self.v_head_dim = v_head_dim + self.page_size = page_size + self.soft_cap = soft_cap + self.scale = scale + self.kv_split = 32 + self.max_allowed_batch = max_allowed_batch + + self.BLOCK_DV = triton.next_power_of_2(self.v_head_dim) + + self.gathered_buffer = self.shmem.empty( + (self.num_ranks, self.max_allowed_batch, self.num_q_heads, self.v_head_dim + 1), dtype=torch.float16 + ) + + # Use per-tile signaling for finer-grained synchronization + # This will tell which rank sent the data to which rank, for each batch item and head + self.signal_flags = self.shmem.zeros( + (self.num_ranks, self.num_ranks, self.max_allowed_batch, self.num_q_heads), dtype=torch.int32 + ) + + # self.producer_stream = torch.cuda.Stream() + # self.consumer_stream = torch.cuda.Stream() + + def clear_flags(self): + """Resets synchronization flags for the next iteration.""" + self.signal_flags.zero_() + self.shmem.barrier() + + def forward(self, q, k_cache, v_cache, global_kv_lens, block_table): + batch = q.shape[0] + assert global_kv_lens.shape[0] == self.num_ranks + assert global_kv_lens.shape[1] == batch + assert batch <= self.max_allowed_batch + + output_split = torch.empty( + [batch, self.num_q_heads, self.kv_split, self.v_head_dim + 1], dtype=q.dtype, device=q.device + ) + final_output = torch.empty([batch, self.num_q_heads, self.v_head_dim], dtype=q.dtype, device=q.device) + + # with torch.cuda.stream(self.producer_stream): + gqa_local_kernels_fused( + q, + k_cache, + v_cache, + self.gathered_buffer, + self.signal_flags, + self.shmem, + [1] * batch, + global_kv_lens[self.rank], + block_table, + self.scale, + soft_cap=self.soft_cap, + output_split=output_split, + kv_split=self.kv_split, + ) + + # with torch.cuda.stream(self.consumer_stream): + kk3 = gqa_global_reduce_fused[(batch, self.num_q_heads)]( + self.gathered_buffer, + final_output, + global_kv_lens, + self.signal_flags, + self.signal_flags.stride(0), # stride_signal_dest + self.signal_flags.stride(1), # stride_signal_src + self.signal_flags.stride(2), # stride_signal_bs + self.signal_flags.stride(3), # stride_signal_h + batch, + self.num_q_heads, + self.gathered_buffer.stride(1), # stride_mid_ob + self.gathered_buffer.stride(2), # stride_mid_oh + self.gathered_buffer.stride(0), # stride_mid_os (now rank stride) + final_output.stride(0), # stride_obs + final_output.stride(1), # stride_oh + self.rank, + self.num_ranks, # NUM_KV_SPLITS becomes num_ranks + self.BLOCK_DV, + self.v_head_dim, + ) + + # print(f"{kk3.n_regs} registers used third, {kk3.n_spills} spills") + # self.clear_flags() + + return final_output diff --git a/examples/13_flash_decode/utils.py b/examples/13_flash_decode/utils.py new file mode 100644 index 00000000..642cd6c2 --- /dev/null +++ b/examples/13_flash_decode/utils.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +import sys +from typing import List, Dict, Any, Optional + +import torch +import os +import json + + +def dist_print(message: str, rank: int, is_error: bool = False): + """Prints a message only from rank 0.""" + if rank == 0: + if is_error: + print(f"❌ ERROR: {message}", file=sys.stderr) + else: + print(message) + + +def print_correctness_report( + rank: int, computed: torch.Tensor, reference: torch.Tensor, error: Optional[Exception] = None +): + """ + Prints a detailed report from rank 0 and the final status from all ranks. + """ + if rank == 0: + print("\n<<<<<<<<<< Correctness Test Report (Impl: FUSED_FULL) >>>>>>>>>>") + print(f"--- Detailed Validation on Rank {rank} ---") + header = f"{'Index':<8} | {'Computed':<15} | {'Reference':<15} | {'Abs. Diff':<15}" + print("--- Comparison of First 16 Values (Head 0) ---") + print(header) + print("-" * len(header)) + + comp_slice = computed[0, 0, :16].cpu().float() + ref_slice = reference[0, 0, :16].cpu().float() + diff_slice = torch.abs(comp_slice - ref_slice) + + for i in range(len(comp_slice)): + print(f"{i:<8} | {comp_slice[i]:<15.6f} | {ref_slice[i]:<15.6f} | {diff_slice[i]:<15.6f}") + print("-" * len(header)) + + # This final status prints from ALL ranks + if error: + print(f"❌ TEST FAILED for Rank {rank}:\n{error}") + else: + max_diff = torch.max(torch.abs(computed - reference)) + print(f"✅ TEST PASSED for Rank {rank}. Max absolute difference: {max_diff:.6f}") diff --git a/examples/14_all_gather_gemm/README.md b/examples/14_all_gather_gemm/README.md new file mode 100644 index 00000000..50e11a71 --- /dev/null +++ b/examples/14_all_gather_gemm/README.md @@ -0,0 +1,87 @@ +# Fused All-Gather + GEMM + +This folder provides an example of a distributed All-Gather + GEMM kernel. It explores two distinct patterns for fusing communication and computation: a **Pull model** and a **Push model**. + +The core kernel implementations are located in `examples/14_all_gather_gemm/`. + +Comparisons are performed against a baseline using the RCCL All-Gather collective and `torch.matmul`. + +----- + +## Architectural Patterns: Pull vs. Push + +The two main patterns explored are: + +### 1\. Pull Model + +In the **Pull model**, the consumer (GEMM kernel) takes full control. It actively "pulls" data from remote GPUs as it is needed using an `iris.load` instruction. The communication is fused directly into a single, persistent compute kernel. + +### 2\. Push Model + +The **Push model** decouples communication and computation. A dedicated producer kernel "pushes" data to a remote inbox using `iris.store`, and the consumer (GEMM kernel) waits for a synchronization signal before performing a fast local load from that inbox. + +----- + +## Usage + +### Simple Example Run + +To run a minimal, standalone example that demonstrates the kernel's functionality and validates its output for a single configuration, use the `example_run` scripts. + +**Pull Model:** + +```terminal +python examples/14_all_gather_gemm/example_run_pull.py --num_ranks 8 +``` + +**Push Model:** + +```terminal +python examples/14_all_gather_gemm/example_run_push.py --num_ranks 8 +``` + +### Validation and Benchmarking + +For more comprehensive testing, dedicated scripts in the `benchmark/examples/` directory handle both correctness validation and performance benchmarking across a range of configurations. The behavior of these scripts is controlled by flags. + +The scripts run a sweep of configurations defined in the JSON file at `dataset/ag_gemm.json`. + +#### Validation (-v) + +To verify the numerical correctness of an implementation against a PyTorch reference, run its benchmark script with the `-v` or `--validate` flag. + +**Pull Model:** + +```terminal +python benchmark/examples/benchmark_all_gather_gemm_pull.py --num_ranks 8 -v +``` + +**Push Model:** + +```terminal +python benchmark/examples/benchmark_all_gather_gemm_push.py --num_ranks 8 -v +``` + +#### Benchmarking (-b) + +To run the full performance benchmark sweep and save the results as `.json` files into the `results/` directory, use the `-b` or `--benchmark` flag. + +**Pull Model:** + +```terminal +python benchmark/examples/benchmark_all_gather_gemm_pull.py --num_ranks 8 -b +``` + +**Push Model:** + +```terminal +python benchmark/examples/benchmark_all_gather_gemm_push.py --num_ranks 8 -b +``` + +#### RCCL + Torch + +To validate and benchmark the RCCL + `torch.matmul` implementation, follow the same steps as the pull/push versions. + +```terminal +python examples/benchmark/reference/all_gather_gemm/benchmark_rccl_torch.py --num_ranks 8 -b +``` \ No newline at end of file diff --git a/examples/14_all_gather_gemm/all_gather_gemm_pull.py b/examples/14_all_gather_gemm/all_gather_gemm_pull.py new file mode 100644 index 00000000..c710c8a1 --- /dev/null +++ b/examples/14_all_gather_gemm/all_gather_gemm_pull.py @@ -0,0 +1,117 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +import triton +import triton.language as tl +import torch +import iris + + +@triton.jit +def persistent_ag_gemm( + A, + B, + C, + M, + N, + K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + NUM_SMS: tl.constexpr, + NUM_XCDS: tl.constexpr, + EVEN_K: tl.constexpr, + heap_bases: tl.tensor, + cur_rank: tl.constexpr, + world_size: tl.constexpr, +): + pid = tl.program_id(0) + + if NUM_XCDS != 1: + pid = (pid % NUM_XCDS) * (NUM_SMS // NUM_XCDS) + (pid // NUM_XCDS) + + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + total_tiles = num_pid_m * num_pid_n + + tl.assume(stride_am > 0) + tl.assume(stride_ak > 0) + tl.assume(stride_bn > 0) + tl.assume(stride_bk > 0) + tl.assume(stride_cm > 0) + tl.assume(stride_cn > 0) + + acc_dtype = tl.float32 if C.type.element_ty != tl.int8 else tl.int32 + + for tile_id in range(pid, total_tiles, NUM_SMS): + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + + rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M) + rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N) + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) + K_local = K // world_size + + for source_rank_id in range(world_size): + loop_k_local = tl.cdiv(K_local, BLOCK_SIZE_K) + if not EVEN_K: + loop_k_local -= 1 + + for k_block_idx in range(0, loop_k_local): + k_offset = k_block_idx * BLOCK_SIZE_K + rk_local = k_offset + tl.arange(0, BLOCK_SIZE_K) + A_ptr = A + rm[:, None] * stride_am + rk_local[None, :] * stride_ak + a = iris.load(tl.multiple_of(A_ptr, (1, 16)), cur_rank, source_rank_id, heap_bases) + + rk_global = (source_rank_id * K_local) + rk_local + B_ptr = B + rk_global[:, None] * stride_bk + rn[None, :] * stride_bn + b = tl.load(tl.multiple_of(B_ptr, (16, 1))) + + acc += tl.dot(a, b) + + if not EVEN_K: + k_offset = loop_k_local * BLOCK_SIZE_K + rk_local = k_offset + tl.arange(0, BLOCK_SIZE_K) + rk_local_mask = rk_local < K_local + A_ptr = A + rm[:, None] * stride_am + rk_local[None, :] * stride_ak + a = iris.load( + tl.multiple_of(A_ptr, (1, 16)), + cur_rank, + source_rank_id, + heap_bases, + mask=rk_local_mask[None, :], + other=0.0, + ) + + rk_global = (source_rank_id * K_local) + rk_local + rk_global_mask = rk_global < K + B_ptr = B + rk_global[:, None] * stride_bk + rn[None, :] * stride_bn + b = tl.load(tl.multiple_of(B_ptr, (16, 1)), mask=rk_global_mask[:, None], other=0.0) + + acc += tl.dot(a, b) + + c = acc.to(C.type.element_ty) + C_BASE = ( + C + + (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M))[:, None] * stride_cm + + (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N))[None, :] * stride_cn + ) + mask = ((pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M))[:, None] < M) & ( + (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N))[None, :] < N + ) + tl.store(C_BASE, c, mask=mask) diff --git a/examples/14_all_gather_gemm/all_gather_gemm_push.py b/examples/14_all_gather_gemm/all_gather_gemm_push.py new file mode 100644 index 00000000..7cb4fe4b --- /dev/null +++ b/examples/14_all_gather_gemm/all_gather_gemm_push.py @@ -0,0 +1,204 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +import triton +import triton.language as tl +import iris +import torch + + +@triton.jit +def push_shards_kernel( + A_local, + A_inbox, + signal_flags, + M, + K_local, + stride_al_m, + stride_al_k, + stride_ai_rank, + stride_ai_m, + stride_ai_k, + stride_sf_d, + stride_sf_s, + stride_sf_m, + stride_sf_k, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + cur_rank: tl.constexpr, + world_size: tl.constexpr, + heap_bases: tl.tensor, +): + pid_m = tl.program_id(0) + pid_k = tl.program_id(1) + tl.assume(stride_al_m > 0) + tl.assume(stride_al_k > 0) + tl.assume(stride_ai_rank > 0) + tl.assume(stride_ai_m > 0) + tl.assume(stride_ai_k > 0) + tl.assume(stride_sf_d > 0) + tl.assume(stride_sf_s > 0) + tl.assume(stride_sf_m > 0) + tl.assume(stride_sf_k > 0) + + offsets_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offsets_k = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + + offsets_m = tl.max_contiguous(tl.multiple_of(offsets_m, BLOCK_SIZE_M), BLOCK_SIZE_M) + offsets_k = tl.max_contiguous(tl.multiple_of(offsets_k, BLOCK_SIZE_K), BLOCK_SIZE_K) + mask = (offsets_m[:, None] < M) & (offsets_k[None, :] < K_local) + + A_ptr = A_local + offsets_m[:, None] * stride_al_m + offsets_k[None, :] * stride_al_k + a_tile = tl.load(tl.multiple_of(A_ptr, (1, 16)), mask=mask, other=0.0) + + for dest_rank_id in range(world_size): + dest_ptr = ( + A_inbox + cur_rank * stride_ai_rank + offsets_m[:, None] * stride_ai_m + offsets_k[None, :] * stride_ai_k + ) + iris.store(dest_ptr, a_tile, cur_rank, dest_rank_id, heap_bases, mask=mask) + + flag_ptr = ( + signal_flags + + dest_rank_id * stride_sf_d + + cur_rank * stride_sf_s + + pid_m * stride_sf_m + + pid_k * stride_sf_k + ) + iris.atomic_add(flag_ptr, 1, cur_rank, dest_rank_id, heap_bases, sem="release", scope="sys") + + +@triton.jit +def gemm_push_kernel( + A_inbox, + B, + C, + M, + N, + K, + signal_flags, + stride_ai_rank, + stride_ai_m, + stride_ai_k, + stride_b_k, + stride_b_n, + stride_c_m, + stride_c_n, + stride_sf_d, + stride_sf_s, + stride_sf_m, + stride_sf_k, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + NUM_SMS: tl.constexpr, + NUM_XCDS: tl.constexpr, + EVEN_K: tl.constexpr, + cur_rank: tl.constexpr, + world_size: tl.constexpr, +): + pid = tl.program_id(0) + if NUM_XCDS != 1: + pid = (pid % NUM_XCDS) * (NUM_SMS // NUM_XCDS) + (pid // NUM_XCDS) + + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + total_tiles = num_pid_m * num_pid_n + + tl.assume(stride_ai_rank > 0) + tl.assume(stride_ai_m > 0) + tl.assume(stride_ai_k > 0) + tl.assume(stride_b_k > 0) + tl.assume(stride_b_n > 0) + tl.assume(stride_c_m > 0) + tl.assume(stride_c_n > 0) + tl.assume(stride_sf_d > 0) + tl.assume(stride_sf_s > 0) + tl.assume(stride_sf_m > 0) + tl.assume(stride_sf_k > 0) + + acc_dtype = tl.float32 if C.type.element_ty != tl.int8 else tl.int32 + + for tile_id in range(pid, total_tiles, NUM_SMS): + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + + tl.assume(pid_m >= 0) + tl.assume(pid_n >= 0) + + rm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + rn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M) + rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N) + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) + K_local = K // world_size + + for source_rank_id in range(world_size): + num_k_tiles = tl.cdiv(K_local, BLOCK_SIZE_K) + loop_k_tiles = num_k_tiles + if not EVEN_K: + loop_k_tiles -= 1 + + for k_tile_idx in range(loop_k_tiles): + flag_ptr = ( + signal_flags + + cur_rank * stride_sf_d + + source_rank_id * stride_sf_s + + pid_m * stride_sf_m + + k_tile_idx * stride_sf_k + ) + while tl.load(flag_ptr, cache_modifier=".ca") == 0: + pass + + k_offset = k_tile_idx * BLOCK_SIZE_K + rk_local = k_offset + tl.arange(0, BLOCK_SIZE_K) + A_ptr = ( + A_inbox + + source_rank_id * stride_ai_rank + + rm[:, None] * stride_ai_m + + rk_local[None, :] * stride_ai_k + ) + a = tl.load(tl.multiple_of(A_ptr, (1, 16))) + rk_global = (source_rank_id * K_local) + rk_local + B_ptr = B + rk_global[:, None] * stride_b_k + rn[None, :] * stride_b_n + b = tl.load(tl.multiple_of(B_ptr, (16, 1))) + acc += tl.dot(a, b) + + if not EVEN_K: + k_tile_idx = loop_k_tiles + flag_ptr = ( + signal_flags + + cur_rank * stride_sf_d + + source_rank_id * stride_sf_s + + pid_m * stride_sf_m + + k_tile_idx * stride_sf_k + ) + while tl.load(flag_ptr, cache_modifier=".ca") == 0: + pass + + k_offset = k_tile_idx * BLOCK_SIZE_K + rk_local = k_offset + tl.arange(0, BLOCK_SIZE_K) + A_ptr = ( + A_inbox + + source_rank_id * stride_ai_rank + + rm[:, None] * stride_ai_m + + rk_local[None, :] * stride_ai_k + ) + a = tl.load(tl.multiple_of(A_ptr, (1, 16)), mask=(rk_local[None, :] < K_local), other=0.0) + rk_global = (source_rank_id * K_local) + rk_local + B_ptr = B + rk_global[:, None] * stride_b_k + rn[None, :] * stride_b_n + b = tl.load(tl.multiple_of(B_ptr, (16, 1)), mask=(rk_global[:, None] < K), other=0.0) + acc += tl.dot(a, b) + + rm_store = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + rn_store = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + C_BASE = C + rm_store[:, None] * stride_c_m + rn_store[None, :] * stride_c_n + c = acc.to(C.type.element_ty) + mask = (rm_store[:, None] < M) & (rn_store[None, :] < N) + tl.store(C_BASE, c, mask=mask) diff --git a/examples/14_all_gather_gemm/example_run_pull.py b/examples/14_all_gather_gemm/example_run_pull.py new file mode 100644 index 00000000..3dfe9733 --- /dev/null +++ b/examples/14_all_gather_gemm/example_run_pull.py @@ -0,0 +1,179 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +""" +A simple, minimal example demonstrating how to use the persistent_ag_gemm +'Pull' model kernel for a distributed All-Gather + GEMM operation. + +This script initializes Iris and Torch Distributed, creates sample input +tensors, launches the fused Triton kernel, and validates the +output against a standard PyTorch reference implementation. + +The kernel is defined in the all_gather_gemm_pull.py file. +""" + +import torch +import torch.multiprocessing as mp +import torch.distributed as dist +import iris +import argparse +from all_gather_gemm_pull import persistent_ag_gemm + + +def parse_args(): + """Parses command-line arguments for the example.""" + parser = argparse.ArgumentParser(description="A minimal example for a fused All-Gather + GEMM (Pull Model).") + parser.add_argument("--M", type=int, default=128, help="M dimension of the GEMM.") + parser.add_argument("--N", type=int, default=256, help="N dimension of the GEMM.") + parser.add_argument("--K", type=int, default=8192, help="Total K dimension of the GEMM (will be sharded).") + parser.add_argument("--BLOCK_SIZE_M", type=int, default=256, help="Triton kernel tile size for M dimension.") + parser.add_argument("--BLOCK_SIZE_N", type=int, default=64, help="Triton kernel tile size for N dimension.") + parser.add_argument("--BLOCK_SIZE_K", type=int, default=64, help="Triton kernel tile size for K dimension.") + parser.add_argument("--GROUP_SIZE_M", type=int, default=6, help="Triton kernel group size for M dimension.") + + parser.add_argument("--num_ranks", type=int, default=8, help="Number of GPUs to run the example on.") + parser.add_argument( + "--dtype", type=str, default="float16", choices=["float16", "bfloat16"], help="PyTorch data type to use." + ) + + return parser.parse_args() + + +def setup_example_data(rank, world_size, args, dtype): + """Creates a set of random tensors to serve as inputs for the kernel.""" + print(f"[Rank {rank}] Creating example tensors...") + + # The total K dimension is sharded across all ranks. + K_total = args.K + if K_total % world_size != 0: + raise ValueError("K dimension must be divisible by the world size for this example.") + K_local = K_total // world_size + + # Create the full A and B matrices on rank 0 + if rank == 0: + A_global = torch.randn(args.M, K_total, dtype=dtype, device="cuda") + B_global = torch.randn(K_total, args.N, dtype=dtype, device="cuda") + else: + A_global = torch.empty(args.M, K_total, dtype=dtype, device="cuda") + B_global = torch.empty(K_total, args.N, dtype=dtype, device="cuda") + + # Broadcast the full matrices to all ranks to ensure data consistency + dist.broadcast(A_global, src=0) + dist.broadcast(B_global, src=0) + + # Each rank takes its local, vertical slice of A + A_local_shard = A_global[:, rank * K_local : (rank + 1) * K_local].contiguous() + + return { + "A_local_shard": A_local_shard, + "B_global": B_global, # B remains replicated + "A_global_for_validation": A_global, # Keep the full A for the reference calculation + } + + +def example_run(rank: int, world_size: int, init_url: str, args: argparse.Namespace): + backend = "nccl" if torch.cuda.is_available() else "gloo" + dist.init_process_group( + backend=backend, init_method=init_url, world_size=world_size, rank=rank, device_id=torch.device(f"cuda:{rank}") + ) + + # Initialize Iris for distributed communication + shmem = iris.iris() + + torch.manual_seed(42) # Use a fixed seed for consistent random data + torch.cuda.set_device(rank) + dtype = getattr(torch, args.dtype) + + if rank == 0: + print("--- Fused All-Gather + GEMM (Pull Model) Minimal Example ---") + print(f"Running with {world_size} rank(s).") + + # Set up the example input tensors + tensor_data = setup_example_data(rank, world_size, args, dtype) + shmem.barrier() + + # Prepare for the kernel launch + A_original = tensor_data["A_local_shard"] + B = tensor_data["B_global"] + + # Allocate a tensor in Iris's shared memory heap for remote access + A_iris = shmem.empty(A_original.shape, dtype=A_original.dtype) + A_iris.copy_(A_original) + + C_fused = torch.empty(args.M, args.N, dtype=dtype).cuda() # Output tensor for our kernel + + NUM_SMS = torch.cuda.get_device_properties(rank).multi_processor_count + grid = (NUM_SMS,) + + # Launch the fused Triton kernel + if rank == 0: + print("\nLaunching persistent_ag_gemm kernel...") + persistent_ag_gemm[grid]( + A_iris, + B, + C_fused, + args.M, + args.N, + args.K, + A_iris.stride(0), + A_iris.stride(1), + B.stride(0), + B.stride(1), + C_fused.stride(0), + C_fused.stride(1), + BLOCK_SIZE_M=args.BLOCK_SIZE_M, + BLOCK_SIZE_N=args.BLOCK_SIZE_N, + BLOCK_SIZE_K=args.BLOCK_SIZE_K, + GROUP_SIZE_M=args.GROUP_SIZE_M, + NUM_SMS=NUM_SMS, + NUM_XCDS=1, + EVEN_K=(args.K % args.BLOCK_SIZE_K == 0), + heap_bases=shmem.get_heap_bases(), + cur_rank=rank, + world_size=world_size, + ) + + torch.cuda.synchronize() + shmem.barrier() + dist.barrier() + + # Print a summary and perform validation + if rank == 0: + print("\n--- Example Run Finished ---") + print(f"Output tensor C shape: {C_fused.shape}") + print("Output tensor C values (first 5 elements of the first row):") + print(C_fused[0, :5]) + print("--------------------------") + + print("\n--- Validation ---") + # Calculate the reference solution using torch.matmul + C_ref = torch.matmul(tensor_data["A_global_for_validation"], B) + + # Compare the results + try: + torch.testing.assert_close(C_fused, C_ref, atol=1.0, rtol=0.1) + print("✅ Validation PASSED") + except AssertionError as e: + print("❌ Validation FAILED") + print(e) + print("------------------") + + shmem.barrier() + dist.destroy_process_group() + + +def main(): + args = parse_args() + num_ranks = args.num_ranks + init_url = "tcp://127.0.0.1:29504" + mp.spawn( + fn=example_run, + args=(num_ranks, init_url, args), + nprocs=num_ranks, + join=True, + ) + + +if __name__ == "__main__": + main() diff --git a/examples/14_all_gather_gemm/example_run_push.py b/examples/14_all_gather_gemm/example_run_push.py new file mode 100644 index 00000000..e70bdf15 --- /dev/null +++ b/examples/14_all_gather_gemm/example_run_push.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +""" +A simple, minimal example demonstrating how to use the 'Push' model kernels +for a distributed All-Gather + GEMM operation. + +This script initializes Iris and Torch Distributed, creates sample input +tensors, and then launches the two-kernel pipeline: +1. A 'push' kernel broadcasts local shards of matrix 'A' to all GPUs. +2. A 'wait-and-compute' kernel waits for data and performs the GEMM. +Finally, it validates the output against a PyTorch reference. + +The kernels are defined in the all_gather_gemm_push.py file. +""" + +import torch +import torch.multiprocessing as mp +import torch.distributed as dist +import iris +import argparse + +# Assume the kernels are in a file named all_gather_gemm_push.py +from all_gather_gemm_push import push_shards_kernel, gemm_push_kernel + + +def parse_args(): + """Parses command-line arguments for the example.""" + parser = argparse.ArgumentParser(description="A minimal example for a fused All-Gather + GEMM (Push Model).") + parser.add_argument("--M", type=int, default=128, help="M dimension of the GEMM.") + parser.add_argument("--N", type=int, default=256, help="N dimension of the GEMM.") + parser.add_argument("--K", type=int, default=8192, help="Total K dimension of the GEMM (will be sharded).") + parser.add_argument("--BLOCK_SIZE_M", type=int, default=256, help="Triton kernel tile size for M dimension.") + parser.add_argument("--BLOCK_SIZE_N", type=int, default=64, help="Triton kernel tile size for N dimension.") + parser.add_argument("--BLOCK_SIZE_K", type=int, default=64, help="Triton kernel tile size for K dimension.") + parser.add_argument("--GROUP_SIZE_M", type=int, default=6, help="Triton kernel group size for M dimension.") + parser.add_argument("--num_ranks", type=int, default=8, help="Number of GPUs to run the example on.") + parser.add_argument( + "--dtype", type=str, default="float16", choices=["float16", "bfloat16"], help="PyTorch data type to use." + ) + return parser.parse_args() + + +def setup_example_data(rank, world_size, args, dtype): + """Creates a set of random tensors to serve as inputs for the kernel.""" + print(f"[Rank {rank}] Creating example tensors...") + + K_total = args.K + if K_total % world_size != 0: + raise ValueError("K dimension must be divisible by the world size for this example.") + K_local = K_total // world_size + + if rank == 0: + A_global = torch.randn(args.M, K_total, dtype=dtype, device="cuda") + B_global = torch.randn(K_total, args.N, dtype=dtype, device="cuda") + else: + A_global = torch.empty(args.M, K_total, dtype=dtype, device="cuda") + B_global = torch.empty(K_total, args.N, dtype=dtype, device="cuda") + + dist.broadcast(A_global, src=0) + dist.broadcast(B_global, src=0) + + A_local_shard = A_global[:, rank * K_local : (rank + 1) * K_local].contiguous() + + return { + "A_local_shard": A_local_shard, + "B_global": B_global, + "A_global_for_validation": A_global, + } + + +def example_run(rank: int, world_size: int, init_url: str, args: argparse.Namespace): + backend = "nccl" if torch.cuda.is_available() else "gloo" + dist.init_process_group( + backend=backend, init_method=init_url, world_size=world_size, rank=rank, device_id=torch.device(f"cuda:{rank}") + ) + + shmem = iris.iris() + torch.manual_seed(42) + torch.cuda.set_device(rank) + dtype = getattr(torch, args.dtype) + + if rank == 0: + print("--- Fused All-Gather + GEMM (Push Model) Minimal Example ---") + print(f"Running with {world_size} rank(s).") + + tensor_data = setup_example_data(rank, world_size, args, dtype) + shmem.barrier() + + # Prepare for the kernel launch + A_original = tensor_data["A_local_shard"] + B = tensor_data["B_global"] + C_fused = torch.empty(args.M, args.N, dtype=dtype).cuda() + + # Allocate tensors in Iris shared memory + A_local_iris = shmem.empty(A_original.shape, dtype=A_original.dtype) + A_local_iris.copy_(A_original) + + # Create an "inbox" on each rank to receive data from others + A_inbox_iris = shmem.empty((world_size, args.M, A_original.shape[1]), dtype=A_original.dtype) + + # Create flags for synchronization + num_m_tiles = (args.M + args.BLOCK_SIZE_M - 1) // args.BLOCK_SIZE_M + num_k_tiles = (A_original.shape[1] + args.BLOCK_SIZE_K - 1) // args.BLOCK_SIZE_K + signal_flags = shmem.zeros((world_size, world_size, num_m_tiles, num_k_tiles), dtype=torch.int32) + + NUM_SMS = torch.cuda.get_device_properties(rank).multi_processor_count + + # Launch the two-kernel Push pipeline + if rank == 0: + print("\nLaunching push_shards_kernel and gemm_push_kernel...") + + # Define grid for the push kernel + push_grid = (num_m_tiles, num_k_tiles) + push_shards_kernel[push_grid]( + A_local_iris, + A_inbox_iris, + signal_flags, + args.M, + A_local_iris.shape[1], + A_local_iris.stride(0), + A_local_iris.stride(1), + A_inbox_iris.stride(0), + A_inbox_iris.stride(1), + A_inbox_iris.stride(2), + signal_flags.stride(0), + signal_flags.stride(1), + signal_flags.stride(2), + signal_flags.stride(3), + BLOCK_SIZE_M=args.BLOCK_SIZE_M, + BLOCK_SIZE_K=args.BLOCK_SIZE_K, + cur_rank=rank, + world_size=world_size, + heap_bases=shmem.get_heap_bases(), + ) + + # Define grid for the GEMM kernel + gemm_grid = (NUM_SMS,) + gemm_push_kernel[gemm_grid]( + A_inbox_iris, + B, + C_fused, + args.M, + args.N, + args.K, + signal_flags, + A_inbox_iris.stride(0), + A_inbox_iris.stride(1), + A_inbox_iris.stride(2), + B.stride(0), + B.stride(1), + C_fused.stride(0), + C_fused.stride(1), + signal_flags.stride(0), + signal_flags.stride(1), + signal_flags.stride(2), + signal_flags.stride(3), + BLOCK_SIZE_M=args.BLOCK_SIZE_M, + BLOCK_SIZE_N=args.BLOCK_SIZE_N, + BLOCK_SIZE_K=args.BLOCK_SIZE_K, + GROUP_SIZE_M=args.GROUP_SIZE_M, + NUM_SMS=NUM_SMS, + NUM_XCDS=1, + EVEN_K=(A_local_iris.shape[1] % args.BLOCK_SIZE_K == 0), + cur_rank=rank, + world_size=world_size, + ) + + torch.cuda.synchronize() + shmem.barrier() + dist.barrier() + + # Print a summary and perform validation + if rank == 0: + print("\n--- Example Run Finished ---") + print(f"Output tensor C shape: {C_fused.shape}") + print("Output tensor C values (first 5 elements of the first row):") + print(C_fused[0, :5]) + print("--------------------------") + + print("\n--- Validation ---") + C_ref = torch.matmul(tensor_data["A_global_for_validation"], B) + try: + torch.testing.assert_close(C_fused, C_ref, atol=1.0, rtol=0.1) + print("✅ Validation PASSED") + except AssertionError as e: + print("❌ Validation FAILED") + print(e) + print("------------------") + + shmem.barrier() + dist.destroy_process_group() + + +def main(): + args = parse_args() + num_ranks = args.num_ranks + init_url = "tcp://127.0.0.1:29504" + mp.spawn( + fn=example_run, + args=(num_ranks, init_url, args), + nprocs=num_ranks, + join=True, + ) + + +if __name__ == "__main__": + main() diff --git a/examples/README.md b/examples/README.md index 0794d70f..414afaae 100644 --- a/examples/README.md +++ b/examples/README.md @@ -24,6 +24,8 @@ This directory contains various algorithm implementations for distributed comput - [`10_gemm_all_scatter_wg_specialization`](10_gemm_all_scatter_wg_specialization): Matrix multiplication with all-scatter using workgroup specialization - [`11_gemm_all_scatter_producer_consumer`](11_gemm_all_scatter_producer_consumer): Matrix multiplication with all-scatter using producer-consumer concurrent kernels - [`12_gemm_all_scatter_bulk_synchronous`](12_gemm_all_scatter_bulk_synchronous): Matrix multiplication with all-scatter using the bulk synchronous parallel approach +- [`13_flash_decode`](13_flash_decode): Fused Flash Decode Attention for accelerating LLM inference +- [`14_all_gather_gemm`](14_all_gather_gemm): Fused All-Gather + GEMM with Pull and Push models ### Utilities - [`benchmark`](benchmark): Benchmarking utilities and performance testing tools @@ -69,4 +71,13 @@ python examples/11_gemm_all_scatter_producer_consumer/benchmark.py --benchmark - # Example command to run benchmark with all-scatter bulk synchronous approach python examples/12_gemm_all_scatter_bulk_synchronous/benchmark.py --benchmark --validate --num_ranks 8 + +# Flash Decode Attention - simple example run +python examples/13_flash_decode/example_run.py --num_ranks 8 + +# All-Gather + GEMM - Pull model +python examples/14_all_gather_gemm/example_run_pull.py --num_ranks 8 + +# All-Gather + GEMM - Push model +python examples/14_all_gather_gemm/example_run_push.py --num_ranks 8 ``` diff --git a/examples/benchmark/bench_all_shapes.py b/examples/benchmark/bench_all_shapes.py index 0f5ce5b3..adb15657 100644 --- a/examples/benchmark/bench_all_shapes.py +++ b/examples/benchmark/bench_all_shapes.py @@ -6,6 +6,7 @@ from datetime import datetime import argparse import json +import torch def launch_sbatch( @@ -110,16 +111,28 @@ def main(hashes, config, sbatch_script_content, input_json, tiling_json, dry_run if mkn not in mkn_gemm_tiles: mkn_gemm_tiles[mkn] = {key: entry[key] for key in optional_keys if key in entry} - if config["partition"] is not None: - if "mi300" in config["partition"]: - print("Running on MI300") + # Determine gemm_sms based on available GPU or partition name + try: + if torch.cuda.is_available(): + gemm_sms = torch.cuda.get_device_properties(0).multi_processor_count + print(f"Auto-detected CU count: {gemm_sms}") + else: + gemm_sms = None + except Exception: + # Fall back to partition-based detection + gemm_sms = None + + if gemm_sms is None: + if config["partition"] is not None: + if "mi300" in config["partition"]: + print("Running on MI300 (partition-based)") + gemm_sms = 304 + elif "mi250" in config["partition"]: + print("Running on MI250 (partition-based)") + gemm_sms = 104 + else: + print("Assuming MI300 (default)") gemm_sms = 304 - elif "mi250" in config["partition"]: - print("Running on MI250") - gemm_sms = 104 - else: - print("Assuming MI300") - gemm_sms = 304 enable_algorithms = False enable_mkn = True diff --git a/examples/benchmark/reference/all_gather.py b/examples/benchmark/reference/all_gather.py index 083e7219..9f204530 100755 --- a/examples/benchmark/reference/all_gather.py +++ b/examples/benchmark/reference/all_gather.py @@ -8,6 +8,7 @@ import random import iris import argparse +import os from examples.common.utils import JSONWriter @@ -53,7 +54,8 @@ def main(): validate = args["validate"] benchmark = args["benchmark"] - dist.init_process_group("nccl") + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + dist.init_process_group("nccl", device_id=torch.device(f"cuda:{local_rank}")) rank = dist.get_rank() world_size = dist.get_world_size() diff --git a/examples/benchmark/reference/all_gather_gemm/benchmark_rccl_torch.py b/examples/benchmark/reference/all_gather_gemm/benchmark_rccl_torch.py new file mode 100644 index 00000000..fd574527 --- /dev/null +++ b/examples/benchmark/reference/all_gather_gemm/benchmark_rccl_torch.py @@ -0,0 +1,203 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import random +import sys +import os +import argparse +import json + +from examples.common.utils import JSONWriter +import iris + +torch.manual_seed(123) +random.seed(123) + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Run a sweep of RCCL All-Gather + torch.matmul benchmarks from a config file.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("-v", "--validate", action="store_true", help="Enable validation mode.") + parser.add_argument("-b", "--benchmark", action="store_true", help="Enable benchmarking mode.") + parser.add_argument( + "--config_file", + type=str, + default="dataset/ag_gemm.json", + help="Path to the JSON file with benchmark configurations.", + ) + parser.add_argument( + "--output_file", type=str, default="rccl_torch_matmul_log.json", help="Base name for output files" + ) + parser.add_argument( + "--output_dir", type=str, default="results/all_gather_gemm_rccl", help="Name of the output directory" + ) + parser.add_argument("--num_ranks", type=int, default=8, help="Number of GPUs to run on.") + parser.add_argument("-m", type=int, default=1024) + parser.add_argument("-n", type=int, default=3584) + parser.add_argument("-k", type=int, default=8192) + parser.add_argument("--datatype", type=str, default="fp16", choices=["fp16", "bf16", "fp32"]) + return parser.parse_args() + + +def worker(rank: int, world_size: int, init_url: str, args: argparse.Namespace): + dist.init_process_group( + backend="nccl", init_method=init_url, world_size=world_size, rank=rank, device_id=torch.device(f"cuda:{rank}") + ) + torch.cuda.set_device(rank) + + output_dir = args.output_dir + if rank == 0: + os.makedirs(output_dir, exist_ok=True) + dist.barrier() + + with open(args.config_file, "r") as f: + configs_to_run = json.load(f) + + if rank == 0: + print(f"Loaded {len(configs_to_run)} configurations from {args.config_file}") + + for config in configs_to_run: + run_args = vars(args).copy() + run_args.update(config) + + dtype_map = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp32": torch.float32} + datatype = dtype_map.get(run_args["datatype"]) + + M, N, K = run_args["m"], run_args["n"], run_args["k"] + if rank == 0: + print(f"\n--- Running Benchmark for M={M}, N={N}, K={K} ---") + sys.stdout.flush() + + base_name, extension = os.path.splitext(args.output_file) + unique_filename = f"{base_name}_m_{M}{extension}" + full_output_path = os.path.join(output_dir, unique_filename) + + json_writer = JSONWriter(full_output_path) + json_writer.add_field("world_size", world_size) + for key, value in run_args.items(): + json_writer.add_field(key, value) + + K_local = K // world_size + + if rank == 0: + A_global = torch.randn((M, K), dtype=datatype, device="cuda") + else: + A_global = torch.empty((M, K), dtype=datatype, device="cuda") + dist.broadcast(A_global, src=0) + + A_local = A_global[:, rank * K_local : (rank + 1) * K_local].contiguous() + + if rank == 0: + B = torch.randn((K, N), device="cuda", dtype=datatype) + else: + B = torch.empty((K, N), device="cuda", dtype=datatype) + dist.broadcast(B, src=0) + + C = torch.empty((M, N), device="cuda", dtype=datatype) + all_a_shards = [torch.empty_like(A_local) for _ in range(world_size)] + + main_stream = torch.cuda.Stream() + kernel_timing = { + "rccl_all_gather": { + "start_event": torch.cuda.Event(enable_timing=True), + "end_event": torch.cuda.Event(enable_timing=True), + "ms": 0, + "experiments": 0, + }, + "torch_matmul": { + "start_event": torch.cuda.Event(enable_timing=True), + "end_event": torch.cuda.Event(enable_timing=True), + "ms": 0, + "experiments": 0, + }, + } + + def run_experiment(): + nonlocal kernel_timing + with torch.cuda.stream(main_stream): + kernel_timing["rccl_all_gather"]["start_event"].record() + dist.all_gather(all_a_shards, A_local) + A_gathered = torch.cat(all_a_shards, dim=1) + kernel_timing["rccl_all_gather"]["end_event"].record() + + kernel_timing["torch_matmul"]["start_event"].record() + torch.matmul(A_gathered, B, out=C) + kernel_timing["torch_matmul"]["end_event"].record() + + torch.cuda.synchronize() + kernel_timing["rccl_all_gather"]["ms"] += kernel_timing["rccl_all_gather"]["start_event"].elapsed_time( + kernel_timing["rccl_all_gather"]["end_event"] + ) + kernel_timing["rccl_all_gather"]["experiments"] += 1 + kernel_timing["torch_matmul"]["ms"] += kernel_timing["torch_matmul"]["start_event"].elapsed_time( + kernel_timing["torch_matmul"]["end_event"] + ) + kernel_timing["torch_matmul"]["experiments"] += 1 + + run_experiment() + dist.barrier() + + for key in kernel_timing: + kernel_timing[key]["ms"] = 0 + kernel_timing[key]["experiments"] = 0 + + if args.benchmark: + total_ms = iris.do_bench(run_experiment, barrier_fn=dist.barrier) + tflops = 2 * M * N * K * 1e-12 / (total_ms * 1e-3) + if rank == 0: + print(f"Result (iris.do_bench): {total_ms:.3f} ms, {tflops:.3f} TFLOPS") + json_writer.add_field("total_ms", total_ms) + json_writer.add_field("tflops", tflops) + + for key in kernel_timing: + if kernel_timing[key]["experiments"] > 0: + avg_kernel_ms = kernel_timing[key]["ms"] / kernel_timing[key]["experiments"] + json_writer.add_field(key + "_ms", avg_kernel_ms) + if rank == 0: + print(f"Result (CUDA Events) - {key}: {avg_kernel_ms:.3f} ms") + + if args.validate: + if not args.benchmark: + run_experiment() + dist.barrier() + + if rank == 0: + print("Validating...") + + C_ref = torch.matmul(A_global, B) + success = torch.allclose(C, C_ref, atol=1.0, rtol=0.05) + passed_str = "passed" if success else "failed" + print(f"Final C validation for rank {rank} is {passed_str}.") + json_writer.add_field("validation_passed", success) + + if rank == 0: + json_writer.flush() + print(f"Saved results to {full_output_path}") + sys.stdout.flush() + + if rank == 0: + print("\nBenchmark sweep complete.") + + dist.barrier() + dist.destroy_process_group() + + +def main(): + args = parse_args() + if not args.validate and not args.benchmark: + print("Error: You must specify a mode to run. Use -v or -b.", file=sys.stderr) + sys.exit(1) + + num_ranks = args.num_ranks + init_url = "tcp://127.0.0.1:29505" + mp.spawn(fn=worker, args=(num_ranks, init_url, args), nprocs=num_ranks, join=True) + + +if __name__ == "__main__": + main() diff --git a/examples/benchmark/reference/all_reduce.py b/examples/benchmark/reference/all_reduce.py index ab4f61ef..369d0619 100755 --- a/examples/benchmark/reference/all_reduce.py +++ b/examples/benchmark/reference/all_reduce.py @@ -8,6 +8,7 @@ import random import iris import argparse +import os from examples.common.utils import JSONWriter @@ -52,7 +53,8 @@ def main(): validate = args["validate"] benchmark = args["benchmark"] - dist.init_process_group("nccl") + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + dist.init_process_group("nccl", device_id=torch.device(f"cuda:{local_rank}")) rank = dist.get_rank() world_size = dist.get_world_size() diff --git a/examples/benchmark/reference/flash_decode_rccl/benchmark_flash_decode_rccl.py b/examples/benchmark/reference/flash_decode_rccl/benchmark_flash_decode_rccl.py new file mode 100644 index 00000000..aa607b64 --- /dev/null +++ b/examples/benchmark/reference/flash_decode_rccl/benchmark_flash_decode_rccl.py @@ -0,0 +1,199 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +import sys +import json +import itertools +import os +from pathlib import Path +import argparse + +import torch +import torch.distributed as dist + +import iris +from examples.benchmark.reference.flash_decode_rccl.flash_decode_layer_rccl import flash_decode_layer_rccl + + +def parse_args(): + """ + Arguments for the benchmark + The default parameters are in dataset/flash_decode_config_rccl.json + A different config file can be set with the --config flag + """ + parser = argparse.ArgumentParser( + description="Run Flash Decode RCCL benchmark with parameters from a config file.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + parser.add_argument( + "-c", + "--config", + type=str, + default="dataset/flash_decode_config_rccl.json", + help="Path to the JSON configuration file", + ) + + config_args, _ = parser.parse_known_args() + + config_defaults = {} + if os.path.exists(config_args.config): + try: + with open(config_args.config, "r") as f: + config_from_file = json.load(f) + if config_from_file: + print(f"Configuration successfully loaded from '{config_args.config}'") + config_defaults = {**config_from_file, **config_from_file.get("sweep_parameters", {})} + if "sweep_parameters" in config_defaults: + del config_defaults["sweep_parameters"] + except json.JSONDecodeError: + print(f"Error: Config file '{config_args.config}' is not valid JSON.") + else: + print(f"Warning: Config file '{config_args.config}' not found.") + + parser.set_defaults(**config_defaults) + + parser.add_argument("--output_dir", type=str, help="Directory to save results") + parser.add_argument("--data_type", type=str, choices=["float16", "bfloat16", "float32"], help="PyTorch data type") + parser.add_argument("--warmup_iterations", type=int, help="Number of warmup iterations") + parser.add_argument("--repeat_iterations", type=int, help="Number of benchmark iterations") + parser.add_argument("--page_size", type=int, help="Page size for KV cache", default=1) + + parser.add_argument("--kv_len", type=int, nargs="+", help="Override KV_LEN_SWEEP") + parser.add_argument("--num_heads", type=int, nargs="+", help="Override NUM_HEADS_SWEEP") + parser.add_argument("--head_dim", type=int, nargs="+", help="Override HEAD_DIM_SWEEP") + parser.add_argument("--num_seqs", type=int, nargs="+", help="Override NUM_SEQS_SWEEP") + + final_args = parser.parse_args() + return final_args + + +def prepare_perf_data(config, num_query_heads, num_kv_heads, page_size, datatype): + """Prepares local data for the performance test on the current rank.""" + num_blocks_per_rank = (config["kv_len"] + page_size - 1) // page_size + + query = torch.randn(config["num_seqs"], num_query_heads, config["head_dim"], dtype=datatype).cuda() + key_cache_this_rank = torch.randn( + num_blocks_per_rank, page_size, num_kv_heads, config["head_dim"], dtype=datatype + ).cuda() + value_cache_this_rank = torch.randn( + num_blocks_per_rank, page_size, num_kv_heads, config["head_dim"], dtype=datatype + ).cuda() + block_tables_this_rank = torch.arange(num_blocks_per_rank, dtype=torch.int32).repeat(config["num_seqs"], 1).cuda() + + return { + "query": query, + "key_cache_this_rank": key_cache_this_rank, + "value_cache_this_rank": value_cache_this_rank, + "block_tables_this_rank": block_tables_this_rank, + } + + +def run_benchmark(args): + local_rank = int(os.environ["LOCAL_RANK"]) + dist.init_process_group(backend="nccl", device_id=torch.device(f"cuda:{local_rank}")) + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) + + output_dir = args.output_dir + datatype = getattr(torch, args.data_type) + page_size = args.page_size + + if rank == 0: + if not os.path.exists(output_dir): + os.makedirs(output_dir) + print(f"Created output directory: '{output_dir}'") + + tp_group = dist.new_group(ranks=range(world_size)) + torch.manual_seed(42) + + config_sweep = [] + param_product = itertools.product(args.kv_len, args.num_heads, args.head_dim, args.num_seqs) + for kv_len, num_heads, head_dim, num_seqs in param_product: + config_sweep.append( + { + "kv_len": kv_len, + "num_heads": num_heads, + "head_dim": head_dim, + "num_seqs": num_seqs, + } + ) + + # Loop through configs + for i, config in enumerate(config_sweep): + if rank == 0: + print(f"\n--- Running Config {i + 1}/{len(config_sweep)}: {config} ---") + + num_query_heads = config["num_heads"] + num_kv_heads = num_query_heads // 8 if num_query_heads >= 8 else 1 + scale = config["head_dim"] ** -0.5 + + keyword_params = { + "page_size": page_size, + "scale": scale, + "soft_cap": 0.0, + "max_allowed_batch": config["num_seqs"], + } + + fd_layer = flash_decode_layer_rccl( + rank, + world_size, + num_query_heads, + num_kv_heads, + config["head_dim"], + config["head_dim"], + tp_group, + **keyword_params, + ) + + tensor_data = prepare_perf_data(config, num_query_heads, num_kv_heads, page_size, datatype) + + kv_lens_per_rank = [config["kv_len"]] * config["num_seqs"] + kv_lens_tensor = torch.tensor(kv_lens_per_rank, dtype=torch.int32).cuda() + global_kv_lens_tensor = kv_lens_tensor.unsqueeze(0).repeat(world_size, 1) + + def run_experiment(): + return fd_layer( + tensor_data["query"], + tensor_data["key_cache_this_rank"], + tensor_data["value_cache_this_rank"], + global_kv_lens_tensor, + tensor_data["block_tables_this_rank"], + ) + + time_ms = iris.do_bench( + fn=run_experiment, + barrier_fn=dist.barrier, + n_warmup=args.warmup_iterations, + n_repeat=args.repeat_iterations, + return_mode="mean", + ) + dist.barrier() + + if rank == 0: + global_kv_len = config["kv_len"] * world_size + print(f"Result -> Global KV Length: {global_kv_len}, Avg. Time: {time_ms:.3f} ms") + + result_entry = config.copy() + result_entry["global_kv_len"] = global_kv_len + result_entry["avg_time_ms"] = time_ms + + filename = f"h{config['num_heads']}_d{config['head_dim']}_s{config['num_seqs']}_kv{config['kv_len']}.json" + output_path = os.path.join(output_dir, filename) + + with open(output_path, "w") as f: + json.dump(result_entry, f, indent=4) + print(f"Saved result to '{output_path}'") + + if rank == 0: + print("\nBenchmark sweep complete.") + + dist.barrier() + dist.destroy_process_group() + + +if __name__ == "__main__": + args = parse_args() + run_benchmark(args) diff --git a/examples/benchmark/reference/flash_decode_rccl/decode_kernels.py b/examples/benchmark/reference/flash_decode_rccl/decode_kernels.py new file mode 100644 index 00000000..a8e37431 --- /dev/null +++ b/examples/benchmark/reference/flash_decode_rccl/decode_kernels.py @@ -0,0 +1,369 @@ +################################################################################ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. +# +# +# Part of the code adapted from +# https://github.com/ByteDance-Seed/Triton-distributed/blob/main/python/triton_dist/kernels/nvidia/flash_decode.py +# +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Permission is hereby granted, free of charge, to any person obtaining +# a copy of this software and associated documentation files +# (the "Software"), to deal in the Software without restriction, +# including without limitation the rights to use, copy, modify, merge, +# publish, distribute, sublicense, and/or sell copies of the Software, +# and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be +# included in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +# +################################################################################ + +import torch +import triton +import math +import os +import triton.language as tl +from triton.language.extra import libdevice +import iris + + +def gqa_local_kernels( + q, + k_cache, + v_cache, + workspace, + q_lens, + kv_lens, + block_table, + scale, + soft_cap=0.0, + output_split=None, + output_combine=None, + kv_split=-1, +): + batch, q_heads, q_head_dim = q.shape + _, page_size, kv_heads, k_head_dim = k_cache.shape + assert page_size == v_cache.shape[1] and kv_heads == v_cache.shape[2] and k_head_dim == q_head_dim + v_head_dim = v_cache.shape[-1] + + BLOCK_N = 64 + BLOCK_HEAD_DIM = 2 ** int(math.log2(q_head_dim)) + BLOCK_DPE = q_head_dim - BLOCK_HEAD_DIM + BLOCK_DV = triton.next_power_of_2(v_head_dim) + + kv_group_num = q_heads // kv_heads + assert q_heads % kv_heads == 0 + + BLOCK_H = 16 + NUM_KV_SPLITS = 32 if kv_split == -1 else kv_split + + grid_split_kv = (batch, triton.cdiv(q_heads, min(BLOCK_H, kv_group_num)), NUM_KV_SPLITS) + + output_split = ( + torch.empty([batch, q_heads, NUM_KV_SPLITS, v_head_dim + 1], dtype=q.dtype, device=q.device) + if output_split is None + else output_split + ) + output_combine = ( + torch.empty([batch, q_heads, v_head_dim + 1], dtype=q.dtype, device=q.device) + if output_combine is None + else output_combine + ) + + gqa_local_decode_split_k[grid_split_kv]( + q, + k_cache, + v_cache, + output_split, + scale, + block_table, + kv_lens, + # shape + batch, + # strides + q.stride(0), + q.stride(1), + k_cache.stride(-3), + k_cache.stride(-2), + v_cache.stride(-3), + v_cache.stride(-2), + output_split.stride(0), + output_split.stride(1), + output_split.stride(2), + block_table.stride(0), + # constants + kv_group_num, + q_heads, + BLOCK_HEAD_DIM, + BLOCK_DPE, + BLOCK_DV, + BLOCK_N, + BLOCK_H, + NUM_KV_SPLITS, + page_size, + soft_cap, + k_head_dim, + v_head_dim, + num_warps=4, + num_stages=2, + ) + + gqa_reduce_local[(batch, q_heads)]( + output_split, + output_combine, + kv_lens, + batch, + q_heads, + output_split.stride(0), + output_split.stride(1), + output_split.stride(2), + output_combine.stride(0), + output_combine.stride(1), + NUM_KV_SPLITS, + BLOCK_DV, + v_head_dim, + num_warps=4, + num_stages=2, + ) + + return output_combine + + +@triton.jit +def gqa_local_decode_split_k( + q_ptr, + k_cache_ptr, + v_cache_ptr, + output_ptr, + sm_scale, + block_table_ptr, + kv_length_ptr, + # shape + batch, + # strides + stride_q_bs, + stride_q_h, + stride_k_cache_bs, + stride_k_cache_h, + stride_v_cache_bs, + stride_v_cache_h, + stride_o_bs, + stride_o_h, + stride_o_split, + stride_table_bs, + # constants + kv_group_num: tl.constexpr, + q_head_num: tl.constexpr, + BLOCK_HEAD_DIM: tl.constexpr, + BLOCK_DPE: tl.constexpr, + BLOCK_DV: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_H: tl.constexpr, + NUM_KV_SPLITS: tl.constexpr, + PAGE_SIZE: tl.constexpr, + soft_cap: tl.constexpr, + K_DIM: tl.constexpr, + V_DIM: tl.constexpr, +): + bid = tl.program_id(0) + hid = tl.program_id(1) + kv_hid = hid // tl.cdiv(kv_group_num, BLOCK_H) + split_kv_id = tl.program_id(2) + + if kv_group_num > BLOCK_H: + VALID_BLOCK_H: tl.constexpr = BLOCK_H + else: + VALID_BLOCK_H: tl.constexpr = kv_group_num + + cur_head = hid * VALID_BLOCK_H + tl.arange(0, BLOCK_H) + mask_h = (cur_head < (hid + 1) * VALID_BLOCK_H) & (cur_head < q_head_num) + + offs_d = tl.arange(0, BLOCK_HEAD_DIM) + offs_dv = tl.arange(0, BLOCK_DV) + mask_d = offs_d < K_DIM + mask_dv = offs_dv < V_DIM + cur_kv_seq_len = tl.load(kv_length_ptr + bid) + + offs_q = bid * stride_q_bs + cur_head[:, None] * stride_q_h + offs_d[None, :] + q = tl.load(q_ptr + offs_q, mask=(mask_h[:, None]) & (mask_d[None, :]), other=0.0) + + if BLOCK_DPE > 0: + offs_dpe = BLOCK_HEAD_DIM + tl.arange(0, BLOCK_DPE) + mask_dpe = offs_dpe < K_DIM + offs_qpe = bid * stride_q_bs + cur_head[:, None] * stride_q_h + offs_dpe[:, None] + qpe = tl.load(q_ptr + offs_qpe, mask=mask_h[:, None] & mask_dpe[None, :], other=0.0) + + kv_len_per_split = tl.cdiv(cur_kv_seq_len, NUM_KV_SPLITS) + split_kv_start = kv_len_per_split * split_kv_id + split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_kv_seq_len) + + e_max = tl.zeros([BLOCK_H], dtype=tl.float32) - float("inf") + e_sum = tl.zeros([BLOCK_H], dtype=tl.float32) + acc = tl.zeros([BLOCK_H, BLOCK_DV], dtype=tl.float32) + + for start_n in range(split_kv_start, split_kv_end, BLOCK_N): + offs_n = start_n + tl.arange(0, BLOCK_N) + kv_page_number = tl.load( + block_table_ptr + bid * stride_table_bs + offs_n // PAGE_SIZE, mask=offs_n < split_kv_end, other=0 + ) + kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE + offs_cache_k = kv_loc[None, :] * stride_k_cache_bs + kv_hid * stride_k_cache_h + offs_d[:, None] + k = tl.load(k_cache_ptr + offs_cache_k, mask=(offs_n[None, :] < split_kv_end) & mask_d[:, None], other=0.0) + qk = tl.dot(q, k.to(q.dtype)) + + if BLOCK_DPE > 0: + offs_cache_kpe = kv_loc[None, :] * stride_k_cache_bs + kv_hid * stride_k_cache_h + offs_dpe[:, None] + kpe = tl.load( + k_cache_ptr + offs_cache_kpe, mask=(offs_n[None, :] < split_kv_end) & mask_dpe[:, None], other=0.0 + ) + qk += tl.dot(qpe, kpe.to(qpe.dtype)) + + qk *= sm_scale + + if soft_cap > 0: + qk = soft_cap * libdevice.tanh(qk / soft_cap) + + qk = tl.where(mask_h[:, None] & (offs_n[None, :] < split_kv_end), qk, float("-inf")) + + offs_cache_v = kv_loc[:, None] * stride_v_cache_bs + kv_hid * stride_v_cache_h + offs_dv[None, :] + v = tl.load(v_cache_ptr + offs_cache_v, mask=(offs_n[:, None] < split_kv_end) & mask_dv[None, :], other=0.0) + + n_e_max = tl.maximum(tl.max(qk, 1), e_max) + re_scale = libdevice.fast_expf(e_max - n_e_max) + p = libdevice.fast_expf(qk - n_e_max[:, None]) + acc *= re_scale[:, None] + acc += tl.dot(p.to(v.dtype), v) + + e_sum = e_sum * re_scale + tl.sum(p, 1) + e_max = n_e_max + + offs_out = bid * stride_o_bs + cur_head[:, None] * stride_o_h + split_kv_id * stride_o_split + offs_dv[None, :] + tl.store(output_ptr + offs_out, acc / e_sum[:, None], mask=mask_h[:, None] & mask_dv[None, :]) + + offs_log = bid * stride_o_bs + cur_head * stride_o_h + split_kv_id * stride_o_split + V_DIM + tl.store(output_ptr + offs_log, e_max + tl.log(e_sum), mask=mask_h) + + +@triton.jit +def gqa_reduce_local( + Mid_O, + o, + B_Seqlen, + batch, + q_heads, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + stride_obs, + stride_oh, + NUM_KV_SPLITS: tl.constexpr, + BLOCK_DV: tl.constexpr, + Lv: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + + offs_d = tl.arange(0, BLOCK_DV) + mask_d = offs_d < Lv + + e_sum = 0.0 + e_max = -float("inf") + acc = tl.zeros([BLOCK_DV], dtype=tl.float32) + + offs_v = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + offs_d + offs_logic = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + Lv + + for split_kv_id in range(0, NUM_KV_SPLITS): + kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS) + split_kv_start = kv_len_per_split * split_kv_id + split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len) + + if split_kv_end > split_kv_start: + tv = tl.load(Mid_O + offs_v + split_kv_id * stride_mid_os, mask=mask_d, other=0.0) + tlogic = tl.load(Mid_O + offs_logic + split_kv_id * stride_mid_os) + n_e_max = tl.maximum(tlogic, e_max) + + old_scale = libdevice.fast_expf(e_max - n_e_max) + acc *= old_scale + exp_logic = libdevice.fast_expf(tlogic - n_e_max) + acc += exp_logic * tv + + e_sum = e_sum * old_scale + exp_logic + e_max = n_e_max + + tl.store( + o + cur_batch * stride_obs + cur_head * stride_oh + offs_d, + acc / e_sum, + mask=mask_d, + ) + tl.store( + o + cur_batch * stride_obs + cur_head * stride_oh + Lv, + e_max + tl.log(e_sum), + ) + + +@triton.jit +def gqa_reduce_global( + Mid_O, + o, + B_Seqlens, + batch, + q_heads, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + stride_obs, + stride_oh, + NUM_KV_SPLITS: tl.constexpr, + BLOCK_DV: tl.constexpr, + Lv: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + + cur_batch_seq_len_ptr = B_Seqlens + cur_batch + + offs_d = tl.arange(0, BLOCK_DV) + mask_d = offs_d < Lv + + e_sum = 0.0 + e_max = -float("inf") + acc = tl.zeros([BLOCK_DV], dtype=tl.float32) + + offs_v = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + offs_d + offs_logic = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + Lv + + for split_kv_id in range(0, NUM_KV_SPLITS): + effective_kv_len = tl.load(cur_batch_seq_len_ptr + split_kv_id * batch) + + if effective_kv_len > 0: + tv = tl.load(Mid_O + offs_v + split_kv_id * stride_mid_os, mask=mask_d, other=0.0) + tlogic = tl.load(Mid_O + offs_logic + split_kv_id * stride_mid_os) + n_e_max = tl.maximum(tlogic, e_max) + + old_scale = libdevice.fast_expf(e_max - n_e_max) + acc *= old_scale + exp_logic = libdevice.fast_expf(tlogic - n_e_max) + acc += exp_logic * tv + + e_sum = e_sum * old_scale + exp_logic + e_max = n_e_max + + tl.store( + o + cur_batch * stride_obs + cur_head * stride_oh + offs_d, + acc / e_sum, + mask=mask_d, + ) diff --git a/examples/benchmark/reference/flash_decode_rccl/flash_decode_layer_rccl.py b/examples/benchmark/reference/flash_decode_rccl/flash_decode_layer_rccl.py new file mode 100644 index 00000000..5dd28eae --- /dev/null +++ b/examples/benchmark/reference/flash_decode_rccl/flash_decode_layer_rccl.py @@ -0,0 +1,136 @@ +################################################################################ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. +# +# +# Part of the code adapted from +# https://github.com/ByteDance-Seed/Triton-distributed/blob/main/python/triton_dist/layers/nvidia/sp_flash_decode_layer.py################################################################################ +# +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Permission is hereby granted, free of charge, to any person obtaining +# a copy of this software and associated documentation files +# (the "Software"), to deal in the Software without restriction, +# including without limitation the rights to use, copy, modify, merge, +# publish, distribute, sublicense, and/or sell copies of the Software, +# and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be +# included in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +# +################################################################################ + + +import torch +import torch.distributed as dist +from decode_kernels import gqa_local_kernels, gqa_reduce_global + + +class flash_decode_layer_rccl(torch.nn.Module): + def __init__( + self, + rank: int, + num_ranks: int, + num_q_heads: int, + num_kv_heads: int, + q_head_dim: int, + v_head_dim: int, + process_group, + page_size: int = 1, + scale: float = 1.0, + soft_cap: float = 0.0, + max_allowed_batch: int = 1, + ): + super().__init__() + self.rank = rank + self.num_ranks = num_ranks + self.process_group = process_group + + self.num_q_heads = num_q_heads + self.num_kv_heads = num_kv_heads + self.q_head_dim = q_head_dim + self.v_head_dim = v_head_dim + self.page_size = page_size + self.soft_cap = soft_cap + self.scale = scale + + self.kv_split = 32 + self.max_allowed_batch = max_allowed_batch + + def forward( + self, + q: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + global_kv_lens: torch.Tensor, + block_table: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + q: Query tensor, identical across all ranks. Shape: [batch, num_q_heads, head_size] + k_cache: This rank's shard of the key cache. + v_cache: This rank's shard of the value cache. + global_kv_lens: A tensor containing the sequence lengths of the K/V cache shards on all ranks. Shape: [num_ranks, batch] + block_table: The block table for this rank's K/V cache. + Returns: + The final attention output tensor. Shape: [batch, num_q_heads, head_size] + """ + batch_size = q.shape[0] + + assert global_kv_lens.shape[0] == self.num_ranks, "global_kv_lens must have a dimension for each rank." + assert global_kv_lens.shape[1] == batch_size, "global_kv_lens batch dimension mismatch." + assert batch_size <= self.max_allowed_batch, ( + f"Input batch size {batch_size} exceeds max allowed {self.max_allowed_batch}." + ) + + output_combine = torch.empty( + [batch_size, self.num_q_heads, self.v_head_dim + 1], dtype=q.dtype, device=q.device + ) + final_output = torch.empty([batch_size, self.num_q_heads, self.v_head_dim], dtype=q.dtype, device=q.device) + + all_ranks_output_combine = torch.empty( + [self.num_ranks, batch_size, self.num_q_heads, self.v_head_dim + 1], dtype=q.dtype, device=q.device + ) + + gqa_local_kernels( + q, + k_cache, + v_cache, + workspace=None, + q_lens=[1] * batch_size, + kv_lens=global_kv_lens[self.rank], + block_table=block_table, + scale=self.scale, + soft_cap=self.soft_cap, + output_combine=output_combine, + kv_split=self.kv_split, + ) + + dist.all_gather_into_tensor(all_ranks_output_combine, output_combine, group=self.process_group) + + gqa_reduce_global[(batch_size, self.num_q_heads, 1)]( + all_ranks_output_combine, + final_output, + global_kv_lens, + batch_size, + self.num_q_heads, + all_ranks_output_combine.stride(1), # stride_mid_ob + all_ranks_output_combine.stride(2), # stride_mid_oh + all_ranks_output_combine.stride(0), # stride_mid_os + final_output.stride(0), # stride_obs + final_output.stride(1), # stride_oh + self.num_ranks, # NUM_KV_SPLITS + 512, # BLOCK_DV + self.v_head_dim, # Lv + ) + + return final_output diff --git a/examples/benchmark/reference/flash_decode_rccl/validate_flash_decode_rccl.py b/examples/benchmark/reference/flash_decode_rccl/validate_flash_decode_rccl.py new file mode 100644 index 00000000..8aedf1ff --- /dev/null +++ b/examples/benchmark/reference/flash_decode_rccl/validate_flash_decode_rccl.py @@ -0,0 +1,217 @@ +################################################################################ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. +# +# +# Part of the code adapted from +# https://github.com/ByteDance-Seed/Triton-distributed/blob/main/python/triton_dist/test/nvidia/test_sp_decode_attn.py +# +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Permission is hereby granted, free of charge, to any person obtaining +# a copy of this software and associated documentation files +# (the "Software"), to deal in the Software without restriction, +# including without limitation the rights to use, copy, modify, merge, +# publish, distribute, sublicense, and/or sell copies of the Software, +# and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be +# included in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +# +################################################################################ + +import sys +import os +from pathlib import Path +import pytest +from typing import List, Optional +from argparse import Namespace + +import torch +import torch.distributed as dist +from examples.benchmark.reference.flash_decode_rccl.flash_decode_layer_rccl import flash_decode_layer_rccl + +project_root = Path(__file__).resolve() +while not (project_root / "tests").is_dir() or not (project_root / "examples").is_dir(): + if project_root == project_root.parent: + raise FileNotFoundError( + "Could not find project root. Make sure your 'tests' and 'examples' " + "directories are siblings in the project structure." + ) + project_root = project_root.parent +print(f"Discovered Project Root: {project_root}") + +module_dir = project_root / "examples" / "13_flash_decode" +print(f"Target Module Directory: {module_dir}") + +target_file = module_dir / "fd_layer_rccl.py" +if module_dir.exists(): + sys.path.insert(0, str(module_dir)) + print(f"'{module_dir}' was added to sys.path.") +else: + print("ERROR: Target directory not found. Not modifying sys.path.") + +from utils import print_correctness_report # noqa: E402 + + +def ref_paged_attn( + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + query_lens: List[int], + kv_lens_per_rank: List[int], + block_tables: torch.Tensor, + scale: float, + soft_cap: Optional[float] = None, +) -> torch.Tensor: + num_seqs = len(query_lens) + block_tables_cpu = block_tables.cpu().numpy() + _, block_size, num_kv_heads, head_size = key_cache.shape + outputs: List[torch.Tensor] = [] + start_idx = 0 + for i in range(num_seqs): + query_len, kv_len = query_lens[i], kv_lens_per_rank[i] + q = query[start_idx : start_idx + query_len] + q *= scale + num_kv_blocks = (kv_len + block_size - 1) // block_size + block_indices = block_tables_cpu[i, :num_kv_blocks] + k = key_cache[block_indices].view(-1, num_kv_heads, head_size)[:kv_len] + v = value_cache[block_indices].view(-1, num_kv_heads, head_size)[:kv_len] + if q.shape[1] != k.shape[1]: + gqa_ratio = q.shape[1] // k.shape[1] + k = torch.repeat_interleave(k, gqa_ratio, dim=1) + v = torch.repeat_interleave(v, gqa_ratio, dim=1) + attn = torch.einsum("qhd,khd->hqk", q, k).float() + empty_mask = torch.ones(query_len, kv_len, device=query.device) + mask = torch.triu(empty_mask, diagonal=kv_len - query_len + 1).bool() + if soft_cap is not None and soft_cap > 0.0: + attn = soft_cap * torch.tanh(attn / soft_cap) + attn.masked_fill_(mask, float("-inf")) + attn = torch.softmax(attn, dim=-1).to(v.dtype) + out = torch.einsum("hqk,khd->qhd", attn, v) + outputs.append(out) + start_idx += query_len + return torch.cat(outputs, dim=0) + + +def prepare_correctness_data(cfg, args, num_query_heads, num_kv_heads, num_blocks_total): + """Creates data on Rank 0 and broadcasts it using torch.distributed.""" + head_dim = cfg["head_dim"] + if args.rank == 0: + query = torch.randn(cfg["num_seqs"], num_query_heads, head_dim, dtype=cfg["dtype"], device="cuda") / 10 + key_value_cache = ( + torch.randn( + num_blocks_total, 2, cfg["block_size"], num_kv_heads, head_dim, dtype=cfg["dtype"], device="cuda" + ) + / 10 + ) + else: + query = torch.empty(cfg["num_seqs"], num_query_heads, head_dim, dtype=cfg["dtype"], device="cuda") + key_value_cache = torch.empty( + num_blocks_total, 2, cfg["block_size"], num_kv_heads, head_dim, dtype=cfg["dtype"], device="cuda" + ) + + dist.broadcast(query, src=0, group=args.tp_group) + dist.broadcast(key_value_cache, src=0, group=args.tp_group) + + return {"query": query, "key_value_cache": key_value_cache} + + +@pytest.mark.parametrize("head_dim", [128]) +@pytest.mark.parametrize("num_seqs", [1, 8]) +@pytest.mark.parametrize("num_heads", [48, 96]) +@pytest.mark.parametrize("kv_len", [4096, 65536]) +def test_correctness_rccl_fused_full(kv_len, num_heads, num_seqs, head_dim): + """ + Tests the correctness of the RCCL Fused implementation against the Torch reference. + """ + rank = dist.get_rank() + torch.cuda.set_device(rank) + + args = Namespace() + args.rank = dist.get_rank() + args.world_size = dist.get_world_size() + args.tp_group = dist.new_group(ranks=range(args.world_size)) + + config = { + "kv_len": kv_len, + "num_heads": num_heads, + "num_seqs": num_seqs, + "head_dim": head_dim, + "dtype": torch.float16, + "block_size": 1, + "soft_cap": 0.0, + } + + # torch.manual_seed(42) + + num_query_heads = num_heads + num_kv_heads = num_query_heads // 8 if num_query_heads >= 8 else 1 + scale = head_dim**-0.5 + num_blocks_per_rank = (config["kv_len"] + config["block_size"] - 1) // config["block_size"] + num_blocks_total = num_blocks_per_rank * args.world_size + + tensor_data = prepare_correctness_data(config, args, num_query_heads, num_kv_heads, num_blocks_total) + query = tensor_data["query"] + key_value_cache = tensor_data["key_value_cache"] + + key_cache = key_value_cache[:, 0].contiguous() + value_cache = key_value_cache[:, 1].contiguous() + key_cache_this_rank = key_cache[args.rank * num_blocks_per_rank : (args.rank + 1) * num_blocks_per_rank] + value_cache_this_rank = value_cache[args.rank * num_blocks_per_rank : (args.rank + 1) * num_blocks_per_rank] + + block_tables_this_rank = torch.arange(num_blocks_per_rank, dtype=torch.int32).repeat(num_seqs, 1).cuda() + + gathered_tables_list = [torch.empty_like(block_tables_this_rank) for _ in range(args.world_size)] + dist.all_gather(gathered_tables_list, block_tables_this_rank, group=args.tp_group) + ref_block_tables = torch.cat([tbl + r * num_blocks_per_rank for r, tbl in enumerate(gathered_tables_list)], dim=-1) + + keyword_params = { + "page_size": config["block_size"], + "scale": scale, + "soft_cap": config["soft_cap"], + "max_allowed_batch": config["num_seqs"], + } + fd_layer = flash_decode_layer_rccl( + args.rank, args.world_size, num_query_heads, num_kv_heads, head_dim, head_dim, args.tp_group, **keyword_params + ) + dist.barrier(group=args.tp_group) + + kv_lens_per_rank = [config["kv_len"]] * num_seqs + kv_lens_tensor = torch.tensor(kv_lens_per_rank, dtype=torch.int32).cuda() + global_kv_lens_tensor = kv_lens_tensor.unsqueeze(0).repeat(args.world_size, 1) + + output = fd_layer(query, key_cache_this_rank, value_cache_this_rank, global_kv_lens_tensor, block_tables_this_rank) + torch.cuda.synchronize() + + ref_output = ref_paged_attn( + query=query.clone(), + key_cache=key_cache, + value_cache=value_cache, + query_lens=[1] * num_seqs, + kv_lens_per_rank=[config["kv_len"] * args.world_size] * num_seqs, + block_tables=ref_block_tables, + scale=scale, + soft_cap=config["soft_cap"], + ) + dist.barrier(group=args.tp_group) + + error = None + try: + torch.testing.assert_close(output, ref_output, atol=1e-4, rtol=1e-4) + except AssertionError as e: + error = e + + print_correctness_report(args.rank, output, ref_output, error) + + if error: + raise error diff --git a/examples/benchmark/reference/reduce_scatter.py b/examples/benchmark/reference/reduce_scatter.py index ed05cbe3..3cdf3603 100755 --- a/examples/benchmark/reference/reduce_scatter.py +++ b/examples/benchmark/reference/reduce_scatter.py @@ -8,6 +8,7 @@ import random import iris import argparse +import os from examples.common.utils import JSONWriter @@ -43,7 +44,8 @@ def main(): m, n, k = args["m"], args["n"], args["k"] validate, benchmark = args["validate"], args["benchmark"] - dist.init_process_group("nccl") + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + dist.init_process_group("nccl", device_id=torch.device(f"cuda:{local_rank}")) rank = dist.get_rank() world_size = dist.get_world_size() torch.cuda.set_device(rank) diff --git a/examples/common/utils.py b/examples/common/utils.py index 5d96e891..0e6ea948 100644 --- a/examples/common/utils.py +++ b/examples/common/utils.py @@ -19,6 +19,28 @@ ALL_GATHER = tl.constexpr(6) +dtype_map = { + "fp32": torch.float32, + "fp16": torch.float16, + "bf16": torch.bfloat16, + "int8": torch.int8, + "int32": torch.int32, + "int64": torch.int64, +} + + +def torch_dtype_from_str(datatype: str) -> torch.dtype: + try: + return dtype_map[datatype] + except KeyError: + print(f"Unknown datatype: {datatype}") + exit(1) + + +def torch_dtype_to_str(dtype: torch.dtype) -> str: + return list(dtype_map.keys())[list(dtype_map.values()).index(dtype)] + + class JSONWriter: def __init__(self, file_path): self.file_path = file_path diff --git a/iris/README.md b/iris/README.md index 6cbb605b..ceb69355 100644 --- a/iris/README.md +++ b/iris/README.md @@ -5,6 +5,4 @@ Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. # Iris library -Python- and Triton-based library facilitating RDMAs for intra-node communication via IPC conduit. - -The `csrc/finegrained_alloc` directory contains a C library interface for fine-grained allocation. The plugin is required to redirect PyTorch allocation to fine-grained memory. \ No newline at end of file +Python- and Triton-based library facilitating RDMAs for intra-node communication via IPC conduit. \ No newline at end of file diff --git a/iris/__init__.py b/iris/__init__.py index 560d157b..ce96b149 100644 --- a/iris/__init__.py +++ b/iris/__init__.py @@ -11,7 +11,7 @@ This package provides: - Iris: Main class for multi-GPU operations - Atomic operations: add, sub, cas, xchg, xor, and, or, min, max -- Memory operations: load, store, get, put +- Memory operations: load, store, copy, get, put - Utility functions: do_bench - HIP integration for AMD GPU support - Logging utilities with rank information @@ -24,14 +24,12 @@ # __init__.py -import os -import torch - from .iris import ( Iris, iris, load, store, + copy, get, put, atomic_add, @@ -62,30 +60,12 @@ # Launcher functionality is now user code - see examples and documentation -# Pipe allocations via finegrained allocator -current_dir = os.path.dirname(__file__) -# Look for the library in the installed package location -finegrained_alloc_path = os.path.join(current_dir, "csrc", "finegrained_alloc", "libfinegrained_allocator.so") - -# Check if the library exists (should be built during pip install) -if not os.path.exists(finegrained_alloc_path): - raise RuntimeError( - f"Fine-grained allocator library not found at {finegrained_alloc_path}. " - "Please ensure the package was installed correctly." - ) - -finegrained_allocator = torch.cuda.memory.CUDAPluggableAllocator( - finegrained_alloc_path, - "finegrained_hipMalloc", - "finegrained_hipFree", -) -torch.cuda.memory.change_current_allocator(finegrained_allocator) - __all__ = [ "Iris", "iris", "load", "store", + "copy", "get", "put", "atomic_add", diff --git a/iris/_distributed_helpers.py b/iris/_distributed_helpers.py index ce656f93..287cda48 100644 --- a/iris/_distributed_helpers.py +++ b/iris/_distributed_helpers.py @@ -73,6 +73,28 @@ def distributed_allgather(data): return np.stack(obj_list, axis=0) +def distributed_allgather_multidim(data): + """ + All-gather operation for multi-dimensional tensors using PyTorch distributed. + """ + if not dist.is_initialized(): + raise RuntimeError("PyTorch distributed is not initialized") + + world_size = dist.get_world_size() + device = _infer_device() + + input_tensor = torch.as_tensor(data).to(device) + + tensor_list = [torch.empty_like(input_tensor) for _ in range(world_size)] + + dist.all_gather(tensor_list, input_tensor) + + stacked_tensor = torch.stack(tensor_list, dim=0) + reshaped_tensor = stacked_tensor.view(world_size, -1) + + return reshaped_tensor.cpu().numpy() + + def distributed_broadcast_scalar(value=None, root=0): """ Broadcast a scalar value from root to all ranks. @@ -108,9 +130,15 @@ def distributed_broadcast_scalar(value=None, root=0): # If NCCL can't handle this dtype, just broadcast the object directly. if backend == "nccl": # Try a quick check using a tiny tensor of the dtype - torch_dtype = torch.from_numpy(np.array(0, dtype=dtype)).dtype - dummy = torch.empty((), dtype=torch_dtype) - if not _nccl_dtype_supported(dummy): + try: + torch_dtype = torch.from_numpy(np.array(0, dtype=dtype)).dtype + dummy = torch.empty((), dtype=torch_dtype) + if not _nccl_dtype_supported(dummy): + obj = [value if rank == root else None] + dist.broadcast_object_list(obj, src=root) + return obj[0] + except (TypeError, ValueError): + # Dtype not supported by torch (e.g., str, object), use object broadcast obj = [value if rank == root else None] dist.broadcast_object_list(obj, src=root) return obj[0] @@ -123,6 +151,54 @@ def distributed_broadcast_scalar(value=None, root=0): return val_t.to("cpu").item() +def distributed_broadcast_tensor(value_to_broadcast=None, root=0): + """ + Broadcast a tensor/array from root to all ranks. + + Args: + value_to_broadcast: Tensor or array to broadcast (only used on root rank) + root: Root rank to broadcast from + + Returns: + Broadcasted numpy array + """ + if not dist.is_initialized(): + raise RuntimeError("PyTorch distributed is not initialized") + + rank = dist.get_rank() + device = _infer_device() + backend = str(dist.get_backend()).lower() + + if rank == root: + if value_to_broadcast is None: + raise ValueError("Root must provide a value to broadcast.") + tensor = torch.as_tensor(value_to_broadcast) + metadata = [tensor.shape, tensor.dtype] + else: + metadata = [None, None] + tensor = None + + dist.broadcast_object_list(metadata, src=root) + shape, dtype = metadata + + if rank != root: + tensor = torch.empty(shape, dtype=dtype) + + use_tensor_collective = backend != "nccl" or _nccl_dtype_supported(tensor) + + if use_tensor_collective: + tensor = tensor.to(device) + dist.broadcast(tensor, src=root) + return tensor.to("cpu").numpy() + else: + if rank == root: + obj = [np.asarray(value_to_broadcast)] + else: + obj = [None] + dist.broadcast_object_list(obj, src=root) + return obj[0] + + def distributed_barrier(): """ Synchronization barrier using PyTorch distributed. diff --git a/iris/hip.py b/iris/hip.py index 2a03c397..89807860 100644 --- a/iris/hip.py +++ b/iris/hip.py @@ -5,74 +5,129 @@ import numpy as np import sys import torch - -rt_path = "libamdhip64.so" -hip_runtime = ctypes.cdll.LoadLibrary(rt_path) +import subprocess +import os + +# Auto-detect backend +_is_amd_backend = True +try: + rt_path = "libamdhip64.so" + gpu_runtime = ctypes.cdll.LoadLibrary(rt_path) +except OSError: + try: + rt_path = "libcudart.so" + gpu_runtime = ctypes.cdll.LoadLibrary(rt_path) + _is_amd_backend = False + except OSError: + rt_path = "libamdhip64.so" + gpu_runtime = ctypes.cdll.LoadLibrary(rt_path) + + +def gpu_try(err): + if err != 0: + if _is_amd_backend: + gpu_runtime.hipGetErrorString.restype = ctypes.c_char_p + error_string = gpu_runtime.hipGetErrorString(ctypes.c_int(err)).decode("utf-8") + raise RuntimeError(f"HIP error code {err}: {error_string}") + else: + gpu_runtime.cudaGetErrorString.restype = ctypes.c_char_p + error_string = gpu_runtime.cudaGetErrorString(ctypes.c_int(err)).decode("utf-8") + raise RuntimeError(f"CUDA error code {err}: {error_string}") -def hip_try(err): - if err != 0: - hip_runtime.hipGetErrorString.restype = ctypes.c_char_p - error_string = hip_runtime.hipGetErrorString(ctypes.c_int(err)).decode("utf-8") - raise RuntimeError(f"HIP error code {err}: {error_string}") +def get_ipc_handle_size(): + """Return the IPC handle size for the current backend.""" + return 64 if _is_amd_backend else 128 -class hipIpcMemHandle_t(ctypes.Structure): - _fields_ = [("reserved", ctypes.c_char * 64)] +class gpuIpcMemHandle_t(ctypes.Structure): + _fields_ = [("reserved", ctypes.c_char * get_ipc_handle_size())] def open_ipc_handle(ipc_handle_data, rank): ptr = ctypes.c_void_p() - hipIpcMemLazyEnablePeerAccess = ctypes.c_uint(1) - hip_runtime.hipIpcOpenMemHandle.argtypes = [ - ctypes.POINTER(ctypes.c_void_p), - hipIpcMemHandle_t, - ctypes.c_uint, - ] + handle_size = get_ipc_handle_size() + + if _is_amd_backend: + hipIpcMemLazyEnablePeerAccess = ctypes.c_uint(1) + gpu_runtime.hipIpcOpenMemHandle.argtypes = [ + ctypes.POINTER(ctypes.c_void_p), + gpuIpcMemHandle_t, + ctypes.c_uint, + ] + else: + gpu_runtime.cudaIpcOpenMemHandle.argtypes = [ + ctypes.POINTER(ctypes.c_void_p), + gpuIpcMemHandle_t, + ctypes.c_uint, + ] + cudaIpcMemLazyEnablePeerAccess = ctypes.c_uint(1) + if isinstance(ipc_handle_data, np.ndarray): - if ipc_handle_data.dtype != np.uint8 or ipc_handle_data.size != 64: - raise ValueError("ipc_handle_data must be a 64-element uint8 numpy array") + if ipc_handle_data.dtype != np.uint8 or ipc_handle_data.size != handle_size: + raise ValueError(f"ipc_handle_data must be a {handle_size}-element uint8 numpy array") ipc_handle_bytes = ipc_handle_data.tobytes() - ipc_handle_data = (ctypes.c_char * 64).from_buffer_copy(ipc_handle_bytes) + ipc_handle_data = (ctypes.c_char * handle_size).from_buffer_copy(ipc_handle_bytes) else: - raise TypeError("ipc_handle_data must be a numpy.ndarray of dtype uint8 with 64 elements") + raise TypeError(f"ipc_handle_data must be a numpy.ndarray of dtype uint8 with {handle_size} elements") - raw_memory = ctypes.create_string_buffer(64) - ctypes.memset(raw_memory, 0x00, 64) - ipc_handle_struct = hipIpcMemHandle_t.from_buffer(raw_memory) + raw_memory = ctypes.create_string_buffer(handle_size) + ctypes.memset(raw_memory, 0x00, handle_size) + ipc_handle_struct = gpuIpcMemHandle_t.from_buffer(raw_memory) ipc_handle_data_bytes = bytes(ipc_handle_data) - ctypes.memmove(raw_memory, ipc_handle_data_bytes, 64) - - hip_try( - hip_runtime.hipIpcOpenMemHandle( - ctypes.byref(ptr), - ipc_handle_struct, - hipIpcMemLazyEnablePeerAccess, + ctypes.memmove(raw_memory, ipc_handle_data_bytes, handle_size) + + if _is_amd_backend: + gpu_try( + gpu_runtime.hipIpcOpenMemHandle( + ctypes.byref(ptr), + ipc_handle_struct, + hipIpcMemLazyEnablePeerAccess, + ) + ) + else: + gpu_try( + gpu_runtime.cudaIpcOpenMemHandle( + ctypes.byref(ptr), + ipc_handle_struct, + cudaIpcMemLazyEnablePeerAccess, + ) ) - ) return ptr.value def get_ipc_handle(ptr, rank): - ipc_handle = hipIpcMemHandle_t() - hip_try(hip_runtime.hipIpcGetMemHandle(ctypes.byref(ipc_handle), ptr)) + ipc_handle = gpuIpcMemHandle_t() + if _is_amd_backend: + gpu_try(gpu_runtime.hipIpcGetMemHandle(ctypes.byref(ipc_handle), ptr)) + else: + gpu_try(gpu_runtime.cudaIpcGetMemHandle(ctypes.byref(ipc_handle), ptr)) return ipc_handle def count_devices(): device_count = ctypes.c_int() - hip_try(hip_runtime.hipGetDeviceCount(ctypes.byref(device_count))) + if _is_amd_backend: + gpu_try(gpu_runtime.hipGetDeviceCount(ctypes.byref(device_count))) + else: + gpu_try(gpu_runtime.cudaGetDeviceCount(ctypes.byref(device_count))) return device_count.value def set_device(gpu_id): - hip_try(hip_runtime.hipSetDevice(gpu_id)) + if _is_amd_backend: + gpu_try(gpu_runtime.hipSetDevice(gpu_id)) + else: + gpu_try(gpu_runtime.cudaSetDevice(gpu_id)) def get_device_id(): device_id = ctypes.c_int() - hip_try(hip_runtime.hipGetDevice(ctypes.byref(device_id))) + if _is_amd_backend: + gpu_try(gpu_runtime.hipGetDevice(ctypes.byref(device_id))) + else: + gpu_try(gpu_runtime.cudaGetDevice(ctypes.byref(device_id))) return device_id.value @@ -80,65 +135,123 @@ def get_cu_count(device_id=None): if device_id is None: device_id = get_device_id() - hipDeviceAttributeMultiprocessorCount = 63 cu_count = ctypes.c_int() - hip_try(hip_runtime.hipDeviceGetAttribute(ctypes.byref(cu_count), hipDeviceAttributeMultiprocessorCount, device_id)) + if _is_amd_backend: + hipDeviceAttributeMultiprocessorCount = 63 + gpu_try( + gpu_runtime.hipDeviceGetAttribute(ctypes.byref(cu_count), hipDeviceAttributeMultiprocessorCount, device_id) + ) + else: + cudaDevAttrMultiProcessorCount = 16 + gpu_try(gpu_runtime.cudaDeviceGetAttribute(ctypes.byref(cu_count), cudaDevAttrMultiProcessorCount, device_id)) return cu_count.value def get_rocm_version(): + if not _is_amd_backend: + # Not applicable for CUDA + return (-1, -1) + major, minor = -1, -1 - with open("/opt/rocm/.info/version", "r") as version_file: - version = version_file.readline().strip() - major = int(version.split(".")[0]) - minor = int(version.split(".")[1]) + + # Try hipconfig --path first + try: + result = subprocess.run(["hipconfig", "--path"], capture_output=True, text=True, check=True) + rocm_path = result.stdout.strip() + except (subprocess.CalledProcessError, FileNotFoundError): + # Then look for $ROCM_PATH environment variable + rocm_path = os.environ.get("ROCM_PATH") + if not rocm_path: + # Finally, try default location + rocm_path = "/opt/rocm" + + # Try to read version from .info/version file + try: + version_file_path = os.path.join(rocm_path, ".info", "version") + with open(version_file_path, "r") as version_file: + version = version_file.readline().strip() + major = int(version.split(".")[0]) + minor = int(version.split(".")[1]) + except (FileNotFoundError, IOError, ValueError, IndexError): + # If we can't read the version file, return -1, -1 + pass + return (major, minor) def get_wall_clock_rate(device_id): - hipDeviceAttributeWallClockRate = 10017 wall_clock_rate = ctypes.c_int() - status = hip_runtime.hipDeviceGetAttribute( - ctypes.byref(wall_clock_rate), hipDeviceAttributeWallClockRate, device_id - ) - hip_try(status) + + if _is_amd_backend: + hipDeviceAttributeWallClockRate = 10017 + status = gpu_runtime.hipDeviceGetAttribute( + ctypes.byref(wall_clock_rate), hipDeviceAttributeWallClockRate, device_id + ) + else: + cudaDevAttrClockRate = 13 + status = gpu_runtime.cudaDeviceGetAttribute(ctypes.byref(wall_clock_rate), cudaDevAttrClockRate, device_id) + + gpu_try(status) return wall_clock_rate.value def get_arch_string(device_id=None): if device_id is None: device_id = get_device_id() - arch_full = torch.cuda.get_device_properties(device_id).gcnArchName - arch_name = arch_full.split(":")[0] - return arch_name + + if _is_amd_backend: + arch_full = torch.cuda.get_device_properties(device_id).gcnArchName + arch_name = arch_full.split(":")[0] + return arch_name + else: + # For CUDA, return compute capability + props = torch.cuda.get_device_properties(device_id) + return f"sm_{props.major}{props.minor}" def get_num_xcc(device_id=None): if device_id is None: device_id = get_device_id() + + if not _is_amd_backend: + # XCC is AMD-specific, return 1 for CUDA + return 1 + rocm_major, _ = get_rocm_version() if rocm_major < 7: return 8 hipDeviceAttributeNumberOfXccs = 10018 xcc_count = ctypes.c_int() - hip_try(hip_runtime.hipDeviceGetAttribute(ctypes.byref(xcc_count), hipDeviceAttributeNumberOfXccs, device_id)) + gpu_try(gpu_runtime.hipDeviceGetAttribute(ctypes.byref(xcc_count), hipDeviceAttributeNumberOfXccs, device_id)) return xcc_count.value def malloc_fine_grained(size): - hipDeviceMallocFinegrained = 0x1 ptr = ctypes.c_void_p() - hip_try(hip_runtime.hipExtMallocWithFlags(ctypes.byref(ptr), size, hipDeviceMallocFinegrained)) + + if _is_amd_backend: + hipDeviceMallocFinegrained = 0x1 + gpu_try(gpu_runtime.hipExtMallocWithFlags(ctypes.byref(ptr), size, hipDeviceMallocFinegrained)) + else: + # CUDA doesn't have direct equivalent, use regular malloc + gpu_try(gpu_runtime.cudaMalloc(ctypes.byref(ptr), size)) + return ptr def hip_malloc(size): ptr = ctypes.c_void_p() - hip_try(hip_runtime.hipMalloc(ctypes.byref(ptr), size)) + if _is_amd_backend: + gpu_try(gpu_runtime.hipMalloc(ctypes.byref(ptr), size)) + else: + gpu_try(gpu_runtime.cudaMalloc(ctypes.byref(ptr), size)) return ptr def hip_free(ptr): - hip_try(hip_runtime.hipFree(ptr)) + if _is_amd_backend: + gpu_try(gpu_runtime.hipFree(ptr)) + else: + gpu_try(gpu_runtime.cudaFree(ptr)) diff --git a/iris/iris.py b/iris/iris.py index d8859077..8a7a7aa1 100644 --- a/iris/iris.py +++ b/iris/iris.py @@ -30,6 +30,7 @@ distributed_allgather, distributed_barrier, distributed_broadcast_scalar, + distributed_broadcast_tensor, ) from iris.hip import ( set_device, @@ -38,6 +39,7 @@ get_ipc_handle, open_ipc_handle, get_wall_clock_rate, + get_ipc_handle_size, ) import numpy as np import math @@ -88,13 +90,16 @@ def __init__(self, heap_size=1 << 30): heap_bases = np.zeros(num_ranks, dtype=np.uint64) heap_bases[cur_rank] = heap_base - ipc_handles = np.zeros((num_ranks, 64), dtype=np.uint8) + ipc_handle_size = get_ipc_handle_size() + ipc_handles = np.zeros((num_ranks, ipc_handle_size), dtype=np.uint8) ipc_handle = get_ipc_handle(heap_base_ptr, cur_rank) distributed_barrier() - all_ipc_handles = distributed_allgather(np.frombuffer(ipc_handle, dtype=np.uint8)) - all_heap_bases = distributed_allgather(np.array([heap_bases[cur_rank]], dtype=np.uint64)) + all_ipc_handles = distributed_allgather(np.frombuffer(ipc_handle, dtype=np.uint8).copy()) + heap_base_bytes = np.array([heap_bases[cur_rank]], dtype=np.uint64).tobytes() + all_heap_bases_bytes = distributed_allgather(np.frombuffer(heap_base_bytes, dtype=np.uint8).copy()) + all_heap_bases = np.frombuffer(all_heap_bases_bytes.tobytes(), dtype=np.uint64).reshape(num_ranks, -1) distributed_barrier() @@ -184,22 +189,74 @@ def error(self, message): def broadcast(self, value, source_rank): """ - Broadcast a Python scalar or small picklable object from one rank to all ranks. + Broadcast a value from one rank to all ranks. + + This method automatically detects the type of value and uses the appropriate + broadcast mechanism: + - For tensors and arrays: uses efficient PyTorch distributed tensor collectives + - For scalars and other objects: uses object broadcast Args: - value (Any): The value to broadcast. Only the ``source_rank`` value is used; + value (Any): The value to broadcast. Can be a scalar, tensor, numpy array, + or any picklable object. Only the ``source_rank`` value is used; other ranks should pass a placeholder (e.g., ``None``). source_rank (int): Rank id that holds the authoritative value. Returns: - Any: The value broadcast to all ranks. + Any: The value broadcast to all ranks. Tensors and arrays are returned as + numpy arrays; scalars and objects are returned in their original type. - Example: + Examples: >>> ctx = iris.iris() + >>> # Broadcasting a scalar >>> value = 42 if ctx.cur_rank == 0 else None >>> value = ctx.broadcast(value, source_rank=0) # All ranks get 42 + >>> + >>> # Broadcasting a tensor + >>> if ctx.cur_rank == 0: + >>> data = torch.randn(10, 10) + >>> else: + >>> data = None + >>> data = ctx.broadcast(data, source_rank=0) # All ranks get the same array """ - return distributed_broadcast_scalar(value, source_rank) + # Check if the value on source_rank is a tensor or array-like + if self.cur_rank == source_rank and value is not None: + # Explicitly exclude strings and non-numeric types + if isinstance(value, (str, dict, bool)): + is_tensor = False + elif isinstance(value, torch.Tensor): + is_tensor = True + elif isinstance(value, np.ndarray): + is_tensor = True + elif isinstance(value, (list, tuple)): + # Try to convert list/tuple to tensor to check if it's numeric + try: + torch.as_tensor(value) + is_tensor = True + except (TypeError, ValueError): + is_tensor = False + else: + # For other types, try to convert and check + try: + test_array = np.asarray(value) + # Check if it's a numeric dtype that torch can handle + if np.issubdtype(test_array.dtype, np.number): + torch.as_tensor(test_array) + is_tensor = True + else: + is_tensor = False + except (TypeError, ValueError): + is_tensor = False + else: + is_tensor = False + + # Broadcast the type decision to all ranks + is_tensor = distributed_broadcast_scalar(is_tensor, source_rank) + + if is_tensor: + return distributed_broadcast_tensor(value, root=source_rank) + else: + return distributed_broadcast_scalar(value, source_rank) def __allocate(self, num_elements, dtype): self.debug(f"allocate: num_elements = {num_elements}, dtype = {dtype}") @@ -1559,6 +1616,80 @@ def store(pointer, value, from_rank, to_rank, heap_bases, mask=None, cache_modif tl.store(translated_ptr, value, mask=mask, cache_modifier=cache_modifier) +@triton.jit +def copy( + src_ptr, + dst_ptr, + from_rank, + to_rank, + cur_rank, + heap_bases, + mask=None, + load_cache_modifier=None, + store_cache_modifier=None, +): + """ + Copies data from the specified rank's memory into the destination rank's memory. + This function performs the transfer by translating src_ptr from the from_rank's address + space to the to_rank's address space, performing a masked load from the translated + source, and storing the loaded data to dst_ptr in the to_rank memory location. + If from_rank and to_rank are the same, this function performs a local copy operation. + It is undefined behaviour if neither from_rank nor to_rank is the cur_rank. + + Args: + src_ptr (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the from_rank's local memory from which to read data. + dst_ptr (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the to_rank's local memory where the data will be written. + from_rank (int): The rank ID that owns src_ptr (source rank). + to_rank (int): The rank ID that will receive the data (destination rank). + cur_rank (int): The rank ID issuing the copy operation. Must be either from_rank or to_rank. + heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks. + mask (Block of triton.int1, optional): If mask[idx] is false, do not load from the translated src_ptr[idx] and do not store to dst_ptr[idx]. Defaults to None. + + load_cache_modifier (str, optional): Controls cache behavior of the load. Supported values are: + - None: *(default)* — Same as ".ca". Uses cache at all levels (CU, L2, LLC) with LRU policy. + - ".ca": Cache at all levels (CU, L2, LLC) with LRU policy. + - ".cg": Bypasses the CU (L1) cache, streams through L2, and may hit in LLC but the line is not retained or inserted. + - ".cv": Bypasses all GPU caches (CU and L2) and fetches directly from system memory. If data exists in the LLC, it may hit, but is not retained or inserted. + + store_cache_modifier (str, optional): Controls cache behavior of the store. Supported values are: + - None: *(default)* — Same as ".wb". Uses write-back caching at all levels (CU, L2, LLC) with LRU policy. + - ".wb": Write-back. Write-allocate on L1 miss, inserted into caches and written back later. + - ".cg": Cache Global. Equivalent to ".wb" — stored through L1 → L2 → LLC under LRU. + - ".cs": Cache Streaming. Bypasses L1, streamed through L2, not retained in LLC. + - ".wt": Write-Through. Bypasses L1 and L2 (coherent cache bypass), may hit in LLC with LRU. + + Returns: + None + + Example: + >>> @triton.jit + >>> def kernel(remote_ptr, local_ptr, heap_bases): + >>> from_rank = 1 + >>> to_rank = 0 + >>> iris.copy(remote_ptr, local_ptr, from_rank, to_rank, to_rank, heap_bases) + """ + + cur_base = tl.load(heap_bases + cur_rank) + + from_base = tl.load(heap_bases + from_rank) + to_base = tl.load(heap_bases + to_rank) + + src_ptr_int = tl.cast(src_ptr, tl.uint64) + src_offset = src_ptr_int - cur_base + + dst_ptr_int = tl.cast(dst_ptr, tl.uint64) + dst_offset = dst_ptr_int - cur_base + + from_base_byte = tl.cast(from_base, tl.pointer_type(tl.int8)) + to_base_byte = tl.cast(to_base, tl.pointer_type(tl.int8)) + + translated_src = tl.cast(from_base_byte + src_offset, src_ptr.dtype) + translated_dst = tl.cast(to_base_byte + dst_offset, src_ptr.dtype) + + data = tl.load(translated_src, mask=mask, cache_modifier=load_cache_modifier) + tl.store(translated_dst, data, mask=mask, cache_modifier=store_cache_modifier) + + @triton.jit def get( from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None, load_cache_modifier=None, store_cache_modifier=None diff --git a/pyproject.toml b/pyproject.toml index c563645a..88f7b2e3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,37 +2,32 @@ # Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. [build-system] -requires = ["setuptools", "wheel"] +requires = ["setuptools>=61", "wheel", "setuptools-scm>=8"] build-backend = "setuptools.build_meta" -[tool.setuptools.dynamic] -version = {attr = "iris.__version__"} - -[tool.setuptools] -package-dir = { "" = "." } - -[tool.setuptools.packages.find] -include = ["iris"] - [project] -name = "Iris" -version = "0.1.0" -description = "Python- and Triton-based library that provide SHMEM-like RDMA support in Triton." +name = "iris" +dynamic = ["version"] +description = "Triton-based framework for Remote Memory Access (RMA) operations with SHMEM-like APIs for multi-GPU programming." authors = [ { name = "Muhammad Awad", email = "muhaawad@amd.com" }, { name = "Muhammad Osama", email = "Muhammad.Osama@amd.com" }, { name = "Brandon Potter", email = "Brandon.Potter@amd.com" } ] -license = "MIT" +license = { text = "MIT" } readme = "README.md" requires-python = ">=3.8" - dependencies = [ "numpy", "requests", "ruff", ] +[project.urls] +Homepage = "https://rocm.github.io/iris/" +Repository = "https://github.com/ROCm/iris" +Documentation = "https://rocm.github.io/iris/" + [project.optional-dependencies] dev = [ "pytest", @@ -40,10 +35,21 @@ dev = [ "mypy", ] +[tool.setuptools] +package-dir = { "" = "." } + +[tool.setuptools.packages.find] +include = ["iris"] + +# ---- setuptools-scm versioning ---- +[tool.setuptools_scm] +version_scheme = "post-release" # .postN after last tag +local_scheme = "node-and-date" # add commit hash (e.g. +gabc1234) and date (e.g. +20250914) +fallback_version = "0.0.0" # used if git metadata unavailable + [tool.ruff] line-length = 120 exclude = [ - "csrc/finegrained_alloc/**", # explicitly exclude all contents "**/*.ipynb" # match notebooks anywhere ] @@ -52,4 +58,4 @@ select = ["E", "F", "W"] ignore = ["E501", "E701", "E731", "E741", "F841", "F401"] [tool.ruff.format] -quote-style = "double" +quote-style = "double" \ No newline at end of file diff --git a/scripts/link_bandwidth.py b/scripts/link_bandwidth.py index 29faf5f1..1c79d250 100644 --- a/scripts/link_bandwidth.py +++ b/scripts/link_bandwidth.py @@ -2,6 +2,15 @@ # Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. import json +import torch + +try: + if torch.cuda.is_available(): + cu_count = torch.cuda.get_device_properties(0).multi_processor_count + else: + cu_count = 304 # Default for MI300 +except Exception: + cu_count = 304 # Default for MI300 # Sample input (replace with file read if needed) config = { @@ -26,7 +35,7 @@ "kpack": 2, "heap_size": 8589934592, "gemm_sms": 48, - "total_sms": 304, + "total_sms": cu_count, "communication_block_size": 256, "communication_sms_multiplier": 1, "M": 8192, diff --git a/setup.py b/setup.py index 97a99f66..69832461 100644 --- a/setup.py +++ b/setup.py @@ -1,95 +1,11 @@ -#!/usr/bin/env python3 # SPDX-License-Identifier: MIT # Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. -import os -import subprocess -import sys -import shutil from setuptools import setup -from setuptools.command.build_py import build_py - -class HIPBuildPy(build_py): - """Custom build command that also builds the HIP library.""" - - def run(self): - # Build the HIP library first - self.build_hip_library() - # Then run the normal Python build - super().run() - - def build_hip_library(self): - """Build the finegrained allocator using hipcc.""" - # Get the project root directory (where setup.py is located) - project_root = os.path.dirname(os.path.abspath(__file__)) - src_dir = os.path.join(project_root, "csrc", "finegrained_alloc") - src_file = os.path.join(src_dir, "finegrained_allocator.hip") - output_file = os.path.join(src_dir, "libfinegrained_allocator.so") - - # Check if source file exists - if not os.path.exists(src_file): - raise FileNotFoundError( - f"Source file not found: {src_file}\n" - "This might happen if the repository is incomplete or if you're " - "installing from a source distribution that doesn't include the C++ source." - ) - - # Ensure the output directory exists - os.makedirs(os.path.dirname(output_file), exist_ok=True) - - basic_warnings = ["-Wall", "-Wextra", "-Werror"] - strict_warnings = [ - "-pedantic", - "-Wshadow", - "-Wnon-virtual-dtor", - "-Wold-style-cast", - "-Wcast-align", - "-Woverloaded-virtual", - "-Wconversion", - "-Wsign-conversion", - "-Wnull-dereference", - "-Wdouble-promotion", - "-Wformat=2", - ] - std_flags = ["-std=c++17"] - output_flags = ["-shared", "-fPIC", "-o", output_file] - - cmd = ["hipcc"] + basic_warnings + strict_warnings + std_flags + output_flags + [src_file] - - print(f"Building finegrained allocator: {' '.join(cmd)}") - - try: - subprocess.run(cmd, cwd=src_dir, check=True, capture_output=True, text=True) - print(f"Successfully built: {output_file}") - - # Copy the built library to the iris package directory for installation - iris_package_dir = os.path.join(project_root, "iris") - target_dir = os.path.join(iris_package_dir, "csrc", "finegrained_alloc") - os.makedirs(target_dir, exist_ok=True) - target_file = os.path.join(target_dir, "libfinegrained_allocator.so") - shutil.copy2(output_file, target_file) - print(f"Copied library to: {target_file}") - - except subprocess.CalledProcessError as e: - print(f"Build failed with return code {e.returncode}") - print(f"stdout: {e.stdout}") - print(f"stderr: {e.stderr}") - raise - except FileNotFoundError: - print("hipcc not found. Please ensure ROCm/HIP is installed.") - print( - "You can install ROCm following the instructions at: https://rocmdocs.amd.com/en/latest/Installation_Guide/Installation-Guide.html" - ) - raise - - -if __name__ == "__main__": - setup( - cmdclass={ - "build_py": HIPBuildPy, - }, - package_data={ - "iris": ["csrc/finegrained_alloc/libfinegrained_allocator.so"], - }, - ) +# This setup.py provides backward compatibility for legacy metadata fields +# that don't map directly from pyproject.toml's modern PEP 621 format. +setup( + url="https://rocm.github.io/iris/", + author="Muhammad Awad, Muhammad Osama, Brandon Potter", +) diff --git a/tests/examples/test_all_load_bench.py b/tests/examples/test_all_load_bench.py index 3d334947..80a0ca71 100644 --- a/tests/examples/test_all_load_bench.py +++ b/tests/examples/test_all_load_bench.py @@ -44,6 +44,8 @@ ], ) def test_all_load_bench(dtype, buffer_size, heap_size, block_size): + # TODO: Benchmark is not accurate. See: https://github.com/ROCm/iris/issues/119 + pytest.skip("Benchmark is not accurate. See: https://github.com/ROCm/iris/issues/119") shmem = iris.iris(heap_size) num_ranks = shmem.get_num_ranks() @@ -56,8 +58,8 @@ def test_all_load_bench(dtype, buffer_size, heap_size, block_size): "datatype": _torch_dtype_to_str(dtype), "block_size": block_size, "active_ranks": num_ranks, - "num_warmup": 1, - "num_experiments": 2, + "num_warmup": 4, + "num_experiments": 8, "verbose": False, "validate": False, } diff --git a/tests/examples/test_atomic_add_bench.py b/tests/examples/test_atomic_add_bench.py new file mode 100644 index 00000000..2406a04c --- /dev/null +++ b/tests/examples/test_atomic_add_bench.py @@ -0,0 +1,144 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +import pytest +import torch +import triton +import triton.language as tl +import numpy as np +import iris + +import importlib.util +import sys +from pathlib import Path + +current_dir = Path(__file__).parent + +# Add examples directory to sys.path so that example files can import from examples.common +# Note: Examples use "from examples.common.utils import ..." which requires examples/ in sys.path +examples_dir = (current_dir / "../..").resolve() +if str(examples_dir) not in sys.path: + sys.path.insert(0, str(examples_dir)) + +# Load utils module from file path (not package import) +# Note: We use path-based imports instead of "from examples.common.utils import ..." +# because examples/ is not included in the installed package. This allows tests to +# work with both editable install (pip install -e .) and regular install (pip install git+...). +utils_path = (current_dir / "../../examples/common/utils.py").resolve() +utils_spec = importlib.util.spec_from_file_location("utils", utils_path) +utils_module = importlib.util.module_from_spec(utils_spec) +utils_spec.loader.exec_module(utils_module) +torch_dtype_to_str = utils_module.torch_dtype_to_str + +# Load benchmark module +file_path = (current_dir / "../../examples/04_atomic_add/atomic_add_bench.py").resolve() +module_name = "atomic_add_bench" +spec = importlib.util.spec_from_file_location(module_name, file_path) +module = importlib.util.module_from_spec(spec) +spec.loader.exec_module(module) + + +@pytest.mark.parametrize( + "dtype", + [ + torch.float16, + torch.bfloat16, + torch.float32, + ], +) +@pytest.mark.parametrize( + "buffer_size, heap_size", + [ + (20480, (1 << 33)), + ], +) +@pytest.mark.parametrize( + "block_size", + [ + 512, + 1024, + ], +) +def test_atomic_bandwidth(dtype, buffer_size, heap_size, block_size): + """Test that atomic_add benchmark runs and produces positive bandwidth.""" + shmem = iris.iris(heap_size) + num_ranks = shmem.get_num_ranks() + + element_size_bytes = torch.tensor([], dtype=dtype).element_size() + n_elements = buffer_size // element_size_bytes + source_buffer = shmem.arange(n_elements, dtype=dtype) + + shmem.barrier() + + args = { + "datatype": torch_dtype_to_str(dtype), + "block_size": block_size, + "verbose": False, + "validate": False, + "num_experiments": 10, + "num_warmup": 5, + } + + source_rank = 0 + destination_rank = 1 if num_ranks > 1 else 0 + + bandwidth_gbps, _ = module.run_experiment(shmem, args, source_rank, destination_rank, source_buffer) + + assert bandwidth_gbps > 0, f"Bandwidth should be positive, got {bandwidth_gbps}" + + shmem.barrier() + + +@pytest.mark.parametrize( + "dtype", + [ + torch.float16, + torch.bfloat16, + torch.float32, + ], +) +@pytest.mark.parametrize( + "buffer_size, heap_size", + [ + (20480, (1 << 33)), + ], +) +@pytest.mark.parametrize( + "block_size", + [ + 512, + 1024, + ], +) +def test_atomic_correctness(dtype, buffer_size, heap_size, block_size): + """Test that atomic_add benchmark runs and produces positive bandwidth.""" + shmem = iris.iris(heap_size) + num_ranks = shmem.get_num_ranks() + + element_size_bytes = torch.tensor([], dtype=dtype).element_size() + n_elements = buffer_size // element_size_bytes + source_buffer = shmem.arange(n_elements, dtype=dtype) + + shmem.barrier() + + args = { + "datatype": torch_dtype_to_str(dtype), + "block_size": block_size, + "verbose": False, + "validate": False, + "num_experiments": 1, + "num_warmup": 0, + } + + source_rank = 0 + destination_rank = 1 if num_ranks > 1 else 0 + + _, result_buffer = module.run_experiment(shmem, args, source_rank, destination_rank, source_buffer) + + if shmem.get_rank() == destination_rank: + expected = torch.ones(n_elements, dtype=dtype, device="cuda") + + assert torch.allclose(result_buffer, expected), "Result buffer should be equal to expected" + + shmem.barrier() diff --git a/tests/examples/test_flash_decode.py b/tests/examples/test_flash_decode.py new file mode 100644 index 00000000..40eaf523 --- /dev/null +++ b/tests/examples/test_flash_decode.py @@ -0,0 +1,238 @@ +################################################################################ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. +# +# +# Part of the code adapted from +# https://github.com/ByteDance-Seed/Triton-distributed/blob/main/python/triton_dist/test/nvidia/test_sp_decode_attn.py +# +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Permission is hereby granted, free of charge, to any person obtaining +# a copy of this software and associated documentation files +# (the "Software"), to deal in the Software without restriction, +# including without limitation the rights to use, copy, modify, merge, +# publish, distribute, sublicense, and/or sell copies of the Software, +# and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be +# included in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +# +################################################################################ + + +import sys +import os +from pathlib import Path +import pytest +from typing import List, Optional +from argparse import Namespace + +import numpy as np +import torch +import iris + +project_root = Path(__file__).resolve() +while not (project_root / "tests").is_dir() or not (project_root / "examples").is_dir(): + if project_root == project_root.parent: + raise FileNotFoundError("Could not find project root") + project_root = project_root.parent +print(f"Project Root: {project_root}") + +module_dir = project_root / "examples" / "13_flash_decode" +print(f"Module Directory: {module_dir}") + +target_file = module_dir / "flash_decode_fused_layer.py" +if module_dir.exists(): + sys.path.insert(0, str(module_dir)) + print(f"'{module_dir}' was added to sys.path.") +else: + print("ERROR: Target directory not found") + +from flash_decode_fused_layer import flash_decode_fused_layer # noqa: E402 +from utils import print_correctness_report # noqa: E402 + + +def ref_paged_attn( + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + query_lens: List[int], + kv_lens_per_rank: List[int], + block_tables: torch.Tensor, + scale: float, + soft_cap: Optional[float] = None, +) -> torch.Tensor: + num_seqs = len(query_lens) + block_tables_cpu = block_tables.cpu().numpy() + _, block_size, num_kv_heads, head_size = key_cache.shape + outputs: List[torch.Tensor] = [] + start_idx = 0 + for i in range(num_seqs): + query_len, kv_len = query_lens[i], kv_lens_per_rank[i] + q = query[start_idx : start_idx + query_len] + q *= scale + num_kv_blocks = (kv_len + block_size - 1) // block_size + block_indices = block_tables_cpu[i, :num_kv_blocks] + k = key_cache[block_indices].view(-1, num_kv_heads, head_size)[:kv_len] + v = value_cache[block_indices].view(-1, num_kv_heads, head_size)[:kv_len] + if q.shape[1] != k.shape[1]: + gqa_ratio = q.shape[1] // k.shape[1] + k = torch.repeat_interleave(k, gqa_ratio, dim=1) + v = torch.repeat_interleave(v, gqa_ratio, dim=1) + attn = torch.einsum("qhd,khd->hqk", q, k).float() + empty_mask = torch.ones(query_len, kv_len, device=query.device) + mask = torch.triu(empty_mask, diagonal=kv_len - query_len + 1).bool() + if soft_cap is not None and soft_cap > 0.0: + attn = soft_cap * torch.tanh(attn / soft_cap) + attn.masked_fill_(mask, float("-inf")) + attn = torch.softmax(attn, dim=-1).to(v.dtype) + out = torch.einsum("hqk,khd->qhd", attn, v) + outputs.append(out) + start_idx += query_len + return torch.cat(outputs, dim=0) + + +def prepare_correctness_data(cfg, args, num_query_heads, num_kv_heads, NUM_BLOCKS): + head_dim = cfg["head_dim"] + if args.rank == 0: + query = torch.randn(cfg["num_seqs"], num_query_heads, head_dim, dtype=cfg["dtype"]) / 10 + key_value_cache = torch.randn(NUM_BLOCKS, 2, cfg["block_size"], num_kv_heads, head_dim, dtype=cfg["dtype"]) / 10 + else: + query = torch.empty(cfg["num_seqs"], num_query_heads, head_dim, dtype=cfg["dtype"]) + key_value_cache = torch.empty(NUM_BLOCKS, 2, cfg["block_size"], num_kv_heads, head_dim, dtype=cfg["dtype"]) + + query = torch.from_numpy(args.shmem.broadcast(query.cpu().numpy(), source_rank=0)).to(query.device) + key_value_cache = torch.from_numpy(args.shmem.broadcast(key_value_cache.cpu().numpy(), source_rank=0)).to( + key_value_cache.device + ) + + return {"query": query, "key_value_cache": key_value_cache} + + +@pytest.mark.parametrize("head_dim", [128]) +@pytest.mark.parametrize("num_seqs", [1, 8]) +@pytest.mark.parametrize("num_heads", [48, 96]) +@pytest.mark.parametrize("kv_len", [4096, 65536]) +def test_correctness_fused_full(kv_len, num_heads, num_seqs, head_dim): + """ + Tests the correctness of the Iris Fused implementation against the Torch reference. + This test is parameterized to run all combinations of the parameters. + """ + shmem = iris.iris() + + args = Namespace() + args.rank = shmem.get_rank() + args.num_ranks = shmem.get_num_ranks() + args.local_num_ranks = shmem.get_num_ranks() + args.shmem = shmem + + config = { + "kv_len": kv_len, + "num_heads": num_heads, + "num_seqs": num_seqs, + "head_dim": head_dim, + "dtype": torch.float16, + "block_size": 1, + "soft_cap": 0, + } + + # torch.manual_seed(42) + torch.set_default_device("cuda") + + num_query_heads = num_heads + num_kv_heads = num_query_heads // 8 if num_query_heads >= 8 else 1 + scale = head_dim**-0.5 + NUM_BLOCKS_PER_RANK = config["kv_len"] + 1 + NUM_BLOCKS = NUM_BLOCKS_PER_RANK * args.num_ranks + + tensor_data = prepare_correctness_data(config, args, num_query_heads, num_kv_heads, NUM_BLOCKS) + query = tensor_data["query"] + key_value_cache = tensor_data["key_value_cache"] + + key_cache = key_value_cache[:, 0, :, :, :].contiguous() + value_cache = key_value_cache[:, 1, :, :, :].contiguous() + key_cache_this_rank = key_cache[ + args.rank * NUM_BLOCKS_PER_RANK : (args.rank + 1) * NUM_BLOCKS_PER_RANK + ].contiguous() + value_cache_this_rank = value_cache[ + args.rank * NUM_BLOCKS_PER_RANK : (args.rank + 1) * NUM_BLOCKS_PER_RANK + ].contiguous() + + block_tables_this_rank = torch.arange(NUM_BLOCKS_PER_RANK, dtype=torch.int32).repeat(num_seqs, 1) + all_block_tables_numpy = iris._distributed_helpers.distributed_allgather_multidim( + block_tables_this_rank.cpu().numpy() + ) + block_tables = torch.from_numpy(all_block_tables_numpy).view(args.num_ranks, num_seqs, -1) + ref_block_tables = torch.cat([block_tables[i] + i * NUM_BLOCKS_PER_RANK for i in range(args.num_ranks)], dim=-1) + + common_params = { + "num_q_heads": num_query_heads, + "num_kv_heads": num_kv_heads, + "q_head_dim": head_dim, + "v_head_dim": head_dim, + "page_size": config["block_size"], + "scale": scale, + "soft_cap": config["soft_cap"], + "max_allowed_batch": num_seqs, + } + + iris_fd_layer = flash_decode_fused_layer( + args.shmem, + args.rank, + args.rank // args.local_num_ranks, + args.num_ranks, + args.num_ranks // args.local_num_ranks, + **common_params, + ) + + args.shmem.barrier() + if hasattr(iris_fd_layer, "clear_flags"): + iris_fd_layer.clear_flags() + args.shmem.barrier() + + kv_lens_per_rank = [config["kv_len"]] * num_seqs + global_kv_lens = [kv_lens_per_rank[0] * args.num_ranks] * num_seqs + kv_lens_tensor = torch.tensor(kv_lens_per_rank, dtype=torch.int32, device=query.device) + global_kv_lens_tensor = kv_lens_tensor.unsqueeze(0).repeat(args.num_ranks, 1) + + output = iris_fd_layer( + query, key_cache_this_rank, value_cache_this_rank, global_kv_lens_tensor, block_tables_this_rank + ) + torch.cuda.synchronize() + + ref_output = ref_paged_attn( + query=query.clone(), + key_cache=key_cache, + value_cache=value_cache, + query_lens=[1] * num_seqs, + kv_lens_per_rank=global_kv_lens, + block_tables=ref_block_tables, + scale=scale, + soft_cap=config["soft_cap"], + ) + args.shmem.barrier() + + error = None + try: + atol = 1e-4 + rtol = 1e-4 + torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol) + except AssertionError as e: + error = e + + print_correctness_report(args.rank, output, ref_output, error) + + if error: + raise error + + args.shmem.barrier() diff --git a/tests/examples/test_load_bench.py b/tests/examples/test_load_bench.py index 16d6c403..e7c52b56 100644 --- a/tests/examples/test_load_bench.py +++ b/tests/examples/test_load_bench.py @@ -20,6 +20,7 @@ spec.loader.exec_module(module) +@pytest.mark.skip(reason="Test is inconsistent and needs debugging - tracked in issue") @pytest.mark.parametrize( "dtype", [ diff --git a/tests/examples/test_message_passing.py b/tests/examples/test_message_passing.py new file mode 100644 index 00000000..795fee61 --- /dev/null +++ b/tests/examples/test_message_passing.py @@ -0,0 +1,154 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +import pytest +import torch +import triton +import triton.language as tl +import numpy as np +import iris + +import importlib.util +from pathlib import Path + +current_dir = Path(__file__).parent + +# Import message_passing_load_store module +load_store_file_path = (current_dir / "../../examples/06_message_passing/message_passing_load_store.py").resolve() +load_store_module_name = "message_passing_load_store" +load_store_spec = importlib.util.spec_from_file_location(load_store_module_name, load_store_file_path) +load_store_module = importlib.util.module_from_spec(load_store_spec) +load_store_spec.loader.exec_module(load_store_module) + +# Import message_passing_put module +put_file_path = (current_dir / "../../examples/06_message_passing/message_passing_put.py").resolve() +put_module_name = "message_passing_put" +put_spec = importlib.util.spec_from_file_location(put_module_name, put_file_path) +put_module = importlib.util.module_from_spec(put_spec) +put_spec.loader.exec_module(put_module) + + +def create_test_args(dtype_str, buffer_size, heap_size, block_size): + """Create args dict that matches what parse_args() returns.""" + return {"datatype": dtype_str, "buffer_size": buffer_size, "heap_size": heap_size, "block_size": block_size} + + +def run_message_passing_kernels(module, args): + """Run the core message passing logic without command line argument parsing.""" + shmem = iris.iris(args["heap_size"]) + dtype = module.torch_dtype_from_str(args["datatype"]) + cur_rank = shmem.get_rank() + world_size = shmem.get_num_ranks() + + # Check that we have exactly 2 ranks as required by message passing examples + if world_size != 2: + pytest.skip("Message passing examples require exactly two processes.") + + # Allocate source and destination buffers on the symmetric heap - match original examples + source_buffer = shmem.zeros(args["buffer_size"], device="cuda", dtype=dtype) + if dtype.is_floating_point: + destination_buffer = shmem.randn(args["buffer_size"], device="cuda", dtype=dtype) + else: + ii = torch.iinfo(dtype) + destination_buffer = shmem.randint(ii.min, ii.max, (args["buffer_size"],), device="cuda", dtype=dtype) + + producer_rank = 0 + consumer_rank = 1 + + n_elements = source_buffer.numel() + grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) + num_blocks = triton.cdiv(n_elements, args["block_size"]) + + # Allocate flags on the symmetric heap + flags = shmem.zeros((num_blocks,), device="cuda", dtype=torch.int32) + + if cur_rank == producer_rank: + # Run producer kernel + module.producer_kernel[grid]( + source_buffer, + destination_buffer, + flags, + n_elements, + producer_rank, + consumer_rank, + args["block_size"], + shmem.get_heap_bases(), + ) + else: + # Run consumer kernel + module.consumer_kernel[grid]( + destination_buffer, flags, n_elements, consumer_rank, args["block_size"], shmem.get_heap_bases() + ) + + shmem.barrier() + + # Validation - only consumer rank validates (matches original examples) + success = True + if cur_rank == consumer_rank: + expected = source_buffer * 2 + if not torch.allclose(destination_buffer, expected, atol=1): + success = False + + shmem.barrier() + return success + + +@pytest.mark.parametrize( + "dtype_str", + [ + "int8", + "fp16", + "bf16", + "fp32", + ], +) +@pytest.mark.parametrize( + "buffer_size, heap_size", + [ + (4096, 1 << 20), # Smaller sizes for testing + (8192, 1 << 21), + ], +) +@pytest.mark.parametrize( + "block_size", + [ + 512, + 1024, + ], +) +def test_message_passing_load_store(dtype_str, buffer_size, heap_size, block_size): + """Test message passing with load/store operations.""" + args = create_test_args(dtype_str, buffer_size, heap_size, block_size) + success = run_message_passing_kernels(load_store_module, args) + assert success, "Message passing load/store validation failed" + + +@pytest.mark.parametrize( + "dtype_str", + [ + "int8", + "fp16", + "bf16", + "fp32", + ], +) +@pytest.mark.parametrize( + "buffer_size, heap_size", + [ + (4096, 1 << 20), # Smaller sizes for testing + (8192, 1 << 21), + ], +) +@pytest.mark.parametrize( + "block_size", + [ + 512, + 1024, + ], +) +def test_message_passing_put(dtype_str, buffer_size, heap_size, block_size): + """Test message passing with put operations.""" + args = create_test_args(dtype_str, buffer_size, heap_size, block_size) + success = run_message_passing_kernels(put_module, args) + assert success, "Message passing put validation failed" diff --git a/tests/run_tests_distributed.py b/tests/run_tests_distributed.py index abf32966..e3254556 100755 --- a/tests/run_tests_distributed.py +++ b/tests/run_tests_distributed.py @@ -21,15 +21,22 @@ def _find_free_port(): return s.getsockname()[1] -def _distributed_worker(rank, world_size, test_file, pytest_args): +def _distributed_worker(rank, world_size, test_file, pytest_args, init_method): """Worker function that runs pytest within a distributed process group.""" + # Set the correct GPU for this specific process + # When ROCR_VISIBLE_DEVICES is set, devices are remapped, so rank 0 should use device 0, etc. + import torch + + if torch.cuda.is_available(): + torch.cuda.set_device(rank) + # Initialize distributed once for all tests - init_method = "tcp://127.0.0.1:12355" dist.init_process_group( backend="nccl", init_method=init_method, rank=rank, world_size=world_size, + device_id=torch.device(f"cuda:{rank}"), ) try: @@ -84,11 +91,16 @@ def main(): print(f"Running {test_file} with {num_ranks} ranks") print(f"args={args}, test_file={test_file}, pytest_args={pytest_args}") + # Find a free port for this test run to avoid conflicts with parallel runs + free_port = _find_free_port() + init_method = f"tcp://127.0.0.1:{free_port}" + print(f"Using init_method: {init_method}") + # Run all tests within a single distributed process group try: mp.spawn( _distributed_worker, - args=(num_ranks, test_file, pytest_args), + args=(num_ranks, test_file, pytest_args, init_method), nprocs=num_ranks, join=True, ) diff --git a/tests/unittests/test_broadcast.py b/tests/unittests/test_broadcast.py new file mode 100644 index 00000000..e35b61e1 --- /dev/null +++ b/tests/unittests/test_broadcast.py @@ -0,0 +1,73 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +import torch +import numpy as np +import pytest +import iris + + +@pytest.mark.parametrize( + "value,expected", + [ + (42, 42), + (3.14159, 3.14159), + (True, True), + (False, False), + ("Hello, Iris!", "Hello, Iris!"), + ({"key": "value", "num": 42}, {"key": "value", "num": 42}), + ], +) +def test_broadcast_scalar(value, expected): + """Test broadcasting scalar values (int, float, bool, string, dict).""" + shmem = iris.iris(1 << 20) + rank = shmem.get_rank() + + val = value if rank == 0 else None + result = shmem.broadcast(val, source_rank=0) + + if isinstance(expected, float): + assert abs(result - expected) < 1e-6 + else: + assert result == expected + + +@pytest.mark.parametrize( + "dtype", + [ + torch.float32, + torch.float16, + torch.int32, + torch.int64, + ], +) +def test_broadcast_tensor_dtype(dtype): + """Test broadcasting tensors with different dtypes.""" + shmem = iris.iris(1 << 20) + rank = shmem.get_rank() + + value = torch.arange(10, dtype=dtype) if rank == 0 else None + result = shmem.broadcast(value, source_rank=0) + + assert isinstance(result, np.ndarray) + np.testing.assert_array_equal(result, np.arange(10)) + + +@pytest.mark.parametrize( + "shape", + [ + (10,), + (10, 20), + (5, 10, 15), + ], +) +def test_broadcast_tensor_shape(shape): + """Test broadcasting tensors with different shapes.""" + shmem = iris.iris(1 << 25) + rank = shmem.get_rank() + + value = torch.randn(shape) if rank == 0 else None + result = shmem.broadcast(value, source_rank=0) + + assert isinstance(result, np.ndarray) + assert result.shape == shape diff --git a/tests/unittests/test_copy.py b/tests/unittests/test_copy.py new file mode 100644 index 00000000..3f4b8ec0 --- /dev/null +++ b/tests/unittests/test_copy.py @@ -0,0 +1,216 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +import torch +import triton +import triton.language as tl +import pytest +import iris + + +@triton.jit +def copy_get_kernel( + data, + results, + cur_rank: tl.constexpr, + num_ranks: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + heap_bases: tl.tensor, +): + """GET: cur_rank == to_rank (pull from remote)""" + pid = tl.program_id(0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < BLOCK_SIZE + + for target_rank in range(num_ranks): + src_data = data + BLOCK_SIZE * cur_rank + dest_data = results + BLOCK_SIZE * target_rank + iris.copy(src_data + offsets, dest_data + offsets, target_rank, cur_rank, cur_rank, heap_bases, mask) + + +@triton.jit +def copy_put_kernel( + data, + results, + cur_rank: tl.constexpr, + num_ranks: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + heap_bases: tl.tensor, +): + """PUT: cur_rank == from_rank (push to remote)""" + pid = tl.program_id(0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < BLOCK_SIZE + + for target_rank in range(num_ranks): + src_data = data + BLOCK_SIZE * cur_rank + dest_data = results + BLOCK_SIZE * cur_rank + iris.copy(src_data + offsets, dest_data + offsets, cur_rank, target_rank, cur_rank, heap_bases, mask) + + +@triton.jit +def copy_local_kernel( + data, + results, + cur_rank: tl.constexpr, + num_ranks: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + heap_bases: tl.tensor, +): + """LOCAL: from_rank == to_rank == cur_rank""" + pid = tl.program_id(0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < BLOCK_SIZE + + for i in range(num_ranks): + src_data = data + BLOCK_SIZE * i + dest_data = results + BLOCK_SIZE * i + iris.copy(src_data + offsets, dest_data + offsets, cur_rank, cur_rank, cur_rank, heap_bases, mask) + + +@pytest.mark.parametrize( + "dtype", + [ + torch.int8, + torch.float16, + torch.bfloat16, + torch.float32, + ], +) +@pytest.mark.parametrize( + "BLOCK_SIZE", + [ + 1, + 8, + 16, + 32, + ], +) +def test_copy_get(dtype, BLOCK_SIZE): + """Test GET operation: cur_rank == to_rank""" + shmem = iris.iris(1 << 20) + num_ranks = shmem.get_num_ranks() + heap_bases = shmem.get_heap_bases() + cur_rank = shmem.get_rank() + + data = shmem.zeros((num_ranks, BLOCK_SIZE), dtype=dtype) + base = cur_rank + num_ranks + for i in range(num_ranks): + data[i, :] = base * (i + 1) + + results = shmem.zeros((num_ranks, BLOCK_SIZE), dtype=dtype) + grid = lambda meta: (1,) + copy_get_kernel[grid](data, results, cur_rank, num_ranks, BLOCK_SIZE, heap_bases) + shmem.barrier() + + expected = shmem.zeros((num_ranks, BLOCK_SIZE), dtype=dtype) + for rank_id in range(num_ranks): + expected[rank_id, :] = (rank_id + num_ranks) * (cur_rank + 1) + + try: + torch.testing.assert_close(results, expected, rtol=0, atol=0) + except AssertionError as e: + print(e) + print("Expected:", expected) + print("Actual:", results) + raise + + +@pytest.mark.parametrize( + "dtype", + [ + torch.int8, + torch.float16, + torch.bfloat16, + torch.float32, + ], +) +@pytest.mark.parametrize( + "BLOCK_SIZE", + [ + 1, + 8, + 16, + 32, + ], +) +def test_copy_put(dtype, BLOCK_SIZE): + """Test PUT operation: cur_rank == from_rank""" + shmem = iris.iris(1 << 20) + num_ranks = shmem.get_num_ranks() + heap_bases = shmem.get_heap_bases() + cur_rank = shmem.get_rank() + + data = shmem.zeros((num_ranks, BLOCK_SIZE), dtype=dtype) + base = cur_rank + num_ranks + for i in range(num_ranks): + data[i, :] = base * (i + 1) + + results = shmem.zeros((num_ranks, BLOCK_SIZE), dtype=dtype) + grid = lambda meta: (1,) + copy_put_kernel[grid](data, results, cur_rank, num_ranks, BLOCK_SIZE, heap_bases) + shmem.barrier() + + # Each rank writes to results[cur_rank] on all targets + # After barrier, results[rank_id] contains data from rank_id + expected = shmem.zeros((num_ranks, BLOCK_SIZE), dtype=dtype) + for rank_id in range(num_ranks): + expected[rank_id, :] = (rank_id + num_ranks) * (rank_id + 1) + + try: + torch.testing.assert_close(results, expected, rtol=0, atol=0) + except AssertionError as e: + print(e) + print("Expected:", expected) + print("Actual:", results) + raise + + +@pytest.mark.parametrize( + "dtype", + [ + torch.int8, + torch.float16, + torch.bfloat16, + torch.float32, + ], +) +@pytest.mark.parametrize( + "BLOCK_SIZE", + [ + 1, + 8, + 16, + 32, + ], +) +def test_copy_local(dtype, BLOCK_SIZE): + """Test LOCAL operation: from_rank == to_rank == cur_rank""" + shmem = iris.iris(1 << 20) + num_ranks = shmem.get_num_ranks() + heap_bases = shmem.get_heap_bases() + cur_rank = shmem.get_rank() + + data = shmem.zeros((num_ranks, BLOCK_SIZE), dtype=dtype) + base = cur_rank + num_ranks + for i in range(num_ranks): + data[i, :] = base * (i + 1) + + results = shmem.zeros((num_ranks, BLOCK_SIZE), dtype=dtype) + grid = lambda meta: (1,) + copy_local_kernel[grid](data, results, cur_rank, num_ranks, BLOCK_SIZE, heap_bases) + shmem.barrier() + + # Local copy: results should match data + expected = data + + try: + torch.testing.assert_close(results, expected, rtol=0, atol=0) + except AssertionError as e: + print(e) + print("Expected:", expected) + print("Actual:", results) + raise diff --git a/tests/unittests/test_copy_cache_modifiers.py b/tests/unittests/test_copy_cache_modifiers.py new file mode 100644 index 00000000..23437bb6 --- /dev/null +++ b/tests/unittests/test_copy_cache_modifiers.py @@ -0,0 +1,104 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +import torch +import triton +import triton.language as tl +import pytest +import iris +from itertools import product + + +@triton.jit +def copy_kernel( + data, + results, + cur_rank: tl.constexpr, + num_ranks: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + heap_bases: tl.tensor, + load_cache_modifier: tl.constexpr, + store_cache_modifier: tl.constexpr, +): + pid = tl.program_id(0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < BLOCK_SIZE + + # Test copy with cache modifiers - copy from current rank to other ranks + for target_rank in range(num_ranks): + src_data = data + BLOCK_SIZE * cur_rank + dest_data = results + BLOCK_SIZE * target_rank + if load_cache_modifier is None and store_cache_modifier is None: + iris.copy(src_data + offsets, dest_data + offsets, cur_rank, target_rank, cur_rank, heap_bases, mask=mask) + elif load_cache_modifier is None: + iris.copy( + src_data + offsets, + dest_data + offsets, + cur_rank, + target_rank, + cur_rank, + heap_bases, + mask=mask, + store_cache_modifier=store_cache_modifier, + ) + elif store_cache_modifier is None: + iris.copy( + src_data + offsets, + dest_data + offsets, + cur_rank, + target_rank, + cur_rank, + heap_bases, + mask=mask, + load_cache_modifier=load_cache_modifier, + ) + else: + iris.copy( + src_data + offsets, + dest_data + offsets, + cur_rank, + target_rank, + cur_rank, + heap_bases, + mask=mask, + load_cache_modifier=load_cache_modifier, + store_cache_modifier=store_cache_modifier, + ) + + +# Define cache modifiers for load and store operations +LOAD_CACHE_MODIFIERS = [None, "", ".ca", ".cg", ".cv"] +STORE_CACHE_MODIFIERS = [None, "", ".wb", ".cg", ".cs", ".wt"] + + +@pytest.mark.parametrize( + "load_cache_modifier,store_cache_modifier", list(product(LOAD_CACHE_MODIFIERS, STORE_CACHE_MODIFIERS)) +) +def test_copy_cache_modifiers(load_cache_modifier, store_cache_modifier): + """Test copy operation with various cache modifiers""" + shmem = iris.iris(1 << 20) + num_ranks = shmem.get_num_ranks() + heap_bases = shmem.get_heap_bases() + cur_rank = shmem.get_rank() + + BLOCK_SIZE = 16 + data = shmem.zeros((num_ranks, BLOCK_SIZE), dtype=torch.float32) + base = cur_rank + num_ranks + for i in range(num_ranks): + data[i, :] = base * (i + 1) + + results = shmem.zeros((num_ranks, BLOCK_SIZE), dtype=torch.float32) + grid = lambda meta: (1,) + copy_kernel[grid]( + data, results, cur_rank, num_ranks, BLOCK_SIZE, heap_bases, load_cache_modifier, store_cache_modifier + ) + + shmem.barrier() + + # Verify results - each rank should have copied its data to all ranks + for i in range(num_ranks): + expected_value = base * (cur_rank + 1) + assert torch.allclose(results[i], torch.full((BLOCK_SIZE,), expected_value, dtype=torch.float32)), ( + f"Mismatch at rank {cur_rank}, target {i} with load_cache_modifier={load_cache_modifier}, store_cache_modifier={store_cache_modifier}" + ) diff --git a/tests/unittests/test_empty.py b/tests/unittests/test_empty.py index 573f5467..e51fb4c2 100644 --- a/tests/unittests/test_empty.py +++ b/tests/unittests/test_empty.py @@ -104,7 +104,7 @@ def test_empty_device_handling(): shmem.empty(3, 3, device=different_device) # Test that different CUDA device throws error - if shmem.device.startswith("cuda:"): + if shmem.device.startswith("cuda:") and torch.cuda.device_count() >= 2: current_device = torch.device(shmem.device) different_cuda = f"cuda:{(current_device.index + 1) % torch.cuda.device_count()}" # Use next GPU with pytest.raises(RuntimeError): diff --git a/tests/unittests/test_full.py b/tests/unittests/test_full.py index 9301d125..a42d4ddb 100644 --- a/tests/unittests/test_full.py +++ b/tests/unittests/test_full.py @@ -122,7 +122,7 @@ def test_full_device_handling(): shmem.full((3, 3), 2.5, device=different_device) # Test that different CUDA device throws error - if shmem.device.startswith("cuda:"): + if shmem.device.startswith("cuda:") and torch.cuda.device_count() >= 2: current_device = torch.device(shmem.device) different_cuda = f"cuda:{(current_device.index + 1) % torch.cuda.device_count()}" # Use next GPU with pytest.raises(RuntimeError): diff --git a/tests/unittests/test_iris_helpers.py b/tests/unittests/test_iris_helpers.py index 145b67fd..91cbf7b6 100644 --- a/tests/unittests/test_iris_helpers.py +++ b/tests/unittests/test_iris_helpers.py @@ -35,7 +35,7 @@ def test_device_validation(): assert not shmem._Iris__is_valid_device("mps") # MPS is always invalid # Test that different CUDA device indices are rejected - if shmem.device.startswith("cuda:"): + if shmem.device.startswith("cuda:") and torch.cuda.device_count() >= 2: current_device = torch.device(shmem.device) different_cuda = f"cuda:{(current_device.index + 1) % torch.cuda.device_count()}" # Use next GPU assert not shmem._Iris__is_valid_device(different_cuda) diff --git a/tests/unittests/test_linspace.py b/tests/unittests/test_linspace.py index bf09fb93..02d26b24 100644 --- a/tests/unittests/test_linspace.py +++ b/tests/unittests/test_linspace.py @@ -103,7 +103,7 @@ def test_linspace_device_handling(): shmem.linspace(0.0, 1.0, 5, device=different_device) # Test that different CUDA device throws error - if shmem.device.startswith("cuda:"): + if shmem.device.startswith("cuda:") and torch.cuda.device_count() >= 2: current_device = torch.device(shmem.device) different_cuda = f"cuda:{(current_device.index + 1) % torch.cuda.device_count()}" # Use next GPU with pytest.raises(RuntimeError): diff --git a/tests/unittests/test_ones.py b/tests/unittests/test_ones.py index b17fce21..e70c63f8 100644 --- a/tests/unittests/test_ones.py +++ b/tests/unittests/test_ones.py @@ -111,7 +111,7 @@ def test_ones_device_handling(): shmem.ones(3, 3, device=different_device) # Test that different CUDA device throws error - if shmem.device.startswith("cuda:"): + if shmem.device.startswith("cuda:") and torch.cuda.device_count() >= 2: current_device = torch.device(shmem.device) different_cuda = f"cuda:{(current_device.index + 1) % torch.cuda.device_count()}" # Use next GPU with pytest.raises(RuntimeError): diff --git a/tests/unittests/test_rand.py b/tests/unittests/test_rand.py index dc31aa18..75b6968b 100644 --- a/tests/unittests/test_rand.py +++ b/tests/unittests/test_rand.py @@ -101,7 +101,7 @@ def test_rand_device_handling(): shmem.rand(3, 3, device=different_device) # Test that different CUDA device throws error - if shmem.device.startswith("cuda:"): + if shmem.device.startswith("cuda:") and torch.cuda.device_count() >= 2: current_device = torch.device(shmem.device) different_cuda = f"cuda:{(current_device.index + 1) % torch.cuda.device_count()}" # Use next GPU with pytest.raises(RuntimeError): diff --git a/tests/unittests/test_randint.py b/tests/unittests/test_randint.py index 06c6f18b..a636be38 100644 --- a/tests/unittests/test_randint.py +++ b/tests/unittests/test_randint.py @@ -102,7 +102,7 @@ def test_randint_device_handling(): shmem.randint(0, 10, (3, 3), device=different_device) # Test that different CUDA device throws error - if shmem.device.startswith("cuda:"): + if shmem.device.startswith("cuda:") and torch.cuda.device_count() >= 2: current_device = torch.device(shmem.device) different_cuda = f"cuda:{(current_device.index + 1) % torch.cuda.device_count()}" # Use next GPU with pytest.raises(RuntimeError): diff --git a/tests/unittests/test_randn.py b/tests/unittests/test_randn.py index 692d90b3..90eee1c7 100644 --- a/tests/unittests/test_randn.py +++ b/tests/unittests/test_randn.py @@ -98,7 +98,7 @@ def test_randn_device_handling(): shmem.randn(3, 3, device=different_device) # Test that different CUDA device throws error - if shmem.device.startswith("cuda:"): + if shmem.device.startswith("cuda:") and torch.cuda.device_count() >= 2: current_device = torch.device(shmem.device) num_devices = torch.cuda.device_count() different_cuda = f"cuda:{(current_device.index + 1) % num_devices}" # Use next GPU diff --git a/tests/unittests/test_zeros.py b/tests/unittests/test_zeros.py index a1a96e73..51126fed 100644 --- a/tests/unittests/test_zeros.py +++ b/tests/unittests/test_zeros.py @@ -111,7 +111,7 @@ def test_zeros_device_handling(): shmem.zeros(3, 3, device=different_device) # Test that different CUDA device throws error - if shmem.device.startswith("cuda:"): + if shmem.device.startswith("cuda:") and torch.cuda.device_count() >= 2: current_device = torch.device(shmem.device) different_cuda = f"cuda:{(current_device.index + 1) % torch.cuda.device_count()}" # Use next GPU with pytest.raises(RuntimeError): diff --git a/tests/unittests/test_zeros_like.py b/tests/unittests/test_zeros_like.py index 3a3b0a42..b7a0ff0c 100644 --- a/tests/unittests/test_zeros_like.py +++ b/tests/unittests/test_zeros_like.py @@ -127,7 +127,7 @@ def test_zeros_like_device_override(): shmem.zeros_like(input_tensor, device=different_device) # Test that different CUDA device throws error - if shmem.device.startswith("cuda:"): + if shmem.device.startswith("cuda:") and torch.cuda.device_count() >= 2: current_device = torch.device(shmem.device) different_cuda = f"cuda:{(current_device.index + 1) % torch.cuda.device_count()}" # Use next GPU with pytest.raises(RuntimeError):