Skip to content

Issue with JAX Grad of Array #1117

@Jammy2211

Description

@Jammy2211

I have updated the concr_cosmology repo to test JAX graphical models.

However, I first did a quick non-graphical model test of JAX with the LH function using grad:

https://github.com/Jammy2211/concr_cosmology/blob/main/start_here.py

analysis = Analysis(dataset=dataset_list[0])

search = af.DynestyStatic(
    name="cancer_example_0",
    nlive=100,
#    sample="rwalk",
    sample="hslice",
    use_gradient=True,
    iterations_per_update=10000,
    number_of_cores=1,
)

result = search.fit(model=model, analysis=analysis)

The use_gradient=True and sample="hslice" means that the following jax.grad is called:

        if self.use_gradient:
            gradient = GradWrapper(fitness)
        else:
            gradient = None

And dynesty uses the gradient for sampling.

However, after ~3000 samples, the following exception is raised:

2025-03-14 10:57:10,294 - autofit.non_linear.initializer - INFO - Generating initial samples of model using 1 cores
2025-03-14 10:57:10,936 - autofit.non_linear.initializer - INFO - Initial samples generated, starting non-linear search
/home/jammy/venvs/PyAuto311JAX/lib/python3.11/site-packages/dynesty/dynesty.py:191: UserWarning: Specifying walks option while using slice sampler does not make sense
  warnings.warn('Specifying walks option while using slice sampler'
343it [00:05, 20.82it/s, bound: 0 | nc: 60 | ncall: 3408 | eff(%): 10.065 | loglstar:   -inf < -501.684 <    inf | logz: -509.835 +/-  0.269 | dlogz: 466.491 >  0.109]Compiling gradient
Exception while calling gradient function:
  params: [ 1.63037407  3.17313047 -1.08876391 -1.5558232   3.98010488 -4.19731134
  1.03857219 -1.27849697 -0.9893693 ]
  args: []
  kwargs: {}
  exception:
Traceback (most recent call last):
  File "/mnt/c/Users/Jammy/Code/PyAutoJAX/PyAutoFit/autofit/non_linear/search/nest/dynesty/search/abstract.py", line 153, in _fit
    raise RuntimeError
RuntimeError

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/jammy/venvs/PyAuto311JAX/lib/python3.11/site-packages/jax/_src/core.py", line 720, in __float__
    return self.aval._float(self)
           ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jammy/venvs/PyAuto311JAX/lib/python3.11/site-packages/jax/_src/core.py", line 1481, in error
    raise ConcretizationTypeError(arg, fname_context)
jax.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape float64[].
The problem arose with the `float` function. If trying to convert the data type of a value, try using `x.astype(float)` or `jnp.array(x, float)` instead.

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/jammy/venvs/PyAuto311JAX/lib/python3.11/site-packages/dynesty/dynesty.py", line 913, in __call__
    return self.func(np.asarray(x).copy(), *self.args, **self.kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/c/Users/Jammy/Code/PyAutoJAX/PyAutoFit/autofit/non_linear/search/nest/dynesty/search/static.py", line 34, in __call__
    return self.grad(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jammy/venvs/PyAuto311JAX/lib/python3.11/site-packages/timeout_decorator/timeout_decorator.py", line 79, in new_function
    return function(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/c/Users/Jammy/Code/PyAutoJAX/PyAutoFit/autofit/non_linear/fitness.py", line 161, in __call__
    except exc.FitException:
  File "/mnt/c/Users/Jammy/Code/PyAutoJAX/PyAutoFit/autofit/mapper/prior_model/abstract.py", line 786, in instance_from_vector
    return self.instance_for_arguments(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/c/Users/Jammy/Code/PyAutoJAX/PyAutoFit/autofit/mapper/prior_model/abstract.py", line 1309, in instance_for_arguments
    return self._instance_for_arguments(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/c/Users/Jammy/Code/PyAutoJAX/PyAutoFit/autofit/mapper/prior_model/collection.py", line 231, in _instance_for_arguments
    value = value.instance_for_arguments(
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/c/Users/Jammy/Code/PyAutoJAX/PyAutoFit/autofit/mapper/prior_model/abstract.py", line 1309, in instance_for_arguments
    return self._instance_for_arguments(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/c/Users/Jammy/Code/PyAutoJAX/PyAutoFit/autofit/mapper/prior_model/array.py", line 90, in _instance_for_arguments
    array[index] = value
    ~~~~~^^^^^^^
ValueError: setting an array element with a sequence.
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
345it [00:05, 58.27it/s, bound: 0 | nc: 35 | ncall: 3459 | eff(%):  9.974 | loglstar:   -inf < -494.610 <    inf | logz: -503.187 +/-  0.286 | dlogz: 461.312 >  0.109]
Traceback (most recent call last):
  File "/mnt/c/Users/Jammy/Code/PyAutoJAX/PyAutoFit/autofit/non_linear/search/nest/dynesty/search/abstract.py", line 153, in _fit
    raise RuntimeError
RuntimeError

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/jammy/venvs/PyAuto311JAX/lib/python3.11/site-packages/jax/_src/core.py", line 720, in __float__
    return self.aval._float(self)
           ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jammy/venvs/PyAuto311JAX/lib/python3.11/site-packages/jax/_src/core.py", line 1481, in error
    raise ConcretizationTypeError(arg, fname_context)
jax.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape float64[].
The problem arose with the `float` function. If trying to convert the data type of a value, try using `x.astype(float)` or `jnp.array(x, float)` instead.

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/mnt/c/Users/Jammy/Code/PyAuto/concr_cosmology/start_here.py", line 451, in <module>
    result = search.fit(model=model, analysis=analysis)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/c/Users/Jammy/Code/PyAutoJAX/PyAutoFit/autofit/non_linear/search/abstract_search.py", line 599, in fit
    result = self.start_resume_fit(
             ^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/c/Users/Jammy/Code/PyAutoJAX/PyAutoFit/autofit/non_linear/search/abstract_search.py", line 120, in decorated
    return func(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/c/Users/Jammy/Code/PyAutoJAX/PyAutoFit/autofit/non_linear/search/abstract_search.py", line 750, in start_resume_fit
    search_internal = self._fit(
                      ^^^^^^^^^^
  File "/mnt/c/Users/Jammy/Code/PyAutoJAX/PyAutoFit/autofit/non_linear/search/nest/dynesty/search/abstract.py", line 192, in _fit
    finished = self.run_search_internal(search_internal=search_internal)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/c/Users/Jammy/Code/PyAutoJAX/PyAutoFit/autofit/non_linear/search/nest/dynesty/search/abstract.py", line 345, in run_search_internal
    search_internal.run_nested(
  File "/home/jammy/venvs/PyAuto311JAX/lib/python3.11/site-packages/dynesty/sampler.py", line 1025, in run_nested
    for it, results in enumerate(
  File "/home/jammy/venvs/PyAuto311JAX/lib/python3.11/site-packages/dynesty/sampler.py", line 848, in sample
    u, v, logl, nc = self._new_point(loglstar_new)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jammy/venvs/PyAuto311JAX/lib/python3.11/site-packages/dynesty/sampler.py", line 421, in _new_point
    u, v, logl, nc, blob = self._get_point_value(loglstar)
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jammy/venvs/PyAuto311JAX/lib/python3.11/site-packages/dynesty/sampler.py", line 404, in _get_point_value
    self._fill_queue(loglstar)
  File "/home/jammy/venvs/PyAuto311JAX/lib/python3.11/site-packages/dynesty/sampler.py", line 397, in _fill_queue
    self.queue = list(mapper(evolve_point, args))
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jammy/venvs/PyAuto311JAX/lib/python3.11/site-packages/dynesty/sampling.py", line 1028, in sample_hslice
    h = grad(v_l)
        ^^^^^^^^^
  File "/home/jammy/venvs/PyAuto311JAX/lib/python3.11/site-packages/dynesty/dynesty.py", line 913, in __call__
    return self.func(np.asarray(x).copy(), *self.args, **self.kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/c/Users/Jammy/Code/PyAutoJAX/PyAutoFit/autofit/non_linear/search/nest/dynesty/search/static.py", line 34, in __call__
    return self.grad(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jammy/venvs/PyAuto311JAX/lib/python3.11/site-packages/timeout_decorator/timeout_decorator.py", line 79, in new_function
    return function(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/c/Users/Jammy/Code/PyAutoJAX/PyAutoFit/autofit/non_linear/fitness.py", line 155, in __call__
    instance = self.model.instance_from_vector(vector=parameters)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/c/Users/Jammy/Code/PyAutoJAX/PyAutoFit/autofit/mapper/prior_model/abstract.py", line 786, in instance_from_vector
    return self.instance_for_arguments(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/c/Users/Jammy/Code/PyAutoJAX/PyAutoFit/autofit/mapper/prior_model/abstract.py", line 1309, in instance_for_arguments
    return self._instance_for_arguments(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/c/Users/Jammy/Code/PyAutoJAX/PyAutoFit/autofit/mapper/prior_model/collection.py", line 231, in _instance_for_arguments
    value = value.instance_for_arguments(
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/c/Users/Jammy/Code/PyAutoJAX/PyAutoFit/autofit/mapper/prior_model/abstract.py", line 1309, in instance_for_arguments
    return self._instance_for_arguments(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/c/Users/Jammy/Code/PyAutoJAX/PyAutoFit/autofit/mapper/prior_model/array.py", line 90, in _instance_for_arguments
    array[index] = value
    ~~~~~^^^^^^^
ValueError: setting an array element with a sequence.
--------------------

I am working on feature/graphical_pytrees but I don't think this is related to graphical models, but an issue with the use of af.Array and a gradded likelihood function.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions