Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 119 additions & 25 deletions deps/ReactantExtra/API.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
#include "mlir/InitAllPasses.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Transforms/Passes.h"
#include "shardy/dialect/sdy/ir/dialect.h"
#include "src/enzyme_ad/jax/Dialect/Dialect.h"
#include "src/enzyme_ad/jax/Implementations/XLADerivatives.h"
#include "src/enzyme_ad/jax/Passes/Passes.h"
Expand Down Expand Up @@ -69,6 +68,10 @@

#include "llvm-c/TargetMachine.h"

// shardy
#include "shardy/dialect/sdy/ir/dialect.h"
#include "shardy/integrations/c/attributes.h"

// IFRT
#include "xla/python/ifrt/array.h"
#include "xla/python/ifrt/client.h"
Expand Down Expand Up @@ -530,6 +533,18 @@ extern "C" void BufferToHost(PjRtBuffer *buffer, void *data) {

extern "C" void FreeClient(PjRtClient *client) { delete client; }

extern "C" int64_t PjRtDeviceGetLocalDeviceId(PjRtDevice *device) {
return device->local_device_id().value();
}

extern "C" int64_t PjRtDeviceGetGlobalDeviceId(PjRtDevice *device) {
return device->global_device_id().value();
}

extern "C" int64_t PjRtDeviceGetLocalHardwareId(PjRtDevice *device) {
return device->local_hardware_id().value();
}

#include "xla/service/custom_call_target_registry.h"
extern "C" void RegisterCustomCallTarget(const char *name, void *address,
const char *platform) {
Expand Down Expand Up @@ -579,22 +594,30 @@ extern "C" MlirModule ConvertLLVMStrToMLIR(const char *lmod, MlirContext cctx) {
}

/* Note that this */
extern "C" xla::PjRtLoadedExecutable *
ClientCompile(PjRtClient *client, MlirModule cmod, int device_ordinal,
int num_replicas, int num_partitions,
bool use_shardy_partitioner) {
extern "C" xla::PjRtLoadedExecutable *ClientCompile(PjRtClient *client,
MlirModule cmod,
int *global_ordinals,
int num_global_ordinals) {
auto program =
std::make_unique<xla::ifrt::HloProgram>(cast<ModuleOp>(*unwrap(cmod)));

CompileOptions options;

if (device_ordinal >= 0) {
options.executable_build_options.set_device_ordinal(device_ordinal);
// https://github.com/pytorch/xla/blob/8b2414094578e829b99a8383877c86d357eeb682/torch_xla/csrc/runtime/pjrt_computation_client.cc#L601
int device_count = client->addressable_device_count();

options.executable_build_options.set_num_replicas(device_count);
options.executable_build_options.set_num_partitions(1);

xla::DeviceAssignment device_assignment(device_count, 1);
for (int64_t device_id = 0; device_id < num_global_ordinals; ++device_id) {
int ordinal = global_ordinals[device_id];
if (ordinal < 0) {
continue;
}
device_assignment(ordinal, 0) = device_id;
}
options.executable_build_options.set_num_replicas(num_replicas);
options.executable_build_options.set_num_partitions(num_partitions);
options.executable_build_options.set_use_shardy_partitioner(
use_shardy_partitioner);
options.executable_build_options.set_device_assignment(device_assignment);

auto addressable_devices = client->addressable_devices();
if (!addressable_devices.empty()) {
Expand All @@ -605,8 +628,7 @@ ClientCompile(PjRtClient *client, MlirModule cmod, int device_ordinal,
assert(device_ordinal < addressable_devices.size());
auto stats = addressable_devices[device_ordinal]->GetAllocatorStats();
if (stats.ok() && stats->bytes_limit) {
options.executable_build_options.set_device_memory_size(
*stats->bytes_limit);
options.executable_build_options.set_device_memory_size(*stats->bytes_limit);
}
}
auto exec =
Expand All @@ -623,12 +645,72 @@ extern "C" uint8_t FutureIsReady(FutureType *Future) {

extern "C" void FutureAwait(FutureType *Future) { Future->Await(); }

extern "C" void XLAExecuteSharded(xla::PjRtLoadedExecutable *exec, int num_args,
PjRtBuffer **op_args, PjRtDevice *device,
uint8_t *is_arg_donatable, int num_results,
PjRtBuffer **op_results, uint8_t *futures,
FutureType **future_results) {
// Create a vector of PjRtBuffer* from the input array.
std::vector<PjRtBuffer *> argument_handles(op_args, op_args + num_args);

// Set up execution options.
ExecuteOptions options;
for (size_t i = 0; i < num_args; i++) {
if (!is_arg_donatable[i]) {
options.non_donatable_input_indices.insert(static_cast<int>(i));
}
}
options.untuple_result = true;

// Optional future to hold asynchronous execution results.
std::optional<PjRtFuture<>> returned_future;

auto results = MyValueOrThrow(
exec->ExecuteSharded(argument_handles,
device, options, returned_future, /*fill_future=*/true));

// Validate the number of results.
if (results.size() != num_results) {
llvm::errs() << "Error: results.size()=" << results.size()
<< " does not match num_results=" << num_results << "\n";
std::abort(); // Terminate if the number of results is incorrect.
}

// Handle futures if they are returned.
if (returned_future.has_value()) {
*futures = true;
for (size_t i = 0; i < num_results; i++) {
future_results[i] = new FutureType(*returned_future);
}
} else {
*futures = false;
}

// Release the results into the output array.
for (size_t i = 0; i < num_results; i++) {
op_results[i] = results[i].release();
}
}

extern "C" void XLAExecute(xla::PjRtLoadedExecutable *exec, int num_args,
PjRtBuffer **op_args, uint8_t *is_arg_donatable,
int num_results, PjRtBuffer **op_results,
uint8_t *futures, FutureType **future_results) {
std::vector<std::vector<PjRtBuffer *>> argument_handles;
argument_handles.emplace_back(op_args, op_args + num_args);
auto client = exec->client();
int num_devices = client->addressable_device_count();

// Ensure argument_handles is structured as num_devices x num_args
std::vector<std::vector<PjRtBuffer *>> argument_handles(num_devices);

// Distribute arguments across devices
for (int device_idx = 0; device_idx < num_devices; ++device_idx) {
argument_handles[device_idx].reserve(num_args);
for (int arg_idx = 0; arg_idx < num_args; ++arg_idx) {
// Assuming op_args is a flat array of size num_devices * num_args
// where arguments for each device are contiguous
argument_handles[device_idx].push_back(op_args[device_idx * num_args + arg_idx]);
}
}

ExecuteOptions options;

Expand All @@ -637,31 +719,43 @@ extern "C" void XLAExecute(xla::PjRtLoadedExecutable *exec, int num_args,
options.non_donatable_input_indices.insert((int)i);
}
options.untuple_result = true;

std::optional<std::vector<FutureType>> returned_futures;
auto results = MyValueOrThrow(
exec->Execute(static_cast<absl::Span<const std::vector<PjRtBuffer *>>>(
argument_handles),
options, returned_futures));

assert(results.size() == 1);
assert(results.size() == num_devices);

if (results[0].size() != num_results) {
llvm::errs() << " results.size()=" << results.size()
<< " num_results=" << num_results << "\n";
for (int device_idx = 0; device_idx < num_devices; ++device_idx) {
if (results[device_idx].size() != num_results) {
llvm::errs() << " results[" << device_idx << "].size()=" << results[device_idx].size()
<< " num_results=" << num_results << "\n";
}
assert(results[device_idx].size() == num_results);
}
assert(results[0].size() == num_results);

// Handle returned futures
if (returned_futures) {
*futures = true;
assert(returned_futures->size() == num_results);
for (size_t i = 0; i < num_results; i++) {
future_results[i] = new FutureType((*returned_futures)[i]);
assert(returned_futures->size() == num_devices * num_results);
for (int device_idx = 0; device_idx < num_devices; ++device_idx) {
for (int result_idx = 0; result_idx < num_results; ++result_idx) {
int flat_index = device_idx * num_results + result_idx;
future_results[flat_index] = new FutureType((*returned_futures)[flat_index]);
}
}
} else {
*futures = false;
}

for (size_t i = 0; i < num_results; i++) {
op_results[i] = results[0][i].release();
// Copy results into the output buffers
for (int device_idx = 0; device_idx < num_devices; ++device_idx) {
for (int result_idx = 0; result_idx < num_results; ++result_idx) {
int flat_index = device_idx * num_results + result_idx;
op_results[flat_index] = results[device_idx][result_idx].release();
}
}
}

Expand Down
16 changes: 12 additions & 4 deletions deps/ReactantExtra/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -360,8 +360,9 @@ cc_library(
],

) + [
"@enzyme_ad//src/enzyme_ad/jax:RegistryUtils.cpp",
"@enzyme_ad//src/enzyme_ad/jax:gpu.cc",
"@enzyme_ad//src/enzyme_ad/jax:RegistryUtils.cpp",
"@enzyme_ad//src/enzyme_ad/jax:gpu.cc",
"@enzyme_ad//src/enzyme_ad/jax:cpu.cc",
# "@com_google_protobuf//:src/google/protobuf/io/coded_stream.cc",
# "@xla//xla:xla.pb.cc",
"@xla//xla:xla_data.pb.cc",
Expand Down Expand Up @@ -448,7 +449,13 @@ cc_library(
"-Wl,-exported_symbol,_ProfilerActivityEnd",
"-Wl,-exported_symbol,_ReactantFuncSetArgAttr",
"-Wl,-exported_symbol,_ReactantCudaDriverGetVersion",
"-Wl,-exported_symbol,_ReactantLLVMParseCommandLineOptions"
"-Wl,-exported_symbol,_ReactantLLVMParseCommandLineOptions",
"-Wl,-exported_symbol,_PjRtDeviceGetLocalDeviceId",
"-Wl,-exported_symbol,_PjRtDeviceGetGlobalDeviceId",
"-Wl,-exported_symbol,_PjRtDeviceGetLocalHardwareId",
"-Wl,-exported_symbol,_XLAExecuteSharded",
"-Wl,-exported_symbol,_ClientGetPlatformName",
"-Wl,-exported_symbol,_RegisterEnzymeXLACPUHandler",
]}),
deps = [
"@enzyme//:EnzymeMLIR",
Expand Down Expand Up @@ -550,13 +557,15 @@ cc_library(
"@xla//xla/mlir/utils:type_util",
"@stablehlo//:stablehlo_capi_objects",
"@stablehlo//:chlo_capi_objects",
"@shardy//shardy/integrations/c:sdy_capi_objects",
"@com_google_absl//absl/hash:hash",
"@com_google_absl//absl/log:initialize",
"@com_google_absl//absl/log:globals",
"@llvm-project//mlir:CAPIIRObjects",
"@llvm-project//mlir:CAPILLVMObjects",
"@jax//jaxlib/mosaic:tpu_dialect_capi_objects",
"@jax//jaxlib/triton:triton_dialect_capi_objects",
"@xla//xla/stream_executor/cuda:cuda_compute_capability_proto_cc_impl",
] + select({
"@xla//xla/tsl:is_cuda_enabled_and_oss":[
"@xla//xla/stream_executor/cuda:all_runtime",
Expand All @@ -568,7 +577,6 @@ cc_library(
"@xla//xla/backends/profiler/gpu:device_tracer",
],
"//conditions:default": [
"@xla//xla/stream_executor/cuda:cuda_compute_capability_proto_cc_impl",
],
}) + if_rocm([
"@xla//xla/service/gpu:amdgpu_compiler",
Expand Down
2 changes: 1 addition & 1 deletion deps/ReactantExtra/WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ http_archive(
urls = ["https://github.com/wsmoses/nsync/archive/{commit}.tar.gz".format(commit = NSYNC_COMMIT)],
)

ENZYMEXLA_COMMIT = "c38ca3f187ef11de6b2292f3cc55c5eb60530d15"
ENZYMEXLA_COMMIT = "d89468ed883ca18c04346eec10f784bbe2b754fc"
ENZYMEXLA_SHA256 = ""

http_archive(
Expand Down
Loading