Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jax 0.2 does not support jax.numpy.reshape with non-constant values in omnistaging mode #43

Open
Tracked by #1425
twiecki opened this issue Sep 25, 2020 · 34 comments
Labels
enhancement New feature or request JAX Involves JAX transpilation question Further information is requested

Comments

@twiecki
Copy link
Contributor

twiecki commented Sep 25, 2020

import pymc3 as pm
import theano
import numpy as np
import theano.sandbox.jax

theano.compile.mode.predefined_linkers["jax"] = theano.sandbox.jax.JaxLinker()
jax_mode = theano.compile.Mode(linker="jax")

x = np.linspace(0, 1, 10)
y = x * 4. + 1.4 + np.random.randn(10)

with pm.Model() as model:
    beta = pm.Normal("beta", 0., 5., shape=2)
    sigma = pm.HalfNormal("sigma", 2.5)
    obs = pm.Normal("obs", beta[0] + beta[1] * x, sigma, observed=y)
    pm.sample(mode=jax_mode)

Traceback:

Auto-assigning NUTS sampler...
INFO:pymc3:Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
INFO:pymc3:Initializing NUTS using jitter+adapt_diag...

---------------------------------------------------------------------------
FilteredStackTrace                        Traceback (most recent call last)
<ipython-input-20-21adaeaad34c> in <module>
     21 with model:
---> 22     pm.sample(mode=jax_mode)

~/projects/pymc/pymc3/sampling.py in sample(draws, step, init, n_init, start, trace, chain_idx, chains, cores, tune, progressbar, model, random_seed, discard_tuned_samples, compute_convergence_checks, callback, return_inferencedata, idata_kwargs, mp_ctx, pickle_backend, **kwargs)
    480             _log.info("Auto-assigning NUTS sampler...")
