Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
17 changes: 8 additions & 9 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
rev: v6.0.0
hooks:
- id: check-added-large-files
- id: check-ast
Expand All @@ -25,28 +25,28 @@ repos:
exclude_types: [jupyter]

- repo: https://github.com/abravalheri/validate-pyproject
rev: v0.24
rev: v0.24.1
hooks:
- id: validate-pyproject
additional_dependencies: ["validate-pyproject-schema-store[all]"]

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.11.2
rev: v0.13.0
hooks:
- id: ruff
args: [--fix]
- id: ruff-format

# Mypy: static type checking
- repo: https://github.com/pre-commit/mirrors-mypy
rev: "v1.15.0"
rev: "v1.18.1"
hooks:
- id: mypy
# Envorce only one source of configuration.
args: ["--config-file", "pyproject.toml"]
additional_dependencies:
- cuda-core
- cuda-bindings>=12.9.1,<13
- cuda-bindings>=12.9.2,<13
- cupy-cuda12x
- mpi4py>=4.1.0
- numba
Expand All @@ -56,7 +56,6 @@ repos:
- scipy
- torch
- types-cffi
- types-pywin32
- invoke
- cython>=3.0.4,!=3.1.0,!=3.1.1
- tomli
Expand All @@ -78,7 +77,7 @@ repos:

# Security: secrets
- repo: https://github.com/gitleaks/gitleaks
rev: v8.24.0
rev: v8.28.0
hooks:
- id: gitleaks

Expand All @@ -91,13 +90,13 @@ repos:

# Shell script linter
- repo: https://github.com/shellcheck-py/shellcheck-py
rev: "v0.10.0.1"
rev: "v0.11.0.1"
hooks:
- id: shellcheck

# Lint: Markdown
- repo: https://github.com/igorshubovych/markdownlint-cli
rev: v0.44.0
rev: v0.45.0
hooks:
- id: markdownlint
# Setting up node version explicitly
Expand Down
183 changes: 130 additions & 53 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,24 @@

# nvmath-python: NVIDIA Math Libraries for the Python Ecosystem

nvmath-python brings the power of the NVIDIA math libraries to the Python ecosystem. The
package aims to provide intuitive pythonic APIs that provide users full access to all the
nvmath-python brings the power of the NVIDIA math libraries to the Python ecosystem.
The package aims to provide intuitive pythonic APIs giving users full access to all
features offered by NVIDIA's libraries in a variety of execution spaces. nvmath-python works
seamlessly with existing Python array/tensor frameworks and focuses on providing
functionality that is missing from those frameworks.

## Some Examples

Using the nvmath-python API allows access to all parameters of the underlying NVIDIA
cuBLASLt library. Some of these parameters are unavailable in other wrappings of NVIDIA's
C-API libraries.
Below are a few representative examples showcasing the three main categories of
features nvmath-python offers: host, device, and distributed APIs.

### Host APIs

Host APIs are called from host code but can execute in any supported execution
space (CPU or GPU). The following example shows how to compute a matrix multiplication
on CuPy matrices. Using the nvmath-python API allows access to *all* parameters
of the underlying NVIDIA cuBLASLt library, a distinguishing feature of nvmath-python
from other wrappings of NVIDIA's C-API libraries.

```python
import cupy as cp
Expand Down Expand Up @@ -61,6 +68,42 @@ print(f"Input types = {type(a), type(b)}, device = {a.device, b.device}")
print(f"Result type = {type(result)}, device = {result.device}")
```

nvmath-python provides the ability to write custom prologs and epilogs for FFT functions as
Python functions and compile them to LTO-IR. For example, to have unitary scaling for an
FFT, we can define an epilog which rescales the output by `1/sqrt(N)`.

```python
import cupy as cp
import nvmath
import math

# Create the data for the batched 1-D FFT.
B, N = 256, 1024
a = cp.random.rand(B, N, dtype=cp.float64) + 1j * cp.random.rand(B, N, dtype=cp.float64)

# Compute the normalization factor for unitary transforms
norm_factor = 1.0 / math.sqrt(N)

# Define the epilog function for the FFT.
def rescale(data_out, offset, data, user_info, unused):
data_out[offset] = data * norm_factor

# Compile the epilog to LTO-IR.
with cp.cuda.Device():
epilog = nvmath.fft.compile_epilog(rescale, "complex128", "complex128")

# Perform the forward FFT, applying the filter as a epilog...
r = nvmath.fft.fft(a, axes=[-1], epilog={"ltoir": epilog})

# Finally, we can test that the fused FFT run result matches the result of separate
# calls
s = cp.fft.fftn(a, axes=[-1], norm="ortho")

assert cp.allclose(r, s)
```

