Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 32 additions & 8 deletions autofit/non_linear/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,23 @@ class Analysis(ABC):
def __init__(
self,
use_jax: bool = False,
use_jax_for_visualization: bool = False,
use_jax_for_visualization: Optional[bool] = None,
**kwargs,
):
"""
Parameters
----------
use_jax
Run the likelihood through ``jax.jit`` for the fast path. When JAX
is unavailable this silently falls back to numpy with a warning.
use_jax_for_visualization
Whether ``fit_for_visualization`` should dispatch through the
``jax.jit``-cached path. ``None`` (default) follows ``use_jax`` —
users who set ``use_jax=True`` automatically get JIT visualization.
Pass ``False`` to force the eager NumPy plotter even when
``use_jax=True``; pass ``True`` to opt in explicitly. Passing
``True`` while ``use_jax=False`` logs a warning and disables it.
"""
import os
if os.environ.get("PYAUTO_DISABLE_JAX") == "1":
use_jax = False
Expand Down Expand Up @@ -68,6 +82,9 @@ def __init__(
use_jax = False
use_jax_for_visualization = False

if use_jax_for_visualization is None:
use_jax_for_visualization = use_jax

if use_jax_for_visualization and not use_jax:
logger.warning(
"use_jax_for_visualization=True requires use_jax=True; "
Expand All @@ -83,15 +100,21 @@ def fit_for_visualization(self, instance):
"""
Build the fit used by the visualizer.

Dispatch over ``self.fit_from`` with an opt-in ``jax.jit`` fast path:
Dispatch over ``self.fit_from`` with a ``jax.jit`` fast path that
follows ``use_jax`` by default:

* ``use_jax_for_visualization=False`` (default) — plain
``self.fit_from(instance)``. Untouched by JAX.
* ``use_jax_for_visualization=True`` — lazily construct
* ``self._use_jax_for_visualization`` is ``False`` — plain
``self.fit_from(instance)``. Untouched by JAX. This is the
resolved state when ``use_jax=False`` (the parameter default),
or when the user explicitly passed
``use_jax_for_visualization=False`` to opt out.
* ``self._use_jax_for_visualization`` is ``True`` — lazily construct
``jax.jit(self.fit_from)`` on the first call and cache it on the
instance as ``_jitted_fit_from``, then call that for every
subsequent visualization. The first call pays the compile cost;
subsequent calls reuse the cached compiled function.
subsequent calls reuse the cached compiled function. This is the
resolved state when ``use_jax=True`` (the sentinel default
``use_jax_for_visualization=None`` follows ``use_jax``).

Caching is per-``Analysis`` instance so each analysis gets its own
compiled function keyed off that instance's closed-over state
Expand All @@ -108,8 +131,9 @@ def fit_for_visualization(self, instance):
nested autoarray / galaxy / lens type it carries) must be pytree-
registered. That wiring lives in each analysis subclass (see
``AnalysisImaging._register_fit_imaging_pytrees`` in PyAutoLens).
Variants that have not yet been pytree-audited must leave
``use_jax_for_visualization`` at its default of ``False``.
Variants that have not yet been pytree-audited must pass
``use_jax_for_visualization=False`` explicitly when constructing
the analysis (or simply leave ``use_jax=False``).
"""
if not self._use_jax_for_visualization:
return self.fit_from(instance=instance)
Expand Down
11 changes: 11 additions & 0 deletions test_autofit/analysis/test_use_jax_for_visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,17 @@ def test_pyauto_disable_jax_env_var_clears_both_flags(monkeypatch):
assert analysis._use_jax_for_visualization is False


def test_pyauto_disable_jax_overrides_sentinel_default(monkeypatch):
"""PYAUTO_DISABLE_JAX=1 must still force both off even when the user
constructs Analysis(use_jax=True) and lets the sentinel resolve. This is
a numpy-only check — JAX-conditional sentinel-resolution assertions live
in autofit_workspace_test/scripts/jax_assertions/fitness_dispatch.py."""
monkeypatch.setenv("PYAUTO_DISABLE_JAX", "1")
analysis = af.Analysis(use_jax=True)
assert analysis._use_jax is False
assert analysis._use_jax_for_visualization is False


def test_fit_for_visualization_works_without_flag():
analysis = _FittableAnalysis()
result = analysis.fit_for_visualization(instance="sentinel")
Expand Down
Loading