Skip to content

Allocator and data transfer support for plugin EP API #25070

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
Jun 27, 2025
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions cmake/onnxruntime_unittests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -1835,10 +1835,9 @@ endif()
if (WIN32 AND onnxruntime_BUILD_SHARED_LIB AND
NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten" AND
NOT onnxruntime_MINIMAL_BUILD)
onnxruntime_add_shared_library_module(example_plugin_ep
${TEST_SRC_DIR}/autoep/library/example_plugin_ep_utils.h
${TEST_SRC_DIR}/autoep/library/example_plugin_ep_utils.cc
${TEST_SRC_DIR}/autoep/library/example_plugin_ep.cc)
file(GLOB onnxruntime_autoep_test_library_src "${TEST_SRC_DIR}/autoep/library/*.h"
"${TEST_SRC_DIR}/autoep/library/*.cc")
onnxruntime_add_shared_library_module(example_plugin_ep ${onnxruntime_autoep_test_library_src})
target_include_directories(example_plugin_ep PRIVATE ${REPO_ROOT}/include/onnxruntime/core/session)
target_link_libraries(example_plugin_ep PRIVATE onnxruntime)

Expand Down
21 changes: 21 additions & 0 deletions include/onnxruntime/core/framework/allocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

#include "core/common/common.h"
#include "core/framework/allocator_stats.h"
#include "core/session/abi_key_value_pairs.h"
// some enums are defined in session/onnxruntime_c_api.h but used in ortdevice.h/ortmemory.h
#include "core/session/onnxruntime_c_api.h"
#include "core/framework/ortdevice.h"
Expand Down Expand Up @@ -37,6 +38,26 @@ struct OrtArenaCfg {
int max_dead_bytes_per_chunk; // use -1 to allow ORT to choose the default
int initial_growth_chunk_size_bytes; // use -1 to allow ORT to choose the default
int64_t max_power_of_two_extend_bytes; // use -1 to allow ORT to choose the default

bool IsValid() {
return arena_extend_strategy >= -1 && arena_extend_strategy <= 1 &&
initial_chunk_size_bytes >= -1 &&
max_dead_bytes_per_chunk >= -1 &&
initial_growth_chunk_size_bytes >= -1 &&
max_power_of_two_extend_bytes >= -1;
}

// config key names that we parse in FromKeyValuePairs
struct ConfigKeyNames {
static constexpr const char* ArenaExtendStrategy = "arena.extend_strategy";
static constexpr const char* InitialChunkSizeBytes = "arena.initial_chunk_size_bytes";
static constexpr const char* MaxDeadBytesPerChunk = "arena.max_dead_bytes_per_chunk";
static constexpr const char* InitialGrowthChunkSizeBytes = "arena.initial_growth_chunk_size_bytes";
static constexpr const char* MaxPowerOfTwoExtendBytes = "arena.max_power_of_two_extend_bytes";
static constexpr const char* MaxMem = "arena.max_mem";
};

static onnxruntime::common::Status FromKeyValuePairs(const OrtKeyValuePairs& kvps, OrtArenaCfg& cfg);
};

