Skip to content

Commit

Permalink
feat: Introduce Exa.TrkX plugin (#1151)
Browse files Browse the repository at this point in the history
This PR introduces the Exa.TrkX plugin which implements the Exa.TrkX GNN-based track finding alongside with some examples.

I have already added a section to the CI-workflow-description, which won't work until the required docker image has been merged to the machines-repository (see acts-project/machines#43)

Points which may need some discussion are 
* the cmake configuration for CUDA, which I moved to a macro since it is now needed for two plugins.
* wether we want to put ONNX model files into the repository (at the moment there are no working ones due to a bug in ONNX I think)
  • Loading branch information
benjaminhuth committed Mar 28, 2022
1 parent 8e168c8 commit 23677b7
Show file tree
Hide file tree
Showing 40 changed files with 4,502 additions and 13 deletions.
16 changes: 16 additions & 0 deletions .github/workflows/builds.yml
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,22 @@ jobs:
-DACTS_BUILD_UNITTESTS=ON
- name: Build
run: cmake --build build --
exatrkx:
runs-on: ubuntu-latest
container: ghcr.io/acts-project/ubuntu2004_exatrkx:v17
steps:
- uses: actions/checkout@v2
- name: Configure
run: >
cmake -B build -S .
-GNinja
-DCMAKE_BUILD_TYPE=Release
-DCMAKE_CXX_FLAGS=-Werror
-DACTS_BUILD_PLUGIN_EXATRKX=ON
-DACTS_BUILD_EXAMPLES_EXATRKX=ON
-DACTS_BUILD_EXAMPLES_PYTHON_BINDINGS=ON
- name: Build
run: cmake --build build --
sycl:
runs-on: ubuntu-latest
container: ghcr.io/acts-project/ubuntu2004_oneapi:v9
Expand Down
36 changes: 29 additions & 7 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ option(ACTS_BUILD_PLUGIN_AUTODIFF "Build the autodiff plugin" OFF)
option(ACTS_USE_SYSTEM_AUTODIFF "Use autodiff provided by the system instead of the bundled version" OFF)
option(ACTS_BUILD_PLUGIN_CUDA "Build CUDA plugin" OFF)
option(ACTS_BUILD_PLUGIN_DD4HEP "Build DD4hep plugin" OFF)
option(ACTS_BUILD_PLUGIN_EXATRKX "Build the Exa.TrkX plugin" OFF)
option(ACTS_BUILD_PLUGIN_IDENTIFICATION "Build Identification plugin" OFF)
option(ACTS_BUILD_PLUGIN_JSON "Build json plugin" OFF)
option(ACTS_USE_SYSTEM_NLOHMANN_JSON "Use nlohmann::json provided by the system instead of the bundled version" OFF)
Expand All @@ -48,6 +49,7 @@ option(ACTS_BUILD_ALIGNMENT "Build Alignment package" OFF)
# examples related options
option(ACTS_BUILD_EXAMPLES "Build standalone examples" OFF)
option(ACTS_BUILD_EXAMPLES_DD4HEP "Build DD4hep-based code in the examples" OFF)
option(ACTS_BUILD_EXAMPLES_EXATRKX "Build the Exa.TrkX example code" OFF)
option(ACTS_BUILD_EXAMPLES_GEANT4 "Build Geant4-based code in the examples" OFF)
option(ACTS_BUILD_EXAMPLES_HEPMC3 "Build HepMC3-based code in the examples" OFF)
option(ACTS_BUILD_EXAMPLES_PYTHIA8 "Build Pythia8-based code in the examples" OFF)
Expand Down Expand Up @@ -79,13 +81,15 @@ set_option_if(ACTS_BUILD_EXAMPLES_PYTHIA8 ACTS_BUILD_EVERYTHING)
set_option_if(ACTS_BUILD_FATRAS_GEANT4 ACTS_BUILD_EVERYTHING)
set_option_if(ACTS_BUILD_FATRAS ACTS_BUILD_FATRAS_GEANT4)
set_option_if(ACTS_BUILD_ALIGNMENT ACTS_BUILD_EVERYTHING)

# any examples component activates the general examples option
set_option_if(
ACTS_BUILD_EXAMPLES
ACTS_BUILD_EXAMPLES_DD4HEP
OR ACTS_BUILD_EXAMPLES_GEANT4
OR ACTS_BUILD_EXAMPLES_HEPMC3
OR ACTS_BUILD_EXAMPLES_PYTHIA8
OR ACTS_BUILD_EXAMPLES_EXATRKX
OR ACTS_BUILD_EVERYTHING)
# core plugins might be required by examples or depend on each other
set_option_if(
Expand All @@ -105,6 +109,7 @@ set_option_if(
ACTS_BUILD_EXAMPLES OR ACTS_BUILD_EVERYTHING)
set_option_if(ACTS_BUILD_PLUGIN_LEGACY ACTS_BUILD_EVERYTHING)
set_option_if(ACTS_BUILD_PLUGIN_AUTODIFF ACTS_BUILD_EVERYTHING)
set_option_if(ACTS_BUILD_EXAMPLES_EXATRKX ACTS_BUILD_PLUGIN_EXATRKX)

# feature tests
include(CheckCXXSourceCompiles)
Expand Down Expand Up @@ -187,6 +192,20 @@ macro(project)
set(${PROJECT_NAME}_VERSION "${${PROJECT_NAME}_VERSION}" CACHE INTERNAL "")
endmacro()

# CUDA settings are collected here in a macro, so that they can be reused by different plugins
macro(enable_cuda)
enable_language(CUDA)
set(CMAKE_CUDA_STANDARD 14 CACHE STRING "CUDA C++ standard to use")
set(CMAKE_CUDA_STANDARD_REQUIRED ON CACHE BOOL
"Force the C++ standard requirement")
if(NOT CMAKE_CUDA_ARCHITECTURES)
set(CMAKE_CUDA_ARCHITECTURES "35;52;75" CACHE STRING
"CUDA architectures to generate code for")
endif()
set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -g -G")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --extended-lambda")
endmacro()

# optional packages
#
# find packages explicitly for each component even if this means searching for
Expand All @@ -210,13 +229,7 @@ if(ACTS_BUILD_PLUGIN_AUTODIFF)
endif()
endif()
if(ACTS_BUILD_PLUGIN_CUDA)
enable_language(CUDA)
set(CMAKE_CUDA_STANDARD 14 CACHE STRING "CUDA C++ standard to use")
set(CMAKE_CUDA_STANDARD_REQUIRED ON CACHE BOOL
"Force the C++ standard requirement")
set(CMAKE_CUDA_ARCHITECTURES "35;52;75" CACHE STRING
"CUDA architectures to generate code for")
set(CMAKE_CUDA_FLAGS_DEBUG "-g -G")
enable_cuda()
endif()
if(ACTS_BUILD_PLUGIN_DD4HEP)
find_package(DD4hep ${_acts_dd4hep_version} REQUIRED CONFIG COMPONENTS DDCore DDDetectors)
Expand All @@ -239,6 +252,15 @@ if(ACTS_BUILD_PLUGIN_TGEO)
find_package(ROOT ${_acts_root_version} REQUIRED CONFIG COMPONENTS Geom)
check_root_compatibility()
endif()
if(ACTS_BUILD_PLUGIN_EXATRKX)
enable_cuda()
find_package(CUDAToolkit REQUIRED)
find_package(Torch REQUIRED)
find_package(OnnxRuntime REQUIRED)
find_package(cugraph REQUIRED)
add_subdirectory(thirdparty/libFRNN)
endif()

# examples dependencies
if(ACTS_BUILD_EXAMPLES)
set(THREADS_PREFER_PTHREAD_FLAG ON)
Expand Down
1 change: 1 addition & 0 deletions Examples/Algorithms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ add_subdirectory(MaterialMapping)
add_subdirectory(Printers)
add_subdirectory(Propagation)
add_subdirectory(TrackFinding)
add_subdirectory_if(TrackFindingExaTrkX ACTS_BUILD_EXAMPLES_EXATRKX)
add_subdirectory(TrackFitting)
add_subdirectory(TruthTracking)
add_subdirectory(Vertexing)
Expand Down
23 changes: 23 additions & 0 deletions Examples/Algorithms/TrackFindingExaTrkX/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
add_library(
ActsExamplesTrackFindingExaTrkX SHARED
src/TrackFindingAlgorithmExaTrkX.cpp
)


target_include_directories(
ActsExamplesTrackFindingExaTrkX
PUBLIC
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>
)


target_link_libraries(
ActsExamplesTrackFindingExaTrkX
PUBLIC
ActsPluginExaTrkX
ActsExamplesFramework
)

install(
TARGETS ActsExamplesTrackFindingExaTrkX
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR})
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
// This file is part of the Acts project.
//
// Copyright (C) 2022 CERN for the benefit of the Acts project
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.

