Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Potential bug using scan to create random variable in PyMC #173

Closed
junpenglao opened this issue Sep 4, 2022 · 6 comments
Closed

Potential bug using scan to create random variable in PyMC #173

junpenglao opened this issue Sep 4, 2022 · 6 comments

Comments

@junpenglao
Copy link

Currently, a simple random walk example fails for me:

import numpy as np
import pymc as pm
import matplotlib.pyplot as plt

import aesara
import aesara.tensor as at

import aeppl

num_timesteps = 100
data = np.random.normal(0, 2.5, size=num_timesteps).cumsum()
plt.plot(data);

with pm.Model() as m:
    sigma = pm.HalfNormal("sigma", 5.)
    mu = pm.Normal("mu", 0., 1.)
    X_rv, updates = aesara.scan(
        fn=lambda x_tm1: at.random.normal(x_tm1, sigma),
        outputs_info=[{"initial": mu}],
        n_steps=num_timesteps
        )
    m.register_rv(X_rv, name="X_rv", data=data)
    # X_rv = pm.GaussianRandomWalk("X_rv", mu, sigma, observed=data)
    idata = pm.sample()

with:

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
/var/folders/7p/srk5qjp563l5f9mrjtp44bh800jqsw/T/ipykernel_33625/1955773877.py in <module>
      9     m.register_rv(X_rv, name="X_rv", data=data)
     10     # X_rv = pm.GaussianRandomWalk("X_rv", mu, sigma, observed=data)
---> 11     idata = pm.sample()

~/Documents/OSS/pymc/pymc/sampling.py in sample(draws, step, init, n_init, initvals, trace, chain_idx, chains, cores, tune, progressbar, model, random_seed, discard_tuned_samples, compute_convergence_checks, callback, jitter_max_retries, return_inferencedata, idata_kwargs, mp_ctx, **kwargs)
    528 
    529     initial_points = None
--> 530     step = assign_step_methods(model, step, methods=pm.STEP_METHODS, step_kwargs=kwargs)
    531 
    532     if isinstance(step, list):

~/Documents/OSS/pymc/pymc/sampling.py in assign_step_methods(model, step, methods, step_kwargs)
    204     # variables
    205     selected_steps = defaultdict(list)
--> 206     model_logp = model.logp()
    207 
    208     for var in model.value_vars:

~/Documents/OSS/pymc/pymc/model.py in logp(self, vars, jacobian, sum)
    756         rv_logps: List[TensorVariable] = []
    757         if rv_values:
--> 758             rv_logps = joint_logp(list(rv_values.keys()), rv_values, sum=False, jacobian=jacobian)
    759             assert isinstance(rv_logps, list)
    760 

~/Documents/OSS/pymc/pymc/distributions/logprob.py in joint_logp(var, rv_values, jacobian, scaling, transformed, sum, **kwargs)
    269     logp_var_dict = {}
    270     for value_var in rv_values.values():
--> 271         logp_var_dict[value_var] = temp_logp_var_dict[value_var]
    272 
    273     if scaling:

KeyError: X_rv{[-1.502085..27194e+01]}

Inspecting with aeppl seems to indicate it does not recognize the RV result from scan:

x_vv = at.constant(data)
mu_vv = mu.clone()
sigma_vv = sigma.clone()

logp_dict = aeppl.factorized_joint_logprob({X_rv: x_vv, mu: mu_vv, sigma: sigma_vv})
logp_dict
# ==> {mu: mu_logprob, sigma: sigma_logprob}

cc @ricardoV94

@ricardoV94
Copy link
Contributor

Sounds like the problem is simply not passing sigma as a non_sequence. Probably Aeppl requires this information: https://colab.research.google.com/drive/1Yvh0VpZE4Bhtu5-mL4qXnlK0-j45ARBF#scrollTo=WPVbMVisw8lM

@junpenglao
Copy link
Author

Thanks @ricardoV94! Besides passing sigma as non_sequence, it is also important to have {x.owner.inputs[0]: x.owner.outputs[0]} in the return.

@ricardoV94
Copy link
Contributor

Thanks @ricardoV94! Besides passing sigma as non_sequence, it is also important to have {x.owner.inputs[0]: x.owner.outputs[0]} in the return.

Only important for forward sampling I think. If you use RandomStream to create the RVs inside scan it all happens behind the scenes.

@ricardoV94
Copy link
Contributor

ricardoV94 commented Sep 4, 2022

Ah, the gradient (but not the logp) raises an error if you have a scan without the explicit updates... strange.

sigma_rv = at.random.halfnormal(0, 5.0, name="sigma")
mu_rv = at.random.normal(0, 1.0, name="mu")

def step(x_tm1, sigma_rv):
    x = at.random.normal(x_tm1, sigma_rv)
    return x #, {x.owner.inputs[0]: x.owner.outputs[0]}

scan_rv, updates = aesara.scan(
    fn=step,
    outputs_info=[{"initial": mu_rv}],
    n_steps=num_timesteps,
    non_sequences=[sigma_rv],
)
scan_rv.name = "scan"

sigma_vv = sigma_rv.clone()
mu_vv = mu_rv.clone()
scan_vv = scan_rv.clone()

logp_dict = aeppl.factorized_joint_logprob({
    sigma_rv: sigma_vv,
    mu_rv: mu_vv,
    scan_rv: scan_vv,
})

# The next line raises
at.grad(logp_dict[scan_vv].sum(), wrt=sigma_vv)
TypeError: Tensor type field must be a TensorType; found <class 'aesara.tensor.random.type.RandomGeneratorType'>.

@rlouf
Copy link
Member

rlouf commented Nov 2, 2022

@junpenglao Does this still fail? If it works by modifying the original example would you mind sharing the modified version?

@junpenglao
Copy link
Author

Yes, this is the working implementation:

with pm.Model(coords={"timestep": np.arange(num_timesteps)}) as m:
    sigma = pm.HalfNormal("sigma", 5.0)
    mu = pm.Normal("mu", 0.0, 1.0)

    def step(x_tm1, sigma):
        x = pm.Normal.dist(x_tm1, sigma)
        # x = at.random.normal(x_tm1, sigma)
        # Return the new variable and the RNG update expression
        return x, {x.owner.inputs[0]: x.owner.outputs[0]}

    X_rv, updates = aesara.scan(
        fn=step, outputs_info=[mu], non_sequences=[sigma], n_steps=num_timesteps
    )
    m.register_rv(X_rv, name="X_rv", data=data)
    # X_rv = pm.GaussianRandomWalk("X_rv", mu, sigma, observed=data)
    idata = pm.sample()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants