From 9748c14b4a8236b8fd4d27cf2dde4a09c4ee58e1 Mon Sep 17 00:00:00 2001 From: jcozar87 Date: Tue, 20 Aug 2019 11:07:25 +0200 Subject: [PATCH] fix: posterior predictive queries now intercept hidden variables --- inferpy/models/prob_model.py | 28 ++++++---------------------- inferpy/queries/query.py | 17 ++++++++++------- 2 files changed, 16 insertions(+), 29 deletions(-) diff --git a/inferpy/models/prob_model.py b/inferpy/models/prob_model.py index 8f0ffb9..554de49 100644 --- a/inferpy/models/prob_model.py +++ b/inferpy/models/prob_model.py @@ -94,9 +94,7 @@ def posterior(self, target_names=None, data={}): raise ValueError("target_names must correspond to not observed variables during the inference: \ {}".format([v for v in self.vars.keys() if v not in self.observed_vars])) - prior_data = self._create_hidden_observations(target_names, data) - - return Query(self.inference_method.expanded_variables["q"], target_names, {**data, **prior_data}) + return Query(self.inference_method.expanded_variables["q"], target_names, {**data}) def posterior_predictive(self, target_names=None, data={}): if self.inference_method is None: @@ -110,26 +108,12 @@ def posterior_predictive(self, target_names=None, data={}): raise ValueError("target_names must correspond to observed variables during the inference: \ {}".format(self.observed_vars)) - prior_data = self._create_hidden_observations(target_names, data) - - return Query(self.inference_method.expanded_variables["p"], target_names, {**data, **prior_data}) - - def _create_hidden_observations(self, target_names, data={}): - # TODO: This code must be implemented independent of the inference method. Right now we are using the p and q - # expanded variables, which belongs only to variational inference methods. When a different VI is implemented - # think about a better way to implement this function and access to the correct dict of random variables - - # NOTE: implementation trick. As p model variables are intercepted with q model variables, - # compute prior observations for local hidden variables which are not targets, - # expanding a new model using plate_size and then sampling - hidden_variable_names = [k for k in self.vars.keys() if k not in target_names and k not in data] - if hidden_variable_names: - expanded_vars, _ = self.expand_model(self.inference_method.plate_size) - prior_data = Query(expanded_vars, hidden_variable_names, data).sample(simplify_result=False) - else: - prior_data = {} + # posterior_predictive uses pmodel variables, but intercepted with qmodel variables. + # TODO: local hidden variables should not be intercepted. See issue #185 + return Query(self.inference_method.expanded_variables["p"], target_names, {**data}, + enable_interceptor_variable=self.inference_method.get_interceptable_condition_variable()) - return prior_data + return result def _build_graph(self): with contextmanager.randvar_registry.init(): diff --git a/inferpy/queries/query.py b/inferpy/queries/query.py index 9d85a19..8a8a33d 100644 --- a/inferpy/queries/query.py +++ b/inferpy/queries/query.py @@ -18,7 +18,7 @@ def wrapper(*args, **kwargs): class Query: - def __init__(self, variables, target_names=None, data={}): + def __init__(self, variables, target_names=None, data={}, enable_interceptor_variable=None): # if provided a single name, create a list with only one item if isinstance(target_names, str): target_names = [target_names] @@ -32,13 +32,15 @@ def __init__(self, variables, target_names=None, data={}): self.observed_variables = variables self.data = data + self.enable_interceptor_variable = enable_interceptor_variable @flatten_result @util.tf_run_ignored def log_prob(self): """ Computes the log probabilities of a (set of) sample(s)""" - with contextmanager.observe(self.observed_variables, self.data): - result = util.runtime.try_run({k: v.log_prob(v.value) for k, v in self.target_variables.items()}) + with util.interceptor.enable_interceptor(self.enable_interceptor_variable): + with contextmanager.observe(self.observed_variables, self.data): + result = util.runtime.try_run({k: v.log_prob(v.value) for k, v in self.target_variables.items()}) return result @@ -51,10 +53,11 @@ def sum_log_prob(self): @util.tf_run_ignored def sample(self, size=1): """ Generates a sample for eache variable in the model """ - with contextmanager.observe(self.observed_variables, self.data): - # each iteration for `size` run the dict in the session, so if there are dependencies among random variables - # they are computed in the same graph operations, and reflected in the results - samples = [util.runtime.try_run(self.target_variables) for _ in range(size)] + with util.interceptor.enable_interceptor(self.enable_interceptor_variable): + with contextmanager.observe(self.observed_variables, self.data): + # each iteration for `size` run the dict in the session, so if there are dependencies among random variables + # they are computed in the same graph operations, and reflected in the results + samples = [util.runtime.try_run(self.target_variables) for _ in range(size)] if size == 1: result = samples[0]