From 28a87570b44c546acf4dd29d7365a4d6b0784705 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 29 Jan 2025 18:46:33 -0500 Subject: [PATCH 1/6] chore: bigggg jll --- deps/ReactantExtra/API.cpp | 138 ++++++++++++++++++++++++++++------- deps/ReactantExtra/BUILD | 15 +++- deps/ReactantExtra/WORKSPACE | 2 +- 3 files changed, 125 insertions(+), 30 deletions(-) diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index be59893666..a7c0a49eb8 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -530,6 +530,18 @@ extern "C" void BufferToHost(PjRtBuffer *buffer, void *data) { extern "C" void FreeClient(PjRtClient *client) { delete client; } +extern "C" int64_t PjRtDeviceGetLocalDeviceId(PjRtDevice *device) { + return device->local_device_id().value(); +} + +extern "C" int64_t PjRtDeviceGetGlobalDeviceId(PjRtDevice *device) { + return device->global_device_id().value(); +} + +extern "C" int64_t PjRtDeviceGetLocalHardwareId(PjRtDevice *device) { + return device->local_hardware_id().value(); +} + #include "xla/service/custom_call_target_registry.h" extern "C" void RegisterCustomCallTarget(const char *name, void *address, const char *platform) { @@ -579,22 +591,30 @@ extern "C" MlirModule ConvertLLVMStrToMLIR(const char *lmod, MlirContext cctx) { } /* Note that this */ -extern "C" xla::PjRtLoadedExecutable * -ClientCompile(PjRtClient *client, MlirModule cmod, int device_ordinal, - int num_replicas, int num_partitions, - bool use_shardy_partitioner) { +extern "C" xla::PjRtLoadedExecutable *ClientCompile(PjRtClient *client, + MlirModule cmod, + int *global_ordinals, + int num_global_ordinals) { auto program = std::make_unique(cast(*unwrap(cmod))); CompileOptions options; - if (device_ordinal >= 0) { - options.executable_build_options.set_device_ordinal(device_ordinal); + // https://github.com/pytorch/xla/blob/8b2414094578e829b99a8383877c86d357eeb682/torch_xla/csrc/runtime/pjrt_computation_client.cc#L601 + int device_count = client->addressable_device_count(); + + options.executable_build_options.set_num_replicas(device_count); + options.executable_build_options.set_num_partitions(1); + + xla::DeviceAssignment device_assignment(device_count, 1); + for (int64_t device_id = 0; device_id < num_global_ordinals; ++device_id) { + int ordinal = global_ordinals[device_id]; + if (ordinal < 0) { + continue; + } + device_assignment(ordinal, 0) = device_id; } - options.executable_build_options.set_num_replicas(num_replicas); - options.executable_build_options.set_num_partitions(num_partitions); - options.executable_build_options.set_use_shardy_partitioner( - use_shardy_partitioner); + options.executable_build_options.set_device_assignment(device_assignment); auto addressable_devices = client->addressable_devices(); if (!addressable_devices.empty()) { @@ -605,8 +625,7 @@ ClientCompile(PjRtClient *client, MlirModule cmod, int device_ordinal, assert(device_ordinal < addressable_devices.size()); auto stats = addressable_devices[device_ordinal]->GetAllocatorStats(); if (stats.ok() && stats->bytes_limit) { - options.executable_build_options.set_device_memory_size( - *stats->bytes_limit); + options.executable_build_options.set_device_memory_size(*stats->bytes_limit); } } auto exec = @@ -623,12 +642,69 @@ extern "C" uint8_t FutureIsReady(FutureType *Future) { extern "C" void FutureAwait(FutureType *Future) { Future->Await(); } -extern "C" void XLAExecute(xla::PjRtLoadedExecutable *exec, int num_args, +extern "C" void XLAExecuteSharded(xla::PjRtLoadedExecutable *exec, int num_args, + PjRtBuffer **op_args, PjRtDevice *device, + uint8_t *is_arg_donatable, int num_results, + PjRtBuffer **op_results, uint8_t *futures, + FutureType **future_results) { + // Create a vector of PjRtBuffer* from the input array. + std::vector argument_handles(op_args, op_args + num_args); + + // Set up execution options. + ExecuteOptions options; + for (size_t i = 0; i < num_args; i++) { + if (!is_arg_donatable[i]) { + options.non_donatable_input_indices.insert(static_cast(i)); + } + } + options.untuple_result = true; + + // Optional future to hold asynchronous execution results. + std::optional> returned_future; + + auto results = MyValueOrThrow( + exec->ExecuteSharded(argument_handles, + device, options, returned_future, /*fill_future=*/true)); + + // Validate the number of results. + if (results.size() != num_results) { + llvm::errs() << "Error: results.size()=" << results.size() + << " does not match num_results=" << num_results << "\n"; + std::abort(); // Terminate if the number of results is incorrect. + } + + // Handle futures if they are returned. + if (returned_future.has_value()) { + *futures = true; + for (size_t i = 0; i < num_results; i++) { + future_results[i] = new FutureType(*returned_future); + } + } else { + *futures = false; + } + + // Release the results into the output array. + for (size_t i = 0; i < num_results; i++) { + op_results[i] = results[i].release(); + } +} + +extern "C" void XLAExecute(xla::PjRtLoadedExecutable *exec, int num_devices, int num_args, PjRtBuffer **op_args, uint8_t *is_arg_donatable, int num_results, PjRtBuffer **op_results, uint8_t *futures, FutureType **future_results) { - std::vector> argument_handles; - argument_handles.emplace_back(op_args, op_args + num_args); + // Ensure argument_handles is structured as num_devices x num_args + std::vector> argument_handles(num_devices); + + // Distribute arguments across devices + for (int device_idx = 0; device_idx < num_devices; ++device_idx) { + argument_handles[device_idx].reserve(num_args); + for (int arg_idx = 0; arg_idx < num_args; ++arg_idx) { + // Assuming op_args is a flat array of size num_devices * num_args + // where arguments for each device are contiguous + argument_handles[device_idx].push_back(op_args[device_idx * num_args + arg_idx]); + } + } ExecuteOptions options; @@ -637,31 +713,43 @@ extern "C" void XLAExecute(xla::PjRtLoadedExecutable *exec, int num_args, options.non_donatable_input_indices.insert((int)i); } options.untuple_result = true; + std::optional> returned_futures; auto results = MyValueOrThrow( exec->Execute(static_cast>>( argument_handles), options, returned_futures)); - assert(results.size() == 1); + assert(results.size() == num_devices); - if (results[0].size() != num_results) { - llvm::errs() << " results.size()=" << results.size() - << " num_results=" << num_results << "\n"; + for (int device_idx = 0; device_idx < num_devices; ++device_idx) { + if (results[device_idx].size() != num_results) { + llvm::errs() << " results[" << device_idx << "].size()=" << results[device_idx].size() + << " num_results=" << num_results << "\n"; + } + assert(results[device_idx].size() == num_results); } - assert(results[0].size() == num_results); + + // Handle returned futures if (returned_futures) { *futures = true; - assert(returned_futures->size() == num_results); - for (size_t i = 0; i < num_results; i++) { - future_results[i] = new FutureType((*returned_futures)[i]); + assert(returned_futures->size() == num_devices * num_results); + for (int device_idx = 0; device_idx < num_devices; ++device_idx) { + for (int result_idx = 0; result_idx < num_results; ++result_idx) { + int flat_index = device_idx * num_results + result_idx; + future_results[flat_index] = new FutureType((*returned_futures)[flat_index]); + } } } else { *futures = false; } - for (size_t i = 0; i < num_results; i++) { - op_results[i] = results[0][i].release(); + // Copy results into the output buffers + for (int device_idx = 0; device_idx < num_devices; ++device_idx) { + for (int result_idx = 0; result_idx < num_results; ++result_idx) { + int flat_index = device_idx * num_results + result_idx; + op_results[flat_index] = results[device_idx][result_idx].release(); + } } } diff --git a/deps/ReactantExtra/BUILD b/deps/ReactantExtra/BUILD index 62c647cdbc..2fce65da28 100644 --- a/deps/ReactantExtra/BUILD +++ b/deps/ReactantExtra/BUILD @@ -360,8 +360,9 @@ cc_library( ], ) + [ - "@enzyme_ad//src/enzyme_ad/jax:RegistryUtils.cpp", - "@enzyme_ad//src/enzyme_ad/jax:gpu.cc", + "@enzyme_ad//src/enzyme_ad/jax:RegistryUtils.cpp", + "@enzyme_ad//src/enzyme_ad/jax:gpu.cc", + "@enzyme_ad//src/enzyme_ad/jax:cpu.cc", # "@com_google_protobuf//:src/google/protobuf/io/coded_stream.cc", # "@xla//xla:xla.pb.cc", "@xla//xla:xla_data.pb.cc", @@ -448,7 +449,13 @@ cc_library( "-Wl,-exported_symbol,_ProfilerActivityEnd", "-Wl,-exported_symbol,_ReactantFuncSetArgAttr", "-Wl,-exported_symbol,_ReactantCudaDriverGetVersion", -"-Wl,-exported_symbol,_ReactantLLVMParseCommandLineOptions" +"-Wl,-exported_symbol,_ReactantLLVMParseCommandLineOptions", +"-Wl,-exported_symbol,_ReactantLLVMParseCommandLineOptions", +"-Wl,-exported_symbol,_PjRtDeviceGetLocalDeviceId", +"-Wl,-exported_symbol,_PjRtDeviceGetGlobalDeviceId", +"-Wl,-exported_symbol,_PjRtDeviceGetLocalHardwareId", +"-Wl,-exported_symbol,_XLAExecuteSharded", +"-Wl,-exported_symbol,_ClientGetPlatformName", ]}), deps = [ "@enzyme//:EnzymeMLIR", @@ -566,9 +573,9 @@ cc_library( "@xla//xla/service/gpu:gpu_transfer_manager", "@xla//xla/stream_executor:kernel", "@xla//xla/backends/profiler/gpu:device_tracer", + "@xla//xla/stream_executor/cuda:cuda_compute_capability_proto_cc_impl", ], "//conditions:default": [ - "@xla//xla/stream_executor/cuda:cuda_compute_capability_proto_cc_impl", ], }) + if_rocm([ "@xla//xla/service/gpu:amdgpu_compiler", diff --git a/deps/ReactantExtra/WORKSPACE b/deps/ReactantExtra/WORKSPACE index cce06d553f..935efdbf42 100644 --- a/deps/ReactantExtra/WORKSPACE +++ b/deps/ReactantExtra/WORKSPACE @@ -9,7 +9,7 @@ http_archive( urls = ["https://github.com/wsmoses/nsync/archive/{commit}.tar.gz".format(commit = NSYNC_COMMIT)], ) -ENZYMEXLA_COMMIT = "c38ca3f187ef11de6b2292f3cc55c5eb60530d15" +ENZYMEXLA_COMMIT = "d89468ed883ca18c04346eec10f784bbe2b754fc" ENZYMEXLA_SHA256 = "" http_archive( From b892263e0047facd69541015112c1036b83b7c26 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 29 Jan 2025 18:49:37 -0500 Subject: [PATCH 2/6] fix: restore the old API --- deps/ReactantExtra/API.cpp | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index a7c0a49eb8..7cce17753a 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -689,10 +689,13 @@ extern "C" void XLAExecuteSharded(xla::PjRtLoadedExecutable *exec, int num_args, } } -extern "C" void XLAExecute(xla::PjRtLoadedExecutable *exec, int num_devices, int num_args, +extern "C" void XLAExecute(xla::PjRtLoadedExecutable *exec, int num_args, PjRtBuffer **op_args, uint8_t *is_arg_donatable, int num_results, PjRtBuffer **op_results, uint8_t *futures, FutureType **future_results) { + auto client = exec->client(); + int num_devices = client->addressable_device_count(); + // Ensure argument_handles is structured as num_devices x num_args std::vector> argument_handles(num_devices); From fe33f20e3b40d845674d16a6f005fed0858a3ac4 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 29 Jan 2025 19:02:39 -0500 Subject: [PATCH 3/6] fix: add shardy c headers --- deps/ReactantExtra/API.cpp | 5 ++++- deps/ReactantExtra/BUILD | 1 + 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index 7cce17753a..adf9e8c80c 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -33,7 +33,6 @@ #include "mlir/InitAllPasses.h" #include "mlir/Pass/PassRegistry.h" #include "mlir/Transforms/Passes.h" -#include "shardy/dialect/sdy/ir/dialect.h" #include "src/enzyme_ad/jax/Dialect/Dialect.h" #include "src/enzyme_ad/jax/Implementations/XLADerivatives.h" #include "src/enzyme_ad/jax/Passes/Passes.h" @@ -69,6 +68,10 @@ #include "llvm-c/TargetMachine.h" +// shardy +#include "shardy/dialect/sdy/ir/dialect.h" +#include "shardy/integrations/c/attributes.h" + // IFRT #include "xla/python/ifrt/array.h" #include "xla/python/ifrt/client.h" diff --git a/deps/ReactantExtra/BUILD b/deps/ReactantExtra/BUILD index 2fce65da28..cc37246d51 100644 --- a/deps/ReactantExtra/BUILD +++ b/deps/ReactantExtra/BUILD @@ -557,6 +557,7 @@ cc_library( "@xla//xla/mlir/utils:type_util", "@stablehlo//:stablehlo_capi_objects", "@stablehlo//:chlo_capi_objects", + "@shardy//shardy/integrations/c:sdy_capi_objects", "@com_google_absl//absl/hash:hash", "@com_google_absl//absl/log:initialize", "@com_google_absl//absl/log:globals", From fd2680e4864c748b9c16ad737a4eb455a59131bd Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 29 Jan 2025 19:16:43 -0500 Subject: [PATCH 4/6] fix: section for import --- deps/ReactantExtra/BUILD | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deps/ReactantExtra/BUILD b/deps/ReactantExtra/BUILD index cc37246d51..2dffd18474 100644 --- a/deps/ReactantExtra/BUILD +++ b/deps/ReactantExtra/BUILD @@ -565,6 +565,7 @@ cc_library( "@llvm-project//mlir:CAPILLVMObjects", "@jax//jaxlib/mosaic:tpu_dialect_capi_objects", "@jax//jaxlib/triton:triton_dialect_capi_objects", + "@xla//xla/stream_executor/cuda:cuda_compute_capability_proto_cc_impl", ] + select({ "@xla//xla/tsl:is_cuda_enabled_and_oss":[ "@xla//xla/stream_executor/cuda:all_runtime", @@ -574,7 +575,6 @@ cc_library( "@xla//xla/service/gpu:gpu_transfer_manager", "@xla//xla/stream_executor:kernel", "@xla//xla/backends/profiler/gpu:device_tracer", - "@xla//xla/stream_executor/cuda:cuda_compute_capability_proto_cc_impl", ], "//conditions:default": [ ], From 7d1c0310ec190d45f9d8a1ca1d1df78fdbbedb95 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 29 Jan 2025 19:18:16 -0500 Subject: [PATCH 5/6] Update deps/ReactantExtra/BUILD MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Mosè Giordano <765740+giordano@users.noreply.github.com> --- deps/ReactantExtra/BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/deps/ReactantExtra/BUILD b/deps/ReactantExtra/BUILD index 2dffd18474..561b69f556 100644 --- a/deps/ReactantExtra/BUILD +++ b/deps/ReactantExtra/BUILD @@ -456,6 +456,7 @@ cc_library( "-Wl,-exported_symbol,_PjRtDeviceGetLocalHardwareId", "-Wl,-exported_symbol,_XLAExecuteSharded", "-Wl,-exported_symbol,_ClientGetPlatformName", +"-Wl,-exported_symbol,_RegisterEnzymeXLACPUHandler", ]}), deps = [ "@enzyme//:EnzymeMLIR", From 9336cfcfef0edecf6ba1f2b49413b06f6938bbc5 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 29 Jan 2025 19:38:02 -0500 Subject: [PATCH 6/6] Update deps/ReactantExtra/BUILD MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Mosè Giordano <765740+giordano@users.noreply.github.com> --- deps/ReactantExtra/BUILD | 1 - 1 file changed, 1 deletion(-) diff --git a/deps/ReactantExtra/BUILD b/deps/ReactantExtra/BUILD index 561b69f556..72bdcb7ff8 100644 --- a/deps/ReactantExtra/BUILD +++ b/deps/ReactantExtra/BUILD @@ -450,7 +450,6 @@ cc_library( "-Wl,-exported_symbol,_ReactantFuncSetArgAttr", "-Wl,-exported_symbol,_ReactantCudaDriverGetVersion", "-Wl,-exported_symbol,_ReactantLLVMParseCommandLineOptions", -"-Wl,-exported_symbol,_ReactantLLVMParseCommandLineOptions", "-Wl,-exported_symbol,_PjRtDeviceGetLocalDeviceId", "-Wl,-exported_symbol,_PjRtDeviceGetGlobalDeviceId", "-Wl,-exported_symbol,_PjRtDeviceGetLocalHardwareId",