# NVSHMEM4Py Device APIs

## Introduction

This notebook documents the NVSHMEM4Py device API for use in Numba CUDA kernels. It is a continuation of `Nvshmem4py` notebook.

## Environment

NVSHMEM4Py Numba-CUDA device APIs require the additional dependency `Numba-CUDA`, with a matching CUDA API version on your machine. Assuming you are using CUDA 12:

In [None]:
!pip install mpi4py nvshmem4py-cu12 cupy-cuda13x numba-cuda[cu12]==0.20.1 cuda-core

## NVSHMEM4Py Numba-CUDA Device API Overview

Device APIs allow developers to write GPU-initiated, one-sided operations from `@numba.cuda.jit` kernels. Users who need fine-grained, low-latency inter-GPU communication entirely from device code are encouraged to use them. These APIs are available via the `nvshmem.device.numba` namespace.

### Features

- Querying
- Remote Memory Access (RMA)
- Signal Operations
- Atomic Memory Operations (AMO)
- Collectives
- Synchronization
- Memory Mapping (direct device loads/stores)

### Pythonic Interface

Unlike the C/C++ APIs, the Numba device APIs are Pythonic in that they accept the `numba.types.Array` type. Certain APIs (such as RMA) omit the transfer element size, as it is deduced from the input array size. Users who need to specify a transfer size should slice the input array to create a view before passing it as an argument. The APIs are data type-aware, so users only need to ensure that the operand arrays have the same data type.

### Thread Scope Variants

Most of these APIs provide `_warp` or `_block` variants, which provide different levels of thread granularity. For example, the `put` API has `put_block` and `put_warp` variants. When used, all threads within the designated scope must receive the same arguments. They are frequently used for these purposes:

- Using `put_block` instead of `put` allows the GPU to copy data with all threads of the same block in parallel if two PEs are connected via a point-to-point connection. If they are connected via a remote connection, only a single GPU thread is used to initialize the copy instruction.
- The `block` and `warp` variants of collectives allow threads of different granularity levels to perform reductions across PEs.

## Example: Ring-Allreduce in Python

The ring-allreduce example performs an allreduce operation using a ring algorithm. The algorithm is separated into two phases: the reduction phase, and the broadcast phase. The example demonstrates use of device side APIs like `put_signal_nbi` and `signal_wait`.

### Problem Setup

Each PE has a local `src` data array initialized to `my_pe() + 1` to indicate its uniqueness. It also has an empty `dst` array of the same size as `src` to hold reduced data. Finally, there's an integral `signal` array to hold signals sent from other PEs.

The following image shows the initial setup of elements on 4 PEs. Signal and chunking is not represented for simplicity.

![NVSHMEM4Py Device API Overview](assets/1.png)


In [None]:
from mpi4py import MPI
from cuda.core.experimental import Device, system, Stream

from numba import cuda, uint64

import nvshmem
import nvshmem.bindings
from nvshmem.core import SignalOp, ComparisonType
from nvshmem.core.device.numba import put_signal_nbi, signal_wait, my_pe, n_pes

@cuda.jit
def ring_reduce(dst, src, nreduce, signal, chunk_size):
    # Numba-CUDA constructs to setup thread-wise variables
    mype = my_pe()
    npes = n_pes()
    peer = (mype + 1) % npes

    thread_id = cuda.threadIdx.x
    num_threads = cuda.blockDim.x
    num_blocks = cuda.gridDim.x
    block_idx = cuda.blockIdx.x
    elems_per_block = nreduce // num_blocks

    if elems_per_block * (block_idx + 1) > nreduce:
        return

### Reduction Phase

Initially, PE0 sends its local data to the next PE. Once finished, it increments PE1's signal flag by 1.

![Reduction-1](images/chapter-nvshmem4py-device/2.png)

Meanwhile, PE1 was waiting for an update to the signal flag. Once received, it indicates that PE0 has sent its data. It now performs a local compute.

![Reduction-1](images/chapter-nvshmem4py-device/3.png)

Once compute finishes, PE1 sends the data to the next PE. The next PE waits for the signal, and then performs local compute. It iterates to the last PE.

![Reduction-1](images/chapter-nvshmem4py-device/4.png)

On the last PE, once the compute finishes, it sends the result to PE0. Notice this time the result is already the final reduced result. PE0 is waiting for a signal to be updated. After receiving, it enters the broadcast phase.

![Reduction-1](images/chapter-nvshmem4py-device/5.png)

Each Cooperative Thread Array (CTA) handles a "chunk" of data within its assigned range for each iteration. Each chunk is handled independently from other chunks.

In [None]:
    init_offset = block_idx * elems_per_block    
    signal_block = signal[block_idx:block_idx+1]
    num_chunks = elems_per_block // chunk_size

    starts = range(init_offset, init_offset+elems_per_block, chunk_size)
    ends = range(init_offset+chunk_size, init_offset+elems_per_block+chunk_size, chunk_size)
    # Reduce phase
    for chunk, (start, end) in enumerate(zip(starts, ends)):
        src_block = src[start:end]
        dst_block = dst[start:end]
        if mype != 0:
            if thread_id == 0:
                signal_wait(signal_block, ComparisonType.CMP_GE, chunk + 1)
 
            cuda.syncthreads()
            for i in range(thread_id, chunk_size, num_threads):
                dst_block[i] = dst_block[i] + src_block[i]
            cuda.syncthreads()
        
        if thread_id == 0:
            src_data = src_block if mype == 0 else dst_block
            put_signal_nbi(dst_block, src_data,  
                           signal_block, uint64(1), SignalOp.SIGNAL_ADD, peer)

