From 3922b544054b8ef1f10147e86d690e15526ede29 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 30 May 2015 21:41:56 +0800 Subject: [PATCH 01/12] Support ordering on named_struct. --- .../sql/catalyst/expressions/predicates.scala | 102 +++++++++++++++--- .../sql/hive/execution/HiveUdfSuite.scala | 15 +++ 2 files changed, 104 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index e2d1c8115e051..a1d6c8673ca44 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.analysis.UnresolvedException import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.types.{DataType, BinaryType, BooleanType, AtomicType} +import org.apache.spark.sql.types.{DataType, BinaryType, BooleanType, AtomicType, StructType} object InterpretedPredicate { def create(expression: Expression, inputSchema: Seq[Attribute]): (Row => Boolean) = @@ -210,13 +210,25 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp case class LessThan(left: Expression, right: Expression) extends BinaryComparison { override def symbol: String = "<" - lazy val ordering: Ordering[Any] = { + lazy val orderings: Either[Ordering[Any], (Int, Ordering[Any])] = { if (left.dataType != right.dataType) { throw new TreeNodeException(this, s"Types do not match ${left.dataType} != ${right.dataType}") } left.dataType match { - case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]] + case i: AtomicType => Left(i.ordering.asInstanceOf[Ordering[Any]]) + case s: StructType => + var i = -1 + val f = s.fields.find { f => + i += 1 + f.dataType.isInstanceOf[AtomicType] + } + if (f == None) { + sys.error(s"Fields in $s do not support ordered operations") + } else { + val a = f.get.dataType.asInstanceOf[AtomicType] + Right((i, a.ordering.asInstanceOf[Ordering[Any]])) + } case other => sys.error(s"Type $other does not support ordered operations") } } @@ -230,7 +242,14 @@ case class LessThan(left: Expression, right: Expression) extends BinaryCompariso if (evalE2 == null) { null } else { - ordering.lt(evalE1, evalE2) + orderings match { + case Left(ordering) => + ordering.lt(evalE1, evalE2) + case Right((idx, ordering)) => + val evalE1Row = evalE1.asInstanceOf[Row] + val evalE2Row = evalE2.asInstanceOf[Row] + ordering.lt(evalE1Row(idx), evalE2Row(idx)) + } } } } @@ -239,13 +258,25 @@ case class LessThan(left: Expression, right: Expression) extends BinaryCompariso case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryComparison { override def symbol: String = "<=" - lazy val ordering: Ordering[Any] = { + lazy val orderings: Either[Ordering[Any], (Int, Ordering[Any])] = { if (left.dataType != right.dataType) { throw new TreeNodeException(this, s"Types do not match ${left.dataType} != ${right.dataType}") } left.dataType match { - case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]] + case i: AtomicType => Left(i.ordering.asInstanceOf[Ordering[Any]]) + case s: StructType => + var i = -1 + val f = s.fields.find { f => + i += 1 + f.dataType.isInstanceOf[AtomicType] + } + if (f == None) { + sys.error(s"Fields in $s do not support ordered operations") + } else { + val a = f.get.dataType.asInstanceOf[AtomicType] + Right((i, a.ordering.asInstanceOf[Ordering[Any]])) + } case other => sys.error(s"Type $other does not support ordered operations") } } @@ -259,7 +290,14 @@ case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryCo if (evalE2 == null) { null } else { - ordering.lteq(evalE1, evalE2) + orderings match { + case Left(ordering) => + ordering.lt(evalE1, evalE2) + case Right((idx, ordering)) => + val evalE1Row = evalE1.asInstanceOf[Row] + val evalE2Row = evalE2.asInstanceOf[Row] + ordering.lteq(evalE1Row(idx), evalE2Row(idx)) + } } } } @@ -268,13 +306,25 @@ case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryCo case class GreaterThan(left: Expression, right: Expression) extends BinaryComparison { override def symbol: String = ">" - lazy val ordering: Ordering[Any] = { + lazy val orderings: Either[Ordering[Any], (Int, Ordering[Any])] = { if (left.dataType != right.dataType) { throw new TreeNodeException(this, s"Types do not match ${left.dataType} != ${right.dataType}") } left.dataType match { - case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]] + case i: AtomicType => Left(i.ordering.asInstanceOf[Ordering[Any]]) + case s: StructType => + var i = -1 + val f = s.fields.find { f => + i += 1 + f.dataType.isInstanceOf[AtomicType] + } + if (f == None) { + sys.error(s"Fields in $s do not support ordered operations") + } else { + val a = f.get.dataType.asInstanceOf[AtomicType] + Right((i, a.ordering.asInstanceOf[Ordering[Any]])) + } case other => sys.error(s"Type $other does not support ordered operations") } } @@ -288,7 +338,14 @@ case class GreaterThan(left: Expression, right: Expression) extends BinaryCompar if (evalE2 == null) { null } else { - ordering.gt(evalE1, evalE2) + orderings match { + case Left(ordering) => + ordering.lt(evalE1, evalE2) + case Right((idx, ordering))=> + val evalE1Row = evalE1.asInstanceOf[Row] + val evalE2Row = evalE2.asInstanceOf[Row] + ordering.gt(evalE1Row(idx), evalE2Row(idx)) + } } } } @@ -297,13 +354,25 @@ case class GreaterThan(left: Expression, right: Expression) extends BinaryCompar case class GreaterThanOrEqual(left: Expression, right: Expression) extends BinaryComparison { override def symbol: String = ">=" - lazy val ordering: Ordering[Any] = { + lazy val orderings: Either[Ordering[Any], (Int, Ordering[Any])] = { if (left.dataType != right.dataType) { throw new TreeNodeException(this, s"Types do not match ${left.dataType} != ${right.dataType}") } left.dataType match { - case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]] + case i: AtomicType => Left(i.ordering.asInstanceOf[Ordering[Any]]) + case s: StructType => + var i = -1 + val f = s.fields.find { f => + i += 1 + f.dataType.isInstanceOf[AtomicType] + } + if (f == None) { + sys.error(s"Fields in $s do not support ordered operations") + } else { + val a = f.get.dataType.asInstanceOf[AtomicType] + Right((i, a.ordering.asInstanceOf[Ordering[Any]])) + } case other => sys.error(s"Type $other does not support ordered operations") } } @@ -317,7 +386,14 @@ case class GreaterThanOrEqual(left: Expression, right: Expression) extends Binar if (evalE2 == null) { null } else { - ordering.gteq(evalE1, evalE2) + orderings match { + case Left(ordering) => + ordering.lt(evalE1, evalE2) + case Right((idx, ordering)) => + val evalE1Row = evalE1.asInstanceOf[Row] + val evalE2Row = evalE2.asInstanceOf[Row] + ordering.gteq(evalE1Row(idx), evalE2Row(idx)) + } } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala index 7f49eac490572..996f4d21af203 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala @@ -93,6 +93,21 @@ class HiveUdfSuite extends QueryTest { sql("DROP TEMPORARY FUNCTION IF EXISTS testUdf") } + test("Max/Min on named_struct") { + assert(sql( + """ + |SELECT max(named_struct( + | "key", key, + | "value", value)).value FROM src + """.stripMargin).head() === Row("val_498")) + assert(sql( + """ + |SELECT min(named_struct( + | "key", key, + | "value", value)).value FROM src + """.stripMargin).head() === Row("val_0")) + } + test("SPARK-6409 UDAFAverage test") { sql(s"CREATE TEMPORARY FUNCTION test_avg AS '${classOf[GenericUDAFAverage].getName}'") checkAnswer( From b6e10092cc4215a01f0378435a7622d014e1fd35 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sun, 31 May 2015 00:53:09 +0800 Subject: [PATCH 02/12] Fix scala style. --- .../org/apache/spark/sql/catalyst/expressions/predicates.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index a1d6c8673ca44..879d9048f872c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -341,7 +341,7 @@ case class GreaterThan(left: Expression, right: Expression) extends BinaryCompar orderings match { case Left(ordering) => ordering.lt(evalE1, evalE2) - case Right((idx, ordering))=> + case Right((idx, ordering)) => val evalE1Row = evalE1.asInstanceOf[Row] val evalE2Row = evalE2.asInstanceOf[Row] ordering.gt(evalE1Row(idx), evalE2Row(idx)) From f651b8dbfbd6b6c0cf67cd405ce472a0d5f6d2f7 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 1 Jun 2015 23:29:03 +0800 Subject: [PATCH 03/12] Remove Either and move orderings to BinaryComparison to reuse it. --- .../sql/catalyst/expressions/predicates.scala | 177 ++++++------------ 1 file changed, 57 insertions(+), 120 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 879d9048f872c..b34919761447d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -171,6 +171,27 @@ case class Or(left: Expression, right: Expression) abstract class BinaryComparison extends BinaryExpression with Predicate { self: Product => + + lazy val orderingsAndIndex: Seq[(Ordering[Any], Int)] = { + if (left.dataType != right.dataType) { + throw new TreeNodeException(this, + s"Types do not match ${left.dataType} != ${right.dataType}") + } + left.dataType match { + case i: AtomicType => Seq(i.ordering.asInstanceOf[Ordering[Any]]).zipWithIndex + case s: StructType => + val atomicFields = s.fields.filter(_.dataType.isInstanceOf[AtomicType]) + if (atomicFields.isEmpty) { + sys.error(s"Fields in $s do not support ordered operations") + } else { + atomicFields.map(_.dataType.asInstanceOf[AtomicType].ordering.asInstanceOf[Ordering[Any]]) + .zipWithIndex + } + case other => sys.error(s"Type $other does not support ordered operations") + } + } + + } case class EqualTo(left: Expression, right: Expression) extends BinaryComparison { @@ -210,29 +231,6 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp case class LessThan(left: Expression, right: Expression) extends BinaryComparison { override def symbol: String = "<" - lazy val orderings: Either[Ordering[Any], (Int, Ordering[Any])] = { - if (left.dataType != right.dataType) { - throw new TreeNodeException(this, - s"Types do not match ${left.dataType} != ${right.dataType}") - } - left.dataType match { - case i: AtomicType => Left(i.ordering.asInstanceOf[Ordering[Any]]) - case s: StructType => - var i = -1 - val f = s.fields.find { f => - i += 1 - f.dataType.isInstanceOf[AtomicType] - } - if (f == None) { - sys.error(s"Fields in $s do not support ordered operations") - } else { - val a = f.get.dataType.asInstanceOf[AtomicType] - Right((i, a.ordering.asInstanceOf[Ordering[Any]])) - } - case other => sys.error(s"Type $other does not support ordered operations") - } - } - override def eval(input: Row): Any = { val evalE1 = left.eval(input) if (evalE1 == null) { @@ -242,13 +240,18 @@ case class LessThan(left: Expression, right: Expression) extends BinaryCompariso if (evalE2 == null) { null } else { - orderings match { - case Left(ordering) => - ordering.lt(evalE1, evalE2) - case Right((idx, ordering)) => - val evalE1Row = evalE1.asInstanceOf[Row] - val evalE2Row = evalE2.asInstanceOf[Row] - ordering.lt(evalE1Row(idx), evalE2Row(idx)) + if (orderingsAndIndex.length == 1) { + orderingsAndIndex(0)._1.lt(evalE1, evalE2) + } else { + // For struct, we need to compare them by order of fields + val evalE1Row = evalE1.asInstanceOf[Row] + val evalE2Row = evalE2.asInstanceOf[Row] + orderingsAndIndex.foreach { ordering => + val idx = ordering._2 + val cmp = ordering._1.compare(evalE1Row(idx), evalE2Row(idx)) + if (cmp != 0) return cmp < 0 + } + false } } } @@ -258,29 +261,6 @@ case class LessThan(left: Expression, right: Expression) extends BinaryCompariso case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryComparison { override def symbol: String = "<=" - lazy val orderings: Either[Ordering[Any], (Int, Ordering[Any])] = { - if (left.dataType != right.dataType) { - throw new TreeNodeException(this, - s"Types do not match ${left.dataType} != ${right.dataType}") - } - left.dataType match { - case i: AtomicType => Left(i.ordering.asInstanceOf[Ordering[Any]]) - case s: StructType => - var i = -1 - val f = s.fields.find { f => - i += 1 - f.dataType.isInstanceOf[AtomicType] - } - if (f == None) { - sys.error(s"Fields in $s do not support ordered operations") - } else { - val a = f.get.dataType.asInstanceOf[AtomicType] - Right((i, a.ordering.asInstanceOf[Ordering[Any]])) - } - case other => sys.error(s"Type $other does not support ordered operations") - } - } - override def eval(input: Row): Any = { val evalE1 = left.eval(input) if (evalE1 == null) { @@ -290,13 +270,12 @@ case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryCo if (evalE2 == null) { null } else { - orderings match { - case Left(ordering) => - ordering.lt(evalE1, evalE2) - case Right((idx, ordering)) => - val evalE1Row = evalE1.asInstanceOf[Row] - val evalE2Row = evalE2.asInstanceOf[Row] - ordering.lteq(evalE1Row(idx), evalE2Row(idx)) + if (orderingsAndIndex.length == 1) { + orderingsAndIndex(0)._1.lteq(evalE1, evalE2) + } else { + val evalE1Row = evalE1.asInstanceOf[Row] + val evalE2Row = evalE2.asInstanceOf[Row] + orderingsAndIndex(0)._1.lteq(evalE1Row(0), evalE2Row(0)) } } } @@ -306,29 +285,6 @@ case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryCo case class GreaterThan(left: Expression, right: Expression) extends BinaryComparison { override def symbol: String = ">" - lazy val orderings: Either[Ordering[Any], (Int, Ordering[Any])] = { - if (left.dataType != right.dataType) { - throw new TreeNodeException(this, - s"Types do not match ${left.dataType} != ${right.dataType}") - } - left.dataType match { - case i: AtomicType => Left(i.ordering.asInstanceOf[Ordering[Any]]) - case s: StructType => - var i = -1 - val f = s.fields.find { f => - i += 1 - f.dataType.isInstanceOf[AtomicType] - } - if (f == None) { - sys.error(s"Fields in $s do not support ordered operations") - } else { - val a = f.get.dataType.asInstanceOf[AtomicType] - Right((i, a.ordering.asInstanceOf[Ordering[Any]])) - } - case other => sys.error(s"Type $other does not support ordered operations") - } - } - override def eval(input: Row): Any = { val evalE1 = left.eval(input) if(evalE1 == null) { @@ -338,13 +294,18 @@ case class GreaterThan(left: Expression, right: Expression) extends BinaryCompar if (evalE2 == null) { null } else { - orderings match { - case Left(ordering) => - ordering.lt(evalE1, evalE2) - case Right((idx, ordering)) => - val evalE1Row = evalE1.asInstanceOf[Row] - val evalE2Row = evalE2.asInstanceOf[Row] - ordering.gt(evalE1Row(idx), evalE2Row(idx)) + if (orderingsAndIndex.length == 1) { + orderingsAndIndex(0)._1.gt(evalE1, evalE2) + } else { + // For struct, we need to compare them by order of fields + val evalE1Row = evalE1.asInstanceOf[Row] + val evalE2Row = evalE2.asInstanceOf[Row] + orderingsAndIndex.foreach { ordering => + val idx = ordering._2 + val cmp = ordering._1.compare(evalE1Row(idx), evalE2Row(idx)) + if (cmp != 0) return cmp > 0 + } + false } } } @@ -354,29 +315,6 @@ case class GreaterThan(left: Expression, right: Expression) extends BinaryCompar case class GreaterThanOrEqual(left: Expression, right: Expression) extends BinaryComparison { override def symbol: String = ">=" - lazy val orderings: Either[Ordering[Any], (Int, Ordering[Any])] = { - if (left.dataType != right.dataType) { - throw new TreeNodeException(this, - s"Types do not match ${left.dataType} != ${right.dataType}") - } - left.dataType match { - case i: AtomicType => Left(i.ordering.asInstanceOf[Ordering[Any]]) - case s: StructType => - var i = -1 - val f = s.fields.find { f => - i += 1 - f.dataType.isInstanceOf[AtomicType] - } - if (f == None) { - sys.error(s"Fields in $s do not support ordered operations") - } else { - val a = f.get.dataType.asInstanceOf[AtomicType] - Right((i, a.ordering.asInstanceOf[Ordering[Any]])) - } - case other => sys.error(s"Type $other does not support ordered operations") - } - } - override def eval(input: Row): Any = { val evalE1 = left.eval(input) if (evalE1 == null) { @@ -386,13 +324,12 @@ case class GreaterThanOrEqual(left: Expression, right: Expression) extends Binar if (evalE2 == null) { null } else { - orderings match { - case Left(ordering) => - ordering.lt(evalE1, evalE2) - case Right((idx, ordering)) => - val evalE1Row = evalE1.asInstanceOf[Row] - val evalE2Row = evalE2.asInstanceOf[Row] - ordering.gteq(evalE1Row(idx), evalE2Row(idx)) + if (orderingsAndIndex.length == 1) { + orderingsAndIndex(0)._1.gteq(evalE1, evalE2) + } else { + val evalE1Row = evalE1.asInstanceOf[Row] + val evalE2Row = evalE2.asInstanceOf[Row] + orderingsAndIndex(0)._1.gteq(evalE1Row(0), evalE2Row(0)) } } } From cf58dc369a417240b60d50a1728413c195db7cbc Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 1 Jun 2015 23:32:50 +0800 Subject: [PATCH 04/12] Use checkAnswer. --- .../apache/spark/sql/hive/execution/HiveUdfSuite.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala index 996f4d21af203..df3b9cc64f7f4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala @@ -94,18 +94,18 @@ class HiveUdfSuite extends QueryTest { } test("Max/Min on named_struct") { - assert(sql( + checkAnswer(sql( """ |SELECT max(named_struct( | "key", key, | "value", value)).value FROM src - """.stripMargin).head() === Row("val_498")) - assert(sql( + """.stripMargin), Seq(Row("val_498"))) + checkAnswer(sql( """ |SELECT min(named_struct( | "key", key, | "value", value)).value FROM src - """.stripMargin).head() === Row("val_0")) + """.stripMargin), Seq(Row("val_0"))) } test("SPARK-6409 UDAFAverage test") { From 3c142e406e7b7461334113d57e6f9e00a1988422 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 1 Jun 2015 23:48:44 +0800 Subject: [PATCH 05/12] Fix scala style. --- .../org/apache/spark/sql/hive/execution/HiveUdfSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala index df3b9cc64f7f4..a6c7010e6b4a3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala @@ -98,13 +98,13 @@ class HiveUdfSuite extends QueryTest { """ |SELECT max(named_struct( | "key", key, - | "value", value)).value FROM src + | "value", value)).value FROM src """.stripMargin), Seq(Row("val_498"))) checkAnswer(sql( """ |SELECT min(named_struct( | "key", key, - | "value", value)).value FROM src + | "value", value)).value FROM src """.stripMargin), Seq(Row("val_0"))) } From 94b27d5dcb73bffa4a03d6cb2ec082084be17190 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 4 Jun 2015 15:44:27 +0800 Subject: [PATCH 06/12] Remove test for error on complex type comparison. --- .../expressions/ExpressionTypeCheckingSuite.scala | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala index dcb3635c5ccae..cab9b3438b4bf 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala @@ -89,9 +89,6 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertError(BitwiseAnd('booleanField, 'booleanField), "operator & accepts integral type") assertError(BitwiseOr('booleanField, 'booleanField), "operator | accepts integral type") assertError(BitwiseXor('booleanField, 'booleanField), "operator ^ accepts integral type") - - assertError(MaxOf('complexField, 'complexField), "function maxOf accepts non-complex type") - assertError(MinOf('complexField, 'complexField), "function minOf accepts non-complex type") } test("check types for predicates") { @@ -115,15 +112,6 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertErrorForDifferingTypes(GreaterThan('intField, 'booleanField)) assertErrorForDifferingTypes(GreaterThanOrEqual('intField, 'booleanField)) - assertError( - LessThan('complexField, 'complexField), "operator < accepts non-complex type") - assertError( - LessThanOrEqual('complexField, 'complexField), "operator <= accepts non-complex type") - assertError( - GreaterThan('complexField, 'complexField), "operator > accepts non-complex type") - assertError( - GreaterThanOrEqual('complexField, 'complexField), "operator >= accepts non-complex type") - assertError( If('intField, 'stringField, 'stringField), "type of predicate expression in If should be boolean") From 9d67f686b3b06459708b6b0d31d49aa81e55b381 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 28 Jul 2015 23:39:35 +0800 Subject: [PATCH 07/12] Fix wrongly merging. --- .../spark/sql/catalyst/InternalRow.scala | 32 +++++++++++-------- .../expressions/codegen/CodeGenerator.scala | 7 ++-- 2 files changed, 22 insertions(+), 17 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala index e77972a4b1dac..57863567617d2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala @@ -21,13 +21,13 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.unsafe.types.{Interval, UTF8String} /** * An abstract class for row used internal in Spark SQL, which only contain the columns as * internal types. */ -abstract class InternalRow extends Serializable { +abstract class InternalRow extends Serializable with SpecializedGetters { def numFields: Int @@ -39,27 +39,30 @@ abstract class InternalRow extends Serializable { def getAs[T](ordinal: Int, dataType: DataType): T = get(ordinal, dataType).asInstanceOf[T] - def isNullAt(ordinal: Int): Boolean = get(ordinal) == null + override def isNullAt(ordinal: Int): Boolean = get(ordinal) == null - def getBoolean(ordinal: Int): Boolean = getAs[Boolean](ordinal, BooleanType) + override def getBoolean(ordinal: Int): Boolean = getAs[Boolean](ordinal, BooleanType) - def getByte(ordinal: Int): Byte = getAs[Byte](ordinal, ByteType) + override def getByte(ordinal: Int): Byte = getAs[Byte](ordinal, ByteType) - def getShort(ordinal: Int): Short = getAs[Short](ordinal, ShortType) + override def getShort(ordinal: Int): Short = getAs[Short](ordinal, ShortType) - def getInt(ordinal: Int): Int = getAs[Int](ordinal, IntegerType) + override def getInt(ordinal: Int): Int = getAs[Int](ordinal, IntegerType) - def getLong(ordinal: Int): Long = getAs[Long](ordinal, LongType) + override def getLong(ordinal: Int): Long = getAs[Long](ordinal, LongType) - def getFloat(ordinal: Int): Float = getAs[Float](ordinal, FloatType) + override def getFloat(ordinal: Int): Float = getAs[Float](ordinal, FloatType) - def getDouble(ordinal: Int): Double = getAs[Double](ordinal, DoubleType) + override def getDouble(ordinal: Int): Double = getAs[Double](ordinal, DoubleType) - def getUTF8String(ordinal: Int): UTF8String = getAs[UTF8String](ordinal, StringType) + override def getUTF8String(ordinal: Int): UTF8String = getAs[UTF8String](ordinal, StringType) - def getBinary(ordinal: Int): Array[Byte] = getAs[Array[Byte]](ordinal, BinaryType) + override def getBinary(ordinal: Int): Array[Byte] = getAs[Array[Byte]](ordinal, BinaryType) - def getDecimal(ordinal: Int): Decimal = getAs[Decimal](ordinal, DecimalType.SYSTEM_DEFAULT) + override def getDecimal(ordinal: Int): Decimal = + getAs[Decimal](ordinal, DecimalType.SYSTEM_DEFAULT) + + override def getInterval(ordinal: Int): Interval = getAs[Interval](ordinal, IntervalType) // This is only use for test and will throw a null pointer exception if the position is null. def getString(ordinal: Int): String = getUTF8String(ordinal).toString @@ -70,7 +73,8 @@ abstract class InternalRow extends Serializable { * @param ordinal position to get the struct from. * @param numFields number of fields the struct type has */ - def getStruct(ordinal: Int, numFields: Int): InternalRow = getAs[InternalRow](ordinal, null) + override def getStruct(ordinal: Int, numFields: Int): InternalRow = + getAs[InternalRow](ordinal, null) override def toString: String = s"[${this.mkString(",")}]" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index ffed75fa0245a..a7c5ef8244e0b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -79,7 +79,6 @@ class CodeGenContext { mutableStates += ((javaType, variableName, initCode)) } - final val intervalType: String = classOf[Interval].getName final val JAVA_BOOLEAN = "boolean" final val JAVA_BYTE = "byte" final val JAVA_SHORT = "short" @@ -109,6 +108,7 @@ class CodeGenContext { case _ if isPrimitiveType(jt) => s"$row.get${primitiveTypeName(jt)}($ordinal)" case StringType => s"$row.getUTF8String($ordinal)" case BinaryType => s"$row.getBinary($ordinal)" + case IntervalType => s"$row.getInterval($ordinal)" case t: StructType => s"$row.getStruct($ordinal, ${t.size})" case _ => s"($jt)$row.get($ordinal)" } @@ -150,7 +150,7 @@ class CodeGenContext { case dt: DecimalType => "Decimal" case BinaryType => "byte[]" case StringType => "UTF8String" - case IntervalType => intervalType + case IntervalType => "Interval" case _: StructType => "InternalRow" case _: ArrayType => s"scala.collection.Seq" case _: MapType => s"scala.collection.Map" @@ -294,7 +294,8 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin classOf[InternalRow].getName, classOf[UnsafeRow].getName, classOf[UTF8String].getName, - classOf[Decimal].getName + classOf[Decimal].getName, + classOf[Interval].getName )) evaluator.setExtendedClass(classOf[GeneratedClass]) try { From 1187a65d0173669abc39be5c2e016346c5ed36d4 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 28 Jul 2015 23:54:38 +0800 Subject: [PATCH 08/12] Fix scala style. --- .../main/scala/org/apache/spark/sql/catalyst/InternalRow.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala index 57863567617d2..78af073698ff1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala @@ -219,7 +219,8 @@ abstract class InternalRow extends Serializable with SpecializedGetters { case l: Long => l.compare(other.get(i).asInstanceOf[Long]) case f: Float => f.compare(other.get(i).asInstanceOf[Float]) case d: Double => d.compare(other.get(i).asInstanceOf[Double]) - case a: Array[Byte] => TypeUtils.compareBinary(a, other.get(i).asInstanceOf[Array[Byte]]) + case a: Array[Byte] => + TypeUtils.compareBinary(a, other.get(i).asInstanceOf[Array[Byte]]) case u: UTF8String => u.compare(other.get(i).asInstanceOf[UTF8String]) case d: Decimal => d.compare(other.get(i).asInstanceOf[Decimal]) } From 1f661966103ff01b7997cf43f95d5be2b32473d2 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 29 Jul 2015 18:20:08 +0800 Subject: [PATCH 09/12] Reuse RowOrdering and GenerateOrdering. --- .../expressions/codegen/CodeGenerator.scala | 33 +++++++++++++++++-- .../codegen/GenerateMutableProjection.scala | 1 + .../codegen/GenerateOrdering.scala | 12 ++++--- .../codegen/GeneratePredicate.scala | 2 ++ .../codegen/GenerateProjection.scala | 1 + .../codegen/GenerateUnsafeProjection.scala | 1 + .../spark/sql/catalyst/expressions/rows.scala | 4 +++ .../spark/sql/catalyst/util/TypeUtils.scala | 2 +- .../apache/spark/sql/types/StructType.scala | 31 ++++++----------- 9 files changed, 59 insertions(+), 28 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index a7c5ef8244e0b..fd20d8576b30c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -79,6 +79,16 @@ class CodeGenContext { mutableStates += ((javaType, variableName, initCode)) } + /** + * Holding all the functions those will be added into generated class. + */ + val addedFuntions: mutable.Map[String, String] = + mutable.Map.empty[String, String] + + def addNewFunction(funcName: String, funcCode: String): Unit = { + addedFuntions += ((funcName, funcCode)) + } + final val JAVA_BOOLEAN = "boolean" final val JAVA_BYTE = "byte" final val JAVA_SHORT = "short" @@ -214,8 +224,21 @@ class CodeGenContext { case dt: DataType if isPrimitiveType(dt) => s"($c1 > $c2 ? 1 : $c1 < $c2 ? -1 : 0)" case BinaryType => s"org.apache.spark.sql.catalyst.util.TypeUtils.compareBinary($c1, $c2)" case NullType => "0" - case s: StructType if s.fields.forall(_.dataType.isInstanceOf[AtomicType]) => - s"$c1.compare($c2)" + case s: StructType if s.supportOrdering(s) => + val ordering = s.fields.map(_.dataType).zipWithIndex.map { + case(dt, index) => new SortOrder(BoundReference(index, dt, nullable = true), Ascending) + } + val comparisons = GenerateOrdering.genComparisons(ordering, this) + val funcCode: String = + s""" + public int compareStruct(InternalRow a, InternalRow b) { + InternalRow i = null; + $comparisons + return 0; + } + """ + addNewFunction("compareStruct", funcCode) + s"this.compareStruct($c1, $c2)" case other => s"$c1.compare($c2)" } @@ -262,6 +285,12 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin ctx.mutableStates.map(_._3).mkString("\n ") } + protected def declareAddedFunctions(ctx: CodeGenContext): String = { + ctx.addedFuntions.map { case (funcName, funcCode) => + s"$funcCode" + }.mkString("\n ") + } + /** * Generates a class for a given input expression. Called when there is not cached code * already available. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 825031a4faf5e..e4a8fc24dac2f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -92,6 +92,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu private $exprType[] expressions; private $mutableRowType mutableRow; ${declareMutableStates(ctx)} + ${declareAddedFunctions(ctx)} public SpecificProjection($exprType[] expr) { expressions = expr; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index dbd4616d281c8..6115b11419249 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -43,9 +43,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR protected def bind(in: Seq[SortOrder], inputSchema: Seq[Attribute]): Seq[SortOrder] = in.map(BindReferences.bindReference(_, inputSchema)) - protected def create(ordering: Seq[SortOrder]): Ordering[InternalRow] = { - val ctx = newCodeGenContext() - + def genComparisons(ordering: Seq[SortOrder], ctx: CodeGenContext): String = { val comparisons = ordering.map { order => val eval = order.child.gen(ctx) val asc = order.direction == Ascending @@ -84,6 +82,12 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR } """ }.mkString("\n") + comparisons + } + + protected def create(ordering: Seq[SortOrder]): Ordering[InternalRow] = { + val ctx = newCodeGenContext() + val comparisons = genComparisons(ordering, ctx) val code = s""" public SpecificOrdering generate($exprType[] expr) { return new SpecificOrdering(expr); @@ -93,6 +97,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR private $exprType[] expressions; ${declareMutableStates(ctx)} + ${declareAddedFunctions(ctx)} public SpecificOrdering($exprType[] expr) { expressions = expr; @@ -106,7 +111,6 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR return 0; } }""" - logDebug(s"Generated Ordering: ${CodeFormatter.format(code)}") compile(code).generate(ctx.references.toArray).asInstanceOf[BaseOrdering] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala index dfd593fb7c064..c7e718a526420 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala @@ -48,6 +48,8 @@ object GeneratePredicate extends CodeGenerator[Expression, (InternalRow) => Bool class SpecificPredicate extends ${classOf[Predicate].getName} { private final $exprType[] expressions; ${declareMutableStates(ctx)} + ${declareAddedFunctions(ctx)} + public SpecificPredicate($exprType[] expr) { expressions = expr; ${initMutableStates(ctx)} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index 35920147105ff..41ca09b1e7e38 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -159,6 +159,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { class SpecificProjection extends ${classOf[BaseProjection].getName} { private $exprType[] expressions; ${declareMutableStates(ctx)} + ${declareAddedFunctions(ctx)} public SpecificProjection($exprType[] expr) { expressions = expr; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 9a4c00e86a3ec..56a258ca0f16d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -266,6 +266,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro class SpecificProjection extends ${classOf[UnsafeProjection].getName} { ${declareMutableStates(ctx)} + ${declareAddedFunctions(ctx)} public SpecificProjection() { ${initMutableStates(ctx)} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index b7c4ece4a16fe..34a33f0ee4323 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -166,6 +166,10 @@ class RowOrdering(ordering: Seq[SortOrder]) extends Ordering[InternalRow] { n.ordering.asInstanceOf[Ordering[Any]].compare(left, right) case n: AtomicType if order.direction == Descending => n.ordering.asInstanceOf[Ordering[Any]].reverse.compare(left, right) + case s: StructType if order.direction == Ascending => + s.ordering.asInstanceOf[Ordering[Any]].compare(left, right) + case s: StructType if order.direction == Descending => + s.ordering.asInstanceOf[Ordering[Any]].reverse.compare(left, right) case other => sys.error(s"Type $other does not support ordered operations") } if (comparison != 0) return comparison diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala index 217f25d035a1f..2f50d40fe25ac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala @@ -37,7 +37,7 @@ object TypeUtils { case i: AtomicType => TypeCheckResult.TypeCheckSuccess case n: NullType => TypeCheckResult.TypeCheckSuccess case s: StructType => - if (s.fields.exists(_.dataType.isInstanceOf[AtomicType])) { + if (s.supportOrdering(s)) { TypeCheckResult.TypeCheckSuccess } else { TypeCheckResult.TypeCheckFailure(s"Fields in $s do not support ordering") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index 086ffa89d8cf0..2f23144858198 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -24,8 +24,7 @@ import org.json4s.JsonDSL._ import org.apache.spark.SparkException import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Attribute} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Attribute, RowOrdering} /** @@ -302,26 +301,16 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru StructType(newFields) } - private[sql] val ordering = StructTypeOrdering.getOrdering(this) -} + private[sql] val ordering = RowOrdering.forSchema(this.fields.map(_.dataType)) -object StructTypeOrdering { - def getOrdering(s: StructType): Ordering[Row] = { - val atomicFields = s.fields.filter(_.dataType.isInstanceOf[AtomicType]) - val orderings = atomicFields.map { f => - f.dataType.asInstanceOf[AtomicType].ordering.asInstanceOf[Ordering[Any]] - }.zipWithIndex - - new Ordering[Row] { - def compare(a: Row, b: Row): Int = { - orderings.foreach { ord => - val idx = ord._2 - val cmp = ord._1.compare(a(idx), b(idx)) - if (cmp != 0) { - return cmp - } - } - return 0 + private[sql] def supportOrdering(s: StructType): Boolean = { + s.fields.forall { f => + if (f.dataType.isInstanceOf[AtomicType]) { + true + } else if (f.dataType.isInstanceOf[StructType]) { + supportOrdering(f.dataType.asInstanceOf[StructType]) + } else { + false } } } From dae6aadba7e87e0c24919958df664268f962ae69 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 31 Jul 2015 16:30:08 +0800 Subject: [PATCH 10/12] Fix nested struct. --- .../expressions/codegen/CodeGenerator.scala | 7 +-- .../sql/hive/execution/HiveUDFSuite.scala | 52 ++++++++++++++----- 2 files changed, 43 insertions(+), 16 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 9832509e8e937..bc46e39db0c4a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -231,16 +231,17 @@ class CodeGenContext { case(dt, index) => new SortOrder(BoundReference(index, dt, nullable = true), Ascending) } val comparisons = GenerateOrdering.genComparisons(ordering, this) + val funId = curId.getAndIncrement val funcCode: String = s""" - public int compareStruct(InternalRow a, InternalRow b) { + public int compareStruct${funId}(InternalRow a, InternalRow b) { InternalRow i = null; $comparisons return 0; } """ - addNewFunction("compareStruct", funcCode) - s"this.compareStruct($c1, $c2)" + addNewFunction(s"compareStruct${funId}", funcCode) + s"this.compareStruct${funId}($c1, $c2)" case other if other.isInstanceOf[AtomicType] => s"$c1.compare($c2)" case _ => throw new IllegalArgumentException( "cannot generate compare code for un-comparable type") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala index 50fd79563e883..7069afc9f7da2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala @@ -28,7 +28,7 @@ import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectIn import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectInspectorFactory} import org.apache.hadoop.hive.serde2.{AbstractSerDe, SerDeStats} import org.apache.hadoop.io.Writable -import org.apache.spark.sql.{AnalysisException, QueryTest, Row} +import org.apache.spark.sql.{AnalysisException, QueryTest, Row, SQLConf} import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.util.Utils @@ -94,18 +94,44 @@ class HiveUDFSuite extends QueryTest { } test("Max/Min on named_struct") { - checkAnswer(sql( - """ - |SELECT max(named_struct( - | "key", key, - | "value", value)).value FROM src - """.stripMargin), Seq(Row("val_498"))) - checkAnswer(sql( - """ - |SELECT min(named_struct( - | "key", key, - | "value", value)).value FROM src - """.stripMargin), Seq(Row("val_0"))) + def testOrderInStruct(): Unit = { + checkAnswer(sql( + """ + |SELECT max(named_struct( + | "key", key, + | "value", value)).value FROM src + """.stripMargin), Seq(Row("val_498"))) + checkAnswer(sql( + """ + |SELECT min(named_struct( + | "key", key, + | "value", value)).value FROM src + """.stripMargin), Seq(Row("val_0"))) + + // nested struct cases + checkAnswer(sql( + """ + |SELECT max(named_struct( + | "key", named_struct( + "key", key, + "value", value), + | "value", value)).value FROM src + """.stripMargin), Seq(Row("val_498"))) + checkAnswer(sql( + """ + |SELECT min(named_struct( + | "key", named_struct( + "key", key, + "value", value), + | "value", value)).value FROM src + """.stripMargin), Seq(Row("val_0"))) + } + val codegenDefault = TestHive.getConf(SQLConf.CODEGEN_ENABLED) + TestHive.setConf(SQLConf.CODEGEN_ENABLED, true) + testOrderInStruct() + TestHive.setConf(SQLConf.CODEGEN_ENABLED, false) + testOrderInStruct() + TestHive.setConf(SQLConf.CODEGEN_ENABLED, codegenDefault) } test("SPARK-6409 UDAFAverage test") { From 3a3f40e05f2049069173da8eb628968ff1341d71 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 31 Jul 2015 16:40:51 +0800 Subject: [PATCH 11/12] Don't need to add compare to InternalRow because we can use RowOrdering. --- .../spark/sql/catalyst/InternalRow.scala | 34 ------------------- 1 file changed, 34 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala index 653b162b5e8a9..4d7325139cc4b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala @@ -201,40 +201,6 @@ abstract class InternalRow extends Serializable with SpecializedGetters { } result } - - def compare(other: InternalRow): Int = { - val len = Math.min(numFields, other.numFields) - var i = 0 - while (i < len) { - val ret: Int = - if (isNullAt(i) && other.isNullAt(i)) { - 0 - } else if (isNullAt(i)) { - 1 - } else if (other.isNullAt(i)) { - -1 - } else { - get(i) match { - case b: Boolean => b.compare(other.get(i).asInstanceOf[Boolean]) - case b: Byte => b.compare(other.get(i).asInstanceOf[Byte]) - case s: Short => s.compare(other.get(i).asInstanceOf[Short]) - case n: Int => n.compare(other.get(i).asInstanceOf[Int]) - case l: Long => l.compare(other.get(i).asInstanceOf[Long]) - case f: Float => f.compare(other.get(i).asInstanceOf[Float]) - case d: Double => d.compare(other.get(i).asInstanceOf[Double]) - case a: Array[Byte] => - TypeUtils.compareBinary(a, other.get(i).asInstanceOf[Array[Byte]]) - case u: UTF8String => u.compare(other.get(i).asInstanceOf[UTF8String]) - case d: Decimal => d.compare(other.get(i).asInstanceOf[Decimal]) - } - } - if (ret != 0) { - return ret - } - i += 1 - } - numFields - other.numFields - } } object InternalRow { From d2ba8adcf5057b6d216badd855c899b701d6282f Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 31 Jul 2015 16:51:32 +0800 Subject: [PATCH 12/12] Remove unused import. --- .../main/scala/org/apache/spark/sql/catalyst/InternalRow.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala index 4d7325139cc4b..b19bf4386b0ba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}