Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions demo/notebooks/prototype_interface.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -766,7 +766,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "stochtree-dev",
"display_name": "venv",
"language": "python",
"name": "python3"
},
Expand All @@ -780,7 +780,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
"version": "3.8.17"
}
},
"nbformat": 4,
Expand Down
449 changes: 449 additions & 0 deletions demo/notebooks/tree_inspection.ipynb

Large diffs are not rendered by default.

33 changes: 33 additions & 0 deletions include/stochtree/tree.h
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,22 @@ class Tree {
return node_type_[nid];
}

/*!
* \brief Whether the node is a numeric split node
* \param nid ID of node being queried
*/
bool IsNumericSplitNode(std::int32_t nid) const {
return node_type_[nid] == TreeNodeType::kNumericalSplitNode;
}

/*!
* \brief Whether the node is a numeric split node
* \param nid ID of node being queried
*/
bool IsCategoricalSplitNode(std::int32_t nid) const {
return node_type_[nid] == TreeNodeType::kCategoricalSplitNode;
}

/*!
* \brief Query whether this tree contains any categorical splits
*/
Expand Down Expand Up @@ -500,18 +516,35 @@ class Tree {
[[nodiscard]] std::vector<std::int32_t> const& GetInternalNodes() const {
return internal_nodes_;
}

/*!
* \brief Get indices of all leaf nodes.
*/
[[nodiscard]] std::vector<std::int32_t> const& GetLeaves() const {
return leaves_;
}

/*!
* \brief Get indices of all leaf parent nodes.
*/
[[nodiscard]] std::vector<std::int32_t> const& GetLeafParents() const {
return leaf_parents_;
}

/*!
* \brief Get indices of all valid (non-deleted) nodes.
*/
[[nodiscard]] std::vector<std::int32_t> GetNodes() {
std::vector<std::int32_t> output;
auto const& self = *this;
this->WalkTree([&output, &self](std::int32_t nidx) {
if (!self.IsDeleted(nidx)) {
output.push_back(nidx);
}
return true;
});
return output;
}

/*!
* \brief Get the depth of a node
Expand Down
14 changes: 9 additions & 5 deletions src/forest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,8 @@ cpp11::writable::integers get_tree_split_counts_forest_container_cpp(cpp11::exte
StochTree::Tree* tree = ensemble->GetTree(tree_num);
std::vector<int32_t> split_nodes = tree->GetInternalNodes();
for (int i = 0; i < split_nodes.size(); i++) {
auto split_feature = split_nodes.at(i);
auto node_id = split_nodes.at(i);
auto split_feature = tree->SplitIndex(node_id);
output.at(split_feature)++;
}
return output;
Expand All @@ -243,7 +244,8 @@ cpp11::writable::integers get_forest_split_counts_forest_container_cpp(cpp11::ex
StochTree::Tree* tree = ensemble->GetTree(i);
std::vector<int32_t> split_nodes = tree->GetInternalNodes();
for (int j = 0; j < split_nodes.size(); j++) {
auto split_feature = split_nodes.at(j);
auto node_id = split_nodes.at(j);
auto split_feature = tree->SplitIndex(node_id);
output.at(split_feature)++;
}
}
Expand All @@ -262,7 +264,8 @@ cpp11::writable::integers get_overall_split_counts_forest_container_cpp(cpp11::e
StochTree::Tree* tree = ensemble->GetTree(j);
std::vector<int32_t> split_nodes = tree->GetInternalNodes();
for (int k = 0; k < split_nodes.size(); k++) {
auto split_feature = split_nodes.at(k);
auto node_id = split_nodes.at(k);
auto split_feature = tree->SplitIndex(node_id);
output.at(split_feature)++;
}
}
Expand All @@ -282,8 +285,9 @@ cpp11::writable::integers get_granular_split_count_array_forest_container_cpp(cp
StochTree::Tree* tree = ensemble->GetTree(j);
std::vector<int32_t> split_nodes = tree->GetInternalNodes();
for (int k = 0; k < split_nodes.size(); k++) {
auto split_feature = split_nodes.at(k);
output.at(num_features*num_trees*i + split_feature*num_trees + j)++;
auto node_id = split_nodes.at(k);
auto split_feature = tree->SplitIndex(node_id);
output.at(split_feature*num_samples*num_trees + j*num_samples + i)++;
}
}
}
Expand Down
169 changes: 159 additions & 10 deletions src/py_stochtree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ class ForestContainerCpp {
return forest_samples_->NumSamples();
}

int NumLeaves(int forest_num) {
int NumLeavesForest(int forest_num) {
StochTree::TreeEnsemble* forest = forest_samples_->GetEnsemble(forest_num);
return forest->NumLeaves();
}
Expand Down Expand Up @@ -428,7 +428,8 @@ class ForestContainerCpp {
StochTree::Tree* tree = ensemble->GetTree(i);
std::vector<int32_t> split_nodes = tree->GetInternalNodes();
for (int j = 0; j < split_nodes.size(); j++) {
auto split_feature = split_nodes.at(j);
auto node_id = split_nodes.at(j);
auto split_feature = tree->SplitIndex(node_id);
accessor(split_feature)++;
}
}
Expand All @@ -449,7 +450,8 @@ class ForestContainerCpp {
StochTree::Tree* tree = ensemble->GetTree(j);
std::vector<int32_t> split_nodes = tree->GetInternalNodes();
for (int k = 0; k < split_nodes.size(); k++) {
auto split_feature = split_nodes.at(k);
auto node_id = split_nodes.at(k);
auto split_feature = tree->SplitIndex(node_id);
accessor(split_feature)++;
}
}
Expand All @@ -460,11 +462,11 @@ class ForestContainerCpp {
py::array_t<int> GetGranularSplitCounts(int num_features) {
int num_samples = forest_samples_->NumSamples();
int num_trees = forest_samples_->NumTrees();
auto result = py::array_t<int>(py::detail::any_container<py::ssize_t>({num_trees,num_features,num_samples}));
auto result = py::array_t<int>(py::detail::any_container<py::ssize_t>({num_samples,num_trees,num_features}));
auto accessor = result.mutable_unchecked<3>();
for (int i = 0; i < num_trees; i++) {
for (int j = 0; j < num_features; j++) {
for (int k = 0; k < num_samples; k++) {
for (int i = 0; i < num_samples; i++) {
for (int j = 0; j < num_trees; j++) {
for (int k = 0; k < num_features; k++) {
accessor(i,j,k) = 0;
}
}
Expand All @@ -475,14 +477,144 @@ class ForestContainerCpp {
StochTree::Tree* tree = ensemble->GetTree(j);
std::vector<int32_t> split_nodes = tree->GetInternalNodes();
for (int k = 0; k < split_nodes.size(); k++) {
auto split_feature = split_nodes.at(k);
accessor(j,split_feature,i)++;
auto node_id = split_nodes.at(k);
auto split_feature = tree->SplitIndex(node_id);
accessor(i,j,split_feature)++;
}
}
}
return result;
}

bool IsLeafNode(int forest_id, int tree_id, int node_id) {
StochTree::TreeEnsemble* ensemble = forest_samples_->GetEnsemble(forest_id);
StochTree::Tree* tree = ensemble->GetTree(tree_id);
return tree->IsLeaf(node_id);
}

bool IsNumericSplitNode(int forest_id, int tree_id, int node_id) {
StochTree::TreeEnsemble* ensemble = forest_samples_->GetEnsemble(forest_id);
StochTree::Tree* tree = ensemble->GetTree(tree_id);
return tree->IsNumericSplitNode(node_id);
}

bool IsCategoricalSplitNode(int forest_id, int tree_id, int node_id) {
StochTree::TreeEnsemble* ensemble = forest_samples_->GetEnsemble(forest_id);
StochTree::Tree* tree = ensemble->GetTree(tree_id);
return tree->IsCategoricalSplitNode(node_id);
}

int ParentNode(int forest_id, int tree_id, int node_id) {
StochTree::TreeEnsemble* ensemble = forest_samples_->GetEnsemble(forest_id);
StochTree::Tree* tree = ensemble->GetTree(tree_id);
return tree->Parent(node_id);
}

int LeftChildNode(int forest_id, int tree_id, int node_id) {
StochTree::TreeEnsemble* ensemble = forest_samples_->GetEnsemble(forest_id);
StochTree::Tree* tree = ensemble->GetTree(tree_id);
return tree->LeftChild(node_id);
}

int RightChildNode(int forest_id, int tree_id, int node_id) {
StochTree::TreeEnsemble* ensemble = forest_samples_->GetEnsemble(forest_id);
StochTree::Tree* tree = ensemble->GetTree(tree_id);
return tree->RightChild(node_id);
}

int SplitIndex(int forest_id, int tree_id, int node_id) {
StochTree::TreeEnsemble* ensemble = forest_samples_->GetEnsemble(forest_id);
StochTree::Tree* tree = ensemble->GetTree(tree_id);
return tree->SplitIndex(node_id);
}

int NodeDepth(int forest_id, int tree_id, int node_id) {
StochTree::TreeEnsemble* ensemble = forest_samples_->GetEnsemble(forest_id);
StochTree::Tree* tree = ensemble->GetTree(tree_id);
return tree->GetDepth(node_id);
}

double SplitThreshold(int forest_id, int tree_id, int node_id) {
StochTree::TreeEnsemble* ensemble = forest_samples_->GetEnsemble(forest_id);
StochTree::Tree* tree = ensemble->GetTree(tree_id);
return tree->Threshold(node_id);
}

py::array_t<int> SplitCategories(int forest_id, int tree_id, int node_id) {
StochTree::TreeEnsemble* ensemble = forest_samples_->GetEnsemble(forest_id);
StochTree::Tree* tree = ensemble->GetTree(tree_id);
std::vector<std::uint32_t> raw_categories = tree->CategoryList(node_id);
int num_categories = raw_categories.size();
auto result = py::array_t<int>(py::detail::any_container<py::ssize_t>({num_categories}));
auto accessor = result.mutable_unchecked<1>();
for (int i = 0; i < num_categories; i++) {
accessor(i) = raw_categories.at(i);
}
return result;
}

py::array_t<double> NodeLeafValues(int forest_id, int tree_id, int node_id) {
StochTree::TreeEnsemble* ensemble = forest_samples_->GetEnsemble(forest_id);
StochTree::Tree* tree = ensemble->GetTree(tree_id);
int num_outputs = tree->OutputDimension();
auto result = py::array_t<double>(py::detail::any_container<py::ssize_t>({num_outputs}));
auto accessor = result.mutable_unchecked<1>();
for (int i = 0; i < num_outputs; i++) {
accessor(i) = tree->LeafValue(node_id, i);
}
return result;
}

int NumNodes(int forest_id, int tree_id) {
StochTree::TreeEnsemble* ensemble = forest_samples_->GetEnsemble(forest_id);
StochTree::Tree* tree = ensemble->GetTree(tree_id);
return tree->NumValidNodes();
}

int NumLeaves(int forest_id, int tree_id) {
StochTree::TreeEnsemble* ensemble = forest_samples_->GetEnsemble(forest_id);
StochTree::Tree* tree = ensemble->GetTree(tree_id);
return tree->NumLeaves();
}

int NumLeafParents(int forest_id, int tree_id) {
StochTree::TreeEnsemble* ensemble = forest_samples_->GetEnsemble(forest_id);
StochTree::Tree* tree = ensemble->GetTree(tree_id);
return tree->NumLeafParents();
}

int NumSplitNodes(int forest_id, int tree_id) {
StochTree::TreeEnsemble* ensemble = forest_samples_->GetEnsemble(forest_id);
StochTree::Tree* tree = ensemble->GetTree(tree_id);
return tree->NumSplitNodes();
}

py::array_t<int> Nodes(int forest_id, int tree_id) {
StochTree::TreeEnsemble* ensemble = forest_samples_->GetEnsemble(forest_id);
StochTree::Tree* tree = ensemble->GetTree(tree_id);
std::vector<std::int32_t> nodes = tree->GetNodes();
int num_nodes = nodes.size();
auto result = py::array_t<int>(py::detail::any_container<py::ssize_t>({num_nodes}));
auto accessor = result.mutable_unchecked<1>();
for (int i = 0; i < num_nodes; i++) {
accessor(i) = nodes.at(i);
}
return result;
}

py::array_t<int> Leaves(int forest_id, int tree_id) {
StochTree::TreeEnsemble* ensemble = forest_samples_->GetEnsemble(forest_id);
StochTree::Tree* tree = ensemble->GetTree(tree_id);
std::vector<std::int32_t> leaves = tree->GetLeaves();
int num_leaves = leaves.size();
auto result = py::array_t<int>(py::detail::any_container<py::ssize_t>({num_leaves}));
auto accessor = result.mutable_unchecked<1>();
for (int i = 0; i < num_leaves; i++) {
accessor(i) = leaves.at(i);
}
return result;
}

private:
std::unique_ptr<StochTree::ForestContainer> forest_samples_;
int num_trees_;
Expand Down Expand Up @@ -1044,8 +1176,25 @@ PYBIND11_MODULE(stochtree_cpp, m) {
.def("GetForestSplitCounts", &ForestContainerCpp::GetForestSplitCounts)
.def("GetOverallSplitCounts", &ForestContainerCpp::GetOverallSplitCounts)
.def("GetGranularSplitCounts", &ForestContainerCpp::GetGranularSplitCounts)
.def("NumLeavesForest", &ForestContainerCpp::NumLeavesForest)
.def("SumLeafSquared", &ForestContainerCpp::SumLeafSquared)
.def("IsLeafNode", &ForestContainerCpp::IsLeafNode)
.def("IsNumericSplitNode", &ForestContainerCpp::IsNumericSplitNode)
.def("IsCategoricalSplitNode", &ForestContainerCpp::IsCategoricalSplitNode)
.def("ParentNode", &ForestContainerCpp::ParentNode)
.def("LeftChildNode", &ForestContainerCpp::LeftChildNode)
.def("RightChildNode", &ForestContainerCpp::RightChildNode)
.def("SplitIndex", &ForestContainerCpp::SplitIndex)
.def("NodeDepth", &ForestContainerCpp::NodeDepth)
.def("SplitThreshold", &ForestContainerCpp::SplitThreshold)
.def("SplitCategories", &ForestContainerCpp::SplitCategories)
.def("NodeLeafValues", &ForestContainerCpp::NodeLeafValues)
.def("NumNodes", &ForestContainerCpp::NumNodes)
.def("NumLeaves", &ForestContainerCpp::NumLeaves)
.def("SumLeafSquared", &ForestContainerCpp::SumLeafSquared);
.def("NumLeafParents", &ForestContainerCpp::NumLeafParents)
.def("NumSplitNodes", &ForestContainerCpp::NumSplitNodes)
.def("Nodes", &ForestContainerCpp::Nodes)
.def("Leaves", &ForestContainerCpp::Leaves);

py::class_<ForestSamplerCpp>(m, "ForestSamplerCpp")
.def(py::init<ForestDatasetCpp&, py::array_t<int>, int, data_size_t, double, double, int, int>())
Expand Down
Loading