diff --git a/demo/notebooks/prototype_interface.ipynb b/demo/notebooks/prototype_interface.ipynb index 854cd484..0d7a36cd 100644 --- a/demo/notebooks/prototype_interface.ipynb +++ b/demo/notebooks/prototype_interface.ipynb @@ -766,7 +766,7 @@ ], "metadata": { "kernelspec": { - "display_name": "stochtree-dev", + "display_name": "venv", "language": "python", "name": "python3" }, @@ -780,7 +780,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.14" + "version": "3.8.17" } }, "nbformat": 4, diff --git a/demo/notebooks/tree_inspection.ipynb b/demo/notebooks/tree_inspection.ipynb new file mode 100644 index 00000000..d87e4eb2 --- /dev/null +++ b/demo/notebooks/tree_inspection.ipynb @@ -0,0 +1,449 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Deeper Dive on Fitted Forests in StochTree\n", + "\n", + "While out of sample evaluation and MCMC diagnostics on parametric BART components (i.e. $\\sigma^2$, the global error variance) are helpful, it's important to be able to inspect the trees in a BART / BCF model (or a custom tree ensemble model). This vignette walks through some of the features `stochtree` provides to query and understand the forests / trees in a model." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Load necessary libraries" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "import seaborn as sns\n", + "import matplotlib.pyplot as plt\n", + "from stochtree import BARTModel\n", + "from sklearn.model_selection import train_test_split" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Demo 1: Supervised Learning" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Generate sample data where feature 1 is the only \"important\" feature." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# RNG\n", + "random_seed = 1234\n", + "rng = np.random.default_rng(random_seed)\n", + "\n", + "# Generate covariates and basis\n", + "n = 1000\n", + "p_X = 10\n", + "X = rng.uniform(0, 1, (n, p_X))\n", + "\n", + "# Define the outcome mean function\n", + "def outcome_mean(X):\n", + " return np.where(\n", + " (X[:,9] >= 0.0) & (X[:,9] < 0.25), -7.5, \n", + " np.where(\n", + " (X[:,9] >= 0.25) & (X[:,9] < 0.5), -2.5, \n", + " np.where(\n", + " (X[:,9] >= 0.5) & (X[:,9] < 0.75), 2.5, \n", + " 7.5\n", + " )\n", + " )\n", + " )\n", + "\n", + "# Generate outcome\n", + "epsilon = rng.normal(0, 1, n)\n", + "y = outcome_mean(X) + epsilon\n", + "\n", + "# Standardize outcome\n", + "y_bar = np.mean(y)\n", + "y_std = np.std(y)\n", + "resid = (y-y_bar)/y_std" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Test-train split" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "sample_inds = np.arange(n)\n", + "train_inds, test_inds = train_test_split(sample_inds, test_size=0.5)\n", + "X_train = X[train_inds,:]\n", + "X_test = X[test_inds,:]\n", + "y_train = y[train_inds]\n", + "y_test = y[test_inds]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Run BART" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "bart_model = BARTModel()\n", + "param_dict = {\"keep_gfr\": True}\n", + "bart_model.sample(X_train=X_train, y_train=y_train, X_test=X_test, num_gfr=10, num_mcmc=100, params=param_dict)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Inspect the MCMC (BART) samples" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "forest_preds_y_mcmc = bart_model.y_hat_test[:,bart_model.num_gfr:]\n", + "y_avg_mcmc = np.squeeze(forest_preds_y_mcmc).mean(axis = 1, keepdims = True)\n", + "y_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(y_test,1), y_avg_mcmc), axis = 1), columns=[\"True outcome\", \"Average estimated outcome\"])\n", + "sns.scatterplot(data=y_df_mcmc, x=\"Average estimated outcome\", y=\"True outcome\")\n", + "plt.axline((0, 0), slope=1, color=\"black\", linestyle=(0, (3,3)))\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "sigma_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(np.arange(bart_model.num_samples),axis=1), np.expand_dims(bart_model.global_var_samples,axis=1)), axis = 1), columns=[\"Sample\", \"Sigma\"])\n", + "sns.scatterplot(data=sigma_df_mcmc, x=\"Sample\", y=\"Sigma\")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Compute the test set RMSE" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "1.2873862776376062" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "np.sqrt(np.mean(np.power(y_test - np.squeeze(y_avg_mcmc),2)))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Check the variable split count in the last \"GFR\" sample" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([28, 23, 23, 19, 18, 28, 35, 34, 40, 29], dtype=int32)" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "bart_model.forest_container_mean.get_forest_split_counts(9, p_X)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([3199, 2738, 2626, 2031, 1976, 3262, 2024, 2764, 2891, 3670],\n", + " dtype=int32)" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "bart_model.forest_container_mean.get_overall_split_counts(p_X)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The split counts appear relatively uniform across features, so let's dig deeper and look at individual trees, starting with the first tree in the last \"grow-from-root\" sample." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "splits = bart_model.forest_container_mean.get_granular_split_counts(p_X)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([0, 0, 0, 0, 0, 0, 0, 0, 0, 1], dtype=int32)" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "splits[9,0,:]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This tree has a single split on the only \"important\" feature. Now, let's look at the second tree." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([0, 0, 0, 0, 0, 0, 0, 0, 0, 1], dtype=int32)" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "splits[9,1,:]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This tree also only splits on the important feature." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([0, 1, 0, 0, 0, 0, 0, 0, 0, 0], dtype=int32)" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "splits[9,20,:]" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([0, 0, 0, 0, 0, 0, 0, 0, 1, 1], dtype=int32)" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "splits[9,30,:]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We see that \"later\" trees are splitting on other features, but we also note that these trees are fitting an outcome that is already residualized many \"relevant splits\" made by trees 1 and 2." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now, let's inspect the first tree for this last GFR sample in more depth, following [this scikit-learn vignette](https://scikit-learn.org/stable/auto_examples/tree/plot_unveil_tree_structure.html)." + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "forest_num = 9\n", + "tree_num = 0" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "node=0 is a split node, which tells us to go to node 1 if X[:, 9] <= 0.49971201341971494 else to node 2.\n", + "\tnode=1 is a leaf node with value=[-0.313].\n", + "\tnode=2 is a leaf node with value=[0.406].\n" + ] + } + ], + "source": [ + "nodes = np.sort(bart_model.forest_container_mean.nodes(forest_num,tree_num))\n", + "for nid in nodes:\n", + " if bart_model.forest_container_mean.is_leaf_node(forest_num,tree_num,nid):\n", + " print(\n", + " \"{space}node={node} is a leaf node with value={value}.\".format(\n", + " space=bart_model.forest_container_mean.node_depth(forest_num,tree_num,nid) * \"\\t\", \n", + " node=nid, value=np.around(bart_model.forest_container_mean.node_leaf_values(forest_num,tree_num,nid), 3)\n", + " )\n", + " )\n", + " else:\n", + " print(\n", + " \"{space}node={node} is a split node, which tells us to \"\n", + " \"go to node {left} if X[:, {feature}] <= {threshold} \"\n", + " \"else to node {right}.\".format(\n", + " space=bart_model.forest_container_mean.node_depth(forest_num,tree_num,nid) * \"\\t\",\n", + " node=nid,\n", + " left=bart_model.forest_container_mean.left_child_node(forest_num,tree_num,nid),\n", + " feature=bart_model.forest_container_mean.node_split_index(forest_num,tree_num,nid),\n", + " threshold=bart_model.forest_container_mean.node_split_threshold(forest_num,tree_num,nid),\n", + " right=bart_model.forest_container_mean.right_child_node(forest_num,tree_num,nid),\n", + " )\n", + " )" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.17" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/include/stochtree/tree.h b/include/stochtree/tree.h index d027c4e6..b252f9cb 100644 --- a/include/stochtree/tree.h +++ b/include/stochtree/tree.h @@ -461,6 +461,22 @@ class Tree { return node_type_[nid]; } + /*! + * \brief Whether the node is a numeric split node + * \param nid ID of node being queried + */ + bool IsNumericSplitNode(std::int32_t nid) const { + return node_type_[nid] == TreeNodeType::kNumericalSplitNode; + } + + /*! + * \brief Whether the node is a numeric split node + * \param nid ID of node being queried + */ + bool IsCategoricalSplitNode(std::int32_t nid) const { + return node_type_[nid] == TreeNodeType::kCategoricalSplitNode; + } + /*! * \brief Query whether this tree contains any categorical splits */ @@ -500,18 +516,35 @@ class Tree { [[nodiscard]] std::vector const& GetInternalNodes() const { return internal_nodes_; } + /*! * \brief Get indices of all leaf nodes. */ [[nodiscard]] std::vector const& GetLeaves() const { return leaves_; } + /*! * \brief Get indices of all leaf parent nodes. */ [[nodiscard]] std::vector const& GetLeafParents() const { return leaf_parents_; } + + /*! + * \brief Get indices of all valid (non-deleted) nodes. + */ + [[nodiscard]] std::vector GetNodes() { + std::vector output; + auto const& self = *this; + this->WalkTree([&output, &self](std::int32_t nidx) { + if (!self.IsDeleted(nidx)) { + output.push_back(nidx); + } + return true; + }); + return output; + } /*! * \brief Get the depth of a node diff --git a/src/forest.cpp b/src/forest.cpp index b0f47c96..d510a16f 100644 --- a/src/forest.cpp +++ b/src/forest.cpp @@ -227,7 +227,8 @@ cpp11::writable::integers get_tree_split_counts_forest_container_cpp(cpp11::exte StochTree::Tree* tree = ensemble->GetTree(tree_num); std::vector split_nodes = tree->GetInternalNodes(); for (int i = 0; i < split_nodes.size(); i++) { - auto split_feature = split_nodes.at(i); + auto node_id = split_nodes.at(i); + auto split_feature = tree->SplitIndex(node_id); output.at(split_feature)++; } return output; @@ -243,7 +244,8 @@ cpp11::writable::integers get_forest_split_counts_forest_container_cpp(cpp11::ex StochTree::Tree* tree = ensemble->GetTree(i); std::vector split_nodes = tree->GetInternalNodes(); for (int j = 0; j < split_nodes.size(); j++) { - auto split_feature = split_nodes.at(j); + auto node_id = split_nodes.at(j); + auto split_feature = tree->SplitIndex(node_id); output.at(split_feature)++; } } @@ -262,7 +264,8 @@ cpp11::writable::integers get_overall_split_counts_forest_container_cpp(cpp11::e StochTree::Tree* tree = ensemble->GetTree(j); std::vector split_nodes = tree->GetInternalNodes(); for (int k = 0; k < split_nodes.size(); k++) { - auto split_feature = split_nodes.at(k); + auto node_id = split_nodes.at(k); + auto split_feature = tree->SplitIndex(node_id); output.at(split_feature)++; } } @@ -282,8 +285,9 @@ cpp11::writable::integers get_granular_split_count_array_forest_container_cpp(cp StochTree::Tree* tree = ensemble->GetTree(j); std::vector split_nodes = tree->GetInternalNodes(); for (int k = 0; k < split_nodes.size(); k++) { - auto split_feature = split_nodes.at(k); - output.at(num_features*num_trees*i + split_feature*num_trees + j)++; + auto node_id = split_nodes.at(k); + auto split_feature = tree->SplitIndex(node_id); + output.at(split_feature*num_samples*num_trees + j*num_samples + i)++; } } } diff --git a/src/py_stochtree.cpp b/src/py_stochtree.cpp index 0420e396..45ff4127 100644 --- a/src/py_stochtree.cpp +++ b/src/py_stochtree.cpp @@ -164,7 +164,7 @@ class ForestContainerCpp { return forest_samples_->NumSamples(); } - int NumLeaves(int forest_num) { + int NumLeavesForest(int forest_num) { StochTree::TreeEnsemble* forest = forest_samples_->GetEnsemble(forest_num); return forest->NumLeaves(); } @@ -428,7 +428,8 @@ class ForestContainerCpp { StochTree::Tree* tree = ensemble->GetTree(i); std::vector split_nodes = tree->GetInternalNodes(); for (int j = 0; j < split_nodes.size(); j++) { - auto split_feature = split_nodes.at(j); + auto node_id = split_nodes.at(j); + auto split_feature = tree->SplitIndex(node_id); accessor(split_feature)++; } } @@ -449,7 +450,8 @@ class ForestContainerCpp { StochTree::Tree* tree = ensemble->GetTree(j); std::vector split_nodes = tree->GetInternalNodes(); for (int k = 0; k < split_nodes.size(); k++) { - auto split_feature = split_nodes.at(k); + auto node_id = split_nodes.at(k); + auto split_feature = tree->SplitIndex(node_id); accessor(split_feature)++; } } @@ -460,11 +462,11 @@ class ForestContainerCpp { py::array_t GetGranularSplitCounts(int num_features) { int num_samples = forest_samples_->NumSamples(); int num_trees = forest_samples_->NumTrees(); - auto result = py::array_t(py::detail::any_container({num_trees,num_features,num_samples})); + auto result = py::array_t(py::detail::any_container({num_samples,num_trees,num_features})); auto accessor = result.mutable_unchecked<3>(); - for (int i = 0; i < num_trees; i++) { - for (int j = 0; j < num_features; j++) { - for (int k = 0; k < num_samples; k++) { + for (int i = 0; i < num_samples; i++) { + for (int j = 0; j < num_trees; j++) { + for (int k = 0; k < num_features; k++) { accessor(i,j,k) = 0; } } @@ -475,14 +477,144 @@ class ForestContainerCpp { StochTree::Tree* tree = ensemble->GetTree(j); std::vector split_nodes = tree->GetInternalNodes(); for (int k = 0; k < split_nodes.size(); k++) { - auto split_feature = split_nodes.at(k); - accessor(j,split_feature,i)++; + auto node_id = split_nodes.at(k); + auto split_feature = tree->SplitIndex(node_id); + accessor(i,j,split_feature)++; } } } return result; } + bool IsLeafNode(int forest_id, int tree_id, int node_id) { + StochTree::TreeEnsemble* ensemble = forest_samples_->GetEnsemble(forest_id); + StochTree::Tree* tree = ensemble->GetTree(tree_id); + return tree->IsLeaf(node_id); + } + + bool IsNumericSplitNode(int forest_id, int tree_id, int node_id) { + StochTree::TreeEnsemble* ensemble = forest_samples_->GetEnsemble(forest_id); + StochTree::Tree* tree = ensemble->GetTree(tree_id); + return tree->IsNumericSplitNode(node_id); + } + + bool IsCategoricalSplitNode(int forest_id, int tree_id, int node_id) { + StochTree::TreeEnsemble* ensemble = forest_samples_->GetEnsemble(forest_id); + StochTree::Tree* tree = ensemble->GetTree(tree_id); + return tree->IsCategoricalSplitNode(node_id); + } + + int ParentNode(int forest_id, int tree_id, int node_id) { + StochTree::TreeEnsemble* ensemble = forest_samples_->GetEnsemble(forest_id); + StochTree::Tree* tree = ensemble->GetTree(tree_id); + return tree->Parent(node_id); + } + + int LeftChildNode(int forest_id, int tree_id, int node_id) { + StochTree::TreeEnsemble* ensemble = forest_samples_->GetEnsemble(forest_id); + StochTree::Tree* tree = ensemble->GetTree(tree_id); + return tree->LeftChild(node_id); + } + + int RightChildNode(int forest_id, int tree_id, int node_id) { + StochTree::TreeEnsemble* ensemble = forest_samples_->GetEnsemble(forest_id); + StochTree::Tree* tree = ensemble->GetTree(tree_id); + return tree->RightChild(node_id); + } + + int SplitIndex(int forest_id, int tree_id, int node_id) { + StochTree::TreeEnsemble* ensemble = forest_samples_->GetEnsemble(forest_id); + StochTree::Tree* tree = ensemble->GetTree(tree_id); + return tree->SplitIndex(node_id); + } + + int NodeDepth(int forest_id, int tree_id, int node_id) { + StochTree::TreeEnsemble* ensemble = forest_samples_->GetEnsemble(forest_id); + StochTree::Tree* tree = ensemble->GetTree(tree_id); + return tree->GetDepth(node_id); + } + + double SplitThreshold(int forest_id, int tree_id, int node_id) { + StochTree::TreeEnsemble* ensemble = forest_samples_->GetEnsemble(forest_id); + StochTree::Tree* tree = ensemble->GetTree(tree_id); + return tree->Threshold(node_id); + } + + py::array_t SplitCategories(int forest_id, int tree_id, int node_id) { + StochTree::TreeEnsemble* ensemble = forest_samples_->GetEnsemble(forest_id); + StochTree::Tree* tree = ensemble->GetTree(tree_id); + std::vector raw_categories = tree->CategoryList(node_id); + int num_categories = raw_categories.size(); + auto result = py::array_t(py::detail::any_container({num_categories})); + auto accessor = result.mutable_unchecked<1>(); + for (int i = 0; i < num_categories; i++) { + accessor(i) = raw_categories.at(i); + } + return result; + } + + py::array_t NodeLeafValues(int forest_id, int tree_id, int node_id) { + StochTree::TreeEnsemble* ensemble = forest_samples_->GetEnsemble(forest_id); + StochTree::Tree* tree = ensemble->GetTree(tree_id); + int num_outputs = tree->OutputDimension(); + auto result = py::array_t(py::detail::any_container({num_outputs})); + auto accessor = result.mutable_unchecked<1>(); + for (int i = 0; i < num_outputs; i++) { + accessor(i) = tree->LeafValue(node_id, i); + } + return result; + } + + int NumNodes(int forest_id, int tree_id) { + StochTree::TreeEnsemble* ensemble = forest_samples_->GetEnsemble(forest_id); + StochTree::Tree* tree = ensemble->GetTree(tree_id); + return tree->NumValidNodes(); + } + + int NumLeaves(int forest_id, int tree_id) { + StochTree::TreeEnsemble* ensemble = forest_samples_->GetEnsemble(forest_id); + StochTree::Tree* tree = ensemble->GetTree(tree_id); + return tree->NumLeaves(); + } + + int NumLeafParents(int forest_id, int tree_id) { + StochTree::TreeEnsemble* ensemble = forest_samples_->GetEnsemble(forest_id); + StochTree::Tree* tree = ensemble->GetTree(tree_id); + return tree->NumLeafParents(); + } + + int NumSplitNodes(int forest_id, int tree_id) { + StochTree::TreeEnsemble* ensemble = forest_samples_->GetEnsemble(forest_id); + StochTree::Tree* tree = ensemble->GetTree(tree_id); + return tree->NumSplitNodes(); + } + + py::array_t Nodes(int forest_id, int tree_id) { + StochTree::TreeEnsemble* ensemble = forest_samples_->GetEnsemble(forest_id); + StochTree::Tree* tree = ensemble->GetTree(tree_id); + std::vector nodes = tree->GetNodes(); + int num_nodes = nodes.size(); + auto result = py::array_t(py::detail::any_container({num_nodes})); + auto accessor = result.mutable_unchecked<1>(); + for (int i = 0; i < num_nodes; i++) { + accessor(i) = nodes.at(i); + } + return result; + } + + py::array_t Leaves(int forest_id, int tree_id) { + StochTree::TreeEnsemble* ensemble = forest_samples_->GetEnsemble(forest_id); + StochTree::Tree* tree = ensemble->GetTree(tree_id); + std::vector leaves = tree->GetLeaves(); + int num_leaves = leaves.size(); + auto result = py::array_t(py::detail::any_container({num_leaves})); + auto accessor = result.mutable_unchecked<1>(); + for (int i = 0; i < num_leaves; i++) { + accessor(i) = leaves.at(i); + } + return result; + } + private: std::unique_ptr forest_samples_; int num_trees_; @@ -1044,8 +1176,25 @@ PYBIND11_MODULE(stochtree_cpp, m) { .def("GetForestSplitCounts", &ForestContainerCpp::GetForestSplitCounts) .def("GetOverallSplitCounts", &ForestContainerCpp::GetOverallSplitCounts) .def("GetGranularSplitCounts", &ForestContainerCpp::GetGranularSplitCounts) + .def("NumLeavesForest", &ForestContainerCpp::NumLeavesForest) + .def("SumLeafSquared", &ForestContainerCpp::SumLeafSquared) + .def("IsLeafNode", &ForestContainerCpp::IsLeafNode) + .def("IsNumericSplitNode", &ForestContainerCpp::IsNumericSplitNode) + .def("IsCategoricalSplitNode", &ForestContainerCpp::IsCategoricalSplitNode) + .def("ParentNode", &ForestContainerCpp::ParentNode) + .def("LeftChildNode", &ForestContainerCpp::LeftChildNode) + .def("RightChildNode", &ForestContainerCpp::RightChildNode) + .def("SplitIndex", &ForestContainerCpp::SplitIndex) + .def("NodeDepth", &ForestContainerCpp::NodeDepth) + .def("SplitThreshold", &ForestContainerCpp::SplitThreshold) + .def("SplitCategories", &ForestContainerCpp::SplitCategories) + .def("NodeLeafValues", &ForestContainerCpp::NodeLeafValues) + .def("NumNodes", &ForestContainerCpp::NumNodes) .def("NumLeaves", &ForestContainerCpp::NumLeaves) - .def("SumLeafSquared", &ForestContainerCpp::SumLeafSquared); + .def("NumLeafParents", &ForestContainerCpp::NumLeafParents) + .def("NumSplitNodes", &ForestContainerCpp::NumSplitNodes) + .def("Nodes", &ForestContainerCpp::Nodes) + .def("Leaves", &ForestContainerCpp::Leaves); py::class_(m, "ForestSamplerCpp") .def(py::init, int, data_size_t, double, double, int, int>()) diff --git a/stochtree/forest.py b/stochtree/forest.py index 866208ee..4eefae79 100644 --- a/stochtree/forest.py +++ b/stochtree/forest.py @@ -11,6 +11,10 @@ class ForestContainer: def __init__(self, num_trees: int, output_dimension: int, leaf_constant: bool, is_exponentiated: bool) -> None: # Initialize a ForestContainerCpp object self.forest_container_cpp = ForestContainerCpp(num_trees, output_dimension, leaf_constant, is_exponentiated) + self.num_trees = num_trees + self.output_dimension = output_dimension + self.leaf_constant = leaf_constant + self.is_exponentiated = is_exponentiated def predict(self, dataset: Dataset) -> np.array: # Predict samples from Dataset @@ -149,14 +153,14 @@ def get_granular_split_counts(self, num_features: int) -> np.array: """ return self.forest_container_cpp.GetGranularSplitCounts(num_features) - def num_leaves(self, forest_num: int) -> int: + def num_forest_leaves(self, forest_num: int) -> int: """ Return the total number of leaves for a given forest in the ``ForestContainer`` forest_num : :obj:`int` Index of the forest to be queried """ - return self.forest_container_cpp.NumLeaves(forest_num) + return self.forest_container_cpp.NumLeavesForest(forest_num) def sum_leaves_squared(self, forest_num: int) -> float: """ @@ -166,4 +170,227 @@ def sum_leaves_squared(self, forest_num: int) -> float: Index of the forest to be queried """ return self.forest_container_cpp.SumLeafSquared(forest_num) + + def is_leaf_node(self, forest_num: int, tree_num: int, node_id: int) -> bool: + """ + Whether or not a given node of a given tree in a given forest in the ``ForestContainer`` is a leaf + + forest_num : :obj:`int` + Index of the forest to be queried + tree_num : :obj:`int` + Index of the tree to be queried + node_id : :obj:`int` + Index of the node to be queried + """ + return self.forest_container_cpp.IsLeafNode(forest_num, tree_num, node_id) + + def is_numeric_split_node(self, forest_num: int, tree_num: int, node_id: int) -> bool: + """ + Whether or not a given node of a given tree in a given forest in the ``ForestContainer`` is a numeric split node + + forest_num : :obj:`int` + Index of the forest to be queried + tree_num : :obj:`int` + Index of the tree to be queried + node_id : :obj:`int` + Index of the node to be queried + """ + return self.forest_container_cpp.IsNumericSplitNode(forest_num, tree_num, node_id) + + def is_categorical_split_node(self, forest_num: int, tree_num: int, node_id: int) -> bool: + """ + Whether or not a given node of a given tree in a given forest in the ``ForestContainer`` is a categorical split node + + forest_num : :obj:`int` + Index of the forest to be queried + tree_num : :obj:`int` + Index of the tree to be queried + node_id : :obj:`int` + Index of the node to be queried + """ + return self.forest_container_cpp.IsCategoricalSplitNode(forest_num, tree_num, node_id) + + def parent_node(self, forest_num: int, tree_num: int, node_id: int) -> int: + """ + Parent node of given node of a given tree in a given forest in the ``ForestContainer`` + + forest_num : :obj:`int` + Index of the forest to be queried + tree_num : :obj:`int` + Index of the tree to be queried + node_id : :obj:`int` + Index of the node to be queried + """ + return self.forest_container_cpp.ParentNode(forest_num, tree_num, node_id) + + def left_child_node(self, forest_num: int, tree_num: int, node_id: int) -> int: + """ + Left child node of given node of a given tree in a given forest in the ``ForestContainer`` + + forest_num : :obj:`int` + Index of the forest to be queried + tree_num : :obj:`int` + Index of the tree to be queried + node_id : :obj:`int` + Index of the node to be queried + """ + return self.forest_container_cpp.LeftChildNode(forest_num, tree_num, node_id) + + def right_child_node(self, forest_num: int, tree_num: int, node_id: int) -> int: + """ + Right child node of given node of a given tree in a given forest in the ``ForestContainer`` + + forest_num : :obj:`int` + Index of the forest to be queried + tree_num : :obj:`int` + Index of the tree to be queried + node_id : :obj:`int` + Index of the node to be queried + """ + return self.forest_container_cpp.RightChildNode(forest_num, tree_num, node_id) + + def node_depth(self, forest_num: int, tree_num: int, node_id: int) -> int: + """ + Depth of given node of a given tree in a given forest in the ``ForestContainer``. + Returns ``-1`` if the node is a leaf. + + forest_num : :obj:`int` + Index of the forest to be queried + tree_num : :obj:`int` + Index of the tree to be queried + node_id : :obj:`int` + Index of the node to be queried + """ + return self.forest_container_cpp.NodeDepth(forest_num, tree_num, node_id) + + def node_split_index(self, forest_num: int, tree_num: int, node_id: int) -> int: + """ + Split index of given node of a given tree in a given forest in the ``ForestContainer``. + Returns ``-1`` if the node is a leaf. + + forest_num : :obj:`int` + Index of the forest to be queried + tree_num : :obj:`int` + Index of the tree to be queried + node_id : :obj:`int` + Index of the node to be queried + """ + if self.is_leaf_node(forest_num, tree_num, node_id): + return -1 + else: + return self.forest_container_cpp.SplitIndex(forest_num, tree_num, node_id) + + def node_split_threshold(self, forest_num: int, tree_num: int, node_id: int) -> float: + """ + Threshold that defines a numeric split for a given node of a given tree in a given forest in the ``ForestContainer``. + Returns ``np.Inf`` if the node is a leaf or a categorical split node. + + forest_num : :obj:`int` + Index of the forest to be queried + tree_num : :obj:`int` + Index of the tree to be queried + node_id : :obj:`int` + Index of the node to be queried + """ + if self.is_leaf_node(forest_num, tree_num, node_id) or self.is_categorical_split_node(forest_num, tree_num, node_id): + return np.Inf + else: + return self.forest_container_cpp.SplitThreshold(forest_num, tree_num, node_id) + + def node_split_categories(self, forest_num: int, tree_num: int, node_id: int) -> np.array: + """ + Array of category indices that define a categorical split for a given node of a given tree in a given forest in the ``ForestContainer``. + Returns ``np.array([np.Inf])`` if the node is a leaf or a numeric split node. + + forest_num : :obj:`int` + Index of the forest to be queried + tree_num : :obj:`int` + Index of the tree to be queried + node_id : :obj:`int` + Index of the node to be queried + """ + if self.is_leaf_node(forest_num, tree_num, node_id) or self.is_numeric_split_node(forest_num, tree_num, node_id): + return np.array([np.Inf]) + else: + return self.forest_container_cpp.SplitCategories(forest_num, tree_num, node_id) + + def node_leaf_values(self, forest_num: int, tree_num: int, node_id: int) -> np.array: + """ + Leaf node value(s) for a given node of a given tree in a given forest in the ``ForestContainer``. + Values are stale if the node is a split node. + + forest_num : :obj:`int` + Index of the forest to be queried + tree_num : :obj:`int` + Index of the tree to be queried + node_id : :obj:`int` + Index of the node to be queried + """ + return self.forest_container_cpp.NodeLeafValues(forest_num, tree_num, node_id) + + def num_nodes(self, forest_num: int, tree_num: int) -> int: + """ + Number of nodes in a given tree in a given forest in the ``ForestContainer``. + + forest_num : :obj:`int` + Index of the forest to be queried + tree_num : :obj:`int` + Index of the tree to be queried + """ + return self.forest_container_cpp.NumNodes(forest_num, tree_num) + + def num_leaves(self, forest_num: int, tree_num: int) -> int: + """ + Number of leaves in a given tree in a given forest in the ``ForestContainer``. + + forest_num : :obj:`int` + Index of the forest to be queried + tree_num : :obj:`int` + Index of the tree to be queried + """ + return self.forest_container_cpp.NumLeaves(forest_num, tree_num) + + def num_leaf_parents(self, forest_num: int, tree_num: int) -> int: + """ + Number of leaf parents in a given tree in a given forest in the ``ForestContainer``. + + forest_num : :obj:`int` + Index of the forest to be queried + tree_num : :obj:`int` + Index of the tree to be queried + """ + return self.forest_container_cpp.NumLeafParents(forest_num, tree_num) + + def num_split_nodes(self, forest_num: int, tree_num: int) -> int: + """ + Number of split_nodes in a given tree in a given forest in the ``ForestContainer``. + + forest_num : :obj:`int` + Index of the forest to be queried + tree_num : :obj:`int` + Index of the tree to be queried + """ + return self.forest_container_cpp.NumSplitNodes(forest_num, tree_num) + + def nodes(self, forest_num: int, tree_num: int) -> np.array: + """ + Array of node indices in a given tree in a given forest in the ``ForestContainer``. + + forest_num : :obj:`int` + Index of the forest to be queried + tree_num : :obj:`int` + Index of the tree to be queried + """ + return self.forest_container_cpp.Nodes(forest_num, tree_num) + + def leaves(self, forest_num: int, tree_num: int) -> np.array: + """ + Array of leaf indices in a given tree in a given forest in the ``ForestContainer``. + + forest_num : :obj:`int` + Index of the forest to be queried + tree_num : :obj:`int` + Index of the tree to be queried + """ + return self.forest_container_cpp.Leaves(forest_num, tree_num) \ No newline at end of file