Skip to content

Commit

Permalink
Merge pull request #20429 from makortel/cleanupTrackMVA
Browse files Browse the repository at this point in the history
Refactor track MVA classifier code
  • Loading branch information
cmsbuild committed Sep 20, 2017
2 parents 8162129 + ef0d666 commit 6bf2255
Show file tree
Hide file tree
Showing 18 changed files with 99 additions and 81 deletions.
Expand Up @@ -102,7 +102,6 @@
#### select tracks based on MaximumImpactParameter, MaximumZ, MinimumTotalLayers, MinimumPixelLayers and MaximumNormChi2
process.pixelTracksCutClassifier = cms.EDProducer( "TrackCutClassifier",
src = cms.InputTag( "pixelTracks" ),
GBRForestLabel = cms.string( "" ),
beamspot = cms.InputTag( "offlineBeamSpot" ),
# vertices = cms.InputTag( "pixelVertices" ),
vertices = cms.InputTag( "" ),
Expand Down
13 changes: 13 additions & 0 deletions HLTrigger/Configuration/python/customizeHLTforCMSSW.py
Expand Up @@ -51,6 +51,18 @@ def customiseFor20422(process):
producer.applyDCConstraint = m2Parameters.applyDCConstraint
return process

# Refactor track MVA classifiers
def customiseFor20429(process):
for producer in producers_by_type(process, "TrackMVAClassifierDetached", "TrackMVAClassifierPrompt"):
producer.mva.GBRForestLabel = producer.GBRForestLabel
producer.mva.GBRForestFileName = producer.GBRForestFileName
del producer.GBRForestLabel
del producer.GBRForestFileName
for producer in producers_by_type(process, "TrackCutClassifier"):
del producer.GBRForestLabel
del producer.GBRForestFileName
return process

# CMSSW version specific customizations
def customizeHLTforCMSSW(process, menuType="GRun"):

Expand All @@ -61,5 +73,6 @@ def customizeHLTforCMSSW(process, menuType="GRun"):
process = customiseFor20269(process)
process = customiseFor19989(process)
process = customiseFor20422(process)
process = customiseFor20429(process)

return process
26 changes: 13 additions & 13 deletions RecoTracker/FinalTrackSelectors/interface/TrackMVAClassifier.h
Expand Up @@ -23,7 +23,7 @@
class TrackMVAClassifierBase : public edm::stream::EDProducer<> {
public:
explicit TrackMVAClassifierBase( const edm::ParameterSet & cfg );
~TrackMVAClassifierBase();
~TrackMVAClassifierBase() override;
protected:

static void fill( edm::ParameterSetDescription& desc);
Expand All @@ -32,18 +32,15 @@ class TrackMVAClassifierBase : public edm::stream::EDProducer<> {
using MVACollection = std::vector<float>;
using QualityMaskCollection = std::vector<unsigned char>;

virtual void initEvent(const edm::EventSetup& es) = 0;

virtual void computeMVA(reco::TrackCollection const & tracks,
reco::BeamSpot const & beamSpot,
reco::VertexCollection const & vertices,
GBRForest const * forestP,
MVACollection & mvas) const = 0;


private:

void beginStream(edm::StreamID) override final;

void produce(edm::Event& evt, const edm::EventSetup& es ) override final;
void produce(edm::Event& evt, const edm::EventSetup& es ) final;

/// source collection label
edm::EDGetTokenT<reco::TrackCollection> src_;
Expand All @@ -53,10 +50,6 @@ class TrackMVAClassifierBase : public edm::stream::EDProducer<> {
bool ignoreVertices_;

// MVA
std::unique_ptr<GBRForest> forest_;
const std::string forestLabel_;
const std::string dbFileName_;
const bool useForestFromDB_;

// qualitycuts (loose, tight, hp)
float qualityCuts[3];
Expand All @@ -81,15 +74,22 @@ class TrackMVAClassifier : public TrackMVAClassifierBase {


private:
void beginStream(edm::StreamID) final {
mva.beginStream();
}

void initEvent(const edm::EventSetup& es) final {
mva.initEvent(es);
}

void computeMVA(reco::TrackCollection const & tracks,
reco::BeamSpot const & beamSpot,
reco::VertexCollection const & vertices,
GBRForest const * forestP,
MVACollection & mvas) const final {

size_t current = 0;
for (auto const & trk : tracks) {
mvas[current++]= mva(trk,beamSpot,vertices,forestP);
mvas[current++]= mva(trk,beamSpot,vertices);
}
}

Expand Down
@@ -1,23 +1,47 @@
#include "RecoTracker/FinalTrackSelectors/interface/TrackMVAClassifier.h"

#include "FWCore/Framework/interface/EventSetup.h"
#include "FWCore/Framework/interface/ESHandle.h"
#include "CondFormats/DataRecord/interface/GBRWrapperRcd.h"

#include "DataFormats/TrackReco/interface/Track.h"
#include "DataFormats/VertexReco/interface/Vertex.h"
#include <limits>

#include "getBestVertex.h"

#include "TFile.h"

namespace {

template<bool PROMPT>
struct mva {
mva(const edm::ParameterSet &){}
mva(const edm::ParameterSet &cfg):
forestLabel_ ( cfg.getParameter<std::string>("GBRForestLabel") ),
dbFileName_ ( cfg.getParameter<std::string>("GBRForestFileName") ),
useForestFromDB_( (!forestLabel_.empty()) & dbFileName_.empty())
{}

void beginStream() {
if(!dbFileName_.empty()){
TFile gbrfile(dbFileName_.c_str());
forestFromFile_.reset((GBRForest*)gbrfile.Get(forestLabel_.c_str()));
}
}

void initEvent(const edm::EventSetup& es) {
forest_ = forestFromFile_.get();
if(useForestFromDB_){
edm::ESHandle<GBRForest> forestHandle;
es.get<GBRWrapperRcd>().get(forestLabel_,forestHandle);
forest_ = forestHandle.product();
}
}

float operator()(reco::Track const & trk,
reco::BeamSpot const & beamSpot,
reco::VertexCollection const & vertices,
GBRForest const * forestP) const {
reco::VertexCollection const & vertices) const {

auto const & forest = *forestP;
auto tmva_pt_ = trk.pt();
auto tmva_ndof_ = trk.ndof();
auto tmva_nlayers_ = trk.hitPattern().trackerLayersWithMeasurement();
Expand Down Expand Up @@ -77,17 +101,22 @@ struct mva {




return forest.GetClassifier(gbrVals_);
return forest_->GetClassifier(gbrVals_);

}

static const char * name();

static void fillDescriptions(edm::ParameterSetDescription & desc) {
desc.add<std::string>("GBRForestLabel",std::string());
desc.add<std::string>("GBRForestFileName",std::string());
}


std::unique_ptr<GBRForest> forestFromFile_;
const GBRForest *forest_ = nullptr; // owned by somebody else
const std::string forestLabel_;
const std::string dbFileName_;
const bool useForestFromDB_;
};

using TrackMVAClassifierDetached = TrackMVAClassifier<mva<false>>;
Expand Down
6 changes: 3 additions & 3 deletions RecoTracker/FinalTrackSelectors/plugins/TrackCutClassifier.cc
Expand Up @@ -175,12 +175,12 @@ namespace {
fillArrayF(drWPVerr_par, dr_par,"drWPVerr_par");
}


void beginStream() {}
void initEvent(const edm::EventSetup&) {}

float operator()(reco::Track const & trk,
reco::BeamSpot const & beamSpot,
reco::VertexCollection const & vertices,
GBRForest const *) const {
reco::VertexCollection const & vertices) const {

float ret = 1.f;
// minimum number of hits for by-passing the other checks
Expand Down
4 changes: 2 additions & 2 deletions RecoTracker/FinalTrackSelectors/python/classifierTest_cff.py
Expand Up @@ -6,7 +6,7 @@

testTrackClassifier1 = TrackMVAClassifierPrompt.clone()
testTrackClassifier1.src = 'initialStepTracks'
testTrackClassifier1.GBRForestLabel = 'MVASelectorIter0_13TeV'
testTrackClassifier1.mva.GBRForestLabel = 'MVASelectorIter0_13TeV'
testTrackClassifier1.qualityCuts = [-0.9,-0.8,-0.7]


Expand All @@ -27,7 +27,7 @@

testTrackClassifier3 = TrackMVAClassifierDetached.clone()
testTrackClassifier3.src = 'detachedTripletStepTracks'
testTrackClassifier3.GBRForestLabel = 'MVASelectorIter3_13TeV'
testTrackClassifier3.mva.GBRForestLabel = 'MVASelectorIter3_13TeV'
testTrackClassifier3.qualityCuts = [-0.5,0.0,0.5]


Expand Down
31 changes: 4 additions & 27 deletions RecoTracker/FinalTrackSelectors/src/TrackMVAClassifierBase.cc
@@ -1,8 +1,5 @@
#include "RecoTracker/FinalTrackSelectors/interface/TrackMVAClassifier.h"

#include "CondFormats/DataRecord/interface/GBRWrapperRcd.h"
#include "FWCore/Framework/interface/EventSetup.h"
#include "FWCore/Framework/interface/ESHandle.h"
#include "FWCore/MessageLogger/interface/MessageLogger.h"

#include "DataFormats/TrackReco/interface/Track.h"
Expand All @@ -16,8 +13,6 @@ void TrackMVAClassifierBase::fill( edm::ParameterSetDescription& desc) {
desc.add<edm::InputTag>("beamspot",edm::InputTag("offlineBeamSpot"));
desc.add<edm::InputTag>("vertices",edm::InputTag("firstStepPrimaryVertices"));
desc.add<bool>("ignoreVertices",false);
desc.add<std::string>("GBRForestLabel",std::string());
desc.add<std::string>("GBRForestFileName",std::string());
// default cuts for "cut based classification"
std::vector<double> cuts = {-.7, 0.1, .7};
desc.add<std::vector<double>>("qualityCuts", cuts);
Expand All @@ -30,10 +25,7 @@ TrackMVAClassifierBase::TrackMVAClassifierBase( const edm::ParameterSet & cfg )
src_ ( consumes<reco::TrackCollection> (cfg.getParameter<edm::InputTag>( "src" )) ),
beamspot_( consumes<reco::BeamSpot> (cfg.getParameter<edm::InputTag>( "beamspot" )) ),
vertices_( mayConsume<reco::VertexCollection>(cfg.getParameter<edm::InputTag>( "vertices" )) ),
ignoreVertices_( cfg.getParameter<bool>( "ignoreVertices" ) ),
forestLabel_ ( cfg.getParameter<std::string>("GBRForestLabel") ),
dbFileName_ ( cfg.getParameter<std::string>("GBRForestFileName") ),
useForestFromDB_( (!forestLabel_.empty()) & dbFileName_.empty()) {
ignoreVertices_( cfg.getParameter<bool>( "ignoreVertices" ) ) {

auto const & qv = cfg.getParameter<std::vector<double>>("qualityCuts");
assert(qv.size()==3);
Expand All @@ -59,24 +51,19 @@ void TrackMVAClassifierBase::produce(edm::Event& evt, const edm::EventSetup& es
edm::Handle<reco::VertexCollection> hVtx;
evt.getByToken(vertices_, hVtx);

GBRForest const * forest = forest_.get();
if(useForestFromDB_){
edm::ESHandle<GBRForest> forestHandle;
es.get<GBRWrapperRcd>().get(forestLabel_,forestHandle);
forest = forestHandle.product();
}
initEvent(es);

// products
auto mvas = std::make_unique<MVACollection>(tracks.size(),-99.f);
auto quals = std::make_unique<QualityMaskCollection>(tracks.size(),0);

if ( hVtx.isValid() && !ignoreVertices_ ) {
computeMVA(tracks,*hBsp,*hVtx,forest,*mvas);
computeMVA(tracks,*hBsp,*hVtx,*mvas);
} else {
if ( !ignoreVertices_ )
edm::LogWarning("TrackMVAClassifierBase") << "ignoreVertices is set to False in the configuration, but the vertex collection is not valid";
std::vector<reco::Vertex> vertices;
computeMVA(tracks,*hBsp,vertices,forest,*mvas);
computeMVA(tracks,*hBsp,vertices,*mvas);
}
assert((*mvas).size()==tracks.size());

Expand All @@ -95,13 +82,3 @@ void TrackMVAClassifierBase::produce(edm::Event& evt, const edm::EventSetup& es
evt.put(std::move(quals),"QualityMasks");

}


#include <TFile.h>
void TrackMVAClassifierBase::beginStream(edm::StreamID) {
if(!dbFileName_.empty()){
TFile gbrfile(dbFileName_.c_str());
forest_.reset((GBRForest*)gbrfile.Get(forestLabel_.c_str()));
}
}

Expand Up @@ -190,7 +190,7 @@
from RecoTracker.FinalTrackSelectors.TrackMVAClassifierDetached_cfi import *
detachedQuadStep = TrackMVAClassifierDetached.clone(
src = 'detachedQuadStepTracks',
GBRForestLabel = 'MVASelectorDetachedQuadStep_Phase1',
mva = dict(GBRForestLabel = 'MVASelectorDetachedQuadStep_Phase1'),
qualityCuts = [-0.5,0.0,0.5],
)

Expand Down
Expand Up @@ -182,23 +182,23 @@
from RecoTracker.FinalTrackSelectors.TrackMVAClassifierDetached_cfi import *
detachedTripletStepClassifier1 = TrackMVAClassifierDetached.clone()
detachedTripletStepClassifier1.src = 'detachedTripletStepTracks'
detachedTripletStepClassifier1.GBRForestLabel = 'MVASelectorIter3_13TeV'
detachedTripletStepClassifier1.mva.GBRForestLabel = 'MVASelectorIter3_13TeV'
detachedTripletStepClassifier1.qualityCuts = [-0.5,0.0,0.5]
detachedTripletStepClassifier2 = TrackMVAClassifierPrompt.clone()
detachedTripletStepClassifier2.src = 'detachedTripletStepTracks'
detachedTripletStepClassifier2.GBRForestLabel = 'MVASelectorIter0_13TeV'
detachedTripletStepClassifier2.mva.GBRForestLabel = 'MVASelectorIter0_13TeV'
detachedTripletStepClassifier2.qualityCuts = [-0.2,0.0,0.4]

from RecoTracker.FinalTrackSelectors.ClassifierMerger_cfi import *
detachedTripletStep = ClassifierMerger.clone()
detachedTripletStep.inputClassifiers=['detachedTripletStepClassifier1','detachedTripletStepClassifier2']

trackingPhase1.toReplaceWith(detachedTripletStep, detachedTripletStepClassifier1.clone(
GBRForestLabel = 'MVASelectorDetachedTripletStep_Phase1',
mva = dict(GBRForestLabel = 'MVASelectorDetachedTripletStep_Phase1'),
qualityCuts = [-0.2,0.3,0.8],
))
trackingPhase1QuadProp.toReplaceWith(detachedTripletStep, detachedTripletStepClassifier1.clone(
GBRForestLabel = 'MVASelectorDetachedTripletStep_Phase1',
mva = dict(GBRForestLabel = 'MVASelectorDetachedTripletStep_Phase1'),
qualityCuts = [-0.2,0.3,0.8],
))

Expand Down
Expand Up @@ -206,7 +206,7 @@
from RecoTracker.FinalTrackSelectors.TrackMVAClassifierPrompt_cfi import *
highPtTripletStep = TrackMVAClassifierPrompt.clone(
src = 'highPtTripletStepTracks',
GBRForestLabel = 'MVASelectorHighPtTripletStep_Phase1',
mva = dict(GBRForestLabel = 'MVASelectorHighPtTripletStep_Phase1'),
qualityCuts = [0.2,0.3,0.4],
)

Expand Down
6 changes: 3 additions & 3 deletions RecoTracker/IterativeTracking/python/InitialStep_cff.py
Expand Up @@ -248,7 +248,7 @@

initialStepClassifier1 = TrackMVAClassifierPrompt.clone()
initialStepClassifier1.src = 'initialStepTracks'
initialStepClassifier1.GBRForestLabel = 'MVASelectorIter0_13TeV'
initialStepClassifier1.mva.GBRForestLabel = 'MVASelectorIter0_13TeV'
initialStepClassifier1.qualityCuts = [-0.9,-0.8,-0.7]

from RecoTracker.IterativeTracking.DetachedTripletStep_cff import detachedTripletStepClassifier1
Expand All @@ -263,11 +263,11 @@
initialStep.inputClassifiers=['initialStepClassifier1','initialStepClassifier2','initialStepClassifier3']

trackingPhase1.toReplaceWith(initialStep, initialStepClassifier1.clone(
GBRForestLabel = 'MVASelectorInitialStep_Phase1',
mva = dict(GBRForestLabel = 'MVASelectorInitialStep_Phase1'),
qualityCuts = [-0.95,-0.85,-0.75],
))
trackingPhase1QuadProp.toReplaceWith(initialStep, initialStepClassifier1.clone(
GBRForestLabel = 'MVASelectorInitialStep_Phase1',
mva = dict(GBRForestLabel = 'MVASelectorInitialStep_Phase1'),
qualityCuts = [-0.95,-0.85,-0.75],
))

Expand Down
Expand Up @@ -172,12 +172,12 @@
from RecoTracker.FinalTrackSelectors.TrackMVAClassifierPrompt_cfi import *
trackingPhase1.toReplaceWith(jetCoreRegionalStep, TrackMVAClassifierPrompt.clone(
src = 'jetCoreRegionalStepTracks',
GBRForestLabel = 'MVASelectorJetCoreRegionalStep_Phase1',
mva = dict(GBRForestLabel = 'MVASelectorJetCoreRegionalStep_Phase1'),
qualityCuts = [-0.2,0.0,0.4],
))
trackingPhase1QuadProp.toReplaceWith(jetCoreRegionalStep, TrackMVAClassifierPrompt.clone(
src = 'jetCoreRegionalStepTracks',
GBRForestLabel = 'MVASelectorJetCoreRegionalStep_Phase1',
mva = dict(GBRForestLabel = 'MVASelectorJetCoreRegionalStep_Phase1'),
qualityCuts = [-0.2,0.0,0.4],
))

Expand Down
2 changes: 1 addition & 1 deletion RecoTracker/IterativeTracking/python/LowPtQuadStep_cff.py
Expand Up @@ -181,7 +181,7 @@
from RecoTracker.FinalTrackSelectors.TrackMVAClassifierPrompt_cfi import *
lowPtQuadStep = TrackMVAClassifierPrompt.clone(
src = 'lowPtQuadStepTracks',
GBRForestLabel = 'MVASelectorLowPtQuadStep_Phase1',
mva = dict(GBRForestLabel = 'MVASelectorLowPtQuadStep_Phase1'),
qualityCuts = [-0.65,-0.35,-0.15],
)

Expand Down
6 changes: 3 additions & 3 deletions RecoTracker/IterativeTracking/python/LowPtTripletStep_cff.py
Expand Up @@ -215,15 +215,15 @@
from RecoTracker.FinalTrackSelectors.TrackMVAClassifierPrompt_cfi import *
lowPtTripletStep = TrackMVAClassifierPrompt.clone()
lowPtTripletStep.src = 'lowPtTripletStepTracks'
lowPtTripletStep.GBRForestLabel = 'MVASelectorIter1_13TeV'
lowPtTripletStep.mva.GBRForestLabel = 'MVASelectorIter1_13TeV'
lowPtTripletStep.qualityCuts = [-0.6,-0.3,-0.1]

trackingPhase1.toReplaceWith(lowPtTripletStep, lowPtTripletStep.clone(
GBRForestLabel = 'MVASelectorLowPtTripletStep_Phase1',
mva = dict(GBRForestLabel = 'MVASelectorLowPtTripletStep_Phase1'),
qualityCuts = [0.0,0.2,0.4],
))
trackingPhase1QuadProp.toReplaceWith(lowPtTripletStep, lowPtTripletStep.clone(
GBRForestLabel = 'MVASelectorLowPtTripletStep_Phase1',
mva = dict(GBRForestLabel = 'MVASelectorLowPtTripletStep_Phase1'),
qualityCuts = [0.0,0.2,0.4],
))

Expand Down

0 comments on commit 6bf2255

Please sign in to comment.