From b9949a79fb98dfd419468cbb51fb3fc795a60e1e Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Thu, 16 Oct 2025 16:27:56 -0500 Subject: [PATCH 1/2] Fixed split count extraction bugs --- src/forest.cpp | 6 ++++-- src/py_stochtree.cpp | 3 ++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/forest.cpp b/src/forest.cpp index 968fe95c..997afe83 100644 --- a/src/forest.cpp +++ b/src/forest.cpp @@ -770,7 +770,8 @@ cpp11::writable::integers get_overall_split_counts_active_forest_cpp(cpp11::exte StochTree::Tree* tree = active_forest->GetTree(i); std::vector split_nodes = tree->GetInternalNodes(); for (int j = 0; j < split_nodes.size(); j++) { - auto split_feature = split_nodes.at(j); + auto node_id = split_nodes.at(j); + auto feature_split = tree->SplitIndex(node_id); output.at(split_feature)++; } } @@ -786,7 +787,8 @@ cpp11::writable::integers get_granular_split_count_array_active_forest_cpp(cpp11 StochTree::Tree* tree = active_forest->GetTree(i); std::vector split_nodes = tree->GetInternalNodes(); for (int j = 0; j < split_nodes.size(); j++) { - auto split_feature = split_nodes.at(j); + auto node_id = split_nodes.at(j); + auto feature_split = tree->SplitIndex(node_id); output.at(split_feature*num_trees + i)++; } } diff --git a/src/py_stochtree.cpp b/src/py_stochtree.cpp index 950caeb8..66621d52 100644 --- a/src/py_stochtree.cpp +++ b/src/py_stochtree.cpp @@ -501,7 +501,8 @@ class ForestContainerCpp { StochTree::Tree* tree = ensemble->GetTree(tree_num); std::vector split_nodes = tree->GetInternalNodes(); for (int i = 0; i < split_nodes.size(); i++) { - auto split_feature = split_nodes.at(i); + auto node_id = split_nodes.at(i); + auto split_feature = tree->SplitIndex(node_id); accessor(split_feature)++; } return result; From 2100c1bd92f67f7f73af26464410a6c9e9f839c4 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Thu, 16 Oct 2025 16:32:06 -0500 Subject: [PATCH 2/2] Made variable names match --- src/forest.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/forest.cpp b/src/forest.cpp index 997afe83..ddd247a0 100644 --- a/src/forest.cpp +++ b/src/forest.cpp @@ -772,7 +772,7 @@ cpp11::writable::integers get_overall_split_counts_active_forest_cpp(cpp11::exte for (int j = 0; j < split_nodes.size(); j++) { auto node_id = split_nodes.at(j); auto feature_split = tree->SplitIndex(node_id); - output.at(split_feature)++; + output.at(feature_split)++; } } return output; @@ -789,7 +789,7 @@ cpp11::writable::integers get_granular_split_count_array_active_forest_cpp(cpp11 for (int j = 0; j < split_nodes.size(); j++) { auto node_id = split_nodes.at(j); auto feature_split = tree->SplitIndex(node_id); - output.at(split_feature*num_trees + i)++; + output.at(feature_split*num_trees + i)++; } } return output;