Skip to content

feat: af.NSS checkpointing + on-the-fly visualization (Phases 2-3) #1273

@Jammy2211

Description

@Jammy2211

Overview

Wire up checkpointing (Phase 2) and on-the-fly visualization (Phase 3) for af.NSS. Both phases share the same architectural hook — the outer loop in af.NSS._fit — so they ship together in one PyAutoFit PR. Closes Phases 2-3 of the nss_first_class_sampler roadmap.

Critical finding during the Phase 1 follow-up audit: the roadmap's framing was wrong. nss.ns.run_nested_sampling is not a one-shot JIT'd jax.lax.while_loop requiring an upstream yallup/nss PR. The JIT boundary is one_step (one outer iteration); the outer loop is plain Python. This means both checkpointing and visualization can be implemented entirely inside af.NSS._fit — no upstream dependency.

Plan

  • Replace the single nss.ns.run_nested_sampling(...) call inside NSS._fit with an inlined equivalent: build blackjax.nss algo + JIT'd one_step closure, then run the outer loop locally.
  • Add a checkpoint_interval: int = 100 kwarg. Every N outer iterations, pickle (state, dead, rng_key, iteration) to paths.search_internal_path / "nss_checkpoint.pkl". Standard JAX pytree → NumPy round-trip for serialisation.
  • On _fit entry, detect an existing checkpoint and resume from it (state + dead list + RNG); otherwise initialise fresh. The Phase 1 "resume not yet supported" warning goes away.
  • Wire the existing iterations_per_quick_update kwarg (Phase 1 accepted-but-no-op) to call analysis.visualize(paths=..., instance=..., during_analysis=True) on the current best live point between outer iterations.
  • On successful completion, delete the checkpoint file — mirror Nautilus's output_search_internal pattern.
  • Add unit tests for _save_checkpoint / _load_checkpoint pytree round-trip, checkpoint_interval kwarg acceptance, and resume detection (mocked to avoid running real nss).
  • Add an end-to-end resume integration smoke under autolens_workspace_developer/searches_minimal/ — run for N iterations, simulate interrupt, restart, confirm continuation to convergence with identical final state.
Detailed implementation plan

Affected Repositories

  • PyAutoFit (primary — library)
  • autolens_workspace_developer (secondary — resume smoke + viz smoke scripts)

Work Classification

Library (with dev-workspace smoke ridealong on the same branch)

Branch Survey

Repository Current Branch Dirty?
./PyAutoFit main clean
./autolens_workspace_developer main dirty (pre-existing local work — not ours, untouched)

worktree_check_conflict (exit=0): no active task claims either repo.

Suggested branch: feature/nss-checkpointing-and-visualization
Worktree root: ~/Code/PyAutoLabs-wt/nss-checkpointing-and-visualization/ (created by /start_library)

