From 2292ba9038e316f27e942d36434c256d491b19dd Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 16 Oct 2025 14:28:49 -0500 Subject: [PATCH 1/2] feat: API to get device properties for cuda --- deps/ReactantExtra/API.cpp | 59 ++++++++++++++++++++++++++++++++++++++ deps/ReactantExtra/BUILD | 1 + 2 files changed, 60 insertions(+) diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index 5369bed1ef..c4c7fcf97c 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -730,9 +730,29 @@ std::vector row_major(int64_t dim) { } static void noop() {} +struct DeviceProperties { + size_t totalGlobalMem; + size_t sharedMemPerBlock; + int regsPerBlock; + int warpSize; + int maxThreadsPerBlock; + int maxThreadsDim[3]; + int maxGridSize[3]; + int clockRate; + size_t totalConstMem; + int major; + int minor; + int multiProcessorCount; + int canMapHostMemory; + int computeMode; + int l2CacheSize; + int maxThreadsPerMultiProcessor; +}; + #ifdef REACTANT_CUDA #include "third_party/gpus/cuda/include/cuda.h" +#include "third_party/gpus/cuda/include/cuda_runtime.h" REACTANT_ABI int32_t ReactantCudaDriverGetVersion() { int32_t data; @@ -769,6 +789,33 @@ REACTANT_ABI int32_t ReactantCudaDeviceGetWarpSizeInThreads() { return warpSize; } +REACTANT_ABI void ReactantCudaDeviceGetProperties(DeviceProperties *jlprops, + int32_t device_id) { + cudaDeviceProp props; + ReactantHandleCuResult(cudaGetDeviceProperties(&props, device_id)); + + jlprops->totalGlobalMem = props.totalGlobalMem; + jlprops->sharedMemPerBlock = props.sharedMemPerBlock; + jlprops->regsPerBlock = props.regsPerBlock; + jlprops->warpSize = props.warpSize; + jlprops->maxThreadsPerBlock = props.maxThreadsPerBlock; + jlprops->maxThreadsDim[0] = props.maxThreadsDim[0]; + jlprops->maxThreadsDim[1] = props.maxThreadsDim[1]; + jlprops->maxThreadsDim[2] = props.maxThreadsDim[2]; + jlprops->maxGridSize[0] = props.maxGridSize[0]; + jlprops->maxGridSize[1] = props.maxGridSize[1]; + jlprops->maxGridSize[2] = props.maxGridSize[2]; + jlprops->clockRate = props.clockRate; + jlprops->totalConstMem = props.totalConstMem; + jlprops->major = props.major; + jlprops->minor = props.minor; + jlprops->multiProcessorCount = props.multiProcessorCount; + jlprops->canMapHostMemory = props.canMapHostMemory; + jlprops->computeMode = props.computeMode; + jlprops->l2CacheSize = props.l2CacheSize; + jlprops->maxThreadsPerMultiProcessor = props.maxThreadsPerMultiProcessor; +} + #else REACTANT_ABI int32_t ReactantCudaDriverGetVersion() { return 0; } @@ -781,6 +828,9 @@ REACTANT_ABI int32_t ReactantCudaDeviceGetComputeCapalilityMinor() { return 0; } REACTANT_ABI int32_t ReactantCudaDeviceGetWarpSizeInThreads() { return 0; } +REACTANT_ABI void ReactantCudaDeviceGetProperties(DeviceProperties *jlprops, + int32_t device_id) {} + #endif REACTANT_ABI void *UnsafeBufferPointer(PjRtBuffer *buffer) { @@ -1955,6 +2005,15 @@ REACTANT_ABI bool ifrt_DeviceIsAddressable(ifrt::Device *device) { return device->IsAddressable(); } +REACTANT_ABI int64_t ifrt_DeviceGetLocalHardwareId(ifrt::Device *device) { + if (!llvm::isa(device)) { + ReactantThrowError( + "ifrt_device_get_allocator_stats: only supported for ifrt-pjrt."); + } + auto ifrt_pjrt_device = llvm::dyn_cast(device); + return ifrt_pjrt_device->pjrt_device()->local_hardware_id().value(); +} + static xla::ifrt::RCReferenceWrapper ifrt_CreateDeviceListFromDevices(ifrt::Client *client, ifrt::Device **device_list, diff --git a/deps/ReactantExtra/BUILD b/deps/ReactantExtra/BUILD index 36134e154f..4f3ce35cfe 100644 --- a/deps/ReactantExtra/BUILD +++ b/deps/ReactantExtra/BUILD @@ -983,6 +983,7 @@ cc_library( "-Wl,-exported_symbol,_ReactantCudaDeviceGetComputeCapalilityMajor", "-Wl,-exported_symbol,_ReactantCudaDeviceGetComputeCapalilityMinor", "-Wl,-exported_symbol,_ReactantCudaDeviceGetWarpSizeInThreads", + "-Wl,-exported_symbol,_ReactantCudaDeviceGetProperties", "-Wl,-exported_symbol,_PjRtDeviceGetLocalDeviceId", "-Wl,-exported_symbol,_PjRtDeviceGetGlobalDeviceId", "-Wl,-exported_symbol,_PjRtDeviceGetLocalHardwareId", From 7b881aadbbcadf8bcf62170508acd10c7d350fd1 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 16 Oct 2025 14:31:53 -0500 Subject: [PATCH 2/2] fix: incorrect error msg --- deps/ReactantExtra/API.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index c4c7fcf97c..80d43090a5 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -2008,7 +2008,7 @@ REACTANT_ABI bool ifrt_DeviceIsAddressable(ifrt::Device *device) { REACTANT_ABI int64_t ifrt_DeviceGetLocalHardwareId(ifrt::Device *device) { if (!llvm::isa(device)) { ReactantThrowError( - "ifrt_device_get_allocator_stats: only supported for ifrt-pjrt."); + "ifrt_DeviceGetLocalHardwareId: only supported for ifrt-pjrt."); } auto ifrt_pjrt_device = llvm::dyn_cast(device); return ifrt_pjrt_device->pjrt_device()->local_hardware_id().value();