--> 481             start_, step = init_nuts(
    482                 init=init,

~/projects/pymc/pymc3/sampling.py in init_nuts(init, chains, n_init, model, random_seed, progressbar, **kwargs)
   2133 
-> 2134     step = pm.NUTS(potential=potential, model=model, **kwargs)
   2135 

~/projects/pymc/pymc3/step_methods/hmc/nuts.py in __init__(self, vars, max_treedepth, early_max_treedepth, **kwargs)
    167         """
--> 168         super().__init__(vars, **kwargs)
    169 

~/projects/pymc/pymc3/step_methods/hmc/base_hmc.py in __init__(self, vars, scaling, step_scale, is_cov, model, blocked, potential, dtype, Emax, target_accept, gamma, k, t0, adapt_step_size, step_rand, **theano_kwargs)
     92 
---> 93         super().__init__(vars, blocked=blocked, model=model, dtype=dtype, **theano_kwargs)
     94 

~/projects/pymc/pymc3/step_methods/arraystep.py in __init__(self, vars, model, blocked, dtype, logp_dlogp_func, **theano_kwargs)
    253             q = func.dict_to_array(model.test_point)
--> 254             logp, dlogp = func(q)
    255         except ValueError:

~/projects/pymc/pymc3/model.py in __call__(self, array, grad_out, extra_vars)
    738 
--> 739         output = self._theano_function(array)
    740         if grad_out is None:

~/projects/Theano-PyMC/theano/compile/function_module.py in __call__(self, *args, **kwargs)
    978             outputs = (
--> 979                 self.fn()
    980                 if output_subset is None

~/projects/Theano-PyMC/theano/gof/link.py in streamline_default_f()
    702                 ):
--> 703                     thunk()
    704                     for old_s in old_storage:

~/projects/Theano-PyMC/theano/sandbox/jax.py in thunk(node, jax_impl_jits, thunk_outputs)
    653                 ):
--> 654                     outputs = [
    655                         jax_impl_jit(*[x[0] for x in thunk_inputs])

~/projects/Theano-PyMC/theano/sandbox/jax.py in <listcomp>(.0)
    654                     outputs = [
--> 655                         jax_impl_jit(*[x[0] for x in thunk_inputs])
    656                         for jax_impl_jit in jax_impl_jits

~/projects/Theano-PyMC/theano/sandbox/jax.py in jax_func(*inputs)
    125         def jax_func(*inputs):
--> 126             func_args = [fn(*inputs) for fn in input_funcs]
    127             return return_func(*func_args)

~/projects/Theano-PyMC/theano/sandbox/jax.py in <listcomp>(.0)
    125         def jax_func(*inputs):
--> 126             func_args = [fn(*inputs) for fn in input_funcs]
    127             return return_func(*func_args)

~/projects/Theano-PyMC/theano/sandbox/jax.py in jax_func(*inputs)
    126             func_args = [fn(*inputs) for fn in input_funcs]
--> 127             return return_func(*func_args)
    128 

~/projects/Theano-PyMC/theano/sandbox/jax.py in reshape(x, shape)
    545     def reshape(x, shape):
--> 546         return jnp.reshape(x, shape)
    547 

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/numpy/lax_numpy.py in reshape(a, newshape, order)
   1145   try:
-> 1146     return a.reshape(newshape, order=order)  # forward to method for ndarrays
   1147   except AttributeError:

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/numpy/lax_numpy.py in _reshape_method(a, *newshape, **kwargs)
   1191     newshape = newshape[0]
-> 1192   return _reshape(a, newshape, order=order)
   1193 

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/numpy/lax_numpy.py in _reshape(a, newshape, order)
   1167 def _reshape(a, newshape, order="C"):
-> 1168   computed_newshape = _compute_newshape(a, newshape)
   1169   if order == "C":

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/numpy/lax_numpy.py in _compute_newshape(a, newshape)
   1158       int, size, "The error arose in jax.numpy.reshape.")
-> 1159   newshape = [check(size) for size in newshape] if iterable else check(newshape)
   1160   newsize = _prod((newshape,) if type(newshape) is Poly else newshape)

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/numpy/lax_numpy.py in <listcomp>(.0)
   1158       int, size, "The error arose in jax.numpy.reshape.")
-> 1159   newshape = [check(size) for size in newshape] if iterable else check(newshape)
   1160   newsize = _prod((newshape,) if type(newshape) is Poly else newshape)

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/numpy/lax_numpy.py in check(size)
   1156   def check(size):
-> 1157     return size if type(size) is Poly else core.concrete_or_error(
   1158       int, size, "The error arose in jax.numpy.reshape.")

FilteredStackTrace: jax.core.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected.

The error arose in jax.numpy.reshape.

While tracing the function incsubtensor at /Users/twiecki/projects/Theano-PyMC/theano/sandbox/jax.py:125, this value became a tracer due to JAX operations on these lines:

  operation yn:bool[] = lt yl:int64[] ym:int64[]
    from line /Users/twiecki/projects/Theano-PyMC/theano/sandbox/jax.py:127 (jax_func)

  operation yq:int64[] = xla_call[ backend=None
                       call_jaxpr={ lambda  ; a b c.
                                    let d = select a b c
                                    in (d,) }
                       device=None
                       donated_invars=(False, False, False)
                       name=_where ] yn:bool[] yo:int64[] yp:int64[]
    from line /Users/twiecki/projects/Theano-PyMC/theano/sandbox/jax.py:127 (jax_func)

You can use transformation parameters such as `static_argnums` for `jit` to avoid tracing particular arguments of transformed functions.

See https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error for more information.

Encountered tracer value: Traced<ShapedArray(int64[])>with<DynamicJaxprTrace(level=0/1)>

The stack trace above excludes JAX-internal frames.
The following is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

ConcretizationTypeError                   Traceback (most recent call last)
~/projects/Theano-PyMC/theano/gof/link.py in streamline_default_f()
    702                 ):
--> 703                     thunk()
    704                     for old_s in old_storage:

~/projects/Theano-PyMC/theano/sandbox/jax.py in thunk(node, jax_impl_jits, thunk_outputs)
    653                 ):
--> 654                     outputs = [
    655                         jax_impl_jit(*[x[0] for x in thunk_inputs])

~/projects/Theano-PyMC/theano/sandbox/jax.py in <listcomp>(.0)
    654                     outputs = [
--> 655                         jax_impl_jit(*[x[0] for x in thunk_inputs])
    656                         for jax_impl_jit in jax_impl_jits

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/traceback_util.py in reraise_with_filtered_traceback(*args, **kwargs)
    136     try:
--> 137       return fun(*args, **kwargs)
    138     except Exception as e:

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/api.py in f_jitted(*args, **kwargs)
    208     flat_fun, out_tree = flatten_fun(f, in_tree)
--> 209     out = xla.xla_call(
    210         flat_fun,

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/core.py in bind(self, fun, *args, **params)
   1143   def bind(self, fun, *args, **params):
-> 1144     return call_bind(self, fun, *args, **params)
   1145 

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/core.py in call_bind(primitive, fun, *args, **params)
   1134   with maybe_new_sublevel(top_trace):
-> 1135     outs = primitive.process(top_trace, fun, tracers, params)
   1136   return map(full_lower, apply_todos(env_trace_todo(), outs))

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/core.py in process(self, trace, fun, tracers, params)
   1146   def process(self, trace, fun, tracers, params):
-> 1147     return trace.process_call(self, fun, tracers, params)
   1148 

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/core.py in process_call(self, primitive, f, tracers, params)
    576   def process_call(self, primitive, f, tracers, params):
--> 577     return primitive.impl(f, *tracers, **params)
    578   process_map = process_call

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/interpreters/xla.py in _xla_call_impl(fun, device, backend, name, donated_invars, *args)
    528 def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name, donated_invars):
--> 529   compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
    530                                *unsafe_map(arg_spec, args))

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/linear_util.py in memoized_fun(fun, *args)
    233     else:
--> 234       ans = call(fun, *args)
    235       cache[key] = (ans, fun.stores)

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/interpreters/xla.py in _xla_callable(fun, device, backend, name, donated_invars, *arg_specs)
    594   if config.omnistaging_enabled:
--> 595     jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, abstract_args)
    596     if any(isinstance(c, core.Tracer) for c in consts):

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/interpreters/partial_eval.py in trace_to_jaxpr_final(fun, in_avals)
   1022     main.jaxpr_stack = ()  # type: ignore
-> 1023     jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
   1024     del main

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/interpreters/partial_eval.py in trace_to_subjaxpr_dynamic(fun, main, in_avals)
   1003     in_tracers = map(trace.new_arg, in_avals)
-> 1004     ans = fun.call_wrapped(*in_tracers)
   1005     out_tracers = map(trace.full_raise, ans)

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
    150     try:
--> 151       ans = self.f(*args, **dict(self.params, **kwargs))
    152     except:

~/projects/Theano-PyMC/theano/sandbox/jax.py in jax_func(*inputs)
    125         def jax_func(*inputs):
--> 126             func_args = [fn(*inputs) for fn in input_funcs]
    127             return return_func(*func_args)

~/projects/Theano-PyMC/theano/sandbox/jax.py in <listcomp>(.0)
    125         def jax_func(*inputs):
--> 126             func_args = [fn(*inputs) for fn in input_funcs]
    127             return return_func(*func_args)

~/projects/Theano-PyMC/theano/sandbox/jax.py in jax_func(*inputs)
    126             func_args = [fn(*inputs) for fn in input_funcs]
--> 127             return return_func(*func_args)
    128 

~/projects/Theano-PyMC/theano/sandbox/jax.py in reshape(x, shape)
    545     def reshape(x, shape):
--> 546         return jnp.reshape(x, shape)
    547 

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/numpy/lax_numpy.py in reshape(a, newshape, order)
   1145   try:
-> 1146     return a.reshape(newshape, order=order)  # forward to method for ndarrays
   1147   except AttributeError:

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/numpy/lax_numpy.py in _reshape_method(a, *newshape, **kwargs)
   1191     newshape = newshape[0]
-> 1192   return _reshape(a, newshape, order=order)
   1193 

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/numpy/lax_numpy.py in _reshape(a, newshape, order)
   1167 def _reshape(a, newshape, order="C"):
-> 1168   computed_newshape = _compute_newshape(a, newshape)
   1169   if order == "C":

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/numpy/lax_numpy.py in _compute_newshape(a, newshape)
   1158       int, size, "The error arose in jax.numpy.reshape.")
-> 1159   newshape = [check(size) for size in newshape] if iterable else check(newshape)
   1160   newsize = _prod((newshape,) if type(newshape) is Poly else newshape)

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/numpy/lax_numpy.py in <listcomp>(.0)
   1158       int, size, "The error arose in jax.numpy.reshape.")
-> 1159   newshape = [check(size) for size in newshape] if iterable else check(newshape)
   1160   newsize = _prod((newshape,) if type(newshape) is Poly else newshape)

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/numpy/lax_numpy.py in check(size)
   1156   def check(size):
-> 1157     return size if type(size) is Poly else core.concrete_or_error(
   1158       int, size, "The error arose in jax.numpy.reshape.")

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/core.py in concrete_or_error(force, val, context)
    873     else:
--> 874       raise_concretization_error(val, context)
    875   else:

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/core.py in raise_concretization_error(val, context)
    852           f"Encountered tracer value: {val}")
--> 853   raise ConcretizationTypeError(msg)
    854 

ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected.

The error arose in jax.numpy.reshape.

While tracing the function incsubtensor at /Users/twiecki/projects/Theano-PyMC/theano/sandbox/jax.py:125, this value became a tracer due to JAX operations on these lines:

  operation yn:bool[] = lt yl:int64[] ym:int64[]
    from line /Users/twiecki/projects/Theano-PyMC/theano/sandbox/jax.py:127 (jax_func)

  operation yq:int64[] = xla_call[ backend=None
                       call_jaxpr={ lambda  ; a b c.
                                    let d = select a b c
                                    in (d,) }
                       device=None
                       donated_invars=(False, False, False)
                       name=_where ] yn:bool[] yo:int64[] yp:int64[]
    from line /Users/twiecki/projects/Theano-PyMC/theano/sandbox/jax.py:127 (jax_func)

You can use transformation parameters such as `static_argnums` for `jit` to avoid tracing particular arguments of transformed functions.

See https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error for more information.

Encountered tracer value: Traced<ShapedArray(int64[])>with<DynamicJaxprTrace(level=0/1)>

During handling of the above exception, another exception occurred:

ConcretizationTypeError                   Traceback (most recent call last)
<ipython-input-20-21adaeaad34c> in <module>
     20 
     21 with model:
---> 22     pm.sample(mode=jax_mode)

~/projects/pymc/pymc3/sampling.py in sample(draws, step, init, n_init, start, trace, chain_idx, chains, cores, tune, progressbar, model, random_seed, discard_tuned_samples, compute_convergence_checks, callback, return_inferencedata, idata_kwargs, mp_ctx, pickle_backend, **kwargs)
    479             # By default, try to use NUTS
    480             _log.info("Auto-assigning NUTS sampler...")
--> 481             start_, step = init_nuts(
    482                 init=init,
    483                 chains=chains,

~/projects/pymc/pymc3/sampling.py in init_nuts(init, chains, n_init, model, random_seed, progressbar, **kwargs)
   2132         raise ValueError(f"Unknown initializer: {init}.")
   2133 
-> 2134     step = pm.NUTS(potential=potential, model=model, **kwargs)
   2135 
   2136     return start, step

~/projects/pymc/pymc3/step_methods/hmc/nuts.py in __init__(self, vars, max_treedepth, early_max_treedepth, **kwargs)
    166         `pm.sample` to the desired number of tuning steps.
    167         """
--> 168         super().__init__(vars, **kwargs)
    169 
    170         self.max_treedepth = max_treedepth

~/projects/pymc/pymc3/step_methods/hmc/base_hmc.py in __init__(self, vars, scaling, step_scale, is_cov, model, blocked, potential, dtype, Emax, target_accept, gamma, k, t0, adapt_step_size, step_rand, **theano_kwargs)
     91         vars = inputvars(vars)
     92 
---> 93         super().__init__(vars, blocked=blocked, model=model, dtype=dtype, **theano_kwargs)
     94 
     95         self.adapt_step_size = adapt_step_size

~/projects/pymc/pymc3/step_methods/arraystep.py in __init__(self, vars, model, blocked, dtype, logp_dlogp_func, **theano_kwargs)
    252             func.set_extra_values(model.test_point)
    253             q = func.dict_to_array(model.test_point)
--> 254             logp, dlogp = func(q)
    255         except ValueError:
    256             if logp_dlogp_func is not None:

~/projects/pymc/pymc3/model.py in __call__(self, array, grad_out, extra_vars)
    737             out = grad_out
    738 
--> 739         output = self._theano_function(array)
    740         if grad_out is None:
    741             return output

~/projects/Theano-PyMC/theano/compile/function_module.py in __call__(self, *args, **kwargs)
    977         try:
    978             outputs = (
--> 979                 self.fn()
    980                 if output_subset is None
    981                 else self.fn(output_subset=output_subset)

~/projects/Theano-PyMC/theano/gof/link.py in streamline_default_f()
    705                         old_s[0] = None
    706             except Exception:
--> 707                 raise_with_op(node, thunk)
    708 
    709         f = streamline_default_f

~/projects/Theano-PyMC/theano/gof/link.py in raise_with_op(node, thunk, exc_info, storage_map)
    346         # extra long error message in that case.
    347         pass
--> 348     reraise(exc_type, exc_value, exc_trace)
    349 
    350 

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/six.py in reraise(tp, value, tb)
    700                 value = tp()
    701             if value.__traceback__ is not tb:
--> 702                 raise value.with_traceback(tb)
    703             raise value
    704         finally:

~/projects/Theano-PyMC/theano/gof/link.py in streamline_default_f()
    701                     thunks, order, post_thunk_old_storage
    702                 ):
--> 703                     thunk()
    704                     for old_s in old_storage:
    705                         old_s[0] = None

~/projects/Theano-PyMC/theano/sandbox/jax.py in thunk(node, jax_impl_jits, thunk_outputs)
    652                     node=node, jax_impl_jits=jax_impl_jits, thunk_outputs=thunk_outputs
    653                 ):
--> 654                     outputs = [
    655                         jax_impl_jit(*[x[0] for x in thunk_inputs])
    656                         for jax_impl_jit in jax_impl_jits

~/projects/Theano-PyMC/theano/sandbox/jax.py in <listcomp>(.0)
    653                 ):
    654                     outputs = [
--> 655                         jax_impl_jit(*[x[0] for x in thunk_inputs])
    656                         for jax_impl_jit in jax_impl_jits
    657                     ]

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/traceback_util.py in reraise_with_filtered_traceback(*args, **kwargs)
    135   def reraise_with_filtered_traceback(*args, **kwargs):
    136     try:
--> 137       return fun(*args, **kwargs)
    138     except Exception as e:
    139       if not is_under_reraiser(e):

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/api.py in f_jitted(*args, **kwargs)
    207       _check_arg(arg)
    208     flat_fun, out_tree = flatten_fun(f, in_tree)
