From ec4731d952765d586f641de0c4a056772e2c1097 Mon Sep 17 00:00:00 2001 From: Kevin Yu Date: Fri, 12 Feb 2016 10:22:02 -0800 Subject: [PATCH 1/2] avoid checking nullability for complex data type --- .../sql/catalyst/expressions/Expression.scala | 2 + .../sql/catalyst/expressions/SortOrder.scala | 3 ++ .../expressions/aggregate/First.scala | 2 + .../catalyst/expressions/aggregate/Last.scala | 2 + .../catalyst/expressions/aggregate/Max.scala | 2 + .../catalyst/expressions/aggregate/Min.scala | 2 + .../catalyst/expressions/aggregate/Sum.scala | 8 ++++ .../expressions/aggregate/interfaces.scala | 1 + .../sql/catalyst/expressions/arithmetic.scala | 7 +++ .../expressions/bitwiseExpressions.scala | 2 + .../expressions/collectionOperations.scala | 1 + .../expressions/complexTypeCreator.scala | 46 +++++++++++++++++++ .../expressions/complexTypeExtractors.scala | 6 +++ .../expressions/conditionalExpressions.scala | 8 ++++ .../expressions/decimalExpressions.scala | 1 + .../expressions/mathExpressions.scala | 27 +++++++++++ .../expressions/namedExpressions.scala | 4 +- .../expressions/nullExpressions.scala | 4 ++ .../sql/catalyst/expressions/objects.scala | 6 +++ .../expressions/stringExpressions.scala | 2 + .../expressions/windowExpressions.scala | 3 ++ .../spark/sql/DataFrameComplexTypeSuite.scala | 11 +++++ 22 files changed, 149 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 7dacdafb7141d..9ff89a11a6bad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -137,6 +137,8 @@ abstract class Expression extends TreeNode[Expression] { */ def dataType: DataType + def prettyDataType: DataType = dataType + /** * Returns true if all the children of this expression have been resolved to a specific schema * and false if any still contains any unresolved placeholders. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala index e0c3b22a3c389..e505b568b7c0e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala @@ -55,6 +55,9 @@ case class SortOrder(child: Expression, direction: SortDirection) } override def dataType: DataType = child.dataType + + override def prettyDataType: DataType = child.prettyDataType + override def nullable: Boolean = child.nullable override def toString: String = s"$child ${direction.sql}" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala index b8ab0364dd8f3..4d5dc65fd624d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala @@ -53,6 +53,8 @@ case class First(child: Expression, ignoreNullsExpr: Expression) extends Declara // Return data type. override def dataType: DataType = child.dataType + override def prettyDataType: DataType = child.prettyDataType + // Expected input data type. override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala index b05d74b49b591..27dbcf6d6c159 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala @@ -50,6 +50,8 @@ case class Last(child: Expression, ignoreNullsExpr: Expression) extends Declarat // Return data type. override def dataType: DataType = child.dataType + override def prettyDataType: DataType = child.prettyDataType + // Expected input data type. override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala index c534fe495fc13..f6b460eb43d70 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala @@ -33,6 +33,8 @@ case class Max(child: Expression) extends DeclarativeAggregate { // Return data type. override def dataType: DataType = child.dataType + override def prettyDataType: DataType = child.prettyDataType + // Expected input data type. override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala index 35289b468183c..cf3742bbce711 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala @@ -33,6 +33,8 @@ case class Min(child: Expression) extends DeclarativeAggregate { // Return data type. override def dataType: DataType = child.dataType + override def prettyDataType: DataType = child.prettyDataType + // Expected input data type. override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala index ad217f25b5a26..788536b7a2b8f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala @@ -33,6 +33,8 @@ case class Sum(child: Expression) extends DeclarativeAggregate { // Return data type. override def dataType: DataType = resultType + override def prettyDataType: DataType = prettyResultType + override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(LongType, DoubleType, DecimalType)) @@ -45,6 +47,12 @@ case class Sum(child: Expression) extends DeclarativeAggregate { case _ => child.dataType } + private lazy val prettyResultType = child.prettyDataType match { + case DecimalType.Fixed(precision, scale) => + DecimalType.bounded(precision + 10, scale) + case _ => child.prettyDataType + } + private lazy val sumDataType = resultType private lazy val sum = AttributeReference("sum", sumDataType)() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index d31ccf9985360..f1ee4f868c451 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -114,6 +114,7 @@ private[sql] case class AggregateExpression( override def children: Seq[Expression] = aggregateFunction :: Nil override def dataType: DataType = aggregateFunction.dataType + override def prettyDataType: DataType = aggregateFunction.prettyDataType override def foldable: Boolean = false override def nullable: Boolean = aggregateFunction.nullable diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index b2df79a58884b..95bbabb8f13ee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -32,6 +32,8 @@ case class UnaryMinus(child: Expression) extends UnaryExpression override def dataType: DataType = child.dataType + override def prettyDataType: DataType = child.prettyDataType + override def toString: String = s"-$child" private lazy val numeric = TypeUtils.getNumeric(dataType) @@ -69,6 +71,7 @@ case class UnaryPositive(child: Expression) override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection.NumericAndInterval) override def dataType: DataType = child.dataType + override def prettyDataType: DataType = child.prettyDataType override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = defineCodeGen(ctx, ev, c => c) @@ -91,6 +94,8 @@ case class Abs(child: Expression) override def dataType: DataType = child.dataType + override def prettyDataType: DataType = child.prettyDataType + private lazy val numeric = TypeUtils.getNumeric(dataType) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = dataType match { @@ -107,6 +112,8 @@ abstract class BinaryArithmetic extends BinaryOperator { override def dataType: DataType = left.dataType + override def prettyDataType: DataType = left.prettyDataType + override lazy val resolved = childrenResolved && checkInputDataTypes().isSuccess /** Name of the function for this expression on a [[Decimal]] type. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala index 3a0a882e3876e..bdbd3fa1d9d51 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala @@ -117,6 +117,8 @@ case class BitwiseNot(child: Expression) extends UnaryExpression with ExpectsInp override def dataType: DataType = child.dataType + override def prettyDataType: DataType = child.prettyDataType + override def toString: String = s"~$child" private lazy val not: (Any) => Any = dataType match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index c71cb73d65bf6..8b2d3a9bb6ebd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -60,6 +60,7 @@ case class SortArray(base: Expression, ascendingOrder: Expression) override def left: Expression = base override def right: Expression = ascendingOrder override def dataType: DataType = base.dataType + override def prettyDataType: DataType = base.prettyDataType override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, BooleanType) override def checkInputDataTypes(): TypeCheckResult = base.dataType match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 3d4819c55a2d5..ad098f28c6afd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -42,6 +42,11 @@ case class CreateArray(children: Seq[Expression]) extends Expression { containsNull = children.exists(_.nullable)) } + override def prettyDataType: DataType = { + ArrayType( + children.headOption.map(_.prettyDataType).getOrElse(NullType)) + } + override def nullable: Boolean = false override def eval(input: InternalRow): Any = { @@ -173,6 +178,19 @@ case class CreateStruct(children: Seq[Expression]) extends Expression { StructType(fields) } + override lazy val prettyDataType: StructType = { + val fields = children.zipWithIndex.map { case (child, idx) => + child match { + case ne: NamedExpression => + StructField(ne.name, ne.prettyDataType, nullable = true, ne.metadata) + case _ => + // give the default value to the nullable, since we just care about datatype + StructField(s"col${idx + 1}", child.prettyDataType, nullable = true, Metadata.empty) + } + } + StructType(fields) + } + override def nullable: Boolean = false override def eval(input: InternalRow): Any = { @@ -233,6 +251,14 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression { StructType(fields) } + override lazy val prettyDataType: StructType = { + val fields = names.zip(valExprs).map { case (name, valExpr) => + StructField(name.asInstanceOf[UTF8String].toString, + valExpr.prettyDataType, nullable = true, Metadata.empty) + } + StructType(fields) + } + override def foldable: Boolean = valExprs.forall(_.foldable) override def nullable: Boolean = false @@ -302,6 +328,19 @@ case class CreateStructUnsafe(children: Seq[Expression]) extends Expression { StructType(fields) } + override lazy val prettyDataType: StructType = { + val fields = children.zipWithIndex.map { case (child, idx) => + child match { + case ne: NamedExpression => + StructField(ne.name, ne.prettyDataType, nullable = true, ne.metadata) + case _ => + // for the prettyDataType, we just give the default nullable value + StructField(s"col${idx + 1}", child.prettyDataType, nullable = true, Metadata.empty) + } + } + StructType(fields) + } + override def nullable: Boolean = false override def eval(input: InternalRow): Any = { @@ -338,6 +377,13 @@ case class CreateNamedStructUnsafe(children: Seq[Expression]) extends Expression StructType(fields) } + override lazy val prettyDataType: StructType = { + val fields = names.zip(valExprs).map { case (name, valExpr) => + StructField(name, valExpr.prettyDataType, nullable = true, Metadata.empty) + } + StructType(fields) + } + override def foldable: Boolean = valExprs.forall(_.foldable) override def nullable: Boolean = false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 3b4468f55ca73..ee98b7a923774 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -107,8 +107,10 @@ case class GetStructField(child: Expression, ordinal: Int, name: Option[String] extends UnaryExpression with ExtractValue { private[sql] lazy val childSchema = child.dataType.asInstanceOf[StructType] + private[sql] lazy val prettyChildSchema = child.prettyDataType.asInstanceOf[StructType] override def dataType: DataType = childSchema(ordinal).dataType + override def prettyDataType: DataType = prettyChildSchema(ordinal).dataType override def nullable: Boolean = child.nullable || childSchema(ordinal).nullable override def toString: String = { @@ -229,6 +231,8 @@ case class GetArrayItem(child: Expression, ordinal: Expression) override def dataType: DataType = child.dataType.asInstanceOf[ArrayType].elementType + override def prettyDataType: DataType = child.prettyDataType.asInstanceOf[ArrayType].elementType + protected override def nullSafeEval(value: Any, ordinal: Any): Any = { val baseValue = value.asInstanceOf[ArrayData] val index = ordinal.asInstanceOf[Number].intValue() @@ -278,6 +282,8 @@ case class GetMapValue(child: Expression, key: Expression) override def dataType: DataType = child.dataType.asInstanceOf[MapType].valueType + override def prettyDataType: DataType = child.prettyDataType.asInstanceOf[MapType].valueType + // todo: current search is O(n), improve it. protected override def nullSafeEval(value: Any, ordinal: Any): Any = { val map = value.asInstanceOf[MapData] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index e97e08947a500..10bc1f057c202 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -47,6 +47,8 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi override def dataType: DataType = trueValue.dataType + override def prettyDataType: DataType = trueValue.prettyDataType + override def eval(input: InternalRow): Any = { if (java.lang.Boolean.TRUE.equals(predicate.eval(input))) { trueValue.eval(input) @@ -102,6 +104,8 @@ abstract class CaseWhenBase( override def dataType: DataType = branches.head._2.dataType + override def prettyDataType: DataType = thenList.head.prettyDataType + override def nullable: Boolean = { // Result is nullable if any of the branch is nullable, or if the else value is nullable branches.exists(_._2.nullable) || elseValue.map(_.nullable).getOrElse(true) @@ -307,6 +311,8 @@ case class Least(children: Seq[Expression]) extends Expression { override def dataType: DataType = children.head.dataType + override def prettyDataType: DataType = children.head.prettyDataType + override def eval(input: InternalRow): Any = { children.foldLeft[Any](null)((r, c) => { val evalc = c.eval(input) @@ -367,6 +373,8 @@ case class Greatest(children: Seq[Expression]) extends Expression { override def dataType: DataType = children.head.dataType + override def prettyDataType: DataType = children.head.prettyDataType + override def eval(input: InternalRow): Any = { children.foldLeft[Any](null)((r, c) => { val evalc = c.eval(input) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala index fa5dea6841149..417352007c85f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala @@ -69,6 +69,7 @@ case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends Un */ case class PromotePrecision(child: Expression) extends UnaryExpression { override def dataType: DataType = child.dataType + override def prettyDataType: DataType = child.prettyDataType override def eval(input: InternalRow): Any = child.eval(input) override def genCode(ctx: CodegenContext): ExprCode = child.genCode(ctx) override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = ev.copy("") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index 5152265152aed..4cc9002880735 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -189,6 +189,13 @@ case class Ceil(child: Expression) extends UnaryMathExpression(math.ceil, "CEIL" case _ => LongType } + override def prettyDataType: DataType = child.prettyDataType match { + case dt @ DecimalType.Fixed(_, 0) => dt + case DecimalType.Fixed(precision, scale) => + DecimalType.bounded(precision - scale + 1, 0) + case _ => LongType + } + override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(DoubleType, DecimalType)) @@ -276,6 +283,13 @@ case class Floor(child: Expression) extends UnaryMathExpression(math.floor, "FLO case _ => LongType } + override def prettyDataType: DataType = child.prettyDataType match { + case dt @ DecimalType.Fixed(_, 0) => dt + case DecimalType.Fixed(precision, scale) => + DecimalType.bounded(precision - scale + 1, 0) + case _ => LongType + } + override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(DoubleType, DecimalType)) @@ -646,6 +660,8 @@ case class ShiftLeft(left: Expression, right: Expression) override def dataType: DataType = left.dataType + override def prettyDataType: DataType = left.prettyDataType + protected override def nullSafeEval(input1: Any, input2: Any): Any = { input1 match { case l: jl.Long => l << input2.asInstanceOf[jl.Integer] @@ -676,6 +692,8 @@ case class ShiftRight(left: Expression, right: Expression) override def dataType: DataType = left.dataType + override def prettyDataType: DataType = left.prettyDataType + protected override def nullSafeEval(input1: Any, input2: Any): Any = { input1 match { case l: jl.Long => l >> input2.asInstanceOf[jl.Integer] @@ -706,6 +724,8 @@ case class ShiftRightUnsigned(left: Expression, right: Expression) override def dataType: DataType = left.dataType + override def prettyDataType: DataType = left.prettyDataType + protected override def nullSafeEval(input1: Any, input2: Any): Any = { input1 match { case l: jl.Long => l >>> input2.asInstanceOf[jl.Integer] @@ -810,6 +830,13 @@ abstract class RoundBase(child: Expression, scale: Expression, case t => t } + override lazy val prettyDataType: DataType = child.prettyDataType match { + // if the new scale is bigger which means we are scaling up, + // keep the original scale as `Decimal` does + case DecimalType.Fixed(p, s) => DecimalType(p, if (_scale > s) s else _scale) + case t => t + } + override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, IntegerType) override def checkInputDataTypes(): TypeCheckResult = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index c083f12724dbb..e6ac11fa95895 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -86,9 +86,10 @@ trait NamedExpression extends Expression { /** Returns a copy of this expression with a new `exprId`. */ def newInstance(): NamedExpression + /** avoid checking nullability for complex data type */ protected def typeSuffix = if (resolved) { - dataType match { + prettyDataType match { case LongType => "L" case _ => "" } @@ -146,6 +147,7 @@ case class Alias(child: Expression, name: String)( override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = ev.copy("") override def dataType: DataType = child.dataType + override def prettyDataType: DataType = child.prettyDataType override def nullable: Boolean = child.nullable override def metadata: Metadata = { explicitMetadata.getOrElse { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index 421200e147b7a..0ae9d63511769 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -55,6 +55,8 @@ case class Coalesce(children: Seq[Expression]) extends Expression { override def dataType: DataType = children.head.dataType + override def prettyDataType: DataType = children.head.prettyDataType + override def eval(input: InternalRow): Any = { var result: Any = null val childIterator = children.iterator @@ -136,6 +138,8 @@ case class NaNvl(left: Expression, right: Expression) override def dataType: DataType = left.dataType + override def prettyDataType: DataType = left.prettyDataType + override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(DoubleType, FloatType), TypeCollection(DoubleType, FloatType)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index 1e418540a2624..52853b15d5a60 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -444,6 +444,8 @@ case class MapObjects private( override def dataType: DataType = ArrayType(lambdaFunction.dataType) + override def prettyDataType: DataType = ArrayType(lambdaFunction.prettyDataType) + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val javaType = ctx.javaType(dataType) val elementJavaType = ctx.javaType(loopVar.dataType) @@ -643,6 +645,8 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp override def children: Seq[Expression] = beanInstance +: setters.values.toSeq override def dataType: DataType = beanInstance.dataType + override def prettyDataType: DataType = beanInstance.prettyDataType + override def eval(input: InternalRow): Any = throw new UnsupportedOperationException("Only code-generated evaluation is supported.") @@ -683,6 +687,8 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String]) override def dataType: DataType = child.dataType + override def prettyDataType: DataType = child.prettyDataType + override def nullable: Boolean = false override def eval(input: InternalRow): Any = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 78e846d3f580e..f882042aab232 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -784,6 +784,8 @@ case class Substring(str: Expression, pos: Expression, len: Expression) override def dataType: DataType = str.dataType + override def prettyDataType: DataType = str.prettyDataType + override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, BinaryType), IntegerType, IntegerType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index c0b453dccf5e9..b2f1a6a44b1d2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -292,6 +292,7 @@ case class WindowExpression( override def children: Seq[Expression] = windowFunction :: windowSpec :: Nil override def dataType: DataType = windowFunction.dataType + override def prettyDataType: DataType = windowFunction.prettyDataType override def foldable: Boolean = windowFunction.foldable override def nullable: Boolean = windowFunction.nullable @@ -366,6 +367,8 @@ abstract class OffsetWindowFunction override def dataType: DataType = input.dataType + override def prettyDataType: DataType = input.prettyDataType + override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, IntegerType, TypeCollection(input.dataType, NullType)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala index 72f676e6225ee..df126ac65cc56 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala @@ -58,4 +58,15 @@ class DataFrameComplexTypeSuite extends QueryTest with SharedSQLContext { val nullIntRow = df.selectExpr("i[1]").collect()(0) assert(nullIntRow == org.apache.spark.sql.Row(null)) } + + test("SPARK-13253") { + val data = sparkContext.parallelize(Array.range(0, 10).map(x => (x, x + 1))) + val df = data.toDF("a", "b") + val arrayCol1 = functions.array(df("a"), df("b")).as("arrayCol1") + arrayCol1.toString + + val arrayCol2 = functions.struct(df("a"), df("b")).as("arrayCol2") + arrayCol2.toString + } + } From 86b7556e976868b375a1d8dcedc5f200be871e23 Mon Sep 17 00:00:00 2001 From: Kevin Yu Date: Wed, 20 Apr 2016 00:11:51 -0700 Subject: [PATCH 2/2] fixing rebase conditionEpxressions --- .../spark/sql/catalyst/expressions/conditionalExpressions.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index 10bc1f057c202..7c0ee1fec8b9b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -104,7 +104,7 @@ abstract class CaseWhenBase( override def dataType: DataType = branches.head._2.dataType - override def prettyDataType: DataType = thenList.head.prettyDataType + override def prettyDataType: DataType = branches.head._2.prettyDataType override def nullable: Boolean = { // Result is nullable if any of the branch is nullable, or if the else value is nullable