Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SOLR-16836: introduce support for high dimensional vectors #1680

Merged
merged 6 commits into from
Jun 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions solr/CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ New Features
* SOLR-16719: AffinityPlacementFactory now supports spreading replicas across domains within the availablity zone and
optionally fail the request if more than a configurable number of replicas need to be placed in a single domain. (Houston Putman, Tomás Fernández Löbbe)

* SOLR-16836: Introduced support for high dimensional vectors (Alessandro Benedetti).

Improvements
---------------------

Expand Down
67 changes: 65 additions & 2 deletions solr/core/src/java/org/apache/solr/schema/DenseVectorField.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,17 @@
import static org.apache.lucene.codecs.lucene95.Lucene95HnswVectorsFormat.DEFAULT_BEAM_WIDTH;
import static org.apache.lucene.codecs.lucene95.Lucene95HnswVectorsFormat.DEFAULT_MAX_CONN;

import java.lang.invoke.MethodHandles;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import org.apache.lucene.document.FieldType;
import org.apache.lucene.document.KnnByteVectorField;
import org.apache.lucene.document.KnnFloatVectorField;
import org.apache.lucene.document.StoredField;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.IndexableField;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
Expand All @@ -43,6 +47,8 @@
import org.apache.solr.util.vector.ByteDenseVectorParser;
import org.apache.solr.util.vector.DenseVectorParser;
import org.apache.solr.util.vector.FloatDenseVectorParser;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
* Provides a field type to support Lucene's {@link org.apache.lucene.document.KnnVectorField}. See
Expand All @@ -53,6 +59,7 @@
* Only {@code Indexed} and {@code Stored} attributes are supported.
*/
public class DenseVectorField extends FloatPointField {
private static final Logger log = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass());
public static final String HNSW_ALGORITHM = "hnsw";
public static final String DEFAULT_KNN_ALGORITHM = HNSW_ALGORITHM;
static final String KNN_VECTOR_DIMENSION = "vectorDimension";
Expand Down Expand Up @@ -182,6 +189,31 @@ public void checkSchemaField(final SchemaField field) throws SolrException {
SolrException.ErrorCode.SERVER_ERROR,
getClass().getSimpleName() + " fields can not have docValues: " + field.getName());
}

switch (vectorEncoding) {
case FLOAT32:
if (dimension > FloatVectorValues.MAX_DIMENSIONS) {
if (log.isWarnEnabled()) {
alessandrobenedetti marked this conversation as resolved.
Show resolved Hide resolved
log.warn(
"The vector dimension {} specified for field {} exceeds the current Lucene default max dimension of {}. It's un-tested territory, extra caution and benchmarks are recommended for production systems.",
dimension,
field.getName(),
FloatVectorValues.MAX_DIMENSIONS);
}
}
break;
case BYTE:
if (dimension > ByteVectorValues.MAX_DIMENSIONS) {
if (log.isWarnEnabled()) {
log.warn(
"The vector dimension {} specified for field {} exceeds the current Lucene default max dimension of {}. It's un-tested territory, extra caution and benchmarks are recommended for production systems.",
dimension,
field.getName(),
ByteVectorValues.MAX_DIMENSIONS);
}
}
break;
}
}

@Override
Expand Down Expand Up @@ -218,22 +250,53 @@ public List<IndexableField> createFields(SchemaField field, Object value) {

@Override
public IndexableField createField(SchemaField field, Object vectorValue) {
FieldType denseVectorFieldType = getDenseVectorFieldType();

if (vectorValue == null) return null;
DenseVectorParser vectorBuilder = (DenseVectorParser) vectorValue;
switch (vectorEncoding) {
case BYTE:
return new KnnByteVectorField(
field.getName(), vectorBuilder.getByteVector(), similarityFunction);
field.getName(), vectorBuilder.getByteVector(), denseVectorFieldType);
case FLOAT32:
return new KnnFloatVectorField(
field.getName(), vectorBuilder.getFloatVector(), similarityFunction);
field.getName(), vectorBuilder.getFloatVector(), denseVectorFieldType);
default:
throw new SolrException(
SolrException.ErrorCode.SERVER_ERROR,
"Unexpected state. Vector Encoding: " + vectorEncoding);
}
}

