Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-18429] [SQL] implement a new Aggregate for CountMinSketch #15877

Closed
wants to merge 8 commits into from

Conversation

wzhfy
Copy link
Contributor

@wzhfy wzhfy commented Nov 14, 2016

What changes were proposed in this pull request?

This PR implements a new Aggregate to generate count min sketch, which is a wrapper of CountMinSketch.

How was this patch tested?

add test cases

@wzhfy
Copy link
Contributor Author

wzhfy commented Nov 14, 2016

cc @rxin

@rxin
Copy link
Contributor

rxin commented Nov 14, 2016

cc @liancheng

@SparkQA
Copy link

SparkQA commented Nov 14, 2016

Test build #68604 has finished for PR 15877 at commit a4753e4.

  • This patch fails Spark unit tests.
  • This patch merges cleanly.
  • This patch adds the following public classes (experimental):
    • case class CountMinSketchAgg(

@wzhfy
Copy link
Contributor Author

wzhfy commented Nov 14, 2016

retest this please

@SparkQA
Copy link

SparkQA commented Nov 14, 2016

Test build #68610 has finished for PR 15877 at commit a4753e4.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds the following public classes (experimental):
    • case class CountMinSketchAgg(

buffer.mergeInPlace(input)
}

override def eval(buffer: CountMinSketch): Any = new GenericArrayData(serialize(buffer))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this an array of bytes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is better just to return the byte array and to change the datatype into a BinaryType

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's better, thanks!

}

override def checkInputDataTypes(): TypeCheckResult = {
val defaultCheck = super.checkInputDataTypes()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need to check this (the super class does not implement this).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ExpectsInputTypes.checkInputDataTypes() checks validity of input types, right?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is fair.

}