Implementation Steps

  1. Inline the outer loop in NSS._fitPyAutoFit/autofit/non_linear/search/nest/nss/search.py. Replace the existing run_nested_sampling(...) call with:

    import blackjax
    from nss.ns import finalise, log_weights as _nss_log_weights, Results, safe_ess
    
    algo = blackjax.nss(
        logprior_fn=prior_logprob,
        loglikelihood_fn=log_likelihood,
        num_delete=self.num_delete,
        num_inner_steps=self.num_mcmc_steps,
    )
    
    @jax.jit
    def one_step(carry, xs):
        state, k = carry
        k, subk = jax.random.split(k, 2)
        state, dead_point = algo.step(subk, state)
        return (state, k), dead_point
    
    checkpoint_path = self._nss_checkpoint_path
    if checkpoint_path is not None and checkpoint_path.exists():
        state, dead, run_key, iteration = _load_checkpoint(checkpoint_path)
        self.logger.info("Resuming NSS from iteration %d", iteration)
    else:
        state = algo.init(initial_samples)
        dead = []
        iteration = 0
    
    t_start = time.time()
    while not state.integrator.logZ_live - state.integrator.logZ < self.termination:
        (state, run_key), dead_info = one_step((state, run_key), None)
        dead.append(dead_info)
        iteration += 1
    
        # Phase 2 — checkpoint
        if (checkpoint_path is not None
            and iteration % self.checkpoint_interval == 0):
            _save_checkpoint(checkpoint_path, state, dead, run_key, iteration)
    
        # Phase 3 — quick-update visualization
        if (self.iterations_per_quick_update is not None
            and iteration % self.iterations_per_quick_update == 0):
            self._fire_quick_update(state, model, analysis)
    
    wall_time = time.time() - t_start
    final_state = finalise(state, dead)
    # ... existing _NSSInternal repackaging
  2. Add _save_checkpoint + _load_checkpoint module-level helpers — pickle-based round-trip via jax.tree_util.tree_map(np.asarray, ...) and inverse. Robust to interrupted writes via atomic rename (pickle.dump to *.tmp, then os.replace to final path).

  3. Add _fire_quick_update(state, model, analysis) helper — extracts the highest-loglikelihood live particle from state.particles, maps to a ModelInstance via model.instance_from_vector, and calls analysis.visualize(paths=self.paths, instance=instance, during_analysis=True). Wraps in try/except to log + continue on visualization errors (don't kill a long fit because a plot misfired).

  4. Add checkpoint_interval kwarg to NSS.__init__ (default 100). Update the docstring's "Stubbed / out of scope" section and remove the Phase 1 "resume not yet supported" + "quick-update visualization not yet wired" warnings.

  5. Delete checkpoint on success — at the end of _fit, after _NSSInternal is built, if checkpoint_path is not None and checkpoint_path.exists(): checkpoint_path.unlink(). Mirror Nautilus's output_search_internal post-success cleanup.

  6. Unit testsPyAutoFit/test_autofit/non_linear/search/nest/nss/test_search.py (extend) and a new test_checkpoint.py:

    • _save_checkpoint / _load_checkpoint round-trip on a synthetic state pytree
    • NSS.__init__(checkpoint_interval=...) accepted; identifier_fields unchanged
    • Resume detection — monkeypatch Path.exists + the loader, confirm _fit enters the resume branch (without running nss)
    • Atomic-write semantics — interrupt mid-_save_checkpoint leaves .tmp file, not partial .pkl
  7. Integration smokeautolens_workspace_developer/searches_minimal/nss_checkpoint_resume.py:

    • Run af.NSS with checkpoint_interval=5, num_delete=10, n_live=40 (2-param Gaussian smoke model)
    • Mid-run, save state and kill via sys.exit inside a quick-update callback
    • Restart with the same paths; assert resume happens (log line) and final state matches a single-shot run
  8. Quick-update smoke — extend nss_first_class_gaussian.py (or create a sibling): set iterations_per_quick_update=3, run, assert paths.image_path contains PNGs written before final convergence.

Key Files

  • PyAutoFit/autofit/non_linear/search/nest/nss/search.py — main _fit rewrite + checkpoint helpers + viz hook
  • PyAutoFit/test_autofit/non_linear/search/nest/nss/test_search.py — extended unit tests
  • PyAutoFit/test_autofit/non_linear/search/nest/nss/test_checkpoint.py — new, focused on serialisation round-trip
  • autolens_workspace_developer/searches_minimal/nss_checkpoint_resume.py — new integration smoke
  • autolens_workspace_developer/searches_minimal/nss_first_class_gaussian.py — extend with quick-update assertions

Out of scope

  • JIT persistent cache (separate follow-up — each cold + resumed fit pays 25-30 s while_loop compile)
  • Install simplification — Phase 4 (autofit/nss_install_simplification.md)
  • Workspace tutorial scripts — Phase 5
  • iterations_per_full_update activation — still API-parity-only for nss (no separate full-update concept)

Risks / open questions

  1. dead list memory growth — for a typical 5000-outer-iteration run, dead accumulates 5000 NSInfo pytrees in Python memory. Each pickle write re-serialises all of them — wasteful disk I/O if the list is long. Land the naïve pickle.dump(dead) first, measure, optimise to incremental-append later if needed. Open question for Phase 4 batch users running many sequential fits.

  2. Checkpoint after success — Nautilus deletes its checkpoint.hdf5 after completion. Mirror that pattern: delete nss_checkpoint.pkl on _fit exit so the next fresh fit doesn't accidentally resume from a stale checkpoint. Open: should we leave the file with a .completed suffix for forensic inspection? Probably not — the samples.csv + samples_summary.json capture everything users actually need.

  3. Resume reproducibility — verify that single-shot(50 iter) and (run 25, save, resume, run 25 more) produce byte-identical state.particles.position at iteration 50. Add as a parity test in the integration smoke.

  4. Visualization costanalysis.visualize on the autolens HST MGE problem takes seconds per call (model plots, residuals, etc.). At iterations_per_quick_update=10 with num_delete=50, that's a viz every 500 evals — fine. Document so users don't set iterations_per_quick_update=1 and tank performance.

Original Prompt

Click to expand starting prompt

Add checkpointing + on-the-fly visualization to af.NSS.

This is Phases 2 and 3 of z_features/nss_first_class_sampler.md. Both
phases share the same architectural hook — the outer loop in
af.NSS._fit — so they ship together in one PR.

Critical finding from Phase 1 follow-up audit

The z_features roadmap claimed nss.ns.run_nested_sampling is "a one-shot
JIT'd jax.lax.while_loop" requiring an upstream yallup/nss PR to add
a checkpoint hook. This is wrong. Inspecting the actual upstream
source (/home/jammy/venv/PyAuto/lib/python3.12/site-packages/nss/ns.py):

@jax.jit
def one_step(carry, xs):
    state, k = carry
    k, subk = jax.random.split(k, 2)
    state, dead_point = algo.step(subk, state)
    return (state, k), dead_point

dead = []

while not state.integrator.logZ_live - state.integrator.logZ < termination:
    (state, rng_key), dead_info = one_step((state, rng_key), None)
    dead.append(dead_info)

The JIT boundary is one_step (one outer iteration = num_delete deaths
processed in one batch). The outer while loop is plain Python. This
means both checkpointing and on-the-fly visualization can be implemented
entirely inside af.NSS._fit — no upstream PR needed.

[... full prompt as authored, truncated here for brevity in the GitHub-rendered issue. See PyAutoPrompt/issued/nss_checkpointing_and_visualization.md for the verbatim source.]

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions