Skip to content

Error about NumPyro

ykawashima edited this page Nov 5, 2021 · 2 revisions

cannot perform MCMC run with NumPyro

The following error occurs when you try to perform MCMC run with NumPyro. This happens at least with the combination of JAX 0.2.16, jaxlib 0.1.68+cuda110, and NumPyro 0.6.0.

Traceback (most recent call last):
  File "mcmc.py", line 162, in <module>
    mcmc.run(rng_key_, y1=nflux)
  File "/home/kawashimayi/anaconda3/lib/python3.8/site-packages/numpyro/infer/mcmc.py", line 498, in run
    states_flat, last_state = partial_map_fn(map_args)
  File "/home/kawashimayi/anaconda3/lib/python3.8/site-packages/numpyro/infer/mcmc.py", line 333, in _single_chain_mcmc
    init_state = self.sampler.init(rng_key, self.num_warmup, init_params,
  File "/home/kawashimayi/anaconda3/lib/python3.8/site-packages/numpyro/infer/hmc.py", line 505, in init
    init_state = hmc_init_fn(init_params, rng_key)
  File "/home/kawashimayi/anaconda3/lib/python3.8/site-packages/numpyro/infer/hmc.py", line 488, in <lambda>
    hmc_init_fn = lambda init_params, rng_key: self._init_fn(  # noqa: E731
  File "/home/kawashimayi/anaconda3/lib/python3.8/site-packages/numpyro/infer/hmc.py", line 211, in init_kernel
    trajectory_length = lax.convert_element_type(trajectory_length, jnp.result_type(float))
  File "/home/kawashimayi/anaconda3/lib/python3.8/site-packages/jax/_src/lax/lax.py", line 425, in convert_element_type
    return _convert_element_type(operand, new_dtype, weak_type=False)
  File "/home/kawashimayi/anaconda3/lib/python3.8/site-packages/jax/_src/lax/lax.py", line 454, in _convert_element_type
    return convert_element_type_p.bind(operand, new_dtype=new_dtype,
  File "/home/kawashimayi/anaconda3/lib/python3.8/site-packages/jax/core.py", line 264, in bind
    out = top_trace.process_primitive(self, tracers, params)
  File "/home/kawashimayi/anaconda3/lib/python3.8/site-packages/jax/core.py", line 603, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/home/kawashimayi/anaconda3/lib/python3.8/site-packages/jax/interpreters/xla.py", line 248, in apply_primitive
    compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args), **params)
  File "/home/kawashimayi/anaconda3/lib/python3.8/site-packages/jax/interpreters/xla.py", line 240, in arg_spec
    aval = abstractify(x)
  File "/home/kawashimayi/anaconda3/lib/python3.8/site-packages/jax/interpreters/xla.py", line 186, in abstractify
    raise TypeError(f"Argument '{x}' of type '{type(x)}' is not a valid JAX type")
TypeError: Argument 'None' of type '<class 'NoneType'>' is not a valid JAX type

Solution: Update numpyro to 0.7.0 (see also this website)