Skip to content

Commit

Permalink
- RecoTauTag/RecoTau/test/runDeepTauIDsOnMiniAOD.py: Added changes ma…
Browse files Browse the repository at this point in the history
…de in the commit 194a1d5 from the PR cms-sw#25016

- RecoTauTag/RecoTau/plugins/DeepTauId.cc: code cleaning
  • Loading branch information
MRD2F committed Dec 3, 2018
1 parent 97f99a1 commit d533baf
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 9 deletions.
3 changes: 1 addition & 2 deletions RecoTauTag/RecoTau/plugins/DeepTauId.cc
Expand Up @@ -252,7 +252,6 @@ class DeepTauId : public deep_tau::DeepTauBase {
desc.add<std::string>("graph_file", "RecoTauTag/TrainingFiles/data/DeepTauId/deepTau_2017v1_20L1024N_quantized.pb");
desc.add<bool>("mem_mapped", false);


edm::ParameterSetDescription descWP;
descWP.add<std::string>("VVVLoose", "0");
descWP.add<std::string>("VVLoose", "0");
Expand Down Expand Up @@ -322,7 +321,7 @@ class DeepTauId : public deep_tau::DeepTauBase {
static constexpr bool check_all_set = false;
static constexpr float default_value_for_set_check = -42;
static const TauIdMVAAuxiliaries clusterVariables;

tensorflow::Tensor inputs(tensorflow::DT_FLOAT, { 1, dnn_inputs_2017v1::NumberOfInputs});
const auto& get = [&](int var_index) -> float& { return inputs.matrix<float>()(0, var_index); };
auto leadChargedHadrCand = dynamic_cast<const pat::PackedCandidate*>(tau.leadChargedHadrCand().get());
Expand Down
22 changes: 15 additions & 7 deletions RecoTauTag/RecoTau/python/tools/runTauIdMVA.py
Expand Up @@ -49,19 +49,27 @@ def __init__(self, process, cms, debug = False,
@staticmethod
def get_cmssw_version(debug = False):
"""returns 'CMSSW_X_Y_Z'"""
if debug: print "get_cmssw_version:", os.environ["CMSSW_RELEASE_BASE"].split('/')[-1]
return os.environ["CMSSW_RELEASE_BASE"].split('/')[-1]
cmssw_version = os.environ["CMSSW_VERSION"]
if debug: print "get_cmssw_version:", cmssw_version
return cmssw_version

@classmethod
def get_cmssw_version_number(klass, debug = False):
"""returns 'X_Y_Z' (without 'CMSSW_')"""
if debug: print "get_cmssw_version_number:", map(int, klass.get_cmssw_version().split("CMSSW_")[1].split("_")[0:3])
return map(int, klass.get_cmssw_version().split("CMSSW_")[1].split("_")[0:3])
"""returns '(release, subversion, patch)' (without 'CMSSW_')"""
v = klass.get_cmssw_version().split("CMSSW_")[1].split("_")[0:3]
if debug: print "get_cmssw_version_number:", v
if v[2] == "X":
patch = -1
else:
patch = int(v[2])
return int(v[0]), int(v[1]), patch

@staticmethod
def versionToInt(release=9, subversion=4, patch=0, debug = False):
if debug: print "versionToInt:", release * 10000 + subversion * 100 + patch
return release * 10000 + subversion * 100 + patch
version = release * 10000 + subversion * 100 + patch + 1 # shifted by one to account for pre-releases.
if debug: print "versionToInt:", version
return version


@classmethod
def is_above_cmssw_version(klass, release=9, subversion=4, patch=0, debug = False):
Expand Down

0 comments on commit d533baf

Please sign in to comment.