Skip to content

Commit

Permalink
docs: Exa.TrkX (#1517)
Browse files Browse the repository at this point in the history
Depends on #1473

Co-authored-by: Andreas Stefl <487211+andiwand@users.noreply.github.com>
  • Loading branch information
benjaminhuth and andiwand committed Oct 3, 2022
1 parent fecfe46 commit 2c40516
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,16 @@
#include "Acts/Plugins/ExaTrkX/ExaTrkXTiming.hpp"
#include "Acts/Utilities/Logger.hpp"

#include <optional>
#include <string>
#include <vector>

namespace Acts {

struct ExaTrkXTime;

/// @class ExaTrkXTrackFindingBase
///
/// @brief Base class for all implementations of the Exa.TrkX pipeline
///
class ExaTrkXTrackFindingBase {
public:
/// Constructor
Expand All @@ -37,11 +37,15 @@ class ExaTrkXTrackFindingBase {

/// Run the inference
///
/// @param inputValues tPacked spacepoints in the form
/// [ r1, phi1, z1, r2, phi2, z2, ... ]
/// @param spacepointIDs The corresponding spacepoint spacepoint spacepointIDs
/// @param trackCandidates This vector is filled with the tracks as vectors of spacepoint spacepoint IDs
/// @param inputValues Spacepoint data as a flattened NxD array, where D is
/// the dimensionality of a spacepoint (usually 3, but additional information
/// like cell information can be provided).
/// @param spacepointIDs The corresponding spacepoint IDs
/// @param trackCandidates This vector is filled with the tracks as vectors
/// of spacepoint IDs
/// @param logger If provided, logging is enabled
/// @param recordTiming If enabled, returns a @ref ExaTrkXTime object with
/// measured timings
/// @note The input values are not const, because the ONNX API
/// takes only non-const pointers.
virtual std::optional<ExaTrkXTime> getTracks(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
#include "Acts/Plugins/ExaTrkX/ExaTrkXTrackFindingBase.hpp"

#include <memory>
#include <string>
#include <vector>

namespace Ort {
class Env;
Expand All @@ -22,9 +20,9 @@ class Value;

namespace Acts {

/// @class ExaTrkXTrackFindingOnnx
/// @brief Implementation of the Exa.TrkX track finding algorithm based on ONNX.
/// Uses cugraph as graph library.
///
/// Class implementing the Exa.TrkX track finding algorithm based on ONNX
class ExaTrkXTrackFindingOnnx final : public ExaTrkXTrackFindingBase {
public:
/// Configuration struct for the track finding.
Expand All @@ -50,12 +48,15 @@ class ExaTrkXTrackFindingOnnx final : public ExaTrkXTrackFindingBase {

/// Run the inference
///
/// @param inputValues tPacked spacepoints in the form
/// [ r1, phi1, z1, r2, phi2, z2, ... ]
/// @param spacepointIDs The corresponding spacepoint spacepoint spacepointIDs
/// @param trackCandidates This vector is filled with the tracks as vectors of spacepoint spacepoint IDs
/// @param inputValues Spacepoint data as a flattened NxD array, where D is
/// the dimensionality of a spacepoint (usually 3, but additional information
/// like cell information can be provided).
/// @param spacepointIDs The corresponding spacepoint IDs
/// @param trackCandidates This vector is filled with the tracks as vectors
/// of spacepoint IDs
/// @param logger If provided, logging is enabled
/// @param recordTiming If enabled, returns a @class ExaTrkXTime object with measured timings
/// @param recordTiming If enabled, returns a @ref ExaTrkXTime object with
/// measured timings
/// @note The input values are not const, because the ONNX API
/// takes only non-const pointers.
std::optional<ExaTrkXTime> getTracks(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,16 @@
#include "Acts/Plugins/ExaTrkX/ExaTrkXTrackFindingBase.hpp"

#include <memory>
#include <string>
#include <vector>

namespace torch::jit {
class Module;
}

namespace Acts {

/// @class ExaTrkXTrackFindingTorch
/// @brief Class implementing the Exa.TrkX track finding algorithm based on
/// libtorch. Uses Boost.Graph for as graph library
///
/// Class implementing the Exa.TrkX track finding algorithm based on libtorch
class ExaTrkXTrackFindingTorch final : public ExaTrkXTrackFindingBase {
public:
/// Configuration struct for the track finding
Expand All @@ -49,12 +47,13 @@ class ExaTrkXTrackFindingTorch final : public ExaTrkXTrackFindingBase {

/// Run the inference
///
/// @param inputValues tPacked spacepoints in the form
/// [ r1, phi1, z1, r2, phi2, z2, ... ]
/// @param spacepointIDs The corresponding spacepoint spacepoint spacepointIDs
/// @param trackCandidates This vector is filled with the tracks as vectors of spacepoint spacepoint IDs
/// @param inputValues Spacepoint data as a flattened NxD array, where D is
/// the dimensionality of a spacepoint (usually 3, but additional information
/// like cell information can be provided).
/// @param spacepointIDs The corresponding spacepoint IDs
/// @param trackCandidates This vector is filled with the tracks as vectors of spacepoint IDs
/// @param logger If provided, logging is enabled
/// @param recordTiming If enabled, returns a @class ExaTrkXTime object with measured timings
/// @param recordTiming If enabled, returns a ExaTrkXTime object with measured timings
/// @note The input values are not const, because the ONNX API
/// takes only non-const pointers.
std::optional<ExaTrkXTime> getTracks(
Expand Down
6 changes: 4 additions & 2 deletions docs/Doxyfile
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,8 @@ STRIP_FROM_INC_PATH = ../Core/include \
../Plugins/Identification/include \
../Plugins/Json/include \
../Plugins/Legacy/include \
../Plugins/TGeo/include
../Plugins/TGeo/include \
../Plugins/ExaTrkX/include

# If the SHORT_NAMES tag is set to YES, doxygen will generate much shorter (but
# less readable) file names. This can be useful is your file systems doesn't
Expand Down Expand Up @@ -780,7 +781,8 @@ INPUT = ../Core/include \
../Plugins/Json/include \
../Plugins/Legacy/include \
../Plugins/Onnx/include \
../Plugins/TGeo/include
../Plugins/TGeo/include \
../Plugins/ExaTrkX/include \

# This tag can be used to specify the character encoding of the source files
# that doxygen parses. Internally doxygen uses the UTF-8 encoding. Doxygen uses
Expand Down
3 changes: 3 additions & 0 deletions docs/core/track_finding.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,6 @@ This is a stub!
(ckf_core)=
## Combinatorial Kalman Filter

## Machine-Learning based Track Finding

There is a lot of research ongoing about machine-learning based approaches to Track Finding. Because these are not yet stable and bullet-prove, there are no such algorithms distributed with the core library. However, there exists a [plugin](exatrkxplugin), that implements the *Exa.TrkX* algorithm in ACTS.
57 changes: 56 additions & 1 deletion docs/plugins/exatrkx.md
Original file line number Diff line number Diff line change
@@ -1 +1,56 @@
# ExaTrkX
(exatrkxplugin)=
# Exa.TrkX plugin

This plugin contains a track finding module based on Graph Neural Networks (GNNs) which is developed by the [Exa.TrkX](https://exatrkx.github.io/) team. Build instructions and dependencies can be found in the [README](https://github.com/acts-project/acts/blob/main/Plugins/ExaTrkX/README.md) of the plugin.

## Model

The Exa.TrkX pipeline is a three-stage GNN-based algorithm:

1) **Graph building** (*Embedding stage*): This is done via metric learning approach. A neural network tries to learn a mapping that minimizes the distance between points of the same track in the embedded space. In this embedded space then a graph is built using a fixed nearest-neighbor search.
2) **Graph size reduction** (*Filter stage*): The resulting graph is typically to large for processing with a GNN. Therefore a binary classifier reduces the graph size further and removes easy to identify not matching edges.
3) **Edge classification** (*GNN stage*): Finally, a GNN is used to find the edges in the Graph that belong to tracks. This is done by scoring each edge with a number between 0 and 1, and a corresponding *edge cut*.

Finally, the track candidates must be extracted by a algorithm that finds the *weakly connected components* in the graph depending on the *edge cut*.

## Implementation

### Neural network backend

Currently there are two backends available:

- The **TorchScript** backend requires besides *libtorch* also *TorchScatter* that implements the scatter operators used in the graph neural network. The corresponding class is {class}`Acts::ExaTrkXTrackFindingTorch`.
- The **ONNX** backend is currently not as well maintained as the TorchScript backend. The main reason is, that some scattering operators are not available in ONNX, so that the graph neural network cannot be exported correctely in the `.onnx` format. The corresponding class is {class}`Acts::ExaTrkXTrackFindingOnnx`.

### Graph building

Both backends use currently different libraries for graph building.

- The TorchScript backend uses *Boost.Graph* for graph building.
- The ONNX backend uses the *cugraph* library for graph building

## API and examples integration

The interface of the backends is defined by {class}`Acts::ExaTrkXTrackFindingBase` which is also used by `ActsExamples::TrackFindingAlgorithmExaTrkX`. The inference can be called with the `getTracks` function:

```{doxygenclass} Acts::ExaTrkXTrackFindingBase
---
outline:
members: getTracks
---
```

This function takes the the input data as a `std::vector<double>` (e.g., a flattened $N \times 3$ array in cylindric coordinates like $[r_0, \varphi_0, z_0, r_1, \dots \varphi_N, z_N]$), as well as some corresponding spacepoint ids as a `std::vector<int>`. It then fills a `std::vector<std::vector<int>>` with the found track candidates using the provided spacepoint ids. Logging and timing measurements can be enabled with the remaining arguments. The hyperparameters of the models are defined in the `Config` member structs.

:::{note}
Any kind of preprocessing (scaling, ...) of the input values must be done before passing them to the inference.
:::

See [here](https://github.com/acts-project/acts/blob/main/Examples/Scripts/Python/exatrkx.py) for the corresponding python example.

## Ressources

* Talk by *Daniel Murnane* at the [Connecting the Dots 2020](https://indico.cern.ch/event/831165/contributions/3717124/attachments/2024241/3385587/GNNs_for_Track_Finding.pdf)
* Talk by *Daniel Murnane* at the [vCHEP 2021](https://indico.cern.ch/event/948465/contributions/4323753/attachments/2246789/3810686/Physics%20and%20Computing%20Performance%20of%20the%20ExaTrkX%20TrackML%20Pipeline.pdf)
* Talk by *Alina Lazar* at the [ACAT 2021](https://indico.cern.ch/event/855454/contributions/4605079/attachments/2357191/4022841/ExaTrkX%20Inference%20-%20ACAT21%20v7.pdf)
* Talk by *Benjamin Huth* at the [ICHEP 2022](https://agenda.infn.it/event/28874/contributions/169199/attachments/94163/128944/slides_benjamin_huth_exatkrkx_acts.pdf)

0 comments on commit 2c40516

Please sign in to comment.