diff --git a/nodedb-vector/src/collection/checkpoint.rs b/nodedb-vector/src/collection/checkpoint.rs index a92a9d57..af086d07 100644 --- a/nodedb-vector/src/collection/checkpoint.rs +++ b/nodedb-vector/src/collection/checkpoint.rs @@ -7,6 +7,7 @@ use crate::collection::tier::StorageTier; use crate::distance::DistanceMetric; use crate::flat::FlatIndex; use crate::hnsw::{HnswIndex, HnswParams}; +use crate::quantize::pq::PqCodec; use super::lifecycle::VectorCollection; @@ -33,12 +34,18 @@ pub(crate) struct CollectionSnapshot { pub(crate) struct SealedSnapshot { pub base_id: u32, pub hnsw_bytes: Vec, + #[serde(default)] + pub pq_bytes: Option>, + #[serde(default)] + pub pq_codes: Option>, } #[derive(Serialize, Deserialize, zerompk::ToMessagePack, zerompk::FromMessagePack)] pub(crate) struct BuildingSnapshot { pub base_id: u32, pub vectors: Vec>, + #[serde(default)] + pub deleted: Vec, } impl VectorCollection { @@ -53,17 +60,27 @@ impl VectorCollection { next_id: self.next_id, growing_base_id: self.growing_base_id, growing_vectors: (0..self.growing.len() as u32) - .filter_map(|i| self.growing.get_vector(i).map(|v| v.to_vec())) + .filter_map(|i| self.growing.get_vector_raw(i).map(|v| v.to_vec())) .collect(), growing_deleted: (0..self.growing.len() as u32) - .map(|i| self.growing.get_vector(i).is_none()) + .map(|i| self.growing.is_deleted(i)) .collect(), sealed_segments: self .sealed .iter() - .map(|s| SealedSnapshot { - base_id: s.base_id, - hnsw_bytes: s.index.checkpoint_to_bytes(), + .map(|s| { + let (pq_bytes, pq_codes) = match &s.pq { + Some((codec, codes)) => { + (zerompk::to_msgpack_vec(codec).ok(), Some(codes.clone())) + } + None => (None, None), + }; + SealedSnapshot { + base_id: s.base_id, + hnsw_bytes: s.index.checkpoint_to_bytes(), + pq_bytes, + pq_codes, + } }) .collect(), building_segments: self @@ -72,7 +89,10 @@ impl VectorCollection { .map(|b| BuildingSnapshot { base_id: b.base_id, vectors: (0..b.flat.len() as u32) - .filter_map(|i| b.flat.get_vector(i).map(|v| v.to_vec())) + .filter_map(|i| b.flat.get_vector_raw(i).map(|v| v.to_vec())) + .collect(), + deleted: (0..b.flat.len() as u32) + .map(|i| b.flat.is_deleted(i)) .collect(), }) .collect(), @@ -118,18 +138,35 @@ impl VectorCollection { }; let mut growing = FlatIndex::new(snap.dim, metric); - for v in &snap.growing_vectors { - growing.insert(v.clone()); + for (i, v) in snap.growing_vectors.iter().enumerate() { + let deleted = snap.growing_deleted.get(i).copied().unwrap_or(false); + if deleted { + growing.insert_tombstoned(v.clone()); + } else { + growing.insert(v.clone()); + } } let mut sealed = Vec::with_capacity(snap.sealed_segments.len()); for ss in &snap.sealed_segments { if let Some(index) = HnswIndex::from_checkpoint(&ss.hnsw_bytes) { - let sq8 = VectorCollection::build_sq8_for_index(&index); + let pq = match (&ss.pq_bytes, &ss.pq_codes) { + (Some(bytes), Some(codes)) => zerompk::from_msgpack::(bytes) + .ok() + .map(|codec| (codec, codes.clone())), + _ => None, + }; + // Only train SQ8 when PQ isn't present — a segment never carries both. + let sq8 = if pq.is_some() { + None + } else { + VectorCollection::build_sq8_for_index(&index) + }; sealed.push(SealedSegment { index, base_id: ss.base_id, sq8, + pq, tier: StorageTier::L0Ram, mmap_vectors: None, }); @@ -143,11 +180,18 @@ impl VectorCollection { .insert(v.clone()) .expect("dimension guaranteed by checkpoint"); } + // Replay building-segment tombstones onto the HNSW index. + for (i, &dead) in bs.deleted.iter().enumerate() { + if dead { + index.delete(i as u32); + } + } let sq8 = VectorCollection::build_sq8_for_index(&index); sealed.push(SealedSegment { index, base_id: bs.base_id, sq8, + pq: None, tier: StorageTier::L0Ram, mmap_vectors: None, }); @@ -155,6 +199,10 @@ impl VectorCollection { let next_segment_id = (sealed.len() + 1) as u32; + let index_config = crate::index_config::IndexConfig { + hnsw: params.clone(), + ..crate::index_config::IndexConfig::default() + }; Some(Self { growing, growing_base_id: snap.growing_base_id, @@ -171,6 +219,7 @@ impl VectorCollection { doc_id_map: snap.doc_id_map.into_iter().collect(), multi_doc_map: snap.multi_doc_map.into_iter().collect(), seal_threshold: DEFAULT_SEAL_THRESHOLD, + index_config, }) } } diff --git a/nodedb-vector/src/collection/lifecycle.rs b/nodedb-vector/src/collection/lifecycle.rs index c3311573..7eee6113 100644 --- a/nodedb-vector/src/collection/lifecycle.rs +++ b/nodedb-vector/src/collection/lifecycle.rs @@ -2,7 +2,7 @@ use crate::flat::FlatIndex; use crate::hnsw::{HnswIndex, HnswParams}; -use crate::quantize::sq8::Sq8Codec; +use crate::index_config::{IndexConfig, IndexType}; use super::segment::{BuildRequest, BuildingSegment, DEFAULT_SEAL_THRESHOLD, SealedSegment}; @@ -40,6 +40,8 @@ pub struct VectorCollection { pub multi_doc_map: std::collections::HashMap>, /// Number of vectors in the growing segment before sealing. pub(crate) seal_threshold: usize, + /// Full index configuration (index type, PQ params, IVF params). + pub(crate) index_config: IndexConfig, } impl VectorCollection { @@ -50,6 +52,25 @@ impl VectorCollection { /// Create an empty collection with an explicit seal threshold. pub fn with_seal_threshold(dim: usize, params: HnswParams, seal_threshold: usize) -> Self { + let index_config = IndexConfig { + hnsw: params.clone(), + ..IndexConfig::default() + }; + Self::with_seal_threshold_and_config(dim, index_config, seal_threshold) + } + + /// Create an empty collection with a full index configuration. + pub fn with_index_config(dim: usize, config: IndexConfig) -> Self { + Self::with_seal_threshold_and_config(dim, config, DEFAULT_SEAL_THRESHOLD) + } + + /// Create an empty collection with a full index config and custom seal threshold. + pub fn with_seal_threshold_and_config( + dim: usize, + config: IndexConfig, + seal_threshold: usize, + ) -> Self { + let params = config.hnsw.clone(); Self { growing: FlatIndex::new(dim, params.metric), growing_base_id: 0, @@ -66,6 +87,7 @@ impl VectorCollection { doc_id_map: std::collections::HashMap::new(), multi_doc_map: std::collections::HashMap::new(), seal_threshold, + index_config: config, } } @@ -213,53 +235,28 @@ impl VectorCollection { .position(|b| b.segment_id == segment_id) { let building = self.building.remove(pos); - let sq8 = Self::build_sq8_for_index(&index); + let use_pq = self.index_config.index_type == IndexType::HnswPq; + let (sq8, pq) = if use_pq { + ( + None, + Self::build_pq_for_index(&index, self.index_config.pq_m), + ) + } else { + (Self::build_sq8_for_index(&index), None) + }; let (tier, mmap_vectors) = self.resolve_tier_for_build(segment_id, &index); self.sealed.push(SealedSegment { index, base_id: building.base_id, sq8, + pq, tier, mmap_vectors, }); } } - /// Build SQ8 quantized data for an HNSW index. - pub fn build_sq8_for_index(index: &HnswIndex) -> Option<(Sq8Codec, Vec)> { - if index.live_count() < 1000 { - return None; - } - let dim = index.dim(); - let n = index.len(); - - let mut refs: Vec<&[f32]> = Vec::with_capacity(n); - for i in 0..n { - if !index.is_deleted(i as u32) - && let Some(v) = index.get_vector(i as u32) - { - refs.push(v); - } - } - if refs.is_empty() { - return None; - } - - let codec = Sq8Codec::calibrate(&refs, dim); - - let mut data = Vec::with_capacity(dim * n); - for i in 0..n { - if let Some(v) = index.get_vector(i as u32) { - data.extend(codec.quantize(v)); - } else { - data.extend(vec![0u8; dim]); - } - } - - Some((codec, data)) - } - /// Access sealed segments (read-only). pub fn sealed_segments(&self) -> &[SealedSegment] { &self.sealed @@ -276,10 +273,64 @@ impl VectorCollection { } /// Compact sealed segments by removing tombstoned nodes. + /// + /// Rewrites `doc_id_map` and `multi_doc_map` for every sealed segment + /// so that global ids continue to resolve to the correct document + /// strings after local-id renumbering. pub fn compact(&mut self) -> usize { let mut total_removed = 0; for seg in &mut self.sealed { - total_removed += seg.index.compact(); + let base_id = seg.base_id; + let (removed, id_map) = seg.index.compact_with_map(); + total_removed += removed; + if removed == 0 { + continue; + } + + // Rebuild doc_id_map for entries in [base_id, base_id + id_map.len()). + let segment_end = base_id as u64 + id_map.len() as u64; + let doc_keys: Vec = self + .doc_id_map + .keys() + .copied() + .filter(|&k| (k as u64) >= base_id as u64 && (k as u64) < segment_end) + .collect(); + // Two-phase: remove all old entries first, then insert new ones so + // we don't clobber a freshly-remapped entry with a later tombstone + // removal. + let mut new_entries: Vec<(u32, String)> = Vec::with_capacity(doc_keys.len()); + for old_global in &doc_keys { + let doc = self.doc_id_map.remove(old_global); + let old_local = (old_global - base_id) as usize; + let new_local = id_map[old_local]; + if new_local != u32::MAX + && let Some(doc) = doc + { + new_entries.push((base_id + new_local, doc)); + } + } + for (k, v) in new_entries { + self.doc_id_map.insert(k, v); + } + + // Rewrite multi_doc_map entries for this segment. + for ids in self.multi_doc_map.values_mut() { + ids.retain_mut(|vid| { + let v = *vid; + if (v as u64) >= base_id as u64 && (v as u64) < segment_end { + let old_local = (v - base_id) as usize; + let new_local = id_map[old_local]; + if new_local == u32::MAX { + false + } else { + *vid = base_id + new_local; + true + } + } else { + true + } + }); + } } total_removed } @@ -382,12 +433,20 @@ impl VectorCollection { 0.0 }; - let quantization = if self.sealed.iter().any(|s| s.sq8.is_some()) { + let quantization = if self.sealed.iter().any(|s| s.pq.is_some()) { + nodedb_types::VectorIndexQuantization::Pq + } else if self.sealed.iter().any(|s| s.sq8.is_some()) { nodedb_types::VectorIndexQuantization::Sq8 } else { nodedb_types::VectorIndexQuantization::None }; + let index_type = match self.index_config.index_type { + IndexType::HnswPq => nodedb_types::VectorIndexType::HnswPq, + IndexType::IvfPq => nodedb_types::VectorIndexType::IvfPq, + IndexType::Hnsw => nodedb_types::VectorIndexType::Hnsw, + }; + let hnsw_mem: usize = self .sealed .iter() @@ -422,7 +481,7 @@ impl VectorCollection { memory_bytes, disk_bytes, build_in_progress: !self.building.is_empty(), - index_type: nodedb_types::VectorIndexType::Hnsw, + index_type, hnsw_m: self.params.m, hnsw_m0: self.params.m0, hnsw_ef_construction: self.params.ef_construction, diff --git a/nodedb-vector/src/collection/mod.rs b/nodedb-vector/src/collection/mod.rs index 07c118ab..d316ba17 100644 --- a/nodedb-vector/src/collection/mod.rs +++ b/nodedb-vector/src/collection/mod.rs @@ -1,6 +1,7 @@ pub mod budget; pub mod checkpoint; pub mod lifecycle; +pub mod quantize; pub mod search; pub mod segment; pub mod tier; diff --git a/nodedb-vector/src/collection/quantize.rs b/nodedb-vector/src/collection/quantize.rs new file mode 100644 index 00000000..d7ec5a89 --- /dev/null +++ b/nodedb-vector/src/collection/quantize.rs @@ -0,0 +1,113 @@ +//! Quantizer training helpers for `VectorCollection`. +//! +//! Split from `lifecycle.rs` to keep that file under the 500-line cap. +//! All methods here are `impl VectorCollection` blocks — Rust allows a +//! type's impl to be split across files. + +use crate::hnsw::{HnswIndex, HnswParams}; +use crate::index_config::{IndexConfig, IndexType}; +use crate::quantize::pq::PqCodec; +use crate::quantize::sq8::Sq8Codec; + +use super::lifecycle::VectorCollection; +use super::segment::DEFAULT_SEAL_THRESHOLD; + +impl VectorCollection { + /// Convenience constructor for PQ-configured collections. + /// + /// Equivalent to building a full `IndexConfig` with + /// `index_type = HnswPq` and the given `pq_m`. + pub fn with_pq_config(dim: usize, hnsw: HnswParams, pq_m: usize) -> Self { + let config = IndexConfig { + hnsw, + index_type: IndexType::HnswPq, + pq_m, + ..IndexConfig::default() + }; + Self::with_index_config(dim, config) + } + + /// Convenience constructor for PQ-configured collections with a custom + /// seal threshold. + pub fn with_seal_threshold_and_pq_config( + dim: usize, + hnsw: HnswParams, + pq_m: usize, + seal_threshold: usize, + ) -> Self { + let config = IndexConfig { + hnsw, + index_type: IndexType::HnswPq, + pq_m, + ..IndexConfig::default() + }; + Self::with_seal_threshold_and_config(dim, config, seal_threshold) + } + + /// Build SQ8 quantized data for an HNSW index. + /// + /// Returns `None` when there are too few live vectors for stable + /// min/max calibration. + pub fn build_sq8_for_index(index: &HnswIndex) -> Option<(Sq8Codec, Vec)> { + if index.live_count() < 1000 { + return None; + } + let dim = index.dim(); + let n = index.len(); + + let mut refs: Vec<&[f32]> = Vec::with_capacity(n); + for i in 0..n { + if !index.is_deleted(i as u32) + && let Some(v) = index.get_vector(i as u32) + { + refs.push(v); + } + } + if refs.is_empty() { + return None; + } + + let codec = Sq8Codec::calibrate(&refs, dim); + + let mut data = Vec::with_capacity(dim * n); + for i in 0..n { + if let Some(v) = index.get_vector(i as u32) { + data.extend(codec.quantize(v)); + } else { + data.extend(vec![0u8; dim]); + } + } + + Some((codec, data)) + } + + /// Train a PQ codec from a built HNSW index's live vectors. + pub fn build_pq_for_index(index: &HnswIndex, pq_m: usize) -> Option<(PqCodec, Vec)> { + let dim = index.dim(); + if pq_m == 0 || !dim.is_multiple_of(pq_m) { + return None; + } + let n = index.len(); + let mut refs: Vec> = Vec::with_capacity(n); + for i in 0..n { + if !index.is_deleted(i as u32) + && let Some(v) = index.get_vector(i as u32) + { + refs.push(v.to_vec()); + } + } + if refs.is_empty() { + return None; + } + let refs_slices: Vec<&[f32]> = refs.iter().map(|v| v.as_slice()).collect(); + let k = 256usize.min(refs.len()); + let codec = PqCodec::train(&refs_slices, dim, pq_m, k, 20); + let codes = codec.encode_batch(&refs_slices); + Some((codec, codes)) + } +} + +// Keep the DEFAULT_SEAL_THRESHOLD import live when future refactors move +// additional ctors into this file; explicitly referenced to suppress +// an otherwise-unused warning. +const _: usize = DEFAULT_SEAL_THRESHOLD; diff --git a/nodedb-vector/src/collection/search.rs b/nodedb-vector/src/collection/search.rs index 060f5b31..eb216bcd 100644 --- a/nodedb-vector/src/collection/search.rs +++ b/nodedb-vector/src/collection/search.rs @@ -1,9 +1,111 @@ //! VectorCollection search: multi-segment merging with SQ8 reranking. -use crate::distance::distance; +use crate::distance::{DistanceMetric, distance}; use crate::hnsw::SearchResult; use super::lifecycle::VectorCollection; +use super::segment::SealedSegment; + +/// Score a single candidate via the SQ8 codec, using the metric-appropriate +/// asymmetric distance. +#[inline] +fn sq8_score( + codec: &crate::quantize::sq8::Sq8Codec, + query: &[f32], + encoded: &[u8], + metric: DistanceMetric, +) -> f32 { + match metric { + DistanceMetric::Cosine => codec.asymmetric_cosine(query, encoded), + DistanceMetric::InnerProduct => codec.asymmetric_ip(query, encoded), + // L2 (and all other metrics that don't have a specialized asymmetric + // form yet) fall back to squared L2 — correct for ordering when the + // metric is L2 and a reasonable proxy otherwise since we rerank with + // exact FP32 below. + _ => codec.asymmetric_l2(query, encoded), + } +} + +/// Candidate-generation + rerank for a sealed segment that has a quantized +/// codec attached. Generates a widened candidate pool via HNSW, re-scores +/// candidates using the quantized codec (this is where SQ8/PQ actually pay +/// off — the FP32 vectors need not be resident), and reranks the top +/// `top_k` via exact FP32 distance from mmap or index storage. +fn quantized_search( + seg: &SealedSegment, + query: &[f32], + top_k: usize, + ef: usize, + metric: DistanceMetric, +) -> Vec { + let rerank_k = top_k.saturating_mul(3).max(20); + let hnsw_candidates = seg.index.search(query, rerank_k, ef); + + // Phase 1: rank candidates by quantized distance. + let mut scored: Vec<(u32, f32)> = if let Some((codec, codes)) = &seg.pq { + let table = codec.build_distance_table(query); + let m = codec.m; + hnsw_candidates + .into_iter() + .filter_map(|r| { + let start = (r.id as usize).checked_mul(m)?; + let end = start.checked_add(m)?; + let slice = codes.get(start..end)?; + Some((r.id, codec.asymmetric_distance(&table, slice))) + }) + .collect() + } else if let Some((codec, data)) = &seg.sq8 { + let dim = codec.dim(); + hnsw_candidates + .into_iter() + .filter_map(|r| { + let start = (r.id as usize).checked_mul(dim)?; + let end = start.checked_add(dim)?; + let slice = data.get(start..end)?; + Some((r.id, sq8_score(codec, query, slice, metric))) + }) + .collect() + } else { + hnsw_candidates + .into_iter() + .map(|r| (r.id, r.distance)) + .collect() + }; + scored.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal)); + + // Keep only the most promising candidates for FP32 rerank. + let keep = rerank_k.min(scored.len()); + scored.truncate(keep); + + // Prefetch FP32 vectors for reranking. + if let Some(mmap) = &seg.mmap_vectors { + let ids: Vec = scored.iter().map(|&(id, _)| id).collect(); + mmap.prefetch_batch(&ids); + } + + // Phase 2: rerank with exact FP32. + let mut reranked: Vec = scored + .into_iter() + .filter_map(|(id, _)| { + let v = if let Some(mmap) = &seg.mmap_vectors { + mmap.get_vector(id)? + } else { + seg.index.get_vector(id)? + }; + Some(SearchResult { + id, + distance: distance(query, v, metric), + }) + }) + .collect(); + reranked.sort_by(|a, b| { + a.distance + .partial_cmp(&b.distance) + .unwrap_or(std::cmp::Ordering::Equal) + }); + reranked.truncate(top_k); + reranked +} impl VectorCollection { /// Search across all segments, merging results by distance. @@ -19,44 +121,8 @@ impl VectorCollection { // Search sealed segments. for seg in &self.sealed { - let results = if let Some(_sq8) = &seg.sq8 { - // Quantized two-phase search: use HNSW graph for O(log N) candidate - // generation, then rerank with exact FP32 distance. - let rerank_k = top_k.saturating_mul(3).max(20); - let hnsw_candidates = seg.index.search(query, rerank_k, ef); - let candidates: Vec<(u32, f32)> = hnsw_candidates - .into_iter() - .map(|r| (r.id, r.distance)) - .collect(); - - // Prefetch FP32 vectors for reranking candidates. - if let Some(mmap) = &seg.mmap_vectors { - let ids: Vec = candidates.iter().map(|&(id, _)| id).collect(); - mmap.prefetch_batch(&ids); - } - - // Phase 2: Rerank with exact FP32 distance. - let mut reranked: Vec = candidates - .iter() - .filter_map(|&(id, _)| { - let v = if let Some(mmap) = &seg.mmap_vectors { - mmap.get_vector(id)? - } else { - seg.index.get_vector(id)? - }; - Some(SearchResult { - id, - distance: distance(query, v, self.params.metric), - }) - }) - .collect(); - reranked.sort_by(|a, b| { - a.distance - .partial_cmp(&b.distance) - .unwrap_or(std::cmp::Ordering::Equal) - }); - reranked.truncate(top_k); - reranked + let results = if seg.pq.is_some() || seg.sq8.is_some() { + quantized_search(seg, query, top_k, ef, self.params.metric) } else { seg.index.search(query, top_k, ef) }; @@ -94,14 +160,18 @@ impl VectorCollection { ) -> Vec { let mut all: Vec = Vec::new(); - let growing_results = self.growing.search_filtered(query, top_k, bitmap); + let growing_results = + self.growing + .search_filtered_offset(query, top_k, bitmap, self.growing_base_id); for mut r in growing_results { r.id += self.growing_base_id; all.push(r); } for seg in &self.sealed { - let results = seg.index.search_with_bitmap_bytes(query, top_k, ef, bitmap); + let results = + seg.index + .search_with_bitmap_bytes_offset(query, top_k, ef, bitmap, seg.base_id); for mut r in results { r.id += seg.base_id; all.push(r); @@ -109,7 +179,9 @@ impl VectorCollection { } for seg in &self.building { - let results = seg.flat.search_filtered(query, top_k, bitmap); + let results = seg + .flat + .search_filtered_offset(query, top_k, bitmap, seg.base_id); for mut r in results { r.id += seg.base_id; all.push(r); diff --git a/nodedb-vector/src/collection/segment.rs b/nodedb-vector/src/collection/segment.rs index 52de6ae0..ded68ce7 100644 --- a/nodedb-vector/src/collection/segment.rs +++ b/nodedb-vector/src/collection/segment.rs @@ -4,6 +4,7 @@ use crate::collection::tier::StorageTier; use crate::flat::FlatIndex; use crate::hnsw::{HnswIndex, HnswParams}; use crate::mmap_segment::MmapVectorSegment; +use crate::quantize::pq::PqCodec; use crate::quantize::sq8::Sq8Codec; /// Default threshold for sealing the growing segment. @@ -44,6 +45,8 @@ pub struct SealedSegment { pub base_id: u32, /// Optional SQ8 quantized vectors for accelerated traversal. pub sq8: Option<(Sq8Codec, Vec)>, + /// Optional PQ-compressed codes (for HnswPq-configured indexes). + pub pq: Option<(PqCodec, Vec)>, /// Storage tier: L0Ram = FP32 in HNSW nodes, L1Nvme = FP32 in mmap file. pub tier: StorageTier, /// mmap-backed vector segment for L1 NVMe tier. diff --git a/nodedb-vector/src/distance/mod.rs b/nodedb-vector/src/distance/mod.rs index 81806b21..8c11911f 100644 --- a/nodedb-vector/src/distance/mod.rs +++ b/nodedb-vector/src/distance/mod.rs @@ -13,6 +13,13 @@ pub use scalar::*; /// feature is enabled; otherwise uses scalar implementations. #[inline] pub fn distance(a: &[f32], b: &[f32], metric: DistanceMetric) -> f32 { + assert_eq!( + a.len(), + b.len(), + "distance: length mismatch (a.len()={}, b.len()={})", + a.len(), + b.len() + ); #[cfg(feature = "simd")] { let rt = simd::runtime(); diff --git a/nodedb-vector/src/distance/simd.rs b/nodedb-vector/src/distance/simd.rs deleted file mode 100644 index 9af97898..00000000 --- a/nodedb-vector/src/distance/simd.rs +++ /dev/null @@ -1,504 +0,0 @@ -//! Runtime SIMD dispatch for vector distance and bitmap operations. -//! -//! Detects CPU features at startup and selects the fastest available -//! kernel for each operation. A single binary supports all targets: -//! -//! - AVX-512 (512-bit, 16 floats/op) — Intel Xeon, AMD Zen 4+ -//! - AVX2+FMA (256-bit, 8 floats/op) — most x86_64 since 2013 -//! - NEON (128-bit, 4 floats/op) — ARM64 (Graviton, Apple Silicon) -//! - Scalar fallback — auto-vectorized loops - -/// Selected SIMD runtime — function pointers to the best available kernels. -pub struct SimdRuntime { - pub l2_squared: fn(&[f32], &[f32]) -> f32, - pub cosine_distance: fn(&[f32], &[f32]) -> f32, - pub neg_inner_product: fn(&[f32], &[f32]) -> f32, - pub hamming: fn(&[u8], &[u8]) -> u32, - pub name: &'static str, -} - -impl SimdRuntime { - /// Detect CPU features and select the best kernels. - pub fn detect() -> Self { - #[cfg(target_arch = "x86_64")] - { - if std::is_x86_feature_detected!("avx512f") { - return Self { - l2_squared: avx512::l2_squared, - cosine_distance: avx512::cosine_distance, - neg_inner_product: avx512::neg_inner_product, - hamming: fast_hamming, - name: "avx512", - }; - } - if std::is_x86_feature_detected!("avx2") && std::is_x86_feature_detected!("fma") { - return Self { - l2_squared: avx2::l2_squared, - cosine_distance: avx2::cosine_distance, - neg_inner_product: avx2::neg_inner_product, - hamming: fast_hamming, - name: "avx2+fma", - }; - } - } - #[cfg(target_arch = "aarch64")] - { - return Self { - l2_squared: neon::l2_squared, - cosine_distance: neon::cosine_distance, - neg_inner_product: neon::neg_inner_product, - hamming: fast_hamming, - name: "neon", - }; - } - #[allow(unreachable_code)] - Self { - l2_squared: scalar_l2, - cosine_distance: scalar_cosine, - neg_inner_product: scalar_ip, - hamming: fast_hamming, - name: "scalar", - } - } -} - -/// Global SIMD runtime — initialized once, used everywhere. -static RUNTIME: std::sync::OnceLock = std::sync::OnceLock::new(); - -/// Get the global SIMD runtime (auto-detects on first call). -pub fn runtime() -> &'static SimdRuntime { - RUNTIME.get_or_init(SimdRuntime::detect) -} - -// ── Scalar fallback ── - -fn scalar_l2(a: &[f32], b: &[f32]) -> f32 { - let mut sum = 0.0f32; - for i in 0..a.len() { - let d = a[i] - b[i]; - sum += d * d; - } - sum -} - -fn scalar_cosine(a: &[f32], b: &[f32]) -> f32 { - let mut dot = 0.0f32; - let mut na = 0.0f32; - let mut nb = 0.0f32; - for i in 0..a.len() { - dot += a[i] * b[i]; - na += a[i] * a[i]; - nb += b[i] * b[i]; - } - let denom = (na * nb).sqrt(); - if denom < f32::EPSILON { - 1.0 - } else { - (1.0 - dot / denom).max(0.0) - } -} - -fn scalar_ip(a: &[f32], b: &[f32]) -> f32 { - let mut dot = 0.0f32; - for i in 0..a.len() { - dot += a[i] * b[i]; - } - -dot -} - -/// Fast Hamming distance using u64 POPCNT (available on all modern CPUs). -fn fast_hamming(a: &[u8], b: &[u8]) -> u32 { - let mut dist = 0u32; - let chunks = a.len() / 8; - for i in 0..chunks { - let off = i * 8; - let xa = u64::from_le_bytes([ - a[off], - a[off + 1], - a[off + 2], - a[off + 3], - a[off + 4], - a[off + 5], - a[off + 6], - a[off + 7], - ]); - let xb = u64::from_le_bytes([ - b[off], - b[off + 1], - b[off + 2], - b[off + 3], - b[off + 4], - b[off + 5], - b[off + 6], - b[off + 7], - ]); - dist += (xa ^ xb).count_ones(); - } - for i in (chunks * 8)..a.len() { - dist += (a[i] ^ b[i]).count_ones(); - } - dist -} - -// ── AVX2+FMA kernels ── - -#[cfg(target_arch = "x86_64")] -mod avx2 { - pub fn l2_squared(a: &[f32], b: &[f32]) -> f32 { - // SAFETY: caller verified avx2+fma via is_x86_feature_detected. - unsafe { l2_squared_impl(a, b) } - } - - #[target_feature(enable = "avx2,fma")] - unsafe fn l2_squared_impl(a: &[f32], b: &[f32]) -> f32 { - unsafe { - use std::arch::x86_64::*; - let n = a.len(); - let mut sum = _mm256_setzero_ps(); - let chunks = n / 8; - for i in 0..chunks { - let off = i * 8; - let va = _mm256_loadu_ps(a.as_ptr().add(off)); - let vb = _mm256_loadu_ps(b.as_ptr().add(off)); - let diff = _mm256_sub_ps(va, vb); - sum = _mm256_fmadd_ps(diff, diff, sum); - } - let mut result = hsum256(sum); - for i in (chunks * 8)..n { - let d = a[i] - b[i]; - result += d * d; - } - result - } - } - - pub fn cosine_distance(a: &[f32], b: &[f32]) -> f32 { - unsafe { cosine_impl(a, b) } - } - - #[target_feature(enable = "avx2,fma")] - unsafe fn cosine_impl(a: &[f32], b: &[f32]) -> f32 { - unsafe { - use std::arch::x86_64::*; - let n = a.len(); - let mut vdot = _mm256_setzero_ps(); - let mut vna = _mm256_setzero_ps(); - let mut vnb = _mm256_setzero_ps(); - let chunks = n / 8; - for i in 0..chunks { - let off = i * 8; - let va = _mm256_loadu_ps(a.as_ptr().add(off)); - let vb = _mm256_loadu_ps(b.as_ptr().add(off)); - vdot = _mm256_fmadd_ps(va, vb, vdot); - vna = _mm256_fmadd_ps(va, va, vna); - vnb = _mm256_fmadd_ps(vb, vb, vnb); - } - let mut dot = hsum256(vdot); - let mut na = hsum256(vna); - let mut nb = hsum256(vnb); - for i in (chunks * 8)..n { - dot += a[i] * b[i]; - na += a[i] * a[i]; - nb += b[i] * b[i]; - } - let denom = (na * nb).sqrt(); - if denom < f32::EPSILON { - 1.0 - } else { - (1.0 - dot / denom).max(0.0) - } - } - } - - pub fn neg_inner_product(a: &[f32], b: &[f32]) -> f32 { - unsafe { ip_impl(a, b) } - } - - #[target_feature(enable = "avx2,fma")] - unsafe fn ip_impl(a: &[f32], b: &[f32]) -> f32 { - unsafe { - use std::arch::x86_64::*; - let n = a.len(); - let mut vdot = _mm256_setzero_ps(); - let chunks = n / 8; - for i in 0..chunks { - let off = i * 8; - let va = _mm256_loadu_ps(a.as_ptr().add(off)); - let vb = _mm256_loadu_ps(b.as_ptr().add(off)); - vdot = _mm256_fmadd_ps(va, vb, vdot); - } - let mut dot = hsum256(vdot); - for i in (chunks * 8)..n { - dot += a[i] * b[i]; - } - -dot - } - } - - /// Horizontal sum of 8 × f32 in a __m256. - #[target_feature(enable = "avx2")] - unsafe fn hsum256(v: std::arch::x86_64::__m256) -> f32 { - use std::arch::x86_64::*; - let hi = _mm256_extractf128_ps(v, 1); - let lo = _mm256_castps256_ps128(v); - let sum128 = _mm_add_ps(lo, hi); - let shuf = _mm_movehdup_ps(sum128); - let sums = _mm_add_ps(sum128, shuf); - let shuf2 = _mm_movehl_ps(sums, sums); - let sums2 = _mm_add_ss(sums, shuf2); - _mm_cvtss_f32(sums2) - } -} - -// ── AVX-512 kernels ── - -#[cfg(target_arch = "x86_64")] -mod avx512 { - pub fn l2_squared(a: &[f32], b: &[f32]) -> f32 { - unsafe { l2_impl(a, b) } - } - - #[target_feature(enable = "avx512f")] - unsafe fn l2_impl(a: &[f32], b: &[f32]) -> f32 { - unsafe { - use std::arch::x86_64::*; - let n = a.len(); - let mut sum = _mm512_setzero_ps(); - let chunks = n / 16; - for i in 0..chunks { - let off = i * 16; - let va = _mm512_loadu_ps(a.as_ptr().add(off)); - let vb = _mm512_loadu_ps(b.as_ptr().add(off)); - let diff = _mm512_sub_ps(va, vb); - sum = _mm512_fmadd_ps(diff, diff, sum); - } - let mut result = _mm512_reduce_add_ps(sum); - for i in (chunks * 16)..n { - let d = a[i] - b[i]; - result += d * d; - } - result - } - } - - pub fn cosine_distance(a: &[f32], b: &[f32]) -> f32 { - unsafe { cosine_impl(a, b) } - } - - #[target_feature(enable = "avx512f")] - unsafe fn cosine_impl(a: &[f32], b: &[f32]) -> f32 { - unsafe { - use std::arch::x86_64::*; - let n = a.len(); - let mut vdot = _mm512_setzero_ps(); - let mut vna = _mm512_setzero_ps(); - let mut vnb = _mm512_setzero_ps(); - let chunks = n / 16; - for i in 0..chunks { - let off = i * 16; - let va = _mm512_loadu_ps(a.as_ptr().add(off)); - let vb = _mm512_loadu_ps(b.as_ptr().add(off)); - vdot = _mm512_fmadd_ps(va, vb, vdot); - vna = _mm512_fmadd_ps(va, va, vna); - vnb = _mm512_fmadd_ps(vb, vb, vnb); - } - let mut dot = _mm512_reduce_add_ps(vdot); - let mut na = _mm512_reduce_add_ps(vna); - let mut nb = _mm512_reduce_add_ps(vnb); - for i in (chunks * 16)..n { - dot += a[i] * b[i]; - na += a[i] * a[i]; - nb += b[i] * b[i]; - } - let denom = (na * nb).sqrt(); - if denom < f32::EPSILON { - 1.0 - } else { - (1.0 - dot / denom).max(0.0) - } - } - } - - pub fn neg_inner_product(a: &[f32], b: &[f32]) -> f32 { - unsafe { ip_impl(a, b) } - } - - #[target_feature(enable = "avx512f")] - unsafe fn ip_impl(a: &[f32], b: &[f32]) -> f32 { - unsafe { - use std::arch::x86_64::*; - let n = a.len(); - let mut vdot = _mm512_setzero_ps(); - let chunks = n / 16; - for i in 0..chunks { - let off = i * 16; - let va = _mm512_loadu_ps(a.as_ptr().add(off)); - let vb = _mm512_loadu_ps(b.as_ptr().add(off)); - vdot = _mm512_fmadd_ps(va, vb, vdot); - } - let mut dot = _mm512_reduce_add_ps(vdot); - for i in (chunks * 16)..n { - dot += a[i] * b[i]; - } - -dot - } - } -} - -// ── NEON kernels (ARM64) ── - -#[cfg(target_arch = "aarch64")] -mod neon { - pub fn l2_squared(a: &[f32], b: &[f32]) -> f32 { - unsafe { l2_impl(a, b) } - } - - unsafe fn l2_impl(a: &[f32], b: &[f32]) -> f32 { - use std::arch::aarch64::*; - let n = a.len(); - let mut sum = vdupq_n_f32(0.0); - let chunks = n / 4; - for i in 0..chunks { - let off = i * 4; - let va = vld1q_f32(a.as_ptr().add(off)); - let vb = vld1q_f32(b.as_ptr().add(off)); - let diff = vsubq_f32(va, vb); - sum = vfmaq_f32(sum, diff, diff); - } - let mut result = vaddvq_f32(sum); - for i in (chunks * 4)..n { - let d = a[i] - b[i]; - result += d * d; - } - result - } - - pub fn cosine_distance(a: &[f32], b: &[f32]) -> f32 { - unsafe { cosine_impl(a, b) } - } - - unsafe fn cosine_impl(a: &[f32], b: &[f32]) -> f32 { - use std::arch::aarch64::*; - let n = a.len(); - let mut vdot = vdupq_n_f32(0.0); - let mut vna = vdupq_n_f32(0.0); - let mut vnb = vdupq_n_f32(0.0); - let chunks = n / 4; - for i in 0..chunks { - let off = i * 4; - let va = vld1q_f32(a.as_ptr().add(off)); - let vb = vld1q_f32(b.as_ptr().add(off)); - vdot = vfmaq_f32(vdot, va, vb); - vna = vfmaq_f32(vna, va, va); - vnb = vfmaq_f32(vnb, vb, vb); - } - let mut dot = vaddvq_f32(vdot); - let mut na = vaddvq_f32(vna); - let mut nb = vaddvq_f32(vnb); - for i in (chunks * 4)..n { - dot += a[i] * b[i]; - na += a[i] * a[i]; - nb += b[i] * b[i]; - } - let denom = (na * nb).sqrt(); - if denom < f32::EPSILON { - 1.0 - } else { - (1.0 - dot / denom).max(0.0) - } - } - - pub fn neg_inner_product(a: &[f32], b: &[f32]) -> f32 { - unsafe { ip_impl(a, b) } - } - - unsafe fn ip_impl(a: &[f32], b: &[f32]) -> f32 { - use std::arch::aarch64::*; - let n = a.len(); - let mut vdot = vdupq_n_f32(0.0); - let chunks = n / 4; - for i in 0..chunks { - let off = i * 4; - let va = vld1q_f32(a.as_ptr().add(off)); - let vb = vld1q_f32(b.as_ptr().add(off)); - vdot = vfmaq_f32(vdot, va, vb); - } - let mut dot = vaddvq_f32(vdot); - for i in (chunks * 4)..n { - dot += a[i] * b[i]; - } - -dot - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn runtime_detects_features() { - let rt = SimdRuntime::detect(); - // Should detect at least scalar on any platform. - assert!(!rt.name.is_empty()); - tracing::info!("SIMD runtime: {}", rt.name); - } - - #[test] - fn l2_matches_scalar() { - let rt = runtime(); - let a: Vec = (0..768).map(|i| (i as f32) * 0.01).collect(); - let b: Vec = (0..768).map(|i| (i as f32) * 0.01 + 0.001).collect(); - - let simd_result = (rt.l2_squared)(&a, &b); - let scalar_result = scalar_l2(&a, &b); - assert!( - (simd_result - scalar_result).abs() < 0.01, - "simd={simd_result}, scalar={scalar_result}" - ); - } - - #[test] - fn cosine_matches_scalar() { - let rt = runtime(); - let a: Vec = (0..768).map(|i| (i as f32).sin()).collect(); - let b: Vec = (0..768).map(|i| (i as f32).cos()).collect(); - - let simd_result = (rt.cosine_distance)(&a, &b); - let scalar_result = scalar_cosine(&a, &b); - assert!( - (simd_result - scalar_result).abs() < 0.001, - "simd={simd_result}, scalar={scalar_result}" - ); - } - - #[test] - fn ip_matches_scalar() { - let rt = runtime(); - let a: Vec = (0..128).map(|i| (i as f32) * 0.1).collect(); - let b: Vec = (0..128).map(|i| (i as f32) * 0.2).collect(); - - let simd_result = (rt.neg_inner_product)(&a, &b); - let scalar_result = scalar_ip(&a, &b); - assert!( - (simd_result - scalar_result).abs() < 0.1, - "simd={simd_result}, scalar={scalar_result}" - ); - } - - #[test] - fn hamming_matches() { - let a = vec![0b10101010u8; 16]; - let b = vec![0b01010101u8; 16]; - assert_eq!(fast_hamming(&a, &b), 128); // all 128 bits differ - } - - #[test] - fn small_vectors() { - let rt = runtime(); - // Vectors smaller than SIMD width — tests remainder handling. - let a = [1.0f32, 2.0, 3.0]; - let b = [4.0f32, 5.0, 6.0]; - let l2 = (rt.l2_squared)(&a, &b); - assert!((l2 - 27.0).abs() < 0.01); // (3² + 3² + 3²) = 27 - } -} diff --git a/nodedb-vector/src/distance/simd/avx2.rs b/nodedb-vector/src/distance/simd/avx2.rs new file mode 100644 index 00000000..ad530afa --- /dev/null +++ b/nodedb-vector/src/distance/simd/avx2.rs @@ -0,0 +1,114 @@ +//! AVX2+FMA kernels for x86_64. + +#![cfg(target_arch = "x86_64")] + +pub fn l2_squared(a: &[f32], b: &[f32]) -> f32 { + assert_eq!(a.len(), b.len(), "avx2 l2: length mismatch"); + // SAFETY: caller verified avx2+fma via is_x86_feature_detected. + unsafe { l2_squared_impl(a, b) } +} + +#[target_feature(enable = "avx2,fma")] +unsafe fn l2_squared_impl(a: &[f32], b: &[f32]) -> f32 { + assert_eq!(a.len(), b.len(), "avx2 l2_impl: length mismatch"); + unsafe { + use std::arch::x86_64::*; + let n = a.len(); + let mut sum = _mm256_setzero_ps(); + let chunks = n / 8; + for i in 0..chunks { + let off = i * 8; + let va = _mm256_loadu_ps(a.as_ptr().add(off)); + let vb = _mm256_loadu_ps(b.as_ptr().add(off)); + let diff = _mm256_sub_ps(va, vb); + sum = _mm256_fmadd_ps(diff, diff, sum); + } + let mut result = hsum256(sum); + for i in (chunks * 8)..n { + let d = a[i] - b[i]; + result += d * d; + } + result + } +} + +pub fn cosine_distance(a: &[f32], b: &[f32]) -> f32 { + assert_eq!(a.len(), b.len(), "avx2 cosine: length mismatch"); + unsafe { cosine_impl(a, b) } +} + +#[target_feature(enable = "avx2,fma")] +unsafe fn cosine_impl(a: &[f32], b: &[f32]) -> f32 { + assert_eq!(a.len(), b.len(), "avx2 cosine_impl: length mismatch"); + unsafe { + use std::arch::x86_64::*; + let n = a.len(); + let mut vdot = _mm256_setzero_ps(); + let mut vna = _mm256_setzero_ps(); + let mut vnb = _mm256_setzero_ps(); + let chunks = n / 8; + for i in 0..chunks { + let off = i * 8; + let va = _mm256_loadu_ps(a.as_ptr().add(off)); + let vb = _mm256_loadu_ps(b.as_ptr().add(off)); + vdot = _mm256_fmadd_ps(va, vb, vdot); + vna = _mm256_fmadd_ps(va, va, vna); + vnb = _mm256_fmadd_ps(vb, vb, vnb); + } + let mut dot = hsum256(vdot); + let mut na = hsum256(vna); + let mut nb = hsum256(vnb); + for i in (chunks * 8)..n { + dot += a[i] * b[i]; + na += a[i] * a[i]; + nb += b[i] * b[i]; + } + let denom = (na * nb).sqrt(); + if denom < f32::EPSILON { + 1.0 + } else { + (1.0 - dot / denom).max(0.0) + } + } +} + +pub fn neg_inner_product(a: &[f32], b: &[f32]) -> f32 { + assert_eq!(a.len(), b.len(), "avx2 ip: length mismatch"); + unsafe { ip_impl(a, b) } +} + +#[target_feature(enable = "avx2,fma")] +unsafe fn ip_impl(a: &[f32], b: &[f32]) -> f32 { + assert_eq!(a.len(), b.len(), "avx2 ip_impl: length mismatch"); + unsafe { + use std::arch::x86_64::*; + let n = a.len(); + let mut vdot = _mm256_setzero_ps(); + let chunks = n / 8; + for i in 0..chunks { + let off = i * 8; + let va = _mm256_loadu_ps(a.as_ptr().add(off)); + let vb = _mm256_loadu_ps(b.as_ptr().add(off)); + vdot = _mm256_fmadd_ps(va, vb, vdot); + } + let mut dot = hsum256(vdot); + for i in (chunks * 8)..n { + dot += a[i] * b[i]; + } + -dot + } +} + +/// Horizontal sum of 8 × f32 in a __m256. +#[target_feature(enable = "avx2")] +unsafe fn hsum256(v: std::arch::x86_64::__m256) -> f32 { + use std::arch::x86_64::*; + let hi = _mm256_extractf128_ps(v, 1); + let lo = _mm256_castps256_ps128(v); + let sum128 = _mm_add_ps(lo, hi); + let shuf = _mm_movehdup_ps(sum128); + let sums = _mm_add_ps(sum128, shuf); + let shuf2 = _mm_movehl_ps(sums, sums); + let sums2 = _mm_add_ss(sums, shuf2); + _mm_cvtss_f32(sums2) +} diff --git a/nodedb-vector/src/distance/simd/avx512.rs b/nodedb-vector/src/distance/simd/avx512.rs new file mode 100644 index 00000000..037ae6fd --- /dev/null +++ b/nodedb-vector/src/distance/simd/avx512.rs @@ -0,0 +1,99 @@ +//! AVX-512 kernels for x86_64. + +#![cfg(target_arch = "x86_64")] + +pub fn l2_squared(a: &[f32], b: &[f32]) -> f32 { + assert_eq!(a.len(), b.len(), "avx512 l2: length mismatch"); + unsafe { l2_impl(a, b) } +} + +#[target_feature(enable = "avx512f")] +unsafe fn l2_impl(a: &[f32], b: &[f32]) -> f32 { + assert_eq!(a.len(), b.len(), "avx512 l2_impl: length mismatch"); + unsafe { + use std::arch::x86_64::*; + let n = a.len(); + let mut sum = _mm512_setzero_ps(); + let chunks = n / 16; + for i in 0..chunks { + let off = i * 16; + let va = _mm512_loadu_ps(a.as_ptr().add(off)); + let vb = _mm512_loadu_ps(b.as_ptr().add(off)); + let diff = _mm512_sub_ps(va, vb); + sum = _mm512_fmadd_ps(diff, diff, sum); + } + let mut result = _mm512_reduce_add_ps(sum); + for i in (chunks * 16)..n { + let d = a[i] - b[i]; + result += d * d; + } + result + } +} + +pub fn cosine_distance(a: &[f32], b: &[f32]) -> f32 { + assert_eq!(a.len(), b.len(), "avx512 cosine: length mismatch"); + unsafe { cosine_impl(a, b) } +} + +#[target_feature(enable = "avx512f")] +unsafe fn cosine_impl(a: &[f32], b: &[f32]) -> f32 { + assert_eq!(a.len(), b.len(), "avx512 cosine_impl: length mismatch"); + unsafe { + use std::arch::x86_64::*; + let n = a.len(); + let mut vdot = _mm512_setzero_ps(); + let mut vna = _mm512_setzero_ps(); + let mut vnb = _mm512_setzero_ps(); + let chunks = n / 16; + for i in 0..chunks { + let off = i * 16; + let va = _mm512_loadu_ps(a.as_ptr().add(off)); + let vb = _mm512_loadu_ps(b.as_ptr().add(off)); + vdot = _mm512_fmadd_ps(va, vb, vdot); + vna = _mm512_fmadd_ps(va, va, vna); + vnb = _mm512_fmadd_ps(vb, vb, vnb); + } + let mut dot = _mm512_reduce_add_ps(vdot); + let mut na = _mm512_reduce_add_ps(vna); + let mut nb = _mm512_reduce_add_ps(vnb); + for i in (chunks * 16)..n { + dot += a[i] * b[i]; + na += a[i] * a[i]; + nb += b[i] * b[i]; + } + let denom = (na * nb).sqrt(); + if denom < f32::EPSILON { + 1.0 + } else { + (1.0 - dot / denom).max(0.0) + } + } +} + +pub fn neg_inner_product(a: &[f32], b: &[f32]) -> f32 { + assert_eq!(a.len(), b.len(), "avx512 ip: length mismatch"); + unsafe { ip_impl(a, b) } +} + +#[target_feature(enable = "avx512f")] +unsafe fn ip_impl(a: &[f32], b: &[f32]) -> f32 { + assert_eq!(a.len(), b.len(), "avx512 ip_impl: length mismatch"); + unsafe { + use std::arch::x86_64::*; + let n = a.len(); + let mut vdot = _mm512_setzero_ps(); + let chunks = n / 16; + for i in 0..chunks { + let off = i * 16; + let va = _mm512_loadu_ps(a.as_ptr().add(off)); + let vb = _mm512_loadu_ps(b.as_ptr().add(off)); + vdot = _mm512_fmadd_ps(va, vb, vdot); + } + let mut dot = _mm512_reduce_add_ps(vdot); + for i in (chunks * 16)..n { + dot += a[i] * b[i]; + } + -dot + } +} diff --git a/nodedb-vector/src/distance/simd/hamming.rs b/nodedb-vector/src/distance/simd/hamming.rs new file mode 100644 index 00000000..0c8ef558 --- /dev/null +++ b/nodedb-vector/src/distance/simd/hamming.rs @@ -0,0 +1,35 @@ +//! Fast Hamming distance using u64 POPCNT. + +pub fn fast_hamming(a: &[u8], b: &[u8]) -> u32 { + assert_eq!(a.len(), b.len(), "fast_hamming: length mismatch"); + let mut dist = 0u32; + let chunks = a.len() / 8; + for i in 0..chunks { + let off = i * 8; + let xa = u64::from_le_bytes([ + a[off], + a[off + 1], + a[off + 2], + a[off + 3], + a[off + 4], + a[off + 5], + a[off + 6], + a[off + 7], + ]); + let xb = u64::from_le_bytes([ + b[off], + b[off + 1], + b[off + 2], + b[off + 3], + b[off + 4], + b[off + 5], + b[off + 6], + b[off + 7], + ]); + dist += (xa ^ xb).count_ones(); + } + for i in (chunks * 8)..a.len() { + dist += (a[i] ^ b[i]).count_ones(); + } + dist +} diff --git a/nodedb-vector/src/distance/simd/mod.rs b/nodedb-vector/src/distance/simd/mod.rs new file mode 100644 index 00000000..c5b1766b --- /dev/null +++ b/nodedb-vector/src/distance/simd/mod.rs @@ -0,0 +1,14 @@ +//! Runtime SIMD dispatch for vector distance and bitmap operations. + +pub mod hamming; +pub mod runtime; +pub mod scalar; + +#[cfg(target_arch = "x86_64")] +pub mod avx2; +#[cfg(target_arch = "x86_64")] +pub mod avx512; +#[cfg(target_arch = "aarch64")] +pub mod neon; + +pub use runtime::{SimdRuntime, runtime}; diff --git a/nodedb-vector/src/distance/simd/neon.rs b/nodedb-vector/src/distance/simd/neon.rs new file mode 100644 index 00000000..6600f779 --- /dev/null +++ b/nodedb-vector/src/distance/simd/neon.rs @@ -0,0 +1,96 @@ +//! NEON kernels for ARM64. + +#![cfg(target_arch = "aarch64")] + +pub fn l2_squared(a: &[f32], b: &[f32]) -> f32 { + assert_eq!(a.len(), b.len(), "neon l2: length mismatch"); + unsafe { l2_impl(a, b) } +} + +unsafe fn l2_impl(a: &[f32], b: &[f32]) -> f32 { + assert_eq!(a.len(), b.len(), "neon l2_impl: length mismatch"); + unsafe { + use std::arch::aarch64::*; + let n = a.len(); + let mut sum = vdupq_n_f32(0.0); + let chunks = n / 4; + for i in 0..chunks { + let off = i * 4; + let va = vld1q_f32(a.as_ptr().add(off)); + let vb = vld1q_f32(b.as_ptr().add(off)); + let diff = vsubq_f32(va, vb); + sum = vfmaq_f32(sum, diff, diff); + } + let mut result = vaddvq_f32(sum); + for i in (chunks * 4)..n { + let d = a[i] - b[i]; + result += d * d; + } + result + } +} + +pub fn cosine_distance(a: &[f32], b: &[f32]) -> f32 { + assert_eq!(a.len(), b.len(), "neon cosine: length mismatch"); + unsafe { cosine_impl(a, b) } +} + +unsafe fn cosine_impl(a: &[f32], b: &[f32]) -> f32 { + assert_eq!(a.len(), b.len(), "neon cosine_impl: length mismatch"); + unsafe { + use std::arch::aarch64::*; + let n = a.len(); + let mut vdot = vdupq_n_f32(0.0); + let mut vna = vdupq_n_f32(0.0); + let mut vnb = vdupq_n_f32(0.0); + let chunks = n / 4; + for i in 0..chunks { + let off = i * 4; + let va = vld1q_f32(a.as_ptr().add(off)); + let vb = vld1q_f32(b.as_ptr().add(off)); + vdot = vfmaq_f32(vdot, va, vb); + vna = vfmaq_f32(vna, va, va); + vnb = vfmaq_f32(vnb, vb, vb); + } + let mut dot = vaddvq_f32(vdot); + let mut na = vaddvq_f32(vna); + let mut nb = vaddvq_f32(vnb); + for i in (chunks * 4)..n { + dot += a[i] * b[i]; + na += a[i] * a[i]; + nb += b[i] * b[i]; + } + let denom = (na * nb).sqrt(); + if denom < f32::EPSILON { + 1.0 + } else { + (1.0 - dot / denom).max(0.0) + } + } +} + +pub fn neg_inner_product(a: &[f32], b: &[f32]) -> f32 { + assert_eq!(a.len(), b.len(), "neon ip: length mismatch"); + unsafe { ip_impl(a, b) } +} + +unsafe fn ip_impl(a: &[f32], b: &[f32]) -> f32 { + assert_eq!(a.len(), b.len(), "neon ip_impl: length mismatch"); + unsafe { + use std::arch::aarch64::*; + let n = a.len(); + let mut vdot = vdupq_n_f32(0.0); + let chunks = n / 4; + for i in 0..chunks { + let off = i * 4; + let va = vld1q_f32(a.as_ptr().add(off)); + let vb = vld1q_f32(b.as_ptr().add(off)); + vdot = vfmaq_f32(vdot, va, vb); + } + let mut dot = vaddvq_f32(vdot); + for i in (chunks * 4)..n { + dot += a[i] * b[i]; + } + -dot + } +} diff --git a/nodedb-vector/src/distance/simd/runtime.rs b/nodedb-vector/src/distance/simd/runtime.rs new file mode 100644 index 00000000..a9fd0404 --- /dev/null +++ b/nodedb-vector/src/distance/simd/runtime.rs @@ -0,0 +1,144 @@ +//! Runtime SIMD detection and dispatch table. + +use super::hamming::fast_hamming; +use super::scalar::{scalar_cosine, scalar_ip, scalar_l2}; + +#[cfg(target_arch = "x86_64")] +use super::{avx2, avx512}; + +#[cfg(target_arch = "aarch64")] +use super::neon; + +/// Selected SIMD runtime — function pointers to the best available kernels. +pub struct SimdRuntime { + pub l2_squared: fn(&[f32], &[f32]) -> f32, + pub cosine_distance: fn(&[f32], &[f32]) -> f32, + pub neg_inner_product: fn(&[f32], &[f32]) -> f32, + pub hamming: fn(&[u8], &[u8]) -> u32, + pub name: &'static str, +} + +impl SimdRuntime { + /// Detect CPU features and select the best kernels. + pub fn detect() -> Self { + #[cfg(target_arch = "x86_64")] + { + if std::is_x86_feature_detected!("avx512f") { + return Self { + l2_squared: avx512::l2_squared, + cosine_distance: avx512::cosine_distance, + neg_inner_product: avx512::neg_inner_product, + hamming: fast_hamming, + name: "avx512", + }; + } + if std::is_x86_feature_detected!("avx2") && std::is_x86_feature_detected!("fma") { + return Self { + l2_squared: avx2::l2_squared, + cosine_distance: avx2::cosine_distance, + neg_inner_product: avx2::neg_inner_product, + hamming: fast_hamming, + name: "avx2+fma", + }; + } + } + #[cfg(target_arch = "aarch64")] + { + return Self { + l2_squared: neon::l2_squared, + cosine_distance: neon::cosine_distance, + neg_inner_product: neon::neg_inner_product, + hamming: fast_hamming, + name: "neon", + }; + } + #[allow(unreachable_code)] + Self { + l2_squared: scalar_l2, + cosine_distance: scalar_cosine, + neg_inner_product: scalar_ip, + hamming: fast_hamming, + name: "scalar", + } + } +} + +/// Global SIMD runtime — initialized once, used everywhere. +static RUNTIME: std::sync::OnceLock = std::sync::OnceLock::new(); + +/// Get the global SIMD runtime (auto-detects on first call). +pub fn runtime() -> &'static SimdRuntime { + RUNTIME.get_or_init(SimdRuntime::detect) +} + +#[cfg(test)] +mod tests { + use super::super::hamming::fast_hamming; + use super::super::scalar::{scalar_cosine, scalar_ip, scalar_l2}; + use super::*; + + #[test] + fn runtime_detects_features() { + let rt = SimdRuntime::detect(); + assert!(!rt.name.is_empty()); + tracing::info!("SIMD runtime: {}", rt.name); + } + + #[test] + fn l2_matches_scalar() { + let rt = runtime(); + let a: Vec = (0..768).map(|i| (i as f32) * 0.01).collect(); + let b: Vec = (0..768).map(|i| (i as f32) * 0.01 + 0.001).collect(); + + let simd_result = (rt.l2_squared)(&a, &b); + let scalar_result = scalar_l2(&a, &b); + assert!( + (simd_result - scalar_result).abs() < 0.01, + "simd={simd_result}, scalar={scalar_result}" + ); + } + + #[test] + fn cosine_matches_scalar() { + let rt = runtime(); + let a: Vec = (0..768).map(|i| (i as f32).sin()).collect(); + let b: Vec = (0..768).map(|i| (i as f32).cos()).collect(); + + let simd_result = (rt.cosine_distance)(&a, &b); + let scalar_result = scalar_cosine(&a, &b); + assert!( + (simd_result - scalar_result).abs() < 0.001, + "simd={simd_result}, scalar={scalar_result}" + ); + } + + #[test] + fn ip_matches_scalar() { + let rt = runtime(); + let a: Vec = (0..128).map(|i| (i as f32) * 0.1).collect(); + let b: Vec = (0..128).map(|i| (i as f32) * 0.2).collect(); + + let simd_result = (rt.neg_inner_product)(&a, &b); + let scalar_result = scalar_ip(&a, &b); + assert!( + (simd_result - scalar_result).abs() < 0.1, + "simd={simd_result}, scalar={scalar_result}" + ); + } + + #[test] + fn hamming_matches() { + let a = vec![0b10101010u8; 16]; + let b = vec![0b01010101u8; 16]; + assert_eq!(fast_hamming(&a, &b), 128); + } + + #[test] + fn small_vectors() { + let rt = runtime(); + let a = [1.0f32, 2.0, 3.0]; + let b = [4.0f32, 5.0, 6.0]; + let l2 = (rt.l2_squared)(&a, &b); + assert!((l2 - 27.0).abs() < 0.01); + } +} diff --git a/nodedb-vector/src/distance/simd/scalar.rs b/nodedb-vector/src/distance/simd/scalar.rs new file mode 100644 index 00000000..c3e52028 --- /dev/null +++ b/nodedb-vector/src/distance/simd/scalar.rs @@ -0,0 +1,38 @@ +//! Scalar fallback kernels for L2, cosine, and inner product. + +pub fn scalar_l2(a: &[f32], b: &[f32]) -> f32 { + assert_eq!(a.len(), b.len(), "scalar_l2: length mismatch"); + let mut sum = 0.0f32; + for i in 0..a.len() { + let d = a[i] - b[i]; + sum += d * d; + } + sum +} + +pub fn scalar_cosine(a: &[f32], b: &[f32]) -> f32 { + assert_eq!(a.len(), b.len(), "scalar_cosine: length mismatch"); + let mut dot = 0.0f32; + let mut na = 0.0f32; + let mut nb = 0.0f32; + for i in 0..a.len() { + dot += a[i] * b[i]; + na += a[i] * a[i]; + nb += b[i] * b[i]; + } + let denom = (na * nb).sqrt(); + if denom < f32::EPSILON { + 1.0 + } else { + (1.0 - dot / denom).max(0.0) + } +} + +pub fn scalar_ip(a: &[f32], b: &[f32]) -> f32 { + assert_eq!(a.len(), b.len(), "scalar_ip: length mismatch"); + let mut dot = 0.0f32; + for i in 0..a.len() { + dot += a[i] * b[i]; + } + -dot +} diff --git a/nodedb-vector/src/flat.rs b/nodedb-vector/src/flat.rs index 81fb01d5..f75ebf21 100644 --- a/nodedb-vector/src/flat.rs +++ b/nodedb-vector/src/flat.rs @@ -105,6 +105,21 @@ impl FlatIndex { /// Search with a pre-filter bitmap (byte-array format). pub fn search_filtered(&self, query: &[f32], top_k: usize, bitmap: &[u8]) -> Vec { + self.search_filtered_offset(query, top_k, bitmap, 0) + } + + /// Search with a pre-filter bitmap applying a global id offset. + /// + /// The bitmap is interpreted in a shifted id space: bit `i + id_offset` + /// tests local id `i`. Used by multi-segment collections where the + /// bitmap holds GLOBAL vector ids. + pub fn search_filtered_offset( + &self, + query: &[f32], + top_k: usize, + bitmap: &[u8], + id_offset: u32, + ) -> Vec { assert_eq!(query.len(), self.dim); let n = self.len(); if n == 0 || top_k == 0 { @@ -116,8 +131,9 @@ impl FlatIndex { if self.deleted[i] { continue; } - let byte_idx = i / 8; - let bit_idx = i % 8; + let global = i + id_offset as usize; + let byte_idx = global / 8; + let bit_idx = global % 8; if byte_idx >= bitmap.len() || (bitmap[byte_idx] & (1 << bit_idx)) == 0 { continue; } @@ -159,6 +175,17 @@ impl FlatIndex { } pub fn get_vector(&self, id: u32) -> Option<&[f32]> { + let idx = id as usize; + if idx < self.deleted.len() && !self.deleted[idx] { + let start = idx * self.dim; + Some(&self.data[start..start + self.dim]) + } else { + None + } + } + + /// Raw access bypassing tombstone filter — used by snapshot/restore. + pub fn get_vector_raw(&self, id: u32) -> Option<&[f32]> { let idx = id as usize; if idx < self.deleted.len() { let start = idx * self.dim; @@ -168,6 +195,28 @@ impl FlatIndex { } } + /// Whether the given local id has been tombstoned. + pub fn is_deleted(&self, id: u32) -> bool { + let idx = id as usize; + idx < self.deleted.len() && self.deleted[idx] + } + + /// Insert a vector that is already tombstoned (for checkpoint restore). + pub fn insert_tombstoned(&mut self, vector: Vec) -> u32 { + assert_eq!( + vector.len(), + self.dim, + "dimension mismatch: expected {}, got {}", + self.dim, + vector.len() + ); + let id = self.len() as u32; + self.data.extend_from_slice(&vector); + self.deleted.push(true); + // No live_count increment — it's dead on arrival. + id + } + pub fn dim(&self) -> usize { self.dim } diff --git a/nodedb-vector/src/hnsw/build.rs b/nodedb-vector/src/hnsw/build.rs index 503953a4..140c2d5c 100644 --- a/nodedb-vector/src/hnsw/build.rs +++ b/nodedb-vector/src/hnsw/build.rs @@ -47,7 +47,7 @@ impl HnswIndex { // Phase 1: Greedy descent from top layer to new_layer + 1. if self.max_layer > new_layer { for layer in (new_layer + 1..=self.max_layer).rev() { - let results = search_layer(self, &query, current_ep, 1, layer, None); + let results = search_layer(self, &query, current_ep, 1, layer, None, 0); if let Some(nearest) = results.first() { current_ep = nearest.id; } @@ -58,7 +58,7 @@ impl HnswIndex { let insert_top = new_layer.min(self.max_layer); for layer in (0..=insert_top).rev() { let ef = self.params.ef_construction; - let candidates = search_layer(self, &query, current_ep, ef, layer, None); + let candidates = search_layer(self, &query, current_ep, ef, layer, None, 0); let m = self.max_neighbors(layer); let selected = select_neighbors_heuristic(self, &candidates, m); diff --git a/nodedb-vector/src/hnsw/graph.rs b/nodedb-vector/src/hnsw/graph.rs index 00a2b99e..72f9b4d0 100644 --- a/nodedb-vector/src/hnsw/graph.rs +++ b/nodedb-vector/src/hnsw/graph.rs @@ -8,6 +8,11 @@ use crate::distance::distance; // Re-export shared params from nodedb-types. pub use nodedb_types::hnsw::HnswParams; +/// Hard cap on the layer assigned to any node during insertion. +/// Standard HNSW practice — prevents pathological RNG draws from inflating +/// `max_layer` and slowing every subsequent search. +pub const MAX_LAYER_CAP: usize = 16; + /// Result of a k-NN search. #[derive(Debug, Clone)] pub struct SearchResult { @@ -254,10 +259,15 @@ impl HnswIndex { } /// Assign a random layer using the exponential distribution. + /// + /// Capped at `MAX_LAYER_CAP` to prevent pathological RNG draws from + /// promoting the index's `max_layer` to hundreds or thousands, which + /// would make every search's Phase-1 greedy descent O(max_layer). pub(crate) fn random_layer(&mut self) -> usize { let ml = 1.0 / (self.params.m as f64).ln(); let r = self.rng.next_f64().max(f64::MIN_POSITIVE); - (-r.ln() * ml).floor() as usize + let layer = (-r.ln() * ml).floor() as usize; + layer.min(MAX_LAYER_CAP) } /// Compute distance between a query vector and a stored node. @@ -279,10 +289,22 @@ impl HnswIndex { } /// Compact the index by removing all tombstoned nodes. + /// + /// Returns the number of removed nodes. See `compact_with_map` for the + /// variant that also returns the old→new id remapping. pub fn compact(&mut self) -> usize { + self.compact_with_map().0 + } + + /// Compact and return both the removed count and the old→new id map. + /// + /// `id_map[old_local]` = new_local, or `u32::MAX` if the node was + /// tombstoned (removed). + pub fn compact_with_map(&mut self) -> (usize, Vec) { let tombstone_count = self.tombstone_count(); if tombstone_count == 0 { - return 0; + let identity: Vec = (0..self.nodes.len() as u32).collect(); + return (0, identity); } self.ensure_mutable_neighbors(); @@ -348,7 +370,7 @@ impl HnswIndex { .unwrap_or(0); self.nodes = new_nodes; - tombstone_count + (tombstone_count, id_map) } } diff --git a/nodedb-vector/src/hnsw/search.rs b/nodedb-vector/src/hnsw/search.rs index c917943f..64ca15b2 100644 --- a/nodedb-vector/src/hnsw/search.rs +++ b/nodedb-vector/src/hnsw/search.rs @@ -31,14 +31,14 @@ impl HnswIndex { // Phase 1: Greedy descent from top layer to layer 1. let mut current_ep = ep; for layer in (1..=self.max_layer).rev() { - let results = search_layer(self, query, current_ep, 1, layer, None); + let results = search_layer(self, query, current_ep, 1, layer, None, 0); if let Some(nearest) = results.first() { current_ep = nearest.id; } } // Phase 2: Beam search at layer 0. - let results = search_layer(self, query, current_ep, ef, 0, None); + let results = search_layer(self, query, current_ep, ef, 0, None, 0); results .into_iter() @@ -51,16 +51,28 @@ impl HnswIndex { } /// Filtered K-NN search with Roaring bitmap pre-filtering. - /// - /// Only nodes whose ID is present in `filter` are included in results. - /// All nodes are still used for graph navigation — this prevents accuracy - /// degradation for selective filters. pub fn search_filtered( &self, query: &[f32], k: usize, ef: usize, filter: &RoaringBitmap, + ) -> Vec { + self.search_filtered_offset(query, k, ef, filter, 0) + } + + /// Filtered K-NN search where the bitmap is keyed in a shifted ID space. + /// + /// `id_offset` is added to local node IDs before testing `filter.contains`. + /// Used by multi-segment collections where the bitmap holds GLOBAL ids + /// and each segment's HNSW nodes are numbered starting at `base_id`. + pub fn search_filtered_offset( + &self, + query: &[f32], + k: usize, + ef: usize, + filter: &RoaringBitmap, + id_offset: u32, ) -> Vec { assert_eq!(query.len(), self.dim, "query dimension mismatch"); if self.is_empty() { @@ -74,13 +86,13 @@ impl HnswIndex { let mut current_ep = ep; for layer in (1..=self.max_layer).rev() { - let results = search_layer(self, query, current_ep, 1, layer, None); + let results = search_layer(self, query, current_ep, 1, layer, None, 0); if let Some(nearest) = results.first() { current_ep = nearest.id; } } - let results = search_layer(self, query, current_ep, ef, 0, Some(filter)); + let results = search_layer(self, query, current_ep, ef, 0, Some(filter), id_offset); results .into_iter() @@ -99,9 +111,22 @@ impl HnswIndex { k: usize, ef: usize, bitmap_bytes: &[u8], + ) -> Vec { + self.search_with_bitmap_bytes_offset(query, k, ef, bitmap_bytes, 0) + } + + /// Deserialize a Roaring bitmap and search with an ID offset applied + /// before testing membership. See `search_filtered_offset` for rationale. + pub fn search_with_bitmap_bytes_offset( + &self, + query: &[f32], + k: usize, + ef: usize, + bitmap_bytes: &[u8], + id_offset: u32, ) -> Vec { match RoaringBitmap::deserialize_from(bitmap_bytes) { - Ok(bitmap) => self.search_filtered(query, k, ef, &bitmap), + Ok(bitmap) => self.search_filtered_offset(query, k, ef, &bitmap, id_offset), Err(_) => self.search(query, k, ef), } } @@ -119,6 +144,7 @@ pub(crate) fn search_layer( ef: usize, layer: usize, filter: Option<&RoaringBitmap>, + id_offset: u32, ) -> Vec { let mut visited: HashSet = HashSet::new(); visited.insert(entry_point); @@ -139,7 +165,7 @@ pub(crate) fn search_layer( return false; } match filter { - Some(f) => f.contains(id), + Some(f) => f.contains(id + id_offset), None => true, } }; diff --git a/nodedb-vector/src/ivf.rs b/nodedb-vector/src/ivf.rs index 37601407..16071f15 100644 --- a/nodedb-vector/src/ivf.rs +++ b/nodedb-vector/src/ivf.rs @@ -220,21 +220,40 @@ fn kmeans_centroids(data: &[&[f32]], dim: usize, k: usize, max_iter: usize) -> V let mut centroids: Vec> = vec![data[0].to_vec()]; let mut min_dists = vec![f32::MAX; n]; + // Initialize min_dists against the first centroid. + for (i, point) in data.iter().enumerate() { + let d = distance(point, ¢roids[0], DistanceMetric::L2); + if d < min_dists[i] { + min_dists[i] = d; + } + } + + let mut rng = crate::hnsw::Xorshift64::new(0xC0FF_EEDE_ADBE_EF42); for _ in 1..k { - let Some(last) = centroids.last() else { break }; + let total: f64 = min_dists.iter().map(|&d| d as f64).sum(); + let next_idx = if total < f64::EPSILON { + 0 + } else { + let target = rng.next_f64() * total; + let mut acc = 0.0f64; + let mut chosen = n - 1; + for (i, &d) in min_dists.iter().enumerate() { + acc += d as f64; + if acc >= target { + chosen = i; + break; + } + } + chosen + }; + centroids.push(data[next_idx].to_vec()); + let last = centroids.last().expect("just pushed"); for (i, point) in data.iter().enumerate() { let d = distance(point, last, DistanceMetric::L2); if d < min_dists[i] { min_dists[i] = d; } } - let best = min_dists - .iter() - .enumerate() - .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal)) - .map(|(i, _)| i) - .unwrap_or(0); - centroids.push(data[best].to_vec()); } let mut assignments = vec![0usize; n]; diff --git a/nodedb-vector/src/quantize/pq.rs b/nodedb-vector/src/quantize/pq.rs index e7f2ad99..62cafe00 100644 --- a/nodedb-vector/src/quantize/pq.rs +++ b/nodedb-vector/src/quantize/pq.rs @@ -15,7 +15,7 @@ use serde::{Deserialize, Serialize}; /// PQ codec with trained codebooks. -#[derive(Clone, Serialize, Deserialize)] +#[derive(Clone, Serialize, Deserialize, zerompk::ToMessagePack, zerompk::FromMessagePack)] pub struct PqCodec { /// Original vector dimensionality. pub dim: usize, @@ -161,7 +161,8 @@ fn l2_sub(a: &[f32], b: &[f32]) -> f32 { /// Simple k-means clustering for PQ codebook training. /// -/// Uses k-means++ initialization for stable convergence. +/// Uses proper k-means++ initialization (weighted d² sampling) with a +/// deterministic seed so training is reproducible across runs. fn kmeans(data: &[&[f32]], dim: usize, k: usize, max_iter: usize) -> Vec> { let n = data.len(); if n == 0 || k == 0 { @@ -169,35 +170,48 @@ fn kmeans(data: &[&[f32]], dim: usize, k: usize, max_iter: usize) -> Vec> = Vec::with_capacity(k); - // First centroid: pick the first data point (deterministic). centroids.push(data[0].to_vec()); let mut min_dists = vec![f32::MAX; n]; - for c in 1..k { - // Update min distances to nearest centroid. + // Update against the first centroid. + for (i, point) in data.iter().enumerate() { + let d = l2_sub(point, ¢roids[0]); + if d < min_dists[i] { + min_dists[i] = d; + } + } + + for _ in 1..k { + let total: f64 = min_dists.iter().map(|&d| d as f64).sum(); + let next_idx = if total < f64::EPSILON { + // All points coincide with existing centroids. + 0 + } else { + let target = rng.next_f64() * total; + let mut acc = 0.0f64; + let mut chosen = n - 1; + for (i, &d) in min_dists.iter().enumerate() { + acc += d as f64; + if acc >= target { + chosen = i; + break; + } + } + chosen + }; + centroids.push(data[next_idx].to_vec()); + // Incrementally update min_dists against the new centroid. + let last = centroids.last().expect("just pushed"); for (i, point) in data.iter().enumerate() { - let d = l2_sub(point, ¢roids[c - 1]); + let d = l2_sub(point, last); if d < min_dists[i] { min_dists[i] = d; } } - // Pick next centroid proportional to d². - let total: f64 = min_dists.iter().map(|&d| d as f64).sum(); - if total < f64::EPSILON { - // All points coincide — duplicate the first centroid. - centroids.push(data[0].to_vec()); - continue; - } - // Deterministic selection: pick the point with max min_dist. - let best_idx = min_dists - .iter() - .enumerate() - .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal)) - .map(|(i, _)| i) - .unwrap_or(0); - centroids.push(data[best_idx].to_vec()); } // K-means iterations. diff --git a/nodedb-vector/tests/collection_bitmap_filter.rs b/nodedb-vector/tests/collection_bitmap_filter.rs new file mode 100644 index 00000000..8e406a53 --- /dev/null +++ b/nodedb-vector/tests/collection_bitmap_filter.rs @@ -0,0 +1,112 @@ +//! Roaring bitmap pre-filter must use the same ID space across segments. +//! +//! Spec: the query planner builds a Roaring bitmap from GLOBAL vector IDs. +//! `search_with_bitmap_bytes` walks each sealed segment and the segment's +//! HNSW index tests `filter.contains(id)` against the segment-LOCAL id. +//! The collection MUST reconcile the two — either by rewriting the bitmap +//! per-segment (subtract `seg.base_id`) or by applying the offset before +//! `f.contains(id)`. Without that, every segment beyond the first silently +//! drops all filtered candidates because global ≠ local. + +#![cfg(feature = "collection")] + +use nodedb_vector::DistanceMetric; +use nodedb_vector::collection::VectorCollection; +use nodedb_vector::hnsw::{HnswIndex, HnswParams}; +use roaring::RoaringBitmap; + +fn params() -> HnswParams { + HnswParams { + metric: DistanceMetric::L2, + ..HnswParams::default() + } +} + +/// Fill a collection's growing segment, seal it, complete the build, +/// so the next inserts land at `base_id == seal_count`. +fn seal_one(coll: &mut VectorCollection, count: usize) { + for i in 0..count { + coll.insert(vec![i as f32, 0.0]); + } + let req = coll.seal("k").expect("seal produced request"); + let mut idx = HnswIndex::new(req.dim, req.params.clone()); + for v in &req.vectors { + idx.insert(v.clone()).unwrap(); + } + coll.complete_build(req.segment_id, idx); +} + +fn bitmap_bytes(ids: impl IntoIterator) -> Vec { + let mut bm = RoaringBitmap::new(); + for id in ids { + bm.insert(id); + } + let mut bytes = Vec::new(); + bm.serialize_into(&mut bytes).unwrap(); + bytes +} + +#[test] +fn bitmap_filter_targets_second_segment_global_ids() { + let mut coll = VectorCollection::with_seal_threshold(2, params(), 50); + seal_one(&mut coll, 50); // segment 0: ids 0..50, base_id = 0 + seal_one(&mut coll, 50); // segment 1: ids 50..100, base_id = 50 + + // Query for a point near id=75 (in segment 1). Filter to only global + // id 75. Correct behavior: returns id=75. Buggy behavior: the second + // segment's bitmap lookup tests local id 25 against a bitmap that + // contains global 75 → zero matches. + let bytes = bitmap_bytes([75u32]); + let results = coll.search_with_bitmap_bytes(&[75.0, 0.0], 1, 64, &bytes); + + assert_eq!( + results.len(), + 1, + "global-id bitmap filter dropped all candidates in segment 1" + ); + assert_eq!(results[0].id, 75); +} + +#[test] +fn bitmap_filter_recovers_many_globals_across_segments() { + let mut coll = VectorCollection::with_seal_threshold(2, params(), 50); + seal_one(&mut coll, 50); + seal_one(&mut coll, 50); + + // Select globals from the second segment only. + let wanted: Vec = (60..70).collect(); + let bytes = bitmap_bytes(wanted.iter().copied()); + + let results = coll.search_with_bitmap_bytes(&[65.0, 0.0], 10, 128, &bytes); + + assert_eq!( + results.len(), + wanted.len(), + "expected all {} second-segment globals to match; got {}", + wanted.len(), + results.len() + ); + let got: std::collections::HashSet = results.iter().map(|r| r.id).collect(); + for id in &wanted { + assert!( + got.contains(id), + "missing expected id {id} from filtered results" + ); + } +} + +#[test] +fn bitmap_filter_first_segment_still_works() { + // Regression guard for the partial-accident: segment 0 has base_id=0 so + // local==global and filtering appears to work. This test pins that down + // so a fix to the second-segment path doesn't regress segment 0. + let mut coll = VectorCollection::with_seal_threshold(2, params(), 50); + seal_one(&mut coll, 50); + seal_one(&mut coll, 50); + + let bytes = bitmap_bytes([10u32, 20, 30]); + let results = coll.search_with_bitmap_bytes(&[20.0, 0.0], 3, 64, &bytes); + let got: std::collections::HashSet = results.iter().map(|r| r.id).collect(); + let expected: std::collections::HashSet = [10u32, 20, 30].into_iter().collect(); + assert_eq!(got, expected); +} diff --git a/nodedb-vector/tests/collection_checkpoint_tombstones.rs b/nodedb-vector/tests/collection_checkpoint_tombstones.rs new file mode 100644 index 00000000..de3b42f9 --- /dev/null +++ b/nodedb-vector/tests/collection_checkpoint_tombstones.rs @@ -0,0 +1,98 @@ +//! Soft-deletes in growing / building segments must survive checkpoint restore. +//! +//! Spec: `delete(id)` on a vector in the growing segment (or a building +//! segment awaiting HNSW completion) tombstones the vector. Checkpoints +//! MUST serialize that tombstone, and `from_checkpoint` MUST apply it so +//! the restored collection reports the same `live_count()` and excludes +//! the deleted vector from `search()` results. +//! +//! Today: +//! - `FlatIndex::get_vector` returns `Some(..)` even for tombstoned +//! slots, so `growing_deleted` is serialized as all-false. +//! - `from_checkpoint` ignores the `growing_deleted` field entirely and +//! re-inserts every vector as live. +//! +//! Result: crash recovery silently resurrects soft-deleted rows — a +//! correctness regression for any workflow using `valid_until` deletes. + +#![cfg(feature = "collection")] + +use nodedb_vector::DistanceMetric; +use nodedb_vector::collection::VectorCollection; +use nodedb_vector::hnsw::HnswParams; + +fn params() -> HnswParams { + HnswParams { + metric: DistanceMetric::L2, + ..HnswParams::default() + } +} + +#[test] +fn growing_segment_tombstones_survive_checkpoint_roundtrip() { + let mut coll = VectorCollection::new(2, params()); + for i in 0..10u32 { + coll.insert(vec![i as f32, 0.0]); + } + assert!(coll.delete(3), "delete on live growing vector must succeed"); + assert!(coll.delete(7), "delete on live growing vector must succeed"); + let live_before = coll.live_count(); + assert_eq!(live_before, 8); + + let bytes = coll.checkpoint_to_bytes(); + let restored = VectorCollection::from_checkpoint(&bytes).expect("checkpoint deserializes"); + + assert_eq!( + restored.live_count(), + live_before, + "tombstoned growing-segment vectors resurrected on restore" + ); + + let results = restored.search(&[3.0, 0.0], 10, 64); + let ids: std::collections::HashSet = results.iter().map(|r| r.id).collect(); + assert!( + !ids.contains(&3), + "soft-deleted id=3 reappeared in search after restore" + ); + assert!( + !ids.contains(&7), + "soft-deleted id=7 reappeared in search after restore" + ); +} + +#[test] +fn building_segment_tombstones_survive_checkpoint_roundtrip() { + // Force a seal so the deleted rows live in a building segment at + // snapshot time, exercising the `building_segments` encode path. + let mut coll = VectorCollection::with_seal_threshold(2, params(), 20); + for i in 0..20u32 { + coll.insert(vec![i as f32, 0.0]); + } + let _req = coll.seal("k").expect("seal produced request"); + // Intentionally do NOT complete the build — vectors now sit in the + // building segment as a FlatIndex. + assert!(coll.delete(5), "delete on building vector must succeed"); + assert!(coll.delete(15), "delete on building vector must succeed"); + let live_before = coll.live_count(); + assert_eq!(live_before, 18); + + let bytes = coll.checkpoint_to_bytes(); + let restored = VectorCollection::from_checkpoint(&bytes).expect("checkpoint deserializes"); + + assert_eq!( + restored.live_count(), + live_before, + "tombstoned building-segment vectors resurrected on restore" + ); + + let results = restored.search(&[5.0, 0.0], 20, 64); + let ids: std::collections::HashSet = results.iter().map(|r| r.id).collect(); + assert!( + !ids.contains(&5), + "soft-deleted id=5 reappeared after restore" + ); + assert!( + !ids.contains(&15), + "soft-deleted id=15 reappeared after restore" + ); +} diff --git a/nodedb-vector/tests/collection_compact_doc_map.rs b/nodedb-vector/tests/collection_compact_doc_map.rs new file mode 100644 index 00000000..a2cad9e1 --- /dev/null +++ b/nodedb-vector/tests/collection_compact_doc_map.rs @@ -0,0 +1,122 @@ +//! `compact()` must keep `doc_id_map` / `multi_doc_map` consistent with the +//! renumbered HNSW local IDs. +//! +//! Spec: `HnswIndex::compact()` removes tombstoned nodes and renumbers +//! surviving local node ids. The collection stores `doc_id_map` and +//! `multi_doc_map` keyed on GLOBAL ids (`seg.base_id + local`). After +//! compaction those globals shift too — the collection MUST walk both +//! maps and rewrite every entry for the compacted segment to the new +//! `(seg.base_id + new_local)` globals. Without the rewrite, +//! `get_doc_id(vid)` and `delete_multi_vector(doc)` point at stale or +//! wrong vectors. + +#![cfg(feature = "collection")] + +use nodedb_vector::DistanceMetric; +use nodedb_vector::collection::VectorCollection; +use nodedb_vector::hnsw::{HnswIndex, HnswParams}; + +fn params() -> HnswParams { + HnswParams { + metric: DistanceMetric::L2, + ..HnswParams::default() + } +} + +fn build_collection_with_docs() -> VectorCollection { + let mut coll = VectorCollection::with_seal_threshold(2, params(), 6); + // Six docs, one vector each. Global ids 0..6. + for i in 0..6u32 { + coll.insert_with_doc_id(vec![i as f32, 0.0], format!("doc_{i}")); + } + // Seal + complete → sealed segment with base_id=0, local ids 0..6. + let req = coll.seal("k").expect("seal produced request"); + let mut idx = HnswIndex::new(req.dim, req.params.clone()); + for v in &req.vectors { + idx.insert(v.clone()).unwrap(); + } + coll.complete_build(req.segment_id, idx); + coll +} + +#[test] +fn doc_id_map_stays_correct_after_compact() { + let mut coll = build_collection_with_docs(); + + // Tombstone two vectors in the middle of the sealed segment. + assert!(coll.delete(1)); + assert!(coll.delete(3)); + + // Sanity: pre-compact, the surviving doc mapping still resolves. + assert_eq!(coll.get_doc_id(0), Some("doc_0")); + assert_eq!(coll.get_doc_id(5), Some("doc_5")); + + let removed = coll.compact(); + assert_eq!(removed, 2, "compact should remove 2 tombstoned nodes"); + + // Spec: the search results (identified by renumbered global ids) still + // resolve to the original doc strings. For the surviving vectors + // {0, 2, 4, 5} post-compact globals become {0, 1, 2, 3}. `get_doc_id` + // MUST map those new globals to "doc_0", "doc_2", "doc_4", "doc_5". + let results = coll.search(&[0.0, 0.0], 4, 64); + let ids: Vec = results.iter().map(|r| r.id).collect(); + assert_eq!(ids.len(), 4, "expected 4 live vectors post-compact"); + + let observed_docs: std::collections::HashSet = ids + .iter() + .filter_map(|id| coll.get_doc_id(*id).map(|s| s.to_string())) + .collect(); + let expected_docs: std::collections::HashSet = ["doc_0", "doc_2", "doc_4", "doc_5"] + .into_iter() + .map(String::from) + .collect(); + + assert_eq!( + observed_docs, expected_docs, + "doc_id_map was not rewritten after compact — globals shifted but the map did not" + ); +} + +#[test] +fn multi_doc_map_stays_correct_after_compact() { + let mut coll = VectorCollection::with_seal_threshold(2, params(), 6); + + // Two multi-vector docs: doc_a owns globals 0,1,2; doc_b owns 3,4,5. + let a_vecs: Vec> = (0..3u32).map(|i| vec![i as f32, 0.0]).collect(); + let a_refs: Vec<&[f32]> = a_vecs.iter().map(|v| v.as_slice()).collect(); + let a_ids = coll.insert_multi_vector(&a_refs, "doc_a".to_string()); + assert_eq!(a_ids, vec![0, 1, 2]); + + let b_vecs: Vec> = (3..6u32).map(|i| vec![i as f32, 0.0]).collect(); + let b_refs: Vec<&[f32]> = b_vecs.iter().map(|v| v.as_slice()).collect(); + let b_ids = coll.insert_multi_vector(&b_refs, "doc_b".to_string()); + assert_eq!(b_ids, vec![3, 4, 5]); + + let req = coll.seal("k").expect("seal produced request"); + let mut idx = HnswIndex::new(req.dim, req.params.clone()); + for v in &req.vectors { + idx.insert(v.clone()).unwrap(); + } + coll.complete_build(req.segment_id, idx); + + // Tombstone one vector from each doc (middle of each group). + assert!(coll.delete(1)); + assert!(coll.delete(4)); + + coll.compact(); + + // Spec: `delete_multi_vector("doc_a")` must reach the two remaining + // vectors that originally belonged to doc_a, regardless of the local + // id renumbering performed by HnswIndex::compact. + let deleted_a = coll.delete_multi_vector("doc_a"); + assert_eq!( + deleted_a, 2, + "delete_multi_vector(doc_a) must find its 2 remaining vectors after compact" + ); + + let live_after = coll.live_count(); + assert_eq!( + live_after, 2, + "post-compact + doc_a delete: only doc_b's 2 remaining vectors survive" + ); +} diff --git a/nodedb-vector/tests/collection_pq_config.rs b/nodedb-vector/tests/collection_pq_config.rs new file mode 100644 index 00000000..da1e87af --- /dev/null +++ b/nodedb-vector/tests/collection_pq_config.rs @@ -0,0 +1,88 @@ +//! `index_type='hnsw_pq'` must produce PQ-compressed segments. +//! +//! Spec: when a collection is configured for HNSW+PQ (advertised via the +//! SQL DDL `CREATE INDEX ... WITH (index_type='hnsw_pq', pq_m=...)`), +//! `complete_build` MUST train a `PqCodec` on the finished segment and +//! surface `VectorIndexQuantization::Pq` in `stats()`. Today, the config +//! is accepted at the DDL layer, stored, and then ignored — +//! `complete_build` unconditionally calls `build_sq8_for_index`, so +//! operators who asked for 8-16× memory reduction silently receive 4× +//! SQ8 and have no signal from `stats()` that the request was dropped. + +#![cfg(feature = "collection")] + +use nodedb_vector::DistanceMetric; +use nodedb_vector::collection::VectorCollection; +use nodedb_vector::hnsw::{HnswIndex, HnswParams}; + +fn params() -> HnswParams { + HnswParams { + m: 16, + m0: 32, + ef_construction: 100, + metric: DistanceMetric::L2, + } +} + +/// Build a collection with 1024 vectors of dim=8 and complete one segment +/// build. The `>= 1000` vector threshold in `build_sq8_for_index` means a +/// quantizer WILL be attached — so `stats().quantization` is either `Sq8` +/// (buggy fallback) or `Pq` (spec-correct for a HnswPq-configured index). +fn make_built_collection_with_pq_config() -> VectorCollection { + // Uses the convenience constructor `with_seal_threshold_and_pq_config` so + // callers don't have to hand-build a full `IndexConfig` just to request PQ. + let mut coll = VectorCollection::with_seal_threshold_and_pq_config(8, params(), 2, 1024); + for i in 0..1024u32 { + let mut v = vec![0.0f32; 8]; + for (d, slot) in v.iter_mut().enumerate() { + *slot = ((i as f32) * 0.01 + (d as f32) * 0.1).sin(); + } + coll.insert(v); + } + let req = coll.seal("pq").expect("seal produced request"); + let mut idx = HnswIndex::new(req.dim, req.params.clone()); + for v in &req.vectors { + idx.insert(v.clone()).unwrap(); + } + coll.complete_build(req.segment_id, idx); + coll +} + +#[test] +fn hnsw_pq_config_produces_pq_quantization() { + let coll = make_built_collection_with_pq_config(); + let stats = coll.stats(); + assert_eq!( + stats.quantization, + nodedb_types::VectorIndexQuantization::Pq, + "index_type='hnsw_pq' must produce PQ-compressed segments and \ + report VectorIndexQuantization::Pq; got {:?}", + stats.quantization + ); +} + +#[test] +fn hnsw_pq_config_stats_index_type_reports_hnsw_pq() { + let coll = make_built_collection_with_pq_config(); + let stats = coll.stats(); + assert_eq!( + stats.index_type, + nodedb_types::VectorIndexType::HnswPq, + "stats().index_type must reflect the configured HnswPq index" + ); +} + +#[test] +fn hnsw_pq_config_survives_checkpoint_roundtrip() { + let coll = make_built_collection_with_pq_config(); + let bytes = coll.checkpoint_to_bytes(); + let restored = VectorCollection::from_checkpoint(&bytes) + .expect("checkpoint must deserialize for PQ-configured collection"); + let stats = restored.stats(); + assert_eq!( + stats.quantization, + nodedb_types::VectorIndexQuantization::Pq, + "PQ codec must survive checkpoint roundtrip; got {:?}", + stats.quantization + ); +} diff --git a/nodedb-vector/tests/hnsw_layer_cap.rs b/nodedb-vector/tests/hnsw_layer_cap.rs new file mode 100644 index 00000000..bef3b0fa --- /dev/null +++ b/nodedb-vector/tests/hnsw_layer_cap.rs @@ -0,0 +1,73 @@ +//! HNSW `random_layer` must be capped at a reasonable maximum. +//! +//! Spec: standard HNSW caps the assigned layer at ~16. The current +//! `random_layer` implementation has no cap — with an unlucky xorshift +//! draw (`r ≈ 2.2e-308`), `-ln(r) * (1/ln(m))` can return a layer in +//! the hundreds or thousands. One outlier insert then promotes the +//! index's `max_layer`, and every subsequent search's Phase-1 greedy +//! descent iterates `(1..=max_layer).rev()` — converting constant-time +//! descent into O(max_layer) per query. + +use nodedb_vector::DistanceMetric; +use nodedb_vector::hnsw::{HnswIndex, HnswParams}; + +/// Hard cap enforced by `HnswIndex::random_layer`. Standard HNSW uses ~16 +/// and the implementation clamps at `MAX_LAYER_CAP = 16`. +const LAYER_CAP: usize = 16; + +#[test] +fn random_layer_never_exceeds_cap_under_normal_inserts() { + let mut idx = HnswIndex::with_seed( + 4, + HnswParams { + m: 16, + m0: 32, + ef_construction: 64, + metric: DistanceMetric::L2, + }, + 1, + ); + for i in 0..5_000u32 { + let v = vec![ + (i as f32).sin(), + (i as f32).cos(), + ((i * 3) as f32).sin(), + ((i * 7) as f32).cos(), + ]; + idx.insert(v).unwrap(); + } + assert!( + idx.max_layer() <= LAYER_CAP, + "max_layer grew to {} (cap = {LAYER_CAP}); one pathological random_layer \ + draw promoted the index and will slow every subsequent search", + idx.max_layer() + ); +} + +#[test] +fn random_layer_capped_with_adversarial_seed() { + // Seeds chosen to exercise xorshift states that produce very small + // `next_f64()` outputs early in the sequence. A correct implementation + // clamps the resulting layer regardless of the RNG draw. + for seed in [1u64, 2, 3, 7, 13, 42, 123, 9_999, 1_000_003] { + let mut idx = HnswIndex::with_seed( + 2, + HnswParams { + m: 2, // small m amplifies -ln(r) * (1/ln(m)) + m0: 4, + ef_construction: 32, + metric: DistanceMetric::L2, + }, + seed, + ); + for i in 0..2_000u32 { + idx.insert(vec![i as f32, 0.0]).unwrap(); + } + assert!( + idx.max_layer() <= LAYER_CAP, + "seed={seed}: max_layer reached {} (cap = {LAYER_CAP}) — \ + random_layer has no upper bound", + idx.max_layer() + ); + } +} diff --git a/nodedb-vector/tests/quantize_kmeans_distribution.rs b/nodedb-vector/tests/quantize_kmeans_distribution.rs new file mode 100644 index 00000000..4ed71611 --- /dev/null +++ b/nodedb-vector/tests/quantize_kmeans_distribution.rs @@ -0,0 +1,137 @@ +//! PQ and IVF-PQ codebook training must distribute centroids across the +//! data even when many input vectors are near-duplicates. +//! +//! Spec: k-means initialization selects centroids spread across the data +//! distribution. The current implementation has two compounding bugs: +//! +//! 1. `min_dists[i]` is only updated against `centroids[c - 1]` (the +//! last centroid), not against the full centroid set. Once two +//! centroids coincide, `min_dists` stops reflecting "distance to the +//! nearest centroid," so every subsequent deterministic-argmax pick +//! lands on the same outlier. +//! 2. The comment says "k-means++" but the selection is deterministic +//! farthest-point, so outliers dominate rather than being sampled +//! proportionally to d². +//! +//! Effect: on workloads with repeated prefixes/suffixes (templated chat, +//! shared headers/footers), most of the 256 centroids alias to one or two +//! points and PQ recall collapses. + +use nodedb_vector::quantize::pq::PqCodec; + +/// Training set of 200 vectors: 190 near-duplicates at the origin plus +/// 10 outliers scattered across a single subspace. A correct k-means++ +/// spreads centroids across both clusters; the current farthest-point- +/// with-broken-min-distance-update collapses to ~2 distinct centroids. +fn clustered_with_duplicates() -> Vec> { + let mut vecs: Vec> = Vec::with_capacity(200); + // Cluster A: 190 near-identical vectors near origin. + for i in 0..190 { + let eps = (i as f32) * 1e-5; + vecs.push(vec![eps, -eps, eps * 0.5, -eps * 0.5]); + } + // Cluster B: 10 outliers at distinct coordinates. + for j in 0..10 { + let x = 100.0 + (j as f32) * 10.0; + vecs.push(vec![x, -x, x * 0.5, -x * 0.5]); + } + vecs +} + +fn unique_centroid_count(codec: &PqCodec, vectors: &[Vec]) -> usize { + let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect(); + let codes = codec.encode_batch(&refs); + let m = codec.m; + // Per-subspace unique centroid indices used across the batch. + let mut min_unique = usize::MAX; + for sub in 0..m { + let mut seen = std::collections::HashSet::new(); + for row in 0..vectors.len() { + seen.insert(codes[row * m + sub]); + } + if seen.len() < min_unique { + min_unique = seen.len(); + } + } + min_unique +} + +#[test] +fn pq_kmeans_produces_diverse_centroids_on_duplicate_heavy_data() { + let vecs = clustered_with_duplicates(); + let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect(); + let codec = PqCodec::train(&refs, 4, 2, 16, 20); + + let unique = unique_centroid_count(&codec, &vecs); + assert!( + unique >= 4, + "k-means collapsed to {unique} unique centroids per subspace on \ + duplicate-heavy input; a correct k-means++ should pick at least \ + 4 distinct cluster representatives for k=16" + ); +} + +#[test] +fn pq_distance_table_separates_duplicates_from_outliers() { + // Spec test: after training, the PQ distance from a duplicate-cluster + // query to a duplicate vector must be meaningfully smaller than the + // distance to an outlier vector. Under the collapse bug, most + // codebook entries alias to one point so all distances look similar. + let vecs = clustered_with_duplicates(); + let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect(); + let codec = PqCodec::train(&refs, 4, 2, 16, 20); + + let query = [0.0f32, 0.0, 0.0, 0.0]; + let table = codec.build_distance_table(&query); + + let dup_code = codec.encode(&vecs[0]); // duplicate cluster + let outlier_code = codec.encode(&vecs[195]); // outlier cluster + + let dup_dist = codec.asymmetric_distance(&table, &dup_code); + let outlier_dist = codec.asymmetric_distance(&table, &outlier_code); + + assert!( + outlier_dist > dup_dist * 10.0, + "PQ failed to distinguish duplicate (d={dup_dist}) from outlier \ + (d={outlier_dist}) — codebook collapsed and the two codes encode \ + to near-identical table entries" + ); +} + +#[cfg(feature = "ivf")] +#[test] +fn ivf_pq_training_does_not_collapse_on_duplicate_heavy_data() { + use nodedb_vector::DistanceMetric; + use nodedb_vector::{IvfPqIndex, IvfPqParams}; + + let vecs = clustered_with_duplicates(); + let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect(); + let mut idx = IvfPqIndex::new( + 4, + IvfPqParams { + n_cells: 8, + pq_m: 2, + pq_k: 16, + nprobe: 4, + metric: DistanceMetric::L2, + }, + ); + idx.train(&refs); + for v in &vecs { + idx.add(v); + } + + // Query at the origin. Correct training assigns near-duplicates to + // one cell and outliers to another; the nearest result must come + // from the duplicate cluster (original indices 0..190). + let results = idx.search(&[0.0, 0.0, 0.0, 0.0], 5); + assert!(!results.is_empty(), "IVF-PQ returned no results"); + for r in &results { + assert!( + r.id < 190, + "IVF-PQ k-means collapse: query at origin returned outlier id={} \ + instead of a near-duplicate cluster member", + r.id + ); + } +} diff --git a/nodedb-vector/tests/simd_length_safety.rs b/nodedb-vector/tests/simd_length_safety.rs new file mode 100644 index 00000000..a8e01003 --- /dev/null +++ b/nodedb-vector/tests/simd_length_safety.rs @@ -0,0 +1,60 @@ +//! Length-parity safety for SIMD distance kernels. +//! +//! Spec: the public `distance(a, b, metric)` dispatcher MUST NOT invoke a +//! SIMD kernel when `a.len() != b.len()`. The AVX2/AVX-512/NEON kernels +//! iterate with `a.len()` and read from `b.as_ptr().add(off)` via +//! `loadu_ps` — reading past `b`'s allocation is undefined behavior. +//! +//! A deterministic panic at the dispatcher boundary is the contract. Either +//! length validation or length-bounded iteration keeps the kernel safe. + +#![cfg(feature = "simd")] + +use nodedb_vector::DistanceMetric; +use nodedb_vector::distance::distance; + +fn assert_rejects_mismatch(metric: DistanceMetric) { + // a.len() = 9 forces one 8-wide SIMD chunk + remainder. b.len() = 1 + // means any unchecked 256-bit load from b is a buffer overrun. A correct + // dispatcher either rejects the call (panic) or bounds iteration by + // `min(a.len(), b.len())`; both surface as a deterministic panic today + // because the scalar remainder loop indexes `b[i]` out of bounds. + let a = vec![1.0f32; 9]; + let b = vec![1.0f32; 1]; + + let result = std::panic::catch_unwind(|| distance(&a, &b, metric)); + assert!( + result.is_err(), + "distance({metric:?}) must reject length mismatch (a.len()=9, b.len()=1) \ + instead of reading past the shorter buffer" + ); +} + +#[test] +fn l2_rejects_length_mismatch() { + assert_rejects_mismatch(DistanceMetric::L2); +} + +#[test] +fn cosine_rejects_length_mismatch() { + assert_rejects_mismatch(DistanceMetric::Cosine); +} + +#[test] +fn inner_product_rejects_length_mismatch() { + assert_rejects_mismatch(DistanceMetric::InnerProduct); +} + +#[test] +fn l2_rejects_swapped_mismatch() { + // Swap order: shorter slice first. The kernels use a.len() as the loop + // bound, so a.len()=1, b.len()=9 exits early — but the dispatcher + // contract is symmetric: any mismatch is invalid input. + let a = vec![1.0f32; 1]; + let b = vec![1.0f32; 9]; + let result = std::panic::catch_unwind(|| distance(&a, &b, DistanceMetric::L2)); + assert!( + result.is_err(), + "distance() must reject length mismatch in either argument order" + ); +}