--> 209     out = xla.xla_call(
    210         flat_fun,
    211         *args_flat,

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/core.py in bind(self, fun, *args, **params)
   1142 
   1143   def bind(self, fun, *args, **params):
-> 1144     return call_bind(self, fun, *args, **params)
   1145 
   1146   def process(self, trace, fun, tracers, params):

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/core.py in call_bind(primitive, fun, *args, **params)
   1133   tracers = map(top_trace.full_raise, args)
   1134   with maybe_new_sublevel(top_trace):
-> 1135     outs = primitive.process(top_trace, fun, tracers, params)
   1136   return map(full_lower, apply_todos(env_trace_todo(), outs))
   1137 

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/core.py in process(self, trace, fun, tracers, params)
   1145 
   1146   def process(self, trace, fun, tracers, params):
-> 1147     return trace.process_call(self, fun, tracers, params)
   1148 
   1149   def post_process(self, trace, out_tracers, params):

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/core.py in process_call(self, primitive, f, tracers, params)
    575 
    576   def process_call(self, primitive, f, tracers, params):
--> 577     return primitive.impl(f, *tracers, **params)
    578   process_map = process_call
    579 

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/interpreters/xla.py in _xla_call_impl(fun, device, backend, name, donated_invars, *args)
    527 
    528 def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name, donated_invars):
--> 529   compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
    530                                *unsafe_map(arg_spec, args))
    531   try:

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/linear_util.py in memoized_fun(fun, *args)
    232       fun.populate_stores(stores)
    233     else:
--> 234       ans = call(fun, *args)
    235       cache[key] = (ans, fun.stores)
    236     return ans

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/interpreters/xla.py in _xla_callable(fun, device, backend, name, donated_invars, *arg_specs)
    593   abstract_args, arg_devices = unzip2(arg_specs)
    594   if config.omnistaging_enabled:
--> 595     jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, abstract_args)
    596     if any(isinstance(c, core.Tracer) for c in consts):
    597       raise core.UnexpectedTracerError("Encountered an unexpected tracer.")

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/interpreters/partial_eval.py in trace_to_jaxpr_final(fun, in_avals)
   1021     main.source_info = fun_sourceinfo(fun.f)  # type: ignore
   1022     main.jaxpr_stack = ()  # type: ignore
-> 1023     jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
   1024     del main
   1025   return jaxpr, out_avals, consts

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/interpreters/partial_eval.py in trace_to_subjaxpr_dynamic(fun, main, in_avals)
   1002     trace = DynamicJaxprTrace(main, core.cur_sublevel())
   1003     in_tracers = map(trace.new_arg, in_avals)
-> 1004     ans = fun.call_wrapped(*in_tracers)
   1005     out_tracers = map(trace.full_raise, ans)
   1006   jaxpr, out_avals, consts = frame.to_jaxpr(in_tracers, out_tracers)

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
    149 
    150     try:
--> 151       ans = self.f(*args, **dict(self.params, **kwargs))
    152     except:
    153       # Some transformations yield from inside context managers, so we have to

~/projects/Theano-PyMC/theano/sandbox/jax.py in jax_func(*inputs)
    124 
    125         def jax_func(*inputs):
--> 126             func_args = [fn(*inputs) for fn in input_funcs]
    127             return return_func(*func_args)
    128 

~/projects/Theano-PyMC/theano/sandbox/jax.py in <listcomp>(.0)
    124 
    125         def jax_func(*inputs):
--> 126             func_args = [fn(*inputs) for fn in input_funcs]
    127             return return_func(*func_args)
    128 

~/projects/Theano-PyMC/theano/sandbox/jax.py in jax_func(*inputs)
    125         def jax_func(*inputs):
    126             func_args = [fn(*inputs) for fn in input_funcs]
--> 127             return return_func(*func_args)
    128 
    129         jax_funcs.append(update_wrapper(jax_func, return_func))

~/projects/Theano-PyMC/theano/sandbox/jax.py in reshape(x, shape)
    544 def jax_funcify_Reshape(op):
    545     def reshape(x, shape):
--> 546         return jnp.reshape(x, shape)
    547 
    548     return reshape

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/numpy/lax_numpy.py in reshape(a, newshape, order)
   1144 def reshape(a, newshape, order="C"):
   1145   try:
-> 1146     return a.reshape(newshape, order=order)  # forward to method for ndarrays
   1147   except AttributeError:
   1148     return _reshape(a, newshape, order=order)

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/numpy/lax_numpy.py in _reshape_method(a, *newshape, **kwargs)
   1190           type(newshape[0]) is not Poly):
   1191     newshape = newshape[0]
-> 1192   return _reshape(a, newshape, order=order)
   1193 
   1194 

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/numpy/lax_numpy.py in _reshape(a, newshape, order)
   1166 
   1167 def _reshape(a, newshape, order="C"):
-> 1168   computed_newshape = _compute_newshape(a, newshape)
   1169   if order == "C":
   1170     return lax.reshape(a, computed_newshape, None)

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/numpy/lax_numpy.py in _compute_newshape(a, newshape)
   1157     return size if type(size) is Poly else core.concrete_or_error(
   1158       int, size, "The error arose in jax.numpy.reshape.")
-> 1159   newshape = [check(size) for size in newshape] if iterable else check(newshape)
   1160   newsize = _prod((newshape,) if type(newshape) is Poly else newshape)
   1161   if newsize < 0:

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/numpy/lax_numpy.py in <listcomp>(.0)
   1157     return size if type(size) is Poly else core.concrete_or_error(
   1158       int, size, "The error arose in jax.numpy.reshape.")
-> 1159   newshape = [check(size) for size in newshape] if iterable else check(newshape)
   1160   newsize = _prod((newshape,) if type(newshape) is Poly else newshape)
   1161   if newsize < 0:

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/numpy/lax_numpy.py in check(size)
   1155   else: iterable = True
   1156   def check(size):
-> 1157     return size if type(size) is Poly else core.concrete_or_error(
   1158       int, size, "The error arose in jax.numpy.reshape.")
   1159   newshape = [check(size) for size in newshape] if iterable else check(newshape)

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/core.py in concrete_or_error(force, val, context)
    872       return force(val.aval.val)
    873     else:
--> 874       raise_concretization_error(val, context)
    875   else:
    876     return force(val)

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/core.py in raise_concretization_error(val, context)
    851          "See https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error for more information.\n\n"
    852           f"Encountered tracer value: {val}")
--> 853   raise ConcretizationTypeError(msg)
    854 
    855 

ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected.

The error arose in jax.numpy.reshape.

While tracing the function incsubtensor at /Users/twiecki/projects/Theano-PyMC/theano/sandbox/jax.py:125, this value became a tracer due to JAX operations on these lines:

  operation yn:bool[] = lt yl:int64[] ym:int64[]
    from line /Users/twiecki/projects/Theano-PyMC/theano/sandbox/jax.py:127 (jax_func)

  operation yq:int64[] = xla_call[ backend=None
                       call_jaxpr={ lambda  ; a b c.
                                    let d = select a b c
                                    in (d,) }
                       device=None
                       donated_invars=(False, False, False)
                       name=_where ] yn:bool[] yo:int64[] yp:int64[]
    from line /Users/twiecki/projects/Theano-PyMC/theano/sandbox/jax.py:127 (jax_func)

You can use transformation parameters such as `static_argnums` for `jit` to avoid tracing particular arguments of transformed functions.

See https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error for more information.

Encountered tracer value: Traced<ShapedArray(int64[])>with<DynamicJaxprTrace(level=0/1)>
Apply node that caused the error: Sum{acc_dtype=float64}(MakeVector{dtype='float64'}.0)
Toposort index: 46
Inputs types: [TensorType(float64, vector)]
Inputs shapes: [(3,)]
Inputs strides: [(8,)]
Inputs values: [array([0.69049938, 0.        , 0.        ])]
Outputs clients: [['output']]