override def createAggregationBuffer(): CountMinSketch = {
val eps: Double = epsExpression.eval().asInstanceOf[Double]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we cache this in lazy vals? I am not sure about the performance implications.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, i'll change them to lazy vals

// ignore empty rows
if (value != null) {
// UTF8String is a spark sql type, while CountMinSketch accepts String type
buffer.add(if (value.isInstanceOf[UTF8String]) value.toString else value)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How bad would it be to add support for UTF8 string to CMS? Or to pass the UTF8 byte array to CMS?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we should pass the byte array to CMS.

/**
* This function returns a count-min sketch of a column with the given esp, confidence and seed.
* A count-min sketch is a probabilistic data structure used for summarizing streams of data in
* sub-linear space, which is useful for equality predicates and join size estimation.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe something on the return type? A developer should know how to work with these bytes.

Copy link
Contributor Author

@wzhfy wzhfy Nov 15, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, I wrote this in usage, I'll add it here too, thanks.

copy(inputAggBufferOffset = newInputAggBufferOffset)

override def inputTypes: Seq[AbstractDataType] = {
// currently `CountMinSketch` supports integral and string types
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we expand this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rxin suggested that for unsupported types, we hash it before count min sketch, i.e. CountMinSketchAgg(hash(col)).

agg.merge(mergeBuffer, group1Buffer)
agg.merge(mergeBuffer, group2Buffer)
checkResult(agg.eval(mergeBuffer), allItems, exactFreq)
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might also be a good place to test merging in a different order, and the merging of an empty partition.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I'll also test these.

data: Array[T],
exactFreq: Map[T, Long]): Unit = {
result match {
case arrayData: ArrayData =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add case _ => fail("unexpected return type") to have a nicer error when something goed wrong there

!seedExpression.foldable) {
TypeCheckFailure(
"The eps, confidence or seed provided must be a literal or constant foldable")
} else if (epsExpression.eval() == null || confidenceExpression.eval() == null ||
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we also check for negative eps and confidence values?

confidenceExpression = Literal(confidence),
seedExpression = Literal(seed))
val err = intercept[IllegalArgumentException] {
invalidAgg.createAggregationBuffer()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See my comment in the CMS agg. This is too late to throw such an error. I'd rather have driver side errors then executor side errors.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we should have driver side errors, thanks.

Copy link
Contributor

@hvanhovell hvanhovell left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks pretty good! I have left a few minor comments. Also consider to register this aggregate in the FunctionRegistry and to add it to functions.scala.

@rxin
Copy link
Contributor

rxin commented Nov 15, 2016

yes please register a count_min_sketch and alias cmsketch in FunctionRegistry.

case StringType => buffer.addBinary(value.asInstanceOf[UTF8String].getBytes)
case ByteType => buffer.addLong(value.asInstanceOf[Byte])
case ShortType => buffer.addLong(value.asInstanceOf[Short])
case IntegerType => buffer.addLong(value.asInstanceOf[Int])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add DateType?

case ByteType => buffer.addLong(value.asInstanceOf[Byte])
case ShortType => buffer.addLong(value.asInstanceOf[Short])
case IntegerType => buffer.addLong(value.asInstanceOf[Int])
case LongType => buffer.addLong(value.asInstanceOf[Long])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add TimestampType?

val value = child.eval(input)
// ignore empty rows
if (value != null) {
child.dataType match {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A general question: what is faster a pattern match at runtime or to use a virtual function here?

cc @davies @cloud-fan

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

virtual function dispatch is usually a lot faster than pattern match.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

although i don't know if it matters much here given we are going to run it through many hash functions.

@SparkQA
Copy link

SparkQA commented Nov 15, 2016

Test build #68664 has finished for PR 15877 at commit 0cca205.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Nov 15, 2016

Test build #68666 has finished for PR 15877 at commit 2064846.

  • This patch fails Spark unit tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@@ -261,6 +261,8 @@ object FunctionRegistry {
expression[VarianceSamp]("var_samp"),
expression[CollectList]("collect_list"),
expression[CollectSet]("collect_set"),
expression[CountMinSketchAgg]("count_min_sketch"),
expression[CountMinSketchAgg]("cmsketch"),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually i take my word back. let's add only count_min_sketch. I don't think it's worth having an alias given this is sketch is difficult to consume (returning some binary)

* @group agg_funcs
* @since 2.2.0
*/
def count_min_sketch(e: Column, eps: Double, confidence: Double, seed: Int): Column = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's not add these for now.

@SparkQA
Copy link

SparkQA commented Nov 16, 2016

Test build #68699 has finished for PR 15877 at commit 7bfdd40.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Nov 16, 2016

Test build #68703 has finished for PR 15877 at commit 6143997.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@wzhfy
Copy link
Contributor Author

wzhfy commented Nov 17, 2016

Hi @hvanhovell @rxin, I've updated this pr, does it look good to you now?

val value = child.eval(input)
// Ignore empty rows
if (value != null) {
child.dataType match {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets not do a pattern match for every update. We should use an update function instead, for example:

private[this] val doUpdate: (CountMinSketch, Any) => Unit = child.dataType match {
  case StringType => (cms, value) => cms.addBinary(value.asInstanceOf[UTF8String].getBytes)
  case ByteType => (cms, value) => cms.addLong(value..asInstanceOf[Byte])
  ...
}

override def update(buffer: CountMinSketch, input: InternalRow): Unit = {
  val value = child.eval(input)
  if (value != null) {
    doUpdate(buffer, value)
  }
}

// Currently `CountMinSketch` supports integral (date/timestamp is represented as int/long
// internally) and string types.
Seq(TypeCollection(IntegralType, StringType, DateType, TimestampType),
DoubleType, DoubleType, IntegerType)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also add FloatType (use Float.floatToIntBits), DoubleType (use Double.doubleToLongBits), BooleanType and BinaryType? We could also add support for Decimal, but that would be a bit harder to get right.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rxin @hvanhovell If we really want to support all these types, is it better to move this conversion and pattern match logics into CountMinSketch? That is, make cms support these types itself. Then, when users do queries e.g. on float type, they don't need to do conversions like cms.estimateCount(Float.floatToIntBits(value)).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes if we want to add support for those I think it'd make sense to do it in count-min sketch itself too.

@SparkQA
Copy link

SparkQA commented Nov 22, 2016

Test build #68985 has finished for PR 15877 at commit ca4a13f.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Nov 22, 2016

Test build #68994 has finished for PR 15877 at commit b009ff8.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@wzhfy
Copy link
Contributor Author

wzhfy commented Nov 23, 2016

cc @rxin @hvanhovell

@@ -152,6 +153,16 @@ public void add(Object item) {
public void add(Object item, long count) {
if (item instanceof String) {
addString((String) item, count);
} else if (item instanceof BigDecimal) {
addString(((BigDecimal) item).toString(), count);
Copy link
Contributor Author

@wzhfy wzhfy Nov 23, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here I use string to represent decimal because there is a one-to-one mapping between BigDecimal and String.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this true?

"1.0" and "1.00" is the same value but not the same string representation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I didn't describe it accurately. It should be "There is a one-to-one mapping between the distinguishable values and the result of this conversion." (from java doc of BigDecimal)

@SparkQA
Copy link

SparkQA commented Nov 23, 2016

Test build #69049 has finished for PR 15877 at commit 1bfb6fd.

  • This patch fails Spark unit tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Nov 23, 2016

Test build #69047 has finished for PR 15877 at commit a6bbefc.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@rxin
Copy link
Contributor

rxin commented Nov 29, 2016

Thanks - I'm going to merge this in master. I will submit a follow-up PR to simplify this a little bit, and remove the handling of float/double/decimal types and require explicit user action on how to turn that into long.

@asfgit asfgit closed this in d57a594 Nov 29, 2016
@rxin
Copy link
Contributor

rxin commented Nov 29, 2016

Hey guys - after looking at the pr more, I'm afraid we have gone overboard with testing here. Most of the test cases written are just repeating each other and doing exactly the same thing. For testing something like this I'd probably just have some simple end-to-end test and then be done with it, because most of the complicated logics are isolated in the actual CountMinSketch implementation itself and already has good test coverage.

assert(buffer.equals(agg.deserialize(agg.serialize(buffer))))
}

def testHighLevelInterface[T: ClassTag](
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wzhfy can you comment on why we need to test both the high level interface and the low level interface?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just followed the style in ApproximatePercentileSuite which is also a TypedImperativeAggregate. I thought they are used to test different levels of operations for TypedImperativeAggregate, e.g. update(buffer: InternalRow, input: InternalRow) and def update(buffer: T, input: InternalRow).

robert3005 pushed a commit to palantir/spark that referenced this pull request Dec 2, 2016
## What changes were proposed in this pull request?

This PR implements a new Aggregate to generate count min sketch, which is a wrapper of CountMinSketch.

## How was this patch tested?

add test cases

Author: wangzhenhua <wangzhenhua@huawei.com>

Closes apache#15877 from wzhfy/cms.
robert3005 pushed a commit to palantir/spark that referenced this pull request Dec 15, 2016
## What changes were proposed in this pull request?

This PR implements a new Aggregate to generate count min sketch, which is a wrapper of CountMinSketch.

## How was this patch tested?

add test cases

Author: wangzhenhua <wangzhenhua@huawei.com>

Closes apache#15877 from wzhfy/cms.
uzadude pushed a commit to uzadude/spark that referenced this pull request Jan 27, 2017
## What changes were proposed in this pull request?

This PR implements a new Aggregate to generate count min sketch, which is a wrapper of CountMinSketch.

## How was this patch tested?

add test cases

Author: wangzhenhua <wangzhenhua@huawei.com>

Closes apache#15877 from wzhfy/cms.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
4 participants