Skip to content

Sparse mode performance and SparseHist input dispatch#129

Open
bendavid wants to merge 12 commits intoWMass:mainfrom
bendavid:sparsedev4
Open

Sparse mode performance and SparseHist input dispatch#129
bendavid wants to merge 12 commits intoWMass:mainfrom
bendavid:sparsedev4

Conversation

@bendavid
Copy link
Copy Markdown
Collaborator

@bendavid bendavid commented Apr 9, 2026

Adds support for sparse histogram input in the TensorWriter.

Significant performance optimizations for sparse mode in
both the TensorWriter and in the Fitter.

Performance optimizations should also give a factor of ~2
improvement for dense mode in the Fitter for large models.

Note that this depends on WMass/wums#25

Two related groups of commits:

Group A — TensorWriter sparse dispatch (7 commits)

  • add option to treat input systematic histograms as difference with respect to nominal
  • add test for sparse mode
  • Support scipy sparse array inputs in TensorWriter and add as_difference option
  • Add multi-systematic dispatch in add_systematic and use wums.SparseHist
  • Add external likelihood term (gradient + hessian) support
  • Add efficient SparseHist multi-systematic dispatch in TensorWriter
  • Speed up TensorWriter for large multi-systematic SparseHist workloads

Group B — Sparse fast path performance (5 commits)

Up to ~20× HVP speedup on the jpsi calibration tensor (76800 bins,
108334 params, 62M-nnz logk, 329M-nnz external sparse Hessian):
HVP 6380 → 320 ms, loss+grad 3010 → 160 ms.

  • inputdata, parsing: prep for sparse fast path with CSR matvec
    — canonicalize sparse index ordering at load time, pre-build a
    CSRSparseMatrix view of logk, add --hvpMethod and
    --noJitCompile CLI options.

  • fitter: dynamic loss/grad/HVP wrappers with jit_compile +
    hvpMethod
    — replace class-level @tf.function decorators with
    instance-level wrappers built dynamically in _make_tf_functions,
    so jit and HVP autodiff mode can be controlled per-fit. Note that
    fwdrev HVP is intentionally never jit-compiled because
    tf.autodiff.ForwardAccumulator does not propagate JVPs through
    XLA-compiled subgraphs.

  • fitter: sparse fast path uses CSR matmul, no dense
    [nbins, nproc]
    — reformulate the sparse branch of
    _compute_yields_noBBB to use tf_sparse_csr.matmul for the
    inner contraction logk @ theta (~8× faster per call than the
    equivalent gather + segment_sum) and never materialize the dense
    [nbins, nproc] grid in the NLL/grad/HVP path. Also forces
    jit_compile=False in sparse mode (CSR matmul has no XLA kernel)
    and falls back to revrev when fwdrev is requested in sparse mode.

  • fitter: external sparse Hessian via CSR matmul — switch the
    external sparse-Hessian likelihood term to use CSR matmul. The
    registered gradient of sm.matmul is itself a single
    sm.matmul, so reverse-over-reverse autodiff no longer
    rematerializes a 2D gather/scatter chain in the second-order
    tape. On the jpsi 329M-nnz prefit Hessian this was the dominant
    HVP cost.

  • rabbit_fit, setup.sh: enable XLA multi-threaded Eigen on CPU
    set XLA_FLAGS=--xla_cpu_multi_thread_eigen=true so XLA's CPU
    emitter uses Eigen's multi-threaded routines for the dense
    matmuls jit_compile=True generates. ~1.3× speedup on dense
    large-model HVP/loss+grad on a many-core system. Set both in
    setup.sh (for sourced shells) and at the very top of
    bin/rabbit_fit.py before any TF import (for direct invocation).

bendavid and others added 12 commits April 9, 2026 08:16
…ce option

Add `as_difference` parameter to `add_systematic` to interpret input histograms
as differences from nominal. Add full scipy sparse array support for `add_process`
and `add_systematic`: in sparse mode, norm is stored as flat CSR and logk is
computed only at nonzero positions, avoiding full-size dense intermediates.
Extend test_sparse_fit.py to cover all modes including scipy sparse inputs.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
add_systematic now detects extra axes in the input histogram beyond the
channel axes (or via an explicit syst_axes argument) and books one
systematic per bin combination on those extra axes, with auto-generated
names from the bin labels. Works for hist inputs as well as for SparseHist
inputs from wums, in both dense and sparse TensorWriter modes.

The local SparseHist implementation has been moved to wums.sparse_hist and
is re-exported here for convenience. SparseHist now always uses the with-flow
layout internally, and the writer extracts either the with-flow or no-flow
representation depending on the channel's flow setting.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
TensorWriter.add_external_likelihood_term accepts a 1D hist for the
gradient and a 2D hist (or wums.SparseHist) for the hessian, both indexed
by hist.axis.StrCategory axes whose bin labels identify the parameters.
Both grad and hess (when provided together) must use the same parameter
list in the same order; the matrix is indexed by a single parameter list.
Multiple terms can be added with distinct names. Sparse hessians via
SparseHist preserve sparsity through the writer and the fit.

