From 40937bc4f148f313d3d934485683235d199e404e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Tue, 15 Oct 2024 08:15:45 +0200 Subject: [PATCH 01/25] Init prototype of IFRT C-API --- deps/ReactantExtra/API.cpp | 258 ++++++++++++++++++++++++++++++++++++- deps/ReactantExtra/BUILD | 1 + 2 files changed, 258 insertions(+), 1 deletion(-) diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index 08619e5de3..f4c4daab33 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -50,7 +50,6 @@ #include "xla/pjrt/pjrt_executable.h" #include "xla/pjrt/pjrt_api.h" #include "xla/pjrt/pjrt_c_api_client.h" -#include "xla/python/ifrt/executable.h" #include "xla/service/cpu/simple_orc_jit.h" #include "xla/python/ifrt/hlo/hlo_program.h" @@ -61,6 +60,16 @@ #include "llvm-c/TargetMachine.h" +// IFRT +#include "xla/python/ifrt/dtype.h" +#include "xla/python/ifrt/shape.h" +#include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/sharding.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/executable.h" + using namespace mlir; using namespace llvm; using namespace xla; @@ -471,3 +480,250 @@ extern "C" void InitializeRegistryAndPasses(MlirDialectRegistry creg) { mlir::enzyme::registerRemoveTransformPass(); mlir::enzyme::registerEnzymeJaxTransformExtension(registry); } + +/* xla::ifrt::DType */ +extern "C" ifrt::DType* ifrt_dtype(ifrt::DType::Kind kind) { + return new ifrt::DType(kind); +} + +// extern "C" void ifrt_dtype_free(ifrt::DType* dtype) { +// delete dtype; +// } + +extern "C" ifrt::DType::Kind ifrt_dtype_kind(ifrt::DType* dtype) { + return dtype->kind(); +} + +extern "C" bool ifrt_dtype_eq(ifrt::DType* dtype1, ifrt::DType* dtype2) { + return *dtype1 == *dtype2; +} + +extern "C" bool ifrt_dtype_ne(ifrt::DType* dtype1, ifrt::DType* dtype2) { + return *dtype1 != *dtype2; +} + +// Returns -1 if not aligned to a byte boundary or there is no fixed size +extern "C" int ifrt_dtype_byte_size(ifrt::DType* dtype) { + auto byte_size = dtype->byte_size(); + if (byte_size.has_value()) { + return byte_size.value(); + } + return -1; +} + +// Returns -1 if there is no fixed size +extern "C" int ifrt_dtype_bit_size(ifrt::DType* dtype) { + auto bit_size = dtype->bit_size(); + if (bit_size.has_value()) { + return bit_size.value(); + } + return -1; +} + +extern "C" const char* ifrt_dtype_debug_string(ifrt::DType* dtype) { + return cstr_from_string(dtype->DebugString()); +} + +/* xla::ifrt::Shape */ +extern "C" ifrt::Shape* ifrt_shape(const int64_t* dims, size_t dims_size) { + return new ifrt::Shape(absl::Span(dims, dims_size)); +} + +// extern "C" void ifrt_shape_free(ifrt::Shape* shape) { +// delete shape; +// } + +extern "C" const int64_t* ifrt_shape_dims(ifrt::Shape* shape) { + return shape->dims().data(); +} + +extern "C" int64_t ifrt_shape_dims_num_elements(ifrt::Shape* shape) { + return shape->num_elements(); +} + +extern "C" const char* ifrt_shape_debug_string(ifrt::Shape* shape) { + return cstr_from_string(shape->DebugString()); +} + +// TODO xla::ifrt::DynamicShape + +/* xla::ifrt::MemoryKind */ +// extern "C" ifrt::MemoryKind* ifrt_memorykind() { +// return new ifrt::MemoryKind(); +// } + +// extern "C" ifrt::MemoryKind* ifrt_memorykind_from_string(const char* kind) { +// return new ifrt::MemoryKind(std::optional{kind}); +// } + +// extern "C" void ifrt_memorykind_free(ifrt::MemoryKind* memory_kind) { +// delete memory_kind; +// } + +/* xla::ifrt::Memory */ +extern "C" ifrt::Memory* ifrt_memory() { + return new ifrt::Memory(); +} + +// extern "C" void ifrt_memory_free(ifrt::Memory* memory) { +// delete memory; +// } + +// MemoryId is a struct with a single int32_t field --> check out xla/python/ifrt/memory.h +extern "C" ifrt::MemoryId ifrt_memory_id(ifrt::Memory* memory) { + return memory->Id(); +} + +// TODO ifrt_memory_kind + +extern "C" const char* ifrt_memory_to_string(ifrt::Memory* memory) { + return cstr_from_string(memory->ToString()); +} + +extern "C" const char* ifrt_memory_debug_string(ifrt::Memory* memory) { + return cstr_from_string(memory->DebugString()); +} + +// TODO ifrt_memory_devices + +/* xla::ifrt::Device */ +extern "C" ifrt::Device* ifrt_device() { + return new ifrt::Device(); +} + +// extern "C" void ifrt_device_free(ifrt::Device* device) { +// delete device; +// } + +extern "C" ifrt::Client* ifrt_device_client(ifrt::Device* device) { + return device->client(); +} + +// DeviceId is a struct with a single int32_t field --> check out xla/pjrt/pjrt_common.h +extern "C" ifrt::DeviceId ifrt_device_id(ifrt::Device* device) { + return device->id(); +} + +// TODO ifrt_device_attributes + +extern "C" const char* ifrt_device_kind(ifrt::Device* device) { + return cstr_from_string(device->kind()); +} + +extern "C" const char* ifrt_device_to_string(ifrt::Device* device) { + return cstr_from_string(device->ToString()); +} + +extern "C" const char* ifrt_device_from_string(ifrt::Client* client) { + return cstr_from_string(client->DebugString()); +} + +extern "C" ifrt::Memory* ifrt_device_default_memory(ifrt::Device* device, char** error) { + return unwrap_absl_statusor(device->DefaultMemory(), error); +} + +// TODO ifrt_device_memories + +extern "C" bool ifrt_device_is_addressable(ifrt::Device* device) { + return device->IsAddressable(); +} + +extern "C" int ifrt_device_process_index(ifrt::Device* device) { + return device->process_index(); +} + +/* xla::ifrt::Sharding */ +// TODO ifrt_sharding_devices +// TODO ifrt_sharding_memory_kind + +// extern "C" void ifrt_sharding_disassemble(ifrt::Sharding* sharding, ifrt::Shape* shape, char** error) { +// auto status = sharding->Disassemble(*shape); +// if (!status.ok()) { +// auto str = status.message(); +// char* err = (char*)malloc(str.size()+1); +// memcpy(err, str.data(), str.size()+1); +// *error = err; +// } +// } + +// TODO ifrt_sharding_disassemble_dynamic_shape +// TODO ifrt_sharding_index_domains + +extern "C" const char* ifrt_sharding_debug_string(ifrt::Sharding* sharding) { + return cstr_from_string(sharding->DebugString()); +} + +/* xla::ifrt::Array */ +extern "C" ifrt::Array* ifrt_array() { + return new ifrt::Array(); +} + +// extern "C" void ifrt_array_free(ifrt::Array* array) { +// delete array; +// } + +extern "C" ifrt::DType ifrt_array_dtype(ifrt::Array* array) { + return array->dtype(); +} + +// ... + +/* xla::ifrt::Client */ +extern "C" int ifrt_client_device_count(ifrt::Client* client) { + return client->device_count(); +} + +extern "C" int ifrt_client_addressable_device_count(ifrt::Client* client) { + return client->addressable_device_count(); +} + +extern "C" ifrt::Device* ifrt_client_devices(ifrt::Client* client) { + return client->devices().data(); +} + +extern "C" ifrt::Device* ifrt_client_addressable_devices(ifrt::Client* client) { + return client->addressable_devices().data(); +} + +extern "C" int ifrt_client_process_index(ifrt::Client* client) { + return client->process_index(); +} + +// TODO Client::GetDefaultDeviceAssignment + +extern "C" ifrt::Device* ifrt_client_lookup_device(ifrt::Client* client, int device_id, **) { + return xla::ValueOrThrow(client->LookupDevice(ifrt::DeviceId(device_id))); +} + +extern "C" ifrt::Device* ifrt_client_lookup_addressable_device(ifrt::Client* client, int device_id, **) { + return xla::ValueOrThrow(client->LookupAddressableDevice(ifrt::DeviceId(device_id))); +} + +extern "C" ifrt::Compiler* ifrt_client_default_compiler(ifrt::Client* client) { + return client->GetDefaultCompiler(); +} + +// TODO ifrt_client_topology_for_devices +// TODO ifrt_client_default_layout_for_device + +// auxiliar functions +template +const char* cstr_from_string(T text) { + char* cstr = (char*)malloc(text.size() + 1); + memcpy(cstr, text.data(), text.size()); + cstr[text.size()] = '\0'; + return cstr; +} + +template +T* unwrap_absl_statusor(absl::StatusOr status, char** error_msg) { + *error_msg = nullptr; + if (!status.ok()) { + auto str = pluginLoad.status().message(); + char* err = (char*)malloc(str.size()+1); + memcpy(err, str.data(), str.size()+1); + *error_msg = err; + return nullptr; + } + return status.value(); +} diff --git a/deps/ReactantExtra/BUILD b/deps/ReactantExtra/BUILD index 61257a421b..c7f7960815 100644 --- a/deps/ReactantExtra/BUILD +++ b/deps/ReactantExtra/BUILD @@ -301,6 +301,7 @@ cc_library( "-Wl,-exported_symbol,_XLAExecute", "-Wl,-exported_symbol,_RegisterDialects", "-Wl,-exported_symbol,_InitializeRegistryAndPasses", +"-Wl,-exported_symbol,_ifrt_*", ]}), deps = [ "@enzyme//:EnzymeMLIR", From 84470161f6baeaeedb7a996a4d8e724ca6d7a41f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Tue, 15 Oct 2024 14:56:29 +0200 Subject: [PATCH 02/25] Add some C-API for `Executable`, `LoadedExecutable` --- deps/ReactantExtra/API.cpp | 90 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 90 insertions(+) diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index f4c4daab33..26164bc1d2 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -706,6 +706,96 @@ extern "C" ifrt::Compiler* ifrt_client_default_compiler(ifrt::Client* client) { // TODO ifrt_client_topology_for_devices // TODO ifrt_client_default_layout_for_device +/* xla::ifrt::Executable */ +extern "C" const char* ifrt_executable_name(ifrt::Executable* executable) { + return cstr_from_string(executable->name()); +} + +extern "C" const char* ifrt_executable_fingerprint(ifrt::Executable* executable) { + auto result = xla::ValueOrThrow(executable->fingerprint()); + if (!result.has_value()) return ""; + return cstr_from_string(result.value()); +} + +extern "C" const char* ifrt_executable_serialize(ifrt::Executable* executable) { + return cstr_from_string(xla::ValueOrThrow(executable->Serialize())); +} + +extern "C" int ifrt_executable_num_devices(ifrt::Executable* executable) { + return executable->num_devices(); +} + +extern "C" int64_t ifrt_executable_size(ifrt::Executable* executable) { + return executable->SizeOfGeneratedCodeInBytes(); +} + +// TODO xla::ifrt::GetCompiledMemoryStats +// TODO xla::ifrt::GetParameterShardings +// TODO xla::ifrt::GetOutputShardings +// TODO xla::ifrt::GetParameterLayouts +// TODO xla::ifrt::GetOutputLayouts +// TODO xla::ifrt::GetHloModules +// TODO xla::ifrt::GetCostAnalysis + +extern "C" ifrt::Client* ifrt_loadedexecutable_client(ifrt::LoadedExecutable* executable) { + return executable->client(); +} + +extern "C" const char* ifrt_loadedexecutable_name(ifrt::LoadedExecutable* executable) { + return cstr_from_string(executable->name()); +} + +extern "C" const char* ifrt_loadedexecutable_fingerprint(ifrt::LoadedExecutable* executable) { + auto result = xla::ValueOrThrow(executable->fingerprint()); + if (!result.has_value()) return ""; + return cstr_from_string(result.value()); +} + +extern "C" const char* ifrt_loadedexecutable_serialize(ifrt::LoadedExecutable* executable) { + return cstr_from_string(xla::ValueOrThrow(executable->Serialize())); +} + +extern "C" ifrt::Future<>* ifrt_loadedexecutable_get_ready_future(ifrt::LoadedExecutable* executable) { + return executable->GetReadyFuture(); +} + +extern "C" int ifrt_loadedexecutable_num_devices(ifrt::LoadedExecutable* executable) { + return executable->num_devices(); +} + +extern "C" int64_t ifrt_loadedexecutable_size(ifrt::LoadedExecutable* executable) { + return executable->SizeOfGeneratedCodeInBytes(); +} + +// TODO xla::ifrt::GetCompiledMemoryStats +// TODO xla::ifrt::GetParameterShardings +// TODO xla::ifrt::GetOutputShardings +// TODO xla::ifrt::GetParameterLayouts +// TODO xla::ifrt::GetOutputLayouts +// TODO xla::ifrt::GetHloModules +// TODO xla::ifrt::GetOutputMemoryKinds +// TODO xla::ifrt::GetCostAnalysis + +// extern "C" ifrt::LoadedExecutable::ExecuteResult* ifrt_loadedexecutable_execute(ifrt::LoadedExecutable* executable, ifrt::Array** args, size_t args_size, ifrt::Array** results, size_t results_size, ifrt::Future<*>** futures, size_t futures_size) { +// std::vector arguments(args, args + args_size); +// std::vector result(results, results + results_size); +// std::vector*> future(futures, futures + futures_size); +// return xla::ValueOrThrow(executable->Execute(arguments, result, future)); +// } + +extern "C" ifrt::Future<> ifrt_loadedexecutable_delete(ifrt::LoadedExecutable* executable) { + return executable->Delete(); +} + +extern "C" bool ifrt_loadedexecutable_is_deleted(ifrt::LoadedExecutable* executable) { + return executable->IsDeleted(); +} + +// TODO ifrt::LoadedExecutable::addressable_device_logical_ids +// TODO ifrt::LoadedExecutable::addressable_devices + +// TODO auxiliary functions for ifrt::LoadedExecutable::ExecuteResult + // auxiliar functions template const char* cstr_from_string(T text) { From e04b4b312405e216fca331fb34de508d8e9e7220 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Tue, 15 Oct 2024 15:35:33 +0200 Subject: [PATCH 03/25] Add C-API for `DynamicShape` --- deps/ReactantExtra/API.cpp | 28 +++++++++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index 26164bc1d2..8e14f823ad 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -545,7 +545,33 @@ extern "C" const char* ifrt_shape_debug_string(ifrt::Shape* shape) { return cstr_from_string(shape->DebugString()); } -// TODO xla::ifrt::DynamicShape +extern "C" ifrt::DynamicShape* ifrt_dynamicshape_create(ifrt::Shape* shape, bool dynamic_dims_mask) { + std::vector bool_vector(dynamic_dims_mask, dynamic_dims_mask + shape->dims().size()); + auto tag = ifrt::BoundedDynamicShapeTag(absl::Span(bool_vector)); + return new ifrt::DynamicShape(*shape, tag); +} + +// TODO ifrt::DynamicShape::GetTag + +extern "C" bool ifrt_dynamicshape_eq(ifrt::DynamicShape* shape1, ifrt::DynamicShape* shape2) { + return *shape1 == *shape2; +} + +extern "C" bool ifrt_dynamicshape_ne(ifrt::DynamicShape* shape1, ifrt::DynamicShape* shape2) { + return *shape1 != *shape2; +} + +extern "C" ifrt::Shape* ifrt_dynamicshape_get_padded_shape(ifrt::DynamicShape* shape) { + return xla::ValueOrThrow(shape->GetPaddedShape()).release(); +} + +extern "C" bool ifrt_dynamicshape_is_dynamic_dim(ifrt::DynamicShape* shape, int dimension) { + return shape->IsDynamicDim(dimension); +} + +extern "C" const char* ifrt_dynamicshape_debug_string(ifrt::DynamicShape* shape) { + return cstr_from_string(shape->DebugString()); +} /* xla::ifrt::MemoryKind */ // extern "C" ifrt::MemoryKind* ifrt_memorykind() { From 64c088c805b4d14c6c83e391e0c6b475a0d471f0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Tue, 15 Oct 2024 16:40:59 +0200 Subject: [PATCH 04/25] Add C-API bindings to `Index`, `IndexDomain` --- deps/ReactantExtra/API.cpp | 107 +++++++++++++++++++++++++++++++++++++ 1 file changed, 107 insertions(+) diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index 8e14f823ad..41d7ba96f4 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -573,6 +573,113 @@ extern "C" const char* ifrt_dynamicshape_debug_string(ifrt::DynamicShape* shape) return cstr_from_string(shape->DebugString()); } +/* xla::ifrt::Index */ +extern "C" ifrt::Index* ifrt_index(const int64_t* elements, size_t elements_size) { + return new ifrt::Index(absl::Span(elements, elements_size)); +} + +extern "C" void ifrt_index_free(ifrt::Index* index) { + delete index; +} + +extern "C" ifrt::Index* ifrt_index_zeros(int num_elements) { + return new ifrt::Index(ifrt::Index::Zeros(num_elements)); +} + +extern "C" const int64_t* ifrt_index_elements(ifrt::Index* index) { + return index->elements().data(); +} + +extern "C" int ifrt_index_count(ifrt::Index* index) { + return index->elements().size(); +} + +extern "C" bool ifrt_index_eq(ifrt::Index* index1, ifrt::Index* index2) { + return *index1 == *index2; +} + +extern "C" bool ifrt_index_ne(ifrt::Index* index1, ifrt::Index* index2) { + return *index1 != *index2; +} + +extern "C" ifrt::Index* ifrt_index_add(ifrt::Index* index, ifrt::Index* offset) { + return new ifrt::Index(*index + *offset); +} + +extern "C" ifrt::Index* ifrt_index_sub(ifrt::Index* index, ifrt::Index* offset) { + return new ifrt::Index(*index - *offset); +} + +// WARN we're not checking if the multiplier has the same size as the index +extern "C" ifrt::Index* ifrt_index_mul(ifrt::Index* index, const int64_t* multiplier) { + return new ifrt::Index(*index * absl::Span(multiplier, ifrt_index_count(index))); +} + +extern "C" void ifrt_index_add_inplace(ifrt::Index* index, ifrt::Index* offset) { + *index += *offset; +} + +extern "C" void ifrt_index_sub_inplace(ifrt::Index* index, ifrt::Index* offset) { + *index -= *offset; +} + +extern "C" void ifrt_index_mul_inplace(ifrt::Index* index, const int64_t* multiplier) { + *index *= absl::Span(multiplier, ifrt_index_count(index)); +} + +extern "C" const char* ifrt_index_debug_string(ifrt::Index* index) { + return cstr_from_string(index->DebugString()); +} + +/* xla::ifrt::IndexDomain */ +extern "C" ifrt::IndexDomain* ifrt_indexdomain_ctor(ifrt::Shape* shape) { + return new ifrt::IndexDomain(*shape); +} + +extern "C" ifrt::IndexDomain* ifrt_indexdomain_ctor_with_origin(ifrt::Index* origin, ifrt::Shape* shape) { + return new ifrt::IndexDomain(*origin, *shape); +} + +extern "C" void ifrt_indexdomain_free(ifrt::IndexDomain* index_domain) { + delete index_domain; +} + +extern "C" const ifrt::Index* ifrt_indexdomain_origin(ifrt::IndexDomain* index_domain) { + return &index_domain->origin(); +} + +extern "C" const ifrt::Shape* ifrt_indexdomain_shape(ifrt::IndexDomain* index_domain) { + return &index_domain->shape(); +} + +extern "C" bool ifrt_indexdomain_eq(ifrt::IndexDomain* index_domain1, ifrt::IndexDomain* index_domain2) { + return *index_domain1 == *index_domain2; +} + +extern "C" bool ifrt_indexdomain_ne(ifrt::IndexDomain* index_domain1, ifrt::IndexDomain* index_domain2) { + return *index_domain1 != *index_domain2; +} + +extern "C" ifrt::IndexDomain* ifrt_indexdomain_add(ifrt::IndexDomain* index_domain, ifrt::Index* offset) { + return new ifrt::IndexDomain(*index_domain + *offset); +} + +extern "C" ifrt::IndexDomain* ifrt_indexdomain_sub(ifrt::IndexDomain* index_domain, ifrt::Index* offset) { + return new ifrt::IndexDomain(*index_domain - *offset); +} + +extern "C" void ifrt_indexdomain_add_inplace(ifrt::IndexDomain* index_domain, ifrt::Index* offset) { + *index_domain += *offset; +} + +extern "C" void ifrt_indexdomain_sub_inplace(ifrt::IndexDomain* index_domain, ifrt::Index* offset) { + *index_domain -= *offset; +} + +extern "C" const char* ifrt_indexdomain_debug_string(ifrt::IndexDomain* index_domain) { + return cstr_from_string(index_domain->DebugString()); +} + /* xla::ifrt::MemoryKind */ // extern "C" ifrt::MemoryKind* ifrt_memorykind() { // return new ifrt::MemoryKind(); From fe9edf8080b169348d85fbc8ebe0d018a491adaf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Tue, 15 Oct 2024 16:46:34 +0200 Subject: [PATCH 05/25] small fixes --- deps/ReactantExtra/API.cpp | 51 ++++++++++++++++++++++---------------- 1 file changed, 29 insertions(+), 22 deletions(-) diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index 41d7ba96f4..d9b4f2b5f1 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -63,6 +63,8 @@ // IFRT #include "xla/python/ifrt/dtype.h" #include "xla/python/ifrt/shape.h" +#include "xla/python/ifrt/index.h" +#include "xla/python/ifrt/index_domain.h" #include "xla/python/ifrt/memory.h" #include "xla/python/ifrt/device.h" #include "xla/python/ifrt/sharding.h" @@ -481,7 +483,7 @@ extern "C" void InitializeRegistryAndPasses(MlirDialectRegistry creg) { mlir::enzyme::registerEnzymeJaxTransformExtension(registry); } -/* xla::ifrt::DType */ +#pragma mark xla::ifrt::DType extern "C" ifrt::DType* ifrt_dtype(ifrt::DType::Kind kind) { return new ifrt::DType(kind); } @@ -524,7 +526,7 @@ extern "C" const char* ifrt_dtype_debug_string(ifrt::DType* dtype) { return cstr_from_string(dtype->DebugString()); } -/* xla::ifrt::Shape */ +#pragma mark xla::ifrt::Shape extern "C" ifrt::Shape* ifrt_shape(const int64_t* dims, size_t dims_size) { return new ifrt::Shape(absl::Span(dims, dims_size)); } @@ -545,6 +547,7 @@ extern "C" const char* ifrt_shape_debug_string(ifrt::Shape* shape) { return cstr_from_string(shape->DebugString()); } +#pragma mark xla::ifrt::DynamicShape extern "C" ifrt::DynamicShape* ifrt_dynamicshape_create(ifrt::Shape* shape, bool dynamic_dims_mask) { std::vector bool_vector(dynamic_dims_mask, dynamic_dims_mask + shape->dims().size()); auto tag = ifrt::BoundedDynamicShapeTag(absl::Span(bool_vector)); @@ -573,7 +576,7 @@ extern "C" const char* ifrt_dynamicshape_debug_string(ifrt::DynamicShape* shape) return cstr_from_string(shape->DebugString()); } -/* xla::ifrt::Index */ +#pragma mark xla::ifrt::Index extern "C" ifrt::Index* ifrt_index(const int64_t* elements, size_t elements_size) { return new ifrt::Index(absl::Span(elements, elements_size)); } @@ -631,7 +634,7 @@ extern "C" const char* ifrt_index_debug_string(ifrt::Index* index) { return cstr_from_string(index->DebugString()); } -/* xla::ifrt::IndexDomain */ +#pragma mark xla::ifrt::IndexDomain extern "C" ifrt::IndexDomain* ifrt_indexdomain_ctor(ifrt::Shape* shape) { return new ifrt::IndexDomain(*shape); } @@ -680,7 +683,7 @@ extern "C" const char* ifrt_indexdomain_debug_string(ifrt::IndexDomain* index_do return cstr_from_string(index_domain->DebugString()); } -/* xla::ifrt::MemoryKind */ +#pragma mark xla::ifrt::MemoryKind // extern "C" ifrt::MemoryKind* ifrt_memorykind() { // return new ifrt::MemoryKind(); // } @@ -693,7 +696,7 @@ extern "C" const char* ifrt_indexdomain_debug_string(ifrt::IndexDomain* index_do // delete memory_kind; // } -/* xla::ifrt::Memory */ +#pragma mark xla::ifrt::Memory extern "C" ifrt::Memory* ifrt_memory() { return new ifrt::Memory(); } @@ -719,7 +722,7 @@ extern "C" const char* ifrt_memory_debug_string(ifrt::Memory* memory) { // TODO ifrt_memory_devices -/* xla::ifrt::Device */ +#pragma mark xla::ifrt::Device extern "C" ifrt::Device* ifrt_device() { return new ifrt::Device(); } @@ -765,7 +768,7 @@ extern "C" int ifrt_device_process_index(ifrt::Device* device) { return device->process_index(); } -/* xla::ifrt::Sharding */ +#pragma mark xla::ifrt::Sharding // TODO ifrt_sharding_devices // TODO ifrt_sharding_memory_kind @@ -786,7 +789,7 @@ extern "C" const char* ifrt_sharding_debug_string(ifrt::Sharding* sharding) { return cstr_from_string(sharding->DebugString()); } -/* xla::ifrt::Array */ +#pragma mark xla::ifrt::Array extern "C" ifrt::Array* ifrt_array() { return new ifrt::Array(); } @@ -801,7 +804,7 @@ extern "C" ifrt::DType ifrt_array_dtype(ifrt::Array* array) { // ... -/* xla::ifrt::Client */ +#pragma mark xla::ifrt::Client extern "C" int ifrt_client_device_count(ifrt::Client* client) { return client->device_count(); } @@ -822,7 +825,7 @@ extern "C" int ifrt_client_process_index(ifrt::Client* client) { return client->process_index(); } -// TODO Client::GetDefaultDeviceAssignment +// TODO xla::ifrt::Client::GetDefaultDeviceAssignment extern "C" ifrt::Device* ifrt_client_lookup_device(ifrt::Client* client, int device_id, **) { return xla::ValueOrThrow(client->LookupDevice(ifrt::DeviceId(device_id))); @@ -839,7 +842,7 @@ extern "C" ifrt::Compiler* ifrt_client_default_compiler(ifrt::Client* client) { // TODO ifrt_client_topology_for_devices // TODO ifrt_client_default_layout_for_device -/* xla::ifrt::Executable */ +#pragma mark xla::ifrt::Executable extern "C" const char* ifrt_executable_name(ifrt::Executable* executable) { return cstr_from_string(executable->name()); } @@ -862,14 +865,15 @@ extern "C" int64_t ifrt_executable_size(ifrt::Executable* executable) { return executable->SizeOfGeneratedCodeInBytes(); } -// TODO xla::ifrt::GetCompiledMemoryStats -// TODO xla::ifrt::GetParameterShardings -// TODO xla::ifrt::GetOutputShardings -// TODO xla::ifrt::GetParameterLayouts -// TODO xla::ifrt::GetOutputLayouts -// TODO xla::ifrt::GetHloModules -// TODO xla::ifrt::GetCostAnalysis +// TODO xla::ifrt::Executable::GetCompiledMemoryStats +// TODO xla::ifrt::Executable::GetParameterShardings +// TODO xla::ifrt::Executable::GetOutputShardings +// TODO xla::ifrt::Executable::GetParameterLayouts +// TODO xla::ifrt::Executable::GetOutputLayouts +// TODO xla::ifrt::Executable::GetHloModules +// TODO xla::ifrt::Executable::GetCostAnalysis +#pragma mark xla::ifrt::LoadedExecutable extern "C" ifrt::Client* ifrt_loadedexecutable_client(ifrt::LoadedExecutable* executable) { return executable->client(); } @@ -924,12 +928,15 @@ extern "C" bool ifrt_loadedexecutable_is_deleted(ifrt::LoadedExecutable* executa return executable->IsDeleted(); } -// TODO ifrt::LoadedExecutable::addressable_device_logical_ids -// TODO ifrt::LoadedExecutable::addressable_devices +// TODO xla::ifrt::LoadedExecutable::addressable_device_logical_ids +// TODO xla::ifrt::LoadedExecutable::addressable_devices + +// TODO auxiliary functions for xla::ifrt::LoadedExecutable::ExecuteResult -// TODO auxiliary functions for ifrt::LoadedExecutable::ExecuteResult +#pragma mark xla::ifrt::CustomCallProgram // auxiliar functions +#pragma mark - template const char* cstr_from_string(T text) { char* cstr = (char*)malloc(text.size() + 1); From 0876a78085b1cb50f97880325665699547ea8d58 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Tue, 15 Oct 2024 16:56:23 +0200 Subject: [PATCH 06/25] small changes --- deps/ReactantExtra/API.cpp | 51 +++++++++++++++++++++----------------- 1 file changed, 28 insertions(+), 23 deletions(-) diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index d9b4f2b5f1..1e3d2e21f6 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -71,6 +71,7 @@ #include "xla/python/ifrt/array.h" #include "xla/python/ifrt/client.h" #include "xla/python/ifrt/executable.h" +#include "xla/python/ifrt/compiler.h" using namespace mlir; using namespace llvm; @@ -484,13 +485,13 @@ extern "C" void InitializeRegistryAndPasses(MlirDialectRegistry creg) { } #pragma mark xla::ifrt::DType -extern "C" ifrt::DType* ifrt_dtype(ifrt::DType::Kind kind) { +extern "C" ifrt::DType* ifrt_dtype_ctor(ifrt::DType::Kind kind) { return new ifrt::DType(kind); } -// extern "C" void ifrt_dtype_free(ifrt::DType* dtype) { -// delete dtype; -// } +extern "C" void ifrt_dtype_free(ifrt::DType* dtype) { + delete dtype; +} extern "C" ifrt::DType::Kind ifrt_dtype_kind(ifrt::DType* dtype) { return dtype->kind(); @@ -527,13 +528,13 @@ extern "C" const char* ifrt_dtype_debug_string(ifrt::DType* dtype) { } #pragma mark xla::ifrt::Shape -extern "C" ifrt::Shape* ifrt_shape(const int64_t* dims, size_t dims_size) { +extern "C" ifrt::Shape* ifrt_shape_ctor(const int64_t* dims, size_t dims_size) { return new ifrt::Shape(absl::Span(dims, dims_size)); } -// extern "C" void ifrt_shape_free(ifrt::Shape* shape) { -// delete shape; -// } +extern "C" void ifrt_shape_free(ifrt::Shape* shape) { + delete shape; +} extern "C" const int64_t* ifrt_shape_dims(ifrt::Shape* shape) { return shape->dims().data(); @@ -548,12 +549,16 @@ extern "C" const char* ifrt_shape_debug_string(ifrt::Shape* shape) { } #pragma mark xla::ifrt::DynamicShape -extern "C" ifrt::DynamicShape* ifrt_dynamicshape_create(ifrt::Shape* shape, bool dynamic_dims_mask) { +extern "C" ifrt::DynamicShape* ifrt_dynamicshape_ctor(ifrt::Shape* shape, bool dynamic_dims_mask) { std::vector bool_vector(dynamic_dims_mask, dynamic_dims_mask + shape->dims().size()); auto tag = ifrt::BoundedDynamicShapeTag(absl::Span(bool_vector)); return new ifrt::DynamicShape(*shape, tag); } +extern "C" void ifrt_dynamicshape_free(ifrt::DynamicShape* shape) { + delete shape; +} + // TODO ifrt::DynamicShape::GetTag extern "C" bool ifrt_dynamicshape_eq(ifrt::DynamicShape* shape1, ifrt::DynamicShape* shape2) { @@ -577,18 +582,18 @@ extern "C" const char* ifrt_dynamicshape_debug_string(ifrt::DynamicShape* shape) } #pragma mark xla::ifrt::Index -extern "C" ifrt::Index* ifrt_index(const int64_t* elements, size_t elements_size) { +extern "C" ifrt::Index* ifrt_index_ctor(const int64_t* elements, size_t elements_size) { return new ifrt::Index(absl::Span(elements, elements_size)); } -extern "C" void ifrt_index_free(ifrt::Index* index) { - delete index; -} - extern "C" ifrt::Index* ifrt_index_zeros(int num_elements) { return new ifrt::Index(ifrt::Index::Zeros(num_elements)); } +extern "C" void ifrt_index_free(ifrt::Index* index) { + delete index; +} + extern "C" const int64_t* ifrt_index_elements(ifrt::Index* index) { return index->elements().data(); } @@ -697,13 +702,13 @@ extern "C" const char* ifrt_indexdomain_debug_string(ifrt::IndexDomain* index_do // } #pragma mark xla::ifrt::Memory -extern "C" ifrt::Memory* ifrt_memory() { +extern "C" ifrt::Memory* ifrt_memory_ctor() { return new ifrt::Memory(); } -// extern "C" void ifrt_memory_free(ifrt::Memory* memory) { -// delete memory; -// } +extern "C" void ifrt_memory_free(ifrt::Memory* memory) { + delete memory; +} // MemoryId is a struct with a single int32_t field --> check out xla/python/ifrt/memory.h extern "C" ifrt::MemoryId ifrt_memory_id(ifrt::Memory* memory) { @@ -723,13 +728,13 @@ extern "C" const char* ifrt_memory_debug_string(ifrt::Memory* memory) { // TODO ifrt_memory_devices #pragma mark xla::ifrt::Device -extern "C" ifrt::Device* ifrt_device() { +extern "C" ifrt::Device* ifrt_device_ctor() { return new ifrt::Device(); } -// extern "C" void ifrt_device_free(ifrt::Device* device) { -// delete device; -// } +extern "C" void ifrt_device_free(ifrt::Device* device) { + delete device; +} extern "C" ifrt::Client* ifrt_device_client(ifrt::Device* device) { return device->client(); @@ -790,7 +795,7 @@ extern "C" const char* ifrt_sharding_debug_string(ifrt::Sharding* sharding) { } #pragma mark xla::ifrt::Array -extern "C" ifrt::Array* ifrt_array() { +extern "C" ifrt::Array* ifrt_array_ctor() { return new ifrt::Array(); } From 89f73fae399fc0a7659df3d1d9e1081ad499ca7d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Tue, 15 Oct 2024 17:14:37 +0200 Subject: [PATCH 07/25] Add C-API for `Compiler` --- deps/ReactantExtra/API.cpp | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index 1e3d2e21f6..b423734e47 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -940,6 +940,17 @@ extern "C" bool ifrt_loadedexecutable_is_deleted(ifrt::LoadedExecutable* executa #pragma mark xla::ifrt::CustomCallProgram +#pragma mark xla::ifrt::Compiler +extern "C" ifrt::LoadedExecutable* ifrt_compiler_compile(ifrt::Compiler* compiler, ifrt::Program* program, char** error) { + // apparently ifrt::CompileOptions is a legacy artifact so we don't use it and set directly to the default + return unwrap_absl_statusor(compiler->Compile(*program, *options, ifrt::CompileOptions()), error); +} + +extern "C" ifrt::LoadedExecutable* ifrt_compiler_deserialize_loadedexecutable(ifrt::Compiler* compiler, const char* data, size_t size, char** error) { + // apparently ifrt::DeserializeExecutableOptions is a legacy artifact so we don't use it and set directly to the default + return unwrap_absl_statusor(compiler->DeserializeLoadedExecutable(data, size, ifrt::DeserializeExecutableOptions()), error); +} + // auxiliar functions #pragma mark - template From a2b871fe7962a380e031307cc600915c3d8c866f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Tue, 15 Oct 2024 17:37:37 +0200 Subject: [PATCH 08/25] Add C-API for `MemoryKind` --- deps/ReactantExtra/API.cpp | 38 +++++++++++++++++++++++++++++--------- 1 file changed, 29 insertions(+), 9 deletions(-) diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index b423734e47..0a0c057b40 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -689,17 +689,37 @@ extern "C" const char* ifrt_indexdomain_debug_string(ifrt::IndexDomain* index_do } #pragma mark xla::ifrt::MemoryKind -// extern "C" ifrt::MemoryKind* ifrt_memorykind() { -// return new ifrt::MemoryKind(); -// } +// Pass a nullptr to create a `MemoryKind` with no memory chosen. +extern "C" ifrt::MemoryKind* ifrt_memorykind_ctor(const char* memory_kind) { + ifrt::MemoryKind tmp{}; + if (memory_kind != nullptr) + tmp = ifrt::MemoryKind(std::string(memory_kind)); + return new ifrt::MemoryKind(tmp); +} -// extern "C" ifrt::MemoryKind* ifrt_memorykind_from_string(const char* kind) { -// return new ifrt::MemoryKind(std::optional{kind}); -// } +extern "C" void ifrt_memorykind_free(ifrt::MemoryKind* memory_kind) { + delete memory_kind; +} -// extern "C" void ifrt_memorykind_free(ifrt::MemoryKind* memory_kind) { -// delete memory_kind; -// } +extern "C" bool ifrt_memorykind_eq(ifrt::MemoryKind* mk1, ifrt::MemoryKind* mk2) { + return *mk1 == *mk2; +} + +extern "C" bool ifrt_memorykind_ne(ifrt::MemoryKind* mk1, ifrt::MemoryKind* mk2) { + return *mk1 != *mk2; +} + +extern "C" const char* ifrt_memorykind_string(ifrt::MemoryKind* memory_kind) { + return cstr_from_string(memory_kind->memory_kind()); +} + +extern "C" const char* ifrt_memorykind_debug_string(ifrt::MemoryKind* memory_kind) { + return cstr_from_string(memory_kind->DebugString()); +} + +extern "C" ifrt::MemoryKind* ifrt_memorykind_canonicalize(ifrt::MemoryKind* memory_kind, ifrt::Device* device) { + return new ifrt::MemoryKind(CanonicalizeMemoryKind(*memory_kind, device)); +} #pragma mark xla::ifrt::Memory extern "C" ifrt::Memory* ifrt_memory_ctor() { From 77b9e8fb9f56a61749d7fde6853fd2c22db8f15f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Wed, 16 Oct 2024 10:30:24 +0200 Subject: [PATCH 09/25] more implementations --- deps/ReactantExtra/API.cpp | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index 0a0c057b40..fe7324b42d 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -71,6 +71,7 @@ #include "xla/python/ifrt/array.h" #include "xla/python/ifrt/client.h" #include "xla/python/ifrt/executable.h" +#include "xla/python/ifrt/hlo/hlo_program.h" #include "xla/python/ifrt/compiler.h" using namespace mlir; @@ -735,7 +736,9 @@ extern "C" ifrt::MemoryId ifrt_memory_id(ifrt::Memory* memory) { return memory->Id(); } -// TODO ifrt_memory_kind +extern "C" const ifrt::MemoryKind* ifrt_memory_kind(ifrt::Memory* memory) { + return &(memory->Kind()); +} extern "C" const char* ifrt_memory_to_string(ifrt::Memory* memory) { return cstr_from_string(memory->ToString()); @@ -745,7 +748,10 @@ extern "C" const char* ifrt_memory_debug_string(ifrt::Memory* memory) { return cstr_from_string(memory->DebugString()); } -// TODO ifrt_memory_devices +extern "C" std::tuple ifrt_memory_devices(ifrt::Memory* memory) { + auto devices = memory->Devices(); + return std::make_tuple; +} #pragma mark xla::ifrt::Device extern "C" ifrt::Device* ifrt_device_ctor() { @@ -960,6 +966,19 @@ extern "C" bool ifrt_loadedexecutable_is_deleted(ifrt::LoadedExecutable* executa #pragma mark xla::ifrt::CustomCallProgram +#pragma mark xla::ifrt::HloProgram +extern "C" ifrt::HloProgram* ifrt_hloprogram_ctor() { + return new ifrt::HloProgram(); +} + +extern "C" ifrt::HloProgram* ifrt_hloprogram_ctor_with_module(mlir::ModuleOp* module) { + return new ifrt::HloProgram(module); +} + +extern "C" ifrt::HloProgram* ifrt_hloprogram_ctor_with_context_and_module(mlir::MLIRContext* context, mlir::ModuleOp* module) { + return new ifrt::HloProgram(context, module); +} + #pragma mark xla::ifrt::Compiler extern "C" ifrt::LoadedExecutable* ifrt_compiler_compile(ifrt::Compiler* compiler, ifrt::Program* program, char** error) { // apparently ifrt::CompileOptions is a legacy artifact so we don't use it and set directly to the default From 070933579a4dc813b21367b4d3b694716c3952f4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Wed, 16 Oct 2024 19:13:31 +0200 Subject: [PATCH 10/25] Refactor `#pragma region`s --- deps/ReactantExtra/API.cpp | 55 ++++++++++++++++++++++++++------------ 1 file changed, 38 insertions(+), 17 deletions(-) diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index fe7324b42d..5117e47157 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -485,7 +485,9 @@ extern "C" void InitializeRegistryAndPasses(MlirDialectRegistry creg) { mlir::enzyme::registerEnzymeJaxTransformExtension(registry); } -#pragma mark xla::ifrt::DType +#pragma region xla::ifrt + +#pragma region xla::ifrt::DType extern "C" ifrt::DType* ifrt_dtype_ctor(ifrt::DType::Kind kind) { return new ifrt::DType(kind); } @@ -527,8 +529,9 @@ extern "C" int ifrt_dtype_bit_size(ifrt::DType* dtype) { extern "C" const char* ifrt_dtype_debug_string(ifrt::DType* dtype) { return cstr_from_string(dtype->DebugString()); } +#pragma endregion -#pragma mark xla::ifrt::Shape +#pragma region xla::ifrt::Shape extern "C" ifrt::Shape* ifrt_shape_ctor(const int64_t* dims, size_t dims_size) { return new ifrt::Shape(absl::Span(dims, dims_size)); } @@ -548,8 +551,9 @@ extern "C" int64_t ifrt_shape_dims_num_elements(ifrt::Shape* shape) { extern "C" const char* ifrt_shape_debug_string(ifrt::Shape* shape) { return cstr_from_string(shape->DebugString()); } +#pragma endregion -#pragma mark xla::ifrt::DynamicShape +#pragma region xla::ifrt::DynamicShape extern "C" ifrt::DynamicShape* ifrt_dynamicshape_ctor(ifrt::Shape* shape, bool dynamic_dims_mask) { std::vector bool_vector(dynamic_dims_mask, dynamic_dims_mask + shape->dims().size()); auto tag = ifrt::BoundedDynamicShapeTag(absl::Span(bool_vector)); @@ -581,8 +585,9 @@ extern "C" bool ifrt_dynamicshape_is_dynamic_dim(ifrt::DynamicShape* shape, int extern "C" const char* ifrt_dynamicshape_debug_string(ifrt::DynamicShape* shape) { return cstr_from_string(shape->DebugString()); } +#pragma endregion -#pragma mark xla::ifrt::Index +#pragma region xla::ifrt::Index extern "C" ifrt::Index* ifrt_index_ctor(const int64_t* elements, size_t elements_size) { return new ifrt::Index(absl::Span(elements, elements_size)); } @@ -639,8 +644,9 @@ extern "C" void ifrt_index_mul_inplace(ifrt::Index* index, const int64_t* multip extern "C" const char* ifrt_index_debug_string(ifrt::Index* index) { return cstr_from_string(index->DebugString()); } +#pragma endregion -#pragma mark xla::ifrt::IndexDomain +#pragma region xla::ifrt::IndexDomain extern "C" ifrt::IndexDomain* ifrt_indexdomain_ctor(ifrt::Shape* shape) { return new ifrt::IndexDomain(*shape); } @@ -688,8 +694,9 @@ extern "C" void ifrt_indexdomain_sub_inplace(ifrt::IndexDomain* index_domain, if extern "C" const char* ifrt_indexdomain_debug_string(ifrt::IndexDomain* index_domain) { return cstr_from_string(index_domain->DebugString()); } +#pragma endregion -#pragma mark xla::ifrt::MemoryKind +#pragma region xla::ifrt::MemoryKind // Pass a nullptr to create a `MemoryKind` with no memory chosen. extern "C" ifrt::MemoryKind* ifrt_memorykind_ctor(const char* memory_kind) { ifrt::MemoryKind tmp{}; @@ -721,8 +728,9 @@ extern "C" const char* ifrt_memorykind_debug_string(ifrt::MemoryKind* memory_kin extern "C" ifrt::MemoryKind* ifrt_memorykind_canonicalize(ifrt::MemoryKind* memory_kind, ifrt::Device* device) { return new ifrt::MemoryKind(CanonicalizeMemoryKind(*memory_kind, device)); } +#pragma endregion -#pragma mark xla::ifrt::Memory +#pragma region xla::ifrt::Memory extern "C" ifrt::Memory* ifrt_memory_ctor() { return new ifrt::Memory(); } @@ -752,8 +760,9 @@ extern "C" std::tuple ifrt_memory_devices(ifrt::Me auto devices = memory->Devices(); return std::make_tuple; } +#pragma endregion -#pragma mark xla::ifrt::Device +#pragma region xla::ifrt::Device extern "C" ifrt::Device* ifrt_device_ctor() { return new ifrt::Device(); } @@ -798,8 +807,9 @@ extern "C" bool ifrt_device_is_addressable(ifrt::Device* device) { extern "C" int ifrt_device_process_index(ifrt::Device* device) { return device->process_index(); } +#pragma endregion -#pragma mark xla::ifrt::Sharding +#pragma region xla::ifrt::Sharding // TODO ifrt_sharding_devices // TODO ifrt_sharding_memory_kind @@ -819,8 +829,9 @@ extern "C" int ifrt_device_process_index(ifrt::Device* device) { extern "C" const char* ifrt_sharding_debug_string(ifrt::Sharding* sharding) { return cstr_from_string(sharding->DebugString()); } +#pragma endregion -#pragma mark xla::ifrt::Array +#pragma region xla::ifrt::Array extern "C" ifrt::Array* ifrt_array_ctor() { return new ifrt::Array(); } @@ -834,8 +845,9 @@ extern "C" ifrt::DType ifrt_array_dtype(ifrt::Array* array) { } // ... +#pragma endregion -#pragma mark xla::ifrt::Client +#pragma region xla::ifrt::Client extern "C" int ifrt_client_device_count(ifrt::Client* client) { return client->device_count(); } @@ -872,8 +884,9 @@ extern "C" ifrt::Compiler* ifrt_client_default_compiler(ifrt::Client* client) { // TODO ifrt_client_topology_for_devices // TODO ifrt_client_default_layout_for_device +#pragma endregion -#pragma mark xla::ifrt::Executable +#pragma region xla::ifrt::Executable extern "C" const char* ifrt_executable_name(ifrt::Executable* executable) { return cstr_from_string(executable->name()); } @@ -903,8 +916,9 @@ extern "C" int64_t ifrt_executable_size(ifrt::Executable* executable) { // TODO xla::ifrt::Executable::GetOutputLayouts // TODO xla::ifrt::Executable::GetHloModules // TODO xla::ifrt::Executable::GetCostAnalysis +#pragma endregion -#pragma mark xla::ifrt::LoadedExecutable +#pragma region xla::ifrt::LoadedExecutable extern "C" ifrt::Client* ifrt_loadedexecutable_client(ifrt::LoadedExecutable* executable) { return executable->client(); } @@ -963,10 +977,12 @@ extern "C" bool ifrt_loadedexecutable_is_deleted(ifrt::LoadedExecutable* executa // TODO xla::ifrt::LoadedExecutable::addressable_devices // TODO auxiliary functions for xla::ifrt::LoadedExecutable::ExecuteResult +#pragma endregion -#pragma mark xla::ifrt::CustomCallProgram +#pragma region xla::ifrt::CustomCallProgram +#pragma endregion -#pragma mark xla::ifrt::HloProgram +#pragma region xla::ifrt::HloProgram extern "C" ifrt::HloProgram* ifrt_hloprogram_ctor() { return new ifrt::HloProgram(); } @@ -978,8 +994,9 @@ extern "C" ifrt::HloProgram* ifrt_hloprogram_ctor_with_module(mlir::ModuleOp* mo extern "C" ifrt::HloProgram* ifrt_hloprogram_ctor_with_context_and_module(mlir::MLIRContext* context, mlir::ModuleOp* module) { return new ifrt::HloProgram(context, module); } +#pragma endregion -#pragma mark xla::ifrt::Compiler +#pragma region xla::ifrt::Compiler extern "C" ifrt::LoadedExecutable* ifrt_compiler_compile(ifrt::Compiler* compiler, ifrt::Program* program, char** error) { // apparently ifrt::CompileOptions is a legacy artifact so we don't use it and set directly to the default return unwrap_absl_statusor(compiler->Compile(*program, *options, ifrt::CompileOptions()), error); @@ -989,9 +1006,12 @@ extern "C" ifrt::LoadedExecutable* ifrt_compiler_deserialize_loadedexecutable(if // apparently ifrt::DeserializeExecutableOptions is a legacy artifact so we don't use it and set directly to the default return unwrap_absl_statusor(compiler->DeserializeLoadedExecutable(data, size, ifrt::DeserializeExecutableOptions()), error); } +#pragma endregion + +#pragma endregion // auxiliar functions -#pragma mark - +#pragma region utils template const char* cstr_from_string(T text) { char* cstr = (char*)malloc(text.size() + 1); @@ -1012,3 +1032,4 @@ T* unwrap_absl_statusor(absl::StatusOr status, char** error_msg) { } return status.value(); } +#pragma endregion From 49be8cfee1f577646b6b6d4be9f70de07ecd4dc9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sun, 20 Oct 2024 16:39:52 +0200 Subject: [PATCH 11/25] more methods --- deps/ReactantExtra/API.cpp | 98 +++++++++++++++++++++++++++++++++----- 1 file changed, 86 insertions(+), 12 deletions(-) diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index 5117e47157..5da565c383 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -69,6 +69,7 @@ #include "xla/python/ifrt/device.h" #include "xla/python/ifrt/sharding.h" #include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/topology.h" #include "xla/python/ifrt/client.h" #include "xla/python/ifrt/executable.h" #include "xla/python/ifrt/hlo/hlo_program.h" @@ -847,6 +848,35 @@ extern "C" ifrt::DType ifrt_array_dtype(ifrt::Array* array) { // ... #pragma endregion +#pragma region xla::ifrt::Topology +extern "C" const char* ifrt_topology_platform_name(ifrt::Topology* topology) { + return cstr_from_string(topology->platform_name()); +} + +extern "C" const char* ifrt_topology_platform_version(ifrt::Topology* topology) { + return cstr_from_string(topology->platform_version()); +} + +// returns PjRtPlatformId which is a type alias for uint64_t +extern "C" uint64_t ifrt_topology_platform_id(ifrt::Topology* topology) { + return topology->platform_id(); +} + +extern "C" std::tuple ifrt_topology_device_descriptions(ifrt::Topology* topology) { + auto descriptions = topology->device_descriptions(); + return std::make_tuple(descriptions.size(), descriptions.data()); +} + +// TODO xla::ifrt::Topology::GetDefaultLayout + +extern "C" const char* ifrt_topology_serialize(ifrt::Topology* topology) { + return cstr_from_string(xla::ValueOrThrow(topology->Serialize())); +} + +// TODO xla::ifrt::Topology::Attributes + +#pragma endregion + #pragma region xla::ifrt::Client extern "C" int ifrt_client_device_count(ifrt::Client* client) { return client->device_count(); @@ -910,11 +940,32 @@ extern "C" int64_t ifrt_executable_size(ifrt::Executable* executable) { } // TODO xla::ifrt::Executable::GetCompiledMemoryStats -// TODO xla::ifrt::Executable::GetParameterShardings -// TODO xla::ifrt::Executable::GetOutputShardings -// TODO xla::ifrt::Executable::GetParameterLayouts -// TODO xla::ifrt::Executable::GetOutputLayouts -// TODO xla::ifrt::Executable::GetHloModules + +extern "C" std::tuple ifrt_executable_parameter_shardings(ifrt::Executable* executable) { + auto shardings = unwrap_absl_statusor(executable->GetParameterShardings()); + return std::make_tuple(shardings.size(), shardings.data()); +} + +extern "C" std::tuple ifrt_executable_output_shardings(ifrt::Executable* executable) { + auto shardings = unwrap_absl_statusor(executable->GetOutputShardings()); + return std::make_tuple(shardings.size(), shardings.data()); +} + +extern "C" std::tuple ifrt_executable_parameter_layouts(ifrt::Executable* executable) { + auto layouts = unwrap_absl_statusor(executable->GetParameterLayouts()); + return std::make_tuple(layouts.size(), layouts.data()); +} + +extern "C" std::tuple ifrt_executable_output_layouts(ifrt::Executable* executable) { + auto layouts = unwrap_absl_statusor(executable->GetOutputLayouts()); + return std::make_tuple(layouts.size(), layouts.data()); +} + +extern "C" std::tuple ifrt_executable_hlo_modules(ifrt::Executable* executable) { + auto modules = unwrap_absl_statusor(executable->GetHloModules()); + return std::make_tuple(modules.size(), modules.data()); +} + // TODO xla::ifrt::Executable::GetCostAnalysis #pragma endregion @@ -950,11 +1001,32 @@ extern "C" int64_t ifrt_loadedexecutable_size(ifrt::LoadedExecutable* executable } // TODO xla::ifrt::GetCompiledMemoryStats -// TODO xla::ifrt::GetParameterShardings -// TODO xla::ifrt::GetOutputShardings -// TODO xla::ifrt::GetParameterLayouts -// TODO xla::ifrt::GetOutputLayouts -// TODO xla::ifrt::GetHloModules + +extern "C" std::tuple ifrt_loadedexecutable_parameter_shardings(ifrt::LoadedExecutable* executable) { + auto shardings = unwrap_absl_statusor(executable->GetParameterShardings()); + return std::make_tuple(shardings.size(), shardings.data()); +} + +extern "C" std::tuple ifrt_loadedexecutable_output_shardings(ifrt::LoadedExecutable* executable) { + auto shardings = unwrap_absl_statusor(executable->GetOutputShardings()); + return std::make_tuple(shardings.size(), shardings.data()); +} + +extern "C" std::tuple ifrt_loadedexecutable_parameter_layouts(ifrt::LoadedExecutable* executable) { + auto layouts = unwrap_absl_statusor(executable->GetParameterLayouts()); + return std::make_tuple(layouts.size(), layouts.data()); +} + +extern "C" std::tuple ifrt_loadedexecutable_output_layouts(ifrt::LoadedExecutable* executable) { + auto layouts = unwrap_absl_statusor(executable->GetOutputLayouts()); + return std::make_tuple(layouts.size(), layouts.data()); +} + +extern "C" std::tuple ifrt_loadedexecutable_hlo_modules(ifrt::LoadedExecutable* executable) { + auto modules = unwrap_absl_statusor(executable->GetHloModules()); + return std::make_tuple(modules.size(), modules.data()); +} + // TODO xla::ifrt::GetOutputMemoryKinds // TODO xla::ifrt::GetCostAnalysis @@ -973,8 +1045,10 @@ extern "C" bool ifrt_loadedexecutable_is_deleted(ifrt::LoadedExecutable* executa return executable->IsDeleted(); } -// TODO xla::ifrt::LoadedExecutable::addressable_device_logical_ids -// TODO xla::ifrt::LoadedExecutable::addressable_devices +extern "C" std::tuple ifrt_loadedexecutable_addressable_devices(ifrt::LoadedExecutable* executable) { + auto devices = executable->addressable_devices(); + return std::make_tuple(devices.size(), devices.data()); +} // TODO auxiliary functions for xla::ifrt::LoadedExecutable::ExecuteResult #pragma endregion From 78eeaa69d112236780dbf27474cf40722f15d16e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sun, 20 Oct 2024 17:21:36 +0200 Subject: [PATCH 12/25] more c methods --- deps/ReactantExtra/API.cpp | 102 ++++++++++++++++++++++++++++++++++--- 1 file changed, 96 insertions(+), 6 deletions(-) diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index 5da565c383..2ed380a1b7 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -61,6 +61,8 @@ #include "llvm-c/TargetMachine.h" // IFRT +#include "xla/python/ifrt/value.h" +#include "xla/python/ifrt/tuple.h" #include "xla/python/ifrt/dtype.h" #include "xla/python/ifrt/shape.h" #include "xla/python/ifrt/index.h" @@ -75,6 +77,11 @@ #include "xla/python/ifrt/hlo/hlo_program.h" #include "xla/python/ifrt/compiler.h" +// IFRT - PJRT +#include "xla/python/pjrt_ifrt/pjrt_dtype.h" +#include "xla/python/pjrt_ifrt/pjrt_memory.h" +#include "xla/python/pjrt_ifrt/pjrt_topology.h" + using namespace mlir; using namespace llvm; using namespace xla; @@ -488,6 +495,52 @@ extern "C" void InitializeRegistryAndPasses(MlirDialectRegistry creg) { #pragma region xla::ifrt +#pragma region xla::ifrt::Value +extern "C" ifrt::Value* ifrt_value_ctor() { + return new ifrt::Value(); +} + +extern "C" void ifrt_value_free(ifrt::Value* value) { + delete value; +} + +extern "C" ifrt::Client* ifrt_value_client(ifrt::Value* value) { + return value->client(); +} + +extern "C" ifrt::Future<> ifrt_value_get_ready_future(ifrt::Value* value) { + return value->GetReadyFuture(); +} + +extern "C" ifrt::Future<> ifrt_value_delete(ifrt::Value* value) { + return value->Delete(); +} + +extern "C" bool ifrt_value_is_deleted(ifrt::Value* value) { + return value->IsDeleted(); +} + +extern "C" const char* ifrt_value_debug_string(ifrt::Value* value) { + return cstr_from_string(value->DebugString()); +} +#pragma endregion + +#pragma region xla::ifrt::Tuple +extern "C" ifrt::Tuple* ifrt_tuple_ctor() { + return new ifrt::Tuple(); +} + +extern "C" void ifrt_tuple_free(ifrt::Tuple* tuple) { + delete tuple; +} + +extern "C" int ifrt_tuple_arity(ifrt::Tuple* tuple) { + return tuple->Arity(); +} + +// TODO ifrt::Tuple::Unpack +#pragma endregion + #pragma region xla::ifrt::DType extern "C" ifrt::DType* ifrt_dtype_ctor(ifrt::DType::Kind kind) { return new ifrt::DType(kind); @@ -530,6 +583,16 @@ extern "C" int ifrt_dtype_bit_size(ifrt::DType* dtype) { extern "C" const char* ifrt_dtype_debug_string(ifrt::DType* dtype) { return cstr_from_string(dtype->DebugString()); } + +// xla::PrimitiveType is a enum, so we use int to represent it on Julia side +extern "C" xla::PrimitiveType ifrt_to_primitive_type(ifrt::DType* dtype) { + return xla::ValueOrThrow(ifrt::ToPrimitiveType(*dtype)); +} + +// xla::PrimitiveType is a enum, so we use int to represent it on Julia side +extern "C" ifrt::DType* ifrt_to_dtype(xla::PrimitiveType primitive_type) { + return new xla::ValueOrThrow(ifrt::ToDType(primitive_type)); +} #pragma endregion #pragma region xla::ifrt::Shape @@ -700,10 +763,9 @@ extern "C" const char* ifrt_indexdomain_debug_string(ifrt::IndexDomain* index_do #pragma region xla::ifrt::MemoryKind // Pass a nullptr to create a `MemoryKind` with no memory chosen. extern "C" ifrt::MemoryKind* ifrt_memorykind_ctor(const char* memory_kind) { - ifrt::MemoryKind tmp{}; - if (memory_kind != nullptr) - tmp = ifrt::MemoryKind(std::string(memory_kind)); - return new ifrt::MemoryKind(tmp); + if (memory_kind == nullptr) + return new ifrt::MemoryKind(); + return new ifrt::MemoryKind(std::string(memory_kind)); } extern "C" void ifrt_memorykind_free(ifrt::MemoryKind* memory_kind) { @@ -763,6 +825,24 @@ extern "C" std::tuple ifrt_memory_devices(ifrt::Me } #pragma endregion +#pragma region xla::ifrt::PjRtMemory +extern "C" ifrt::PjRtMemory* ifrt_pjrt_memory_ctor(ifrt::PjRtClient* client, xla::PjRtMemorySpace* memory_space) { + return new ifrt::PjRtMemory(client, memory_space); +} + +extern "C" void ifrt_pjrt_memory_free(ifrt::PjRtMemory* memory) { + delete memory; +} + +extern "C" ifrt::PjRtClient* ifrt_pjrt_memory_client(ifrt::PjRtMemory* memory) { + return memory->client(); +} + +extern "C" xla::PjRtMemorySpace* ifrt_pjrt_memory_space(ifrt::PjRtMemory* memory) { + return memory->memory_space(); +} +#pragma endregion + #pragma region xla::ifrt::Device extern "C" ifrt::Device* ifrt_device_ctor() { return new ifrt::Device(); @@ -877,6 +957,16 @@ extern "C" const char* ifrt_topology_serialize(ifrt::Topology* topology) { #pragma endregion +#pragma region xla::ifrt::PjRtTopology +extern "C" ifrt::PjRtTopology* ifrt_pjrt_topology_ctor(const xla::PjRtTopologyDescription* description) { + return new ifrt::PjRtTopology(description); +} + +extern "C" const xla::PjRtTopologyDescription* ifrt_pjrt_topology_description(ifrt::PjRtTopology* topology) { + return topology->description(); +} +#pragma endregion + #pragma region xla::ifrt::Client extern "C" int ifrt_client_device_count(ifrt::Client* client) { return client->device_count(); @@ -1027,8 +1117,8 @@ extern "C" std::tuple ifrt_loadedexecutable_hlo_module return std::make_tuple(modules.size(), modules.data()); } -// TODO xla::ifrt::GetOutputMemoryKinds -// TODO xla::ifrt::GetCostAnalysis +// TODO xla::ifrt::LoadedExecutable::GetOutputMemoryKinds +// TODO xla::ifrt::LoadedExecutable::GetCostAnalysis // extern "C" ifrt::LoadedExecutable::ExecuteResult* ifrt_loadedexecutable_execute(ifrt::LoadedExecutable* executable, ifrt::Array** args, size_t args_size, ifrt::Array** results, size_t results_size, ifrt::Future<*>** futures, size_t futures_size) { // std::vector arguments(args, args + args_size); From 65ed193b995bc064716b06c0fea814f227f28826 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sun, 20 Oct 2024 20:00:02 +0200 Subject: [PATCH 13/25] more methods --- deps/ReactantExtra/API.cpp | 44 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index 2ed380a1b7..05859cfe60 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -541,6 +541,16 @@ extern "C" int ifrt_tuple_arity(ifrt::Tuple* tuple) { // TODO ifrt::Tuple::Unpack #pragma endregion +#pragma region xla::ifrt::PjRtTuple +extern "C" ifrt::PjRtTuple* ifrt_pjrt_tuple_ctor(ifrt::PjRtCompatibleClient* client, ifrt::Value* values, int nvalues) { + return new ifrt::PjRtTuple(client, values); +} + +extern "C" void ifrt_pjrt_tuple_free(ifrt::PjRtTuple* tuple) { + delete tuple; +} +#pragma endregion + #pragma region xla::ifrt::DType extern "C" ifrt::DType* ifrt_dtype_ctor(ifrt::DType::Kind kind) { return new ifrt::DType(kind); @@ -1059,6 +1069,20 @@ extern "C" std::tuple ifrt_executable_hlo_modules(ifrt // TODO xla::ifrt::Executable::GetCostAnalysis #pragma endregion +#pragma region xla::ifrt::PjRtExecutable +extern "C" ifrt::PjRtExecutable* ifrt_pjrt_executable_ctor(xla::PjRtExecutable* pjrt_executable, ifrt::XlaCompileOptions* compile_options) { + return new xla::ValueOrThrow(ifrt::PjRtExecutable(pjrt_executable, compile_options)); +} + +extern "C" void ifrt_pjrt_executable_free(ifrt::PjRtExecutable* executable) { + delete executable; +} + +extern "C" xla::PjRtExecutable* ifrt_pjrt_executable_pjrt_executable(ifrt::PjRtExecutable* executable) { + return executable->pjrt_executable(); +} +#pragma endregion + #pragma region xla::ifrt::LoadedExecutable extern "C" ifrt::Client* ifrt_loadedexecutable_client(ifrt::LoadedExecutable* executable) { return executable->client(); @@ -1143,6 +1167,26 @@ extern "C" std::tuple ifrt_loadedexecutable_addressable_d // TODO auxiliary functions for xla::ifrt::LoadedExecutable::ExecuteResult #pragma endregion +#pragma region xla::ifrt::PjRtLoadedExecutable +// TODO add support for LoadedHostCallback +extern "C" ifrt::PjRtLoadedExecutable* ifrt_pjrt_loadedexecutable_ctor(ifrt::PjRtCompatibleClient* client, xla::PjRtLoadedExecutable* pjrt_loaded_executable) { + return new xla::ValueOrThrow(ifrt::PjRtLoadedExecutable(client, pjrt_loaded_executable, std::vector>())); +} + +// TODO add support for LoadedHostCallback +extern "C" ifrt::PjRtLoadedExecutable* ifrt_pjrt_loadedexecutable_ctor_from_mlir_module(ifrt::PjRtCompatibleClient* client, mlir::ModuleOp* module, xla::CompileOptions* compile_options) { + return new xla::ValueOrThrow(ifrt::PjRtLoadedExecutable(client, *module, *compile_options, std::vector>())); +} + +extern "C" void ifrt_pjrt_loadedexecutable_free(ifrt::PjRtLoadedExecutable* executable) { + delete executable; +} + +extern "C" xla::PjRtLoadedExecutable* ifrt_pjrt_loadedexecutable_pjrt_loadedexecutable(ifrt::PjRtLoadedExecutable* executable) { + return executable->pjrt_loadedexecutable(); +} +#pragma endregion + #pragma region xla::ifrt::CustomCallProgram #pragma endregion From eb264013dd8f04d2755b4cb252340e56b3cb4fea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sun, 20 Oct 2024 20:25:22 +0200 Subject: [PATCH 14/25] more stuff --- deps/ReactantExtra/API.cpp | 64 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 64 insertions(+) diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index 05859cfe60..89f5becded 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -900,6 +900,22 @@ extern "C" int ifrt_device_process_index(ifrt::Device* device) { } #pragma endregion +#pragma region xla::ifrt::PjRtDevice +// DeviceId is a struct with a single int32_t field --> check out xla/pjrt/pjrt_common.h +// TODO support `attributes` parameter +extern "C" ifrt::PjRtDevice* ifrt_pjrt_device_ctor(ifrt::PjRtClient* client, DeviceId device_id, const char* kind, const char* to_string, const char* debug_string, int process_index, xla::PjRtDevice* pjrt_device) { + return new ifrt::PjRtDevice(client, device_id, kind, to_string, debug_string, process_index, absl::flat_hash_map(), pjrt_device); +} + +extern "C" void ifrt_pjrt_device_free(ifrt::PjRtDevice* device) { + delete device; +} + +extern "C" xla::PjRtDevice* ifrt_pjrt_device_pjrt_device(ifrt::PjRtDevice* device) { + return device->pjrt_device(); +} +#pragma endregion + #pragma region xla::ifrt::Sharding // TODO ifrt_sharding_devices // TODO ifrt_sharding_memory_kind @@ -938,6 +954,15 @@ extern "C" ifrt::DType ifrt_array_dtype(ifrt::Array* array) { // ... #pragma endregion +#pragma region xla::ifrt::PjRtArray +// TODO constructors / `Create` + +extern "C" std::tuple ifrt_pjrt_array_pjrt_buffers(ifrt::PjRtArray* array) { + auto buffers = array->pjrt_buffers(); + return std::make_tuple(buffers.size(), buffers.data()); +} +#pragma endregion + #pragma region xla::ifrt::Topology extern "C" const char* ifrt_topology_platform_name(ifrt::Topology* topology) { return cstr_from_string(topology->platform_name()); @@ -1016,6 +1041,35 @@ extern "C" ifrt::Compiler* ifrt_client_default_compiler(ifrt::Client* client) { // TODO ifrt_client_default_layout_for_device #pragma endregion +#pragma region xla::ifrt::PjRtClient +// TODO support more parameters of `PjRtClient::CreateOptions` +extern "C" ifrt::PjRtClient* ifrt_pjrt_client_ctor(xla::PjRtClient* pjrt_client) { + return new xla::ValueOrThrow(ifrt::PjRtClient(ifrt::PjRtClient::CreateOptions{pjrt_client})); +} + +extern "C" void ifrt_pjrt_client_free(ifrt::PjRtClient* client) { + delete client; +} + +extern "C" xla::PjRtClient* ifrt_pjrt_client_pjrt_client(ifrt::PjRtClient* client) { + return client->pjrt_client(); +} + +extern "C" ifrt::PjRtCompatibleArray* ifrt_pjrt_client_create_pjrt_array(ifrt::PjRtClient* client, ifrt::PjRtBuffer* pjrt_buffer) { + return new xla::ValueOrThrow(client->Create(pjrt_buffer)); +} + +// TODO extern "C" ifrt::PjRtCompatibleArray* ifrt_pjrt_client_create_pjrt_array_from_buffers(ifrt::Shape* shape, ifrt::PjRtBuffer** pjrt_buffers, int num_buffers) {} + +extern "C" ifrt::PjRtCompatibleDevice* ifrt_pjrt_client_lookup_pjrt_device(ifrt::PjRtClient* client, xla::PjRtDevice* pjrt_device) { + return xla::ValueOrThrow(client->LookupPjRtDevice(pjrt_device)); +} + +extern "C" ifrt::PjRtCompatibleMemory* ifrt_pjrt_client_lookup_pjrt_memory(ifrt::PjRtClient* client, xla::PjRtMemorySpace* pjrt_memory_space) { + return xla::ValueOrThrow(client->LookupPjRtMemory(pjrt_memory_space)); +} +#pragma endregion + #pragma region xla::ifrt::Executable extern "C" const char* ifrt_executable_name(ifrt::Executable* executable) { return cstr_from_string(executable->name()); @@ -1216,6 +1270,16 @@ extern "C" ifrt::LoadedExecutable* ifrt_compiler_deserialize_loadedexecutable(if } #pragma endregion +#pragma region xla::ifrt::PjRtCompiler +extern "C" ifrt::PjRtCompiler* ifrt_pjrt_compiler_ctor(ifrt::PjRtClient* client) { + return new ifrt::PjRtCompiler(client); +} + +extern "C" void ifrt_pjrt_compiler_free(ifrt::PjRtCompiler* compiler) { + delete compiler; +} +#pragma endregion + #pragma endregion // auxiliar functions From df60112dfedd3650e2e8f0c845819882bf1dd85e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sun, 20 Oct 2024 20:46:08 +0200 Subject: [PATCH 15/25] more changes --- deps/ReactantExtra/API.cpp | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index 89f5becded..e290f80858 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -73,14 +73,21 @@ #include "xla/python/ifrt/array.h" #include "xla/python/ifrt/topology.h" #include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/host_callback.h" #include "xla/python/ifrt/executable.h" #include "xla/python/ifrt/hlo/hlo_program.h" #include "xla/python/ifrt/compiler.h" // IFRT - PJRT #include "xla/python/pjrt_ifrt/pjrt_dtype.h" +#include "xla/python/pjrt_ifrt/pjrt_tuple.h" #include "xla/python/pjrt_ifrt/pjrt_memory.h" +#include "xla/python/pjrt_ifrt/pjrt_device.h" #include "xla/python/pjrt_ifrt/pjrt_topology.h" +#include "xla/python/pjrt_ifrt/pjrt_client.h" +#include "xla/python/pjrt_ifrt/pjrt_host_callback.h" +#include "xla/python/pjrt_ifrt/pjrt_executable.h" +#include "xla/python/pjrt_ifrt/pjrt_compiler.h" using namespace mlir; using namespace llvm; @@ -1070,6 +1077,36 @@ extern "C" ifrt::PjRtCompatibleMemory* ifrt_pjrt_client_lookup_pjrt_memory(ifrt: } #pragma endregion +#pragma region xla::ifrt::HostCallback +extern "C" ifrt::HostCallback* ifrt_hostcallback_serialize(ifrt::HostCallback* host_callback) { + return cstr_from_string(host_callback->Serialize()); +} +#pragma endregion + +#pragma region xla::ifrt::LoadedHostCallback +extern "C" ifrt::Client* ifrt_loadedhostcallback_client(ifrt::LoadedHostCallback* host_callback) { + return host_callback->client(); +} + +extern "C" const char* ifrt_loadedhostcallback_serialize(ifrt::LoadedHostCallback* host_callback) { + return cstr_from_string(host_callback->Serialize()); +} +#pragma endregion + +#pragma region xla::ifrt::PjRtHostSendAndRecvLoadHostCallback +extern "C" ifrt::PjRtHostSendAndRecvLoadHostCallback* ifrt_pjrt_hostsendandrecv_loadhostcallback_ctor(ifrt::PjRtClient* client, xla::HostCallback* host_callback) { + return new xla::ValueOrThrow(ifrt::PjRtHostSendAndRecvLoadHostCallback(client, host_callback)); +} + +extern "C" void ifrt_pjrt_hostsendandrecv_loadhostcallback_free(ifrt::PjRtHostSendAndRecvLoadHostCallback* host_callback) { + delete host_callback; +} + +extern "C" xla::HostCallback* ifrt_pjrt_hostsendandrecv_loadhostcallback_host_callback(ifrt::PjRtHostSendAndRecvLoadHostCallback* host_callback) { + return host_callback->host_callback(); +} +#pragma endregion + #pragma region xla::ifrt::Executable extern "C" const char* ifrt_executable_name(ifrt::Executable* executable) { return cstr_from_string(executable->name()); From 77924503eba873934ad55d878a90204e12e94c9a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Wed, 23 Oct 2024 10:44:18 +0200 Subject: [PATCH 16/25] more changes to `ifrt::Array` --- deps/ReactantExtra/API.cpp | 29 +++++++++++++++++++++++------ 1 file changed, 23 insertions(+), 6 deletions(-) diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index e290f80858..dc3e4dfe4b 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -950,15 +950,32 @@ extern "C" ifrt::Array* ifrt_array_ctor() { return new ifrt::Array(); } -// extern "C" void ifrt_array_free(ifrt::Array* array) { -// delete array; -// } +extern "C" void ifrt_array_free(ifrt::Array* array) { + delete array; +} + +extern "C" ifrt::DType* ifrt_array_dtype(ifrt::Array* array) { + return new ifrt::DTypep(array->dtype()); +} + +extern "C" ifrt::Shape* ifrt_array_shape(ifrt::Array* array) { + return &array->shape(); +} + +extern "C" ifrt::Sharding* ifrt_array_sharding(ifrt::Array* array) { + return &array->sharding(); +} -extern "C" ifrt::DType ifrt_array_dtype(ifrt::Array* array) { - return array->dtype(); +extern "C" ifrt::PjRtLayout* ifrt_array_layout(ifrt::Array* array) { + return array->layout().release(); } -// ... +// TODO xla::ifrt::Array::DisassembleIntoSingleDeviceArrays +// TODO xla::ifrt::Array::FullyReplicatedShard + +extern "C" Future<> ifrt_array_copy_to_host_buffer(ifrt::Array* array, void* data, const int64_t* byte_strides, int semantics) { + return array->CopyToHostBuffer(data, absl::Span(byte_strides, array->shape().num_elements()), semantics); +} #pragma endregion #pragma region xla::ifrt::PjRtArray From f351548b3be56c77cd374427be840d452d769469 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Tue, 29 Oct 2024 18:51:21 +0100 Subject: [PATCH 17/25] Fix PJRT-backed IFRT-backend lib dependency --- deps/ReactantExtra/BUILD | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deps/ReactantExtra/BUILD b/deps/ReactantExtra/BUILD index c7f7960815..fc77619593 100644 --- a/deps/ReactantExtra/BUILD +++ b/deps/ReactantExtra/BUILD @@ -351,7 +351,7 @@ cc_library( "@xla//xla/pjrt:status_casters", "@xla//xla/python/ifrt:ifrt", - "@xla//xla/python/pjrt_ifrt:xla_ifrt", + "@xla//xla/python/pjrt_ifrt:pjrt_ifrt", "@xla//xla/python/ifrt/hlo:hlo_program", "@xla//xla/ffi:call_frame", "@com_google_protobuf//:protobuf", From cbc47873e2d69c889e6459d0dbd9e2d6c58d06ea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Wed, 30 Oct 2024 18:43:13 +0100 Subject: [PATCH 18/25] fixes --- deps/ReactantExtra/API.cpp | 340 ++++++++++++++++++++----------------- 1 file changed, 180 insertions(+), 160 deletions(-) diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index dc3e4dfe4b..09d9d4c5fb 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -83,6 +83,7 @@ #include "xla/python/pjrt_ifrt/pjrt_tuple.h" #include "xla/python/pjrt_ifrt/pjrt_memory.h" #include "xla/python/pjrt_ifrt/pjrt_device.h" +#include "xla/python/pjrt_ifrt/pjrt_array.h" #include "xla/python/pjrt_ifrt/pjrt_topology.h" #include "xla/python/pjrt_ifrt/pjrt_client.h" #include "xla/python/pjrt_ifrt/pjrt_host_callback.h" @@ -118,6 +119,30 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirComplexAttrDoubleGetChecked(MlirLocation lo // TODO MLIR_CAPI_EXPORTED MlirTypeID mlirComplexAttrGetTypeID(void) { return wrap(complex::NumberAttr::getTypeID()); } #pragma endregion +// auxiliar functions +#pragma region utils +template +const char* cstr_from_string(T text) { + char* cstr = (char*)malloc(text.size() + 1); + memcpy(cstr, text.data(), text.size()); + cstr[text.size()] = '\0'; + return cstr; +} + +template +T* unwrap_absl_statusor(absl::StatusOr status, char** error_msg) { + *error_msg = nullptr; + if (!status.ok()) { + auto str = status.message(); + char* err = (char*)malloc(str.size()+1); + memcpy(err, str.data(), str.size()+1); + *error_msg = err; + return nullptr; + } + return status.value(); +} +#pragma endregion + // int google::protobuf::io::CodedInputStream::default_recursion_limit_ = 100; // int xla::_LayoutProto_default_instance_; @@ -503,14 +528,6 @@ extern "C" void InitializeRegistryAndPasses(MlirDialectRegistry creg) { #pragma region xla::ifrt #pragma region xla::ifrt::Value -extern "C" ifrt::Value* ifrt_value_ctor() { - return new ifrt::Value(); -} - -extern "C" void ifrt_value_free(ifrt::Value* value) { - delete value; -} - extern "C" ifrt::Client* ifrt_value_client(ifrt::Value* value) { return value->client(); } @@ -533,14 +550,6 @@ extern "C" const char* ifrt_value_debug_string(ifrt::Value* value) { #pragma endregion #pragma region xla::ifrt::Tuple -extern "C" ifrt::Tuple* ifrt_tuple_ctor() { - return new ifrt::Tuple(); -} - -extern "C" void ifrt_tuple_free(ifrt::Tuple* tuple) { - delete tuple; -} - extern "C" int ifrt_tuple_arity(ifrt::Tuple* tuple) { return tuple->Arity(); } @@ -550,7 +559,12 @@ extern "C" int ifrt_tuple_arity(ifrt::Tuple* tuple) { #pragma region xla::ifrt::PjRtTuple extern "C" ifrt::PjRtTuple* ifrt_pjrt_tuple_ctor(ifrt::PjRtCompatibleClient* client, ifrt::Value* values, int nvalues) { - return new ifrt::PjRtTuple(client, values); + auto values_ptr = new tsl::RCReference[nvalues]; + for (int i=0; i(values[i]); + } + auto span = absl::Span>(values_ptr, nvalues); + return new ifrt::PjRtTuple(client, span); } extern "C" void ifrt_pjrt_tuple_free(ifrt::PjRtTuple* tuple) { @@ -608,7 +622,8 @@ extern "C" xla::PrimitiveType ifrt_to_primitive_type(ifrt::DType* dtype) { // xla::PrimitiveType is a enum, so we use int to represent it on Julia side extern "C" ifrt::DType* ifrt_to_dtype(xla::PrimitiveType primitive_type) { - return new xla::ValueOrThrow(ifrt::ToDType(primitive_type)); + auto dtype = xla::ValueOrThrow(ifrt::ToDType(primitive_type)); + return new ifrt::DType(dtype.kind()); } #pragma endregion @@ -635,10 +650,10 @@ extern "C" const char* ifrt_shape_debug_string(ifrt::Shape* shape) { #pragma endregion #pragma region xla::ifrt::DynamicShape -extern "C" ifrt::DynamicShape* ifrt_dynamicshape_ctor(ifrt::Shape* shape, bool dynamic_dims_mask) { - std::vector bool_vector(dynamic_dims_mask, dynamic_dims_mask + shape->dims().size()); - auto tag = ifrt::BoundedDynamicShapeTag(absl::Span(bool_vector)); - return new ifrt::DynamicShape(*shape, tag); +extern "C" ifrt::DynamicShape* ifrt_dynamicshape_ctor(ifrt::Shape* shape, const bool* dynamic_dims_mask) { + auto tag = ifrt::BoundedDynamicShapeTag(absl::Span(dynamic_dims_mask, shape->dims().size())); + auto dynshape = xla::ValueOrThrow(ifrt::DynamicShape::Create(*shape, tag)); + return new ifrt::DynamicShape(dynshape); } extern "C" void ifrt_dynamicshape_free(ifrt::DynamicShape* shape) { @@ -656,7 +671,8 @@ extern "C" bool ifrt_dynamicshape_ne(ifrt::DynamicShape* shape1, ifrt::DynamicSh } extern "C" ifrt::Shape* ifrt_dynamicshape_get_padded_shape(ifrt::DynamicShape* shape) { - return xla::ValueOrThrow(shape->GetPaddedShape()).release(); + auto padshape = xla::ValueOrThrow(shape->GetPaddedShape()); + return new ifrt::Shape(padshape); } extern "C" bool ifrt_dynamicshape_is_dynamic_dim(ifrt::DynamicShape* shape, int dimension) { @@ -801,24 +817,12 @@ extern "C" const char* ifrt_memorykind_string(ifrt::MemoryKind* memory_kind) { return cstr_from_string(memory_kind->memory_kind()); } -extern "C" const char* ifrt_memorykind_debug_string(ifrt::MemoryKind* memory_kind) { - return cstr_from_string(memory_kind->DebugString()); -} - extern "C" ifrt::MemoryKind* ifrt_memorykind_canonicalize(ifrt::MemoryKind* memory_kind, ifrt::Device* device) { return new ifrt::MemoryKind(CanonicalizeMemoryKind(*memory_kind, device)); } #pragma endregion #pragma region xla::ifrt::Memory -extern "C" ifrt::Memory* ifrt_memory_ctor() { - return new ifrt::Memory(); -} - -extern "C" void ifrt_memory_free(ifrt::Memory* memory) { - delete memory; -} - // MemoryId is a struct with a single int32_t field --> check out xla/python/ifrt/memory.h extern "C" ifrt::MemoryId ifrt_memory_id(ifrt::Memory* memory) { return memory->Id(); @@ -836,9 +840,9 @@ extern "C" const char* ifrt_memory_debug_string(ifrt::Memory* memory) { return cstr_from_string(memory->DebugString()); } -extern "C" std::tuple ifrt_memory_devices(ifrt::Memory* memory) { +extern "C" std::tuple ifrt_memory_devices(ifrt::Memory* memory) { auto devices = memory->Devices(); - return std::make_tuple; + return std::make_tuple(devices.size(), devices.data()); } #pragma endregion @@ -856,44 +860,36 @@ extern "C" ifrt::PjRtClient* ifrt_pjrt_memory_client(ifrt::PjRtMemory* memory) { } extern "C" xla::PjRtMemorySpace* ifrt_pjrt_memory_space(ifrt::PjRtMemory* memory) { - return memory->memory_space(); + return memory->pjrt_memory(); } #pragma endregion #pragma region xla::ifrt::Device -extern "C" ifrt::Device* ifrt_device_ctor() { - return new ifrt::Device(); -} - -extern "C" void ifrt_device_free(ifrt::Device* device) { - delete device; -} - extern "C" ifrt::Client* ifrt_device_client(ifrt::Device* device) { return device->client(); } // DeviceId is a struct with a single int32_t field --> check out xla/pjrt/pjrt_common.h extern "C" ifrt::DeviceId ifrt_device_id(ifrt::Device* device) { - return device->id(); + return device->Id(); } // TODO ifrt_device_attributes extern "C" const char* ifrt_device_kind(ifrt::Device* device) { - return cstr_from_string(device->kind()); + return cstr_from_string(device->Kind()); } extern "C" const char* ifrt_device_to_string(ifrt::Device* device) { return cstr_from_string(device->ToString()); } -extern "C" const char* ifrt_device_from_string(ifrt::Client* client) { - return cstr_from_string(client->DebugString()); +extern "C" const char* ifrt_device_debug_string(ifrt::Device* device) { + return cstr_from_string(device->DebugString()); } -extern "C" ifrt::Memory* ifrt_device_default_memory(ifrt::Device* device, char** error) { - return unwrap_absl_statusor(device->DefaultMemory(), error); +extern "C" ifrt::Memory* ifrt_device_default_memory(ifrt::Device* device) { + return xla::ValueOrThrow(device->DefaultMemory()); } // TODO ifrt_device_memories @@ -903,14 +899,14 @@ extern "C" bool ifrt_device_is_addressable(ifrt::Device* device) { } extern "C" int ifrt_device_process_index(ifrt::Device* device) { - return device->process_index(); + return device->ProcessIndex(); } #pragma endregion #pragma region xla::ifrt::PjRtDevice // DeviceId is a struct with a single int32_t field --> check out xla/pjrt/pjrt_common.h // TODO support `attributes` parameter -extern "C" ifrt::PjRtDevice* ifrt_pjrt_device_ctor(ifrt::PjRtClient* client, DeviceId device_id, const char* kind, const char* to_string, const char* debug_string, int process_index, xla::PjRtDevice* pjrt_device) { +extern "C" ifrt::PjRtDevice* ifrt_pjrt_device_ctor(ifrt::PjRtClient* client, ifrt::DeviceId device_id, const char* kind, const char* to_string, const char* debug_string, int process_index, xla::PjRtDevice* pjrt_device) { return new ifrt::PjRtDevice(client, device_id, kind, to_string, debug_string, process_index, absl::flat_hash_map(), pjrt_device); } @@ -946,44 +942,40 @@ extern "C" const char* ifrt_sharding_debug_string(ifrt::Sharding* sharding) { #pragma endregion #pragma region xla::ifrt::Array -extern "C" ifrt::Array* ifrt_array_ctor() { - return new ifrt::Array(); -} - -extern "C" void ifrt_array_free(ifrt::Array* array) { - delete array; -} - extern "C" ifrt::DType* ifrt_array_dtype(ifrt::Array* array) { - return new ifrt::DTypep(array->dtype()); + return new ifrt::DType(array->dtype()); } -extern "C" ifrt::Shape* ifrt_array_shape(ifrt::Array* array) { - return &array->shape(); +extern "C" const ifrt::Shape* ifrt_array_shape(ifrt::Array* array) { + return &(array->shape()); } -extern "C" ifrt::Sharding* ifrt_array_sharding(ifrt::Array* array) { - return &array->sharding(); +extern "C" const ifrt::Sharding* ifrt_array_sharding(ifrt::Array* array) { + return &(array->sharding()); } -extern "C" ifrt::PjRtLayout* ifrt_array_layout(ifrt::Array* array) { - return array->layout().release(); +extern "C" PjRtLayout* ifrt_array_layout(ifrt::Array* array) { + return xla::ValueOrThrow(array->layout()).release(); } // TODO xla::ifrt::Array::DisassembleIntoSingleDeviceArrays // TODO xla::ifrt::Array::FullyReplicatedShard -extern "C" Future<> ifrt_array_copy_to_host_buffer(ifrt::Array* array, void* data, const int64_t* byte_strides, int semantics) { - return array->CopyToHostBuffer(data, absl::Span(byte_strides, array->shape().num_elements()), semantics); +extern "C" ifrt::Future<> ifrt_array_copy_to_host_buffer(ifrt::Array* array, void* data, const int64_t* byte_strides, int semantics) { + return array->CopyToHostBuffer(data, absl::Span(byte_strides, array->shape().num_elements()), ifrt::ArrayCopySemantics(semantics)); } #pragma endregion #pragma region xla::ifrt::PjRtArray // TODO constructors / `Create` -extern "C" std::tuple ifrt_pjrt_array_pjrt_buffers(ifrt::PjRtArray* array) { +extern "C" std::tuple ifrt_pjrt_array_pjrt_buffers(ifrt::PjRtArray* array) { auto buffers = array->pjrt_buffers(); - return std::make_tuple(buffers.size(), buffers.data()); + auto buffers_ptr = new xla::PjRtBuffer*[buffers.size()]; + for (int i=0; iplatform_id(); } -extern "C" std::tuple ifrt_topology_device_descriptions(ifrt::Topology* topology) { - auto descriptions = topology->device_descriptions(); - return std::make_tuple(descriptions.size(), descriptions.data()); +extern "C" std::tuple ifrt_topology_device_descriptions(ifrt::Topology* topology) { + auto descriptions = topology->DeviceDescriptions(); + auto descriptions_ptr = new const xla::PjRtDeviceDescription*[descriptions.size()]; + for (int i=0; i{description}); } extern "C" const xla::PjRtTopologyDescription* ifrt_pjrt_topology_description(ifrt::PjRtTopology* topology) { - return topology->description(); + return topology->description().get(); } #pragma endregion @@ -1035,11 +1031,11 @@ extern "C" int ifrt_client_addressable_device_count(ifrt::Client* client) { return client->addressable_device_count(); } -extern "C" ifrt::Device* ifrt_client_devices(ifrt::Client* client) { +extern "C" ifrt::Device* const* ifrt_client_devices(ifrt::Client* client) { return client->devices().data(); } -extern "C" ifrt::Device* ifrt_client_addressable_devices(ifrt::Client* client) { +extern "C" ifrt::Device* const* ifrt_client_addressable_devices(ifrt::Client* client) { return client->addressable_devices().data(); } @@ -1049,12 +1045,12 @@ extern "C" int ifrt_client_process_index(ifrt::Client* client) { // TODO xla::ifrt::Client::GetDefaultDeviceAssignment -extern "C" ifrt::Device* ifrt_client_lookup_device(ifrt::Client* client, int device_id, **) { +extern "C" ifrt::Device* ifrt_client_lookup_device(ifrt::Client* client, int device_id) { return xla::ValueOrThrow(client->LookupDevice(ifrt::DeviceId(device_id))); } -extern "C" ifrt::Device* ifrt_client_lookup_addressable_device(ifrt::Client* client, int device_id, **) { - return xla::ValueOrThrow(client->LookupAddressableDevice(ifrt::DeviceId(device_id))); +extern "C" ifrt::Device* ifrt_client_lookup_addressable_device(ifrt::Client* client, int device_id) { + return xla::ValueOrThrow(client->LookupAddressableDevice(device_id)); } extern "C" ifrt::Compiler* ifrt_client_default_compiler(ifrt::Client* client) { @@ -1068,7 +1064,7 @@ extern "C" ifrt::Compiler* ifrt_client_default_compiler(ifrt::Client* client) { #pragma region xla::ifrt::PjRtClient // TODO support more parameters of `PjRtClient::CreateOptions` extern "C" ifrt::PjRtClient* ifrt_pjrt_client_ctor(xla::PjRtClient* pjrt_client) { - return new xla::ValueOrThrow(ifrt::PjRtClient(ifrt::PjRtClient::CreateOptions{pjrt_client})); + return xla::ValueOrThrow(ifrt::PjRtClient::Create(ifrt::PjRtClient::CreateOptions{std::shared_ptr{pjrt_client}})).release(); } extern "C" void ifrt_pjrt_client_free(ifrt::PjRtClient* client) { @@ -1079,8 +1075,9 @@ extern "C" xla::PjRtClient* ifrt_pjrt_client_pjrt_client(ifrt::PjRtClient* clien return client->pjrt_client(); } -extern "C" ifrt::PjRtCompatibleArray* ifrt_pjrt_client_create_pjrt_array(ifrt::PjRtClient* client, ifrt::PjRtBuffer* pjrt_buffer) { - return new xla::ValueOrThrow(client->Create(pjrt_buffer)); +extern "C" ifrt::PjRtCompatibleArray* ifrt_pjrt_client_create_pjrt_array(ifrt::PjRtClient* client, xla::PjRtBuffer* pjrt_buffer) { + auto buffer_ptr = std::make_shared(*pjrt_buffer); + return xla::ValueOrThrow(client->CreatePjRtArray(buffer_ptr)).release(); } // TODO extern "C" ifrt::PjRtCompatibleArray* ifrt_pjrt_client_create_pjrt_array_from_buffers(ifrt::Shape* shape, ifrt::PjRtBuffer** pjrt_buffers, int num_buffers) {} @@ -1095,7 +1092,7 @@ extern "C" ifrt::PjRtCompatibleMemory* ifrt_pjrt_client_lookup_pjrt_memory(ifrt: #pragma endregion #pragma region xla::ifrt::HostCallback -extern "C" ifrt::HostCallback* ifrt_hostcallback_serialize(ifrt::HostCallback* host_callback) { +extern "C" const char* ifrt_hostcallback_serialize(ifrt::HostCallback* host_callback) { return cstr_from_string(host_callback->Serialize()); } #pragma endregion @@ -1106,21 +1103,25 @@ extern "C" ifrt::Client* ifrt_loadedhostcallback_client(ifrt::LoadedHostCallback } extern "C" const char* ifrt_loadedhostcallback_serialize(ifrt::LoadedHostCallback* host_callback) { - return cstr_from_string(host_callback->Serialize()); + // auto msg = ; + return cstr_from_string(xla::ValueOrThrow(host_callback->Serialize())); } #pragma endregion -#pragma region xla::ifrt::PjRtHostSendAndRecvLoadHostCallback -extern "C" ifrt::PjRtHostSendAndRecvLoadHostCallback* ifrt_pjrt_hostsendandrecv_loadhostcallback_ctor(ifrt::PjRtClient* client, xla::HostCallback* host_callback) { - return new xla::ValueOrThrow(ifrt::PjRtHostSendAndRecvLoadHostCallback(client, host_callback)); +#pragma region xla::ifrt::PjRtHostSendAndRecvLoadedHostCallback +extern "C" ifrt::PjRtHostSendAndRecvLoadedHostCallback* ifrt_pjrt_hostsendandrecv_loadhostcallback_ctor(ifrt::PjRtClient* client, xla::HostCallback* host_callback) { + auto xla_callback_ptr = std::unique_ptr(host_callback); + auto callback = xla::ValueOrThrow(ifrt::PjRtHostSendAndRecvLoadedHostCallback(client, xla_callback_ptr)); + xla_callback_ptr.release(); + return new ifrt::PjRtHostSendAndRecvLoadedHostCallback(callback); } -extern "C" void ifrt_pjrt_hostsendandrecv_loadhostcallback_free(ifrt::PjRtHostSendAndRecvLoadHostCallback* host_callback) { +extern "C" void ifrt_pjrt_hostsendandrecv_loadhostcallback_free(ifrt::PjRtHostSendAndRecvLoadedHostCallback* host_callback) { delete host_callback; } -extern "C" xla::HostCallback* ifrt_pjrt_hostsendandrecv_loadhostcallback_host_callback(ifrt::PjRtHostSendAndRecvLoadHostCallback* host_callback) { - return host_callback->host_callback(); +extern "C" xla::HostCallback* ifrt_pjrt_hostsendandrecv_loadhostcallback_host_callback(ifrt::PjRtHostSendAndRecvLoadedHostCallback* host_callback) { + return new xla::HostCallback(host_callback->host_callback()); } #pragma endregion @@ -1130,7 +1131,7 @@ extern "C" const char* ifrt_executable_name(ifrt::Executable* executable) { } extern "C" const char* ifrt_executable_fingerprint(ifrt::Executable* executable) { - auto result = xla::ValueOrThrow(executable->fingerprint()); + auto result = xla::ValueOrThrow(executable->Fingerprint()); if (!result.has_value()) return ""; return cstr_from_string(result.value()); } @@ -1149,29 +1150,43 @@ extern "C" int64_t ifrt_executable_size(ifrt::Executable* executable) { // TODO xla::ifrt::Executable::GetCompiledMemoryStats -extern "C" std::tuple ifrt_executable_parameter_shardings(ifrt::Executable* executable) { - auto shardings = unwrap_absl_statusor(executable->GetParameterShardings()); - return std::make_tuple(shardings.size(), shardings.data()); +extern "C" std::tuple ifrt_executable_parameter_shardings(ifrt::Executable* executable) { + auto shardings = executable->GetParameterShardings(); + if (!shardings.has_value()) return std::make_tuple(0, nullptr); + return std::make_tuple(shardings.value().size(), shardings.value().data()); } -extern "C" std::tuple ifrt_executable_output_shardings(ifrt::Executable* executable) { - auto shardings = unwrap_absl_statusor(executable->GetOutputShardings()); - return std::make_tuple(shardings.size(), shardings.data()); +extern "C" std::tuple ifrt_executable_output_shardings(ifrt::Executable* executable) { + auto shardings = executable->GetOutputShardings(); + if (!shardings.has_value()) return std::make_tuple(0, nullptr); + return std::make_tuple(shardings.value().size(), shardings.value().data()); } -extern "C" std::tuple ifrt_executable_parameter_layouts(ifrt::Executable* executable) { - auto layouts = unwrap_absl_statusor(executable->GetParameterLayouts()); - return std::make_tuple(layouts.size(), layouts.data()); +extern "C" std::tuple ifrt_executable_parameter_layouts(ifrt::Executable* executable) { + auto layouts = xla::ValueOrThrow(executable->GetParameterLayouts()); + auto layouts_ptr = new xla::PjRtLayout*[layouts.size()]; + for (int i=0; i ifrt_executable_output_layouts(ifrt::Executable* executable) { - auto layouts = unwrap_absl_statusor(executable->GetOutputLayouts()); - return std::make_tuple(layouts.size(), layouts.data()); +extern "C" std::tuple ifrt_executable_output_layouts(ifrt::Executable* executable) { + auto layouts = xla::ValueOrThrow(executable->GetOutputLayouts()); + auto layouts_ptr = new xla::PjRtLayout*[layouts.size()]; + for (int i=0; i ifrt_executable_hlo_modules(ifrt::Executable* executable) { - auto modules = unwrap_absl_statusor(executable->GetHloModules()); - return std::make_tuple(modules.size(), modules.data()); +extern "C" std::tuple ifrt_executable_hlo_modules(ifrt::Executable* executable) { + auto modules = xla::ValueOrThrow(executable->GetHloModules()); + auto modules_ptr = new xla::HloModule*[modules.size()]; + for (int i=0; i ifrt_executable_hlo_modules(ifrt #pragma region xla::ifrt::PjRtExecutable extern "C" ifrt::PjRtExecutable* ifrt_pjrt_executable_ctor(xla::PjRtExecutable* pjrt_executable, ifrt::XlaCompileOptions* compile_options) { - return new xla::ValueOrThrow(ifrt::PjRtExecutable(pjrt_executable, compile_options)); + auto pjrt_executable_shared = std::make_shared(*pjrt_executable); + auto options = std::make_unique(*compile_options); + auto executable = xla::ValueOrThrow(ifrt::PjRtExecutable(pjrt_executable_shared, options)); + return new ifrt::PjRtExecutable(executable); } extern "C" void ifrt_pjrt_executable_free(ifrt::PjRtExecutable* executable) { @@ -1201,7 +1219,7 @@ extern "C" const char* ifrt_loadedexecutable_name(ifrt::LoadedExecutable* execut } extern "C" const char* ifrt_loadedexecutable_fingerprint(ifrt::LoadedExecutable* executable) { - auto result = xla::ValueOrThrow(executable->fingerprint()); + auto result = xla::ValueOrThrow(executable->Fingerprint()); if (!result.has_value()) return ""; return cstr_from_string(result.value()); } @@ -1210,7 +1228,7 @@ extern "C" const char* ifrt_loadedexecutable_serialize(ifrt::LoadedExecutable* e return cstr_from_string(xla::ValueOrThrow(executable->Serialize())); } -extern "C" ifrt::Future<>* ifrt_loadedexecutable_get_ready_future(ifrt::LoadedExecutable* executable) { +extern "C" ifrt::Future<> ifrt_loadedexecutable_get_ready_future(ifrt::LoadedExecutable* executable) { return executable->GetReadyFuture(); } @@ -1224,29 +1242,43 @@ extern "C" int64_t ifrt_loadedexecutable_size(ifrt::LoadedExecutable* executable // TODO xla::ifrt::GetCompiledMemoryStats -extern "C" std::tuple ifrt_loadedexecutable_parameter_shardings(ifrt::LoadedExecutable* executable) { - auto shardings = unwrap_absl_statusor(executable->GetParameterShardings()); - return std::make_tuple(shardings.size(), shardings.data()); +extern "C" std::tuple ifrt_loadedexecutable_parameter_shardings(ifrt::LoadedExecutable* executable) { + auto shardings = executable->GetParameterShardings(); + if (!shardings.has_value()) return std::make_tuple(0, nullptr); + return std::make_tuple(shardings.value().size(), shardings.value().data()); } -extern "C" std::tuple ifrt_loadedexecutable_output_shardings(ifrt::LoadedExecutable* executable) { - auto shardings = unwrap_absl_statusor(executable->GetOutputShardings()); - return std::make_tuple(shardings.size(), shardings.data()); +extern "C" std::tuple ifrt_loadedexecutable_output_shardings(ifrt::LoadedExecutable* executable) { + auto shardings = executable->GetOutputShardings(); + if (!shardings.has_value()) return std::make_tuple(0, nullptr); + return std::make_tuple(shardings.value().size(), shardings.value().data()); } -extern "C" std::tuple ifrt_loadedexecutable_parameter_layouts(ifrt::LoadedExecutable* executable) { - auto layouts = unwrap_absl_statusor(executable->GetParameterLayouts()); - return std::make_tuple(layouts.size(), layouts.data()); +extern "C" std::tuple ifrt_loadedexecutable_parameter_layouts(ifrt::LoadedExecutable* executable) { + auto layouts = xla::ValueOrThrow(executable->GetParameterLayouts()); + auto layouts_ptr = new xla::PjRtLayout*[layouts.size()]; + for (int i=0; i ifrt_loadedexecutable_output_layouts(ifrt::LoadedExecutable* executable) { - auto layouts = unwrap_absl_statusor(executable->GetOutputLayouts()); - return std::make_tuple(layouts.size(), layouts.data()); +extern "C" std::tuple ifrt_loadedexecutable_output_layouts(ifrt::LoadedExecutable* executable) { + auto layouts = xla::ValueOrThrow(executable->GetOutputLayouts()); + auto layouts_ptr = new xla::PjRtLayout*[layouts.size()]; + for (int i=0; i ifrt_loadedexecutable_hlo_modules(ifrt::LoadedExecutable* executable) { - auto modules = unwrap_absl_statusor(executable->GetHloModules()); - return std::make_tuple(modules.size(), modules.data()); +extern "C" std::tuple ifrt_loadedexecutable_hlo_modules(ifrt::LoadedExecutable* executable) { + auto modules = xla::ValueOrThrow(executable->GetHloModules()); + auto modules_ptr = new xla::HloModule*[modules.size()]; + for (int i=0; iIsDeleted(); } -extern "C" std::tuple ifrt_loadedexecutable_addressable_devices(ifrt::LoadedExecutable* executable) { +extern "C" std::tuple ifrt_loadedexecutable_addressable_devices(ifrt::LoadedExecutable* executable) { auto devices = executable->addressable_devices(); return std::make_tuple(devices.size(), devices.data()); } @@ -1278,12 +1310,14 @@ extern "C" std::tuple ifrt_loadedexecutable_addressable_d #pragma region xla::ifrt::PjRtLoadedExecutable // TODO add support for LoadedHostCallback extern "C" ifrt::PjRtLoadedExecutable* ifrt_pjrt_loadedexecutable_ctor(ifrt::PjRtCompatibleClient* client, xla::PjRtLoadedExecutable* pjrt_loaded_executable) { - return new xla::ValueOrThrow(ifrt::PjRtLoadedExecutable(client, pjrt_loaded_executable, std::vector>())); + auto executable = xla::ValueOrThrow(ifrt::PjRtLoadedExecutable(client, pjrt_loaded_executable, std::vector>())); + return new ifrt::PjRtLoadedExecutable(executable); } // TODO add support for LoadedHostCallback extern "C" ifrt::PjRtLoadedExecutable* ifrt_pjrt_loadedexecutable_ctor_from_mlir_module(ifrt::PjRtCompatibleClient* client, mlir::ModuleOp* module, xla::CompileOptions* compile_options) { - return new xla::ValueOrThrow(ifrt::PjRtLoadedExecutable(client, *module, *compile_options, std::vector>())); + auto executable = xla::ValueOrThrow(ifrt::PjRtLoadedExecutable(client, *module, *compile_options, std::vector>())); + return new ifrt::PjRtLoadedExecutable(executable); } extern "C" void ifrt_pjrt_loadedexecutable_free(ifrt::PjRtLoadedExecutable* executable) { @@ -1291,7 +1325,7 @@ extern "C" void ifrt_pjrt_loadedexecutable_free(ifrt::PjRtLoadedExecutable* exec } extern "C" xla::PjRtLoadedExecutable* ifrt_pjrt_loadedexecutable_pjrt_loadedexecutable(ifrt::PjRtLoadedExecutable* executable) { - return executable->pjrt_loadedexecutable(); + return executable->pjrt_loaded_executable(); } #pragma endregion @@ -1313,14 +1347,24 @@ extern "C" ifrt::HloProgram* ifrt_hloprogram_ctor_with_context_and_module(mlir:: #pragma endregion #pragma region xla::ifrt::Compiler -extern "C" ifrt::LoadedExecutable* ifrt_compiler_compile(ifrt::Compiler* compiler, ifrt::Program* program, char** error) { +extern "C" ifrt::LoadedExecutable* ifrt_compiler_compile(ifrt::Compiler* compiler, ifrt::Program* program) { + // apparently ifrt::CompileOptions is a legacy artifact so we don't use it and set directly to the default + return xla::ValueOrThrow(compiler->Compile(*program, ifrt::CompileOptions())).release(); +} + +extern "C" ifrt::LoadedExecutable* ifrt_compiler_compile_with_topology(ifrt::Compiler* compiler, ifrt::Program* program, const ifrt::Topology* topology) { // apparently ifrt::CompileOptions is a legacy artifact so we don't use it and set directly to the default - return unwrap_absl_statusor(compiler->Compile(*program, *options, ifrt::CompileOptions()), error); + auto options = std::make_unique(); + auto program_ptr = std::unique_ptr(program); + auto exec_ptr = xla::ValueOrThrow(compiler->Compile(program, *topology, options)).release(); + program_ptr.release(); + return exec_ptr; } -extern "C" ifrt::LoadedExecutable* ifrt_compiler_deserialize_loadedexecutable(ifrt::Compiler* compiler, const char* data, size_t size, char** error) { +extern "C" ifrt::LoadedExecutable* ifrt_compiler_deserialize_loadedexecutable(ifrt::Compiler* compiler, const char* data) { // apparently ifrt::DeserializeExecutableOptions is a legacy artifact so we don't use it and set directly to the default - return unwrap_absl_statusor(compiler->DeserializeLoadedExecutable(data, size, ifrt::DeserializeExecutableOptions()), error); + auto options = std::make_unique(); + return xla::ValueOrThrow(compiler->DeserializeLoadedExecutable(std::string(data), options)).release(); } #pragma endregion @@ -1335,27 +1379,3 @@ extern "C" void ifrt_pjrt_compiler_free(ifrt::PjRtCompiler* compiler) { #pragma endregion #pragma endregion - -// auxiliar functions -#pragma region utils -template -const char* cstr_from_string(T text) { - char* cstr = (char*)malloc(text.size() + 1); - memcpy(cstr, text.data(), text.size()); - cstr[text.size()] = '\0'; - return cstr; -} - -template -T* unwrap_absl_statusor(absl::StatusOr status, char** error_msg) { - *error_msg = nullptr; - if (!status.ok()) { - auto str = pluginLoad.status().message(); - char* err = (char*)malloc(str.size()+1); - memcpy(err, str.data(), str.size()+1); - *error_msg = err; - return nullptr; - } - return status.value(); -} -#pragma endregion From 58b2c53dad9ab0339ac962e4b400cdcf1c8d8175 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Mon, 4 Nov 2024 09:03:17 +0100 Subject: [PATCH 19/25] fixes --- deps/ReactantExtra/API.cpp | 45 +++++++++++++++++++------------------- 1 file changed, 23 insertions(+), 22 deletions(-) diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index 09d9d4c5fb..d2b02f650b 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -564,7 +564,7 @@ extern "C" ifrt::PjRtTuple* ifrt_pjrt_tuple_ctor(ifrt::PjRtCompatibleClient* cli values_ptr[i] = tsl::RCReference(values[i]); } auto span = absl::Span>(values_ptr, nvalues); - return new ifrt::PjRtTuple(client, span); + return xla::ValueOrThrow(ifrt::PjRtTuple::Create(client, span)).release(); } extern "C" void ifrt_pjrt_tuple_free(ifrt::PjRtTuple* tuple) { @@ -814,7 +814,10 @@ extern "C" bool ifrt_memorykind_ne(ifrt::MemoryKind* mk1, ifrt::MemoryKind* mk2) } extern "C" const char* ifrt_memorykind_string(ifrt::MemoryKind* memory_kind) { - return cstr_from_string(memory_kind->memory_kind()); + if (memory_kind->memory_kind().has_value()) + return cstr_from_string(memory_kind->memory_kind().value()); + else + return ""; } extern "C" ifrt::MemoryKind* ifrt_memorykind_canonicalize(ifrt::MemoryKind* memory_kind, ifrt::Device* device) { @@ -1110,10 +1113,8 @@ extern "C" const char* ifrt_loadedhostcallback_serialize(ifrt::LoadedHostCallbac #pragma region xla::ifrt::PjRtHostSendAndRecvLoadedHostCallback extern "C" ifrt::PjRtHostSendAndRecvLoadedHostCallback* ifrt_pjrt_hostsendandrecv_loadhostcallback_ctor(ifrt::PjRtClient* client, xla::HostCallback* host_callback) { - auto xla_callback_ptr = std::unique_ptr(host_callback); - auto callback = xla::ValueOrThrow(ifrt::PjRtHostSendAndRecvLoadedHostCallback(client, xla_callback_ptr)); - xla_callback_ptr.release(); - return new ifrt::PjRtHostSendAndRecvLoadedHostCallback(callback); + auto xla_callback_ptr = std::make_unique(*host_callback); + return new ifrt::PjRtHostSendAndRecvLoadedHostCallback(client, std::move(xla_callback_ptr)); } extern "C" void ifrt_pjrt_hostsendandrecv_loadhostcallback_free(ifrt::PjRtHostSendAndRecvLoadedHostCallback* host_callback) { @@ -1193,11 +1194,10 @@ extern "C" std::tuple ifrt_executable_hlo_modules(ifrt #pragma endregion #pragma region xla::ifrt::PjRtExecutable -extern "C" ifrt::PjRtExecutable* ifrt_pjrt_executable_ctor(xla::PjRtExecutable* pjrt_executable, ifrt::XlaCompileOptions* compile_options) { +extern "C" ifrt::Executable* ifrt_pjrt_executable_ctor(xla::PjRtExecutable* pjrt_executable, ifrt::XlaCompileOptions* compile_options) { auto pjrt_executable_shared = std::make_shared(*pjrt_executable); auto options = std::make_unique(*compile_options); - auto executable = xla::ValueOrThrow(ifrt::PjRtExecutable(pjrt_executable_shared, options)); - return new ifrt::PjRtExecutable(executable); + return xla::ValueOrThrow(ifrt::PjRtExecutable::Create(pjrt_executable_shared, std::move(options))).release(); } extern "C" void ifrt_pjrt_executable_free(ifrt::PjRtExecutable* executable) { @@ -1309,15 +1309,14 @@ extern "C" std::tuple ifrt_loadedexecutable_addres #pragma region xla::ifrt::PjRtLoadedExecutable // TODO add support for LoadedHostCallback -extern "C" ifrt::PjRtLoadedExecutable* ifrt_pjrt_loadedexecutable_ctor(ifrt::PjRtCompatibleClient* client, xla::PjRtLoadedExecutable* pjrt_loaded_executable) { - auto executable = xla::ValueOrThrow(ifrt::PjRtLoadedExecutable(client, pjrt_loaded_executable, std::vector>())); - return new ifrt::PjRtLoadedExecutable(executable); +extern "C" ifrt::LoadedExecutable* ifrt_pjrt_loadedexecutable_ctor(ifrt::PjRtCompatibleClient* client, xla::PjRtLoadedExecutable* pjrt_loaded_executable) { + auto pjrt_loaded_executable_ptr = std::make_shared(*pjrt_loaded_executable); + return xla::ValueOrThrow(ifrt::PjRtLoadedExecutable::Create(client, pjrt_loaded_executable_ptr, std::vector>())).release(); } // TODO add support for LoadedHostCallback -extern "C" ifrt::PjRtLoadedExecutable* ifrt_pjrt_loadedexecutable_ctor_from_mlir_module(ifrt::PjRtCompatibleClient* client, mlir::ModuleOp* module, xla::CompileOptions* compile_options) { - auto executable = xla::ValueOrThrow(ifrt::PjRtLoadedExecutable(client, *module, *compile_options, std::vector>())); - return new ifrt::PjRtLoadedExecutable(executable); +extern "C" ifrt::LoadedExecutable* ifrt_pjrt_loadedexecutable_ctor_from_mlir_module(ifrt::PjRtCompatibleClient* client, mlir::ModuleOp* module, xla::CompileOptions* compile_options) { + return xla::ValueOrThrow(ifrt::PjRtLoadedExecutable::Create(client, *module, *compile_options, std::vector>())).release(); } extern "C" void ifrt_pjrt_loadedexecutable_free(ifrt::PjRtLoadedExecutable* executable) { @@ -1338,33 +1337,35 @@ extern "C" ifrt::HloProgram* ifrt_hloprogram_ctor() { } extern "C" ifrt::HloProgram* ifrt_hloprogram_ctor_with_module(mlir::ModuleOp* module) { - return new ifrt::HloProgram(module); + return new ifrt::HloProgram(*module); } extern "C" ifrt::HloProgram* ifrt_hloprogram_ctor_with_context_and_module(mlir::MLIRContext* context, mlir::ModuleOp* module) { - return new ifrt::HloProgram(context, module); + auto context_ptr = std::make_unique(*context); + return new ifrt::HloProgram(std::move(context_ptr), *module); } #pragma endregion #pragma region xla::ifrt::Compiler extern "C" ifrt::LoadedExecutable* ifrt_compiler_compile(ifrt::Compiler* compiler, ifrt::Program* program) { // apparently ifrt::CompileOptions is a legacy artifact so we don't use it and set directly to the default - return xla::ValueOrThrow(compiler->Compile(*program, ifrt::CompileOptions())).release(); + auto program_ptr = std::make_unique(*program); + auto options = std::make_unique(); + return xla::ValueOrThrow(compiler->Compile(std::move(program_ptr), std::move(options))).release(); } extern "C" ifrt::LoadedExecutable* ifrt_compiler_compile_with_topology(ifrt::Compiler* compiler, ifrt::Program* program, const ifrt::Topology* topology) { // apparently ifrt::CompileOptions is a legacy artifact so we don't use it and set directly to the default auto options = std::make_unique(); - auto program_ptr = std::unique_ptr(program); - auto exec_ptr = xla::ValueOrThrow(compiler->Compile(program, *topology, options)).release(); - program_ptr.release(); + auto program_ptr = std::make_unique(*program); + auto exec_ptr = xla::ValueOrThrow(compiler->Compile(std::move(program_ptr), *topology, options)).release(); return exec_ptr; } extern "C" ifrt::LoadedExecutable* ifrt_compiler_deserialize_loadedexecutable(ifrt::Compiler* compiler, const char* data) { // apparently ifrt::DeserializeExecutableOptions is a legacy artifact so we don't use it and set directly to the default auto options = std::make_unique(); - return xla::ValueOrThrow(compiler->DeserializeLoadedExecutable(std::string(data), options)).release(); + return xla::ValueOrThrow(compiler->DeserializeLoadedExecutable(std::string(data), std::move(options))).release(); } #pragma endregion From 4651c8f3d1eeea62e18161e137ff902a7974b16b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Mon, 4 Nov 2024 09:36:03 +0100 Subject: [PATCH 20/25] small fix --- deps/ReactantExtra/API.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index d2b02f650b..86b0b58483 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -1358,7 +1358,7 @@ extern "C" ifrt::LoadedExecutable* ifrt_compiler_compile_with_topology(ifrt::Com // apparently ifrt::CompileOptions is a legacy artifact so we don't use it and set directly to the default auto options = std::make_unique(); auto program_ptr = std::make_unique(*program); - auto exec_ptr = xla::ValueOrThrow(compiler->Compile(std::move(program_ptr), *topology, options)).release(); + auto exec_ptr = xla::ValueOrThrow(compiler->Compile(std::move(program_ptr), *topology, std::move(options))).release(); return exec_ptr; } From 3ab7e88367a0f8df6f61f0e9c7ad9fc97071304e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sun, 1 Dec 2024 17:42:27 +0100 Subject: [PATCH 21/25] disable functions who need `shared_ptr` --- deps/ReactantExtra/API.cpp | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index 86b0b58483..88b9b40a8f 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -1078,10 +1078,11 @@ extern "C" xla::PjRtClient* ifrt_pjrt_client_pjrt_client(ifrt::PjRtClient* clien return client->pjrt_client(); } -extern "C" ifrt::PjRtCompatibleArray* ifrt_pjrt_client_create_pjrt_array(ifrt::PjRtClient* client, xla::PjRtBuffer* pjrt_buffer) { - auto buffer_ptr = std::make_shared(*pjrt_buffer); - return xla::ValueOrThrow(client->CreatePjRtArray(buffer_ptr)).release(); -} +// TODO there are problems with using `make_shared +// extern "C" ifrt::PjRtCompatibleArray* ifrt_pjrt_client_create_pjrt_array(ifrt::PjRtClient* client, xla::PjRtBuffer* pjrt_buffer) { +// auto buffer_ptr = std::make_shared(*pjrt_buffer); +// return xla::ValueOrThrow(client->CreatePjRtArray(buffer_ptr)).release(); +// } // TODO extern "C" ifrt::PjRtCompatibleArray* ifrt_pjrt_client_create_pjrt_array_from_buffers(ifrt::Shape* shape, ifrt::PjRtBuffer** pjrt_buffers, int num_buffers) {} @@ -1194,11 +1195,12 @@ extern "C" std::tuple ifrt_executable_hlo_modules(ifrt #pragma endregion #pragma region xla::ifrt::PjRtExecutable -extern "C" ifrt::Executable* ifrt_pjrt_executable_ctor(xla::PjRtExecutable* pjrt_executable, ifrt::XlaCompileOptions* compile_options) { - auto pjrt_executable_shared = std::make_shared(*pjrt_executable); - auto options = std::make_unique(*compile_options); - return xla::ValueOrThrow(ifrt::PjRtExecutable::Create(pjrt_executable_shared, std::move(options))).release(); -} +// TODO there are problems with using `make_shared +// extern "C" ifrt::Executable* ifrt_pjrt_executable_ctor(xla::PjRtExecutable* pjrt_executable, ifrt::XlaCompileOptions* compile_options) { +// auto pjrt_executable_shared = std::make_shared(*pjrt_executable); +// auto options = std::make_unique(*compile_options); +// return xla::ValueOrThrow(ifrt::PjRtExecutable::Create(pjrt_executable_shared, std::move(options))).release(); +// } extern "C" void ifrt_pjrt_executable_free(ifrt::PjRtExecutable* executable) { delete executable; @@ -1309,10 +1311,11 @@ extern "C" std::tuple ifrt_loadedexecutable_addres #pragma region xla::ifrt::PjRtLoadedExecutable // TODO add support for LoadedHostCallback -extern "C" ifrt::LoadedExecutable* ifrt_pjrt_loadedexecutable_ctor(ifrt::PjRtCompatibleClient* client, xla::PjRtLoadedExecutable* pjrt_loaded_executable) { - auto pjrt_loaded_executable_ptr = std::make_shared(*pjrt_loaded_executable); - return xla::ValueOrThrow(ifrt::PjRtLoadedExecutable::Create(client, pjrt_loaded_executable_ptr, std::vector>())).release(); -} +// TODO there are problems with using `make_shared +// extern "C" ifrt::LoadedExecutable* ifrt_pjrt_loadedexecutable_ctor(ifrt::PjRtCompatibleClient* client, xla::PjRtLoadedExecutable* pjrt_loaded_executable) { +// auto pjrt_loaded_executable_ptr = std::make_shared(*pjrt_loaded_executable); +// return xla::ValueOrThrow(ifrt::PjRtLoadedExecutable::Create(client, pjrt_loaded_executable_ptr, std::vector>())).release(); +// } // TODO add support for LoadedHostCallback extern "C" ifrt::LoadedExecutable* ifrt_pjrt_loadedexecutable_ctor_from_mlir_module(ifrt::PjRtCompatibleClient* client, mlir::ModuleOp* module, xla::CompileOptions* compile_options) { From 3d0ea13420dd8b72ffc41c95acf312e458fbedf7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sun, 1 Dec 2024 17:42:50 +0100 Subject: [PATCH 22/25] fix return type of `ifrt_compiler_compile_with_topology` --- deps/ReactantExtra/API.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index 88b9b40a8f..f38b9cfc0e 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -1357,7 +1357,7 @@ extern "C" ifrt::LoadedExecutable* ifrt_compiler_compile(ifrt::Compiler* compile return xla::ValueOrThrow(compiler->Compile(std::move(program_ptr), std::move(options))).release(); } -extern "C" ifrt::LoadedExecutable* ifrt_compiler_compile_with_topology(ifrt::Compiler* compiler, ifrt::Program* program, const ifrt::Topology* topology) { +extern "C" ifrt::Executable* ifrt_compiler_compile_with_topology(ifrt::Compiler* compiler, ifrt::Program* program, const ifrt::Topology* topology) { // apparently ifrt::CompileOptions is a legacy artifact so we don't use it and set directly to the default auto options = std::make_unique(); auto program_ptr = std::make_unique(*program); From 19167b261f8ab1ef9a9bb12bf76b1c47b9c68127 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sun, 1 Dec 2024 17:43:20 +0100 Subject: [PATCH 23/25] comment `ifrt_hloprogram_ctor_with_context_and_module` due to problems with `MLIRContext` constructor --- deps/ReactantExtra/API.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index f38b9cfc0e..6ef434f33a 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -1343,10 +1343,10 @@ extern "C" ifrt::HloProgram* ifrt_hloprogram_ctor_with_module(mlir::ModuleOp* mo return new ifrt::HloProgram(*module); } -extern "C" ifrt::HloProgram* ifrt_hloprogram_ctor_with_context_and_module(mlir::MLIRContext* context, mlir::ModuleOp* module) { - auto context_ptr = std::make_unique(*context); - return new ifrt::HloProgram(std::move(context_ptr), *module); -} +// extern "C" ifrt::HloProgram* ifrt_hloprogram_ctor_with_context_and_module(mlir::MLIRContext* context, mlir::ModuleOp* module) { +// auto context_ptr = std::make_unique(*context); +// return new ifrt::HloProgram(std::move(context_ptr), *module); +// } #pragma endregion #pragma region xla::ifrt::Compiler From 8bbe74bbc892b6effe58057774ba0ae612df70b4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sun, 1 Dec 2024 17:43:51 +0100 Subject: [PATCH 24/25] fix `tsl::RCReference` construction --- deps/ReactantExtra/API.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index 6ef434f33a..ac6f282008 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -561,7 +561,8 @@ extern "C" int ifrt_tuple_arity(ifrt::Tuple* tuple) { extern "C" ifrt::PjRtTuple* ifrt_pjrt_tuple_ctor(ifrt::PjRtCompatibleClient* client, ifrt::Value* values, int nvalues) { auto values_ptr = new tsl::RCReference[nvalues]; for (int i=0; i(values[i]); + values_ptr[i] = tsl::RCReference(); + values_ptr[i].reset(&values[i]); } auto span = absl::Span>(values_ptr, nvalues); return xla::ValueOrThrow(ifrt::PjRtTuple::Create(client, span)).release(); From bb82fc7c40762037c95eebee74011afd61e741df Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= <15837247+mofeing@users.noreply.github.com> Date: Sun, 1 Dec 2024 19:27:35 +0100 Subject: [PATCH 25/25] Update deps/ReactantExtra/API.cpp --- deps/ReactantExtra/API.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index ac6f282008..b02a805360 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -818,7 +818,7 @@ extern "C" const char* ifrt_memorykind_string(ifrt::MemoryKind* memory_kind) { if (memory_kind->memory_kind().has_value()) return cstr_from_string(memory_kind->memory_kind().value()); else - return ""; + return nullptr; } extern "C" ifrt::MemoryKind* ifrt_memorykind_canonicalize(ifrt::MemoryKind* memory_kind, ifrt::Device* device) {