-
Notifications
You must be signed in to change notification settings - Fork 28.8k
[SPARK-22826][SQL] findWiderTypeForTwo Fails over StructField of Array #20010
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
bbcc7c6
90c2979
eb52629
527099d
86e1929
583f674
23c83ac
a3ed2e0
e034bbc
7d146e3
d42dfa5
09e49fb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions._ | |
import org.apache.spark.sql.catalyst.expressions.aggregate._ | ||
import org.apache.spark.sql.catalyst.plans.logical._ | ||
import org.apache.spark.sql.catalyst.rules.Rule | ||
import org.apache.spark.sql.internal.SQLConf | ||
import org.apache.spark.sql.types._ | ||
|
||
|
||
|
@@ -99,11 +100,22 @@ object TypeCoercion { | |
case (_: TimestampType, _: DateType) | (_: DateType, _: TimestampType) => | ||
Some(TimestampType) | ||
|
||
case (t1 @ ArrayType(elementType1, nullable1), t2 @ ArrayType(elementType2, nullable2)) | ||
if t1.sameType(t2) => | ||
val dataType = findTightestCommonType(elementType1, elementType2).get | ||
Some(ArrayType(dataType, nullable1 || nullable2)) | ||
|
||
case (t1 @ MapType(keyType1, valueType1, nullable1), | ||
t2 @ MapType(keyType2, valueType2, nullable2)) if t1.sameType(t2) => | ||
val keyType = findTightestCommonType(keyType1, keyType2).get | ||
val valueType = findTightestCommonType(valueType1, valueType2).get | ||
Some(MapType(keyType, valueType, nullable1 || nullable2)) | ||
|
||
case (t1 @ StructType(fields1), t2 @ StructType(fields2)) if t1.sameType(t2) => | ||
Some(StructType(fields1.zip(fields2).map { case (f1, f2) => | ||
// Since `t1.sameType(t2)` is true, two StructTypes have the same DataType | ||
// except `name` (in case of `spark.sql.caseSensitive=false`) and `nullable`. | ||
// - Different names: use f1.name | ||
// - Names differing in case: use f1.name | ||
// - Different nullabilities: `nullable` is true iff one of them is nullable. | ||
val dataType = findTightestCommonType(f1.dataType, f2.dataType).get | ||
StructField(f1.name, dataType, nullable = f1.nullable || f2.nullable) | ||
|
@@ -148,6 +160,61 @@ object TypeCoercion { | |
case (l, r) => None | ||
} | ||
|
||
/** | ||
* Case 2 type widening over complex types. `widerTypeFunc` is a function that finds the wider | ||
* type over point types. The `widerTypeFunc` specifies behavior over whether types should be | ||
* promoted to StringType. | ||
*/ | ||
private def findWiderTypeForTwoComplex( | ||
t1: DataType, | ||
t2: DataType, | ||
widerTypeFunc: (DataType, DataType) => Option[DataType]): Option[DataType] = { | ||
(t1, t2) match { | ||
case (_, _) if t1 == t2 => Some(t1) | ||
case (NullType, _) => Some(t1) | ||
case (_, NullType) => Some(t1) | ||
|
||
case (ArrayType(elementType1, nullable1), ArrayType(elementType2, nullable2)) => | ||
val dataType = widerTypeFunc.apply(elementType1, elementType2) | ||
|
||
dataType.map(ArrayType(_, nullable1 || nullable2)) | ||
|
||
case (MapType(keyType1, valueType1, nullable1), MapType(keyType2, valueType2, nullable2)) => | ||
val keyType = widerTypeFunc.apply(keyType1, keyType2) | ||
val valueType = widerTypeFunc.apply(valueType1, valueType2) | ||
|
||
if (keyType.nonEmpty && valueType.nonEmpty) { | ||
Some(MapType(keyType.get, valueType.get, nullable1 || nullable2)) | ||
} else { | ||
None | ||
} | ||
|
||
case (StructType(fields1), StructType(fields2)) => | ||
val fieldTypes = fields1.zip(fields2).map { case (f1, f2) => | ||
// In order to match Case 2 widening of types, we do not require field data types be the | ||
// same type, but fields having different names are considered heterogeneous | ||
if ((SQLConf.get.caseSensitiveAnalysis && f1.name.equals(f2.name)) | ||
|| (!SQLConf.get.caseSensitiveAnalysis && f1.name.equalsIgnoreCase(f2.name))) { | ||
widerTypeFunc(f1.dataType, f2.dataType) | ||
} else { | ||
None | ||
} | ||
} | ||
|
||
if (fieldTypes.forall(_.nonEmpty)) { | ||
val structFields = fields1.zip(fields2).zip(fieldTypes).map { case ((f1, f2), t) => | ||
StructField(f1.name, t.get, nullable = f1.nullable || f2.nullable) | ||
} | ||
|
||
Some(StructType(structFields)) | ||
} else { | ||
None | ||
} | ||
|
||
case _ => None | ||
} | ||
} | ||
|
||
/** | ||
* Case 2 type widening (see the classdoc comment above for TypeCoercion). | ||
* | ||
|
@@ -158,11 +225,7 @@ object TypeCoercion { | |
findTightestCommonType(t1, t2) | ||
.orElse(findWiderTypeForDecimal(t1, t2)) | ||
.orElse(stringPromotion(t1, t2)) | ||
.orElse((t1, t2) match { | ||
case (ArrayType(et1, containsNull1), ArrayType(et2, containsNull2)) => | ||
findWiderTypeForTwo(et1, et2).map(ArrayType(_, containsNull1 || containsNull2)) | ||
case _ => None | ||
}) | ||
|
||
.orElse(findWiderTypeForTwoComplex(t1, t2, findWiderTypeForTwo)) | ||
|
||
} | ||
|
||
private def findWiderCommonType(types: Seq[DataType]): Option[DataType] = { | ||
|
@@ -182,12 +245,7 @@ object TypeCoercion { | |
t2: DataType): Option[DataType] = { | ||
findTightestCommonType(t1, t2) | ||
.orElse(findWiderTypeForDecimal(t1, t2)) | ||
.orElse((t1, t2) match { | ||
case (ArrayType(et1, containsNull1), ArrayType(et2, containsNull2)) => | ||
findWiderTypeWithoutStringPromotionForTwo(et1, et2) | ||
.map(ArrayType(_, containsNull1 || containsNull2)) | ||
case _ => None | ||
}) | ||
|
||
.orElse(findWiderTypeForTwoComplex(t1, t2, findWiderTypeWithoutStringPromotionForTwo)) | ||
} | ||
|
||
def findWiderTypeWithoutStringPromotion(types: Seq[DataType]): Option[DataType] = { | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -389,6 +389,51 @@ class TypeCoercionSuite extends AnalysisTest { | |
widenTest(StringType, MapType(IntegerType, StringType, true), None) | ||
widenTest(ArrayType(IntegerType), StructType(Seq()), None) | ||
|
||
widenTest( | ||
ArrayType(StringType, containsNull = true), | ||
ArrayType(StringType, containsNull = false), | ||
Some(ArrayType(StringType, containsNull = true))) | ||
widenTest( | ||
MapType(StringType, StringType, valueContainsNull = true), | ||
MapType(StringType, StringType, valueContainsNull = false), | ||
Some(MapType(StringType, StringType, valueContainsNull = true))) | ||
widenTest( | ||
StructType(Seq(StructField("a", StringType, nullable = true))), | ||
StructType(Seq(StructField("a", StringType, nullable = false))), | ||
Some(StructType(Seq(StructField("a", StringType, nullable = true))))) | ||
|
||
widenTest( | ||
StructType( | ||
Seq(StructField("a", ArrayType(StringType, containsNull = true), nullable = true))), | ||
StructType( | ||
Seq(StructField("a", ArrayType(StringType, containsNull = false), nullable = false))), | ||
Some(StructType( | ||
Seq(StructField("a", ArrayType(StringType, containsNull = true), nullable = true))))) | ||
widenTest( | ||
StructType( | ||
Seq(StructField("a", MapType(StringType, StringType, valueContainsNull = true)))), | ||
StructType( | ||
Seq(StructField("a", MapType(StringType, StringType, valueContainsNull = false)))), | ||
Some(StructType( | ||
Seq(StructField("a", MapType(StringType, StringType, valueContainsNull = true)))))) | ||
widenTest( | ||
ArrayType( | ||
StructType(Seq(StructField("a", StringType, nullable = true))), containsNull = true), | ||
ArrayType( | ||
StructType(Seq(StructField("a", StringType, nullable = false))), containsNull = false), | ||
Some(ArrayType( | ||
StructType(Seq(StructField("a", StringType, nullable = true))), containsNull = true))) | ||
widenTest( | ||
MapType( | ||
StringType, | ||
StructType(Seq(StructField("a", StringType, nullable = true))), valueContainsNull = true), | ||
MapType( | ||
StringType, | ||
StructType(Seq(StructField("a", StringType, nullable = false))), valueContainsNull = false), | ||
Some(MapType( | ||
StringType, | ||
StructType(Seq(StructField("a", StringType, nullable = true))), valueContainsNull = true))) | ||
|
||
widenTest( | ||
StructType(Seq(StructField("a", IntegerType))), | ||
StructType(Seq(StructField("b", IntegerType))), | ||
|
@@ -431,7 +476,7 @@ class TypeCoercionSuite extends AnalysisTest { | |
} | ||
} | ||
|
||
test("wider common type for decimal and array") { | ||
test("wider common type for decimal and complex types") { | ||
def widenTestWithStringPromotion( | ||
t1: DataType, | ||
t2: DataType, | ||
|
@@ -462,27 +507,139 @@ class TypeCoercionSuite extends AnalysisTest { | |
ArrayType(DoubleType, containsNull = false), | ||
Some(ArrayType(DoubleType, containsNull = true))) | ||
widenTestWithStringPromotion( | ||
ArrayType(TimestampType, containsNull = false), | ||
ArrayType(StringType, containsNull = true), | ||
ArrayType(ArrayType(IntegerType), containsNull = true), | ||
ArrayType(ArrayType(LongType), containsNull = false), | ||
Some(ArrayType(ArrayType(LongType), containsNull = true))) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we really wanna do this? Is there any other database have this behavior? I think for complex type, we should ignore the nullable difference, but I'm not sure if we should do type coercion inside complex type. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @mgaido91, thoughts on this? It's definitely possible for us to revert back to behavior where we don't do IntegerType-to-LongType, xType-to-StringType, etc. promotion inside complex types, which was how a previous form of this PR handled it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this is more coherent with the behavior in other parts and this is the behavior I would expect. But I think that we should follow @gatorsmile's suggestion and check Hive's behavior first. |
||
widenTestWithStringPromotion( | ||
ArrayType(StringType), ArrayType(TimestampType), Some(ArrayType(StringType))) | ||
widenTestWithStringPromotion( | ||
StructType( | ||
Seq(StructField("a", ArrayType(LongType, containsNull = true), nullable = true))), | ||
StructType( | ||
Seq(StructField("a", ArrayType(StringType, containsNull = false), nullable = false))), | ||
Some(StructType( | ||
Seq(StructField("a", ArrayType(StringType, containsNull = true), nullable = true))))) | ||
widenTestWithStringPromotion( | ||
ArrayType(StructType(Seq(StructField("a", LongType, nullable = true))), containsNull = true), | ||
ArrayType(StructType( | ||
Seq(StructField("a", StringType, nullable = false))), containsNull = false), | ||
Some(ArrayType(StructType( | ||
Seq(StructField("a", StringType, nullable = true))), containsNull = true))) | ||
widenTestWithStringPromotion( | ||
StructType(Seq(StructField("a", ArrayType(LongType)))), | ||
StructType(Seq(StructField("b", ArrayType(StringType)))), | ||
None) | ||
widenTestWithStringPromotion( | ||
ArrayType(StructType(Seq(StructField("a", LongType)))), | ||
ArrayType(StructType(Seq(StructField("b", StringType)))), | ||
None) | ||
|
||
// MapType | ||
widenTestWithStringPromotion( | ||
MapType(StringType, ShortType, valueContainsNull = true), | ||
MapType(StringType, DoubleType, valueContainsNull = false), | ||
Some(MapType(StringType, DoubleType, valueContainsNull = true))) | ||
widenTestWithStringPromotion( | ||
MapType(StringType, MapType(StringType, IntegerType, valueContainsNull = true), | ||
valueContainsNull = true), | ||
MapType(StringType, MapType(StringType, LongType, valueContainsNull = false), | ||
valueContainsNull = false), | ||
Some(MapType(StringType, MapType(StringType, LongType, valueContainsNull = true), | ||
valueContainsNull = true))) | ||
widenTestWithStringPromotion( | ||
StructType(Seq(StructField("a", MapType( | ||
StringType, LongType, valueContainsNull = true), nullable = true))), | ||
StructType(Seq(StructField("a", MapType( | ||
StringType, StringType, valueContainsNull = false), nullable = false))), | ||
Some(StructType(Seq(StructField("a", MapType( | ||
StringType, StringType, valueContainsNull = true), nullable = true))))) | ||
widenTestWithStringPromotion( | ||
MapType(StringType, | ||
StructType(Seq(StructField("a", LongType, nullable = true))), | ||
valueContainsNull = true), | ||
MapType(StringType, | ||
StructType(Seq(StructField("a", StringType, nullable = false))), | ||
valueContainsNull = false), | ||
Some(MapType(StringType, | ||
StructType(Seq(StructField("a", StringType, nullable = true))), | ||
valueContainsNull = true))) | ||
widenTestWithStringPromotion( | ||
StructType(Seq(StructField("a", MapType(StringType, LongType)))), | ||
StructType(Seq(StructField("b", MapType(StringType, StringType)))), | ||
None) | ||
widenTestWithStringPromotion( | ||
MapType(StringType, StructType(Seq(StructField("a", LongType)))), | ||
MapType(StringType, StructType(Seq(StructField("b", StringType)))), | ||
None) | ||
|
||
// String promotion | ||
widenTestWithStringPromotion(IntegerType, StringType, Some(StringType)) | ||
widenTestWithStringPromotion(StringType, TimestampType, Some(StringType)) | ||
widenTestWithStringPromotion( | ||
ArrayType(TimestampType, containsNull = true), | ||
ArrayType(StringType, containsNull = false), | ||
Some(ArrayType(StringType, containsNull = true))) | ||
widenTestWithStringPromotion( | ||
ArrayType(ArrayType(IntegerType), containsNull = false), | ||
ArrayType(ArrayType(LongType), containsNull = false), | ||
Some(ArrayType(ArrayType(LongType), containsNull = false))) | ||
ArrayType(LongType, containsNull = true), | ||
ArrayType(StringType, containsNull = false), | ||
Some(ArrayType(StringType, containsNull = true))) | ||
widenTestWithStringPromotion( | ||
MapType(StringType, TimestampType, valueContainsNull = true), | ||
MapType(StringType, StringType, valueContainsNull = false), | ||
Some(MapType(StringType, StringType, valueContainsNull = true))) | ||
widenTestWithStringPromotion( | ||
MapType(StringType, LongType, valueContainsNull = true), | ||
MapType(StringType, StringType, valueContainsNull = false), | ||
Some(MapType(StringType, StringType, valueContainsNull = true))) | ||
|
||
// Without string promotion | ||
widenTestWithoutStringPromotion(IntegerType, StringType, None) | ||
widenTestWithoutStringPromotion(StringType, TimestampType, None) | ||
widenTestWithoutStringPromotion(ArrayType(LongType), ArrayType(StringType), None) | ||
widenTestWithoutStringPromotion(ArrayType(StringType), ArrayType(TimestampType), None) | ||
widenTestWithoutStringPromotion( | ||
MapType(StringType, LongType), | ||
MapType(StringType, TimestampType), | ||
None) | ||
widenTestWithoutStringPromotion( | ||
MapType(StringType, StringType), | ||
MapType(StringType, TimestampType), | ||
None) | ||
widenTestWithoutStringPromotion( | ||
StructType(Seq(StructField("a", MapType(StringType, LongType)))), | ||
StructType(Seq(StructField("a", MapType(StringType, StringType)))), | ||
None) | ||
widenTestWithoutStringPromotion( | ||
StructType(Seq(StructField("a", ArrayType(LongType)))), | ||
StructType(Seq(StructField("a", ArrayType(StringType)))), | ||
None) | ||
widenTestWithoutStringPromotion( | ||
StructType(Seq(StructField("a", MapType(StringType, LongType)))), | ||
StructType(Seq(StructField("a", MapType(StringType, StringType)))), | ||
None) | ||
widenTestWithoutStringPromotion( | ||
ArrayType(StructType(Seq(StructField("a", LongType)))), | ||
ArrayType(StructType(Seq(StructField("a", StringType)))), | ||
None) | ||
|
||
// String promotion | ||
widenTestWithStringPromotion(IntegerType, StringType, Some(StringType)) | ||
widenTestWithStringPromotion(StringType, TimestampType, Some(StringType)) | ||
widenTestWithStringPromotion( | ||
ArrayType(LongType), ArrayType(StringType), Some(ArrayType(StringType))) | ||
// Although the data types promotion would not fail, tests should still return None due to field | ||
// name mismatch | ||
widenTestWithoutStringPromotion( | ||
StructType(Seq(StructField("a", MapType(StringType, LongType)))), | ||
StructType(Seq(StructField("b", MapType(StringType, LongType)))), | ||
None) | ||
widenTestWithoutStringPromotion( | ||
StructType(Seq(StructField("a", ArrayType(LongType)))), | ||
StructType(Seq(StructField("b", ArrayType(LongType)))), | ||
None) | ||
widenTestWithoutStringPromotion( | ||
StructType(Seq(StructField("a", MapType(StringType, LongType)))), | ||
StructType(Seq(StructField("b", MapType(StringType, LongType)))), | ||
None) | ||
widenTestWithStringPromotion( | ||
ArrayType(StringType), ArrayType(TimestampType), Some(ArrayType(StringType))) | ||
ArrayType(StructType(Seq(StructField("a", LongType)))), | ||
ArrayType(StructType(Seq(StructField("b", LongType)))), | ||
None) | ||
} | ||
|
||
private def ruleTest(rule: Rule[LogicalPlan], initial: Expression, transformed: Expression) { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this requires that fields with the same name must also be in the same position. Is this assumption correct?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mgaido91 That's seems to be the assumption already made in
findTightestCommonType
. ThesameType
function on DataType also requires structfields are ordered the same, else it returns false.The difference here is that we don't require the structfields strictly have the same type, so we can support widening to LongType, StringType, etc. But we do require the fields 1. have the same order, and 2. have the same name (either with strict case, or ignoring case).