From a0dcaa8a52dcda16abccc93d3b63c251e60da3a3 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 26 Jan 2016 12:36:06 -0800 Subject: [PATCH 1/2] DataFrame API for Bloom filter --- .../apache/spark/util/sketch/BloomFilter.java | 12 +- .../spark/util/sketch/BloomFilterImpl.java | 157 ++++++++++++------ sql/core/pom.xml | 5 + .../spark/sql/DataFrameStatFunctions.scala | 76 +++++++++ .../apache/spark/sql/JavaDataFrameSuite.java | 31 ++++ .../apache/spark/sql/DataFrameStatSuite.scala | 22 +++ 6 files changed, 253 insertions(+), 50 deletions(-) diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java index 00378d58518f6..345f6185e322e 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java @@ -48,9 +48,9 @@ public enum Version { /** * {@code BloomFilter} binary format version 1 (all values written in big-endian order): * - Version number, always 1 (32 bit) + * - Number of hash functions (32 bit) * - Total number of words of the underlying bit array (32 bit) * - The words/longs (numWords * 64 bit) - * - Number of hash functions (32 bit) */ V1(1); @@ -95,6 +95,16 @@ int getVersionNumber() { */ public abstract boolean put(Object item); + /** + * A specific version of {@link #put(Object)}, that can only be used to put byte array. + */ + public abstract boolean putBinary(byte[] bytes); + + /** + * A specific version of {@link #put(Object)}, that can only be used to put long. + */ + public abstract boolean putLong(long l); + /** * Determines whether a given bloom filter is compatible with this bloom filter. For two * bloom filters to be compatible, they must have the same bit size. diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java index 1c08d07afaeaa..588b4ffab2e6b 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java @@ -19,10 +19,10 @@ import java.io.*; -public class BloomFilterImpl extends BloomFilter { +public class BloomFilterImpl extends BloomFilter implements Serializable { - private final int numHashFunctions; - private final BitArray bits; + private int numHashFunctions; + private BitArray bits; BloomFilterImpl(int numHashFunctions, long numBits) { this(new BitArray(numBits), numHashFunctions); @@ -33,6 +33,8 @@ private BloomFilterImpl(BitArray bits, int numHashFunctions) { this.numHashFunctions = numHashFunctions; } + private BloomFilterImpl() {} + @Override public boolean equals(Object other) { if (other == this) { @@ -63,55 +65,90 @@ public long bitSize() { return bits.bitSize(); } - private static long hashObjectToLong(Object item) { - if (item instanceof String) { - try { - byte[] bytes = ((String) item).getBytes("utf-8"); - return hashBytesToLong(bytes); - } catch (UnsupportedEncodingException e) { - throw new RuntimeException("Only support utf-8 string", e); - } + private byte[] getBytesFromUTF8String(Object s) { + try { + return ((String) s).getBytes("utf-8"); + } catch (UnsupportedEncodingException e) { + throw new RuntimeException("Only support utf-8 string", e); + } + } + + private long integralToLong(Object i) { + long longValue; + + if (i instanceof Long) { + longValue = (Long) i; + } else if (i instanceof Integer) { + longValue = ((Integer) i).longValue(); + } else if (i instanceof Short) { + longValue = ((Short) i).longValue(); + } else if (i instanceof Byte) { + longValue = ((Byte) i).longValue(); } else { - long longValue; - - if (item instanceof Long) { - longValue = (Long) item; - } else if (item instanceof Integer) { - longValue = ((Integer) item).longValue(); - } else if (item instanceof Short) { - longValue = ((Short) item).longValue(); - } else if (item instanceof Byte) { - longValue = ((Byte) item).longValue(); - } else { - throw new IllegalArgumentException( - "Support for " + item.getClass().getName() + " not implemented" - ); - } + throw new IllegalArgumentException( + "Support for " + i.getClass().getName() + " not implemented" + ); + } - int h1 = Murmur3_x86_32.hashLong(longValue, 0); - int h2 = Murmur3_x86_32.hashLong(longValue, h1); - return (((long) h1) << 32) | (h2 & 0xFFFFFFFFL); + return longValue; + } + + @Override + public boolean put(Object item) { + if (item instanceof String) { + return putBinary(getBytesFromUTF8String(item)); + } else { + return putLong(integralToLong(item)); } } - private static long hashBytesToLong(byte[] bytes) { + @Override + public boolean putBinary(byte[] bytes) { int h1 = Murmur3_x86_32.hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, bytes.length, 0); int h2 = Murmur3_x86_32.hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, bytes.length, h1); - return (((long) h1) << 32) | (h2 & 0xFFFFFFFFL); + + long bitSize = bits.bitSize(); + boolean bitsChanged = false; + for (int i = 1; i <= numHashFunctions; i++) { + int combinedHash = h1 + (i * h2); + // Flip all the bits if it's negative (guaranteed positive number) + if (combinedHash < 0) { + combinedHash = ~combinedHash; + } + bitsChanged |= bits.set(combinedHash % bitSize); + } + return bitsChanged; } - @Override - public boolean put(Object item) { + private boolean mightContainBinary(byte[] bytes) { + int h1 = Murmur3_x86_32.hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, bytes.length, 0); + int h2 = Murmur3_x86_32.hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, bytes.length, h1); + long bitSize = bits.bitSize(); + for (int i = 1; i <= numHashFunctions; i++) { + int combinedHash = h1 + (i * h2); + // Flip all the bits if it's negative (guaranteed positive number) + if (combinedHash < 0) { + combinedHash = ~combinedHash; + } + if (!bits.get(combinedHash % bitSize)) { + return false; + } + } + return true; + } - // Here we first hash the input element into 2 int hash values, h1 and h2, then produce n hash - // values by `h1 + i * h2` with 1 <= i <= numHashFunctions. - // Note that `CountMinSketch` use a different strategy for long type, it hash the input long - // element with every i to produce n hash values. - long hash64 = hashObjectToLong(item); - int h1 = (int) (hash64 >> 32); - int h2 = (int) hash64; + @Override + public boolean putLong(long l) { + // Here we first hash the input long element into 2 int hash values, h1 and h2, then produce n + // hash values by `h1 + i * h2` with 1 <= i <= numHashFunctions. + // Note that `CountMinSketch` use a different strategy, it hash the input long element with + // every i to produce n hash values. + // TODO: the strategy of `CountMinSketch` looks more advanced, should we follow it here? + int h1 = Murmur3_x86_32.hashLong(l, 0); + int h2 = Murmur3_x86_32.hashLong(l, h1); + long bitSize = bits.bitSize(); boolean bitsChanged = false; for (int i = 1; i <= numHashFunctions; i++) { int combinedHash = h1 + (i * h2); @@ -124,13 +161,11 @@ public boolean put(Object item) { return bitsChanged; } - @Override - public boolean mightContain(Object item) { - long bitSize = bits.bitSize(); - long hash64 = hashObjectToLong(item); - int h1 = (int) (hash64 >> 32); - int h2 = (int) hash64; + private boolean mightContainLong(long l) { + int h1 = Murmur3_x86_32.hashLong(l, 0); + int h2 = Murmur3_x86_32.hashLong(l, h1); + long bitSize = bits.bitSize(); for (int i = 1; i <= numHashFunctions; i++) { int combinedHash = h1 + (i * h2); // Flip all the bits if it's negative (guaranteed positive number) @@ -144,6 +179,15 @@ public boolean mightContain(Object item) { return true; } + @Override + public boolean mightContain(Object item) { + if (item instanceof String) { + return mightContainBinary(getBytesFromUTF8String(item)); + } else { + return mightContainLong(integralToLong(item)); + } + } + @Override public boolean isCompatible(BloomFilter other) { if (other == null) { @@ -191,11 +235,11 @@ public void writeTo(OutputStream out) throws IOException { DataOutputStream dos = new DataOutputStream(out); dos.writeInt(Version.V1.getVersionNumber()); - bits.writeTo(dos); dos.writeInt(numHashFunctions); + bits.writeTo(dos); } - public static BloomFilterImpl readFrom(InputStream in) throws IOException { + private void readFrom0(InputStream in) throws IOException { DataInputStream dis = new DataInputStream(in); int version = dis.readInt(); @@ -203,6 +247,21 @@ public static BloomFilterImpl readFrom(InputStream in) throws IOException { throw new IOException("Unexpected Bloom filter version number (" + version + ")"); } - return new BloomFilterImpl(BitArray.readFrom(dis), dis.readInt()); + this.numHashFunctions = dis.readInt(); + this.bits = BitArray.readFrom(dis); + } + + public static BloomFilterImpl readFrom(InputStream in) throws IOException { + BloomFilterImpl filter = new BloomFilterImpl(); + filter.readFrom0(in); + return filter; + } + + private void writeObject(ObjectOutputStream out) throws IOException { + writeTo(out); + } + + private void readObject(ObjectInputStream in) throws IOException { + readFrom0(in); } } diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 31b364f351d56..2b45ad9cd2eca 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -42,6 +42,11 @@ 1.5.6 jar + + org.apache.spark + spark-sketch_${scala.binary.version} + ${project.version} + org.apache.spark spark-core_${scala.binary.version} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index e66aa5f947181..08cf7ac662e0b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -22,7 +22,10 @@ import java.{lang => jl, util => ju} import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.stat._ +import org.apache.spark.sql.types.{IntegralType, StringType} +import org.apache.spark.util.sketch.BloomFilter /** * :: Experimental :: @@ -309,4 +312,77 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { def sampleBy[T](col: String, fractions: ju.Map[T, jl.Double], seed: Long): DataFrame = { sampleBy(col, fractions.asScala.toMap.asInstanceOf[Map[T, Double]], seed) } + + /** + * Builds a Bloom filter over a specified column. + * + * @param colName name of the column over which the filter is built + * @param expectedNumItems expected number of items which will be put into the filter. + * @param fpp expected false positive probability of the filter. + * + * @since 2.0.0 + */ + def bloomFilter(colName: String, expectedNumItems: Long, fpp: Double): BloomFilter = { + buildBloomFilter(Column(colName), BloomFilter.create(expectedNumItems, fpp)) + } + + /** + * Builds a Bloom filter over a specified column. + * + * @param col the column over which the filter is built + * @param expectedNumItems expected number of items which will be put into the filter. + * @param fpp expected false positive probability of the filter. + * + * @since 2.0.0 + */ + def bloomFilter(col: Column, expectedNumItems: Long, fpp: Double): BloomFilter = { + buildBloomFilter(col, BloomFilter.create(expectedNumItems, fpp)) + } + + /** + * Builds a Bloom filter over a specified column. + * + * @param colName name of the column over which the filter is built + * @param expectedNumItems expected number of items which will be put into the filter. + * @param numBits expected number of bits of the filter. + * + * @since 2.0.0 + */ + def bloomFilter(colName: String, expectedNumItems: Long, numBits: Long): BloomFilter = { + buildBloomFilter(Column(colName), BloomFilter.create(expectedNumItems, numBits)) + } + + /** + * Builds a Bloom filter over a specified column. + * + * @param col the column over which the filter is built + * @param expectedNumItems expected number of items which will be put into the filter. + * @param numBits expected number of bits of the filter. + * + * @since 2.0.0 + */ + def bloomFilter(col: Column, expectedNumItems: Long, numBits: Long): BloomFilter = { + buildBloomFilter(col, BloomFilter.create(expectedNumItems, numBits)) + } + + private def buildBloomFilter(col: Column, zero: BloomFilter): BloomFilter = { + val singleCol = df.select(col) + val colType = singleCol.schema.head.dataType + + require(colType == StringType || colType.isInstanceOf[IntegralType], + s"Bloom filter only supports string type and integral types, but got $colType.") + + val seqOp: (BloomFilter, InternalRow) => BloomFilter = if (colType == StringType) { + (filter, row) => + filter.putBinary(row.getUTF8String(0).getBytes) + filter + } else { + (filter, row) => + // TODO: specialize it. + filter.putLong(row.get(0, colType).asInstanceOf[Number].longValue()) + filter + } + + singleCol.queryExecution.toRdd.aggregate(zero)(seqOp, _ mergeInPlace _) + } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index ac1607ba3521a..c4e4edea3bc08 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -39,6 +39,7 @@ import org.apache.spark.sql.test.TestSQLContext; import org.apache.spark.sql.types.*; import static org.apache.spark.sql.types.DataTypes.*; +import org.apache.spark.util.sketch.BloomFilter; public class JavaDataFrameSuite { private transient JavaSparkContext jsc; @@ -299,6 +300,7 @@ public void pivot() { Assert.assertEquals(30000.0, actual[1].getDouble(2), 0.01); } + @Test public void testGenericLoad() { DataFrame df1 = context.read().format("text").load( Thread.currentThread().getContextClassLoader().getResource("text-suite.txt").toString()); @@ -321,4 +323,33 @@ public void testTextLoad() { Thread.currentThread().getContextClassLoader().getResource("text-suite2.txt").toString()); Assert.assertEquals(5L, df2.count()); } + + @Test + public void testBloomFilter() { + DataFrame df = context.range(1000); + + BloomFilter filter1 = df.stat().bloomFilter("id", 1000, 0.03); + assert(filter1.expectedFpp() - 0.03 < 1e-3); + for (int i = 0; i < 1000; i++) { + assert(filter1.mightContain(i)); + } + + BloomFilter filter2 = df.stat().bloomFilter(col("id").multiply(3), 1000, 0.03); + assert(filter2.expectedFpp() - 0.03 < 1e-3); + for (int i = 0; i < 1000; i++) { + assert(filter2.mightContain(i * 3)); + } + + BloomFilter filter3 = df.stat().bloomFilter("id", 1000, 64 * 5); + assert(filter3.bitSize() == 64 * 5); + for (int i = 0; i < 1000; i++) { + assert(filter3.mightContain(i)); + } + + BloomFilter filter4 = df.stat().bloomFilter(col("id").multiply(3), 1000, 64 * 5); + assert(filter4.bitSize() == 64 * 5); + for (int i = 0; i < 1000; i++) { + assert(filter4.mightContain(i * 3)); + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 63ad6c439a870..bb60252419074 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -210,4 +210,26 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { sampled.groupBy("key").count().orderBy("key"), Seq(Row(0, 6), Row(1, 11))) } + + // This test only verifies some basic requirements, more correctness tests can be found in + // `BloomFilterSuite` in project spark-sketch. + test("Bloom filter") { + val df = sqlContext.range(1000) + + val filter1 = df.stat.bloomFilter("id", 1000, 0.03) + assert(filter1.expectedFpp() - 0.03 < 1e-3) + assert(0.until(1000).forall(filter1.mightContain)) + + val filter2 = df.stat.bloomFilter($"id" * 3, 1000, 0.03) + assert(filter2.expectedFpp() - 0.03 < 1e-3) + assert(0.until(1000).forall(i => filter2.mightContain(i * 3))) + + val filter3 = df.stat.bloomFilter("id", 1000, 64 * 5) + assert(filter3.bitSize() == 64 * 5) + assert(0.until(1000).forall(filter3.mightContain)) + + val filter4 = df.stat.bloomFilter($"id" * 3, 1000, 64 * 5) + assert(filter4.bitSize() == 64 * 5) + assert(0.until(1000).forall(i => filter4.mightContain(i * 3))) + } } From bd0671ca177457d5f433c2744ed82c7e4960c965 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 27 Jan 2016 10:57:12 -0800 Subject: [PATCH 2/2] address comments --- .../apache/spark/util/sketch/BloomFilter.java | 28 +++++++++- .../spark/util/sketch/BloomFilterImpl.java | 56 ++++++++----------- .../spark/util/sketch/CountMinSketchImpl.java | 47 ++-------------- .../org/apache/spark/util/sketch/Utils.java | 48 ++++++++++++++++ .../spark/sql/DataFrameStatFunctions.scala | 2 + 5 files changed, 101 insertions(+), 80 deletions(-) create mode 100644 common/sketch/src/main/java/org/apache/spark/util/sketch/Utils.java diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java index 1ff743461791f..81772fcea0ec2 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java @@ -98,15 +98,20 @@ int getVersionNumber() { public abstract boolean put(Object item); /** - * A specific version of {@link #put(Object)}, that can only be used to put byte array. + * A specialized variant of {@link #put(Object)}, that can only be used to put utf-8 string. */ - public abstract boolean putBinary(byte[] bytes); + public abstract boolean putString(String str); /** - * A specific version of {@link #put(Object)}, that can only be used to put long. + * A specialized variant of {@link #put(Object)}, that can only be used to put long. */ public abstract boolean putLong(long l); + /** + * A specialized variant of {@link #put(Object)}, that can only be used to put byte array. + */ + public abstract boolean putBinary(byte[] bytes); + /** * Determines whether a given bloom filter is compatible with this bloom filter. For two * bloom filters to be compatible, they must have the same bit size. @@ -131,6 +136,23 @@ int getVersionNumber() { */ public abstract boolean mightContain(Object item); + /** + * A specialized variant of {@link #mightContain(Object)}, that can only be used to test utf-8 + * string. + */ + public abstract boolean mightContainString(String str); + + /** + * A specialized variant of {@link #mightContain(Object)}, that can only be used to test long. + */ + public abstract boolean mightContainLong(long l); + + /** + * A specialized variant of {@link #mightContain(Object)}, that can only be used to test byte + * array. + */ + public abstract boolean mightContainBinary(byte[] bytes); + /** * Writes out this {@link BloomFilter} to an output stream in binary format. * It is the caller's responsibility to close the stream. diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java index 588b4ffab2e6b..35107e0b389d7 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java @@ -65,43 +65,22 @@ public long bitSize() { return bits.bitSize(); } - private byte[] getBytesFromUTF8String(Object s) { - try { - return ((String) s).getBytes("utf-8"); - } catch (UnsupportedEncodingException e) { - throw new RuntimeException("Only support utf-8 string", e); - } - } - - private long integralToLong(Object i) { - long longValue; - - if (i instanceof Long) { - longValue = (Long) i; - } else if (i instanceof Integer) { - longValue = ((Integer) i).longValue(); - } else if (i instanceof Short) { - longValue = ((Short) i).longValue(); - } else if (i instanceof Byte) { - longValue = ((Byte) i).longValue(); - } else { - throw new IllegalArgumentException( - "Support for " + i.getClass().getName() + " not implemented" - ); - } - - return longValue; - } - @Override public boolean put(Object item) { if (item instanceof String) { - return putBinary(getBytesFromUTF8String(item)); + return putString((String) item); + } else if (item instanceof byte[]) { + return putBinary((byte[]) item); } else { - return putLong(integralToLong(item)); + return putLong(Utils.integralToLong(item)); } } + @Override + public boolean putString(String str) { + return putBinary(Utils.getBytesFromUTF8String(str)); + } + @Override public boolean putBinary(byte[] bytes) { int h1 = Murmur3_x86_32.hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, bytes.length, 0); @@ -120,7 +99,13 @@ public boolean putBinary(byte[] bytes) { return bitsChanged; } - private boolean mightContainBinary(byte[] bytes) { + @Override + public boolean mightContainString(String str) { + return mightContainBinary(Utils.getBytesFromUTF8String(str)); + } + + @Override + public boolean mightContainBinary(byte[] bytes) { int h1 = Murmur3_x86_32.hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, bytes.length, 0); int h2 = Murmur3_x86_32.hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, bytes.length, h1); @@ -161,7 +146,8 @@ public boolean putLong(long l) { return bitsChanged; } - private boolean mightContainLong(long l) { + @Override + public boolean mightContainLong(long l) { int h1 = Murmur3_x86_32.hashLong(l, 0); int h2 = Murmur3_x86_32.hashLong(l, h1); @@ -182,9 +168,11 @@ private boolean mightContainLong(long l) { @Override public boolean mightContain(Object item) { if (item instanceof String) { - return mightContainBinary(getBytesFromUTF8String(item)); + return mightContainString((String) item); + } else if (item instanceof byte[]) { + return mightContainBinary((byte[]) item); } else { - return mightContainLong(integralToLong(item)); + return mightContainLong(Utils.integralToLong(item)); } } diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java index 8cc29e4076307..e49ae22906c4c 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java @@ -40,8 +40,7 @@ class CountMinSketchImpl extends CountMinSketch implements Serializable { private double eps; private double confidence; - private CountMinSketchImpl() { - } + private CountMinSketchImpl() {} CountMinSketchImpl(int depth, int width, int seed) { this.depth = depth; @@ -143,23 +142,7 @@ public void add(Object item, long count) { if (item instanceof String) { addString((String) item, count); } else { - long longValue; - - if (item instanceof Long) { - longValue = (Long) item; - } else if (item instanceof Integer) { - longValue = ((Integer) item).longValue(); - } else if (item instanceof Short) { - longValue = ((Short) item).longValue(); - } else if (item instanceof Byte) { - longValue = ((Byte) item).longValue(); - } else { - throw new IllegalArgumentException( - "Support for " + item.getClass().getName() + " not implemented" - ); - } - - addLong(longValue, count); + addLong(Utils.integralToLong(item), count); } } @@ -201,13 +184,7 @@ private int hash(long item, int count) { } private static int[] getHashBuckets(String key, int hashCount, int max) { - byte[] b; - try { - b = key.getBytes("UTF-8"); - } catch (UnsupportedEncodingException e) { - throw new RuntimeException(e); - } - return getHashBuckets(b, hashCount, max); + return getHashBuckets(Utils.getBytesFromUTF8String(key), hashCount, max); } private static int[] getHashBuckets(byte[] b, int hashCount, int max) { @@ -225,23 +202,7 @@ public long estimateCount(Object item) { if (item instanceof String) { return estimateCountForStringItem((String) item); } else { - long longValue; - - if (item instanceof Long) { - longValue = (Long) item; - } else if (item instanceof Integer) { - longValue = ((Integer) item).longValue(); - } else if (item instanceof Short) { - longValue = ((Short) item).longValue(); - } else if (item instanceof Byte) { - longValue = ((Byte) item).longValue(); - } else { - throw new IllegalArgumentException( - "Support for " + item.getClass().getName() + " not implemented" - ); - } - - return estimateCountForLongItem(longValue); + return estimateCountForLongItem(Utils.integralToLong(item)); } } diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/Utils.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/Utils.java new file mode 100644 index 0000000000000..a6b33313035b0 --- /dev/null +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/Utils.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.sketch; + +import java.io.UnsupportedEncodingException; + +public class Utils { + public static byte[] getBytesFromUTF8String(String str) { + try { + return str.getBytes("utf-8"); + } catch (UnsupportedEncodingException e) { + throw new IllegalArgumentException("Only support utf-8 string", e); + } + } + + public static long integralToLong(Object i) { + long longValue; + + if (i instanceof Long) { + longValue = (Long) i; + } else if (i instanceof Integer) { + longValue = ((Integer) i).longValue(); + } else if (i instanceof Short) { + longValue = ((Short) i).longValue(); + } else if (i instanceof Byte) { + longValue = ((Byte) i).longValue(); + } else { + throw new IllegalArgumentException("Unsupported data type " + i.getClass().getName()); + } + + return longValue; + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index f22ab06c1dd26..b0b6995a2214f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -449,6 +449,8 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { val seqOp: (BloomFilter, InternalRow) => BloomFilter = if (colType == StringType) { (filter, row) => + // For string type, we can get bytes of our `UTF8String` directly, and call the `putBinary` + // instead of `putString` to avoid unnecessary conversion. filter.putBinary(row.getUTF8String(0).getBytes) filter } else {