Skip to content

Commit

Permalink
refactor: Restructure the ONNX and mlpack plugins (#2025)
Browse files Browse the repository at this point in the history
Fixes the CMake configuration, python bindings, naming conventions, build dependencies etc. Should not change anything about the execution itself.
  • Loading branch information
paulgessinger committed Apr 12, 2023
1 parent 79fe828 commit 15aec1a
Show file tree
Hide file tree
Showing 24 changed files with 132 additions and 168 deletions.
11 changes: 5 additions & 6 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ set_option_if(
ACTS_BUILD_EXAMPLES OR ACTS_BUILD_EVERYTHING)
set_option_if(ACTS_BUILD_PLUGIN_LEGACY ACTS_BUILD_EVERYTHING)
set_option_if(ACTS_BUILD_PLUGIN_AUTODIFF ACTS_BUILD_EVERYTHING)
set_option_if(ACTS_BUILD_EXAMPLES_EXATRKX ACTS_BUILD_PLUGIN_EXATRKX)
set_option_if(ACTS_BUILD_PLUGIN_EXATRKX ACTS_BUILD_EXAMPLES_EXATRKX)

# feature tests
include(CheckCXXSourceCompiles)
Expand Down Expand Up @@ -291,12 +291,8 @@ if(ACTS_BUILD_PLUGIN_JSON)
add_subdirectory(thirdparty/nlohmann_json)
endif()
endif()
if(ACTS_BUILD_PLUGIN_ONNX)
find_package(OnnxRuntime ${_acts_onnxruntime_version} REQUIRED)
endif()
if(ACTS_BUILD_PLUGIN_MLPACK)
find_package(mlpack ${_acts_mlpack_version} REQUIRED)
include_directories(SYSTEM ${mlpack_INCLUDE_DIR})
endif()
if(ACTS_BUILD_PLUGIN_SYCL)
find_package(SYCL REQUIRED)
Expand All @@ -321,14 +317,17 @@ if(ACTS_BUILD_PLUGIN_EXATRKX)
)
endif()
if(ACTS_EXATRKX_ENABLE_ONNX)
find_package(OnnxRuntime REQUIRED)
find_package(cugraph REQUIRED)
endif()
if(ACTS_EXATRKX_ENABLE_TORCH)
find_package(TorchScatter REQUIRED)
endif()
endif()

if(ACTS_BUILD_PLUGIN_ONNX OR ACTS_EXATRKX_ENABLE_ONNX)
find_package(OnnxRuntime ${_acts_onnxruntime_version} REQUIRED)
endif()

# examples dependencies
if(ACTS_BUILD_EXAMPLES)
set(THREADS_PREFER_PTHREAD_FLAG ON)
Expand Down
2 changes: 1 addition & 1 deletion Examples/Algorithms/TrackFindingML/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ target_link_libraries(

if(ACTS_BUILD_PLUGIN_MLPACK)
target_link_libraries(
ActsExamplesTrackFindingML PUBLIC ActsPluginmlpack)
ActsExamplesTrackFindingML PUBLIC ActsPluginMlpack)
endif()

