Skip to content

Commit

Permalink
use Rockova prior, refactor prior leaf prob computaion
Browse files Browse the repository at this point in the history
  • Loading branch information
aloctavodia committed Oct 16, 2020
1 parent 6700a74 commit ac96b1a
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 67 deletions.
25 changes: 7 additions & 18 deletions pymc3/distributions/bart.py
Expand Up @@ -10,7 +10,7 @@ class BARTParamsError(Exception):


class BaseBART(NoDistribution):
def __init__(self, X, Y, m=200, alpha=0.95, beta=2.0, cache_size=5000, *args, **kwargs):
def __init__(self, X, Y, m=200, alpha=0.25, cache_size=5000, *args, **kwargs):

self.Y_shared = Y
self.X = X
Expand Down Expand Up @@ -45,26 +45,16 @@ def __init__(self, X, Y, m=200, alpha=0.95, beta=2.0, cache_size=5000, *args, **
raise BARTParamsError(
"The type for the alpha parameter for the tree structure must be float"
)
if alpha <= 0 or 1 <= alpha:
if alpha <= 0 or 0.5 <= alpha:
raise BARTParamsError(
"The value for the alpha parameter for the tree structure "
"must be in the interval (0, 1)"
)
if not isinstance(beta, float):
raise BARTParamsError(
"The type for the beta parameter for the tree structure must be float"
)
if beta < 0:
raise BARTParamsError(
"The value for the beta parameter for the tree structure "
'must be in the interval [0, float("inf"))'
"must be in the interval (0, 0.5)"
)

self.num_observations = X.shape[0]
self.number_variates = X.shape[1]
self.m = m
self.alpha = alpha
self.beta = beta
self._normal_dist_sampler = NormalDistributionSampler(cache_size)
self._disc_uniform_dist_sampler = DiscreteUniformDistributionSampler(cache_size)
self.trees = self.init_list_of_trees()
Expand Down Expand Up @@ -231,8 +221,8 @@ def refresh_cache(self):


class BART(BaseBART):
def __init__(self, X, Y, m=200, alpha=0.95, beta=2.0):
super().__init__(X, Y, m, alpha, beta)
def __init__(self, X, Y, m=200, alpha=0.25):
super().__init__(X, Y, m, alpha)

def _repr_latex_(self, name=None, dist=None):
if dist is None:
Expand All @@ -241,11 +231,10 @@ def _repr_latex_(self, name=None, dist=None):
Y = (type(self.Y),)
m = (self.m,)
alpha = self.alpha
beta = self.beta
m = self.m
name = r"\text{%s}" % name
return r"""${} \sim \text{{BART}}(\mathit{{alpha}}={},~\mathit{{beta}}={},~\mathit{{m}}={})$""".format(
name, alpha, beta, m
return r"""${} \sim \text{{BART}}(\mathit{{alpha}}={},~\mathit{{m}}={})$""".format(
name, alpha, m
)

def draw_leaf_value(self, tree, idx_data_points):
Expand Down
44 changes: 0 additions & 44 deletions pymc3/distributions/tree.py
Expand Up @@ -381,28 +381,6 @@ def evaluate_splitting_rule(self, x):
else:
return x[self.idx_split_variable] <= self.split_value

def prior_log_probability_node(self, alpha, beta):
"""
Calculate the log probability of the node being a SplitNode.
Taken from equation 7 in [Chipman2010].
Parameters
----------
alpha : float
beta : float
Returns
-------
float
References
----------
.. [Chipman2010] Chipman, H. A., George, E. I., & McCulloch, R. E. (2010). BART: Bayesian
additive regression trees. The Annals of Applied Statistics, 4(1), 266-298.,
`link <https://projecteuclid.org/download/pdfview_1/euclid.aoas/1273584455>`__
"""
return np.log(alpha * np.power(1.0 + self.depth, -beta))


class LeafNode(BaseNode):
def __init__(self, index, value, idx_data_points):
Expand All @@ -427,25 +405,3 @@ def __eq__(self, other):
)
else:
return NotImplemented

def prior_log_probability_node(self, alpha, beta):
"""
Calculate the log probability of the node being a LeafNode (1 - p(being SplitNode)).
Taken from equation 7 in [Chipman2010].
Parameters
----------
alpha : float
beta : float
Returns
-------
float
References
----------
.. [Chipman2010] Chipman, H. A., George, E. I., & McCulloch, R. E. (2010). BART: Bayesian
additive regression trees. The Annals of Applied Statistics, 4(1), 266-298.,
`link <https://projecteuclid.org/download/pdfview_1/euclid.aoas/1273584455>`__
"""
return np.log(1.0 - alpha * np.power(1.0 + self.depth, -beta))
43 changes: 38 additions & 5 deletions pymc3/step_methods/pgbart.py
Expand Up @@ -35,13 +35,14 @@ def __init__(self, vars=None, num_particles=10, max_stages=5000, model=None):
model = modelcontext(model)
vars = inputvars(vars)
self.bart = vars[0].distribution
self.prior_prob_leaf_node = _compute_prior_probability(self.bart.alpha)

self.num_particles = num_particles
self.max_stages = max_stages
self.first_iteration = True
self.previous_trees_particles_list = []
for i in range(self.bart.m):
p = Particle(self.bart.trees[i])
p = Particle(self.bart.trees[i], self.prior_prob_leaf_node)
self.previous_trees_particles_list.append(p)

shared = make_shared_replacements(vars, model)
Expand Down Expand Up @@ -185,7 +186,7 @@ def init_particles(self, tree_id, R_j, num_observations):
idx_data_points=initial_idx_data_points_leaf_nodes,
)
for _ in range(self.num_particles):
new_particle = Particle(new_tree)
new_particle = Particle(new_tree, self.prior_prob_leaf_node)
list_of_particles.append(new_particle)
return list_of_particles

