From 8c6d08339c3ded5b7c1e3c4ecdc4069e2cb7affd Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Wed, 16 Oct 2024 15:07:44 -0500 Subject: [PATCH] Properly enforcing min_samples_leaf cutoff in MCMC sampler --- include/stochtree/tree_sampler.h | 83 ++++++++++++++++++-------------- 1 file changed, 46 insertions(+), 37 deletions(-) diff --git a/include/stochtree/tree_sampler.h b/include/stochtree/tree_sampler.h index c38e0310..9e0d0562 100644 --- a/include/stochtree/tree_sampler.h +++ b/include/stochtree/tree_sampler.h @@ -802,46 +802,55 @@ static inline void MCMCGrowTreeOneIter(Tree* tree, ForestTracker& tracker, LeafM double no_split_log_marginal_likelihood = std::get<1>(split_eval); int32_t left_n = std::get<2>(split_eval); int32_t right_n = std::get<3>(split_eval); - - // Determine probability of growing the split node and its two new left and right nodes - double pg = tree_prior.GetAlpha() * std::pow(1+leaf_depth, -tree_prior.GetBeta()); - double pgl = tree_prior.GetAlpha() * std::pow(1+leaf_depth+1, -tree_prior.GetBeta()); - double pgr = tree_prior.GetAlpha() * std::pow(1+leaf_depth+1, -tree_prior.GetBeta()); - - // Determine whether a "grow" move is possible from the newly formed tree - // in order to compute the probability of choosing "prune" from the new tree - // (which is always possible by construction) - bool non_constant = NodesNonConstantAfterSplit(dataset, tracker, split, tree_num, leaf_chosen, var_chosen); - bool min_samples_left_check = left_n >= 2*tree_prior.GetMinSamplesLeaf(); - bool min_samples_right_check = right_n >= 2*tree_prior.GetMinSamplesLeaf(); - double prob_prune_new; - if (non_constant && (min_samples_left_check || min_samples_right_check)) { - prob_prune_new = 0.5; - } else { - prob_prune_new = 1.0; - } - // Determine the number of leaves in the current tree and leaf parents in the proposed tree - int num_leaf_parents = tree->NumLeafParents(); - double p_leaf = 1/static_cast(num_leaves); - double p_leaf_parent = 1/static_cast(num_leaf_parents+1); + // Reject the split if either of the left and right nodes are smaller than tree_prior.GetMinSamplesLeaf() + bool left_node_sample_cutoff = left_n >= tree_prior.GetMinSamplesLeaf(); + bool right_node_sample_cutoff = right_n >= tree_prior.GetMinSamplesLeaf(); + if ((left_node_sample_cutoff) && (right_node_sample_cutoff)) { + + // Determine probability of growing the split node and its two new left and right nodes + double pg = tree_prior.GetAlpha() * std::pow(1+leaf_depth, -tree_prior.GetBeta()); + double pgl = tree_prior.GetAlpha() * std::pow(1+leaf_depth+1, -tree_prior.GetBeta()); + double pgr = tree_prior.GetAlpha() * std::pow(1+leaf_depth+1, -tree_prior.GetBeta()); + + // Determine whether a "grow" move is possible from the newly formed tree + // in order to compute the probability of choosing "prune" from the new tree + // (which is always possible by construction) + bool non_constant = NodesNonConstantAfterSplit(dataset, tracker, split, tree_num, leaf_chosen, var_chosen); + bool min_samples_left_check = left_n >= 2*tree_prior.GetMinSamplesLeaf(); + bool min_samples_right_check = right_n >= 2*tree_prior.GetMinSamplesLeaf(); + double prob_prune_new; + if (non_constant && (min_samples_left_check || min_samples_right_check)) { + prob_prune_new = 0.5; + } else { + prob_prune_new = 1.0; + } + + // Determine the number of leaves in the current tree and leaf parents in the proposed tree + int num_leaf_parents = tree->NumLeafParents(); + double p_leaf = 1/static_cast(num_leaves); + double p_leaf_parent = 1/static_cast(num_leaf_parents+1); + + // Compute the final MH ratio + double log_mh_ratio = ( + std::log(pg) + std::log(1-pgl) + std::log(1-pgr) - std::log(1-pg) + std::log(prob_prune_new) + + std::log(p_leaf_parent) - std::log(prob_grow_old) - std::log(p_leaf) - no_split_log_marginal_likelihood + split_log_marginal_likelihood + ); + // Threshold at 0 + if (log_mh_ratio > 0) { + log_mh_ratio = 0; + } - // Compute the final MH ratio - double log_mh_ratio = ( - std::log(pg) + std::log(1-pgl) + std::log(1-pgr) - std::log(1-pg) + std::log(prob_prune_new) + - std::log(p_leaf_parent) - std::log(prob_grow_old) - std::log(p_leaf) - no_split_log_marginal_likelihood + split_log_marginal_likelihood - ); - // Threshold at 0 - if (log_mh_ratio > 0) { - log_mh_ratio = 0; - } + // Draw a uniform random variable and accept/reject the proposal on this basis + std::uniform_real_distribution mh_accept(0.0, 1.0); + double log_acceptance_prob = std::log(mh_accept(gen)); + if (log_acceptance_prob <= log_mh_ratio) { + accept = true; + AddSplitToModel(tracker, dataset, tree_prior, split, gen, tree, tree_num, leaf_chosen, var_chosen, false); + } else { + accept = false; + } - // Draw a uniform random variable and accept/reject the proposal on this basis - std::uniform_real_distribution mh_accept(0.0, 1.0); - double log_acceptance_prob = std::log(mh_accept(gen)); - if (log_acceptance_prob <= log_mh_ratio) { - accept = true; - AddSplitToModel(tracker, dataset, tree_prior, split, gen, tree, tree_num, leaf_chosen, var_chosen, false); } else { accept = false; }