Backtrace when the node is created(use Theano flag traceback.limit=N to make it longer):
  File "<ipython-input-20-21adaeaad34c>", line 22, in <module>
    pm.sample(mode=jax_mode)
  File "/Users/twiecki/projects/pymc/pymc3/sampling.py", line 481, in sample
    start_, step = init_nuts(
  File "/Users/twiecki/projects/pymc/pymc3/sampling.py", line 2134, in init_nuts
    step = pm.NUTS(potential=potential, model=model, **kwargs)
  File "/Users/twiecki/projects/pymc/pymc3/step_methods/hmc/nuts.py", line 168, in __init__
    super().__init__(vars, **kwargs)
  File "/Users/twiecki/projects/pymc/pymc3/step_methods/hmc/base_hmc.py", line 93, in __init__
    super().__init__(vars, blocked=blocked, model=model, dtype=dtype, **theano_kwargs)
  File "/Users/twiecki/projects/pymc/pymc3/step_methods/arraystep.py", line 245, in __init__
    func = model.logp_dlogp_function(
  File "/Users/twiecki/projects/pymc/pymc3/model.py", line 1005, in logp_dlogp_function
    costs = [self.logpt]
  File "/Users/twiecki/projects/pymc/pymc3/model.py", line 1015, in logpt
    logp = tt.sum([tt.sum(factor) for factor in factors])

HINT: Use the Theano flag 'exception_verbosity=high' for a debugprint and storage map footprint of this apply node.

@twiecki twiecki added the JAX Involves JAX transpilation label Sep 25, 2020
@brandonwillard brandonwillard self-assigned this Sep 25, 2020
@brandonwillard brandonwillard added the bug Something isn't working label Sep 25, 2020
@brandonwillard
Copy link
Member

brandonwillard commented Sep 25, 2020

I just pushed a fix, and—with that—this model appears to work when chains=1. The sampler just sits there when chains is greater than one, though. I'm looking into that now.

@brandonwillard
Copy link
Member

Looks like the problem could be related to a JAX + fork multiprocessing issue.

@brandonwillard
Copy link
Member

I tried passing a multiprocessing context initialized with the "spawn" start method and that went straight to a pickling error:

In [5]: with model:
    trace = pm.sample(1000, chains=2, mp_ctx=ctx)
Auto-assigning NUTS sampler...
INFO:pymc3:Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
INFO:pymc3:Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (2 chains in 4 jobs)
INFO:pymc3:Multiprocess sampling (2 chains in 4 jobs)
NUTS: [sigma, beta]
INFO:pymc3:NUTS: [sigma, beta]
~/projects/code/python/Theano/theano/sandbox/jax.py:198: UserWarning: `jnp.copy` is not implemented yet. Using the object's `copy` method.
  warn("`jnp.copy` is not implemented yet. " "Using the object's `copy` method.")
~/projects/code/python/Theano/theano/sandbox/jax.py:202: UserWarning: Object has no `copy` method: Traced<ShapedArray(float64[2]):JaxprTrace(level=-1/1)>
  warn("Object has no `copy` method: {}".format(x))
~/projects/code/python/Theano/theano/sandbox/jax.py:202: UserWarning: Object has no `copy` method: Traced<ShapedArray(float64[]):JaxprTrace(level=-1/1)>
  warn("Object has no `copy` method: {}".format(x))
---------------------------------------------------------------------------
RemoteTraceback                           Traceback (most recent call last)
RemoteTraceback: 
"""
Traceback (most recent call last):
  File "~/projects/code/python/pymc3/pymc3/parallel_sampling.py", line 114, in _unpickle_step_method
    self._step_method = pickle.loads(self._step_method)
  File "~/projects/code/python/Theano/theano/compile/mode.py", line 305, in __setstate__
    linker = predefined_linkers[linker]
KeyError: 'jax'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "~/projects/code/python/pymc3/pymc3/parallel_sampling.py", line 135, in run
    self._unpickle_step_method()
  File "~/projects/code/python/pymc3/pymc3/parallel_sampling.py", line 116, in _unpickle_step_method
    raise ValueError(unpickle_error)
ValueError: The model could not be unpickled. This is required for sampling with more than one core and multiprocessing context spawn or forkserver.
"""

The above exception was the direct cause of the following exception:

ValueError                                Traceback (most recent call last)
ValueError: The model could not be unpickled. This is required for sampling with more than one core and multiprocessing context spawn or forkserver.

The pickling error arises due to the absence of a JAXLinker entry in theano.compile.mode.predefined_linkers by default (i.e. we currently have to load the theano.sandbox.jax module in order to add the entry).

I'll try adding the JAXLinker as a default entry in theano.compile.mode.predefined_linkers, but—so far—the import dependencies are preventing that (see #45).

@twiecki
Copy link
Contributor Author

twiecki commented Sep 26, 2020

Tried again with the updated branch and cores=1, chains=1 and now getting:

---------------------------------------------------------------------------
FilteredStackTrace                        Traceback (most recent call last)
<ipython-input-1-9b664886d863> in <module>
     22 with model:
---> 23     pm.sample(mode=jax_mode, chains=1, cores=1)

~/projects/pymc/pymc3/sampling.py in sample(draws, step, init, n_init, start, trace, chain_idx, chains, cores, tune, progressbar, model, random_seed, discard_tuned_samples, compute_convergence_checks, callback, return_inferencedata, idata_kwargs, mp_ctx, pickle_backend, **kwargs)
    480             _log.info("Auto-assigning NUTS sampler...")
--> 481             start_, step = init_nuts(
    482                 init=init,

~/projects/pymc/pymc3/sampling.py in init_nuts(init, chains, n_init, model, random_seed, progressbar, **kwargs)
   2133 
-> 2134     step = pm.NUTS(potential=potential, model=model, **kwargs)
   2135 

~/projects/pymc/pymc3/step_methods/hmc/nuts.py in __init__(self, vars, max_treedepth, early_max_treedepth, **kwargs)
    167         """
--> 168         super().__init__(vars, **kwargs)
    169 

~/projects/pymc/pymc3/step_methods/hmc/base_hmc.py in __init__(self, vars, scaling, step_scale, is_cov, model, blocked, potential, dtype, Emax, target_accept, gamma, k, t0, adapt_step_size, step_rand, **theano_kwargs)
     92 
---> 93         super().__init__(vars, blocked=blocked, model=model, dtype=dtype, **theano_kwargs)
     94 

~/projects/pymc/pymc3/step_methods/arraystep.py in __init__(self, vars, model, blocked, dtype, logp_dlogp_func, **theano_kwargs)
    253             q = func.dict_to_array(model.test_point)
--> 254             logp, dlogp = func(q)
    255         except ValueError:

~/projects/pymc/pymc3/model.py in __call__(self, array, grad_out, extra_vars)
    738 
--> 739         output = self._theano_function(array)
    740         if grad_out is None:

~/projects/Theano-PyMC/theano/compile/function_module.py in __call__(self, *args, **kwargs)
    978             outputs = (
--> 979                 self.fn()
    980                 if output_subset is None

~/projects/Theano-PyMC/theano/gof/link.py in streamline_default_f()
    702                 ):
--> 703                     thunk()
    704                     for old_s in old_storage:

~/projects/Theano-PyMC/theano/sandbox/jax.py in thunk(node, jax_impl_jits, thunk_outputs)
    660             ):
--> 661                 outputs = [
    662                     jax_impl_jit(*[x[0] for x in thunk_inputs])

~/projects/Theano-PyMC/theano/sandbox/jax.py in <listcomp>(.0)
    661                 outputs = [
--> 662                     jax_impl_jit(*[x[0] for x in thunk_inputs])
    663                     for jax_impl_jit in jax_impl_jits

~/projects/Theano-PyMC/theano/sandbox/jax.py in jax_func(*inputs)
    123         def jax_func(*inputs):
--> 124             func_args = [fn(*inputs) for fn in input_funcs]
    125             return return_func(*func_args)

~/projects/Theano-PyMC/theano/sandbox/jax.py in <listcomp>(.0)
    123         def jax_func(*inputs):
--> 124             func_args = [fn(*inputs) for fn in input_funcs]
    125             return return_func(*func_args)

~/projects/Theano-PyMC/theano/sandbox/jax.py in jax_func(*inputs)
    124             func_args = [fn(*inputs) for fn in input_funcs]
--> 125             return return_func(*func_args)
    126 

~/projects/Theano-PyMC/theano/sandbox/jax.py in reshape(x, shape)
    543     def reshape(x, shape):
--> 544         return jnp.reshape(x, shape)
    545 

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/numpy/lax_numpy.py in reshape(a, newshape, order)
   1145   try:
-> 1146     return a.reshape(newshape, order=order)  # forward to method for ndarrays
   1147   except AttributeError:

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/numpy/lax_numpy.py in _reshape_method(a, *newshape, **kwargs)
   1191     newshape = newshape[0]
-> 1192   return _reshape(a, newshape, order=order)
   1193 

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/numpy/lax_numpy.py in _reshape(a, newshape, order)
   1167 def _reshape(a, newshape, order="C"):
-> 1168   computed_newshape = _compute_newshape(a, newshape)
   1169   if order == "C":

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/numpy/lax_numpy.py in _compute_newshape(a, newshape)
   1158       int, size, "The error arose in jax.numpy.reshape.")
-> 1159   newshape = [check(size) for size in newshape] if iterable else check(newshape)
   1160   newsize = _prod((newshape,) if type(newshape) is Poly else newshape)

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/numpy/lax_numpy.py in <listcomp>(.0)
   1158       int, size, "The error arose in jax.numpy.reshape.")
-> 1159   newshape = [check(size) for size in newshape] if iterable else check(newshape)
   1160   newsize = _prod((newshape,) if type(newshape) is Poly else newshape)

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/numpy/lax_numpy.py in check(size)
   1156   def check(size):
-> 1157     return size if type(size) is Poly else core.concrete_or_error(
   1158       int, size, "The error arose in jax.numpy.reshape.")

FilteredStackTrace: jax.core.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected.

The error arose in jax.numpy.reshape.

While tracing the function incsubtensor at /Users/twiecki/projects/Theano-PyMC/theano/sandbox/jax.py:123, this value became a tracer due to JAX operations on these lines:

  operation yn:bool[] = lt yl:int64[] ym:int64[]
    from line /Users/twiecki/projects/Theano-PyMC/theano/sandbox/jax.py:125 (jax_func)

  operation yq:int64[] = xla_call[ backend=None
                       call_jaxpr={ lambda  ; a b c.
                                    let d = select a b c
                                    in (d,) }
                       device=None
                       donated_invars=(False, False, False)
                       name=_where ] yn:bool[] yo:int64[] yp:int64[]
    from line /Users/twiecki/projects/Theano-PyMC/theano/sandbox/jax.py:125 (jax_func)

You can use transformation parameters such as `static_argnums` for `jit` to avoid tracing particular arguments of transformed functions.

See https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error for more information.

Encountered tracer value: Traced<ShapedArray(int64[])>with<DynamicJaxprTrace(level=0/1)>

The stack trace above excludes JAX-internal frames.
The following is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

ConcretizationTypeError                   Traceback (most recent call last)
~/projects/Theano-PyMC/theano/gof/link.py in streamline_default_f()
    702                 ):
--> 703                     thunk()
    704                     for old_s in old_storage:

~/projects/Theano-PyMC/theano/sandbox/jax.py in thunk(node, jax_impl_jits, thunk_outputs)
    660             ):
--> 661                 outputs = [
    662                     jax_impl_jit(*[x[0] for x in thunk_inputs])

~/projects/Theano-PyMC/theano/sandbox/jax.py in <listcomp>(.0)
    661                 outputs = [
--> 662                     jax_impl_jit(*[x[0] for x in thunk_inputs])
    663                     for jax_impl_jit in jax_impl_jits

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/traceback_util.py in reraise_with_filtered_traceback(*args, **kwargs)
    136     try:
--> 137       return fun(*args, **kwargs)
    138     except Exception as e:

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/api.py in f_jitted(*args, **kwargs)
    208     flat_fun, out_tree = flatten_fun(f, in_tree)
--> 209     out = xla.xla_call(
    210         flat_fun,

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/core.py in bind(self, fun, *args, **params)
   1143   def bind(self, fun, *args, **params):
-> 1144     return call_bind(self, fun, *args, **params)
   1145 

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/core.py in call_bind(primitive, fun, *args, **params)
   1134   with maybe_new_sublevel(top_trace):
-> 1135     outs = primitive.process(top_trace, fun, tracers, params)
   1136   return map(full_lower, apply_todos(env_trace_todo(), outs))

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/core.py in process(self, trace, fun, tracers, params)
   1146   def process(self, trace, fun, tracers, params):
-> 1147     return trace.process_call(self, fun, tracers, params)
   1148 

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/core.py in process_call(self, primitive, f, tracers, params)
    576   def process_call(self, primitive, f, tracers, params):
--> 577     return primitive.impl(f, *tracers, **params)
    578   process_map = process_call

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/interpreters/xla.py in _xla_call_impl(fun, device, backend, name, donated_invars, *args)
    528 def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name, donated_invars):
--> 529   compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
    530                                *unsafe_map(arg_spec, args))

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/linear_util.py in memoized_fun(fun, *args)
    233     else:
--> 234       ans = call(fun, *args)
    235       cache[key] = (ans, fun.stores)

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/interpreters/xla.py in _xla_callable(fun, device, backend, name, donated_invars, *arg_specs)
    594   if config.omnistaging_enabled:
--> 595     jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, abstract_args)
    596     if any(isinstance(c, core.Tracer) for c in consts):

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/interpreters/partial_eval.py in trace_to_jaxpr_final(fun, in_avals)
   1022     main.jaxpr_stack = ()  # type: ignore
-> 1023     jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
   1024     del main

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/interpreters/partial_eval.py in trace_to_subjaxpr_dynamic(fun, main, in_avals)
   1003     in_tracers = map(trace.new_arg, in_avals)
-> 1004     ans = fun.call_wrapped(*in_tracers)
   1005     out_tracers = map(trace.full_raise, ans)

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
    150     try:
--> 151       ans = self.f(*args, **dict(self.params, **kwargs))
    152     except:

~/projects/Theano-PyMC/theano/sandbox/jax.py in jax_func(*inputs)
    123         def jax_func(*inputs):
--> 124             func_args = [fn(*inputs) for fn in input_funcs]
    125             return return_func(*func_args)

~/projects/Theano-PyMC/theano/sandbox/jax.py in <listcomp>(.0)
    123         def jax_func(*inputs):
--> 124             func_args = [fn(*inputs) for fn in input_funcs]
    125             return return_func(*func_args)

~/projects/Theano-PyMC/theano/sandbox/jax.py in jax_func(*inputs)
    124             func_args = [fn(*inputs) for fn in input_funcs]
--> 125             return return_func(*func_args)
    126 

~/projects/Theano-PyMC/theano/sandbox/jax.py in reshape(x, shape)
    543     def reshape(x, shape):
--> 544         return jnp.reshape(x, shape)
    545 

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/numpy/lax_numpy.py in reshape(a, newshape, order)
   1145   try:
-> 1146     return a.reshape(newshape, order=order)  # forward to method for ndarrays
   1147   except AttributeError:

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/numpy/lax_numpy.py in _reshape_method(a, *newshape, **kwargs)
   1191     newshape = newshape[0]
-> 1192   return _reshape(a, newshape, order=order)
   1193 

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/numpy/lax_numpy.py in _reshape(a, newshape, order)
   1167 def _reshape(a, newshape, order="C"):
-> 1168   computed_newshape = _compute_newshape(a, newshape)
   1169   if order == "C":

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/numpy/lax_numpy.py in _compute_newshape(a, newshape)
   1158       int, size, "The error arose in jax.numpy.reshape.")
-> 1159   newshape = [check(size) for size in newshape] if iterable else check(newshape)
   1160   newsize = _prod((newshape,) if type(newshape) is Poly else newshape)

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/numpy/lax_numpy.py in <listcomp>(.0)
   1158       int, size, "The error arose in jax.numpy.reshape.")
-> 1159   newshape = [check(size) for size in newshape] if iterable else check(newshape)
   1160   newsize = _prod((newshape,) if type(newshape) is Poly else newshape)

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/numpy/lax_numpy.py in check(size)
   1156   def check(size):
-> 1157     return size if type(size) is Poly else core.concrete_or_error(
   1158       int, size, "The error arose in jax.numpy.reshape.")

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/core.py in concrete_or_error(force, val, context)
    873     else:
--> 874       raise_concretization_error(val, context)
    875   else:

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/core.py in raise_concretization_error(val, context)
    852           f"Encountered tracer value: {val}")
--> 853   raise ConcretizationTypeError(msg)
    854 

ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected.

The error arose in jax.numpy.reshape.

While tracing the function incsubtensor at /Users/twiecki/projects/Theano-PyMC/theano/sandbox/jax.py:123, this value became a tracer due to JAX operations on these lines:

  operation yn:bool[] = lt yl:int64[] ym:int64[]
    from line /Users/twiecki/projects/Theano-PyMC/theano/sandbox/jax.py:125 (jax_func)

  operation yq:int64[] = xla_call[ backend=None
                       call_jaxpr={ lambda  ; a b c.
                                    let d = select a b c
                                    in (d,) }
                       device=None
                       donated_invars=(False, False, False)
                       name=_where ] yn:bool[] yo:int64[] yp:int64[]
    from line /Users/twiecki/projects/Theano-PyMC/theano/sandbox/jax.py:125 (jax_func)

You can use transformation parameters such as `static_argnums` for `jit` to avoid tracing particular arguments of transformed functions.

See https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error for more information.

Encountered tracer value: Traced<ShapedArray(int64[])>with<DynamicJaxprTrace(level=0/1)>

During handling of the above exception, another exception occurred:

ConcretizationTypeError                   Traceback (most recent call last)
<ipython-input-1-9b664886d863> in <module>
     21 
     22 with model:
---> 23     pm.sample(mode=jax_mode, chains=1, cores=1)

~/projects/pymc/pymc3/sampling.py in sample(draws, step, init, n_init, start, trace, chain_idx, chains, cores, tune, progressbar, model, random_seed, discard_tuned_samples, compute_convergence_checks, callback, return_inferencedata, idata_kwargs, mp_ctx, pickle_backend, **kwargs)
    479             # By default, try to use NUTS
    480             _log.info("Auto-assigning NUTS sampler...")
--> 481             start_, step = init_nuts(
    482                 init=init,
    483                 chains=chains,

~/projects/pymc/pymc3/sampling.py in init_nuts(init, chains, n_init, model, random_seed, progressbar, **kwargs)
   2132         raise ValueError(f"Unknown initializer: {init}.")
   2133 
-> 2134     step = pm.NUTS(potential=potential, model=model, **kwargs)
   2135 
   2136     return start, step

~/projects/pymc/pymc3/step_methods/hmc/nuts.py in __init__(self, vars, max_treedepth, early_max_treedepth, **kwargs)
    166         `pm.sample` to the desired number of tuning steps.
    167         """
--> 168         super().__init__(vars, **kwargs)
    169 
    170         self.max_treedepth = max_treedepth

~/projects/pymc/pymc3/step_methods/hmc/base_hmc.py in __init__(self, vars, scaling, step_scale, is_cov, model, blocked, potential, dtype, Emax, target_accept, gamma, k, t0, adapt_step_size, step_rand, **theano_kwargs)
     91         vars = inputvars(vars)
     92 
---> 93         super().__init__(vars, blocked=blocked, model=model, dtype=dtype, **theano_kwargs)
     94 
     95         self.adapt_step_size = adapt_step_size

~/projects/pymc/pymc3/step_methods/arraystep.py in __init__(self, vars, model, blocked, dtype, logp_dlogp_func, **theano_kwargs)
    252             func.set_extra_values(model.test_point)
    253             q = func.dict_to_array(model.test_point)
--> 254             logp, dlogp = func(q)
    255         except ValueError:
    256             if logp_dlogp_func is not None:

~/projects/pymc/pymc3/model.py in __call__(self, array, grad_out, extra_vars)
    737             out = grad_out
    738 
--> 739         output = self._theano_function(array)
    740         if grad_out is None:
    741             return output

~/projects/Theano-PyMC/theano/compile/function_module.py in __call__(self, *args, **kwargs)
    977         try:
    978             outputs = (
--> 979                 self.fn()
    980                 if output_subset is None
    981                 else self.fn(output_subset=output_subset)

~/projects/Theano-PyMC/theano/gof/link.py in streamline_default_f()
    705                         old_s[0] = None
    706             except Exception:
--> 707                 raise_with_op(node, thunk)
    708 
    709         f = streamline_default_f

~/projects/Theano-PyMC/theano/gof/link.py in raise_with_op(node, thunk, exc_info, storage_map)
    346         # extra long error message in that case.
    347         pass
--> 348     reraise(exc_type, exc_value, exc_trace)
    349 
    350 

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/six.py in reraise(tp, value, tb)
    700                 value = tp()
    701             if value.__traceback__ is not tb:
--> 702                 raise value.with_traceback(tb)
    703             raise value
    704         finally:

~/projects/Theano-PyMC/theano/gof/link.py in streamline_default_f()
    701                     thunks, order, post_thunk_old_storage
    702                 ):
--> 703                     thunk()
    704                     for old_s in old_storage:
    705                         old_s[0] = None

~/projects/Theano-PyMC/theano/sandbox/jax.py in thunk(node, jax_impl_jits, thunk_outputs)
    659                 node=node, jax_impl_jits=jax_impl_jits, thunk_outputs=thunk_outputs
    660             ):
--> 661                 outputs = [
    662                     jax_impl_jit(*[x[0] for x in thunk_inputs])
    663                     for jax_impl_jit in jax_impl_jits

~/projects/Theano-PyMC/theano/sandbox/jax.py in <listcomp>(.0)
    660             ):
    661                 outputs = [
--> 662                     jax_impl_jit(*[x[0] for x in thunk_inputs])
    663                     for jax_impl_jit in jax_impl_jits
    664                 ]

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/traceback_util.py in reraise_with_filtered_traceback(*args, **kwargs)
    135   def reraise_with_filtered_traceback(*args, **kwargs):
    136     try:
--> 137       return fun(*args, **kwargs)
    138     except Exception as e:
    139       if not is_under_reraiser(e):

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/api.py in f_jitted(*args, **kwargs)
    207       _check_arg(arg)
    208     flat_fun, out_tree = flatten_fun(f, in_tree)
--> 209     out = xla.xla_call(
    210         flat_fun,
    211         *args_flat,

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/core.py in bind(self, fun, *args, **params)
   1142 
   1143   def bind(self, fun, *args, **params):
-> 1144     return call_bind(self, fun, *args, **params)
   1145 
   1146   def process(self, trace, fun, tracers, params):

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/core.py in call_bind(primitive, fun, *args, **params)
   1133   tracers = map(top_trace.full_raise, args)
   1134   with maybe_new_sublevel(top_trace):
-> 1135     outs = primitive.process(top_trace, fun, tracers, params)
   1136   return map(full_lower, apply_todos(env_trace_todo(), outs))
   1137 

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/core.py in process(self, trace, fun, tracers, params)
   1145 
   1146   def process(self, trace, fun, tracers, params):
-> 1147     return trace.process_call(self, fun, tracers, params)
   1148 
   1149   def post_process(self, trace, out_tracers, params):

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/core.py in process_call(self, primitive, f, tracers, params)
    575 
    576   def process_call(self, primitive, f, tracers, params):
--> 577     return primitive.impl(f, *tracers, **params)
    578   process_map = process_call
    579 

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/interpreters/xla.py in _xla_call_impl(fun, device, backend, name, donated_invars, *args)
    527 
    528 def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name, donated_invars):
--> 529   compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
    530                                *unsafe_map(arg_spec, args))
    531   try:

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/linear_util.py in memoized_fun(fun, *args)
    232       fun.populate_stores(stores)
    233     else:
--> 234       ans = call(fun, *args)
    235       cache[key] = (ans, fun.stores)
    236     return ans

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/interpreters/xla.py in _xla_callable(fun, device, backend, name, donated_invars, *arg_specs)
    593   abstract_args, arg_devices = unzip2(arg_specs)
    594   if config.omnistaging_enabled:
--> 595     jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, abstract_args)
    596     if any(isinstance(c, core.Tracer) for c in consts):
    597       raise core.UnexpectedTracerError("Encountered an unexpected tracer.")

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/interpreters/partial_eval.py in trace_to_jaxpr_final(fun, in_avals)
   1021     main.source_info = fun_sourceinfo(fun.f)  # type: ignore
   1022     main.jaxpr_stack = ()  # type: ignore
