Sparse mode performance and SparseHist input dispatch#129
Open
bendavid wants to merge 12 commits intoWMass:mainfrom
Open
Sparse mode performance and SparseHist input dispatch#129bendavid wants to merge 12 commits intoWMass:mainfrom
bendavid wants to merge 12 commits intoWMass:mainfrom
Conversation
…ce option Add `as_difference` parameter to `add_systematic` to interpret input histograms as differences from nominal. Add full scipy sparse array support for `add_process` and `add_systematic`: in sparse mode, norm is stored as flat CSR and logk is computed only at nonzero positions, avoiding full-size dense intermediates. Extend test_sparse_fit.py to cover all modes including scipy sparse inputs. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
add_systematic now detects extra axes in the input histogram beyond the channel axes (or via an explicit syst_axes argument) and books one systematic per bin combination on those extra axes, with auto-generated names from the bin labels. Works for hist inputs as well as for SparseHist inputs from wums, in both dense and sparse TensorWriter modes. The local SparseHist implementation has been moved to wums.sparse_hist and is re-exported here for convenience. SparseHist now always uses the with-flow layout internally, and the writer extracts either the with-flow or no-flow representation depending on the channel's flow setting. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
TensorWriter.add_external_likelihood_term accepts a 1D hist for the gradient and a 2D hist (or wums.SparseHist) for the hessian, both indexed by hist.axis.StrCategory axes whose bin labels identify the parameters. Both grad and hess (when provided together) must use the same parameter list in the same order; the matrix is indexed by a single parameter list. Multiple terms can be added with distinct names. Sparse hessians via SparseHist preserve sparsity through the writer and the fit. The terms are serialized under an external_terms HDF5 group, loaded back in FitInputData, and resolved against the full fit parameter list (POIs + systs) at Fitter init. Fitter._compute_external_nll adds an additive g^T x_sub + 0.5 x_sub^T H x_sub contribution to the NLL, fully differentiable through TF autodiff so all existing loss_val_grad and hessian methods pick it up automatically. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
The generic _get_systematic_slices loop calls h[slice_dict] once per combination on the extra (systematic) axes, which for SparseHist input is O(nnz) per slice and prohibitively slow when there are many extra bins (e.g. ~108k corparms over a ~31M nnz SparseHist would take hours). Add a fast path that pre-extracts the with-flow flat representation once, computes a linear systematic index from the extra-axis coordinates, sorts globally, and then yields contiguous per-bin runs. Empty combinations yield an empty SparseHist over the kept axes so the caller can still book the corresponding systematic name (allowing it to be constrained by an external term even when the template variation is identically zero). This is O(nnz log nnz) total instead of O(nnz) per slice, and supports both single and asymmetric (up/down) inputs. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Several independent optimizations to the writer + write() path. On a realistic 2-channel jpsi calibration tensor with ~108k corparm systematics and a 330M-nnz external hessian, total wall time drops from ~4m30s to ~1m13s. 1. Vectorized SparseHist multi-syst dispatch in add_systematic. New _add_systematics_sparsehist_batched does all per-entry math (channel flat index, norm lookup, sign-flip-protected logk) once over the full ~25M-entry array, partitions by linear systematic index via a single argsort + searchsorted, and bulk-inserts per-syst (indices, values) directly into dict_logkavg / dict_logkavg_indices. Empty bin combinations still get an entry and a corresponding book_systematic call so they appear in the fit parameter list and can be constrained externally. Triggered when the input is a single SparseHist with extra axes plus mirror=True, as_difference=True, no add_to_data_covariance. Per-channel booking goes from ~93s to ~9s. 2. Pre-allocate sparse assembly buffers in write(). The previous loop grew norm_sparse_* and logk_sparse_* via np.ndarray.resize once per (channel, process, syst), which is O(N^2) total because each resize allocates a new buffer and copies all elements. A quick first pass over the dict structures now computes the total nnz so the buffers can be allocated once and filled in place. 3. Replace list.index() with a dict in get_groups, get_constraintweights, get_noiidxs. The old code did systs.index(name) once per group member, giving O(nsysts*nmembers) behaviour: with 108k systs all in a single corparms group this was the dominant cost of write(), eating ~75 seconds. 4. Skip the unnecessary to_flat_csr sort in add_external_likelihood_term. For SparseHist hess input, access _flat_indices/_values directly and recover (rows, cols) via np.divmod, instead of going through to_flat_csr(flow=False) which sorts ~330M entries we then never read in order. ~30s saved. 5. Switch h5py compression from gzip to Blosc2 LZ4 in h5pyutils_write. ~5x faster on integer arrays at slightly better compression ratios. h5pyutils_read imports hdf5plugin so the filter is registered for read-back. 6. Add a compress=True parameter to writeFlatInChunks and have writeSparse pass compress=False for the values payload of an explicitly sparse tensor. Densely packed nonzero floats from real physics tensors compress only ~4% at 5x the write cost, so the compression is pure overhead there. Index buffers continue to compress (~10x ratio with negligible overhead). Also adds a regression test in test_multi_systematic.py that constructs a multi-syst SparseHist and asserts the batched fast path produces bit-identical hnorm/hlogk to per-syst manual booking, with log_normal + as_difference=True and entries that exercise the logkepsilon sign-flip fallback. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Three preparatory changes that the fitter changes in following
commits will rely on:
* inputdata.py: in sparse mode, call tf.sparse.reorder on norm and
logk at load time to canonicalize their indices into row-major
order. The fitter sparse fast path reduces nonzero entries via
row-keyed reductions, which want coalesced memory access on the
sorted indices.
* inputdata.py: pre-build a tf.linalg.sparse.CSRSparseMatrix view
of logk so the fitter can use sm.matmul (a multi-threaded CSR
kernel) for the inner contraction logk @ theta. SparseMatrixMatMul
has no XLA kernel, so any tf.function calling it must be built
with jit_compile=False; the fitter handles this in sparse mode.
* parsing.py: add --hvpMethod {revrev,fwdrev} to choose the
autodiff mode for the Hessian-vector product, and --noJitCompile
to disable XLA jit_compile (on by default in dense mode).
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Replace the class-level @tf.function decorators on loss_val,
loss_val_grad, and loss_val_grad_hessp_{fwdrev,revrev} with
instance-level wrappers built dynamically in _make_tf_functions()
at construction time. This lets jit_compile and the HVP autodiff
mode be controlled per-fit via --jitCompile / --hvpMethod without
class-level redefinition.
* --jitCompile (on by default): wraps loss/grad and revrev HVP
with tf.function(jit_compile=True). The fwdrev HVP wrapper is
intentionally NOT jit-compiled because tf.autodiff.Forward-
Accumulator does not propagate JVPs through XLA-compiled
subgraphs (the JVP comes back as zero), regardless of inner/
outer placement of jit_compile.
* --hvpMethod {revrev,fwdrev}: selects which underlying HVP
wrapper is bound to self.loss_val_grad_hessp.
The dynamic wrappers are also stripped and rebuilt in __deepcopy__,
since the FuncGraph state held by an already-traced tf.function
cannot be deepcopy'd. _compute_loss is collapsed to a one-liner
since its only job is to dispatch to _compute_nll.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Reformulate the sparse branch of _compute_yields_noBBB so that the
NLL/grad/HVP path never materializes the dense [nbinsfull, nproc]
intermediate, and uses tf.linalg.sparse's CSR SparseMatrixMatMul
for the dominant inner contraction logk @ theta. The CSR kernel is
multi-threaded and ~8x faster per call than the equivalent
gather + unsorted_segment_sum that the previous form lowered to
under TF on CPU.
Changes:
* _compute_yields_noBBB takes a new compute_norm flag. The dense
[nbinsfull, nproc] normcentral grid is only built when an
external caller actually wants per-process yields, or when
binByBinStat "full" mode needs them for the analytic beta
solution. The NLL/grad/HVP path passes compute_norm=False.
* Sparse branch: replace tf.sparse.sparse_dense_matmul(logk, ...)
with tf_sparse_csr.matmul(logk_csr, ...) on the pre-built CSR
view from inputdata.py.
* Sparse branch: collapse to per-bin yields via
tf.math.unsorted_segment_sum on the modified sparse values
keyed by bin index, equivalent to but cheaper than
tf.sparse.reduce_sum at this scale.
* _compute_yields_with_beta plumbs need_norm correctly so the
bbb-lite path doesn't pay for the dense materialization.
* _expected_yield_noBBB explicitly passes compute_norm=False.
* _make_tf_functions: SparseMatrixMatMul has no XLA kernel, so
force jit_compile=False on all wrappers in sparse mode
regardless of the user's --jitCompile setting.
* _make_tf_functions: tf.autodiff.ForwardAccumulator cannot
trace tangents through SparseMatrixMatMul (no JVP rule for the
CSR variant), so when --hvpMethod=fwdrev is requested in sparse
mode, fall back to revrev with a warning.
Profile on the jpsi calibration tensor (76800 bins, 108334 params,
62M-nnz logk): HVP per call drops from ~6400 ms to ~320 ms (~20x
speedup), loss+grad from ~3000 ms to ~160 ms.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Switch the external sparse-Hessian likelihood term to use
tf.linalg.sparse's CSR SparseMatrixMatMul instead of an element-wise
gather-based 0.5 x^T H x form. The CSR matmul kernel is multi-
threaded, and crucially its registered gradient is itself a single
sm.matmul call, so reverse-over-reverse autodiff no longer
rematerializes a 2D gather/scatter chain in the second-order tape.
On large external-Hessian problems this was the dominant HVP cost.
Changes:
* Fitter.__init__ external_terms loop: replace the "hess_sparse"
(rows, cols, vals) tuple with a "hess_csr" CSRSparseMatrix view
of the canonically-sorted SparseTensor, built once per term.
* _compute_external_nll: dispatch on "hess_csr" instead of
"hess_sparse" and compute 0.5 * x_sub^T (H @ x_sub) via
tf_sparse_csr.matmul.
Profile on the jpsi calibration tensor (329M-nnz prefit external
Hessian on 108332 of the 108334 fit parameters): the closed-form
external HVP path that previously dominated the second-order tape
collapses to a single CSR matvec per HVP call, contributing
negligibly to the per-call cost.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Set XLA_FLAGS=--xla_cpu_multi_thread_eigen=true so XLA's CPU emitter
uses Eigen's multi-threaded routines for the dense linear-algebra
ops generated by jit_compile=True. This is a free win on dense
fits with no downside on sparse mode (where the dominant ops have
no parallel CPU kernel anyway). Measured ~1.3x speedup on dense
large-model HVP and loss+grad on a many-core system:
default HVP 51.1 ms L+G 31.2 ms
--xla_cpu_multi_thread_eigen=true HVP 39.1 ms L+G 23.0 ms
The flag is set in two places:
* setup.sh: exported when users source the rabbit setup script.
Append-only so any user-set XLA_FLAGS survive.
* bin/rabbit_fit.py: also set programmatically at the very top
of the script (before any TF import) so users who launch
rabbit_fit.py directly without sourcing setup.sh still get
the speedup. Same append-only logic.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Adds support for sparse histogram input in the TensorWriter.
Significant performance optimizations for sparse mode in
both the TensorWriter and in the Fitter.
Performance optimizations should also give a factor of ~2
improvement for dense mode in the Fitter for large models.
Note that this depends on WMass/wums#25
Two related groups of commits:
Group A — TensorWriter sparse dispatch (7 commits)
add option to treat input systematic histograms as difference with respect to nominaladd test for sparse modeSupport scipy sparse array inputs in TensorWriter and add as_difference optionAdd multi-systematic dispatch in add_systematic and use wums.SparseHistAdd external likelihood term (gradient + hessian) supportAdd efficient SparseHist multi-systematic dispatch in TensorWriterSpeed up TensorWriter for large multi-systematic SparseHist workloadsGroup B — Sparse fast path performance (5 commits)
Up to ~20× HVP speedup on the jpsi calibration tensor (76800 bins,
108334 params, 62M-nnz logk, 329M-nnz external sparse Hessian):
HVP 6380 → 320 ms, loss+grad 3010 → 160 ms.
inputdata, parsing: prep for sparse fast path with CSR matvec
— canonicalize sparse index ordering at load time, pre-build a
CSRSparseMatrixview oflogk, add--hvpMethodand--noJitCompileCLI options.fitter: dynamic loss/grad/HVP wrappers with jit_compile +
hvpMethod — replace class-level
@tf.functiondecorators withinstance-level wrappers built dynamically in
_make_tf_functions,so jit and HVP autodiff mode can be controlled per-fit. Note that
fwdrev HVP is intentionally never jit-compiled because
tf.autodiff.ForwardAccumulatordoes not propagate JVPs throughXLA-compiled subgraphs.
fitter: sparse fast path uses CSR matmul, no dense
[nbins, nproc] — reformulate the sparse branch of
_compute_yields_noBBBto usetf_sparse_csr.matmulfor theinner contraction
logk @ theta(~8× faster per call than theequivalent gather + segment_sum) and never materialize the dense
[nbins, nproc]grid in the NLL/grad/HVP path. Also forcesjit_compile=False in sparse mode (CSR matmul has no XLA kernel)
and falls back to revrev when fwdrev is requested in sparse mode.
fitter: external sparse Hessian via CSR matmul — switch the
external sparse-Hessian likelihood term to use CSR matmul. The
registered gradient of
sm.matmulis itself a singlesm.matmul, so reverse-over-reverse autodiff no longerrematerializes a 2D gather/scatter chain in the second-order
tape. On the jpsi 329M-nnz prefit Hessian this was the dominant
HVP cost.
rabbit_fit, setup.sh: enable XLA multi-threaded Eigen on CPU —
set
XLA_FLAGS=--xla_cpu_multi_thread_eigen=trueso XLA's CPUemitter uses Eigen's multi-threaded routines for the dense
matmuls jit_compile=True generates. ~1.3× speedup on dense
large-model HVP/loss+grad on a many-core system. Set both in
setup.sh(for sourced shells) and at the very top ofbin/rabbit_fit.pybefore any TF import (for direct invocation).