Skip to content

Allocator and data transfer support for plugin EP API #25070

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: main
Choose a base branch
from

Conversation

skottmckay
Copy link
Contributor

Description

Add allocator and data transfer infrastructure for plugin EP API

Allocators are created via the OrtEpFactory using OrtMemoryInfo that as added to the OrtEpDevice instances the factory returns. This allows allocators to be created outside of an inference session and shared.

When a library is loaded a default instance of each allocator is added to the shared allocators if there is no existing allocator (e.g. user provided custom allocator).
CreateSharedAllocator can be used to replace this default instance with a user configured one. e.g. add an arena or provide other configuration options that are passed through to the OrtEpFactory's CreateAllocator function.
 
Similarly IDataTransfer is supported by the factory implementing OrtDataTransferImpl, which will also enable data transfer outside of a session. That will be added in a future PR as the synchronization requirements need to be figured out and will affect the public API.

Motivation and Context

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can commit the suggested changes from lintrunner.

@@ -37,6 +44,26 @@ struct OrtArenaCfg {
int max_dead_bytes_per_chunk; // use -1 to allow ORT to choose the default
int initial_growth_chunk_size_bytes; // use -1 to allow ORT to choose the default
int64_t max_power_of_two_extend_bytes; // use -1 to allow ORT to choose the default

bool IsValid() {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moving validity check that was scattered around to be here.

@@ -75,7 +78,7 @@ class Environment {
* Registers an allocator for sharing between multiple sessions.
* Return an error if an allocator with the same OrtMemoryInfo is already registered.
*/
Status RegisterAllocator(AllocatorPtr allocator);
Status RegisterAllocator(OrtAllocator* allocator);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Separating external from internal usage by making the public method take OrtAllocator*. This enables GetSharedAllocator to return any of the OrtAllocator based instances.

Comment on lines +6237 to +6241
ORT_API2_STATUS(CopyTensors, _In_ void* this_ptr,
_In_reads_(num_tensors) const OrtValue** src_tensors,
_In_reads_(num_tensors) OrtValue** dst_tensors,
_In_reads_(num_tensors) OrtSyncStream** streams,
_In_ size_t num_tensors);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Design choice: Minimize the EP API by having a single function that can do asynchronous batched copies. The library is free to implement that as synchronous copies. This avoids having additional functions for copying a single tensor.

@@ -6096,7 +6185,109 @@ struct OrtCompileApi {

ORT_RUNTIME_CLASS(Ep);
ORT_RUNTIME_CLASS(EpFactory);
ORT_RUNTIME_CLASS(MemoryDevice); // opaque class to wrap onnxruntime::OrtDevice
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I couldn't avoid adding OrtDevice to the API as IDataTransfer uses it. To try and disambiguate/clarify I called it OrtMemoryDevice as we have OrtHardwareDevice, and calling it OrtDevice in the API felt too vague.

Comment on lines +145 to +147
#define ORT_API_T(RETURN_TYPE, NAME, ...) \
RETURN_TYPE(ORT_API_CALL* NAME)(__VA_ARGS__) NO_EXCEPTION

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We were missing the NO_EXCEPTION on a number of functions in the initial plugin EP setup. This macro should help keep things consistent by adding NO_EXCEPTION like ORT_API2_STATUS does.

@@ -15,6 +17,61 @@

#include "core/framework/bfc_arena.h"

using Status = onnxruntime::common::Status;

Status OrtArenaCfg::FromKeyValuePairs(const OrtKeyValuePairs& kvps, OrtArenaCfg& cfg) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Support creation from OrtKeyValuePairs to keep the public API minimal. We use OrtKeyValuePairs in general for configuration/options/metadata and this avoids having to add OrtArenaCfg as an additional parameter in some functions like CreateSharedAllocator.

@@ -16,6 +16,20 @@ Status DataTransferManager::RegisterDataTransfer(std::unique_ptr<IDataTransfer>
return Status::OK();
}

Status DataTransferManager::UnregisterDataTransfer(IDataTransfer* data_transfer) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to unregister if the plugin EP library is unloaded.

Comment on lines -130 to +140
auto st = env->CreateAndRegisterAllocator(*mem_info, arena_cfg);
auto& env = ort_env->GetEnvironment();
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed unnecessary call from here -> OrtEnv.CreateAndRegisterAllocator -> Environment.CreateAndRegisterAllocator.

As the Environment instance is available from OrtEnv there's no need to it to have a CreateAndRegisterAllocator method that forwards.

// some enums are defined in session/onnxruntime_c_api.h but used in ortdevice.h/ortmemory.h
#include "core/session/onnxruntime_c_api.h"
#include "core/framework/ortdevice.h"
#include "core/framework/ortmemoryinfo.h"

namespace onnxruntime {
namespace common {
class Status;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: common.h already includes status.h, so we don't need a forward declaration. we could explicitly include status.h too for clarity.

a user of this file would need a complete Status definition anyway to call FromKeyValuePairs().

*
* OrtEpDevice maps to the EP factory, and the factory provides the allocator implementation.
*
* OrtDeviceMemoryType_DEFAULT is always supported for non-CPU based devices.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is the device memory allocator always required? does this also apply to the QNN EP, which only has the HtpSharedMemoryAllocator?

Comment on lines +38 to +39
ORT_RETURN_IF_ERROR(from_string(it->first, it->second, value));
cfg.arena_extend_strategy = onnxruntime::narrow<int32_t>(value);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consider using ParseStringWithClassicLocale() in parse_string.h. it handles conversions to specific int types directly.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants