Skip to content

Commit

Permalink
feat: Exa.TrkX abstract pipeline & metric hook (#2392)
Browse files Browse the repository at this point in the history
This Does the following:
* Abstracts the Exa.TrkX pipeline from the examples algorithm to the plugin. Exposes the pipeline to python, so it is possible to run and test the pipeline without the examples framework (e.g., directly with `*.pyg` or `*.csv` files):

```python
spacepoints = list(np.random.uniform((100,3)).flatten())
spacepointIDs = list(np.arange(100))

pipeline = acts.examples.Pipeline(emb, [flt,gnn], trk, acts.logging.VERBOSE)
pipeline.run(spacepoints, spacepointIDs)
```

* Adds a interface for a pipeline hook, that is invoked after each stage with `nodes` and `edges` as arguments.
* Adds a implementation of the hook that prints the edge-based metrics of the graph with respect to the truth graph
* Extends the examples algorithm so it can generate the truth graph and a metric-hook:

```
08:52:52    EdgeClassifi   VERBOSE   Memory (used / total) [in MB]: 22155.4 / 40377.2
08:52:52    MetricsHook    INFO      Metrics for total graph:
08:52:52    MetricsHook    INFO      Efficiency=0.579324, purity=0.319214
08:52:52    MetricsHook    INFO      Metrics for target graph (pT > 0.5 GeV, nHits >= 3):
08:52:52    MetricsHook    INFO      Efficiency=0.968275, purity=0.138659
08:52:52    EdgeClassifi   DEBUG     Start edge classification
```
  • Loading branch information
benjaminhuth committed Aug 30, 2023
1 parent 241c1e1 commit 5e57328
Show file tree
Hide file tree
Showing 11 changed files with 579 additions and 69 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,13 @@

#pragma once

#include "Acts/Definitions/Units.hpp"
#include "Acts/Plugins/ExaTrkX/ExaTrkXPipeline.hpp"
#include "Acts/Plugins/ExaTrkX/Stages.hpp"
#include "ActsExamples/EventData/Cluster.hpp"
#include "ActsExamples/EventData/ProtoTrack.hpp"
#include "ActsExamples/EventData/SimHit.hpp"
#include "ActsExamples/EventData/SimParticle.hpp"
#include "ActsExamples/EventData/SimSpacePoint.hpp"
#include "ActsExamples/Framework/DataHandle.hpp"
#include "ActsExamples/Framework/IAlgorithm.hpp"
Expand All @@ -35,6 +38,13 @@ class TrackFindingAlgorithmExaTrkX final : public IAlgorithm {
/// * cluster size in local y
std::string inputClusters;

/// Input simhits (Optional).
std::string inputSimHits;
/// Input measurement simhit map (Optional).
std::string inputParticles;
/// Input measurement simhit map (Optional).
std::string inputMeasurementSimhitsMap;

/// Output protoTracks collection.
std::string outputProtoTracks;

Expand All @@ -52,6 +62,10 @@ class TrackFindingAlgorithmExaTrkX final : public IAlgorithm {
float cellSumScale = 1.f;
float clusterXScale = 1.f;
float clusterYScale = 1.f;

/// Target graph properties
std::size_t targetMinHits = 3;
double targetMinPT = 500 * Acts::UnitConstants::MeV;
};

/// Constructor of the track finding algorithm
Expand All @@ -72,18 +86,23 @@ class TrackFindingAlgorithmExaTrkX final : public IAlgorithm {
const Config& config() const { return m_cfg; }

private:
std::vector<std::vector<int>> runPipeline(
std::vector<float>& inputValues, std::vector<int>& spacepointIDs) const;

// configuration
Config m_cfg;

Acts::ExaTrkXPipeline m_pipeline;

ReadDataHandle<SimSpacePointContainer> m_inputSpacePoints{this,
"InputSpacePoints"};
ReadDataHandle<ClusterContainer> m_inputClusters{this, "InputClusters"};

WriteDataHandle<ProtoTrackContainer> m_outputProtoTracks{this,
"OutputProtoTracks"};

// for truth graph
ReadDataHandle<SimHitContainer> m_inputSimHits{this, "InputSimHits"};
ReadDataHandle<SimParticleContainer> m_inputParticles{this, "InputParticles"};
ReadDataHandle<IndexMultimap<Index>> m_inputMeasurementMap{
this, "InputMeasurementMap"};
};

} // namespace ActsExamples
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

#include "ActsExamples/TrackFindingExaTrkX/TrackFindingAlgorithmExaTrkX.hpp"

#include "Acts/Definitions/Units.hpp"
#include "Acts/Plugins/ExaTrkX/TorchTruthGraphMetricsHook.hpp"
#include "ActsExamples/EventData/Index.hpp"
#include "ActsExamples/EventData/IndexSourceLink.hpp"
#include "ActsExamples/EventData/ProtoTrack.hpp"
Expand All @@ -16,35 +18,120 @@

#include <numeric>

using namespace ActsExamples;
using namespace Acts::UnitLiterals;

namespace {

class ExamplesEdmHook : public Acts::ExaTrkXHook {
double m_targetPT = 0.5_GeV;
std::size_t m_targetSize = 3;

std::unique_ptr<const Acts::Logger> m_logger;
std::unique_ptr<Acts::TorchTruthGraphMetricsHook> m_truthGraphHook;
std::unique_ptr<Acts::TorchTruthGraphMetricsHook> m_targetGraphHook;

const Acts::Logger& logger() const { return *m_logger; }

struct HitInfo {
std::size_t spacePointIndex;
int32_t hitIndex;
};

public:
ExamplesEdmHook(const SimSpacePointContainer& spacepoints,
const IndexMultimap<Index>& measHitMap,
const SimHitContainer& simhits,
const SimParticleContainer& particles,
std::size_t targetMinHits, double targetMinPT,
const Acts::Logger& logger)
: m_targetPT(targetMinPT),
m_targetSize(targetMinHits),
m_logger(logger.clone("MetricsHook")) {
// Associate tracks to graph, collect momentum
std::unordered_map<ActsFatras::Barcode, std::vector<HitInfo>> tracks;

for (auto i = 0ul; i < spacepoints.size(); ++i) {
const auto measId = spacepoints[i]
.sourceLinks()[0]
.template get<IndexSourceLink>()
.index();

auto [a, b] = measHitMap.equal_range(measId);
for (auto it = a; it != b; ++it) {
const auto& hit = *simhits.nth(it->second);

tracks[hit.particleId()].push_back({i, hit.index()});
}
}

// Collect edges for truth graph and target graph
std::vector<int64_t> truthGraph;
std::vector<int64_t> targetGraph;

for (auto& [pid, track] : tracks) {
// Sort by hit index, so the edges are connected correctly
std::sort(track.begin(), track.end(), [](const auto& a, const auto& b) {
return a.hitIndex < b.hitIndex;
});

auto found = particles.find(pid);
if (found == particles.end()) {
ACTS_WARNING("Did not find " << pid << ", skip track");
continue;
}

for (auto i = 0ul; i < track.size() - 1; ++i) {
truthGraph.push_back(track[i].spacePointIndex);
truthGraph.push_back(track[i + 1].spacePointIndex);

if (found->transverseMomentum() > m_targetPT &&
track.size() >= m_targetSize) {
targetGraph.push_back(track[i].spacePointIndex);
targetGraph.push_back(track[i + 1].spacePointIndex);
}
}
}

m_truthGraphHook = std::make_unique<Acts::TorchTruthGraphMetricsHook>(
truthGraph, logger.clone());
m_targetGraphHook = std::make_unique<Acts::TorchTruthGraphMetricsHook>(
targetGraph, logger.clone());
}

~ExamplesEdmHook(){};

void operator()(const std::any& nodes, const std::any& edges) const override {
ACTS_INFO("Metrics for total graph:");
(*m_truthGraphHook)(nodes, edges);
ACTS_INFO("Metrics for target graph (pT > "
<< m_targetPT / Acts::UnitConstants::GeV
<< " GeV, nHits >= " << m_targetSize << "):");
(*m_targetGraphHook)(nodes, edges);
}
};

} // namespace

ActsExamples::TrackFindingAlgorithmExaTrkX::TrackFindingAlgorithmExaTrkX(
Config config, Acts::Logging::Level level)
: ActsExamples::IAlgorithm("TrackFindingMLBasedAlgorithm", level),
m_cfg(std::move(config)) {
m_cfg(std::move(config)),
m_pipeline(m_cfg.graphConstructor, m_cfg.edgeClassifiers,
m_cfg.trackBuilder, logger().clone()) {
if (m_cfg.inputSpacePoints.empty()) {
throw std::invalid_argument("Missing spacepoint input collection");
}
if (m_cfg.outputProtoTracks.empty()) {
throw std::invalid_argument("Missing protoTrack output collection");
}
if (!m_cfg.graphConstructor) {
throw std::invalid_argument("Missing graph construction module");
}
if (!m_cfg.trackBuilder) {
throw std::invalid_argument("Missing track building module");
}
if (m_cfg.edgeClassifiers.empty() or
not std::all_of(m_cfg.edgeClassifiers.begin(),
m_cfg.edgeClassifiers.end(),
[](const auto& a) { return static_cast<bool>(a); })) {
throw std::invalid_argument("Missing graph construction module");
}

// Sanitizer run with dummy input to detect configuration issues
// TODO This would be quite helpful I think, but currently it does not work in
// general because the stages do not expose the number of node features.
// TODO This would be quite helpful I think, but currently it does not work
// in general because the stages do not expose the number of node features.
// However, this must be addressed anyways when we also want to allow to
// configure this more flexible with e.g. cluster information as input. So for
// now, we disable this.
// configure this more flexible with e.g. cluster information as input. So
// for now, we disable this.
#if 0
if( m_cfg.sanitize ) {
Eigen::VectorXf dummyInput = Eigen::VectorXf::Random(3 * 15);
Expand All @@ -59,23 +146,10 @@ ActsExamples::TrackFindingAlgorithmExaTrkX::TrackFindingAlgorithmExaTrkX(

m_inputSpacePoints.initialize(m_cfg.inputSpacePoints);
m_outputProtoTracks.initialize(m_cfg.outputProtoTracks);
}

std::vector<std::vector<int>>
ActsExamples::TrackFindingAlgorithmExaTrkX::runPipeline(
std::vector<float>& inputValues, std::vector<int>& spacepointIDs) const {
auto [nodes, edges] =
(*m_cfg.graphConstructor)(inputValues, spacepointIDs.size());
std::any edge_weights;

for (auto edgeClassifier : m_cfg.edgeClassifiers) {
auto [newNodes, newEdges, newWeights] = (*edgeClassifier)(nodes, edges);
nodes = newNodes;
edges = newEdges;
edge_weights = newWeights;
}

return (*m_cfg.trackBuilder)(nodes, edges, edge_weights, spacepointIDs);
m_inputSimHits.maybeInitialize(m_cfg.inputSimHits);
m_inputParticles.maybeInitialize(m_cfg.inputParticles);
m_inputMeasurementMap.maybeInitialize(m_cfg.inputMeasurementSimhitsMap);
}

/// Allow access to features with nice names
Expand All @@ -92,7 +166,15 @@ enum feat : std::size_t {
ActsExamples::ProcessCode ActsExamples::TrackFindingAlgorithmExaTrkX::execute(
const ActsExamples::AlgorithmContext& ctx) const {
// Read input data
const auto& spacepoints = m_inputSpacePoints(ctx);
auto spacepoints = m_inputSpacePoints(ctx);

auto hook = std::make_unique<Acts::ExaTrkXHook>();
if (m_inputSimHits.isInitialized() && m_inputMeasurementMap.isInitialized()) {
hook = std::make_unique<ExamplesEdmHook>(
spacepoints, m_inputMeasurementMap(ctx), m_inputSimHits(ctx),
m_inputParticles(ctx), m_cfg.targetMinHits, m_cfg.targetMinPT,
logger());
}

std::optional<ClusterContainer> clusters;
if (m_inputClusters.isInitialized()) {
Expand Down Expand Up @@ -151,7 +233,10 @@ ActsExamples::ProcessCode ActsExamples::TrackFindingAlgorithmExaTrkX::execute(
ACTS_DEBUG("Avg activation: " << sumActivation / sumCells);

// Run the pipeline
const auto trackCandidates = runPipeline(features, spacepointIDs);
const auto trackCandidates = m_pipeline.run(features, spacepointIDs, *hook);

ACTS_DEBUG("Done with pipeline, received " << trackCandidates.size()
<< " candidates");

// Make the prototracks
std::vector<ProtoTrack> protoTracks;
Expand Down
48 changes: 44 additions & 4 deletions Examples/Python/src/ExaTrkXTrackFinding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@

#include "Acts/Plugins/ExaTrkX/BoostTrackBuilding.hpp"
#include "Acts/Plugins/ExaTrkX/CugraphTrackBuilding.hpp"
#include "Acts/Plugins/ExaTrkX/ExaTrkXPipeline.hpp"
#include "Acts/Plugins/ExaTrkX/OnnxEdgeClassifier.hpp"
#include "Acts/Plugins/ExaTrkX/OnnxMetricLearning.hpp"
#include "Acts/Plugins/ExaTrkX/TorchEdgeClassifier.hpp"
#include "Acts/Plugins/ExaTrkX/TorchMetricLearning.hpp"
#include "Acts/Plugins/ExaTrkX/TorchTruthGraphMetricsHook.hpp"
#include "Acts/Plugins/Python/Utilities.hpp"
#include "Acts/TrackFinding/MeasurementSelector.hpp"
#include "ActsExamples/TrackFinding/SeedingAlgorithm.hpp"
Expand All @@ -21,6 +23,7 @@

#include <memory>

#include <pybind11/functional.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

Expand Down Expand Up @@ -166,10 +169,47 @@ void addExaTrkXTrackFinding(Context &ctx) {

ACTS_PYTHON_DECLARE_ALGORITHM(
ActsExamples::TrackFindingAlgorithmExaTrkX, mex,
"TrackFindingAlgorithmExaTrkX", inputSpacePoints, inputClusters,
outputProtoTracks, graphConstructor, edgeClassifiers, trackBuilder,
rScale, phiScale, zScale, cellCountScale, cellSumScale, clusterXScale,
clusterYScale);
"TrackFindingAlgorithmExaTrkX", inputSpacePoints, inputSimHits,
inputParticles, inputMeasurementSimhitsMap, outputProtoTracks,
graphConstructor, edgeClassifiers, trackBuilder, rScale, phiScale, zScale,
targetMinHits, targetMinPT);

{
auto cls =
py::class_<Acts::ExaTrkXHook, std::shared_ptr<Acts::ExaTrkXHook>>(
mex, "ExaTrkXHook");
}

{
using Class = Acts::TorchTruthGraphMetricsHook;

auto cls = py::class_<Class, Acts::ExaTrkXHook, std::shared_ptr<Class>>(
mex, "TorchTruthGraphMetricsHook")
.def(py::init(
[](const std::vector<int64_t> &g, Logging::Level lvl) {
return std::make_shared<Class>(
g, getDefaultLogger("PipelineHook", lvl));
}));
}

{
using Class = Acts::ExaTrkXPipeline;

auto cls =
py::class_<Class, std::shared_ptr<Class>>(mex, "ExaTrkXPipeline")
.def(py::init(
[](std::shared_ptr<GraphConstructionBase> g,
std::vector<std::shared_ptr<EdgeClassificationBase>> e,
std::shared_ptr<TrackBuildingBase> t,
Logging::Level lvl) {
return std::make_shared<Class>(
g, e, t, getDefaultLogger("MetricLearning", lvl));
}),
py::arg("graphConstructor"), py::arg("edgeClassifiers"),
py::arg("trackBuilder"), py::arg("level"))
.def("run", &ExaTrkXPipeline::run, py::arg("features"),
py::arg("spacepoints"), py::arg("hook") = Acts::ExaTrkXHook{});
}
}

} // namespace Acts::Python
2 changes: 2 additions & 0 deletions Plugins/ExaTrkX/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
set(SOURCES
src/buildEdges.cpp
src/ExaTrkXPipeline.cpp
)

if(ACTS_EXATRKX_ENABLE_ONNX)
Expand All @@ -15,6 +16,7 @@ if(ACTS_EXATRKX_ENABLE_TORCH)
src/TorchEdgeClassifier.cpp
src/TorchMetricLearning.cpp
src/BoostTrackBuilding.cpp
src/TorchTruthGraphMetricsHook.cpp
)
endif()

Expand Down

0 comments on commit 5e57328

Please sign in to comment.