Skip to content

Commit

Permalink
refactor: Fitting and Finding algorithm function interface (#922)
Browse files Browse the repository at this point in the history
Changes both the fitting and the finding algorithm to use a regular
interface with virtual methods rather than std::function to decouple the
actual fitter/finder instances. The decoupling is only for compile-time
memory usage reduction. The reason is that pybind11 does a round trip
throught python with std::function, which we want to avoid. This is
functionally equivalent, but uses a different mechanism.

Also adds factory functions for these decoupled objects. We don't
foresee these to be switchable, this decoupling is done for purely
technical reasons anyway.
  • Loading branch information
paulgessinger committed Aug 6, 2021
1 parent cc52918 commit 6523cba
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,23 @@ class TrackFindingAlgorithm final : public BareAlgorithm {
Acts::MeasurementSelector>;
using TrackFinderResult = std::vector<
Acts::Result<Acts::CombinatorialKalmanFilterResult<IndexSourceLink>>>;
using TrackFinderFunction = std::function<TrackFinderResult(
const IndexSourceLinkContainer&, const TrackParametersContainer&,
const TrackFinderOptions&)>;

/// Find function that takes the above parameters
/// @note This is separated into a virtual interface to keep compilation units
/// small
class TrackFinderFunction {
public:
virtual ~TrackFinderFunction() = default;
virtual TrackFinderResult operator()(const IndexSourceLinkContainer&,
const TrackParametersContainer&,
const TrackFinderOptions&) const = 0;
};

/// Create the track finder function implementation.
///
/// The magnetic field is intentionally given by-value since the variant
/// contains shared_ptr anyways.
static TrackFinderFunction makeTrackFinderFunction(
static std::shared_ptr<TrackFinderFunction> makeTrackFinderFunction(
std::shared_ptr<const Acts::TrackingGeometry> trackingGeometry,
std::shared_ptr<const Acts::MagneticFieldProvider> magneticField);

Expand All @@ -53,7 +61,7 @@ class TrackFindingAlgorithm final : public BareAlgorithm {
/// Output find trajectories collection.
std::string outputTrajectories;
/// Type erased track finder function.
TrackFinderFunction findTracks;
std::shared_ptr<TrackFinderFunction> findTracks;
/// CKF measurement selector config
Acts::MeasurementSelector::Config measurementSelectorCfg;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ ActsExamples::ProcessCode ActsExamples::TrackFindingAlgorithm::execute(
// Perform the track finding for all initial parameters
ACTS_DEBUG("Invoke track finding with " << initialParameters.size()
<< " seeds.");
auto results = m_cfg.findTracks(sourceLinks, initialParameters, options);
auto results = (*m_cfg.findTracks)(sourceLinks, initialParameters, options);
// Loop over the track finding results for all initial parameters
for (std::size_t iseed = 0; iseed < initialParameters.size(); ++iseed) {
// The result for this seed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ using Navigator = Acts::Navigator;
using Propagator = Acts::Propagator<Stepper, Navigator>;
using CKF = Acts::CombinatorialKalmanFilter<Propagator, Updater, Smoother>;

struct TrackFinderFunctionImpl {
struct TrackFinderFunctionImpl
: public ActsExamples::TrackFindingAlgorithm::TrackFinderFunction {
CKF trackFinder;

TrackFinderFunctionImpl(CKF&& f) : trackFinder(std::move(f)) {}
Expand All @@ -36,14 +37,14 @@ struct TrackFinderFunctionImpl {
const ActsExamples::IndexSourceLinkContainer& sourcelinks,
const ActsExamples::TrackParametersContainer& initialParameters,
const ActsExamples::TrackFindingAlgorithm::TrackFinderOptions& options)
const {
const override {
return trackFinder.findTracks(sourcelinks, initialParameters, options);
};
};

} // namespace

ActsExamples::TrackFindingAlgorithm::TrackFinderFunction
std::shared_ptr<ActsExamples::TrackFindingAlgorithm::TrackFinderFunction>
ActsExamples::TrackFindingAlgorithm::makeTrackFinderFunction(
std::shared_ptr<const Acts::TrackingGeometry> trackingGeometry,
std::shared_ptr<const Acts::MagneticFieldProvider> magneticField) {
Expand All @@ -57,5 +58,5 @@ ActsExamples::TrackFindingAlgorithm::makeTrackFinderFunction(
CKF trackFinder(std::move(propagator));

// build the track finder functions. owns the track finder object.
return TrackFinderFunctionImpl(std::move(trackFinder));
return std::make_shared<TrackFinderFunctionImpl>(std::move(trackFinder));
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,25 +34,40 @@ class TrackFittingAlgorithm final : public BareAlgorithm {
Acts::KalmanFitterOptions<MeasurementCalibrator, Acts::VoidOutlierFinder>;
using TrackFitterResult =
Acts::Result<Acts::KalmanFitterResult<IndexSourceLink>>;
using TrackFitterFunction = std::function<TrackFitterResult(
const std::vector<IndexSourceLink>&, const TrackParameters&,
const TrackFitterOptions&)>;

/// Fit function that takes the above parameters and runs a fit
/// @note This is separated into a virtual interface to keep compilation units
/// small
class TrackFitterFunction {
public:
virtual ~TrackFitterFunction() = default;
virtual TrackFitterResult operator()(const std::vector<IndexSourceLink>&,
const TrackParameters&,
const TrackFitterOptions&) const = 0;
};

/// Fit function that takes the above parameters plus a sorted surface
/// sequence for the DirectNavigator to follow
using DirectedTrackFitterFunction = std::function<TrackFitterResult(
const std::vector<IndexSourceLink>&, const TrackParameters&,
const TrackFitterOptions&, const std::vector<const Acts::Surface*>&)>;
/// @note This is separated into a virtual interface to keep compilation units
/// small
class DirectedTrackFitterFunction {
public:
virtual ~DirectedTrackFitterFunction() = default;
virtual TrackFitterResult operator()(
const std::vector<IndexSourceLink>&, const TrackParameters&,
const TrackFitterOptions&,
const std::vector<const Acts::Surface*>&) const = 0;
};

/// Create the track fitter function implementation.
///
/// The magnetic field is intentionally given by-value since the variant
/// contains shared_ptr anyways.
static TrackFitterFunction makeTrackFitterFunction(
static std::shared_ptr<TrackFitterFunction> makeTrackFitterFunction(
std::shared_ptr<const Acts::TrackingGeometry> trackingGeometry,
std::shared_ptr<const Acts::MagneticFieldProvider> magneticField);

static DirectedTrackFitterFunction makeTrackFitterFunction(
static std::shared_ptr<DirectedTrackFitterFunction> makeTrackFitterFunction(
std::shared_ptr<const Acts::MagneticFieldProvider> magneticField);

struct Config {
Expand All @@ -69,9 +84,9 @@ class TrackFittingAlgorithm final : public BareAlgorithm {
/// Output fitted trajectories collection.
std::string outputTrajectories;
/// Type erased fitter function.
TrackFitterFunction fit;
std::shared_ptr<TrackFitterFunction> fit;
/// Type erased direct navigation fitter function
DirectedTrackFitterFunction dFit;
std::shared_ptr<DirectedTrackFitterFunction> dFit;
/// Tracking geometry for surface lookup
std::shared_ptr<const Acts::TrackingGeometry> trackingGeometry;
/// Some more detailed steering - mainly for debugging, correct for MCS
Expand Down Expand Up @@ -113,10 +128,10 @@ ActsExamples::TrackFittingAlgorithm::fitTrack(
Acts::VoidOutlierFinder>& options,
const std::vector<const Acts::Surface*>& surfSequence) const {
if (m_cfg.directNavigation) {
return m_cfg.dFit(sourceLinks, initialParameters, options, surfSequence);
return (*m_cfg.dFit)(sourceLinks, initialParameters, options, surfSequence);
}

return m_cfg.fit(sourceLinks, initialParameters, options);
return (*m_cfg.fit)(sourceLinks, initialParameters, options);
}

} // namespace ActsExamples
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ using Fitter = Acts::KalmanFitter<Propagator, Updater, Smoother>;
using DirectPropagator = Acts::Propagator<Stepper, Acts::DirectNavigator>;
using DirectFitter = Acts::KalmanFitter<DirectPropagator, Updater, Smoother>;

struct TrackFitterFunctionImpl {
struct TrackFitterFunctionImpl
: public ActsExamples::TrackFittingAlgorithm::TrackFitterFunction {
Fitter trackFitter;

TrackFitterFunctionImpl(Fitter&& f) : trackFitter(std::move(f)) {}
Expand All @@ -39,27 +40,28 @@ struct TrackFitterFunctionImpl {
const std::vector<ActsExamples::IndexSourceLink>& sourceLinks,
const ActsExamples::TrackParameters& initialParameters,
const ActsExamples::TrackFittingAlgorithm::TrackFitterOptions& options)
const {
const override {
return trackFitter.fit(sourceLinks, initialParameters, options);
};
};

struct DirectedFitterFunctionImpl {
struct DirectedFitterFunctionImpl
: public ActsExamples::TrackFittingAlgorithm::DirectedTrackFitterFunction {
DirectFitter fitter;
DirectedFitterFunctionImpl(DirectFitter&& f) : fitter(std::move(f)) {}

ActsExamples::TrackFittingAlgorithm::TrackFitterResult operator()(
const std::vector<ActsExamples::IndexSourceLink>& sourceLinks,
const ActsExamples::TrackParameters& initialParameters,
const ActsExamples::TrackFittingAlgorithm::TrackFitterOptions& options,
const std::vector<const Acts::Surface*>& sSequence) const {
const std::vector<const Acts::Surface*>& sSequence) const override {
return fitter.fit(sourceLinks, initialParameters, options, sSequence);
};
};

} // namespace

ActsExamples::TrackFittingAlgorithm::TrackFitterFunction
std::shared_ptr<ActsExamples::TrackFittingAlgorithm::TrackFitterFunction>
ActsExamples::TrackFittingAlgorithm::makeTrackFitterFunction(
std::shared_ptr<const Acts::TrackingGeometry> trackingGeometry,
std::shared_ptr<const Acts::MagneticFieldProvider> magneticField) {
Expand All @@ -73,10 +75,11 @@ ActsExamples::TrackFittingAlgorithm::makeTrackFitterFunction(
Fitter trackFitter(std::move(propagator));

// build the fitter functions. owns the fitter object.
return TrackFitterFunctionImpl(std::move(trackFitter));
return std::make_shared<TrackFitterFunctionImpl>(std::move(trackFitter));
}

ActsExamples::TrackFittingAlgorithm::DirectedTrackFitterFunction
std::shared_ptr<
ActsExamples::TrackFittingAlgorithm::DirectedTrackFitterFunction>
ActsExamples::TrackFittingAlgorithm::makeTrackFitterFunction(
std::shared_ptr<const Acts::MagneticFieldProvider> magneticField) {
// construct all components for the fitter
Expand All @@ -86,5 +89,5 @@ ActsExamples::TrackFittingAlgorithm::makeTrackFitterFunction(
DirectFitter fitter(std::move(propagator));

// build the fitter functions. owns the fitter object.
return DirectedFitterFunctionImpl(std::move(fitter));
return std::make_shared<DirectedFitterFunctionImpl>(std::move(fitter));
}

0 comments on commit 6523cba

Please sign in to comment.