diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java index d392fb187ad65..81772fcea0ec2 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java @@ -49,9 +49,9 @@ public enum Version { * {@code BloomFilter} binary format version 1 (all values written in big-endian order): * */ V1(1); @@ -97,6 +97,21 @@ int getVersionNumber() { */ public abstract boolean put(Object item); + /** + * A specialized variant of {@link #put(Object)}, that can only be used to put utf-8 string. + */ + public abstract boolean putString(String str); + + /** + * A specialized variant of {@link #put(Object)}, that can only be used to put long. + */ + public abstract boolean putLong(long l); + + /** + * A specialized variant of {@link #put(Object)}, that can only be used to put byte array. + */ + public abstract boolean putBinary(byte[] bytes); + /** * Determines whether a given bloom filter is compatible with this bloom filter. For two * bloom filters to be compatible, they must have the same bit size. @@ -121,6 +136,23 @@ int getVersionNumber() { */ public abstract boolean mightContain(Object item); + /** + * A specialized variant of {@link #mightContain(Object)}, that can only be used to test utf-8 + * string. + */ + public abstract boolean mightContainString(String str); + + /** + * A specialized variant of {@link #mightContain(Object)}, that can only be used to test long. + */ + public abstract boolean mightContainLong(long l); + + /** + * A specialized variant of {@link #mightContain(Object)}, that can only be used to test byte + * array. + */ + public abstract boolean mightContainBinary(byte[] bytes); + /** * Writes out this {@link BloomFilter} to an output stream in binary format. * It is the caller's responsibility to close the stream. diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java index 1c08d07afaeaa..35107e0b389d7 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java @@ -19,10 +19,10 @@ import java.io.*; -public class BloomFilterImpl extends BloomFilter { +public class BloomFilterImpl extends BloomFilter implements Serializable { - private final int numHashFunctions; - private final BitArray bits; + private int numHashFunctions; + private BitArray bits; BloomFilterImpl(int numHashFunctions, long numBits) { this(new BitArray(numBits), numHashFunctions); @@ -33,6 +33,8 @@ private BloomFilterImpl(BitArray bits, int numHashFunctions) { this.numHashFunctions = numHashFunctions; } + private BloomFilterImpl() {} + @Override public boolean equals(Object other) { if (other == this) { @@ -63,55 +65,75 @@ public long bitSize() { return bits.bitSize(); } - private static long hashObjectToLong(Object item) { + @Override + public boolean put(Object item) { if (item instanceof String) { - try { - byte[] bytes = ((String) item).getBytes("utf-8"); - return hashBytesToLong(bytes); - } catch (UnsupportedEncodingException e) { - throw new RuntimeException("Only support utf-8 string", e); - } + return putString((String) item); + } else if (item instanceof byte[]) { + return putBinary((byte[]) item); } else { - long longValue; - - if (item instanceof Long) { - longValue = (Long) item; - } else if (item instanceof Integer) { - longValue = ((Integer) item).longValue(); - } else if (item instanceof Short) { - longValue = ((Short) item).longValue(); - } else if (item instanceof Byte) { - longValue = ((Byte) item).longValue(); - } else { - throw new IllegalArgumentException( - "Support for " + item.getClass().getName() + " not implemented" - ); - } - - int h1 = Murmur3_x86_32.hashLong(longValue, 0); - int h2 = Murmur3_x86_32.hashLong(longValue, h1); - return (((long) h1) << 32) | (h2 & 0xFFFFFFFFL); + return putLong(Utils.integralToLong(item)); } } - private static long hashBytesToLong(byte[] bytes) { + @Override + public boolean putString(String str) { + return putBinary(Utils.getBytesFromUTF8String(str)); + } + + @Override + public boolean putBinary(byte[] bytes) { int h1 = Murmur3_x86_32.hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, bytes.length, 0); int h2 = Murmur3_x86_32.hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, bytes.length, h1); - return (((long) h1) << 32) | (h2 & 0xFFFFFFFFL); + + long bitSize = bits.bitSize(); + boolean bitsChanged = false; + for (int i = 1; i <= numHashFunctions; i++) { + int combinedHash = h1 + (i * h2); + // Flip all the bits if it's negative (guaranteed positive number) + if (combinedHash < 0) { + combinedHash = ~combinedHash; + } + bitsChanged |= bits.set(combinedHash % bitSize); + } + return bitsChanged; } @Override - public boolean put(Object item) { + public boolean mightContainString(String str) { + return mightContainBinary(Utils.getBytesFromUTF8String(str)); + } + + @Override + public boolean mightContainBinary(byte[] bytes) { + int h1 = Murmur3_x86_32.hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, bytes.length, 0); + int h2 = Murmur3_x86_32.hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, bytes.length, h1); + long bitSize = bits.bitSize(); + for (int i = 1; i <= numHashFunctions; i++) { + int combinedHash = h1 + (i * h2); + // Flip all the bits if it's negative (guaranteed positive number) + if (combinedHash < 0) { + combinedHash = ~combinedHash; + } + if (!bits.get(combinedHash % bitSize)) { + return false; + } + } + return true; + } - // Here we first hash the input element into 2 int hash values, h1 and h2, then produce n hash - // values by `h1 + i * h2` with 1 <= i <= numHashFunctions. - // Note that `CountMinSketch` use a different strategy for long type, it hash the input long - // element with every i to produce n hash values. - long hash64 = hashObjectToLong(item); - int h1 = (int) (hash64 >> 32); - int h2 = (int) hash64; + @Override + public boolean putLong(long l) { + // Here we first hash the input long element into 2 int hash values, h1 and h2, then produce n + // hash values by `h1 + i * h2` with 1 <= i <= numHashFunctions. + // Note that `CountMinSketch` use a different strategy, it hash the input long element with + // every i to produce n hash values. + // TODO: the strategy of `CountMinSketch` looks more advanced, should we follow it here? + int h1 = Murmur3_x86_32.hashLong(l, 0); + int h2 = Murmur3_x86_32.hashLong(l, h1); + long bitSize = bits.bitSize(); boolean bitsChanged = false; for (int i = 1; i <= numHashFunctions; i++) { int combinedHash = h1 + (i * h2); @@ -125,12 +147,11 @@ public boolean put(Object item) { } @Override - public boolean mightContain(Object item) { - long bitSize = bits.bitSize(); - long hash64 = hashObjectToLong(item); - int h1 = (int) (hash64 >> 32); - int h2 = (int) hash64; + public boolean mightContainLong(long l) { + int h1 = Murmur3_x86_32.hashLong(l, 0); + int h2 = Murmur3_x86_32.hashLong(l, h1); + long bitSize = bits.bitSize(); for (int i = 1; i <= numHashFunctions; i++) { int combinedHash = h1 + (i * h2); // Flip all the bits if it's negative (guaranteed positive number) @@ -144,6 +165,17 @@ public boolean mightContain(Object item) { return true; } + @Override + public boolean mightContain(Object item) { + if (item instanceof String) { + return mightContainString((String) item); + } else if (item instanceof byte[]) { + return mightContainBinary((byte[]) item); + } else { + return mightContainLong(Utils.integralToLong(item)); + } + } + @Override public boolean isCompatible(BloomFilter other) { if (other == null) { @@ -191,11 +223,11 @@ public void writeTo(OutputStream out) throws IOException { DataOutputStream dos = new DataOutputStream(out); dos.writeInt(Version.V1.getVersionNumber()); - bits.writeTo(dos); dos.writeInt(numHashFunctions); + bits.writeTo(dos); } - public static BloomFilterImpl readFrom(InputStream in) throws IOException { + private void readFrom0(InputStream in) throws IOException { DataInputStream dis = new DataInputStream(in); int version = dis.readInt(); @@ -203,6 +235,21 @@ public static BloomFilterImpl readFrom(InputStream in) throws IOException { throw new IOException("Unexpected Bloom filter version number (" + version + ")"); } - return new BloomFilterImpl(BitArray.readFrom(dis), dis.readInt()); + this.numHashFunctions = dis.readInt(); + this.bits = BitArray.readFrom(dis); + } + + public static BloomFilterImpl readFrom(InputStream in) throws IOException { + BloomFilterImpl filter = new BloomFilterImpl(); + filter.readFrom0(in); + return filter; + } + + private void writeObject(ObjectOutputStream out) throws IOException { + writeTo(out); + } + + private void readObject(ObjectInputStream in) throws IOException { + readFrom0(in); } } diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java index 8cc29e4076307..e49ae22906c4c 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java @@ -40,8 +40,7 @@ class CountMinSketchImpl extends CountMinSketch implements Serializable { private double eps; private double confidence; - private CountMinSketchImpl() { - } + private CountMinSketchImpl() {} CountMinSketchImpl(int depth, int width, int seed) { this.depth = depth; @@ -143,23 +142,7 @@ public void add(Object item, long count) { if (item instanceof String) { addString((String) item, count); } else { - long longValue; - - if (item instanceof Long) { - longValue = (Long) item; - } else if (item instanceof Integer) { - longValue = ((Integer) item).longValue(); - } else if (item instanceof Short) { - longValue = ((Short) item).longValue(); - } else if (item instanceof Byte) { - longValue = ((Byte) item).longValue(); - } else { - throw new IllegalArgumentException( - "Support for " + item.getClass().getName() + " not implemented" - ); - } - - addLong(longValue, count); + addLong(Utils.integralToLong(item), count); } } @@ -201,13 +184,7 @@ private int hash(long item, int count) { } private static int[] getHashBuckets(String key, int hashCount, int max) { - byte[] b; - try { - b = key.getBytes("UTF-8"); - } catch (UnsupportedEncodingException e) { - throw new RuntimeException(e); - } - return getHashBuckets(b, hashCount, max); + return getHashBuckets(Utils.getBytesFromUTF8String(key), hashCount, max); } private static int[] getHashBuckets(byte[] b, int hashCount, int max) { @@ -225,23 +202,7 @@ public long estimateCount(Object item) { if (item instanceof String) { return estimateCountForStringItem((String) item); } else { - long longValue; - - if (item instanceof Long) { - longValue = (Long) item; - } else if (item instanceof Integer) { - longValue = ((Integer) item).longValue(); - } else if (item instanceof Short) { - longValue = ((Short) item).longValue(); - } else if (item instanceof Byte) { - longValue = ((Byte) item).longValue(); - } else { - throw new IllegalArgumentException( - "Support for " + item.getClass().getName() + " not implemented" - ); - } - - return estimateCountForLongItem(longValue); + return estimateCountForLongItem(Utils.integralToLong(item)); } } diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/Utils.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/Utils.java new file mode 100644 index 0000000000000..a6b33313035b0 --- /dev/null +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/Utils.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.sketch; + +import java.io.UnsupportedEncodingException; + +public class Utils { + public static byte[] getBytesFromUTF8String(String str) { + try { + return str.getBytes("utf-8"); + } catch (UnsupportedEncodingException e) { + throw new IllegalArgumentException("Only support utf-8 string", e); + } + } + + public static long integralToLong(Object i) { + long longValue; + + if (i instanceof Long) { + longValue = (Long) i; + } else if (i instanceof Integer) { + longValue = ((Integer) i).longValue(); + } else if (i instanceof Short) { + longValue = ((Short) i).longValue(); + } else if (i instanceof Byte) { + longValue = ((Byte) i).longValue(); + } else { + throw new IllegalArgumentException("Unsupported data type " + i.getClass().getName()); + } + + return longValue; + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index 465b12bb59d1e..b0b6995a2214f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -22,9 +22,10 @@ import java.{lang => jl, util => ju} import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.stat._ -import org.apache.spark.sql.types._ -import org.apache.spark.util.sketch.CountMinSketch +import org.apache.spark.sql.types.{IntegralType, StringType} +import org.apache.spark.util.sketch.{BloomFilter, CountMinSketch} /** * :: Experimental :: @@ -390,4 +391,75 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { } ) } + + /** + * Builds a Bloom filter over a specified column. + * + * @param colName name of the column over which the filter is built + * @param expectedNumItems expected number of items which will be put into the filter. + * @param fpp expected false positive probability of the filter. + * @since 2.0.0 + */ + def bloomFilter(colName: String, expectedNumItems: Long, fpp: Double): BloomFilter = { + buildBloomFilter(Column(colName), BloomFilter.create(expectedNumItems, fpp)) + } + + /** + * Builds a Bloom filter over a specified column. + * + * @param col the column over which the filter is built + * @param expectedNumItems expected number of items which will be put into the filter. + * @param fpp expected false positive probability of the filter. + * @since 2.0.0 + */ + def bloomFilter(col: Column, expectedNumItems: Long, fpp: Double): BloomFilter = { + buildBloomFilter(col, BloomFilter.create(expectedNumItems, fpp)) + } + + /** + * Builds a Bloom filter over a specified column. + * + * @param colName name of the column over which the filter is built + * @param expectedNumItems expected number of items which will be put into the filter. + * @param numBits expected number of bits of the filter. + * @since 2.0.0 + */ + def bloomFilter(colName: String, expectedNumItems: Long, numBits: Long): BloomFilter = { + buildBloomFilter(Column(colName), BloomFilter.create(expectedNumItems, numBits)) + } + + /** + * Builds a Bloom filter over a specified column. + * + * @param col the column over which the filter is built + * @param expectedNumItems expected number of items which will be put into the filter. + * @param numBits expected number of bits of the filter. + * @since 2.0.0 + */ + def bloomFilter(col: Column, expectedNumItems: Long, numBits: Long): BloomFilter = { + buildBloomFilter(col, BloomFilter.create(expectedNumItems, numBits)) + } + + private def buildBloomFilter(col: Column, zero: BloomFilter): BloomFilter = { + val singleCol = df.select(col) + val colType = singleCol.schema.head.dataType + + require(colType == StringType || colType.isInstanceOf[IntegralType], + s"Bloom filter only supports string type and integral types, but got $colType.") + + val seqOp: (BloomFilter, InternalRow) => BloomFilter = if (colType == StringType) { + (filter, row) => + // For string type, we can get bytes of our `UTF8String` directly, and call the `putBinary` + // instead of `putString` to avoid unnecessary conversion. + filter.putBinary(row.getUTF8String(0).getBytes) + filter + } else { + (filter, row) => + // TODO: specialize it. + filter.putLong(row.get(0, colType).asInstanceOf[Number].longValue()) + filter + } + + singleCol.queryExecution.toRdd.aggregate(zero)(seqOp, _ mergeInPlace _) + } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 9cf94e72d34e2..0d4c128cb36d6 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -40,6 +40,7 @@ import org.apache.spark.util.sketch.CountMinSketch; import static org.apache.spark.sql.functions.*; import static org.apache.spark.sql.types.DataTypes.*; +import org.apache.spark.util.sketch.BloomFilter; public class JavaDataFrameSuite { private transient JavaSparkContext jsc; @@ -300,6 +301,7 @@ public void pivot() { Assert.assertEquals(30000.0, actual[1].getDouble(2), 0.01); } + @Test public void testGenericLoad() { DataFrame df1 = context.read().format("text").load( Thread.currentThread().getContextClassLoader().getResource("text-suite.txt").toString()); @@ -347,4 +349,33 @@ public void testCountMinSketch() { Assert.assertEquals(sketch4.relativeError(), 0.001, 1e-4); Assert.assertEquals(sketch4.confidence(), 0.99, 5e-3); } + + @Test + public void testBloomFilter() { + DataFrame df = context.range(1000); + + BloomFilter filter1 = df.stat().bloomFilter("id", 1000, 0.03); + assert (filter1.expectedFpp() - 0.03 < 1e-3); + for (int i = 0; i < 1000; i++) { + assert (filter1.mightContain(i)); + } + + BloomFilter filter2 = df.stat().bloomFilter(col("id").multiply(3), 1000, 0.03); + assert (filter2.expectedFpp() - 0.03 < 1e-3); + for (int i = 0; i < 1000; i++) { + assert (filter2.mightContain(i * 3)); + } + + BloomFilter filter3 = df.stat().bloomFilter("id", 1000, 64 * 5); + assert (filter3.bitSize() == 64 * 5); + for (int i = 0; i < 1000; i++) { + assert (filter3.mightContain(i)); + } + + BloomFilter filter4 = df.stat().bloomFilter(col("id").multiply(3), 1000, 64 * 5); + assert (filter4.bitSize() == 64 * 5); + for (int i = 0; i < 1000; i++) { + assert (filter4.mightContain(i * 3)); + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 8f3ea5a2860ba..f01f126f7696d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -246,4 +246,26 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { .countMinSketch('id, depth = 10, width = 20, seed = 42) } } + + // This test only verifies some basic requirements, more correctness tests can be found in + // `BloomFilterSuite` in project spark-sketch. + test("Bloom filter") { + val df = sqlContext.range(1000) + + val filter1 = df.stat.bloomFilter("id", 1000, 0.03) + assert(filter1.expectedFpp() - 0.03 < 1e-3) + assert(0.until(1000).forall(filter1.mightContain)) + + val filter2 = df.stat.bloomFilter($"id" * 3, 1000, 0.03) + assert(filter2.expectedFpp() - 0.03 < 1e-3) + assert(0.until(1000).forall(i => filter2.mightContain(i * 3))) + + val filter3 = df.stat.bloomFilter("id", 1000, 64 * 5) + assert(filter3.bitSize() == 64 * 5) + assert(0.until(1000).forall(filter3.mightContain)) + + val filter4 = df.stat.bloomFilter($"id" * 3, 1000, 64 * 5) + assert(filter4.bitSize() == 64 * 5) + assert(0.until(1000).forall(i => filter4.mightContain(i * 3))) + } }