diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/types/DataTypes.java b/sql/catalyst/src/main/java/org/apache/spark/sql/types/DataTypes.java index 0f8570fe470bd..3be6505dc045f 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/types/DataTypes.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/types/DataTypes.java @@ -141,7 +141,7 @@ public static MapType createMapType(DataType keyType, DataType valueType) { if (valueType == null) { throw new IllegalArgumentException("valueType should not be null."); } - return new MapType(keyType, valueType, true); + return new MapType(keyType, valueType, true, false); } /** @@ -159,7 +159,7 @@ public static MapType createMapType( if (valueType == null) { throw new IllegalArgumentException("valueType should not be null."); } - return new MapType(keyType, valueType, valueContainsNull); + return new MapType(keyType, valueType, valueContainsNull, false); } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 0155741ddbc1d..886ef89694502 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -104,6 +104,7 @@ class Analyzer( ResolveAggregateFunctions :: TimeWindowing :: ResolveInlineTables :: + SortMaps :: TypeCoercion.typeCoercionRules ++ extendedResolutionRules : _*), Batch("Nondeterministic", Once, @@ -2329,3 +2330,49 @@ object ResolveCreateNamedStruct extends Rule[LogicalPlan] { CreateNamedStruct(children.toList) } } + +/** + * MapType expressions are not comparable. + */ +object SortMaps extends Rule[LogicalPlan] { + private def containsUnorderedMap(e: Expression): Boolean = + e.resolved && MapType.containsUnorderedMap(e.dataType) + + override def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressions { + case cmp @ BinaryComparison(left, right) if containsUnorderedMap(left) => + cmp.withNewChildren(OrderMaps(left) :: right :: Nil) + case cmp @ BinaryComparison(left, right) if containsUnorderedMap(right) => + cmp.withNewChildren(left :: OrderMaps(right) :: Nil) + case sort: SortOrder if containsUnorderedMap(sort.child) => + sort.copy(child = OrderMaps(sort.child)) + } transform { + case a: Aggregate if a.resolved && a.groupingExpressions.exists(containsUnorderedMap) => + // Modify the top level grouping expressions + val replacements = a.groupingExpressions.collect { + case a: Attribute if containsUnorderedMap(a) => + a -> Alias(OrderMaps(a), a.name)(exprId = a.exprId, qualifier = a.qualifier) + case e if containsUnorderedMap(e) => + e -> OrderMaps(e) + } + + // Tranform the expression tree. + a.transformExpressionsUp { + case e => + // TODO create an expression map! + replacements + .find(_._1.semanticEquals(e)) + .map(_._2) + .getOrElse(e) + } + + case Distinct(child) if child.resolved && child.output.exists(containsUnorderedMap) => + val projectList = child.output.map { a => + if (containsUnorderedMap(a)) { + Alias(OrderMaps(a), a.name)(exprId = a.exprId, qualifier = a.qualifier) + } else { + a + } + } + Distinct(Project(projectList, child)) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 26d26385904f6..64e175cdcfa4b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.UsingJoin import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ /** @@ -244,7 +245,7 @@ trait CheckAnalysis extends PredicateHelper { def checkValidGroupingExprs(expr: Expression): Unit = { // Check if the data type of expr is orderable. - if (!RowOrdering.isOrderable(expr.dataType)) { + if (!TypeUtils.isOrderable(expr.dataType)) { failAnalysis( s"expression ${expr.sql} cannot be used as a grouping expression " + s"because its data type ${expr.dataType.simpleString} is not an orderable " + @@ -265,7 +266,7 @@ trait CheckAnalysis extends PredicateHelper { case Sort(orders, _, _) => orders.foreach { order => - if (!RowOrdering.isOrderable(order.dataType)) { + if (!TypeUtils.isOrderable(order.dataType)) { failAnalysis( s"sorting is not supported for columns of type ${order.dataType.simpleString}") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index e95e97b9dc6cb..3e65b40b33d8a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -133,7 +133,7 @@ object RowEncoder { ObjectType(classOf[Object])) } - case t @ MapType(kt, vt, valueNullable) => + case t @ MapType(kt, vt, valueNullable, _) => val keys = Invoke( Invoke(inputObject, "keysIterator", ObjectType(classOf[scala.collection.Iterator[_]])), @@ -279,7 +279,7 @@ object RowEncoder { "make", arrayData :: Nil) - case MapType(kt, vt, valueNullable) => + case MapType(kt, vt, valueNullable, _) => val keyArrayType = ArrayType(kt, false) val keyData = deserializerFor(Invoke(input, "keyArray", keyArrayType)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 4db1ae6faa159..f57d4b2022382 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -67,7 +67,7 @@ object Cast { canCast(fromType, toType) && resolvableNullability(fn || forceNullable(fromType, toType), tn) - case (MapType(fromKey, fromValue, fn), MapType(toKey, toValue, tn)) => + case (MapType(fromKey, fromValue, fn, _), MapType(toKey, toValue, tn, false)) => canCast(fromKey, toKey) && (!forceNullable(fromKey, toKey)) && canCast(fromValue, toValue) && diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala index 3bebd552ef51a..653d8af43119d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ import org.apache.spark.util.collection.unsafe.sort.PrefixComparators._ @@ -61,7 +62,7 @@ case class SortOrder(child: Expression, direction: SortDirection, nullOrdering: override def foldable: Boolean = false override def checkInputDataTypes(): TypeCheckResult = { - if (RowOrdering.isOrderable(dataType)) { + if (TypeUtils.isOrderable(dataType)) { TypeCheckResult.TypeCheckSuccess } else { TypeCheckResult.TypeCheckFailure(s"cannot sort data type ${dataType.simpleString}") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 09007b7c89fe3..3088844cb3d3c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -484,6 +484,7 @@ class CodegenContext { case dt: DataType if dt.isInstanceOf[AtomicType] => s"$c1.equals($c2)" case array: ArrayType => genComp(array, c1, c2) + " == 0" case struct: StructType => genComp(struct, c1, c2) + " == 0" + case map: MapType if map.ordered => genComp(map, c1, c2) + " == 0" case udt: UserDefinedType[_] => genEqual(udt.sqlType, c1, c2) case _ => throw new IllegalArgumentException( @@ -554,6 +555,47 @@ class CodegenContext { """ addNewFunction(compareFunc, funcCode) s"this.$compareFunc($c1, $c2)" + case MapType(keyType, valueType, _, true) => + val compareFunc = freshName("compareMap") + val funcCode: String = + s""" + public int $compareFunc(MapData a, MapData b) { + int lengthA = a.numElements(); + int lengthB = b.numElements(); + ArrayData aKeys = a.keyArray(); + ArrayData aValues = a.valueArray(); + ArrayData bKeys = b.keyArray(); + ArrayData bValues = b.valueArray(); + int minLength = (lengthA > lengthB) ? lengthB : lengthA; + for (int i = 0; i < minLength; i++) { + ${javaType(keyType)} keyA = ${getValue("aKeys", keyType, "i")}; + ${javaType(keyType)} keyB = ${getValue("bKeys", keyType, "i")}; + int comp = ${genComp(keyType, "keyA", "keyB")}; + if (comp != 0) { + return comp; + } + boolean isNullA = aValues.isNullAt(i); + boolean isNullB = bValues.isNullAt(i); + if (isNullA && isNullB) { + // Nothing + } else if (isNullA) { + return -1; + } else if (isNullB) { + return 1; + } else { + ${javaType(valueType)} valueA = ${getValue("aValues", valueType, "i")}; + ${javaType(valueType)} valueB = ${getValue("bValues", valueType, "i")}; + comp = ${genComp(valueType, "valueA", "valueB")}; + if (comp != 0) { + return comp; + } + } + } + return lengthA - lengthB; + } + """ + addNewFunction(compareFunc, funcCode) + s"this.$compareFunc($c1, $c2)" case schema: StructType => INPUT_ROW = "i" val comparisons = GenerateOrdering.genComparisons(this, schema) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index b1cb6edefb852..25c60ac5a9f1b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -130,7 +130,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] dataType: DataType): ExprCode = dataType match { case s: StructType => createCodeForStruct(ctx, input, s) case ArrayType(elementType, _) => createCodeForArray(ctx, input, elementType) - case MapType(keyType, valueType, _) => createCodeForMap(ctx, input, keyType, valueType) + case MapType(keyType, valueType, _, _) => createCodeForMap(ctx, input, keyType, valueType) // UTF8String act as a pointer if it's inside UnsafeRow, so copy it to make it safe. case StringType => ExprCode("", "false", s"$input.clone()") case udt: UserDefinedType[_] => convertToSafe(ctx, input, udt.sqlType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 7e4c9089a2cb9..604e307f9c09f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -38,7 +38,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro case _: CalendarIntervalType => true case t: StructType => t.toSeq.forall(field => canSupport(field.dataType)) case t: ArrayType if canSupport(t.elementType) => true - case MapType(kt, vt, _) if canSupport(kt) && canSupport(vt) => true + case MapType(kt, vt, _, _) if canSupport(kt) && canSupport(vt) => true case udt: UserDefinedType[_] => canSupport(udt.sqlType) case _ => false } @@ -126,7 +126,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro $rowWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor); """ - case m @ MapType(kt, vt, _) => + case m @ MapType(kt, vt, _, _) => s""" // Remember the current cursor so that we can calculate how many bytes are // written later. @@ -209,7 +209,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro $arrayWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor); """ - case m @ MapType(kt, vt, _) => + case m @ MapType(kt, vt, _, _) => s""" final int $tmpCursor = $bufferHolder.cursor; ${writeMapToBuffer(ctx, element, kt, vt, bufferHolder)} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index c863ba434120d..c6e990eba4996 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -21,7 +21,7 @@ import java.util.Comparator import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode} -import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData} +import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.types._ /** @@ -139,7 +139,7 @@ case class SortArray(base: Expression, ascendingOrder: Expression) override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, BooleanType) override def checkInputDataTypes(): TypeCheckResult = base.dataType match { - case ArrayType(dt, _) if RowOrdering.isOrderable(dt) => + case ArrayType(dt, _) if TypeUtils.isOrderable(dt) => ascendingOrder match { case Literal(_: Boolean, BooleanType) => TypeCheckResult.TypeCheckSuccess @@ -287,3 +287,144 @@ case class ArrayContains(left: Expression, right: Expression) override def prettyName: String = "array_contains" } + +/** + * This expression orders all maps in an expression's result. This expression enables the use of + * maps in comparisons and equality operations. + */ +case class OrderMaps(child: Expression) extends UnaryExpression with ExpectsInputTypes { + + override def inputTypes: Seq[AbstractDataType] = + Seq(TypeCollection(ArrayType, MapType, StructType)) + + /** Create a data type in which all maps are ordered. */ + private[this] def createDataType(dataType: DataType): DataType = dataType match { + case StructType(fields) => + StructType(fields.map { field => + field.copy(dataType = createDataType(field.dataType)) + }) + case ArrayType(elementType, containsNull) => + ArrayType(createDataType(elementType), containsNull) + case MapType(keyType, valueType, valueContainsNull, false) => + MapType( + createDataType(keyType), + createDataType(valueType), + valueContainsNull, + ordered = true) + case _ => + dataType + } + + override lazy val dataType: DataType = createDataType(child.dataType) + + private[this] val identity = (id: Any) => id + + /** + * Create a function that transforms a Spark SQL datum to a new datum for which all MapData + * elements have been ordered. + */ + private[this] def createTransform(dataType: DataType): Option[Any => Any] = { + dataType match { + case m @ MapType(keyType, valueType, _, false) => + val keyTransform = createTransform(keyType).getOrElse(identity) + val valueTransform = createTransform(valueType).getOrElse(identity) + val ordering = Ordering.Tuple2(m.interpretedKeyOrdering, m.interpretedValueOrdering) + Option((data: Any) => { + val input = data.asInstanceOf[MapData] + val length = input.numElements() + val buffer = Array.ofDim[(Any, Any)](length) + + // Move the entries into a temporary buffer. + var i = 0 + val keys = input.keyArray() + val values = input.valueArray() + while (i < length) { + val key = keyTransform(keys.get(i, keyType)) + val value = if (!values.isNullAt(i)) { + valueTransform(values.get(i, valueType)) + } else { + null + } + buffer(i) = key -> value + i += 1 + } + + // Sort the buffer. + java.util.Arrays.sort(buffer, ordering) + + // Recreate the map data. + i = 0 + val sortedKeys = Array.ofDim[Any](length) + val sortedValues = Array.ofDim[Any](length) + while (i < length) { + sortedKeys(i) = buffer(i)._1 + sortedValues(i) = buffer(i)._2 + i += 1 + } + ArrayBasedMapData(sortedKeys, sortedValues) + }) + case ArrayType(dt, _) => + createTransform(dt).map { transform => + data: Any => { + val input = data.asInstanceOf[ArrayData] + val length = input.numElements() + val output = Array.ofDim[Any](length) + var i = 0 + while (i < length) { + if (!input.isNullAt(i)) { + output(i) = transform(input.get(i, dt)) + } + i += i + } + new GenericArrayData(output) + } + } + case StructType(fields) => + val transformOpts = fields.map { field => + createTransform(field.dataType) + } + // Only transform a struct if a meaningful transformation has been defined. + if (transformOpts.exists(_.isDefined)) { + val transforms = transformOpts.zip(fields).map { case (opt, field) => + val dataType = field.dataType + val transform = opt.getOrElse(identity) + (input: InternalRow, i: Int) => { + transform(input.get(i, dataType)) + } + } + val length = fields.length + val tf = (data: Any) => { + val input = data.asInstanceOf[InternalRow] + val output = Array.ofDim[Any](length) + var i = 0 + while (i < length) { + if (!input.isNullAt(i)) { + output(i) = transforms(i)(input, i) + } + i += 1 + } + new GenericInternalRow(output) + } + Some(tf) + } else { + None + } + case _ => + None + } + } + + @transient private[this] lazy val transform = { + createTransform(child.dataType).getOrElse(identity) + } + + override protected def nullSafeEval(input: Any): Any = transform(input) + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + // TODO we should code generate this. + val tf = ctx.addReferenceObj("transform", transform, classOf[Any => Any].getCanonicalName) + nullSafeCodeGen(ctx, ev, eval => { + s"${ev.value} = (${ctx.boxedType(dataType)})$tf.apply($eval);" + }) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 0c256c3d890f1..ad0f21b00c7c4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -61,7 +61,7 @@ object ExtractValue { case (_: ArrayType, _) => GetArrayItem(child, extraction) - case (MapType(kt, _, _), _) => GetMapValue(child, extraction) + case (MapType(kt, _, _, _), _) => GetMapValue(child, extraction) case (otherType, _) => val errorMsg = otherType match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 6c38f4998e914..ac2c5bd1d32f5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -229,7 +229,7 @@ abstract class ExplodeBase extends UnaryExpression with CollectionGenerator with new StructType() .add("col", et, containsNull) } - case MapType(kt, vt, valueContainsNull) => + case MapType(kt, vt, valueContainsNull, _) => if (position) { new StructType() .add("pos", IntegerType, nullable = false) @@ -255,7 +255,7 @@ abstract class ExplodeBase extends UnaryExpression with CollectionGenerator with }) rows } - case MapType(kt, vt, _) => + case MapType(kt, vt, _, _) => val inputMap = child.eval(input).asInstanceOf[MapData] if (inputMap == null) { Nil diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index e14f0544c2b81..98b2640b4db69 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -407,7 +407,7 @@ abstract class HashExpression[E] extends Expression { case BinaryType => genHashBytes(input, result) case StringType => genHashString(input, result) case ArrayType(et, containsNull) => genHashForArray(ctx, input, result, et, containsNull) - case MapType(kt, vt, valueContainsNull) => + case MapType(kt, vt, valueContainsNull, _) => genHashForMap(ctx, input, result, kt, vt, valueContainsNull) case StructType(fields) => genHashForStruct(ctx, input, result, fields) case udt: UserDefinedType[_] => computeHashWithTailRec(input, udt.sqlType, result, ctx) @@ -474,7 +474,7 @@ abstract class InterpretedHashFunction { case udt: UserDefinedType[_] => val mapType = udt.sqlType.asInstanceOf[MapType] mapType.keyType -> mapType.valueType - case MapType(kt, vt, _) => kt -> vt + case MapType(kt, vt, _, _) => kt -> vt } val keys = map.keyArray() val values = map.valueArray() @@ -756,7 +756,7 @@ object HiveHashFunction extends InterpretedHashFunction { case udt: UserDefinedType[_] => val mapType = udt.sqlType.asInstanceOf[MapType] mapType.keyType -> mapType.valueType - case MapType(_kt, _vt, _) => _kt -> _vt + case MapType(_kt, _vt, _, _) => _kt -> _vt } val keys = map.keyArray() val values = map.valueArray() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 5c27179ec3b46..ad20815f43120 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -540,7 +540,7 @@ case class MapObjects private( val genFunctionValue = lambdaFunction.dataType match { case StructType(_) => makeCopyIfInstanceOf(classOf[UnsafeRow], genFunction.value) case ArrayType(_, _) => makeCopyIfInstanceOf(classOf[UnsafeArrayData], genFunction.value) - case MapType(_, _, _) => makeCopyIfInstanceOf(classOf[UnsafeMapData], genFunction.value) + case MapType(_, _, _, _) => makeCopyIfInstanceOf(classOf[UnsafeMapData], genFunction.value) case _ => genFunction.value } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala index e24a3de3cfdbe..2ee803f2513d5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala @@ -18,22 +18,32 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ /** * An interpreted row ordering comparator. */ -class InterpretedOrdering(ordering: Seq[SortOrder]) extends Ordering[InternalRow] { +class InterpretedOrdering(orders: Seq[SortOrder]) extends Ordering[InternalRow] { def this(ordering: Seq[SortOrder], inputSchema: Seq[Attribute]) = this(ordering.map(BindReferences.bindReference(_, inputSchema))) + @transient private[this] lazy val orderings = orders.toIndexedSeq.map { order => + val ordering = TypeUtils.getInterpretedOrdering(order.dataType) + if (order.direction == Ascending) { + ordering + } else { + ordering.reverse + } + } + def compare(a: InternalRow, b: InternalRow): Int = { var i = 0 - val size = ordering.size + val size = orders.size while (i < size) { - val order = ordering(i) + val order = orders(i) val left = order.child.eval(a) val right = order.child.eval(b) @@ -44,29 +54,14 @@ class InterpretedOrdering(ordering: Seq[SortOrder]) extends Ordering[InternalRow } else if (right == null) { return if (order.nullOrdering == NullsFirst) 1 else -1 } else { - val comparison = order.dataType match { - case dt: AtomicType if order.direction == Ascending => - dt.ordering.asInstanceOf[Ordering[Any]].compare(left, right) - case dt: AtomicType if order.direction == Descending => - dt.ordering.asInstanceOf[Ordering[Any]].reverse.compare(left, right) - case a: ArrayType if order.direction == Ascending => - a.interpretedOrdering.asInstanceOf[Ordering[Any]].compare(left, right) - case a: ArrayType if order.direction == Descending => - a.interpretedOrdering.asInstanceOf[Ordering[Any]].reverse.compare(left, right) - case s: StructType if order.direction == Ascending => - s.interpretedOrdering.asInstanceOf[Ordering[Any]].compare(left, right) - case s: StructType if order.direction == Descending => - s.interpretedOrdering.asInstanceOf[Ordering[Any]].reverse.compare(left, right) - case other => - throw new IllegalArgumentException(s"Type $other does not support ordered operations") - } + val comparison = orderings(i).compare(left, right) if (comparison != 0) { return comparison } } i += 1 } - return 0 + 0 } } @@ -83,21 +78,10 @@ object InterpretedOrdering { } object RowOrdering { - - /** - * Returns true iff the data type can be ordered (i.e. can be sorted). - */ - def isOrderable(dataType: DataType): Boolean = dataType match { - case NullType => true - case dt: AtomicType => true - case struct: StructType => struct.fields.forall(f => isOrderable(f.dataType)) - case array: ArrayType => isOrderable(array.elementType) - case udt: UserDefinedType[_] => isOrderable(udt.sqlType) - case _ => false - } - /** * Returns true iff outputs from the expressions can be ordered. */ - def isOrderable(exprs: Seq[Expression]): Boolean = exprs.forall(e => isOrderable(e.dataType)) + def isOrderable(exprs: Seq[Expression]): Boolean = exprs.forall { e => + TypeUtils.isOrderable(e.dataType) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 3fcbb05372d87..f1decf8f9ff43 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -417,11 +417,10 @@ case class EqualTo(left: Expression, right: Expression) override def checkInputDataTypes(): TypeCheckResult = { super.checkInputDataTypes() match { case TypeCheckResult.TypeCheckSuccess => - // TODO: although map type is not orderable, technically map type should be able to be used - // in equality comparison, remove this type check once we support it. - if (left.dataType.existsRecursively(_.isInstanceOf[MapType])) { - TypeCheckResult.TypeCheckFailure("Cannot use map type in EqualTo, but the actual " + - s"input type is ${left.dataType.catalogString}.") + // Maps are only allowed when they are ordered. + if (MapType.containsUnorderedMap(left.dataType)) { + TypeCheckResult.TypeCheckFailure( + s"Cannot use unordered map type in EqualTo: ${left.dataType.catalogString}.") } else { TypeCheckResult.TypeCheckSuccess } @@ -450,11 +449,11 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp override def checkInputDataTypes(): TypeCheckResult = { super.checkInputDataTypes() match { case TypeCheckResult.TypeCheckSuccess => - // TODO: although map type is not orderable, technically map type should be able to be used - // in equality comparison, remove this type check once we support it. - if (left.dataType.existsRecursively(_.isInstanceOf[MapType])) { - TypeCheckResult.TypeCheckFailure("Cannot use map type in EqualNullSafe, but the actual " + - s"input type is ${left.dataType.catalogString}.") + EqualNullSafe + // Maps are only allowed when they are ordered. + if (MapType.containsUnorderedMap(left.dataType)) { + TypeCheckResult.TypeCheckFailure( + s"Cannot use unordered map type in EqualNullSafe: ${left.dataType.catalogString}.") } else { TypeCheckResult.TypeCheckSuccess } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 6958398e03f70..3687ddd167dcf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -521,8 +521,8 @@ object SimplifyCasts extends Rule[LogicalPlan] { case Cast(e, dataType) if e.dataType == dataType => e case c @ Cast(e, dataType) => (e.dataType, dataType) match { case (ArrayType(from, false), ArrayType(to, true)) if from == to => e - case (MapType(fromKey, fromValue, false), MapType(toKey, toValue, true)) - if fromKey == toKey && fromValue == toValue => e + case (MapType(fromKey, fromValue, false, fromOrder), MapType(toKey, toValue, true, toOrder)) + if fromKey == toKey && fromValue == toValue && (!toOrder || fromOrder) => e case _ => c } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala index 7101ca5a17de9..4f34b5f3278bb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.catalyst.util import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.RowOrdering import org.apache.spark.sql.types._ /** @@ -34,7 +33,7 @@ object TypeUtils { } def checkForOrderingExpr(dt: DataType, caller: String): TypeCheckResult = { - if (RowOrdering.isOrderable(dt)) { + if (isOrderable(dt)) { TypeCheckResult.TypeCheckSuccess } else { TypeCheckResult.TypeCheckFailure(s"$caller does not support ordering on type $dt") @@ -65,13 +64,34 @@ object TypeUtils { case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]] case a: ArrayType => a.interpretedOrdering.asInstanceOf[Ordering[Any]] case s: StructType => s.interpretedOrdering.asInstanceOf[Ordering[Any]] + case m: MapType if m.ordered => m.interpretedOrdering.asInstanceOf[Ordering[Any]] + case other => + throw new IllegalArgumentException(s"Type $other does not support ordered operations") } } + /** + * Returns true iff the data type can be ordered (i.e. can be sorted). + */ + def isOrderable(dataType: DataType): Boolean = dataType match { + case NullType => true + case dt: AtomicType => true + case struct: StructType => struct.fields.forall(f => isOrderable(f.dataType)) + case array: ArrayType => isOrderable(array.elementType) + case MapType(keyType, valueType, _, true) => isOrderable(keyType) && isOrderable(valueType) + case udt: UserDefinedType[_] => isOrderable(udt.sqlType) + case _ => false + } + def compareBinary(x: Array[Byte], y: Array[Byte]): Int = { - for (i <- 0 until x.length; if i < y.length) { - val res = x(i).compareTo(y(i)) - if (res != 0) return res + var i = 0 + val length = scala.math.min(x.length, y.length) + while (i < length) { + val res = x(i) - y(i) + if (res != 0) { + return res + } + i += 1 } x.length - y.length } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index 2642d9395ba88..dbe1ce667e453 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -203,7 +203,8 @@ object DataType { (left, right) match { case (ArrayType(leftElementType, _), ArrayType(rightElementType, _)) => equalsIgnoreNullability(leftElementType, rightElementType) - case (MapType(leftKeyType, leftValueType, _), MapType(rightKeyType, rightValueType, _)) => + case (MapType(leftKeyType, leftValueType, _, _), + MapType(rightKeyType, rightValueType, _, _)) => equalsIgnoreNullability(leftKeyType, rightKeyType) && equalsIgnoreNullability(leftValueType, rightValueType) case (StructType(leftFields), StructType(rightFields)) => @@ -234,7 +235,7 @@ object DataType { case (ArrayType(fromElement, fn), ArrayType(toElement, tn)) => (tn || !fn) && equalsIgnoreCompatibleNullability(fromElement, toElement) - case (MapType(fromKey, fromValue, fn), MapType(toKey, toValue, tn)) => + case (MapType(fromKey, fromValue, fn, _), MapType(toKey, toValue, tn, _)) => (tn || !fn) && equalsIgnoreCompatibleNullability(fromKey, toKey) && equalsIgnoreCompatibleNullability(fromValue, toValue) @@ -260,7 +261,7 @@ object DataType { case (ArrayType(fromElement, _), ArrayType(toElement, _)) => equalsIgnoreCaseAndNullability(fromElement, toElement) - case (MapType(fromKey, fromValue, _), MapType(toKey, toValue, _)) => + case (MapType(fromKey, fromValue, _, _), MapType(toKey, toValue, _, _)) => equalsIgnoreCaseAndNullability(fromKey, toKey) && equalsIgnoreCaseAndNullability(fromValue, toValue) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala index 3a32aa43d1c3a..21812f2673cc8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala @@ -17,10 +17,13 @@ package org.apache.spark.sql.types +import scala.math.Ordering + import org.json4s.JsonAST.JValue import org.json4s.JsonDSL._ import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.sql.catalyst.util.{MapData, TypeUtils} /** * The data type for Maps. Keys in a map are not allowed to have `null` values. @@ -33,9 +36,10 @@ import org.apache.spark.annotation.InterfaceStability */ @InterfaceStability.Stable case class MapType( - keyType: DataType, - valueType: DataType, - valueContainsNull: Boolean) extends DataType { + keyType: DataType, + valueType: DataType, + valueContainsNull: Boolean, + ordered: Boolean = false) extends DataType { /** No-arg constructor for kryo. */ def this() = this(null, null, false) @@ -68,11 +72,66 @@ case class MapType( override def sql: String = s"MAP<${keyType.sql}, ${valueType.sql}>" override private[spark] def asNullable: MapType = - MapType(keyType.asNullable, valueType.asNullable, valueContainsNull = true) + MapType(keyType.asNullable, valueType.asNullable, valueContainsNull = true, ordered) override private[spark] def existsRecursively(f: (DataType) => Boolean): Boolean = { f(this) || keyType.existsRecursively(f) || valueType.existsRecursively(f) } + + @transient + private[sql] lazy val interpretedKeyOrdering: Ordering[Any] = + TypeUtils.getInterpretedOrdering(keyType) + + @transient + private[sql] lazy val interpretedValueOrdering: Ordering[Any] = + TypeUtils.getInterpretedOrdering(valueType) + + @transient + private[sql] lazy val interpretedOrdering: Ordering[MapData] = new Ordering[MapData] { + assert(ordered) + val keyOrdering = interpretedKeyOrdering + val valueOrdering = interpretedValueOrdering + def compare(left: MapData, right: MapData): Int = { + val leftKeys = left.keyArray() + val leftValues = left.valueArray() + val rightKeys = right.keyArray() + val rightValues = right.valueArray() + val minLength = scala.math.min(leftKeys.numElements(), rightKeys.numElements()) + var i = 0 + while (i < minLength) { + val keyComp = keyOrdering.compare(leftKeys.get(i, keyType), rightKeys.get(i, keyType)) + if (keyComp != 0) { + return keyComp + } + // TODO this has been taken from ArrayData. Perhaps we should factor out the common code. + val isNullLeft = leftValues.isNullAt(i) + val isNullRight = rightValues.isNullAt(i) + if (isNullLeft && isNullRight) { + // Do nothing. + } else if (isNullLeft) { + return -1 + } else if (isNullRight) { + return 1 + } else { + val comp = valueOrdering.compare( + leftValues.get(i, valueType), + rightValues.get(i, valueType)) + if (comp != 0) { + return comp + } + } + i += 1 + } + val diff = left.numElements() - right.numElements() + if (diff < 0) { + -1 + } else if (diff > 0) { + 1 + } else { + 0 + } + } + } } /** @@ -94,5 +153,15 @@ object MapType extends AbstractDataType { * The `valueContainsNull` is true. */ def apply(keyType: DataType, valueType: DataType): MapType = - MapType(keyType: DataType, valueType: DataType, valueContainsNull = true) + new MapType(keyType, valueType, valueContainsNull = true, ordered = false) + + /** + * Check if a dataType contains an unordered map. + */ + private[sql] def containsUnorderedMap(dataType: DataType): Boolean = { + dataType.existsRecursively { + case m: MapType => !m.ordered + case _ => false + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index 0205c13aa986d..7ff5be98ee24b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -452,12 +452,13 @@ object StructType extends AbstractDataType { merge(leftElementType, rightElementType), leftContainsNull || rightContainsNull) - case (MapType(leftKeyType, leftValueType, leftContainsNull), - MapType(rightKeyType, rightValueType, rightContainsNull)) => + case (MapType(leftKeyType, leftValueType, leftContainsNull, leftOrdered), + MapType(rightKeyType, rightValueType, rightContainsNull, rightOrdered)) => MapType( merge(leftKeyType, rightKeyType), merge(leftValueType, rightValueType), - leftContainsNull || rightContainsNull) + leftContainsNull || rightContainsNull, + leftOrdered && rightOrdered) case (StructType(leftFields), StructType(rightFields)) => val newFields = ArrayBuffer.empty[StructField] diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala index 850869799507f..57e7dfa483724 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala @@ -200,7 +200,7 @@ object RandomDataGenerator { forType(elementType, nullable = containsNull, rand).map { elementGenerator => () => Seq.fill(rand.nextInt(MAX_ARR_SIZE))(elementGenerator()) } - case MapType(keyType, valueType, valueContainsNull) => + case MapType(keyType, valueType, valueContainsNull, false) => for ( keyGenerator <- forType(keyType, nullable = false, rand); valueGenerator <- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 8c1faea2394c6..83cd56ddb83a9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -212,8 +212,8 @@ class AnalysisErrorSuite extends AnalysisTest { errorTest( "sorting by unsupported column types", - mapRelation.orderBy('map.asc), - "sort" :: "type" :: "map" :: Nil) + intervalRelation.orderBy('interval.asc), + "sort" :: "type" :: "calendarinterval" :: Nil) errorTest( "sorting by attributes are not from grouping expressions", @@ -439,12 +439,7 @@ class AnalysisErrorSuite extends AnalysisTest { checkDataType(dataType, shouldSuccess = true) } - val unsupportedDataTypes = Seq( - MapType(StringType, LongType), - new StructType() - .add("f1", FloatType, nullable = true) - .add("f2", MapType(StringType, LongType), nullable = true), - new UngroupableUDT()) + val unsupportedDataTypes = Seq(new UngroupableUDT()) unsupportedDataTypes.foreach { dataType => checkDataType(dataType, shouldSuccess = false) } @@ -465,7 +460,7 @@ class AnalysisErrorSuite extends AnalysisTest { "another aggregate function." :: Nil) } - test("Join can work on binary types but can't work on map types") { + test("Join should work on map types") { val left = LocalRelation('a.binary, 'b.map(StringType, StringType)) val right = LocalRelation('c.binary, 'd.map(StringType, StringType)) @@ -480,7 +475,7 @@ class AnalysisErrorSuite extends AnalysisTest { right, joinType = Cross, condition = Some('b === 'd)) - assertAnalysisError(plan2, "Cannot use map type in EqualTo" :: Nil) + assertAnalysisSuccess(plan2) } test("PredicateSubQuery is used outside of a filter") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index 744057b7c5f4c..4bb9d523f6423 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -40,8 +40,8 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { val e = intercept[AnalysisException] { assertSuccess(expr) } - assert(e.getMessage.contains( - s"cannot resolve '${expr.sql}' due to data type mismatch:")) + assert(e.getMessage.contains("cannot resolve ")) + assert(e.getMessage.contains("due to data type mismatch:")) assert(e.getMessage.contains(errorMessage)) } @@ -51,8 +51,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { } def assertErrorForDifferingTypes(expr: Expression): Unit = { - assertError(expr, - s"differing types in '${expr.sql}'") + assertError(expr, "differing types in") } test("check types for unary arithmetic") { @@ -99,6 +98,8 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertSuccess(LessThanOrEqual('intField, 'stringField)) assertSuccess(GreaterThan('intField, 'stringField)) assertSuccess(GreaterThanOrEqual('intField, 'stringField)) + assertSuccess(EqualTo('mapField, 'mapField)) + assertSuccess(EqualNullSafe('mapField, 'mapField)) // We will transform EqualTo with numeric and boolean types to CaseKeyWhen assertSuccess(EqualTo('intField, 'booleanField)) @@ -111,8 +112,6 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertErrorForDifferingTypes(GreaterThan('intField, 'booleanField)) assertErrorForDifferingTypes(GreaterThanOrEqual('intField, 'booleanField)) - assertError(EqualTo('mapField, 'mapField), "Cannot use map type in EqualTo") - assertError(EqualNullSafe('mapField, 'mapField), "Cannot use map type in EqualNullSafe") assertError(LessThan('mapField, 'mapField), s"requires ${TypeCollection.Ordered.simpleString} type") assertError(LessThanOrEqual('mapField, 'mapField), diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala index 3741a6ba95a86..21e8aba7195e1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala @@ -57,4 +57,7 @@ object TestRelations { val mapRelation = LocalRelation( AttributeReference("map", MapType(IntegerType, IntegerType))()) + + val intervalRelation = LocalRelation( + AttributeReference("interval", CalendarIntervalType)()) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala index f80214af43fc1..69bb3d0d985eb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala @@ -175,7 +175,7 @@ case class GenerateExec( case ArrayType(dataType, nullable) => ("", "", Seq(codeGenAccessor(ctx, data.value, "col", index, dataType, nullable, checks))) - case MapType(keyType, valueType, valueContainsNull) => + case MapType(keyType, valueType, valueContainsNull, _) => // Materialize the key and the value arrays before we enter the loop. val keyArray = ctx.freshName("keyArray") val valueArray = ctx.freshName("valueArray") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index b3ef29f6e34c4..989ce11f4850d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -176,7 +176,7 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { }.mkString("{", ",", "}") case (seq: Seq[_], ArrayType(typ, _)) => seq.map(v => (v, typ)).map(toHiveStructString).mkString("[", ",", "]") - case (map: Map[_, _], MapType(kType, vType, _)) => + case (map: Map[_, _], MapType(kType, vType, _, _)) => map.map { case (key, value) => toHiveStructString((key, kType)) + ":" + toHiveStructString((value, vType)) @@ -194,7 +194,7 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { }.mkString("{", ",", "}") case (seq: Seq[_], ArrayType(typ, _)) => seq.map(v => (v, typ)).map(toHiveStructString).mkString("[", ",", "]") - case (map: Map[_, _], MapType(kType, vType, _)) => + case (map: Map[_, _], MapType(kType, vType, _, _)) => map.map { case (key, value) => toHiveStructString((key, kType)) + ":" + toHiveStructString((value, vType)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala index b4f36ce3752c0..36bdc66cb2ce4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala @@ -477,7 +477,7 @@ private[parquet] class ParquetSchemaConverter( // Spark 1.4.x and prior versions convert MapType into a 3-level group annotated by // MAP_KEY_VALUE. This is covered by `convertGroupField(field: GroupType): DataType`. - case MapType(keyType, valueType, valueContainsNull) if writeLegacyParquetFormat => + case MapType(keyType, valueType, valueContainsNull, _) if writeLegacyParquetFormat => // group (MAP) { // repeated group map (MAP_KEY_VALUE) { // required key; @@ -508,7 +508,7 @@ private[parquet] class ParquetSchemaConverter( .named("list")) .named(field.name) - case MapType(keyType, valueType, valueContainsNull) => + case MapType(keyType, valueType, valueContainsNull, _) => // group (MAP) { // repeated group key_value { // required key; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 5ba44ff9f5d9d..0fbf2f7419d0d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -25,10 +25,11 @@ import org.apache.spark.sql.{AnalysisException, SaveMode, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogRelation, CatalogTable, SessionCatalog} -import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Cast, RowOrdering} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Cast} import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.{BaseRelation, InsertableRelation} @@ -161,7 +162,7 @@ case class AnalyzeCreateTable(sparkSession: SparkSession) extends Rule[LogicalPl checkDuplication(normalizedSortCols, "sort") schema.filter(f => normalizedSortCols.contains(f.name)).map(_.dataType).foreach { - case dt if RowOrdering.isOrderable(dt) => // OK + case dt if TypeUtils.isOrderable(dt) => // OK case other => failAnalysis(s"Cannot use ${other.simpleString} for sorting column") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala index 46fd54e5c7420..a2a17e41b2898 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala @@ -39,7 +39,7 @@ object EvaluatePython { case _: StructType => true case _: UserDefinedType[_] => true case ArrayType(elementType, _) => needConversionInPython(elementType) - case MapType(keyType, valueType, _) => + case MapType(keyType, valueType, _, _) => needConversionInPython(keyType) || needConversionInPython(valueType) case _ => false } @@ -124,7 +124,7 @@ object EvaluatePython { case (c, ArrayType(elementType, _)) if c.getClass.isArray => new GenericArrayData(c.asInstanceOf[Array[_]].map(e => fromJava(e, elementType))) - case (javaMap: java.util.Map[_, _], MapType(keyType, valueType, _)) => + case (javaMap: java.util.Map[_, _], MapType(keyType, valueType, _, _)) => ArrayBasedMapData( javaMap, (key: Any) => fromJava(key, keyType), diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index c44fc3d393862..5517f68c275b6 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -161,7 +161,7 @@ void validateDataFrameWithBeans(Bean bean, Dataset df) { new StructField("b", new ArrayType(IntegerType$.MODULE$, true), true, Metadata.empty()), schema.apply("b")); ArrayType valueType = new ArrayType(DataTypes.IntegerType, false); - MapType mapType = new MapType(DataTypes.StringType, valueType, true); + MapType mapType = new MapType(DataTypes.StringType, valueType, true, false); Assert.assertEquals( new StructField("c", mapType, true, Metadata.empty()), schema.apply("c")); diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 52aa1088acd4a..0568cd11fce19 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -747,7 +747,7 @@ private[hive] trait HiveInspectors { def toInspector(dataType: DataType): ObjectInspector = dataType match { case ArrayType(tpe, _) => ObjectInspectorFactory.getStandardListObjectInspector(toInspector(tpe)) - case MapType(keyType, valueType, _) => + case MapType(keyType, valueType, _, _) => ObjectInspectorFactory.getStandardMapObjectInspector( toInspector(keyType), toInspector(valueType)) case StringType => PrimitiveObjectInspectorFactory.javaStringObjectInspector @@ -814,7 +814,7 @@ private[hive] trait HiveInspectors { list.add(wrap(e, listObjectInspector, dt))) ObjectInspectorFactory.getStandardConstantListObjectInspector(listObjectInspector, list) } - case Literal(value, MapType(keyType, valueType, _)) => + case Literal(value, MapType(keyType, valueType, _, _)) => val keyOI = toInspector(keyType) val valueOI = toInspector(valueType) if (value == null) { @@ -1014,7 +1014,7 @@ private[hive] trait HiveInspectors { getStructTypeInfo( java.util.Arrays.asList(fields.map(_.name) : _*), java.util.Arrays.asList(fields.map(_.dataType.toTypeInfo) : _*)) - case MapType(keyType, valueType, _) => + case MapType(keyType, valueType, _, _) => getMapTypeInfo(keyType.toTypeInfo, valueType.toTypeInfo) case BinaryType => binaryTypeInfo case BooleanType => booleanTypeInfo diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala index 81cd65c3cc337..c6e8ce3a2a450 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala @@ -414,7 +414,7 @@ private[spark] object HiveUtils extends Logging { }.mkString("{", ",", "}") case (seq: Seq[_], ArrayType(typ, _)) => seq.map(v => (v, typ)).map(toHiveStructString).mkString("[", ",", "]") - case (map: Map[_, _], MapType(kType, vType, _)) => + case (map: Map[_, _], MapType(kType, vType, _, _)) => map.map { case (key, value) => toHiveStructString((key, kType)) + ":" + toHiveStructString((value, vType)) @@ -437,7 +437,7 @@ private[spark] object HiveUtils extends Logging { }.mkString("{", ",", "}") case (seq: Seq[_], ArrayType(typ, _)) => seq.map(v => (v, typ)).map(toHiveStructString).mkString("[", ",", "]") - case (map: Map[_, _], MapType(kType, vType, _)) => + case (map: Map[_, _], MapType(kType, vType, _, _)) => map.map { case (key, value) => toHiveStructString((key, kType)) + ":" + toHiveStructString((value, vType)) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala index 3de1f4aeb74dc..25de09d2c7368 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala @@ -107,7 +107,7 @@ class HiveInspectorSuite extends SparkFunSuite with HiveInspectors { def toWritableInspector(dataType: DataType): ObjectInspector = dataType match { case ArrayType(tpe, _) => ObjectInspectorFactory.getStandardListObjectInspector(toWritableInspector(tpe)) - case MapType(keyType, valueType, _) => + case MapType(keyType, valueType, _, _) => ObjectInspectorFactory.getStandardMapObjectInspector( toWritableInspector(keyType), toWritableInspector(valueType)) case StringType => PrimitiveObjectInspectorFactory.writableStringObjectInspector