Skip to content

Commit

Permalink
[SPARK-27288][SQL] Pruning nested field in complex map key from objec…
Browse files Browse the repository at this point in the history
…t serializers

## What changes were proposed in this pull request?

In the original PR #24158, pruning nested field in complex map key was not supported, because some methods in schema pruning did't support it at that moment. This is a followup to add it.

## How was this patch tested?

Added tests.

Closes #24220 from viirya/SPARK-26847-followup.

Authored-by: Liang-Chi Hsieh <viirya@gmail.com>
Signed-off-by: Takeshi Yamamuro <yamamuro@apache.org>
  • Loading branch information
viirya authored and maropu committed Mar 27, 2019
1 parent fac3110 commit 93ff690
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,8 @@ object ObjectSerializerPruning extends Rule[LogicalPlan] {
fields.map(f => collectStructType(f.dataType, structs))
case ArrayType(elementType, _) =>
collectStructType(elementType, structs)
case MapType(_, valueType, _) =>
// Because we can't select a field from struct in key, so we skip key type.
case MapType(keyType, valueType, _) =>
collectStructType(keyType, structs)
collectStructType(valueType, structs)
// We don't use UserDefinedType in those serializers.
case _: UserDefinedType[_] =>
Expand Down Expand Up @@ -179,13 +179,20 @@ object ObjectSerializerPruning extends Rule[LogicalPlan] {

val transformedSerializer = serializer.transformDown {
case m: ExternalMapToCatalyst =>
val prunedKeyConverter = m.keyConverter.transformDown {
case s: CreateNamedStruct if structTypeIndex < prunedStructTypes.size =>
val prunedType = prunedStructTypes(structTypeIndex)
structTypeIndex += 1
pruneNamedStruct(s, prunedType)
}
val prunedValueConverter = m.valueConverter.transformDown {
case s: CreateNamedStruct if structTypeIndex < prunedStructTypes.size =>
val prunedType = prunedStructTypes(structTypeIndex)
structTypeIndex += 1
pruneNamedStruct(s, prunedType)
}
m.copy(valueConverter = alignNullTypeInIf(prunedValueConverter))
m.copy(keyConverter = alignNullTypeInIf(prunedKeyConverter),
valueConverter = alignNullTypeInIf(prunedValueConverter))
case s: CreateNamedStruct if structTypeIndex < prunedStructTypes.size =>
val prunedType = prunedStructTypes(structTypeIndex)
structTypeIndex += 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,9 @@ class ObjectSerializerPruningSuite extends PlanTest {
Seq(StructType.fromDDL("a struct<a:int, b:int>, b int"),
StructType.fromDDL("a int, b int")),
Seq(StructType.fromDDL("a int, b int, c string")),
Seq.empty[StructType],
Seq(StructType.fromDDL("c long, d string"))
Seq(StructType.fromDDL("a struct<a:int, b:int>, b int"),
StructType.fromDDL("a int, b int")),
Seq(StructType.fromDDL("a int, b int"), StructType.fromDDL("c long, d string"))
)

dataTypes.zipWithIndex.foreach { case (dt, idx) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ class DatasetOptimizationSuite extends QueryTest with SharedSQLContext {
val structs = serializer.collect {
case c: CreateNamedStruct => Seq(c)
case m: ExternalMapToCatalyst =>
m.valueConverter.collect {
m.keyConverter.collect {
case c: CreateNamedStruct => c
} ++ m.valueConverter.collect {
case c: CreateNamedStruct => c
}
}.flatten
Expand Down Expand Up @@ -123,6 +125,21 @@ class DatasetOptimizationSuite extends QueryTest with SharedSQLContext {
val df2 = mapDs.select("_1.k._2")
testSerializer(df2, Seq(Seq("_2")))
checkAnswer(df2, Seq(Row(11), Row(22), Row(33)))

val df3 = mapDs.select(expr("map_values(_1)._2[0]"))
testSerializer(df3, Seq(Seq("_2")))
checkAnswer(df3, Seq(Row(11), Row(22), Row(33)))
}
}

test("Pruned nested serializers: map of complex key") {
withSQLConf(SQLConf.SERIALIZER_NESTED_SCHEMA_PRUNING_ENABLED.key -> "true") {
val mapData = Seq((Map((("1", 1), "a_1")), 1), (Map((("2", 2), "b_1")), 2),
(Map((("3", 3), "c_1")), 3))
val mapDs = mapData.toDS().map(t => (t._1, t._2 + 1))
val df1 = mapDs.select(expr("map_keys(_1)._1[0]"))
testSerializer(df1, Seq(Seq("_1")))
checkAnswer(df1, Seq(Row("1"), Row("2"), Row("3")))
}
}
}

0 comments on commit 93ff690

Please sign in to comment.