Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -199,9 +199,9 @@ private static int optimalNumOfHashFunctions(long n, long m) {
* See http://en.wikipedia.org/wiki/Bloom_filter#Probability_of_false_positives for the formula.
*
* @param n expected insertions (must be positive)
* @param p false positive rate (must be 0 < p < 1)
* @param p false positive rate (must be 0 &lt; p &lt; 1)
*/
private static long optimalNumOfBits(long n, double p) {
public static long optimalNumOfBits(long n, double p) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change to public is because DataFrameStatFunctions#buildBloomFilter needs to use this method to calculate the numBits from expectedNumItems and fpp

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you find (must be 0 &lt; p &lt; 1) to be quite messy, we can try changing it to (must be {@literal 0 < p < 1})

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am good.

return (long) (-n * Math.log(p) / (Math.log(2) * Math.log(2)));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,16 @@
package org.apache.spark.sql

import java.{lang => jl, util => ju}
import java.io.ByteArrayInputStream

import scala.collection.JavaConverters._

import org.apache.spark.SparkException
import org.apache.spark.connect.proto.{Relation, StatSampleBy}
import org.apache.spark.sql.DataFrameStatFunctions.approxQuantileResultEncoder
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, BinaryEncoder, PrimitiveDoubleEncoder}
import org.apache.spark.sql.functions.lit
import org.apache.spark.util.sketch.CountMinSketch
import org.apache.spark.util.sketch.{BloomFilter, CountMinSketch}

/**
* Statistic functions for `DataFrame`s.
Expand Down Expand Up @@ -584,6 +586,90 @@ final class DataFrameStatFunctions private[sql] (sparkSession: SparkSession, roo
}
CountMinSketch.readFrom(ds.head())
}

/**
* Builds a Bloom filter over a specified column.
*
* @param colName
* name of the column over which the filter is built
* @param expectedNumItems
* expected number of items which will be put into the filter.
* @param fpp
* expected false positive probability of the filter.
* @since 3.5.0
*/
def bloomFilter(colName: String, expectedNumItems: Long, fpp: Double): BloomFilter = {
buildBloomFilter(Column(colName), expectedNumItems, -1L, fpp)
}

/**
* Builds a Bloom filter over a specified column.
*
* @param col
* the column over which the filter is built
* @param expectedNumItems
* expected number of items which will be put into the filter.
* @param fpp
* expected false positive probability of the filter.
* @since 3.5.0
*/
def bloomFilter(col: Column, expectedNumItems: Long, fpp: Double): BloomFilter = {
buildBloomFilter(col, expectedNumItems, -1L, fpp)
}

/**
* Builds a Bloom filter over a specified column.
*
* @param colName
* name of the column over which the filter is built
* @param expectedNumItems
* expected number of items which will be put into the filter.
* @param numBits
* expected number of bits of the filter.
* @since 3.5.0
*/
def bloomFilter(colName: String, expectedNumItems: Long, numBits: Long): BloomFilter = {
buildBloomFilter(Column(colName), expectedNumItems, numBits, Double.NaN)
}

/**
* Builds a Bloom filter over a specified column.
*
* @param col
* the column over which the filter is built
* @param expectedNumItems
* expected number of items which will be put into the filter.
* @param numBits
* expected number of bits of the filter.
* @since 3.5.0
*/
def bloomFilter(col: Column, expectedNumItems: Long, numBits: Long): BloomFilter = {
buildBloomFilter(col, expectedNumItems, numBits, Double.NaN)
}

private def buildBloomFilter(
col: Column,
expectedNumItems: Long,
numBits: Long,
fpp: Double): BloomFilter = {
def numBitsValue: Long = if (!fpp.isNaN) {
BloomFilter.optimalNumOfBits(expectedNumItems, fpp)
} else {
numBits
}

if (fpp <= 0d || fpp >= 1d) {
Copy link
Contributor Author

@LuciferYang LuciferYang Aug 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the subsequent process, fpp is no longer involved, so a check is added here. Otherwise, if the user passes an invalid fpp value, the error message will "Number of bits must be positive", which is quite strange.

throw new SparkException("False positive probability must be within range (0.0, 1.0)")
}
val agg = Column.fn("bloom_filter_agg", col, lit(expectedNumItems), lit(numBitsValue))

val ds = sparkSession.newDataset(BinaryEncoder) { builder =>
builder.getProjectBuilder
.setInput(root)
.addExpressions(agg.expr)
}
BloomFilter.readFrom(new ByteArrayInputStream(ds.head()))
}
}

