Skip to content

Commit

Permalink
Switched to DirectML as GPU provider on Windows
Browse files Browse the repository at this point in the history
  • Loading branch information
Acly committed Jul 29, 2023
1 parent 0644e94 commit e87e353
Show file tree
Hide file tree
Showing 7 changed files with 152 additions and 55 deletions.
24 changes: 15 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,19 @@
* Optimized for minimal copying and overhead
* Fully C++ based neural network inference (via [onnxruntime](https://onnxruntime.ai/))
* *Platforms:* Windows, Linux
* *Backends:* CPU, CUDA
* *Backends:* CPU, GPU via DirectML (Windows only), GPU via CUDA (Linux/NVIDIA only)

## Features

### Segmentation

Identify objects in an image and mask them (based on [SegmentAnything](https://segment-anything.com))
Identify objects in an image and generate masks for them (based on [SegmentAnything](https://segment-anything.com))

```cpp
// Load an image...
Image image = Image::load("example.png");
// ...or use existing image data:
// ImageView image(pixel_data, {width, height}, Channels::rgba);
ImageView image(pixel_data, {width, height}, Channels::rgba);

// Analyse the image
Environment env;
Expand All @@ -32,6 +32,8 @@ Image mask = segmentation.compute_mask(Point{220, 355});
Image mask = segmentation.compute_mask(Region(Point{140, 200}, Extent{300, 300}));
```
Performance is interactive: roughly 500ms for `Segmentation::process` and 80ms per mask on CPU. Running on GPU can be much faster: 50ms and 12ms respectively on RTX4070, with around 500MB of VRAM used.
## Building
Expand All @@ -54,18 +56,22 @@ cmake --build . --config Release
```


## Using
## Documentation

The library can be added to existing CMake projects either via `add_subdirectory(dlimgedit/src)` to build from source, or by adding the target from installed binaries with `find_package(dlimgedit)`. Packages should work out of the box on CPU. The `onnxruntime` shared library is installed as a required runtime dependency.
The library can be added to existing CMake projects either via `add_subdirectory(dlimgedit/src)` to build from source, or by adding the target from installed binaries with `find_package(dlimgedit)`. Packages should work out of the box on CPU. The `onnxruntime` shared library is installed as a required runtime dependency. Execution on GPU may require further libraries at runtime, see below.

The official public API is C++14 compatible.
The public API is C++14 compatible.

See the [public header](src/include/dlimgedit/dlimgedit.hpp) for API documentation.

### CUDA Backend
### GPU on Windows (DirectML)

Using `Backend::gpu` on Windows makes use of [DirectML](https://github.com/microsoft/DirectML) to run inference on GPU. A large range of GPUs is supported. Deploying `DirectML.dll` ([nuget](https://www.nuget.org/api/v2/package/Microsoft.AI.DirectML/1.12.0)) next to applications is recommended, otherwise the version of the DLL which ships with Windows will be used, and it is usually too old.

### GPU on Linux (CUDA)

Using `Backend::gpu` makes use of CUDA to run inference on GPU. This requires the following additional libraries to be installed:
Using `Backend::gpu` on Linux makes use of CUDA to run inference on GPU. This requires the following additional libraries to be installed:
* [NVIDIA CUDA Toolkit (Version 11.x)](https://developer.nvidia.com/cuda-11-8-0-download-archive)
* [NVIDIA cuDNN (Version 8.x for CUDA 11.x)](https://developer.nvidia.com/cudnn)

On Windows the location of the DLL files must be added to the PATH environment variable or copied to the executable folder. Refer to [NVIDIA's installation instructions](https://docs.nvidia.com/deeplearning/cudnn/install-guide/index.html) for detailed instructions.
Refer to [NVIDIA's installation guide](https://docs.nvidia.com/deeplearning/cudnn/install-guide/index.html) for detailed instructions.
89 changes: 55 additions & 34 deletions depend/onnx/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,40 +1,61 @@
include(FetchContent)

if(WIN32)
set(OS win)
set(EXT zip)
else()
set(OS linux)
set(EXT tgz)
endif()
set(PREFIX ${CMAKE_SHARED_LIBRARY_PREFIX})
set(SUFFIX ${CMAKE_SHARED_LIBRARY_SUFFIX})
# Use onnxruntime with DirectML

# DirectML.dll is installed with windows, but the version is typically too old.
FetchContent_Declare(
directml
URL https://www.nuget.org/api/v2/package/Microsoft.AI.DirectML/1.12.0
DOWNLOAD_EXTRACT_TIMESTAMP true
)
FetchContent_Declare(
onnxruntime
URL https://github.com/microsoft/onnxruntime/releases/download/v1.15.1/Microsoft.ML.OnnxRuntime.DirectML.1.15.1.zip
DOWNLOAD_EXTRACT_TIMESTAMP true
)
FetchContent_MakeAvailable(directml onnxruntime)

add_library(directml SHARED IMPORTED GLOBAL)
set_target_properties(directml PROPERTIES
IMPORTED_LOCATION ${directml_SOURCE_DIR}/bin/x64-win/DirectML.dll
IMPORTED_IMPLIB ${directml_SOURCE_DIR}/bin/x64-win/DirectML.lib
)
add_library(onnxruntime SHARED IMPORTED GLOBAL)
set_target_properties(onnxruntime PROPERTIES
IMPORTED_LOCATION ${onnxruntime_SOURCE_DIR}/runtimes/win-x64/native/onnxruntime.dll
IMPORTED_IMPLIB ${onnxruntime_SOURCE_DIR}/runtimes/win-x64/native/onnxruntime.lib
INTERFACE_INCLUDE_DIRECTORIES ${onnxruntime_SOURCE_DIR}/build/native/include
)
add_dependencies(onnxruntime directml)
set(ONNX_RUNTIME_DEPENDENCIES
$<TARGET_FILE:onnxruntime>
$<TARGET_FILE:directml>
PARENT_SCOPE)

FetchContent_Declare(
onnxruntime
URL https://github.com/microsoft/onnxruntime/releases/download/v1.15.1/onnxruntime-${OS}-x64-gpu-1.15.1.${EXT}
DOWNLOAD_EXTRACT_TIMESTAMP true
)
FetchContent_MakeAvailable(onnxruntime)
else() # Linux
# Use onnxruntime with CUDA

add_library(onnxruntime_providers_shared SHARED IMPORTED GLOBAL)
set_target_properties(onnxruntime_providers_shared PROPERTIES
IMPORTED_LOCATION ${onnxruntime_SOURCE_DIR}/lib/${PREFIX}onnxruntime_providers_shared${SUFFIX}
)
add_library(onnxruntime_providers_cuda SHARED IMPORTED GLOBAL)
set_target_properties(onnxruntime_providers_cuda PROPERTIES
IMPORTED_LOCATION ${onnxruntime_SOURCE_DIR}/lib/${PREFIX}onnxruntime_providers_cuda${SUFFIX}
)
add_library(onnxruntime SHARED IMPORTED GLOBAL)
set_target_properties(onnxruntime PROPERTIES
IMPORTED_LOCATION ${onnxruntime_SOURCE_DIR}/lib/${PREFIX}onnxruntime${SUFFIX}
IMPORTED_IMPLIB ${onnxruntime_SOURCE_DIR}/lib/onnxruntime.lib
INTERFACE_INCLUDE_DIRECTORIES ${onnxruntime_SOURCE_DIR}/include
)
add_dependencies(onnxruntime onnxruntime_providers_shared onnxruntime_providers_cuda)
FetchContent_Declare(
onnxruntime
URL https://github.com/microsoft/onnxruntime/releases/download/v1.15.1/onnxruntime-linux-x64-gpu-1.15.1.tgz
DOWNLOAD_EXTRACT_TIMESTAMP true
)
FetchContent_MakeAvailable(onnxruntime)

set(ONNX_RUNTIME_DEPENDENCIES
$<TARGET_FILE:onnxruntime>
$<TARGET_FILE:onnxruntime_providers_shared>
$<TARGET_FILE:onnxruntime_providers_cuda>
PARENT_SCOPE)
add_library(onnxruntime_providers_shared SHARED IMPORTED GLOBAL)
set_target_properties(onnxruntime_providers_shared PROPERTIES
IMPORTED_LOCATION ${onnxruntime_SOURCE_DIR}/lib/libonnxruntime_providers_shared.so
)
add_library(onnxruntime_providers_cuda SHARED IMPORTED GLOBAL)
set_target_properties(onnxruntime_providers_cuda PROPERTIES
IMPORTED_LOCATION ${onnxruntime_SOURCE_DIR}/lib/libonnxruntime_providers_cuda.so
)
add_library(onnxruntime SHARED IMPORTED GLOBAL)
set_target_properties(onnxruntime PROPERTIES
IMPORTED_LOCATION ${onnxruntime_SOURCE_DIR}/lib/libonnxruntime.so
INTERFACE_INCLUDE_DIRECTORIES ${onnxruntime_SOURCE_DIR}/include
)
add_dependencies(onnxruntime onnxruntime_providers_shared onnxruntime_providers_cuda)

endif()
3 changes: 2 additions & 1 deletion src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@ target_include_directories(dlimgedit PUBLIC
target_link_libraries(dlimgedit PRIVATE onnxruntime stb Eigen3::Eigen)

if (MSVC)
target_compile_options(dlimgedit PRIVATE /W4 /WX)
target_compile_options(dlimgedit PRIVATE /Zi /W4 /WX)
target_compile_options(dlimgedit PRIVATE /wd4127 /wd5054) # for Eigen
target_link_options(dlimgedit PRIVATE /DEBUG)

foreach (file ${ONNX_RUNTIME_DEPENDENCIES})
add_custom_command(TARGET dlimgedit POST_BUILD
Expand Down
21 changes: 12 additions & 9 deletions src/environment.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "environment.hpp"
#include "platform.hpp"
#include "segmentation.hpp"

#include <thread>
Expand All @@ -17,20 +18,22 @@ Path EnvironmentImpl::verify_path(std::string_view path) {
}

bool EnvironmentImpl::is_supported(Backend backend) {
constexpr char const* cpu_provider = "CPUExecutionProvider";
constexpr char const* gpu_provider =
is_windows ? "DmlExecutionProvider" : "CUDAExecutionProvider";

auto requested = backend == Backend::gpu ? gpu_provider : cpu_provider;
auto providers = Ort::GetAvailableProviders();
switch (backend) {
case Backend::cpu:
return true;
case Backend::gpu:
return std::find(providers.begin(), providers.end(), "CUDAExecutionProvider") !=
providers.end();
}
return false;
return std::find(providers.begin(), providers.end(), requested) != providers.end();
}

Ort::Env init_onnx() {
if (OrtGetApiBase()->GetApi(ORT_API_VERSION) == nullptr) {
throw Exception("Could not load onnxruntime library, version mismatch");
if (is_windows) {
throw Exception("Could not load onnxruntime library, version mismatch. Make sure "
"onnxruntime.dll is in the same directory as the executable.");
}
throw Exception("Could not load onnxruntime library, version mismatch.");
}
auto env = Ort::Env(OrtLoggingLevel::ORT_LOGGING_LEVEL_ERROR, "dlimgedit");
env.DisableTelemetryEvents();
Expand Down
12 changes: 12 additions & 0 deletions src/platform.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#pragma once

namespace dlimg {

#ifdef _WIN32
constexpr bool is_windows = true;
#else
constexpr bool is_windows = false;
#endif
constexpr bool is_linux = !is_windows;

} // namespace dlimg
56 changes: 54 additions & 2 deletions src/session.cpp
Original file line number Diff line number Diff line change
@@ -1,14 +1,62 @@
#include "session.hpp"
#include "environment.hpp"
#include "platform.hpp"

#include <dml_provider_factory.h>
#include <onnxruntime_c_api.h>

#include <memory>
#include <optional>
#include <string>

namespace dlimg {
namespace {

void check(OrtStatusPtr res) {
if (res != nullptr) {
auto msg = std::string(Ort::GetApi().GetErrorMessage(res));
Ort::GetApi().ReleaseStatus(res);
throw Exception(msg);
}
}

std::pair<Ort::SessionOptions, OrtSessionOptions*> create_session_options() {
OrtSessionOptions* opts;
check(Ort::GetApi().CreateSessionOptions(&opts));
return {Ort::SessionOptions(opts), opts};
}

std::optional<std::unique_lock<std::mutex>> lock_session(std::mutex& m,
EnvironmentImpl const& env) {
// Locking is required for Ort::Session::Run() when using DirectML provider.
// For other providers calling Run() concurrently is safe.
if (env.backend == Backend::gpu && is_windows) {
return std::unique_lock<std::mutex>(m);
}
return {};
}

} // namespace

Ort::Session create_session(EnvironmentImpl& env, char const* kind, char const* model) {
Ort::SessionOptions opts;
auto [opts, opts_raw] = create_session_options();
opts.SetIntraOpNumThreads(env.thread_count);
opts.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL);

if (env.backend == Backend::gpu) {
opts.AppendExecutionProvider_CUDA({});
if (is_windows) {
// Use DirectML. The following two options are required:
opts.SetExecutionMode(ExecutionMode::ORT_SEQUENTIAL);
opts.DisableMemPattern();

OrtDmlApi* dml_api = nullptr;
check(Ort::GetApi().GetExecutionProviderApi("DML", ORT_API_VERSION,
(void const**)(&dml_api)));
check(dml_api->SessionOptionsAppendExecutionProvider_DML(opts_raw, 0));
} else if (is_linux) {
// Use CUDA.
opts.AppendExecutionProvider_CUDA({});
}
}
Path model_path = env.model_directory / kind / model;
if (!exists(model_path)) {
Expand Down Expand Up @@ -40,13 +88,17 @@ Shape Session::output_shape(int index) const {
void Session::run(std::span<Ort::Value const> inputs, std::span<Ort::Value> outputs) {
ASSERT(inputs.size() == input_names_.size());
ASSERT(outputs.size() == output_names_.size());

auto lock = lock_session(mutex_, env_);
Ort::RunOptions opts{nullptr};
session_.Run(opts, input_names_.data(), inputs.data(), inputs.size(), output_names_.data(),
outputs.data(), outputs.size());
}

std::vector<Ort::Value> Session::run(std::span<Ort::Value const> inputs) {
ASSERT(inputs.size() == input_names_.size());

auto lock = lock_session(mutex_, env_);
Ort::RunOptions opts{nullptr};
return session_.Run(opts, input_names_.data(), inputs.data(), inputs.size(),
output_names_.data(), output_names_.size());
Expand Down
2 changes: 2 additions & 0 deletions src/session.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <onnxruntime_cxx_api.h>

#include <array>
#include <mutex>
#include <span>
#include <type_traits>

Expand Down Expand Up @@ -48,6 +49,7 @@ class Session {
Ort::Session session_;
std::span<char const* const> input_names_;
std::span<char const* const> output_names_;
std::mutex mutex_;
};

} // namespace dlimg

0 comments on commit e87e353

Please sign in to comment.