#pragma once

#include "Acts/Plugins/ExaTrkX/ExaTrkXTrackFinding.hpp"
#include "ActsExamples/Framework/BareAlgorithm.hpp"

#include <string>
#include <vector>

namespace ActsExamples {

class TrackFindingAlgorithmExaTrkX final : public BareAlgorithm {
public:
struct Config {
/// Input spacepoints collection.
std::string inputSpacePoints;

/// Output protoTracks collection.
std::string outputProtoTracks;

/// ML based track finder
std::shared_ptr<Acts::ExaTrkXTrackFinding> trackFinderML;

// NOTE the other config parameters for the Exa.TrkX class for now are just
// initialized as the defaults
};

/// Constructor of the track finding algorithm
///
/// @param cfg is the config struct to configure the algorithm
/// @param level is the logging level
TrackFindingAlgorithmExaTrkX(Config cfg, Acts::Logging::Level lvl);

virtual ~TrackFindingAlgorithmExaTrkX() {}

/// Framework execute method of the track finding algorithm
///
/// @param ctx is the algorithm context that holds event-wise information
/// @return a process code to steer the algorithm flow
ActsExamples::ProcessCode execute(
const ActsExamples::AlgorithmContext& ctx) const final;

const Config& config() const { return m_cfg; }

private:
// configuration
Config m_cfg;
};

} // namespace ActsExamples
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
// This file is part of the Acts project.
//
// Copyright (C) 2022 CERN for the benefit of the Acts project
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.