The terms are serialized under an external_terms HDF5 group, loaded back
in FitInputData, and resolved against the full fit parameter list (POIs +
systs) at Fitter init. Fitter._compute_external_nll adds an additive
g^T x_sub + 0.5 x_sub^T H x_sub contribution to the NLL, fully
differentiable through TF autodiff so all existing loss_val_grad and
hessian methods pick it up automatically.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
The generic _get_systematic_slices loop calls h[slice_dict] once per
combination on the extra (systematic) axes, which for SparseHist input
is O(nnz) per slice and prohibitively slow when there are many extra
bins (e.g. ~108k corparms over a ~31M nnz SparseHist would take hours).

Add a fast path that pre-extracts the with-flow flat representation
once, computes a linear systematic index from the extra-axis
coordinates, sorts globally, and then yields contiguous per-bin runs.
Empty combinations yield an empty SparseHist over the kept axes so the
caller can still book the corresponding systematic name (allowing it
to be constrained by an external term even when the template variation
is identically zero). This is O(nnz log nnz) total instead of O(nnz)
per slice, and supports both single and asymmetric (up/down) inputs.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Several independent optimizations to the writer + write() path. On a
realistic 2-channel jpsi calibration tensor with ~108k corparm
systematics and a 330M-nnz external hessian, total wall time drops
from ~4m30s to ~1m13s.

1. Vectorized SparseHist multi-syst dispatch in add_systematic.
   New _add_systematics_sparsehist_batched does all per-entry math
   (channel flat index, norm lookup, sign-flip-protected logk) once
   over the full ~25M-entry array, partitions by linear systematic
   index via a single argsort + searchsorted, and bulk-inserts
   per-syst (indices, values) directly into dict_logkavg /
   dict_logkavg_indices. Empty bin combinations still get an entry
   and a corresponding book_systematic call so they appear in the
   fit parameter list and can be constrained externally. Triggered
   when the input is a single SparseHist with extra axes plus
   mirror=True, as_difference=True, no add_to_data_covariance.
   Per-channel booking goes from ~93s to ~9s.

2. Pre-allocate sparse assembly buffers in write(). The previous
   loop grew norm_sparse_* and logk_sparse_* via np.ndarray.resize
   once per (channel, process, syst), which is O(N^2) total because
   each resize allocates a new buffer and copies all elements. A
   quick first pass over the dict structures now computes the total
   nnz so the buffers can be allocated once and filled in place.

3. Replace list.index() with a dict in get_groups,
   get_constraintweights, get_noiidxs. The old code did
   systs.index(name) once per group member, giving O(nsysts*nmembers)
   behaviour: with 108k systs all in a single corparms group this was
   the dominant cost of write(), eating ~75 seconds.

4. Skip the unnecessary to_flat_csr sort in
   add_external_likelihood_term. For SparseHist hess input, access
   _flat_indices/_values directly and recover (rows, cols) via
   np.divmod, instead of going through to_flat_csr(flow=False) which
   sorts ~330M entries we then never read in order. ~30s saved.

5. Switch h5py compression from gzip to Blosc2 LZ4 in
   h5pyutils_write. ~5x faster on integer arrays at slightly better
   compression ratios. h5pyutils_read imports hdf5plugin so the
   filter is registered for read-back.

6. Add a compress=True parameter to writeFlatInChunks and have
   writeSparse pass compress=False for the values payload of an
   explicitly sparse tensor. Densely packed nonzero floats from real
   physics tensors compress only ~4% at 5x the write cost, so the
   compression is pure overhead there. Index buffers continue to
   compress (~10x ratio with negligible overhead).

Also adds a regression test in test_multi_systematic.py that
constructs a multi-syst SparseHist and asserts the batched fast path
produces bit-identical hnorm/hlogk to per-syst manual booking, with
log_normal + as_difference=True and entries that exercise the
logkepsilon sign-flip fallback.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Three preparatory changes that the fitter changes in following
commits will rely on:

  * inputdata.py: in sparse mode, call tf.sparse.reorder on norm and
    logk at load time to canonicalize their indices into row-major
    order. The fitter sparse fast path reduces nonzero entries via
    row-keyed reductions, which want coalesced memory access on the
    sorted indices.

  * inputdata.py: pre-build a tf.linalg.sparse.CSRSparseMatrix view
    of logk so the fitter can use sm.matmul (a multi-threaded CSR
    kernel) for the inner contraction logk @ theta. SparseMatrixMatMul
    has no XLA kernel, so any tf.function calling it must be built
    with jit_compile=False; the fitter handles this in sparse mode.

  * parsing.py: add --hvpMethod {revrev,fwdrev} to choose the
    autodiff mode for the Hessian-vector product, and --noJitCompile
    to disable XLA jit_compile (on by default in dense mode).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Replace the class-level @tf.function decorators on loss_val,
