Skip to content

Commit

Permalink
refactor: rearrange seed confirmation to avoid unnecessary checks (#1577
Browse files Browse the repository at this point in the history
)
  • Loading branch information
LuisFelipeCoelho committed Oct 24, 2022
1 parent 0237412 commit f20738e
Show file tree
Hide file tree
Showing 11 changed files with 73 additions and 82 deletions.
20 changes: 15 additions & 5 deletions Core/include/Acts/Seeding/SeedFilter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,19 @@
#include <vector>

namespace Acts {
struct SeedFilterState {
// longitudinal impact parameter as defined by bottom and middle space point
float zOrigin;
// number of minimum top SPs in seed confirmation
size_t nTopSeedConf = 0;
// number of high quality seeds in seed confirmation
int numQualitySeeds = 0;
// number of seeds that did not pass the quality confirmation but were still
// accepted, if quality confirmation is not used this is the total number of
// seeds
int numSeeds = 0;
};

/// Filter seeds at various stages with the currently
/// available information.
template <typename external_spacepoint_t>
Expand All @@ -38,17 +51,14 @@ class SeedFilter {
/// with both bottom and middle space point
/// @param invHelixDiameterVec vector containing 1/(2*r) values where r is the helix radius
/// @param impactParametersVec vector containing the impact parameters
/// @param zOrigin on the z axis as defined by bottom and middle space point
/// @param numQualitySeeds number of high quality seeds in seed confirmation
/// @param numSeeds number of seeds that did not pass the quality confirmation but were still accepted, if quality confirmation is not used this is the total number of seeds
/// @param seedFilterState holds quantities used in seed filter
/// @param outCont Output container for the seeds
virtual void filterSeeds_2SpFixed(
InternalSpacePoint<external_spacepoint_t>& bottomSP,
InternalSpacePoint<external_spacepoint_t>& middleSP,
std::vector<InternalSpacePoint<external_spacepoint_t>*>& topSpVec,
std::vector<float>& invHelixDiameterVec,
std::vector<float>& impactParametersVec, float zOrigin,
int& numQualitySeeds, int& numSeeds,
std::vector<float>& impactParametersVec, SeedFilterState& seedFilterState,
std::vector<std::pair<
float, std::unique_ptr<const InternalSeed<external_spacepoint_t>>>>&
outCont) const;
Expand Down
40 changes: 13 additions & 27 deletions Core/include/Acts/Seeding/SeedFilter.ipp
Original file line number Diff line number Diff line change
Expand Up @@ -27,30 +27,14 @@ void SeedFilter<external_spacepoint_t>::filterSeeds_2SpFixed(
InternalSpacePoint<external_spacepoint_t>& middleSP,
std::vector<InternalSpacePoint<external_spacepoint_t>*>& topSpVec,
std::vector<float>& invHelixDiameterVec,
std::vector<float>& impactParametersVec, float zOrigin,
int& numQualitySeeds, int& numSeeds,
std::vector<float>& impactParametersVec, SeedFilterState& seedFilterState,
std::vector<std::pair<
float, std::unique_ptr<const InternalSeed<external_spacepoint_t>>>>&
outCont) const {
// seed confirmation
int nTopSeedConf = 0;
if (m_cfg.seedConfirmation) {
// check if bottom SP is in the central or forward region
SeedConfirmationRangeConfig seedConfRange =
(bottomSP.z() > m_cfg.centralSeedConfirmationRange.zMaxSeedConf ||
bottomSP.z() < m_cfg.centralSeedConfirmationRange.zMinSeedConf)
? m_cfg.forwardSeedConfirmationRange
: m_cfg.centralSeedConfirmationRange;
// set the minimum number of top SP depending on whether the bottom SP is
// in the central or forward region
nTopSeedConf = bottomSP.radius() > seedConfRange.rMaxSeedConf
? seedConfRange.nTopForLargeR
: seedConfRange.nTopForSmallR;
}

size_t maxWeightSeedIndex = 0;
bool maxWeightSeed = false;
float weightMax = -std::numeric_limits<float>::max();
float zOrigin = seedFilterState.zOrigin;

// initialize original index locations
std::vector<size_t> idx(topSpVec.size());
Expand Down Expand Up @@ -144,8 +128,9 @@ void SeedFilter<external_spacepoint_t>::filterSeeds_2SpFixed(
// impact parameter, z-origin and number of compatible seeds inside a
// pre-defined range that also depends on the region of the detector (i.e.
// forward or central region) defined by SeedConfirmationRange
deltaSeedConf = compatibleSeedR.size() + 1 - nTopSeedConf;
if (deltaSeedConf < 0 || (numQualitySeeds and deltaSeedConf == 0)) {
deltaSeedConf = compatibleSeedR.size() + 1 - seedFilterState.nTopSeedConf;
if (deltaSeedConf < 0 ||
(seedFilterState.numQualitySeeds and deltaSeedConf == 0)) {
continue;
}
bool seedRangeCuts = bottomSP.radius() < m_cfg.seedConfMinBottomRadius ||
Expand All @@ -168,9 +153,9 @@ void SeedFilter<external_spacepoint_t>::filterSeeds_2SpFixed(
if (deltaSeedConf > 0) {
// if we have not yet reached our max number of quality seeds we add the
// new seed to outCont
if (numQualitySeeds < m_cfg.maxQualitySeedsPerSpMConf) {
if (seedFilterState.numQualitySeeds < m_cfg.maxQualitySeedsPerSpMConf) {
// fill high quality seed
++numQualitySeeds;
seedFilterState.numQualitySeeds++;
outCont.push_back(std::make_pair(
weight,
std::make_unique<const InternalSeed<external_spacepoint_t>>(
Expand All @@ -191,9 +176,9 @@ void SeedFilter<external_spacepoint_t>::filterSeeds_2SpFixed(
// keep the normal behavior without seed quality confirmation
// if we have not yet reached our max number of seeds we add the new seed
// to outCont
if (numSeeds < m_cfg.maxSeedsPerSpMConf) {
if (seedFilterState.numSeeds < m_cfg.maxSeedsPerSpMConf) {
// fill seed
++numSeeds;
seedFilterState.numSeeds++;
outCont.push_back(std::make_pair(
weight, std::make_unique<const InternalSeed<external_spacepoint_t>>(
bottomSP, middleSP, *topSpVec[i], zOrigin, false)));
Expand All @@ -206,12 +191,13 @@ void SeedFilter<external_spacepoint_t>::filterSeeds_2SpFixed(
}
// if no high quality seed was found for a certain middle+bottom SP pair,
// lower quality seeds can be accepted
if (m_cfg.seedConfirmation and maxWeightSeed and !numQualitySeeds) {
if (m_cfg.seedConfirmation and maxWeightSeed and
!seedFilterState.numQualitySeeds) {
// if we have not yet reached our max number of seeds we add the new seed to
// outCont
if (numSeeds < m_cfg.maxSeedsPerSpMConf) {
if (seedFilterState.numSeeds < m_cfg.maxSeedsPerSpMConf) {
// fill seed
++numSeeds;
seedFilterState.numSeeds++;
outCont.push_back(std::make_pair(
weightMax,
std::make_unique<const InternalSeed<external_spacepoint_t>>(
Expand Down
4 changes: 2 additions & 2 deletions Core/include/Acts/Seeding/SeedFinder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class SeedFinder {
///////////////////////////////////////////////////////////////////

public:
struct State {
struct SeedingState {
// bottom space point
std::vector<InternalSpacePoint<external_spacepoint_t>*> compatBottomSP;
std::vector<InternalSpacePoint<external_spacepoint_t>*> compatTopSP;
Expand Down Expand Up @@ -78,7 +78,7 @@ class SeedFinder {
/// @note Ranges must be separate objects for each parallel call.
template <template <typename...> typename container_t, typename sp_range_t>
void createSeedsForGroup(
State& state,
SeedingState& state,
std::back_insert_iterator<container_t<Seed<external_spacepoint_t>>> outIt,
sp_range_t bottomSPs, sp_range_t middleSPs, sp_range_t topSPs,
Extent rRangeSPExtent) const;
Expand Down
54 changes: 24 additions & 30 deletions Core/include/Acts/Seeding/SeedFinder.ipp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ SeedFinder<external_spacepoint_t, platform_t>::SeedFinder(
template <typename external_spacepoint_t, typename platform_t>
template <template <typename...> typename container_t, typename sp_range_t>
void SeedFinder<external_spacepoint_t, platform_t>::createSeedsForGroup(
State& state,
SeedingState& state,
std::back_insert_iterator<container_t<Seed<external_spacepoint_t>>> outIt,
sp_range_t bottomSPs, sp_range_t middleSPs, sp_range_t topSPs,
Extent rRangeSPExtent) const {
Expand Down Expand Up @@ -71,21 +71,6 @@ void SeedFinder<external_spacepoint_t, platform_t>::createSeedsForGroup(
}
}

size_t nTopSeedConf = 0;
if (m_config.seedConfirmation == true) {
// check if middle SP is in the central or forward region
SeedConfirmationRangeConfig seedConfRange =
(zM > m_config.centralSeedConfirmationRange.zMaxSeedConf ||
zM < m_config.centralSeedConfirmationRange.zMinSeedConf)
? m_config.forwardSeedConfirmationRange
: m_config.centralSeedConfirmationRange;
// set the minimum number of top SP depending on whether the middle SP is
// in the central or forward region
nTopSeedConf = rM > seedConfRange.rMaxSeedConf
? seedConfRange.nTopForLargeR
: seedConfRange.nTopForSmallR;
}

state.compatTopSP.clear();

for (auto topSP : topSPs) {
Expand Down Expand Up @@ -154,12 +139,25 @@ void SeedFinder<external_spacepoint_t, platform_t>::createSeedsForGroup(
}
state.compatTopSP.push_back(topSP);
}
if (state.compatTopSP.empty()) {
continue;
}
// apply cut on the number of top SP if seedConfirmation is true
if (m_config.seedConfirmation == true &&
state.compatTopSP.size() < nTopSeedConf) {
SeedFilterState seedFilterState;
if (m_config.seedConfirmation == true) {
// check if middle SP is in the central or forward region
SeedConfirmationRangeConfig seedConfRange =
(zM > m_config.centralSeedConfirmationRange.zMaxSeedConf ||
zM < m_config.centralSeedConfirmationRange.zMinSeedConf)
? m_config.forwardSeedConfirmationRange
: m_config.centralSeedConfirmationRange;
// set the minimum number of top SP depending on whether the middle SP is
// in the central or forward region
seedFilterState.nTopSeedConf = rM > seedConfRange.rMaxSeedConf
? seedConfRange.nTopForLargeR
: seedConfRange.nTopForSmallR;
if (state.compatTopSP.size() < seedFilterState.nTopSeedConf) {
continue;
}
}
if (state.compatTopSP.empty()) {
continue;
}

Expand Down Expand Up @@ -250,9 +248,6 @@ void SeedFinder<external_spacepoint_t, platform_t>::createSeedsForGroup(
size_t numBotSP = state.compatBottomSP.size();
size_t numTopSP = state.compatTopSP.size();

int numQualitySeeds = 0;
int numSeeds = 0;

size_t t0 = 0;

for (size_t b = 0; b < numBotSP; b++) {
Expand All @@ -262,7 +257,7 @@ void SeedFinder<external_spacepoint_t, platform_t>::createSeedsForGroup(
}

auto lb = state.linCircleBottom[b];
float Zob = lb.Zo;
seedFilterState.zOrigin = lb.Zo;
float cotThetaB = lb.cotTheta;
float Vb = lb.V;
float Ub = lb.U;
Expand Down Expand Up @@ -515,12 +510,11 @@ void SeedFinder<external_spacepoint_t, platform_t>::createSeedsForGroup(
if (!state.topSpVec.empty()) {
m_config.seedFilter->filterSeeds_2SpFixed(
*state.compatBottomSP[b], *spM, state.topSpVec, state.curvatures,
state.impactParameters, Zob, numQualitySeeds, numSeeds,
state.seedsPerSpM);
state.impactParameters, seedFilterState, state.seedsPerSpM);
}
}
m_config.seedFilter->filterSeeds_1SpFixed(state.seedsPerSpM,
numQualitySeeds, outIt);
m_config.seedFilter->filterSeeds_1SpFixed(
state.seedsPerSpM, seedFilterState.numQualitySeeds, outIt);
}
}

Expand All @@ -529,7 +523,7 @@ template <typename sp_range_t>
std::vector<Seed<external_spacepoint_t>>
SeedFinder<external_spacepoint_t, platform_t>::createSeedsForGroup(
sp_range_t bottomSPs, sp_range_t middleSPs, sp_range_t topSPs) const {
State state;
SeedingState state;
Extent extent;
std::vector<Seed<external_spacepoint_t>> ret;

Expand Down
5 changes: 3 additions & 2 deletions Core/include/Acts/Seeding/SeedFinderOrthogonal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -195,13 +195,14 @@ class SeedFinderOrthogonal {
* @param middle The (singular) middle spacepoint.
* @param bottom The (vector of) candidate bottom spacepoints.
* @param top The (vector of) candidate top spacepoints.
* @param numQualitySeeds number of high quality seeds in seed confirmation.
* @param seedFilterState holds quantities used in seed filter
* @param cont The container to write the resulting seeds to.
*/
template <typename output_container_t>
void filterCandidates(internal_sp_t &middle,
std::vector<internal_sp_t *> &bottom,
std::vector<internal_sp_t *> &top, int numQualitySeeds,
std::vector<internal_sp_t *> &top,
SeedFilterState seedFilterState,
output_container_t &cont) const;

/**
Expand Down
22 changes: 11 additions & 11 deletions Core/include/Acts/Seeding/SeedFinderOrthogonal.ipp
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ template <typename external_spacepoint_t>
template <typename output_container_t>
void SeedFinderOrthogonal<external_spacepoint_t>::filterCandidates(
internal_sp_t &middle, std::vector<internal_sp_t *> &bottom,
std::vector<internal_sp_t *> &top, int numQualitySeeds,
std::vector<internal_sp_t *> &top, SeedFilterState seedFilterState,
output_container_t &cont) const {
float rM = middle.radius();
float varianceRM = middle.varianceR();
Expand Down Expand Up @@ -284,11 +284,9 @@ void SeedFinderOrthogonal<external_spacepoint_t>::filterCandidates(
top[t]->z() - middle.z()));
}

int numSeeds = 0;

for (size_t b = 0; b < numBotSP; b++) {
auto lb = linCircleBottom[b];
float Zob = lb.Zo;
seedFilterState.zOrigin = lb.Zo;
float cotThetaB = lb.cotTheta;
float Vb = lb.V;
float Ub = lb.U;
Expand Down Expand Up @@ -391,9 +389,9 @@ void SeedFinderOrthogonal<external_spacepoint_t>::filterCandidates(
}
}
if (!top_valid.empty()) {
m_config.seedFilter->filterSeeds_2SpFixed(
*bottom[b], middle, top_valid, curvatures, impactParameters, Zob,
numQualitySeeds, numSeeds, cont);
m_config.seedFilter->filterSeeds_2SpFixed(*bottom[b], middle, top_valid,
curvatures, impactParameters,
seedFilterState, cont);
}
}
}
Expand Down Expand Up @@ -551,28 +549,30 @@ void SeedFinderOrthogonal<external_spacepoint_t>::processFromMiddleSP(
float, std::unique_ptr<const InternalSeed<external_spacepoint_t>>>>
protoseeds;

int numQualitySeeds = 0;
// TODO: add seed confirmation
SeedFilterState seedFilterState;

/*
* If we have candidates for increasing z tracks, we try to combine them.
*/
if (!bottom_lh_v.empty() && !top_lh_v.empty()) {
filterCandidates(middle, bottom_lh_v, top_lh_v, numQualitySeeds,
filterCandidates(middle, bottom_lh_v, top_lh_v, seedFilterState,
protoseeds);
}

/*
* Try to combine candidates for decreasing z tracks.
*/
if (!bottom_hl_v.empty() && !top_hl_v.empty()) {
filterCandidates(middle, bottom_hl_v, top_hl_v, numQualitySeeds,
filterCandidates(middle, bottom_hl_v, top_hl_v, seedFilterState,
protoseeds);
}

/*
* Run a seed filter, just like in other seeding algorithms.
*/
m_config.seedFilter->filterSeeds_1SpFixed(protoseeds, numQualitySeeds,
m_config.seedFilter->filterSeeds_1SpFixed(protoseeds,
seedFilterState.numQualitySeeds,
std::back_inserter(out_cont));
}

Expand Down
2 changes: 1 addition & 1 deletion Examples/Algorithms/TrackFinding/src/SeedingAlgorithm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ ActsExamples::ProcessCode ActsExamples::SeedingAlgorithm::execute(
// run the seeding
static thread_local SimSeedContainer seeds;
seeds.clear();
static thread_local decltype(finder)::State state;
static thread_local decltype(finder)::SeedingState state;

auto group = spacePointsGrouping.begin();
auto groupEnd = spacePointsGrouping.end();
Expand Down
2 changes: 1 addition & 1 deletion Tests/UnitTests/Core/Seeding/SeedFinderTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ int main(int argc, char** argv) {
std::move(grid), rRangeSPExtent, config);

std::vector<std::vector<Acts::Seed<SpacePoint>>> seedVector;
decltype(a)::State state;
decltype(a)::SeedingState state;
auto start = std::chrono::system_clock::now();
auto groupIt = spGroup.begin();
auto endOfGroups = spGroup.end();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ int main(int argc, char** argv) {
groupIt = spGroup.begin();

if (do_cpu) {
decltype(seedFinder_cpu)::State state;
decltype(seedFinder_cpu)::SeedingState state;
for (int i_s = 0; i_s < skip; i_s++)
++groupIt;
for (; !(groupIt == spGroup.end()); ++groupIt) {
Expand Down
2 changes: 1 addition & 1 deletion Tests/UnitTests/Plugins/Cuda/Seeding2/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ int main(int argc, char* argv[]) {
// Perform the seed finding.
if (!cmdl.onlyGPU) {
auto spGroup_itr = spGroup.begin();
decltype(seedFinder_host)::State state;
decltype(seedFinder_host)::SeedingState state;
for (std::size_t i = 0;
spGroup_itr != spGroup_end && i < cmdl.groupsToIterate;
++i, ++spGroup_itr) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ auto main(int argc, char** argv) -> int {
std::vector<std::vector<Acts::Seed<SpacePoint>>> seedVector_cpu;

if (!cmdlTool.onlyGpu) {
decltype(normalSeedFinder)::State state;
decltype(normalSeedFinder)::SeedingState state;
for (auto groupIt = spGroup.begin(); !(groupIt == spGroup.end());
++groupIt) {
normalSeedFinder.createSeedsForGroup(
Expand Down

0 comments on commit f20738e

Please sign in to comment.