Skip to content

Commit

Permalink
refactor!: Add CalibrationContext to calibrator signature (#2354)
Browse files Browse the repository at this point in the history
This is currently unused by the calibrators, but should allow the calibrator to become
conditions-aware.

Blocked by:
- #2352

Closes #2274
  • Loading branch information
paulgessinger committed Aug 25, 2023
1 parent 32e0119 commit 77ca0e8
Show file tree
Hide file tree
Showing 19 changed files with 87 additions and 35 deletions.
6 changes: 5 additions & 1 deletion Core/include/Acts/TrackFinding/CombinatorialKalmanFilter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,9 @@ class CombinatorialKalmanFilter {
/// Whether to run smoothing to get fitted parameter
bool smoothing = true;

/// Calibration context for the finding run
const CalibrationContext* calibrationContext{nullptr};

/// @brief CombinatorialKalmanFilter actor operation
///
/// @tparam propagator_state_t Type of the Propagagor state
Expand Down Expand Up @@ -884,7 +887,7 @@ class CombinatorialKalmanFilter {
ts.setReferenceSurface(boundParams.referenceSurface().getSharedPtr());

// now calibrate the track state
m_extensions.calibrator(gctx, sourceLink, ts);
m_extensions.calibrator(gctx, calibrationContext, sourceLink, ts);

result.trackStateCandidates.push_back(ts);
}
Expand Down Expand Up @@ -1356,6 +1359,7 @@ class CombinatorialKalmanFilter {
combKalmanActor.actorLogger = m_actorLogger.get();
combKalmanActor.updaterLogger = m_updaterLogger.get();
combKalmanActor.smootherLogger = m_smootherLogger.get();
combKalmanActor.calibrationContext = &tfOptions.calibrationContext.get();

// copy source link accessor, calibrator and measurement selector
combKalmanActor.m_sourcelinkAccessor = tfOptions.sourcelinkAccessor;
Expand Down
13 changes: 9 additions & 4 deletions Core/include/Acts/TrackFitting/GlobalChiSquareFitter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,9 @@ struct Gx2FitterExtensions {
typename MultiTrajectory<traj_t>::ConstTrackStateProxy;
using Parameters = typename TrackStateProxy::Parameters;

using Calibrator = Delegate<void(const GeometryContext&, const SourceLink&,
TrackStateProxy)>;
using Calibrator =
Delegate<void(const GeometryContext&, const CalibrationContext&,
const SourceLink&, TrackStateProxy)>;

using Updater = Delegate<Result<void>(const GeometryContext&, TrackStateProxy,
Direction, const Logger&)>;
Expand Down Expand Up @@ -287,6 +288,9 @@ class Gx2Fitter {
/// The Surface being
SurfaceReached targetReached;

/// Calibration context for the fit
const CalibrationContext* calibrationContext{nullptr};

/// @brief Gx2f actor operation
///
/// @tparam propagator_state_t is the type of Propagator state
Expand Down Expand Up @@ -363,8 +367,8 @@ class Gx2Fitter {

// We have predicted parameters, so calibrate the uncalibrated input
// measurement
extensions.calibrator(state.geoContext, sourcelink_it->second,
trackStateProxy);
extensions.calibrator(state.geoContext, *calibrationContext,
sourcelink_it->second, trackStateProxy);

const size_t measdimPlaceholder = 2;
auto measurement =
Expand Down Expand Up @@ -504,6 +508,7 @@ class Gx2Fitter {
auto& gx2fActor = propagatorOptions.actionList.template get<GX2FActor>();
gx2fActor.inputMeasurements = &inputMeasurements;
gx2fActor.extensions = gx2fOptions.extensions;
gx2fActor.calibrationContext = &gx2fOptions.calibrationContext.get();
gx2fActor.actorLogger = m_actorLogger.get();

typename propagator_t::template action_list_t_result_t<
Expand Down
5 changes: 3 additions & 2 deletions Core/include/Acts/TrackFitting/GsfOptions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@ struct GsfExtensions {
using TrackStateProxy = typename traj_t::TrackStateProxy;
using ConstTrackStateProxy = typename traj_t::ConstTrackStateProxy;

using Calibrator = Delegate<void(const GeometryContext&, const SourceLink&,
TrackStateProxy)>;
using Calibrator =
Delegate<void(const GeometryContext&, const CalibrationContext&,
const SourceLink&, TrackStateProxy)>;

using Updater = Delegate<Result<void>(const GeometryContext&, TrackStateProxy,
Direction, const Logger&)>;
Expand Down
18 changes: 12 additions & 6 deletions Core/include/Acts/TrackFitting/KalmanFitter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,9 @@ struct KalmanFitterExtensions {
using ConstTrackStateProxy = typename traj_t::ConstTrackStateProxy;
using Parameters = typename TrackStateProxy::Parameters;

using Calibrator = Delegate<void(const GeometryContext&, const SourceLink&,
TrackStateProxy)>;
using Calibrator =
Delegate<void(const GeometryContext&, const CalibrationContext&,
const SourceLink&, TrackStateProxy)>;

using Smoother = Delegate<Result<void>(const GeometryContext&, traj_t&,
size_t, const Logger&)>;
Expand Down Expand Up @@ -324,6 +325,9 @@ class KalmanFitter {
/// The Surface being
SurfaceReached targetReached;

/// Calibration context for the fit
const CalibrationContext* calibrationContext{nullptr};

/// @brief Kalman actor operation
///
/// @tparam propagator_state_t is the type of Propagagor state
Expand Down Expand Up @@ -587,8 +591,9 @@ class KalmanFitter {
// do the kalman update (no need to perform covTransport here, hence no
// point in performing globalToLocal correction)
auto trackStateProxyRes = detail::kalmanHandleMeasurement(
state, stepper, extensions, *surface, sourcelink_it->second,
*result.fittedStates, result.lastTrackIndex, false, logger());
*calibrationContext, state, stepper, extensions, *surface,
sourcelink_it->second, *result.fittedStates, result.lastTrackIndex,
false, logger());

if (!trackStateProxyRes.ok()) {
return trackStateProxyRes.error();
Expand Down Expand Up @@ -730,8 +735,8 @@ class KalmanFitter {

// We have predicted parameters, so calibrate the uncalibrated input
// measuerement
extensions.calibrator(state.geoContext, sourcelink_it->second,
trackStateProxy);
extensions.calibrator(state.geoContext, *calibrationContext,
sourcelink_it->second, trackStateProxy);

// If the update is successful, set covariance and
auto updateRes = extensions.updater(state.geoContext, trackStateProxy,
Expand Down Expand Up @@ -1075,6 +1080,7 @@ class KalmanFitter {
kalmanActor.reversedFilteringCovarianceScaling =
kfOptions.reversedFilteringCovarianceScaling;
kalmanActor.freeToBoundCorrection = kfOptions.freeToBoundCorrection;
kalmanActor.calibrationContext = &kfOptions.calibrationContext.get();
kalmanActor.extensions = kfOptions.extensions;
kalmanActor.actorLogger = m_actorLogger.get();

Expand Down
10 changes: 8 additions & 2 deletions Core/include/Acts/TrackFitting/detail/GsfActor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,10 @@ struct GsfActor {
MixtureReductionMethod reductionMethod = MixtureReductionMethod::eMaxWeight;

const Logger* logger{nullptr};

/// Calibration context for the fit
const CalibrationContext* calibrationContext{nullptr};

} m_cfg;

const Logger& logger() const { return *m_cfg.logger; }
Expand Down Expand Up @@ -567,8 +571,9 @@ struct GsfActor {
const auto& singleStepper = cmp.singleStepper(stepper);

auto trackStateProxyRes = detail::kalmanHandleMeasurement(
singleState, singleStepper, m_cfg.extensions, surface, source_link,
tmpStates.traj, MultiTrajectoryTraits::kInvalid, false, logger());
*m_cfg.calibrationContext, singleState, singleStepper,
m_cfg.extensions, surface, source_link, tmpStates.traj,
MultiTrajectoryTraits::kInvalid, false, logger());

if (!trackStateProxyRes.ok()) {
return trackStateProxyRes.error();
Expand Down Expand Up @@ -799,6 +804,7 @@ struct GsfActor {
m_cfg.disableAllMaterialHandling = options.disableAllMaterialHandling;
m_cfg.weightCutoff = options.weightCutoff;
m_cfg.reductionMethod = options.stateReductionMethod;
m_cfg.calibrationContext = &options.calibrationContext.get();
}
};

Expand Down
10 changes: 6 additions & 4 deletions Core/include/Acts/TrackFitting/detail/KalmanUpdateHelpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "Acts/EventData/SourceLink.hpp"
#include "Acts/EventData/detail/CorrectedTransformationFreeToBound.hpp"
#include "Acts/Surfaces/Surface.hpp"
#include "Acts/Utilities/CalibrationContext.hpp"
#include "Acts/Utilities/Result.hpp"

namespace Acts {
Expand All @@ -35,9 +36,9 @@ namespace detail {
template <typename propagator_state_t, typename stepper_t,
typename extensions_t, typename traj_t>
auto kalmanHandleMeasurement(
propagator_state_t &state, const stepper_t &stepper,
const extensions_t &extensions, const Surface &surface,
const SourceLink &source_link, traj_t &fittedStates,
const CalibrationContext &calibrationContext, propagator_state_t &state,
const stepper_t &stepper, const extensions_t &extensions,
const Surface &surface, const SourceLink &source_link, traj_t &fittedStates,
const size_t lastTrackIndex, bool doCovTransport, const Logger &logger,
const FreeToBoundCorrection &freeToBoundCorrection = FreeToBoundCorrection(
false)) -> Result<typename traj_t::TrackStateProxy> {
Expand Down Expand Up @@ -70,7 +71,8 @@ auto kalmanHandleMeasurement(

// We have predicted parameters, so calibrate the uncalibrated input
// measuerement
extensions.calibrator(state.geoContext, source_link, trackStateProxy);
extensions.calibrator(state.geoContext, calibrationContext, source_link,
trackStateProxy);

// Get and set the type flags
auto typeFlags = trackStateProxy.typeFlags();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@

#pragma once

#include "Acts/Definitions/Direction.hpp"
#include "Acts/EventData/MultiTrajectory.hpp"
#include "Acts/EventData/SourceLink.hpp"
#include "Acts/Geometry/GeometryContext.hpp"
#include "Acts/Utilities/CalibrationContext.hpp"
#include "Acts/Utilities/Logger.hpp"
#include "Acts/Utilities/Result.hpp"
#include "Acts/Utilities/TypeTraits.hpp"
Expand All @@ -19,6 +21,7 @@ namespace Acts {

template <typename traj_t>
void voidKalmanCalibrator(const GeometryContext& /*gctx*/,
const CalibrationContext& /*cctx*/,
const SourceLink& /*sourceLink*/,
typename traj_t::TrackStateProxy /*trackState*/) {
throw std::runtime_error{"VoidKalmanCalibrator should not ever execute"};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "Acts/Geometry/GeometryContext.hpp"
#include "Acts/Geometry/GeometryIdentifier.hpp"
#include "Acts/Surfaces/Surface.hpp"
#include "Acts/Utilities/CalibrationContext.hpp"

namespace Acts {
class ConstVectorMultiTrajectory;
Expand All @@ -35,6 +36,7 @@ struct RefittingCalibrator {
};

void calibrate(const Acts::GeometryContext& gctx,
const Acts::CalibrationContext& cctx,
const Acts::SourceLink& sourceLink, Proxy trackState) const;
};

Expand Down
2 changes: 2 additions & 0 deletions Examples/Algorithms/TrackFitting/src/RefittingCalibrator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@
#include "Acts/Definitions/Algebra.hpp"
#include "Acts/EventData/MeasurementHelpers.hpp"
#include "Acts/EventData/SourceLink.hpp"
#include "Acts/Utilities/CalibrationContext.hpp"

namespace ActsExamples {

void RefittingCalibrator::calibrate(const Acts::GeometryContext& /*gctx*/,
const Acts::CalibrationContext& /*cctx*/,
const Acts::SourceLink& sourceLink,
Proxy trackState) const {
const auto sl = sourceLink.get<RefittingSourceLink>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#pragma once

#include "Acts/EventData/SourceLink.hpp"
#include "Acts/Utilities/CalibrationContext.hpp"
#include <Acts/Plugins/Onnx/OnnxRuntimeBase.hpp>
#include <ActsExamples/EventData/MeasurementCalibration.hpp>

Expand Down Expand Up @@ -52,7 +53,7 @@ class NeuralCalibrator : public MeasurementCalibrator {
void calibrate(
const MeasurementContainer& measurements,
const ClusterContainer* clusters, const Acts::GeometryContext& gctx,
const Acts::SourceLink& sourceLink,
const Acts::CalibrationContext& cctx, const Acts::SourceLink& sourceLink,
Acts::MultiTrajectory<Acts::VectorMultiTrajectory>::TrackStateProxy&
trackState) const override;

Expand Down
7 changes: 5 additions & 2 deletions Examples/Framework/ML/src/NeuralCalibrator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
// file, You can obtain one at http://mozilla.org/MPL/2.0/.

#include "Acts/EventData/SourceLink.hpp"
#include "Acts/Utilities/CalibrationContext.hpp"
#include <ActsExamples/EventData/NeuralCalibrator.hpp>

#include <TFile.h>
Expand Down Expand Up @@ -68,7 +69,8 @@ ActsExamples::NeuralCalibrator::NeuralCalibrator(

void ActsExamples::NeuralCalibrator::calibrate(
const MeasurementContainer& measurements, const ClusterContainer* clusters,
const Acts::GeometryContext& gctx, const Acts::SourceLink& sourceLink,
const Acts::GeometryContext& gctx, const Acts::CalibrationContext& cctx,
const Acts::SourceLink& sourceLink,
Acts::MultiTrajectory<Acts::VectorMultiTrajectory>::TrackStateProxy&
trackState) const {
trackState.setUncalibratedSourceLink(sourceLink);
Expand All @@ -78,7 +80,8 @@ void ActsExamples::NeuralCalibrator::calibrate(

if (std::find(m_volumeIds.begin(), m_volumeIds.end(),
idxSourceLink.geometryId().volume()) == m_volumeIds.end()) {
m_fallback.calibrate(measurements, clusters, gctx, sourceLink, trackState);
m_fallback.calibrate(measurements, clusters, gctx, cctx, sourceLink,
trackState);
return;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "Acts/EventData/SourceLink.hpp"
#include "Acts/EventData/VectorMultiTrajectory.hpp"
#include "Acts/Geometry/GeometryContext.hpp"
#include "Acts/Utilities/CalibrationContext.hpp"
#include "ActsExamples/EventData/Cluster.hpp"
#include "ActsExamples/EventData/IndexSourceLink.hpp"
#include <ActsExamples/EventData/Measurement.hpp>
Expand All @@ -30,7 +31,7 @@ class MeasurementCalibrator {
virtual void calibrate(
const MeasurementContainer& measurements,
const ClusterContainer* clusters, const Acts::GeometryContext& gctx,
const Acts::SourceLink& sourceLink,
const Acts::CalibrationContext& cctx, const Acts::SourceLink& sourceLink,
Acts::VectorMultiTrajectory::TrackStateProxy& trackState) const = 0;

virtual ~MeasurementCalibrator() = default;
Expand All @@ -48,7 +49,7 @@ class PassThroughCalibrator : public MeasurementCalibrator {
void calibrate(
const MeasurementContainer& measurements,
const ClusterContainer* clusters, const Acts::GeometryContext& gctx,
const Acts::SourceLink& sourceLink,
const Acts::CalibrationContext& cctx, const Acts::SourceLink& sourceLink,
Acts::VectorMultiTrajectory::TrackStateProxy& trackState) const override;
};

Expand All @@ -63,6 +64,7 @@ class MeasurementCalibratorAdapter {
MeasurementCalibratorAdapter() = delete;

void calibrate(const Acts::GeometryContext& gctx,
const Acts::CalibrationContext& cctx,
const Acts::SourceLink& sourceLink,
Acts::VectorMultiTrajectory::TrackStateProxy trackState) const;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class ScalingCalibrator : public MeasurementCalibrator {
void calibrate(
const MeasurementContainer& measurements,
const ClusterContainer* clusters, const Acts::GeometryContext& gctx,
const Acts::SourceLink& sourceLink,
const Acts::CalibrationContext& cctx, const Acts::SourceLink& sourceLink,
Acts::VectorMultiTrajectory::TrackStateProxy& trackState) const override;

bool needsClusters() const override { return true; }
Expand Down
8 changes: 5 additions & 3 deletions Examples/Framework/src/EventData/MeasurementCalibration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class VectorMultiTrajectory;
void ActsExamples::PassThroughCalibrator::calibrate(
const MeasurementContainer& measurements,
const ClusterContainer* /*clusters*/, const Acts::GeometryContext& /*gctx*/,
const Acts::CalibrationContext& /*cctx*/,
const Acts::SourceLink& sourceLink,
Acts::VectorMultiTrajectory::TrackStateProxy& trackState) const {
trackState.setUncalibratedSourceLink(sourceLink);
Expand All @@ -45,8 +46,9 @@ ActsExamples::MeasurementCalibratorAdapter::MeasurementCalibratorAdapter(
m_clusters{clusters} {}

void ActsExamples::MeasurementCalibratorAdapter::calibrate(
const Acts::GeometryContext& gctx, const Acts::SourceLink& sourceLink,
const Acts::GeometryContext& gctx, const Acts::CalibrationContext& cctx,
const Acts::SourceLink& sourceLink,
Acts::VectorMultiTrajectory::TrackStateProxy trackState) const {
return m_calibrator.calibrate(m_measurements, m_clusters, gctx, sourceLink,
trackState);
return m_calibrator.calibrate(m_measurements, m_clusters, gctx, cctx,
sourceLink, trackState);
}
5 changes: 4 additions & 1 deletion Examples/Framework/src/EventData/ScalingCalibrator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "Acts/EventData/SourceLink.hpp"
#include "Acts/Geometry/GeometryContext.hpp"
#include "Acts/Geometry/GeometryIdentifier.hpp"
#include "Acts/Utilities/CalibrationContext.hpp"
#include "ActsExamples/EventData/Cluster.hpp"
#include "ActsExamples/EventData/IndexSourceLink.hpp"
#include "ActsExamples/EventData/Measurement.hpp"
Expand Down Expand Up @@ -129,7 +130,9 @@ ActsExamples::ScalingCalibrator::ScalingCalibrator(

void ActsExamples::ScalingCalibrator::calibrate(
const MeasurementContainer& measurements, const ClusterContainer* clusters,
const Acts::GeometryContext& /*gctx*/, const Acts::SourceLink& sourceLink,
const Acts::GeometryContext& /*gctx*/,
const Acts::CalibrationContext& /*cctx*/,
const Acts::SourceLink& sourceLink,
Acts::VectorMultiTrajectory::TrackStateProxy& trackState) const {
trackState.setUncalibratedSourceLink(sourceLink);
const IndexSourceLink& idxSourceLink = sourceLink.get<IndexSourceLink>();
Expand Down

0 comments on commit 77ca0e8

Please sign in to comment.