Skip to content

Commit

Permalink
feat: python bindings and truth tracking example for GX2F (#2512)
Browse files Browse the repository at this point in the history
This adds a basic python framework to the Global Chi Square Fitter (GX2F). It runs with the current GX2F-implementation and is created to test the GX2F further.

So far some of the pulls already go into the right direction. Note, that we cannot fit with B-Fields != 0.
![Canvas](https://github.com/acts-project/acts/assets/70842573/8a52aa7a-7fec-4137-ac8b-851b35cde6e3)
  • Loading branch information
AJPfleger committed Oct 23, 2023
1 parent 713ff52 commit 2ed5a93
Show file tree
Hide file tree
Showing 6 changed files with 415 additions and 2 deletions.
3 changes: 2 additions & 1 deletion Examples/Algorithms/TrackFitting/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ add_library(
src/TrackFittingAlgorithm.cpp
src/KalmanFitterFunction.cpp
src/RefittingAlgorithm.cpp
src/GsfFitterFunction.cpp)
src/GsfFitterFunction.cpp
src/GlobalChiSquareFitterFunction.cpp)
target_include_directories(
ActsExamplesTrackFitting
PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,4 +92,21 @@ std::shared_ptr<TrackFitterFunction> makeGsfFitterFunction(
bool abortOnError, bool disableAllMaterialHandling,
const Acts::Logger& logger);

/// Makes a fitter function object for the Global Chi Square Fitter (GX2F)
///
/// @param trackingGeometry the trackingGeometry for the propagator
/// @param magneticField the magnetic field for the propagator
/// @param multipleScattering bool
/// @param energyLoss bool
/// @param freeToBoundCorrection bool
/// @param logger a logger instance
std::shared_ptr<TrackFitterFunction> makeGlobalChiSquareFitterFunction(
std::shared_ptr<const Acts::TrackingGeometry> trackingGeometry,
std::shared_ptr<const Acts::MagneticFieldProvider> magneticField,
bool multipleScattering = true, bool energyLoss = true,
Acts::FreeToBoundCorrection freeToBoundCorrection =
Acts::FreeToBoundCorrection(),
const Acts::Logger& logger = *Acts::getDefaultLogger("Gx2f",
Acts::Logging::INFO));

} // namespace ActsExamples
162 changes: 162 additions & 0 deletions Examples/Algorithms/TrackFitting/src/GlobalChiSquareFitterFunction.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
// This file is part of the Acts project.
//
// Copyright (C) 2023 CERN for the benefit of the Acts project
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.

// TODO We still use some Kalman Fitter functionalities. Check for replacement

#include "Acts/Definitions/Direction.hpp"
#include "Acts/Definitions/TrackParametrization.hpp"
#include "Acts/EventData/MultiTrajectory.hpp"
#include "Acts/EventData/TrackContainer.hpp"
#include "Acts/EventData/TrackStatePropMask.hpp"
#include "Acts/EventData/VectorMultiTrajectory.hpp"
#include "Acts/EventData/VectorTrackContainer.hpp"
#include "Acts/EventData/detail/CorrectedTransformationFreeToBound.hpp"
#include "Acts/Geometry/GeometryIdentifier.hpp"
#include "Acts/Propagator/DirectNavigator.hpp"
#include "Acts/Propagator/EigenStepper.hpp"
#include "Acts/Propagator/Navigator.hpp"
#include "Acts/Propagator/Propagator.hpp"
#include "Acts/TrackFitting/GlobalChiSquareFitter.hpp"
#include "Acts/TrackFitting/KalmanFitter.hpp"
#include "Acts/Utilities/Delegate.hpp"
#include "Acts/Utilities/Logger.hpp"
#include "ActsExamples/EventData/IndexSourceLink.hpp"
#include "ActsExamples/EventData/MeasurementCalibration.hpp"
#include "ActsExamples/EventData/Track.hpp"
#include "ActsExamples/TrackFitting/RefittingCalibrator.hpp"
#include "ActsExamples/TrackFitting/TrackFitterFunction.hpp"

#include <algorithm>
#include <cmath>
#include <functional>
#include <memory>
#include <utility>
#include <vector>

namespace Acts {
class MagneticFieldProvider;
class SourceLink;
class Surface;
class TrackingGeometry;
} // namespace Acts

namespace {

using Stepper = Acts::EigenStepper<>;
using Propagator = Acts::Propagator<Stepper, Acts::Navigator>;
using Fitter =
Acts::Experimental::Gx2Fitter<Propagator, Acts::VectorMultiTrajectory>;
using DirectPropagator = Acts::Propagator<Stepper, Acts::DirectNavigator>;
using DirectFitter =
Acts::KalmanFitter<DirectPropagator, Acts::VectorMultiTrajectory>;

using TrackContainer =
Acts::TrackContainer<Acts::VectorTrackContainer,
Acts::VectorMultiTrajectory, std::shared_ptr>;

using namespace ActsExamples;

struct GlobalChiSquareFitterFunctionImpl final : public TrackFitterFunction {
Fitter fitter;
DirectFitter directFitter;

bool multipleScattering = false;
bool energyLoss = false;
Acts::FreeToBoundCorrection freeToBoundCorrection;

IndexSourceLink::SurfaceAccessor m_slSurfaceAccessor;

GlobalChiSquareFitterFunctionImpl(Fitter&& f, DirectFitter&& df,
const Acts::TrackingGeometry& trkGeo)
: fitter(std::move(f)),
directFitter(std::move(df)),
m_slSurfaceAccessor{trkGeo} {}

template <typename calibrator_t>
auto makeGx2fOptions(const GeneralFitterOptions& options,
const calibrator_t& calibrator) const {
Acts::Experimental::Gx2FitterExtensions<Acts::VectorMultiTrajectory>
extensions;
extensions.calibrator.connect<&calibrator_t::calibrate>(&calibrator);

extensions.surfaceAccessor
.connect<&IndexSourceLink::SurfaceAccessor::operator()>(
&m_slSurfaceAccessor);

const Acts::Experimental::Gx2FitterOptions gx2fOptions(
options.geoContext, options.magFieldContext, options.calibrationContext,
extensions, options.propOptions, &(*options.referenceSurface),
multipleScattering, energyLoss, freeToBoundCorrection, 5);

return gx2fOptions;
}

TrackFitterResult operator()(const std::vector<Acts::SourceLink>& sourceLinks,
const TrackParameters& initialParameters,
const GeneralFitterOptions& options,
const MeasurementCalibratorAdapter& calibrator,
TrackContainer& tracks) const override {
const auto gx2fOptions = makeGx2fOptions(options, calibrator);
return fitter.fit(sourceLinks.begin(), sourceLinks.end(), initialParameters,
gx2fOptions, tracks);
}

// We need a placeholder for the directNavigator overload. Otherwise, we would
// have an unimplemented pure virtual method in a final class.
TrackFitterResult operator()(
const std::vector<Acts::SourceLink>& /*sourceLinks*/,
const TrackParameters& /*initialParameters*/,
const GeneralFitterOptions& /*options*/,
const RefittingCalibrator& /*calibrator*/,
const std::vector<const Acts::Surface*>& /*surfaceSequence*/,
TrackContainer& /*tracks*/) const override {
throw std::runtime_error(
"direct navigation with GX2 fitter is not implemented");
}
};

} // namespace

std::shared_ptr<ActsExamples::TrackFitterFunction>
ActsExamples::makeGlobalChiSquareFitterFunction(
std::shared_ptr<const Acts::TrackingGeometry> trackingGeometry,
std::shared_ptr<const Acts::MagneticFieldProvider> magneticField,
bool multipleScattering, bool energyLoss,
Acts::FreeToBoundCorrection freeToBoundCorrection,
const Acts::Logger& logger) {
// Stepper should be copied into the fitters
const Stepper stepper(std::move(magneticField));

// Standard fitter
const auto& geo = *trackingGeometry;
Acts::Navigator::Config cfg{std::move(trackingGeometry)};
cfg.resolvePassive = false;
cfg.resolveMaterial = true;
cfg.resolveSensitive = true;
Acts::Navigator navigator(cfg, logger.cloneWithSuffix("Navigator"));
Propagator propagator(stepper, std::move(navigator),
logger.cloneWithSuffix("Propagator"));
Fitter trackFitter(std::move(propagator), logger.cloneWithSuffix("Fitter"));

// Direct fitter
Acts::DirectNavigator directNavigator{
logger.cloneWithSuffix("DirectNavigator")};
DirectPropagator directPropagator(stepper, std::move(directNavigator),
logger.cloneWithSuffix("DirectPropagator"));
DirectFitter directTrackFitter(std::move(directPropagator),
logger.cloneWithSuffix("DirectFitter"));

// build the fitter function. owns the fitter object.
auto fitterFunction = std::make_shared<GlobalChiSquareFitterFunctionImpl>(
std::move(trackFitter), std::move(directTrackFitter), geo);
fitterFunction->multipleScattering = multipleScattering;
fitterFunction->energyLoss = energyLoss;
fitterFunction->freeToBoundCorrection = freeToBoundCorrection;

return fitterFunction;
}
49 changes: 49 additions & 0 deletions Examples/Python/python/acts/examples/reconstruction.py
Original file line number Diff line number Diff line change
Expand Up @@ -1134,6 +1134,55 @@ def addCKFTracks(
return s


def addGx2fTracks(
s: acts.examples.Sequencer,
trackingGeometry: acts.TrackingGeometry,
field: acts.MagneticFieldProvider,
# directNavigation: bool = False,
inputProtoTracks: str = "truth_particle_tracks",
multipleScattering: bool = False,
energyLoss: bool = False,
clusters: str = None,
calibrator: acts.examples.MeasurementCalibrator = acts.examples.makePassThroughCalibrator(),
logLevel: Optional[acts.logging.Level] = None,
) -> None:
customLogLevel = acts.examples.defaultLogging(s, logLevel)

gx2fOptions = {
"multipleScattering": multipleScattering,
"energyLoss": energyLoss,
"freeToBoundCorrection": acts.examples.FreeToBoundCorrection(False),
"level": customLogLevel(),
}

fitAlg = acts.examples.TrackFittingAlgorithm(
level=customLogLevel(),
inputMeasurements="measurements",
inputSourceLinks="sourcelinks",
inputProtoTracks=inputProtoTracks,
inputInitialTrackParameters="estimatedparameters",
inputClusters=clusters if clusters is not None else "",
outputTracks="gx2fTracks",
pickTrack=-1,
fit=acts.examples.makeGlobalChiSquareFitterFunction(
trackingGeometry, field, **gx2fOptions
),
calibrator=calibrator,
)
s.addAlgorithm(fitAlg)
s.addWhiteboardAlias("tracks", fitAlg.config.outputTracks)

trackConverter = acts.examples.TracksToTrajectories(
level=customLogLevel(),
inputTracks=fitAlg.config.outputTracks,
outputTrajectories="gx2fTrajectories",
)
s.addAlgorithm(trackConverter)
s.addWhiteboardAlias("trajectories", trackConverter.config.outputTrajectories)

return s


def addTrajectoryWriters(
s: acts.examples.Sequencer,
name: str,
Expand Down
20 changes: 19 additions & 1 deletion Examples/Python/src/TrackFitting.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// This file is part of the Acts project.
//
// Copyright (C) 2021 CERN for the benefit of the Acts project
// Copyright (C) 2021-2023 CERN for the benefit of the Acts project
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
Expand Down Expand Up @@ -127,6 +127,24 @@ void addTrackFitting(Context& ctx) {
py::arg("weightCutoff"), py::arg("finalReductionMethod"),
py::arg("abortOnError"), py::arg("disableAllMaterialHandling"),
py::arg("level"));

mex.def(
"makeGlobalChiSquareFitterFunction",
[](std::shared_ptr<const Acts::TrackingGeometry> trackingGeometry,
std::shared_ptr<const Acts::MagneticFieldProvider> magneticField,
bool multipleScattering, bool energyLoss,
Acts::FreeToBoundCorrection freeToBoundCorrection,
Logging::Level level) {
return ActsExamples::makeGlobalChiSquareFitterFunction(
trackingGeometry, magneticField, multipleScattering, energyLoss,
freeToBoundCorrection, *Acts::getDefaultLogger("Gx2f", level));
},
py::arg("trackingGeometry"), py::arg("magneticField"),
py::arg("multipleScattering"), py::arg("energyLoss"),
py::arg("freeToBoundCorrection"), py::arg("level"));

// TODO add other important parameters like nUpdates
// TODO add also in trackfitterfunction
}

{
Expand Down

0 comments on commit 2ed5a93

Please sign in to comment.