-
Notifications
You must be signed in to change notification settings - Fork 159
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' into unit-test-geometric-digitization
- Loading branch information
Showing
29 changed files
with
624 additions
and
149 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
add_library( | ||
ActsExamplesFrameworkML SHARED | ||
src/NeuralCalibrator.cpp | ||
) | ||
|
||
target_include_directories( | ||
ActsExamplesFrameworkML | ||
PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>) | ||
|
||
target_link_libraries( | ||
ActsExamplesFrameworkML | ||
PUBLIC ActsExamplesFramework ActsPluginOnnx | ||
) | ||
|
||
install( | ||
TARGETS ActsExamplesFrameworkML | ||
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}) | ||
|
||
install( | ||
DIRECTORY include/ActsExamples | ||
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}) |
72 changes: 72 additions & 0 deletions
72
Examples/Framework/ML/include/ActsExamples/EventData/NeuralCalibrator.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
// This file is part of the Acts project. | ||
// | ||
// Copyright (C) 2023 CERN for the benefit of the Acts project | ||
// | ||
// This Source Code Form is subject to the terms of the Mozilla Public | ||
// License, v. 2.0. If a copy of the MPL was not distributed with this | ||
// file, You can obtain one at http://mozilla.org/MPL/2.0/. | ||
|
||
#pragma once | ||
|
||
#include <Acts/Plugins/Onnx/OnnxRuntimeBase.hpp> | ||
#include <ActsExamples/EventData/MeasurementCalibration.hpp> | ||
|
||
#include <filesystem> | ||
|
||
namespace ActsExamples { | ||
|
||
class NeuralCalibrator : public MeasurementCalibrator { | ||
public: | ||
/// Measurement position calibration based on mixture density network | ||
/// (MDN) model. The model takes as input: | ||
/// | ||
/// - A 7x7 charge matrix centered on the center pixel of the cluster; | ||
/// - The volume and layer identifiers from | ||
/// the GeometryIdentifier of the containing surface; | ||
/// - The bound phi and theta angles of the predicted track state; | ||
/// - The initial estimated position | ||
/// - The initial estimated variance | ||
/// | ||
/// Given these inputs, a mixture density network estimates | ||
/// the parameters of a gaussian mixture model: | ||
/// | ||
/// P(Y|X) = \sum_i P(Prior_i) N(Y|Mean_i(X), Variance_i(X)) | ||
/// | ||
/// These are translated to single position + variance estimate by | ||
/// taking the most probable value based on the estimated priors. | ||
/// The measurements are assumed to be 2-dimensional. | ||
/// | ||
/// This class implements the MeasurementCalibrator interface, and | ||
/// therefore internally computes the network input and runs the | ||
/// inference engine itself. | ||
/// | ||
/// @param [in] modelPath The path to the .onnx model file | ||
/// @param [in] nComponent The number of components in the gaussian mixture | ||
/// @param [in] volumes The volume ids for which to apply the calibration | ||
NeuralCalibrator(const std::filesystem::path& modelPath, | ||
size_t nComponents = 1, | ||
std::vector<size_t> volumeIds = {7, 8, 9}); | ||
|
||
/// The MeasurementCalibrator interface methods | ||
void calibrate( | ||
const MeasurementContainer& measurements, | ||
const ClusterContainer* clusters, const Acts::GeometryContext& /*gctx*/, | ||
Acts::MultiTrajectory<Acts::VectorMultiTrajectory>::TrackStateProxy& | ||
trackState) const override; | ||
|
||
bool needsClusters() const override { return true; } | ||
|
||
private: | ||
Ort::Env m_env; | ||
Acts::OnnxRuntimeBase m_model; | ||
size_t m_nComponents; | ||
size_t m_nInputs = | ||
57; // TODO make this configurable? e.g. for changing matrix size? | ||
|
||
// TODO: this should probably be handled outside of the calibrator, | ||
// by setting up a GeometryHierarchyMap<MeasurementCalibrator> | ||
std::vector<size_t> m_volumeIds; | ||
PassThroughCalibrator m_fallback; | ||
}; | ||
|
||
} // namespace ActsExamples |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
// This file is part of the Acts project. | ||
// | ||
// Copyright (C) 2023 CERN for the benefit of the Acts project | ||
// | ||
// This Source Code Form is subject to the terms of the Mozilla Public | ||
// License, v. 2.0. If a copy of the MPL was not distributed with this | ||
// file, You can obtain one at http://mozilla.org/MPL/2.0/. | ||
|
||
#include <ActsExamples/EventData/NeuralCalibrator.hpp> | ||
|
||
#include <TFile.h> | ||
|
||
namespace detail { | ||
|
||
template <typename Array> | ||
size_t fillChargeMatrix(Array& arr, const ActsExamples::Cluster& cluster, | ||
size_t size0 = 7u, size_t size1 = 7u) { | ||
// First, rescale the activations to sum to unity. This promotes | ||
// numerical stability in the index computation | ||
double totalAct = 0; | ||
for (const ActsExamples::Cluster::Cell& cell : cluster.channels) { | ||
totalAct += cell.activation; | ||
} | ||
std::vector<double> weights; | ||
for (const ActsExamples::Cluster::Cell& cell : cluster.channels) { | ||
weights.push_back(cell.activation / totalAct); | ||
} | ||
|
||
double acc0 = 0; | ||
double acc1 = 0; | ||
for (size_t i = 0; i < cluster.channels.size(); i++) { | ||
acc0 += cluster.channels.at(i).bin[0] * weights.at(i); | ||
acc1 += cluster.channels.at(i).bin[1] * weights.at(i); | ||
} | ||
|
||
// By convention, put the center pixel in the middle cell. | ||
// Achieved by translating the cluster --> compute the offsets | ||
int offset0 = static_cast<int>(acc0) - size0 / 2; | ||
int offset1 = static_cast<int>(acc1) - size1 / 2; | ||
|
||
// Zero the charge matrix first, to guard against leftovers | ||
arr = Eigen::ArrayXXf::Zero(1, size0 * size1); | ||
// Fill the matrix | ||
for (const ActsExamples::Cluster::Cell& cell : cluster.channels) { | ||
// Translate each pixel | ||
int iMat = cell.bin[0] - offset0; | ||
int jMat = cell.bin[1] - offset1; | ||
if (iMat >= 0 && iMat < (int)size0 && jMat >= 0 && jMat < (int)size1) { | ||
typename Array::Index index = iMat * size0 + jMat; | ||
if (index < arr.size()) { | ||
arr(index) = cell.activation; | ||
} | ||
} | ||
} | ||
return size0 * size1; | ||
} | ||
|
||
} // namespace detail | ||
|
||
ActsExamples::NeuralCalibrator::NeuralCalibrator( | ||
const std::filesystem::path& modelPath, size_t nComponents, | ||
std::vector<size_t> volumeIds) | ||
: m_env(ORT_LOGGING_LEVEL_WARNING, "NeuralCalibrator"), | ||
m_model(m_env, modelPath.c_str()), | ||
m_nComponents{nComponents}, | ||
m_volumeIds{std::move(volumeIds)} {} | ||
|
||
void ActsExamples::NeuralCalibrator::calibrate( | ||
const MeasurementContainer& measurements, const ClusterContainer* clusters, | ||
const Acts::GeometryContext& gctx, | ||
Acts::MultiTrajectory<Acts::VectorMultiTrajectory>::TrackStateProxy& | ||
trackState) const { | ||
Acts::SourceLink usl = trackState.getUncalibratedSourceLink(); | ||
const IndexSourceLink& sourceLink = usl.get<IndexSourceLink>(); | ||
assert((sourceLink.index() < measurements.size()) and | ||
"Source link index is outside the container bounds"); | ||
|
||
if (std::find(m_volumeIds.begin(), m_volumeIds.end(), | ||
sourceLink.geometryId().volume()) == m_volumeIds.end()) { | ||
m_fallback.calibrate(measurements, clusters, gctx, trackState); | ||
return; | ||
} | ||
|
||
Acts::NetworkBatchInput inputBatch(1, m_nInputs); | ||
auto input = inputBatch(0, Eigen::all); | ||
|
||
// TODO: Matrix size should be configurable perhaps? | ||
size_t matSize0 = 7u; | ||
size_t matSize1 = 7u; | ||
size_t iInput = ::detail::fillChargeMatrix( | ||
input, (*clusters)[sourceLink.index()], matSize0, matSize1); | ||
|
||
input[iInput++] = sourceLink.geometryId().volume(); | ||
input[iInput++] = sourceLink.geometryId().layer(); | ||
input[iInput++] = trackState.parameters()[Acts::eBoundPhi]; | ||
input[iInput++] = trackState.parameters()[Acts::eBoundTheta]; | ||
|
||
std::visit( | ||
[&](const auto& measurement) { | ||
auto E = measurement.expander(); | ||
auto P = measurement.projector(); | ||
Acts::ActsVector<Acts::eBoundSize> fpar = E * measurement.parameters(); | ||
Acts::ActsSymMatrix<Acts::eBoundSize> fcov = | ||
E * measurement.covariance() * E.transpose(); | ||
|
||
input[iInput++] = fpar[Acts::eBoundLoc0]; | ||
input[iInput++] = fpar[Acts::eBoundLoc1]; | ||
input[iInput++] = fcov(Acts::eBoundLoc0, Acts::eBoundLoc0); | ||
input[iInput++] = fcov(Acts::eBoundLoc1, Acts::eBoundLoc1); | ||
if (iInput != m_nInputs) { | ||
throw std::runtime_error("Expected input size of " + | ||
std::to_string(m_nInputs) + | ||
", got: " + std::to_string(iInput)); | ||
} | ||
|
||
// Input is a single row, hence .front() | ||
std::vector<float> output = | ||
m_model.runONNXInference(inputBatch).front(); | ||
// Assuming 2-D measurements, the expected params structure is: | ||
// [ 0, nComponent[ --> priors | ||
// [ nComponent, 3*nComponent[ --> means | ||
// [3*nComponent, 5*nComponent[ --> variances | ||
size_t nParams = 5 * m_nComponents; | ||
if (output.size() != nParams) { | ||
throw std::runtime_error( | ||
"Got output vector of size " + std::to_string(output.size()) + | ||
", expected size " + std::to_string(nParams)); | ||
} | ||
|
||
// Most probable value computation of mixture density | ||
size_t iMax = 0; | ||
if (m_nComponents > 1) { | ||
iMax = std::distance( | ||
output.begin(), | ||
std::max_element(output.begin(), output.begin() + m_nComponents)); | ||
} | ||
size_t iLoc0 = m_nComponents + iMax * 2; | ||
size_t iVar0 = 3 * m_nComponents + iMax * 2; | ||
|
||
fpar[Acts::eBoundLoc0] = output[iLoc0]; | ||
fpar[Acts::eBoundLoc1] = output[iLoc0 + 1]; | ||
fcov(Acts::eBoundLoc0, Acts::eBoundLoc0) = output[iVar0]; | ||
fcov(Acts::eBoundLoc1, Acts::eBoundLoc1) = output[iVar0 + 1]; | ||
|
||
constexpr size_t kSize = | ||
std::remove_reference_t<decltype(measurement)>::size(); | ||
std::array<Acts::BoundIndices, kSize> indices = measurement.indices(); | ||
Acts::ActsVector<kSize> cpar = P * fpar; | ||
Acts::ActsSymMatrix<kSize> ccov = P * fcov * P.transpose(); | ||
|
||
Acts::SourceLink sl{sourceLink.geometryId(), sourceLink}; | ||
|
||
Acts::Measurement<Acts::BoundIndices, kSize> calibrated( | ||
std::move(sl), indices, cpar, ccov); | ||
|
||
trackState.allocateCalibrated(calibrated.size()); | ||
trackState.setCalibrated(calibrated); | ||
}, | ||
(measurements)[sourceLink.index()]); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.