From 52e44f5af7ce50d162804e64744a85c4c0c16103 Mon Sep 17 00:00:00 2001 From: Claude Date: Wed, 27 May 2026 16:07:14 +0000 Subject: [PATCH 1/3] =?UTF-8?q?feat(codec):=20PR-X12=20v1=20=E2=80=94=20?= =?UTF-8?q?=CE=BB-RDO=20mode=20selection=20+=20rANS=20entropy=20coder?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Completes the x265-shaped cognitive codec to a usable v1 by adding the two stages the board tracked as missing (RDO + ANS): - `codec::rdo` (A6) — λ-rate-distortion mode selection: scores all four modes by `(rate << 8) + λ_q8·distortion` and picks the minimum. The soft, cost-weighted generalization of `predict_intra`'s hard tree. Uses an integer fixed-point λ (λ × 256) rather than the design's f32, for deterministic, cross-platform-bit-exact decisions consistent with the substrate's no-float discipline. - `codec::ans` (A7) — static-table rANS over the 4-symbol mode alphabet (Skip/Merge/Delta/Escape). Self-contained stream (count + normalized freq table + payload); bit-exact round-trip. Chosen over CABAC per the design's open-Q1 ruling. Grounded against the existing integer-only LeafCu/CellMode/MergeDir and the 2/3/3/6-byte wire format; transform (A4) stays deferred to v2 per design Q2, stream (A8) is the remaining follow-on. 81 lib tests + 20 doctests green; clippy -D warnings clean. https://claude.ai/code/session_01HbqooFZHAjaUtFEzhA1R2u --- .claude/blackboard.md | 6 +- src/hpc/codec/ans.rs | 527 ++++++++++++++++++++++++++++++++++++++++++ src/hpc/codec/mod.rs | 15 +- src/hpc/codec/rdo.rs | 464 +++++++++++++++++++++++++++++++++++++ 4 files changed, 1007 insertions(+), 5 deletions(-) create mode 100644 src/hpc/codec/ans.rs create mode 100644 src/hpc/codec/rdo.rs diff --git a/.claude/blackboard.md b/.claude/blackboard.md index 250894b1..8e1a182b 100644 --- a/.claude/blackboard.md +++ b/.claude/blackboard.md @@ -34,8 +34,10 @@ ## Consolidation-sprint debt (PR-X program; ground-truthed `ls src/hpc/` 2026-05-27) > Shipped-state vs `pr-master-consolidation.md`. Landed: ✅ **PR-X10** `linalg/`, > ✅ **PR-X11** `pillar/`, ✅ **PR-X13** `ogit_bridge/`, ✅ **PR-X3** `blocked_grid/`. -- **PR-X12 codec ⚠️ PARTIAL** — `src/hpc/codec/` has `ctu/mode/predict/mod` only; - **RDO + ANS entropy stages missing**. Doc-canon merged (#198/#205), kernel half-built. +- **PR-X12 codec ⚠️ v1 NEAR-COMPLETE** — `ctu/mode/predict` + now **`rdo` (A6, λ-RDO, + integer fixed-point λ_q8 — no float) + `ans` (A7, static-table rANS over the 4-symbol + mode alphabet, bit-exact round-trip)**. Remaining: `transform` (A4, deferred to v2 per + design Q2) + `stream` (A8, framing over `ans`). 81 lib + 20 doctests green, clippy clean. - **PR-X4 splat4d ❌ OUTSTANDING** — no `src/hpc/splat4d/`. Unbuilt. - **PR-X9 cognitive ❌ OUTSTANDING** — no `src/hpc/cognitive/`. Unbuilt; must **consume** `lance-graph-contract::splat::CamPlaneSplat` (q8), never redefine it (contract is sacred). diff --git a/src/hpc/codec/ans.rs b/src/hpc/codec/ans.rs new file mode 100644 index 00000000..0ae57c9b --- /dev/null +++ b/src/hpc/codec/ans.rs @@ -0,0 +1,527 @@ +//! rANS entropy coder over the 4-symbol mode alphabet (PR-X12 A7). +//! +//! The codec's mode-decision pass ([`super::predict`] / [`super::rdo`]) +//! emits one [`CellMode`] tag per cell. Those tags are *not* uniformly +//! distributed — coherent cognitive state is dominated by `Skip` +//! (~70%) with a long tail of `Merge`/`Delta`/`Escape` (see the design +//! doc's compression target). A 2-bit fixed encoding wastes the +//! redundancy; an entropy coder spends `-log2(p)` bits per tag instead. +//! +//! This module ships a **range Asymmetric Numeral System** (rANS) +//! coder, chosen over CABAC per the design's open-Q1 ruling: a single +//! multiply + table lookup per symbol (cache-friendly), no per-bit +//! context-state branches, and within ~0.5% of CABAC's ratio on typical +//! streams (the variant zstd uses). It encodes/decodes the **mode-tag +//! stream only** — the per-mode payloads (`basin_idx`, `delta`, +//! `escape_idx`) are packed separately by [`super::mode`]. +//! +//! # Static per-block frequency table (not backward-adaptive) +//! +//! v1 uses a **static** frequency table computed from the block's own +//! mode histogram and written into the stream header. This is bit-exact +//! and free of the adaptation-replay hazard that backward-adaptive rANS +//! has (the encoder processes symbols in reverse while the decoder +//! processes them forward, so an adaptive model would have to be +//! replayed in two directions). "Per-CTU" granularity (design open-Q5) +//! is achieved by calling [`encode_modes`] once per CTU's tag stream. +//! Backward-adaptive coding is a documented follow-on. +//! +//! # Stream layout (`encode_modes` / `decode_modes`) +//! +//! ```text +//! bytes [0..4) symbol count n (u32, little-endian) +//! bytes [4..12) normalized freq table (4 × u16, little-endian) +//! bytes [12..) rANS payload (state-flushed bytes) +//! ``` +//! +//! The freq table entries sum to [`RANS_M`] (= 4096); a symbol that +//! never appears has frequency 0 and is never encoded. The payload is +//! self-delimiting given `n` and the table — the decoder reads exactly +//! the bytes the encoder produced. +//! +//! # Numerics +//! +//! State is `u32`, kept in the normalized interval `[RANS_BYTE_L, +//! RANS_BYTE_L << 8)`. With [`RANS_SCALE_BITS`] = 12 and +//! [`RANS_BYTE_L`] = `1 << 23`, every intermediate (`x / f`, +//! `(x / f) << 12`, `x_max = (RANS_BYTE_L >> 12 << 8) * f`) stays within +//! `u32` for all `f ∈ [1, 4096]`. No `unsafe`, no float. +//! +//! # What this module does NOT do +//! +//! - **Payload coding** — `basin_idx` / `delta` / `escape_idx` bytes are +//! [`super::mode`]'s responsibility; A7 codes only the mode tags. +//! - **Framing** (frame headers, CTU markers) — PR-X12 A8 `stream.rs`. +//! - **Backward-adaptive frequency models** — documented follow-on. + +use super::ctu::CellMode; + +// ════════════════════════════════════════════════════════════════════ +// Constants +// ════════════════════════════════════════════════════════════════════ + +/// Frequency-table precision in bits. The table's frequencies sum to +/// `1 << RANS_SCALE_BITS`. 12 bits (4096) is ample for a 4-symbol +/// alphabet and keeps every rANS intermediate inside `u32`. +pub const RANS_SCALE_BITS: u32 = 12; + +/// Total of all symbol frequencies (`1 << RANS_SCALE_BITS` = 4096). +pub const RANS_M: u32 = 1 << RANS_SCALE_BITS; + +/// Lower bound of the normalized state interval. The state `x` is kept +/// in `[RANS_BYTE_L, RANS_BYTE_L << 8)` via byte-wise renormalization. +pub const RANS_BYTE_L: u32 = 1 << 23; + +/// Fixed 4-symbol alphabet size (`Skip`/`Merge`/`Delta`/`Escape`). +const ALPHABET: usize = 4; + +/// Header byte length: `u32` count + 4 × `u16` frequencies. +const HEADER_LEN: usize = 4 + ALPHABET * 2; + +// ════════════════════════════════════════════════════════════════════ +// Frequency table +// ════════════════════════════════════════════════════════════════════ + +/// Normalized symbol-frequency table for one block of mode tags. +/// +/// `freq[s]` is symbol `s`'s frequency (summing to [`RANS_M`]); `cum[s]` +/// is the exclusive prefix sum (the symbol's start of its sub-interval +/// in `[0, RANS_M)`). A symbol absent from the source histogram has +/// `freq = 0` and is never emitted. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct RansFreqTable { + /// Per-symbol frequency, indexed by [`CellMode`] discriminant. + /// Sums to [`RANS_M`] for any non-empty histogram. + pub freq: [u16; ALPHABET], + /// Exclusive prefix sum of `freq` — `cum[s] = Σ freq[0..s]`. + pub cum: [u16; ALPHABET], +} + +impl RansFreqTable { + /// Build a normalized table from a raw mode histogram. + /// + /// Each non-zero count is scaled to `count * RANS_M / total`, floored + /// but bumped to a minimum of 1 so any present symbol stays encodable. + /// The rounding drift is absorbed by the most-frequent symbol so the + /// frequencies sum to exactly [`RANS_M`]. + /// + /// An all-zero histogram (no symbols) yields an all-zero table; it is + /// only valid for an empty stream (`n == 0`), which never encodes a + /// symbol. + /// + /// ``` + /// use ndarray::hpc::codec::ans::{RansFreqTable, RANS_M}; + /// // 3:1 Skip:Delta split → frequencies sum to RANS_M. + /// let t = RansFreqTable::from_histogram([3, 0, 1, 0]); + /// assert_eq!(t.freq.iter().map(|&f| f as u32).sum::(), RANS_M); + /// assert!(t.freq[0] > t.freq[2]); // Skip more probable than Delta + /// ``` + pub fn from_histogram(counts: [u32; ALPHABET]) -> Self { + let total: u64 = counts.iter().map(|&c| c as u64).sum(); + let mut freq = [0u16; ALPHABET]; + if total == 0 { + return Self { + freq, + cum: [0; ALPHABET], + }; + } + + let mut sum: u32 = 0; + let mut max_idx = 0usize; + let mut max_count = 0u32; + for (i, &count) in counts.iter().enumerate() { + if count == 0 { + continue; + } + let scaled = (count as u64 * RANS_M as u64) / total; + let f = (scaled as u32).max(1); + freq[i] = f as u16; + sum += f; + if count > max_count { + max_count = count; + max_idx = i; + } + } + + // Absorb floor/clamp drift on the most-frequent symbol so the + // table sums to exactly RANS_M. The max bucket has the largest + // frequency, so a small adjustment can never drive it below 1. + let adjusted = freq[max_idx] as i64 + (RANS_M as i64 - sum as i64); + debug_assert!((1..=RANS_M as i64).contains(&adjusted), "freq adjustment out of range: {adjusted}"); + freq[max_idx] = adjusted as u16; + + Self { + freq, + cum: prefix_sum(&freq), + } + } + + /// Compute the table from a slice of mode tags. + /// + /// ``` + /// use ndarray::hpc::codec::ans::RansFreqTable; + /// use ndarray::hpc::codec::CellMode; + /// let modes = [CellMode::Skip, CellMode::Skip, CellMode::Delta]; + /// let t = RansFreqTable::from_symbols(&modes); + /// assert!(t.freq[CellMode::Skip as usize] > t.freq[CellMode::Delta as usize]); + /// ``` + pub fn from_symbols(symbols: &[CellMode]) -> Self { + let mut counts = [0u32; ALPHABET]; + for &s in symbols { + counts[s as usize] += 1; + } + Self::from_histogram(counts) + } + + /// Reconstruct a table from the 4 little-endian frequencies stored in + /// a stream header. Returns `None` if the frequencies don't sum to + /// [`RANS_M`] (a non-empty table) and aren't all zero (an empty one). + fn from_freqs(freq: [u16; ALPHABET]) -> Option { + let sum: u32 = freq.iter().map(|&f| f as u32).sum(); + if sum != RANS_M && sum != 0 { + return None; + } + Some(Self { + freq, + cum: prefix_sum(&freq), + }) + } +} + +/// Exclusive prefix sum of a 4-entry frequency array. +#[inline] +fn prefix_sum(freq: &[u16; ALPHABET]) -> [u16; ALPHABET] { + let mut cum = [0u16; ALPHABET]; + let mut acc = 0u16; + for i in 0..ALPHABET { + cum[i] = acc; + acc += freq[i]; + } + cum +} + +/// Resolve the slot value `slot ∈ [0, RANS_M)` to its owning symbol. +/// +/// Exactly one symbol with non-zero frequency owns each slot, because +/// the `cum`/`freq` pairs partition `[0, RANS_M)`. +#[inline] +fn symbol_from_slot(slot: u16, table: &RansFreqTable) -> CellMode { + let mut s = 0usize; + while s < ALPHABET { + let f = table.freq[s]; + if f != 0 && slot >= table.cum[s] && slot < table.cum[s] + f { + break; + } + s += 1; + } + // The partition guarantees a hit for any slot < RANS_M; clamp to the + // last symbol defensively so malformed input can't index out of range. + let s = s.min(ALPHABET - 1); + match s { + 0 => CellMode::Skip, + 1 => CellMode::Merge, + 2 => CellMode::Delta, + _ => CellMode::Escape, + } +} + +// ════════════════════════════════════════════════════════════════════ +// Core rANS encode / decode +// ════════════════════════════════════════════════════════════════════ + +/// rANS-encode the mode-tag stream against `table`, returning the +/// state-flushed payload bytes (no header). +/// +/// Symbols are processed in reverse so the decoder recovers them in +/// forward order. The renormalization bytes are emitted LSB-first during +/// processing and the 4-byte final state is appended; the whole buffer +/// is reversed once at the end so a forward read reconstructs the state +/// and then consumes renorm bytes in the right order. +fn rans_encode(symbols: &[CellMode], table: &RansFreqTable) -> Vec { + let mut x: u32 = RANS_BYTE_L; + let mut out: Vec = Vec::new(); + for &s in symbols.iter().rev() { + let si = s as usize; + let f = table.freq[si] as u32; + debug_assert!(f > 0, "encoding a symbol with zero frequency"); + let start = table.cum[si] as u32; + let x_max = ((RANS_BYTE_L >> RANS_SCALE_BITS) << 8) * f; + while x >= x_max { + out.push((x & 0xFF) as u8); + x >>= 8; + } + x = ((x / f) << RANS_SCALE_BITS) + (x % f) + start; + } + // Flush the 4-byte state MSB-first; after the final `reverse()` these + // become the first 4 bytes read, in little-endian order for the + // decoder's init. + out.push((x >> 24) as u8); + out.push((x >> 16) as u8); + out.push((x >> 8) as u8); + out.push((x & 0xFF) as u8); + out.reverse(); + out +} + +/// rANS-decode `n` mode tags from `payload` against `table`. +/// +/// `payload` is the byte slice produced by [`rans_encode`]. Reading is +/// forward: the first 4 bytes initialize the state, the rest feed +/// renormalization. Defensive on truncated input (feeds zero bytes past +/// the end rather than panicking); valid round-trip input never reads +/// past the produced bytes. +fn rans_decode(payload: &[u8], table: &RansFreqTable, n: usize) -> Vec { + let mut out = Vec::with_capacity(n); + if n == 0 { + return out; + } + let mut x = + (payload[0] as u32) | ((payload[1] as u32) << 8) | ((payload[2] as u32) << 16) | ((payload[3] as u32) << 24); + let mut pos = 4usize; + for _ in 0..n { + let slot = (x & (RANS_M - 1)) as u16; + let s = symbol_from_slot(slot, table); + out.push(s); + let si = s as usize; + let f = table.freq[si] as u32; + let start = table.cum[si] as u32; + x = f * (x >> RANS_SCALE_BITS) + slot as u32 - start; + while x < RANS_BYTE_L { + let b = payload.get(pos).copied().unwrap_or(0); + pos += 1; + x = (x << 8) | b as u32; + } + } + out +} + +// ════════════════════════════════════════════════════════════════════ +// Self-contained stream API +// ════════════════════════════════════════════════════════════════════ + +/// Entropy-code a mode-tag stream into a self-contained byte stream +/// (header + payload, see the module-level layout). +/// +/// The returned bytes embed the symbol count and the normalized +/// frequency table, so [`decode_modes`] needs no side channel. An empty +/// input yields a 12-byte all-zero-table header plus the flushed initial +/// state. +/// +/// ``` +/// use ndarray::hpc::codec::ans::{encode_modes, decode_modes}; +/// use ndarray::hpc::codec::CellMode; +/// let modes = [CellMode::Skip, CellMode::Skip, CellMode::Delta, CellMode::Skip]; +/// let stream = encode_modes(&modes); +/// assert_eq!(decode_modes(&stream).unwrap(), modes); +/// ``` +pub fn encode_modes(symbols: &[CellMode]) -> Vec { + let table = RansFreqTable::from_symbols(symbols); + let payload = rans_encode(symbols, &table); + let mut out = Vec::with_capacity(HEADER_LEN + payload.len()); + out.extend_from_slice(&(symbols.len() as u32).to_le_bytes()); + for f in table.freq { + out.extend_from_slice(&f.to_le_bytes()); + } + out.extend_from_slice(&payload); + out +} + +/// Decode a stream produced by [`encode_modes`] back into mode tags. +/// +/// Returns `None` if the stream is shorter than the header, if the +/// embedded frequency table is malformed (doesn't sum to [`RANS_M`] for +/// a non-empty stream), or if a non-empty stream lacks the 4 payload +/// bytes needed to initialize the rANS state. +/// +/// ``` +/// use ndarray::hpc::codec::ans::{encode_modes, decode_modes}; +/// use ndarray::hpc::codec::CellMode; +/// // Empty input round-trips to an empty Vec. +/// assert_eq!(decode_modes(&encode_modes(&[])).unwrap(), Vec::::new()); +/// ``` +pub fn decode_modes(stream: &[u8]) -> Option> { + if stream.len() < HEADER_LEN { + return None; + } + let n = u32::from_le_bytes([stream[0], stream[1], stream[2], stream[3]]) as usize; + let mut freq = [0u16; ALPHABET]; + for (i, slot) in freq.iter_mut().enumerate() { + let lo = stream[4 + 2 * i]; + let hi = stream[5 + 2 * i]; + *slot = u16::from_le_bytes([lo, hi]); + } + let table = RansFreqTable::from_freqs(freq)?; + let payload = &stream[HEADER_LEN..]; + if n > 0 && payload.len() < 4 { + return None; + } + Some(rans_decode(payload, &table, n)) +} + +// ════════════════════════════════════════════════════════════════════ +// Tests +// ════════════════════════════════════════════════════════════════════ + +#[cfg(test)] +mod tests { + use super::*; + + const ALL: [CellMode; 4] = [CellMode::Skip, CellMode::Merge, CellMode::Delta, CellMode::Escape]; + + /// Deterministic xorshift so tests need no rand dependency. + fn xorshift(state: &mut u64) -> u64 { + let mut x = *state; + x ^= x << 13; + x ^= x >> 7; + x ^= x << 17; + *state = x; + x + } + + fn mode_of(v: u64) -> CellMode { + ALL[(v % 4) as usize] + } + + #[test] + fn freq_table_sums_to_m_for_nonempty() { + for counts in [[1, 0, 0, 0], [3, 0, 1, 0], [10, 5, 3, 1], [1, 1, 1, 1], [1000, 1, 1, 1]] { + let t = RansFreqTable::from_histogram(counts); + let sum: u32 = t.freq.iter().map(|&f| f as u32).sum(); + assert_eq!(sum, RANS_M, "counts={counts:?}"); + } + } + + #[test] + fn freq_table_present_symbols_get_nonzero_freq() { + let t = RansFreqTable::from_histogram([1_000_000, 1, 1, 1]); + for s in 0..4 { + assert!(t.freq[s] >= 1, "present symbol {s} must have freq >= 1"); + } + } + + #[test] + fn freq_table_absent_symbols_stay_zero() { + let t = RansFreqTable::from_histogram([5, 0, 5, 0]); + assert_eq!(t.freq[CellMode::Merge as usize], 0); + assert_eq!(t.freq[CellMode::Escape as usize], 0); + } + + #[test] + fn empty_stream_roundtrips() { + let stream = encode_modes(&[]); + assert_eq!(decode_modes(&stream).unwrap(), Vec::::new()); + } + + #[test] + fn single_symbol_roundtrips_all_modes() { + for &m in &ALL { + let stream = encode_modes(&[m]); + assert_eq!(decode_modes(&stream).unwrap(), vec![m], "mode={m:?}"); + } + } + + #[test] + fn all_same_symbol_roundtrips() { + for &m in &ALL { + let symbols = vec![m; 257]; + let stream = encode_modes(&symbols); + assert_eq!(decode_modes(&stream).unwrap(), symbols, "mode={m:?}"); + } + } + + #[test] + fn mixed_stream_roundtrips() { + let symbols = [ + CellMode::Skip, + CellMode::Delta, + CellMode::Skip, + CellMode::Merge, + CellMode::Escape, + CellMode::Skip, + CellMode::Skip, + CellMode::Delta, + ]; + let stream = encode_modes(&symbols); + assert_eq!(decode_modes(&stream).unwrap(), symbols); + } + + #[test] + fn large_random_streams_roundtrip() { + let mut state = 0x1234_5678_9abc_def0u64; + for &len in &[1usize, 2, 16, 100, 1000, 4096] { + let symbols: Vec = (0..len).map(|_| mode_of(xorshift(&mut state))).collect(); + let stream = encode_modes(&symbols); + let decoded = decode_modes(&stream).expect("decode"); + assert_eq!(decoded, symbols, "len={len}"); + } + } + + #[test] + fn skewed_skip_dominant_stream_roundtrips_and_compresses() { + // 70% Skip, 25% Merge, 4.5% Delta, 0.5% Escape — the design's + // coherent-state target distribution. + let mut state = 0xdead_beef_cafe_babeu64; + let symbols: Vec = (0..4000) + .map(|_| { + let r = xorshift(&mut state) % 1000; + if r < 700 { + CellMode::Skip + } else if r < 950 { + CellMode::Merge + } else if r < 995 { + CellMode::Delta + } else { + CellMode::Escape + } + }) + .collect(); + let stream = encode_modes(&symbols); + assert_eq!(decode_modes(&stream).unwrap(), symbols); + // Skewed distribution → well under 2 bits/symbol (vs 8 dense). + let payload_bits = (stream.len() - HEADER_LEN) * 8; + assert!( + payload_bits < symbols.len() * 2, + "expected < 2 bits/symbol on skewed input, got {} bits for {} symbols", + payload_bits, + symbols.len() + ); + } + + #[test] + fn decode_rejects_short_header() { + assert!(decode_modes(&[0u8; 11]).is_none()); + } + + #[test] + fn decode_rejects_malformed_freq_table() { + // n = 1 but freqs sum to something != RANS_M and != 0. + let mut stream = vec![0u8; HEADER_LEN + 4]; + stream[0] = 1; // n = 1 + stream[4] = 1; // freq[0] = 1 → sum 1, neither 0 nor RANS_M + assert!(decode_modes(&stream).is_none()); + } + + #[test] + fn decode_rejects_missing_payload_state() { + // n = 1, valid single-symbol table, but no 4 payload bytes. + let table = RansFreqTable::from_histogram([1, 0, 0, 0]); + let mut stream = Vec::new(); + stream.extend_from_slice(&1u32.to_le_bytes()); + for f in table.freq { + stream.extend_from_slice(&f.to_le_bytes()); + } + stream.extend_from_slice(&[0u8; 2]); // only 2 payload bytes + assert!(decode_modes(&stream).is_none()); + } + + #[test] + fn header_layout_is_count_then_freqs() { + let symbols = [CellMode::Skip, CellMode::Skip, CellMode::Delta]; + let stream = encode_modes(&symbols); + let n = u32::from_le_bytes([stream[0], stream[1], stream[2], stream[3]]); + assert_eq!(n, 3); + let f_skip = u16::from_le_bytes([stream[4], stream[5]]); + let f_delta = u16::from_le_bytes([stream[8], stream[9]]); + assert!(f_skip > f_delta); + } +} diff --git a/src/hpc/codec/mod.rs b/src/hpc/codec/mod.rs index 9a46a24a..bc3e8699 100644 --- a/src/hpc/codec/mod.rs +++ b/src/hpc/codec/mod.rs @@ -12,9 +12,14 @@ //! - [`mode`] — A2: bit-pack / unpack helpers for the on-wire 16-bit //! header + per-mode tail (Skip/Merge/Delta/Escape). //! - [`predict`] — A3-intra: encoder-side mode-decision kernel that -//! picks the cheapest `LeafCu` from a cell + NESW neighbours. -//! - `transform`, `quantize`, `rdo`, `ans`, `stream` — A4-A8, queued as -//! follow-up sprints. +//! picks the cheapest `LeafCu` from a cell + NESW neighbours (hard +//! decision tree). +//! - [`rdo`] — A6: λ-rate-distortion mode selection — the soft, +//! cost-weighted generalization of [`predict`] (integer fixed-point λ). +//! - [`ans`] — A7: rANS entropy coder over the 4-symbol mode alphabet. +//! - `transform` (A4, deferred to v2 per design Q2 — 1-D DCT doesn't +//! help an 8-bit scalar residual) and `stream` (A8, byte-stream +//! framing over [`ans`]) remain as follow-ups. //! //! # Feature gate //! @@ -25,10 +30,13 @@ //! //! `.claude/knowledge/pr-x12-codec-x265-design.md` — master design doc. +pub mod ans; pub mod ctu; pub mod mode; pub mod predict; +pub mod rdo; +pub use ans::{decode_modes, encode_modes, RansFreqTable, RANS_BYTE_L, RANS_M, RANS_SCALE_BITS}; pub use ctu::{CellMode, MergeDir, MAX_QUAD_TREE_NODES, MAX_SPLIT_DEPTH}; pub use ctu::{Ctu, CtuArena, CtuPartition, LeafCu, MaxSplitDepthReached, MergeError, NodeIdx}; pub use mode::{ @@ -36,3 +44,4 @@ pub use mode::{ MAX_BASIN_IDX, }; pub use predict::{predict_intra, IntraConfig, IntraContext}; +pub use rdo::{rdo_select, RdoChoice, RdoConfig, RdoContext}; diff --git a/src/hpc/codec/rdo.rs b/src/hpc/codec/rdo.rs new file mode 100644 index 00000000..56ec9ca2 --- /dev/null +++ b/src/hpc/codec/rdo.rs @@ -0,0 +1,464 @@ +//! λ-rate-distortion mode selection (PR-X12 A6). +//! +//! [`super::predict`] makes a *hard* mode decision: Skip iff δ = 0, +//! Merge iff a neighbour matches exactly, else Delta/Escape. RDO is the +//! *soft* generalization — for each cell it scores all four modes by +//! +//! ```text +//! cost = rate + λ · distortion +//! ``` +//! +//! and picks the minimum. At λ = 0 it minimizes pure rate (always the +//! cheapest mode that exists); as λ → ∞ it minimizes pure distortion +//! (lossless Escape when available, else the tightest Delta). This is +//! the lever that lets high-confidence cells spend bits for fidelity and +//! low-confidence cells tolerate lossy compression — the design's +//! "λ calibrated via NARS confidence". +//! +//! # Integer λ (no float) +//! +//! The design doc sketches `lambda: f32`. This implementation uses a +//! **fixed-point** λ instead — [`RdoConfig::lambda_q8`] is λ × 256 as a +//! `u32`, and the cost is computed in `u64`: +//! +//! ```text +//! cost_q8 = (rate_bytes << 8) + lambda_q8 · distortion +//! ``` +//! +//! This keeps mode decisions deterministic and bit-identical across +//! platforms (no float rounding divergence), consistent with the +//! substrate's no-float discipline. `f32` λ is a cross-platform +//! reproducibility hazard the codec doesn't need: distortion is an +//! integer delta magnitude and rate is an integer byte count, so the +//! whole R + λD lattice is exactly representable in fixed point. +//! +//! # Rate model +//! +//! Rate is the leaf's on-wire byte length +//! ([`super::mode::packed_byte_len`]): Skip = 2, Merge = 3, Delta = 3, +//! Escape = 6. This is the *pre-entropy* size — the bytes a leaf +//! actually occupies before the A7 rANS pass folds the mode tag down to +//! `-log2(p)` bits. An entropy-aware rate (mode tag weighted by its +//! frequency) is a documented follow-on; the pre-entropy byte rate is +//! exact, deterministic, and never under-counts the payload. +//! +//! # Distortion model +//! +//! Distortion is the absolute error of the reconstructed δ against the +//! true δ, both in the basin's u8-quantization space (matching +//! [`super::predict::IntraContext`]'s `delta_i32`): +//! +//! | Mode | Reconstructed δ | Distortion | +//! |--------|----------------------------|-------------------------| +//! | Skip | 0 (basin exactly) | `|δ|` | +//! | Merge | neighbour's δ (as `i8`) | `|δ − δ_neighbour|` | +//! | Delta | `clamp(δ, −128, 127)` | `|δ − clamp(δ)|` (0 in range) | +//! | Escape | δ (full value preserved) | 0 (lossless) | +//! +//! Escape requires a caller-supplied escape-vector cursor (same contract +//! as [`super::predict::predict_intra`]); without one, Escape is not a +//! candidate and the selector falls back to the best lossy mode. + +use super::ctu::{CellMode, LeafCu, MergeDir}; +use super::mode::packed_byte_len; + +// ════════════════════════════════════════════════════════════════════ +// Configuration +// ════════════════════════════════════════════════════════════════════ + +/// RDO tuning. `lambda_q8` is the rate-distortion tradeoff λ in Q8 fixed +/// point (λ × 256): `0` minimizes rate, large values minimize distortion. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct RdoConfig { + /// λ × 256. Cost = `(rate_bytes << 8) + lambda_q8 · distortion`. + pub lambda_q8: u32, +} + +impl RdoConfig { + /// λ = 0 — pure rate minimization (always the cheapest feasible mode, + /// regardless of fidelity loss). + pub const RATE_ONLY: Self = Self { lambda_q8: 0 }; + + /// A λ large enough that any non-zero distortion outweighs the 6-byte + /// Escape rate, i.e. effectively lossless mode selection. `1 << 20` + /// (λ = 4096) dominates the maximum 6-byte rate (`6 << 8 = 1536`) for + /// any distortion ≥ 1. + pub const LOSSLESS: Self = Self { lambda_q8: 1 << 20 }; + + /// Construct from a fixed-point λ. + /// + /// ``` + /// use ndarray::hpc::codec::rdo::RdoConfig; + /// // λ = 2.5 → lambda_q8 = 640 + /// let cfg = RdoConfig::from_lambda_q8(640); + /// assert_eq!(cfg.lambda_q8, 640); + /// ``` + pub const fn from_lambda_q8(lambda_q8: u32) -> Self { + Self { lambda_q8 } + } +} + +impl Default for RdoConfig { + /// A middle λ (= 16.0) biased toward fidelity but not strictly + /// lossless — Skip/Merge a cell only when the distortion is small. + fn default() -> Self { + Self { lambda_q8: 16 << 8 } + } +} + +// ════════════════════════════════════════════════════════════════════ +// Context + result +// ════════════════════════════════════════════════════════════════════ + +/// Per-cell context for the RDO decision. Mirrors +/// [`super::predict::IntraContext`]: a pre-resolved basin index, the +/// signed δ from basin to cell (in u8-quantization space), and the four +/// NEWS neighbour leaves (indexed by [`MergeDir`] discriminant: +/// `North=0, East=1, West=2, South=3`). +#[derive(Debug, Clone, Copy)] +pub struct RdoContext<'a> { + /// Pre-resolved basin index (12-bit max, not re-validated here). + pub basin_idx: u16, + /// Signed δ from basin → cell, in the basin's u8 quantization space. + pub delta_i32: i32, + /// NEWS neighbour leaves, `None` at block boundaries. + pub neighbours: [Option<&'a LeafCu>; 4], +} + +/// The selected mode plus its scored cost and reconstruction distortion. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct RdoChoice { + /// The chosen leaf, ready to pack via [`super::mode::pack_leaf`]. + pub leaf: LeafCu, + /// `(rate_bytes << 8) + lambda_q8 · distortion` for the chosen mode. + pub cost_q8: u64, + /// Reconstruction error |δ − δ̂| in u8-quantization units (0 = exact). + pub distortion: u32, +} + +// ════════════════════════════════════════════════════════════════════ +// The selector +// ════════════════════════════════════════════════════════════════════ + +/// Score all feasible modes and return the minimum-cost +/// [`RdoChoice`]. +/// +/// Feasible modes are scored in the order Skip → Merge → Delta → Escape; +/// ties are resolved in favour of the earlier (lower-rate) mode via a +/// strict `<` comparison. Skip and Delta are always feasible; Merge is +/// feasible only when a NEWS neighbour is a `Delta`-mode leaf on the same +/// basin; Escape is feasible only when `escape_next` is supplied. +/// +/// When Escape wins, the caller's escape-vector cursor is post-incremented +/// (the chosen leaf references the pre-increment value) so batched calls +/// don't collide — identical to [`super::predict::predict_intra`]. The +/// cursor is left untouched if any other mode wins. +/// +/// # Examples +/// +/// λ = 0 picks the cheapest mode regardless of distortion — Skip (2 +/// bytes) even when δ ≠ 0: +/// +/// ``` +/// use ndarray::hpc::codec::rdo::{rdo_select, RdoConfig, RdoContext}; +/// use ndarray::hpc::codec::CellMode; +/// let ctx = RdoContext { basin_idx: 7, delta_i32: 40, neighbours: [None; 4] }; +/// let choice = rdo_select(&ctx, &RdoConfig::RATE_ONLY, None); +/// assert_eq!(choice.leaf.mode, CellMode::Skip); +/// assert_eq!(choice.distortion, 40); // lossy: reconstructs as basin +/// ``` +/// +/// A lossless λ keeps the same in-range δ exact via Delta (0 distortion): +/// +/// ``` +/// use ndarray::hpc::codec::rdo::{rdo_select, RdoConfig, RdoContext}; +/// use ndarray::hpc::codec::CellMode; +/// let ctx = RdoContext { basin_idx: 7, delta_i32: 40, neighbours: [None; 4] }; +/// let choice = rdo_select(&ctx, &RdoConfig::LOSSLESS, None); +/// assert_eq!(choice.leaf.mode, CellMode::Delta); +/// assert_eq!(choice.distortion, 0); +/// ``` +pub fn rdo_select(ctx: &RdoContext, cfg: &RdoConfig, escape_next: Option<&mut u32>) -> RdoChoice { + let lambda = cfg.lambda_q8 as u64; + + // ── Skip (always feasible) ─────────────────────────────────────── + // Reconstructs the cell as the basin exactly; distortion = |δ|. + let mut best = Candidate::new(LeafCu::skip(ctx.basin_idx), ctx.delta_i32.unsigned_abs(), lambda); + + // ── Merge (feasible if a Delta-mode same-basin neighbour exists) ── + // The reconstructed δ is the neighbour's stored δ read back as i8. + // Among qualifying neighbours, pick the one with the smallest + // distortion (closest δ); ties keep the earlier NEWS slot. + if let Some((dir, nb_dist)) = best_merge_neighbour(ctx) { + consider(&mut best, Candidate::new(LeafCu::merge(ctx.basin_idx, dir), nb_dist, lambda)); + } + + // ── Delta (always feasible) ────────────────────────────────────── + // δ clamped into i8 range; distortion = |δ − clamp(δ)| (0 in range). + let clamped = ctx.delta_i32.clamp(-128, 127); + let delta_dist = ctx.delta_i32.abs_diff(clamped); + consider(&mut best, Candidate::new(LeafCu::delta(ctx.basin_idx, clamped as u8), delta_dist, lambda)); + + // ── Escape (feasible only with a cursor) ───────────────────────── + // Lossless: the full cell value is preserved in the escape vector, so + // distortion = 0. Build the candidate against the *current* cursor + // value, but only advance the cursor (post-increment) if Escape wins — + // a losing Escape leaves the caller's allocator untouched. + if let Some(next) = escape_next { + let cand = Candidate::new(LeafCu::escape(ctx.basin_idx, *next), 0, lambda); + if cand.cost < best.cost { + *next = next.wrapping_add(1); + best = cand; + } + } + + RdoChoice { + leaf: best.leaf, + cost_q8: best.cost, + distortion: best.distortion, + } +} + +// ════════════════════════════════════════════════════════════════════ +// Internals +// ════════════════════════════════════════════════════════════════════ + +/// A scored mode candidate during selection. `cost` is computed at +/// construction from the leaf's wire rate, the distortion, and λ. +struct Candidate { + leaf: LeafCu, + distortion: u32, + cost: u64, +} + +impl Candidate { + /// Score `leaf` (rate derived from its mode) at the given distortion + /// and fixed-point λ. + #[inline] + fn new(leaf: LeafCu, distortion: u32, lambda_q8: u64) -> Self { + let rate = packed_byte_len(leaf.mode) as u64; + Self { + leaf, + distortion, + cost: cost_q8(rate, distortion, lambda_q8), + } + } +} + +/// `(rate << 8) + λ_q8 · distortion`, saturating so an extreme λ can't +/// wrap. `u64` headroom: rate ≤ 6, distortion ≤ ~2³¹, λ ≤ 2³² → the +/// product is ≤ 2⁶³, well inside `u64`, but saturating keeps it total. +#[inline] +fn cost_q8(rate_bytes: u64, distortion: u32, lambda_q8: u64) -> u64 { + (rate_bytes << 8).saturating_add(lambda_q8.saturating_mul(distortion as u64)) +} + +/// Replace `best` with `cand` if `cand` is strictly cheaper. Strict `<` +/// keeps the earlier (lower-rate) mode on ties, giving deterministic +/// selection. +#[inline] +fn consider(best: &mut Candidate, cand: Candidate) { + if cand.cost < best.cost { + *best = cand; + } +} + +/// Find the NEWS neighbour that yields the smallest Merge distortion. +/// +/// A neighbour qualifies iff it is a `Delta`-mode leaf on the same basin +/// (Merge inherits a δ, and only across the same reference basin). The +/// reconstructed δ is the neighbour's stored byte read back as `i8`. +/// Returns `(dir, distortion)` for the best qualifying neighbour, or +/// `None` if none qualify. +fn best_merge_neighbour(ctx: &RdoContext) -> Option<(MergeDir, u32)> { + let mut best: Option<(MergeDir, u32)> = None; + for (i, slot) in ctx.neighbours.iter().enumerate() { + let Some(nb) = slot else { continue }; + if nb.mode != CellMode::Delta || nb.basin_idx != ctx.basin_idx { + continue; + } + let Some(nb_delta) = nb.delta else { continue }; + let recon = (nb_delta as i8) as i32; + let dist = ctx.delta_i32.abs_diff(recon); + let dir = merge_dir_from_index(i); + match best { + Some((_, best_dist)) if best_dist <= dist => {} + _ => best = Some((dir, dist)), + } + } + best +} + +#[inline] +fn merge_dir_from_index(i: usize) -> MergeDir { + match i { + 0 => MergeDir::North, + 1 => MergeDir::East, + 2 => MergeDir::West, + _ => MergeDir::South, + } +} + +// ════════════════════════════════════════════════════════════════════ +// Tests +// ════════════════════════════════════════════════════════════════════ + +#[cfg(test)] +mod tests { + use super::*; + + fn ctx<'a>(basin: u16, delta: i32, neighbours: [Option<&'a LeafCu>; 4]) -> RdoContext<'a> { + RdoContext { + basin_idx: basin, + delta_i32: delta, + neighbours, + } + } + + #[test] + fn skip_when_delta_zero_at_any_lambda() { + for cfg in [RdoConfig::RATE_ONLY, RdoConfig::default(), RdoConfig::LOSSLESS] { + let choice = rdo_select(&ctx(7, 0, [None; 4]), &cfg, None); + assert_eq!(choice.leaf.mode, CellMode::Skip); + assert_eq!(choice.distortion, 0); + } + } + + #[test] + fn rate_only_prefers_skip_even_when_lossy() { + let choice = rdo_select(&ctx(7, 50, [None; 4]), &RdoConfig::RATE_ONLY, None); + assert_eq!(choice.leaf.mode, CellMode::Skip); + assert_eq!(choice.distortion, 50); + // λ = 0 → cost is pure rate = 2 bytes << 8. + assert_eq!(choice.cost_q8, 2 << 8); + } + + #[test] + fn lossless_lambda_keeps_in_range_delta_exact() { + let choice = rdo_select(&ctx(7, 40, [None; 4]), &RdoConfig::LOSSLESS, None); + assert_eq!(choice.leaf.mode, CellMode::Delta); + assert_eq!(choice.leaf.delta, Some(40)); + assert_eq!(choice.distortion, 0); + } + + #[test] + fn negative_in_range_delta_packs_wrapping() { + let choice = rdo_select(&ctx(7, -40, [None; 4]), &RdoConfig::LOSSLESS, None); + assert_eq!(choice.leaf.mode, CellMode::Delta); + assert_eq!(choice.leaf.delta, Some((-40_i32) as u8)); + assert_eq!(choice.distortion, 0); + } + + #[test] + fn out_of_range_delta_chooses_escape_when_lossless_and_allocator() { + let mut next = 3u32; + let choice = rdo_select(&ctx(7, 1000, [None; 4]), &RdoConfig::LOSSLESS, Some(&mut next)); + assert_eq!(choice.leaf.mode, CellMode::Escape); + assert_eq!(choice.leaf.escape_idx, Some(3)); + assert_eq!(choice.distortion, 0); + assert_eq!(next, 4, "cursor advances only when Escape wins"); + } + + #[test] + fn escape_cursor_untouched_when_escape_loses() { + // δ = 0 → Skip wins (distortion 0, cheapest rate); the cursor must + // not advance even though an allocator was supplied. + let mut next = 9u32; + let choice = rdo_select(&ctx(7, 0, [None; 4]), &RdoConfig::LOSSLESS, Some(&mut next)); + assert_eq!(choice.leaf.mode, CellMode::Skip); + assert_eq!(next, 9, "cursor must not advance when Escape loses"); + } + + #[test] + fn out_of_range_delta_without_allocator_falls_back_to_clamped_delta() { + // No escape cursor → Escape infeasible; the best lossless-ish + // option is a clamped Delta (lossy). + let choice = rdo_select(&ctx(7, 1000, [None; 4]), &RdoConfig::LOSSLESS, None); + assert_eq!(choice.leaf.mode, CellMode::Delta); + assert_eq!(choice.leaf.delta, Some(127)); + assert_eq!(choice.distortion, 1000 - 127); + } + + #[test] + fn merge_chosen_when_neighbour_matches_and_lambda_rewards_it() { + // Northern neighbour is Delta(basin=7, δ=20). True δ = 20 → Merge + // reconstructs exactly (distortion 0) at 3 bytes, beating Delta's + // 3 bytes only on the tie-break? Both 3 bytes, both distortion 0. + // Merge is scored after Skip; Skip here has distortion 20 so loses + // at lossless λ. Merge (scored before Delta) wins the 3-byte tie. + let nb = LeafCu::delta(7, 20); + let neighbours = [Some(&nb), None, None, None]; + let choice = rdo_select(&ctx(7, 20, neighbours), &RdoConfig::LOSSLESS, None); + assert_eq!(choice.leaf.mode, CellMode::Merge); + assert_eq!(choice.leaf.merge_dir, Some(MergeDir::North)); + assert_eq!(choice.distortion, 0); + } + + #[test] + fn merge_picks_closest_neighbour() { + // Among qualifying neighbours, `best_merge_neighbour` must pick the + // closest δ. An exact-match Merge ties Delta (both 3 bytes, dist 0) + // and wins on scoring order; a non-exact neighbour never wins + // (Delta represents any in-range δ optimally), so the selector must + // surface the EXACT neighbour, not the off-by-5 one. + let east = LeafCu::delta(7, 20); // exact: |20 − 20| = 0 + let west = LeafCu::delta(7, 25); // off: |20 − 25| = 5 + let neighbours = [None, Some(&east), Some(&west), None]; + let choice = rdo_select(&ctx(7, 20, neighbours), &RdoConfig::default(), None); + assert_eq!(choice.leaf.mode, CellMode::Merge); + assert_eq!(choice.leaf.merge_dir, Some(MergeDir::East)); + assert_eq!(choice.distortion, 0); + } + + #[test] + fn merge_rejected_when_basin_differs() { + let nb = LeafCu::delta(99, 20); // different basin + let neighbours = [Some(&nb), None, None, None]; + let choice = rdo_select(&ctx(7, 20, neighbours), &RdoConfig::LOSSLESS, None); + // Can't Merge across basins → exact Delta(20) instead. + assert_eq!(choice.leaf.mode, CellMode::Delta); + assert_eq!(choice.leaf.delta, Some(20)); + } + + #[test] + fn merge_rejected_when_neighbour_not_delta_mode() { + let nb_skip = LeafCu::skip(7); + let nb_merge = LeafCu::merge(7, MergeDir::North); + let neighbours = [Some(&nb_skip), Some(&nb_merge), None, None]; + let choice = rdo_select(&ctx(7, 20, neighbours), &RdoConfig::LOSSLESS, None); + assert_eq!(choice.leaf.mode, CellMode::Delta); + } + + #[test] + fn cost_is_rate_plus_lambda_distortion() { + // λ_q8 = 256 (λ = 1.0), δ = 10, no neighbours, no escape. + // Skip: rate 2, dist 10 → (2<<8) + 256*10 = 512 + 2560 = 3072 + // Delta: rate 3, dist 0 → (3<<8) + 0 = 768 + // Delta wins. + let cfg = RdoConfig::from_lambda_q8(256); + let choice = rdo_select(&ctx(7, 10, [None; 4]), &cfg, None); + assert_eq!(choice.leaf.mode, CellMode::Delta); + assert_eq!(choice.cost_q8, 768); + } + + #[test] + fn chosen_leaf_round_trips_through_mode_pack() { + use super::super::mode::{pack_leaf, unpack_leaf}; + let nb = LeafCu::delta(7, 20); + let neighbours = [None, Some(&nb), None, None]; + let choice = rdo_select(&ctx(7, 20, neighbours), &RdoConfig::LOSSLESS, None); + let mut buf = [0u8; 6]; + let n = pack_leaf(&choice.leaf, &mut buf).unwrap(); + let (decoded, consumed) = unpack_leaf(&buf).unwrap(); + assert_eq!(n, consumed); + assert_eq!(decoded, choice.leaf); + } + + #[test] + fn candidate_new_matches_free_cost_fn() { + // `Candidate::new` derives rate from the leaf mode (Delta = 3) and + // must agree with the free `cost_q8` used for the Escape compare. + let c = Candidate::new(LeafCu::delta(7, 5), 7, 256); + assert_eq!(c.cost, cost_q8(3, 7, 256)); + } +} From 347f36d6f45ed4238b54508fd29aa4fe0f1959bf Mon Sep 17 00:00:00 2001 From: Claude Date: Wed, 27 May 2026 16:10:40 +0000 Subject: [PATCH 2/3] docs(codec): clarify Merge-vs-Delta tie comment per P2 review https://claude.ai/code/session_01HbqooFZHAjaUtFEzhA1R2u --- src/hpc/codec/rdo.rs | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/hpc/codec/rdo.rs b/src/hpc/codec/rdo.rs index 56ec9ca2..c6bc6adf 100644 --- a/src/hpc/codec/rdo.rs +++ b/src/hpc/codec/rdo.rs @@ -381,11 +381,10 @@ mod tests { #[test] fn merge_chosen_when_neighbour_matches_and_lambda_rewards_it() { - // Northern neighbour is Delta(basin=7, δ=20). True δ = 20 → Merge - // reconstructs exactly (distortion 0) at 3 bytes, beating Delta's - // 3 bytes only on the tie-break? Both 3 bytes, both distortion 0. - // Merge is scored after Skip; Skip here has distortion 20 so loses - // at lossless λ. Merge (scored before Delta) wins the 3-byte tie. + // Northern neighbour is Delta(basin=7, δ=20). True δ = 20, so both + // Merge and Delta reconstruct exactly (distortion 0) at 3 bytes — + // an exact cost tie. Skip (2 bytes) has distortion 20 and loses at + // lossless λ. Merge is scored before Delta, so it wins the tie. let nb = LeafCu::delta(7, 20); let neighbours = [Some(&nb), None, None, None]; let choice = rdo_select(&ctx(7, 20, neighbours), &RdoConfig::LOSSLESS, None); From ed6f3ba595b6093595d1a483bdc755777c82c821 Mon Sep 17 00:00:00 2001 From: Claude Date: Wed, 27 May 2026 16:16:50 +0000 Subject: [PATCH 3/3] fix(codec): reject ANS streams with n>0 but an empty freq table Codex P2: decode_modes accepted a stream whose header claimed n > 0 while all four stored frequencies were zero (an "empty model"). from_freqs permits an all-zero table for the n == 0 empty-stream case, so a corrupt or non-encode_modes stream slipped through and rans_decode fabricated Escape tags from the freq-0 table instead of reporting malformed input. Add an explicit n > 0 / all-zero-table guard + regression test. https://claude.ai/code/session_01HbqooFZHAjaUtFEzhA1R2u --- src/hpc/codec/ans.rs | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/src/hpc/codec/ans.rs b/src/hpc/codec/ans.rs index 0ae57c9b..6a433c35 100644 --- a/src/hpc/codec/ans.rs +++ b/src/hpc/codec/ans.rs @@ -351,6 +351,13 @@ pub fn decode_modes(stream: &[u8]) -> Option> { *slot = u16::from_le_bytes([lo, hi]); } let table = RansFreqTable::from_freqs(freq)?; + // A non-empty stream needs a non-empty model. An all-zero freq table is + // only valid for n == 0 (the empty stream); with n > 0 it cannot encode + // any symbol, so reject rather than letting `rans_decode` fabricate + // Escape tags from the empty (freq-0) table on a corrupt input. + if n > 0 && freq.iter().all(|&f| f == 0) { + return None; + } let payload = &stream[HEADER_LEN..]; if n > 0 && payload.len() < 4 { return None; @@ -487,6 +494,17 @@ mod tests { ); } + #[test] + fn decode_rejects_nonempty_count_with_empty_freq_table() { + // Codex P2: header claims n > 0 but the stored freq table is all-zero + // (an "empty model"). `from_freqs` accepts an all-zero table for the + // n == 0 case, so the n > 0 guard must reject here rather than let + // rans_decode fabricate Escape tags from the freq-0 table. + let mut stream = vec![0u8; HEADER_LEN + 8]; + stream[0] = 5; // n = 5, but all four u16 freqs stay 0 + assert!(decode_modes(&stream).is_none()); + } + #[test] fn decode_rejects_short_header() { assert!(decode_modes(&[0u8; 11]).is_none());