Skip to content

Commit

Permalink
untyped select API disallows TypedColumn without input type.
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Feb 25, 2020
1 parent 096ce42 commit 83958fb
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 30 deletions.
11 changes: 11 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/Column.scala
Expand Up @@ -97,6 +97,17 @@ class TypedColumn[-T, U](
new TypedColumn[T, U](newExpr, encoder)
}

/**
* This method is used internally in SparkSQL to check if a `TypedColumn` has been inserted with
* specific input type and schema by `withInputType`.
*/
private[sql] def needInputType: Boolean = {
expr.find {
case ta: TypedAggregateExpression if ta.inputDeserializer.isEmpty => true
case _ => false
}.isDefined
}

/**
* Gives the [[TypedColumn]] a name (alias).
* If the current `TypedColumn` has metadata associated with it, this metadata will be propagated
Expand Down
11 changes: 3 additions & 8 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Expand Up @@ -1432,15 +1432,10 @@ class Dataset[T] private[sql](
def select(cols: Column*): DataFrame = withPlan {
val untypedCols = cols.map {
case typedCol: TypedColumn[_, _] =>
val isSimpleEncoder = typedCol.encoder.namedExpressions.head match {
case Alias(_: BoundReference, _) if !typedCol.encoder.isSerializedAsStruct => true
case _ => false
}
if (isSimpleEncoder) {
// This typed column produces simple type output that can be fit into untyped `DataFrame`.
typedCol.withInputType(exprEnc, logicalPlan.output)
if (!typedCol.needInputType) {
typedCol
} else {
throw new AnalysisException(s"Typed column $typedCol with complex serializer " +
throw new AnalysisException(s"Typed column $typedCol that needs input type and schema " +
"cannot be passed in untyped `select` API. Use the typed `Dataset.select` API instead.")
}

Expand Down
Expand Up @@ -228,15 +228,6 @@ case class FooAgg(s: Int) extends Aggregator[Row, Int, Int] {
def outputEncoder: Encoder[Int] = Encoders.scalaInt
}

case class FooAggWithComplexOutput(s: Int) extends Aggregator[Row, Int, (Int, Int)] {
def zero: Int = s
def reduce(b: Int, r: Row): Int = b + r.getAs[Int](0)
def merge(b1: Int, b2: Int): Int = b1 + b2
def finish(b: Int): (Int, Int) = (1, b)
def bufferEncoder: Encoder[Int] = Encoders.scalaInt
def outputEncoder: Encoder[(Int, Int)] = ExpressionEncoder()
}

class DatasetAggregatorSuite extends QueryTest with SharedSparkSession {
import testImplicits._

Expand Down Expand Up @@ -413,28 +404,18 @@ class DatasetAggregatorSuite extends QueryTest with SharedSparkSession {
checkDataset(group.as[OptionBooleanIntData], OptionBooleanIntData("bob", Some((true, 3))))
}

test("SPARK-30590: untyped select should not accept complex typed column expressions") {
test("SPARK-30590: untyped select should not accept typed column without input type") {
val df = Seq((1, 2, 3, 4, 5, 6)).toDF("a", "b", "c", "d", "e", "f")
val fooAgg = (i: Int) => FooAgg(i).toColumn.name(s"foo_agg_$i")

val agg1 = df.select(fooAgg(1), fooAgg(2), fooAgg(3), fooAgg(4), fooAgg(5))
checkDataset(agg1, (3, 5, 7, 9, 11))

// Passes typed columns to untyped `Dataset.select` API.
val agg2 = df.select(fooAgg(1), fooAgg(2), fooAgg(3), fooAgg(4), fooAgg(5), fooAgg(6))
checkAnswer(agg2, Row(3, 5, 7, 9, 11, 13) :: Nil)

val complexFooAgg = (i: Int) => FooAggWithComplexOutput(i).toColumn.name(s"foo_agg_$i")
val err = intercept[AnalysisException] {
df.select(
complexFooAgg(1),
complexFooAgg(2),
complexFooAgg(3),
complexFooAgg(4),
complexFooAgg(5),
complexFooAgg(6))
df.select(fooAgg(1), fooAgg(2), fooAgg(3), fooAgg(4), fooAgg(5), fooAgg(6))
}.getMessage
assert(err.contains("with complex serializer cannot be passed in untyped `select` API. " +
assert(err.contains("cannot be passed in untyped `select` API. " +
"Use the typed `Dataset.select` API instead."))
}
}

0 comments on commit 83958fb

Please sign in to comment.