Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-18134][SQL] Comparable MapTypes [POC] #15970

Closed
wants to merge 10 commits into from
Expand Up @@ -141,7 +141,7 @@ public static MapType createMapType(DataType keyType, DataType valueType) {
if (valueType == null) {
throw new IllegalArgumentException("valueType should not be null.");
}
return new MapType(keyType, valueType, true);
return new MapType(keyType, valueType, true, false);
}

/**
Expand All @@ -159,7 +159,7 @@ public static MapType createMapType(
if (valueType == null) {
throw new IllegalArgumentException("valueType should not be null.");
}
return new MapType(keyType, valueType, valueContainsNull);
return new MapType(keyType, valueType, valueContainsNull, false);
}

/**
Expand Down
Expand Up @@ -104,6 +104,7 @@ class Analyzer(
ResolveAggregateFunctions ::
TimeWindowing ::
ResolveInlineTables ::
SortMaps ::
TypeCoercion.typeCoercionRules ++
extendedResolutionRules : _*),
Batch("Nondeterministic", Once,
Expand Down Expand Up @@ -2329,3 +2330,49 @@ object ResolveCreateNamedStruct extends Rule[LogicalPlan] {
CreateNamedStruct(children.toList)
}
}

/**
* MapType expressions are not comparable.
*/
object SortMaps extends Rule[LogicalPlan] {
private def containsUnorderedMap(e: Expression): Boolean =
e.resolved && MapType.containsUnorderedMap(e.dataType)

override def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressions {
case cmp @ BinaryComparison(left, right) if containsUnorderedMap(left) =>
cmp.withNewChildren(OrderMaps(left) :: right :: Nil)
case cmp @ BinaryComparison(left, right) if containsUnorderedMap(right) =>
cmp.withNewChildren(left :: OrderMaps(right) :: Nil)
case sort: SortOrder if containsUnorderedMap(sort.child) =>
sort.copy(child = OrderMaps(sort.child))
} transform {
case a: Aggregate if a.resolved && a.groupingExpressions.exists(containsUnorderedMap) =>
// Modify the top level grouping expressions
val replacements = a.groupingExpressions.collect {
case a: Attribute if containsUnorderedMap(a) =>
a -> Alias(OrderMaps(a), a.name)(exprId = a.exprId, qualifier = a.qualifier)
case e if containsUnorderedMap(e) =>
e -> OrderMaps(e)
}

// Tranform the expression tree.
a.transformExpressionsUp {
case e =>
// TODO create an expression map!
replacements
.find(_._1.semanticEquals(e))
.map(_._2)
.getOrElse(e)
}

case Distinct(child) if child.resolved && child.output.exists(containsUnorderedMap) =>
val projectList = child.output.map { a =>
if (containsUnorderedMap(a)) {
Alias(OrderMaps(a), a.name)(exprId = a.exprId, qualifier = a.qualifier)
} else {
a
}
}
Distinct(Project(projectList, child))
}
}
Expand Up @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans.UsingJoin
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._

/**
Expand Down Expand Up @@ -244,7 +245,7 @@ trait CheckAnalysis extends PredicateHelper {

def checkValidGroupingExprs(expr: Expression): Unit = {
// Check if the data type of expr is orderable.
if (!RowOrdering.isOrderable(expr.dataType)) {
if (!TypeUtils.isOrderable(expr.dataType)) {
failAnalysis(
s"expression ${expr.sql} cannot be used as a grouping expression " +
s"because its data type ${expr.dataType.simpleString} is not an orderable " +
Expand All @@ -265,7 +266,7 @@ trait CheckAnalysis extends PredicateHelper {

case Sort(orders, _, _) =>
orders.foreach { order =>
if (!RowOrdering.isOrderable(order.dataType)) {
if (!TypeUtils.isOrderable(order.dataType)) {
failAnalysis(
s"sorting is not supported for columns of type ${order.dataType.simpleString}")
}
Expand Down
Expand Up @@ -133,7 +133,7 @@ object RowEncoder {
ObjectType(classOf[Object]))
}

case t @ MapType(kt, vt, valueNullable) =>
case t @ MapType(kt, vt, valueNullable, _) =>
val keys =
Invoke(
Invoke(inputObject, "keysIterator", ObjectType(classOf[scala.collection.Iterator[_]])),
Expand Down Expand Up @@ -279,7 +279,7 @@ object RowEncoder {
"make",
arrayData :: Nil)

case MapType(kt, vt, valueNullable) =>
case MapType(kt, vt, valueNullable, _) =>
val keyArrayType = ArrayType(kt, false)
val keyData = deserializerFor(Invoke(input, "keyArray", keyArrayType))

Expand Down
Expand Up @@ -67,7 +67,7 @@ object Cast {
canCast(fromType, toType) &&
resolvableNullability(fn || forceNullable(fromType, toType), tn)

case (MapType(fromKey, fromValue, fn), MapType(toKey, toValue, tn)) =>
case (MapType(fromKey, fromValue, fn, _), MapType(toKey, toValue, tn, false)) =>
canCast(fromKey, toKey) &&
(!forceNullable(fromKey, toKey)) &&
canCast(fromValue, toValue) &&
Expand Down
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
import org.apache.spark.util.collection.unsafe.sort.PrefixComparators._

Expand Down Expand Up @@ -61,7 +62,7 @@ case class SortOrder(child: Expression, direction: SortDirection, nullOrdering:
override def foldable: Boolean = false

override def checkInputDataTypes(): TypeCheckResult = {
if (RowOrdering.isOrderable(dataType)) {
if (TypeUtils.isOrderable(dataType)) {
TypeCheckResult.TypeCheckSuccess
} else {
TypeCheckResult.TypeCheckFailure(s"cannot sort data type ${dataType.simpleString}")
Expand Down
Expand Up @@ -484,6 +484,7 @@ class CodegenContext {
case dt: DataType if dt.isInstanceOf[AtomicType] => s"$c1.equals($c2)"
case array: ArrayType => genComp(array, c1, c2) + " == 0"
case struct: StructType => genComp(struct, c1, c2) + " == 0"
case map: MapType if map.ordered => genComp(map, c1, c2) + " == 0"
case udt: UserDefinedType[_] => genEqual(udt.sqlType, c1, c2)
case _ =>
throw new IllegalArgumentException(
Expand Down Expand Up @@ -554,6 +555,47 @@ class CodegenContext {
"""
addNewFunction(compareFunc, funcCode)
s"this.$compareFunc($c1, $c2)"
case MapType(keyType, valueType, _, true) =>
val compareFunc = freshName("compareMap")
val funcCode: String =
s"""
public int $compareFunc(MapData a, MapData b) {
int lengthA = a.numElements();
int lengthB = b.numElements();
ArrayData aKeys = a.keyArray();
ArrayData aValues = a.valueArray();
ArrayData bKeys = b.keyArray();
ArrayData bValues = b.valueArray();
int minLength = (lengthA > lengthB) ? lengthB : lengthA;
for (int i = 0; i < minLength; i++) {
${javaType(keyType)} keyA = ${getValue("aKeys", keyType, "i")};
${javaType(keyType)} keyB = ${getValue("bKeys", keyType, "i")};
int comp = ${genComp(keyType, "keyA", "keyB")};
if (comp != 0) {
return comp;
}
boolean isNullA = aValues.isNullAt(i);
boolean isNullB = bValues.isNullAt(i);
if (isNullA && isNullB) {
// Nothing
} else if (isNullA) {
return -1;
} else if (isNullB) {
return 1;
} else {
${javaType(valueType)} valueA = ${getValue("aValues", valueType, "i")};
${javaType(valueType)} valueB = ${getValue("bValues", valueType, "i")};
comp = ${genComp(valueType, "valueA", "valueB")};
if (comp != 0) {
return comp;
}
}
}
return lengthA - lengthB;
}
"""
addNewFunction(compareFunc, funcCode)
s"this.$compareFunc($c1, $c2)"
case schema: StructType =>
INPUT_ROW = "i"
val comparisons = GenerateOrdering.genComparisons(this, schema)
Expand Down
Expand Up @@ -130,7 +130,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
dataType: DataType): ExprCode = dataType match {
case s: StructType => createCodeForStruct(ctx, input, s)
case ArrayType(elementType, _) => createCodeForArray(ctx, input, elementType)
case MapType(keyType, valueType, _) => createCodeForMap(ctx, input, keyType, valueType)
case MapType(keyType, valueType, _, _) => createCodeForMap(ctx, input, keyType, valueType)
// UTF8String act as a pointer if it's inside UnsafeRow, so copy it to make it safe.
case StringType => ExprCode("", "false", s"$input.clone()")
case udt: UserDefinedType[_] => convertToSafe(ctx, input, udt.sqlType)
Expand Down
Expand Up @@ -38,7 +38,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
case _: CalendarIntervalType => true
case t: StructType => t.toSeq.forall(field => canSupport(field.dataType))
case t: ArrayType if canSupport(t.elementType) => true
case MapType(kt, vt, _) if canSupport(kt) && canSupport(vt) => true
case MapType(kt, vt, _, _) if canSupport(kt) && canSupport(vt) => true
case udt: UserDefinedType[_] => canSupport(udt.sqlType)
case _ => false
}
Expand Down Expand Up @@ -126,7 +126,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
$rowWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor);
"""

case m @ MapType(kt, vt, _) =>
case m @ MapType(kt, vt, _, _) =>
s"""
// Remember the current cursor so that we can calculate how many bytes are
// written later.
Expand Down Expand Up @@ -209,7 +209,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
$arrayWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor);
"""

case m @ MapType(kt, vt, _) =>
case m @ MapType(kt, vt, _, _) =>
s"""
final int $tmpCursor = $bufferHolder.cursor;
${writeMapToBuffer(ctx, element, kt, vt, bufferHolder)}
Expand Down