Skip to content

Commit

Permalink
feat: Exa.TrkX with torchscript backend (#1473)
Browse files Browse the repository at this point in the history
This PR is a big update to the Exa.TrkX plugin:

* It now provides both a TorchScript-based and a ONNX-based backend (the ONNX one doesn't really work at the moment due to missing operators in Onnx, however, we should'nt throw away the code for this I think)
* Partial integration of Acts logging instead of `std::cout` based things
* The `TrajectoriesToProtoTracks` algorithm to be able to compare results of the CKF and the Exa.TrkkX track finding
* Update python to support all of this
* The dependency to FRNN is now managed by CMake's `FetchContent` mechanism
* The CI routines are restrucutured
  * no build in the github workflows anymore
  * Build and check both backends in the CI bridge

Co-authored-by: Andreas Stefl <487211+andiwand@users.noreply.github.com>
  • Loading branch information
benjaminhuth and andiwand committed Sep 16, 2022
1 parent ff772f1 commit b409627
Show file tree
Hide file tree
Showing 51 changed files with 1,613 additions and 3,653 deletions.
30 changes: 0 additions & 30 deletions .github/workflows/builds.yml
Original file line number Diff line number Diff line change
Expand Up @@ -566,36 +566,6 @@ jobs:
- name: ccache stats
run: ccache -s

exatrkx:
runs-on: ubuntu-latest
container: ghcr.io/acts-project/ubuntu2004_exatrkx:v29
steps:
- uses: actions/checkout@v2

- name: Cache build
uses: actions/cache@v3
with:
path: ${{ github.workspace }}/ccache
key: ${{ runner.os }}-ccache-exatrkx_${{ env.CCACHE_KEY_SUFFIX }}_${{ github.sha }}
restore-keys: |
${{ runner.os }}-ccache-exatrkx_${{ env.CCACHE_KEY_SUFFIX }}_
- name: Configure
run: >
ccache -z &&
cmake -B build -S .
-GNinja
-DCMAKE_CXX_COMPILER_LAUNCHER=ccache
-DCMAKE_BUILD_TYPE=Release
-DCMAKE_CXX_FLAGS=-Werror
-DACTS_BUILD_PLUGIN_EXATRKX=ON
-DACTS_BUILD_EXAMPLES_EXATRKX=ON
-DACTS_BUILD_EXAMPLES_PYTHON_BINDINGS=ON
- name: Build
run: cmake --build build
- name: ccache stats
run: ccache -s

sycl:
runs-on: ubuntu-latest
container: ghcr.io/acts-project/ubuntu2004_oneapi:v29
Expand Down
29 changes: 17 additions & 12 deletions .gitlab-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ clang_tidy:

build:
stage: build
image: ghcr.io/acts-project/ubuntu2004_exatrkx:v29
image: ghcr.io/acts-project/ubuntu2004_exatrkx:v30
tags:
- docker
variables:
Expand Down Expand Up @@ -71,30 +71,35 @@ build:
- cd ..
- mkdir build
- >
cmake -B build -S src
cmake -B build -S src
-DCMAKE_CXX_COMPILER_LAUNCHER=ccache
-GNinja
-DCMAKE_BUILD_TYPE=Release
-DCMAKE_CXX_FLAGS=-w
-DCMAKE_CUDA_FLAGS=-w
-GNinja
-DCMAKE_BUILD_TYPE=Release
-DCMAKE_CXX_FLAGS=-w
-DCMAKE_CUDA_FLAGS=-w
-DCMAKE_CUDA_ARCHITECTURES="75;86"
-DACTS_BUILD_PLUGIN_EXATRKX=ON
-DACTS_BUILD_EXAMPLES_EXATRKX=ON
-DACTS_BUILD_PLUGIN_EXATRKX=ON
-DACTS_BUILD_EXAMPLES_EXATRKX=ON
-DACTS_EXATRKX_ENABLE_TORCH=ON
-DACTS_EXATRKX_ENABLE_ONNX=ON
-DACTS_BUILD_EXAMPLES_PYTHON_BINDINGS=ON
- cmake --build build --

test:
stage: test
needs:
- build
image: ghcr.io/acts-project/ubuntu2004_exatrkx:v29
image: ghcr.io/acts-project/ubuntu2004_exatrkx:v30
tags:
- docker-gpu-nvidia
script:
- apt-get update -y
- apt-get install -y python3 libxxhash0
- curl https://acts.web.cern.ch/ci/exatrkx/onnx_models_v01.tar --output models.tar
- tar -xf models.tar
- curl https://acts.web.cern.ch/ci/exatrkx/onnx_models_v01.tar --output onnx_models.tar
- curl https://bhuth.webo.family/index.php/s/oyFg8WF2cDPrJYz/download --output torchscript_models.tar
- tar -xf onnx_models.tar
- tar -xf torchscript_models.tar
- source build/python/setup.sh
- nvidia-smi
- python3 src/Examples/Scripts/Python/exatrkx.py
- python3 src/Examples/Scripts/Python/exatrkx.py onnx
- python3 src/Examples/Scripts/Python/exatrkx.py torch
19 changes: 16 additions & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ option(ACTS_BUILD_PLUGIN_ACTSVG "Build SVG display plugin" OFF)
option(ACTS_BUILD_PLUGIN_CUDA "Build CUDA plugin" OFF)
option(ACTS_BUILD_PLUGIN_DD4HEP "Build DD4hep plugin" OFF)
option(ACTS_BUILD_PLUGIN_EXATRKX "Build the Exa.TrkX plugin" OFF)
option(ACTS_EXATRKX_ENABLE_ONNX "Build the Onnx backend for the exatrkx plugin" OFF)
option(ACTS_EXATRKX_ENABLE_TORCH "Build the torchscript backend for the exatrkx plugin" ON)
option(ACTS_USE_SYSTEM_ACTSDD4HEP "Use the ActsDD4hep glue library provided by the system instead of building it" OFF)
option(ACTS_BUILD_PLUGIN_IDENTIFICATION "Build Identification plugin" OFF)
option(ACTS_BUILD_PLUGIN_JSON "Build json plugin" OFF)
Expand Down Expand Up @@ -289,9 +291,20 @@ if(ACTS_BUILD_PLUGIN_EXATRKX)
enable_cuda()
find_package(CUDAToolkit REQUIRED)
find_package(Torch REQUIRED)
find_package(OnnxRuntime REQUIRED)
find_package(cugraph REQUIRED)
add_subdirectory(thirdparty/libFRNN)
add_subdirectory(thirdparty/FRNN)
if(NOT (ACTS_EXATRKX_ENABLE_ONNX OR ACTS_EXATRKX_ENABLE_TORCH))
message(FATAL_ERROR
"When building the Exa.TrkX plugin, at least one of ACTS_EXATRKX_ENABLE_ONNX \
and ACTS_EXATRKX_ENABLE_TORCHSCRIPT must be enabled."
)
endif()
if(ACTS_EXATRKX_ENABLE_ONNX)
find_package(OnnxRuntime REQUIRED)
find_package(cugraph REQUIRED)
endif()
if(ACTS_EXATRKX_ENABLE_TORCH)
find_package(TorchScatter REQUIRED)
endif()
endif()

# examples dependencies
Expand Down
3 changes: 2 additions & 1 deletion Examples/Algorithms/TrackFinding/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ add_library(
src/TrackFindingAlgorithm.cpp
src/TrackFindingAlgorithmFunction.cpp
src/TrackFindingOptions.cpp
src/TrackParamsEstimationAlgorithm.cpp)
src/TrackParamsEstimationAlgorithm.cpp
src/TrajectoriesToPrototracks.cpp)
target_include_directories(
ActsExamplesTrackFinding
PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// This file is part of the Acts project.
//
// Copyright (C) 2022 CERN for the benefit of the Acts project
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.

#include "ActsExamples/Framework/BareAlgorithm.hpp"

namespace ActsExamples {

class TrajectoriesToPrototracks final : public BareAlgorithm {
public:
struct Config {
std::string inputTrajectories = "trajectories";
std::string outputPrototracks = "tracks-from-trajectories";
};

/// Construct the algorithm.
///
/// @param cfg is the algorithm configuration
/// @param lvl is the logging level
TrajectoriesToPrototracks(Config cfg, Acts::Logging::Level lvl)
: BareAlgorithm("TrajectoriesToPrototracks", lvl), m_cfg(cfg) {}

/// Run the algorithm.
///
/// @param ctx is the algorithm context with event information
/// @return a process code indication success or failure
ProcessCode execute(const AlgorithmContext& ctx) const final override;

/// Const access to the config
const Config& config() const { return m_cfg; }

private:
Config m_cfg;
};

} // namespace ActsExamples
49 changes: 49 additions & 0 deletions Examples/Algorithms/TrackFinding/src/TrajectoriesToPrototracks.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
// This file is part of the Acts project.
//
// Copyright (C) 2022 CERN for the benefit of the Acts project
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.

#include "ActsExamples/TrackFinding/TrajectoriesToPrototracks.hpp"

#include "ActsExamples/EventData/IndexSourceLink.hpp"
#include "ActsExamples/EventData/ProtoTrack.hpp"
#include "ActsExamples/EventData/Trajectories.hpp"
#include "ActsExamples/Framework/WhiteBoard.hpp"

namespace ActsExamples {

ProcessCode TrajectoriesToPrototracks::execute(
const AlgorithmContext& ctx) const {
const auto trajectories =
ctx.eventStore.get<TrajectoriesContainer>(m_cfg.inputTrajectories);

ProtoTrackContainer tracks;

for (const auto& trajectory : trajectories) {
for (const auto tip : trajectory.tips()) {
ProtoTrack track;

trajectory.multiTrajectory().visitBackwards(tip, [&](const auto& state) {
if (not state.typeFlags().test(Acts::TrackStateFlag::MeasurementFlag)) {
return true;
}

const auto& source_link =
static_cast<const IndexSourceLink&>(state.uncalibrated());
track.push_back(source_link.index());

return true;
});

tracks.push_back(track);
}
}

ctx.eventStore.add(m_cfg.outputPrototracks, std::move(tracks));

return ProcessCode::SUCCESS;
}
} // namespace ActsExamples
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

#pragma once

#include "Acts/Plugins/ExaTrkX/ExaTrkXTrackFinding.hpp"
#include "Acts/Plugins/ExaTrkX/ExaTrkXTrackFindingBase.hpp"
#include "ActsExamples/Framework/BareAlgorithm.hpp"

#include <string>
Expand All @@ -26,10 +26,12 @@ class TrackFindingAlgorithmExaTrkX final : public BareAlgorithm {
std::string outputProtoTracks;

/// ML based track finder
std::shared_ptr<Acts::ExaTrkXTrackFinding> trackFinderML;
std::shared_ptr<Acts::ExaTrkXTrackFindingBase> trackFinderML;

// NOTE the other config parameters for the Exa.TrkX class for now are just
// initialized as the defaults
/// Scaling of the input features
float rScale = 1.f;
float phiScale = 1.f;
float zScale = 1.f;
};

/// Constructor of the track finding algorithm
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ ActsExamples::ProcessCode ActsExamples::TrackFindingAlgorithmExaTrkX::execute(
ACTS_INFO("Received " << num_spacepoints << " spacepoints");

std::vector<float> inputValues;
std::vector<uint32_t> spacepointIDs;
std::vector<int> spacepointIDs;
inputValues.reserve(spacepoints.size() * 3);
spacepointIDs.reserve(spacepoints.size());
for (const auto& sp : spacepoints) {
Expand All @@ -50,18 +50,22 @@ ActsExamples::ProcessCode ActsExamples::TrackFindingAlgorithmExaTrkX::execute(
float z = sp.z();
float r = sp.r();
float phi = std::atan2(y, x);
inputValues.push_back(r);
inputValues.push_back(phi);
inputValues.push_back(z);
for (const auto slink : sp.sourceLinks()) {
const auto islink = static_cast<const IndexSourceLink&>(*slink);
spacepointIDs.push_back(islink.index());
}

inputValues.push_back(r / m_cfg.rScale);
inputValues.push_back(phi / m_cfg.phiScale);
inputValues.push_back(z / m_cfg.zScale);

// For now just take the first index since does require one single index per
// spacepoint
const auto islink =
static_cast<const IndexSourceLink&>(*sp.sourceLinks().front());
spacepointIDs.push_back(islink.index());
}

// ProtoTrackContainer protoTracks;
std::vector<std::vector<uint32_t> > trackCandidates;
m_cfg.trackFinderML->getTracks(inputValues, spacepointIDs, trackCandidates);
std::vector<std::vector<int> > trackCandidates;
m_cfg.trackFinderML->getTracks(inputValues, spacepointIDs, trackCandidates,
Acts::LoggerWrapper{logger()});

std::vector<ProtoTrack> protoTracks;
protoTracks.reserve(trackCandidates.size());
Expand Down
22 changes: 15 additions & 7 deletions Examples/Python/python/acts/examples/reconstruction.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,12 +794,16 @@ def addCKFTracks(
return s


def addExaTrkx(
ExaTrkXBackend = Enum("ExaTrkXBackend", "Torch Onnx")


def addExaTrkX(
s: acts.examples.Sequencer,
trackingGeometry: acts.TrackingGeometry,
geometrySelection: Union[Path, str],
onnxModelDir: Union[Path, str],
modelDir: Union[Path, str],
outputDirRoot: Optional[Union[Path, str]] = None,
backend: Optional[ExaTrkXBackend] = ExaTrkXBackend.Torch,
logLevel: Optional[acts.logging.Level] = None,
) -> None:

Expand Down Expand Up @@ -834,11 +838,15 @@ def addExaTrkx(
)
)

# Setup the track finding algorithm with ExaTrkX
# It takes all the source links created from truth hit smearing, seeds from
# truth particle smearing and source link selection config
exaTrkxFinding = acts.examples.ExaTrkXTrackFinding(
inputMLModuleDir=str(onnxModelDir),
# For now we don't configure only the common options so this works
exaTrkxModule = (
acts.examples.ExaTrkXTrackFindingTorch
if backend == ExaTrkXBackend.Torch
else acts.examples.ExaTrkXTrackFindingOnnx
)

exaTrkxFinding = exaTrkxModule(
modelDir=str(modelDir),
spacepointFeatures=3,
embeddingDim=8,
rVal=1.6,
Expand Down

0 comments on commit b409627

Please sign in to comment.