From a0081fe0ce8ab200c44b185c77132d68236262e5 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Sat, 16 May 2026 10:25:31 +0100 Subject: [PATCH] feat: af.NSS checkpoint/resume + on-the-fly visualization (Phases 2-3) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Phases 2-3 of nss_first_class_sampler land together — they share the same architectural hook (the outer loop in NSS._fit) so the PR couples them naturally. Phase 2 — checkpointing/resume: - New checkpoint_interval kwarg (default 100 outer iterations). - Module-level _save_checkpoint / _load_checkpoint helpers serialise the (state, dead, run_key, iteration) tuple via pickle, round-tripping through NumPy with jax.tree_util.tree_map so the on-disk format is independent of the JAX install. Atomic via tmp-and-rename — a SLURM timeout halfway through a write leaves the previous good checkpoint intact. - _fit detects an existing nss_checkpoint.pkl at paths.search_internal_path on entry and resumes; otherwise initialises fresh. Post-success cleanup deletes the checkpoint to mirror Nautilus's output_search_internal pattern. Phase 3 — on-the-fly visualization: - iterations_per_quick_update (already accepted but no-op in Phase 1) now wires to a _fire_quick_update helper that picks the current best live particle from state.particles and calls analysis.visualize(paths, instance, during_analysis=True). Wrapped in try/except so a viz failure logs a warning but does not kill a long fit. Implementation: the upstream nss.ns.run_nested_sampling outer loop is plain Python (not jax.lax.while_loop — the JIT boundary is one_step, processing num_delete deaths per outer iteration). Inlining the loop in NSS._fit lets us hook checkpoint writes + viz calls between iterations without any upstream yallup/nss PR. Replaces the single run_nested_sampling call with a blackjax.nss algo + manual outer while-loop + finalise + log_weights — mirroring the upstream pattern. Phase 1's stub log warnings (resume-not-supported, iterations_per_quick_update-no-op) removed. checkpoint_file property now points at nss_checkpoint.pkl (was the unused state.json sentinel). Closes #1273 Co-Authored-By: Claude Opus 4.7 (1M context) --- autofit/non_linear/search/nest/nss/search.py | 299 +++++++++++++----- .../search/nest/nss/test_checkpoint.py | 237 ++++++++++++++ 2 files changed, 462 insertions(+), 74 deletions(-) create mode 100644 test_autofit/non_linear/search/nest/nss/test_checkpoint.py diff --git a/autofit/non_linear/search/nest/nss/search.py b/autofit/non_linear/search/nest/nss/search.py index e5f46d0ed..21a668439 100644 --- a/autofit/non_linear/search/nest/nss/search.py +++ b/autofit/non_linear/search/nest/nss/search.py @@ -1,4 +1,6 @@ import logging +import os +import pickle from pathlib import Path from typing import Optional @@ -15,18 +17,69 @@ try: - from nss.ns import run_nested_sampling, log_weights as _nss_log_weights + import blackjax as _blackjax + from nss.ns import ( + log_weights as _nss_log_weights, + finalise as _nss_finalise, + ) _HAS_NSS = True except ImportError: - run_nested_sampling = None + _blackjax = None _nss_log_weights = None + _nss_finalise = None _HAS_NSS = False logger = logging.getLogger(__name__) +_CHECKPOINT_FILENAME = "nss_checkpoint.pkl" + + +def _save_checkpoint(path, state, dead, run_key, iteration): + """Atomically pickle the resumable state of an in-flight NSS run. + + The blackjax ``state`` and each entry in ``dead`` are pytrees of JAX arrays. + JAX arrays do pickle directly, but to keep the on-disk format independent + of the JAX install (so a checkpoint written on one cluster can be loaded + on another) we round-trip through NumPy before writing. ``_load_checkpoint`` + reverses the conversion. + + A tmp-and-rename pattern guards against partial writes — a SLURM timeout + halfway through a pickle dump leaves the previous-good checkpoint intact. + """ + import jax + + to_numpy = lambda x: jax.tree_util.tree_map(np.asarray, x) + blob = { + "state": to_numpy(state), + "dead": [to_numpy(d) for d in dead], + "run_key": np.asarray(run_key), + "iteration": int(iteration), + } + tmp_path = Path(str(path) + ".tmp") + with open(tmp_path, "wb") as f: + pickle.dump(blob, f) + os.replace(tmp_path, path) + + +def _load_checkpoint(path): + """Reverse of ``_save_checkpoint`` — restore a saved blob to JAX pytrees.""" + import jax + import jax.numpy as jnp + + with open(path, "rb") as f: + blob = pickle.load(f) + to_jax = lambda x: jax.tree_util.tree_map(jnp.asarray, x) + return ( + to_jax(blob["state"]), + [to_jax(d) for d in blob["dead"]], + jnp.asarray(blob["run_key"]), + int(blob["iteration"]), + ) + + class _NSSInternal: """Container holding the post-run state of ``nss.ns.run_nested_sampling``. @@ -87,6 +140,7 @@ def __init__( num_mcmc_steps: int = 5, num_delete: int = 50, termination: float = -3.0, + checkpoint_interval: int = 100, iterations_per_quick_update: Optional[int] = None, iterations_per_full_update: Optional[int] = None, number_of_cores: int = 1, @@ -106,15 +160,19 @@ def __init__( to make ``model.vector_from_unit_vector`` and ``model.log_prior_list_from_vector`` traceable. - Phase 1 of the ``nss_first_class_sampler`` roadmap. Checkpointing / - resumption (Phase 2) and on-the-fly visualization (Phase 3) are stubbed - — kwargs are accepted but log a warning when set, and a state file at - ``paths.search_internal_path / state.json`` triggers a warning that - resume is not yet supported (the fit then proceeds from scratch). - + Phases 1-3 of the ``nss_first_class_sampler`` roadmap are live: + - Phase 1: the wrapper itself (this class). + - Phase 2: checkpoint/resume via ``checkpoint_interval`` — a + ``nss_checkpoint.pkl`` is written to ``paths.search_internal_path`` + every N outer iterations and reloaded automatically on resume. + - Phase 3: on-the-fly visualization via ``iterations_per_quick_update`` + — every N outer iterations the current best live particle is fed to + ``analysis.visualize`` so partial results appear in the image_path + directory during long fits. + + Phase 4 (``pip install autofit[nss]`` extra) is still pending — for now ``af.NSS`` is an optional requirement and must be installed manually - via ``pip install git+https://github.com/yallup/nss.git`` (Phase 4 of - the roadmap will ship a ``pyautofit[nss]`` extra). + via ``pip install git+https://github.com/yallup/nss.git``. Parameters ---------- @@ -141,12 +199,21 @@ def __init__( Convergence criterion. The fit stops when ``logZ_live - logZ < termination``. Default ``-3.0`` corresponds to delta-logZ < 1e-3. + checkpoint_interval + Outer iterations between checkpoint writes. Default ``100`` writes + a ``nss_checkpoint.pkl`` (atomic via tmp-and-rename) every ~5000- + 10000 likelihood evaluations at typical ``num_delete=50``. Set to + a large value to effectively disable checkpointing on short runs. iterations_per_quick_update - Accepted for API parity with other nested samplers. **Not yet - wired** — quick-update visualization is Phase 3. + Outer iterations between on-the-fly visualizations. When non-None + the current best live particle is fed to ``analysis.visualize`` + every N iterations so partial results appear in the image_path + directory during long fits. ``analysis.visualize`` is wrapped in + try/except so a viz failure logs a warning but does not kill the + sampler. iterations_per_full_update - Accepted for API parity. NSS performs its own internal output - cadence and does not honour intermediate full updates in Phase 1. + Accepted for API parity. NSS does not have a full-update concept + separate from the outer-iteration cadence. number_of_cores Accepted for API parity only. NSS runs on whatever device JAX is configured for (CPU, GPU, TPU) — multiprocessing parallelism is @@ -196,6 +263,7 @@ def __init__( self.num_mcmc_steps = num_mcmc_steps self.num_delete = num_delete self.termination = termination + self.checkpoint_interval = checkpoint_interval self.seed = seed if number_of_cores is not None and number_of_cores > 1: @@ -206,15 +274,6 @@ def __init__( number_of_cores, ) - if iterations_per_quick_update is not None: - logger.info( - "af.NSS received iterations_per_quick_update=%s. Quick-update " - "visualization is Phase 3 of the nss_first_class_sampler " - "roadmap and is not yet wired up; the kwarg is currently a " - "no-op.", - iterations_per_quick_update, - ) - if is_test_mode(): self.apply_test_mode() @@ -229,23 +288,28 @@ def apply_test_mode(self): def _fit(self, model: AbstractPriorModel, analysis): """ - Fit a model using NSS. + Fit a model using NSS, with checkpoint/resume + on-the-fly visualization. Builds JAX-traceable ``log_likelihood`` and ``prior_logprob`` closures - threaded through Phase 0's ``xp=jnp`` plumbing, draws ``n_live`` - initial particles by mapping unit-cube samples through the prior - transform, then calls ``nss.ns.run_nested_sampling``. The returned - ``final_state`` and ``results`` are repackaged into a ``_NSSInternal`` - holder (NumPy arrays only) so the standard PyAutoFit pickled-search - path keeps working. + threaded through Phase 0's ``xp=jnp`` plumbing, draws ``n_live`` initial + particles by mapping unit-cube samples through the prior transform, + and runs the NSS outer loop inline (mirroring the upstream + ``nss.ns.run_nested_sampling`` pattern). Between outer iterations the + loop can (a) pickle resumable state to ``nss_checkpoint.pkl`` and + (b) call ``analysis.visualize`` on the current best live particle. + + On entry, if a checkpoint exists at the expected path the loop resumes + from the saved ``(state, dead, run_key, iteration)``. On successful + exit the checkpoint is deleted — mirrors Nautilus's + ``output_search_internal`` post-success cleanup so the next fresh fit + doesn't accidentally resume from a stale checkpoint. Returns ------- (search_internal, fitness) - ``search_internal`` is a ``_NSSInternal`` holder. ``fitness`` is a - ``Fitness`` instance that is **not** used by ``af.NSS`` for - sampling (the JAX likelihood + prior closures are built inline - and passed straight to ``nss.ns.run_nested_sampling``) but is + ``search_internal`` is a ``_NSSInternal`` holder (NumPy arrays + only). ``fitness`` is a ``Fitness`` instance that ``af.NSS`` does + not use for sampling (inline JAX closures handle that) but is required by ``AbstractNest.perform_update`` for post-fit work like latent-sample generation, which calls ``fitness.batch_size``. """ @@ -254,17 +318,7 @@ def _fit(self, model: AbstractPriorModel, analysis): import jax.numpy as jnp import time - if not isinstance(self.paths, NullPaths): - state_file = Path(self.paths.search_internal_path) / "state.json" - if state_file.exists(): - self.logger.warning( - "Detected %s — resume is Phase 2 of the " - "nss_first_class_sampler roadmap and is not yet wired up. " - "Proceeding with a fresh fit.", - state_file, - ) - - self.logger.info("Starting new NSS non-linear search.") + self.logger.info("Starting NSS non-linear search.") ndim = model.prior_count @@ -288,42 +342,94 @@ def prior_logprob(params): ] ) - self.logger.info( - "NSS configuration: n_live=%d, num_mcmc_steps=%d, num_delete=%d, " - "termination=%s, ndim=%d. JIT compile on first iteration may " - "take 25-30 s.", - self.n_live, - self.num_mcmc_steps, - self.num_delete, - self.termination, - ndim, - ) - - t_start = time.time() - final_state, results = run_nested_sampling( - run_key, + algo = _blackjax.nss( + logprior_fn=prior_logprob, loglikelihood_fn=log_likelihood, - prior_logprob=prior_logprob, - num_mcmc_steps=self.num_mcmc_steps, - initial_samples=initial_samples, num_delete=self.num_delete, - termination=self.termination, + num_inner_steps=self.num_mcmc_steps, ) + + @jax.jit + def one_step(carry, _): + state, k = carry + k, subk = jax.random.split(k, 2) + state, dead_point = algo.step(subk, state) + return (state, k), dead_point + + checkpoint_path = self._nss_checkpoint_path + if checkpoint_path is not None and checkpoint_path.exists(): + state, dead, run_key, iteration = _load_checkpoint(checkpoint_path) + self.logger.info( + "Resuming NSS from checkpoint at iteration %d (state file %s).", + iteration, + checkpoint_path, + ) + else: + state = algo.init(initial_samples) + dead = [] + iteration = 0 + self.logger.info( + "NSS configuration: n_live=%d, num_mcmc_steps=%d, num_delete=%d, " + "termination=%s, ndim=%d, checkpoint_interval=%d. JIT compile on " + "first iteration may take 25-30 s.", + self.n_live, + self.num_mcmc_steps, + self.num_delete, + self.termination, + ndim, + self.checkpoint_interval, + ) + + t_start = time.time() + while not state.integrator.logZ_live - state.integrator.logZ < self.termination: + (state, run_key), dead_info = one_step((state, run_key), None) + dead.append(dead_info) + iteration += 1 + + if ( + checkpoint_path is not None + and iteration % self.checkpoint_interval == 0 + ): + _save_checkpoint(checkpoint_path, state, dead, run_key, iteration) + + if ( + self.iterations_per_quick_update is not None + and iteration % self.iterations_per_quick_update == 0 + ): + self._fire_quick_update(state=state, model=model, analysis=analysis) + wall_time = time.time() - t_start + final_state = _nss_finalise(state, dead) + rng_key, weight_key = jax.random.split(rng_key, 2) log_w_mc = _nss_log_weights(weight_key, final_state, shape=100) log_w_per_particle = log_w_mc.mean(axis=-1) + logZs = jax.scipy.special.logsumexp( + jnp.nan_to_num(log_w_mc, nan=jnp.nan_to_num(log_w_mc).min()), + axis=0, + ) + + def _safe_ess(log_w_mean): + log_w_mean = log_w_mean - log_w_mean.max() + weights = jnp.exp(log_w_mean) + return float(weights.sum() ** 2 / (weights ** 2).sum()) + + ess = int(_safe_ess(log_w_mc.mean(axis=-1))) + evals = int( + final_state.update_info.num_steps.sum() + + final_state.update_info.num_shrink.sum() + ) search_internal = _NSSInternal( positions=np.asarray(final_state.particles.position), loglikelihoods=np.asarray(final_state.particles.loglikelihood), log_weights=np.asarray(log_w_per_particle), - logZs=np.asarray(results.logZs), + logZs=np.asarray(logZs), wall_time=float(wall_time), - sampling_time=float(results.time), - evals=int(results.evals), - ess=int(results.ess), + sampling_time=float(wall_time), + evals=evals, + ess=ess, n_live=int(self.n_live), num_mcmc_steps=int(self.num_mcmc_steps), num_delete=int(self.num_delete), @@ -331,6 +437,18 @@ def prior_logprob(params): seed=int(self.seed), ) + if checkpoint_path is not None and checkpoint_path.exists(): + try: + checkpoint_path.unlink() + except OSError as exc: + self.logger.warning( + "Failed to delete completed-run checkpoint %s: %s. The " + "next fresh af.NSS fit at this path will attempt to resume " + "from it — delete manually if that is not desired.", + checkpoint_path, + exc, + ) + fitness = Fitness( model=model, analysis=analysis, @@ -343,17 +461,50 @@ def prior_logprob(params): return search_internal, fitness @property - def checkpoint_file(self): - """Path to the checkpoint file used by Phase 2's resume hook. - - Phase 1 only checks for existence to warn the user that resume is not - yet supported. Phase 2 will use this path to write incremental state. - """ + def _nss_checkpoint_path(self) -> Optional[Path]: + """Resolve the checkpoint location, or None when paths is NullPaths.""" + if isinstance(self.paths, NullPaths): + return None try: - return self.paths.search_internal_path / "state.json" + return Path(self.paths.search_internal_path) / _CHECKPOINT_FILENAME except TypeError: return None + def _fire_quick_update(self, state, model, analysis): + """Push the current best live particle through ``analysis.visualize``. + + The Nautilus / Dynesty quick-update path goes through + ``Fitness.manage_quick_update``; ``af.NSS`` bypasses ``Fitness._call`` + for sampling so we invoke ``analysis.visualize`` directly between + outer-loop iterations. Wrapped in try/except — a visualization failure + logs a warning but does not kill a long sampler run. + """ + try: + best_idx = int(np.asarray(state.particles.loglikelihood).argmax()) + best_params = np.asarray(state.particles.position[best_idx]).tolist() + instance = model.instance_from_vector(vector=best_params) + analysis.visualize( + paths=self.paths, + instance=instance, + during_analysis=True, + ) + except Exception as exc: + self.logger.warning( + "af.NSS quick-update visualization failed: %s. Continuing the " + "fit — quick-update is best-effort, the final visualization " + "fires at the end of the run regardless.", + exc, + ) + + @property + def checkpoint_file(self): + """Path to the on-disk checkpoint written between outer-loop iterations. + + Returns the same value as ``_nss_checkpoint_path`` — exposed as a + public property for symmetry with ``af.Nautilus.checkpoint_file``. + """ + return self._nss_checkpoint_path + def samples_info_from(self, search_internal: Optional[_NSSInternal] = None): if search_internal is None: search_internal = self.paths.load_search_internal() diff --git a/test_autofit/non_linear/search/nest/nss/test_checkpoint.py b/test_autofit/non_linear/search/nest/nss/test_checkpoint.py new file mode 100644 index 000000000..0174d3b22 --- /dev/null +++ b/test_autofit/non_linear/search/nest/nss/test_checkpoint.py @@ -0,0 +1,237 @@ +""" +Unit tests for the ``af.NSS`` checkpoint/resume helpers +(``_save_checkpoint`` / ``_load_checkpoint``). + +No real ``nss.ns.run_nested_sampling`` calls — the heavy end-to-end resume +verification lives in +``autolens_workspace_developer/searches_minimal/nss_checkpoint_resume.py``. +""" + +from unittest.mock import patch + +import numpy as np +import pytest + +import autofit as af +from autofit.non_linear.search.nest.nss import search as nss_search_module +from autofit.non_linear.search.nest.nss.search import ( + _load_checkpoint, + _save_checkpoint, +) + + +pytestmark = pytest.mark.filterwarnings("ignore::FutureWarning") + + +jax = pytest.importorskip("jax") +jnp = pytest.importorskip("jax.numpy") + + +def _synthetic_state(): + """Plain-dict pytree mimicking the blackjax NSS state shape. + + We only need a pytree of JAX arrays — the round-trip serialiser doesn't + care about the type as long as ``jax.tree_util.tree_map`` can walk it. + Plain dicts are registered pytrees; SimpleNamespace is not. + """ + return { + "particles": { + "position": jnp.asarray([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]), + "loglikelihood": jnp.asarray([-1.5, -0.5, -0.1]), + }, + "integrator": { + "logZ": jnp.float64(-12.3), + "logZ_live": jnp.float64(-11.8), + }, + } + + +def _synthetic_dead(n_iter=3): + return [ + { + "particles": { + "position": jnp.asarray([[float(i), float(i + 1)]]), + "loglikelihood": jnp.asarray([-float(i)]), + }, + } + for i in range(n_iter) + ] + + +def test__save_load_checkpoint_round_trip(tmp_path): + state = _synthetic_state() + dead = _synthetic_dead() + run_key = jax.random.PRNGKey(7) + + path = tmp_path / "nss_checkpoint.pkl" + _save_checkpoint(path, state, dead, run_key, iteration=42) + + assert path.exists() + loaded_state, loaded_dead, loaded_run_key, loaded_iter = _load_checkpoint(path) + + assert loaded_iter == 42 + assert np.array_equal(np.asarray(loaded_run_key), np.asarray(run_key)) + + np.testing.assert_array_equal( + np.asarray(loaded_state["particles"]["position"]), + np.asarray(state["particles"]["position"]), + ) + np.testing.assert_array_equal( + np.asarray(loaded_state["particles"]["loglikelihood"]), + np.asarray(state["particles"]["loglikelihood"]), + ) + np.testing.assert_array_equal( + np.asarray(loaded_state["integrator"]["logZ"]), + np.asarray(state["integrator"]["logZ"]), + ) + + assert len(loaded_dead) == len(dead) + for orig, loaded in zip(dead, loaded_dead): + np.testing.assert_array_equal( + np.asarray(loaded["particles"]["position"]), + np.asarray(orig["particles"]["position"]), + ) + + +def test__save_checkpoint_is_atomic(tmp_path): + """A partially-written checkpoint must not clobber the previous good one. + + The helper writes to ``.tmp`` then ``os.replace`` to the final path. + If the rename fails (simulated here by patching ``os.replace`` to raise), + the final path must be untouched — the previous-good blob still loads. + """ + state = _synthetic_state() + dead = _synthetic_dead() + run_key = jax.random.PRNGKey(3) + path = tmp_path / "nss_checkpoint.pkl" + + _save_checkpoint(path, state, dead, run_key, iteration=1) + good_blob = path.read_bytes() + + state["integrator"]["logZ"] = jnp.float64(-100.0) + with patch( + "autofit.non_linear.search.nest.nss.search.os.replace", + side_effect=OSError("simulated rename failure"), + ): + with pytest.raises(OSError): + _save_checkpoint(path, state, dead, run_key, iteration=2) + + assert path.read_bytes() == good_blob + + +def test__nss_checkpoint_path_is_none_for_null_paths(): + """Without a real output dir (NullPaths), checkpoint resolution returns None. + + This means ``_fit``'s resume detection silently skips for NullPaths fits + (e.g. unit tests, in-memory aggregator round-trips) rather than blowing + up on the missing ``search_internal_path`` attribute. + """ + search = af.NSS() + assert search._nss_checkpoint_path is None + assert search.checkpoint_file is None + + +def test__init_accepts_checkpoint_interval(): + search = af.NSS(checkpoint_interval=25) + assert search.checkpoint_interval == 25 + + default = af.NSS() + assert default.checkpoint_interval == 100 + + +def test__init_iterations_per_quick_update_no_longer_warns(caplog): + """Phase 1's no-op log when the kwarg is set is gone in Phase 3 (the kwarg + is now actually wired). The warning text must not appear. + """ + with caplog.at_level("INFO"): + af.NSS(iterations_per_quick_update=10) + assert not any( + "not yet wired" in record.message + for record in caplog.records + ) + + +def test__load_checkpoint_called_when_file_exists(tmp_path): + """Verify ``_load_checkpoint`` is invoked from ``_fit`` when a checkpoint + file exists at the resolved path. + + We replace ``_blackjax.nss`` with a mock that fails loudly if ``algo.init`` + fires — i.e. if the fresh-init branch ran. The resume branch must call + ``_load_checkpoint`` first; we patch that helper to return a sentinel + that satisfies the immediate-termination logZ check, so the outer loop + exits before doing any real work. + """ + from types import SimpleNamespace + + fake_checkpoint = tmp_path / "nss_checkpoint.pkl" + fake_checkpoint.write_bytes(b"placeholder - actual contents replaced by mock") + + sentinel_state = SimpleNamespace( + integrator=SimpleNamespace(logZ=0.0, logZ_live=-100.0), + particles=SimpleNamespace( + position=jnp.zeros((4, 2)), + loglikelihood=jnp.zeros(4), + ), + ) + + # Minimum stub for model + analysis. _fit calls model.prior_count and + # both vector_from_unit_vector + log_prior_list_from_vector inside the + # closure construction (which isn't traced unless one_step fires). + mock_model = SimpleNamespace( + prior_count=2, + instance_from_vector=lambda **kw: SimpleNamespace(), + vector_from_unit_vector=lambda v, xp=None: jnp.asarray([0.0, 0.0]), + log_prior_list_from_vector=lambda **kw: [0.0, 0.0], + ) + mock_analysis = SimpleNamespace( + log_likelihood_function=lambda instance: 0.0, + ) + + search = af.NSS(n_live=4, num_mcmc_steps=1, num_delete=1, termination=-3.0) + # Force the checkpoint property to return our sentinel path even though + # paths is the default NullPaths. + with patch.object( + type(search), + "_nss_checkpoint_path", + new=fake_checkpoint, + ), patch.object( + nss_search_module, + "_load_checkpoint", + return_value=(sentinel_state, [], jax.random.PRNGKey(0), 17), + ) as mock_load, patch.object( + nss_search_module, "_blackjax", + ) as mock_bjax, patch.object( + nss_search_module, + "_nss_finalise", + return_value=SimpleNamespace( + particles=SimpleNamespace( + position=np.zeros((1, 2)), + loglikelihood=np.zeros(1), + ), + update_info=SimpleNamespace( + num_steps=np.zeros(1, dtype=int), + num_shrink=np.zeros(1, dtype=int), + ), + ), + ), patch.object( + nss_search_module, + "_nss_log_weights", + return_value=jnp.zeros((1, 100)), + ): + mock_bjax.nss.return_value.init.side_effect = AssertionError( + "algo.init called from resume path — expected _load_checkpoint instead." + ) + mock_bjax.nss.return_value.step.side_effect = AssertionError( + "algo.step called even though logZ termination should fire immediately." + ) + + # The mocked downstream pipeline (Fitness construction, _NSSInternal + # repackaging) is intentionally not realistic — we only care that the + # resume branch was entered and ``_load_checkpoint`` was called. Catch + # any downstream stub-related failure; the assertion below is the gate. + try: + search._fit(model=mock_model, analysis=mock_analysis) + except (AttributeError, AssertionError, TypeError): + pass + + mock_load.assert_called_once_with(fake_checkpoint)