sc/kernels: apply Owen scramble at full stoc_len too#11
Conversation
When `stoc_len == 2**sc_prec`, `_prepare_rng_prefix` previously skipped
the per-dim XOR scramble. The comment claimed scrambling was "not
needed" at full length because the marginal `count(r<v)=v` over a
Sobol-N permutation is invariant under XOR.
That reasoning only covers the marginal. The enable-signal matmul reads
joint counts `|{t: rng_a[d,t]<ba AND rng_b[d,t]<bb}|`, which depend on
the per-d (rng_a, rng_b) *trajectory*, not just marginals. With the
default `make_sobol_simple_config` broadcasting the same Sobol-Q/Sobol-K
pair across all D dims, skipping the scramble at full length left every
dim with an identical joint trajectory. SC noise across D then
accumulated as a single biased estimator instead of averaging across
independent estimators.
Effect on Llama-3.1-8B-Instruct PPL (wikitext-2 test, ctx=1024 stride=512,
per_row, sc_prec=8):
FP16 6.7711 x1.000
INT8 per_row det. 6.9328 x1.024 (deterministic floor)
SC sl=128 (Owen) 7.1771 x1.060
SC sl=256 (no Owen) 7.9383 x1.172 <- worse than INT8 floor
After this change SC at sl=256 also goes through `_owen_scramble`, giving
each dim a distinct XOR mask and recovering cross-D averaging. The
prefix-vs-full distinction is no longer load-bearing, so the `is_prefix`
guard is removed entirely.
Owen scramble itself is unchanged; only the gate is widened.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
There was a problem hiding this comment.
Pull request overview
This PR updates the SC enable-table RNG preparation logic so that the per-dimension Owen XOR scramble is applied even when using the full-length Sobol stream (stoc_len == 2**sc_prec). This aims to reduce cross-dimension correlation in joint (rng_a, rng_b) trajectories, improving noise averaging behavior across D dimensions.
Changes:
- Always apply
_owen_scramble()in_prepare_rng_prefix()whengrid_levels == 2**sc_prec, including for full-length streams. - Remove the previous “no scramble needed at full length” special-casing and replace it with updated rationale in comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| prefix = rng[:, :stoc_len].contiguous() if stoc_len < rng.shape[1] else rng | ||
| if grid_levels == base_levels: | ||
| # Fixed-level path: if we're truncating a longer Sobol sequence, apply | ||
| # Owen scramble to break the prefix stratification artifact. When the | ||
| # sequence is used in full (non-truncated), no scramble is needed. | ||
| if is_prefix: | ||
| return _owen_scramble(prefix, base_levels) | ||
| return prefix | ||
| # Per-dim Owen XOR — even at full length, this decorrelates the joint | ||
| # (rng_a, rng_b) trajectory across dimensions; without it, all D dims | ||
| # share the same joint, and SC noise accumulates instead of averaging. | ||
| return _owen_scramble(prefix, base_levels) |
|
Per Allen's recommendation, dropping this implementation approach. Closing the PR; remote branch |
Summary
_prepare_rng_prefixpreviously skipped the per-dim XOR scramble whenstoc_len == 2**sc_prec(full-length Sobol), with the comment "no scramble is needed". That reasoning only holds for the marginal countcount(r<v)=v, which is invariant under XOR over a Sobol permutation.The enable-signal matmul actually reads joint counts
|{t : rng_a[d,t]<ba AND rng_b[d,t]<bb}|which depend on the per-d trajectory, not just marginals. Withmake_sobol_simple_configbroadcasting the same Sobol-Q/Sobol-K pair across allDdims, skipping the scramble at full length left every dim with an identical joint trajectory, so SC noise acrossDaccumulated instead of averaging out.Effect
Llama-3.1-8B-Instruct PPL on wikitext-2 test (ctx=1024, stride=512, per_row, sc_prec=8):
sl=128(Owen applied)sl=256(no Owen, prior)SC
sl=256was worse than its own deterministic INT8 floor. After this changesl=256also runs through Owen scramble; validation run pending.Why this is safe
_owen_scrambleonly depends onprefix.shape[0](D); it doesn't care whether the input is a prefix or the full sequence.[0, base_levels), so each per-d sequence is still a permutation of[0, base_levels). Marginals (and thereforek_table) are unchanged.grid_levels != base_levels) is untouched.Status
Proposed fix. Owen-at-full-length validation run (
SC sl=256, full wikitext-2) is currently in flight; results will be posted in a follow-up comment.Test plan
sl=256per_row PPL drops from x1.172 toward x1.024 (INT8 floor) or bettersl=128per_row PPL unchanged (it already went through Owen)sl>=96(sanity)🤖 Generated with Claude Code