diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index adf9e8c80c..f270e8ef23 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -597,7 +597,8 @@ extern "C" MlirModule ConvertLLVMStrToMLIR(const char *lmod, MlirContext cctx) { extern "C" xla::PjRtLoadedExecutable *ClientCompile(PjRtClient *client, MlirModule cmod, int *global_ordinals, - int num_global_ordinals) { + int num_global_ordinals, + const char* xla_gpu_cuda_data_dir) { auto program = std::make_unique(cast(*unwrap(cmod))); @@ -608,6 +609,7 @@ extern "C" xla::PjRtLoadedExecutable *ClientCompile(PjRtClient *client, options.executable_build_options.set_num_replicas(device_count); options.executable_build_options.set_num_partitions(1); + options.executable_build_options.mutable_debug_options()->set_xla_gpu_cuda_data_dir(xla_gpu_cuda_data_dir); xla::DeviceAssignment device_assignment(device_count, 1); for (int64_t device_id = 0; device_id < num_global_ordinals; ++device_id) {