Skip to content

Commit

Permalink
fix: Don't let the GSF remove all weights (#1933)
Browse files Browse the repository at this point in the history
For the unlikely case all components are below the configured weight
cutoff, the Actor would remove all components and thus let the GSF
crash.
Also refactors some assertions from `throw_assert` to `assert`.
  • Loading branch information
benjaminhuth committed Mar 13, 2023
1 parent 16ebb52 commit 250ea03
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 32 deletions.
12 changes: 10 additions & 2 deletions Core/include/Acts/TrackFitting/detail/GsfActor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -452,9 +452,17 @@ struct GsfActor {
auto new_end = std::remove_if(cmps.begin(), cmps.end(), [&](auto& cmp) {
return proj(cmp) < m_cfg.weightCutoff;
});
cmps.erase(new_end, cmps.end());

detail::normalizeWeights(cmps, proj);
// In case we would remove all components, keep only the largest
if (std::distance(cmps.begin(), new_end) == 0) {
cmps = {*std::max_element(
cmps.begin(), cmps.end(),
[&](auto& a, auto& b) { return proj(a) < proj(b); })};
std::get<0>(cmps.front()).weight = 1.0;
} else {
cmps.erase(new_end, cmps.end());
detail::normalizeWeights(cmps, proj);
}
}

/// Function that updates the stepper from the MultiTrajectory
Expand Down
60 changes: 30 additions & 30 deletions Core/include/Acts/TrackFitting/detail/GsfUtils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,45 +23,36 @@ constexpr static double s_normalizationTolerance = 1.e-4;

namespace detail {

template <typename component_range_t, typename projector_t,
typename print_flag_t = std::false_type>
template <typename component_range_t, typename projector_t>
bool weightsAreNormalized(const component_range_t &cmps,
const projector_t &proj,
double tol = s_normalizationTolerance,
print_flag_t print_flag = print_flag_t{}) {
double sum_of_weights = 0.0;
double tol = s_normalizationTolerance) {
double sumOfWeights = 0.0;

for (auto it = cmps.begin(); it != cmps.end(); ++it) {
sum_of_weights += proj(*it);
sumOfWeights += proj(*it);
}

if (std::abs(sum_of_weights - 1.0) < tol) {
return true;
} else {
if constexpr (print_flag) {
std::cout << std::setprecision(10)
<< "diff from 1: " << std::abs(sum_of_weights - 1.0) << "\n";
}

return false;
}
return std::abs(sumOfWeights - 1.0) < tol;
}

template <typename component_range_t, typename projector_t>
void normalizeWeights(component_range_t &cmps, const projector_t &proj) {
double sum_of_weights = 0.0;
double sumOfWeights = 0.0;

// we need decltype(auto) here to support proxy-types with reference
// semantics, otherwise there is a `cannot bind ... to ...` error
for (auto it = cmps.begin(); it != cmps.end(); ++it) {
decltype(auto) cmp = *it;
throw_assert(std::isfinite(proj(cmp)), "weight not finite:" << proj(cmp));
sum_of_weights += proj(cmp);
assert(std::isfinite(proj(cmp)) && "weight not finite in normalization");
sumOfWeights += proj(cmp);
}

assert(sumOfWeights > 0 && "sum of weights is not > 0");

for (auto it = cmps.begin(); it != cmps.end(); ++it) {
decltype(auto) cmp = *it;
proj(cmp) /= sum_of_weights;
proj(cmp) /= sumOfWeights;
}
}

Expand Down Expand Up @@ -92,16 +83,25 @@ class ScopedGsfInfoPrinterAndChecker {
}
}

void checks(const std::string_view &where) const {
void checks(bool onStart) const {
const auto cmps = m_stepper.constComponentIterable(m_state.stepping);
throw_assert(
[[maybe_unused]] const bool allFinite =
std::all_of(cmps.begin(), cmps.end(),
[](auto cmp) { return std::isfinite(cmp.weight()); }),
"some weights are not finite at " << where);

throw_assert(detail::weightsAreNormalized(
cmps, [](const auto &cmp) { return cmp.weight(); }),
"not normalized at " << where);
[](auto cmp) { return std::isfinite(cmp.weight()); });
[[maybe_unused]] const bool allNormalized = detail::weightsAreNormalized(
cmps, [](const auto &cmp) { return cmp.weight(); });
[[maybe_unused]] const bool zeroComponents =
m_stepper.numberComponents(m_state.stepping) == 0;

if (onStart) {
assert(not zeroComponents && "no cmps at the start");
assert(allFinite && "weights not finite at the start");
assert(allNormalized && "not normalized at the start");
} else {
assert(not zeroComponents && "no cmps at the end");
assert(allFinite && "weights not finite at the end");
assert(allNormalized && "not normalized at the end");
}
}

public:
Expand All @@ -112,7 +112,7 @@ class ScopedGsfInfoPrinterAndChecker {
m_p_initial(stepper.momentum(state.stepping)),
m_logger{logger} {
// Some initial printing
checks("start");
checks(true);
ACTS_VERBOSE("Gsf step "
<< state.stepping.steps << " at mean position "
<< stepper.position(state.stepping).transpose()
Expand All @@ -136,7 +136,7 @@ class ScopedGsfInfoPrinterAndChecker {
ACTS_VERBOSE("Delta Momentum = " << std::setprecision(5)
<< p_final - m_p_initial);
}
checks("end");
checks(false);
}
};

Expand Down

0 comments on commit 250ea03

Please sign in to comment.