### Device-side APIs

nvmath-python exposes NVIDIA's device-side (Dx) APIs. This allows developers to call NVIDIA
library functions inside their custom device kernels. For example, a numba jit function can
call cuFFT in order to implement FFT-based convolution.
Expand Down Expand Up @@ -91,7 +134,6 @@ def main():
ffts_per_block=ffts_per_block,
elements_per_thread=2,
execution="Block",
compiler="numba",
)
FFT_inv = fft(
fft_type="c2c",
Expand All @@ -101,41 +143,35 @@ def main():
ffts_per_block=ffts_per_block,
elements_per_thread=2,
execution="Block",
compiler="numba",
)

value_type = FFT_fwd.value_type
storage_size = FFT_fwd.storage_size
shared_memory_size = FFT_fwd.shared_memory_size
fft_stride = FFT_fwd.stride
ept = FFT_fwd.elements_per_thread
block_dim = FFT_fwd.block_dim

# Define a numba jit function targeting CUDA devices
@cuda.jit(link=FFT_fwd.files + FFT_inv.files)
@cuda.jit
def f(signal, filter):

thread_data = cuda.local.array(shape=(storage_size,), dtype=value_type)
shared_mem = cuda.shared.array(shape=(0,), dtype=value_type)
thread_data = cuda.local.array(
shape=(FFT_fwd.storage_size,), dtype=FFT_fwd.value_type,
)
shared_mem = cuda.shared.array(shape=(0,), dtype=FFT_fwd.value_type)

fft_id = (cuda.blockIdx.x * ffts_per_block) + cuda.threadIdx.y
if(fft_id >= batch_size):
return
offset = cuda.threadIdx.x

for i in range(ept):
thread_data[i] = signal[fft_id, offset + i * fft_stride]
for i in range(FFT_fwd.elements_per_thread):
thread_data[i] = signal[fft_id, offset + i * FFT_fwd.stride]

# Call the cuFFTDx FFT function from *inside* your custom function
FFT_fwd(thread_data, shared_mem)

for i in range(ept):
thread_data[i] = thread_data[i] * filter[fft_id, offset + i * fft_stride]
for i in range(FFT_fwd.elements_per_thread):
thread_data[i] *= filter[fft_id, offset + i * FFT_fwd.stride]

FFT_inv(thread_data, shared_mem)

for i in range(ept):
signal[fft_id, offset + i * fft_stride] = thread_data[i]
for i in range(FFT_fwd.elements_per_thread):
signal[fft_id, offset + i * FFT_fwd.stride] = thread_data[i]


data = random_complex((ffts_per_block, size), np.float32)
Expand All @@ -144,7 +180,7 @@ def main():
data_d = cuda.to_device(data)
filter_d = cuda.to_device(filter)

f[1, block_dim, 0, shared_memory_size](data_d, filter_d)
f[1, FFT_fwd.block_dim, 0, FFT_fwd.shared_memory_size](data_d, filter_d)
cuda.synchronize()

data_test = data_d.copy_to_host()
Expand All @@ -159,38 +195,79 @@ if __name__ == "__main__":
main()
```

nvmath-python provides the ability to write custom prologs and epilogs for FFT functions as
a Python functions and compiled them LTO-IR. For example, to have unitary scaling for an
FFT, we can define an epilog which rescales the output by 1/sqrt(N).
### Distributed APIs

Distributed APIs are called from host code but execute on a distributed
(multi-node multi-GPU) system. The following example shows the use of the
function-form distributed FFT with CuPy ndarrays:

```python
import cupy as cp
import nvmath
import math

# Create the data for the batched 1-D FFT.
B, N = 256, 1024
a = cp.random.rand(B, N, dtype=cp.float64) + 1j * cp.random.rand(B, N, dtype=cp.float64)

# Compute the normalization factor for unitary transforms
norm_factor = 1.0 / math.sqrt(N)

# Define the epilog function for the FFT.
def rescale(data_out, offset, data, user_info, unused):
data_out[offset] = data * norm_factor

