From 5b9b9013826d126b8d7ce986515f395d147acb91 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Wed, 4 Jul 2018 17:25:50 +0900 Subject: [PATCH 1/4] Type coercion between StructTypes. --- .../sql/catalyst/analysis/TypeCoercion.scala | 18 ++++ .../catalyst/analysis/TypeCoercionSuite.scala | 91 ++++++++++++++++++- 2 files changed, 104 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index cf90e6e555fc8..099be788cd7ad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -185,6 +185,15 @@ object TypeCoercion { MapType(kt, vt, valueContainsNull1 || valueContainsNull2) } } + case (StructType(fields1), StructType(fields2)) if fields1.length == fields2.length => + val resolver = SQLConf.get.resolver + fields1.zip(fields2).foldLeft(Option(new StructType())) { + case (Some(struct), (field1, field2)) if resolver(field1.name, field2.name) => + findWiderTypeForTwo(field1.dataType, field2.dataType).map { + dt => struct.add(field1.name, dt, field1.nullable || field2.nullable) + } + case _ => None + } case _ => None }) } @@ -232,6 +241,15 @@ object TypeCoercion { MapType(kt, vt, valueContainsNull1 || valueContainsNull2) } } + case (StructType(fields1), StructType(fields2)) if fields1.length == fields2.length => + val resolver = SQLConf.get.resolver + fields1.zip(fields2).foldLeft(Option(new StructType())) { + case (Some(struct), (field1, field2)) if resolver(field1.name, field2.name) => + findWiderTypeWithoutStringPromotionForTwo(field1.dataType, field2.dataType).map { + dt => struct.add(field1.name, dt, field1.nullable || field2.nullable) + } + case _ => None + } case _ => None }) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 4e5ca1b8cdd36..d0724dd3f92ab 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -54,7 +54,7 @@ class TypeCoercionSuite extends AnalysisTest { // | NullType | ByteType | ShortType | IntegerType | LongType | DoubleType | FloatType | Dec(10, 2) | BinaryType | BooleanType | StringType | DateType | TimestampType | ArrayType | MapType | StructType | NullType | CalendarIntervalType | DecimalType(38, 18) | DoubleType | IntegerType | // | CalendarIntervalType | X | X | X | X | X | X | X | X | X | X | X | X | X | X | X | X | CalendarIntervalType | X | X | X | // +----------------------+----------+-----------+-------------+----------+------------+-----------+------------+------------+-------------+------------+----------+---------------+------------+----------+-------------+----------+----------------------+---------------------+-------------+--------------+ - // Note: StructType* is castable only when the internal child types also match; otherwise, not castable. + // Note: StructType* is castable when all the internal child types are castable according to the table. // Note: ArrayType* is castable when the element type is castable according to the table. // Note: MapType* is castable when both the key type and the value type are castable according to the table. // scalastyle:on line.size.limit @@ -454,15 +454,18 @@ class TypeCoercionSuite extends AnalysisTest { def widenTestWithStringPromotion( t1: DataType, t2: DataType, - expected: Option[DataType]): Unit = { - checkWidenType(TypeCoercion.findWiderTypeForTwo, t1, t2, expected) + expected: Option[DataType], + isSymmetric: Boolean = true): Unit = { + checkWidenType(TypeCoercion.findWiderTypeForTwo, t1, t2, expected, isSymmetric) } def widenTestWithoutStringPromotion( t1: DataType, t2: DataType, - expected: Option[DataType]): Unit = { - checkWidenType(TypeCoercion.findWiderTypeWithoutStringPromotionForTwo, t1, t2, expected) + expected: Option[DataType], + isSymmetric: Boolean = true): Unit = { + checkWidenType( + TypeCoercion.findWiderTypeWithoutStringPromotionForTwo, t1, t2, expected, isSymmetric) } // Decimal @@ -492,6 +495,10 @@ class TypeCoercionSuite extends AnalysisTest { ArrayType(MapType(IntegerType, FloatType), containsNull = false), ArrayType(MapType(LongType, DoubleType), containsNull = false), Some(ArrayType(MapType(LongType, DoubleType), containsNull = false))) + widenTestWithStringPromotion( + ArrayType(new StructType().add("num", ShortType), containsNull = false), + ArrayType(new StructType().add("num", LongType), containsNull = false), + Some(ArrayType(new StructType().add("num", LongType), containsNull = false))) // MapType widenTestWithStringPromotion( @@ -506,6 +513,64 @@ class TypeCoercionSuite extends AnalysisTest { MapType(IntegerType, MapType(ShortType, TimestampType), valueContainsNull = false), MapType(LongType, MapType(DoubleType, StringType), valueContainsNull = false), Some(MapType(LongType, MapType(DoubleType, StringType), valueContainsNull = false))) + widenTestWithStringPromotion( + MapType(IntegerType, new StructType().add("num", ShortType), valueContainsNull = false), + MapType(LongType, new StructType().add("num", LongType), valueContainsNull = false), + Some(MapType(LongType, new StructType().add("num", LongType), valueContainsNull = false))) + + // StructType + widenTestWithStringPromotion( + new StructType() + .add("num", ShortType, nullable = true).add("ts", StringType, nullable = false), + new StructType() + .add("num", DoubleType, nullable = false).add("ts", TimestampType, nullable = true), + Some(new StructType() + .add("num", DoubleType, nullable = true).add("ts", StringType, nullable = true))) + widenTestWithStringPromotion( + new StructType() + .add("arr", ArrayType(ShortType, containsNull = false), nullable = false), + new StructType() + .add("arr", ArrayType(DoubleType, containsNull = true), nullable = false), + Some(new StructType() + .add("arr", ArrayType(DoubleType, containsNull = true), nullable = false))) + widenTestWithStringPromotion( + new StructType() + .add("map", MapType(ShortType, TimestampType, valueContainsNull = true), nullable = false), + new StructType() + .add("map", MapType(DoubleType, StringType, valueContainsNull = false), nullable = false), + Some(new StructType() + .add("map", MapType(DoubleType, StringType, valueContainsNull = true), nullable = false))) + + widenTestWithStringPromotion( + new StructType().add("num", IntegerType), + new StructType().add("num", LongType).add("str", StringType), + None) + widenTestWithoutStringPromotion( + new StructType().add("num", IntegerType), + new StructType().add("num", LongType).add("str", StringType), + None) + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + widenTestWithStringPromotion( + new StructType().add("a", IntegerType), + new StructType().add("A", LongType), + None) + widenTestWithoutStringPromotion( + new StructType().add("a", IntegerType), + new StructType().add("A", LongType), + None) + } + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + widenTestWithStringPromotion( + new StructType().add("a", IntegerType), + new StructType().add("A", LongType), + Some(new StructType().add("a", LongType)), + isSymmetric = false) + widenTestWithoutStringPromotion( + new StructType().add("a", IntegerType), + new StructType().add("A", LongType), + Some(new StructType().add("a", LongType)), + isSymmetric = false) + } // Without string promotion widenTestWithoutStringPromotion(IntegerType, StringType, None) @@ -520,6 +585,14 @@ class TypeCoercionSuite extends AnalysisTest { MapType(StringType, IntegerType), MapType(TimestampType, IntegerType), None) widenTestWithoutStringPromotion( MapType(IntegerType, StringType), MapType(IntegerType, TimestampType), None) + widenTestWithoutStringPromotion( + new StructType().add("a", IntegerType), + new StructType().add("a", StringType), + None) + widenTestWithoutStringPromotion( + new StructType().add("a", StringType), + new StructType().add("a", IntegerType), + None) // String promotion widenTestWithStringPromotion(IntegerType, StringType, Some(StringType)) @@ -544,6 +617,14 @@ class TypeCoercionSuite extends AnalysisTest { MapType(IntegerType, StringType), MapType(IntegerType, TimestampType), Some(MapType(IntegerType, StringType))) + widenTestWithStringPromotion( + new StructType().add("a", IntegerType), + new StructType().add("a", StringType), + Some(new StructType().add("a", StringType))) + widenTestWithStringPromotion( + new StructType().add("a", StringType), + new StructType().add("a", IntegerType), + Some(new StructType().add("a", StringType))) } private def ruleTest(rule: Rule[LogicalPlan], initial: Expression, transformed: Expression) { From bed6849dbf51e1981772cd353ce1a7ae4f0626e2 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Thu, 5 Jul 2018 14:40:14 +0900 Subject: [PATCH 2/4] Reduce code duplication. --- .../sql/catalyst/analysis/TypeCoercion.scala | 67 +++++++------------ 1 file changed, 26 insertions(+), 41 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 099be788cd7ad..5e3d5de30adba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -166,6 +166,30 @@ object TypeCoercion { case (l, r) => None } + private def mergeComplexTypes( + t1: DataType, + t2: DataType, + mergeFunc: (DataType, DataType) => Option[DataType]): Option[DataType] = (t1, t2) match { + case (ArrayType(et1, containsNull1), ArrayType(et2, containsNull2)) => + mergeFunc(et1, et2).map(ArrayType(_, containsNull1 || containsNull2)) + case (MapType(kt1, vt1, valueContainsNull1), MapType(kt2, vt2, valueContainsNull2)) => + mergeFunc(kt1, kt2).flatMap { kt => + mergeFunc(vt1, vt2).map { vt => + MapType(kt, vt, valueContainsNull1 || valueContainsNull2) + } + } + case (StructType(fields1), StructType(fields2)) if fields1.length == fields2.length => + val resolver = SQLConf.get.resolver + fields1.zip(fields2).foldLeft(Option(new StructType())) { + case (Some(struct), (field1, field2)) if resolver(field1.name, field2.name) => + mergeFunc(field1.dataType, field2.dataType).map { + dt => struct.add(field1.name, dt, field1.nullable || field2.nullable) + } + case _ => None + } + case _ => None + } + /** * Case 2 type widening (see the classdoc comment above for TypeCoercion). * @@ -176,26 +200,7 @@ object TypeCoercion { findTightestCommonType(t1, t2) .orElse(findWiderTypeForDecimal(t1, t2)) .orElse(stringPromotion(t1, t2)) - .orElse((t1, t2) match { - case (ArrayType(et1, containsNull1), ArrayType(et2, containsNull2)) => - findWiderTypeForTwo(et1, et2).map(ArrayType(_, containsNull1 || containsNull2)) - case (MapType(kt1, vt1, valueContainsNull1), MapType(kt2, vt2, valueContainsNull2)) => - findWiderTypeForTwo(kt1, kt2).flatMap { kt => - findWiderTypeForTwo(vt1, vt2).map { vt => - MapType(kt, vt, valueContainsNull1 || valueContainsNull2) - } - } - case (StructType(fields1), StructType(fields2)) if fields1.length == fields2.length => - val resolver = SQLConf.get.resolver - fields1.zip(fields2).foldLeft(Option(new StructType())) { - case (Some(struct), (field1, field2)) if resolver(field1.name, field2.name) => - findWiderTypeForTwo(field1.dataType, field2.dataType).map { - dt => struct.add(field1.name, dt, field1.nullable || field2.nullable) - } - case _ => None - } - case _ => None - }) + .orElse(mergeComplexTypes(t1, t2, findWiderTypeForTwo)) } /** @@ -231,27 +236,7 @@ object TypeCoercion { t2: DataType): Option[DataType] = { findTightestCommonType(t1, t2) .orElse(findWiderTypeForDecimal(t1, t2)) - .orElse((t1, t2) match { - case (ArrayType(et1, containsNull1), ArrayType(et2, containsNull2)) => - findWiderTypeWithoutStringPromotionForTwo(et1, et2) - .map(ArrayType(_, containsNull1 || containsNull2)) - case (MapType(kt1, vt1, valueContainsNull1), MapType(kt2, vt2, valueContainsNull2)) => - findWiderTypeWithoutStringPromotionForTwo(kt1, kt2).flatMap { kt => - findWiderTypeWithoutStringPromotionForTwo(vt1, vt2).map { vt => - MapType(kt, vt, valueContainsNull1 || valueContainsNull2) - } - } - case (StructType(fields1), StructType(fields2)) if fields1.length == fields2.length => - val resolver = SQLConf.get.resolver - fields1.zip(fields2).foldLeft(Option(new StructType())) { - case (Some(struct), (field1, field2)) if resolver(field1.name, field2.name) => - findWiderTypeWithoutStringPromotionForTwo(field1.dataType, field2.dataType).map { - dt => struct.add(field1.name, dt, field1.nullable || field2.nullable) - } - case _ => None - } - case _ => None - }) + .orElse(mergeComplexTypes(t1, t2, findWiderTypeWithoutStringPromotionForTwo)) } def findWiderTypeWithoutStringPromotion(types: Seq[DataType]): Option[DataType] = { From d1e5a78f5204cd10338df80b479a6f6d77a44e29 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Thu, 5 Jul 2018 19:54:11 +0900 Subject: [PATCH 3/4] Rename to findWiderTypeForComplex. --- .../apache/spark/sql/catalyst/analysis/TypeCoercion.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 5e3d5de30adba..43db3f10bcd90 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -166,7 +166,7 @@ object TypeCoercion { case (l, r) => None } - private def mergeComplexTypes( + private def findWiderTypeForComplex( t1: DataType, t2: DataType, mergeFunc: (DataType, DataType) => Option[DataType]): Option[DataType] = (t1, t2) match { @@ -200,7 +200,7 @@ object TypeCoercion { findTightestCommonType(t1, t2) .orElse(findWiderTypeForDecimal(t1, t2)) .orElse(stringPromotion(t1, t2)) - .orElse(mergeComplexTypes(t1, t2, findWiderTypeForTwo)) + .orElse(findWiderTypeForComplex(t1, t2, findWiderTypeForTwo)) } /** @@ -236,7 +236,7 @@ object TypeCoercion { t2: DataType): Option[DataType] = { findTightestCommonType(t1, t2) .orElse(findWiderTypeForDecimal(t1, t2)) - .orElse(mergeComplexTypes(t1, t2, findWiderTypeWithoutStringPromotionForTwo)) + .orElse(findWiderTypeForComplex(t1, t2, findWiderTypeWithoutStringPromotionForTwo)) } def findWiderTypeWithoutStringPromotion(types: Seq[DataType]): Option[DataType] = { From 1397e53e52b5b31e350b5759d84746f0fc43f5f9 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Thu, 5 Jul 2018 22:09:12 +0900 Subject: [PATCH 4/4] Reuse for findTightestCommonType. --- .../sql/catalyst/analysis/TypeCoercion.scala | 36 +++++-------------- .../catalyst/analysis/TypeCoercionSuite.scala | 2 +- 2 files changed, 10 insertions(+), 28 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 43db3f10bcd90..b6ca30c7398f2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -102,25 +102,7 @@ object TypeCoercion { case (_: TimestampType, _: DateType) | (_: DateType, _: TimestampType) => Some(TimestampType) - case (t1 @ StructType(fields1), t2 @ StructType(fields2)) if t1.sameType(t2) => - Some(StructType(fields1.zip(fields2).map { case (f1, f2) => - // Since `t1.sameType(t2)` is true, two StructTypes have the same DataType - // except `name` (in case of `spark.sql.caseSensitive=false`) and `nullable`. - // - Different names: use f1.name - // - Different nullabilities: `nullable` is true iff one of them is nullable. - val dataType = findTightestCommonType(f1.dataType, f2.dataType).get - StructField(f1.name, dataType, nullable = f1.nullable || f2.nullable) - })) - - case (a1 @ ArrayType(et1, hasNull1), a2 @ ArrayType(et2, hasNull2)) if a1.sameType(a2) => - findTightestCommonType(et1, et2).map(ArrayType(_, hasNull1 || hasNull2)) - - case (m1 @ MapType(kt1, vt1, hasNull1), m2 @ MapType(kt2, vt2, hasNull2)) if m1.sameType(m2) => - val keyType = findTightestCommonType(kt1, kt2) - val valueType = findTightestCommonType(vt1, vt2) - Some(MapType(keyType.get, valueType.get, hasNull1 || hasNull2)) - - case _ => None + case (t1, t2) => findTypeForComplex(t1, t2, findTightestCommonType) } /** Promotes all the way to StringType. */ @@ -166,15 +148,15 @@ object TypeCoercion { case (l, r) => None } - private def findWiderTypeForComplex( + private def findTypeForComplex( t1: DataType, t2: DataType, - mergeFunc: (DataType, DataType) => Option[DataType]): Option[DataType] = (t1, t2) match { + findTypeFunc: (DataType, DataType) => Option[DataType]): Option[DataType] = (t1, t2) match { case (ArrayType(et1, containsNull1), ArrayType(et2, containsNull2)) => - mergeFunc(et1, et2).map(ArrayType(_, containsNull1 || containsNull2)) + findTypeFunc(et1, et2).map(ArrayType(_, containsNull1 || containsNull2)) case (MapType(kt1, vt1, valueContainsNull1), MapType(kt2, vt2, valueContainsNull2)) => - mergeFunc(kt1, kt2).flatMap { kt => - mergeFunc(vt1, vt2).map { vt => + findTypeFunc(kt1, kt2).flatMap { kt => + findTypeFunc(vt1, vt2).map { vt => MapType(kt, vt, valueContainsNull1 || valueContainsNull2) } } @@ -182,7 +164,7 @@ object TypeCoercion { val resolver = SQLConf.get.resolver fields1.zip(fields2).foldLeft(Option(new StructType())) { case (Some(struct), (field1, field2)) if resolver(field1.name, field2.name) => - mergeFunc(field1.dataType, field2.dataType).map { + findTypeFunc(field1.dataType, field2.dataType).map { dt => struct.add(field1.name, dt, field1.nullable || field2.nullable) } case _ => None @@ -200,7 +182,7 @@ object TypeCoercion { findTightestCommonType(t1, t2) .orElse(findWiderTypeForDecimal(t1, t2)) .orElse(stringPromotion(t1, t2)) - .orElse(findWiderTypeForComplex(t1, t2, findWiderTypeForTwo)) + .orElse(findTypeForComplex(t1, t2, findWiderTypeForTwo)) } /** @@ -236,7 +218,7 @@ object TypeCoercion { t2: DataType): Option[DataType] = { findTightestCommonType(t1, t2) .orElse(findWiderTypeForDecimal(t1, t2)) - .orElse(findWiderTypeForComplex(t1, t2, findWiderTypeWithoutStringPromotionForTwo)) + .orElse(findTypeForComplex(t1, t2, findWiderTypeWithoutStringPromotionForTwo)) } def findWiderTypeWithoutStringPromotion(types: Seq[DataType]): Option[DataType] = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index d0724dd3f92ab..8cc5a23779a2a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -397,7 +397,7 @@ class TypeCoercionSuite extends AnalysisTest { widenTest( StructType(Seq(StructField("a", IntegerType, nullable = false))), StructType(Seq(StructField("a", DoubleType, nullable = false))), - None) + Some(StructType(Seq(StructField("a", DoubleType, nullable = false))))) widenTest( StructType(Seq(StructField("a", IntegerType, nullable = false))),