From 30e89111d673776a6b59b11cdb29ab8713ba6f7c Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sun, 2 Aug 2015 20:12:03 -0700 Subject: [PATCH] [SPARK-9546][SQL] Centralize orderable data type checking. This pull request creates two isOrderable functions in RowOrdering that can be used to check whether a data type or a sequence of expressions can be used in sorting. Author: Reynold Xin Closes #7880 from rxin/SPARK-9546 and squashes the following commits: f9e322d [Reynold Xin] Fixed tests. 0439b43 [Reynold Xin] [SPARK-9546][SQL] Centralize orderable data type checking. --- .../sql/catalyst/analysis/CheckAnalysis.scala | 8 +- .../expressions/ExpectsInputTypes.scala | 4 +- .../sql/catalyst/expressions/Expression.scala | 2 +- .../catalyst/expressions/RowOrdering.scala | 93 +++++++++++++++++++ .../sql/catalyst/expressions/SortOrder.scala | 9 ++ .../expressions/codegen/CodeGenerator.scala | 12 ++- .../codegen/GenerateOrdering.scala | 2 - .../expressions/collectionOperations.scala | 21 +++-- .../spark/sql/catalyst/expressions/rows.scala | 44 --------- .../spark/sql/catalyst/util/TypeUtils.scala | 27 +++--- .../apache/spark/sql/types/StructType.scala | 12 --- .../analysis/AnalysisErrorSuite.scala | 14 +-- .../ExpressionTypeCheckingSuite.scala | 50 +++++----- .../spark/sql/execution/SparkStrategies.scala | 14 +-- .../spark/sql/DataFrameFunctionsSuite.scala | 5 +- 15 files changed, 173 insertions(+), 144 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/RowOrdering.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 364569d8f0b40..187b238045f85 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -130,11 +130,9 @@ trait CheckAnalysis { case Sort(orders, _, _) => orders.foreach { order => - order.dataType match { - case t: AtomicType => // OK - case NullType => // OK - case t => - failAnalysis(s"Sorting is not supported for columns of type ${t.simpleString}") + if (!RowOrdering.isOrderable(order.dataType)) { + failAnalysis( + s"sorting is not supported for columns of type ${order.dataType.simpleString}") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala index abe6457747550..2dcbd4eb15031 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala @@ -44,8 +44,8 @@ trait ExpectsInputTypes extends Expression { override def checkInputDataTypes(): TypeCheckResult = { val mismatches = children.zip(inputTypes).zipWithIndex.collect { case ((child, expected), idx) if !expected.acceptsType(child.dataType) => - s"argument ${idx + 1} is expected to be of type ${expected.simpleString}, " + - s"however, '${child.prettyString}' is of type ${child.dataType.simpleString}." + s"argument ${idx + 1} requires ${expected.simpleString} type, " + + s"however, '${child.prettyString}' is of ${child.dataType.simpleString} type." } if (mismatches.isEmpty) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 2842b3ec5a0c8..ef2fc2e8c29d4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -420,7 +420,7 @@ abstract class BinaryOperator extends BinaryExpression with ExpectsInputTypes { TypeCheckResult.TypeCheckFailure(s"differing types in '$prettyString' " + s"(${left.dataType.simpleString} and ${right.dataType.simpleString}).") } else if (!inputType.acceptsType(left.dataType)) { - TypeCheckResult.TypeCheckFailure(s"'$prettyString' accepts ${inputType.simpleString} type," + + TypeCheckResult.TypeCheckFailure(s"'$prettyString' requires ${inputType.simpleString} type," + s" not ${left.dataType.simpleString}") } else { TypeCheckResult.TypeCheckSuccess diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/RowOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/RowOrdering.scala new file mode 100644 index 0000000000000..873f5324c573e --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/RowOrdering.scala @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.types._ + + +/** + * An interpreted row ordering comparator. + */ +class RowOrdering(ordering: Seq[SortOrder]) extends Ordering[InternalRow] { + + def this(ordering: Seq[SortOrder], inputSchema: Seq[Attribute]) = + this(ordering.map(BindReferences.bindReference(_, inputSchema))) + + def compare(a: InternalRow, b: InternalRow): Int = { + var i = 0 + while (i < ordering.size) { + val order = ordering(i) + val left = order.child.eval(a) + val right = order.child.eval(b) + + if (left == null && right == null) { + // Both null, continue looking. + } else if (left == null) { + return if (order.direction == Ascending) -1 else 1 + } else if (right == null) { + return if (order.direction == Ascending) 1 else -1 + } else { + val comparison = order.dataType match { + case dt: AtomicType if order.direction == Ascending => + dt.ordering.asInstanceOf[Ordering[Any]].compare(left, right) + case dt: AtomicType if order.direction == Descending => + dt.ordering.asInstanceOf[Ordering[Any]].reverse.compare(left, right) + case s: StructType if order.direction == Ascending => + s.ordering.asInstanceOf[Ordering[Any]].compare(left, right) + case s: StructType if order.direction == Descending => + s.ordering.asInstanceOf[Ordering[Any]].reverse.compare(left, right) + case other => + throw new IllegalArgumentException(s"Type $other does not support ordered operations") + } + if (comparison != 0) { + return comparison + } + } + i += 1 + } + return 0 + } +} + +object RowOrdering { + + /** + * Returns true iff the data type can be ordered (i.e. can be sorted). + */ + def isOrderable(dataType: DataType): Boolean = dataType match { + case NullType => true + case dt: AtomicType => true + case struct: StructType => struct.fields.forall(f => isOrderable(f.dataType)) + case _ => false + } + + /** + * Returns true iff outputs from the expressions can be ordered. + */ + def isOrderable(exprs: Seq[Expression]): Boolean = exprs.forall(e => isOrderable(e.dataType)) + + /** + * Creates a [[RowOrdering]] for the given schema, in natural ascending order. + */ + def forSchema(dataTypes: Seq[DataType]): RowOrdering = { + new RowOrdering(dataTypes.zipWithIndex.map { + case (dt, index) => new SortOrder(BoundReference(index, dt, nullable = true), Ascending) + }) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala index 5eb5b0d176fc1..f6a872ba446eb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.types._ import org.apache.spark.util.collection.unsafe.sort.PrefixComparators.DoublePrefixComparator @@ -36,6 +37,14 @@ case class SortOrder(child: Expression, direction: SortDirection) /** Sort order is not foldable because we don't have an eval for it. */ override def foldable: Boolean = false + override def checkInputDataTypes(): TypeCheckResult = { + if (RowOrdering.isOrderable(dataType)) { + TypeCheckResult.TypeCheckSuccess + } else { + TypeCheckResult.TypeCheckFailure(s"cannot sort data type ${dataType.simpleString}") + } + } + override def dataType: DataType = child.dataType override def nullable: Boolean = child.nullable diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 3c91227d06080..03ec4b4b4ec55 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -220,7 +220,11 @@ class CodeGenContext { } /** - * Generates code for compare expression in Java. + * Generates code for comparing two expressions. + * + * @param dataType data type of the expressions + * @param c1 name of the variable of expression 1's output + * @param c2 name of the variable of expression 2's output */ def genComp(dataType: DataType, c1: String, c2: String): String = dataType match { // java boolean doesn't support > or < operator @@ -231,7 +235,7 @@ class CodeGenContext { case dt: DataType if isPrimitiveType(dt) => s"($c1 > $c2 ? 1 : $c1 < $c2 ? -1 : 0)" case BinaryType => s"org.apache.spark.sql.catalyst.util.TypeUtils.compareBinary($c1, $c2)" case NullType => "0" - case schema: StructType if schema.supportOrdering(schema) => + case schema: StructType => val comparisons = GenerateOrdering.genComparisons(this, schema) val compareFunc = freshName("compareStruct") val funcCode: String = @@ -245,8 +249,8 @@ class CodeGenContext { addNewFunction(compareFunc, funcCode) s"this.$compareFunc($c1, $c2)" case other if other.isInstanceOf[AtomicType] => s"$c1.compare($c2)" - case _ => throw new IllegalArgumentException( - "cannot generate compare code for un-comparable type") + case _ => + throw new IllegalArgumentException("cannot generate compare code for un-comparable type") } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index 4da91ed8d7752..42be394c3bf5c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.Logging -import org.apache.spark.annotation.Private import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types.StructType @@ -26,7 +25,6 @@ import org.apache.spark.sql.types.StructType /** * Inherits some default implementation for Java from `Ordering[Row]` */ -@Private class BaseOrdering extends Ordering[InternalRow] { def compare(a: InternalRow, b: InternalRow): Int = { throw new UnsupportedOperationException diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 80b8da23e880b..6ccb56578f790 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -20,6 +20,7 @@ import java.util.Comparator import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenFallback, CodeGenContext, GeneratedExpressionCode} +import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ /** @@ -54,15 +55,17 @@ case class SortArray(base: Expression, ascendingOrder: Expression) override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, BooleanType) override def checkInputDataTypes(): TypeCheckResult = base.dataType match { - case _ @ ArrayType(n: AtomicType, _) => TypeCheckResult.TypeCheckSuccess - case _ @ ArrayType(n, _) => TypeCheckResult.TypeCheckFailure( - s"Type $n is not the AtomicType, we can not perform the ordering operations") - case other => - TypeCheckResult.TypeCheckFailure(s"ArrayType(AtomicType) is expected, but we got $other") + case ArrayType(dt, _) if RowOrdering.isOrderable(dt) => + TypeCheckResult.TypeCheckSuccess + case ArrayType(dt, _) => + TypeCheckResult.TypeCheckFailure( + s"$prettyName does not support sorting array of type ${dt.simpleString}") + case _ => + TypeCheckResult.TypeCheckFailure(s"$prettyName only supports array input.") } @transient - private lazy val lt = { + private lazy val lt: Comparator[Any] = { val ordering = base.dataType match { case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]] } @@ -83,7 +86,7 @@ case class SortArray(base: Expression, ascendingOrder: Expression) } @transient - private lazy val gt = { + private lazy val gt: Comparator[Any] = { val ordering = base.dataType match { case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]] } @@ -106,9 +109,7 @@ case class SortArray(base: Expression, ascendingOrder: Expression) override def nullSafeEval(array: Any, ascending: Any): Any = { val elementType = base.dataType.asInstanceOf[ArrayType].elementType val data = array.asInstanceOf[ArrayData].toArray[AnyRef](elementType) - java.util.Arrays.sort( - data, - if (ascending.asInstanceOf[Boolean]) lt else gt) + java.util.Arrays.sort(data, if (ascending.asInstanceOf[Boolean]) lt else gt) new GenericArrayData(data.asInstanceOf[Array[Any]]) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index 7e1031c755f83..d04434b953e41 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -121,47 +121,3 @@ class GenericMutableRow(val values: Array[Any]) extends MutableRow { override def copy(): InternalRow = new GenericInternalRow(values.clone()) } - -class RowOrdering(ordering: Seq[SortOrder]) extends Ordering[InternalRow] { - def this(ordering: Seq[SortOrder], inputSchema: Seq[Attribute]) = - this(ordering.map(BindReferences.bindReference(_, inputSchema))) - - def compare(a: InternalRow, b: InternalRow): Int = { - var i = 0 - while (i < ordering.size) { - val order = ordering(i) - val left = order.child.eval(a) - val right = order.child.eval(b) - - if (left == null && right == null) { - // Both null, continue looking. - } else if (left == null) { - return if (order.direction == Ascending) -1 else 1 - } else if (right == null) { - return if (order.direction == Ascending) 1 else -1 - } else { - val comparison = order.dataType match { - case n: AtomicType if order.direction == Ascending => - n.ordering.asInstanceOf[Ordering[Any]].compare(left, right) - case n: AtomicType if order.direction == Descending => - n.ordering.asInstanceOf[Ordering[Any]].reverse.compare(left, right) - case s: StructType if order.direction == Ascending => - s.ordering.asInstanceOf[Ordering[Any]].compare(left, right) - case s: StructType if order.direction == Descending => - s.ordering.asInstanceOf[Ordering[Any]].reverse.compare(left, right) - case other => sys.error(s"Type $other does not support ordered operations") - } - if (comparison != 0) return comparison - } - i += 1 - } - return 0 - } -} - -object RowOrdering { - def forSchema(dataTypes: Seq[DataType]): RowOrdering = - new RowOrdering(dataTypes.zipWithIndex.map { - case(dt, index) => new SortOrder(BoundReference(index, dt, nullable = true), Ascending) - }) -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala index 2f50d40fe25ac..0b41f92c6193c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala @@ -18,39 +18,34 @@ package org.apache.spark.sql.catalyst.util import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.expressions.RowOrdering import org.apache.spark.sql.types._ /** * Helper functions to check for valid data types. */ object TypeUtils { - def checkForNumericExpr(t: DataType, caller: String): TypeCheckResult = { - if (t.isInstanceOf[NumericType] || t == NullType) { + def checkForNumericExpr(dt: DataType, caller: String): TypeCheckResult = { + if (dt.isInstanceOf[NumericType] || dt == NullType) { TypeCheckResult.TypeCheckSuccess } else { - TypeCheckResult.TypeCheckFailure(s"$caller accepts numeric types, not $t") + TypeCheckResult.TypeCheckFailure(s"$caller requires numeric types, not $dt") } } - def checkForOrderingExpr(t: DataType, caller: String): TypeCheckResult = { - t match { - case i: AtomicType => TypeCheckResult.TypeCheckSuccess - case n: NullType => TypeCheckResult.TypeCheckSuccess - case s: StructType => - if (s.supportOrdering(s)) { - TypeCheckResult.TypeCheckSuccess - } else { - TypeCheckResult.TypeCheckFailure(s"Fields in $s do not support ordering") - } - case other => TypeCheckResult.TypeCheckFailure(s"$t doesn't support ordering on $caller") + def checkForOrderingExpr(dt: DataType, caller: String): TypeCheckResult = { + if (RowOrdering.isOrderable(dt)) { + TypeCheckResult.TypeCheckSuccess + } else { + TypeCheckResult.TypeCheckFailure(s"$caller does not support ordering on type $dt") } - } def checkForSameTypeInputExpr(types: Seq[DataType], caller: String): TypeCheckResult = { if (types.distinct.size > 1) { TypeCheckResult.TypeCheckFailure( - s"input to $caller should all be the same type, but it's ${types.mkString("[", ", ", "]")}") + s"input to $caller should all be the same type, but it's " + + types.map(_.simpleString).mkString("[", ", ", "]")) } else { TypeCheckResult.TypeCheckSuccess } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index 2f23144858198..6928707f7bf6e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -302,18 +302,6 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru } private[sql] val ordering = RowOrdering.forSchema(this.fields.map(_.dataType)) - - private[sql] def supportOrdering(s: StructType): Boolean = { - s.fields.forall { f => - if (f.dataType.isInstanceOf[AtomicType]) { - true - } else if (f.dataType.isInstanceOf[StructType]) { - supportOrdering(f.dataType.asInstanceOf[StructType]) - } else { - false - } - } - } } object StructType extends AbstractDataType { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index aa19cdce31ec6..26935c6e3b24f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -68,22 +68,22 @@ class AnalysisErrorSuite extends SparkFunSuite with BeforeAndAfter { errorTest( "single invalid type, single arg", testRelation.select(TestFunction(dateLit :: Nil, IntegerType :: Nil).as('a)), - "cannot resolve" :: "testfunction" :: "argument 1" :: "expected to be of type int" :: - "'null' is of type date" ::Nil) + "cannot resolve" :: "testfunction" :: "argument 1" :: "requires int type" :: + "'null' is of date type" ::Nil) errorTest( "single invalid type, second arg", testRelation.select( TestFunction(dateLit :: dateLit :: Nil, DateType :: IntegerType :: Nil).as('a)), - "cannot resolve" :: "testfunction" :: "argument 2" :: "expected to be of type int" :: - "'null' is of type date" ::Nil) + "cannot resolve" :: "testfunction" :: "argument 2" :: "requires int type" :: + "'null' is of date type" ::Nil) errorTest( "multiple invalid type", testRelation.select( TestFunction(dateLit :: dateLit :: Nil, IntegerType :: IntegerType :: Nil).as('a)), "cannot resolve" :: "testfunction" :: "argument 1" :: "argument 2" :: - "expected to be of type int" :: "'null' is of type date" ::Nil) + "requires int type" :: "'null' is of date type" ::Nil) errorTest( "unresolved window function", @@ -111,12 +111,12 @@ class AnalysisErrorSuite extends SparkFunSuite with BeforeAndAfter { errorTest( "bad casts", testRelation.select(Literal(1).cast(BinaryType).as('badCast)), - "cannot cast" :: Literal(1).dataType.simpleString :: BinaryType.simpleString :: Nil) + "cannot cast" :: Literal(1).dataType.simpleString :: BinaryType.simpleString :: Nil) errorTest( "sorting by unsupported column types", listRelation.orderBy('list.asc), - "sorting" :: "type" :: "array" :: Nil) + "sort" :: "type" :: "array" :: Nil) errorTest( "non-boolean filters", diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index 8f616ae9d29c3..c9bcc68f02030 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -53,9 +53,9 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { } test("check types for unary arithmetic") { - assertError(UnaryMinus('stringField), "type (numeric or calendarinterval)") - assertError(Abs('stringField), "expected to be of type numeric") - assertError(BitwiseNot('stringField), "expected to be of type integral") + assertError(UnaryMinus('stringField), "(numeric or calendarinterval) type") + assertError(Abs('stringField), "requires numeric type") + assertError(BitwiseNot('stringField), "requires integral type") } test("check types for binary arithmetic") { @@ -78,21 +78,21 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertErrorForDifferingTypes(MaxOf('intField, 'booleanField)) assertErrorForDifferingTypes(MinOf('intField, 'booleanField)) - assertError(Add('booleanField, 'booleanField), "accepts (numeric or calendarinterval) type") + assertError(Add('booleanField, 'booleanField), "requires (numeric or calendarinterval) type") assertError(Subtract('booleanField, 'booleanField), - "accepts (numeric or calendarinterval) type") - assertError(Multiply('booleanField, 'booleanField), "accepts numeric type") - assertError(Divide('booleanField, 'booleanField), "accepts numeric type") - assertError(Remainder('booleanField, 'booleanField), "accepts numeric type") + "requires (numeric or calendarinterval) type") + assertError(Multiply('booleanField, 'booleanField), "requires numeric type") + assertError(Divide('booleanField, 'booleanField), "requires numeric type") + assertError(Remainder('booleanField, 'booleanField), "requires numeric type") - assertError(BitwiseAnd('booleanField, 'booleanField), "accepts integral type") - assertError(BitwiseOr('booleanField, 'booleanField), "accepts integral type") - assertError(BitwiseXor('booleanField, 'booleanField), "accepts integral type") + assertError(BitwiseAnd('booleanField, 'booleanField), "requires integral type") + assertError(BitwiseOr('booleanField, 'booleanField), "requires integral type") + assertError(BitwiseXor('booleanField, 'booleanField), "requires integral type") assertError(MaxOf('complexField, 'complexField), - s"accepts ${TypeCollection.Ordered.simpleString} type") + s"requires ${TypeCollection.Ordered.simpleString} type") assertError(MinOf('complexField, 'complexField), - s"accepts ${TypeCollection.Ordered.simpleString} type") + s"requires ${TypeCollection.Ordered.simpleString} type") } test("check types for predicates") { @@ -116,13 +116,13 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertErrorForDifferingTypes(GreaterThanOrEqual('intField, 'booleanField)) assertError(LessThan('complexField, 'complexField), - s"accepts ${TypeCollection.Ordered.simpleString} type") + s"requires ${TypeCollection.Ordered.simpleString} type") assertError(LessThanOrEqual('complexField, 'complexField), - s"accepts ${TypeCollection.Ordered.simpleString} type") + s"requires ${TypeCollection.Ordered.simpleString} type") assertError(GreaterThan('complexField, 'complexField), - s"accepts ${TypeCollection.Ordered.simpleString} type") + s"requires ${TypeCollection.Ordered.simpleString} type") assertError(GreaterThanOrEqual('complexField, 'complexField), - s"accepts ${TypeCollection.Ordered.simpleString} type") + s"requires ${TypeCollection.Ordered.simpleString} type") assertError(If('intField, 'stringField, 'stringField), "type of predicate expression in If should be boolean") @@ -145,11 +145,11 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertSuccess(SumDistinct('stringField)) assertSuccess(Average('stringField)) - assertError(Min('complexField), "doesn't support ordering on function min") - assertError(Max('complexField), "doesn't support ordering on function max") - assertError(Sum('booleanField), "function sum accepts numeric type") - assertError(SumDistinct('booleanField), "function sumDistinct accepts numeric type") - assertError(Average('booleanField), "function average accepts numeric type") + assertError(Min('complexField), "min does not support ordering on type") + assertError(Max('complexField), "max does not support ordering on type") + assertError(Sum('booleanField), "function sum requires numeric type") + assertError(SumDistinct('booleanField), "function sumDistinct requires numeric type") + assertError(Average('booleanField), "function average requires numeric type") } test("check types for others") { @@ -181,8 +181,8 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertSuccess(Round('intField, Literal(1))) assertError(Round('intField, 'intField), "Only foldable Expression is allowed") - assertError(Round('intField, 'booleanField), "expected to be of type int") - assertError(Round('intField, 'complexField), "expected to be of type int") - assertError(Round('booleanField, 'intField), "expected to be of type numeric") + assertError(Round('intField, 'booleanField), "requires int type") + assertError(Round('intField, 'complexField), "requires int type") + assertError(Round('booleanField, 'intField), "requires numeric type") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 4aff52d992e6b..952ba7d45c13e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -89,18 +89,6 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { condition.map(Filter(_, broadcastHashJoin)).getOrElse(broadcastHashJoin) :: Nil } - private[this] def isValidSort( - leftKeys: Seq[Expression], - rightKeys: Seq[Expression]): Boolean = { - leftKeys.zip(rightKeys).forall { keys => - (keys._1.dataType, keys._2.dataType) match { - case (l: AtomicType, r: AtomicType) => true - case (NullType, NullType) => true - case _ => false - } - } - } - def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, CanBroadcast(right)) => makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildRight) @@ -111,7 +99,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // If the sort merge join option is set, we want to use sort merge join prior to hashjoin // for now let's support inner join first, then add outer join case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) - if sqlContext.conf.sortMergeJoinEnabled && isValidSort(leftKeys, rightKeys) => + if sqlContext.conf.sortMergeJoinEnabled && RowOrdering.isOrderable(leftKeys) => val mergeJoin = joins.SortMergeJoin(leftKeys, rightKeys, planLater(left), planLater(right)) condition.map(Filter(_, mergeJoin)).getOrElse(mergeJoin) :: Nil diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 46921d14256b9..431dcf7382f16 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -305,13 +305,12 @@ class DataFrameFunctionsSuite extends QueryTest { val df2 = Seq((Array[Array[Int]](Array(2)), "x")).toDF("a", "b") assert(intercept[AnalysisException] { df2.selectExpr("sort_array(a)").collect() - }.getMessage().contains("Type ArrayType(IntegerType,false) is not the AtomicType, " + - "we can not perform the ordering operations")) + }.getMessage().contains("does not support sorting array of type array")) val df3 = Seq(("xxx", "x")).toDF("a", "b") assert(intercept[AnalysisException] { df3.selectExpr("sort_array(a)").collect() - }.getMessage().contains("ArrayType(AtomicType) is expected, but we got StringType")) + }.getMessage().contains("only supports array input")) } test("array size function") {