#include "ActsExamples/TrackFindingExaTrkX/TrackFindingAlgorithmExaTrkX.hpp"

#include "ActsExamples/EventData/Index.hpp"
#include "ActsExamples/EventData/ProtoTrack.hpp"
#include "ActsExamples/EventData/SimSpacePoint.hpp"
#include "ActsExamples/Framework/WhiteBoard.hpp"

ActsExamples::TrackFindingAlgorithmExaTrkX::TrackFindingAlgorithmExaTrkX(
Config config, Acts::Logging::Level level)
: ActsExamples::BareAlgorithm("TrackFindingMLBasedAlgorithm", level),
m_cfg(std::move(config)) {
if (m_cfg.inputSpacePoints.empty()) {
throw std::invalid_argument("Missing spacepoint input collection");
}
if (m_cfg.outputProtoTracks.empty()) {
throw std::invalid_argument("Missing protoTrack output collection");
}
if (!m_cfg.trackFinderML) {
throw std::invalid_argument("Missing track finder");
}
}

ActsExamples::ProcessCode ActsExamples::TrackFindingAlgorithmExaTrkX::execute(
const ActsExamples::AlgorithmContext& ctx) const {
// Read input data
const auto& spacepoints =
ctx.eventStore.get<SimSpacePointContainer>(m_cfg.inputSpacePoints);

// Convert Input data to a list of size [num_measurements x
// measurement_features]
size_t num_spacepoints = spacepoints.size();
ACTS_INFO("Received " << num_spacepoints << " spacepoints");

std::vector<float> inputValues;
std::vector<uint32_t> spacepointIDs;
inputValues.reserve(spacepoints.size() * 3);
spacepointIDs.reserve(spacepoints.size());
for (const auto& sp : spacepoints) {
float x = sp.x();
float y = sp.y();
float z = sp.z();
float r = sp.r();
float phi = std::atan2(y, x);
inputValues.push_back(r);
inputValues.push_back(phi);
inputValues.push_back(z);

spacepointIDs.push_back(sp.measurementIndex());
}

// ProtoTrackContainer protoTracks;
std::vector<std::vector<uint32_t> > trackCandidates;
m_cfg.trackFinderML->getTracks(inputValues, spacepointIDs, trackCandidates);

std::vector<ProtoTrack> protoTracks;
protoTracks.reserve(trackCandidates.size());
for (auto& x : trackCandidates) {
ProtoTrack onetrack;
std::copy(x.begin(), x.end(), std::back_inserter(onetrack));
protoTracks.push_back(std::move(onetrack));
}

ACTS_INFO("Created " << protoTracks.size() << " proto tracks");
ctx.eventStore.add(m_cfg.outputProtoTracks, std::move(protoTracks));

return ActsExamples::ProcessCode::SUCCESS;
}
6 changes: 6 additions & 0 deletions Examples/Python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,12 @@ else()
target_sources(ActsPythonBindings PRIVATE src/HepMC3Stub.cpp)
endif()

