Skip to content

Commit

Permalink
refactor!: Simplify CKF SourceLinkAccessor and make fewer assumptions (
Browse files Browse the repository at this point in the history
…#1203)

This PR migrates the CKF's source link accessor away from a class with 3 necessary methods to a single delegate that is given as part of the options. At the same time it changes the lookup key from the raw geometry identifier to the current surface, which will allow navigating to experiment specific detector elements without the need to go through a map lookup.

The CKF options move from being templated on the source link accessor itself to the source link accessor **iterator type**. The delegate is then defined as returning a range (a pair of) source link iterators, and those are used. Previously, source link accessors needed to implement three methods:

```cpp
size_t count(GeometryIdentifier id) const;
std::pair<Iterator, Iterator> range(GeometryIdentifier id) const;
const SourceLink& at(Iterator it) const;
```

which is now changed to a callable like

```cpp
std::pair<Iterator, Iterator> range(const Surface&) const;
```

which is accessed through the delegate. This PR also removes the obsolete methods. In addition the accessor was required to have a public *container* member which the `findTracks` method would accept and then set on the accessor. This is now assumed to happen before calling, the delegate can contain pointer to the accessor instance with the container already set. We believe this is better suited to environments with complex access call chains (e.g. ATLAS).

BREAKING CHANGE: This changes the way source links are passed to the CKF `findTracks` method. Instead of taking a container of source links, it now assumed the CKF options have a source link accessor connected which is configured to be able to access that container.
  • Loading branch information
paulgessinger committed Apr 8, 2022
1 parent 61e4021 commit 65d1e09
Show file tree
Hide file tree
Showing 10 changed files with 110 additions and 80 deletions.
75 changes: 37 additions & 38 deletions Core/include/Acts/TrackFinding/CombinatorialKalmanFilter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,13 +121,21 @@ struct CombinatorialKalmanFilterExtensions {
}
};

/// Delegate type that retrieves a range of source links to for a given surface
/// to be processed by the CKF
template <typename source_link_iterator_t>
using SourceLinkAccessorDelegate =
Delegate<std::pair<source_link_iterator_t, source_link_iterator_t>(
const Surface&)>;

