From e19bd79732717fde76b7db2334f6f2629ee8af2f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sun, 9 Feb 2025 04:16:15 -0600 Subject: [PATCH 01/13] Remove old IFRT bindings --- deps/ReactantExtra/API.cpp | 1008 ------------------------------------ 1 file changed, 1008 deletions(-) diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index bfb0e9c239..f219fc3141 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -991,1011 +991,3 @@ extern "C" MlirOperation LinkInModule(MlirModule prevModC, MlirModule newModC, newMod.getBody()->getOperations()); return wrap(entryFn); } - -#pragma region xla::ifrt - -#pragma region xla::ifrt::Value -extern "C" ifrt::Client *ifrt_value_client(ifrt::Value *value) { - return value->client(); -} - -extern "C" ifrt::Future<> ifrt_value_get_ready_future(ifrt::Value *value) { - return value->GetReadyFuture(); -} - -extern "C" ifrt::Future<> ifrt_value_delete(ifrt::Value *value) { - return value->Delete(); -} - -extern "C" bool ifrt_value_is_deleted(ifrt::Value *value) { - return value->IsDeleted(); -} - -extern "C" const char *ifrt_value_debug_string(ifrt::Value *value) { - return cstr_from_string(value->DebugString()); -} -#pragma endregion - -#pragma region xla::ifrt::Tuple -extern "C" int ifrt_tuple_arity(ifrt::Tuple *tuple) { return tuple->Arity(); } - -// TODO ifrt::Tuple::Unpack -#pragma endregion - -#pragma region xla::ifrt::PjRtTuple -extern "C" ifrt::PjRtTuple * -ifrt_pjrt_tuple_ctor(ifrt::PjRtCompatibleClient *client, ifrt::Value *values, - int nvalues) { - auto values_ptr = new tsl::RCReference[nvalues]; - for (int i = 0; i < nvalues; i++) { - values_ptr[i] = tsl::RCReference(); - values_ptr[i].reset(&values[i]); - } - auto span = absl::Span>(values_ptr, nvalues); - return MyValueOrThrow(ifrt::PjRtTuple::Create(client, span)).release(); -} - -extern "C" void ifrt_pjrt_tuple_free(ifrt::PjRtTuple *tuple) { delete tuple; } -#pragma endregion - -#pragma region xla::ifrt::DType -extern "C" ifrt::DType *ifrt_dtype_ctor(ifrt::DType::Kind kind) { - return new ifrt::DType(kind); -} - -extern "C" void ifrt_dtype_free(ifrt::DType *dtype) { delete dtype; } - -extern "C" ifrt::DType::Kind ifrt_dtype_kind(ifrt::DType *dtype) { - return dtype->kind(); -} - -extern "C" bool ifrt_dtype_eq(ifrt::DType *dtype1, ifrt::DType *dtype2) { - return *dtype1 == *dtype2; -} - -extern "C" bool ifrt_dtype_ne(ifrt::DType *dtype1, ifrt::DType *dtype2) { - return *dtype1 != *dtype2; -} - -// Returns -1 if not aligned to a byte boundary or there is no fixed size -extern "C" int ifrt_dtype_byte_size(ifrt::DType *dtype) { - auto byte_size = dtype->byte_size(); - if (byte_size.has_value()) { - return byte_size.value(); - } - return -1; -} - -// Returns -1 if there is no fixed size -extern "C" int ifrt_dtype_bit_size(ifrt::DType *dtype) { - auto bit_size = dtype->bit_size(); - if (bit_size.has_value()) { - return bit_size.value(); - } - return -1; -} - -extern "C" const char *ifrt_dtype_debug_string(ifrt::DType *dtype) { - return cstr_from_string(dtype->DebugString()); -} - -// xla::PrimitiveType is a enum, so we use int to represent it on Julia side -extern "C" xla::PrimitiveType ifrt_to_primitive_type(ifrt::DType *dtype) { - return MyValueOrThrow(ifrt::ToPrimitiveType(*dtype)); -} - -// xla::PrimitiveType is a enum, so we use int to represent it on Julia side -extern "C" ifrt::DType *ifrt_to_dtype(xla::PrimitiveType primitive_type) { - auto dtype = MyValueOrThrow(ifrt::ToDType(primitive_type)); - return new ifrt::DType(dtype.kind()); -} -#pragma endregion - -#pragma region xla::ifrt::Shape -extern "C" ifrt::Shape *ifrt_shape_ctor(const int64_t *dims, size_t dims_size) { - return new ifrt::Shape(absl::Span(dims, dims_size)); -} - -extern "C" void ifrt_shape_free(ifrt::Shape *shape) { delete shape; } - -extern "C" const int64_t *ifrt_shape_dims(ifrt::Shape *shape) { - return shape->dims().data(); -} - -extern "C" int64_t ifrt_shape_dims_num_elements(ifrt::Shape *shape) { - return shape->num_elements(); -} - -extern "C" const char *ifrt_shape_debug_string(ifrt::Shape *shape) { - return cstr_from_string(shape->DebugString()); -} -#pragma endregion - -#pragma region xla::ifrt::DynamicShape -extern "C" ifrt::DynamicShape * -ifrt_dynamicshape_ctor(ifrt::Shape *shape, const bool *dynamic_dims_mask) { - auto tag = ifrt::BoundedDynamicShapeTag( - absl::Span(dynamic_dims_mask, shape->dims().size())); - auto dynshape = MyValueOrThrow(ifrt::DynamicShape::Create(*shape, tag)); - return new ifrt::DynamicShape(dynshape); -} - -extern "C" void ifrt_dynamicshape_free(ifrt::DynamicShape *shape) { - delete shape; -} - -// TODO ifrt::DynamicShape::GetTag - -extern "C" bool ifrt_dynamicshape_eq(ifrt::DynamicShape *shape1, - ifrt::DynamicShape *shape2) { - return *shape1 == *shape2; -} - -extern "C" bool ifrt_dynamicshape_ne(ifrt::DynamicShape *shape1, - ifrt::DynamicShape *shape2) { - return *shape1 != *shape2; -} - -extern "C" ifrt::Shape * -ifrt_dynamicshape_get_padded_shape(ifrt::DynamicShape *shape) { - auto padshape = MyValueOrThrow(shape->GetPaddedShape()); - return new ifrt::Shape(padshape); -} - -extern "C" bool ifrt_dynamicshape_is_dynamic_dim(ifrt::DynamicShape *shape, - int dimension) { - return shape->IsDynamicDim(dimension); -} - -extern "C" const char * -ifrt_dynamicshape_debug_string(ifrt::DynamicShape *shape) { - return cstr_from_string(shape->DebugString()); -} -#pragma endregion - -#pragma region xla::ifrt::Index -extern "C" ifrt::Index *ifrt_index_ctor(const int64_t *elements, - size_t elements_size) { - return new ifrt::Index(absl::Span(elements, elements_size)); -} - -extern "C" ifrt::Index *ifrt_index_zeros(int num_elements) { - return new ifrt::Index(ifrt::Index::Zeros(num_elements)); -} - -extern "C" void ifrt_index_free(ifrt::Index *index) { delete index; } - -extern "C" const int64_t *ifrt_index_elements(ifrt::Index *index) { - return index->elements().data(); -} - -extern "C" int ifrt_index_count(ifrt::Index *index) { - return index->elements().size(); -} - -extern "C" bool ifrt_index_eq(ifrt::Index *index1, ifrt::Index *index2) { - return *index1 == *index2; -} - -extern "C" bool ifrt_index_ne(ifrt::Index *index1, ifrt::Index *index2) { - return *index1 != *index2; -} - -extern "C" ifrt::Index *ifrt_index_add(ifrt::Index *index, - ifrt::Index *offset) { - return new ifrt::Index(*index + *offset); -} - -extern "C" ifrt::Index *ifrt_index_sub(ifrt::Index *index, - ifrt::Index *offset) { - return new ifrt::Index(*index - *offset); -} - -// WARN we're not checking if the multiplier has the same size as the index -extern "C" ifrt::Index *ifrt_index_mul(ifrt::Index *index, - const int64_t *multiplier) { - return new ifrt::Index( - *index * absl::Span(multiplier, ifrt_index_count(index))); -} - -extern "C" void ifrt_index_add_inplace(ifrt::Index *index, - ifrt::Index *offset) { - *index += *offset; -} - -extern "C" void ifrt_index_sub_inplace(ifrt::Index *index, - ifrt::Index *offset) { - *index -= *offset; -} - -extern "C" void ifrt_index_mul_inplace(ifrt::Index *index, - const int64_t *multiplier) { - *index *= absl::Span(multiplier, ifrt_index_count(index)); -} - -extern "C" const char *ifrt_index_debug_string(ifrt::Index *index) { - return cstr_from_string(index->DebugString()); -} -#pragma endregion - -#pragma region xla::ifrt::IndexDomain -extern "C" ifrt::IndexDomain *ifrt_indexdomain_ctor(ifrt::Shape *shape) { - return new ifrt::IndexDomain(*shape); -} - -extern "C" ifrt::IndexDomain * -ifrt_indexdomain_ctor_with_origin(ifrt::Index *origin, ifrt::Shape *shape) { - return new ifrt::IndexDomain(*origin, *shape); -} - -extern "C" void ifrt_indexdomain_free(ifrt::IndexDomain *index_domain) { - delete index_domain; -} - -extern "C" const ifrt::Index * -ifrt_indexdomain_origin(ifrt::IndexDomain *index_domain) { - return &index_domain->origin(); -} - -extern "C" const ifrt::Shape * -ifrt_indexdomain_shape(ifrt::IndexDomain *index_domain) { - return &index_domain->shape(); -} - -extern "C" bool ifrt_indexdomain_eq(ifrt::IndexDomain *index_domain1, - ifrt::IndexDomain *index_domain2) { - return *index_domain1 == *index_domain2; -} - -extern "C" bool ifrt_indexdomain_ne(ifrt::IndexDomain *index_domain1, - ifrt::IndexDomain *index_domain2) { - return *index_domain1 != *index_domain2; -} - -extern "C" ifrt::IndexDomain * -ifrt_indexdomain_add(ifrt::IndexDomain *index_domain, ifrt::Index *offset) { - return new ifrt::IndexDomain(*index_domain + *offset); -} - -extern "C" ifrt::IndexDomain * -ifrt_indexdomain_sub(ifrt::IndexDomain *index_domain, ifrt::Index *offset) { - return new ifrt::IndexDomain(*index_domain - *offset); -} - -extern "C" void ifrt_indexdomain_add_inplace(ifrt::IndexDomain *index_domain, - ifrt::Index *offset) { - *index_domain += *offset; -} - -extern "C" void ifrt_indexdomain_sub_inplace(ifrt::IndexDomain *index_domain, - ifrt::Index *offset) { - *index_domain -= *offset; -} - -extern "C" const char * -ifrt_indexdomain_debug_string(ifrt::IndexDomain *index_domain) { - return cstr_from_string(index_domain->DebugString()); -} -#pragma endregion - -#pragma region xla::ifrt::MemoryKind -// Pass a nullptr to create a `MemoryKind` with no memory chosen. -extern "C" ifrt::MemoryKind *ifrt_memorykind_ctor(const char *memory_kind) { - if (memory_kind == nullptr) - return new ifrt::MemoryKind(); - return new ifrt::MemoryKind(std::string(memory_kind)); -} - -extern "C" void ifrt_memorykind_free(ifrt::MemoryKind *memory_kind) { - delete memory_kind; -} - -extern "C" bool ifrt_memorykind_eq(ifrt::MemoryKind *mk1, - ifrt::MemoryKind *mk2) { - return *mk1 == *mk2; -} - -extern "C" bool ifrt_memorykind_ne(ifrt::MemoryKind *mk1, - ifrt::MemoryKind *mk2) { - return *mk1 != *mk2; -} - -extern "C" const char *ifrt_memorykind_string(ifrt::MemoryKind *memory_kind) { - if (memory_kind->memory_kind().has_value()) - return cstr_from_string(memory_kind->memory_kind().value()); - else - return nullptr; -} - -extern "C" ifrt::MemoryKind * -ifrt_memorykind_canonicalize(ifrt::MemoryKind *memory_kind, - ifrt::Device *device) { - return new ifrt::MemoryKind(CanonicalizeMemoryKind(*memory_kind, device)); -} -#pragma endregion - -#pragma region xla::ifrt::Memory -// MemoryId is a struct with a single int32_t field --> check out -// xla/python/ifrt/memory.h -extern "C" ifrt::MemoryId ifrt_memory_id(ifrt::Memory *memory) { - return memory->Id(); -} - -extern "C" const ifrt::MemoryKind *ifrt_memory_kind(ifrt::Memory *memory) { - return &(memory->Kind()); -} - -extern "C" const char *ifrt_memory_to_string(ifrt::Memory *memory) { - return cstr_from_string(memory->ToString()); -} - -extern "C" const char *ifrt_memory_debug_string(ifrt::Memory *memory) { - return cstr_from_string(memory->DebugString()); -} - -extern "C" std::tuple -ifrt_memory_devices(ifrt::Memory *memory) { - auto devices = memory->Devices(); - return std::make_tuple(devices.size(), - devices.data()); -} -#pragma endregion - -#pragma region xla::ifrt::PjRtMemory -extern "C" ifrt::PjRtMemory * -ifrt_pjrt_memory_ctor(ifrt::PjRtClient *client, - xla::PjRtMemorySpace *memory_space) { - return new ifrt::PjRtMemory(client, memory_space); -} - -extern "C" void ifrt_pjrt_memory_free(ifrt::PjRtMemory *memory) { - delete memory; -} - -extern "C" ifrt::PjRtClient *ifrt_pjrt_memory_client(ifrt::PjRtMemory *memory) { - return memory->client(); -} - -extern "C" xla::PjRtMemorySpace * -ifrt_pjrt_memory_space(ifrt::PjRtMemory *memory) { - return memory->pjrt_memory(); -} -#pragma endregion - -#pragma region xla::ifrt::Device -extern "C" ifrt::Client *ifrt_device_client(ifrt::Device *device) { - return device->client(); -} - -// DeviceId is a struct with a single int32_t field --> check out -// xla/pjrt/pjrt_common.h -extern "C" ifrt::DeviceId ifrt_device_id(ifrt::Device *device) { - return device->Id(); -} - -// TODO ifrt_device_attributes - -extern "C" const char *ifrt_device_kind(ifrt::Device *device) { - return cstr_from_string(device->Kind()); -} - -extern "C" const char *ifrt_device_to_string(ifrt::Device *device) { - return cstr_from_string(device->ToString()); -} - -extern "C" const char *ifrt_device_debug_string(ifrt::Device *device) { - return cstr_from_string(device->DebugString()); -} - -extern "C" ifrt::Memory *ifrt_device_default_memory(ifrt::Device *device) { - return MyValueOrThrow(device->DefaultMemory()); -} - -// TODO ifrt_device_memories - -extern "C" bool ifrt_device_is_addressable(ifrt::Device *device) { - return device->IsAddressable(); -} - -extern "C" int ifrt_device_process_index(ifrt::Device *device) { - return device->ProcessIndex(); -} -#pragma endregion - -#pragma region xla::ifrt::PjRtDevice -// DeviceId is a struct with a single int32_t field --> check out -// xla/pjrt/pjrt_common.h -// TODO support `attributes` parameter -extern "C" ifrt::PjRtDevice * -ifrt_pjrt_device_ctor(ifrt::PjRtClient *client, ifrt::DeviceId device_id, - const char *kind, const char *to_string, - const char *debug_string, int process_index, - xla::PjRtDevice *pjrt_device) { - return new ifrt::PjRtDevice( - client, device_id, kind, to_string, debug_string, process_index, - absl::flat_hash_map(), pjrt_device); -} - -extern "C" void ifrt_pjrt_device_free(ifrt::PjRtDevice *device) { - delete device; -} - -extern "C" xla::PjRtDevice * -ifrt_pjrt_device_pjrt_device(ifrt::PjRtDevice *device) { - return device->pjrt_device(); -} -#pragma endregion - -#pragma region xla::ifrt::Sharding -// TODO ifrt_sharding_devices -// TODO ifrt_sharding_memory_kind - -// extern "C" void ifrt_sharding_disassemble(ifrt::Sharding* sharding, -// ifrt::Shape* shape, char** error) { -// auto status = sharding->Disassemble(*shape); -// if (!status.ok()) { -// auto str = status.message(); -// char* err = (char*)malloc(str.size()+1); -// memcpy(err, str.data(), str.size()+1); -// *error = err; -// } -// } - -// TODO ifrt_sharding_disassemble_dynamic_shape -// TODO ifrt_sharding_index_domains - -extern "C" const char *ifrt_sharding_debug_string(ifrt::Sharding *sharding) { - return cstr_from_string(sharding->DebugString()); -} -#pragma endregion - -#pragma region xla::ifrt::Array -extern "C" ifrt::DType *ifrt_array_dtype(ifrt::Array *array) { - return new ifrt::DType(array->dtype()); -} - -extern "C" const ifrt::Shape *ifrt_array_shape(ifrt::Array *array) { - return &(array->shape()); -} - -extern "C" const ifrt::Sharding *ifrt_array_sharding(ifrt::Array *array) { - return &(array->sharding()); -} - -// @mofeng this is now a shared ptr, will let you fix -// extern "C" PjRtLayout *ifrt_array_layout(ifrt::Array *array) { -// return MyValueOrThrow(array->layout()).release(); -// } - -// TODO xla::ifrt::Array::DisassembleIntoSingleDeviceArrays -// TODO xla::ifrt::Array::FullyReplicatedShard - -extern "C" ifrt::Future<> -ifrt_array_copy_to_host_buffer(ifrt::Array *array, void *data, - const int64_t *byte_strides, int semantics) { - return array->CopyToHostBuffer( - data, - absl::Span(byte_strides, array->shape().num_elements()), - ifrt::ArrayCopySemantics(semantics)); -} -#pragma endregion - -#pragma region xla::ifrt::PjRtArray -// TODO constructors / `Create` - -extern "C" std::tuple -ifrt_pjrt_array_pjrt_buffers(ifrt::PjRtArray *array) { - auto buffers = array->pjrt_buffers(); - auto buffers_ptr = new xla::PjRtBuffer *[buffers.size()]; - for (int i = 0; i < buffers.size(); i++) { - buffers_ptr[i] = buffers[i].get(); - } - return std::make_tuple(buffers.size(), buffers_ptr); -} -#pragma endregion - -#pragma region xla::ifrt::Topology -extern "C" const char *ifrt_topology_platform_name(ifrt::Topology *topology) { - return cstr_from_string(topology->platform_name()); -} - -extern "C" const char * -ifrt_topology_platform_version(ifrt::Topology *topology) { - return cstr_from_string(topology->platform_version()); -} - -// returns PjRtPlatformId which is a type alias for uint64_t -extern "C" uint64_t ifrt_topology_platform_id(ifrt::Topology *topology) { - return topology->platform_id(); -} - -extern "C" std::tuple -ifrt_topology_device_descriptions(ifrt::Topology *topology) { - auto descriptions = topology->DeviceDescriptions(); - auto descriptions_ptr = - new const xla::PjRtDeviceDescription *[descriptions.size()]; - for (int i = 0; i < descriptions.size(); i++) { - descriptions_ptr[i] = descriptions[i].release(); - } - return std::make_tuple(descriptions.size(), descriptions_ptr); -} - -// TODO xla::ifrt::Topology::GetDefaultLayout - -extern "C" const char *ifrt_topology_serialize(ifrt::Topology *topology) { - return cstr_from_string(MyValueOrThrow(topology->Serialize())); -} - -// TODO xla::ifrt::Topology::Attributes - -#pragma endregion - -#pragma region xla::ifrt::PjRtTopology -extern "C" ifrt::PjRtTopology * -ifrt_pjrt_topology_ctor(const xla::PjRtTopologyDescription *description) { - return new ifrt::PjRtTopology( - std::shared_ptr{description}); -} - -extern "C" const xla::PjRtTopologyDescription * -ifrt_pjrt_topology_description(ifrt::PjRtTopology *topology) { - return topology->description().get(); -} -#pragma endregion - -#pragma region xla::ifrt::Client -extern "C" int ifrt_client_device_count(ifrt::Client *client) { - return client->device_count(); -} - -extern "C" int ifrt_client_addressable_device_count(ifrt::Client *client) { - return client->addressable_device_count(); -} - -extern "C" ifrt::Device *const *ifrt_client_devices(ifrt::Client *client) { - return client->devices().data(); -} - -extern "C" ifrt::Device *const * -ifrt_client_addressable_devices(ifrt::Client *client) { - return client->addressable_devices().data(); -} - -extern "C" int ifrt_client_process_index(ifrt::Client *client) { - return client->process_index(); -} - -// TODO xla::ifrt::Client::GetDefaultDeviceAssignment - -extern "C" ifrt::Device *ifrt_client_lookup_device(ifrt::Client *client, - int device_id) { - return MyValueOrThrow(client->LookupDevice(ifrt::DeviceId(device_id))); -} - -extern "C" ifrt::Device * -ifrt_client_lookup_addressable_device(ifrt::Client *client, int device_id) { - return MyValueOrThrow(client->LookupAddressableDevice(device_id)); -} - -extern "C" ifrt::Compiler *ifrt_client_default_compiler(ifrt::Client *client) { - return client->GetDefaultCompiler(); -} - -// TODO ifrt_client_topology_for_devices -// TODO ifrt_client_default_layout_for_device -#pragma endregion - -#pragma region xla::ifrt::PjRtClient -// TODO support more parameters of `PjRtClient::CreateOptions` -extern "C" ifrt::PjRtClient * -ifrt_pjrt_client_ctor(xla::PjRtClient *pjrt_client) { - return MyValueOrThrow( - ifrt::PjRtClient::Create(ifrt::PjRtClient::CreateOptions{ - std::shared_ptr{pjrt_client}})) - .release(); -} - -extern "C" void ifrt_pjrt_client_free(ifrt::PjRtClient *client) { - delete client; -} - -extern "C" xla::PjRtClient * -ifrt_pjrt_client_pjrt_client(ifrt::PjRtClient *client) { - return client->pjrt_client(); -} - -// TODO there are problems with using `make_shared -// extern "C" ifrt::PjRtCompatibleArray* -// ifrt_pjrt_client_create_pjrt_array(ifrt::PjRtClient* client, xla::PjRtBuffer* -// pjrt_buffer) { -// auto buffer_ptr = std::make_shared(*pjrt_buffer); -// return MyValueOrThrow(client->CreatePjRtArray(buffer_ptr)).release(); -// } - -// TODO extern "C" ifrt::PjRtCompatibleArray* -// ifrt_pjrt_client_create_pjrt_array_from_buffers(ifrt::Shape* shape, -// ifrt::PjRtBuffer** pjrt_buffers, int num_buffers) {} - -extern "C" ifrt::PjRtCompatibleDevice * -ifrt_pjrt_client_lookup_pjrt_device(ifrt::PjRtClient *client, - xla::PjRtDevice *pjrt_device) { - return MyValueOrThrow(client->LookupPjRtDevice(pjrt_device)); -} - -extern "C" ifrt::PjRtCompatibleMemory * -ifrt_pjrt_client_lookup_pjrt_memory(ifrt::PjRtClient *client, - xla::PjRtMemorySpace *pjrt_memory_space) { - return MyValueOrThrow(client->LookupPjRtMemory(pjrt_memory_space)); -} -#pragma endregion - -#pragma region xla::ifrt::HostCallback -extern "C" const char * -ifrt_hostcallback_serialize(ifrt::HostCallback *host_callback) { - return cstr_from_string(host_callback->Serialize()); -} -#pragma endregion - -#pragma region xla::ifrt::LoadedHostCallback -extern "C" ifrt::Client * -ifrt_loadedhostcallback_client(ifrt::LoadedHostCallback *host_callback) { - return host_callback->client(); -} - -extern "C" const char * -ifrt_loadedhostcallback_serialize(ifrt::LoadedHostCallback *host_callback) { - // auto msg = ; - return cstr_from_string(MyValueOrThrow(host_callback->Serialize())); -} -#pragma endregion - -#pragma region xla::ifrt::PjRtHostSendAndRecvLoadedHostCallback -extern "C" ifrt::PjRtHostSendAndRecvLoadedHostCallback * -ifrt_pjrt_hostsendandrecv_loadhostcallback_ctor( - ifrt::PjRtClient *client, xla::HostCallback *host_callback) { - auto xla_callback_ptr = std::make_unique(*host_callback); - return new ifrt::PjRtHostSendAndRecvLoadedHostCallback( - client, std::move(xla_callback_ptr)); -} - -extern "C" void ifrt_pjrt_hostsendandrecv_loadhostcallback_free( - ifrt::PjRtHostSendAndRecvLoadedHostCallback *host_callback) { - delete host_callback; -} - -extern "C" xla::HostCallback * -ifrt_pjrt_hostsendandrecv_loadhostcallback_host_callback( - ifrt::PjRtHostSendAndRecvLoadedHostCallback *host_callback) { - return new xla::HostCallback(host_callback->host_callback()); -} -#pragma endregion - -#pragma region xla::ifrt::Executable -extern "C" const char *ifrt_executable_name(ifrt::Executable *executable) { - return cstr_from_string(executable->name()); -} - -extern "C" const char * -ifrt_executable_fingerprint(ifrt::Executable *executable) { - auto result = MyValueOrThrow(executable->Fingerprint()); - if (!result.has_value()) - return ""; - return cstr_from_string(result.value()); -} - -extern "C" const char *ifrt_executable_serialize(ifrt::Executable *executable) { - return cstr_from_string(MyValueOrThrow(executable->Serialize())); -} - -extern "C" int ifrt_executable_num_devices(ifrt::Executable *executable) { - return executable->num_devices(); -} - -extern "C" int64_t ifrt_executable_size(ifrt::Executable *executable) { - return executable->SizeOfGeneratedCodeInBytes(); -} - -// TODO xla::ifrt::Executable::GetCompiledMemoryStats - -extern "C" std::tuple -ifrt_executable_parameter_shardings(ifrt::Executable *executable) { - auto shardings = executable->GetParameterShardings(); - if (!shardings.has_value()) - return std::make_tuple(0, nullptr); - return std::make_tuple(shardings.value().size(), shardings.value().data()); -} - -extern "C" std::tuple -ifrt_executable_output_shardings(ifrt::Executable *executable) { - auto shardings = executable->GetOutputShardings(); - if (!shardings.has_value()) - return std::make_tuple(0, nullptr); - return std::make_tuple(shardings.value().size(), shardings.value().data()); -} - -// @mofeng this is now a shared ptr, will let you fix -// extern "C" std::tuple -// ifrt_executable_parameter_layouts(ifrt::Executable *executable) { -// auto layouts = MyValueOrThrow(executable->GetParameterLayouts()); -// auto layouts_ptr = new xla::PjRtLayout *[layouts.size()]; -// for (int i = 0; i < layouts.size(); i++) { -// layouts_ptr[i] = layouts[i].release(); -// } -// return std::make_tuple(layouts.size(), layouts_ptr); -// } - -// @mofeng this is now a shared ptr, will let you fix -// extern "C" std::tuple -// ifrt_executable_output_layouts(ifrt::Executable *executable) { -// auto layouts = MyValueOrThrow(executable->GetOutputLayouts()); -// auto layouts_ptr = new xla::PjRtLayout *[layouts.size()]; -// for (int i = 0; i < layouts.size(); i++) { -// layouts_ptr[i] = layouts[i].release(); -// } -// return std::make_tuple(layouts.size(), layouts_ptr); -// } - -extern "C" std::tuple -ifrt_executable_hlo_modules(ifrt::Executable *executable) { - auto modules = MyValueOrThrow(executable->GetHloModules()); - auto modules_ptr = new xla::HloModule *[modules.size()]; - for (int i = 0; i < modules.size(); i++) { - modules_ptr[i] = modules[i].get(); - } - return std::make_tuple(modules.size(), modules_ptr); -} - -// TODO xla::ifrt::Executable::GetCostAnalysis -#pragma endregion - -#pragma region xla::ifrt::PjRtExecutable -// TODO there are problems with using `make_shared -// extern "C" ifrt::Executable* ifrt_pjrt_executable_ctor(xla::PjRtExecutable* -// pjrt_executable, ifrt::XlaCompileOptions* compile_options) { -// auto pjrt_executable_shared = -// std::make_shared(*pjrt_executable); auto options = -// std::make_unique(*compile_options); return -// MyValueOrThrow(ifrt::PjRtExecutable::Create(pjrt_executable_shared, -// std::move(options))).release(); -// } - -extern "C" void ifrt_pjrt_executable_free(ifrt::PjRtExecutable *executable) { - delete executable; -} - -extern "C" xla::PjRtExecutable * -ifrt_pjrt_executable_pjrt_executable(ifrt::PjRtExecutable *executable) { - return executable->pjrt_executable(); -} -#pragma endregion - -#pragma region xla::ifrt::LoadedExecutable -extern "C" ifrt::Client * -ifrt_loadedexecutable_client(ifrt::LoadedExecutable *executable) { - return executable->client(); -} - -extern "C" const char * -ifrt_loadedexecutable_name(ifrt::LoadedExecutable *executable) { - return cstr_from_string(executable->name()); -} - -extern "C" const char * -ifrt_loadedexecutable_fingerprint(ifrt::LoadedExecutable *executable) { - auto result = MyValueOrThrow(executable->Fingerprint()); - if (!result.has_value()) - return ""; - return cstr_from_string(result.value()); -} - -extern "C" const char * -ifrt_loadedexecutable_serialize(ifrt::LoadedExecutable *executable) { - return cstr_from_string(MyValueOrThrow(executable->Serialize())); -} - -extern "C" ifrt::Future<> -ifrt_loadedexecutable_get_ready_future(ifrt::LoadedExecutable *executable) { - return executable->GetReadyFuture(); -} - -extern "C" int -ifrt_loadedexecutable_num_devices(ifrt::LoadedExecutable *executable) { - return executable->num_devices(); -} - -extern "C" int64_t -ifrt_loadedexecutable_size(ifrt::LoadedExecutable *executable) { - return executable->SizeOfGeneratedCodeInBytes(); -} - -// TODO xla::ifrt::GetCompiledMemoryStats - -extern "C" std::tuple -ifrt_loadedexecutable_parameter_shardings(ifrt::LoadedExecutable *executable) { - auto shardings = executable->GetParameterShardings(); - if (!shardings.has_value()) - return std::make_tuple(0, nullptr); - return std::make_tuple(shardings.value().size(), shardings.value().data()); -} - -extern "C" std::tuple -ifrt_loadedexecutable_output_shardings(ifrt::LoadedExecutable *executable) { - auto shardings = executable->GetOutputShardings(); - if (!shardings.has_value()) - return std::make_tuple(0, nullptr); - return std::make_tuple(shardings.value().size(), shardings.value().data()); -} - -// @mofeng this is now a shared ptr, will let you fix -// extern "C" std::tuple -// ifrt_loadedexecutable_parameter_layouts(ifrt::LoadedExecutable *executable) { -// auto layouts = MyValueOrThrow(executable->GetParameterLayouts()); -// auto layouts_ptr = new xla::PjRtLayout *[layouts.size()]; -// for (int i = 0; i < layouts.size(); i++) { -// layouts_ptr[i] = layouts[i].release(); -// } -// return std::make_tuple(layouts.size(), layouts_ptr); -// } - -// @mofeng this is now a shared ptr, will let you fix -// extern "C" std::tuple -// ifrt_loadedexecutable_output_layouts(ifrt::LoadedExecutable *executable) { -// auto layouts = MyValueOrThrow(executable->GetOutputLayouts()); -// auto layouts_ptr = new xla::PjRtLayout *[layouts.size()]; -// for (int i = 0; i < layouts.size(); i++) { -// layouts_ptr[i] = layouts[i].release(); -// } -// return std::make_tuple(layouts.size(), layouts_ptr); -// } - -extern "C" std::tuple -ifrt_loadedexecutable_hlo_modules(ifrt::LoadedExecutable *executable) { - auto modules = MyValueOrThrow(executable->GetHloModules()); - auto modules_ptr = new xla::HloModule *[modules.size()]; - for (int i = 0; i < modules.size(); i++) { - modules_ptr[i] = modules[i].get(); - } - return std::make_tuple(modules.size(), modules_ptr); -} - -// TODO xla::ifrt::LoadedExecutable::GetOutputMemoryKinds -// TODO xla::ifrt::LoadedExecutable::GetCostAnalysis - -// extern "C" ifrt::LoadedExecutable::ExecuteResult* -// ifrt_loadedexecutable_execute(ifrt::LoadedExecutable* executable, -// ifrt::Array** args, size_t args_size, ifrt::Array** results, size_t -// results_size, ifrt::Future<*>** futures, size_t futures_size) { -// std::vector arguments(args, args + args_size); -// std::vector result(results, results + results_size); -// std::vector*> future(futures, futures + futures_size); -// return MyValueOrThrow(executable->Execute(arguments, result, future)); -// } - -extern "C" ifrt::Future<> -ifrt_loadedexecutable_delete(ifrt::LoadedExecutable *executable) { - return executable->Delete(); -} - -extern "C" bool -ifrt_loadedexecutable_is_deleted(ifrt::LoadedExecutable *executable) { - return executable->IsDeleted(); -} - -extern "C" std::tuple -ifrt_loadedexecutable_addressable_devices(ifrt::LoadedExecutable *executable) { - auto devices = executable->addressable_devices(); - return std::make_tuple(devices.size(), devices.data()); -} - -// TODO auxiliary functions for xla::ifrt::LoadedExecutable::ExecuteResult -#pragma endregion - -#pragma region xla::ifrt::PjRtLoadedExecutable -// TODO add support for LoadedHostCallback -// TODO there are problems with using `make_shared -// extern "C" ifrt::LoadedExecutable* -// ifrt_pjrt_loadedexecutable_ctor(ifrt::PjRtCompatibleClient* client, -// xla::PjRtLoadedExecutable* pjrt_loaded_executable) { -// auto pjrt_loaded_executable_ptr = -// std::make_shared(*pjrt_loaded_executable); -// return MyValueOrThrow(ifrt::PjRtLoadedExecutable::Create(client, -// pjrt_loaded_executable_ptr, -// std::vector>())).release(); -// } - -// TODO add support for LoadedHostCallback -extern "C" ifrt::LoadedExecutable * -ifrt_pjrt_loadedexecutable_ctor_from_mlir_module( - ifrt::PjRtCompatibleClient *client, mlir::ModuleOp *module, - xla::CompileOptions *compile_options) { - return MyValueOrThrow( - ifrt::PjRtLoadedExecutable::Create( - client, *module, *compile_options, - std::vector>())) - .release(); -} - -extern "C" void -ifrt_pjrt_loadedexecutable_free(ifrt::PjRtLoadedExecutable *executable) { - delete executable; -} - -extern "C" xla::PjRtLoadedExecutable * -ifrt_pjrt_loadedexecutable_pjrt_loadedexecutable( - ifrt::PjRtLoadedExecutable *executable) { - return executable->pjrt_loaded_executable(); -} -#pragma endregion - -#pragma region xla::ifrt::CustomCallProgram -#pragma endregion - -#pragma region xla::ifrt::HloProgram -extern "C" ifrt::HloProgram *ifrt_hloprogram_ctor() { - return new ifrt::HloProgram(); -} - -extern "C" ifrt::HloProgram * -ifrt_hloprogram_ctor_with_module(mlir::ModuleOp *module) { - return new ifrt::HloProgram(*module); -} - -// extern "C" ifrt::HloProgram* -// ifrt_hloprogram_ctor_with_context_and_module(mlir::MLIRContext* context, -// mlir::ModuleOp* module) { -// auto context_ptr = std::make_unique(*context); -// return new ifrt::HloProgram(std::move(context_ptr), *module); -// } -#pragma endregion - -#pragma region xla::ifrt::Compiler -extern "C" ifrt::LoadedExecutable * -ifrt_compiler_compile(ifrt::Compiler *compiler, ifrt::Program *program) { - // apparently ifrt::CompileOptions is a legacy artifact so we don't use it and - // set directly to the default - auto program_ptr = std::make_unique(*program); - auto options = std::make_unique(); - return MyValueOrThrow( - compiler->Compile(std::move(program_ptr), std::move(options))) - .release(); -} - -extern "C" ifrt::Executable * -ifrt_compiler_compile_with_topology(ifrt::Compiler *compiler, - ifrt::Program *program, - const ifrt::Topology *topology) { - // apparently ifrt::CompileOptions is a legacy artifact so we don't use it and - // set directly to the default - auto options = std::make_unique(); - auto program_ptr = std::make_unique(*program); - auto exec_ptr = - MyValueOrThrow(compiler->Compile(std::move(program_ptr), *topology, - std::move(options))) - .release(); - return exec_ptr; -} - -extern "C" ifrt::LoadedExecutable * -ifrt_compiler_deserialize_loadedexecutable(ifrt::Compiler *compiler, - const char *data) { - // apparently ifrt::DeserializeExecutableOptions is a legacy artifact so we - // don't use it and set directly to the default - auto options = std::make_unique(); - return MyValueOrThrow(compiler->DeserializeLoadedExecutable( - std::string(data), std::move(options))) - .release(); -} -#pragma endregion - -#pragma region xla::ifrt::PjRtCompiler -extern "C" ifrt::PjRtCompiler * -ifrt_pjrt_compiler_ctor(ifrt::PjRtClient *client) { - return new ifrt::PjRtCompiler(client); -} - -extern "C" void ifrt_pjrt_compiler_free(ifrt::PjRtCompiler *compiler) { - delete compiler; -} -#pragma endregion - -#pragma endregion From 7981e190c5f483b98b5a55c5a82776b24cd9db43 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sun, 9 Feb 2025 06:19:07 -0600 Subject: [PATCH 02/13] implement minimum IFRT-PjRt methods --- deps/ReactantExtra/API.cpp | 114 +++++++++++++++++++++++++++++++++++++ 1 file changed, 114 insertions(+) diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index f219fc3141..c2e1721837 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -991,3 +991,117 @@ extern "C" MlirOperation LinkInModule(MlirModule prevModC, MlirModule newModC, newMod.getBody()->getOperations()); return wrap(entryFn); } + +namespace reactant { + +template struct unwrap_type { typedef T type; }; +template struct unwrap_type> { typedef T type; }; +template struct unwrap_type> { typedef T type; }; + +template using unwrap_type_t = typename unwrap_type::type; + +template +struct Holded { + public: + Holded(T& obj) : holded(obj) {} + ~Holded() = default; + + unwrap_type_t* ptr() const { + return holded.get(); + } + + T obj() const { + return holded; + } + + T value() const { + return holded; + } + + unwrap_type_t* operator->() const { + return ptr(); + } + + private: + T holded; +}; + +template +Holded* capture(T obj) { + return new Holded(obj); +} + +} // namespace reactant + +using reactant::Holded; + +extern "C" Holded>* reactant_hold_pjrtclient(xla::PjRtClient* client) { + return reactant::capture(std::shared_ptr(client)); +} + +extern "C" void reactant_release_pjrtclient(Holded>* client) { delete client; } + +extern "C" Holded>* reactant_hold_pjrtbuffer(xla::PjRtBuffer* buffer) { + return reactant::capture(std::shared_ptr(buffer)); +} + +extern "C" void reactant_release_pjrtbuffer(Holded>* buffer) { delete buffer; } + +extern "C" ifrt::PjRtClient* MakeIFRTPJRTClient(Holded>* pjrt_client) { + xla::ifrt::PjRtClient::CreateOptions options = {pjrt_client->obj()}; + return MyValueOrThrow(xla::ifrt::PjRtClient::Create(options)).release(); +} + +extern "C" void FreeIFRTPJRTClient(ifrt::PjRtClient* client) { delete client; } + +extern "C" xla::ifrt::LoadedExecutable* IFRTPJRT_ClientCompile(ifrt::PjRtClient* client, MlirModule mlir_mod) { + mlir::ModuleOp mlir_mod_op = cast(*unwrap(mlir_mod)); + // TODO import sharding config from `ClientCompile`? + xla::CompileOptions compile_options; + // TODO can't create LoadedExecutable from mlir::ModuleOp on IFRT-proxy backend + return MyValueOrThrow(xla::ifrt::PjRtLoadedExecutable::Create(client, mlir_mod_op, compile_options, std::vector>())).release(); +} + +extern "C" void FreeLoadedExecutableIFRTPJRT(xla::ifrt::PjRtLoadedExecutable* exec) { delete exec; } + +extern "C" Holded>* ArrayFromHostBufferIFRTPJRT(ifrt::PjRtClient* client, Holded>* buffer) { + return reactant::capture(MyValueOrThrow(xla::ifrt::PjRtArray::Create(client, buffer->obj()))); +} + +extern "C" void reactant_release_ifrt_pjrt_array(Holded>* array) { delete array; } + +extern "C" void IFRT_Execute(ifrt::LoadedExecutable* exec, int num_args, Holded>** op_args, uint8_t* is_arg_donatable, int num_results, Holded>** op_results, uint8_t *futures, FutureType** status) { + std::vector> args; + for (int i = 0; i < num_args; i++) { + args.emplace_back(op_args[i]->obj()); + } + + ifrt::ExecuteOptions options; + for (size_t i = 0; i < num_args; i++) { + if (!is_arg_donatable[i]) { + options.non_donatable_input_indices.insert(static_cast(i)); + } + } + options.fill_status = true; + + auto result = MyValueOrThrow(exec->Execute(static_cast>>(args), options, /* devices */ std::nullopt)); + + if (result.outputs.size() != num_results) { + llvm::errs() << "Error: results.size()=" << result.outputs.size() + << " does not match num_results=" << num_results << "\n"; + std::abort(); // Terminate if the number of results is incorrect. + } + + // there is only 1 status and is valid because we set `options.fill_status = true` + *futures = true; + *status = new FutureType(result.status); + + for (int i = 0; i < num_results; i++) { + op_results[i] = reactant::capture(result.outputs[i]); + } +} + +// in principle, use ArrayCopySemantics::kAlwaysCopy (=0) +extern "C" FutureType* IFRT_Array_CopyToHostBuffer(Holded>* array, void* data, ifrt::ArrayCopySemantics semantics) { + (*array)->CopyToHostBuffer(data, std::nullopt, semantics); +} From 7a83cc1077e4323f44ebf3e31ddfb1546447a2ff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sun, 9 Feb 2025 08:46:07 -0600 Subject: [PATCH 03/13] rename exported bindings --- deps/ReactantExtra/API.cpp | 14 +++++++------- deps/ReactantExtra/BUILD | 3 ++- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index c2e1721837..d5ef43f5cd 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -1047,14 +1047,14 @@ extern "C" Holded>* reactant_hold_pjrtbuffer(xl extern "C" void reactant_release_pjrtbuffer(Holded>* buffer) { delete buffer; } -extern "C" ifrt::PjRtClient* MakeIFRTPJRTClient(Holded>* pjrt_client) { +extern "C" ifrt::PjRtClient* ifrt_pjrt_MakeClient(Holded>* pjrt_client) { xla::ifrt::PjRtClient::CreateOptions options = {pjrt_client->obj()}; return MyValueOrThrow(xla::ifrt::PjRtClient::Create(options)).release(); } -extern "C" void FreeIFRTPJRTClient(ifrt::PjRtClient* client) { delete client; } +extern "C" void ifrt_pjrt_FreeClient(ifrt::PjRtClient* client) { delete client; } -extern "C" xla::ifrt::LoadedExecutable* IFRTPJRT_ClientCompile(ifrt::PjRtClient* client, MlirModule mlir_mod) { +extern "C" xla::ifrt::LoadedExecutable* ifrt_pjrt_ClientCompile(ifrt::PjRtClient* client, MlirModule mlir_mod) { mlir::ModuleOp mlir_mod_op = cast(*unwrap(mlir_mod)); // TODO import sharding config from `ClientCompile`? xla::CompileOptions compile_options; @@ -1062,15 +1062,15 @@ extern "C" xla::ifrt::LoadedExecutable* IFRTPJRT_ClientCompile(ifrt::PjRtClient* return MyValueOrThrow(xla::ifrt::PjRtLoadedExecutable::Create(client, mlir_mod_op, compile_options, std::vector>())).release(); } -extern "C" void FreeLoadedExecutableIFRTPJRT(xla::ifrt::PjRtLoadedExecutable* exec) { delete exec; } +extern "C" void ifrt_pjrt_FreeLoadedExecutable(xla::ifrt::PjRtLoadedExecutable* exec) { delete exec; } -extern "C" Holded>* ArrayFromHostBufferIFRTPJRT(ifrt::PjRtClient* client, Holded>* buffer) { +extern "C" Holded>* ifrt_pjrt_ArrayFromHostBuffer(ifrt::PjRtClient* client, Holded>* buffer) { return reactant::capture(MyValueOrThrow(xla::ifrt::PjRtArray::Create(client, buffer->obj()))); } extern "C" void reactant_release_ifrt_pjrt_array(Holded>* array) { delete array; } -extern "C" void IFRT_Execute(ifrt::LoadedExecutable* exec, int num_args, Holded>** op_args, uint8_t* is_arg_donatable, int num_results, Holded>** op_results, uint8_t *futures, FutureType** status) { +extern "C" void ifrt_Execute(ifrt::LoadedExecutable* exec, int num_args, Holded>** op_args, uint8_t* is_arg_donatable, int num_results, Holded>** op_results, uint8_t *futures, FutureType** status) { std::vector> args; for (int i = 0; i < num_args; i++) { args.emplace_back(op_args[i]->obj()); @@ -1102,6 +1102,6 @@ extern "C" void IFRT_Execute(ifrt::LoadedExecutable* exec, int num_args, Holded< } // in principle, use ArrayCopySemantics::kAlwaysCopy (=0) -extern "C" FutureType* IFRT_Array_CopyToHostBuffer(Holded>* array, void* data, ifrt::ArrayCopySemantics semantics) { +extern "C" FutureType* ifrt_CopyArrayToHostBuffer(Holded>* array, void* data, ifrt::ArrayCopySemantics semantics) { (*array)->CopyToHostBuffer(data, std::nullopt, semantics); } diff --git a/deps/ReactantExtra/BUILD b/deps/ReactantExtra/BUILD index 3339973a2b..0708922af7 100644 --- a/deps/ReactantExtra/BUILD +++ b/deps/ReactantExtra/BUILD @@ -437,7 +437,6 @@ cc_library( "-Wl,-exported_symbol,_RegisterDialects", "-Wl,-exported_symbol,_InitializeRegistry", "-Wl,-exported_symbol,_InitializePasses", -"-Wl,-exported_symbol,_ifrt_*", "-Wl,-exported_symbol,_RegisterCustomCallTarget", "-Wl,-exported_symbol,_ConvertLLVMToMLIR", "-Wl,-exported_symbol,_RegisterEnzymeXLAGPUHandler", @@ -465,6 +464,8 @@ cc_library( "-Wl,-exported_symbol,_BufferShape", "-Wl,-exported_symbol,_BufferNDimensions", "-Wl,-exported_symbol,_BufferPrimitiveType", +"-Wl,-exported_symbol,_ifrt_*", +"-Wl,-exported_symbol,_reactant_*", ]}), deps = [ "@enzyme//:EnzymeMLIR", From d8e2a1e50821c738fcf1f925f86185642e7cacaa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sun, 9 Feb 2025 09:44:25 -0600 Subject: [PATCH 04/13] Prototype Julia-side of the bindings --- src/xla/Buffer.jl | 20 ++++++++++++++++---- src/xla/Client.jl | 16 ++++++++++++++-- src/xla/IFRT/Array.jl | 32 ++++++++++++++++++++++++++++++++ src/xla/IFRT/Client.jl | 23 +++++++++++++++++++++++ src/xla/IFRT/IFRT.jl | 13 +++++++++++++ src/xla/IFRT/LoadedExecutable.jl | 15 +++++++++++++++ src/xla/XLA.jl | 2 ++ 7 files changed, 115 insertions(+), 6 deletions(-) create mode 100644 src/xla/IFRT/Array.jl create mode 100644 src/xla/IFRT/Client.jl create mode 100644 src/xla/IFRT/IFRT.jl create mode 100644 src/xla/IFRT/LoadedExecutable.jl diff --git a/src/xla/Buffer.jl b/src/xla/Buffer.jl index 52a0718655..f0820d1113 100644 --- a/src/xla/Buffer.jl +++ b/src/xla/Buffer.jl @@ -1,16 +1,28 @@ # Buffer @inline function free_buffer(buffer) - sbuffer = buffer.buffer - if sbuffer != C_NULL - @ccall MLIR.API.mlir_c.PjRtBufferFree(sbuffer::Ptr{Cvoid})::Cvoid + if buffer.holded == C_NULL + if buffer.buffer != C_NULL + @ccall MLIR.API.mlir_c.PjRtBufferFree(buffer.buffer::Ptr{Cvoid})::Cvoid + end + else + @ccall MLIR.API.mlir_c.reactant_release_pjrtbuffer(buffer.holded::Ptr{Cvoid})::Cvoid end end mutable struct Buffer buffer::Ptr{Cvoid} + holded::Ptr{Cvoid} function Buffer(buffer::Ptr{Cvoid}) - return finalizer(free_buffer, new(buffer)) + return finalizer(free_buffer, new(buffer, C_NULL)) + end +end + +function hold!(buffer::Buffer) + if buffer.holded == C_NULL + sbuffer = buffer.buffer + buffer.holded = @ccall MLIR.API.mlir_c.reactant_hold_pjrtbuffer(sbuffer::Ptr{Cvoid})::Ptr{Cvoid} end + return buffer end function Base.ndims(buffer::Buffer) diff --git a/src/xla/Client.jl b/src/xla/Client.jl index a9ac15e318..de59f04df1 100644 --- a/src/xla/Client.jl +++ b/src/xla/Client.jl @@ -1,12 +1,13 @@ mutable struct Client client::Ptr{Cvoid} global_ordinals::Vector{Cint} + holded::Ptr{Cvoid} function Client(client::Ptr{Cvoid}) @assert client != C_NULL global_ordinals = Cint[] - client = new(client, global_ordinals) + client = new(client, global_ordinals, C_NULL) # https://github.com/pytorch/xla/blob/8b2414094578e829b99a8383877c86d357eeb682/torch_xla/csrc/runtime/pjrt_computation_client.cc#L127 devices = [ @@ -29,7 +30,18 @@ end Base.:(==)(a::Client, b::Client) = a.client == b.client @inline function free_client(client::Client) - @ccall MLIR.API.mlir_c.FreeClient(client.client::Ptr{Cvoid})::Cvoid + if client.holded == C_NULL + @ccall MLIR.API.mlir_c.FreeClient(client.client::Ptr{Cvoid})::Cvoid + else + @ccall MLIR.API.mlir_c.reactant_release_pjrtclient(client.holded::Ptr{Cvoid})::Cvoid + end +end + +function hold!(client::Client) + if client.holded == C_NULL + client.holded = @ccall MLIR.API.mlir_c.reactant_hold_pjrtclient(client.client::Ptr{Cvoid})::Ptr{Cvoid} + end + return client end function ClientNumDevices(client::Client) diff --git a/src/xla/IFRT/Array.jl b/src/xla/IFRT/Array.jl new file mode 100644 index 0000000000..e25353ddc2 --- /dev/null +++ b/src/xla/IFRT/Array.jl @@ -0,0 +1,32 @@ +@cenum ArrayCopySemantics::UInt32 begin + AlwaysCopy = 0 + ReuseInput = 1 + DonateInput = 2 +end + +# currently, only supports IFRT-PjRt +mutable struct Array + ptr::Ptr{Cvoid} + + function Array(ptr::Ptr{Cvoid}) + @assert ptr != C_NULL + return finalizer(free_array, new(ptr)) + end +end + +function free_array(array) + @ccall MLIR.API.mlir_c.reactant_release_ifrt_pjrt_array(array.ptr::Ptr{Cvoid})::Cvoid +end + +function Array(client::Client, buffer::XLA.Buffer) + hold!(buffer) + GC.@preserve client buffer begin + return Array(@ccall MLIR.API.mlir_c.ifrt_pjrt_ArrayFromHostBuffer(client.ptr::Ptr{Cvoid}, buffer.holded::Ptr{Cvoid})::Ptr{Cvoid}) + end +end + +function CopyArrayToHostBuffer(array::Array, data) + GC.@preserve array begin + @ccall MLIR.API.mlir_c.ifrt_CopyArrayToHostBuffer(array.ptr::Ptr{Cvoid}, data::Ptr{Cvoid}, AlwaysCopy::Cuint)::Cvoid + end +end diff --git a/src/xla/IFRT/Client.jl b/src/xla/IFRT/Client.jl new file mode 100644 index 0000000000..e6fa9708a0 --- /dev/null +++ b/src/xla/IFRT/Client.jl @@ -0,0 +1,23 @@ +# currently, only supports IFRT-PjRt +mutable struct Client + ptr::Ptr{Cvoid} + + function Client(ptr::Ptr{Cvoid}) + @assert ptr != C_NULL + return finalizer(free_client, new(ptr)) + end +end + +function Client(pjrt_client::XLA.Client) + # it needs a `std::shared_ptr` + hold!(pjrt_client) + return Client(@ccall MLIR.API.mlir_c.ifrt_pjrt_MakeClient(pjrt_client.holded::Ptr{Cvoid})::Ptr{Cvoid}) +end + +function free_client(client) + @ccall MLIR.API.mlir_c.ifrt_pjrt_FreeClient(client.ptr::Ptr{Cvoid})::Cvoid +end + +function compile(client::Client, code::MLIR.IR.Module) + return LoadedExecutable(@ccall MLIR.API.mlir_c.ifrt_pjrt_ClientCompile(client.ptr::Ptr{Cvoid}, mod.module_::MLIR.API.MlirModule)::Ptr{Cvoid}) +end diff --git a/src/xla/IFRT/IFRT.jl b/src/xla/IFRT/IFRT.jl new file mode 100644 index 0000000000..def0304cd6 --- /dev/null +++ b/src/xla/IFRT/IFRT.jl @@ -0,0 +1,13 @@ +module IFRT + +using CEnum + +import ..XLA +import .XLA: hold! +import ..MLIR + +include("LoadedExecutable.jl") +include("Client.jl") +include("Array.jl") + +end diff --git a/src/xla/IFRT/LoadedExecutable.jl b/src/xla/IFRT/LoadedExecutable.jl new file mode 100644 index 0000000000..29b72edd26 --- /dev/null +++ b/src/xla/IFRT/LoadedExecutable.jl @@ -0,0 +1,15 @@ +# currently, only supports IFRT-PjRt +mutable struct LoadedExecutable + ptr::Ptr{Cvoid} + + function LoadedExecutable(ptr::Ptr{Cvoid}) + @assert ptr != C_NULL + return finalizer(free_exec, new(ptr)) + end +end + +@inline function free_exec(exec) + @ccall MLIR.API.mlir_c.ifrt_pjrt_FreeLoadedExecutable(exec.ptr::Ptr{Cvoid})::Cvoid +end + +# TODO execute diff --git a/src/xla/XLA.jl b/src/xla/XLA.jl index fe00c69f65..d2b820d61f 100644 --- a/src/xla/XLA.jl +++ b/src/xla/XLA.jl @@ -95,4 +95,6 @@ function __init__() return nothing end +include("IFRT/IFRT.jl") + end From aa2599543adda37b775a2b2e8be175d3d27f76ea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= <15837247+mofeing@users.noreply.github.com> Date: Sun, 9 Feb 2025 16:48:12 +0100 Subject: [PATCH 05/13] Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/xla/Buffer.jl | 4 +++- src/xla/Client.jl | 4 +++- src/xla/IFRT/Array.jl | 10 ++++++++-- src/xla/IFRT/Client.jl | 12 ++++++++++-- 4 files changed, 24 insertions(+), 6 deletions(-) diff --git a/src/xla/Buffer.jl b/src/xla/Buffer.jl index f0820d1113..52035de10a 100644 --- a/src/xla/Buffer.jl +++ b/src/xla/Buffer.jl @@ -20,7 +20,9 @@ end function hold!(buffer::Buffer) if buffer.holded == C_NULL sbuffer = buffer.buffer - buffer.holded = @ccall MLIR.API.mlir_c.reactant_hold_pjrtbuffer(sbuffer::Ptr{Cvoid})::Ptr{Cvoid} + buffer.holded = @ccall MLIR.API.mlir_c.reactant_hold_pjrtbuffer( + sbuffer::Ptr{Cvoid} + )::Ptr{Cvoid} end return buffer end diff --git a/src/xla/Client.jl b/src/xla/Client.jl index de59f04df1..862cb3e876 100644 --- a/src/xla/Client.jl +++ b/src/xla/Client.jl @@ -39,7 +39,9 @@ end function hold!(client::Client) if client.holded == C_NULL - client.holded = @ccall MLIR.API.mlir_c.reactant_hold_pjrtclient(client.client::Ptr{Cvoid})::Ptr{Cvoid} + client.holded = @ccall MLIR.API.mlir_c.reactant_hold_pjrtclient( + client.client::Ptr{Cvoid} + )::Ptr{Cvoid} end return client end diff --git a/src/xla/IFRT/Array.jl b/src/xla/IFRT/Array.jl index e25353ddc2..bd76a213fa 100644 --- a/src/xla/IFRT/Array.jl +++ b/src/xla/IFRT/Array.jl @@ -21,12 +21,18 @@ end function Array(client::Client, buffer::XLA.Buffer) hold!(buffer) GC.@preserve client buffer begin - return Array(@ccall MLIR.API.mlir_c.ifrt_pjrt_ArrayFromHostBuffer(client.ptr::Ptr{Cvoid}, buffer.holded::Ptr{Cvoid})::Ptr{Cvoid}) + return Array( + @ccall MLIR.API.mlir_c.ifrt_pjrt_ArrayFromHostBuffer( + client.ptr::Ptr{Cvoid}, buffer.holded::Ptr{Cvoid} + )::Ptr{Cvoid} + ) end end function CopyArrayToHostBuffer(array::Array, data) GC.@preserve array begin - @ccall MLIR.API.mlir_c.ifrt_CopyArrayToHostBuffer(array.ptr::Ptr{Cvoid}, data::Ptr{Cvoid}, AlwaysCopy::Cuint)::Cvoid + @ccall MLIR.API.mlir_c.ifrt_CopyArrayToHostBuffer( + array.ptr::Ptr{Cvoid}, data::Ptr{Cvoid}, AlwaysCopy::Cuint + )::Cvoid end end diff --git a/src/xla/IFRT/Client.jl b/src/xla/IFRT/Client.jl index e6fa9708a0..de553be306 100644 --- a/src/xla/IFRT/Client.jl +++ b/src/xla/IFRT/Client.jl @@ -11,7 +11,11 @@ end function Client(pjrt_client::XLA.Client) # it needs a `std::shared_ptr` hold!(pjrt_client) - return Client(@ccall MLIR.API.mlir_c.ifrt_pjrt_MakeClient(pjrt_client.holded::Ptr{Cvoid})::Ptr{Cvoid}) + return Client( + @ccall MLIR.API.mlir_c.ifrt_pjrt_MakeClient( + pjrt_client.holded::Ptr{Cvoid} + )::Ptr{Cvoid} + ) end function free_client(client) @@ -19,5 +23,9 @@ function free_client(client) end function compile(client::Client, code::MLIR.IR.Module) - return LoadedExecutable(@ccall MLIR.API.mlir_c.ifrt_pjrt_ClientCompile(client.ptr::Ptr{Cvoid}, mod.module_::MLIR.API.MlirModule)::Ptr{Cvoid}) + return LoadedExecutable( + @ccall MLIR.API.mlir_c.ifrt_pjrt_ClientCompile( + client.ptr::Ptr{Cvoid}, mod.module_::MLIR.API.MlirModule + )::Ptr{Cvoid} + ) end From 5f29d05d4f9366a9216b00f687c505f705d976ab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sun, 9 Feb 2025 10:37:28 -0600 Subject: [PATCH 06/13] implement execute --- src/xla/IFRT/LoadedExecutable.jl | 30 +++++++++++++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/src/xla/IFRT/LoadedExecutable.jl b/src/xla/IFRT/LoadedExecutable.jl index 29b72edd26..bf63b2bd0b 100644 --- a/src/xla/IFRT/LoadedExecutable.jl +++ b/src/xla/IFRT/LoadedExecutable.jl @@ -12,4 +12,32 @@ end @ccall MLIR.API.mlir_c.ifrt_pjrt_FreeLoadedExecutable(exec.ptr::Ptr{Cvoid})::Cvoid end -# TODO execute +function execute(exec::LoadedExecutable, args::NTuple{N,Ptr{Cvoid}}, donated_mask::NTuple{N,UInt8}, ::Val{n_results}) where {N,n_results} + results = Ref{NTuple{n_results, Ptr{Cvoid}}}() + has_future = Ref{UInt8}() + status = Ref{NTuple{1, Ptr{Cvoid}}}() # unused right now + + args = Base.RefValue(args) + donated_mask = Base.RefValue(donated_mask) + + GC.@preserve exec args donated_mask results has_future status begin + @ccall MLIR.API.mlir_c.ifrt_Execute( + exec.ptr::Ptr{Cvoid}, + N::Cint, + args::Ptr{Cvoid}, + donated_mask::Ptr{Cvoid}, + n_results::Cint, + Base.unsafe_convert(Ptr{Cvoid}, results)::Ptr{Cvoid}, + has_future::Ptr{Cvoid}, + status::Ptr{Cvoid}, + )::Cvoid + end + + @assert has_future[] == true + + results = results[] + + return ntuple(Val(n_results)) do i + return Array(results[i]) + end +end From eb6f636823d01b9ac22412d82a13f3e83d92a20e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sun, 9 Feb 2025 10:37:37 -0600 Subject: [PATCH 07/13] fixes --- src/xla/IFRT/Array.jl | 2 +- src/xla/IFRT/Client.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/xla/IFRT/Array.jl b/src/xla/IFRT/Array.jl index bd76a213fa..afc69bee10 100644 --- a/src/xla/IFRT/Array.jl +++ b/src/xla/IFRT/Array.jl @@ -30,7 +30,7 @@ function Array(client::Client, buffer::XLA.Buffer) end function CopyArrayToHostBuffer(array::Array, data) - GC.@preserve array begin + GC.@preserve array data begin @ccall MLIR.API.mlir_c.ifrt_CopyArrayToHostBuffer( array.ptr::Ptr{Cvoid}, data::Ptr{Cvoid}, AlwaysCopy::Cuint )::Cvoid diff --git a/src/xla/IFRT/Client.jl b/src/xla/IFRT/Client.jl index de553be306..0bd2502d3b 100644 --- a/src/xla/IFRT/Client.jl +++ b/src/xla/IFRT/Client.jl @@ -25,7 +25,7 @@ end function compile(client::Client, code::MLIR.IR.Module) return LoadedExecutable( @ccall MLIR.API.mlir_c.ifrt_pjrt_ClientCompile( - client.ptr::Ptr{Cvoid}, mod.module_::MLIR.API.MlirModule + client.ptr::Ptr{Cvoid}, code.module_::MLIR.API.MlirModule )::Ptr{Cvoid} ) end From 98d02a97a6a0bb51ec909a8346113fb3e58509f5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= <15837247+mofeing@users.noreply.github.com> Date: Sun, 9 Feb 2025 17:45:42 +0100 Subject: [PATCH 08/13] Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/xla/IFRT/LoadedExecutable.jl | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/xla/IFRT/LoadedExecutable.jl b/src/xla/IFRT/LoadedExecutable.jl index bf63b2bd0b..f376947042 100644 --- a/src/xla/IFRT/LoadedExecutable.jl +++ b/src/xla/IFRT/LoadedExecutable.jl @@ -12,10 +12,15 @@ end @ccall MLIR.API.mlir_c.ifrt_pjrt_FreeLoadedExecutable(exec.ptr::Ptr{Cvoid})::Cvoid end -function execute(exec::LoadedExecutable, args::NTuple{N,Ptr{Cvoid}}, donated_mask::NTuple{N,UInt8}, ::Val{n_results}) where {N,n_results} - results = Ref{NTuple{n_results, Ptr{Cvoid}}}() +function execute( + exec::LoadedExecutable, + args::NTuple{N,Ptr{Cvoid}}, + donated_mask::NTuple{N,UInt8}, + ::Val{n_results}, +) where {N,n_results} + results = Ref{NTuple{n_results,Ptr{Cvoid}}}() has_future = Ref{UInt8}() - status = Ref{NTuple{1, Ptr{Cvoid}}}() # unused right now + status = Ref{NTuple{1,Ptr{Cvoid}}}() # unused right now args = Base.RefValue(args) donated_mask = Base.RefValue(donated_mask) From 0cb4ac4e893f03739910648c36457b37f1dc8b0c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sun, 9 Feb 2025 12:32:19 -0600 Subject: [PATCH 09/13] last fixes --- deps/ReactantExtra/API.cpp | 23 +++++++++++++++-------- src/xla/IFRT/Array.jl | 2 +- src/xla/IFRT/Client.jl | 4 ++-- 3 files changed, 18 insertions(+), 11 deletions(-) diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index d5ef43f5cd..279b1ecc86 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -106,6 +106,8 @@ #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" #include "triton/Dialect/Triton/IR/Dialect.h" +#include "llvm/Support/ExtensibleRTTI.h" + using namespace mlir; using namespace llvm; using namespace xla; @@ -1047,14 +1049,14 @@ extern "C" Holded>* reactant_hold_pjrtbuffer(xl extern "C" void reactant_release_pjrtbuffer(Holded>* buffer) { delete buffer; } -extern "C" ifrt::PjRtClient* ifrt_pjrt_MakeClient(Holded>* pjrt_client) { +extern "C" ifrt::Client* ifrt_pjrt_MakeClient(Holded>* pjrt_client) { xla::ifrt::PjRtClient::CreateOptions options = {pjrt_client->obj()}; return MyValueOrThrow(xla::ifrt::PjRtClient::Create(options)).release(); } -extern "C" void ifrt_pjrt_FreeClient(ifrt::PjRtClient* client) { delete client; } +extern "C" void ifrt_FreeClient(ifrt::Client* client) { delete client; } -extern "C" xla::ifrt::LoadedExecutable* ifrt_pjrt_ClientCompile(ifrt::PjRtClient* client, MlirModule mlir_mod) { +extern "C" xla::ifrt::LoadedExecutable* ifrt_ClientCompile(ifrt::PjRtClient* client, MlirModule mlir_mod) { mlir::ModuleOp mlir_mod_op = cast(*unwrap(mlir_mod)); // TODO import sharding config from `ClientCompile`? xla::CompileOptions compile_options; @@ -1064,11 +1066,12 @@ extern "C" xla::ifrt::LoadedExecutable* ifrt_pjrt_ClientCompile(ifrt::PjRtClient extern "C" void ifrt_pjrt_FreeLoadedExecutable(xla::ifrt::PjRtLoadedExecutable* exec) { delete exec; } -extern "C" Holded>* ifrt_pjrt_ArrayFromHostBuffer(ifrt::PjRtClient* client, Holded>* buffer) { - return reactant::capture(MyValueOrThrow(xla::ifrt::PjRtArray::Create(client, buffer->obj()))); +// TODO replace with `Client::MakeArrayFromHostBuffer` and generalize to `ifrt::Client` +extern "C" Holded>* ifrt_pjrt_ArrayFromHostBuffer(ifrt::PjRtClient* client, Holded>* buffer) { + return reactant::capture(tsl::RCReference(MyValueOrThrow(xla::ifrt::PjRtArray::Create(client, buffer->obj())))); } -extern "C" void reactant_release_ifrt_pjrt_array(Holded>* array) { delete array; } +extern "C" void reactant_release_ifrt_array(Holded>* array) { delete array; } extern "C" void ifrt_Execute(ifrt::LoadedExecutable* exec, int num_args, Holded>** op_args, uint8_t* is_arg_donatable, int num_results, Holded>** op_results, uint8_t *futures, FutureType** status) { std::vector> args; @@ -1102,6 +1105,10 @@ extern "C" void ifrt_Execute(ifrt::LoadedExecutable* exec, int num_args, Holded< } // in principle, use ArrayCopySemantics::kAlwaysCopy (=0) -extern "C" FutureType* ifrt_CopyArrayToHostBuffer(Holded>* array, void* data, ifrt::ArrayCopySemantics semantics) { - (*array)->CopyToHostBuffer(data, std::nullopt, semantics); +extern "C" FutureType* ifrt_CopyArrayToHostBuffer(Holded>* array, void* data, ifrt::ArrayCopySemantics semantics) { + return new FutureType((*array)->CopyToHostBuffer(data, std::nullopt, semantics)); +} + +extern "C" void reactant_generic_llvm_rtti_root_dtor(llvm::RTTIRoot* root) { + delete root; } diff --git a/src/xla/IFRT/Array.jl b/src/xla/IFRT/Array.jl index afc69bee10..ef2ffa409f 100644 --- a/src/xla/IFRT/Array.jl +++ b/src/xla/IFRT/Array.jl @@ -15,7 +15,7 @@ mutable struct Array end function free_array(array) - @ccall MLIR.API.mlir_c.reactant_release_ifrt_pjrt_array(array.ptr::Ptr{Cvoid})::Cvoid + @ccall MLIR.API.mlir_c.reactant_release_ifrt_array(array.ptr::Ptr{Cvoid})::Cvoid end function Array(client::Client, buffer::XLA.Buffer) diff --git a/src/xla/IFRT/Client.jl b/src/xla/IFRT/Client.jl index 0bd2502d3b..d75d1c5fcf 100644 --- a/src/xla/IFRT/Client.jl +++ b/src/xla/IFRT/Client.jl @@ -19,12 +19,12 @@ function Client(pjrt_client::XLA.Client) end function free_client(client) - @ccall MLIR.API.mlir_c.ifrt_pjrt_FreeClient(client.ptr::Ptr{Cvoid})::Cvoid + @ccall MLIR.API.mlir_c.ifrt_FreeClient(client.ptr::Ptr{Cvoid})::Cvoid end function compile(client::Client, code::MLIR.IR.Module) return LoadedExecutable( - @ccall MLIR.API.mlir_c.ifrt_pjrt_ClientCompile( + @ccall MLIR.API.mlir_c.ifrt_ClientCompile( client.ptr::Ptr{Cvoid}, code.module_::MLIR.API.MlirModule )::Ptr{Cvoid} ) From c76a9002b0c36a29b553e8bbb5083d2d41319b23 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sun, 9 Feb 2025 12:36:50 -0600 Subject: [PATCH 10/13] Remove Julia code --- src/xla/Buffer.jl | 21 ++------------ src/xla/Client.jl | 18 ++---------- src/xla/IFRT/Array.jl | 38 ------------------------- src/xla/IFRT/Client.jl | 31 --------------------- src/xla/IFRT/IFRT.jl | 13 --------- src/xla/IFRT/LoadedExecutable.jl | 48 -------------------------------- src/xla/XLA.jl | 2 -- 7 files changed, 5 insertions(+), 166 deletions(-) delete mode 100644 src/xla/IFRT/Array.jl delete mode 100644 src/xla/IFRT/Client.jl delete mode 100644 src/xla/IFRT/IFRT.jl delete mode 100644 src/xla/IFRT/LoadedExecutable.jl diff --git a/src/xla/Buffer.jl b/src/xla/Buffer.jl index 52035de10a..06bf54e074 100644 --- a/src/xla/Buffer.jl +++ b/src/xla/Buffer.jl @@ -1,30 +1,15 @@ # Buffer @inline function free_buffer(buffer) - if buffer.holded == C_NULL - if buffer.buffer != C_NULL - @ccall MLIR.API.mlir_c.PjRtBufferFree(buffer.buffer::Ptr{Cvoid})::Cvoid - end - else - @ccall MLIR.API.mlir_c.reactant_release_pjrtbuffer(buffer.holded::Ptr{Cvoid})::Cvoid + if buffer.buffer != C_NULL + @ccall MLIR.API.mlir_c.PjRtBufferFree(buffer.buffer::Ptr{Cvoid})::Cvoid end end mutable struct Buffer buffer::Ptr{Cvoid} - holded::Ptr{Cvoid} function Buffer(buffer::Ptr{Cvoid}) - return finalizer(free_buffer, new(buffer, C_NULL)) - end -end - -function hold!(buffer::Buffer) - if buffer.holded == C_NULL - sbuffer = buffer.buffer - buffer.holded = @ccall MLIR.API.mlir_c.reactant_hold_pjrtbuffer( - sbuffer::Ptr{Cvoid} - )::Ptr{Cvoid} + return finalizer(free_buffer, new(buffer)) end - return buffer end function Base.ndims(buffer::Buffer) diff --git a/src/xla/Client.jl b/src/xla/Client.jl index 862cb3e876..a9ac15e318 100644 --- a/src/xla/Client.jl +++ b/src/xla/Client.jl @@ -1,13 +1,12 @@ mutable struct Client client::Ptr{Cvoid} global_ordinals::Vector{Cint} - holded::Ptr{Cvoid} function Client(client::Ptr{Cvoid}) @assert client != C_NULL global_ordinals = Cint[] - client = new(client, global_ordinals, C_NULL) + client = new(client, global_ordinals) # https://github.com/pytorch/xla/blob/8b2414094578e829b99a8383877c86d357eeb682/torch_xla/csrc/runtime/pjrt_computation_client.cc#L127 devices = [ @@ -30,20 +29,7 @@ end Base.:(==)(a::Client, b::Client) = a.client == b.client @inline function free_client(client::Client) - if client.holded == C_NULL - @ccall MLIR.API.mlir_c.FreeClient(client.client::Ptr{Cvoid})::Cvoid - else - @ccall MLIR.API.mlir_c.reactant_release_pjrtclient(client.holded::Ptr{Cvoid})::Cvoid - end -end - -function hold!(client::Client) - if client.holded == C_NULL - client.holded = @ccall MLIR.API.mlir_c.reactant_hold_pjrtclient( - client.client::Ptr{Cvoid} - )::Ptr{Cvoid} - end - return client + @ccall MLIR.API.mlir_c.FreeClient(client.client::Ptr{Cvoid})::Cvoid end function ClientNumDevices(client::Client) diff --git a/src/xla/IFRT/Array.jl b/src/xla/IFRT/Array.jl deleted file mode 100644 index ef2ffa409f..0000000000 --- a/src/xla/IFRT/Array.jl +++ /dev/null @@ -1,38 +0,0 @@ -@cenum ArrayCopySemantics::UInt32 begin - AlwaysCopy = 0 - ReuseInput = 1 - DonateInput = 2 -end - -# currently, only supports IFRT-PjRt -mutable struct Array - ptr::Ptr{Cvoid} - - function Array(ptr::Ptr{Cvoid}) - @assert ptr != C_NULL - return finalizer(free_array, new(ptr)) - end -end - -function free_array(array) - @ccall MLIR.API.mlir_c.reactant_release_ifrt_array(array.ptr::Ptr{Cvoid})::Cvoid -end - -function Array(client::Client, buffer::XLA.Buffer) - hold!(buffer) - GC.@preserve client buffer begin - return Array( - @ccall MLIR.API.mlir_c.ifrt_pjrt_ArrayFromHostBuffer( - client.ptr::Ptr{Cvoid}, buffer.holded::Ptr{Cvoid} - )::Ptr{Cvoid} - ) - end -end - -function CopyArrayToHostBuffer(array::Array, data) - GC.@preserve array data begin - @ccall MLIR.API.mlir_c.ifrt_CopyArrayToHostBuffer( - array.ptr::Ptr{Cvoid}, data::Ptr{Cvoid}, AlwaysCopy::Cuint - )::Cvoid - end -end diff --git a/src/xla/IFRT/Client.jl b/src/xla/IFRT/Client.jl deleted file mode 100644 index d75d1c5fcf..0000000000 --- a/src/xla/IFRT/Client.jl +++ /dev/null @@ -1,31 +0,0 @@ -# currently, only supports IFRT-PjRt -mutable struct Client - ptr::Ptr{Cvoid} - - function Client(ptr::Ptr{Cvoid}) - @assert ptr != C_NULL - return finalizer(free_client, new(ptr)) - end -end - -function Client(pjrt_client::XLA.Client) - # it needs a `std::shared_ptr` - hold!(pjrt_client) - return Client( - @ccall MLIR.API.mlir_c.ifrt_pjrt_MakeClient( - pjrt_client.holded::Ptr{Cvoid} - )::Ptr{Cvoid} - ) -end - -function free_client(client) - @ccall MLIR.API.mlir_c.ifrt_FreeClient(client.ptr::Ptr{Cvoid})::Cvoid -end - -function compile(client::Client, code::MLIR.IR.Module) - return LoadedExecutable( - @ccall MLIR.API.mlir_c.ifrt_ClientCompile( - client.ptr::Ptr{Cvoid}, code.module_::MLIR.API.MlirModule - )::Ptr{Cvoid} - ) -end diff --git a/src/xla/IFRT/IFRT.jl b/src/xla/IFRT/IFRT.jl deleted file mode 100644 index def0304cd6..0000000000 --- a/src/xla/IFRT/IFRT.jl +++ /dev/null @@ -1,13 +0,0 @@ -module IFRT - -using CEnum - -import ..XLA -import .XLA: hold! -import ..MLIR - -include("LoadedExecutable.jl") -include("Client.jl") -include("Array.jl") - -end diff --git a/src/xla/IFRT/LoadedExecutable.jl b/src/xla/IFRT/LoadedExecutable.jl deleted file mode 100644 index f376947042..0000000000 --- a/src/xla/IFRT/LoadedExecutable.jl +++ /dev/null @@ -1,48 +0,0 @@ -# currently, only supports IFRT-PjRt -mutable struct LoadedExecutable - ptr::Ptr{Cvoid} - - function LoadedExecutable(ptr::Ptr{Cvoid}) - @assert ptr != C_NULL - return finalizer(free_exec, new(ptr)) - end -end - -@inline function free_exec(exec) - @ccall MLIR.API.mlir_c.ifrt_pjrt_FreeLoadedExecutable(exec.ptr::Ptr{Cvoid})::Cvoid -end - -function execute( - exec::LoadedExecutable, - args::NTuple{N,Ptr{Cvoid}}, - donated_mask::NTuple{N,UInt8}, - ::Val{n_results}, -) where {N,n_results} - results = Ref{NTuple{n_results,Ptr{Cvoid}}}() - has_future = Ref{UInt8}() - status = Ref{NTuple{1,Ptr{Cvoid}}}() # unused right now - - args = Base.RefValue(args) - donated_mask = Base.RefValue(donated_mask) - - GC.@preserve exec args donated_mask results has_future status begin - @ccall MLIR.API.mlir_c.ifrt_Execute( - exec.ptr::Ptr{Cvoid}, - N::Cint, - args::Ptr{Cvoid}, - donated_mask::Ptr{Cvoid}, - n_results::Cint, - Base.unsafe_convert(Ptr{Cvoid}, results)::Ptr{Cvoid}, - has_future::Ptr{Cvoid}, - status::Ptr{Cvoid}, - )::Cvoid - end - - @assert has_future[] == true - - results = results[] - - return ntuple(Val(n_results)) do i - return Array(results[i]) - end -end diff --git a/src/xla/XLA.jl b/src/xla/XLA.jl index d2b820d61f..fe00c69f65 100644 --- a/src/xla/XLA.jl +++ b/src/xla/XLA.jl @@ -95,6 +95,4 @@ function __init__() return nothing end -include("IFRT/IFRT.jl") - end From ea601cd53e8148187e3c876d0f4912544a7f88fd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sun, 9 Feb 2025 12:37:47 -0600 Subject: [PATCH 11/13] fix --- src/xla/Buffer.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/xla/Buffer.jl b/src/xla/Buffer.jl index 06bf54e074..52a0718655 100644 --- a/src/xla/Buffer.jl +++ b/src/xla/Buffer.jl @@ -1,7 +1,8 @@ # Buffer @inline function free_buffer(buffer) - if buffer.buffer != C_NULL - @ccall MLIR.API.mlir_c.PjRtBufferFree(buffer.buffer::Ptr{Cvoid})::Cvoid + sbuffer = buffer.buffer + if sbuffer != C_NULL + @ccall MLIR.API.mlir_c.PjRtBufferFree(sbuffer::Ptr{Cvoid})::Cvoid end end From ac5605ae0dbf47451e64d0fc6b558910b755e8c7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Mon, 10 Feb 2025 08:45:20 +0100 Subject: [PATCH 12/13] Remove `reactant_generic_llvm_rtti_root_dtor` --- deps/ReactantExtra/API.cpp | 4 ---- 1 file changed, 4 deletions(-) diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index 279b1ecc86..d8fedb3eee 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -1108,7 +1108,3 @@ extern "C" void ifrt_Execute(ifrt::LoadedExecutable* exec, int num_args, Holded< extern "C" FutureType* ifrt_CopyArrayToHostBuffer(Holded>* array, void* data, ifrt::ArrayCopySemantics semantics) { return new FutureType((*array)->CopyToHostBuffer(data, std::nullopt, semantics)); } - -extern "C" void reactant_generic_llvm_rtti_root_dtor(llvm::RTTIRoot* root) { - delete root; -} From cad339f1fa1ec9f91b7a49a2671b75050653f37d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Mon, 10 Feb 2025 16:39:51 +0100 Subject: [PATCH 13/13] rename `Holded` to `HeldValue` --- deps/ReactantExtra/API.cpp | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index 8695e0e2f9..de2bc52043 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -1111,10 +1111,10 @@ template struct unwrap_type> { typedef T type; template using unwrap_type_t = typename unwrap_type::type; template -struct Holded { +struct HeldValue { public: - Holded(T& obj) : holded(obj) {} - ~Holded() = default; + HeldValue(T& obj) : holded(obj) {} + ~HeldValue() = default; unwrap_type_t* ptr() const { return holded.get(); @@ -1137,27 +1137,27 @@ struct Holded { }; template -Holded* capture(T obj) { - return new Holded(obj); +HeldValue* capture(T obj) { + return new HeldValue(obj); } } // namespace reactant -using reactant::Holded; +using reactant::HeldValue; -extern "C" Holded>* reactant_hold_pjrtclient(xla::PjRtClient* client) { +extern "C" HeldValue>* reactant_hold_pjrtclient(xla::PjRtClient* client) { return reactant::capture(std::shared_ptr(client)); } -extern "C" void reactant_release_pjrtclient(Holded>* client) { delete client; } +extern "C" void reactant_release_pjrtclient(HeldValue>* client) { delete client; } -extern "C" Holded>* reactant_hold_pjrtbuffer(xla::PjRtBuffer* buffer) { +extern "C" HeldValue>* reactant_hold_pjrtbuffer(xla::PjRtBuffer* buffer) { return reactant::capture(std::shared_ptr(buffer)); } -extern "C" void reactant_release_pjrtbuffer(Holded>* buffer) { delete buffer; } +extern "C" void reactant_release_pjrtbuffer(HeldValue>* buffer) { delete buffer; } -extern "C" ifrt::Client* ifrt_pjrt_MakeClient(Holded>* pjrt_client) { +extern "C" ifrt::Client* ifrt_pjrt_MakeClient(HeldValue>* pjrt_client) { xla::ifrt::PjRtClient::CreateOptions options = {pjrt_client->obj()}; return MyValueOrThrow(xla::ifrt::PjRtClient::Create(options)).release(); } @@ -1175,13 +1175,13 @@ extern "C" xla::ifrt::LoadedExecutable* ifrt_ClientCompile(ifrt::PjRtClient* cli extern "C" void ifrt_pjrt_FreeLoadedExecutable(xla::ifrt::PjRtLoadedExecutable* exec) { delete exec; } // TODO replace with `Client::MakeArrayFromHostBuffer` and generalize to `ifrt::Client` -extern "C" Holded>* ifrt_pjrt_ArrayFromHostBuffer(ifrt::PjRtClient* client, Holded>* buffer) { +extern "C" HeldValue>* ifrt_pjrt_ArrayFromHostBuffer(ifrt::PjRtClient* client, HeldValue>* buffer) { return reactant::capture(tsl::RCReference(MyValueOrThrow(xla::ifrt::PjRtArray::Create(client, buffer->obj())))); } -extern "C" void reactant_release_ifrt_array(Holded>* array) { delete array; } +extern "C" void reactant_release_ifrt_array(HeldValue>* array) { delete array; } -extern "C" void ifrt_Execute(ifrt::LoadedExecutable* exec, int num_args, Holded>** op_args, uint8_t* is_arg_donatable, int num_results, Holded>** op_results, uint8_t *futures, FutureType** status) { +extern "C" void ifrt_Execute(ifrt::LoadedExecutable* exec, int num_args, HeldValue>** op_args, uint8_t* is_arg_donatable, int num_results, HeldValue>** op_results, uint8_t *futures, FutureType** status) { std::vector> args; for (int i = 0; i < num_args; i++) { args.emplace_back(op_args[i]->obj()); @@ -1213,6 +1213,6 @@ extern "C" void ifrt_Execute(ifrt::LoadedExecutable* exec, int num_args, Holded< } // in principle, use ArrayCopySemantics::kAlwaysCopy (=0) -extern "C" FutureType* ifrt_CopyArrayToHostBuffer(Holded>* array, void* data, ifrt::ArrayCopySemantics semantics) { +extern "C" FutureType* ifrt_CopyArrayToHostBuffer(HeldValue>* array, void* data, ifrt::ArrayCopySemantics semantics) { return new FutureType((*array)->CopyToHostBuffer(data, std::nullopt, semantics)); }