loss_val_grad, and loss_val_grad_hessp_{fwdrev,revrev} with
instance-level wrappers built dynamically in _make_tf_functions()
at construction time. This lets jit_compile and the HVP autodiff
mode be controlled per-fit via --jitCompile / --hvpMethod without
class-level redefinition.

  * --jitCompile (on by default): wraps loss/grad and revrev HVP
    with tf.function(jit_compile=True). The fwdrev HVP wrapper is
    intentionally NOT jit-compiled because tf.autodiff.Forward-
    Accumulator does not propagate JVPs through XLA-compiled
    subgraphs (the JVP comes back as zero), regardless of inner/
    outer placement of jit_compile.

  * --hvpMethod {revrev,fwdrev}: selects which underlying HVP
    wrapper is bound to self.loss_val_grad_hessp.

The dynamic wrappers are also stripped and rebuilt in __deepcopy__,
since the FuncGraph state held by an already-traced tf.function
cannot be deepcopy'd. _compute_loss is collapsed to a one-liner
since its only job is to dispatch to _compute_nll.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Reformulate the sparse branch of _compute_yields_noBBB so that the
NLL/grad/HVP path never materializes the dense [nbinsfull, nproc]
intermediate, and uses tf.linalg.sparse's CSR SparseMatrixMatMul
for the dominant inner contraction logk @ theta. The CSR kernel is
multi-threaded and ~8x faster per call than the equivalent
gather + unsorted_segment_sum that the previous form lowered to
under TF on CPU.

Changes:

  * _compute_yields_noBBB takes a new compute_norm flag. The dense
    [nbinsfull, nproc] normcentral grid is only built when an
    external caller actually wants per-process yields, or when
    binByBinStat "full" mode needs them for the analytic beta
    solution. The NLL/grad/HVP path passes compute_norm=False.

  * Sparse branch: replace tf.sparse.sparse_dense_matmul(logk, ...)
    with tf_sparse_csr.matmul(logk_csr, ...) on the pre-built CSR
    view from inputdata.py.

  * Sparse branch: collapse to per-bin yields via
    tf.math.unsorted_segment_sum on the modified sparse values
    keyed by bin index, equivalent to but cheaper than
    tf.sparse.reduce_sum at this scale.

  * _compute_yields_with_beta plumbs need_norm correctly so the
    bbb-lite path doesn't pay for the dense materialization.

  * _expected_yield_noBBB explicitly passes compute_norm=False.

  * _make_tf_functions: SparseMatrixMatMul has no XLA kernel, so
    force jit_compile=False on all wrappers in sparse mode
    regardless of the user's --jitCompile setting.

  * _make_tf_functions: tf.autodiff.ForwardAccumulator cannot
    trace tangents through SparseMatrixMatMul (no JVP rule for the
    CSR variant), so when --hvpMethod=fwdrev is requested in sparse
    mode, fall back to revrev with a warning.

Profile on the jpsi calibration tensor (76800 bins, 108334 params,
62M-nnz logk): HVP per call drops from ~6400 ms to ~320 ms (~20x
speedup), loss+grad from ~3000 ms to ~160 ms.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Switch the external sparse-Hessian likelihood term to use
tf.linalg.sparse's CSR SparseMatrixMatMul instead of an element-wise
gather-based 0.5 x^T H x form. The CSR matmul kernel is multi-
threaded, and crucially its registered gradient is itself a single
sm.matmul call, so reverse-over-reverse autodiff no longer
rematerializes a 2D gather/scatter chain in the second-order tape.
On large external-Hessian problems this was the dominant HVP cost.

Changes:

  * Fitter.__init__ external_terms loop: replace the "hess_sparse"
    (rows, cols, vals) tuple with a "hess_csr" CSRSparseMatrix view
    of the canonically-sorted SparseTensor, built once per term.

  * _compute_external_nll: dispatch on "hess_csr" instead of
    "hess_sparse" and compute 0.5 * x_sub^T (H @ x_sub) via
    tf_sparse_csr.matmul.

Profile on the jpsi calibration tensor (329M-nnz prefit external
Hessian on 108332 of the 108334 fit parameters): the closed-form
external HVP path that previously dominated the second-order tape
collapses to a single CSR matvec per HVP call, contributing
negligibly to the per-call cost.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Set XLA_FLAGS=--xla_cpu_multi_thread_eigen=true so XLA's CPU emitter
uses Eigen's multi-threaded routines for the dense linear-algebra
ops generated by jit_compile=True. This is a free win on dense
fits with no downside on sparse mode (where the dominant ops have
no parallel CPU kernel anyway). Measured ~1.3x speedup on dense
large-model HVP and loss+grad on a many-core system:

  default                                 HVP 51.1 ms  L+G 31.2 ms
  --xla_cpu_multi_thread_eigen=true       HVP 39.1 ms  L+G 23.0 ms

The flag is set in two places:

  * setup.sh: exported when users source the rabbit setup script.
    Append-only so any user-set XLA_FLAGS survive.

  * bin/rabbit_fit.py: also set programmatically at the very top
    of the script (before any TF import) so users who launch
    rabbit_fit.py directly without sourcing setup.sh still get
    the speedup. Same append-only logic.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant