Skip to content

Commit

Permalink
Pivot with null as the pivot value throws NPE
Browse files Browse the repository at this point in the history
  • Loading branch information
HyukjinKwon committed Mar 9, 2017
1 parent 932196d commit 0476565
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,12 @@ class Analyzer(
value + "_" + suffix
}
}
if (aggregates.forall(a => PivotFirst.supportsDataType(a.dataType))) {

val shouldTwoStepAggregate =
aggregates.forall(a => PivotFirst.supportsDataType(a.dataType)) &&
!pivotValues.exists(_.dataType.acceptsType(NullType))

if (shouldTwoStepAggregate) {
// Since evaluating |pivotValues| if statements for each input row can get slow this is an
// alternate plan that instead uses two steps of aggregation.
val namedAggExps: Seq[NamedExpression] = aggregates.map(a => Alias(a, a.sql)())
Expand Down Expand Up @@ -524,15 +529,21 @@ class Analyzer(
def ifExpr(expr: Expression) = {
If(EqualTo(pivotColumn, value), expr, Literal(null))
}
def ifNullSafeExpr(expr: Expression) = {
If(EqualNullSafe(pivotColumn, value), expr, Literal(null))
}
aggregates.map { aggregate =>
val filteredAggregate = aggregate.transformDown {
// Assumption is the aggregate function ignores nulls. This is true for all current
// AggregateFunction's with the exception of First and Last in their default mode
// (which we handle) and possibly some Hive UDAF's.
// AggregateFunction's with the exception of First, Last and Count in their
// default mode (which we handle) and possibly some Hive UDAF's.
case First(expr, _) =>
First(ifExpr(expr), Literal(true))
case Last(expr, _) =>
Last(ifExpr(expr), Literal(true))
case c: Count =>
// In case of count, `null` should be counted.
c.withNewChildren(c.children.map(ifNullSafeExpr))
case a: AggregateFunction =>
a.withNewChildren(a.children.map(ifExpr))
}.transform {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,4 +216,10 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext{
Row("d", 15000.0, 48000.0) :: Row("J", 20000.0, 30000.0) :: Nil
)
}

test("pivot with null should not throw NPE") {
checkAnswer(
Seq(Tuple1(None), Tuple1(Some(1))).toDF("a").groupBy($"a").pivot("a").count(),
Row(null, 1, 0) :: Row(1, 0, 1) :: Nil)
}
}

0 comments on commit 0476565

Please sign in to comment.