namespace onnxruntime {
Expand Down
1 change: 1 addition & 0 deletions include/onnxruntime/core/framework/execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class GraphOptimizerRegistry;
#include "core/framework/stream_handles.h"
#include "core/framework/tuning_context.h"

struct OrtEpDevice;
struct OrtRunOptions;

namespace onnxruntime {
Expand Down
35 changes: 34 additions & 1 deletion include/onnxruntime/core/session/environment.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@
class InferenceSession;
struct IExecutionProviderFactory;
struct SessionOptions;
namespace plugin_ep {
class DataTransfer;
} // namespace plugin_ep

/**
Provides the runtime environment for onnxruntime.
Expand Down Expand Up @@ -85,7 +88,7 @@
* Registers an allocator for sharing between multiple sessions.
* Return an error if an allocator with the same OrtMemoryInfo is already registered.
*/
Status RegisterAllocator(AllocatorPtr allocator);
Status RegisterAllocator(OrtAllocator* allocator);

/**
* Creates and registers an allocator for sharing between multiple sessions.
Expand Down Expand Up @@ -130,7 +133,16 @@
const std::vector<const OrtEpDevice*>& GetOrtEpDevices() const {
return execution_devices_;
}

Status CreateSharedAllocator(const OrtEpDevice& ep_device,
OrtDeviceMemoryType mem_type, OrtAllocatorType allocator_type,
const OrtKeyValuePairs* allocator_options, OrtAllocator** allocator);
Status ReleaseSharedAllocator(const OrtEpDevice& ep_device, OrtDeviceMemoryType mem_type);
#endif // !defined(ORT_MINIMAL_BUILD)

// return a shared allocator from a plugin EP or custom allocator added with RegisterAllocator
OrtAllocator* GetSharedAllocator(const OrtMemoryInfo& mem_info);

~Environment();

private:
Expand All @@ -140,12 +152,33 @@
const OrtThreadingOptions* tp_options = nullptr,
bool create_global_thread_pools = false);

Status RegisterAllocatorImpl(AllocatorPtr allocator, bool replace_existing);
Status UnregisterAllocatorImpl(const OrtMemoryInfo& mem_info, bool error_if_not_found = true);
Status CreateSharedAllocatorImpl(const OrtEpDevice& ep_device,
const OrtMemoryInfo& memory_info, OrtAllocatorType allocator_type,
const OrtKeyValuePairs* allocator_options, OrtAllocator** allocator,
bool replace_existing);

std::unique_ptr<logging::LoggingManager> logging_manager_;
std::unique_ptr<onnxruntime::concurrency::ThreadPool> intra_op_thread_pool_;
std::unique_ptr<onnxruntime::concurrency::ThreadPool> inter_op_thread_pool_;
bool create_global_thread_pools_{false};

std::mutex mutex_;

// shared allocators from various sources.
// CreateAndRegisterAllocator[V2]: IAllocator allocators created by ORT
// RegisterAllocator: IAllocatorImplWrappingOrtAllocator custom allocators registered by the user.
// TODO: How can we detect registration of an allocator from an InferenceSession?

Check warning on line 172 in include/onnxruntime/core/session/environment.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2] Raw Output: include/onnxruntime/core/session/environment.h:172: Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2]
// OrtEpDevice: We create a default shared IAllocatorImplWrappingOrtAllocator for each OrtEpDevice memory info.
std::vector<AllocatorPtr> shared_allocators_;

// RegisterAllocator and CreateSharedAllocator pointers. Used for GetSharedAllocator.
// Every instance here is also in shared_allocators_.
std::unordered_set<OrtAllocator*> shared_ort_allocators_;

using OrtAllocatorUniquePtr = std::unique_ptr<OrtAllocator, std::function<void(OrtAllocator*)>>;

#if !defined(ORT_MINIMAL_BUILD)
// register EPs that are built into the ORT binary so they can take part in AutoEP selection
// added to ep_libraries
Expand Down
90 changes: 88 additions & 2 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,9 @@ extern "C" {
#define _In_opt_
#define _In_opt_z_
#define _Out_
#define _Outptr_
#define _Out_opt_
#define _Outptr_
#define _Outptr_opt_
#define _Inout_
#define _Inout_opt_
#define _Frees_ptr_opt_
Expand Down Expand Up @@ -4581,7 +4582,8 @@ struct OrtApi {
* \param[in] provider_options_values value of the provider options map
* \param[in] num_keys Length of the provider options map
*/
ORT_API2_STATUS(CreateAndRegisterAllocatorV2, _Inout_ OrtEnv* env, _In_ const char* provider_type, _In_ const OrtMemoryInfo* mem_info, _In_ const OrtArenaCfg* arena_cfg,
ORT_API2_STATUS(CreateAndRegisterAllocatorV2, _Inout_ OrtEnv* env, _In_ const char* provider_type,
_In_ const OrtMemoryInfo* mem_info, _In_ const OrtArenaCfg* arena_cfg,
_In_reads_(num_keys) const char* const* provider_options_keys, _In_reads_(num_keys) const char* const* provider_options_values, _In_ size_t num_keys);

/** \brief Run the model asynchronously in a thread owned by intra op thread pool
Expand Down Expand Up @@ -5930,6 +5932,90 @@ struct OrtApi {
_In_z_ const char* config_key, _Outptr_result_maybenull_z_ const char** config_value);

/// @}

/** \brief Get the OrtMemoryInfo for the device.
*
* \param[in] ep_device The OrtEpDevice instance to query.
* \return A pointer to the OrtMemoryInfo for the device.
*
* \since Version 1.23
*/
ORT_API_T(const OrtMemoryInfo*, EpDevice_MemoryInfo, _In_ const OrtEpDevice* ep_device);

/** \brief Create/replace a shared allocator for the OrtEpDevice in the OrtEnv.
*
* OrtEpDevice maps to the EP factory, and the factory provides the allocator implementation.
*
* Both OrtDeviceMemoryType_DEFAULT and OrtDeviceMemoryType_HOST_ACCESSIBLE are optional for an EP to provide.
* It is EP implementation dependent as to what is available.
*
* If a shared allocator already exists for the OrtEpDevice and OrtDeviceMemoryType, it is replaced. This allows
* changing the shared allocator configuration from the default. e.g. adding an arena.
*
* \param[in] env The OrtEnv instance to create the shared allocator in.
* \param[in] ep_device The OrtEpDevice instance to create the shared allocator for.
* \param[in] mem_type The memory type to use for the shared allocator.
* \param[in] allocator_type The type of allocator to create (e.g. OrtAllocatorType::OrtArenaAllocator).
* \param[in] allocator_options Optional key-value pairs to configure the allocator. If arena based, see
* include/onnxruntime/core/framework/allocator.h for the keys and values that can be
* used.
* \param[out] allocator A pointer to the created shared allocator. Owned by the OrtEnv instance.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.23
*/
ORT_API2_STATUS(CreateSharedAllocator, _In_ OrtEnv* env, _In_ const OrtEpDevice* ep_device,
_In_ OrtDeviceMemoryType mem_type, _In_ OrtAllocatorType allocator_type,
_In_opt_ const OrtKeyValuePairs* allocator_options,
_Outptr_opt_ OrtAllocator** allocator);

/** \brief Get a shared allocator from the OrtEnv.
*
* By default there is a shared allocator created for all OrtEpDevice instances, so if you get the OrtMemoryInfo
* from the OrtEpDevice using EpDevice_MemoryInfo a shared allocator is guaranteed to exist.
*
* This will also match and return custom allocators added with RegisterAllocator.
*
* \param[in] env The OrtEnv instance to get the shared allocator from.
* \param[in] mem_info The OrtMemoryInfo instance to get the shared allocator for.
* \return A pointer to the shared allocator, or nullptr if no shared allocator exists for the given memory info.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.23
*/
ORT_API_T(OrtAllocator*, GetSharedAllocator, _In_ OrtEnv* env, _In_ const OrtMemoryInfo* mem_info);

/** \brief Release a shared allocator from the OrtEnv for the OrtEpDevice and memory type.
*
* This will release the shared allocator for the given OrtEpDevice and memory type.
* If no shared allocator exists, this is a no-op.
*
* \param[in] env The OrtEnv instance to release the shared allocator from.
* \param[in] ep_device The OrtEpDevice instance to release the shared allocator for.
* \param[in] mem_type The memory type of the shared allocator to release.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.23
*/
ORT_API2_STATUS(ReleaseSharedAllocator, _In_ OrtEnv* env, _In_ const OrtEpDevice* ep_device,
_In_ OrtDeviceMemoryType mem_type);

/** \brief Get a const pointer to the raw data inside a tensor
*
* Used to read the internal tensor data directly.
* \note The returned pointer is valid until the \p value is destroyed.
*
* \param[in] value A tensor type (string tensors are not supported)
* \param[out] out Filled in with a pointer to the internal storage
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.23
*/
ORT_API2_STATUS(GetTensorData, _In_ const OrtValue* value, _Outptr_ const void** out);
};

/*
Expand Down
Loading
Loading