Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhancements to GBRForest/Tree for converting TMVA classifiers (74X) #10370

Merged
merged 1 commit into from Jul 30, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 6 additions & 2 deletions CondFormats/EgammaObjects/interface/GBRForest.h
Expand Up @@ -36,7 +36,11 @@
virtual ~GBRForest();

double GetResponse(const float* vector) const;
double GetClassifier(const float* vector) const;
double GetGradBoostClassifier(const float* vector) const;
double GetAdaBoostClassifier(const float* vector) const { return GetResponse(vector); }

//for backwards-compatibility
double GetClassifier(const float* vector) const { return GetGradBoostClassifier(vector); }

void SetInitialResponse(double response) { fInitialResponse = response; }

Expand All @@ -61,7 +65,7 @@ inline double GBRForest::GetResponse(const float* vector) const {
}

//_______________________________________________________________________
inline double GBRForest::GetClassifier(const float* vector) const {
inline double GBRForest::GetGradBoostClassifier(const float* vector) const {
double response = GetResponse(vector);
return 2.0/(1.0+exp(-2.0*response))-1; //MVA output between -1 and 1
}
Expand Down
4 changes: 2 additions & 2 deletions CondFormats/EgammaObjects/interface/GBRTree.h
Expand Up @@ -38,7 +38,7 @@
public:

GBRTree();
explicit GBRTree(const TMVA::DecisionTree *tree);
explicit GBRTree(const TMVA::DecisionTree *tree, double scale, bool useyesnoleaf, bool adjustboundary);
virtual ~GBRTree();

double GetResponse(const float* vector) const;
Expand All @@ -65,7 +65,7 @@
unsigned int CountIntermediateNodes(const TMVA::DecisionTreeNode *node);
unsigned int CountTerminalNodes(const TMVA::DecisionTreeNode *node);

void AddNode(const TMVA::DecisionTreeNode *node);
void AddNode(const TMVA::DecisionTreeNode *node, double scale, bool isregression, bool useyesnoleaf, bool adjustboundary);

std::vector<unsigned char> fCutIndices;
std::vector<float> fCutVals;
Expand Down
23 changes: 20 additions & 3 deletions CondFormats/EgammaObjects/src/GBRForest.cxx
Expand Up @@ -21,17 +21,34 @@ GBRForest::~GBRForest()
GBRForest::GBRForest(const TMVA::MethodBDT *bdt)
{

if (bdt->DoRegression()) {
//special handling for non-gradient-boosted (ie ADABoost) classifiers, where tree responses
//need to be renormalized after the training for evaluation purposes
bool isadaclassifier = !bdt->DoRegression() && !bdt->GetOptions().Contains("~BoostType=Grad");
bool useyesnoleaf = isadaclassifier && bdt->GetOptions().Contains("~UseYesNoLeaf=True");
bool isregression = bdt->DoRegression();
//newer tmva versions use >= instead of > in decision tree splits, so adjust cut value
//to reproduce the correct behaviour
bool adjustboundaries = (bdt->GetTrainingROOTVersionCode()>=ROOT_VERSION(5,34,20) && bdt->GetTrainingROOTVersionCode()<ROOT_VERSION(6,0,0)) || bdt->GetTrainingROOTVersionCode()>=ROOT_VERSION(6,2,0);

if (isregression) {
fInitialResponse = bdt->GetBoostWeights().front();
}
else {
fInitialResponse = 0.;
}

double norm = 0;
if (isadaclassifier) {
for (std::vector<double>::const_iterator it=bdt->GetBoostWeights().begin(); it!=bdt->GetBoostWeights().end(); ++it) {
norm += *it;
}
}

const std::vector<TMVA::DecisionTree*> &forest = bdt->GetForest();
fTrees.reserve(forest.size());
for (std::vector<TMVA::DecisionTree*>::const_iterator it=forest.begin(); it!=forest.end(); ++it) {
fTrees.push_back(GBRTree(*it));
for (unsigned int itree=0; itree<forest.size(); ++itree) {
double scale = isadaclassifier ? bdt->GetBoostWeights()[itree]/norm : 1.0;
fTrees.push_back(GBRTree(forest[itree],scale,useyesnoleaf,adjustboundaries));
}

}
Expand Down
33 changes: 26 additions & 7 deletions CondFormats/EgammaObjects/src/GBRTree.cxx
Expand Up @@ -13,7 +13,7 @@ GBRTree::GBRTree()
}

//_______________________________________________________________________
GBRTree::GBRTree(const TMVA::DecisionTree *tree)
GBRTree::GBRTree(const TMVA::DecisionTree *tree, double scale, bool useyesnoleaf, bool adjustboundary)
{

//printf("boostweights size = %i, forest size = %i\n",bdt->GetBoostWeights().size(),bdt->GetForest().size());
Expand All @@ -29,7 +29,7 @@ GBRTree::GBRTree(const TMVA::DecisionTree *tree)
fRightIndices.reserve(nIntermediate);
fResponses.reserve(nTerminal);

AddNode((TMVA::DecisionTreeNode*)tree->GetRoot());
AddNode((TMVA::DecisionTreeNode*)tree->GetRoot(), scale, tree->DoRegression(), useyesnoleaf, adjustboundary);

//special case, root node is terminal, create fake intermediate node at root
if (fCutIndices.size()==0) {
Expand Down Expand Up @@ -73,17 +73,36 @@ unsigned int GBRTree::CountTerminalNodes(const TMVA::DecisionTreeNode *node) {


//_______________________________________________________________________
void GBRTree::AddNode(const TMVA::DecisionTreeNode *node) {
void GBRTree::AddNode(const TMVA::DecisionTreeNode *node, double scale, bool isregression, bool useyesnoleaf, bool adjustboundary) {

if (!node->GetLeft() || !node->GetRight() || node->IsTerminal()) {
fResponses.push_back(node->GetResponse());
double response = 0.;
if (isregression) {
response = node->GetResponse();
}
else {
if (useyesnoleaf) {
response = double(node->GetNodeType());
}
else {
response = node->GetPurity();
}
}
response *= scale;
fResponses.push_back(response);
return;
}
else {
int thisidx = fCutIndices.size();

fCutIndices.push_back(node->GetSelector());
fCutVals.push_back(node->GetCutValue());
float cutval = node->GetCutValue();
//newer tmva versions use >= instead of > in decision tree splits, so adjust cut value
//to reproduce the correct behaviour
if (adjustboundary) {
cutval = std::nextafter(cutval,std::numeric_limits<float>::lowest());
}
fCutVals.push_back(cutval);
fLeftIndices.push_back(0);
fRightIndices.push_back(0);

Expand All @@ -105,15 +124,15 @@ void GBRTree::AddNode(const TMVA::DecisionTreeNode *node) {
else {
fLeftIndices[thisidx] = fCutIndices.size();
}
AddNode(left);
AddNode(left, scale, isregression, useyesnoleaf, adjustboundary);

if (!right->GetLeft() || !right->GetRight() || right->IsTerminal()) {
fRightIndices[thisidx] = -fResponses.size();
}
else {
fRightIndices[thisidx] = fCutIndices.size();
}
AddNode(right);
AddNode(right, scale, isregression, useyesnoleaf, adjustboundary);

}

Expand Down