install(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

#include "ActsExamples/TrackFindingML/AmbiguityResolutionMLDBScanAlgorithm.hpp"

#include "Acts/Plugins/mlpack/AmbiguityDBScanClustering.hpp"
#include "Acts/Plugins/Mlpack/AmbiguityDBScanClustering.hpp"
#include "ActsExamples/Framework/ProcessCode.hpp"
#include "ActsExamples/Framework/WhiteBoard.hpp"

Expand Down
19 changes: 12 additions & 7 deletions Examples/Python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -166,14 +166,19 @@ endif()

if(ACTS_BUILD_PLUGIN_ONNX)
target_link_libraries(ActsPythonBindings PUBLIC ActsExamplesTrackFindingML)
target_sources(ActsPythonBindings PRIVATE src/MLTrackFinding.cpp)
else()
target_sources(ActsPythonBindings PRIVATE src/MLTrackFindingStub.cpp)
endif()
target_sources(ActsPythonBindings PRIVATE src/Onnx.cpp)
list(APPEND py_files examples/onnx/__init__.py)

if(ACTS_BUILD_PLUGIN_MLPACK)
target_sources(ActsPythonBindings PRIVATE src/OnnxMlpack.cpp)
list(APPEND py_files examples/onnx/mlpack.py)
else()
target_sources(ActsPythonBindings PRIVATE src/OnnxMlpackStub.cpp)
endif()

if(ACTS_BUILD_PLUGIN_MLPACK)
target_compile_definitions(
ActsPythonBindings PUBLIC ACTS_PLUGIN_MLPACK)
else()
target_sources(ActsPythonBindings PRIVATE src/OnnxStub.cpp)
target_sources(ActsPythonBindings PRIVATE src/OnnxMlpackStub.cpp)
endif()

add_custom_target(ActsPythonGlueCode)
Expand Down
6 changes: 3 additions & 3 deletions Examples/Python/include/Acts/Plugins/Python/Utilities.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@
namespace Acts::Python {

struct Context {
std::unordered_map<std::string, pybind11::module_*> modules;
std::unordered_map<std::string, pybind11::module_> modules;

pybind11::module_& get(const std::string& name) { return *modules.at(name); }
pybind11::module_& get(const std::string& name) { return modules.at(name); }

template <typename... Args, typename = std::enable_if_t<sizeof...(Args) >= 2>>
auto get(Args&&... args) {
return std::make_tuple((*modules.at(args))...);
return std::make_tuple((modules.at(args))...);
}
};

Expand Down
9 changes: 9 additions & 0 deletions Examples/Python/python/acts/examples/onnx/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from acts._adapter import _patch_config
from acts import ActsPythonBindings

if not hasattr(ActsPythonBindings._examples, "_onnx"):
raise ImportError("ActsPythonBindings._examples._onnx not found")

_patch_config(ActsPythonBindings._examples._onnx)

from acts.ActsPythonBindings._examples._onnx import *
9 changes: 9 additions & 0 deletions Examples/Python/python/acts/examples/onnx/mlpack.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from acts._adapter import _patch_config
from acts import ActsPythonBindings

if not hasattr(ActsPythonBindings._examples, "_mlpack"):
raise ImportError("ActsPythonBindings._examples._mlpack not found")

_patch_config(ActsPythonBindings._examples._mlpack)

from acts.ActsPythonBindings._examples._mlpack import *
2 changes: 1 addition & 1 deletion Examples/Python/python/acts/examples/reconstruction.py
Original file line number Diff line number Diff line change
Expand Up @@ -1310,7 +1310,7 @@ def addAmbiguityResolutionML(
writeTrajectories: bool = True,
logLevel: Optional[acts.logging.Level] = None,
) -> None:
from acts.examples import AmbiguityResolutionMLAlgorithm
from acts.examples.onnx import AmbiguityResolutionMLAlgorithm

customLogLevel = acts.examples.defaultLogging(s, logLevel)

Expand Down
2 changes: 1 addition & 1 deletion Examples/Python/src/Geant4Component.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ PYBIND11_MODULE(ActsPythonBindingsGeant4, mod) {
}

Acts::Python::Context ctx;
ctx.modules["geant4"] = &mod;
ctx.modules["geant4"] = mod;

addGeant4HepMC3(ctx);
}
12 changes: 7 additions & 5 deletions Examples/Python/src/ModuleEntry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,19 +107,20 @@ void addHepMC3(Context& ctx);
void addExaTrkXTrackFinding(Context& ctx);
void addEDM4hep(Context& ctx);
void addSvg(Context& ctx);
void addMLTrackFinding(Context& ctx);
void addOnnx(Context& ctx);
void addOnnxMlpack(Context& ctx);

} // namespace Acts::Python

using namespace Acts::Python;

