Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pymc3jax: AttributeError: 'Identity' object has no attribute 'nfunc_spec' #60

Closed
mschmidt87 opened this issue Sep 29, 2020 · 3 comments
Closed

Comments

@mschmidt87
Copy link

Description of your problem

I ran into another problem with the experimental JAX-based sampler on the pymc3jax branch:

I am playing around with a hierarchical model where I am simulating something like a hierarchical Gaussian mixture process, i.e. I have 3 clusters with associated cluster_means and std around them and then I simulate a number of instantiations for each cluster, where each instance has its own mean and a fixed std.

This is the code to simulate the data

import pandas as pd
import numpy as np
import pymc3 as pm
import arviz as az
import pylab as pl

np.random.seed(123)

N_clusters = 3  # Number of clusters
N_samples = [10, 5, 0]  # Number of samples per cluster
total_samples = sum(N_samples)
N = 100 # Number of samples per sample
cluster_means = [1., 2., 3.]  # Mean of means within cluster
cluster_means_std = [0.1, 0.1, 0.1]  # Std of means within cluster
std = 0.5

data = []
true_means = []
for i in range(N_clusters):
    if N_samples[i] > 0:
        means = np.random.normal(loc=cluster_means[i], scale=cluster_means_std[i], size=N_samples[i])
        true_means = np.append(true_means, means)
        data.append(np.array([np.random.normal(means[j], std, N) for j in range(N_samples[i])]))
data = np.vstack(data)
clusters = []
for i in range(N_clusters):
    clusters += [i] * N_samples[i]
data = data.reshape(-1)

c = np.repeat(clusters, N).reshape(-1)
sample = np.repeat(np.arange(sum(N_samples)), N)

Using these data, I am creating this model:

with pm.Model() as model:
    a = pm.Normal('a', mu= 0., sigma=3., shape=N_clusters)
    sigma_a = pm.Exponential('sigma_a', 1., shape=N_clusters)
    
    mu_tilde = pm.Normal('mu_t', mu=0., sigma=1., shape=total_samples)
    mu = pm.Deterministic('mu', mu_tilde * sigma_a[clusters] + a[clusters])
    
    sigma = pm.Exponential('sigma', 1., shape=total_samples)
    
    data_obs = pm.Normal('data', mu=mu[sample], sigma=sigma[sample], observed=data)

and then I want to use the sampler to do inference:

import pymc3.sampling_jax


with model:
    trace_jax = pm.sampling_jax.sample_numpyro_nuts(
            2000, tune=2000, target_accept=.9)

Please provide the full traceback.
Running this code, I am getting this error:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-6-7bb1ee15df7f> in <module>
      3 
      4 with model:
----> 5     trace_jax = pm.sampling_jax.sample_numpyro_nuts(
      6             2000, tune=2000, target_accept=.9)
      7     idata = trace_jax

/path/to/pymc3/pymc3/sampling_jax.py in sample_numpyro_nuts(draws, tune, chains, target_accept, random_seed, model, progress_bar)
    114 
    115     fgraph = theano.gof.FunctionGraph(model.free_RVs, [model.logpt])
--> 116     fns = theano.sandbox.jaxify.jax_funcify(fgraph)
    117     logp_fn_jax = fns[0]
    118 

/path/to/lib/python3.8/functools.py in wrapper(*args, **kw)
    873                             '1 positional argument')
    874 
--> 875         return dispatch(args[0].__class__)(*args, **kw)
    876 
    877     funcname = getattr(func, '__name__', 'singledispatch function')

/path/to/Theano-PyMC/theano/sandbox/jaxify.py in jax_funcify_FunctionGraph(fgraph)
    523 
    524     out_nodes = [r.owner for r in fgraph.outputs if r.owner is not None]
--> 525     jax_funcs = [compose_jax_funcs(o, fgraph.inputs) for o in out_nodes]
    526 
    527     return jax_funcs

/path/to/Theano-PyMC/theano/sandbox/jaxify.py in <listcomp>(.0)
    523 
    524     out_nodes = [r.owner for r in fgraph.outputs if r.owner is not None]
--> 525     jax_funcs = [compose_jax_funcs(o, fgraph.inputs) for o in out_nodes]
    526 
    527     return jax_funcs

/path/to/Theano-PyMC/theano/sandbox/jaxify.py in compose_jax_funcs(out_node, fgraph_inputs, memo)
    109             input_f = jax_data_func
    110         else:
--> 111             input_f = compose_jax_funcs(i.owner, fgraph_inputs, memo)
    112 
    113         input_funcs.append(input_f)

/path/to/Theano-PyMC/theano/sandbox/jaxify.py in compose_jax_funcs(out_node, fgraph_inputs, memo)
    109             input_f = jax_data_func
    110         else:
--> 111             input_f = compose_jax_funcs(i.owner, fgraph_inputs, memo)
    112 
    113         input_funcs.append(input_f)