### Broadcast Phase

The broadcast phase is kicked off by the last PE's `put` instruction to PE0. Once PE0 receives the final result, a chain of `put`s are invoked following the PE order. Afterward, all PEs possess the final computed result.

![Broadcast-1](images/chapter-nvshmem4py-device/6.png)

In [None]:
    # Broadcast phase
    if thread_id == 0:
        for chunk, (start, end) in enumerate(zip(starts, ends)):
            dst_block = dst[start:end]
            if mype < npes - 1:  # Last pe already has the final result
                expected_val = (chunk + 1) if mype == 0 else (num_chunks + chunk + 1)
                signal_wait(signal_block, ComparisonType.CMP_GE, expected_val)
            
            if mype < npes - 2:
                put_signal_nbi(dst_block, dst_block,
                               signal_block, uint64(1), SignalOp.SIGNAL_ADD, peer)

## Full Code Example

In [None]:
from mpi4py import MPI
from cuda.core.experimental import Device, system, Stream

from numba import cuda, uint64

import nvshmem
import nvshmem.bindings
from nvshmem.core import SignalOp, ComparisonType
from nvshmem.core.device.numba import put_signal_nbi, signal_wait, my_pe, n_pes

@cuda.jit
def ring_reduce(dst, src, nreduce, signal, chunk_size):
    mype = my_pe()
    npes = n_pes()
    peer = (mype + 1) % npes

    thread_id = cuda.threadIdx.x
    num_threads = cuda.blockDim.x
    num_blocks = cuda.gridDim.x
    block_idx = cuda.blockIdx.x
    elems_per_block = nreduce // num_blocks

    # Change src, dst, nreduce, signal to what this block is going to process
    # Each CTA will work independently
    if elems_per_block * (block_idx + 1) > nreduce:
        return
    
    # Adjust pointers for this block
    init_offset = block_idx * elems_per_block
    
    signal_block = signal[block_idx:block_idx+1]

    num_chunks = elems_per_block // chunk_size

    starts = range(init_offset, init_offset+elems_per_block, chunk_size)
    ends = range(init_offset+chunk_size, init_offset+elems_per_block+chunk_size, chunk_size)
    # Reduce phase
    for chunk, (start, end) in enumerate(zip(starts, ends)):
        src_block = src[start:end]
        dst_block = dst[start:end]
        if mype != 0:
            if thread_id == 0:
                signal_wait(signal_block, ComparisonType.CMP_GE, chunk + 1)
 
            cuda.syncthreads()
            for i in range(thread_id, chunk_size, num_threads):
                dst_block[i] = dst_block[i] + src_block[i]
            cuda.syncthreads()
        
        if thread_id == 0:
            src_data = src_block if mype == 0 else dst_block
            put_signal_nbi(dst_block, src_data,  
                           signal_block, uint64(1), SignalOp.SIGNAL_ADD, peer)

    # if signal is printed here, it will be 0 for first and last PE, num_chunks for other PEs.

    # Broadcast phase
    if thread_id == 0:
        for chunk, (start, end) in enumerate(zip(starts, ends)):
            dst_block = dst[start:end]
            if mype < npes - 1:  # Last pe already has the final result
                expected_val = (chunk + 1) if mype == 0 else (num_chunks + chunk + 1)
                signal_wait(signal_block, ComparisonType.CMP_GE, expected_val)
            
            if mype < npes - 2:
                put_signal_nbi(dst_block, dst_block,
                               signal_block, uint64(1), SignalOp.SIGNAL_ADD, peer)
            

# Initialize MPI and NVSHMEM
local_rank_per_node = MPI.COMM_WORLD.Get_rank() % system.num_devices
dev = Device(local_rank_per_node)
dev.set_current()

nb_stream = cuda.stream() # WAR: Numba-CUDA takes numba stream object or int
cu_stream_ref = Stream.from_handle(nb_stream.handle.value)

nvshmem.core.init(
    device=dev,
    uid=None,
    rank=None,
    nranks=None,
    mpi_comm=MPI.COMM_WORLD,
    initializer_method="mpi",
)

mype = nvshmem.bindings.my_pe()
npes = nvshmem.bindings.n_pes()

# Test parameters
nreduce = 1024

num_blocks = 32
elems_per_block = nreduce // num_blocks
num_chunk_per_block = 4
chunk_size = elems_per_block // num_chunk_per_block

threads_per_block = 512 

# Allocate arrays
src = nvshmem.core.array((nreduce,), dtype="int32")
dst = nvshmem.core.array((nreduce,), dtype="int32")
signal = nvshmem.core.array((num_blocks,), dtype="uint64")

# Initialize data
for i in range(nreduce):
    src[i] = mype + 1

dst[:] = 0

# Initialize signal
for i in range(num_blocks):
    signal[i] = 0

# Launch kernel
ring_reduce[num_blocks, threads_per_block, nb_stream, 0](dst, src, nreduce, signal, chunk_size)

nvshmem.core.barrier(nvshmem.core.Teams.TEAM_WORLD, stream=cu_stream_ref)
dev.sync()

# Check results
expected_result = sum(range(1, npes + 1))
for i in range(nreduce):
    assert dst[i] == expected_result, f"PE {mype}: Mismatch at index {i}: got {dst[i]}, expected {expected_result}"
print(f"PE {mype}: Ring allreduce test passed")

# Clean up
nvshmem.core.free_array(src)
nvshmem.core.free_array(dst)
nvshmem.core.free_array(signal)
nvshmem.core.finalize()