-> 1023     jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
   1024     del main
   1025   return jaxpr, out_avals, consts

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/interpreters/partial_eval.py in trace_to_subjaxpr_dynamic(fun, main, in_avals)
   1002     trace = DynamicJaxprTrace(main, core.cur_sublevel())
   1003     in_tracers = map(trace.new_arg, in_avals)
-> 1004     ans = fun.call_wrapped(*in_tracers)
   1005     out_tracers = map(trace.full_raise, ans)
   1006   jaxpr, out_avals, consts = frame.to_jaxpr(in_tracers, out_tracers)

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
    149 
    150     try:
--> 151       ans = self.f(*args, **dict(self.params, **kwargs))
    152     except:
    153       # Some transformations yield from inside context managers, so we have to

~/projects/Theano-PyMC/theano/sandbox/jax.py in jax_func(*inputs)
    122 
    123         def jax_func(*inputs):
--> 124             func_args = [fn(*inputs) for fn in input_funcs]
    125             return return_func(*func_args)
    126 

~/projects/Theano-PyMC/theano/sandbox/jax.py in <listcomp>(.0)
    122 
    123         def jax_func(*inputs):
--> 124             func_args = [fn(*inputs) for fn in input_funcs]
    125             return return_func(*func_args)
    126 

~/projects/Theano-PyMC/theano/sandbox/jax.py in jax_func(*inputs)
    123         def jax_func(*inputs):
    124             func_args = [fn(*inputs) for fn in input_funcs]
