-
Notifications
You must be signed in to change notification settings - Fork 28.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-18429] [SQL] implement a new Aggregate for CountMinSketch #15877
Changes from all commits
a9d5e03
15c7ca5
a283e5f
dedcfca
3e86075
ca4a13f
b009ff8
1bfb6fd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,146 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.sql.catalyst.expressions.aggregate | ||
|
||
import java.io.{ByteArrayInputStream, ByteArrayOutputStream} | ||
|
||
import org.apache.spark.sql.catalyst.InternalRow | ||
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult | ||
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} | ||
import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescription} | ||
import org.apache.spark.sql.types._ | ||
import org.apache.spark.unsafe.types.UTF8String | ||
import org.apache.spark.util.sketch.CountMinSketch | ||
|
||
/** | ||
* This function returns a count-min sketch of a column with the given esp, confidence and seed. | ||
* A count-min sketch is a probabilistic data structure used for summarizing streams of data in | ||
* sub-linear space, which is useful for equality predicates and join size estimation. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe something on the return type? A developer should know how to work with these bytes. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok, I wrote this in usage, I'll add it here too, thanks. |
||
* The result returned by the function is an array of bytes, which should be deserialized to a | ||
* `CountMinSketch` before usage. | ||
* | ||
* @param child child expression that can produce column value with `child.eval(inputRow)` | ||
* @param epsExpression relative error, must be positive | ||
* @param confidenceExpression confidence, must be positive and less than 1.0 | ||
* @param seedExpression random seed | ||
*/ | ||
@ExpressionDescription( | ||
usage = """ | ||
_FUNC_(col, eps, confidence, seed) - Returns a count-min sketch of a column with the given esp, | ||
confidence and seed. The result is an array of bytes, which should be deserialized to a | ||
`CountMinSketch` before usage. `CountMinSketch` is useful for equality predicates and join | ||
size estimation. | ||
""") | ||
case class CountMinSketchAgg( | ||
child: Expression, | ||
epsExpression: Expression, | ||
confidenceExpression: Expression, | ||
seedExpression: Expression, | ||
override val mutableAggBufferOffset: Int, | ||
override val inputAggBufferOffset: Int) extends TypedImperativeAggregate[CountMinSketch] { | ||
|
||
def this( | ||
child: Expression, | ||
epsExpression: Expression, | ||
confidenceExpression: Expression, | ||
seedExpression: Expression) = { | ||
this(child, epsExpression, confidenceExpression, seedExpression, 0, 0) | ||
} | ||
|
||
// Mark as lazy so that they are not evaluated during tree transformation. | ||
private lazy val eps: Double = epsExpression.eval().asInstanceOf[Double] | ||
private lazy val confidence: Double = confidenceExpression.eval().asInstanceOf[Double] | ||
private lazy val seed: Int = seedExpression.eval().asInstanceOf[Int] | ||
|
||
override def checkInputDataTypes(): TypeCheckResult = { | ||
val defaultCheck = super.checkInputDataTypes() | ||
if (defaultCheck.isFailure) { | ||
defaultCheck | ||
} else if (!epsExpression.foldable || !confidenceExpression.foldable || | ||
!seedExpression.foldable) { | ||
TypeCheckFailure( | ||
"The eps, confidence or seed provided must be a literal or constant foldable") | ||
} else if (epsExpression.eval() == null || confidenceExpression.eval() == null || | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we also check for negative eps and confidence values? |
||
seedExpression.eval() == null) { | ||
TypeCheckFailure("The eps, confidence or seed provided should not be null") | ||
} else if (eps <= 0D) { | ||
TypeCheckFailure(s"Relative error must be positive (current value = $eps)") | ||
} else if (confidence <= 0D || confidence >= 1D) { | ||
TypeCheckFailure(s"Confidence must be within range (0.0, 1.0) (current value = $confidence)") | ||
} else { | ||
TypeCheckSuccess | ||
} | ||
} | ||
|
||
override def createAggregationBuffer(): CountMinSketch = { | ||
CountMinSketch.create(eps, confidence, seed) | ||
} | ||
|
||
override def update(buffer: CountMinSketch, input: InternalRow): Unit = { | ||
val value = child.eval(input) | ||
// Ignore empty rows | ||
if (value != null) { | ||
child.dataType match { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Lets not do a pattern match for every update. We should use an update function instead, for example: private[this] val doUpdate: (CountMinSketch, Any) => Unit = child.dataType match {
case StringType => (cms, value) => cms.addBinary(value.asInstanceOf[UTF8String].getBytes)
case ByteType => (cms, value) => cms.addLong(value..asInstanceOf[Byte])
...
}
override def update(buffer: CountMinSketch, input: InternalRow): Unit = {
val value = child.eval(input)
if (value != null) {
doUpdate(buffer, value)
}
} |
||
// `Decimal` and `UTF8String` are internal types in spark sql, we need to convert them | ||
// into acceptable types for `CountMinSketch`. | ||
case DecimalType() => buffer.add(value.asInstanceOf[Decimal].toJavaBigDecimal) | ||
// For string type, we can get bytes of our `UTF8String` directly, and call the `addBinary` | ||
// instead of `addString` to avoid unnecessary conversion. | ||
case StringType => buffer.addBinary(value.asInstanceOf[UTF8String].getBytes) | ||
case _ => buffer.add(value) | ||
} | ||
} | ||
} | ||
|
||
override def merge(buffer: CountMinSketch, input: CountMinSketch): Unit = { | ||
buffer.mergeInPlace(input) | ||
} | ||
|
||
override def eval(buffer: CountMinSketch): Any = serialize(buffer) | ||
|
||
override def serialize(buffer: CountMinSketch): Array[Byte] = { | ||
val out = new ByteArrayOutputStream() | ||
buffer.writeTo(out) | ||
out.toByteArray | ||
} | ||
|
||
override def deserialize(storageFormat: Array[Byte]): CountMinSketch = { | ||
val in = new ByteArrayInputStream(storageFormat) | ||
CountMinSketch.readFrom(in) | ||
} | ||
|
||
override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): CountMinSketchAgg = | ||
copy(mutableAggBufferOffset = newMutableAggBufferOffset) | ||
|
||
override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): CountMinSketchAgg = | ||
copy(inputAggBufferOffset = newInputAggBufferOffset) | ||
|
||
override def inputTypes: Seq[AbstractDataType] = { | ||
Seq(TypeCollection(NumericType, StringType, DateType, TimestampType, BooleanType, BinaryType), | ||
DoubleType, DoubleType, IntegerType) | ||
} | ||
|
||
override def nullable: Boolean = false | ||
|
||
override def dataType: DataType = BinaryType | ||
|
||
override def children: Seq[Expression] = | ||
Seq(child, epsExpression, confidenceExpression, seedExpression) | ||
|
||
override def prettyName: String = "count_min_sketch" | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here I use string to represent decimal because there is a one-to-one mapping between BigDecimal and String.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this true?
"1.0" and "1.00" is the same value but not the same string representation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I didn't describe it accurately. It should be "There is a one-to-one mapping between the distinguishable values and the result of this conversion." (from java doc of BigDecimal)