From 10fac82d5aef8b90300cc83b29c5a54f0fd18ab6 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Wed, 8 Jul 2015 16:33:21 -0700 Subject: [PATCH 1/4] [SPARK-8926][SQL] Good errors for ExpectsInputType expressions --- .../catalyst/analysis/HiveTypeCoercion.scala | 12 +- .../expressions/ExpectsInputTypes.scala | 13 +- .../spark/sql/types/AbstractDataType.scala | 30 +++- .../apache/spark/sql/types/ArrayType.scala | 8 +- .../org/apache/spark/sql/types/DataType.scala | 4 +- .../apache/spark/sql/types/DecimalType.scala | 8 +- .../org/apache/spark/sql/types/MapType.scala | 8 +- .../apache/spark/sql/types/StructType.scala | 8 +- .../analysis/AnalysisErrorSuite.scala | 167 ++++++++++++++++++ .../sql/catalyst/analysis/AnalysisSuite.scala | 126 ++----------- .../analysis/HiveTypeCoercionSuite.scala | 8 + .../apache/spark/sql/hive/HiveContext.scala | 2 +- 12 files changed, 252 insertions(+), 142 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 5367b7f3308ee..8cb71995eb818 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -702,11 +702,19 @@ object HiveTypeCoercion { @Nullable val ret: Expression = (inType, expectedType) match { // If the expected type is already a parent of the input type, no need to cast. - case _ if expectedType.isParentOf(inType) => e + case _ if expectedType.isSameType(inType) => e // Cast null type (usually from null literals) into target types case (NullType, target) => Cast(e, target.defaultConcreteType) + // If the function accepts any numeric type (i.e. the ADT `NumericType`) and the input is + // already a number, leave it as is. + case (_: NumericType, NumericType) => e + + // If the function accepts any numeric type and the input is a string, we follow the hive + // convention and cast that input into a double + case (StringType, NumericType) => Cast(e, NumericType.defaultConcreteType) + // Implicit cast among numeric types // If input is a numeric type but not decimal, and we expect a decimal type, // cast the input to unlimited precision decimal. @@ -732,7 +740,7 @@ object HiveTypeCoercion { // First see if we can find our input type in the type collection. If we can, then just // use the current expression; otherwise, find the first one we can implicitly cast. case (_, TypeCollection(types)) => - if (types.exists(_.isParentOf(inType))) { + if (types.exists(_.isSameType(inType))) { e } else { types.flatMap(implicitCast(e, _)).headOption.orNull diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala index 916e30154d4f1..14bd4eb0e4a58 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala @@ -37,7 +37,16 @@ trait ExpectsInputTypes { self: Expression => def inputTypes: Seq[AbstractDataType] override def checkInputDataTypes(): TypeCheckResult = { - // TODO: implement proper type checking. - TypeCheckResult.TypeCheckSuccess + val mismatches = children.zip(inputTypes).zipWithIndex.collect { + case ((child, expected), idx) if !expected.acceptsType(child.dataType) => + s"Argument ${idx + 1} is expected to be of type ${expected.simpleString}, " + + s"however, ${child.prettyString} is of type ${child.dataType}." + } + + if (mismatches.isEmpty) { + TypeCheckResult.TypeCheckSuccess + } else { + TypeCheckResult.TypeCheckFailure(mismatches.mkString(" ")) + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala index fb1b47e946214..ad75fa2e31d90 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala @@ -34,9 +34,16 @@ private[sql] abstract class AbstractDataType { private[sql] def defaultConcreteType: DataType /** - * Returns true if this data type is a parent of the `childCandidate`. + * Returns true if this data type is the same type as `other`. This is different that equality + * as equality will also consider data type parametrization, such as decimal precision. */ - private[sql] def isParentOf(childCandidate: DataType): Boolean + private[sql] def isSameType(other: DataType): Boolean + + /** + * Returns true if `other` is an acceptable input type for a function that expectes this, + * possibly abstract, DataType. + */ + private[sql] def acceptsType(other: DataType): Boolean = isSameType(other) /** Readable string representation for the type. */ private[sql] def simpleString: String @@ -58,11 +65,14 @@ private[sql] class TypeCollection(private val types: Seq[AbstractDataType]) require(types.nonEmpty, s"TypeCollection ($types) cannot be empty") - private[sql] override def defaultConcreteType: DataType = types.head.defaultConcreteType + override private[sql] def defaultConcreteType: DataType = types.head.defaultConcreteType + + override private[sql] def isSameType(other: DataType): Boolean = false - private[sql] override def isParentOf(childCandidate: DataType): Boolean = false + override private[sql] def acceptsType(other: DataType): Boolean = + types.exists(_.isSameType(other)) - private[sql] override def simpleString: String = { + override private[sql] def simpleString: String = { types.map(_.simpleString).mkString("(", " or ", ")") } } @@ -108,7 +118,7 @@ abstract class NumericType extends AtomicType { } -private[sql] object NumericType { +private[sql] object NumericType extends AbstractDataType { /** * Enables matching against NumericType for expressions: * {{{ @@ -117,6 +127,14 @@ private[sql] object NumericType { * }}} */ def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[NumericType] + + override private[sql] def defaultConcreteType: DataType = DoubleType + + override private[sql] def simpleString: String = "numeric" + + override private[sql] def isSameType(other: DataType): Boolean = false + + override private[sql] def acceptsType(other: DataType): Boolean = other.isInstanceOf[NumericType] } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala index 43413ec761e6b..76ca7a84c1d1a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala @@ -26,13 +26,13 @@ object ArrayType extends AbstractDataType { /** Construct a [[ArrayType]] object with the given element type. The `containsNull` is true. */ def apply(elementType: DataType): ArrayType = ArrayType(elementType, containsNull = true) - private[sql] override def defaultConcreteType: DataType = ArrayType(NullType, containsNull = true) + override private[sql] def defaultConcreteType: DataType = ArrayType(NullType, containsNull = true) - private[sql] override def isParentOf(childCandidate: DataType): Boolean = { - childCandidate.isInstanceOf[ArrayType] + override private[sql] def isSameType(other: DataType): Boolean = { + other.isInstanceOf[ArrayType] } - private[sql] override def simpleString: String = "array" + override private[sql] def simpleString: String = "array" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index 7d00047d08d74..806c48ccc9a6a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -75,9 +75,9 @@ abstract class DataType extends AbstractDataType { */ private[spark] def asNullable: DataType - private[sql] override def defaultConcreteType: DataType = this + override private[sql] def defaultConcreteType: DataType = this - private[sql] override def isParentOf(childCandidate: DataType): Boolean = this == childCandidate + override private[sql] def isSameType(other: DataType): Boolean = this == other } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala index 127b16ff85bed..a1cafeab1704d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -84,13 +84,13 @@ case class DecimalType(precisionInfo: Option[PrecisionInfo]) extends FractionalT /** Extra factory methods and pattern matchers for Decimals */ object DecimalType extends AbstractDataType { - private[sql] override def defaultConcreteType: DataType = Unlimited + override private[sql] def defaultConcreteType: DataType = Unlimited - private[sql] override def isParentOf(childCandidate: DataType): Boolean = { - childCandidate.isInstanceOf[DecimalType] + override private[sql] def isSameType(other: DataType): Boolean = { + other.isInstanceOf[DecimalType] } - private[sql] override def simpleString: String = "decimal" + override private[sql] def simpleString: String = "decimal" val Unlimited: DecimalType = DecimalType(None) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala index 868dea13d971e..ddead10bc2171 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala @@ -69,13 +69,13 @@ case class MapType( object MapType extends AbstractDataType { - private[sql] override def defaultConcreteType: DataType = apply(NullType, NullType) + override private[sql] def defaultConcreteType: DataType = apply(NullType, NullType) - private[sql] override def isParentOf(childCandidate: DataType): Boolean = { - childCandidate.isInstanceOf[MapType] + override private[sql] def isSameType(other: DataType): Boolean = { + other.isInstanceOf[MapType] } - private[sql] override def simpleString: String = "map" + override private[sql] def simpleString: String = "map" /** * Construct a [[MapType]] object with the given key type and value type. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index 3b17566d54d9b..538c675566914 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -303,13 +303,13 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru object StructType extends AbstractDataType { - private[sql] override def defaultConcreteType: DataType = new StructType + override private[sql] def defaultConcreteType: DataType = new StructType - private[sql] override def isParentOf(childCandidate: DataType): Boolean = { - childCandidate.isInstanceOf[StructType] + override private[sql] def isSameType(other: DataType): Boolean = { + other.isInstanceOf[StructType] } - private[sql] override def simpleString: String = "struct" + override private[sql] def simpleString: String = "struct" def apply(fields: Seq[StructField]): StructType = StructType(fields.toArray) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala new file mode 100644 index 0000000000000..85960bbdf9f25 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.scalatest.BeforeAndAfter + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.types._ +import org.apache.spark.sql.catalyst.{InternalRow, SimpleCatalystConf} +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ + +case class TestFunction( + children: Seq[Expression], + inputTypes: Seq[AbstractDataType]) extends Expression with ExpectsInputTypes { + override def nullable: Boolean = true + override def eval(input: InternalRow): Any = ??? + override def dataType: DataType = StringType +} + +case class UnresolvedTestPlan() extends LeafNode { + override lazy val resolved = false + override def output: Seq[Attribute] = Nil +} + +class AnalysisErrorSuite extends SparkFunSuite with BeforeAndAfter { + import AnalysisSuite._ + + def errorTest( + name: String, + plan: LogicalPlan, + errorMessages: Seq[String], + caseSensitive: Boolean = true): Unit = { + test(name) { + val error = intercept[AnalysisException] { + if (caseSensitive) { + caseSensitiveAnalyze(plan) + } else { + caseInsensitiveAnalyze(plan) + } + } + + errorMessages.foreach(m => assert(error.getMessage.toLowerCase contains m.toLowerCase)) + } + } + + val dateLit = Literal.create(null, DateType) + + errorTest( + "single invalid type, single arg", + testRelation.select(TestFunction(dateLit :: Nil, IntegerType :: Nil).as('a)), + "cannot resolve" :: "testfunction" :: "argument 1" :: "expected to be of type int" :: + "null is of type datetype" ::Nil) + + errorTest( + "single invalid type, second arg", + testRelation.select( + TestFunction(dateLit :: dateLit :: Nil, DateType :: IntegerType :: Nil).as('a)), + "cannot resolve" :: "testfunction" :: "argument 2" :: "expected to be of type int" :: + "null is of type datetype" ::Nil) + + errorTest( + "multiple invalid type", + testRelation.select( + TestFunction(dateLit :: dateLit :: Nil, IntegerType :: IntegerType :: Nil).as('a)), + "cannot resolve" :: "testfunction" :: "argument 1" :: "argument 2" :: + "expected to be of type int" :: "null is of type datetype" ::Nil) + + errorTest( + "unresolved window function", + testRelation2.select( + WindowExpression( + UnresolvedWindowFunction( + "lead", + UnresolvedAttribute("c") :: Nil), + WindowSpecDefinition( + UnresolvedAttribute("a") :: Nil, + SortOrder(UnresolvedAttribute("b"), Ascending) :: Nil, + UnspecifiedFrame)).as('window)), + "lead" :: "window functions currently requires a HiveContext" :: Nil) + + errorTest( + "too many generators", + listRelation.select(Explode('list).as('a), Explode('list).as('b)), + "only one generator" :: "explode" :: Nil) + + errorTest( + "unresolved attributes", + testRelation.select('abcd), + "cannot resolve" :: "abcd" :: Nil) + + errorTest( + "bad casts", + testRelation.select(Literal(1).cast(BinaryType).as('badCast)), + "cannot cast" :: Literal(1).dataType.simpleString :: BinaryType.simpleString :: Nil) + + errorTest( + "non-boolean filters", + testRelation.where(Literal(1)), + "filter" :: "'1'" :: "not a boolean" :: Literal(1).dataType.simpleString :: Nil) + + errorTest( + "missing group by", + testRelation2.groupBy('a)('b), + "'b'" :: "group by" :: Nil + ) + + errorTest( + "ambiguous field", + nestedRelation.select($"top.duplicateField"), + "Ambiguous reference to fields" :: "duplicateField" :: Nil, + caseSensitive = false) + + errorTest( + "ambiguous field due to case insensitivity", + nestedRelation.select($"top.differentCase"), + "Ambiguous reference to fields" :: "differentCase" :: "differentcase" :: Nil, + caseSensitive = false) + + errorTest( + "missing field", + nestedRelation2.select($"top.c"), + "No such struct field" :: "aField" :: "bField" :: "cField" :: Nil, + caseSensitive = false) + + errorTest( + "catch all unresolved plan", + UnresolvedTestPlan(), + "unresolved" :: Nil) + + + test("SPARK-6452 regression test") { + // CheckAnalysis should throw AnalysisException when Aggregate contains missing attribute(s) + val plan = + Aggregate( + Nil, + Alias(Sum(AttributeReference("a", IntegerType)(exprId = ExprId(1))), "b")() :: Nil, + LocalRelation( + AttributeReference("a", IntegerType)(exprId = ExprId(2)))) + + assert(plan.resolved) + + val message = intercept[AnalysisException] { + caseSensitiveAnalyze(plan) + }.getMessage + + assert(message.contains("resolved attribute(s) a#1 missing from a#2")) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 77ca080f366cd..58df1de983a09 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -class AnalysisSuite extends SparkFunSuite with BeforeAndAfter { +object AnalysisSuite { val caseSensitiveConf = new SimpleCatalystConf(true) val caseInsensitiveConf = new SimpleCatalystConf(false) @@ -61,25 +61,28 @@ class AnalysisSuite extends SparkFunSuite with BeforeAndAfter { val nestedRelation = LocalRelation( AttributeReference("top", StructType( StructField("duplicateField", StringType) :: - StructField("duplicateField", StringType) :: - StructField("differentCase", StringType) :: - StructField("differentcase", StringType) :: Nil + StructField("duplicateField", StringType) :: + StructField("differentCase", StringType) :: + StructField("differentcase", StringType) :: Nil ))()) val nestedRelation2 = LocalRelation( AttributeReference("top", StructType( StructField("aField", StringType) :: - StructField("bField", StringType) :: - StructField("cField", StringType) :: Nil + StructField("bField", StringType) :: + StructField("cField", StringType) :: Nil ))()) val listRelation = LocalRelation( AttributeReference("list", ArrayType(IntegerType))()) - before { - caseSensitiveCatalog.registerTable(Seq("TaBlE"), testRelation) - caseInsensitiveCatalog.registerTable(Seq("TaBlE"), testRelation) - } + caseSensitiveCatalog.registerTable(Seq("TaBlE"), testRelation) + caseInsensitiveCatalog.registerTable(Seq("TaBlE"), testRelation) +} + + +class AnalysisSuite extends SparkFunSuite with BeforeAndAfter { + import AnalysisSuite._ test("union project *") { val plan = (1 to 100) @@ -149,91 +152,6 @@ class AnalysisSuite extends SparkFunSuite with BeforeAndAfter { caseInsensitiveAnalyzer.execute(UnresolvedRelation(Seq("TaBlE"), None)) === testRelation) } - def errorTest( - name: String, - plan: LogicalPlan, - errorMessages: Seq[String], - caseSensitive: Boolean = true): Unit = { - test(name) { - val error = intercept[AnalysisException] { - if (caseSensitive) { - caseSensitiveAnalyze(plan) - } else { - caseInsensitiveAnalyze(plan) - } - } - - errorMessages.foreach(m => assert(error.getMessage.toLowerCase contains m.toLowerCase)) - } - } - - errorTest( - "unresolved window function", - testRelation2.select( - WindowExpression( - UnresolvedWindowFunction( - "lead", - UnresolvedAttribute("c") :: Nil), - WindowSpecDefinition( - UnresolvedAttribute("a") :: Nil, - SortOrder(UnresolvedAttribute("b"), Ascending) :: Nil, - UnspecifiedFrame)).as('window)), - "lead" :: "window functions currently requires a HiveContext" :: Nil) - - errorTest( - "too many generators", - listRelation.select(Explode('list).as('a), Explode('list).as('b)), - "only one generator" :: "explode" :: Nil) - - errorTest( - "unresolved attributes", - testRelation.select('abcd), - "cannot resolve" :: "abcd" :: Nil) - - errorTest( - "bad casts", - testRelation.select(Literal(1).cast(BinaryType).as('badCast)), - "cannot cast" :: Literal(1).dataType.simpleString :: BinaryType.simpleString :: Nil) - - errorTest( - "non-boolean filters", - testRelation.where(Literal(1)), - "filter" :: "'1'" :: "not a boolean" :: Literal(1).dataType.simpleString :: Nil) - - errorTest( - "missing group by", - testRelation2.groupBy('a)('b), - "'b'" :: "group by" :: Nil - ) - - errorTest( - "ambiguous field", - nestedRelation.select($"top.duplicateField"), - "Ambiguous reference to fields" :: "duplicateField" :: Nil, - caseSensitive = false) - - errorTest( - "ambiguous field due to case insensitivity", - nestedRelation.select($"top.differentCase"), - "Ambiguous reference to fields" :: "differentCase" :: "differentcase" :: Nil, - caseSensitive = false) - - errorTest( - "missing field", - nestedRelation2.select($"top.c"), - "No such struct field" :: "aField" :: "bField" :: "cField" :: Nil, - caseSensitive = false) - - case class UnresolvedTestPlan() extends LeafNode { - override lazy val resolved = false - override def output: Seq[Attribute] = Nil - } - - errorTest( - "catch all unresolved plan", - UnresolvedTestPlan(), - "unresolved" :: Nil) - test("divide should be casted into fractional types") { val testRelation2 = LocalRelation( @@ -258,22 +176,4 @@ class AnalysisSuite extends SparkFunSuite with BeforeAndAfter { assert(pl(3).dataType == DecimalType.Unlimited) assert(pl(4).dataType == DoubleType) } - - test("SPARK-6452 regression test") { - // CheckAnalysis should throw AnalysisException when Aggregate contains missing attribute(s) - val plan = - Aggregate( - Nil, - Alias(Sum(AttributeReference("a", IntegerType)(exprId = ExprId(1))), "b")() :: Nil, - LocalRelation( - AttributeReference("a", IntegerType)(exprId = ExprId(2)))) - - assert(plan.resolved) - - val message = intercept[AnalysisException] { - caseSensitiveAnalyze(plan) - }.getMessage - - assert(message.contains("resolved attribute(s) a#1 missing from a#2")) - } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index 93db33d44eb25..65c9d31ef4199 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -77,6 +77,14 @@ class HiveTypeCoercionSuite extends PlanTest { shouldCast(DecimalType(10, 2), TypeCollection(IntegerType, DecimalType), DecimalType(10, 2)) shouldCast(DecimalType(10, 2), TypeCollection(DecimalType, IntegerType), DecimalType(10, 2)) shouldCast(IntegerType, TypeCollection(DecimalType(10, 2), StringType), DecimalType(10, 2)) + + shouldCast(StringType, NumericType, DoubleType) + + // NumericType should not be changed when function accepts any of them. + Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, + DecimalType.Unlimited, DecimalType(10,2)).foreach { tpe => + shouldCast(tpe, NumericType, tpe) + } } test("ineligible implicit type cast") { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 439d8cab5f257..bbc39b892b79e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -359,7 +359,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { hiveconf.set(key, value) } - private[sql] override def setConf[T](entry: SQLConfEntry[T], value: T): Unit = { + override private[sql] def setConf[T](entry: SQLConfEntry[T], value: T): Unit = { setConf(entry.key, entry.stringConverter(value)) } From 5428fda5a2d7d0d958f5dc47132f096a94df9113 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Wed, 8 Jul 2015 16:45:41 -0700 Subject: [PATCH 2/4] style --- .../apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala | 2 +- .../spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 85960bbdf9f25..683a06b746e51 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -32,7 +32,7 @@ case class TestFunction( children: Seq[Expression], inputTypes: Seq[AbstractDataType]) extends Expression with ExpectsInputTypes { override def nullable: Boolean = true - override def eval(input: InternalRow): Any = ??? + override def eval(input: InternalRow): Any = throw new NotImplementedError override def dataType: DataType = StringType } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index 65c9d31ef4199..6e3aa0eebeb15 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -82,7 +82,7 @@ class HiveTypeCoercionSuite extends PlanTest { // NumericType should not be changed when function accepts any of them. Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, - DecimalType.Unlimited, DecimalType(10,2)).foreach { tpe => + DecimalType.Unlimited, DecimalType(10, 2)).foreach { tpe => shouldCast(tpe, NumericType, tpe) } } From 137160d2769bb6e0a3a5f2f84fd1d903ed59099a Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Wed, 8 Jul 2015 16:46:51 -0700 Subject: [PATCH 3/4] style --- .../apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 683a06b746e51..8f4c1bcad3fc0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -32,7 +32,7 @@ case class TestFunction( children: Seq[Expression], inputTypes: Seq[AbstractDataType]) extends Expression with ExpectsInputTypes { override def nullable: Boolean = true - override def eval(input: InternalRow): Any = throw new NotImplementedError + override def eval(input: InternalRow): Any = throw new UnsupportedOperationException override def dataType: DataType = StringType } From c654a0e4f308d53e82ab78e6679eabd43649d9db Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Wed, 8 Jul 2015 17:50:51 -0700 Subject: [PATCH 4/4] fix udts and make errors pretty --- .../spark/sql/catalyst/expressions/ExpectsInputTypes.scala | 2 +- .../scala/org/apache/spark/sql/types/UserDefinedType.scala | 5 ++++- .../spark/sql/catalyst/analysis/AnalysisErrorSuite.scala | 6 +++--- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala index 14bd4eb0e4a58..986cc09499d1f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala @@ -40,7 +40,7 @@ trait ExpectsInputTypes { self: Expression => val mismatches = children.zip(inputTypes).zipWithIndex.collect { case ((child, expected), idx) if !expected.acceptsType(child.dataType) => s"Argument ${idx + 1} is expected to be of type ${expected.simpleString}, " + - s"however, ${child.prettyString} is of type ${child.dataType}." + s"however, ${child.prettyString} is of type ${child.dataType.simpleString}." } if (mismatches.isEmpty) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala index 6b20505c6009a..e47cfb4833bd8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala @@ -77,5 +77,8 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable { * For UDT, asNullable will not change the nullability of its internal sqlType and just returns * itself. */ - private[spark] override def asNullable: UserDefinedType[UserType] = this + override private[spark] def asNullable: UserDefinedType[UserType] = this + + override private[sql] def acceptsType(dataType: DataType) = + this.getClass == dataType.getClass } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 8f4c1bcad3fc0..73236c3acbca2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -68,21 +68,21 @@ class AnalysisErrorSuite extends SparkFunSuite with BeforeAndAfter { "single invalid type, single arg", testRelation.select(TestFunction(dateLit :: Nil, IntegerType :: Nil).as('a)), "cannot resolve" :: "testfunction" :: "argument 1" :: "expected to be of type int" :: - "null is of type datetype" ::Nil) + "null is of type date" ::Nil) errorTest( "single invalid type, second arg", testRelation.select( TestFunction(dateLit :: dateLit :: Nil, DateType :: IntegerType :: Nil).as('a)), "cannot resolve" :: "testfunction" :: "argument 2" :: "expected to be of type int" :: - "null is of type datetype" ::Nil) + "null is of type date" ::Nil) errorTest( "multiple invalid type", testRelation.select( TestFunction(dateLit :: dateLit :: Nil, IntegerType :: IntegerType :: Nil).as('a)), "cannot resolve" :: "testfunction" :: "argument 1" :: "argument 2" :: - "expected to be of type int" :: "null is of type datetype" ::Nil) + "expected to be of type int" :: "null is of type date" ::Nil) errorTest( "unresolved window function",