diff --git a/src/VecSim/CMakeLists.txt b/src/VecSim/CMakeLists.txt index 69eb4ee13..b459e377d 100644 --- a/src/VecSim/CMakeLists.txt +++ b/src/VecSim/CMakeLists.txt @@ -52,6 +52,9 @@ if (TARGET svs::svs) endif() if(VECSIM_BUILD_TESTS) - add_library(VectorSimilaritySerializer utils/serializer.cpp) + add_library(VectorSimilaritySerializer + algorithms/hnsw/hnsw_serializer.cpp + algorithms/svs/svs_serializer.cpp + ) target_link_libraries(VectorSimilarity VectorSimilaritySerializer) endif() diff --git a/src/VecSim/algorithms/hnsw/hnsw.h b/src/VecSim/algorithms/hnsw/hnsw.h index d87ddfcc9..ecba5921e 100644 --- a/src/VecSim/algorithms/hnsw/hnsw.h +++ b/src/VecSim/algorithms/hnsw/hnsw.h @@ -25,6 +25,7 @@ #ifdef BUILD_TESTS #include "hnsw_serialization_utils.h" #include "VecSim/utils/serializer.h" +#include "hnsw_serializer.h" #endif #include @@ -85,7 +86,7 @@ class HNSWIndex : public VecSimIndexAbstract, public VecSimIndexTombstone #ifdef BUILD_TESTS , - public Serializer + public HNSWSerializer #endif { protected: @@ -2355,5 +2356,5 @@ HNSWIndex::getHNSWElementNeighbors(size_t label, int ***neig } #ifdef BUILD_TESTS -#include "hnsw_serializer.h" +#include "hnsw_serializer_impl.h" #endif diff --git a/src/VecSim/algorithms/hnsw/hnsw_multi.h b/src/VecSim/algorithms/hnsw/hnsw_multi.h index 43cce821f..34bed712a 100644 --- a/src/VecSim/algorithms/hnsw/hnsw_multi.h +++ b/src/VecSim/algorithms/hnsw/hnsw_multi.h @@ -64,7 +64,7 @@ class HNSWIndex_Multi : public HNSWIndex { HNSWIndex_Multi(std::ifstream &input, const HNSWParams *params, const AbstractIndexInitParams &abstractInitParams, const IndexComponents &components, - Serializer::EncodingVersion version) + HNSWSerializer::EncodingVersion version) : HNSWIndex(input, params, abstractInitParams, components, version), labelLookup(this->maxElements, this->allocator) {} diff --git a/src/VecSim/utils/serializer.cpp b/src/VecSim/algorithms/hnsw/hnsw_serializer.cpp similarity index 50% rename from src/VecSim/utils/serializer.cpp rename to src/VecSim/algorithms/hnsw/hnsw_serializer.cpp index 318608d2a..4a0fac9c6 100644 --- a/src/VecSim/utils/serializer.cpp +++ b/src/VecSim/algorithms/hnsw/hnsw_serializer.cpp @@ -7,38 +7,34 @@ * GNU Affero General Public License v3 (AGPLv3). */ -#include -#include +#include "hnsw_serializer.h" -#include "VecSim/utils/serializer.h" - -// Persist index into a file in the specified location. -void Serializer::saveIndex(const std::string &location) { - - // Serializing with the latest version. - EncodingVersion version = EncodingVersion_V4; - - std::ofstream output(location, std::ios::binary); - writeBinaryPOD(output, version); - saveIndexIMP(output); - output.close(); -} - -Serializer::EncodingVersion Serializer::ReadVersion(std::ifstream &input) { +HNSWSerializer::HNSWSerializer(EncodingVersion version) : m_version(version) {} +HNSWSerializer::EncodingVersion HNSWSerializer::ReadVersion(std::ifstream &input) { input.seekg(0, std::ifstream::beg); - // The version number is the first field that is serialized. - EncodingVersion version = EncodingVersion_INVALID; + EncodingVersion version = EncodingVersion::INVALID; readBinaryPOD(input, version); - if (version <= EncodingVersion_DEPRECATED) { + + if (version <= EncodingVersion::DEPRECATED) { input.close(); throw std::runtime_error("Cannot load index: deprecated encoding version: " + - std::to_string(version)); - } else if (version >= EncodingVersion_INVALID) { + std::to_string(static_cast(version))); + } else if (version >= EncodingVersion::INVALID) { input.close(); throw std::runtime_error("Cannot load index: bad encoding version: " + - std::to_string(version)); + std::to_string(static_cast(version))); } return version; } + +void HNSWSerializer::saveIndex(const std::string &location) { + EncodingVersion version = EncodingVersion::V4; + std::ofstream output(location, std::ios::binary); + writeBinaryPOD(output, version); + saveIndexIMP(output); + output.close(); +} + +HNSWSerializer::EncodingVersion HNSWSerializer::getVersion() const { return m_version; } diff --git a/src/VecSim/algorithms/hnsw/hnsw_serializer.h b/src/VecSim/algorithms/hnsw/hnsw_serializer.h index e0c3b201a..af1ae2871 100644 --- a/src/VecSim/algorithms/hnsw/hnsw_serializer.h +++ b/src/VecSim/algorithms/hnsw/hnsw_serializer.h @@ -9,309 +9,33 @@ #pragma once -template -HNSWIndex::HNSWIndex(std::ifstream &input, const HNSWParams *params, - const AbstractIndexInitParams &abstractInitParams, - const IndexComponents &components, - Serializer::EncodingVersion version) - : VecSimIndexAbstract(abstractInitParams, components), Serializer(version), - epsilon(params->epsilon), graphDataBlocks(this->allocator), idToMetaData(this->allocator), - visitedNodesHandlerPool(0, this->allocator) { +#include +#include +#include "VecSim/utils/serializer.h" - this->restoreIndexFields(input); - this->fieldsValidation(); +// Middle layer for HNSW serialization +// Abstract functions should be implemented by the templated HNSW index - // Since level generator is implementation-defined, we dont read its value from the file. - // We use seed = 200 and not the default value (100) to get different sequence of - // levels value than the loaded index. - levelGenerator.seed(200); +class HNSWSerializer : public Serializer { +public: + enum class EncodingVersion { + DEPRECATED = 2, // Last deprecated version + V3, + V4, + INVALID + }; - // Set the initial capacity based on the number of elements in the loaded index. - maxElements = RoundUpInitialCapacity(this->curElementCount, this->blockSize); - this->idToMetaData.resize(maxElements); - this->visitedNodesHandlerPool.resize(maxElements); + explicit HNSWSerializer(EncodingVersion version = EncodingVersion::V4); - size_t initial_vector_size = maxElements / this->blockSize; - graphDataBlocks.reserve(initial_vector_size); -} + static EncodingVersion ReadVersion(std::ifstream &input); -template -void HNSWIndex::saveIndexIMP(std::ofstream &output) { - this->saveIndexFields(output); - this->saveGraph(output); -} + void saveIndex(const std::string &location); -template -void HNSWIndex::fieldsValidation() const { - if (this->M > UINT16_MAX / 2) - throw std::runtime_error("HNSW index parameter M is too large: argument overflow"); - if (this->M <= 1) - throw std::runtime_error("HNSW index parameter M cannot be 1 or 0"); -} + EncodingVersion getVersion() const; -template -HNSWIndexMetaData HNSWIndex::checkIntegrity() const { - HNSWIndexMetaData res = {.valid_state = false, - .memory_usage = -1, - .double_connections = HNSW_INVALID_META_DATA, - .unidirectional_connections = HNSW_INVALID_META_DATA, - .min_in_degree = HNSW_INVALID_META_DATA, - .max_in_degree = HNSW_INVALID_META_DATA, - .connections_to_repair = 0}; +protected: + EncodingVersion m_version; - // Save the current memory usage (before we use additional memory for the integrity check). - res.memory_usage = this->getAllocationSize(); - size_t connections_checked = 0, double_connections = 0, num_deleted = 0, - min_in_degree = SIZE_MAX, max_in_degree = 0; - size_t max_level_in_graph = 0; // including marked deleted elements - for (size_t i = 0; i < this->curElementCount; i++) { - if (this->isMarkedDeleted(i)) { - num_deleted++; - } - if (getGraphDataByInternalId(i)->toplevel > max_level_in_graph) { - max_level_in_graph = getGraphDataByInternalId(i)->toplevel; - } - } - std::vector> inbound_connections_num( - this->curElementCount, std::vector(max_level_in_graph + 1, 0)); - size_t incoming_edges_sets_sizes = 0; - for (size_t i = 0; i < this->curElementCount; i++) { - for (size_t l = 0; l <= getGraphDataByInternalId(i)->toplevel; l++) { - ElementLevelData &cur = this->getElementLevelData(i, l); - std::set s; - for (unsigned int j = 0; j < cur.numLinks; j++) { - // Check if we found an invalid neighbor. - if (cur.links[j] >= this->curElementCount || cur.links[j] == i) { - return res; - } - // If the neighbor has deleted, then this connection should be repaired. - if (isMarkedDeleted(cur.links[j])) { - res.connections_to_repair++; - } - inbound_connections_num[cur.links[j]][l]++; - s.insert(cur.links[j]); - connections_checked++; - - // Check if this connection is bidirectional. - ElementLevelData &other = this->getElementLevelData(cur.links[j], l); - for (int r = 0; r < other.numLinks; r++) { - if (other.links[r] == (idType)i) { - double_connections++; - break; - } - } - } - // Check if a certain neighbor appeared more than once. - if (s.size() != cur.numLinks) { - return res; - } - incoming_edges_sets_sizes += cur.incomingUnidirectionalEdges->size(); - } - } - if (num_deleted != this->numMarkedDeleted) { - return res; - } - - // Validate that each node's in-degree is coherent with the in-degree observed by the - // outgoing edges. - for (size_t i = 0; i < this->curElementCount; i++) { - for (size_t l = 0; l <= getGraphDataByInternalId(i)->toplevel; l++) { - if (inbound_connections_num[i][l] > max_in_degree) { - max_in_degree = inbound_connections_num[i][l]; - } - if (inbound_connections_num[i][l] < min_in_degree) { - min_in_degree = inbound_connections_num[i][l]; - } - } - } - - res.double_connections = double_connections; - res.unidirectional_connections = incoming_edges_sets_sizes; - res.min_in_degree = max_in_degree; - res.max_in_degree = min_in_degree; - if (incoming_edges_sets_sizes + double_connections != connections_checked) { - return res; - } - - res.valid_state = true; - return res; -} - -template -void HNSWIndex::restoreIndexFields(std::ifstream &input) { - // Restore index build parameters - readBinaryPOD(input, this->M); - readBinaryPOD(input, this->M0); - readBinaryPOD(input, this->efConstruction); - - // Restore index search parameter - readBinaryPOD(input, this->ef); - readBinaryPOD(input, this->epsilon); - - // Restore index meta-data - this->elementGraphDataSize = sizeof(ElementGraphData) + sizeof(idType) * this->M0; - this->levelDataSize = sizeof(ElementLevelData) + sizeof(idType) * this->M; - readBinaryPOD(input, this->mult); - - // Restore index state - readBinaryPOD(input, this->curElementCount); - readBinaryPOD(input, this->numMarkedDeleted); - readBinaryPOD(input, this->maxLevel); - readBinaryPOD(input, this->entrypointNode); -} - -template -void HNSWIndex::restoreGraph(std::ifstream &input, EncodingVersion version) { - // Restore id to metadata vector - labelType label = 0; - elementFlags flags = 0; - for (idType id = 0; id < this->curElementCount; id++) { - readBinaryPOD(input, label); - readBinaryPOD(input, flags); - this->idToMetaData[id].label = label; - this->idToMetaData[id].flags = flags; - - // Restore label lookup by getting the label from data_level0_memory_ - setVectorId(label, id); - } - - // Todo: create vector data container and load the stored data based on the index storage params - // when other storage types will be available. - dynamic_cast(this->vectors) - ->restoreBlocks(input, this->curElementCount, m_version); - - // Get graph data blocks - ElementGraphData *cur_egt; - auto tmpData = this->getAllocator()->allocate_unique(this->elementGraphDataSize); - size_t toplevel = 0; - size_t num_blocks = dynamic_cast(this->vectors)->numBlocks(); - for (size_t i = 0; i < num_blocks; i++) { - this->graphDataBlocks.emplace_back(this->blockSize, this->elementGraphDataSize, - this->allocator); - unsigned int block_len = 0; - readBinaryPOD(input, block_len); - for (size_t j = 0; j < block_len; j++) { - // Reset tmpData - memset(tmpData.get(), 0, this->elementGraphDataSize); - // Read the current element top level - readBinaryPOD(input, toplevel); - // Allocate space and structs for the current element - try { - new (tmpData.get()) - ElementGraphData(toplevel, this->levelDataSize, this->allocator); - } catch (std::runtime_error &e) { - this->log(VecSimCommonStrings::LOG_WARNING_STRING, - "Error - allocating memory for new element failed due to low memory"); - throw e; - } - // Add the current element to the current block, and update cur_egt to point to it. - this->graphDataBlocks.back().addElement(tmpData.get()); - cur_egt = (ElementGraphData *)this->graphDataBlocks.back().getElement(j); - - // Restore the current element's graph data - for (size_t k = 0; k <= toplevel; k++) { - restoreLevel(input, getElementLevelData(cur_egt, k), version); - } - } - } -} - -template -void HNSWIndex::restoreLevel(std::ifstream &input, ElementLevelData &data, - EncodingVersion version) { - readBinaryPOD(input, data.numLinks); - for (size_t i = 0; i < data.numLinks; i++) { - readBinaryPOD(input, data.links[i]); - } - - // Restore the incoming edges of the current element - unsigned int size; - readBinaryPOD(input, size); - data.incomingUnidirectionalEdges->reserve(size); - idType id = INVALID_ID; - for (size_t i = 0; i < size; i++) { - readBinaryPOD(input, id); - data.incomingUnidirectionalEdges->push_back(id); - } -} - -template -void HNSWIndex::saveIndexFields(std::ofstream &output) const { - // Save index type - writeBinaryPOD(output, VecSimAlgo_HNSWLIB); - - // Save VecSimIndex fields - writeBinaryPOD(output, this->dim); - writeBinaryPOD(output, this->vecType); - writeBinaryPOD(output, this->metric); - writeBinaryPOD(output, this->blockSize); - writeBinaryPOD(output, this->isMulti); - writeBinaryPOD(output, this->maxElements); // This will be used to restore the index initial - // capacity - - // Save index build parameters - writeBinaryPOD(output, this->M); - writeBinaryPOD(output, this->M0); - writeBinaryPOD(output, this->efConstruction); - - // Save index search parameter - writeBinaryPOD(output, this->ef); - writeBinaryPOD(output, this->epsilon); - - // Save index meta-data - writeBinaryPOD(output, this->mult); - - // Save index state - writeBinaryPOD(output, this->curElementCount); - writeBinaryPOD(output, this->numMarkedDeleted); - writeBinaryPOD(output, this->maxLevel); - writeBinaryPOD(output, this->entrypointNode); -} - -template -void HNSWIndex::saveGraph(std::ofstream &output) const { - // Save id to metadata vector - for (idType id = 0; id < this->curElementCount; id++) { - labelType label = this->idToMetaData[id].label; - elementFlags flags = this->idToMetaData[id].flags; - writeBinaryPOD(output, label); - writeBinaryPOD(output, flags); - } - - this->vectors->saveVectorsData(output); - - // Save graph data blocks - for (size_t i = 0; i < this->graphDataBlocks.size(); i++) { - auto &block = this->graphDataBlocks[i]; - unsigned int block_len = block.getLength(); - writeBinaryPOD(output, block_len); - for (size_t j = 0; j < block_len; j++) { - ElementGraphData *cur_element = (ElementGraphData *)block.getElement(j); - writeBinaryPOD(output, cur_element->toplevel); - - // Save all the levels of the current element - for (size_t level = 0; level <= cur_element->toplevel; level++) { - saveLevel(output, getElementLevelData(cur_element, level)); - } - } - } -} - -template -void HNSWIndex::saveLevel(std::ofstream &output, ElementLevelData &data) const { - // Save the links of the current element - writeBinaryPOD(output, data.numLinks); - for (size_t i = 0; i < data.numLinks; i++) { - writeBinaryPOD(output, data.links[i]); - } - - // Save the incoming edges of the current element - unsigned int size = data.incomingUnidirectionalEdges->size(); - writeBinaryPOD(output, size); - for (idType id : *data.incomingUnidirectionalEdges) { - writeBinaryPOD(output, id); - } - - // Shrink the incoming edges vector for integrity check - data.incomingUnidirectionalEdges->shrink_to_fit(); -} +private: + void saveIndexFields(std::ofstream &output) const = 0; +}; diff --git a/src/VecSim/algorithms/hnsw/hnsw_serializer_declarations.h b/src/VecSim/algorithms/hnsw/hnsw_serializer_declarations.h index 767447319..9a86133a5 100644 --- a/src/VecSim/algorithms/hnsw/hnsw_serializer_declarations.h +++ b/src/VecSim/algorithms/hnsw/hnsw_serializer_declarations.h @@ -13,7 +13,8 @@ public: HNSWIndex(std::ifstream &input, const HNSWParams *params, const AbstractIndexInitParams &abstractInitParams, - const IndexComponents &components, EncodingVersion version); + const IndexComponents &components, + HNSWSerializer::EncodingVersion version); // Validates the connections between vectors HNSWIndexMetaData checkIntegrity() const; @@ -22,16 +23,17 @@ HNSWIndexMetaData checkIntegrity() const; virtual void saveIndexIMP(std::ofstream &output) override; // used by index factory to load nodes connections -void restoreGraph(std::ifstream &input, EncodingVersion version); +void restoreGraph(std::ifstream &input, HNSWSerializer::EncodingVersion version); private: // Functions for index saving. -void saveIndexFields(std::ofstream &output) const; +void saveIndexFields(std::ofstream &output) const override; void saveGraph(std::ofstream &output) const; void saveLevel(std::ofstream &output, ElementLevelData &data) const; -void restoreLevel(std::ifstream &input, ElementLevelData &data, EncodingVersion version); +void restoreLevel(std::ifstream &input, ElementLevelData &data, + HNSWSerializer::EncodingVersion version); void computeIndegreeForAll(); // Functions for index loading. diff --git a/src/VecSim/algorithms/hnsw/hnsw_serializer_impl.h b/src/VecSim/algorithms/hnsw/hnsw_serializer_impl.h new file mode 100644 index 000000000..5f9dd8cbf --- /dev/null +++ b/src/VecSim/algorithms/hnsw/hnsw_serializer_impl.h @@ -0,0 +1,321 @@ +/* + * Copyright (c) 2006-Present, Redis Ltd. + * All rights reserved. + * + * Licensed under your choice of the Redis Source Available License 2.0 + * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the + * GNU Affero General Public License v3 (AGPLv3). + */ + +#pragma once + +#include "hnsw_serializer.h" + +template +HNSWIndex::HNSWIndex(std::ifstream &input, const HNSWParams *params, + const AbstractIndexInitParams &abstractInitParams, + const IndexComponents &components, + HNSWSerializer::EncodingVersion version) + : VecSimIndexAbstract(abstractInitParams, components), + HNSWSerializer(version), epsilon(params->epsilon), graphDataBlocks(this->allocator), + idToMetaData(this->allocator), visitedNodesHandlerPool(0, this->allocator) { + + this->restoreIndexFields(input); + this->fieldsValidation(); + + // Since level generator is implementation-defined, we dont read its value from the file. + // We use seed = 200 and not the default value (100) to get different sequence of + // levels value than the loaded index. + levelGenerator.seed(200); + + // Set the initial capacity based on the number of elements in the loaded index. + maxElements = RoundUpInitialCapacity(this->curElementCount, this->blockSize); + this->idToMetaData.resize(maxElements); + this->visitedNodesHandlerPool.resize(maxElements); + + size_t initial_vector_size = maxElements / this->blockSize; + graphDataBlocks.reserve(initial_vector_size); +} + +template +void HNSWIndex::saveIndexIMP(std::ofstream &output) { + this->saveIndexFields(output); + this->saveGraph(output); +} + +template +void HNSWIndex::fieldsValidation() const { + if (this->M > UINT16_MAX / 2) + throw std::runtime_error("HNSW index parameter M is too large: argument overflow"); + if (this->M <= 1) + throw std::runtime_error("HNSW index parameter M cannot be 1 or 0"); +} + +template +HNSWIndexMetaData HNSWIndex::checkIntegrity() const { + HNSWIndexMetaData res = {.valid_state = false, + .memory_usage = -1, + .double_connections = HNSW_INVALID_META_DATA, + .unidirectional_connections = HNSW_INVALID_META_DATA, + .min_in_degree = HNSW_INVALID_META_DATA, + .max_in_degree = HNSW_INVALID_META_DATA, + .connections_to_repair = 0}; + + // Save the current memory usage (before we use additional memory for the integrity check). + res.memory_usage = this->getAllocationSize(); + size_t connections_checked = 0, double_connections = 0, num_deleted = 0, + min_in_degree = SIZE_MAX, max_in_degree = 0; + size_t max_level_in_graph = 0; // including marked deleted elements + for (size_t i = 0; i < this->curElementCount; i++) { + if (this->isMarkedDeleted(i)) { + num_deleted++; + } + if (getGraphDataByInternalId(i)->toplevel > max_level_in_graph) { + max_level_in_graph = getGraphDataByInternalId(i)->toplevel; + } + } + std::vector> inbound_connections_num( + this->curElementCount, std::vector(max_level_in_graph + 1, 0)); + size_t incoming_edges_sets_sizes = 0; + for (size_t i = 0; i < this->curElementCount; i++) { + for (size_t l = 0; l <= getGraphDataByInternalId(i)->toplevel; l++) { + ElementLevelData &cur = this->getElementLevelData(i, l); + std::set s; + for (unsigned int j = 0; j < cur.numLinks; j++) { + // Check if we found an invalid neighbor. + if (cur.links[j] >= this->curElementCount || cur.links[j] == i) { + return res; + } + // If the neighbor has deleted, then this connection should be repaired. + if (isMarkedDeleted(cur.links[j])) { + res.connections_to_repair++; + } + inbound_connections_num[cur.links[j]][l]++; + s.insert(cur.links[j]); + connections_checked++; + + // Check if this connection is bidirectional. + ElementLevelData &other = this->getElementLevelData(cur.links[j], l); + for (int r = 0; r < other.numLinks; r++) { + if (other.links[r] == (idType)i) { + double_connections++; + break; + } + } + } + // Check if a certain neighbor appeared more than once. + if (s.size() != cur.numLinks) { + return res; + } + incoming_edges_sets_sizes += cur.incomingUnidirectionalEdges->size(); + } + } + if (num_deleted != this->numMarkedDeleted) { + return res; + } + + // Validate that each node's in-degree is coherent with the in-degree observed by the + // outgoing edges. + for (size_t i = 0; i < this->curElementCount; i++) { + for (size_t l = 0; l <= getGraphDataByInternalId(i)->toplevel; l++) { + if (inbound_connections_num[i][l] > max_in_degree) { + max_in_degree = inbound_connections_num[i][l]; + } + if (inbound_connections_num[i][l] < min_in_degree) { + min_in_degree = inbound_connections_num[i][l]; + } + } + } + + res.double_connections = double_connections; + res.unidirectional_connections = incoming_edges_sets_sizes; + res.min_in_degree = max_in_degree; + res.max_in_degree = min_in_degree; + if (incoming_edges_sets_sizes + double_connections != connections_checked) { + return res; + } + + res.valid_state = true; + return res; +} + +template +void HNSWIndex::restoreIndexFields(std::ifstream &input) { + // Restore index build parameters + readBinaryPOD(input, this->M); + readBinaryPOD(input, this->M0); + readBinaryPOD(input, this->efConstruction); + + // Restore index search parameter + readBinaryPOD(input, this->ef); + readBinaryPOD(input, this->epsilon); + + // Restore index meta-data + this->elementGraphDataSize = sizeof(ElementGraphData) + sizeof(idType) * this->M0; + this->levelDataSize = sizeof(ElementLevelData) + sizeof(idType) * this->M; + readBinaryPOD(input, this->mult); + + // Restore index state + readBinaryPOD(input, this->curElementCount); + readBinaryPOD(input, this->numMarkedDeleted); + readBinaryPOD(input, this->maxLevel); + readBinaryPOD(input, this->entrypointNode); +} + +template +void HNSWIndex::restoreGraph(std::ifstream &input, + HNSWSerializer::EncodingVersion version) { + // Restore id to metadata vector + labelType label = 0; + elementFlags flags = 0; + for (idType id = 0; id < this->curElementCount; id++) { + readBinaryPOD(input, label); + readBinaryPOD(input, flags); + this->idToMetaData[id].label = label; + this->idToMetaData[id].flags = flags; + + // Restore label lookup by getting the label from data_level0_memory_ + setVectorId(label, id); + } + + // Todo: create vector data container and load the stored data based on the index storage params + // when other storage types will be available. + dynamic_cast(this->vectors) + ->restoreBlocks(input, this->curElementCount, + static_cast(m_version)); + + // Get graph data blocks + ElementGraphData *cur_egt; + auto tmpData = this->getAllocator()->allocate_unique(this->elementGraphDataSize); + size_t toplevel = 0; + size_t num_blocks = dynamic_cast(this->vectors)->numBlocks(); + for (size_t i = 0; i < num_blocks; i++) { + this->graphDataBlocks.emplace_back(this->blockSize, this->elementGraphDataSize, + this->allocator); + unsigned int block_len = 0; + readBinaryPOD(input, block_len); + for (size_t j = 0; j < block_len; j++) { + // Reset tmpData + memset(tmpData.get(), 0, this->elementGraphDataSize); + // Read the current element top level + readBinaryPOD(input, toplevel); + // Allocate space and structs for the current element + try { + new (tmpData.get()) + ElementGraphData(toplevel, this->levelDataSize, this->allocator); + } catch (std::runtime_error &e) { + this->log(VecSimCommonStrings::LOG_WARNING_STRING, + "Error - allocating memory for new element failed due to low memory"); + throw e; + } + // Add the current element to the current block, and update cur_egt to point to it. + this->graphDataBlocks.back().addElement(tmpData.get()); + cur_egt = (ElementGraphData *)this->graphDataBlocks.back().getElement(j); + + // Restore the current element's graph data + for (size_t k = 0; k <= toplevel; k++) { + restoreLevel(input, getElementLevelData(cur_egt, k), version); + } + } + } +} + +template +void HNSWIndex::restoreLevel(std::ifstream &input, ElementLevelData &data, + HNSWSerializer::EncodingVersion version) { + readBinaryPOD(input, data.numLinks); + for (size_t i = 0; i < data.numLinks; i++) { + readBinaryPOD(input, data.links[i]); + } + + // Restore the incoming edges of the current element + unsigned int size; + readBinaryPOD(input, size); + data.incomingUnidirectionalEdges->reserve(size); + idType id = INVALID_ID; + for (size_t i = 0; i < size; i++) { + readBinaryPOD(input, id); + data.incomingUnidirectionalEdges->push_back(id); + } +} + +template +void HNSWIndex::saveIndexFields(std::ofstream &output) const { + // Save index type + writeBinaryPOD(output, VecSimAlgo_HNSWLIB); + + // Save VecSimIndex fields + writeBinaryPOD(output, this->dim); + writeBinaryPOD(output, this->vecType); + writeBinaryPOD(output, this->metric); + writeBinaryPOD(output, this->blockSize); + writeBinaryPOD(output, this->isMulti); + writeBinaryPOD(output, this->maxElements); // This will be used to restore the index initial + // capacity + + // Save index build parameters + writeBinaryPOD(output, this->M); + writeBinaryPOD(output, this->M0); + writeBinaryPOD(output, this->efConstruction); + + // Save index search parameter + writeBinaryPOD(output, this->ef); + writeBinaryPOD(output, this->epsilon); + + // Save index meta-data + writeBinaryPOD(output, this->mult); + + // Save index state + writeBinaryPOD(output, this->curElementCount); + writeBinaryPOD(output, this->numMarkedDeleted); + writeBinaryPOD(output, this->maxLevel); + writeBinaryPOD(output, this->entrypointNode); +} + +template +void HNSWIndex::saveGraph(std::ofstream &output) const { + // Save id to metadata vector + for (idType id = 0; id < this->curElementCount; id++) { + labelType label = this->idToMetaData[id].label; + elementFlags flags = this->idToMetaData[id].flags; + writeBinaryPOD(output, label); + writeBinaryPOD(output, flags); + } + + this->vectors->saveVectorsData(output); + + // Save graph data blocks + for (size_t i = 0; i < this->graphDataBlocks.size(); i++) { + auto &block = this->graphDataBlocks[i]; + unsigned int block_len = block.getLength(); + writeBinaryPOD(output, block_len); + for (size_t j = 0; j < block_len; j++) { + ElementGraphData *cur_element = (ElementGraphData *)block.getElement(j); + writeBinaryPOD(output, cur_element->toplevel); + + // Save all the levels of the current element + for (size_t level = 0; level <= cur_element->toplevel; level++) { + saveLevel(output, getElementLevelData(cur_element, level)); + } + } + } +} + +template +void HNSWIndex::saveLevel(std::ofstream &output, ElementLevelData &data) const { + // Save the links of the current element + writeBinaryPOD(output, data.numLinks); + for (size_t i = 0; i < data.numLinks; i++) { + writeBinaryPOD(output, data.links[i]); + } + + // Save the incoming edges of the current element + unsigned int size = data.incomingUnidirectionalEdges->size(); + writeBinaryPOD(output, size); + for (idType id : *data.incomingUnidirectionalEdges) { + writeBinaryPOD(output, id); + } + + // Shrink the incoming edges vector for integrity check + data.incomingUnidirectionalEdges->shrink_to_fit(); +} diff --git a/src/VecSim/algorithms/hnsw/hnsw_single.h b/src/VecSim/algorithms/hnsw/hnsw_single.h index f8299ba32..6fbfc9967 100644 --- a/src/VecSim/algorithms/hnsw/hnsw_single.h +++ b/src/VecSim/algorithms/hnsw/hnsw_single.h @@ -40,7 +40,7 @@ class HNSWIndex_Single : public HNSWIndex { HNSWIndex_Single(std::ifstream &input, const HNSWParams *params, const AbstractIndexInitParams &abstractInitParams, const IndexComponents &components, - Serializer::EncodingVersion version) + HNSWSerializer::EncodingVersion version) : HNSWIndex(input, params, abstractInitParams, components, version), labelLookup(this->maxElements, this->allocator) {} diff --git a/src/VecSim/algorithms/svs/svs.h b/src/VecSim/algorithms/svs/svs.h index 5e0e5615f..13870fdd4 100644 --- a/src/VecSim/algorithms/svs/svs.h +++ b/src/VecSim/algorithms/svs/svs.h @@ -26,7 +26,16 @@ #include "VecSim/algorithms/svs/svs_batch_iterator.h" #include "VecSim/algorithms/svs/svs_extensions.h" -struct SVSIndexBase { +#ifdef BUILD_TESTS +#include "svs_serializer.h" +#endif + +struct SVSIndexBase +#ifdef BUILD_TESTS + : public SVSSerializer +#endif +{ + virtual ~SVSIndexBase() = default; virtual int addVectors(const void *vectors_data, const labelType *labels, size_t n) = 0; virtual int deleteVectors(const labelType *labels, size_t n) = 0; @@ -667,6 +676,17 @@ class SVSIndex : public VecSimIndexAbstract, fl } #ifdef BUILD_TESTS + +private: + void saveIndexIMP(std::ofstream &output) override; + void impl_save(const std::string &location) override; + void saveIndexFields(std::ofstream &output) const override; + + bool compareMetadataFile(const std::string &metadataFilePath) const override; + void loadIndex(const std::string &folder_path) override; + bool checkIntegrity() const override; + +public: void fitMemory() override {} std::vector> getStoredVectorDataByLabel(labelType label) const override { assert(false && "Not implemented"); @@ -681,3 +701,8 @@ class SVSIndex : public VecSimIndexAbstract, fl svs::logging::logger_ptr getLogger() const override { return logger_; } #endif }; + +#ifdef BUILD_TESTS +// Including implementations for Serializer base +#include "svs_serializer_impl.h" +#endif diff --git a/src/VecSim/algorithms/svs/svs_extensions.h b/src/VecSim/algorithms/svs/svs_extensions.h index b2bf0c611..3903d2289 100644 --- a/src/VecSim/algorithms/svs/svs_extensions.h +++ b/src/VecSim/algorithms/svs/svs_extensions.h @@ -27,6 +27,14 @@ struct SVSStorageTraits { static constexpr bool is_compressed() { return true; } + static auto make_blocked_allocator(size_t block_size, size_t dim, + std::shared_ptr allocator) { + // SVS block size is a power of two, so we can use it directly + auto svs_bs = svs_details::SVSBlockSize(block_size, element_size(dim)); + allocator_type data_allocator{std::move(allocator)}; + return svs::make_blocked_allocator_handle({svs_bs}, data_allocator); + } + static constexpr VecSimSvsQuantBits get_compression_mode() { return VecSimSvsQuant_Scalar; } template @@ -34,12 +42,22 @@ struct SVSStorageTraits { std::shared_ptr allocator, size_t /*leanvec_dim*/) { const auto dim = data.dimensions(); - auto svs_bs = svs_details::SVSBlockSize(block_size, element_size(dim)); + auto blocked_alloc = make_blocked_allocator(block_size, dim, std::move(allocator)); + return index_storage_type::compress(data, pool, blocked_alloc); + } - allocator_type data_allocator{std::move(allocator)}; - auto blocked_alloc = svs::make_blocked_allocator_handle({svs_bs}, data_allocator); + static index_storage_type load(const svs::lib::LoadTable &table, size_t block_size, size_t dim, + std::shared_ptr allocator) { + auto blocked_alloc = make_blocked_allocator(block_size, dim, std::move(allocator)); + return index_storage_type::load(table, blocked_alloc); + } - return index_storage_type::compress(data, pool, blocked_alloc); + static index_storage_type load(const std::string &path, size_t block_size, size_t dim, + std::shared_ptr allocator) { + assert(svs::data::detail::is_likely_reload(path)); // TODO implement auto_load for SQDataset + auto blocked_alloc = make_blocked_allocator(block_size, dim, std::move(allocator)); + // Load the data from disk + return svs::lib::load_from_disk(path, blocked_alloc); } static constexpr size_t element_size(size_t dims, size_t alignment = 0, @@ -94,19 +112,38 @@ struct SVSStorageTraits allocator) { + // SVS block size is a power of two, so we can use it directly + auto svs_bs = svs_details::SVSBlockSize(block_size, element_size(dim)); + allocator_type data_allocator{std::move(allocator)}; + return svs::make_blocked_allocator_handle({svs_bs}, data_allocator); + } + template static index_storage_type create_storage(const Dataset &data, size_t block_size, Pool &pool, std::shared_ptr allocator, size_t /*leanvec_dim*/) { const auto dim = data.dimensions(); - auto svs_bs = svs_details::SVSBlockSize(block_size, element_size(dim)); - - allocator_type data_allocator{std::move(allocator)}; - auto blocked_alloc = svs::make_blocked_allocator_handle({svs_bs}, data_allocator); + auto blocked_alloc = make_blocked_allocator(block_size, dim, std::move(allocator)); return index_storage_type::compress(data, pool, 0, blocked_alloc); } + static index_storage_type load(const svs::lib::LoadTable &table, size_t block_size, size_t dim, + std::shared_ptr allocator) { + auto blocked_alloc = make_blocked_allocator(block_size, dim, std::move(allocator)); + return index_storage_type::load(table, /*alignment=*/0, blocked_alloc); + } + + static index_storage_type load(const std::string &path, size_t block_size, size_t dim, + std::shared_ptr allocator) { + assert(svs::data::detail::is_likely_reload(path)); // TODO implement auto_load for LVQ + auto blocked_alloc = make_blocked_allocator(block_size, dim, std::move(allocator)); + // Load the data from disk + return svs::lib::load_from_disk(path, /*alignment=*/0, blocked_alloc); + } + static constexpr size_t element_size(size_t dims, size_t alignment = 0, size_t /*leanvec_dim*/ = 0) { using primary_type = typename index_storage_type::primary_type; @@ -151,22 +188,41 @@ struct SVSStorageTraits { } } + static auto make_blocked_allocator(size_t block_size, size_t dim, + std::shared_ptr allocator) { + // SVS block size is a power of two, so we can use it directly + auto svs_bs = svs_details::SVSBlockSize(block_size, element_size(dim)); + allocator_type data_allocator{std::move(allocator)}; + return svs::make_blocked_allocator_handle({svs_bs}, data_allocator); + } + template static index_storage_type create_storage(const Dataset &data, size_t block_size, Pool &pool, std::shared_ptr allocator, size_t leanvec_dim) { - const auto dims = data.dimensions(); - auto svs_bs = svs_details::SVSBlockSize(block_size, element_size(dims)); - - allocator_type data_allocator{std::move(allocator)}; - auto blocked_alloc = svs::make_blocked_allocator_handle({svs_bs}, data_allocator); + const auto dim = data.dimensions(); + auto blocked_alloc = make_blocked_allocator(block_size, dim, std::move(allocator)); return index_storage_type::reduce( data, std::nullopt, pool, 0, - svs::lib::MaybeStatic(check_leanvec_dim(dims, leanvec_dim)), + svs::lib::MaybeStatic(check_leanvec_dim(dim, leanvec_dim)), blocked_alloc); } + static index_storage_type load(const svs::lib::LoadTable &table, size_t block_size, size_t dim, + std::shared_ptr allocator) { + auto blocked_alloc = make_blocked_allocator(block_size, dim, std::move(allocator)); + return index_storage_type::load(table, /*alignment=*/0, blocked_alloc); + } + + static index_storage_type load(const std::string &path, size_t block_size, size_t dim, + std::shared_ptr allocator) { + assert(svs::data::detail::is_likely_reload(path)); // TODO implement auto_load for LeanVec + auto blocked_alloc = make_blocked_allocator(block_size, dim, std::move(allocator)); + // Load the data from disk + return svs::lib::load_from_disk(path, /*alignment=*/0, blocked_alloc); + } + static constexpr size_t element_size(size_t dims, size_t alignment = 0, size_t leanvec_dim = 0) { return SVSStorageTraits::element_size( diff --git a/src/VecSim/algorithms/svs/svs_serializer.cpp b/src/VecSim/algorithms/svs/svs_serializer.cpp new file mode 100644 index 000000000..58ef82ebe --- /dev/null +++ b/src/VecSim/algorithms/svs/svs_serializer.cpp @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2006-Present, Redis Ltd. + * All rights reserved. + * + * Licensed under your choice of the Redis Source Available License 2.0 + * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the + * GNU Affero General Public License v3 (AGPLv3). + */ + +#include "svs_serializer.h" + +namespace fs = std::filesystem; + +SVSSerializer::SVSSerializer(EncodingVersion version) : m_version(version) {} + +SVSSerializer::EncodingVersion SVSSerializer::ReadVersion(std::ifstream &input) { + input.seekg(0, std::ifstream::beg); + + EncodingVersion version = EncodingVersion::INVALID; + readBinaryPOD(input, version); + + if (version >= EncodingVersion::INVALID) { + input.close(); + throw std::runtime_error("Cannot load index: bad encoding version: " + + std::to_string(static_cast(version))); + } + return version; +} + +void SVSSerializer::saveIndex(const std::string &location) { + EncodingVersion version = EncodingVersion::V0; + auto metadata_path = fs::path(location) / "metadata"; + std::ofstream output(metadata_path, std::ios::binary); + writeBinaryPOD(output, version); + saveIndexIMP(output); + output.close(); + impl_save(location); +} + +SVSSerializer::EncodingVersion SVSSerializer::getVersion() const { return m_version; } diff --git a/src/VecSim/algorithms/svs/svs_serializer.h b/src/VecSim/algorithms/svs/svs_serializer.h new file mode 100644 index 000000000..d66644364 --- /dev/null +++ b/src/VecSim/algorithms/svs/svs_serializer.h @@ -0,0 +1,81 @@ +/* + * Copyright (c) 2006-Present, Redis Ltd. + * All rights reserved. + * + * Licensed under your choice of the Redis Source Available License 2.0 + * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the + * GNU Affero General Public License v3 (AGPLv3). + */ + +#pragma once + +#include +#include +#include +#include "VecSim/utils/serializer.h" +#include + +typedef struct { + bool valid_state; + long memory_usage; // in bytes + size_t index_size; + size_t storage_size; + size_t label_count; + size_t capacity; + size_t changes_count; + bool is_compressed; + bool is_multi; +} SVSIndexMetaData; + +// Middle layer for SVS serialization +// Abstract functions should be implemented by the templated SVS index + +class SVSSerializer : public Serializer { +public: + enum class EncodingVersion { V0, INVALID }; + + explicit SVSSerializer(EncodingVersion version = EncodingVersion::V0); + + static EncodingVersion ReadVersion(std::ifstream &input); + + void saveIndex(const std::string &location) override; + + virtual void loadIndex(const std::string &location) = 0; + + EncodingVersion getVersion() const; + + virtual bool checkIntegrity() const = 0; + +protected: + EncodingVersion m_version; + + virtual void impl_save(const std::string &location) = 0; + + // Helper function to compare the svs index fields with the metadata file + template + static void compareField(std::istream &in, const T &expected, const std::string &fieldName); + +private: + virtual bool compareMetadataFile(const std::string &metadataFilePath) const = 0; +}; + +// Implement << operator for enum class +inline std::ostream &operator<<(std::ostream &os, SVSSerializer::EncodingVersion version) { + return os << static_cast(version); +} + +template +void SVSSerializer::compareField(std::istream &in, const T &expected, + const std::string &fieldName) { + T actual; + Serializer::readBinaryPOD(in, actual); + if (!in.good()) { + throw std::runtime_error("Failed to read field: " + fieldName); + } + if (actual != expected) { + std::ostringstream msg; + msg << "Field mismatch in \"" << fieldName << "\": expected " << expected << ", got " + << actual; + throw std::runtime_error(msg.str()); + } +} diff --git a/src/VecSim/algorithms/svs/svs_serializer_impl.h b/src/VecSim/algorithms/svs/svs_serializer_impl.h new file mode 100644 index 000000000..3da957cf6 --- /dev/null +++ b/src/VecSim/algorithms/svs/svs_serializer_impl.h @@ -0,0 +1,231 @@ +/* + * Copyright (c) 2006-Present, Redis Ltd. + * All rights reserved. + * + * Licensed under your choice of the Redis Source Available License 2.0 + * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the + * GNU Affero General Public License v3 (AGPLv3). + */ +#pragma once + +#include "svs_serializer.h" +#include "svs/index/vamana/dynamic_index.h" +#include "svs/index/vamana/multi.h" + +// Saves all relevant fields of SVSIndex to the output stream +// This function saves all template parameters and instance fields needed to reconstruct +// an SVSIndex +template +void SVSIndex::saveIndexFields( + std::ofstream &output) const { + // Save base class fields from VecSimIndexAbstract + // Note: this->vecType corresponds to DataType template parameter + // Note: this->metric corresponds to MetricType template parameter + writeBinaryPOD(output, this->dim); + writeBinaryPOD(output, this->vecType); // DataType template parameter (as VecSimType enum) + writeBinaryPOD(output, this->dataSize); + writeBinaryPOD(output, this->metric); // MetricType template parameter (as VecSimMetric enum) + writeBinaryPOD(output, this->blockSize); + writeBinaryPOD(output, this->isMulti); + + // Save SVS-specific configuration fields + writeBinaryPOD(output, this->forcePreprocessing); + + // Save build parameters + writeBinaryPOD(output, this->buildParams.alpha); + writeBinaryPOD(output, this->buildParams.graph_max_degree); + writeBinaryPOD(output, this->buildParams.window_size); + writeBinaryPOD(output, this->buildParams.max_candidate_pool_size); + writeBinaryPOD(output, this->buildParams.prune_to); + writeBinaryPOD(output, this->buildParams.use_full_search_history); + + // Save search parameters + writeBinaryPOD(output, this->search_window_size); + writeBinaryPOD(output, this->epsilon); + + // Save template parameters as metadata for validation during loading + writeBinaryPOD(output, getCompressionMode()); + + // QuantBits, ResidualBits, and IsLeanVec information + + // Save additional template parameter constants for complete reconstruction + writeBinaryPOD(output, static_cast(QuantBits)); // Template parameter QuantBits + writeBinaryPOD(output, static_cast(ResidualBits)); // Template parameter ResidualBits + writeBinaryPOD(output, static_cast(IsLeanVec)); // Template parameter IsLeanVec + writeBinaryPOD(output, static_cast(isMulti)); // Template parameter isMulti + + // Save additional metadata for validation during loading + writeBinaryPOD(output, this->lastMode); // Last search mode +} + +// Saves metadata (e.g., encoding version) to satisfy Serializer interface. +// Full index is saved separately in saveIndex() using file paths. +template +void SVSIndex::saveIndexIMP( + std::ofstream &output) { + + // Save all index fields using the dedicated function + saveIndexFields(output); +} + +// Saves metadata (e.g., encoding version) to satisfy Serializer interface. +// Full index is saved separately in saveIndex() using file paths. +template +void SVSIndex::impl_save( + const std::string &location) { + impl_->save(location + "/config", location + "/graph", location + "/data"); +} + +// This function will load the serialized svs index from the given folder path +// This function should be called after the index is created with the same parameters as the +// original index. The index fields and template parameters will be validated before loading. After +// sucssessful loading, the graph can be validated with checkIntegrity. +template +void SVSIndex::loadIndex( + const std::string &folder_path) { + svs::threads::ThreadPoolHandle threadpool_handle{VecSimSVSThreadPool{threadpool_}}; + + // Verify metadata compatibility, will throw runtime exception if not compatible + compareMetadataFile(folder_path + "/metadata"); + + if constexpr (isMulti) { + auto loaded = svs::index::vamana::auto_multi_dynamic_assemble( + folder_path + "/config", + SVS_LAZY(graph_builder_t::load(folder_path + "/graph", this->blockSize, + this->buildParams, this->getAllocator())), + SVS_LAZY(storage_traits_t::load(folder_path + "/data", this->blockSize, this->dim, + this->getAllocator())), + distance_f(), std::move(threadpool_handle), + svs::index::vamana::MultiMutableVamanaLoad::FROM_MULTI, logger_); + impl_ = std::make_unique(std::move(loaded)); + } else { + auto loaded = svs::index::vamana::auto_dynamic_assemble( + folder_path + "/config", + SVS_LAZY(graph_builder_t::load(folder_path + "/graph", this->blockSize, + this->buildParams, this->getAllocator())), + SVS_LAZY(storage_traits_t::load(folder_path + "/data", this->blockSize, this->dim, + this->getAllocator())), + distance_f(), std::move(threadpool_handle), false, logger_); + impl_ = std::make_unique(std::move(loaded)); + } +} + +template +bool SVSIndex::compareMetadataFile(const std::string &metadataFilePath) const { + std::ifstream input(metadataFilePath, std::ios::binary); + if (!input.is_open()) { + throw std::runtime_error("Failed to open metadata file: " + metadataFilePath); + } + + // To check version, use ReadVersion + SVSSerializer::ReadVersion(input); + + compareField(input, this->dim, "dim"); + compareField(input, this->vecType, "vecType"); + compareField(input, this->dataSize, "dataSize"); + compareField(input, this->metric, "metric"); + compareField(input, this->blockSize, "blockSize"); + compareField(input, this->isMulti, "isMulti"); + + compareField(input, this->forcePreprocessing, "forcePreprocessing"); + + compareField(input, this->buildParams.alpha, "buildParams.alpha"); + compareField(input, this->buildParams.graph_max_degree, "buildParams.graph_max_degree"); + compareField(input, this->buildParams.window_size, "buildParams.window_size"); + compareField(input, this->buildParams.max_candidate_pool_size, + "buildParams.max_candidate_pool_size"); + compareField(input, this->buildParams.prune_to, "buildParams.prune_to"); + compareField(input, this->buildParams.use_full_search_history, + "buildParams.use_full_search_history"); + + compareField(input, this->search_window_size, "search_window_size"); + compareField(input, this->epsilon, "epsilon"); + + auto compressionMode = getCompressionMode(); + compareField(input, compressionMode, "compression_mode"); + + compareField(input, static_cast(QuantBits), "QuantBits"); + compareField(input, static_cast(ResidualBits), "ResidualBits"); + compareField(input, static_cast(IsLeanVec), "IsLeanVec"); + compareField(input, static_cast(isMulti), "isMulti (template param)"); + + return true; +} + +template +bool SVSIndex::checkIntegrity() + const { + if (!impl_) { + throw std::runtime_error( + "SVSIndex integrity check failed: index implementation (impl_) is null."); + } + + try { + // SVS internal index integrity validation + if constexpr (isMulti) { + impl_->get_parent_index().debug_check_invariants(true); + } else { + impl_->debug_check_invariants(true); + } + } + // debug_check_invariants throws svs::lib::ANNException : public std::runtime_error in case of + // fail. + catch (...) { + throw; + } + + try { + size_t index_size = impl_->size(); + size_t storage_size = impl_->view_data().size(); + size_t capacity = storage_traits_t::storage_capacity(impl_->view_data()); + size_t label_count = this->indexLabelCount(); + + // Storage size must match index size + if (storage_size != index_size) { + throw std::runtime_error( + "SVSIndex integrity check failed: storage_size != index_size."); + } + + // Capacity must be at least index size + if (capacity < index_size) { + throw std::runtime_error("SVSIndex integrity check failed: capacity < index_size."); + } + + // Binary label validation: verify label iteration and count consistency + size_t labels_counted = 0; + bool label_validation_passed = true; + + try { + impl_->on_ids([&](size_t label) { labels_counted++; }); + + // Validate label count consistency + label_validation_passed = (labels_counted == label_count); + + // For multi-index, also ensure label count doesn't exceed index size + if constexpr (isMulti) { + label_validation_passed = label_validation_passed && (label_count <= index_size); + } + } catch (...) { + label_validation_passed = false; + } + + if (!label_validation_passed) { + throw std::runtime_error("SVSIndex integrity check failed: label validation failed."); + } + + return true; + + } catch (const std::exception &e) { + throw std::runtime_error(std::string("SVSIndex integrity check failed with exception: ") + + e.what()); + } catch (...) { + throw std::runtime_error("SVSIndex integrity check failed with unknown exception."); + } +} diff --git a/src/VecSim/algorithms/svs/svs_utils.h b/src/VecSim/algorithms/svs/svs_utils.h index 214faf716..d507f5873 100644 --- a/src/VecSim/algorithms/svs/svs_utils.h +++ b/src/VecSim/algorithms/svs/svs_utils.h @@ -221,17 +221,22 @@ struct SVSStorageTraits { return VecSimSvsQuant_NONE; // No compression for this storage } + static blocked_type make_blocked_allocator(size_t block_size, size_t dim, + std::shared_ptr allocator) { + // SVS storage element size and block size can be differ than VecSim + auto svs_bs = svs_details::SVSBlockSize(block_size, element_size(dim)); + allocator_type data_allocator{std::move(allocator)}; + return blocked_type{{svs_bs}, data_allocator}; + } + template static index_storage_type create_storage(const Dataset &data, size_t block_size, Pool &pool, std::shared_ptr allocator, size_t /* leanvec_dim */) { const auto dim = data.dimensions(); const auto size = data.size(); - // SVS storage element size and block size can be differ than VecSim - auto svs_bs = svs_details::SVSBlockSize(block_size, element_size(dim)); // Allocate initial SVS storage for index - allocator_type data_allocator{std::move(allocator)}; - blocked_type blocked_alloc{{svs_bs}, data_allocator}; + auto blocked_alloc = make_blocked_allocator(block_size, dim, std::move(allocator)); index_storage_type init_data{size, dim, blocked_alloc}; // Copy data to allocated storage svs::threads::parallel_for(pool, svs::threads::StaticPartition(data.eachindex()), @@ -243,6 +248,20 @@ struct SVSStorageTraits { return init_data; } + static index_storage_type load(const svs::lib::LoadTable &table, size_t block_size, size_t dim, + std::shared_ptr allocator) { + auto blocked_alloc = make_blocked_allocator(block_size, dim, std::move(allocator)); + // Load the data from disk + return index_storage_type::load(table, blocked_alloc); + } + + static index_storage_type load(const std::string &path, size_t block_size, size_t dim, + std::shared_ptr allocator) { + auto blocked_alloc = make_blocked_allocator(block_size, dim, std::move(allocator)); + // Load the data from disk + return index_storage_type::load(path, blocked_alloc); + } + // SVS storage element size can be differ than VecSim DataSize static constexpr size_t element_size(size_t dims, size_t /*alignment*/ = 0, size_t /*leanvec_dim*/ = 0) { @@ -257,7 +276,15 @@ struct SVSGraphBuilder { using allocator_type = svs_details::SVSAllocator; using blocked_type = svs::data::Blocked; using graph_data_type = svs::data::BlockedData; - using graph_type = svs::graphs::SimpleGraphBase; + using graph_type = svs::graphs::SimpleGraph; + + static blocked_type make_blocked_allocator(size_t block_size, size_t graph_max_degree, + std::shared_ptr allocator) { + // SVS block size is a power of two, so we can use it directly + auto svs_bs = svs_details::SVSBlockSize(block_size, element_size(graph_max_degree)); + allocator_type data_allocator{std::move(allocator)}; + return blocked_type{{svs_bs}, data_allocator}; + } // Build SVS Graph using custom allocator // The logic has been taken from one of `MutableVamanaIndex` constructors @@ -269,11 +296,9 @@ struct SVSGraphBuilder { SVSIdType entry_point, size_t block_size, std::shared_ptr allocator, const svs::logging::logger_ptr &logger) { - auto svs_bs = - svs_details::SVSBlockSize(block_size, element_size(parameters.graph_max_degree)); // Perform graph construction. - allocator_type data_allocator{std::move(allocator)}; - blocked_type blocked_alloc{{svs_bs}, data_allocator}; + auto blocked_alloc = + make_blocked_allocator(block_size, parameters.graph_max_degree, std::move(allocator)); auto graph = graph_type{data.size(), parameters.graph_max_degree, blocked_alloc}; // SVS incorporates an advanced software prefetching scheme with two parameters: step and // lookahead. These parameters determine how far ahead to prefetch data vectors @@ -292,6 +317,24 @@ struct SVSGraphBuilder { return graph; } + static graph_type load(const svs::lib::LoadTable &table, size_t block_size, + const svs::index::vamana::VamanaBuildParameters ¶meters, + std::shared_ptr allocator) { + auto blocked_alloc = + make_blocked_allocator(block_size, parameters.graph_max_degree, std::move(allocator)); + // Load the graph from disk + return graph_type::load(table, blocked_alloc); + } + + static graph_type load(const std::string &path, size_t block_size, + const svs::index::vamana::VamanaBuildParameters ¶meters, + std::shared_ptr allocator) { + auto blocked_alloc = + make_blocked_allocator(block_size, parameters.graph_max_degree, std::move(allocator)); + // Load the graph from disk + return graph_type::load(path, blocked_alloc); + } + // SVS Vamana graph element size static constexpr size_t element_size(size_t graph_max_degree, size_t alignment = 0) { // For every Vamana graph node SVS allocates a record with current node ID and diff --git a/src/VecSim/containers/data_blocks_container.cpp b/src/VecSim/containers/data_blocks_container.cpp index 062f69cf7..bf63b683a 100644 --- a/src/VecSim/containers/data_blocks_container.cpp +++ b/src/VecSim/containers/data_blocks_container.cpp @@ -8,7 +8,7 @@ */ #include "data_blocks_container.h" -#include "VecSim/utils/serializer.h" +#include "VecSim/algorithms/hnsw/hnsw_serializer.h" #include DataBlocksContainer::DataBlocksContainer(size_t blockSize, size_t elementBytesCount, @@ -83,7 +83,9 @@ void DataBlocksContainer::restoreBlocks(std::istream &input, size_t num_vectors, // Get number of blocks unsigned int num_blocks = 0; - if (version == Serializer::EncodingVersion_V3) { + HNSWSerializer::EncodingVersion hnsw_version = + static_cast(version); + if (hnsw_version == HNSWSerializer::EncodingVersion::V3) { // In V3, the number of blocks is serialized, so we need to read it from the file. Serializer::readBinaryPOD(input, num_blocks); } else { @@ -97,7 +99,7 @@ void DataBlocksContainer::restoreBlocks(std::istream &input, size_t num_vectors, this->blocks.emplace_back(this->block_size, this->element_bytes_count, this->allocator, this->alignment); unsigned int block_len = 0; - if (version == Serializer::EncodingVersion_V3) { + if (hnsw_version == HNSWSerializer::EncodingVersion::V3) { // In V3, the length of each block is serialized, so we need to read it from the file. Serializer::readBinaryPOD(input, block_len); } else { diff --git a/src/VecSim/containers/data_blocks_container.h b/src/VecSim/containers/data_blocks_container.h index fca9f3884..c375590f2 100644 --- a/src/VecSim/containers/data_blocks_container.h +++ b/src/VecSim/containers/data_blocks_container.h @@ -53,7 +53,8 @@ class DataBlocksContainer : public VecsimBaseObject, public RawDataContainer { void saveVectorsData(std::ostream &output) const override; // Use that in deserialization when file was created with old version (v3) that serialized // the blocks themselves and not just thw raw vector data. - void restoreBlocks(std::istream &input, size_t num_vectors, Serializer::EncodingVersion); + void restoreBlocks(std::istream &input, size_t num_vectors, + Serializer::EncodingVersion version); void shrinkToFit(); #endif diff --git a/src/VecSim/index_factories/hnsw_factory.cpp b/src/VecSim/index_factories/hnsw_factory.cpp index 389c4cd18..0fe8b18ea 100644 --- a/src/VecSim/index_factories/hnsw_factory.cpp +++ b/src/VecSim/index_factories/hnsw_factory.cpp @@ -167,7 +167,7 @@ template inline VecSimIndex *NewIndex_ChooseMultiOrSingle(std::ifstream &input, const HNSWParams *params, const AbstractIndexInitParams &abstractInitParams, IndexComponents &components, - Serializer::EncodingVersion version) { + HNSWSerializer::EncodingVersion version) { HNSWIndex *index = nullptr; // check if single and call the ctor that loads index information from file. if (params->multi) @@ -199,7 +199,7 @@ VecSimIndex *NewIndex(const std::string &location, bool is_normalized) { throw std::runtime_error("Cannot open file"); } - Serializer::EncodingVersion version = Serializer::ReadVersion(input); + HNSWSerializer::EncodingVersion version = HNSWSerializer::ReadVersion(input); VecSimAlgo algo = VecSimAlgo_BF; Serializer::readBinaryPOD(input, algo); diff --git a/src/VecSim/index_factories/svs_factory.cpp b/src/VecSim/index_factories/svs_factory.cpp index 8219723e8..b368be6e6 100644 --- a/src/VecSim/index_factories/svs_factory.cpp +++ b/src/VecSim/index_factories/svs_factory.cpp @@ -201,6 +201,27 @@ VecSimIndex *NewIndex(const VecSimParams *params, bool is_normalized) { return NewIndexImpl(params, is_normalized); } +#if BUILD_TESTS +VecSimIndex *NewIndex(const std::string &location, const VecSimParams *params, bool is_normalized) { + auto index = NewIndexImpl(params, is_normalized); + // Side-cast to SVSIndexBase to call loadIndex + SVSIndexBase *svs_index = dynamic_cast(index); + if (svs_index != nullptr) { + try { + svs_index->loadIndex(location); + } catch (const std::exception &e) { + VecSimIndex_Free(index); + throw; + } + } else { + VecSimIndex_Free(index); + throw std::runtime_error( + "Cannot load index: Error in index creation before loading serialization"); + } + return index; +} +#endif + size_t EstimateElementSize(const SVSParams *params) { using graph_idx_type = uint32_t; // Assuming that the graph_max_degree can be unset in params. @@ -227,9 +248,14 @@ size_t EstimateInitialSize(const SVSParams *params, bool is_normalized) { // This is a temporary solution to avoid breaking the build when SVS is not available // and to allow the code to compile without SVS support. // TODO: remove HAVE_SVS when SVS will support all Redis platforms and compilers -#else // HAVE_SVS +#else // HAVE_SVS namespace SVSFactory { VecSimIndex *NewIndex(const VecSimParams *params, bool is_normalized) { return NULL; } +#if BUILD_TESTS +VecSimIndex *NewIndex(const std::string &location, const VecSimParams *params, bool is_normalized) { + return NULL; +} +#endif size_t EstimateInitialSize(const SVSParams *params, bool is_normalized) { return -1; } size_t EstimateElementSize(const SVSParams *params) { return -1; } }; // namespace SVSFactory diff --git a/src/VecSim/index_factories/svs_factory.h b/src/VecSim/index_factories/svs_factory.h index ccd910999..c4c6d04db 100644 --- a/src/VecSim/index_factories/svs_factory.h +++ b/src/VecSim/index_factories/svs_factory.h @@ -10,12 +10,17 @@ #pragma once #include // size_t +#include #include "VecSim/vec_sim.h" //typedef VecSimIndex #include "VecSim/vec_sim_common.h" // VecSimParams, SVSParams namespace SVSFactory { VecSimIndex *NewIndex(const VecSimParams *params, bool is_normalized = false); +#if BUILD_TESTS +VecSimIndex *NewIndex(const std::string &location, const VecSimParams *params, + bool is_normalized = false); +#endif size_t EstimateInitialSize(const SVSParams *params, bool is_normalized = false); size_t EstimateElementSize(const SVSParams *params); }; // namespace SVSFactory diff --git a/src/VecSim/utils/serializer.h b/src/VecSim/utils/serializer.h index f5c3a10fd..211a887b6 100644 --- a/src/VecSim/utils/serializer.h +++ b/src/VecSim/utils/serializer.h @@ -12,21 +12,37 @@ #include #include +/* + * Serializer Abstraction Layer for Vector Indexes + * ----------------------------------------------- + * This header defines the base `Serializer` class, which provides a generic interface for + * serializing vector indexes to disk. It is designed to be inherited + * by algorithm-specific serializers (e.g., HNSWSerializer, SVSSerializer), and provides a + * versioned, extensible mechanism for managing persistent representations of index state. + * Each serializer subclass must define its own EncodingVersion enum. + * How to Extend: + * 1. Derive a new class from `Serializer`, e.g., `MyIndexSerializer`. + * 2. Implement `saveIndex()` and `saveIndexIMP()`. + * 3. Implement `saveIndexFields()` to write out relevant fields in a deterministic order. + * 4. Optionally, add version-aware deserialization methods. + * + * Example Inheritance Tree: + * Serializer (abstract) + * ├── HNSWSerializer + * │ └── HNSWIndex + * └── SVSSerializer + * └── SVSIndex + */ + class Serializer { public: - typedef enum EncodingVersion { - EncodingVersion_DEPRECATED = 2, // Last deprecated version - EncodingVersion_V3, - EncodingVersion_V4, - EncodingVersion_INVALID, // This should always be last. - } EncodingVersion; + enum class EncodingVersion { INVALID }; - Serializer(EncodingVersion version = EncodingVersion_V4) : m_version(version) {} + Serializer(EncodingVersion version = EncodingVersion::INVALID) : m_version(version) {} - // Persist index into a file in the specified location with V3 encoding routine. - void saveIndex(const std::string &location); + virtual void saveIndex(const std::string &location) = 0; - EncodingVersion getVersion() const { return m_version; } + EncodingVersion getVersion() const; static EncodingVersion ReadVersion(std::ifstream &input); @@ -46,4 +62,7 @@ class Serializer { // Index memory size might be changed during index saving. virtual void saveIndexIMP(std::ofstream &output) = 0; + +private: + virtual void saveIndexFields(std::ofstream &output) const = 0; }; diff --git a/src/python_bindings/bindings.cpp b/src/python_bindings/bindings.cpp index a1922567e..f9e307f6c 100644 --- a/src/python_bindings/bindings.cpp +++ b/src/python_bindings/bindings.cpp @@ -11,6 +11,7 @@ #include "VecSim/index_factories/hnsw_factory.h" #if HAVE_SVS #include "VecSim/algorithms/svs/svs.h" +#include "VecSim/index_factories/svs_factory.h" #endif #include "VecSim/batch_iterator.h" #include "VecSim/types/bfloat16.h" @@ -566,6 +567,15 @@ class PySVSIndex : public PyVecSimIndex { } } + explicit PySVSIndex(const std::string &location, const SVSParams &svs_params) { + VecSimParams params = {.algo = VecSimAlgo_SVS, .algoParams = {.svsParams = svs_params}}; + this->index = + std::shared_ptr(SVSFactory::NewIndex(location, ¶ms), VecSimIndex_Free); + if (!this->index) { + throw std::runtime_error("Index creation failed"); + } + } + void addVectorsParallel(const py::object &input, const py::object &vectors_labels) { py::array vectors_data(input); // py::array labels(vectors_labels); @@ -587,6 +597,28 @@ class PySVSIndex : public PyVecSimIndex { assert(svs_index); svs_index->addVectors(vectors_data.data(), labels.data(), n_vectors); } + + void checkIntegrity() { + auto svs_index = dynamic_cast(this->index.get()); + assert(svs_index); + try { + svs_index->checkIntegrity(); + } catch (const std::exception &e) { + throw std::runtime_error(std::string("SVSIndex integrity check failed: ") + e.what()); + } + } + + void saveIndex(const std::string &location) { + auto svs_index = dynamic_cast(this->index.get()); + assert(svs_index); + svs_index->saveIndex(location); + } + + void loadIndex(const std::string &location) { + auto svs_index = dynamic_cast(this->index.get()); + assert(svs_index); + svs_index->loadIndex(location); + } }; class PyTiered_SVSIndex : public PyTieredIndex { @@ -805,8 +837,16 @@ PYBIND11_MODULE(VecSim, m) { py::class_(m, "SVSIndex") .def(py::init([](const SVSParams ¶ms) { return new PySVSIndex(params); }), py::arg("params")) + .def(py::init([](const std::string &location, const SVSParams ¶ms) { + return new PySVSIndex(location, params); + }), + py::arg("location"), py::arg("params")) .def("add_vector_parallel", &PySVSIndex::addVectorsParallel, py::arg("vectors"), - py::arg("labels")); + py::arg("labels")) + .def("check_integrity", &PySVSIndex::checkIntegrity) + .def("save_index", &PySVSIndex::saveIndex, py::arg("location")) + .def("load_index", &PySVSIndex::loadIndex, py::arg("location")); + py::class_(m, "Tiered_SVSIndex") .def(py::init([](const SVSParams &svs_params, const TieredSVSParams &tiered_svs_params, size_t flat_buffer_size = DEFAULT_BLOCK_SIZE) { diff --git a/tests/unit/test_common.cpp b/tests/unit/test_common.cpp index f4e1405f1..5958ab7e2 100644 --- a/tests/unit/test_common.cpp +++ b/tests/unit/test_common.cpp @@ -18,6 +18,10 @@ #include "VecSim/algorithms/hnsw/hnsw.h" #include "VecSim/algorithms/hnsw/hnsw_tiered.h" #include "VecSim/index_factories/hnsw_factory.h" +#if HAVE_SVS +#include "VecSim/index_factories/svs_factory.h" +#include "VecSim/algorithms/svs/svs.h" +#endif #include "mock_thread_pool.h" #include "tests_utils.h" #include "VecSim/index_factories/tiered_factory.h" @@ -28,6 +32,7 @@ #include #include #include +#include #include #include #include @@ -486,7 +491,7 @@ TEST_F(SerializerTest, HNSWSerialzer) { // Use a valid version output.seekp(0, std::ios_base::beg); - Serializer::writeBinaryPOD(output, Serializer::EncodingVersion_V3); + Serializer::writeBinaryPOD(output, HNSWSerializer::EncodingVersion::V3); Serializer::writeBinaryPOD(output, 42); output.flush(); @@ -498,7 +503,7 @@ TEST_F(SerializerTest, HNSWSerialzer) { // Use a valid version output.seekp(0, std::ios_base::beg); - Serializer::writeBinaryPOD(output, Serializer::EncodingVersion_V3); + Serializer::writeBinaryPOD(output, HNSWSerializer::EncodingVersion::V3); Serializer::writeBinaryPOD(output, VecSimAlgo_HNSWLIB); Serializer::writeBinaryPOD(output, size_t(128)); @@ -512,6 +517,41 @@ TEST_F(SerializerTest, HNSWSerialzer) { output.close(); } +#if HAVE_SVS +TEST_F(SerializerTest, SVSSerializer) { + + this->file_name = std::string(getenv("ROOT")) + "/tests/unit/bad_index_svs"; + auto metadata_path = std::filesystem::path(this->file_name) / "metadata"; + + // Try to load an index from a directory that doesn't exist. + SVSParams params = { + .type = VecSimType_FLOAT32, + .dim = 1024, + .metric = VecSimMetric_L2, + }; + VecSimParams index_params = {.algo = VecSimAlgo_SVS, .algoParams = {.svsParams = params}}; + + ASSERT_EXCEPTION_MESSAGE( + SVSFactory::NewIndex(this->file_name, &index_params), std::runtime_error, + std::string("Failed to open metadata file: ") + metadata_path.string()); + + // Create directory and metadata file with invalid encoding version + std::filesystem::create_directories(this->file_name); + std::ofstream output(metadata_path, std::ios::binary); + + // Write invalid encoding version (42) + Serializer::writeBinaryPOD(output, 42); + output.flush(); + output.close(); + + ASSERT_EXCEPTION_MESSAGE(SVSFactory::NewIndex(this->file_name, &index_params), + std::runtime_error, "Cannot load index: bad encoding version: 42"); + + // Clean up + std::filesystem::remove_all(this->file_name); +} +#endif + struct logCtx { public: std::vector logBuffer; @@ -525,7 +565,6 @@ void test_log_impl(void *ctx, const char *level, const char *message) { } TEST(CommonAPITest, testlogBasic) { - logCtx log; log.prefix = "test log prefix: "; diff --git a/tests/unit/test_hnsw.cpp b/tests/unit/test_hnsw.cpp index 0f67073ba..47a5393b7 100644 --- a/tests/unit/test_hnsw.cpp +++ b/tests/unit/test_hnsw.cpp @@ -1755,7 +1755,7 @@ TYPED_TEST(HNSWTest, HNSWSerializationCurrentVersion) { // Verify that the index was loaded as expected. ASSERT_TRUE(serialized_hnsw_index->checkIntegrity().valid_state); - ASSERT_EQ(serialized_hnsw_index->getVersion(), Serializer::EncodingVersion_V4); + ASSERT_EQ(serialized_hnsw_index->getVersion(), HNSWSerializer::EncodingVersion::V4); VecSimIndexDebugInfo info2 = VecSimIndex_DebugInfo(serialized_index); ASSERT_EQ(info2.commonInfo.basicInfo.algo, VecSimAlgo_HNSWLIB); @@ -1825,7 +1825,7 @@ TYPED_TEST(HNSWTest, HNSWSerializationV3) { auto *serialized_hnsw_index = this->CastToHNSW(serialized_index); // Verify that the index was loaded as expected. - ASSERT_EQ(serialized_hnsw_index->getVersion(), Serializer::EncodingVersion_V3); + ASSERT_EQ(serialized_hnsw_index->getVersion(), HNSWSerializer::EncodingVersion::V3); ASSERT_TRUE(serialized_hnsw_index->checkIntegrity().valid_state); VecSimIndexDebugInfo info = VecSimIndex_DebugInfo(serialized_index); diff --git a/tests/unit/test_svs.cpp b/tests/unit/test_svs.cpp index bd66b2064..248c1bb49 100644 --- a/tests/unit/test_svs.cpp +++ b/tests/unit/test_svs.cpp @@ -12,13 +12,14 @@ #include "unit_test_utils.h" #include #include +#include #include #include - #if HAVE_SVS #include #include "spdlog/sinks/ostream_sink.h" #include "VecSim/algorithms/svs/svs.h" +#include "VecSim/index_factories/svs_factory.h" // There are possible cases when SVS Index cannot be created with the requested quantization mode // due to platform and/or hardware limitations or combination of requested 'compression' modes. @@ -2770,6 +2771,188 @@ TEST(SVSTest, quant_modes) { } } +TEST(SVSTest, save_load) { + namespace fs = std::filesystem; + // Limit VecSim log level to avoid printing too much information + VecSimIndexInterface::setLogCallbackFunction(svsTestLogCallBackNoDebug); + + const size_t dim = 4; + const size_t n = 100; + const size_t k = 10; + + // Helper function to convert quant_bits to string for error messages + auto quant_bits_to_string = [](VecSimSvsQuantBits quant_bits) -> std::string { + switch (quant_bits) { + case VecSimSvsQuant_NONE: + return "VecSimSvsQuant_NONE"; + case VecSimSvsQuant_Scalar: + return "VecSimSvsQuant_Scalar"; + case VecSimSvsQuant_8: + return "VecSimSvsQuant_8"; + case VecSimSvsQuant_4: + return "VecSimSvsQuant_4"; + case VecSimSvsQuant_4x4: + return "VecSimSvsQuant_4x4"; + case VecSimSvsQuant_4x8: + return "VecSimSvsQuant_4x8"; + case VecSimSvsQuant_4x8_LeanVec: + return "VecSimSvsQuant_4x8_LeanVec"; + case VecSimSvsQuant_8x8_LeanVec: + return "VecSimSvsQuant_8x8_LeanVec"; + default: + return "Unknown(" + std::to_string(static_cast(quant_bits)) + ")"; + } + }; + + // Test both single and multi variations + for (bool is_multi : {false, true}) { + for (auto quant_bits : {VecSimSvsQuant_NONE, VecSimSvsQuant_Scalar, VecSimSvsQuant_8, + VecSimSvsQuant_4, VecSimSvsQuant_4x4, VecSimSvsQuant_4x8, + VecSimSvsQuant_4x8_LeanVec, VecSimSvsQuant_8x8_LeanVec}) { + SVSParams params = { + .type = VecSimType_FLOAT32, + .dim = dim, + .metric = VecSimMetric_L2, + .multi = is_multi, + .blockSize = 1024, + /* SVS-Vamana specifics */ + .quantBits = quant_bits, + .graph_max_degree = 63, // x^2-1 to round the graph block size + .construction_window_size = 20, + .max_candidate_pool_size = 1024, + .prune_to = 60, + .use_search_history = VecSimOption_ENABLE, + }; + + VecSimParams index_params = CreateParams(params); + VecSimIndex *index = VecSimIndex_New(&index_params); + if (index == nullptr) { + if (std::get<1>(svs_details::isSVSQuantBitsSupported(quant_bits))) { + GTEST_FAIL() << "Failed to create SVS index with quant_bits: " + << quant_bits_to_string(quant_bits) + << ", multi: " << (is_multi ? "true" : "false"); + } else { + GTEST_SKIP() << "SVS LVQ is not supported for quant_bits: " + << quant_bits_to_string(quant_bits) + << ", multi: " << (is_multi ? "true" : "false"); + } + } + + std::vector> v(n); + std::vector ids(n); + + if (is_multi) { + const size_t per_label = 2; + const size_t num_labels = n / per_label; + + for (size_t i = 0; i < n; i++) { + size_t label_id = (i / per_label); + GenerateVector(v[i].data(), dim, i); + ids[i] = label_id; + } + } else { + // For single-index, each vector has a unique label (same as its index) + for (size_t i = 0; i < n; i++) { + GenerateVector(v[i].data(), dim, i); + ids[i] = i; + } + } + + auto svs_index = dynamic_cast(index); + ASSERT_NE(svs_index, nullptr) + << "Failed to cast to SVSIndexBase with quant_bits: " + << quant_bits_to_string(quant_bits) << ", multi: " << (is_multi ? "true" : "false"); + svs_index->addVectors(v.data(), ids.data(), n); + + ASSERT_EQ(VecSimIndex_IndexSize(index), n) + << "Index size mismatch after adding vectors with quant_bits: " + << quant_bits_to_string(quant_bits) << ", multi: " << (is_multi ? "true" : "false"); + + float query[] = {50, 50, 50, 50}; + auto verify_res = [&](size_t id, double score, size_t idx) { + EXPECT_DOUBLE_EQ(VecSimIndex_GetDistanceFrom_Unsafe(index, id, query), score); + // Both single and multi should return labels starting from 45 + if (is_multi) { + // For multi, that label of {50,50,50,50} is 25 + size_t expected_label = (20 + idx); + EXPECT_EQ(id, expected_label); + } else { + EXPECT_EQ(id, (idx + 45)); + } + }; + runTopKSearchTest(index, query, k, verify_res, nullptr, BY_ID); + + fs::path tmp{fs::temp_directory_path()}; + auto subdir = "vecsim_test_" + std::to_string(std::rand()); + auto index_path = tmp / subdir; + while (fs::exists(index_path)) { + subdir = "vecsim_test_" + std::to_string(std::rand()); + index_path = tmp / subdir; + } + fs::create_directories(index_path); + + try { + svs_index->saveIndex(index_path.string()); + } catch (const std::exception &e) { + GTEST_FAIL() << "Failed to save index with quant_bits: " + << quant_bits_to_string(quant_bits) + << ", multi: " << (is_multi ? "true" : "false") + << ", error: " << e.what(); + } + VecSimIndex_Free(index); + + // Recreate the index from the saved path + index = VecSimIndex_New(&index_params); + svs_index = dynamic_cast(index); + ASSERT_NE(svs_index, nullptr) + << "Failed to recreate index for loading with quant_bits: " + << quant_bits_to_string(quant_bits) << ", multi: " << (is_multi ? "true" : "false"); + + try { + svs_index->loadIndex(index_path.string()); + svs_index->checkIntegrity(); + } catch (const std::exception &e) { + GTEST_FAIL() << "Failed to load index with quant_bits: " + << quant_bits_to_string(quant_bits) + << ", multi: " << (is_multi ? "true" : "false") + << ", error: " << e.what(); + } + + // Verify the index was loaded correctly + ASSERT_EQ(VecSimIndex_IndexSize(index), n) + << "Index size mismatch after loading with quant_bits: " + << quant_bits_to_string(quant_bits) << ", multi: " << (is_multi ? "true" : "false"); + runTopKSearchTest(index, query, k, verify_res, nullptr, BY_ID); + + // Test load from file with constructor + VecSimIndex *svs_index_load = nullptr; + try { + svs_index_load = SVSFactory::NewIndex(index_path.string(), &index_params); + } catch (const std::exception &e) { + GTEST_FAIL() << "Failed to create index from file with quant_bits: " + << quant_bits_to_string(quant_bits) + << ", multi: " << (is_multi ? "true" : "false") + << ", error: " << e.what(); + } + ASSERT_NE(svs_index_load, nullptr) + << "Failed to create index from file with quant_bits: " + << quant_bits_to_string(quant_bits) << ", multi: " << (is_multi ? "true" : "false"); + + // Verify the index was loaded correctly + ASSERT_EQ(VecSimIndex_IndexSize(svs_index_load), n) + << "Index size mismatch for constructor-loaded index with quant_bits: " + << quant_bits_to_string(quant_bits) << ", multi: " << (is_multi ? "true" : "false"); + runTopKSearchTest(svs_index_load, query, k, verify_res, nullptr, BY_ID); + + VecSimIndex_Free(svs_index_load); + VecSimIndex_Free(index); + + // Cleanup + fs::remove_all(index_path); // Cleanup the saved index directory + } + } +} + TYPED_TEST(SVSTest, logging_runtime_params) { const size_t dim = 4; const size_t n = 100;