-
Notifications
You must be signed in to change notification settings - Fork 157
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Introduce Exa.TrkX plugin (#1151)
This PR introduces the Exa.TrkX plugin which implements the Exa.TrkX GNN-based track finding alongside with some examples. I have already added a section to the CI-workflow-description, which won't work until the required docker image has been merged to the machines-repository (see acts-project/machines#43) Points which may need some discussion are * the cmake configuration for CUDA, which I moved to a macro since it is now needed for two plugins. * wether we want to put ONNX model files into the repository (at the moment there are no working ones due to a bug in ONNX I think)
- Loading branch information
1 parent
8e168c8
commit 23677b7
Showing
40 changed files
with
4,502 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
add_library( | ||
ActsExamplesTrackFindingExaTrkX SHARED | ||
src/TrackFindingAlgorithmExaTrkX.cpp | ||
) | ||
|
||
|
||
target_include_directories( | ||
ActsExamplesTrackFindingExaTrkX | ||
PUBLIC | ||
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include> | ||
) | ||
|
||
|
||
target_link_libraries( | ||
ActsExamplesTrackFindingExaTrkX | ||
PUBLIC | ||
ActsPluginExaTrkX | ||
ActsExamplesFramework | ||
) | ||
|
||
install( | ||
TARGETS ActsExamplesTrackFindingExaTrkX | ||
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}) |
57 changes: 57 additions & 0 deletions
57
...kFindingExaTrkX/include/ActsExamples/TrackFindingExaTrkX/TrackFindingAlgorithmExaTrkX.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
// This file is part of the Acts project. | ||
// | ||
// Copyright (C) 2022 CERN for the benefit of the Acts project | ||
// | ||
// This Source Code Form is subject to the terms of the Mozilla Public | ||
// License, v. 2.0. If a copy of the MPL was not distributed with this | ||
// file, You can obtain one at http://mozilla.org/MPL/2.0/. | ||
|
||
#pragma once | ||
|
||
#include "Acts/Plugins/ExaTrkX/ExaTrkXTrackFinding.hpp" | ||
#include "ActsExamples/Framework/BareAlgorithm.hpp" | ||
|
||
#include <string> | ||
#include <vector> | ||
|
||
namespace ActsExamples { | ||
|
||
class TrackFindingAlgorithmExaTrkX final : public BareAlgorithm { | ||
public: | ||
struct Config { | ||
/// Input spacepoints collection. | ||
std::string inputSpacePoints; | ||
|
||
/// Output protoTracks collection. | ||
std::string outputProtoTracks; | ||
|
||
/// ML based track finder | ||
std::shared_ptr<Acts::ExaTrkXTrackFinding> trackFinderML; | ||
|
||
// NOTE the other config parameters for the Exa.TrkX class for now are just | ||
// initialized as the defaults | ||
}; | ||
|
||
/// Constructor of the track finding algorithm | ||
/// | ||
/// @param cfg is the config struct to configure the algorithm | ||
/// @param level is the logging level | ||
TrackFindingAlgorithmExaTrkX(Config cfg, Acts::Logging::Level lvl); | ||
|
||
virtual ~TrackFindingAlgorithmExaTrkX() {} | ||
|
||
/// Framework execute method of the track finding algorithm | ||
/// | ||
/// @param ctx is the algorithm context that holds event-wise information | ||
/// @return a process code to steer the algorithm flow | ||
ActsExamples::ProcessCode execute( | ||
const ActsExamples::AlgorithmContext& ctx) const final; | ||
|
||
const Config& config() const { return m_cfg; } | ||
|
||
private: | ||
// configuration | ||
Config m_cfg; | ||
}; | ||
|
||
} // namespace ActsExamples |
75 changes: 75 additions & 0 deletions
75
Examples/Algorithms/TrackFindingExaTrkX/src/TrackFindingAlgorithmExaTrkX.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
// This file is part of the Acts project. | ||
// | ||
// Copyright (C) 2022 CERN for the benefit of the Acts project | ||
// | ||
// This Source Code Form is subject to the terms of the Mozilla Public | ||
// License, v. 2.0. If a copy of the MPL was not distributed with this | ||
// file, You can obtain one at http://mozilla.org/MPL/2.0/. | ||
|
||
#include "ActsExamples/TrackFindingExaTrkX/TrackFindingAlgorithmExaTrkX.hpp" | ||
|
||
#include "ActsExamples/EventData/Index.hpp" | ||
#include "ActsExamples/EventData/ProtoTrack.hpp" | ||
#include "ActsExamples/EventData/SimSpacePoint.hpp" | ||
#include "ActsExamples/Framework/WhiteBoard.hpp" | ||
|
||
ActsExamples::TrackFindingAlgorithmExaTrkX::TrackFindingAlgorithmExaTrkX( | ||
Config config, Acts::Logging::Level level) | ||
: ActsExamples::BareAlgorithm("TrackFindingMLBasedAlgorithm", level), | ||
m_cfg(std::move(config)) { | ||
if (m_cfg.inputSpacePoints.empty()) { | ||
throw std::invalid_argument("Missing spacepoint input collection"); | ||
} | ||
if (m_cfg.outputProtoTracks.empty()) { | ||
throw std::invalid_argument("Missing protoTrack output collection"); | ||
} | ||
if (!m_cfg.trackFinderML) { | ||
throw std::invalid_argument("Missing track finder"); | ||
} | ||
} | ||
|
||
ActsExamples::ProcessCode ActsExamples::TrackFindingAlgorithmExaTrkX::execute( | ||
const ActsExamples::AlgorithmContext& ctx) const { | ||
// Read input data | ||
const auto& spacepoints = | ||
ctx.eventStore.get<SimSpacePointContainer>(m_cfg.inputSpacePoints); | ||
|
||
// Convert Input data to a list of size [num_measurements x | ||
// measurement_features] | ||
size_t num_spacepoints = spacepoints.size(); | ||
ACTS_INFO("Received " << num_spacepoints << " spacepoints"); | ||
|
||
std::vector<float> inputValues; | ||
std::vector<uint32_t> spacepointIDs; | ||
inputValues.reserve(spacepoints.size() * 3); | ||
spacepointIDs.reserve(spacepoints.size()); | ||
for (const auto& sp : spacepoints) { | ||
float x = sp.x(); | ||
float y = sp.y(); | ||
float z = sp.z(); | ||
float r = sp.r(); | ||
float phi = std::atan2(y, x); | ||
inputValues.push_back(r); | ||
inputValues.push_back(phi); | ||
inputValues.push_back(z); | ||
|
||
spacepointIDs.push_back(sp.measurementIndex()); | ||
} | ||
|
||
// ProtoTrackContainer protoTracks; | ||
std::vector<std::vector<uint32_t> > trackCandidates; | ||
m_cfg.trackFinderML->getTracks(inputValues, spacepointIDs, trackCandidates); | ||
|
||
std::vector<ProtoTrack> protoTracks; | ||
protoTracks.reserve(trackCandidates.size()); | ||
for (auto& x : trackCandidates) { | ||
ProtoTrack onetrack; | ||
std::copy(x.begin(), x.end(), std::back_inserter(onetrack)); | ||
protoTracks.push_back(std::move(onetrack)); | ||
} | ||
|
||
ACTS_INFO("Created " << protoTracks.size() << " proto tracks"); | ||
ctx.eventStore.add(m_cfg.outputProtoTracks, std::move(protoTracks)); | ||
|
||
return ActsExamples::ProcessCode::SUCCESS; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
// This file is part of the Acts project. | ||
// | ||
// Copyright (C) 2021 CERN for the benefit of the Acts project | ||
// | ||
// This Source Code Form is subject to the terms of the Mozilla Public | ||
// License, v. 2.0. If a copy of the MPL was not distributed with this | ||
// file, You can obtain one at http://mozilla.org/MPL/2.0/. | ||
|
||
#include "Acts/Plugins/ExaTrkX/ExaTrkXTrackFinding.hpp" | ||
|
||
#include "Acts/Plugins/Python/Utilities.hpp" | ||
#include "Acts/TrackFinding/MeasurementSelector.hpp" | ||
#include "ActsExamples/TrackFinding/SeedingAlgorithm.hpp" | ||
#include "ActsExamples/TrackFinding/SpacePointMaker.hpp" | ||
#include "ActsExamples/TrackFinding/TrackFindingAlgorithm.hpp" | ||
#include "ActsExamples/TrackFindingExaTrkX/TrackFindingAlgorithmExaTrkX.hpp" | ||
|
||
#include <memory> | ||
|
||
#include <pybind11/pybind11.h> | ||
#include <pybind11/stl.h> | ||
|
||
namespace py = pybind11; | ||
|
||
using namespace ActsExamples; | ||
using namespace Acts; | ||
|
||
namespace Acts::Python { | ||
|
||
void addExaTrkXTrackFinding(Context& ctx) { | ||
auto [m, mex] = ctx.get("main", "examples"); | ||
|
||
{ | ||
using Alg = Acts::ExaTrkXTrackFinding; | ||
using Config = Acts::ExaTrkXTrackFinding::Config; | ||
|
||
auto alg = py::class_<Alg, std::shared_ptr<Alg>>(mex, "ExaTrkXTrackFinding") | ||
.def(py::init<const Config&>(), py::arg("config")) | ||
.def_property_readonly("config", &Alg::config); | ||
|
||
auto c = py::class_<Config>(alg, "Config").def(py::init<>()); | ||
ACTS_PYTHON_STRUCT_BEGIN(c, Config); | ||
ACTS_PYTHON_MEMBER(inputMLModuleDir); | ||
ACTS_PYTHON_MEMBER(spacepointFeatures); | ||
ACTS_PYTHON_MEMBER(embeddingDim); | ||
ACTS_PYTHON_MEMBER(rVal); | ||
ACTS_PYTHON_MEMBER(knnVal); | ||
ACTS_PYTHON_MEMBER(filterCut); | ||
ACTS_PYTHON_STRUCT_END(); | ||
} | ||
|
||
{ | ||
using Alg = ActsExamples::TrackFindingAlgorithmExaTrkX; | ||
using Config = Alg::Config; | ||
|
||
auto alg = | ||
py::class_<Alg, ActsExamples::BareAlgorithm, std::shared_ptr<Alg>>( | ||
mex, "TrackFindingAlgorithmExaTrkX") | ||
.def(py::init<const Config&, Acts::Logging::Level>(), | ||
py::arg("config"), py::arg("level")) | ||
.def_property_readonly("config", &Alg::config); | ||
|
||
auto c = py::class_<Config>(alg, "Config").def(py::init<>()); | ||
ACTS_PYTHON_STRUCT_BEGIN(c, Config); | ||
ACTS_PYTHON_MEMBER(inputSpacePoints); | ||
ACTS_PYTHON_MEMBER(outputProtoTracks); | ||
ACTS_PYTHON_MEMBER(trackFinderML); | ||
ACTS_PYTHON_STRUCT_END(); | ||
} | ||
} | ||
|
||
} // namespace Acts::Python |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
// This file is part of the Acts project. | ||
// | ||
// Copyright (C) 2022 CERN for the benefit of the Acts project | ||
// | ||
// This Source Code Form is subject to the terms of the Mozilla Public | ||
// License, v. 2.0. If a copy of the MPL was not distributed with this | ||
// file, You can obtain one at http://mozilla.org/MPL/2.0/. | ||
|
||
#include "Acts/Plugins/Python/Utilities.hpp" | ||
|
||
#include <pybind11/pybind11.h> | ||
#include <pybind11/stl.h> | ||
|
||
namespace Acts::Python { | ||
void addExaTrkXTrackFinding(Context&) { | ||
// dummy function | ||
} | ||
} // namespace Acts::Python |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.