private object DataFrameStatFunctions {
Expand Down
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add a negative test case where mightContain evaluates to false?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

6ffbfa0 Added checks for values that are definitely not included.

Original file line number Diff line number Diff line change
Expand Up @@ -176,4 +176,91 @@ class ClientDataFrameStatSuite extends RemoteSparkSession {
assert(sketch.relativeError() === 0.001)
assert(sketch.confidence() === 0.99 +- 5e-3)
}

test("Bloom filter -- Long Column") {
val session = spark
import session.implicits._
val data = Seq(-143, -32, -5, 1, 17, 39, 43, 101, 127, 997).map(_.toLong)
val df = data.toDF("id")
val negativeValues = Seq(-11, 1021, 32767).map(_.toLong)
checkBloomFilter(data, negativeValues, df)
}

test("Bloom filter -- Int Column") {
val session = spark
import session.implicits._
val data = Seq(-143, -32, -5, 1, 17, 39, 43, 101, 127, 997)
val df = data.toDF("id")
val negativeValues = Seq(-11, 1021, 32767)
checkBloomFilter(data, negativeValues, df)
}

test("Bloom filter -- Short Column") {
val session = spark
import session.implicits._
val data = Seq(-143, -32, -5, 1, 17, 39, 43, 101, 127, 997).map(_.toShort)
val df = data.toDF("id")
val negativeValues = Seq(-11, 1021, 32767).map(_.toShort)
checkBloomFilter(data, negativeValues, df)
}

test("Bloom filter -- Byte Column") {
val session = spark
import session.implicits._
val data = Seq(-32, -5, 1, 17, 39, 43, 101, 127).map(_.toByte)
val df = data.toDF("id")
val negativeValues = Seq(-101, 55, 113).map(_.toByte)
checkBloomFilter(data, negativeValues, df)
}

test("Bloom filter -- String Column") {
val session = spark
import session.implicits._
val data = Seq(-143, -32, -5, 1, 17, 39, 43, 101, 127, 997).map(_.toString)
val df = data.toDF("id")
val negativeValues = Seq(-11, 1021, 32767).map(_.toString)
checkBloomFilter(data, negativeValues, df)
}

private def checkBloomFilter(
data: Seq[Any],
notContainValues: Seq[Any],
df: DataFrame): Unit = {
val filter1 = df.stat.bloomFilter("id", 1000, 0.03)
assert(filter1.expectedFpp() - 0.03 < 1e-3)
assert(data.forall(filter1.mightContain))
assert(notContainValues.forall(n => !filter1.mightContain(n)))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added checks for values that are definitely not included.

val filter2 = df.stat.bloomFilter("id", 1000, 64 * 5)
assert(filter2.bitSize() == 64 * 5)
assert(data.forall(filter2.mightContain))
assert(notContainValues.forall(n => !filter2.mightContain(n)))
}

test("Bloom filter -- Wrong dataType Column") {
val session = spark
import session.implicits._
val data = Range(0, 1000).map(_.toDouble)
val message = intercept[AnalysisException] {
data.toDF("id").stat.bloomFilter("id", 1000, 0.03)
}.getMessage
assert(message.contains("DATATYPE_MISMATCH.BLOOM_FILTER_WRONG_TYPE"))
}

test("Bloom filter test invalid inputs") {
val df = spark.range(1000).toDF("id")
val message1 = intercept[SparkException] {
df.stat.bloomFilter("id", -1000, 100)
}.getMessage
assert(message1.contains("Expected insertions must be positive"))

val message2 = intercept[SparkException] {
df.stat.bloomFilter("id", 1000, -100)
}.getMessage
assert(message2.contains("Number of bits must be positive"))

val message3 = intercept[SparkException] {
df.stat.bloomFilter("id", 1000, -1.0)
}.getMessage
assert(message3.contains("False positive probability must be within range (0.0, 1.0)"))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -163,9 +163,6 @@ object CheckConnectJvmClientCompatibility {
// DataFrameNaFunctions
ProblemFilters.exclude[Problem]("org.apache.spark.sql.DataFrameNaFunctions.fillValue"),

// DataFrameStatFunctions
ProblemFilters.exclude[Problem]("org.apache.spark.sql.DataFrameStatFunctions.bloomFilter"),

// Dataset
ProblemFilters.exclude[MissingClassProblem](
"org.apache.spark.sql.Dataset$" // private[sql]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, Mu
import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, ExpressionEncoder, RowEncoder}
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.UnboundRowEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.BloomFilterAggregate
import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException, ParserUtils}
import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter, UsingJoin}
import org.apache.spark.sql.catalyst.plans.logical
Expand Down Expand Up @@ -1731,6 +1732,36 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging {
val ignoreNulls = extractBoolean(children(3), "ignoreNulls")
Some(Lead(children.head, children(1), children(2), ignoreNulls))

case "bloom_filter_agg" if fun.getArgumentsCount == 3 =>
// [col, expectedNumItems: Long, numBits: Long]
val children = fun.getArgumentsList.asScala.map(transformExpression)

// Check expectedNumItems is LongType and value greater than 0L
val expectedNumItemsExpr = children(1)
val expectedNumItems = expectedNumItemsExpr match {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change to Column.fn("bloom_filter_agg", col, lit(expectedNumItems), lit(numBits), the logic indeed appears simpler now, and I have a point for discussion.

@hvanhovell Do you think we should check the validity of the input here? By checking here, the error message can be exactly the same as the api in sql/core. However, if we use the validation mechanism of BloomFilterAggregate, the content of the error message will be different, but the code will be more concise.

Perhaps we don't need to ensure that the error message is the same as before?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can do that in a follow-up.

case Literal(l: Long, LongType) => l
case _ =>
throw InvalidPlanInput("Expected insertions must be long literal.")
}
if (expectedNumItems <= 0L) {
throw InvalidPlanInput("Expected insertions must be positive.")
}

val numBitsExpr = children(2)
// Check numBits is LongType and value greater than 0L
numBitsExpr match {
case Literal(numBits: Long, LongType) =>
if (numBits <= 0L) {
throw InvalidPlanInput("Number of bits must be positive.")
}
case _ =>
throw InvalidPlanInput("Number of bits must be long literal.")
}

Some(
new BloomFilterAggregate(children.head, expectedNumItemsExpr, numBitsExpr)
.toAggregateExpression())

case "window" if Seq(2, 3, 4).contains(fun.getArgumentsCount) =>
val children = fun.getArgumentsList.asScala.map(transformExpression)
val timeCol = children.head
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.trees.TernaryLike
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.{RUNTIME_BLOOM_FILTER_MAX_NUM_BITS, RUNTIME_BLOOM_FILTER_MAX_NUM_ITEMS}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.sketch.BloomFilter

/**
Expand Down Expand Up @@ -78,7 +79,7 @@ case class BloomFilterAggregate(
"exprName" -> "estimatedNumItems or numBits"
)
)
case (LongType, LongType, LongType) =>
case (LongType | IntegerType | ShortType | ByteType | StringType, LongType, LongType) =>
if (!estimatedNumItemsExpression.foldable) {
DataTypeMismatch(
errorSubClass = "NON_FOLDABLE_INPUT",
Expand Down Expand Up @@ -150,6 +151,15 @@ case class BloomFilterAggregate(
Math.min(numBitsExpression.eval().asInstanceOf[Number].longValue,
SQLConf.get.getConf(RUNTIME_BLOOM_FILTER_MAX_NUM_BITS))

// Mark as lazy so that `updater` is not evaluated during tree transformation.
private lazy val updater: BloomFilterUpdater = child.dataType match {
case LongType => LongUpdater
case IntegerType => IntUpdater
case ShortType => ShortUpdater
case ByteType => ByteUpdater
case StringType => BinaryUpdater
}

override def first: Expression = child

override def second: Expression = estimatedNumItemsExpression
Expand All @@ -174,7 +184,7 @@ case class BloomFilterAggregate(
if (value == null) {
return buffer
}
buffer.putLong(value.asInstanceOf[Long])
updater.update(buffer, value)
buffer
}

Expand Down Expand Up @@ -224,3 +234,32 @@ object BloomFilterAggregate {
bloomFilter
}
}

private trait BloomFilterUpdater {
def update(bf: BloomFilter, v: Any): Boolean
}

private object LongUpdater extends BloomFilterUpdater with Serializable {
override def update(bf: BloomFilter, v: Any): Boolean =
bf.putLong(v.asInstanceOf[Long])
}

private object IntUpdater extends BloomFilterUpdater with Serializable {
override def update(bf: BloomFilter, v: Any): Boolean =
bf.putLong(v.asInstanceOf[Int])
}

private object ShortUpdater extends BloomFilterUpdater with Serializable {
override def update(bf: BloomFilter, v: Any): Boolean =
bf.putLong(v.asInstanceOf[Short])
}

private object ByteUpdater extends BloomFilterUpdater with Serializable {
override def update(bf: BloomFilter, v: Any): Boolean =
bf.putLong(v.asInstanceOf[Byte])
}

private object BinaryUpdater extends BloomFilterUpdater with Serializable {
override def update(bf: BloomFilter, v: Any): Boolean =
bf.putBinary(v.asInstanceOf[UTF8String].getBytes)
}