-
Notifications
You must be signed in to change notification settings - Fork 157
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Add CSV writer for MultiTrajectory (#809)
This PR is for adding a csv writer for MultiTrajectory. The code goes though the trajectories from the CFK step and extract some information (e.g. nMeasurements, nOutliers, nHoles, chi2Sum ... ). Selection criteria applied are taken from the CKF performance algorithm.
- Loading branch information
1 parent
7ba646a
commit c7438cf
Showing
4 changed files
with
286 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
76 changes: 76 additions & 0 deletions
76
Examples/Io/Csv/include/ActsExamples/Io/Csv/CsvMultiTrajectoryWriter.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
// This file is part of the Acts project. | ||
// | ||
// Copyright (C) 2021 CERN for the benefit of the Acts project | ||
// | ||
// This Source Code Form is subject to the terms of the Mozilla Public | ||
// License, v. 2.0. If a copy of the MPL was not distributed with this | ||
// file, You can obtain one at http://mozilla.org/MPL/2.0/. | ||
|
||
#pragma once | ||
|
||
#include "Acts/EventData/MultiTrajectoryHelpers.hpp" | ||
#include "ActsExamples/EventData/Trajectories.hpp" | ||
#include "ActsExamples/Framework/WriterT.hpp" | ||
#include "ActsFatras/EventData/Barcode.hpp" | ||
|
||
#include <fstream> | ||
|
||
using namespace Acts::UnitLiterals; | ||
|
||
namespace ActsExamples { | ||
|
||
/// @class CsvMultiTrajectoryWriter | ||
/// | ||
/// Write out the tracks reconstructed using Combinatorial Kalman Filter in | ||
/// comma-separated-value format. | ||
/// | ||
/// This writes one file per event into the configured output directory. By | ||
/// default it writes to the current working directory. | ||
/// Files are named using the following schema | ||
/// | ||
/// event000000001-CKFtracks.csv | ||
/// event000000002-CKFtracks.csv | ||
/// | ||
/// and each line in the file corresponds to one track. | ||
class CsvMultiTrajectoryWriter : public WriterT<TrajectoriesContainer> { | ||
public: | ||
struct Config { | ||
std::string inputTrajectories; ///< Input trajectory collection | ||
std::string outputDir; ///< where to place output files | ||
std::string | ||
inputMeasurementParticlesMap; ///< Input hit-particles map collection | ||
size_t outputPrecision = 6; ///< floating point precision | ||
size_t nMeasurementsMin = 9; ///< Min number of measurements | ||
double truthMatchProbMin = 0.5; ///< Probability threshold for fake tracks | ||
double ptMin = 1_GeV; ///< Min pt of tracks | ||
}; | ||
|
||
/// constructor | ||
/// @param cfg is the configuration object | ||
/// @parm level is the output logging level | ||
CsvMultiTrajectoryWriter(const Config& cfg, | ||
Acts::Logging::Level level = Acts::Logging::INFO); | ||
|
||
protected: | ||
/// @brief Write method called by the base class | ||
/// @param [in] context is the algorithm context for consistency | ||
/// @param [in] tracks is the track collection | ||
ProcessCode writeT(const AlgorithmContext& context, | ||
const TrajectoriesContainer& trajectories) final override; | ||
|
||
private: | ||
Config m_cfg; //!< Nested configuration struct | ||
|
||
/// @brief Struct for brief trajectory summary info | ||
/// | ||
struct trackInfo : public Acts::MultiTrajectoryHelpers::TrajectoryState { | ||
size_t trackId; | ||
ActsFatras::Barcode particleId; | ||
size_t nMajorityHits; | ||
std::string trackType; | ||
double truthMatchProb; | ||
const TrackParameters* fittedParameters; | ||
}; // trackInfo struct | ||
}; | ||
|
||
} // namespace ActsExamples |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,192 @@ | ||
// This file is part of the Acts project. | ||
// | ||
// Copyright (C) 2021 CERN for the benefit of the Acts project | ||
// | ||
// This Source Code Form is subject to the terms of the Mozilla Public | ||
// License, v. 2.0. If a copy of the MPL was not distributed with this | ||
// file, You can obtain one at http://mozilla.org/MPL/2.0/. | ||
|
||
#include "ActsExamples/Io/Csv/CsvMultiTrajectoryWriter.hpp" | ||
|
||
#include "ActsExamples/Utilities/Paths.hpp" | ||
#include "ActsExamples/Validation/TrackClassification.hpp" | ||
|
||
#include <ios> | ||
#include <iostream> | ||
#include <stdexcept> | ||
#include <string> | ||
#include <unordered_map> | ||
#include <unordered_set> | ||
|
||
using namespace ActsExamples; | ||
|
||
CsvMultiTrajectoryWriter::CsvMultiTrajectoryWriter( | ||
const CsvMultiTrajectoryWriter::Config& cfg, Acts::Logging::Level level) | ||
: WriterT<TrajectoriesContainer>(cfg.inputTrajectories, | ||
"CsvMultiTrajectoryWriter", level), | ||
m_cfg(cfg) { | ||
if (m_cfg.inputTrajectories.empty()) { | ||
throw std::invalid_argument("Missing input trajectories collection"); | ||
} | ||
} | ||
|
||
ProcessCode CsvMultiTrajectoryWriter::writeT( | ||
const AlgorithmContext& context, | ||
const TrajectoriesContainer& trajectories) { | ||
// open per-event file | ||
std::string path = | ||
perEventFilepath(m_cfg.outputDir, "CKFtracks.csv", context.eventNumber); | ||
std::ofstream mos(path, std::ofstream::out | std::ofstream::trunc); | ||
if (!mos) { | ||
throw std::ios_base::failure("Could not open '" + path + "' to write"); | ||
} | ||
|
||
using HitParticlesMap = ActsExamples::IndexMultimap<ActsFatras::Barcode>; | ||
const auto& hitParticlesMap = context.eventStore.get<HitParticlesMap>( | ||
m_cfg.inputMeasurementParticlesMap); | ||
|
||
std::unordered_map<size_t, trackInfo> infoMap; | ||
|
||
// Counter of truth-matched reco tracks | ||
using RecoTrackInfo = std::pair<trackInfo, size_t>; | ||
std::map<ActsFatras::Barcode, std::vector<RecoTrackInfo>> matched; | ||
|
||
size_t trackId = 0; | ||
for (const auto& traj : trajectories) { | ||
// The trajectory entry indices and the multiTrajectory | ||
const auto& trackTips = traj.tips(); | ||
const auto& mj = traj.multiTrajectory(); | ||
if (trackTips.empty()) { | ||
ACTS_WARNING("Empty multiTrajectory."); | ||
continue; | ||
} | ||
|
||
// Loop over all trajectories in a multiTrajectory | ||
for (const size_t& trackTip : trackTips) { | ||
// Collect the trajectory summary info | ||
auto trajState = | ||
Acts::MultiTrajectoryHelpers::trajectoryState(mj, trackTip); | ||
// Reco track selection | ||
//@TODO: add interface for applying others cuts on reco tracks: | ||
// -> pT, d0, z0, detector-specific hits/holes number cut | ||
if (trajState.nMeasurements < m_cfg.nMeasurementsMin) { | ||
continue; | ||
} | ||
|
||
// Check if the reco track has fitted track parameters | ||
if (not traj.hasTrackParameters(trackTip)) { | ||
ACTS_WARNING( | ||
"No fitted track parameters for trajectory with entry index = " | ||
<< trackTip); | ||
continue; | ||
} | ||
|
||
// Get the majority truth particle to this track | ||
std::vector<ParticleHitCount> particleHitCount; | ||
identifyContributingParticles(hitParticlesMap, traj, trackTip, | ||
particleHitCount); | ||
if (particleHitCount.empty()) { | ||
ACTS_WARNING( | ||
"No truth particle associated with this trajectory with entry " | ||
"index = " | ||
<< trackTip); | ||
continue; | ||
} | ||
|
||
// Requirement on the pT of the track | ||
const auto& momentum = traj.trackParameters(trackTip).momentum(); | ||
const auto pT = Acts::VectorHelpers::perp(momentum); | ||
if (pT < m_cfg.ptMin) { | ||
continue; | ||
} | ||
|
||
// Get the majority particle counts | ||
ActsFatras::Barcode majorityParticleId = | ||
particleHitCount.front().particleId; | ||
// n Majority hits | ||
size_t nMajorityHits = particleHitCount.front().hitCount; | ||
|
||
// track info | ||
trackInfo toAdd; | ||
toAdd.trackId = trackId; | ||
toAdd.particleId = majorityParticleId; | ||
toAdd.nStates = trajState.nStates; | ||
toAdd.nMajorityHits = nMajorityHits; | ||
toAdd.nMeasurements = trajState.nMeasurements; | ||
toAdd.nOutliers = trajState.nOutliers; | ||
toAdd.nHoles = trajState.nHoles; | ||
toAdd.chi2Sum = trajState.chi2Sum; | ||
toAdd.NDF = trajState.NDF; | ||
toAdd.truthMatchProb = toAdd.nMajorityHits * 1. / trajState.nMeasurements; | ||
toAdd.fittedParameters = &traj.trackParameters(trackTip); | ||
toAdd.trackType = "unknown"; | ||
|
||
// Check if the trajectory is matched with truth. | ||
if (toAdd.truthMatchProb >= m_cfg.truthMatchProbMin) { | ||
matched[toAdd.particleId].push_back({toAdd, toAdd.trackId}); | ||
} else { | ||
toAdd.trackType = "fake"; | ||
} | ||
|
||
infoMap[toAdd.trackId] = toAdd; | ||
|
||
trackId++; | ||
} // end of one trajectory | ||
} // end of multi-trajectory | ||
|
||
// Find duplicates | ||
std::unordered_set<size_t> listGoodTracks; | ||
for (auto& [particleId, matchedTracks] : matched) { | ||
std::sort(matchedTracks.begin(), matchedTracks.end(), | ||
[](const RecoTrackInfo& lhs, const RecoTrackInfo& rhs) { | ||
// sort by nMajorityHits | ||
if (lhs.first.nMajorityHits > rhs.first.nMajorityHits) | ||
return true; | ||
if (lhs.first.nMajorityHits < rhs.first.nMajorityHits) | ||
return false; | ||
// sort by nOutliers | ||
return (lhs.first.nOutliers < rhs.first.nOutliers); | ||
}); | ||
|
||
listGoodTracks.insert(matchedTracks.front().first.trackId); | ||
} | ||
|
||
// write csv header | ||
mos << "track_id,particleId," | ||
<< "nStates,nMajorityHits,nMeasurements,nOutliers,nHoles," | ||
<< "chi2,ndf,chi2/ndf," | ||
<< "pT," | ||
<< "truthMatchProbability," | ||
<< "good/duplicate/fake"; | ||
|
||
mos << '\n'; | ||
mos << std::setprecision(m_cfg.outputPrecision); | ||
|
||
// good/duplicate/fake = 0/1/2 | ||
for (auto& [id, trajState] : infoMap) { | ||
if (listGoodTracks.find(id) != listGoodTracks.end()) { | ||
trajState.trackType = "good"; | ||
} else if (trajState.trackType != "fake") { | ||
trajState.trackType = "duplicate"; | ||
} | ||
|
||
// write the track info | ||
mos << trajState.trackId << ","; | ||
mos << trajState.particleId << ","; | ||
mos << trajState.nStates << ","; | ||
mos << trajState.nMajorityHits << ","; | ||
mos << trajState.nMeasurements << ","; | ||
mos << trajState.nOutliers << ","; | ||
mos << trajState.nHoles << ","; | ||
mos << trajState.chi2Sum << ","; | ||
mos << trajState.NDF << ","; | ||
mos << trajState.chi2Sum * 1.0 / trajState.NDF << ","; | ||
mos << Acts::VectorHelpers::perp(trajState.fittedParameters->momentum()) | ||
<< ","; | ||
mos << trajState.truthMatchProb << ","; | ||
mos << trajState.trackType; | ||
mos << '\n'; | ||
} | ||
|
||
return ProcessCode::SUCCESS; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters