Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-30590][SQL] Untyped select API cannot take typed column expression that needs input type #27499

Closed
wants to merge 13 commits into from
5 changes: 4 additions & 1 deletion project/MimaExcludes.scala
Expand Up @@ -492,7 +492,10 @@ object MimaExcludes {
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.AFTSurvivalRegression.setPredictionCol"),

// [SPARK-29543][SS][UI] Init structured streaming ui
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQueryListener#QueryStartedEvent.this")
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQueryListener#QueryStartedEvent.this"),

// [SPARK-30590][SQL] Untyped select API cannot take typed column expression
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.functions.count")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Put it under 3.0 exclude rules temporarily. The version number in the master branch is still 3.0.0.

)

// Exclude rules for 2.4.x
Expand Down
14 changes: 10 additions & 4 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Expand Up @@ -1430,6 +1430,11 @@ class Dataset[T] private[sql](
*/
@scala.annotation.varargs
def select(cols: Column*): DataFrame = withPlan {
HyukjinKwon marked this conversation as resolved.
Show resolved Hide resolved
cols.find(_.isInstanceOf[TypedColumn[_, _]]).foreach { typedCol =>
throw new AnalysisException(s"$typedCol is a typed column that " +
"cannot be passed in untyped `select` API. If you are going to select " +
"multiple typed columns, you can use `Dataset.selectUntyped` API.")
}
cloud-fan marked this conversation as resolved.
Show resolved Hide resolved
Project(cols.map(_.named), logicalPlan)
}

Expand Down Expand Up @@ -1493,11 +1498,12 @@ class Dataset[T] private[sql](
}

/**
* Internal helper function for building typed selects that return tuples. For simplicity and
* code reuse, we do this without the help of the type system and then use helper functions
* that cast appropriately for the user facing interface.
* Selects a set of typed column based expressions.
*
* @group typedrel
* @since 3.1.0
*/
protected def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_] = {
def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_] = {
val encoders = columns.map(_.encoder)
val namedColumns =
columns.map(_.withInputType(exprEnc, logicalPlan.output).named)
Expand Down
3 changes: 1 addition & 2 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Expand Up @@ -352,8 +352,7 @@ object functions {
* @group agg_funcs
* @since 1.3.0
*/
def count(columnName: String): TypedColumn[Any, Long] =
count(Column(columnName)).as(ExpressionEncoder[Long]())
def count(columnName: String): Column = count(Column(columnName))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to me it is wrongly being a TypedColumn. Count is a DeclarativeAggregate.

Copy link
Contributor

@cloud-fan cloud-fan Feb 21, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems a right change but let's revert this line considering it's code freeze period ..

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok. :)


/**
* Aggregate function: returns the number of distinct items in a group.
Expand Down
Expand Up @@ -219,6 +219,15 @@ case class OptionBooleanIntAggregator(colName: String)
def OptionalBoolIntEncoder: Encoder[Option[(Boolean, Int)]] = ExpressionEncoder()
}

case class FooAgg(s: Int) extends Aggregator[Row, Int, Int] {
def zero: Int = s
def reduce(b: Int, r: Row): Int = b + r.getAs[Int](0)
def merge(b1: Int, b2: Int): Int = b1 + b2
def finish(b: Int): Int = b
def bufferEncoder: Encoder[Int] = Encoders.scalaInt
def outputEncoder: Encoder[Int] = Encoders.scalaInt
}

class DatasetAggregatorSuite extends QueryTest with SharedSparkSession {
import testImplicits._

Expand Down Expand Up @@ -394,4 +403,21 @@ class DatasetAggregatorSuite extends QueryTest with SharedSparkSession {
checkAnswer(group, Row("bob", Row(true, 3)) :: Nil)
checkDataset(group.as[OptionBooleanIntData], OptionBooleanIntData("bob", Some((true, 3))))
}

test("SPARK-30590: select multiple typed column expressions") {
val df = Seq((1, 2, 3, 4, 5, 6)).toDF("a", "b", "c", "d", "e", "f")
val fooAgg = (i: Int) => FooAgg(i).toColumn.name(s"foo_agg_$i")

val agg1 = df.select(fooAgg(1), fooAgg(2), fooAgg(3), fooAgg(4), fooAgg(5))
checkDataset(agg1, (3, 5, 7, 9, 11))

val agg2 = df.selectUntyped(fooAgg(1), fooAgg(2), fooAgg(3), fooAgg(4), fooAgg(5), fooAgg(6))
.asInstanceOf[Dataset[(Int, Int, Int, Int, Int, Int)]]
HyukjinKwon marked this conversation as resolved.
Show resolved Hide resolved
checkDataset(agg2, (3, 5, 7, 9, 11, 13))

val err = intercept[AnalysisException] {
df.select(fooAgg(1), fooAgg(2), fooAgg(3), fooAgg(4), fooAgg(5), fooAgg(6))
HyukjinKwon marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not related to this PR, just a note:

We have 5 overloads of typed select, and typed count is supported in both typed and untyped select. That said, if we add a 6th overload of typed select, it can break queries that call the untyped select with 6 typed counts.

I'm not sure what's the best way to move forward. Maybe we should add new methods typedSelect to disambiguate the untyped version.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, to be clear, if we add a 6th overload of typed select, a call to the untyped select with 6 typed count could return Dataset[(Long, Long, ...)] instead of DataFrame.

I think you meant something like existing selectUntyped? Although its naming is confusing.

}.getMessage
assert(err.contains("a typed column that cannot be passed in untyped `select` API"))
}
}
Expand Up @@ -597,7 +597,8 @@ class DatasetSuite extends QueryTest
val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()

checkDatasetUnorderly(
ds.groupByKey(_._1).agg(sum("_2").as[Long], sum($"_2" + 1).as[Long], count("*")),
ds.groupByKey(_._1).agg(sum("_2").as[Long],
sum($"_2" + 1).as[Long], count("*").as[Long]),
("a", 30L, 32L, 2L), ("b", 3L, 5L, 2L), ("c", 1L, 2L, 1L))
}

Expand Down
Expand Up @@ -40,7 +40,7 @@ class DeprecatedDatasetAggregatorSuite extends QueryTest with SharedSparkSession
ds.groupByKey(_._1).agg(
typed.sum(_._2),
expr("sum(_2)").as[Long],
count("*")),
count("*").as[Long]),
("a", 30.0, 30L, 2L), ("b", 3.0, 3L, 2L), ("c", 1.0, 1L, 1L))
}

Expand Down