Navigation Menu

Skip to content

Commit

Permalink
fix: posterior predictive queries now intercept hidden variables
Browse files Browse the repository at this point in the history
  • Loading branch information
jcozar87 committed Aug 20, 2019
1 parent 713bdee commit 9748c14
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 29 deletions.
28 changes: 6 additions & 22 deletions inferpy/models/prob_model.py
Expand Up @@ -94,9 +94,7 @@ def posterior(self, target_names=None, data={}):
raise ValueError("target_names must correspond to not observed variables during the inference: \
{}".format([v for v in self.vars.keys() if v not in self.observed_vars]))

prior_data = self._create_hidden_observations(target_names, data)

return Query(self.inference_method.expanded_variables["q"], target_names, {**data, **prior_data})
return Query(self.inference_method.expanded_variables["q"], target_names, {**data})

def posterior_predictive(self, target_names=None, data={}):
if self.inference_method is None:
Expand All @@ -110,26 +108,12 @@ def posterior_predictive(self, target_names=None, data={}):
raise ValueError("target_names must correspond to observed variables during the inference: \
{}".format(self.observed_vars))

prior_data = self._create_hidden_observations(target_names, data)

return Query(self.inference_method.expanded_variables["p"], target_names, {**data, **prior_data})

def _create_hidden_observations(self, target_names, data={}):
# TODO: This code must be implemented independent of the inference method. Right now we are using the p and q
# expanded variables, which belongs only to variational inference methods. When a different VI is implemented
# think about a better way to implement this function and access to the correct dict of random variables

# NOTE: implementation trick. As p model variables are intercepted with q model variables,
# compute prior observations for local hidden variables which are not targets,
# expanding a new model using plate_size and then sampling
hidden_variable_names = [k for k in self.vars.keys() if k not in target_names and k not in data]
if hidden_variable_names:
expanded_vars, _ = self.expand_model(self.inference_method.plate_size)
prior_data = Query(expanded_vars, hidden_variable_names, data).sample(simplify_result=False)
else:
prior_data = {}
# posterior_predictive uses pmodel variables, but intercepted with qmodel variables.
# TODO: local hidden variables should not be intercepted. See issue #185
return Query(self.inference_method.expanded_variables["p"], target_names, {**data},
enable_interceptor_variable=self.inference_method.get_interceptable_condition_variable())

return prior_data
return result

def _build_graph(self):
with contextmanager.randvar_registry.init():
Expand Down
17 changes: 10 additions & 7 deletions inferpy/queries/query.py
Expand Up @@ -18,7 +18,7 @@ def wrapper(*args, **kwargs):


class Query:
def __init__(self, variables, target_names=None, data={}):
def __init__(self, variables, target_names=None, data={}, enable_interceptor_variable=None):
# if provided a single name, create a list with only one item
if isinstance(target_names, str):
target_names = [target_names]
Expand All @@ -32,13 +32,15 @@ def __init__(self, variables, target_names=None, data={}):

self.observed_variables = variables
self.data = data
self.enable_interceptor_variable = enable_interceptor_variable

@flatten_result
@util.tf_run_ignored
def log_prob(self):
""" Computes the log probabilities of a (set of) sample(s)"""
with contextmanager.observe(self.observed_variables, self.data):
result = util.runtime.try_run({k: v.log_prob(v.value) for k, v in self.target_variables.items()})
with util.interceptor.enable_interceptor(self.enable_interceptor_variable):
with contextmanager.observe(self.observed_variables, self.data):
result = util.runtime.try_run({k: v.log_prob(v.value) for k, v in self.target_variables.items()})

return result

Expand All @@ -51,10 +53,11 @@ def sum_log_prob(self):
@util.tf_run_ignored
def sample(self, size=1):
""" Generates a sample for eache variable in the model """
with contextmanager.observe(self.observed_variables, self.data):
# each iteration for `size` run the dict in the session, so if there are dependencies among random variables
# they are computed in the same graph operations, and reflected in the results
samples = [util.runtime.try_run(self.target_variables) for _ in range(size)]
with util.interceptor.enable_interceptor(self.enable_interceptor_variable):
with contextmanager.observe(self.observed_variables, self.data):
# each iteration for `size` run the dict in the session, so if there are dependencies among random variables
# they are computed in the same graph operations, and reflected in the results
samples = [util.runtime.try_run(self.target_variables) for _ in range(size)]

if size == 1:
result = samples[0]
Expand Down

0 comments on commit 9748c14

Please sign in to comment.