Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-18429] [SQL] implement a new Aggregate for CountMinSketch #15877

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@
* <li>{@link Integer}</li>
* <li>{@link Long}</li>
* <li>{@link String}</li>
* <li>{@link Float}</li>
* <li>{@link Double}</li>
* <li>{@link java.math.BigDecimal}</li>
* <li>{@link Boolean}</li>
* </ul>
* A {@link CountMinSketch} is initialized with a random seed, and a pair of parameters:
* <ol>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import java.io.ObjectOutputStream;
import java.io.OutputStream;
import java.io.Serializable;
import java.math.BigDecimal;
import java.util.Arrays;
import java.util.Random;

Expand Down Expand Up @@ -152,6 +153,16 @@ public void add(Object item) {
public void add(Object item, long count) {
if (item instanceof String) {
addString((String) item, count);
} else if (item instanceof BigDecimal) {
addString(((BigDecimal) item).toString(), count);
Copy link
Contributor Author

@wzhfy wzhfy Nov 23, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here I use string to represent decimal because there is a one-to-one mapping between BigDecimal and String.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this true?

"1.0" and "1.00" is the same value but not the same string representation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I didn't describe it accurately. It should be "There is a one-to-one mapping between the distinguishable values and the result of this conversion." (from java doc of BigDecimal)

} else if (item instanceof byte[]) {
addBinary((byte[]) item, count);
} else if (item instanceof Float) {
addLong(Float.floatToIntBits((Float) item), count);
} else if (item instanceof Double) {
addLong(Double.doubleToLongBits((Double) item), count);
} else if (item instanceof Boolean) {
addLong(((Boolean) item) ? 1L : 0L, count);
} else {
addLong(Utils.integralToLong(item), count);
}
Expand Down Expand Up @@ -216,10 +227,6 @@ private int hash(long item, int count) {
return ((int) hash) % width;
}

private static int[] getHashBuckets(String key, int hashCount, int max) {
return getHashBuckets(Utils.getBytesFromUTF8String(key), hashCount, max);
}

private static int[] getHashBuckets(byte[] b, int hashCount, int max) {
int[] result = new int[hashCount];
int hash1 = Murmur3_x86_32.hashUnsafeBytes(b, Platform.BYTE_ARRAY_OFFSET, b.length, 0);
Expand All @@ -233,7 +240,18 @@ private static int[] getHashBuckets(byte[] b, int hashCount, int max) {
@Override
public long estimateCount(Object item) {
if (item instanceof String) {
return estimateCountForStringItem((String) item);
return estimateCountForBinaryItem(Utils.getBytesFromUTF8String((String) item));
} else if (item instanceof BigDecimal) {
return estimateCountForBinaryItem(
Utils.getBytesFromUTF8String(((BigDecimal) item).toString()));
} else if (item instanceof byte[]) {
return estimateCountForBinaryItem((byte[]) item);
} else if (item instanceof Float) {
return estimateCountForLongItem(Float.floatToIntBits((Float) item));
} else if (item instanceof Double) {
return estimateCountForLongItem(Double.doubleToLongBits((Double) item));
} else if (item instanceof Boolean) {
return estimateCountForLongItem(((Boolean) item) ? 1L : 0L);
} else {
return estimateCountForLongItem(Utils.integralToLong(item));
}
Expand All @@ -247,7 +265,7 @@ private long estimateCountForLongItem(long item) {
return res;
}

private long estimateCountForStringItem(String item) {
private long estimateCountForBinaryItem(byte[] item) {
long res = Long.MAX_VALUE;
int[] buckets = getHashBuckets(item, depth, width);
for (int i = 0; i < depth; ++i) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.util.sketch

import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
import java.nio.charset.StandardCharsets

import scala.reflect.ClassTag
import scala.util.Random
Expand All @@ -44,6 +45,12 @@ class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite
}

def testAccuracy[T: ClassTag](typeName: String)(itemGenerator: Random => T): Unit = {
def getProbeItem(item: T): Any = item match {
// Use a string to represent the content of an array of bytes
case bytes: Array[Byte] => new String(bytes, StandardCharsets.UTF_8)
case i => identity(i)
}

test(s"accuracy - $typeName") {
// Uses fixed seed to ensure reproducible test execution
val r = new Random(31)
Expand All @@ -56,7 +63,7 @@ class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite

val exactFreq = {
val sampledItems = sampledItemIndices.map(allItems)
sampledItems.groupBy(identity).mapValues(_.length.toLong)
sampledItems.groupBy(getProbeItem).mapValues(_.length.toLong)
}

val sketch = CountMinSketch.create(epsOfTotalCount, confidence, seed)
Expand All @@ -67,7 +74,7 @@ class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite

val probCorrect = {
val numErrors = allItems.map { item =>
val count = exactFreq.getOrElse(item, 0L)
val count = exactFreq.getOrElse(getProbeItem(item), 0L)
val ratio = (sketch.estimateCount(item) - count).toDouble / numAllItems
if (ratio > epsOfTotalCount) 1 else 0
}.sum
Expand Down Expand Up @@ -135,6 +142,18 @@ class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite

testItemType[String]("String") { r => r.nextString(r.nextInt(20)) }

testItemType[Float]("Float") { _.nextFloat() }

testItemType[Double]("Double") { _.nextDouble() }

testItemType[java.math.BigDecimal]("Decimal") { r => new java.math.BigDecimal(r.nextDouble()) }

testItemType[Boolean]("Boolean") { _.nextBoolean() }

testItemType[Array[Byte]]("Binary") { r =>
Utils.getBytesFromUTF8String(r.nextString(r.nextInt(20)))
}

test("incompatible merge") {
intercept[IncompatibleMergeException] {
CountMinSketch.create(10, 10, 1).mergeInPlace(null)
Expand Down
5 changes: 5 additions & 0 deletions sql/catalyst/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@
<artifactId>spark-unsafe_${scala.binary.version}</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-sketch_${scala.binary.version}</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.scalacheck</groupId>
<artifactId>scalacheck_${scala.binary.version}</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ object FunctionRegistry {
expression[VarianceSamp]("var_samp"),
expression[CollectList]("collect_list"),
expression[CollectSet]("collect_set"),
expression[CountMinSketchAgg]("count_min_sketch"),

// string functions
expression[Ascii]("ascii"),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.expressions.aggregate

import java.io.{ByteArrayInputStream, ByteArrayOutputStream}

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescription}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.sketch.CountMinSketch

/**
* This function returns a count-min sketch of a column with the given esp, confidence and seed.
* A count-min sketch is a probabilistic data structure used for summarizing streams of data in
* sub-linear space, which is useful for equality predicates and join size estimation.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe something on the return type? A developer should know how to work with these bytes.

Copy link
Contributor Author

@wzhfy wzhfy Nov 15, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, I wrote this in usage, I'll add it here too, thanks.

* The result returned by the function is an array of bytes, which should be deserialized to a
* `CountMinSketch` before usage.
*
* @param child child expression that can produce column value with `child.eval(inputRow)`
* @param epsExpression relative error, must be positive
* @param confidenceExpression confidence, must be positive and less than 1.0
* @param seedExpression random seed
*/
@ExpressionDescription(
usage = """
_FUNC_(col, eps, confidence, seed) - Returns a count-min sketch of a column with the given esp,
confidence and seed. The result is an array of bytes, which should be deserialized to a
`CountMinSketch` before usage. `CountMinSketch` is useful for equality predicates and join
size estimation.
""")
case class CountMinSketchAgg(
child: Expression,
epsExpression: Expression,
confidenceExpression: Expression,
seedExpression: Expression,
override val mutableAggBufferOffset: Int,
override val inputAggBufferOffset: Int) extends TypedImperativeAggregate[CountMinSketch] {

def this(
child: Expression,
epsExpression: Expression,
confidenceExpression: Expression,
seedExpression: Expression) = {
this(child, epsExpression, confidenceExpression, seedExpression, 0, 0)
}

// Mark as lazy so that they are not evaluated during tree transformation.
private lazy val eps: Double = epsExpression.eval().asInstanceOf[Double]
private lazy val confidence: Double = confidenceExpression.eval().asInstanceOf[Double]
private lazy val seed: Int = seedExpression.eval().asInstanceOf[Int]

override def checkInputDataTypes(): TypeCheckResult = {
val defaultCheck = super.checkInputDataTypes()
if (defaultCheck.isFailure) {
defaultCheck
} else if (!epsExpression.foldable || !confidenceExpression.foldable ||
!seedExpression.foldable) {
TypeCheckFailure(
"The eps, confidence or seed provided must be a literal or constant foldable")
} else if (epsExpression.eval() == null || confidenceExpression.eval() == null ||
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we also check for negative eps and confidence values?

seedExpression.eval() == null) {
TypeCheckFailure("The eps, confidence or seed provided should not be null")
} else if (eps <= 0D) {
TypeCheckFailure(s"Relative error must be positive (current value = $eps)")
} else if (confidence <= 0D || confidence >= 1D) {
TypeCheckFailure(s"Confidence must be within range (0.0, 1.0) (current value = $confidence)")
} else {
TypeCheckSuccess
}
}

override def createAggregationBuffer(): CountMinSketch = {
CountMinSketch.create(eps, confidence, seed)
}

override def update(buffer: CountMinSketch, input: InternalRow): Unit = {
val value = child.eval(input)
// Ignore empty rows
if (value != null) {
child.dataType match {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets not do a pattern match for every update. We should use an update function instead, for example:

private[this] val doUpdate: (CountMinSketch, Any) => Unit = child.dataType match {
  case StringType => (cms, value) => cms.addBinary(value.asInstanceOf[UTF8String].getBytes)
  case ByteType => (cms, value) => cms.addLong(value..asInstanceOf[Byte])
  ...
}

override def update(buffer: CountMinSketch, input: InternalRow): Unit = {
  val value = child.eval(input)
  if (value != null) {
    doUpdate(buffer, value)
  }
}

// `Decimal` and `UTF8String` are internal types in spark sql, we need to convert them
// into acceptable types for `CountMinSketch`.
case DecimalType() => buffer.add(value.asInstanceOf[Decimal].toJavaBigDecimal)
// For string type, we can get bytes of our `UTF8String` directly, and call the `addBinary`
// instead of `addString` to avoid unnecessary conversion.
case StringType => buffer.addBinary(value.asInstanceOf[UTF8String].getBytes)
case _ => buffer.add(value)
}
}
}

override def merge(buffer: CountMinSketch, input: CountMinSketch): Unit = {
buffer.mergeInPlace(input)
}

override def eval(buffer: CountMinSketch): Any = serialize(buffer)

override def serialize(buffer: CountMinSketch): Array[Byte] = {
val out = new ByteArrayOutputStream()
buffer.writeTo(out)
out.toByteArray
}

override def deserialize(storageFormat: Array[Byte]): CountMinSketch = {
val in = new ByteArrayInputStream(storageFormat)
CountMinSketch.readFrom(in)
}

override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): CountMinSketchAgg =
copy(mutableAggBufferOffset = newMutableAggBufferOffset)

override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): CountMinSketchAgg =
copy(inputAggBufferOffset = newInputAggBufferOffset)

override def inputTypes: Seq[AbstractDataType] = {
Seq(TypeCollection(NumericType, StringType, DateType, TimestampType, BooleanType, BinaryType),
DoubleType, DoubleType, IntegerType)
}

override def nullable: Boolean = false

override def dataType: DataType = BinaryType

override def children: Seq[Expression] =
Seq(child, epsExpression, confidenceExpression, seedExpression)

override def prettyName: String = "count_min_sketch"
}
Loading