Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-12937][SQL] bloom filter serialization #10920

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,19 @@

package org.apache.spark.util.sketch;

import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.Arrays;

public final class BitArray {
private final long[] data;
private long bitCount;

static int numWords(long numBits) {
if (numBits <= 0) {
throw new IllegalArgumentException("numBits must be positive, but got " + numBits);
}
long numWords = (long) Math.ceil(numBits / 64.0);
if (numWords > Integer.MAX_VALUE) {
throw new IllegalArgumentException("Can't allocate enough space for " + numBits + " bits");
Expand All @@ -32,13 +38,14 @@ static int numWords(long numBits) {
}

BitArray(long numBits) {
if (numBits <= 0) {
throw new IllegalArgumentException("numBits must be positive");
}
this.data = new long[numWords(numBits)];
this(new long[numWords(numBits)]);
}

private BitArray(long[] data) {
this.data = data;
long bitCount = 0;
for (long value : data) {
bitCount += Long.bitCount(value);
for (long word : data) {
bitCount += Long.bitCount(word);
}
this.bitCount = bitCount;
}
Expand Down Expand Up @@ -78,13 +85,28 @@ void putAll(BitArray array) {
this.bitCount = bitCount;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || !(o instanceof BitArray)) return false;
void writeTo(DataOutputStream out) throws IOException {
out.writeInt(data.length);
for (long datum : data) {
out.writeLong(datum);
}
}

static BitArray readFrom(DataInputStream in) throws IOException {
int numWords = in.readInt();
long[] data = new long[numWords];
for (int i = 0; i < numWords; i++) {
data[i] = in.readLong();
}
return new BitArray(data);
}

BitArray bitArray = (BitArray) o;
return Arrays.equals(data, bitArray.data);
@Override
public boolean equals(Object other) {
if (this == other) return true;
if (other == null || !(other instanceof BitArray)) return false;
BitArray that = (BitArray) other;
return Arrays.equals(data, that.data);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@

package org.apache.spark.util.sketch;

import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;

/**
* A Bloom filter is a space-efficient probabilistic data structure, that is used to test whether
* an element is a member of a set. It returns false when the element is definitely not in the
Expand All @@ -39,6 +43,28 @@
* The implementation is largely based on the {@code BloomFilter} class from guava.
*/
public abstract class BloomFilter {

public enum Version {
/**
* {@code BloomFilter} binary format version 1 (all values written in big-endian order):
* - Version number, always 1 (32 bit)
* - Total number of words of the underlying bit array (32 bit)
* - The words/longs (numWords * 64 bit)
* - Number of hash functions (32 bit)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we write the number of hash functions at the end rather than before the words?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Scaladoc requires an extra space before - to form an unordered list. I'll fix this one in #10911.

*/
V1(1);

private final int versionNumber;

Version(int versionNumber) {
this.versionNumber = versionNumber;
}

int getVersionNumber() {
return versionNumber;
}
}

/**
* Returns the false positive probability, i.e. the probability that
* {@linkplain #mightContain(Object)} will erroneously return {@code true} for an object that
Expand Down Expand Up @@ -83,7 +109,7 @@ public abstract class BloomFilter {
* bloom filters are appropriately sized to avoid saturating them.
*
* @param other The bloom filter to combine this bloom filter with. It is not mutated.
* @throws IllegalArgumentException if {@code isCompatible(that) == false}
* @throws IncompatibleMergeException if {@code isCompatible(other) == false}
*/
public abstract BloomFilter mergeInPlace(BloomFilter other) throws IncompatibleMergeException;

Expand All @@ -93,6 +119,20 @@ public abstract class BloomFilter {
*/
public abstract boolean mightContain(Object item);

/**
* Writes out this {@link BloomFilter} to an output stream in binary format.
* It is the caller's responsibility to close the stream.
*/
public abstract void writeTo(OutputStream out) throws IOException;

/**
* Reads in a {@link BloomFilter} from an input stream.
* It is the caller's responsibility to close the stream.
*/
public static BloomFilter readFrom(InputStream in) throws IOException {
return BloomFilterImpl.readFrom(in);
}

/**
* Computes the optimal k (number of hashes per element inserted in Bloom filter), given the
* expected insertions and total number of bits in the Bloom filter.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,40 @@

package org.apache.spark.util.sketch;

import java.io.UnsupportedEncodingException;
import java.io.*;

public class BloomFilterImpl extends BloomFilter {

private final int numHashFunctions;
private final BitArray bits;

BloomFilterImpl(int numHashFunctions, long numBits) {
this(new BitArray(numBits), numHashFunctions);
}

private BloomFilterImpl(BitArray bits, int numHashFunctions) {
this.bits = bits;
this.numHashFunctions = numHashFunctions;
this.bits = new BitArray(numBits);
}

@Override
public boolean equals(Object other) {
if (other == this) {
return true;
}

if (other == null || !(other instanceof BloomFilterImpl)) {
return false;
}

BloomFilterImpl that = (BloomFilterImpl) other;

return this.numHashFunctions == that.numHashFunctions && this.bits.equals(that.bits);
}

@Override
public int hashCode() {
return bits.hashCode() * 31 + numHashFunctions;
}

@Override
Expand Down Expand Up @@ -161,4 +185,24 @@ public BloomFilter mergeInPlace(BloomFilter other) throws IncompatibleMergeExcep
this.bits.putAll(that.bits);
return this;
}

@Override
public void writeTo(OutputStream out) throws IOException {
DataOutputStream dos = new DataOutputStream(out);

dos.writeInt(Version.V1.getVersionNumber());
bits.writeTo(dos);
dos.writeInt(numHashFunctions);
}

public static BloomFilterImpl readFrom(InputStream in) throws IOException {
DataInputStream dis = new DataInputStream(in);

int version = dis.readInt();
if (version != Version.V1.getVersionNumber()) {
throw new IOException("Unexpected Bloom filter version number (" + version + ")");
}

return new BloomFilterImpl(BitArray.readFrom(dis), dis.readInt());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,21 @@
* This implementation is largely based on the {@code CountMinSketch} class from stream-lib.
*/
abstract public class CountMinSketch {
/**
* Version number of the serialized binary format.
*/

public enum Version {
/**
* {@code CountMinSketch} binary format version 1 (all values written in big-endian order):
* - Version number, always 1 (32 bit)
* - Total count of added items (64 bit)
* - Depth (32 bit)
* - Width (32 bit)
* - Hash functions (depth * 64 bit)
* - Count table
* - Row 0 (width * 64 bit)
* - Row 1 (width * 64 bit)
* - ...
* - Row depth - 1 (width * 64 bit)
*/
V1(1);

private final int versionNumber;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @liancheng , I removed this as the design doc says users should not care about the version being used.

Expand All @@ -67,13 +78,11 @@ public enum Version {
this.versionNumber = versionNumber;
}

public int getVersionNumber() {
int getVersionNumber() {
return versionNumber;
}
}

public abstract Version version();

/**
* Returns the relative error (or {@code eps}) of this {@link CountMinSketch}.
*/
Expand Down Expand Up @@ -128,13 +137,13 @@ public abstract CountMinSketch mergeInPlace(CountMinSketch other)

/**
* Writes out this {@link CountMinSketch} to an output stream in binary format.
* It is the caller's responsibility to close the stream
* It is the caller's responsibility to close the stream.
*/
public abstract void writeTo(OutputStream out) throws IOException;

/**
* Reads in a {@link CountMinSketch} from an input stream.
* It is the caller's responsibility to close the stream
* It is the caller's responsibility to close the stream.
*/
public static CountMinSketch readFrom(InputStream in) throws IOException {
return CountMinSketchImpl.readFrom(in);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,21 +26,6 @@
import java.util.Arrays;
import java.util.Random;

/*
* Binary format of a serialized CountMinSketchImpl, version 1 (all values written in big-endian
* order):
*
* - Version number, always 1 (32 bit)
* - Total count of added items (64 bit)
* - Depth (32 bit)
* - Width (32 bit)
* - Hash functions (depth * 64 bit)
* - Count table
* - Row 0 (width * 64 bit)
* - Row 1 (width * 64 bit)
* - ...
* - Row depth - 1 (width * 64 bit)
*/
class CountMinSketchImpl extends CountMinSketch {
public static final long PRIME_MODULUS = (1L << 31) - 1;

Expand Down Expand Up @@ -112,11 +97,6 @@ public int hashCode() {
return hash;
}

@Override
public Version version() {
return Version.V1;
}

private void initTablesWith(int depth, int width, int seed) {
this.table = new long[depth][width];
this.hashA = new long[depth];
Expand Down Expand Up @@ -327,7 +307,7 @@ public CountMinSketch mergeInPlace(CountMinSketch other) throws IncompatibleMerg
public void writeTo(OutputStream out) throws IOException {
DataOutputStream dos = new DataOutputStream(out);

dos.writeInt(version().getVersionNumber());
dos.writeInt(Version.V1.getVersionNumber());

dos.writeLong(this.totalCount);
dos.writeInt(this.depth);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.util.sketch

import java.io.{ByteArrayInputStream, ByteArrayOutputStream}

import scala.reflect.ClassTag
import scala.util.Random

Expand All @@ -25,6 +27,20 @@ import org.scalatest.FunSuite // scalastyle:ignore funsuite
class BloomFilterSuite extends FunSuite { // scalastyle:ignore funsuite
private final val EPSILON = 0.01

// Serializes and deserializes a given `BloomFilter`, then checks whether the deserialized
// version is equivalent to the original one.
private def checkSerDe(filter: BloomFilter): Unit = {
val out = new ByteArrayOutputStream()
filter.writeTo(out)
out.close()

val in = new ByteArrayInputStream(out.toByteArray)
val deserialized = BloomFilter.readFrom(in)
in.close()

assert(filter == deserialized)
}

def testAccuracy[T: ClassTag](typeName: String, numItems: Int)(itemGen: Random => T): Unit = {
test(s"accuracy - $typeName") {
// use a fixed seed to make the test predictable.
Expand All @@ -51,6 +67,8 @@ class BloomFilterSuite extends FunSuite { // scalastyle:ignore funsuite
// Also check the actual fpp is not significantly higher than we expected.
val actualFpp = errorCount.toDouble / (numItems - numInsertion)
assert(actualFpp - fpp < EPSILON)

checkSerDe(filter)
}
}

Expand All @@ -76,6 +94,8 @@ class BloomFilterSuite extends FunSuite { // scalastyle:ignore funsuite

items1.foreach(i => assert(filter1.mightContain(i)))
items2.foreach(i => assert(filter1.mightContain(i)))

checkSerDe(filter1)
}
}

Expand Down