diff --git a/core/src/main/java/com/datastax/oss/driver/api/core/data/CqlDuration.java b/core/src/main/java/com/datastax/oss/driver/api/core/data/CqlDuration.java index 8ec509ea7f6..5bb07b92923 100644 --- a/core/src/main/java/com/datastax/oss/driver/api/core/data/CqlDuration.java +++ b/core/src/main/java/com/datastax/oss/driver/api/core/data/CqlDuration.java @@ -20,6 +20,7 @@ import com.datastax.oss.driver.shaded.guava.common.base.Preconditions; import com.datastax.oss.driver.shaded.guava.common.collect.ImmutableList; import edu.umd.cs.findbugs.annotations.NonNull; +import java.io.Serializable; import java.time.Duration; import java.time.Period; import java.time.temporal.ChronoUnit; @@ -42,7 +43,9 @@ * in time, regardless of the calendar). */ @Immutable -public final class CqlDuration implements TemporalAmount { +public final class CqlDuration implements TemporalAmount, Serializable { + + private static final long serialVersionUID = 1L; @VisibleForTesting static final long NANOS_PER_MICRO = 1000L; @VisibleForTesting static final long NANOS_PER_MILLI = 1000 * NANOS_PER_MICRO; @@ -75,8 +78,11 @@ public final class CqlDuration implements TemporalAmount { private static final ImmutableList TEMPORAL_UNITS = ImmutableList.of(ChronoUnit.MONTHS, ChronoUnit.DAYS, ChronoUnit.NANOS); + /** @serial */ private final int months; + /** @serial */ private final int days; + /** @serial */ private final long nanoseconds; private CqlDuration(int months, int days, long nanoseconds) { diff --git a/core/src/main/java/com/datastax/oss/driver/api/core/data/CqlVector.java b/core/src/main/java/com/datastax/oss/driver/api/core/data/CqlVector.java index 152d0f40823..2889ea5eb24 100644 --- a/core/src/main/java/com/datastax/oss/driver/api/core/data/CqlVector.java +++ b/core/src/main/java/com/datastax/oss/driver/api/core/data/CqlVector.java @@ -22,6 +22,12 @@ import com.datastax.oss.driver.shaded.guava.common.collect.Iterables; import com.datastax.oss.driver.shaded.guava.common.collect.Streams; import edu.umd.cs.findbugs.annotations.NonNull; +import java.io.IOException; +import java.io.InvalidObjectException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.io.ObjectStreamException; +import java.io.Serializable; import java.util.ArrayList; import java.util.Arrays; import java.util.Iterator; @@ -44,7 +50,7 @@ * where possible we've tried to make the API of this class similar to the equivalent methods on * {@link List}. */ -public class CqlVector implements Iterable { +public class CqlVector implements Iterable, Serializable { /** * Create a new CqlVector containing the specified values. @@ -190,4 +196,58 @@ public int hashCode() { public String toString() { return Iterables.toString(this.list); } + + /** + * Serialization proxy for CqlVector. Allows serialization regardless of implementation of list + * field. + * + * @param inner type of CqlVector, assume Number is always Serializable. + */ + private static class SerializationProxy implements Serializable { + + private static final long serialVersionUID = 1; + + private transient List list; + + SerializationProxy(CqlVector vector) { + this.list = vector.list; + } + + // Reconstruct CqlVector's list of elements. + private void readObject(ObjectInputStream stream) throws IOException, ClassNotFoundException { + stream.defaultReadObject(); + + int size = stream.readInt(); + list = new ArrayList<>(size); + for (int i = 0; i < size; i++) { + list.add((T) stream.readObject()); + } + } + + // Return deserialized proxy object as CqlVector. + private Object readResolve() throws ObjectStreamException { + return new CqlVector(list); + } + + // Write size of CqlVector followed by items in order. + private void writeObject(ObjectOutputStream stream) throws IOException { + stream.defaultWriteObject(); + + stream.writeInt(list.size()); + for (T item : list) { + stream.writeObject(item); + } + } + } + + /** @serialData The number of elements in the vector, followed by each element in-order. */ + private Object writeReplace() { + return new SerializationProxy(this); + } + + private void readObject(@SuppressWarnings("unused") ObjectInputStream stream) + throws InvalidObjectException { + // Should never be called since we serialized a proxy + throw new InvalidObjectException("Proxy required"); + } } diff --git a/core/src/test/java/com/datastax/oss/driver/api/core/data/CqlDurationTest.java b/core/src/test/java/com/datastax/oss/driver/api/core/data/CqlDurationTest.java index 56c0b00b5e3..f5c263f0594 100644 --- a/core/src/test/java/com/datastax/oss/driver/api/core/data/CqlDurationTest.java +++ b/core/src/test/java/com/datastax/oss/driver/api/core/data/CqlDurationTest.java @@ -20,6 +20,7 @@ import static org.assertj.core.api.Assertions.fail; import com.datastax.oss.driver.TestDataProviders; +import com.datastax.oss.driver.internal.SerializationHelper; import com.tngtech.java.junit.dataprovider.DataProviderRunner; import com.tngtech.java.junit.dataprovider.UseDataProvider; import java.time.ZonedDateTime; @@ -190,4 +191,18 @@ public void should_subtract_from_temporal() { assertThat(dateTime.minus(CqlDuration.from("1h15s15ns"))) .isEqualTo("2018-10-03T22:59:44.999999985-07:00[America/Los_Angeles]"); } + + @Test + public void should_serialize_and_deserialize() throws Exception { + CqlDuration initial = CqlDuration.from("3mo2d15s"); + CqlDuration deserialized = SerializationHelper.serializeAndDeserialize(initial); + assertThat(deserialized).isEqualTo(initial); + } + + @Test + public void should_serialize_and_deserialize_negative() throws Exception { + CqlDuration initial = CqlDuration.from("-2d15m"); + CqlDuration deserialized = SerializationHelper.serializeAndDeserialize(initial); + assertThat(deserialized).isEqualTo(initial); + } } diff --git a/core/src/test/java/com/datastax/oss/driver/api/core/data/CqlVectorTest.java b/core/src/test/java/com/datastax/oss/driver/api/core/data/CqlVectorTest.java index ecf8f1249d0..75dfbc26e42 100644 --- a/core/src/test/java/com/datastax/oss/driver/api/core/data/CqlVectorTest.java +++ b/core/src/test/java/com/datastax/oss/driver/api/core/data/CqlVectorTest.java @@ -19,9 +19,12 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy; import com.datastax.oss.driver.api.core.type.codec.TypeCodecs; +import com.datastax.oss.driver.internal.SerializationHelper; import com.datastax.oss.driver.shaded.guava.common.collect.Iterators; +import java.util.AbstractList; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.List; import java.util.stream.Collectors; import org.assertj.core.util.Lists; @@ -195,4 +198,37 @@ public void should_correctly_compare_vectors() { assertThat(vector1).isNotSameAs(vector5); assertThat(vector1).isNotEqualTo(vector5); } + + @Test + public void should_serialize_and_deserialize() throws Exception { + CqlVector initial = CqlVector.newInstance(VECTOR_ARGS); + CqlVector deserialized = SerializationHelper.serializeAndDeserialize(initial); + assertThat(deserialized).isEqualTo(initial); + } + + @Test + public void should_serialize_and_deserialize_empty_vector() throws Exception { + CqlVector initial = CqlVector.newInstance(Collections.emptyList()); + CqlVector deserialized = SerializationHelper.serializeAndDeserialize(initial); + assertThat(deserialized).isEqualTo(initial); + } + + @Test + public void should_serialize_and_deserialize_unserializable_list() throws Exception { + CqlVector initial = + CqlVector.newInstance( + new AbstractList() { + @Override + public Float get(int index) { + return VECTOR_ARGS[index]; + } + + @Override + public int size() { + return VECTOR_ARGS.length; + } + }); + CqlVector deserialized = SerializationHelper.serializeAndDeserialize(initial); + assertThat(deserialized).isEqualTo(initial); + } }