diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BitArray.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BitArray.java index 1bc665ad54b72..0e5b6f5668c0d 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/BitArray.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BitArray.java @@ -17,6 +17,9 @@ package org.apache.spark.util.sketch; +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; import java.util.Arrays; public final class BitArray { @@ -24,6 +27,9 @@ public final class BitArray { private long bitCount; static int numWords(long numBits) { + if (numBits <= 0) { + throw new IllegalArgumentException("numBits must be positive"); + } long numWords = (long) Math.ceil(numBits / 64.0); if (numWords > Integer.MAX_VALUE) { throw new IllegalArgumentException("Can't allocate enough space for " + numBits + " bits"); @@ -32,13 +38,14 @@ static int numWords(long numBits) { } BitArray(long numBits) { - if (numBits <= 0) { - throw new IllegalArgumentException("numBits must be positive"); - } - this.data = new long[numWords(numBits)]; + this(new long[numWords(numBits)]); + } + + private BitArray(long[] data) { + this.data = data; long bitCount = 0; - for (long value : data) { - bitCount += Long.bitCount(value); + for (long datum : data) { + bitCount += Long.bitCount(datum); } this.bitCount = bitCount; } @@ -78,13 +85,28 @@ void putAll(BitArray array) { this.bitCount = bitCount; } - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || !(o instanceof BitArray)) return false; + void writeTo(DataOutputStream out) throws IOException { + out.writeInt(data.length); + for (long datum : data) { + out.writeLong(datum); + } + } - BitArray bitArray = (BitArray) o; - return Arrays.equals(data, bitArray.data); + static BitArray readFrom(DataInputStream in) throws IOException { + int numWords = in.readInt(); + long[] data = new long[numWords]; + for (int i = 0; i < numWords; i++) { + data[i] = in.readLong(); + } + return new BitArray(data); + } + + @Override + public boolean equals(Object other) { + if (this == other) return true; + if (other == null || !(other instanceof BitArray)) return false; + BitArray that = (BitArray) other; + return Arrays.equals(data, that.data); } @Override diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java index 38949c6311df8..de10c6a23c105 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java @@ -17,6 +17,10 @@ package org.apache.spark.util.sketch; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; + /** * A Bloom filter is a space-efficient probabilistic data structure, that is used to test whether * an element is a member of a set. It returns false when the element is definitely not in the @@ -83,7 +87,7 @@ public abstract class BloomFilter { * bloom filters are appropriately sized to avoid saturating them. * * @param other The bloom filter to combine this bloom filter with. It is not mutated. - * @throws IllegalArgumentException if {@code isCompatible(that) == false} + * @throws IncompatibleMergeException if {@code isCompatible(that) == false} */ public abstract BloomFilter mergeInPlace(BloomFilter other) throws IncompatibleMergeException; @@ -93,6 +97,20 @@ public abstract class BloomFilter { */ public abstract boolean mightContain(Object item); + /** + * Writes out this {@link BloomFilter} to an output stream in binary format. + * It is the caller's responsibility to close the stream. + */ + public abstract void writeTo(OutputStream out) throws IOException; + + /** + * Reads in a {@link BloomFilter} from an input stream. + * It is the caller's responsibility to close the stream. + */ + public static BloomFilter readFrom(InputStream in) throws IOException { + return BloomFilterImpl.readFrom(in); + } + /** * Computes the optimal k (number of hashes per element inserted in Bloom filter), given the * expected insertions and total number of bits in the Bloom filter. diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java index bbd6cf719dc0e..b97043686b331 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java @@ -17,16 +17,49 @@ package org.apache.spark.util.sketch; -import java.io.UnsupportedEncodingException; +import java.io.*; +/* + * Binary format of a serialized BloomFilterImpl, version 1 (all values written in big-endian + * order): + * + * - Version number, always 1 (32 bit) + * - Total number of words of the BitArray (32 bit) + * - Long array inside the BitArray (numWords * 64 bit) + * - Number of hash functions (32 bit) + */ public class BloomFilterImpl extends BloomFilter { private final int numHashFunctions; private final BitArray bits; BloomFilterImpl(int numHashFunctions, long numBits) { + this(new BitArray(numBits), numHashFunctions); + } + + private BloomFilterImpl(BitArray bits, int numHashFunctions) { + this.bits = bits; this.numHashFunctions = numHashFunctions; - this.bits = new BitArray(numBits); + } + + @Override + public boolean equals(Object other) { + if (other == this) { + return true; + } + + if (other == null || !(other instanceof BloomFilterImpl)) { + return false; + } + + BloomFilterImpl that = (BloomFilterImpl) other; + + return this.numHashFunctions == that.numHashFunctions && this.bits.equals(that.bits); + } + + @Override + public int hashCode() { + return bits.hashCode() * 31 + numHashFunctions; } @Override @@ -161,4 +194,24 @@ public BloomFilter mergeInPlace(BloomFilter other) throws IncompatibleMergeExcep this.bits.putAll(that.bits); return this; } + + @Override + public void writeTo(OutputStream out) throws IOException { + DataOutputStream dos = new DataOutputStream(out); + + dos.writeInt(Version.V1.getVersionNumber()); + bits.writeTo(dos); + dos.writeInt(numHashFunctions); + } + + public static BloomFilterImpl readFrom(InputStream in) throws IOException { + DataInputStream dis = new DataInputStream(in); + + int version = dis.readInt(); + if (version != Version.V1.getVersionNumber()) { + throw new IOException("Unexpected Bloom Filter version number (" + version + ")"); + } + + return new BloomFilterImpl(BitArray.readFrom(dis), dis.readInt()); + } } diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java index 9f4ff42403c34..004fbbf3152ff 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java @@ -55,25 +55,6 @@ * This implementation is largely based on the {@code CountMinSketch} class from stream-lib. */ abstract public class CountMinSketch { - /** - * Version number of the serialized binary format. - */ - public enum Version { - V1(1); - - private final int versionNumber; - - Version(int versionNumber) { - this.versionNumber = versionNumber; - } - - public int getVersionNumber() { - return versionNumber; - } - } - - public abstract Version version(); - /** * Returns the relative error (or {@code eps}) of this {@link CountMinSketch}. */ @@ -128,13 +109,13 @@ public abstract CountMinSketch mergeInPlace(CountMinSketch other) /** * Writes out this {@link CountMinSketch} to an output stream in binary format. - * It is the caller's responsibility to close the stream + * It is the caller's responsibility to close the stream. */ public abstract void writeTo(OutputStream out) throws IOException; /** * Reads in a {@link CountMinSketch} from an input stream. - * It is the caller's responsibility to close the stream + * It is the caller's responsibility to close the stream. */ public static CountMinSketch readFrom(InputStream in) throws IOException { return CountMinSketchImpl.readFrom(in); diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java index 0209446ea3b1d..8f17ddb310115 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java @@ -112,11 +112,6 @@ public int hashCode() { return hash; } - @Override - public Version version() { - return Version.V1; - } - private void initTablesWith(int depth, int width, int seed) { this.table = new long[depth][width]; this.hashA = new long[depth]; @@ -327,7 +322,7 @@ public CountMinSketch mergeInPlace(CountMinSketch other) throws IncompatibleMerg public void writeTo(OutputStream out) throws IOException { DataOutputStream dos = new DataOutputStream(out); - dos.writeInt(version().getVersionNumber()); + dos.writeInt(Version.V1.getVersionNumber()); dos.writeLong(this.totalCount); dos.writeInt(this.depth); diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/Version.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/Version.java new file mode 100644 index 0000000000000..40790c92f3aef --- /dev/null +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/Version.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.sketch; + +/** + * Version number of the serialized binary format for bloom filter or count-min sketch. + */ +public enum Version { + V1(1); + + private final int versionNumber; + + Version(int versionNumber) { + this.versionNumber = versionNumber; + } + + int getVersionNumber() { + return versionNumber; + } +} diff --git a/common/sketch/src/test/scala/org/apache/spark/util/sketch/BloomFilterSuite.scala b/common/sketch/src/test/scala/org/apache/spark/util/sketch/BloomFilterSuite.scala index d2de509f19517..a0408d2da4dff 100644 --- a/common/sketch/src/test/scala/org/apache/spark/util/sketch/BloomFilterSuite.scala +++ b/common/sketch/src/test/scala/org/apache/spark/util/sketch/BloomFilterSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.util.sketch +import java.io.{ByteArrayInputStream, ByteArrayOutputStream} + import scala.reflect.ClassTag import scala.util.Random @@ -25,6 +27,20 @@ import org.scalatest.FunSuite // scalastyle:ignore funsuite class BloomFilterSuite extends FunSuite { // scalastyle:ignore funsuite private final val EPSILON = 0.01 + // Serializes and deserializes a given `BloomFilter`, then checks whether the deserialized + // version is equivalent to the original one. + private def checkSerDe(filter: BloomFilter): Unit = { + val out = new ByteArrayOutputStream() + filter.writeTo(out) + out.close() + + val in = new ByteArrayInputStream(out.toByteArray) + val deserialized = BloomFilter.readFrom(in) + in.close() + + assert(filter == deserialized) + } + def testAccuracy[T: ClassTag](typeName: String, numItems: Int)(itemGen: Random => T): Unit = { test(s"accuracy - $typeName") { // use a fixed seed to make the test predictable. @@ -51,6 +67,8 @@ class BloomFilterSuite extends FunSuite { // scalastyle:ignore funsuite // Also check the actual fpp is not significantly higher than we expected. val actualFpp = errorCount.toDouble / (numItems - numInsertion) assert(actualFpp - fpp < EPSILON) + + checkSerDe(filter) } } @@ -76,6 +94,8 @@ class BloomFilterSuite extends FunSuite { // scalastyle:ignore funsuite items1.foreach(i => assert(filter1.mightContain(i))) items2.foreach(i => assert(filter1.mightContain(i))) + + checkSerDe(filter1) } }