From 5603918ae963f78aafb2d1f4f2bd9d566870495b Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Sat, 4 Aug 2018 15:38:08 +0200 Subject: [PATCH 1/8] Initial implementation --- .../expressions/collectionOperations.scala | 43 +++++++++++++++++++ .../CollectionExpressionsSuite.scala | 10 +++++ 2 files changed, 53 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index b03bd7d942d72..3d118e926db90 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -4256,3 +4256,46 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike override def prettyName: String = "array_except" } + +/** + * + */ +@ExpressionDescription( + usage = """ + _FUNC_(expr) - + """, + examples = """ + Examples: + > + """) +case class StructFlatten(child: Expression) extends UnaryExpression with ExpectsInputTypes { + val depth = 1 + val delimiter = "_" + + def fieldName(prefix: String, name: String): String = { + if (prefix.isEmpty) name else prefix + delimiter + name + } + def flatField(field: StructField, prefix: String): Array[StructField] = field match { + case f @ StructField(name, st: StructType, _, _) => + flatStruct(st, fieldName(prefix, field.name)) + case _ => Array(field.copy(name = fieldName(prefix, field.name))) + } + def flatStruct(st: StructType, prefix: String): Array[StructField] = { + st.fields.flatMap(field => flatField(field, prefix)) + } + override def dataType: DataType = child.dataType match { + case st: StructType => st.copy(fields = flatStruct(st, "")) + case other => other + } + + override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StructType)) + + override def eval(input: InternalRow): Any = { + val value = child.eval(input) + ??? + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + ??? + } +} \ No newline at end of file diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 2f6f9064f9e62..ed727c43d9aeb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -23,6 +23,7 @@ import java.util.TimeZone import scala.util.Random import org.apache.spark.SparkFunSuite + import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.DateTimeTestUtils @@ -1618,4 +1619,13 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper assert(ArrayExcept(a20, a21).dataType.asInstanceOf[ArrayType].containsNull === false) assert(ArrayExcept(a24, a22).dataType.asInstanceOf[ArrayType].containsNull === true) } + + test("struct flatten") { + val struct = CreateStruct(Seq(CreateStruct(Seq(Literal(1))))) + val expectedSchema = StructType(Seq( + StructField("col1_col1", IntegerType, false) + )) + + assert(StructFlatten(struct).dataType == expectedSchema) + } } From 0be0d059b8bf571068226c515888a64093468cff Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Sat, 4 Aug 2018 18:07:45 +0200 Subject: [PATCH 2/8] Making the depth and delimiter as parameters --- .../expressions/collectionOperations.scala | 36 ++++++++++--------- .../CollectionExpressionsSuite.scala | 22 ++++++++---- 2 files changed, 36 insertions(+), 22 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 3d118e926db90..a2fb3ed3e4d73 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -4268,34 +4268,38 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike Examples: > """) -case class StructFlatten(child: Expression) extends UnaryExpression with ExpectsInputTypes { - val depth = 1 - val delimiter = "_" +case class StructFlatten( + child: Expression, + depth: Int = Int.MaxValue, + delimiter: String = "_") extends UnaryExpression with CodegenFallback { def fieldName(prefix: String, name: String): String = { if (prefix.isEmpty) name else prefix + delimiter + name } - def flatField(field: StructField, prefix: String): Array[StructField] = field match { - case f @ StructField(name, st: StructType, _, _) => - flatStruct(st, fieldName(prefix, field.name)) + def flatStructField(field: StructField, prefix: String): Array[StructField] = field match { + case StructField(name, st: StructType, _, _) => + flatStructType(st, fieldName(prefix, name)) case _ => Array(field.copy(name = fieldName(prefix, field.name))) } - def flatStruct(st: StructType, prefix: String): Array[StructField] = { - st.fields.flatMap(field => flatField(field, prefix)) + def flatStructType(st: StructType, prefix: String): Array[StructField] = { + st.fields.flatMap(field => flatStructField(field, prefix)) } override def dataType: DataType = child.dataType match { - case st: StructType => st.copy(fields = flatStruct(st, "")) + case st: StructType => st.copy(fields = flatStructType(st, "")) case other => other } - override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StructType)) - override def eval(input: InternalRow): Any = { - val value = child.eval(input) - ??? + def flatColumn(column: Any): Array[Any] = column match { + case row: GenericInternalRow => flatRow(row).values + case _ => Array(column) } - - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - ??? + def flatRow(st: GenericInternalRow): GenericInternalRow = { + val values = st.values.flatMap(column => flatColumn(column)) + new GenericInternalRow(values) + } + override def nullSafeEval(input: Any): Any = input match { + case row: GenericInternalRow => flatRow(row) + case other => other } } \ No newline at end of file diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index ed727c43d9aeb..545a27fcd86e0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -1621,11 +1621,21 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper } test("struct flatten") { - val struct = CreateStruct(Seq(CreateStruct(Seq(Literal(1))))) - val expectedSchema = StructType(Seq( - StructField("col1_col1", IntegerType, false) - )) - - assert(StructFlatten(struct).dataType == expectedSchema) + // level = 2 + val struct1 = CreateStruct(Seq(CreateStruct(Seq(Literal(1))))) + val expectedSchema1 = StructType(Seq(StructField("col1_col1", IntegerType, false))) + assert(StructFlatten(struct1).dataType == expectedSchema1) + checkEvaluation(StructFlatten(struct1), Row(1)) + + // level = 3 + val struct2 = CreateNamedStruct(Seq(Literal("level0"), CreateNamedStruct(Seq( + Literal("level1"), CreateNamedStruct(Seq( + Literal("col1"), Literal(1), Literal("col2"), Literal("a") + )))))) + val expectedSchema2 = StructType(Seq( + StructField("level0_level1_col1", IntegerType, false), + StructField("level0_level1_col2", StringType, false))) + assert(StructFlatten(struct2).dataType == expectedSchema2) + checkEvaluation(StructFlatten(struct2), Row(1, "a")) } } From 5666ec372a4b79f6161120584abc0c312b111bfb Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Sat, 4 Aug 2018 20:04:23 +0200 Subject: [PATCH 3/8] Test for depth = 0 --- .../expressions/collectionOperations.scala | 10 +++---- .../CollectionExpressionsSuite.scala | 27 +++++++++++++++---- 2 files changed, 27 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index a2fb3ed3e4d73..0047d20d4e877 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -4285,21 +4285,21 @@ case class StructFlatten( st.fields.flatMap(field => flatStructField(field, prefix)) } override def dataType: DataType = child.dataType match { - case st: StructType => st.copy(fields = flatStructType(st, "")) + case st: StructType if depth > 0 => st.copy(fields = flatStructType(st, "")) case other => other } - def flatColumn(column: Any): Array[Any] = column match { + def flatValue(value: Any): Array[Any] = value match { case row: GenericInternalRow => flatRow(row).values - case _ => Array(column) + case _ => Array(value) } def flatRow(st: GenericInternalRow): GenericInternalRow = { - val values = st.values.flatMap(column => flatColumn(column)) + val values = st.values.flatMap(column => flatValue(column)) new GenericInternalRow(values) } override def nullSafeEval(input: Any): Any = input match { - case row: GenericInternalRow => flatRow(row) + case row: GenericInternalRow if depth > 0 => flatRow(row) case other => other } } \ No newline at end of file diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 545a27fcd86e0..b79e8f76c38a0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -1621,21 +1621,38 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper } test("struct flatten") { - // level = 2 + // 2 nested structs, depth = default val struct1 = CreateStruct(Seq(CreateStruct(Seq(Literal(1))))) val expectedSchema1 = StructType(Seq(StructField("col1_col1", IntegerType, false))) assert(StructFlatten(struct1).dataType == expectedSchema1) checkEvaluation(StructFlatten(struct1), Row(1)) - // level = 3 + // 3 nested structs, depth = default, delimiter = "-" val struct2 = CreateNamedStruct(Seq(Literal("level0"), CreateNamedStruct(Seq( Literal("level1"), CreateNamedStruct(Seq( Literal("col1"), Literal(1), Literal("col2"), Literal("a") )))))) val expectedSchema2 = StructType(Seq( - StructField("level0_level1_col1", IntegerType, false), - StructField("level0_level1_col2", StringType, false))) - assert(StructFlatten(struct2).dataType == expectedSchema2) + StructField("level0-level1-col1", IntegerType, false), + StructField("level0-level1-col2", StringType, false))) + assert(StructFlatten(struct2, delimiter = "-").dataType == expectedSchema2) checkEvaluation(StructFlatten(struct2), Row(1, "a")) + + // 3 nested structs, depth = 0 + val expectedSchema3 = StructType(Seq(StructField("level0", + StructType(Seq(StructField("level1", + StructType(Seq( + StructField("col1", IntegerType, false), StructField("col2", StringType, false) + )), false) + )), false) + )) + assert(StructFlatten(struct2, depth = 0).dataType == expectedSchema3) + checkEvaluation(StructFlatten(struct2, depth = 0), Row(Row(Row(1, "a")))) + + // 3 nested structs, depth = 1 +// val expectedSchema3 = StructType(Seq(StructField("level0_level1", StructType(Seq( +// StructField("col1", IntegerType, false), +// StructField("col2", StringType, false))), false))) +// assert(StructFlatten(struct2, depth = 1).dataType == expectedSchema3) } } From cd88a2125ba6932ba1fdceca1a24d57124a23afa Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Sat, 4 Aug 2018 20:21:19 +0200 Subject: [PATCH 4/8] Test for depth = 1 --- .../expressions/collectionOperations.scala | 32 +++++++++++-------- .../CollectionExpressionsSuite.scala | 11 ++++--- 2 files changed, 26 insertions(+), 17 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 0047d20d4e877..c417f83a80679 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -4276,30 +4276,36 @@ case class StructFlatten( def fieldName(prefix: String, name: String): String = { if (prefix.isEmpty) name else prefix + delimiter + name } - def flatStructField(field: StructField, prefix: String): Array[StructField] = field match { - case StructField(name, st: StructType, _, _) => - flatStructType(st, fieldName(prefix, name)) - case _ => Array(field.copy(name = fieldName(prefix, field.name))) + def flatStructField(field: StructField, prefix: String, d: Int): Array[StructField] = { + field match { + case StructField(name, st: StructType, _, _) if d > 0 => + flatStructType(st, fieldName(prefix, name), d - 1) + case _ => Array(field.copy(name = fieldName(prefix, field.name))) + } } - def flatStructType(st: StructType, prefix: String): Array[StructField] = { - st.fields.flatMap(field => flatStructField(field, prefix)) + def flatStructType(st: StructType, prefix: String, d: Int): Array[StructField] = { + st.fields.flatMap(field => flatStructField(field, prefix, d)) } override def dataType: DataType = child.dataType match { - case st: StructType if depth > 0 => st.copy(fields = flatStructType(st, "")) + case st: StructType if depth > 0 => st.copy(fields = flatStructType(st, "", depth)) case other => other } - def flatValue(value: Any): Array[Any] = value match { - case row: GenericInternalRow => flatRow(row).values + def flatValue(value: Any, d: Int): Array[Any] = value match { + case row: GenericInternalRow if d > 0 => flatRow(row, d - 1).values case _ => Array(value) } - def flatRow(st: GenericInternalRow): GenericInternalRow = { - val values = st.values.flatMap(column => flatValue(column)) - new GenericInternalRow(values) + def flatRow(struct: GenericInternalRow, d: Int): GenericInternalRow = { + if (d > 0 ) { + val values = struct.values.flatMap(column => flatValue(column, d)) + new GenericInternalRow(values) + } else { + struct + } } override def nullSafeEval(input: Any): Any = input match { - case row: GenericInternalRow if depth > 0 => flatRow(row) + case row: GenericInternalRow => flatRow(row, depth) case other => other } } \ No newline at end of file diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index b79e8f76c38a0..a41eaaaf4fb08 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -1650,9 +1650,12 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(StructFlatten(struct2, depth = 0), Row(Row(Row(1, "a")))) // 3 nested structs, depth = 1 -// val expectedSchema3 = StructType(Seq(StructField("level0_level1", StructType(Seq( -// StructField("col1", IntegerType, false), -// StructField("col2", StringType, false))), false))) -// assert(StructFlatten(struct2, depth = 1).dataType == expectedSchema3) + val expectedSchema4 = StructType(Seq(StructField("level0_level1", + StructType(Seq( + StructField("col1", IntegerType, false), StructField("col2", StringType, false) + )), false) + )) + assert(StructFlatten(struct2, depth = 1).dataType == expectedSchema4) + checkEvaluation(StructFlatten(struct2, depth = 1), Row(Row(1, "a"))) } } From b0da02d37ac6db38f63bac95dc295ac37fe4a692 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Sat, 4 Aug 2018 20:30:18 +0200 Subject: [PATCH 5/8] Renaming st to struct --- .../spark/sql/catalyst/expressions/collectionOperations.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index c417f83a80679..6b5d9807daff9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -4283,8 +4283,8 @@ case class StructFlatten( case _ => Array(field.copy(name = fieldName(prefix, field.name))) } } - def flatStructType(st: StructType, prefix: String, d: Int): Array[StructField] = { - st.fields.flatMap(field => flatStructField(field, prefix, d)) + def flatStructType(struct: StructType, prefix: String, d: Int): Array[StructField] = { + struct.fields.flatMap(field => flatStructField(field, prefix, d)) } override def dataType: DataType = child.dataType match { case st: StructType if depth > 0 => st.copy(fields = flatStructType(st, "", depth)) From ec361791b83d71f29823157a2c2b49162ddb5901 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Sat, 4 Aug 2018 21:24:37 +0200 Subject: [PATCH 6/8] Negative tests --- .../expressions/CollectionExpressionsSuite.scala | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index a41eaaaf4fb08..9a5fdba0a9b63 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -23,7 +23,6 @@ import java.util.TimeZone import scala.util.Random import org.apache.spark.SparkFunSuite - import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.DateTimeTestUtils @@ -1620,7 +1619,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper assert(ArrayExcept(a24, a22).dataType.asInstanceOf[ArrayType].containsNull === true) } - test("struct flatten") { + test("flatten structures") { // 2 nested structs, depth = default val struct1 = CreateStruct(Seq(CreateStruct(Seq(Literal(1))))) val expectedSchema1 = StructType(Seq(StructField("col1_col1", IntegerType, false))) @@ -1658,4 +1657,14 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper assert(StructFlatten(struct2, depth = 1).dataType == expectedSchema4) checkEvaluation(StructFlatten(struct2, depth = 1), Row(Row(1, "a"))) } + + test("flatten structures shouldn't change maps and arrays") { + val arr = Literal.create(Seq(1, 2), ArrayType(IntegerType)) + val struct1 = CreateStruct(Seq(arr, CreateStruct(Seq(Literal(3))))) + checkEvaluation(StructFlatten(struct1), Row(Seq(1, 2), 3)) + + val map = Literal.create(Map("a" -> 1), MapType(StringType, IntegerType)) + val struct2 = CreateStruct(Seq(map, CreateStruct(Seq(Literal(3))))) + checkEvaluation(StructFlatten(struct2), Row(Map("a" -> 1), 3)) + } } From ced63d7f093c168e2bc9457b6c08b87bfe6c0751 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Sat, 4 Aug 2018 22:10:00 +0200 Subject: [PATCH 7/8] Register struct_flatten --- .../sql/catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/collectionOperations.scala | 2 ++ .../spark/sql/DataFrameFunctionsSuite.scala | 16 ++++++++++++++++ 3 files changed, 19 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index f7517486e5411..92e6434d93292 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -441,6 +441,7 @@ object FunctionRegistry { expression[ArrayRemove]("array_remove"), expression[ArrayDistinct]("array_distinct"), expression[ArrayTransform]("transform"), + expression[StructFlatten]("struct_flatten"), CreateStruct.registryEntry, // misc functions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 6b5d9807daff9..4f0a25941e4ac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -4273,6 +4273,8 @@ case class StructFlatten( depth: Int = Int.MaxValue, delimiter: String = "_") extends UnaryExpression with CodegenFallback { + def this(child: Expression) = this(child, Int.MaxValue, "_") + def fieldName(prefix: String, name: String): String = { if (prefix.isEmpty) name else prefix + delimiter + name } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 923482024b033..6c46a8ad24979 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -1879,6 +1879,22 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { } assert(ex.getMessage.contains("Cannot use null as map key")) } + + test("struct_flatten function") { + val df = spark.range(0, 10) + .select(struct( + 'id as "col1", + struct(('id + 1) as "col3") as "col2" + ) as "st") + val flatten = df.selectExpr("struct_flatten(st)") + val expected = spark.range(0, 10) + .select(struct( + 'id as "col1", + ('id + 1) as "col2_col3" + ) as "st") + + checkAnswer(flatten, expected) + } } object DataFrameFunctionsSuite { From 8de14652b838ea053f430d17129c73c85cb2e0cb Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Sat, 4 Aug 2018 23:00:07 +0200 Subject: [PATCH 8/8] Making Scala style checker happy --- .../spark/sql/catalyst/expressions/collectionOperations.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 5fea3c7b76382..1456774133867 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -4301,4 +4301,4 @@ case class StructFlatten( case row: GenericInternalRow => flatRow(row, depth) case other => other } -} \ No newline at end of file +}