Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions deps/ReactantExtra/API.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -730,9 +730,29 @@ std::vector<int64_t> row_major(int64_t dim) {
}
static void noop() {}

struct DeviceProperties {
size_t totalGlobalMem;
size_t sharedMemPerBlock;
int regsPerBlock;
int warpSize;
int maxThreadsPerBlock;
int maxThreadsDim[3];
int maxGridSize[3];
int clockRate;
size_t totalConstMem;
int major;
int minor;
int multiProcessorCount;
int canMapHostMemory;
int computeMode;
int l2CacheSize;
int maxThreadsPerMultiProcessor;
};

#ifdef REACTANT_CUDA

#include "third_party/gpus/cuda/include/cuda.h"
#include "third_party/gpus/cuda/include/cuda_runtime.h"

REACTANT_ABI int32_t ReactantCudaDriverGetVersion() {
int32_t data;
Expand Down Expand Up @@ -769,6 +789,33 @@ REACTANT_ABI int32_t ReactantCudaDeviceGetWarpSizeInThreads() {
return warpSize;
}

REACTANT_ABI void ReactantCudaDeviceGetProperties(DeviceProperties *jlprops,
int32_t device_id) {
cudaDeviceProp props;
ReactantHandleCuResult(cudaGetDeviceProperties(&props, device_id));

jlprops->totalGlobalMem = props.totalGlobalMem;
jlprops->sharedMemPerBlock = props.sharedMemPerBlock;
jlprops->regsPerBlock = props.regsPerBlock;
jlprops->warpSize = props.warpSize;
jlprops->maxThreadsPerBlock = props.maxThreadsPerBlock;
jlprops->maxThreadsDim[0] = props.maxThreadsDim[0];
jlprops->maxThreadsDim[1] = props.maxThreadsDim[1];
jlprops->maxThreadsDim[2] = props.maxThreadsDim[2];
jlprops->maxGridSize[0] = props.maxGridSize[0];
jlprops->maxGridSize[1] = props.maxGridSize[1];
jlprops->maxGridSize[2] = props.maxGridSize[2];
jlprops->clockRate = props.clockRate;
jlprops->totalConstMem = props.totalConstMem;
jlprops->major = props.major;
jlprops->minor = props.minor;
jlprops->multiProcessorCount = props.multiProcessorCount;
jlprops->canMapHostMemory = props.canMapHostMemory;
jlprops->computeMode = props.computeMode;
jlprops->l2CacheSize = props.l2CacheSize;
jlprops->maxThreadsPerMultiProcessor = props.maxThreadsPerMultiProcessor;
}

#else

REACTANT_ABI int32_t ReactantCudaDriverGetVersion() { return 0; }
Expand All @@ -781,6 +828,9 @@ REACTANT_ABI int32_t ReactantCudaDeviceGetComputeCapalilityMinor() { return 0; }

REACTANT_ABI int32_t ReactantCudaDeviceGetWarpSizeInThreads() { return 0; }

REACTANT_ABI void ReactantCudaDeviceGetProperties(DeviceProperties *jlprops,
int32_t device_id) {}

#endif

REACTANT_ABI void *UnsafeBufferPointer(PjRtBuffer *buffer) {
Expand Down Expand Up @@ -1955,6 +2005,15 @@ REACTANT_ABI bool ifrt_DeviceIsAddressable(ifrt::Device *device) {
return device->IsAddressable();
}

REACTANT_ABI int64_t ifrt_DeviceGetLocalHardwareId(ifrt::Device *device) {
if (!llvm::isa<ifrt::PjRtDevice>(device)) {
ReactantThrowError(
"ifrt_DeviceGetLocalHardwareId: only supported for ifrt-pjrt.");
}
auto ifrt_pjrt_device = llvm::dyn_cast<ifrt::PjRtDevice>(device);
return ifrt_pjrt_device->pjrt_device()->local_hardware_id().value();
}

static xla::ifrt::RCReferenceWrapper<ifrt::DeviceList>
ifrt_CreateDeviceListFromDevices(ifrt::Client *client,
ifrt::Device **device_list,
Expand Down
1 change: 1 addition & 0 deletions deps/ReactantExtra/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -983,6 +983,7 @@ cc_library(
"-Wl,-exported_symbol,_ReactantCudaDeviceGetComputeCapalilityMajor",
"-Wl,-exported_symbol,_ReactantCudaDeviceGetComputeCapalilityMinor",
"-Wl,-exported_symbol,_ReactantCudaDeviceGetWarpSizeInThreads",
"-Wl,-exported_symbol,_ReactantCudaDeviceGetProperties",
"-Wl,-exported_symbol,_PjRtDeviceGetLocalDeviceId",
"-Wl,-exported_symbol,_PjRtDeviceGetGlobalDeviceId",
"-Wl,-exported_symbol,_PjRtDeviceGetLocalHardwareId",
Expand Down
Loading