diff --git a/core/revapi.json b/core/revapi.json index f844479cd29..63e2cef5a1e 100644 --- a/core/revapi.json +++ b/core/revapi.json @@ -6887,7 +6887,70 @@ "code": "java.method.removed", "old": "method com.datastax.oss.driver.api.core.type.reflect.GenericType> com.datastax.oss.driver.api.core.type.reflect.GenericType::vectorOf(com.datastax.oss.driver.api.core.type.reflect.GenericType)", "justification": "Refactoring in JAVA-3061" - } + }, + { + "code": "java.class.removed", + "old": "class com.datastax.oss.driver.api.core.data.CqlVector.Builder", + "justification": "Refactorings in PR 1666" + }, + { + "code": "java.method.removed", + "old": "method com.datastax.oss.driver.api.core.data.CqlVector.Builder com.datastax.oss.driver.api.core.data.CqlVector::builder()", + "justification": "Refactorings in PR 1666" + }, + { + "code": "java.method.removed", + "old": "method java.lang.Iterable com.datastax.oss.driver.api.core.data.CqlVector::getValues()", + "justification": "Refactorings in PR 1666" + }, + { + "code": "java.generics.formalTypeParameterChanged", + "old": "class com.datastax.oss.driver.api.core.data.CqlVector", + "new": "class com.datastax.oss.driver.api.core.data.CqlVector", + "justification": "Refactorings in PR 1666" + }, + { + "code": "java.method.parameterTypeChanged", + "old": "parameter com.datastax.oss.driver.api.core.type.codec.TypeCodec> com.datastax.oss.driver.api.core.type.codec.TypeCodecs::vectorOf(===com.datastax.oss.driver.api.core.type.CqlVectorType===, com.datastax.oss.driver.api.core.type.codec.TypeCodec)", + "new": "parameter com.datastax.oss.driver.api.core.type.codec.TypeCodec> com.datastax.oss.driver.api.core.type.codec.TypeCodecs::vectorOf(===com.datastax.oss.driver.api.core.type.VectorType===, com.datastax.oss.driver.api.core.type.codec.TypeCodec)", + "justification": "Refactorings in PR 1666" + }, + { + "code": "java.method.parameterTypeParameterChanged", + "old": "parameter com.datastax.oss.driver.api.core.type.codec.TypeCodec> com.datastax.oss.driver.api.core.type.codec.TypeCodecs::vectorOf(com.datastax.oss.driver.api.core.type.CqlVectorType, ===com.datastax.oss.driver.api.core.type.codec.TypeCodec===)", + "new": "parameter com.datastax.oss.driver.api.core.type.codec.TypeCodec> com.datastax.oss.driver.api.core.type.codec.TypeCodecs::vectorOf(com.datastax.oss.driver.api.core.type.VectorType, ===com.datastax.oss.driver.api.core.type.codec.TypeCodec===)", + "justification": "Refactorings in PR 1666" + }, + { + "code": "java.method.returnTypeTypeParametersChanged", + "old": "method com.datastax.oss.driver.api.core.type.codec.TypeCodec> com.datastax.oss.driver.api.core.type.codec.TypeCodecs::vectorOf(com.datastax.oss.driver.api.core.type.CqlVectorType, com.datastax.oss.driver.api.core.type.codec.TypeCodec)", + "new": "method com.datastax.oss.driver.api.core.type.codec.TypeCodec> com.datastax.oss.driver.api.core.type.codec.TypeCodecs::vectorOf(com.datastax.oss.driver.api.core.type.VectorType, com.datastax.oss.driver.api.core.type.codec.TypeCodec)", + "justification": "Refactorings in PR 1666" + }, + { + "code": "java.generics.formalTypeParameterChanged", + "old": "method com.datastax.oss.driver.api.core.type.codec.TypeCodec> com.datastax.oss.driver.api.core.type.codec.TypeCodecs::vectorOf(com.datastax.oss.driver.api.core.type.CqlVectorType, com.datastax.oss.driver.api.core.type.codec.TypeCodec)", + "new": "method com.datastax.oss.driver.api.core.type.codec.TypeCodec> com.datastax.oss.driver.api.core.type.codec.TypeCodecs::vectorOf(com.datastax.oss.driver.api.core.type.VectorType, com.datastax.oss.driver.api.core.type.codec.TypeCodec)", + "justification": "Refactorings in PR 1666" + }, + { + "code": "java.method.parameterTypeParameterChanged", + "old": "parameter com.datastax.oss.driver.api.core.type.reflect.GenericType> com.datastax.oss.driver.api.core.type.reflect.GenericType::vectorOf(===com.datastax.oss.driver.api.core.type.reflect.GenericType===)", + "new": "parameter com.datastax.oss.driver.api.core.type.reflect.GenericType> com.datastax.oss.driver.api.core.type.reflect.GenericType::vectorOf(===com.datastax.oss.driver.api.core.type.reflect.GenericType===)", + "justification": "Refactorings in PR 1666" + }, + { + "code": "java.method.returnTypeTypeParametersChanged", + "old": "method com.datastax.oss.driver.api.core.type.reflect.GenericType> com.datastax.oss.driver.api.core.type.reflect.GenericType::vectorOf(com.datastax.oss.driver.api.core.type.reflect.GenericType)", + "new": "method com.datastax.oss.driver.api.core.type.reflect.GenericType> com.datastax.oss.driver.api.core.type.reflect.GenericType::vectorOf(com.datastax.oss.driver.api.core.type.reflect.GenericType)", + "justification": "Refactorings in PR 1666" + }, + { + "code": "java.generics.formalTypeParameterChanged", + "old": "method com.datastax.oss.driver.api.core.type.reflect.GenericType> com.datastax.oss.driver.api.core.type.reflect.GenericType::vectorOf(com.datastax.oss.driver.api.core.type.reflect.GenericType)", + "new": "method com.datastax.oss.driver.api.core.type.reflect.GenericType> com.datastax.oss.driver.api.core.type.reflect.GenericType::vectorOf(com.datastax.oss.driver.api.core.type.reflect.GenericType)", + "justification": "Refactorings in PR 1666" + } ] } } diff --git a/core/src/main/java/com/datastax/oss/driver/api/core/data/CqlVector.java b/core/src/main/java/com/datastax/oss/driver/api/core/data/CqlVector.java new file mode 100644 index 00000000000..152d0f40823 --- /dev/null +++ b/core/src/main/java/com/datastax/oss/driver/api/core/data/CqlVector.java @@ -0,0 +1,193 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.datastax.oss.driver.api.core.data; + +import com.datastax.oss.driver.api.core.type.codec.TypeCodec; +import com.datastax.oss.driver.shaded.guava.common.base.Preconditions; +import com.datastax.oss.driver.shaded.guava.common.base.Predicates; +import com.datastax.oss.driver.shaded.guava.common.base.Splitter; +import com.datastax.oss.driver.shaded.guava.common.collect.Iterables; +import com.datastax.oss.driver.shaded.guava.common.collect.Streams; +import edu.umd.cs.findbugs.annotations.NonNull; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; +import java.util.Objects; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +/** + * Representation of a vector as defined in CQL. + * + *

A CQL vector is a fixed-length array of non-null numeric values. These properties don't map + * cleanly to an existing class in the standard JDK Collections hierarchy so we provide this value + * object instead. Like other value object collections returned by the driver instances of this + * class are not immutable; think of these value objects as a representation of a vector stored in + * the database as an initial step in some additional computation. + * + *

While we don't implement any Collection APIs we do implement Iterable. We also attempt to play + * nice with the Streams API in order to better facilitate integration with data pipelines. Finally, + * where possible we've tried to make the API of this class similar to the equivalent methods on + * {@link List}. + */ +public class CqlVector implements Iterable { + + /** + * Create a new CqlVector containing the specified values. + * + * @param vals the collection of values to wrap. + * @return a CqlVector wrapping those values + */ + public static CqlVector newInstance(V... vals) { + + // Note that Array.asList() guarantees the return of an array which implements RandomAccess + return new CqlVector(Arrays.asList(vals)); + } + + /** + * Create a new CqlVector that "wraps" an existing ArrayList. Modifications to the passed + * ArrayList will also be reflected in the returned CqlVector. + * + * @param list the collection of values to wrap. + * @return a CqlVector wrapping those values + */ + public static CqlVector newInstance(List list) { + Preconditions.checkArgument(list != null, "Input list should not be null"); + return new CqlVector(list); + } + + /** + * Create a new CqlVector instance from the specified string representation. Note that this method + * is intended to mirror {@link #toString()}; passing this method the output from a toString + * call on some CqlVector should return a CqlVector that is equal to the origin instance. + * + * @param str a String representation of a CqlVector + * @param subtypeCodec + * @return a new CqlVector built from the String representation + */ + public static CqlVector from( + @NonNull String str, @NonNull TypeCodec subtypeCodec) { + Preconditions.checkArgument(str != null, "Cannot create CqlVector from null string"); + Preconditions.checkArgument(!str.isEmpty(), "Cannot create CqlVector from empty string"); + ArrayList vals = + Streams.stream(Splitter.on(", ").split(str.substring(1, str.length() - 1))) + .map(subtypeCodec::parse) + .collect(Collectors.toCollection(ArrayList::new)); + return new CqlVector(vals); + } + + private final List list; + + private CqlVector(@NonNull List list) { + + Preconditions.checkArgument( + Iterables.all(list, Predicates.notNull()), "CqlVectors cannot contain null values"); + this.list = list; + } + + /** + * Retrieve the value at the specified index. Modelled after {@link List#get(int)} + * + * @param idx the index to retrieve + * @return the value at the specified index + */ + public T get(int idx) { + return list.get(idx); + } + + /** + * Update the value at the specified index. Modelled after {@link List#set(int, Object)} + * + * @param idx the index to set + * @param val the new value for the specified index + * @return the old value for the specified index + */ + public T set(int idx, T val) { + return list.set(idx, val); + } + + /** + * Return the size of this vector. Modelled after {@link List#size()} + * + * @return the vector size + */ + public int size() { + return this.list.size(); + } + + /** + * Return a CqlVector consisting of the contents of a portion of this vector. Modelled after + * {@link List#subList(int, int)} + * + * @param from the index to start from (inclusive) + * @param to the index to end on (exclusive) + * @return a new CqlVector wrapping the sublist + */ + public CqlVector subVector(int from, int to) { + return new CqlVector(this.list.subList(from, to)); + } + + /** + * Return a boolean indicating whether the vector is empty. Modelled after {@link List#isEmpty()} + * + * @return true if the list is empty, false otherwise + */ + public boolean isEmpty() { + return this.list.isEmpty(); + } + + /** + * Create an {@link Iterator} for this vector + * + * @return the generated iterator + */ + @Override + public Iterator iterator() { + return this.list.iterator(); + } + + /** + * Create a {@link Stream} of the values in this vector + * + * @return the Stream instance + */ + public Stream stream() { + return this.list.stream(); + } + + @Override + public boolean equals(Object o) { + if (o == this) { + return true; + } else if (o instanceof CqlVector) { + CqlVector that = (CqlVector) o; + return this.list.equals(that.list); + } else { + return false; + } + } + + @Override + public int hashCode() { + return Objects.hash(list); + } + + @Override + public String toString() { + return Iterables.toString(this.list); + } +} diff --git a/core/src/main/java/com/datastax/oss/driver/api/core/data/GettableById.java b/core/src/main/java/com/datastax/oss/driver/api/core/data/GettableById.java index 6c6cf95a568..1b4197667e9 100644 --- a/core/src/main/java/com/datastax/oss/driver/api/core/data/GettableById.java +++ b/core/src/main/java/com/datastax/oss/driver/api/core/data/GettableById.java @@ -529,7 +529,7 @@ default CqlDuration getCqlDuration(@NonNull CqlIdentifier id) { * @throws IllegalArgumentException if the id is invalid. */ @Nullable - default List getVector( + default CqlVector getVector( @NonNull CqlIdentifier id, @NonNull Class elementsClass) { return getVector(firstIndexOf(id), elementsClass); } diff --git a/core/src/main/java/com/datastax/oss/driver/api/core/data/GettableByIndex.java b/core/src/main/java/com/datastax/oss/driver/api/core/data/GettableByIndex.java index a805342defc..0efb003ca24 100644 --- a/core/src/main/java/com/datastax/oss/driver/api/core/data/GettableByIndex.java +++ b/core/src/main/java/com/datastax/oss/driver/api/core/data/GettableByIndex.java @@ -444,8 +444,9 @@ default CqlDuration getCqlDuration(int i) { * @throws IndexOutOfBoundsException if the index is invalid. */ @Nullable - default List getVector(int i, @NonNull Class elementsClass) { - return get(i, GenericType.listOf(elementsClass)); + default CqlVector getVector( + int i, @NonNull Class elementsClass) { + return get(i, GenericType.vectorOf(elementsClass)); } /** diff --git a/core/src/main/java/com/datastax/oss/driver/api/core/data/GettableByName.java b/core/src/main/java/com/datastax/oss/driver/api/core/data/GettableByName.java index 3214994c04a..377f8292002 100644 --- a/core/src/main/java/com/datastax/oss/driver/api/core/data/GettableByName.java +++ b/core/src/main/java/com/datastax/oss/driver/api/core/data/GettableByName.java @@ -525,9 +525,9 @@ default CqlDuration getCqlDuration(@NonNull String name) { * @throws IllegalArgumentException if the name is invalid. */ @Nullable - default List getVector( + default CqlVector getVector( @NonNull String name, @NonNull Class elementsClass) { - return getList(firstIndexOf(name), elementsClass); + return getVector(firstIndexOf(name), elementsClass); } /** diff --git a/core/src/main/java/com/datastax/oss/driver/api/core/data/SettableById.java b/core/src/main/java/com/datastax/oss/driver/api/core/data/SettableById.java index 3c17f0cb6f1..84055b0e964 100644 --- a/core/src/main/java/com/datastax/oss/driver/api/core/data/SettableById.java +++ b/core/src/main/java/com/datastax/oss/driver/api/core/data/SettableById.java @@ -571,9 +571,9 @@ default SelfT setCqlDuration(@NonNull CqlIdentifier id, @Nullable CqlDuration v) */ @NonNull @CheckReturnValue - default SelfT setVector( + default SelfT setVector( @NonNull CqlIdentifier id, - @Nullable List v, + @Nullable CqlVector v, @NonNull Class elementsClass) { SelfT result = null; for (Integer i : allIndicesOf(id)) { diff --git a/core/src/main/java/com/datastax/oss/driver/api/core/data/SettableByIndex.java b/core/src/main/java/com/datastax/oss/driver/api/core/data/SettableByIndex.java index 52bc92d4c09..01e6d5cdf58 100644 --- a/core/src/main/java/com/datastax/oss/driver/api/core/data/SettableByIndex.java +++ b/core/src/main/java/com/datastax/oss/driver/api/core/data/SettableByIndex.java @@ -423,9 +423,9 @@ default SelfT setCqlDuration(int i, @Nullable CqlDuration v) { */ @NonNull @CheckReturnValue - default SelfT setVector( - int i, @Nullable List v, @NonNull Class elementsClass) { - return set(i, v, GenericType.listOf(elementsClass)); + default SelfT setVector( + int i, @Nullable CqlVector v, @NonNull Class elementsClass) { + return set(i, v, GenericType.vectorOf(elementsClass)); } /** diff --git a/core/src/main/java/com/datastax/oss/driver/api/core/data/SettableByName.java b/core/src/main/java/com/datastax/oss/driver/api/core/data/SettableByName.java index 559ad40cbff..a78753789e3 100644 --- a/core/src/main/java/com/datastax/oss/driver/api/core/data/SettableByName.java +++ b/core/src/main/java/com/datastax/oss/driver/api/core/data/SettableByName.java @@ -570,8 +570,10 @@ default SelfT setCqlDuration(@NonNull String name, @Nullable CqlDuration v) { */ @NonNull @CheckReturnValue - default SelfT setVector( - @NonNull String name, @Nullable List v, @NonNull Class elementsClass) { + default SelfT setVector( + @NonNull String name, + @Nullable CqlVector v, + @NonNull Class elementsClass) { SelfT result = null; for (Integer i : allIndicesOf(name)) { result = (result == null ? this : result).setVector(i, v, elementsClass); diff --git a/core/src/main/java/com/datastax/oss/driver/api/core/type/codec/ExtraTypeCodecs.java b/core/src/main/java/com/datastax/oss/driver/api/core/type/codec/ExtraTypeCodecs.java index 6bf044ebf03..65571e01f75 100644 --- a/core/src/main/java/com/datastax/oss/driver/api/core/type/codec/ExtraTypeCodecs.java +++ b/core/src/main/java/com/datastax/oss/driver/api/core/type/codec/ExtraTypeCodecs.java @@ -16,8 +16,10 @@ package com.datastax.oss.driver.api.core.type.codec; import com.datastax.oss.driver.api.core.session.SessionBuilder; +import com.datastax.oss.driver.api.core.type.DataTypes; import com.datastax.oss.driver.api.core.type.codec.registry.MutableCodecRegistry; import com.datastax.oss.driver.api.core.type.reflect.GenericType; +import com.datastax.oss.driver.internal.core.type.DefaultVectorType; import com.datastax.oss.driver.internal.core.type.codec.SimpleBlobCodec; import com.datastax.oss.driver.internal.core.type.codec.TimestampCodec; import com.datastax.oss.driver.internal.core.type.codec.extras.OptionalCodec; @@ -36,6 +38,7 @@ import com.datastax.oss.driver.internal.core.type.codec.extras.time.PersistentZonedTimestampCodec; import com.datastax.oss.driver.internal.core.type.codec.extras.time.TimestampMillisCodec; import com.datastax.oss.driver.internal.core.type.codec.extras.time.ZonedTimestampCodec; +import com.datastax.oss.driver.internal.core.type.codec.extras.vector.FloatVectorToArrayCodec; import com.fasterxml.jackson.databind.ObjectMapper; import edu.umd.cs.findbugs.annotations.NonNull; import java.nio.ByteBuffer; @@ -479,4 +482,9 @@ public static TypeCodec json( @NonNull Class javaType, @NonNull ObjectMapper objectMapper) { return new JsonCodec<>(javaType, objectMapper); } + + /** Builds a new codec that maps CQL float vectors of the specified size to an array of floats. */ + public static TypeCodec floatVectorToArray(int dimensions) { + return new FloatVectorToArrayCodec(new DefaultVectorType(DataTypes.FLOAT, dimensions)); + } } diff --git a/core/src/main/java/com/datastax/oss/driver/api/core/type/codec/TypeCodecs.java b/core/src/main/java/com/datastax/oss/driver/api/core/type/codec/TypeCodecs.java index e824e7f41fc..d4cf3ddb0c0 100644 --- a/core/src/main/java/com/datastax/oss/driver/api/core/type/codec/TypeCodecs.java +++ b/core/src/main/java/com/datastax/oss/driver/api/core/type/codec/TypeCodecs.java @@ -16,6 +16,7 @@ package com.datastax.oss.driver.api.core.type.codec; import com.datastax.oss.driver.api.core.data.CqlDuration; +import com.datastax.oss.driver.api.core.data.CqlVector; import com.datastax.oss.driver.api.core.data.TupleValue; import com.datastax.oss.driver.api.core.data.UdtValue; import com.datastax.oss.driver.api.core.type.CustomType; @@ -207,12 +208,17 @@ public static TypeCodec tupleOf(@NonNull TupleType cqlType) { return new TupleCodec(cqlType); } - public static TypeCodec> vectorOf( + public static TypeCodec> vectorOf( @NonNull VectorType type, @NonNull TypeCodec subtypeCodec) { return new VectorCodec( DataTypes.vectorOf(subtypeCodec.getCqlType(), type.getDimensions()), subtypeCodec); } + public static TypeCodec> vectorOf( + int dimensions, @NonNull TypeCodec subtypeCodec) { + return new VectorCodec(DataTypes.vectorOf(subtypeCodec.getCqlType(), dimensions), subtypeCodec); + } + /** * Builds a new codec that maps a CQL user defined type to the driver's {@link UdtValue}, for the * given type definition. diff --git a/core/src/main/java/com/datastax/oss/driver/api/core/type/reflect/GenericType.java b/core/src/main/java/com/datastax/oss/driver/api/core/type/reflect/GenericType.java index a1977e39f23..350e869ae52 100644 --- a/core/src/main/java/com/datastax/oss/driver/api/core/type/reflect/GenericType.java +++ b/core/src/main/java/com/datastax/oss/driver/api/core/type/reflect/GenericType.java @@ -16,6 +16,7 @@ package com.datastax.oss.driver.api.core.type.reflect; import com.datastax.oss.driver.api.core.data.CqlDuration; +import com.datastax.oss.driver.api.core.data.CqlVector; import com.datastax.oss.driver.api.core.data.GettableByIndex; import com.datastax.oss.driver.api.core.data.TupleValue; import com.datastax.oss.driver.api.core.data.UdtValue; @@ -147,6 +148,23 @@ public static GenericType> setOf(@NonNull GenericType elementType) return new GenericType<>(token); } + @NonNull + public static GenericType> vectorOf( + @NonNull Class elementType) { + TypeToken> token = + new TypeToken>() {}.where( + new TypeParameter() {}, TypeToken.of(elementType)); + return new GenericType<>(token); + } + + @NonNull + public static GenericType> vectorOf( + @NonNull GenericType elementType) { + TypeToken> token = + new TypeToken>() {}.where(new TypeParameter() {}, elementType.token); + return new GenericType<>(token); + } + @NonNull public static GenericType> mapOf( @NonNull Class keyType, @NonNull Class valueType) { diff --git a/core/src/main/java/com/datastax/oss/driver/internal/core/type/codec/VectorCodec.java b/core/src/main/java/com/datastax/oss/driver/internal/core/type/codec/VectorCodec.java index 75b3e46ddfd..a94ae728725 100644 --- a/core/src/main/java/com/datastax/oss/driver/internal/core/type/codec/VectorCodec.java +++ b/core/src/main/java/com/datastax/oss/driver/internal/core/type/codec/VectorCodec.java @@ -16,13 +16,13 @@ package com.datastax.oss.driver.internal.core.type.codec; import com.datastax.oss.driver.api.core.ProtocolVersion; +import com.datastax.oss.driver.api.core.data.CqlVector; import com.datastax.oss.driver.api.core.type.DataType; import com.datastax.oss.driver.api.core.type.VectorType; import com.datastax.oss.driver.api.core.type.codec.TypeCodec; import com.datastax.oss.driver.api.core.type.reflect.GenericType; -import com.datastax.oss.driver.shaded.guava.common.base.Splitter; +import com.datastax.oss.driver.internal.core.type.DefaultVectorType; import com.datastax.oss.driver.shaded.guava.common.collect.Iterables; -import com.datastax.oss.driver.shaded.guava.common.collect.Streams; import edu.umd.cs.findbugs.annotations.NonNull; import edu.umd.cs.findbugs.annotations.Nullable; import java.nio.ByteBuffer; @@ -30,23 +30,26 @@ import java.util.Iterator; import java.util.List; import java.util.NoSuchElementException; -import java.util.stream.Collectors; -public class VectorCodec implements TypeCodec> { +public class VectorCodec implements TypeCodec> { private final VectorType cqlType; - private final GenericType> javaType; + private final GenericType> javaType; private final TypeCodec subtypeCodec; - public VectorCodec(VectorType cqlType, TypeCodec subtypeCodec) { + public VectorCodec(@NonNull VectorType cqlType, @NonNull TypeCodec subtypeCodec) { this.cqlType = cqlType; this.subtypeCodec = subtypeCodec; - this.javaType = GenericType.listOf(subtypeCodec.getJavaType()); + this.javaType = GenericType.vectorOf(subtypeCodec.getJavaType()); + } + + public VectorCodec(int dimensions, @NonNull TypeCodec subtypeCodec) { + this(new DefaultVectorType(subtypeCodec.getCqlType(), dimensions), subtypeCodec); } @NonNull @Override - public GenericType> getJavaType() { + public GenericType> getJavaType() { return this.javaType; } @@ -59,7 +62,7 @@ public DataType getCqlType() { @Nullable @Override public ByteBuffer encode( - @Nullable List value, @NonNull ProtocolVersion protocolVersion) { + @Nullable CqlVector value, @NonNull ProtocolVersion protocolVersion) { if (value == null || cqlType.getDimensions() <= 0) { return null; } @@ -103,7 +106,7 @@ public ByteBuffer encode( @Nullable @Override - public List decode( + public CqlVector decode( @Nullable ByteBuffer bytes, @NonNull ProtocolVersion protocolVersion) { if (bytes == null || bytes.remaining() == 0) { return null; @@ -133,27 +136,20 @@ Elements should at least precede themselves with their size (along the lines of /* Restore the input ByteBuffer to its original state */ bytes.rewind(); - return rv; + return CqlVector.newInstance(rv); } @NonNull @Override - public String format(@Nullable List value) { + public String format(@Nullable CqlVector value) { return value == null ? "NULL" : Iterables.toString(value); } @Nullable @Override - public List parse(@Nullable String value) { + public CqlVector parse(@Nullable String value) { return (value == null || value.isEmpty() || value.equalsIgnoreCase("NULL")) ? null - : this.from(value); - } - - private List from(@Nullable String value) { - - return Streams.stream(Splitter.on(", ").split(value.substring(1, value.length() - 1))) - .map(subtypeCodec::parse) - .collect(Collectors.toCollection(ArrayList::new)); + : CqlVector.from(value, this.subtypeCodec); } } diff --git a/core/src/main/java/com/datastax/oss/driver/internal/core/type/codec/extras/vector/AbstractVectorToArrayCodec.java b/core/src/main/java/com/datastax/oss/driver/internal/core/type/codec/extras/vector/AbstractVectorToArrayCodec.java new file mode 100644 index 00000000000..79db9f6bc8a --- /dev/null +++ b/core/src/main/java/com/datastax/oss/driver/internal/core/type/codec/extras/vector/AbstractVectorToArrayCodec.java @@ -0,0 +1,140 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.datastax.oss.driver.internal.core.type.codec.extras.vector; + +import com.datastax.oss.driver.api.core.ProtocolVersion; +import com.datastax.oss.driver.api.core.type.DataType; +import com.datastax.oss.driver.api.core.type.VectorType; +import com.datastax.oss.driver.api.core.type.codec.TypeCodec; +import com.datastax.oss.driver.api.core.type.reflect.GenericType; +import edu.umd.cs.findbugs.annotations.NonNull; +import edu.umd.cs.findbugs.annotations.Nullable; +import java.lang.reflect.Array; +import java.nio.ByteBuffer; +import java.util.Objects; + +/** Common super-class for all codecs which map a CQL vector type onto a primitive array */ +public abstract class AbstractVectorToArrayCodec implements TypeCodec { + + @NonNull protected final VectorType cqlType; + @NonNull protected final GenericType javaType; + + /** + * @param cqlType The CQL type. Must be a list type. + * @param arrayType The Java type. Must be an array class. + */ + protected AbstractVectorToArrayCodec( + @NonNull VectorType cqlType, @NonNull GenericType arrayType) { + this.cqlType = Objects.requireNonNull(cqlType, "cqlType cannot be null"); + this.javaType = Objects.requireNonNull(arrayType, "arrayType cannot be null"); + if (!arrayType.isArray()) { + throw new IllegalArgumentException("Expecting Java array class, got " + arrayType); + } + } + + @NonNull + @Override + public GenericType getJavaType() { + return this.javaType; + } + + @NonNull + @Override + public DataType getCqlType() { + return this.cqlType; + } + + @Nullable + @Override + public ByteBuffer encode(@Nullable ArrayT array, @NonNull ProtocolVersion protocolVersion) { + if (array == null) { + return null; + } + int length = Array.getLength(array); + int totalSize = length * sizeOfComponentType(); + ByteBuffer output = ByteBuffer.allocate(totalSize); + for (int i = 0; i < length; i++) { + serializeElement(output, array, i, protocolVersion); + } + output.flip(); + return output; + } + + @Nullable + @Override + public ArrayT decode(@Nullable ByteBuffer bytes, @NonNull ProtocolVersion protocolVersion) { + if (bytes == null || bytes.remaining() == 0) { + throw new IllegalArgumentException( + "Input ByteBuffer must not be null and must have non-zero remaining bytes"); + } + ByteBuffer input = bytes.duplicate(); + int length = this.cqlType.getDimensions(); + int elementSize = sizeOfComponentType(); + ArrayT array = newInstance(); + for (int i = 0; i < length; i++) { + // Null elements can happen on the decode path, but we cannot tolerate them + if (elementSize < 0) { + throw new NullPointerException("Primitive arrays cannot store null elements"); + } else { + deserializeElement(input, array, i, protocolVersion); + } + } + return array; + } + + /** + * Creates a new array instance with a size matching the specified vector. + * + * @return a new array instance with a size matching the specified vector. + */ + @NonNull + protected abstract ArrayT newInstance(); + + /** + * Return the size in bytes of the array component type. + * + * @return the size in bytes of the array component type. + */ + protected abstract int sizeOfComponentType(); + + /** + * Write the {@code index}th element of {@code array} to {@code output}. + * + * @param output The ByteBuffer to write to. + * @param array The array to read from. + * @param index The element index. + * @param protocolVersion The protocol version to use. + */ + protected abstract void serializeElement( + @NonNull ByteBuffer output, + @NonNull ArrayT array, + int index, + @NonNull ProtocolVersion protocolVersion); + + /** + * Read the {@code index}th element of {@code array} from {@code input}. + * + * @param input The ByteBuffer to read from. + * @param array The array to write to. + * @param index The element index. + * @param protocolVersion The protocol version to use. + */ + protected abstract void deserializeElement( + @NonNull ByteBuffer input, + @NonNull ArrayT array, + int index, + @NonNull ProtocolVersion protocolVersion); +} diff --git a/core/src/main/java/com/datastax/oss/driver/internal/core/type/codec/extras/vector/FloatVectorToArrayCodec.java b/core/src/main/java/com/datastax/oss/driver/internal/core/type/codec/extras/vector/FloatVectorToArrayCodec.java new file mode 100644 index 00000000000..80c035e96d3 --- /dev/null +++ b/core/src/main/java/com/datastax/oss/driver/internal/core/type/codec/extras/vector/FloatVectorToArrayCodec.java @@ -0,0 +1,105 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.datastax.oss.driver.internal.core.type.codec.extras.vector; + +import com.datastax.oss.driver.api.core.ProtocolVersion; +import com.datastax.oss.driver.api.core.type.VectorType; +import com.datastax.oss.driver.api.core.type.reflect.GenericType; +import com.datastax.oss.driver.internal.core.type.codec.FloatCodec; +import com.datastax.oss.driver.shaded.guava.common.base.Preconditions; +import com.datastax.oss.driver.shaded.guava.common.base.Splitter; +import edu.umd.cs.findbugs.annotations.NonNull; +import edu.umd.cs.findbugs.annotations.Nullable; +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.Iterator; +import java.util.Objects; + +/** A codec that maps CQL vectors to the Java type {@code float[]}. */ +public class FloatVectorToArrayCodec extends AbstractVectorToArrayCodec { + + public FloatVectorToArrayCodec(VectorType type) { + super(type, GenericType.of(float[].class)); + } + + @Override + public boolean accepts(@NonNull Class javaClass) { + Objects.requireNonNull(javaClass); + return float[].class.equals(javaClass); + } + + @Override + public boolean accepts(@NonNull Object value) { + Objects.requireNonNull(value); + return value instanceof float[]; + } + + @NonNull + @Override + protected float[] newInstance() { + return new float[cqlType.getDimensions()]; + } + + @Override + protected int sizeOfComponentType() { + return 4; + } + + @Override + protected void serializeElement( + @NonNull ByteBuffer output, + @NonNull float[] array, + int index, + @NonNull ProtocolVersion protocolVersion) { + output.putFloat(array[index]); + } + + @Override + protected void deserializeElement( + @NonNull ByteBuffer input, + @NonNull float[] array, + int index, + @NonNull ProtocolVersion protocolVersion) { + array[index] = input.getFloat(); + } + + @NonNull + @Override + public String format(@Nullable float[] value) { + return value == null ? "NULL" : Arrays.toString(value); + } + + @Nullable + @Override + public float[] parse(@Nullable String str) { + Preconditions.checkArgument(str != null, "Cannot create float array from null string"); + Preconditions.checkArgument(!str.isEmpty(), "Cannot create float array from empty string"); + + FloatCodec codec = new FloatCodec(); + float[] rv = this.newInstance(); + Iterator strIter = + Splitter.on(", ").trimResults().split(str.substring(1, str.length() - 1)).iterator(); + for (int i = 0; i < rv.length; ++i) { + String strVal = strIter.next(); + if (strVal == null) { + throw new IllegalArgumentException("Null element observed in float array string"); + } + Float f = codec.parse(strVal); + rv[i] = f.floatValue(); + } + return rv; + } +} diff --git a/core/src/main/java/com/datastax/oss/driver/internal/core/type/codec/registry/CachingCodecRegistry.java b/core/src/main/java/com/datastax/oss/driver/internal/core/type/codec/registry/CachingCodecRegistry.java index ca282c3e355..cb5d45255e1 100644 --- a/core/src/main/java/com/datastax/oss/driver/internal/core/type/codec/registry/CachingCodecRegistry.java +++ b/core/src/main/java/com/datastax/oss/driver/internal/core/type/codec/registry/CachingCodecRegistry.java @@ -16,6 +16,7 @@ package com.datastax.oss.driver.internal.core.type.codec.registry; import com.datastax.oss.driver.api.core.data.CqlDuration; +import com.datastax.oss.driver.api.core.data.CqlVector; import com.datastax.oss.driver.api.core.data.TupleValue; import com.datastax.oss.driver.api.core.data.UdtValue; import com.datastax.oss.driver.api.core.type.*; @@ -371,6 +372,23 @@ protected GenericType inspectType(@NonNull Object value, @Nullable DataType c inspectType(firstValue, cqlType == null ? null : ((MapType) cqlType).getValueType()); return GenericType.mapOf(keyType, valueType); } + } else if (value instanceof CqlVector) { + CqlVector vector = (CqlVector) value; + if (vector.isEmpty()) { + return cqlType == null ? JAVA_TYPE_FOR_EMPTY_CQLVECTORS : inferJavaTypeFromCqlType(cqlType); + } else { + Object firstElement = vector.iterator().next(); + if (firstElement == null) { + throw new IllegalArgumentException( + "Can't infer vector codec because the first element is null " + + "(note that CQL does not allow null values in collections)"); + } + GenericType elementType = + (GenericType) + inspectType( + firstElement, cqlType == null ? null : ((VectorType) cqlType).getElementType()); + return GenericType.vectorOf(elementType); + } } else { // There's not much more we can do return GenericType.of(value.getClass()); @@ -390,6 +408,11 @@ protected GenericType inferJavaTypeFromCqlType(@NonNull DataType cqlType) { DataType valueType = ((MapType) cqlType).getValueType(); return GenericType.mapOf( inferJavaTypeFromCqlType(keyType), inferJavaTypeFromCqlType(valueType)); + } else if (cqlType instanceof VectorType) { + DataType elementType = ((VectorType) cqlType).getElementType(); + GenericType numberType = + (GenericType) inferJavaTypeFromCqlType(elementType); + return GenericType.vectorOf(numberType); } switch (cqlType.getProtocolCode()) { case ProtocolConstants.DataType.CUSTOM: @@ -492,6 +515,22 @@ protected DataType inferCqlTypeFromValue(@NonNull Object value) { return null; } return DataTypes.mapOf(keyType, valueType); + } else if (value instanceof CqlVector) { + CqlVector vector = (CqlVector) value; + if (vector.isEmpty()) { + return CQL_TYPE_FOR_EMPTY_VECTORS; + } + Object firstElement = vector.iterator().next(); + if (firstElement == null) { + throw new IllegalArgumentException( + "Can't infer vector codec because the first element is null " + + "(note that CQL does not allow null values in collections)"); + } + DataType elementType = inferCqlTypeFromValue(firstElement); + if (elementType == null) { + return null; + } + return DataTypes.vectorOf(elementType, vector.size()); } Class javaClass = value.getClass(); if (ByteBuffer.class.isAssignableFrom(javaClass)) { @@ -538,7 +577,7 @@ protected DataType inferCqlTypeFromValue(@NonNull Object value) { return null; } - private TypeCodec getElementCodec( + private TypeCodec getElementCodecForCqlAndJavaType( ContainerType cqlType, TypeToken token, boolean isJavaCovariant) { DataType elementCqlType = cqlType.getElementType(); @@ -550,6 +589,14 @@ private TypeCodec getElementCodec( return codecFor(elementCqlType); } + private TypeCodec getElementCodecForJavaType( + ParameterizedType parameterizedType, boolean isJavaCovariant) { + + Type[] typeArguments = parameterizedType.getActualTypeArguments(); + GenericType elementType = GenericType.of(typeArguments[0]); + return codecFor(elementType, isJavaCovariant); + } + // Try to create a codec when we haven't found it in the cache @NonNull protected TypeCodec createCodec( @@ -565,11 +612,11 @@ protected TypeCodec createCodec( TypeToken token = javaType.__getToken(); if (cqlType instanceof ListType && List.class.isAssignableFrom(token.getRawType())) { TypeCodec elementCodec = - getElementCodec((ContainerType) cqlType, token, isJavaCovariant); + getElementCodecForCqlAndJavaType((ContainerType) cqlType, token, isJavaCovariant); return TypeCodecs.listOf(elementCodec); } else if (cqlType instanceof SetType && Set.class.isAssignableFrom(token.getRawType())) { TypeCodec elementCodec = - getElementCodec((ContainerType) cqlType, token, isJavaCovariant); + getElementCodecForCqlAndJavaType((ContainerType) cqlType, token, isJavaCovariant); return TypeCodecs.setOf(elementCodec); } else if (cqlType instanceof MapType && Map.class.isAssignableFrom(token.getRawType())) { DataType keyCqlType = ((MapType) cqlType).getKeyType(); @@ -593,9 +640,14 @@ protected TypeCodec createCodec( } else if (cqlType instanceof UserDefinedType && UdtValue.class.isAssignableFrom(token.getRawType())) { return TypeCodecs.udtOf((UserDefinedType) cqlType); - } else if (cqlType instanceof VectorType && List.class.isAssignableFrom(token.getRawType())) { + } else if (cqlType instanceof VectorType + && CqlVector.class.isAssignableFrom(token.getRawType())) { VectorType vectorType = (VectorType) cqlType; - TypeCodec elementCodec = getElementCodec(vectorType, token, isJavaCovariant); + /* For a vector type we'll always get back an instance of TypeCodec due to the + * type of CqlVector... but getElementCodecForCqlAndJavaType() is a generalized function that can't + * return this more precise type. Thus the cast here. */ + TypeCodec elementCodec = + uncheckedCast(getElementCodecForCqlAndJavaType(vectorType, token, isJavaCovariant)); return TypeCodecs.vectorOf(vectorType, elementCodec); } else if (cqlType instanceof CustomType && ByteBuffer.class.isAssignableFrom(token.getRawType())) { @@ -612,15 +664,13 @@ protected TypeCodec createCodec(@NonNull GenericType javaType, boolean isJ TypeToken token = javaType.__getToken(); if (List.class.isAssignableFrom(token.getRawType()) && token.getType() instanceof ParameterizedType) { - Type[] typeArguments = ((ParameterizedType) token.getType()).getActualTypeArguments(); - GenericType elementType = GenericType.of(typeArguments[0]); - TypeCodec elementCodec = codecFor(elementType, isJavaCovariant); + TypeCodec elementCodec = + getElementCodecForJavaType((ParameterizedType) token.getType(), isJavaCovariant); return TypeCodecs.listOf(elementCodec); } else if (Set.class.isAssignableFrom(token.getRawType()) && token.getType() instanceof ParameterizedType) { - Type[] typeArguments = ((ParameterizedType) token.getType()).getActualTypeArguments(); - GenericType elementType = GenericType.of(typeArguments[0]); - TypeCodec elementCodec = codecFor(elementType, isJavaCovariant); + TypeCodec elementCodec = + getElementCodecForJavaType((ParameterizedType) token.getType(), isJavaCovariant); return TypeCodecs.setOf(elementCodec); } else if (Map.class.isAssignableFrom(token.getRawType()) && token.getType() instanceof ParameterizedType) { @@ -631,6 +681,9 @@ protected TypeCodec createCodec(@NonNull GenericType javaType, boolean isJ TypeCodec valueCodec = codecFor(valueType, isJavaCovariant); return TypeCodecs.mapOf(keyCodec, valueCodec); } + /* Note that this method cannot generate TypeCodec instances for any CqlVector type. VectorCodec needs + * to know the dimensions of the vector it will be operating on and there's no way to determine that from + * the Java type alone. */ throw new CodecNotFoundException(null, javaType); } @@ -652,6 +705,11 @@ protected TypeCodec createCodec(@NonNull DataType cqlType) { TypeCodec keyCodec = codecFor(keyType); TypeCodec valueCodec = codecFor(valueType); return TypeCodecs.mapOf(keyCodec, valueCodec); + } else if (cqlType instanceof VectorType) { + VectorType vectorType = (VectorType) cqlType; + TypeCodec elementCodec = + uncheckedCast(codecFor(vectorType.getElementType())); + return TypeCodecs.vectorOf(vectorType, elementCodec); } else if (cqlType instanceof TupleType) { return TypeCodecs.tupleOf((TupleType) cqlType); } else if (cqlType instanceof UserDefinedType) { @@ -687,8 +745,11 @@ private static TypeCodec uncheckedCast( GenericType.setOf(Boolean.class); private static final GenericType> JAVA_TYPE_FOR_EMPTY_MAPS = GenericType.mapOf(Boolean.class, Boolean.class); + private static final GenericType> JAVA_TYPE_FOR_EMPTY_CQLVECTORS = + GenericType.vectorOf(Number.class); private static final DataType CQL_TYPE_FOR_EMPTY_LISTS = DataTypes.listOf(DataTypes.BOOLEAN); private static final DataType CQL_TYPE_FOR_EMPTY_SETS = DataTypes.setOf(DataTypes.BOOLEAN); private static final DataType CQL_TYPE_FOR_EMPTY_MAPS = DataTypes.mapOf(DataTypes.BOOLEAN, DataTypes.BOOLEAN); + private static final DataType CQL_TYPE_FOR_EMPTY_VECTORS = DataTypes.vectorOf(DataTypes.INT, 0); } diff --git a/core/src/test/java/com/datastax/oss/driver/api/core/data/CqlVectorTest.java b/core/src/test/java/com/datastax/oss/driver/api/core/data/CqlVectorTest.java new file mode 100644 index 00000000000..ecf8f1249d0 --- /dev/null +++ b/core/src/test/java/com/datastax/oss/driver/api/core/data/CqlVectorTest.java @@ -0,0 +1,198 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.datastax.oss.driver.api.core.data; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import com.datastax.oss.driver.api.core.type.codec.TypeCodecs; +import com.datastax.oss.driver.shaded.guava.common.collect.Iterators; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; +import org.assertj.core.util.Lists; +import org.junit.Test; + +public class CqlVectorTest { + + private static final Float[] VECTOR_ARGS = {1.0f, 2.5f}; + + private void validate_built_vector(CqlVector vec) { + + assertThat(vec.size()).isEqualTo(2); + assertThat(vec.isEmpty()).isFalse(); + assertThat(vec.get(0)).isEqualTo(VECTOR_ARGS[0]); + assertThat(vec.get(1)).isEqualTo(VECTOR_ARGS[1]); + } + + @Test + public void should_build_vector_from_elements() { + + validate_built_vector(CqlVector.newInstance(VECTOR_ARGS)); + } + + @Test + public void should_build_vector_from_list() { + + validate_built_vector(CqlVector.newInstance(Lists.newArrayList(VECTOR_ARGS))); + } + + @Test + public void should_build_vector_from_tostring_output() { + + CqlVector vector1 = CqlVector.newInstance(VECTOR_ARGS); + CqlVector vector2 = CqlVector.from(vector1.toString(), TypeCodecs.FLOAT); + assertThat(vector2).isEqualTo(vector1); + } + + @Test + public void should_throw_from_null_string() { + + assertThatThrownBy( + () -> { + CqlVector.from(null, TypeCodecs.FLOAT); + }) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void should_throw_from_empty_string() { + + assertThatThrownBy( + () -> { + CqlVector.from("", TypeCodecs.FLOAT); + }) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void should_throw_when_building_with_nulls() { + + assertThatThrownBy( + () -> { + CqlVector.newInstance(1.1f, null, 2.2f); + }) + .isInstanceOf(IllegalArgumentException.class); + + Float[] theArray = new Float[] {1.1f, null, 2.2f}; + assertThatThrownBy( + () -> { + CqlVector.newInstance(theArray); + }) + .isInstanceOf(IllegalArgumentException.class); + + List theList = Lists.newArrayList(1.1f, null, 2.2f); + assertThatThrownBy( + () -> { + CqlVector.newInstance(theList); + }) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void should_build_empty_vector() { + + CqlVector vector = CqlVector.newInstance(); + assertThat(vector.isEmpty()).isTrue(); + assertThat(vector.size()).isEqualTo(0); + } + + @Test + public void should_behave_mostly_like_a_list() { + + CqlVector vector = CqlVector.newInstance(VECTOR_ARGS); + assertThat(vector.get(0)).isEqualTo(VECTOR_ARGS[0]); + Float newVal = VECTOR_ARGS[0] * 2; + vector.set(0, newVal); + assertThat(vector.get(0)).isEqualTo(newVal); + assertThat(vector.isEmpty()).isFalse(); + assertThat(vector.size()).isEqualTo(2); + assertThat(Iterators.toArray(vector.iterator(), Float.class)).isEqualTo(VECTOR_ARGS); + } + + @Test + public void should_play_nicely_with_streams() { + + CqlVector vector = CqlVector.newInstance(VECTOR_ARGS); + List results = + vector.stream() + .map((f) -> f * 2) + .collect(Collectors.toCollection(() -> new ArrayList())); + for (int i = 0; i < vector.size(); ++i) { + assertThat(results.get(i)).isEqualTo(vector.get(i) * 2); + } + } + + @Test + public void should_reflect_changes_to_mutable_list() { + + List theList = Lists.newArrayList(1.1f, 2.2f, 3.3f); + CqlVector vector = CqlVector.newInstance(theList); + assertThat(vector.size()).isEqualTo(3); + assertThat(vector.get(2)).isEqualTo(3.3f); + + float newVal1 = 4.4f; + theList.set(2, newVal1); + assertThat(vector.size()).isEqualTo(3); + assertThat(vector.get(2)).isEqualTo(newVal1); + + float newVal2 = 5.5f; + theList.add(newVal2); + assertThat(vector.size()).isEqualTo(4); + assertThat(vector.get(3)).isEqualTo(newVal2); + } + + @Test + public void should_reflect_changes_to_array() { + + Float[] theArray = new Float[] {1.1f, 2.2f, 3.3f}; + CqlVector vector = CqlVector.newInstance(theArray); + assertThat(vector.size()).isEqualTo(3); + assertThat(vector.get(2)).isEqualTo(3.3f); + + float newVal1 = 4.4f; + theArray[2] = newVal1; + assertThat(vector.size()).isEqualTo(3); + assertThat(vector.get(2)).isEqualTo(newVal1); + } + + @Test + public void should_correctly_compare_vectors() { + + Float[] args = VECTOR_ARGS.clone(); + CqlVector vector1 = CqlVector.newInstance(args); + CqlVector vector2 = CqlVector.newInstance(args); + CqlVector vector3 = CqlVector.newInstance(Lists.newArrayList(args)); + assertThat(vector1).isNotSameAs(vector2); + assertThat(vector1).isEqualTo(vector2); + assertThat(vector1).isNotSameAs(vector3); + assertThat(vector1).isEqualTo(vector3); + + Float[] differentArgs = args.clone(); + float newVal = differentArgs[0] * 2; + differentArgs[0] = newVal; + CqlVector vector4 = CqlVector.newInstance(differentArgs); + assertThat(vector1).isNotSameAs(vector4); + assertThat(vector1).isNotEqualTo(vector4); + + Float[] biggerArgs = Arrays.copyOf(args, args.length + 1); + biggerArgs[biggerArgs.length - 1] = newVal; + CqlVector vector5 = CqlVector.newInstance(biggerArgs); + assertThat(vector1).isNotSameAs(vector5); + assertThat(vector1).isNotEqualTo(vector5); + } +} diff --git a/core/src/test/java/com/datastax/oss/driver/internal/core/type/codec/VectorCodecTest.java b/core/src/test/java/com/datastax/oss/driver/internal/core/type/codec/VectorCodecTest.java index 9b463dcb53e..82ec7b5ed67 100644 --- a/core/src/test/java/com/datastax/oss/driver/internal/core/type/codec/VectorCodecTest.java +++ b/core/src/test/java/com/datastax/oss/driver/internal/core/type/codec/VectorCodecTest.java @@ -18,18 +18,20 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; +import com.datastax.oss.driver.api.core.data.CqlVector; import com.datastax.oss.driver.api.core.type.DataTypes; import com.datastax.oss.driver.api.core.type.VectorType; import com.datastax.oss.driver.api.core.type.codec.TypeCodecs; import com.datastax.oss.driver.api.core.type.reflect.GenericType; import com.datastax.oss.driver.internal.core.type.DefaultVectorType; -import com.google.common.collect.Lists; -import java.util.List; +import java.util.Arrays; import org.junit.Test; -public class VectorCodecTest extends CodecTestBase> { +public class VectorCodecTest extends CodecTestBase> { - private static final List VECTOR = Lists.newArrayList(1.0f, 2.5f); + private static final Float[] VECTOR_ARGS = {1.0f, 2.5f}; + + private static final CqlVector VECTOR = CqlVector.newInstance(VECTOR_ARGS); private static final String VECTOR_HEX_STRING = "0x" + "3f800000" + "40200000"; @@ -49,21 +51,21 @@ public void should_encode() { /** Too few eleements will cause an exception, extra elements will be silently ignored */ @Test public void should_throw_on_encode_with_too_few_elements() { - assertThatThrownBy(() -> encode(VECTOR.subList(0, 1))) + assertThatThrownBy(() -> encode(VECTOR.subVector(0, 1))) .isInstanceOf(IllegalArgumentException.class); } @Test public void should_throw_on_encode_with_empty_list() { - assertThatThrownBy(() -> encode(Lists.newArrayList())) + assertThatThrownBy(() -> encode(CqlVector.newInstance())) .isInstanceOf(IllegalArgumentException.class); } @Test public void should_encode_with_too_many_elements() { - List doubleVector = Lists.newArrayList(VECTOR); - doubleVector.addAll(VECTOR); - assertThat(encode(doubleVector)).isEqualTo(VECTOR_HEX_STRING); + Float[] doubledVectorContents = Arrays.copyOf(VECTOR_ARGS, VECTOR_ARGS.length * 2); + System.arraycopy(VECTOR_ARGS, 0, doubledVectorContents, VECTOR_ARGS.length, VECTOR_ARGS.length); + assertThat(encode(CqlVector.newInstance(doubledVectorContents))).isEqualTo(VECTOR_HEX_STRING); } @Test @@ -118,14 +120,14 @@ public void should_accept_vector_type_correct_dimension_only() { @Test public void should_accept_generic_type() { - assertThat(codec.accepts(GenericType.listOf(GenericType.FLOAT))).isTrue(); - assertThat(codec.accepts(GenericType.listOf(GenericType.INTEGER))).isFalse(); + assertThat(codec.accepts(GenericType.vectorOf(GenericType.FLOAT))).isTrue(); + assertThat(codec.accepts(GenericType.vectorOf(GenericType.INTEGER))).isFalse(); assertThat(codec.accepts(GenericType.of(Integer.class))).isFalse(); } @Test public void should_accept_raw_type() { - assertThat(codec.accepts(List.class)).isTrue(); + assertThat(codec.accepts(CqlVector.class)).isTrue(); assertThat(codec.accepts(Integer.class)).isFalse(); } diff --git a/core/src/test/java/com/datastax/oss/driver/internal/core/type/codec/registry/CachingCodecRegistryTestDataProviders.java b/core/src/test/java/com/datastax/oss/driver/internal/core/type/codec/registry/CachingCodecRegistryTestDataProviders.java index 64bbd800c92..a0d0b77ca87 100644 --- a/core/src/test/java/com/datastax/oss/driver/internal/core/type/codec/registry/CachingCodecRegistryTestDataProviders.java +++ b/core/src/test/java/com/datastax/oss/driver/internal/core/type/codec/registry/CachingCodecRegistryTestDataProviders.java @@ -17,6 +17,7 @@ import com.datastax.oss.driver.api.core.CqlIdentifier; import com.datastax.oss.driver.api.core.data.CqlDuration; +import com.datastax.oss.driver.api.core.data.CqlVector; import com.datastax.oss.driver.api.core.data.TupleValue; import com.datastax.oss.driver.api.core.data.UdtValue; import com.datastax.oss.driver.api.core.type.DataTypes; @@ -285,6 +286,55 @@ public static Object[][] collectionsWithCqlAndJavaTypes() ImmutableMap.of( ImmutableMap.of(udtValue, udtValue), ImmutableMap.of(tupleValue, tupleValue)) }, + // vectors + { + DataTypes.vectorOf(DataTypes.INT, 1), + GenericType.vectorOf(Integer.class), + GenericType.vectorOf(Integer.class), + CqlVector.newInstance(1) + }, + { + DataTypes.vectorOf(DataTypes.BIGINT, 1), + GenericType.vectorOf(Long.class), + GenericType.vectorOf(Long.class), + CqlVector.newInstance(1l) + }, + { + DataTypes.vectorOf(DataTypes.SMALLINT, 1), + GenericType.vectorOf(Short.class), + GenericType.vectorOf(Short.class), + CqlVector.newInstance((short) 1) + }, + { + DataTypes.vectorOf(DataTypes.TINYINT, 1), + GenericType.vectorOf(Byte.class), + GenericType.vectorOf(Byte.class), + CqlVector.newInstance((byte) 1) + }, + { + DataTypes.vectorOf(DataTypes.FLOAT, 1), + GenericType.vectorOf(Float.class), + GenericType.vectorOf(Float.class), + CqlVector.newInstance(1.0f) + }, + { + DataTypes.vectorOf(DataTypes.DOUBLE, 1), + GenericType.vectorOf(Double.class), + GenericType.vectorOf(Double.class), + CqlVector.newInstance(1.0d) + }, + { + DataTypes.vectorOf(DataTypes.DECIMAL, 1), + GenericType.vectorOf(BigDecimal.class), + GenericType.vectorOf(BigDecimal.class), + CqlVector.newInstance(BigDecimal.ONE) + }, + { + DataTypes.vectorOf(DataTypes.VARINT, 1), + GenericType.vectorOf(BigInteger.class), + GenericType.vectorOf(BigInteger.class), + CqlVector.newInstance(BigInteger.ONE) + }, }; }