diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 2ab5cb666fbcd..8420c54f7c335 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -150,6 +150,7 @@ object HiveTypeCoercion { * Converts string "NaN"s that are in binary operators with a NaN-able types (Float / Double) to * the appropriate numeric equivalent. */ + // TODO: remove this rule and make Cast handle Nan. object ConvertNaNs extends Rule[LogicalPlan] { private val StringNaN = Literal("NaN") @@ -159,19 +160,19 @@ object HiveTypeCoercion { case e if !e.childrenResolved => e /* Double Conversions */ - case b @ BinaryExpression(StringNaN, right @ DoubleType()) => + case b @ BinaryOperator(StringNaN, right @ DoubleType()) => b.makeCopy(Array(Literal(Double.NaN), right)) - case b @ BinaryExpression(left @ DoubleType(), StringNaN) => + case b @ BinaryOperator(left @ DoubleType(), StringNaN) => b.makeCopy(Array(left, Literal(Double.NaN))) /* Float Conversions */ - case b @ BinaryExpression(StringNaN, right @ FloatType()) => + case b @ BinaryOperator(StringNaN, right @ FloatType()) => b.makeCopy(Array(Literal(Float.NaN), right)) - case b @ BinaryExpression(left @ FloatType(), StringNaN) => + case b @ BinaryOperator(left @ FloatType(), StringNaN) => b.makeCopy(Array(left, Literal(Float.NaN))) /* Use float NaN by default to avoid unnecessary type widening */ - case b @ BinaryExpression(left @ StringNaN, StringNaN) => + case b @ BinaryOperator(left @ StringNaN, StringNaN) => b.makeCopy(Array(left, Literal(Float.NaN))) } } @@ -245,12 +246,12 @@ object HiveTypeCoercion { Union(newLeft, newRight) - // Also widen types for BinaryExpressions. + // Also widen types for BinaryOperator. case q: LogicalPlan => q transformExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e - case b @ BinaryExpression(left, right) if left.dataType != right.dataType => + case b @ BinaryOperator(left, right) if left.dataType != right.dataType => findTightestCommonTypeOfTwo(left.dataType, right.dataType).map { widestType => val newLeft = if (left.dataType == widestType) left else Cast(left, widestType) val newRight = if (right.dataType == widestType) right else Cast(right, widestType) @@ -478,7 +479,7 @@ object HiveTypeCoercion { // Promote integers inside a binary expression with fixed-precision decimals to decimals, // and fixed-precision decimals in an expression with floats / doubles to doubles - case b @ BinaryExpression(left, right) if left.dataType != right.dataType => + case b @ BinaryOperator(left, right) if left.dataType != right.dataType => (left.dataType, right.dataType) match { case (t, DecimalType.Fixed(p, s)) if intTypeToFixed.contains(t) => b.makeCopy(Array(Cast(left, intTypeToFixed(t)), right)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala new file mode 100644 index 0000000000000..450fc4165f93b --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.types.DataType + + +/** + * An trait that gets mixin to define the expected input types of an expression. + */ +trait ExpectsInputTypes { self: Expression => + + /** + * Expected input types from child expressions. The i-th position in the returned seq indicates + * the type requirement for the i-th child. + * + * The possible values at each position are: + * 1. a specific data type, e.g. LongType, StringType. + * 2. a non-leaf data type, e.g. NumericType, IntegralType, FractionalType. + * 3. a list of specific data types, e.g. Seq(StringType, BinaryType). + */ + def inputTypes: Seq[Any] + + override def checkInputDataTypes(): TypeCheckResult = { + // We will do the type checking in `HiveTypeCoercion`, so always returning success here. + TypeCheckResult.TypeCheckSuccess + } +} + +/** + * Expressions that require a specific `DataType` as input should implement this trait + * so that the proper type conversions can be performed in the analyzer. + */ +trait AutoCastInputTypes { self: Expression => + + def inputTypes: Seq[DataType] + + override def checkInputDataTypes(): TypeCheckResult = { + // We will always do type casting for `AutoCastInputTypes` in `HiveTypeCoercion`, + // so type mismatch error won't be reported here, but for underling `Cast`s. + TypeCheckResult.TypeCheckSuccess + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index e18a3118945e8..cafbbafdca207 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -119,17 +119,6 @@ abstract class Expression extends TreeNode[Expression] { */ def childrenResolved: Boolean = children.forall(_.resolved) - /** - * Returns a string representation of this expression that does not have developer centric - * debugging information like the expression id. - */ - def prettyString: String = { - transform { - case a: AttributeReference => PrettyAttribute(a.name) - case u: UnresolvedAttribute => PrettyAttribute(u.name) - }.toString - } - /** * Returns true when two expressions will always compute the same result, even if they differ * cosmetically (i.e. capitalization of names in attributes may be different). @@ -154,71 +143,40 @@ abstract class Expression extends TreeNode[Expression] { * Note: it's not valid to call this method until `childrenResolved == true`. */ def checkInputDataTypes(): TypeCheckResult = TypeCheckResult.TypeCheckSuccess -} - -abstract class BinaryExpression extends Expression with trees.BinaryNode[Expression] { - self: Product => - - def symbol: String = sys.error(s"BinaryExpressions must override either toString or symbol") - - override def foldable: Boolean = left.foldable && right.foldable - - override def nullable: Boolean = left.nullable || right.nullable - - override def toString: String = s"($left $symbol $right)" /** - * Short hand for generating binary evaluation code. - * If either of the sub-expressions is null, the result of this computation - * is assumed to be null. - * - * @param f accepts two variable names and returns Java code to compute the output. + * Returns a user-facing string representation of this expression's name. + * This should usually match the name of the function in SQL. */ - protected def defineCodeGen( - ctx: CodeGenContext, - ev: GeneratedExpressionCode, - f: (String, String) => String): String = { - nullSafeCodeGen(ctx, ev, (result, eval1, eval2) => { - s"$result = ${f(eval1, eval2)};" - }) - } + def prettyName: String = getClass.getSimpleName.toLowerCase /** - * Short hand for generating binary evaluation code. - * If either of the sub-expressions is null, the result of this computation - * is assumed to be null. + * Returns a user-facing string representation of this expression, i.e. does not have developer + * centric debugging information like the expression id. */ - protected def nullSafeCodeGen( - ctx: CodeGenContext, - ev: GeneratedExpressionCode, - f: (String, String, String) => String): String = { - val eval1 = left.gen(ctx) - val eval2 = right.gen(ctx) - val resultCode = f(ev.primitive, eval1.primitive, eval2.primitive) - s""" - ${eval1.code} - boolean ${ev.isNull} = ${eval1.isNull}; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - if (!${ev.isNull}) { - ${eval2.code} - if (!${eval2.isNull}) { - $resultCode - } else { - ${ev.isNull} = true; - } - } - """ + def prettyString: String = { + transform { + case a: AttributeReference => PrettyAttribute(a.name) + case u: UnresolvedAttribute => PrettyAttribute(u.name) + }.toString } -} -private[sql] object BinaryExpression { - def unapply(e: BinaryExpression): Option[(Expression, Expression)] = Some((e.left, e.right)) + override def toString: String = prettyName + children.mkString("(", ",", ")") } + +/** + * A leaf expression, i.e. one without any child expressions. + */ abstract class LeafExpression extends Expression with trees.LeafNode[Expression] { self: Product => } + +/** + * An expression with one input and one output. The output is by default evaluated to null + * if the input is evaluated to null. + */ abstract class UnaryExpression extends Expression with trees.UnaryNode[Expression] { self: Product => @@ -265,39 +223,76 @@ abstract class UnaryExpression extends Expression with trees.UnaryNode[Expressio } } + /** - * An trait that gets mixin to define the expected input types of an expression. + * An expression with two inputs and one output. The output is by default evaluated to null + * if any input is evaluated to null. */ -trait ExpectsInputTypes { self: Expression => +abstract class BinaryExpression extends Expression with trees.BinaryNode[Expression] { + self: Product => + + override def foldable: Boolean = left.foldable && right.foldable + + override def nullable: Boolean = left.nullable || right.nullable /** - * Expected input types from child expressions. The i-th position in the returned seq indicates - * the type requirement for the i-th child. + * Short hand for generating binary evaluation code. + * If either of the sub-expressions is null, the result of this computation + * is assumed to be null. * - * The possible values at each position are: - * 1. a specific data type, e.g. LongType, StringType. - * 2. a non-leaf data type, e.g. NumericType, IntegralType, FractionalType. - * 3. a list of specific data types, e.g. Seq(StringType, BinaryType). + * @param f accepts two variable names and returns Java code to compute the output. */ - def inputTypes: Seq[Any] + protected def defineCodeGen( + ctx: CodeGenContext, + ev: GeneratedExpressionCode, + f: (String, String) => String): String = { + nullSafeCodeGen(ctx, ev, (result, eval1, eval2) => { + s"$result = ${f(eval1, eval2)};" + }) + } - override def checkInputDataTypes(): TypeCheckResult = { - // We will do the type checking in `HiveTypeCoercion`, so always returning success here. - TypeCheckResult.TypeCheckSuccess + /** + * Short hand for generating binary evaluation code. + * If either of the sub-expressions is null, the result of this computation + * is assumed to be null. + */ + protected def nullSafeCodeGen( + ctx: CodeGenContext, + ev: GeneratedExpressionCode, + f: (String, String, String) => String): String = { + val eval1 = left.gen(ctx) + val eval2 = right.gen(ctx) + val resultCode = f(ev.primitive, eval1.primitive, eval2.primitive) + s""" + ${eval1.code} + boolean ${ev.isNull} = ${eval1.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${eval2.code} + if (!${eval2.isNull}) { + $resultCode + } else { + ${ev.isNull} = true; + } + } + """ } } + /** - * Expressions that require a specific `DataType` as input should implement this trait - * so that the proper type conversions can be performed in the analyzer. + * An expression that has two inputs that are expected to the be same type. If the two inputs have + * different types, the analyzer will find the tightest common type and do the proper type casting. */ -trait AutoCastInputTypes { self: Expression => +abstract class BinaryOperator extends BinaryExpression { + self: Product => - def inputTypes: Seq[DataType] + def symbol: String - override def checkInputDataTypes(): TypeCheckResult = { - // We will always do type casting for `AutoCastInputTypes` in `HiveTypeCoercion`, - // so type mismatch error won't be reported here, but for underling `Cast`s. - TypeCheckResult.TypeCheckSuccess - } + override def toString: String = s"($left $symbol $right)" +} + + +private[sql] object BinaryOperator { + def unapply(e: BinaryOperator): Option[(Expression, Expression)] = Some((e.left, e.right)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index ebabb6f117851..caf021b016a41 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -29,7 +29,7 @@ case class ScalaUDF(function: AnyRef, dataType: DataType, children: Seq[Expressi override def nullable: Boolean = true - override def toString: String = s"scalaUDF(${children.mkString(",")})" + override def toString: String = s"UDF(${children.mkString(",")})" // scalastyle:off diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index a9fc54c548f49..da520f56b430e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -128,7 +128,6 @@ case class Max(child: Expression) extends PartialAggregate with trees.UnaryNode[ override def nullable: Boolean = true override def dataType: DataType = child.dataType - override def toString: String = s"MAX($child)" override def asPartial: SplitEvaluation = { val partialMax = Alias(Max(child), "PartialMax")() @@ -162,7 +161,6 @@ case class Count(child: Expression) extends PartialAggregate with trees.UnaryNod override def nullable: Boolean = false override def dataType: LongType.type = LongType - override def toString: String = s"COUNT($child)" override def asPartial: SplitEvaluation = { val partialCount = Alias(Count(child), "PartialCount")() @@ -401,8 +399,6 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN DoubleType } - override def toString: String = s"AVG($child)" - override def asPartial: SplitEvaluation = { child.dataType match { case DecimalType.Fixed(_, _) | DecimalType.Unlimited => @@ -494,8 +490,6 @@ case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[ child.dataType } - override def toString: String = s"SUM($child)" - override def asPartial: SplitEvaluation = { child.dataType match { case DecimalType.Fixed(_, _) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 5363b3556886a..4fbf4c87009c2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -57,7 +57,7 @@ case class UnaryMinus(child: Expression) extends UnaryArithmetic { } case class UnaryPositive(child: Expression) extends UnaryArithmetic { - override def toString: String = s"positive($child)" + override def prettyName: String = "positive" override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = defineCodeGen(ctx, ev, c => c) @@ -69,8 +69,6 @@ case class UnaryPositive(child: Expression) extends UnaryArithmetic { * A function that get the absolute value of the numeric value. */ case class Abs(child: Expression) extends UnaryArithmetic { - override def toString: String = s"Abs($child)" - override def checkInputDataTypes(): TypeCheckResult = TypeUtils.checkForNumericExpr(child.dataType, "function abs") @@ -79,10 +77,9 @@ case class Abs(child: Expression) extends UnaryArithmetic { protected override def evalInternal(evalE: Any) = numeric.abs(evalE) } -abstract class BinaryArithmetic extends BinaryExpression { +abstract class BinaryArithmetic extends BinaryOperator { self: Product => - override def dataType: DataType = left.dataType override def checkInputDataTypes(): TypeCheckResult = { @@ -360,7 +357,9 @@ case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic { } """ } - override def toString: String = s"MaxOf($left, $right)" + + override def symbol: String = "max" + override def prettyName: String = symbol } case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic { @@ -413,5 +412,6 @@ case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic { """ } - override def toString: String = s"MinOf($left, $right)" + override def symbol: String = "min" + override def prettyName: String = symbol } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 5def57b067424..67e7dc4ec8b14 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -43,7 +43,7 @@ case class CreateArray(children: Seq[Expression]) extends Expression { children.map(_.eval(input)) } - override def toString: String = s"Array(${children.mkString(",")})" + override def prettyName: String = "array" } /** @@ -71,4 +71,6 @@ case class CreateStruct(children: Seq[Expression]) extends Expression { override def eval(input: InternalRow): Any = { InternalRow(children.map(_.eval(input)): _*) } + + override def prettyName: String = "struct" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala index 78be2824347d7..145d323a9f0bb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala @@ -38,8 +38,6 @@ case class Coalesce(children: Seq[Expression]) extends Expression { } } - override def toString: String = s"Coalesce(${children.mkString(",")})" - override def dataType: DataType = children.head.dataType override def eval(input: InternalRow): Any = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index a777f77add2db..34df89a163895 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -120,7 +120,7 @@ case class InSet(value: Expression, hset: Set[Any]) } case class And(left: Expression, right: Expression) - extends BinaryExpression with Predicate with AutoCastInputTypes { + extends BinaryOperator with Predicate with AutoCastInputTypes { override def inputTypes: Seq[DataType] = Seq(BooleanType, BooleanType) @@ -169,7 +169,7 @@ case class And(left: Expression, right: Expression) } case class Or(left: Expression, right: Expression) - extends BinaryExpression with Predicate with AutoCastInputTypes { + extends BinaryOperator with Predicate with AutoCastInputTypes { override def inputTypes: Seq[DataType] = Seq(BooleanType, BooleanType) @@ -217,7 +217,7 @@ case class Or(left: Expression, right: Expression) } } -abstract class BinaryComparison extends BinaryExpression with Predicate { +abstract class BinaryComparison extends BinaryOperator with Predicate { self: Product => override def checkInputDataTypes(): TypeCheckResult = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala index daa9f4403ffab..5d51a4ca65332 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala @@ -137,8 +137,6 @@ case class CombineSets(left: Expression, right: Expression) extends BinaryExpres override def dataType: DataType = left.dataType - override def symbol: String = "++=" - override def eval(input: InternalRow): Any = { val leftEval = left.eval(input).asInstanceOf[OpenHashSet[Any]] if(leftEval != null) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 4cbfc4e084948..b020f2bbc5818 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -75,8 +75,6 @@ trait StringRegexExpression extends AutoCastInputTypes { case class Like(left: Expression, right: Expression) extends BinaryExpression with StringRegexExpression { - override def symbol: String = "LIKE" - // replace the _ with .{1} exactly match 1 time of any character // replace the % with .*, match 0 or more times with any character override def escape(v: String): String = @@ -101,14 +99,16 @@ case class Like(left: Expression, right: Expression) } override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).matches() + + override def toString: String = s"$left LIKE $right" } case class RLike(left: Expression, right: Expression) extends BinaryExpression with StringRegexExpression { - override def symbol: String = "RLIKE" override def escape(v: String): String = v override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).find(0) + override def toString: String = s"$left RLIKE $right" } trait CaseConversionExpression extends AutoCastInputTypes { @@ -134,9 +134,7 @@ trait CaseConversionExpression extends AutoCastInputTypes { */ case class Upper(child: Expression) extends UnaryExpression with CaseConversionExpression { - override def convert(v: UTF8String): UTF8String = v.toUpperCase() - - override def toString: String = s"Upper($child)" + override def convert(v: UTF8String): UTF8String = v.toUpperCase override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, c => s"($c).toUpperCase()") @@ -148,9 +146,7 @@ case class Upper(child: Expression) extends UnaryExpression with CaseConversionE */ case class Lower(child: Expression) extends UnaryExpression with CaseConversionExpression { - override def convert(v: UTF8String): UTF8String = v.toLowerCase() - - override def toString: String = s"Lower($child)" + override def convert(v: UTF8String): UTF8String = v.toLowerCase override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, c => s"($c).toLowerCase()") @@ -178,8 +174,6 @@ trait StringComparison extends AutoCastInputTypes { } } - override def symbol: String = nodeName - override def toString: String = s"$nodeName($left, $right)" } @@ -284,12 +278,6 @@ case class Substring(str: Expression, pos: Expression, len: Expression) } } } - - override def toString: String = len match { - // TODO: This is broken because max is not an integer value. - case max if max == Integer.MAX_VALUE => s"SUBSTR($str, $pos)" - case _ => s"SUBSTR($str, $pos, $len)" - } } /** @@ -304,9 +292,9 @@ case class StringLength(child: Expression) extends UnaryExpression with AutoCast if (string == null) null else string.asInstanceOf[UTF8String].length } - override def toString: String = s"length($child)" - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, c => s"($c).length()") } + + override def prettyName: String = "length" } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index bda217935cb05..86792f0217572 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -73,7 +73,7 @@ class TreeNodeSuite extends SparkFunSuite { val expected = Seq("+", "1", "*", "2", "-", "3", "4") val expression = Add(Literal(1), Multiply(Literal(2), Subtract(Literal(3), Literal(4)))) expression transformDown { - case b: BinaryExpression => actual.append(b.symbol); b + case b: BinaryOperator => actual.append(b.symbol); b case l: Literal => actual.append(l.toString); l } @@ -85,7 +85,7 @@ class TreeNodeSuite extends SparkFunSuite { val expected = Seq("1", "2", "3", "4", "-", "*", "+") val expression = Add(Literal(1), Multiply(Literal(2), Subtract(Literal(3), Literal(4)))) expression transformUp { - case b: BinaryExpression => actual.append(b.symbol); b + case b: BinaryOperator => actual.append(b.symbol); b case l: Literal => actual.append(l.toString); l } @@ -125,7 +125,7 @@ class TreeNodeSuite extends SparkFunSuite { val expected = Seq("1", "2", "3", "4", "-", "*", "+") val expression = Add(Literal(1), Multiply(Literal(2), Subtract(Literal(3), Literal(4)))) expression foreachUp { - case b: BinaryExpression => actual.append(b.symbol); + case b: BinaryOperator => actual.append(b.symbol); case l: Literal => actual.append(l.toString); }