PYBIND11_MODULE(ActsPythonBindings, m) {
Acts::Python::Context ctx;
ctx.modules["main"] = &m;
ctx.modules["main"] = m;
auto mex = m.def_submodule("_examples");
ctx.modules["examples"] = &mex;
ctx.modules["examples"] = mex;
auto prop = m.def_submodule("_propagator");
ctx.modules["propagation"] = &prop;
ctx.modules["propagation"] = prop;
m.doc() = "Acts";

m.attr("__version__") =
Expand Down Expand Up @@ -277,5 +278,6 @@ PYBIND11_MODULE(ActsPythonBindings, m) {
addExaTrkXTrackFinding(ctx);
addEDM4hep(ctx);
addSvg(ctx);
addMLTrackFinding(ctx);
addOnnx(ctx);
addOnnxMlpack(ctx);
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,6 @@

#include "Acts/Plugins/Python/Utilities.hpp"
#include "ActsExamples/TrackFindingML/AmbiguityResolutionMLAlgorithm.hpp"
#ifdef ACTS_PLUGIN_MLPACK
#include "ActsExamples/TrackFindingML/AmbiguityResolutionMLDBScanAlgorithm.hpp"
#endif

#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
Expand All @@ -22,19 +19,14 @@ using namespace Acts;

namespace Acts::Python {

void addMLTrackFinding(Context& ctx) {
void addOnnx(Context& ctx) {
auto [m, mex] = ctx.get("main", "examples");
auto onnx = mex.def_submodule("_onnx");
ctx.modules["onnx"] = onnx;

ACTS_PYTHON_DECLARE_ALGORITHM(ActsExamples::AmbiguityResolutionMLAlgorithm,
mex, "AmbiguityResolutionMLAlgorithm",
onnx, "AmbiguityResolutionMLAlgorithm",
inputTracks, inputDuplicateNN, outputTracks,
nMeasurementsMin);

#ifdef ACTS_PLUGIN_MLPACK
ACTS_PYTHON_DECLARE_ALGORITHM(
ActsExamples::AmbiguityResolutionMLDBScanAlgorithm, mex,
"AmbiguityResolutionMLDBScanAlgorithm", inputTracks, inputDuplicateNN,
outputTracks, nMeasurementsMin, epsilonDBScan, minPointsDBScan);
#endif
}
} // namespace Acts::Python
31 changes: 31 additions & 0 deletions Examples/Python/src/OnnxMlpack.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// This file is part of the Acts project.
//
// Copyright (C) 2023 CERN for the benefit of the Acts project
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.

#include "Acts/Plugins/Python/Utilities.hpp"
#include "ActsExamples/TrackFindingML/AmbiguityResolutionMLDBScanAlgorithm.hpp"

#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

namespace py = pybind11;

using namespace ActsExamples;
using namespace Acts;

namespace Acts::Python {

void addOnnxMlpack(Context& ctx) {
auto [m, mex, onnx] = ctx.get("main", "examples", "onnx");
auto mlpack = mex.def_submodule("_mlpack");

ACTS_PYTHON_DECLARE_ALGORITHM(
ActsExamples::AmbiguityResolutionMLDBScanAlgorithm, mlpack,
"AmbiguityResolutionMLDBScanAlgorithm", inputTracks, inputDuplicateNN,
outputTracks, nMeasurementsMin, epsilonDBScan, minPointsDBScan);
}
} // namespace Acts::Python
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
#include <pybind11/stl.h>

namespace Acts::Python {
void addMLTrackFinding(Context& /*unused*/) {
void addOnnxMlpack(Context& /*unused*/) {
// dummy function
}
} // namespace Acts::Python
18 changes: 18 additions & 0 deletions Examples/Python/src/OnnxStub.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
// This file is part of the Acts project.
//
// Copyright (C) 2023 CERN for the benefit of the Acts project
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.

#include "Acts/Plugins/Python/Utilities.hpp"

#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

namespace Acts::Python {
void addOnnx(Context& /*unused*/) {
// dummy function
}
} // namespace Acts::Python
2 changes: 1 addition & 1 deletion Examples/Python/src/Pythia8.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ void addPythia8(Context& ctx) {
auto mex = ctx.get("examples");

auto p8 = mex.def_submodule("pythia8");
ctx.modules["pythia8"] = &p8;
ctx.modules["pythia8"] = p8;

using Gen = ActsExamples::Pythia8Generator;
auto gen = py::class_<Gen, ActsExamples::EventGenerator::ParticlesGenerator,
Expand Down
7 changes: 7 additions & 0 deletions Examples/Python/tests/helpers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,13 @@
except ImportError:
edm4hepEnabled = False

try:
import acts.examples.onnx

onnxEnabled = True
except ImportError:
onnxEnabled = False


try:
import acts.examples
Expand Down

0 comments on commit 15aec1a

Please sign in to comment.