/// Combined options for the combinatorial Kalman filter.
///
/// @tparam source_link_accessor_t Source link accessor type, should be
/// semiregular.
template <typename source_link_accessor_t>
template <typename source_link_iterator_t>
struct CombinatorialKalmanFilterOptions {
using SourceLinkAccessor = source_link_accessor_t;
using SourceLinkIterator = source_link_iterator_t;
using SourceLinkAccessor = SourceLinkAccessorDelegate<source_link_iterator_t>;

/// PropagatorOptions with context
///
Expand Down Expand Up @@ -290,7 +298,6 @@ class CombinatorialKalmanFilter {
using CurvilinearState =
std::tuple<CurvilinearTrackParameters, BoundMatrix, double>;
// The source link container type
using SourceLinkContainer = typename source_link_accessor_t::Container;
/// Broadcast the result_type
using result_type = CombinatorialKalmanFilterResult;

Expand Down Expand Up @@ -545,8 +552,8 @@ class CombinatorialKalmanFilter {
size_t nBranchesOnSurface = 0;

// Count the number of source links on the surface
size_t nSourcelinks = m_sourcelinkAccessor.count(surface->geometryId());
if (nSourcelinks > 0) {
auto [slBegin, slEnd] = m_sourcelinkAccessor(*surface);
if (slBegin != slEnd) {
// Screen output message
ACTS_VERBOSE("Measurement surface " << surface->geometryId()
<< " detected.");
Expand Down Expand Up @@ -579,8 +586,8 @@ class CombinatorialKalmanFilter {

// Create trackstates for all source links (will be filtered later)
// Results are stored in result => no return value
createSourceLinkTrackStates(state.geoContext, surface, result,
boundState, prevTip);
createSourceLinkTrackStates(state.geoContext, result, boundState,
prevTip, slBegin, slEnd);

// Invoke the measurement selector to select compatible measurements
// with the predicted track parameter.
Expand Down Expand Up @@ -735,39 +742,43 @@ class CombinatorialKalmanFilter {

/// Create and fill track states for all source links
/// @param gctx The current geometry context
/// @param surface The surface currently being processed
/// @param result Reference to the result struct of the actor
/// @param boundState Bound state from the propagation on this surface
/// @param prevTip Index pointing at previous trajectory state (i.e. tip)
/// @param slBegin Begin iterator for sourcelinks
/// @param slEnd End iterator for sourcelinks
template <typename source_link_iterator_t>
void createSourceLinkTrackStates(const Acts::GeometryContext& gctx,
const Surface* surface,
result_type& result,
const BoundState& boundState,
size_t prevTip) const {
size_t prevTip,
source_link_iterator_t slBegin,
source_link_iterator_t slEnd) const {
const auto& [boundParams, jacobian, pathLength] = boundState;

// Get all source links on the surface
auto [lower_it, upper_it] =
m_sourcelinkAccessor.range(surface->geometryId());

result.trackStateCandidates.clear();
result.trackStateCandidates.reserve(std::distance(lower_it, upper_it));
if constexpr (std::is_same_v<
typename std::iterator_traits<
source_link_iterator_t>::iterator_category,
std::random_access_iterator_tag>) {
result.trackStateCandidates.reserve(std::distance(slBegin, slEnd));
}

result.stateBuffer.clear();

using PM = TrackStatePropMask;

// Calibrate all the source links on the surface since the selection has
// to be done based on calibrated measurement
for (auto it = lower_it; it != upper_it; ++it) {
for (auto it = slBegin; it != slEnd; ++it) {
// get the source link
const auto& sourceLink = m_sourcelinkAccessor.at(it);
const auto& sourceLink = *it;

// prepare the track state
PM mask =
PM::Predicted | PM::Jacobian | PM::Uncalibrated | PM::Calibrated;

if (it != lower_it) {
if (it != slBegin) {
// not the first TrackState, only need uncalibrated and calibrated
mask = PM::Uncalibrated | PM::Calibrated;
}
Expand All @@ -778,7 +789,7 @@ class CombinatorialKalmanFilter {
// fail!
auto ts = result.stateBuffer.getTrackState(tsi);

if (it == lower_it) {
if (it == slBegin) {
// only set these for first
ts.predicted() = boundParams.parameters();
if (boundParams.covariance()) {
Expand Down Expand Up @@ -1181,14 +1192,13 @@ class CombinatorialKalmanFilter {
/// Combinatorial Kalman Filter implementation, calls the the Kalman filter
/// and smoother
///
/// @tparam source_link_accessor_t Type of the source link accessor
/// @tparam source_link_iterator_t Type of the source link iterator
/// @tparam start_parameters_container_t Type of the initial parameters
/// container
/// @tparam calibrator_t Type of the source link calibrator
/// @tparam measurement_selector_t Type of the measurement selector
/// @tparam parameters_t Type of parameters used for local parameters
///
/// @param sourcelinks The fittable uncalibrated measurements
/// @param initialParameters The initial track parameters
/// @param tfOptions CombinatorialKalmanFilterOptions steering the track
/// finding
Expand All @@ -1199,32 +1209,23 @@ class CombinatorialKalmanFilter {
///
/// @return a container of track finding result for all the initial track
/// parameters
template <typename source_link_accessor_t,
template <typename source_link_iterator_t,
typename start_parameters_container_t,
typename parameters_t = BoundTrackParameters>
std::vector<Result<CombinatorialKalmanFilterResult>> findTracks(
const typename source_link_accessor_t::Container& sourcelinks,
const start_parameters_container_t& initialParameters,
const CombinatorialKalmanFilterOptions<source_link_accessor_t>& tfOptions)
const CombinatorialKalmanFilterOptions<source_link_iterator_t>& tfOptions)
const {
static_assert(
SourceLinkAccessorConcept<source_link_accessor_t>,
"The source link accessor does not fullfill SourceLinkAccessorConcept");
static_assert(
std::is_same_v<GeometryIdentifier,
typename source_link_accessor_t::Key>,
"The source link container does not have GeometryIdentifier as the key "
"type");

const auto& logger = tfOptions.logger;

ACTS_VERBOSE("Preparing " << sourcelinks.size() << " input measurements");
using SourceLinkAccessor =
SourceLinkAccessorDelegate<source_link_iterator_t>;

// Create the ActionList and AbortList
using CombinatorialKalmanFilterAborter =
Aborter<source_link_accessor_t, parameters_t>;
Aborter<SourceLinkAccessor, parameters_t>;
using CombinatorialKalmanFilterActor =
Actor<source_link_accessor_t, parameters_t>;
Actor<SourceLinkAccessor, parameters_t>;
using Actors = ActionList<CombinatorialKalmanFilterActor>;
using Aborters = AbortList<CombinatorialKalmanFilterAborter>;

Expand All @@ -1245,8 +1246,6 @@ class CombinatorialKalmanFilter {

// copy source link accessor, calibrator and measurement selector
combKalmanActor.m_sourcelinkAccessor = tfOptions.sourcelinkAccessor;
// set the pointer to the source links
combKalmanActor.m_sourcelinkAccessor.container = &sourcelinks;
combKalmanActor.m_extensions = tfOptions.extensions;

// Run the CombinatorialKalmanFilter.
Expand Down
8 changes: 5 additions & 3 deletions Core/include/Acts/TrackFinding/SourceLinkAccessorConcept.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

namespace Acts {

class Surface;

namespace Concepts {
namespace SourceLinkAccessor {

Expand Down Expand Up @@ -41,17 +43,17 @@ METHOD_TRAIT(at_t, at);
static_assert(value_exists, "Value type not found");
constexpr static bool iterator_exists = exists<iterator_t, S>;
static_assert(iterator_exists, "Iterator type not found");

constexpr static bool container_pointer_exists =
std::is_same_v<std::decay_t<decltype(*(std::declval<S>().container))>, container_t<S>>;
static_assert(container_pointer_exists, "Pointer to container not found");

constexpr static bool count_exists = has_method<const S,
size_t, count_t, const typename S::Key&>;
size_t, count_t, const Surface&>;
static_assert(count_exists, "count method not found");
constexpr static bool range_exists = has_method<const S,
std::pair<typename S::Iterator, typename S::Iterator>,
range_t, const typename S::Key&>;
range_t, const Surface&>;
static_assert(range_exists, "range method not found");
constexpr static bool at_exists = has_method<const S,
const typename S::Value&, at_t, const typename S::Iterator&>;
Expand Down
1 change: 1 addition & 0 deletions Core/include/Acts/Utilities/Delegate.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ class Delegate<R(Args...)> {

m_payload = instance;
m_function = [](const void* payload, Args... args) -> return_type {
assert(payload != nullptr && "Payload is required, but not set");
const auto* concretePayload = static_cast<const Type*>(payload);
return std::invoke(Callable, concretePayload,
std::forward<Args>(args)...);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class TrackFindingAlgorithm final : public BareAlgorithm {
/// Track finder function that takes input measurements, initial trackstate
/// and track finder options and returns some track-finder-specific result.
using TrackFinderOptions =
Acts::CombinatorialKalmanFilterOptions<IndexSourceLinkAccessor>;
Acts::CombinatorialKalmanFilterOptions<IndexSourceLinkAccessor::Iterator>;
using TrackFinderResult =
std::vector<Acts::Result<Acts::CombinatorialKalmanFilterResult>>;

Expand All @@ -37,8 +37,7 @@ class TrackFindingAlgorithm final : public BareAlgorithm {
class TrackFinderFunction {
public:
virtual ~TrackFinderFunction() = default;
virtual TrackFinderResult operator()(const IndexSourceLinkContainer&,
const TrackParametersContainer&,
virtual TrackFinderResult operator()(const TrackParametersContainer&,
const TrackFinderOptions&) const = 0;
};

Expand Down
13 changes: 9 additions & 4 deletions Examples/Algorithms/TrackFinding/src/TrackFindingAlgorithm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,16 +71,21 @@ ActsExamples::ProcessCode ActsExamples::TrackFindingAlgorithm::execute(
extensions.measurementSelector.connect<&Acts::MeasurementSelector::select>(
&measSel);

IndexSourceLinkAccessor slAccessor;
slAccessor.container = &sourceLinks;
Acts::SourceLinkAccessorDelegate<IndexSourceLinkAccessor::Iterator>
slAccessorDelegate;
slAccessorDelegate.connect<&IndexSourceLinkAccessor::range>(&slAccessor);

// Set the CombinatorialKalmanFilter options
ActsExamples::TrackFindingAlgorithm::TrackFinderOptions options(
ctx.geoContext, ctx.magFieldContext, ctx.calibContext,
IndexSourceLinkAccessor(), extensions, Acts::LoggerWrapper{logger()},
pOptions, &(*pSurface));
ctx.geoContext, ctx.magFieldContext, ctx.calibContext, slAccessorDelegate,
extensions, Acts::LoggerWrapper{logger()}, pOptions, &(*pSurface));

// Perform the track finding for all initial parameters
ACTS_DEBUG("Invoke track finding with " << initialParameters.size()
<< " seeds.");
auto results = (*m_cfg.findTracks)(sourceLinks, initialParameters, options);
auto results = (*m_cfg.findTracks)(initialParameters, options);

// Compute shared hits from all the reconstructed tracks
if (m_cfg.computeSharedHits) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,10 @@ struct TrackFinderFunctionImpl
TrackFinderFunctionImpl(CKF&& f) : trackFinder(std::move(f)) {}

ActsExamples::TrackFindingAlgorithm::TrackFinderResult operator()(
const ActsExamples::IndexSourceLinkContainer& sourcelinks,
const ActsExamples::TrackParametersContainer& initialParameters,
const ActsExamples::TrackFindingAlgorithm::TrackFinderOptions& options)
const override {
return trackFinder.findTracks(sourcelinks, initialParameters, options);
return trackFinder.findTracks(initialParameters, options);
};
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@
#pragma once

#include "Acts/Geometry/GeometryIdentifier.hpp"
#include "Acts/Surfaces/Surface.hpp"
#include "ActsExamples/Utilities/GroupBy.hpp"
#include "ActsExamples/Utilities/Range.hpp"

#include <algorithm>
#include <cassert>
#include <cstddef>
#include <iostream>
#include <utility>

#include <boost/container/flat_map.hpp>
Expand Down Expand Up @@ -225,21 +227,11 @@ struct GeometryIdMultisetAccessor {
// pointer to the container
const Container* container = nullptr;

// count the number of elements with requested geoId
size_t count(const Acts::GeometryIdentifier& geoId) const {
assert(container != nullptr);
return container->count(geoId);
}

// get the range of elements with requested geoId
std::pair<Iterator, Iterator> range(
const Acts::GeometryIdentifier& geoId) const {
std::pair<Iterator, Iterator> range(const Acts::Surface& surface) const {
assert(container != nullptr);
return container->equal_range(geoId);
return container->equal_range(surface.geometryId());
}

// get the element using the iterator
const Value& at(const Iterator& it) const { return *it; }
};

} // namespace ActsExamples
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#pragma once

#include "Acts/EventData/SourceLink.hpp"
#include "Acts/Surfaces/Surface.hpp"
#include "ActsExamples/EventData/GeometryContainers.hpp"
#include "ActsExamples/EventData/Index.hpp"

Expand Down

0 comments on commit 65d1e09

Please sign in to comment.