# Compile the epilog to LTO-IR.
with cp.cuda.Device():
epilog = nvmath.fft.compile_epilog(rescale, "complex128", "complex128")

# Perform the forward FFT, applying the filter as a epilog...
r = nvmath.fft.fft(a, axes=[-1], epilog={"ltoir": epilog})

# Finally, we can test that the fused FFT run result matches the result of separate
# calls
s = cp.fft.fftn(a, axes=[-1], norm="ortho")

assert cp.allclose(r, s)
from mpi4py import MPI

import nvmath.distributed
from nvmath.distributed.distribution import Slab

# Initialize nvmath.distributed.
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
nranks = comm.Get_size()
device_id = rank % cp.cuda.runtime.getDeviceCount()
nvmath.distributed.initialize(device_id, comm, backends=["nvshmem"])

# The global 3-D FFT size is (512, 256, 512).
# In this example, the input data is distributed across processes according to
# the cuFFTMp Slab distribution on the X axis.
shape = 512 // nranks, 256, 512

# cuFFTMp uses the NVSHMEM PGAS model for distributed computation, which requires GPU
# operands to be on the symmetric heap.
a = nvmath.distributed.allocate_symmetric_memory(shape, cp, dtype=cp.complex128)
# a is a cupy ndarray and can be operated on using in-place cupy operations.
with cp.cuda.Device(device_id):
a[:] = cp.random.rand(*shape, dtype=cp.float64) + 1j *
cp.random.rand(*shape, dtype=cp.float64)

# Forward FFT.
# In this example, the forward FFT operand is distributed according
# to Slab.X distribution. With reshape=False, the FFT result will be
# distributed according to Slab.Y distribution.
b = nvmath.distributed.fft.fft(a, distribution=Slab.X, options={"reshape": False})

# Distributed FFT performs computations in-place. The result is stored in the same
# buffer as operand a. Note, however, that operand b has a different shape (due
# to Slab.Y distribution).
if rank == 0:
print(f"Shape of a on rank {rank} is {a.shape}")
print(f"Shape of b on rank {rank} is {b.shape}")

# Inverse FFT.
# Recall from previous transform that the inverse FFT operand is distributed according
# to Slab.Y. With reshape=False, the inverse FFT result will be distributed according
# to Slab.X distribution.
c = nvmath.distributed.fft.ifft(b, distribution=Slab.Y, options={"reshape": False})

# The shape of c is the same as a (due to Slab.X distribution). Once again, note that
# a, b and c are sharing the same symmetric memory buffer (distributed FFT operations
# are in-place).
if rank == 0:
print(f"Shape of c on rank {rank} is {c.shape}")

# Synchronize the default stream
with cp.cuda.Device(device_id):
cp.cuda.get_current_stream().synchronize()

if rank == 0:
print(f"Input type = {type(a)}, device = {a.device}")
print(f"FFT output type = {type(b)}, device = {b.device}")
print(f"IFFT output type = {type(c)}, device = {c.device}")

# GPU operands on the symmetric heap are not garbage-collected and the user is
# responsible for freeing any that they own (this deallocation is a collective
# operation that must be called by all processes at the same point in the execution).
# All cuFFTMp operations are inplace (a, b, and c share the same memory buffer), so
# we take care to only free the buffer once.
nvmath.distributed.free_symmetric_memory(a)
```

## License
Expand Down
16 changes: 15 additions & 1 deletion builder/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,21 @@ def check_path(header):

def decide_lib_name(ext_name):
# TODO: move the record of the supported lib list elsewhere?
for lib in ("cublas", "cusolver", "cufftMp", "cufft", "cusparse", "curand", "nvpl", "nvshmem", "mathdx", "cudss"):
for lib in (
"cublasMp",
"cublas",
"cusolver",
"cufftMp",
"cufft",
"cusparse",
"curand",
"nvpl",
"nvshmem",
"nccl",
"mathdx",
"cudss",
"cutensor",
):
if lib in ext_name:
return lib
else:
Expand Down
4 changes: 4 additions & 0 deletions docs/sphinx/_static/switcher.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
"version": "latest",
"url": "https://docs.nvidia.com/cuda/nvmath-python/latest"
},
{
"version": "0.7.0",
"url": "https://docs.nvidia.com/cuda/nvmath-python/0.7.0"
},
{
"version": "0.6.0",
"url": "https://docs.nvidia.com/cuda/nvmath-python/0.6.0"
Expand Down
Loading