Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/VecSim/algorithms/brute_force/brute_force.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,10 @@ class BruteForceIndex : public VecSimIndexAbstract<DataType, DistType> {
size_t indexCapacity() const override;
std::unique_ptr<RawDataContainer::Iterator> getVectorsIterator() const;
const DataType *getDataByInternalId(idType id) const {
return reinterpret_cast<const DataType *>(this->vectors->getElement(id));
// `vectors` is always a DataBlocksContainer; skip the RawDataContainer vtable.
return reinterpret_cast<const DataType *>(
static_cast<const DataBlocksContainer *>(this->vectors)
->DataBlocksContainer::getElement(id));
}
VecSimQueryReply *topKQuery(const void *queryBlob, size_t k,
VecSimQueryParams *queryParams) const override;
Expand Down
4 changes: 3 additions & 1 deletion src/VecSim/algorithms/hnsw/hnsw.h
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,9 @@ labelType HNSWIndex<DataType, DistType>::getEntryPointLabel() const {

template <typename DataType, typename DistType>
const char *HNSWIndex<DataType, DistType>::getDataByInternalId(idType internal_id) const {
return this->vectors->getElement(internal_id);
// `vectors` is always a DataBlocksContainer; skip the RawDataContainer vtable on the hot path.
return static_cast<const DataBlocksContainer *>(this->vectors)
->DataBlocksContainer::getElement(internal_id);
}

template <typename DataType, typename DistType>
Expand Down
5 changes: 0 additions & 5 deletions src/VecSim/containers/data_blocks_container.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,6 @@ RawDataContainer::Status DataBlocksContainer::addElement(const void *element, si
return Status::OK;
}

const char *DataBlocksContainer::getElement(size_t id) const {
assert(id < element_count);
return blocks.at(id / this->block_size).getElement(id % this->block_size);
}

RawDataContainer::Status DataBlocksContainer::removeElement(size_t id) {
assert(id == element_count - 1); // only the last element can be removed
blocks.back().popLastElement();
Expand Down
6 changes: 5 additions & 1 deletion src/VecSim/containers/data_blocks_container.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,11 @@ class DataBlocksContainer : public VecsimBaseObject, public RawDataContainer {

Status addElement(const void *element, size_t id) override;

const char *getElement(size_t id) const override;
// Inlined so the hot search path (via getDataByInternalId) can fold the indexing arithmetic.
const char *getElement(size_t id) const override {
assert(id < element_count);
return blocks[id / block_size].getElement(id % block_size);
}

Status removeElement(size_t id) override;

Expand Down
5 changes: 5 additions & 0 deletions src/VecSim/spaces/computer/calculator.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ class IndexCalculatorInterface : public VecsimBaseObject {
virtual ~IndexCalculatorInterface() = default;

virtual DistType calcDistance(const void *v1, const void *v2, size_t dim) const = 0;

// Raw distance function; cached by the index to skip the vtable on the hot path.
virtual spaces::dist_func_t<DistType> getDistFunc() const = 0;
};

/**
Expand Down Expand Up @@ -56,4 +59,6 @@ class DistanceCalculatorCommon
DistType calcDistance(const void *v1, const void *v2, size_t dim) const override {
return this->dist_func(v1, v2, dim);
}

spaces::dist_func_t<DistType> getDistFunc() const override { return this->dist_func; }
};
16 changes: 13 additions & 3 deletions src/VecSim/vec_sim_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ struct VecSimIndexAbstract : public VecSimIndexInterface {
RawDataContainer *vectors; // The raw vectors data container.
private:
IndexCalculatorInterface<DistType> *indexCalculator; // Distance calculator.
spaces::dist_func_t<DistType> cachedDistFunc; // Cached dist func, used on the hot path.
PreprocessorsContainerAbstract *preprocessors; // Storage and query preprocessors.

size_t inputBlobSize; // The size of input vectors/queries blob in bytes. May differ from dim *
Expand Down Expand Up @@ -120,8 +121,11 @@ struct VecSimIndexAbstract : public VecSimIndexInterface {
metric(params.metric),
blockSize(params.blockSize ? params.blockSize : DEFAULT_BLOCK_SIZE), lastMode(EMPTY_MODE),
isMulti(params.multi), logCallbackCtx(params.logCtx),
indexCalculator(components.indexCalculator), preprocessors(components.preprocessors),
inputBlobSize(params.inputBlobSize), storedDataSize(params.storedDataSize) {
indexCalculator(components.indexCalculator),
cachedDistFunc(components.indexCalculator ? components.indexCalculator->getDistFunc()
: nullptr),
preprocessors(components.preprocessors), inputBlobSize(params.inputBlobSize),
storedDataSize(params.storedDataSize) {
assert(VecSimType_sizeof(vecType));
assert(storedDataSize);
assert(inputBlobSize);
Expand All @@ -142,10 +146,16 @@ struct VecSimIndexAbstract : public VecSimIndexInterface {
/**
* @brief Calculate the distance between two vectors based on index parameters.
*
* Uses the cached dist func to avoid the indexCalculator vtable on the hot path.
*
* @note Precondition: @c cachedDistFunc must be non-null. Subclasses that construct
* this index with a null @c indexCalculator (e.g. SVS, which uses its own
* internal distance kernels) must not call this method.
*
* @return the distance between the vectors.
*/
DistType calcDistance(const void *vector_data1, const void *vector_data2) const {
return indexCalculator->calcDistance(vector_data1, vector_data2, this->dim);
return cachedDistFunc(vector_data1, vector_data2, this->dim);
}

/**
Expand Down
3 changes: 3 additions & 0 deletions tests/unit/test_components.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ class DistanceCalculatorDummy : public DistanceCalculatorInterface<DistType, dum
virtual DistType calcDistance(const void *v1, const void *v2, size_t dim) const {
return this->dist_func(7);
}

// Dummy uses a non-standard dist func signature, so the standard slot is unavailable.
spaces::dist_func_t<DistType> getDistFunc() const override { return nullptr; }
};

} // namespace dummyCalcultor
Expand Down
Loading