if(ACTS_BUILD_EXAMPLES_EXATRKX)
target_link_libraries(ActsPythonBindings PUBLIC ActsExamplesTrackFindingExaTrkX)
target_sources(ActsPythonBindings PRIVATE src/ExaTrkXTrackFinding.cpp)
else()
target_sources(ActsPythonBindings PRIVATE src/ExaTrkXTrackFindingStub.cpp)
endif()

add_custom_target(ActsPythonGlueCode)
configure_file(setup.sh.in ${_python_dir}/setup.sh COPYONLY)
Expand Down
72 changes: 72 additions & 0 deletions Examples/Python/src/ExaTrkXTrackFinding.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
// This file is part of the Acts project.
//
// Copyright (C) 2021 CERN for the benefit of the Acts project
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.

#include "Acts/Plugins/ExaTrkX/ExaTrkXTrackFinding.hpp"

#include "Acts/Plugins/Python/Utilities.hpp"
#include "Acts/TrackFinding/MeasurementSelector.hpp"
#include "ActsExamples/TrackFinding/SeedingAlgorithm.hpp"
#include "ActsExamples/TrackFinding/SpacePointMaker.hpp"
#include "ActsExamples/TrackFinding/TrackFindingAlgorithm.hpp"
#include "ActsExamples/TrackFindingExaTrkX/TrackFindingAlgorithmExaTrkX.hpp"

#include <memory>

#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

namespace py = pybind11;

using namespace ActsExamples;
using namespace Acts;

namespace Acts::Python {

void addExaTrkXTrackFinding(Context& ctx) {
auto [m, mex] = ctx.get("main", "examples");

{
using Alg = Acts::ExaTrkXTrackFinding;
using Config = Acts::ExaTrkXTrackFinding::Config;

auto alg = py::class_<Alg, std::shared_ptr<Alg>>(mex, "ExaTrkXTrackFinding")
.def(py::init<const Config&>(), py::arg("config"))
.def_property_readonly("config", &Alg::config);

auto c = py::class_<Config>(alg, "Config").def(py::init<>());
ACTS_PYTHON_STRUCT_BEGIN(c, Config);
ACTS_PYTHON_MEMBER(inputMLModuleDir);
ACTS_PYTHON_MEMBER(spacepointFeatures);
ACTS_PYTHON_MEMBER(embeddingDim);
ACTS_PYTHON_MEMBER(rVal);
ACTS_PYTHON_MEMBER(knnVal);
ACTS_PYTHON_MEMBER(filterCut);
ACTS_PYTHON_STRUCT_END();
}

{
using Alg = ActsExamples::TrackFindingAlgorithmExaTrkX;
using Config = Alg::Config;

auto alg =
py::class_<Alg, ActsExamples::BareAlgorithm, std::shared_ptr<Alg>>(
mex, "TrackFindingAlgorithmExaTrkX")
.def(py::init<const Config&, Acts::Logging::Level>(),
py::arg("config"), py::arg("level"))
.def_property_readonly("config", &Alg::config);

auto c = py::class_<Config>(alg, "Config").def(py::init<>());
ACTS_PYTHON_STRUCT_BEGIN(c, Config);
ACTS_PYTHON_MEMBER(inputSpacePoints);
ACTS_PYTHON_MEMBER(outputProtoTracks);
ACTS_PYTHON_MEMBER(trackFinderML);
ACTS_PYTHON_STRUCT_END();
}
}

} // namespace Acts::Python
18 changes: 18 additions & 0 deletions Examples/Python/src/ExaTrkXTrackFindingStub.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
// This file is part of the Acts project.
//
// Copyright (C) 2022 CERN for the benefit of the Acts project
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.

#include "Acts/Plugins/Python/Utilities.hpp"

#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

namespace Acts::Python {
void addExaTrkXTrackFinding(Context&) {
// dummy function
}
} // namespace Acts::Python
4 changes: 2 additions & 2 deletions Examples/Python/src/HepMC3Stub.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,5 @@
#include "Acts/Plugins/Python/Utilities.hpp"

namespace Acts::Python {
void addHepMC3(Context& ctx) {}
} // namespace Acts::Python
void addHepMC3(Context&) {}
} // namespace Acts::Python

0 comments on commit 23677b7

Please sign in to comment.