However, I first did a quick non-graphical model test of JAX with the LH function using grad:
analysis = Analysis(dataset=dataset_list[0])
search = af.DynestyStatic(
name="cancer_example_0",
nlive=100,
# sample="rwalk",
sample="hslice",
use_gradient=True,
iterations_per_update=10000,
number_of_cores=1,
)
result = search.fit(model=model, analysis=analysis)
if self.use_gradient:
gradient = GradWrapper(fitness)
else:
gradient = None
And dynesty uses the gradient for sampling.
2025-03-14 10:57:10,294 - autofit.non_linear.initializer - INFO - Generating initial samples of model using 1 cores
2025-03-14 10:57:10,936 - autofit.non_linear.initializer - INFO - Initial samples generated, starting non-linear search
/home/jammy/venvs/PyAuto311JAX/lib/python3.11/site-packages/dynesty/dynesty.py:191: UserWarning: Specifying walks option while using slice sampler does not make sense
warnings.warn('Specifying walks option while using slice sampler'
343it [00:05, 20.82it/s, bound: 0 | nc: 60 | ncall: 3408 | eff(%): 10.065 | loglstar: -inf < -501.684 < inf | logz: -509.835 +/- 0.269 | dlogz: 466.491 > 0.109]Compiling gradient
Exception while calling gradient function:
params: [ 1.63037407 3.17313047 -1.08876391 -1.5558232 3.98010488 -4.19731134
1.03857219 -1.27849697 -0.9893693 ]
args: []
kwargs: {}
exception:
Traceback (most recent call last):
File "/mnt/c/Users/Jammy/Code/PyAutoJAX/PyAutoFit/autofit/non_linear/search/nest/dynesty/search/abstract.py", line 153, in _fit
raise RuntimeError
RuntimeError
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/home/jammy/venvs/PyAuto311JAX/lib/python3.11/site-packages/jax/_src/core.py", line 720, in __float__
return self.aval._float(self)
^^^^^^^^^^^^^^^^^^^^^^
File "/home/jammy/venvs/PyAuto311JAX/lib/python3.11/site-packages/jax/_src/core.py", line 1481, in error
raise ConcretizationTypeError(arg, fname_context)
jax.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape float64[].
The problem arose with the `float` function. If trying to convert the data type of a value, try using `x.astype(float)` or `jnp.array(x, float)` instead.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/jammy/venvs/PyAuto311JAX/lib/python3.11/site-packages/dynesty/dynesty.py", line 913, in __call__
return self.func(np.asarray(x).copy(), *self.args, **self.kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/c/Users/Jammy/Code/PyAutoJAX/PyAutoFit/autofit/non_linear/search/nest/dynesty/search/static.py", line 34, in __call__
return self.grad(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jammy/venvs/PyAuto311JAX/lib/python3.11/site-packages/timeout_decorator/timeout_decorator.py", line 79, in new_function
return function(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/c/Users/Jammy/Code/PyAutoJAX/PyAutoFit/autofit/non_linear/fitness.py", line 161, in __call__
except exc.FitException:
File "/mnt/c/Users/Jammy/Code/PyAutoJAX/PyAutoFit/autofit/mapper/prior_model/abstract.py", line 786, in instance_from_vector
return self.instance_for_arguments(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/c/Users/Jammy/Code/PyAutoJAX/PyAutoFit/autofit/mapper/prior_model/abstract.py", line 1309, in instance_for_arguments
return self._instance_for_arguments(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/c/Users/Jammy/Code/PyAutoJAX/PyAutoFit/autofit/mapper/prior_model/collection.py", line 231, in _instance_for_arguments
value = value.instance_for_arguments(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/c/Users/Jammy/Code/PyAutoJAX/PyAutoFit/autofit/mapper/prior_model/abstract.py", line 1309, in instance_for_arguments
return self._instance_for_arguments(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/c/Users/Jammy/Code/PyAutoJAX/PyAutoFit/autofit/mapper/prior_model/array.py", line 90, in _instance_for_arguments
array[index] = value
~~~~~^^^^^^^
ValueError: setting an array element with a sequence.
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
345it [00:05, 58.27it/s, bound: 0 | nc: 35 | ncall: 3459 | eff(%): 9.974 | loglstar: -inf < -494.610 < inf | logz: -503.187 +/- 0.286 | dlogz: 461.312 > 0.109]
Traceback (most recent call last):
File "/mnt/c/Users/Jammy/Code/PyAutoJAX/PyAutoFit/autofit/non_linear/search/nest/dynesty/search/abstract.py", line 153, in _fit
raise RuntimeError
RuntimeError
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/home/jammy/venvs/PyAuto311JAX/lib/python3.11/site-packages/jax/_src/core.py", line 720, in __float__
return self.aval._float(self)
^^^^^^^^^^^^^^^^^^^^^^
File "/home/jammy/venvs/PyAuto311JAX/lib/python3.11/site-packages/jax/_src/core.py", line 1481, in error
raise ConcretizationTypeError(arg, fname_context)
jax.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape float64[].
The problem arose with the `float` function. If trying to convert the data type of a value, try using `x.astype(float)` or `jnp.array(x, float)` instead.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/mnt/c/Users/Jammy/Code/PyAuto/concr_cosmology/start_here.py", line 451, in <module>
result = search.fit(model=model, analysis=analysis)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/c/Users/Jammy/Code/PyAutoJAX/PyAutoFit/autofit/non_linear/search/abstract_search.py", line 599, in fit
result = self.start_resume_fit(
^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/c/Users/Jammy/Code/PyAutoJAX/PyAutoFit/autofit/non_linear/search/abstract_search.py", line 120, in decorated
return func(self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/c/Users/Jammy/Code/PyAutoJAX/PyAutoFit/autofit/non_linear/search/abstract_search.py", line 750, in start_resume_fit
search_internal = self._fit(
^^^^^^^^^^
File "/mnt/c/Users/Jammy/Code/PyAutoJAX/PyAutoFit/autofit/non_linear/search/nest/dynesty/search/abstract.py", line 192, in _fit
finished = self.run_search_internal(search_internal=search_internal)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/c/Users/Jammy/Code/PyAutoJAX/PyAutoFit/autofit/non_linear/search/nest/dynesty/search/abstract.py", line 345, in run_search_internal
search_internal.run_nested(
File "/home/jammy/venvs/PyAuto311JAX/lib/python3.11/site-packages/dynesty/sampler.py", line 1025, in run_nested
for it, results in enumerate(
File "/home/jammy/venvs/PyAuto311JAX/lib/python3.11/site-packages/dynesty/sampler.py", line 848, in sample
u, v, logl, nc = self._new_point(loglstar_new)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jammy/venvs/PyAuto311JAX/lib/python3.11/site-packages/dynesty/sampler.py", line 421, in _new_point
u, v, logl, nc, blob = self._get_point_value(loglstar)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jammy/venvs/PyAuto311JAX/lib/python3.11/site-packages/dynesty/sampler.py", line 404, in _get_point_value
self._fill_queue(loglstar)
File "/home/jammy/venvs/PyAuto311JAX/lib/python3.11/site-packages/dynesty/sampler.py", line 397, in _fill_queue
self.queue = list(mapper(evolve_point, args))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jammy/venvs/PyAuto311JAX/lib/python3.11/site-packages/dynesty/sampling.py", line 1028, in sample_hslice
h = grad(v_l)
^^^^^^^^^
File "/home/jammy/venvs/PyAuto311JAX/lib/python3.11/site-packages/dynesty/dynesty.py", line 913, in __call__
return self.func(np.asarray(x).copy(), *self.args, **self.kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/c/Users/Jammy/Code/PyAutoJAX/PyAutoFit/autofit/non_linear/search/nest/dynesty/search/static.py", line 34, in __call__
return self.grad(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jammy/venvs/PyAuto311JAX/lib/python3.11/site-packages/timeout_decorator/timeout_decorator.py", line 79, in new_function
return function(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/c/Users/Jammy/Code/PyAutoJAX/PyAutoFit/autofit/non_linear/fitness.py", line 155, in __call__
instance = self.model.instance_from_vector(vector=parameters)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/c/Users/Jammy/Code/PyAutoJAX/PyAutoFit/autofit/mapper/prior_model/abstract.py", line 786, in instance_from_vector
return self.instance_for_arguments(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/c/Users/Jammy/Code/PyAutoJAX/PyAutoFit/autofit/mapper/prior_model/abstract.py", line 1309, in instance_for_arguments
return self._instance_for_arguments(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/c/Users/Jammy/Code/PyAutoJAX/PyAutoFit/autofit/mapper/prior_model/collection.py", line 231, in _instance_for_arguments
value = value.instance_for_arguments(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/c/Users/Jammy/Code/PyAutoJAX/PyAutoFit/autofit/mapper/prior_model/abstract.py", line 1309, in instance_for_arguments
return self._instance_for_arguments(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/c/Users/Jammy/Code/PyAutoJAX/PyAutoFit/autofit/mapper/prior_model/array.py", line 90, in _instance_for_arguments
array[index] = value
~~~~~^^^^^^^
ValueError: setting an array element with a sequence.
--------------------
I have updated the
concr_cosmologyrepo to test JAX graphical models.However, I first did a quick non-graphical model test of JAX with the LH function using grad:
https://github.com/Jammy2211/concr_cosmology/blob/main/start_here.py
The
use_gradient=Trueandsample="hslice"means that the followingjax.gradis called:And dynesty uses the gradient for sampling.
However, after ~3000 samples, the following exception is raised:
I am working on
feature/graphical_pytreesbut I don't think this is related to graphical models, but an issue with the use ofaf.Arrayand a gradded likelihood function.