From b216fa104a17b076f246ef236a6deab94ed16246 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Wed, 22 Feb 2017 23:03:20 +0900 Subject: [PATCH 1/4] Fix ClassCastException --- .../catalyst/expressions/aggregate/Percentile.scala | 12 ++++++++---- .../scala/org/apache/spark/sql/DataFrameSuite.scala | 5 +++++ 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala index 6b7cf7991d39d..6737da946b90c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} import java.util -import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} import org.apache.spark.sql.catalyst.expressions._ @@ -138,7 +138,8 @@ case class Percentile( override def update( buffer: OpenHashMap[Number, Long], input: InternalRow): OpenHashMap[Number, Long] = { - val key = child.eval(input).asInstanceOf[Number] + val scalaValue = CatalystTypeConverters.convertToScala(child.eval(input), child.dataType) + val key = scalaValue.asInstanceOf[Number] val frqValue = frequencyExpression.eval(input) // Null values are ignored in counts map. @@ -246,7 +247,8 @@ case class Percentile( val projection = UnsafeProjection.create(Array[DataType](child.dataType, LongType)) // Write pairs in counts map to byte buffer. obj.foreach { case (key, count) => - val row = InternalRow.apply(key, count) + val catalystValue = CatalystTypeConverters.convertToCatalyst(key) + val row = InternalRow.apply(catalystValue, count) val unsafeRow = projection.apply(row) out.writeInt(unsafeRow.getSizeInBytes) unsafeRow.writeToStream(out, buffer) @@ -274,7 +276,9 @@ case class Percentile( val row = new UnsafeRow(2) row.pointTo(bs, sizeOfNextRow) // Insert the pairs into counts map. - val key = row.get(0, child.dataType).asInstanceOf[Number] + val catalystValue = row.get(0, child.dataType) + val scalaValue = CatalystTypeConverters.convertToScala(catalystValue, child.dataType) + val key = scalaValue.asInstanceOf[Number] val count = row.get(1, LongType).asInstanceOf[Long] counts.update(key, count) sizeOfNextRow = ins.readInt() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index e6338ab7cd800..5e65436079db2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1702,4 +1702,9 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val df = Seq(123L -> "123", 19157170390056973L -> "19157170390056971").toDF("i", "j") checkAnswer(df.select($"i" === $"j"), Row(true) :: Row(false) :: Nil) } + + test("SPARK-19691 Calculating percentile of decimal column fails with ClassCastException") { + val df = spark.range(1).selectExpr("CAST(id as DECIMAL) as x").selectExpr("percentile(x, 0.5)") + checkAnswer(df, Row(BigDecimal(0.0)) :: Nil) + } } From 325c95d58b0d4c801a218f497b482f619f4c1114 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Thu, 23 Feb 2017 03:35:24 +0900 Subject: [PATCH 2/4] Reuse converter funcs --- .../catalyst/expressions/aggregate/Percentile.scala | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala index 6737da946b90c..376e59cd684b2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala @@ -89,6 +89,10 @@ case class Percentile( case arrayData: ArrayData => arrayData.toDoubleArray().toSeq } + private lazy val toScalaValue = CatalystTypeConverters.createToScalaConverter(child.dataType) + private lazy val toCatalystValue = + CatalystTypeConverters.createToCatalystConverter(child.dataType) + override def children: Seq[Expression] = { child :: percentageExpression ::frequencyExpression :: Nil } @@ -138,8 +142,7 @@ case class Percentile( override def update( buffer: OpenHashMap[Number, Long], input: InternalRow): OpenHashMap[Number, Long] = { - val scalaValue = CatalystTypeConverters.convertToScala(child.eval(input), child.dataType) - val key = scalaValue.asInstanceOf[Number] + val key = toScalaValue(child.eval(input)).asInstanceOf[Number] val frqValue = frequencyExpression.eval(input) // Null values are ignored in counts map. @@ -247,7 +250,7 @@ case class Percentile( val projection = UnsafeProjection.create(Array[DataType](child.dataType, LongType)) // Write pairs in counts map to byte buffer. obj.foreach { case (key, count) => - val catalystValue = CatalystTypeConverters.convertToCatalyst(key) + val catalystValue = toCatalystValue(key) val row = InternalRow.apply(catalystValue, count) val unsafeRow = projection.apply(row) out.writeInt(unsafeRow.getSizeInBytes) @@ -277,8 +280,7 @@ case class Percentile( row.pointTo(bs, sizeOfNextRow) // Insert the pairs into counts map. val catalystValue = row.get(0, child.dataType) - val scalaValue = CatalystTypeConverters.convertToScala(catalystValue, child.dataType) - val key = scalaValue.asInstanceOf[Number] + val key = toScalaValue(catalystValue).asInstanceOf[Number] val count = row.get(1, LongType).asInstanceOf[Long] counts.update(key, count) sizeOfNextRow = ins.readInt() From ef26f262cc747505cb0d2a55d6ee0c531263ac0a Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Thu, 23 Feb 2017 13:39:14 +0900 Subject: [PATCH 3/4] Replace Number with AnyRef --- .../expressions/aggregate/Percentile.scala | 60 ++++++++++--------- .../aggregate/PercentileSuite.scala | 31 +++++----- 2 files changed, 47 insertions(+), 44 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala index 376e59cd684b2..6692a19188659 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala @@ -61,7 +61,7 @@ case class Percentile( frequencyExpression : Expression, mutableAggBufferOffset: Int = 0, inputAggBufferOffset: Int = 0) - extends TypedImperativeAggregate[OpenHashMap[Number, Long]] with ImplicitCastInputTypes { + extends TypedImperativeAggregate[OpenHashMap[AnyRef, Long]] with ImplicitCastInputTypes { def this(child: Expression, percentageExpression: Expression) = { this(child, percentageExpression, Literal(1L), 0, 0) @@ -89,10 +89,6 @@ case class Percentile( case arrayData: ArrayData => arrayData.toDoubleArray().toSeq } - private lazy val toScalaValue = CatalystTypeConverters.createToScalaConverter(child.dataType) - private lazy val toCatalystValue = - CatalystTypeConverters.createToCatalystConverter(child.dataType) - override def children: Seq[Expression] = { child :: percentageExpression ::frequencyExpression :: Nil } @@ -134,20 +130,30 @@ case class Percentile( } } - override def createAggregationBuffer(): OpenHashMap[Number, Long] = { + private def toLongValue(d: Any): Long = d match { + case d: Decimal => d.toLong + case n: Number => n.longValue + } + + private def toDoubleValue(d: Any): Double = d match { + case d: Decimal => d.toDouble + case n: Number => n.doubleValue + } + + override def createAggregationBuffer(): OpenHashMap[AnyRef, Long] = { // Initialize new counts map instance here. - new OpenHashMap[Number, Long]() + new OpenHashMap[AnyRef, Long]() } override def update( - buffer: OpenHashMap[Number, Long], - input: InternalRow): OpenHashMap[Number, Long] = { - val key = toScalaValue(child.eval(input)).asInstanceOf[Number] + buffer: OpenHashMap[AnyRef, Long], + input: InternalRow): OpenHashMap[AnyRef, Long] = { + val key = child.eval(input).asInstanceOf[AnyRef] val frqValue = frequencyExpression.eval(input) // Null values are ignored in counts map. if (key != null && frqValue != null) { - val frqLong = frqValue.asInstanceOf[Number].longValue() + val frqLong = toLongValue(frqValue) // add only when frequency is positive if (frqLong > 0) { buffer.changeValue(key, frqLong, _ + frqLong) @@ -159,32 +165,32 @@ case class Percentile( } override def merge( - buffer: OpenHashMap[Number, Long], - other: OpenHashMap[Number, Long]): OpenHashMap[Number, Long] = { + buffer: OpenHashMap[AnyRef, Long], + other: OpenHashMap[AnyRef, Long]): OpenHashMap[AnyRef, Long] = { other.foreach { case (key, count) => buffer.changeValue(key, count, _ + count) } buffer } - override def eval(buffer: OpenHashMap[Number, Long]): Any = { + override def eval(buffer: OpenHashMap[AnyRef, Long]): Any = { generateOutput(getPercentiles(buffer)) } - private def getPercentiles(buffer: OpenHashMap[Number, Long]): Seq[Double] = { + private def getPercentiles(buffer: OpenHashMap[AnyRef, Long]): Seq[Double] = { if (buffer.isEmpty) { return Seq.empty } val sortedCounts = buffer.toSeq.sortBy(_._1)( - child.dataType.asInstanceOf[NumericType].ordering.asInstanceOf[Ordering[Number]]) + child.dataType.asInstanceOf[NumericType].ordering.asInstanceOf[Ordering[AnyRef]]) val accumlatedCounts = sortedCounts.scanLeft(sortedCounts.head._1, 0L) { case ((key1, count1), (key2, count2)) => (key2, count1 + count2) }.tail val maxPosition = accumlatedCounts.last._2 - 1 percentages.map { percentile => - getPercentile(accumlatedCounts, maxPosition * percentile).doubleValue() + getPercentile(accumlatedCounts, maxPosition * percentile) } } @@ -204,7 +210,7 @@ case class Percentile( * This function has been based upon similar function from HIVE * `org.apache.hadoop.hive.ql.udf.UDAFPercentile.getPercentile()`. */ - private def getPercentile(aggreCounts: Seq[(Number, Long)], position: Double): Number = { + private def getPercentile(aggreCounts: Seq[(AnyRef, Long)], position: Double): Double = { // We may need to do linear interpolation to get the exact percentile val lower = position.floor.toLong val higher = position.ceil.toLong @@ -217,18 +223,17 @@ case class Percentile( val lowerKey = aggreCounts(lowerIndex)._1 if (higher == lower) { // no interpolation needed because position does not have a fraction - return lowerKey + return toDoubleValue(lowerKey) } val higherKey = aggreCounts(higherIndex)._1 if (higherKey == lowerKey) { // no interpolation needed because lower position and higher position has the same key - return lowerKey + return toDoubleValue(lowerKey) } // Linear interpolation to get the exact percentile - return (higher - position) * lowerKey.doubleValue() + - (position - lower) * higherKey.doubleValue() + (higher - position) * toDoubleValue(lowerKey) + (position - lower) * toDoubleValue(higherKey) } /** @@ -242,7 +247,7 @@ case class Percentile( } } - override def serialize(obj: OpenHashMap[Number, Long]): Array[Byte] = { + override def serialize(obj: OpenHashMap[AnyRef, Long]): Array[Byte] = { val buffer = new Array[Byte](4 << 10) // 4K val bos = new ByteArrayOutputStream() val out = new DataOutputStream(bos) @@ -250,8 +255,7 @@ case class Percentile( val projection = UnsafeProjection.create(Array[DataType](child.dataType, LongType)) // Write pairs in counts map to byte buffer. obj.foreach { case (key, count) => - val catalystValue = toCatalystValue(key) - val row = InternalRow.apply(catalystValue, count) + val row = InternalRow.apply(key, count) val unsafeRow = projection.apply(row) out.writeInt(unsafeRow.getSizeInBytes) unsafeRow.writeToStream(out, buffer) @@ -266,11 +270,11 @@ case class Percentile( } } - override def deserialize(bytes: Array[Byte]): OpenHashMap[Number, Long] = { + override def deserialize(bytes: Array[Byte]): OpenHashMap[AnyRef, Long] = { val bis = new ByteArrayInputStream(bytes) val ins = new DataInputStream(bis) try { - val counts = new OpenHashMap[Number, Long] + val counts = new OpenHashMap[AnyRef, Long] // Read unsafeRow size and content in bytes. var sizeOfNextRow = ins.readInt() while (sizeOfNextRow >= 0) { @@ -280,7 +284,7 @@ case class Percentile( row.pointTo(bs, sizeOfNextRow) // Insert the pairs into counts map. val catalystValue = row.get(0, child.dataType) - val key = toScalaValue(catalystValue).asInstanceOf[Number] + val key = catalystValue.asInstanceOf[AnyRef] val count = row.get(1, LongType).asInstanceOf[Long] counts.update(key, count) sizeOfNextRow = ins.readInt() diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala index 1533fe5f90ee2..d561884dd6290 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala @@ -21,7 +21,6 @@ import org.apache.spark.SparkException import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult._ -import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.types._ @@ -39,12 +38,12 @@ class PercentileSuite extends SparkFunSuite { val agg = new Percentile(BoundReference(0, IntegerType, true), Literal(0.5)) // Check empty serialize and deserialize - val buffer = new OpenHashMap[Number, Long]() + val buffer = new OpenHashMap[AnyRef, Long]() assert(compareEquals(agg.deserialize(agg.serialize(buffer)), buffer)) // Check non-empty buffer serializa and deserialize. data.foreach { key => - buffer.changeValue(key, 1L, _ + 1L) + buffer.changeValue(new Integer(key), 1L, _ + 1L) } assert(compareEquals(agg.deserialize(agg.serialize(buffer)), buffer)) } @@ -52,31 +51,31 @@ class PercentileSuite extends SparkFunSuite { test("class Percentile, high level interface, update, merge, eval...") { val count = 10000 val percentages = Seq(0, 0.25, 0.5, 0.75, 1) - val expectedPercentiles = Seq(1, 2500.75, 5000.5, 7500.25, 10000) + val expectedPercentiles = Seq[Double](1, 2500.75, 5000.5, 7500.25, 10000) val childExpression = Cast(BoundReference(0, IntegerType, nullable = false), DoubleType) val percentageExpression = CreateArray(percentages.toSeq.map(Literal(_))) val agg = new Percentile(childExpression, percentageExpression) // Test with rows without frequency - val rows = (1 to count).map( x => Seq(x)) - runTest( agg, rows, expectedPercentiles) + val rows = (1 to count).map(x => Seq(x)) + runTest(agg, rows, expectedPercentiles) // Test with row with frequency. Second and third columns are frequency in Int and Long val countForFrequencyTest = 1000 - val rowsWithFrequency = (1 to countForFrequencyTest).map( x => Seq(x, x):+ x.toLong) + val rowsWithFrequency = (1 to countForFrequencyTest).map(x => Seq(x, x):+ x.toLong) val expectedPercentilesWithFrquency = Seq(1.0, 500.0, 707.0, 866.0, 1000.0) val frequencyExpressionInt = BoundReference(1, IntegerType, nullable = false) val aggInt = new Percentile(childExpression, percentageExpression, frequencyExpressionInt) - runTest( aggInt, rowsWithFrequency, expectedPercentilesWithFrquency) + runTest(aggInt, rowsWithFrequency, expectedPercentilesWithFrquency) val frequencyExpressionLong = BoundReference(2, LongType, nullable = false) val aggLong = new Percentile(childExpression, percentageExpression, frequencyExpressionLong) - runTest( aggLong, rowsWithFrequency, expectedPercentilesWithFrquency) + runTest(aggLong, rowsWithFrequency, expectedPercentilesWithFrquency) // Run test with Flatten data - val flattenRows = (1 to countForFrequencyTest).flatMap( current => - (1 to current).map( y => current )).map( Seq(_)) + val flattenRows = (1 to countForFrequencyTest).flatMap(current => + (1 to current).map(y => current )).map(Seq(_)) runTest(agg, flattenRows, expectedPercentilesWithFrquency) } @@ -153,7 +152,7 @@ class PercentileSuite extends SparkFunSuite { } val validFrequencyTypes = Seq(ByteType, ShortType, IntegerType, LongType) - for ( dataType <- validDataTypes; + for (dataType <- validDataTypes; frequencyType <- validFrequencyTypes) { val child = AttributeReference("a", dataType)() val frq = AttributeReference("frq", frequencyType)() @@ -176,7 +175,7 @@ class PercentileSuite extends SparkFunSuite { StringType, DateType, TimestampType, CalendarIntervalType, NullType) - for( dataType <- invalidDataTypes; + for(dataType <- invalidDataTypes; frequencyType <- validFrequencyTypes) { val child = AttributeReference("a", dataType)() val frq = AttributeReference("frq", frequencyType)() @@ -186,7 +185,7 @@ class PercentileSuite extends SparkFunSuite { s"'`a`' is of ${dataType.simpleString} type.")) } - for( dataType <- validDataTypes; + for(dataType <- validDataTypes; frequencyType <- invalidFrequencyDataTypes) { val child = AttributeReference("a", dataType)() val frq = AttributeReference("frq", frequencyType)() @@ -294,11 +293,11 @@ class PercentileSuite extends SparkFunSuite { agg.update(buffer, InternalRow(1, -5)) agg.eval(buffer) } - assert( caught.getMessage.startsWith("Negative values found in ")) + assert(caught.getMessage.startsWith("Negative values found in ")) } private def compareEquals( - left: OpenHashMap[Number, Long], right: OpenHashMap[Number, Long]): Boolean = { + left: OpenHashMap[AnyRef, Long], right: OpenHashMap[AnyRef, Long]): Boolean = { left.size == right.size && left.forall { case (key, count) => right.apply(key) == count } From 88f4f47bae435f8cebe6d3e0ad31b3a77516014b Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Thu, 23 Feb 2017 19:07:28 +0900 Subject: [PATCH 4/4] Apply review comments --- .../catalyst/expressions/aggregate/Percentile.scala | 10 ++-------- .../expressions/aggregate/PercentileSuite.scala | 2 +- 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala index 6692a19188659..8433a93ea3032 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala @@ -130,11 +130,6 @@ case class Percentile( } } - private def toLongValue(d: Any): Long = d match { - case d: Decimal => d.toLong - case n: Number => n.longValue - } - private def toDoubleValue(d: Any): Double = d match { case d: Decimal => d.toDouble case n: Number => n.doubleValue @@ -153,7 +148,7 @@ case class Percentile( // Null values are ignored in counts map. if (key != null && frqValue != null) { - val frqLong = toLongValue(frqValue) + val frqLong = frqValue.asInstanceOf[Number].longValue() // add only when frequency is positive if (frqLong > 0) { buffer.changeValue(key, frqLong, _ + frqLong) @@ -283,8 +278,7 @@ case class Percentile( val row = new UnsafeRow(2) row.pointTo(bs, sizeOfNextRow) // Insert the pairs into counts map. - val catalystValue = row.get(0, child.dataType) - val key = catalystValue.asInstanceOf[AnyRef] + val key = row.get(0, child.dataType) val count = row.get(1, LongType).asInstanceOf[Long] counts.update(key, count) sizeOfNextRow = ins.readInt() diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala index d561884dd6290..2420ba513f287 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala @@ -51,7 +51,7 @@ class PercentileSuite extends SparkFunSuite { test("class Percentile, high level interface, update, merge, eval...") { val count = 10000 val percentages = Seq(0, 0.25, 0.5, 0.75, 1) - val expectedPercentiles = Seq[Double](1, 2500.75, 5000.5, 7500.25, 10000) + val expectedPercentiles = Seq(1, 2500.75, 5000.5, 7500.25, 10000) val childExpression = Cast(BoundReference(0, IntegerType, nullable = false), DoubleType) val percentageExpression = CreateArray(percentages.toSeq.map(Literal(_))) val agg = new Percentile(childExpression, percentageExpression)