--> 125             return return_func(*func_args)
    126 
    127         jax_funcs.append(update_wrapper(jax_func, return_func))

~/projects/Theano-PyMC/theano/sandbox/jax.py in reshape(x, shape)
    542 def jax_funcify_Reshape(op):
    543     def reshape(x, shape):
--> 544         return jnp.reshape(x, shape)
    545 
    546     return reshape

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/numpy/lax_numpy.py in reshape(a, newshape, order)
   1144 def reshape(a, newshape, order="C"):
   1145   try:
-> 1146     return a.reshape(newshape, order=order)  # forward to method for ndarrays
   1147   except AttributeError:
   1148     return _reshape(a, newshape, order=order)

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/numpy/lax_numpy.py in _reshape_method(a, *newshape, **kwargs)
   1190           type(newshape[0]) is not Poly):
   1191     newshape = newshape[0]
-> 1192   return _reshape(a, newshape, order=order)
   1193 
   1194 

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/numpy/lax_numpy.py in _reshape(a, newshape, order)
   1166 
   1167 def _reshape(a, newshape, order="C"):
-> 1168   computed_newshape = _compute_newshape(a, newshape)
   1169   if order == "C":
   1170     return lax.reshape(a, computed_newshape, None)

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/numpy/lax_numpy.py in _compute_newshape(a, newshape)
   1157     return size if type(size) is Poly else core.concrete_or_error(
   1158       int, size, "The error arose in jax.numpy.reshape.")
-> 1159   newshape = [check(size) for size in newshape] if iterable else check(newshape)
   1160   newsize = _prod((newshape,) if type(newshape) is Poly else newshape)
   1161   if newsize < 0:

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/numpy/lax_numpy.py in <listcomp>(.0)
   1157     return size if type(size) is Poly else core.concrete_or_error(
   1158       int, size, "The error arose in jax.numpy.reshape.")
-> 1159   newshape = [check(size) for size in newshape] if iterable else check(newshape)
   1160   newsize = _prod((newshape,) if type(newshape) is Poly else newshape)
   1161   if newsize < 0:

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/numpy/lax_numpy.py in check(size)
   1155   else: iterable = True
   1156   def check(size):
-> 1157     return size if type(size) is Poly else core.concrete_or_error(
   1158       int, size, "The error arose in jax.numpy.reshape.")
   1159   newshape = [check(size) for size in newshape] if iterable else check(newshape)

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/core.py in concrete_or_error(force, val, context)
    872       return force(val.aval.val)
    873     else:
--> 874       raise_concretization_error(val, context)
    875   else:
    876     return force(val)

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/core.py in raise_concretization_error(val, context)
    851          "See https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error for more information.\n\n"
    852           f"Encountered tracer value: {val}")
--> 853   raise ConcretizationTypeError(msg)
    854 
    855 

ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected.

The error arose in jax.numpy.reshape.

While tracing the function incsubtensor at /Users/twiecki/projects/Theano-PyMC/theano/sandbox/jax.py:123, this value became a tracer due to JAX operations on these lines:

  operation yn:bool[] = lt yl:int64[] ym:int64[]
    from line /Users/twiecki/projects/Theano-PyMC/theano/sandbox/jax.py:125 (jax_func)

  operation yq:int64[] = xla_call[ backend=None
                       call_jaxpr={ lambda  ; a b c.
                                    let d = select a b c
                                    in (d,) }
                       device=None
                       donated_invars=(False, False, False)
                       name=_where ] yn:bool[] yo:int64[] yp:int64[]
    from line /Users/twiecki/projects/Theano-PyMC/theano/sandbox/jax.py:125 (jax_func)

You can use transformation parameters such as `static_argnums` for `jit` to avoid tracing particular arguments of transformed functions.

See https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error for more information.

Encountered tracer value: Traced<ShapedArray(int64[])>with<DynamicJaxprTrace(level=0/1)>
Apply node that caused the error: IncSubtensor{InplaceInc;int64:int64:}(IncSubtensor{InplaceInc;int64:int64:}.0, Reshape{1}.0, Constant{0}, Constant{1})
Toposort index: 45
Inputs types: [TensorType(float64, vector), TensorType(float64, vector), Scalar(int64), Scalar(int64)]
Inputs shapes: [(3,)]
Inputs strides: [(8,)]
Inputs values: [array([0.69049938, 0.        , 0.        ])]
Outputs clients: [['output']]

