-
Notifications
You must be signed in to change notification settings - Fork 3.3k
Allocator and data transfer support for plugin EP API #25070
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can commit the suggested changes from lintrunner.
@@ -37,6 +44,26 @@ struct OrtArenaCfg { | |||
int max_dead_bytes_per_chunk; // use -1 to allow ORT to choose the default | |||
int initial_growth_chunk_size_bytes; // use -1 to allow ORT to choose the default | |||
int64_t max_power_of_two_extend_bytes; // use -1 to allow ORT to choose the default | |||
|
|||
bool IsValid() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moving validity check that was scattered around to be here.
@@ -75,7 +78,7 @@ class Environment { | |||
* Registers an allocator for sharing between multiple sessions. | |||
* Return an error if an allocator with the same OrtMemoryInfo is already registered. | |||
*/ | |||
Status RegisterAllocator(AllocatorPtr allocator); | |||
Status RegisterAllocator(OrtAllocator* allocator); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Separating external from internal usage by making the public method take OrtAllocator*. This enables GetSharedAllocator to return any of the OrtAllocator based instances.
ORT_API2_STATUS(CopyTensors, _In_ void* this_ptr, | ||
_In_reads_(num_tensors) const OrtValue** src_tensors, | ||
_In_reads_(num_tensors) OrtValue** dst_tensors, | ||
_In_reads_(num_tensors) OrtSyncStream** streams, | ||
_In_ size_t num_tensors); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Design choice: Minimize the EP API by having a single function that can do asynchronous batched copies. The library is free to implement that as synchronous copies. This avoids having additional functions for copying a single tensor.
@@ -6096,7 +6185,109 @@ struct OrtCompileApi { | |||
|
|||
ORT_RUNTIME_CLASS(Ep); | |||
ORT_RUNTIME_CLASS(EpFactory); | |||
ORT_RUNTIME_CLASS(MemoryDevice); // opaque class to wrap onnxruntime::OrtDevice |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I couldn't avoid adding OrtDevice to the API as IDataTransfer uses it. To try and disambiguate/clarify I called it OrtMemoryDevice as we have OrtHardwareDevice, and calling it OrtDevice
in the API felt too vague.
#define ORT_API_T(RETURN_TYPE, NAME, ...) \ | ||
RETURN_TYPE(ORT_API_CALL* NAME)(__VA_ARGS__) NO_EXCEPTION | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We were missing the NO_EXCEPTION on a number of functions in the initial plugin EP setup. This macro should help keep things consistent by adding NO_EXCEPTION like ORT_API2_STATUS does.
@@ -15,6 +17,61 @@ | |||
|
|||
#include "core/framework/bfc_arena.h" | |||
|
|||
using Status = onnxruntime::common::Status; | |||
|
|||
Status OrtArenaCfg::FromKeyValuePairs(const OrtKeyValuePairs& kvps, OrtArenaCfg& cfg) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Support creation from OrtKeyValuePairs to keep the public API minimal. We use OrtKeyValuePairs in general for configuration/options/metadata and this avoids having to add OrtArenaCfg as an additional parameter in some functions like CreateSharedAllocator.
@@ -16,6 +16,20 @@ Status DataTransferManager::RegisterDataTransfer(std::unique_ptr<IDataTransfer> | |||
return Status::OK(); | |||
} | |||
|
|||
Status DataTransferManager::UnregisterDataTransfer(IDataTransfer* data_transfer) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to unregister if the plugin EP library is unloaded.
auto st = env->CreateAndRegisterAllocator(*mem_info, arena_cfg); | ||
auto& env = ort_env->GetEnvironment(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed unnecessary call from here -> OrtEnv.CreateAndRegisterAllocator -> Environment.CreateAndRegisterAllocator.
As the Environment instance is available from OrtEnv there's no need to it to have a CreateAndRegisterAllocator method that forwards.
Fix MIGraphX build.
// some enums are defined in session/onnxruntime_c_api.h but used in ortdevice.h/ortmemory.h | ||
#include "core/session/onnxruntime_c_api.h" | ||
#include "core/framework/ortdevice.h" | ||
#include "core/framework/ortmemoryinfo.h" | ||
|
||
namespace onnxruntime { | ||
namespace common { | ||
class Status; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: common.h already includes status.h, so we don't need a forward declaration. we could explicitly include status.h too for clarity.
a user of this file would need a complete Status definition anyway to call FromKeyValuePairs().
* | ||
* OrtEpDevice maps to the EP factory, and the factory provides the allocator implementation. | ||
* | ||
* OrtDeviceMemoryType_DEFAULT is always supported for non-CPU based devices. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is the device memory allocator always required? does this also apply to the QNN EP, which only has the HtpSharedMemoryAllocator?
ORT_RETURN_IF_ERROR(from_string(it->first, it->second, value)); | ||
cfg.arena_extend_strategy = onnxruntime::narrow<int32_t>(value); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
consider using ParseStringWithClassicLocale()
in parse_string.h. it handles conversions to specific int types directly.
Description
Add allocator and data transfer infrastructure for plugin EP API
Allocators are created via the OrtEpFactory using OrtMemoryInfo that as added to the OrtEpDevice instances the factory returns. This allows allocators to be created outside of an inference session and shared.
When a library is loaded a default instance of each allocator is added to the shared allocators if there is no existing allocator (e.g. user provided custom allocator).
CreateSharedAllocator can be used to replace this default instance with a user configured one. e.g. add an arena or provide other configuration options that are passed through to the OrtEpFactory's CreateAllocator function.
Similarly IDataTransfer is supported by the factory implementing OrtDataTransferImpl, which will also enable data transfer outside of a session. That will be added in a future PR as the synchronization requirements need to be figured out and will affect the public API.
Motivation and Context