/**
* This is needed at the moment to support dimensions higher than a hard-coded arbitrary Lucene
* max dimension. N.B. this may stop working and need changes when adopting future Lucene
* releases.
*
* @return a FieldType compatible with Dense vectors
*/
private FieldType getDenseVectorFieldType() {
alessandrobenedetti marked this conversation as resolved.
Show resolved Hide resolved
FieldType vectorFieldType =
new FieldType() {
@Override
public int vectorDimension() {
return dimension;
}

@Override
public VectorEncoding vectorEncoding() {
return vectorEncoding;
}

@Override
public VectorSimilarityFunction vectorSimilarityFunction() {
return similarityFunction;
}
};

return vectorFieldType;
}

@Override
public Object toObject(IndexableField f) {
if (vectorEncoding.equals(VectorEncoding.BYTE)) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
<?xml version="1.0" ?>
<!--
Licensed to the Apache Software Foundation (ASF) under one or more
contributor license agreements. See the NOTICE file distributed with
this work for additional information regarding copyright ownership.
The ASF licenses this file to You under the Apache License, Version 2.0
(the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->

<!-- Test schema file for DenseVectorField -->

<schema name="schema-densevector" version="1.0">
<fieldType name="string" class="solr.StrField" multiValued="true"/>
<fieldType name="knn_vector" class="solr.DenseVectorField" vectorDimension="2048" similarityFunction="cosine" />
<fieldType name="plong" class="solr.LongPointField" useDocValuesAsStored="false"/>

<field name="id" type="string" indexed="true" stored="true" multiValued="false" required="false"/>
<field name="vector" type="knn_vector" indexed="true" stored="true"/>
<field name="string_field" type="string" indexed="true" stored="true" multiValued="false" required="false"/>

<field name="_version_" type="plong" indexed="true" stored="true" multiValued="false" />
<field name="_text_" type="text_general" indexed="true" stored="false" multiValued="true"/>
<copyField source="*" dest="_text_"/>
<fieldType name="text_general" class="solr.TextField" positionIncrementGap="100" multiValued="true">
<analyzer type="index">
<tokenizer class="solr.StandardTokenizerFactory"/>
<filter class="solr.StopFilterFactory" words="stopwords.txt" ignoreCase="true"/>
<filter class="solr.LowerCaseFilterFactory"/>
</analyzer>
<analyzer type="query">
<tokenizer class="solr.StandardTokenizerFactory"/>
<filter class="solr.StopFilterFactory" words="stopwords.txt" ignoreCase="true"/>
<filter class="solr.SynonymGraphFilterFactory" synonyms="synonyms.txt" ignoreCase="true" expand="true"/>
<filter class="solr.LowerCaseFilterFactory"/>
</analyzer>
</fieldType>

<uniqueKey>id</uniqueKey>
</schema>
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,17 @@
<schema name="schema-densevector" version="1.0">
<fieldType name="string" class="solr.StrField" multiValued="true"/>
<fieldType name="knn_vector" class="solr.DenseVectorField" vectorDimension="4" similarityFunction="cosine" />
<fieldType name="plong" class="solr.LongPointField" useDocValuesAsStored="false"/>

<fieldType name="knn_vector_byte_encoding" class="solr.DenseVectorField" vectorDimension="4" similarityFunction="cosine" vectorEncoding="BYTE"/>

<fieldType name="high_dimensional_float_knn_vector" class="solr.DenseVectorField" vectorDimension="2048" similarityFunction="cosine" vectorEncoding="FLOAT32"/>
<fieldType name="high_dimensional_byte_knn_vector" class="solr.DenseVectorField" vectorDimension="2048" similarityFunction="cosine" vectorEncoding="BYTE"/>
<fieldType name="plong" class="solr.LongPointField" useDocValuesAsStored="false"/>

<field name="id" type="string" indexed="true" stored="true" multiValued="false" required="false"/>
<field name="vector" type="knn_vector" indexed="true" stored="true"/>
<field name="vector2" type="knn_vector" indexed="true" stored="true"/>
<field name="vector_byte_encoding" type="knn_vector_byte_encoding" indexed="true" stored="true" />
<field name="2048_byte_vector" type="high_dimensional_byte_knn_vector" indexed="true" stored="true" />
<field name="2048_float_vector" type="high_dimensional_float_knn_vector" indexed="true" stored="true" />
<field name="string_field" type="string" indexed="true" stored="true" multiValued="false" required="false"/>

<field name="_version_" type="plong" indexed="true" stored="true" multiValued="false" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@

import static org.hamcrest.core.Is.is;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
Expand Down Expand Up @@ -456,6 +458,25 @@ public void indexing_correctDocument_shouldBeIndexed() throws Exception {
}
}

@Test
public void indexing_highDimensionalityVectorDocument_shouldBeIndexed() throws Exception {
try {
initCore("solrconfig-basic.xml", "schema-densevector-high-dimensionality.xml");

List<Float> highDimensionalityVector = new ArrayList<>();
for (float i = 0; i < 2048f; i++) {
highDimensionalityVector.add(i);
}
SolrInputDocument correctDoc = new SolrInputDocument();
correctDoc.addField("id", "0");
correctDoc.addField("vector", highDimensionalityVector);

assertU(adoc(correctDoc));
alessandrobenedetti marked this conversation as resolved.
Show resolved Hide resolved
} finally {
deleteCore();
}
}

@Test
public void query_vectorFloatEncoded_storedField_shouldBeReturnedInResults() throws Exception {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import org.apache.solr.SolrTestCaseJ4;
import org.apache.solr.common.SolrException;
Expand Down Expand Up @@ -214,6 +215,92 @@ public void correctVectorField_shouldSearchOnThatField() {
"//result/doc[3]/str[@name='id'][.='12']");
}

@Test
public void highDimensionFloatVectorField_shouldSearchOnThatField() {
int highDimension = 2048;
List<SolrInputDocument> docsToIndex = this.prepareHighDimensionFloatVectorsDocs(highDimension);
for (SolrInputDocument doc : docsToIndex) {
assertU(adoc(doc));
}
assertU(commit());

float[] highDimensionalityQueryVector = new float[highDimension];
for (int i = 0; i < highDimension; i++) {
highDimensionalityQueryVector[i] = i;
}
String vectorToSearch = Arrays.toString(highDimensionalityQueryVector);

assertQ(
req(CommonParams.Q, "{!knn f=2048_float_vector topK=1}" + vectorToSearch, "fl", "id"),
"//result[@numFound='1']",
"//result/doc[1]/str[@name='id'][.='1']");
}

@Test
public void highDimensionByteVectorField_shouldSearchOnThatField() {
int highDimension = 2048;
List<SolrInputDocument> docsToIndex = this.prepareHighDimensionByteVectorsDocs(highDimension);
for (SolrInputDocument doc : docsToIndex) {
assertU(adoc(doc));
}
assertU(commit());

byte[] highDimensionalityQueryVector = new byte[highDimension];
for (int i = 0; i < highDimension; i++) {
highDimensionalityQueryVector[i] = (byte) (i % 127);
}
String vectorToSearch = Arrays.toString(highDimensionalityQueryVector);

assertQ(
req(CommonParams.Q, "{!knn f=2048_byte_vector topK=1}" + vectorToSearch, "fl", "id"),
"//result[@numFound='1']",
"//result/doc[1]/str[@name='id'][.='1']");
}

private List<SolrInputDocument> prepareHighDimensionFloatVectorsDocs(int highDimension) {
int docsCount = 13;
String field = "2048_float_vector";
List<SolrInputDocument> docs = new ArrayList<>(docsCount);

for (int i = 1; i < docsCount + 1; i++) {
SolrInputDocument doc = new SolrInputDocument();
doc.addField(IDField, i);
docs.add(doc);
}

for (int i = 0; i < docsCount; i++) {
List<Integer> highDimensionalityVector = new ArrayList<>();
for (int j = i * highDimension; j < highDimension; j++) {
highDimensionalityVector.add(j);
}
docs.get(i).addField(field, highDimensionalityVector);
}
Collections.reverse(docs);
return docs;
}

private List<SolrInputDocument> prepareHighDimensionByteVectorsDocs(int highDimension) {
int docsCount = 13;
String field = "2048_byte_vector";
List<SolrInputDocument> docs = new ArrayList<>(docsCount);

for (int i = 1; i < docsCount + 1; i++) {
SolrInputDocument doc = new SolrInputDocument();
doc.addField(IDField, i);
docs.add(doc);
}

for (int i = 0; i < docsCount; i++) {
List<Integer> highDimensionalityVector = new ArrayList<>();
for (int j = i * highDimension; j < highDimension; j++) {
highDimensionalityVector.add(j % 127);
}
docs.get(i).addField(field, highDimensionalityVector);
}
Collections.reverse(docs);
return docs;
}

@Test
public void vectorByteEncodingField_shouldSearchOnThatField() {
String vectorToSearch = "[2, 2, 1, 3]";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ s|Required |Default: none
The dimension of the dense vector to pass in.
+
Accepted values:
Any integer < = `1024`.
Any integer.

`similarityFunction`::
+
Expand Down