-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-42664][CONNECT] Support bloomFilter function for DataFrameStatFunctions
#42414
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
beaaae6
a154c51
dfbe1c4
d600ebb
4709dd5
fe958a6
6ffbfa0
cf3104a
80a6b4b
1b88765
473ad60
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,14 +18,16 @@ | |
| package org.apache.spark.sql | ||
|
|
||
| import java.{lang => jl, util => ju} | ||
| import java.io.ByteArrayInputStream | ||
|
|
||
| import scala.collection.JavaConverters._ | ||
|
|
||
| import org.apache.spark.SparkException | ||
| import org.apache.spark.connect.proto.{Relation, StatSampleBy} | ||
| import org.apache.spark.sql.DataFrameStatFunctions.approxQuantileResultEncoder | ||
| import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, BinaryEncoder, PrimitiveDoubleEncoder} | ||
| import org.apache.spark.sql.functions.lit | ||
| import org.apache.spark.util.sketch.CountMinSketch | ||
| import org.apache.spark.util.sketch.{BloomFilter, CountMinSketch} | ||
|
|
||
| /** | ||
| * Statistic functions for `DataFrame`s. | ||
|
|
@@ -584,6 +586,90 @@ final class DataFrameStatFunctions private[sql] (sparkSession: SparkSession, roo | |
| } | ||
| CountMinSketch.readFrom(ds.head()) | ||
| } | ||
|
|
||
| /** | ||
| * Builds a Bloom filter over a specified column. | ||
| * | ||
| * @param colName | ||
| * name of the column over which the filter is built | ||
| * @param expectedNumItems | ||
| * expected number of items which will be put into the filter. | ||
| * @param fpp | ||
| * expected false positive probability of the filter. | ||
| * @since 3.5.0 | ||
| */ | ||
| def bloomFilter(colName: String, expectedNumItems: Long, fpp: Double): BloomFilter = { | ||
| buildBloomFilter(Column(colName), expectedNumItems, -1L, fpp) | ||
| } | ||
|
|
||
| /** | ||
| * Builds a Bloom filter over a specified column. | ||
| * | ||
| * @param col | ||
| * the column over which the filter is built | ||
| * @param expectedNumItems | ||
| * expected number of items which will be put into the filter. | ||
| * @param fpp | ||
| * expected false positive probability of the filter. | ||
| * @since 3.5.0 | ||
| */ | ||
| def bloomFilter(col: Column, expectedNumItems: Long, fpp: Double): BloomFilter = { | ||
| buildBloomFilter(col, expectedNumItems, -1L, fpp) | ||
| } | ||
|
|
||
| /** | ||
| * Builds a Bloom filter over a specified column. | ||
| * | ||
| * @param colName | ||
| * name of the column over which the filter is built | ||
| * @param expectedNumItems | ||
| * expected number of items which will be put into the filter. | ||
| * @param numBits | ||
| * expected number of bits of the filter. | ||
| * @since 3.5.0 | ||
| */ | ||
| def bloomFilter(colName: String, expectedNumItems: Long, numBits: Long): BloomFilter = { | ||
| buildBloomFilter(Column(colName), expectedNumItems, numBits, Double.NaN) | ||
| } | ||
|
|
||
| /** | ||
| * Builds a Bloom filter over a specified column. | ||
| * | ||
| * @param col | ||
| * the column over which the filter is built | ||
| * @param expectedNumItems | ||
| * expected number of items which will be put into the filter. | ||
| * @param numBits | ||
| * expected number of bits of the filter. | ||
| * @since 3.5.0 | ||
| */ | ||
| def bloomFilter(col: Column, expectedNumItems: Long, numBits: Long): BloomFilter = { | ||
| buildBloomFilter(col, expectedNumItems, numBits, Double.NaN) | ||
| } | ||
|
|
||
| private def buildBloomFilter( | ||
| col: Column, | ||
| expectedNumItems: Long, | ||
| numBits: Long, | ||
| fpp: Double): BloomFilter = { | ||
| def numBitsValue: Long = if (!fpp.isNaN) { | ||
| BloomFilter.optimalNumOfBits(expectedNumItems, fpp) | ||
| } else { | ||
| numBits | ||
| } | ||
|
|
||
| if (fpp <= 0d || fpp >= 1d) { | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In the subsequent process, |
||
| throw new SparkException("False positive probability must be within range (0.0, 1.0)") | ||
| } | ||
| val agg = Column.fn("bloom_filter_agg", col, lit(expectedNumItems), lit(numBitsValue)) | ||
|
|
||
| val ds = sparkSession.newDataset(BinaryEncoder) { builder => | ||
| builder.getProjectBuilder | ||
| .setInput(root) | ||
| .addExpressions(agg.expr) | ||
| } | ||
| BloomFilter.readFrom(new ByteArrayInputStream(ds.head())) | ||
| } | ||
| } | ||
|
|
||
| private object DataFrameStatFunctions { | ||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe add a negative test case where mightContain evaluates to false?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 6ffbfa0 Added checks for values that are definitely not included. |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -176,4 +176,91 @@ class ClientDataFrameStatSuite extends RemoteSparkSession { | |
| assert(sketch.relativeError() === 0.001) | ||
| assert(sketch.confidence() === 0.99 +- 5e-3) | ||
| } | ||
|
|
||
| test("Bloom filter -- Long Column") { | ||
| val session = spark | ||
| import session.implicits._ | ||
| val data = Seq(-143, -32, -5, 1, 17, 39, 43, 101, 127, 997).map(_.toLong) | ||
| val df = data.toDF("id") | ||
| val negativeValues = Seq(-11, 1021, 32767).map(_.toLong) | ||
| checkBloomFilter(data, negativeValues, df) | ||
| } | ||
|
|
||
| test("Bloom filter -- Int Column") { | ||
| val session = spark | ||
| import session.implicits._ | ||
| val data = Seq(-143, -32, -5, 1, 17, 39, 43, 101, 127, 997) | ||
| val df = data.toDF("id") | ||
| val negativeValues = Seq(-11, 1021, 32767) | ||
| checkBloomFilter(data, negativeValues, df) | ||
| } | ||
|
|
||
| test("Bloom filter -- Short Column") { | ||
| val session = spark | ||
| import session.implicits._ | ||
| val data = Seq(-143, -32, -5, 1, 17, 39, 43, 101, 127, 997).map(_.toShort) | ||
| val df = data.toDF("id") | ||
| val negativeValues = Seq(-11, 1021, 32767).map(_.toShort) | ||
| checkBloomFilter(data, negativeValues, df) | ||
| } | ||
|
|
||
| test("Bloom filter -- Byte Column") { | ||
| val session = spark | ||
| import session.implicits._ | ||
| val data = Seq(-32, -5, 1, 17, 39, 43, 101, 127).map(_.toByte) | ||
| val df = data.toDF("id") | ||
| val negativeValues = Seq(-101, 55, 113).map(_.toByte) | ||
| checkBloomFilter(data, negativeValues, df) | ||
| } | ||
|
|
||
| test("Bloom filter -- String Column") { | ||
| val session = spark | ||
| import session.implicits._ | ||
| val data = Seq(-143, -32, -5, 1, 17, 39, 43, 101, 127, 997).map(_.toString) | ||
| val df = data.toDF("id") | ||
| val negativeValues = Seq(-11, 1021, 32767).map(_.toString) | ||
| checkBloomFilter(data, negativeValues, df) | ||
| } | ||
|
|
||
| private def checkBloomFilter( | ||
| data: Seq[Any], | ||
| notContainValues: Seq[Any], | ||
| df: DataFrame): Unit = { | ||
| val filter1 = df.stat.bloomFilter("id", 1000, 0.03) | ||
| assert(filter1.expectedFpp() - 0.03 < 1e-3) | ||
| assert(data.forall(filter1.mightContain)) | ||
| assert(notContainValues.forall(n => !filter1.mightContain(n))) | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added checks for values that are definitely not included. |
||
| val filter2 = df.stat.bloomFilter("id", 1000, 64 * 5) | ||
| assert(filter2.bitSize() == 64 * 5) | ||
| assert(data.forall(filter2.mightContain)) | ||
| assert(notContainValues.forall(n => !filter2.mightContain(n))) | ||
| } | ||
|
|
||
| test("Bloom filter -- Wrong dataType Column") { | ||
| val session = spark | ||
| import session.implicits._ | ||
| val data = Range(0, 1000).map(_.toDouble) | ||
| val message = intercept[AnalysisException] { | ||
| data.toDF("id").stat.bloomFilter("id", 1000, 0.03) | ||
| }.getMessage | ||
| assert(message.contains("DATATYPE_MISMATCH.BLOOM_FILTER_WRONG_TYPE")) | ||
| } | ||
|
|
||
| test("Bloom filter test invalid inputs") { | ||
| val df = spark.range(1000).toDF("id") | ||
| val message1 = intercept[SparkException] { | ||
| df.stat.bloomFilter("id", -1000, 100) | ||
| }.getMessage | ||
| assert(message1.contains("Expected insertions must be positive")) | ||
|
|
||
| val message2 = intercept[SparkException] { | ||
| df.stat.bloomFilter("id", 1000, -100) | ||
| }.getMessage | ||
| assert(message2.contains("Number of bits must be positive")) | ||
|
|
||
| val message3 = intercept[SparkException] { | ||
| df.stat.bloomFilter("id", 1000, -1.0) | ||
| }.getMessage | ||
| assert(message3.contains("False positive probability must be within range (0.0, 1.0)")) | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -48,6 +48,7 @@ import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, Mu | |
| import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, ExpressionEncoder, RowEncoder} | ||
| import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.UnboundRowEncoder | ||
| import org.apache.spark.sql.catalyst.expressions._ | ||
| import org.apache.spark.sql.catalyst.expressions.aggregate.BloomFilterAggregate | ||
| import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException, ParserUtils} | ||
| import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter, UsingJoin} | ||
| import org.apache.spark.sql.catalyst.plans.logical | ||
|
|
@@ -1731,6 +1732,36 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { | |
| val ignoreNulls = extractBoolean(children(3), "ignoreNulls") | ||
| Some(Lead(children.head, children(1), children(2), ignoreNulls)) | ||
|
|
||
| case "bloom_filter_agg" if fun.getArgumentsCount == 3 => | ||
| // [col, expectedNumItems: Long, numBits: Long] | ||
| val children = fun.getArgumentsList.asScala.map(transformExpression) | ||
|
|
||
| // Check expectedNumItems is LongType and value greater than 0L | ||
| val expectedNumItemsExpr = children(1) | ||
| val expectedNumItems = expectedNumItemsExpr match { | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Change to @hvanhovell Do you think we should check the validity of the input here? By checking here, the error message can be exactly the same as the api in Perhaps we don't need to ensure that the error message is the same as before?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can do that in a follow-up. |
||
| case Literal(l: Long, LongType) => l | ||
| case _ => | ||
| throw InvalidPlanInput("Expected insertions must be long literal.") | ||
| } | ||
| if (expectedNumItems <= 0L) { | ||
| throw InvalidPlanInput("Expected insertions must be positive.") | ||
| } | ||
|
|
||
| val numBitsExpr = children(2) | ||
| // Check numBits is LongType and value greater than 0L | ||
| numBitsExpr match { | ||
| case Literal(numBits: Long, LongType) => | ||
| if (numBits <= 0L) { | ||
| throw InvalidPlanInput("Number of bits must be positive.") | ||
| } | ||
| case _ => | ||
| throw InvalidPlanInput("Number of bits must be long literal.") | ||
| } | ||
|
|
||
| Some( | ||
| new BloomFilterAggregate(children.head, expectedNumItemsExpr, numBitsExpr) | ||
| .toAggregateExpression()) | ||
|
|
||
| case "window" if Seq(2, 3, 4).contains(fun.getArgumentsCount) => | ||
| val children = fun.getArgumentsList.asScala.map(transformExpression) | ||
| val timeCol = children.head | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Change to
publicis becauseDataFrameStatFunctions#buildBloomFilterneeds to use this method to calculate thenumBitsfromexpectedNumItemsandfppThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you find
(must be 0 < p < 1)to be quite messy, we can try changing it to(must be {@literal 0 < p < 1})There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am good.