/path/to/Theano-PyMC/theano/sandbox/jaxify.py in compose_jax_funcs(out_node, fgraph_inputs, memo)
    109             input_f = jax_data_func
    110         else:
--> 111             input_f = compose_jax_funcs(i.owner, fgraph_inputs, memo)
    112 
    113         input_funcs.append(input_f)

/path/to/Theano-PyMC/theano/sandbox/jaxify.py in compose_jax_funcs(out_node, fgraph_inputs, memo)
    109             input_f = jax_data_func
    110         else:
--> 111             input_f = compose_jax_funcs(i.owner, fgraph_inputs, memo)
    112 
    113         input_funcs.append(input_f)

/path/to/Theano-PyMC/theano/sandbox/jaxify.py in compose_jax_funcs(out_node, fgraph_inputs, memo)
    109             input_f = jax_data_func
    110         else:
--> 111             input_f = compose_jax_funcs(i.owner, fgraph_inputs, memo)
    112 
    113         input_funcs.append(input_f)

/path/to/Theano-PyMC/theano/sandbox/jaxify.py in compose_jax_funcs(out_node, fgraph_inputs, memo)
    109             input_f = jax_data_func
    110         else:
--> 111             input_f = compose_jax_funcs(i.owner, fgraph_inputs, memo)
    112 
    113         input_funcs.append(input_f)

/path/to/Theano-PyMC/theano/sandbox/jaxify.py in compose_jax_funcs(out_node, fgraph_inputs, memo)
    109             input_f = jax_data_func
    110         else:
--> 111             input_f = compose_jax_funcs(i.owner, fgraph_inputs, memo)
    112 
    113         input_funcs.append(input_f)

/path/to/Theano-PyMC/theano/sandbox/jaxify.py in compose_jax_funcs(out_node, fgraph_inputs, memo)
    109             input_f = jax_data_func
    110         else:
--> 111             input_f = compose_jax_funcs(i.owner, fgraph_inputs, memo)
    112 
    113         input_funcs.append(input_f)

/path/to/Theano-PyMC/theano/sandbox/jaxify.py in compose_jax_funcs(out_node, fgraph_inputs, memo)
    109             input_f = jax_data_func
    110         else:
--> 111             input_f = compose_jax_funcs(i.owner, fgraph_inputs, memo)
    112 
    113         input_funcs.append(input_f)

/path/to/Theano-PyMC/theano/sandbox/jaxify.py in compose_jax_funcs(out_node, fgraph_inputs, memo)
    109             input_f = jax_data_func
    110         else:
--> 111             input_f = compose_jax_funcs(i.owner, fgraph_inputs, memo)
    112 
    113         input_funcs.append(input_f)

/path/to/Theano-PyMC/theano/sandbox/jaxify.py in compose_jax_funcs(out_node, fgraph_inputs, memo)
    109             input_f = jax_data_func
    110         else:
--> 111             input_f = compose_jax_funcs(i.owner, fgraph_inputs, memo)
    112 
    113         input_funcs.append(input_f)

/path/to/Theano-PyMC/theano/sandbox/jaxify.py in compose_jax_funcs(out_node, fgraph_inputs, memo)
     90         return memo[out_node]
     91 
---> 92     jax_return_func = jax_funcify(out_node.op)
     93 
     94     input_funcs = []

/path/to/lib/python3.8/functools.py in wrapper(*args, **kw)
    873                             '1 positional argument')
    874 
--> 875         return dispatch(args[0].__class__)(*args, **kw)
    876 
    877     funcname = getattr(func, '__name__', 'singledispatch function')

/path/to/Theano-PyMC/theano/sandbox/jaxify.py in jax_funcify_Elemwise(op)
    320 def jax_funcify_Elemwise(op):
    321     scalar_op = op.scalar_op
--> 322     return jax_funcify(scalar_op)
    323 
    324 

/path/to/lib/python3.8/functools.py in wrapper(*args, **kw)
    873                             '1 positional argument')
    874 
--> 875         return dispatch(args[0].__class__)(*args, **kw)
    876 
    877     funcname = getattr(func, '__name__', 'singledispatch function')

/path/to/Theano-PyMC/theano/sandbox/jaxify.py in jax_funcify_ScalarOp(op)
    142 def jax_funcify_ScalarOp(op):
    143     print(op)
--> 144     func_name = op.nfunc_spec[0]
    145 
    146     if "." in func_name:

AttributeError: 'Identity' object has no attribute 'nfunc_spec'

Versions and main components

  • PyMC3 Version: checkout of pymc3jax branch
  • Theano Version: checkout of Theano-Pymc master branch
  • Python Version: 3.8
  • Operating system: Mac OS
  • How did you install PyMC3: manual installation of the branch
@junpenglao
Copy link
Contributor

Thanks - it will be fixed by #48

@junpenglao
Copy link
Contributor

could you try on your model again?

@brandonwillard brandonwillard transferred this issue from pymc-devs/pymc Sep 29, 2020
@mschmidt87
Copy link
Author

Yes, problem solved, thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants