Skip to content

Commit

Permalink
[SPARK-14870] [SQL] Fix NPE in TPCDS q14a
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

This PR fixes a bug in `TungstenAggregate` that manifests while aggregating by keys over nullable `BigDecimal` columns. This causes a null pointer exception while executing TPCDS q14a.

## How was this patch tested?

1. Added regression test in `DataFrameAggregateSuite`.
2. Verified that TPCDS q14a works

Author: Sameer Agarwal <sameer@databricks.com>

Closes #12651 from sameeragarwal/tpcds-fix.
  • Loading branch information
sameeragarwal authored and davies committed Apr 25, 2016
1 parent c752b6c commit cbdcd4e
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -239,16 +239,19 @@ class CodegenContext {

/**
* Update a column in MutableRow from ExprCode.
*
* @param isVectorized True if the underlying row is of type `ColumnarBatch.Row`, false otherwise
*/
def updateColumn(
row: String,
dataType: DataType,
ordinal: Int,
ev: ExprCode,
nullable: Boolean): String = {
nullable: Boolean,
isVectorized: Boolean = false): String = {
if (nullable) {
// Can't call setNullAt on DecimalType, because we need to keep the offset
if (dataType.isInstanceOf[DecimalType]) {
if (!isVectorized && dataType.isInstanceOf[DecimalType]) {
s"""
if (!${ev.isNull}) {
${setColumn(row, dataType, ordinal, ev.value)};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -633,7 +633,8 @@ case class TungstenAggregate(
updateExpr.map(BindReferences.bindReference(_, inputAttr).genCode(ctx))
val updateVectorizedRow = vectorizedRowEvals.zipWithIndex.map { case (ev, i) =>
val dt = updateExpr(i).dataType
ctx.updateColumn(vectorizedRowBuffer, dt, i, ev, updateExpr(i).nullable)
ctx.updateColumn(vectorizedRowBuffer, dt, i, ev, updateExpr(i).nullable,
isVectorized = true)
}
Option(
s"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,21 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
df1.groupBy("key").min("value2"),
Seq(Row("a", 0), Row("b", 4))
)

checkAnswer(
decimalData.groupBy("a").agg(sum("b")),
Seq(Row(new java.math.BigDecimal(1.0), new java.math.BigDecimal(3.0)),
Row(new java.math.BigDecimal(2.0), new java.math.BigDecimal(3.0)),
Row(new java.math.BigDecimal(3.0), new java.math.BigDecimal(3.0)))
)

checkAnswer(
decimalDataWithNulls.groupBy("a").agg(sum("b")),
Seq(Row(new java.math.BigDecimal(1.0), new java.math.BigDecimal(1.0)),
Row(new java.math.BigDecimal(2.0), new java.math.BigDecimal(1.0)),
Row(new java.math.BigDecimal(3.0), new java.math.BigDecimal(3.0)),
Row(null, new java.math.BigDecimal(2.0)))
)
}

test("rollup") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,19 @@ private[sql] trait SQLTestData { self =>
df
}

protected lazy val decimalDataWithNulls: DataFrame = {
val df = sqlContext.sparkContext.parallelize(
DecimalDataWithNulls(1, 1) ::
DecimalDataWithNulls(1, null) ::
DecimalDataWithNulls(2, 1) ::
DecimalDataWithNulls(2, null) ::
DecimalDataWithNulls(3, 1) ::
DecimalDataWithNulls(3, 2) ::
DecimalDataWithNulls(null, 2) :: Nil).toDF()
df.registerTempTable("decimalDataWithNulls")
df
}

protected lazy val binaryData: DataFrame = {
val df = sqlContext.sparkContext.parallelize(
BinaryData("12".getBytes(StandardCharsets.UTF_8), 1) ::
Expand Down Expand Up @@ -267,6 +280,7 @@ private[sql] trait SQLTestData { self =>
negativeData
largeAndSmallInts
decimalData
decimalDataWithNulls
binaryData
upperCaseData
lowerCaseData
Expand Down Expand Up @@ -296,6 +310,7 @@ private[sql] object SQLTestData {
case class TestData3(a: Int, b: Option[Int])
case class LargeAndSmallInts(a: Int, b: Int)
case class DecimalData(a: BigDecimal, b: BigDecimal)
case class DecimalDataWithNulls(a: BigDecimal, b: BigDecimal)
case class BinaryData(a: Array[Byte], b: Int)
case class UpperCaseData(N: Int, L: String)
case class LowerCaseData(n: Int, l: String)
Expand Down

0 comments on commit cbdcd4e

Please sign in to comment.