From 85f3106fb422f872747c4be7b0a83cc5b2ec8930 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 25 Mar 2015 18:58:10 -0700 Subject: [PATCH 1/4] add CreateStruct --- .../sql/catalyst/analysis/Analyzer.scala | 6 ++++ .../catalyst/expressions/complexTypes.scala | 28 ++++++++++++++++++- .../ExpressionEvaluationSuite.scala | 7 +++++ 3 files changed, 40 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 44eceb0b372e6..871d306377672 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -212,6 +212,12 @@ class Analyzer(catalog: Catalog, case o => o :: Nil } Alias(c.copy(children = expandedArgs), name)() :: Nil + case Alias(c @ CreateStruct(args), name) if containsStar(args) => + val expandedArgs = args.flatMap { + case s: Star => s.expand(child.output, resolver) + case o => o :: Nil + } + Alias(c.copy(children = expandedArgs), name)() :: Nil case o => o :: Nil }, child) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala index 3fd78db297462..e5afb3602370f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala @@ -120,7 +120,7 @@ case class ArrayGetField(child: Expression, field: StructField, ordinal: Int, co case class CreateArray(children: Seq[Expression]) extends Expression { override type EvaluatedType = Any - override def foldable: Boolean = !children.exists(!_.foldable) + override def foldable: Boolean = children.forall(_.foldable) lazy val childTypes = children.map(_.dataType).distinct @@ -142,3 +142,29 @@ case class CreateArray(children: Seq[Expression]) extends Expression { override def toString: String = s"Array(${children.mkString(",")})" } + +/** + * Returns a Row containing the evaluation of all children expressions. + */ +case class CreateStruct(children: Seq[Expression]) extends Expression { + override type EvaluatedType = Row + + override def foldable: Boolean = children.forall(_.foldable) + + override lazy val resolved: Boolean = childrenResolved + + override def dataType: StructType = { + assert(resolved, s"CreateStruct is called with unresolved children: $children.") + val fields = children.map { + case named: NamedExpression => + StructField(named.name, named.dataType, named.nullable, named.metadata) + } + StructType(fields) + } + + override def nullable: Boolean = false + + override def eval(input: Row): EvaluatedType = { + Row(children.map(_.eval(input)): _*) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index dcfd8b28cb02a..0ff8a7624f6bb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -1080,4 +1080,11 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(c1 ^ c2, 3, row) checkEvaluation(~c1, -2, row) } + + test("CreateStruct") { + val row = Row(1, 2, 3) + val c1 = 'a.int.at(0) + val c3 = 'a.int.at(2) + checkEvaluation(CreateStruct(Seq(c1, c3)), Row(1, 3), row) + } } From 85dd55915a4b11371b7dacb9091fdff29b0195df Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Thu, 26 Mar 2015 09:47:56 -0700 Subject: [PATCH 2/4] use NamedExpr --- .../spark/sql/catalyst/expressions/complexTypes.scala | 9 ++++----- .../catalyst/expressions/ExpressionEvaluationSuite.scala | 4 ++-- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala index e5afb3602370f..696c4e10b5c24 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala @@ -146,18 +146,17 @@ case class CreateArray(children: Seq[Expression]) extends Expression { /** * Returns a Row containing the evaluation of all children expressions. */ -case class CreateStruct(children: Seq[Expression]) extends Expression { +case class CreateStruct(children: Seq[NamedExpression]) extends Expression { override type EvaluatedType = Row override def foldable: Boolean = children.forall(_.foldable) override lazy val resolved: Boolean = childrenResolved - override def dataType: StructType = { + override lazy val dataType: StructType = { assert(resolved, s"CreateStruct is called with unresolved children: $children.") - val fields = children.map { - case named: NamedExpression => - StructField(named.name, named.dataType, named.nullable, named.metadata) + val fields = children.map { child => + StructField(child.name, child.dataType, child.nullable, child.metadata) } StructType(fields) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index 0ff8a7624f6bb..0591c81ef9007 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -1083,8 +1083,8 @@ class ExpressionEvaluationSuite extends FunSuite { test("CreateStruct") { val row = Row(1, 2, 3) - val c1 = 'a.int.at(0) - val c3 = 'a.int.at(2) + val c1 = 'a.int.at(0).as("a") + val c3 = 'c.int.at(2).as("c") checkEvaluation(CreateStruct(Seq(c1, c3)), Row(1, 3), row) } } From ae7ac3e1504947ec41232ef41f24b5e0e6c30eca Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 31 Mar 2015 00:29:51 -0700 Subject: [PATCH 3/4] move unit test to a separate suite --- .../catalyst/expressions/complexTypes.scala | 1 + .../ExpressionEvaluationSuite.scala | 54 +++++++++++-------- 2 files changed, 33 insertions(+), 22 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala index 696c4e10b5c24..74883cbacd2e4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala @@ -145,6 +145,7 @@ case class CreateArray(children: Seq[Expression]) extends Expression { /** * Returns a Row containing the evaluation of all children expressions. + * TODO: [[CreateStruct]] does not support codegen. */ case class CreateStruct(children: Seq[NamedExpression]) extends Expression { override type EvaluatedType = Row diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index 0591c81ef9007..1183a0d899dda 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -30,7 +30,34 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedGetField import org.apache.spark.sql.types._ -class ExpressionEvaluationSuite extends FunSuite { +class ExpressionEvaluationBaseSuite extends FunSuite { + + def evaluate(expression: Expression, inputRow: Row = EmptyRow): Any = { + expression.eval(inputRow) + } + + def checkEvaluation(expression: Expression, expected: Any, inputRow: Row = EmptyRow): Unit = { + val actual = try evaluate(expression, inputRow) catch { + case e: Exception => fail(s"Exception evaluating $expression", e) + } + if(actual != expected) { + val input = if(inputRow == EmptyRow) "" else s", input: $inputRow" + fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input") + } + } + + def checkDoubleEvaluation( + expression: Expression, + expected: Spread[Double], + inputRow: Row = EmptyRow): Unit = { + val actual = try evaluate(expression, inputRow) catch { + case e: Exception => fail(s"Exception evaluating $expression", e) + } + actual.asInstanceOf[Double] shouldBe expected + } +} + +class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { test("literals") { checkEvaluation(Literal(1), 1) @@ -134,27 +161,6 @@ class ExpressionEvaluationSuite extends FunSuite { } } - def evaluate(expression: Expression, inputRow: Row = EmptyRow): Any = { - expression.eval(inputRow) - } - - def checkEvaluation(expression: Expression, expected: Any, inputRow: Row = EmptyRow): Unit = { - val actual = try evaluate(expression, inputRow) catch { - case e: Exception => fail(s"Exception evaluating $expression", e) - } - if(actual != expected) { - val input = if(inputRow == EmptyRow) "" else s", input: $inputRow" - fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input") - } - } - - def checkDoubleEvaluation(expression: Expression, expected: Spread[Double], inputRow: Row = EmptyRow): Unit = { - val actual = try evaluate(expression, inputRow) catch { - case e: Exception => fail(s"Exception evaluating $expression", e) - } - actual.asInstanceOf[Double] shouldBe expected - } - test("IN") { checkEvaluation(In(Literal(1), Seq(Literal(1), Literal(2))), true) checkEvaluation(In(Literal(2), Seq(Literal(1), Literal(2))), true) @@ -1080,6 +1086,10 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(c1 ^ c2, 3, row) checkEvaluation(~c1, -2, row) } +} + +// TODO: Make the tests work with codegen. +class ExpressionEvaluationWithoutCodeGenSuite extends ExpressionEvaluationBaseSuite { test("CreateStruct") { val row = Row(1, 2, 3) From 3795c57afe2bd1b7eb7da2c214918dbdaec8a8b2 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 31 Mar 2015 00:51:42 -0700 Subject: [PATCH 4/4] update error message --- .../apache/spark/sql/catalyst/expressions/complexTypes.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala index 74883cbacd2e4..3b2b9211268a9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala @@ -155,7 +155,8 @@ case class CreateStruct(children: Seq[NamedExpression]) extends Expression { override lazy val resolved: Boolean = childrenResolved override lazy val dataType: StructType = { - assert(resolved, s"CreateStruct is called with unresolved children: $children.") + assert(resolved, + s"CreateStruct contains unresolvable children: ${children.filterNot(_.resolved)}.") val fields = children.map { child => StructField(child.name, child.dataType, child.nullable, child.metadata) }