Expand All @@ -197,20 +198,24 @@ def resample(self, list_of_particles, normalized_weights):


class Particle:
def __init__(self, tree):
def __init__(self, tree, prior_prob_leaf_node):
self.tree = tree.copy() # Mantiene el arbol que nos interesa en este momento
self.expansion_nodes = self.tree.idx_leaf_nodes.copy() # This should be the array [0]
self.tree_history = [self.tree.copy()]
self.expansion_nodes_history = [self.expansion_nodes.copy()]
self.log_weight = 0.0
self.prior_prob_leaf_node = prior_prob_leaf_node

def sample_tree_sequential(self, bart):
if self.expansion_nodes:
index_leaf_node = self.expansion_nodes.pop(0)
# Probability that this node will remain a leaf node
log_prob = self.tree[index_leaf_node].prior_log_probability_node(bart.alpha, bart.beta)
try:
prob_leaf = self.prior_prob_leaf_node[self.tree[index_leaf_node].depth]
except IndexError:
prob_leaf = 1

if np.exp(log_prob) < np.random.random():
if prob_leaf < np.random.random():
self.grow_successful = bart.grow_tree(self.tree, index_leaf_node)
# TODO: in case the grow_tree fails, should we try to sample the tree from another leaf node?
if self.grow_successful:
Expand All @@ -237,6 +242,34 @@ def update_weight(self):
pass


def _compute_prior_probability(alpha):
"""
Calculate the probability of the node being a LeafNode (1 - p(being SplitNode)).
Taken from equation 19 in [On Theory for BART]. XXX FIX reference below!
Parameters
----------
alpha : float
Returns
-------
float
References
----------
.. [Chipman2010] Chipman, H. A., George, E. I., & McCulloch, R. E. (2010). BART: Bayesian
additive regression trees. The Annals of Applied Statistics, 4(1), 266-298.,
`link <https://projecteuclid.org/download/pdfview_1/euclid.aoas/1273584455>`__
"""
prior_leaf_prob = [0]
depth = 1
prob = 1
while prob < 1:
prob = 1 - alpha ** depth
depth += 1
return prior_leaf_prob


from theano import function as theano_function
from ..theanof import join_nonshared_inputs

Expand Down

0 comments on commit ac96b1a

Please sign in to comment.