diff --git a/src/forest.cpp b/src/forest.cpp index 968fe95c..ddd247a0 100644 --- a/src/forest.cpp +++ b/src/forest.cpp @@ -770,8 +770,9 @@ cpp11::writable::integers get_overall_split_counts_active_forest_cpp(cpp11::exte StochTree::Tree* tree = active_forest->GetTree(i); std::vector split_nodes = tree->GetInternalNodes(); for (int j = 0; j < split_nodes.size(); j++) { - auto split_feature = split_nodes.at(j); - output.at(split_feature)++; + auto node_id = split_nodes.at(j); + auto feature_split = tree->SplitIndex(node_id); + output.at(feature_split)++; } } return output; @@ -786,8 +787,9 @@ cpp11::writable::integers get_granular_split_count_array_active_forest_cpp(cpp11 StochTree::Tree* tree = active_forest->GetTree(i); std::vector split_nodes = tree->GetInternalNodes(); for (int j = 0; j < split_nodes.size(); j++) { - auto split_feature = split_nodes.at(j); - output.at(split_feature*num_trees + i)++; + auto node_id = split_nodes.at(j); + auto feature_split = tree->SplitIndex(node_id); + output.at(feature_split*num_trees + i)++; } } return output; diff --git a/src/py_stochtree.cpp b/src/py_stochtree.cpp index 950caeb8..66621d52 100644 --- a/src/py_stochtree.cpp +++ b/src/py_stochtree.cpp @@ -501,7 +501,8 @@ class ForestContainerCpp { StochTree::Tree* tree = ensemble->GetTree(tree_num); std::vector split_nodes = tree->GetInternalNodes(); for (int i = 0; i < split_nodes.size(); i++) { - auto split_feature = split_nodes.at(i); + auto node_id = split_nodes.at(i); + auto split_feature = tree->SplitIndex(node_id); accessor(split_feature)++; } return result;