Backtrace when the node is created(use Theano flag traceback.limit=N to make it longer):
  File "/Users/twiecki/projects/pymc/pymc3/step_methods/arraystep.py", line 245, in __init__
    func = model.logp_dlogp_function(
  File "/Users/twiecki/projects/pymc/pymc3/model.py", line 1008, in logp_dlogp_function
    return ValueGradFunction(costs, grad_vars, extra_vars, **kwargs)
  File "/Users/twiecki/projects/pymc/pymc3/model.py", line 690, in __init__
    grad = tt.grad(self._cost_joined, self._vars_joined)
  File "/Users/twiecki/projects/Theano-PyMC/theano/gradient.py", line 649, in grad
    rval = _populate_grad_dict(var_to_app_to_idx, grad_dict, wrt, cost_name)
  File "/Users/twiecki/projects/Theano-PyMC/theano/gradient.py", line 1469, in _populate_grad_dict
    rval = [access_grad_cache(elem) for elem in wrt]
  File "/Users/twiecki/projects/Theano-PyMC/theano/gradient.py", line 1469, in <listcomp>
    rval = [access_grad_cache(elem) for elem in wrt]
  File "/Users/twiecki/projects/Theano-PyMC/theano/gradient.py", line 1456, in access_grad_cache
    grad_dict[var] = reduce(lambda x, y: x + y, terms)
  File "/Users/twiecki/projects/Theano-PyMC/theano/gradient.py", line 1456, in <lambda>
    grad_dict[var] = reduce(lambda x, y: x + y, terms)
  File "/Users/twiecki/projects/pymc/pymc3/step_methods/arraystep.py", line 245, in __init__
    func = model.logp_dlogp_function(
  File "/Users/twiecki/projects/pymc/pymc3/model.py", line 1008, in logp_dlogp_function
    return ValueGradFunction(costs, grad_vars, extra_vars, **kwargs)
  File "/Users/twiecki/projects/pymc/pymc3/model.py", line 690, in __init__
    grad = tt.grad(self._cost_joined, self._vars_joined)
  File "/Users/twiecki/projects/Theano-PyMC/theano/gradient.py", line 649, in grad
    rval = _populate_grad_dict(var_to_app_to_idx, grad_dict, wrt, cost_name)
  File "/Users/twiecki/projects/Theano-PyMC/theano/gradient.py", line 1469, in _populate_grad_dict
    rval = [access_grad_cache(elem) for elem in wrt]
  File "/Users/twiecki/projects/Theano-PyMC/theano/gradient.py", line 1469, in <listcomp>
    rval = [access_grad_cache(elem) for elem in wrt]
  File "/Users/twiecki/projects/Theano-PyMC/theano/gradient.py", line 1421, in access_grad_cache
    term = access_term_cache(node)[idx]
  File "/Users/twiecki/projects/Theano-PyMC/theano/gradient.py", line 1231, in access_term_cache
    input_grads = node.op.L_op(inputs, node.outputs, new_output_grads)

HINT: Use the Theano flag 'exception_verbosity=high' for a debugprint and storage map footprint of this apply node.

@brandonwillard
Copy link
Member

brandonwillard commented Sep 26, 2020

Looks like the same error.

Here's the exact code I'm using on e4043ce0b and its output:

import pymc3 as pm
import theano
import numpy as np
import theano.sandbox.jax

theano.compile.mode.predefined_linkers["jax"] = theano.sandbox.jax.JaxLinker()
jax_mode = theano.compile.Mode(linker="jax")

x = np.linspace(0, 1, 10)
y = x * 4. + 1.4 + np.random.randn(10)

with pm.Model() as model:
    beta = pm.Normal("beta", 0., 5., shape=2)
    sigma = pm.HalfNormal("sigma", 2.5)
    obs = pm.Normal("obs", beta[0] + beta[1] * x, sigma, observed=y)
    pm.sample(mode=jax_mode, chains=1, cores=1)
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
/home/bwillard/apps/anaconda3/envs/theano-36/lib/python3.7/site-packages/jax/lib/xla_bridge.py:130: UserWarning: No GPU/TPU found, falling back to CPU.
  warnings.warn('No GPU/TPU found, falling back to CPU.')
Sequential sampling (1 chains in 1 job)
INFO:pymc3:Sequential sampling (1 chains in 1 job)
NUTS: [sigma, beta]
INFO:pymc3:NUTS: [sigma, beta]
Sampling 1 chain for 1_000 tune and 1_000 draw iterations (1_000 + 1_000 draws total) took 22 seconds.
INFO:pymc3:Sampling 1 chain for 1_000 tune and 1_000 draw iterations (1_000 + 1_000 draws total) took 22 seconds.
There was 1 divergence after tuning. Increase `target_accept` or reparameterize.
ERROR:pymc3:There was 1 divergence after tuning. Increase `target_accept` or reparameterize.
Only one chain was sampled, this makes it impossible to run some convergence checks
INFO:pymc3:Only one chain was sampled, this makes it impossible to run some convergence checks
Out[2]: <MultiTrace: 1 chains, 1000 iterations, 3 variables>

Perhaps it's a difference in the jax and jaxlib versions; this is what I have:

jax                           0.1.75
jaxlib                        0.1.52

@twiecki
Copy link
Contributor Author

twiecki commented Sep 26, 2020

jax 0.2.0
jaxlib 0.1.55

@brandonwillard
Copy link
Member

Ah, I'll try those versions.

Otherwise, I just pushed a update that puts the JAX Linker in the predefined_linkers. With that, the following works:

import multiprocessing as mp

import theano

import numpy as np

import pymc3 as pm

ctx = mp.get_context('spawn')

jax_mode = theano.compile.Mode(linker="jax")

x = np.linspace(0, 1, 10)
y = x * 4. + 1.4 + np.random.randn(10)

with pm.Model() as model:
    beta = pm.Normal("beta", 0., 5., shape=2)
    sigma = pm.HalfNormal("sigma", 2.5)
    obs = pm.Normal("obs", beta[0] + beta[1] * x, sigma, observed=y)

with model:
    pm.sample(mode=jax_mode, mp_ctx=ctx)
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
~/apps/anaconda3/envs/theano-36/lib/python3.7/site-packages/jax/lib/xla_bridge.py:130: UserWarning: No GPU/TPU found, falling back to CPU.
  warnings.warn('No GPU/TPU found, falling back to CPU.')
Multiprocess sampling (4 chains in 4 jobs)
INFO:pymc3:Multiprocess sampling (4 chains in 4 jobs)
NUTS: [sigma, beta]
INFO:pymc3:NUTS: [sigma, beta]
~/apps/anaconda3/envs/theano-36/lib/python3.7/site-packages/jax/lib/xla_bridge.py:130: UserWarning: No GPU/TPU found, falling back to CPU.
  warnings.warn('No GPU/TPU found, falling back to CPU.')
~/apps/anaconda3/envs/theano-36/lib/python3.7/site-packages/jax/lib/xla_bridge.py:130: UserWarning: No GPU/TPU found, falling back to CPU.
  warnings.warn('No GPU/TPU found, falling back to CPU.')
~/apps/anaconda3/envs/theano-36/lib/python3.7/site-packages/jax/lib/xla_bridge.py:130: UserWarning: No GPU/TPU found, falling back to CPU.
  warnings.warn('No GPU/TPU found, falling back to CPU.')
~/apps/anaconda3/envs/theano-36/lib/python3.7/site-packages/jax/lib/xla_bridge.py:130: UserWarning: No GPU/TPU found, falling back to CPU.
  warnings.warn('No GPU/TPU found, falling back to CPU.')
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 67 seconds.
INFO:pymc3:Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 67 seconds.
There were 16 divergences after tuning. Increase `target_accept` or reparameterize.
ERROR:pymc3:There were 16 divergences after tuning. Increase `target_accept` or reparameterize.
There were 2 divergences after tuning. Increase `target_accept` or reparameterize.
ERROR:pymc3:There were 2 divergences after tuning. Increase `target_accept` or reparameterize.
The number of effective samples is smaller than 25% for some parameters.
INFO:pymc3:The number of effective samples is smaller than 25% for some parameters.
Out[3]: <MultiTrace: 4 chains, 1000 iterations, 3 variables>

@twiecki
Copy link
Contributor Author

twiecki commented Sep 26, 2020

That's awesome!

@brandonwillard
Copy link
Member

brandonwillard commented Sep 26, 2020

OK, I got the same error as you using jax 0.2.0. It looks like the change that introduced the error might've been in this part of this commit. First, let me see if jax.numpy.reshape was previously being given the same inputs.

@brandonwillard
Copy link
Member

For anyone who's interested and more familiar with JAX, here's a MWE of the problem under jax 0.2.0 and jaxlib 0.1.55:

import numpy as np

import jax.numpy as jnp


x = np.zeros((2 * 3))
z = (2, 3)

expected_res = np.reshape(x, np.array(z, dtype=np.int))


def b(z):
    return jnp.array(z, dtype=np.int)

def a(x, z):
    return jnp.reshape(x, b(z))


jax_res = jax.jit(a)(x, z)
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected.

The error arose in jax.numpy.reshape.

While tracing the function a at <ipython-input-26-182028b110f7>:10, this value became a tracer due to JAX operations on these lines:

  operation d:int64[1] = broadcast_in_dim[ broadcast_dimensions=(  )
                               shape=(1,) ] b:int64[]
    from line <ipython-input-26-182028b110f7>:8 (b)

  operation e:int64[1] = broadcast_in_dim[ broadcast_dimensions=(  )
                               shape=(1,) ] c:int64[]
    from line <ipython-input-26-182028b110f7>:8 (b)

You can use transformation parameters such as `static_argnums` for `jit` to avoid tracing particular arguments of transformed functions.

See https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error for more information.

Encountered tracer value: Traced<ShapedArray(int64[])>with<DynamicJaxprTrace(level=0/1)>

@twiecki
Copy link
Contributor Author

twiecki commented Sep 26, 2020 via email

@junpenglao
Copy link
Contributor

If I understand correctly, Jax does not support jitting reshape as the shape is dynamic.

@junpenglao
Copy link
Contributor

junpenglao commented Sep 26, 2020

Oh so they use to support it, hmm maybe marking the shape as static will help?

ie this works

jax_res = jax.jit(a, static_argnums=(1))(x, z)

@brandonwillard
Copy link
Member

maybe marking the shape as static will help?

The actual graph is likely much more complex, and I don't know if there's an argument in the JITed function that's actually static and maps directly to the shape like that example does.

More importantly, why isn't it supported anymore?

@junpenglao
Copy link
Contributor

I think it is related to a change call omnistaging: google/jax#3370

@junpenglao
Copy link
Contributor

TFP also have to made quite a bit of changes due to omnistaging in Jax: https://github.com/tensorflow/probability/search?q=omnistaging&type=commits
I think the general strategy is to push static computation to numpy as much as possible

@twiecki
Copy link
Contributor Author

twiecki commented Sep 26, 2020 via email

@junpenglao
Copy link
Contributor

For reshape specifically: tensorflow/probability@782d0c6

@twiecki twiecki changed the title JAX: pm.sample() throws error for simple model Jax 0.2 does not support jitting reshape Sep 26, 2020
@brandonwillard
Copy link
Member

Here's the sub-graph causing the problem:

from theano.printing import debugprint as tt_dprint


dlogp_fn = model.logp_dlogp_function(mode=jax_mode)
dlogp_fgraph = dlogp_fn._theano_function.maker.fgraph

tt_dprint(dlogp_fgraph.outputs[1].owner.inputs[1])
Reshape{1} [id A] ''   
 |Elemwise{Composite{(Switch(i0, (i1 * i2 * i2), i3) + i4 + (i5 * (((i6 * i7 * Composite{inv(Composite{(sqr(i0) * i0)}(i0))}(i2)) / i8) - (i9 * Composite{inv(Composite{(sqr(i0) * i0)}(i0))}(i2))) * i2))}}[(0, 7)] [id B] '(d__logp/dsigma_log__)'   
 | |Elemwise{Composite{Cast{int8}(GE(i0, i1))}} [id C] ''   
 | | |Elemwise{exp,no_inplace} [id D] 'sigma'   
 | | | ...
 |MakeVector{dtype='int64'} [id BL] ''   
   |Elemwise{Composite{(Composite{Switch(LT(i0, i1), i1, i0)}(Composite{Switch(GE(i0, i1), i1, i0)}(i0, i1), i2) - Switch(LT(Composite{Switch(LT(i0, i1), i1, i0)}(Composite{Switch(GE(i0, i1), i1, i0)}(i3, i1), i2), Composite{Switch(LT(i0, i1), i1, i0)}(Composite{Switch(GE(i0, i1), i1, i0)}(i0, i1), i2)), Composite{Switch(LT(i0, i1), i1, i0)}(Composite{Switch(GE(i0, i1), i1, i0)}(i3, i1), i2), Composite{Switch(LT(i0, i1), i1, i0)}(Composite{Switch(GE(i0, i1), i1, i0)}(i0, i1), i2)))}}[(0, 1)] [id BM] ''   
     |TensorConstant{3} [id BN]
     |Shape_i{0} [id BO] ''   
     | |__args_joined [id G]
     |TensorConstant{0} [id K]
     |TensorConstant{2} [id BP]

The graph tells us that the shape parameter to jax.numpy.reshape is coming from a MakeVector Op that's the result of a bunch of conditional statements over some constants and a Shape_i Op. The Shape_i is simply getting the 0th shape value of the __args_joined, which is the only input variable of the FunctionGraph we extracted named dlogp_fgraph.

This means that a more representative MWE would probably be as follows:

def d(y, z):
    return jnp.shape(y)[z]

def c(y, z):
    return jnp.where(d(y, 0) / z > 0, d(y, 0) / z, 0)

def b(y):
    return jnp.array([c(y, 2), c(y, 3)], dtype=np.int)

def a(y):
    return jnp.reshape(y, b(y))

jax_res = jax.jit(a)(x)

Unfortunately, I don't think we can use static_argnums in this case.

@brandonwillard
Copy link
Member

brandonwillard commented Sep 26, 2020

I just pushed a commit that disables omnistaging by default. That should allow jax > 0.2.0.

I also found this example, which is essentially our problem, and it says that the solution is to use NumPy to compute the shape. @junpenglao is there a straightforward way to force a NumPy computation of those DynamicJaxprTracers, or is this solution not relevant to our situation?

@twiecki
Copy link
Contributor Author

twiecki commented Sep 27, 2020

It works for me with that work-around now 👍 .

@junpenglao
Copy link
Contributor

Let's keep this opened as we should aim to fix the omnistaging issue.

@junpenglao junpenglao reopened this Sep 28, 2020
@twiecki twiecki changed the title Jax 0.2 does not support jitting reshape Jax 0.2 does not support jitting reshape in omnistaging mode Sep 28, 2020
@brandonwillard brandonwillard changed the title Jax 0.2 does not support jitting reshape in omnistaging mode Jax 0.2 does not support jax.numpy.reshape with non-constant values in omnistaging mode Sep 28, 2020
@brandonwillard
Copy link
Member

@junpenglao, does the solution to this particular issue involve jax.lax.reshape? Is that the level at which symbolic inputs are viable?

@junpenglao
Copy link
Contributor

I think there are a few level of fix we could think about:

  • theano, where we enforce static shape and rewrite the shape issue @brandonwillard you mentioned to me to always carries static shape around
  • pymc3, which we get rid of the reshape in model.logp_dlogp_fun

In the short term, I need to understand better how the reshape is done in theano - IIUC, when jax.jit is applied, jnp.reshape is transfer into jax.lax.reshape.
For example, in the example above, we need to make sure all shape related computation are done in numpy:

def d(y, z):
    return np.shape(y)[z]

def c(y, z):
    return np.where(d(y, 0) / z > 0, d(y, 0) / z, 0)

def b(y):
    return [c(y, 2), c(y, 3)]

def a(y):
    return jnp.reshape(y, b(y))

jax_res = jax.jit(a)(x)

@brandonwillard
Copy link
Member

brandonwillard commented Sep 29, 2020

  • theano, where we enforce static shape

We can't do this for all Theano graphs; that would remove a wide array of valuable Theano capabilities!

@junpenglao
Copy link
Contributor

I see - could we add a static reshape mode for pymc3 instead? jax does not really support dynamic shape anyway so if we want the jax backend to do dynamic graph stuff it would be pretty difficult.

@brandonwillard brandonwillard added the question Further information is requested label Oct 13, 2020
@brandonwillard brandonwillard removed their assignment Oct 13, 2020
@brandonwillard brandonwillard added enhancement New feature or request and removed bug Something isn't working labels Oct 13, 2020
@brandonwillard
Copy link
Member

It seems like we should dig into the lower-level aspects of jax and see if we can take a more direct approach from our end. One that doesn't go through these omnistaging changes, for instance.

This seems like the correct approach if only because it might also provide solutions to similar symbolic limitations we've encountered (e.g. #68).

@twiecki
Copy link
Contributor Author

twiecki commented Oct 14, 2020 via email

@brandonwillard
Copy link
Member

Yeah, that's what I was thinking.

@lucianopaz
Copy link
Contributor

From a jax post explaining omnistaging it looks like using numpy.reshape instead of jax.numpy.reshape is the correct way to solve this reshape problem.

@twiecki
Copy link
Contributor Author

twiecki commented Apr 3, 2021

but that would incur Python call overhead, no? I think that's just a poor workaround.

@brandonwillard
Copy link
Member

brandonwillard commented Apr 3, 2021

From a jax post explaining omnistaging it looks like using numpy.reshape instead of jax.numpy.reshape is the correct way to solve this reshape problem.

but that would incur Python call overhead, no? I think that's just a poor workaround.

The last I recall from reading and experimenting with all that is that they're actually saying such graphs are no longer possible, so do everything involving np.reshape before using JAX.

We can handle that by (re)allowing a mix of Python and JAX thunks, of course, but that's much less ideal than a single JAX compiled/JITed function, especially when it comes to interactions with other JAX code (e.g. JAX-based sampler functions).

@twiecki
Copy link
Contributor Author

twiecki commented Apr 3, 2021 via email

@brandonwillard
Copy link
Member

brandonwillard commented Apr 3, 2021

Can numba do that?

Can it compile a function that uses np.reshape? Yes, and so can Cython.

Here's an example:

import numpy as np
import numba


@numba.njit
def testfn(x, shape):
    return np.reshape(x, shape)
>>> testfn(np.ones(10), (5, 2))
array([[1., 1.],
       [1., 1.],
       [1., 1.],
       [1., 1.],
       [1., 1.]])

As I understand it, JAX isn't meant to be an all-purpose JITer; it's constrained by its connections to XLA, a specific domain of work/relevance, etc.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request JAX Involves JAX transpilation question Further information is requested
Projects
None yet
Development

No branches or pull requests

4 participants