Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-25048][SQL] Pivoting by multiple columns in Scala/Java #22316

Closed
wants to merge 11 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.util.toPrettySQL
import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{NumericType, StructType}

Expand Down Expand Up @@ -406,6 +407,14 @@ class RelationalGroupedDataset protected[sql](
* df.groupBy($"year").pivot($"course", Seq("dotNET", "Java")).sum($"earnings")
* }}}
*
* For pivoting by multiple columns, use the `struct` function to combine the columns and values:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the documentation states it's an overloaded version of the `pivot` method with `pivotColumn` of the `String` type., shall we move this contents to that method?

Also, I would document this, for instance,

From Spark 2.4.0, values can be literal columns, for instance, struct. For pivoting by multiple columns, use the struct function to combine the columns and values.

*
* {{{
* df.groupBy($"year")
* .pivot(struct($"course", $"training"), Seq(struct(lit("java"), lit("Experts"))))
* .agg(sum($"earnings"))
* }}}
*
* @param pivotColumn the column to pivot.
* @param values List of values that will be translated to columns in the output DataFrame.
* @since 2.4.0
Expand All @@ -416,7 +425,7 @@ class RelationalGroupedDataset protected[sql](
new RelationalGroupedDataset(
df,
groupingExprs,
RelationalGroupedDataset.PivotType(pivotColumn.expr, values.map(Literal.apply)))
RelationalGroupedDataset.PivotType(pivotColumn.expr, values.map(lit(_).expr)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think about map(lit).map(_.expr) instead?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't see any advantages of this. It is longer and slower.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@MaxGekk, just for doubly doubly sure, shell we Try(...).getOrElse(lit(...).expr)? Looks at least there's one case of a potential behaviour change about scale and precision.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks at least there's one case of a potential behaviour change about scale and precision.

Could you explain, please. Why do you expect some behavior change?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

now we eventually call Literal.create instead of Literal.apply. I'm not sure if there is a behavior change though.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from a quick look, seems Literal.create is more powerful and should not have regressions.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's true in general but specifically is decimal precision more correct?

case _: RelationalGroupedDataset.PivotType =>
throw new UnsupportedOperationException("repeated pivots are not supported")
case _ =>
Expand Down Expand Up @@ -561,5 +570,5 @@ private[sql] object RelationalGroupedDataset {
/**
* To indicate it's the PIVOT
*/
private[sql] case class PivotType(pivotCol: Expression, values: Seq[Literal]) extends GroupType
private[sql] case class PivotType(pivotCol: Expression, values: Seq[Expression]) extends GroupType
}
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,22 @@ public void pivot() {
Assert.assertEquals(30000.0, actual.get(1).getDouble(2), 0.01);
}

@Test
public void pivotColumnValues() {
Dataset<Row> df = spark.table("courseSales");
List<Row> actual = df.groupBy("year")
.pivot(col("course"), Arrays.asList(lit("dotNET"), lit("Java")))
.agg(sum("earnings")).orderBy("year").collectAsList();

Assert.assertEquals(2012, actual.get(0).getInt(0));
Assert.assertEquals(15000.0, actual.get(0).getDouble(1), 0.01);
Assert.assertEquals(20000.0, actual.get(0).getDouble(2), 0.01);

Assert.assertEquals(2013, actual.get(1).getInt(0));
Assert.assertEquals(48000.0, actual.get(1).getDouble(1), 0.01);
Assert.assertEquals(30000.0, actual.get(1).getDouble(2), 0.01);
}

private String getResource(String resource) {
try {
// The following "getResource" has different behaviors in SBT and Maven.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -308,4 +308,27 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext {

assert(exception.getMessage.contains("aggregate functions are not allowed"))
}

test("pivoting column list with values") {
val expected = Row(2012, 10000.0, null) :: Row(2013, 48000.0, 30000.0) :: Nil
val df = trainingSales
.groupBy($"sales.year")
.pivot(struct(lower($"sales.course"), $"training"), Seq(
struct(lit("dotnet"), lit("Experts")),
struct(lit("java"), lit("Dummies")))
).agg(sum($"sales.earnings"))

checkAnswer(df, expected)
}

test("pivoting column list") {
val exception = intercept[RuntimeException] {
trainingSales
.groupBy($"sales.year")
.pivot(struct(lower($"sales.course"), $"training"))
.agg(sum($"sales.earnings"))
.collect()
Copy link
Member

@maropu maropu Sep 3, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't need this .collect() to cactch the RuntimeException? btw, IMHO AnalysisException is better than RuntimeException in this case? Can't we?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My changes don't throw the exception. It is thrown in the collect() :

@maropu Do you propose to catch RuntimeException and replace it by AnalysisException?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried in your branch;

scala> df.show
+--------+--------------------+
|training|               sales|
+--------+--------------------+
| Experts|[dotNET, 2012, 10...|
| Experts|[JAVA, 2012, 2000...|
| Dummies|[dotNet, 2012, 50...|
| Experts|[dotNET, 2013, 48...|
| Dummies|[Java, 2013, 3000...|
+--------+--------------------+

scala> df.groupBy($"sales.year").pivot(struct(lower($"sales.course"), $"training")).agg(sum($"sales.earnings"))
java.lang.RuntimeException: Unsupported literal type class org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema [dotnet,Dummies]
  at org.apache.spark.sql.catalyst.expressions.Literal$.apply(literals.scala:78)
  at org.apache.spark.sql.catalyst.expressions.Literal$$anonfun$create$2.apply(literals.scala:164)
  at org.apache.spark.sql.catalyst.expressions.Literal$$anonfun$create$2.apply(literals.scala:164)
  at scala.util.Try.getOrElse(Try.scala:79)
  at org.apache.spark.sql.catalyst.expressions.Literal$.create(literals.scala:163)
  at org.apache.spark.sql.functions$.typedLit(functions.scala:127)

I miss something?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I miss something?

No, you don't. The exception for sure is thrown inside of lit because collect() returns a complex value which cannot be "wrapped" by lit. This is exactly checked in the test which I added to show existing behavior.

btw, IMHO AnalysisException is better than RuntimeException in this case?

@maropu Could you explain, please, why do you think AnalysisException is better for the error occurs in run-time?

Just in case, in the PR, I don't aim to change behavior of existing method: def pivot(pivotColumn: Column): RelationalGroupedDataset. I believe it should be discussed separately regarding to needs for changing user visible behavior. The PR aims to improve def pivot(pivotColumn: Column, values: Seq[Any]): RelationalGroupedDataset to allow users to specify struct literals in particular. Please, see the description.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think invalid queries basically throw `AnalysisException. But, yea, indeed, we'd better to keep the current behaivour. Thanks!

}
assert(exception.getMessage.contains("Unsupported literal type"))
}
}