Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-25832][SQL][BRANCH-2.4] Revert newly added map related functions #22827

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion R/pkg/NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,6 @@ exportMethods("%<=>%",
"lower",
"lpad",
"ltrim",
"map_entries",
"map_from_arrays",
"map_keys",
"map_values",
Expand Down
15 changes: 1 addition & 14 deletions R/pkg/R/functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ NULL
#' head(select(tmp, sort_array(tmp$v1)))
#' head(select(tmp, sort_array(tmp$v1, asc = FALSE)))
#' tmp3 <- mutate(df, v3 = create_map(df$model, df$cyl))
#' head(select(tmp3, map_entries(tmp3$v3), map_keys(tmp3$v3), map_values(tmp3$v3)))
#' head(select(tmp3, map_keys(tmp3$v3), map_values(tmp3$v3)))
#' head(select(tmp3, element_at(tmp3$v3, "Valiant")))
#' tmp4 <- mutate(df, v4 = create_array(df$mpg, df$cyl), v5 = create_array(df$cyl, df$hp))
#' head(select(tmp4, concat(tmp4$v4, tmp4$v5), arrays_overlap(tmp4$v4, tmp4$v5)))
Expand Down Expand Up @@ -3252,19 +3252,6 @@ setMethod("flatten",
column(jc)
})

#' @details
#' \code{map_entries}: Returns an unordered array of all entries in the given map.
#'
#' @rdname column_collection_functions
#' @aliases map_entries map_entries,Column-method
#' @note map_entries since 2.4.0
setMethod("map_entries",
signature(x = "Column"),
function(x) {
jc <- callJStatic("org.apache.spark.sql.functions", "map_entries", x@jc)
column(jc)
})

#' @details
#' \code{map_from_arrays}: Creates a new map column. The array in the first column is used for
#' keys. The array in the second column is used for values. All elements in the array for key
Expand Down
4 changes: 0 additions & 4 deletions R/pkg/R/generics.R
Original file line number Diff line number Diff line change
Expand Up @@ -1076,10 +1076,6 @@ setGeneric("lpad", function(x, len, pad) { standardGeneric("lpad") })
#' @name NULL
setGeneric("ltrim", function(x, trimString) { standardGeneric("ltrim") })

#' @rdname column_collection_functions
#' @name NULL
setGeneric("map_entries", function(x) { standardGeneric("map_entries") })

#' @rdname column_collection_functions
#' @name NULL
setGeneric("map_from_arrays", function(x, y) { standardGeneric("map_from_arrays") })
Expand Down
7 changes: 1 addition & 6 deletions R/pkg/tests/fulltests/test_sparkSQL.R
Original file line number Diff line number Diff line change
Expand Up @@ -1570,13 +1570,8 @@ test_that("column functions", {
result <- collect(select(df, flatten(df[[1]])))[[1]]
expect_equal(result, list(list(1L, 2L, 3L, 4L), list(5L, 6L, 7L, 8L)))

# Test map_entries(), map_keys(), map_values() and element_at()
# Test map_keys(), map_values() and element_at()
df <- createDataFrame(list(list(map = as.environment(list(x = 1, y = 2)))))
result <- collect(select(df, map_entries(df$map)))[[1]]
expected_entries <- list(listToStruct(list(key = "x", value = 1)),
listToStruct(list(key = "y", value = 2)))
expect_equal(result, list(expected_entries))

result <- collect(select(df, map_keys(df$map)))[[1]]
expect_equal(result, list(list("x", "y")))

Expand Down
20 changes: 0 additions & 20 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2540,26 +2540,6 @@ def map_values(col):
return Column(sc._jvm.functions.map_values(_to_java_column(col)))


@since(2.4)
def map_entries(col):
"""
Collection function: Returns an unordered array of all entries in the given map.

:param col: name of column or expression

>>> from pyspark.sql.functions import map_entries
>>> df = spark.sql("SELECT map(1, 'a', 2, 'b') as data")
>>> df.select(map_entries("data").alias("entries")).show()
+----------------+
| entries|
+----------------+
|[[1, a], [2, b]]|
+----------------+
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.map_entries(_to_java_column(col)))


@since(2.4)
def map_from_entries(col):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,6 @@ object FunctionRegistry {
expression[MapFromArrays]("map_from_arrays"),
expression[MapKeys]("map_keys"),
expression[MapValues]("map_values"),
expression[MapEntries]("map_entries"),
expression[MapFromEntries]("map_from_entries"),
expression[MapConcat]("map_concat"),
expression[Size]("size"),
Expand All @@ -433,13 +432,9 @@ object FunctionRegistry {
expression[ArrayRemove]("array_remove"),
expression[ArrayDistinct]("array_distinct"),
expression[ArrayTransform]("transform"),
expression[MapFilter]("map_filter"),
expression[ArrayFilter]("filter"),
expression[ArrayExists]("exists"),
expression[ArrayAggregate]("aggregate"),
expression[TransformValues]("transform_values"),
expression[TransformKeys]("transform_keys"),
expression[MapZipWith]("map_zip_with"),
expression[ZipWith]("zip_with"),

CreateStruct.registryEntry,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ object TypeCoercion {
BooleanEquality ::
FunctionArgumentConversion ::
ConcatCoercion(conf) ::
MapZipWithCoercion ::
EltCoercion(conf) ::
CaseWhenCoercion ::
IfCoercion ::
Expand Down Expand Up @@ -763,30 +762,6 @@ object TypeCoercion {
}
}

/**
* Coerces key types of two different [[MapType]] arguments of the [[MapZipWith]] expression
* to a common type.
*/
object MapZipWithCoercion extends TypeCoercionRule {
override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
// Lambda function isn't resolved when the rule is executed.
case m @ MapZipWith(left, right, function) if m.arguments.forall(a => a.resolved &&
MapType.acceptsType(a.dataType)) && !m.leftKeyType.sameType(m.rightKeyType) =>
findWiderTypeForTwo(m.leftKeyType, m.rightKeyType) match {
case Some(finalKeyType) if !Cast.forceNullable(m.leftKeyType, finalKeyType) &&
!Cast.forceNullable(m.rightKeyType, finalKeyType) =>
val newLeft = castIfNotSameType(
left,
MapType(finalKeyType, m.leftValueType, m.leftValueContainsNull))
val newRight = castIfNotSameType(
right,
MapType(finalKeyType, m.rightValueType, m.rightValueContainsNull))
MapZipWith(newLeft, newRight, function)
case _ => m
}
}
}

/**
* Coerces the types of [[Elt]] children to expected ones.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -340,174 +340,6 @@ case class MapValues(child: Expression)
override def prettyName: String = "map_values"
}

/**
* Returns an unordered array of all entries in the given map.
*/
@ExpressionDescription(
usage = "_FUNC_(map) - Returns an unordered array of all entries in the given map.",
examples = """
Examples:
> SELECT _FUNC_(map(1, 'a', 2, 'b'));
[{"key":1,"value":"a"},{"key":2,"value":"b"}]
""",
since = "2.4.0")
case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInputTypes {

override def inputTypes: Seq[AbstractDataType] = Seq(MapType)

@transient private lazy val childDataType: MapType = child.dataType.asInstanceOf[MapType]

override def dataType: DataType = {
ArrayType(
StructType(
StructField("key", childDataType.keyType, false) ::
StructField("value", childDataType.valueType, childDataType.valueContainsNull) ::
Nil),
false)
}

override protected def nullSafeEval(input: Any): Any = {
val childMap = input.asInstanceOf[MapData]
val keys = childMap.keyArray()
val values = childMap.valueArray()
val length = childMap.numElements()
val resultData = new Array[AnyRef](length)
var i = 0
while (i < length) {
val key = keys.get(i, childDataType.keyType)
val value = values.get(i, childDataType.valueType)
val row = new GenericInternalRow(Array[Any](key, value))
resultData.update(i, row)
i += 1
}
new GenericArrayData(resultData)
}

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, c => {
val arrayData = ctx.freshName("arrayData")
val numElements = ctx.freshName("numElements")
val keys = ctx.freshName("keys")
val values = ctx.freshName("values")
val isKeyPrimitive = CodeGenerator.isPrimitiveType(childDataType.keyType)
val isValuePrimitive = CodeGenerator.isPrimitiveType(childDataType.valueType)

val wordSize = UnsafeRow.WORD_SIZE
val structSize = UnsafeRow.calculateBitSetWidthInBytes(2) + wordSize * 2
val (isPrimitive, elementSize) = if (isKeyPrimitive && isValuePrimitive) {
(true, structSize + wordSize)
} else {
(false, -1)
}

val allocation =
s"""
|ArrayData $arrayData = ArrayData.allocateArrayData(
| $elementSize, $numElements, " $prettyName failed.");
""".stripMargin

val code = if (isPrimitive) {
val genCodeForPrimitive = genCodeForPrimitiveElements(
ctx, arrayData, keys, values, ev.value, numElements, structSize)
s"""
|if ($arrayData instanceof UnsafeArrayData) {
| $genCodeForPrimitive
|} else {
| ${genCodeForAnyElements(ctx, arrayData, keys, values, ev.value, numElements)}
|}
""".stripMargin
} else {
s"${genCodeForAnyElements(ctx, arrayData, keys, values, ev.value, numElements)}"
}

s"""
|final int $numElements = $c.numElements();
|final ArrayData $keys = $c.keyArray();
|final ArrayData $values = $c.valueArray();
|$allocation
|$code
""".stripMargin
})
}

private def getKey(varName: String, index: String) =
CodeGenerator.getValue(varName, childDataType.keyType, index)

private def getValue(varName: String, index: String) =
CodeGenerator.getValue(varName, childDataType.valueType, index)

private def genCodeForPrimitiveElements(
ctx: CodegenContext,
arrayData: String,
keys: String,
values: String,
resultArrayData: String,
numElements: String,
structSize: Int): String = {
val unsafeArrayData = ctx.freshName("unsafeArrayData")
val baseObject = ctx.freshName("baseObject")
val unsafeRow = ctx.freshName("unsafeRow")
val structsOffset = ctx.freshName("structsOffset")
val offset = ctx.freshName("offset")
val z = ctx.freshName("z")
val calculateHeader = "UnsafeArrayData.calculateHeaderPortionInBytes"

val baseOffset = Platform.BYTE_ARRAY_OFFSET
val wordSize = UnsafeRow.WORD_SIZE
val structSizeAsLong = s"${structSize}L"

val setKey = CodeGenerator.setColumn(unsafeRow, childDataType.keyType, 0, getKey(keys, z))

val valueAssignmentChecked = CodeGenerator.createArrayAssignment(
unsafeRow, childDataType.valueType, values, "1", z, childDataType.valueContainsNull)

s"""
|UnsafeArrayData $unsafeArrayData = (UnsafeArrayData)$arrayData;
|Object $baseObject = $unsafeArrayData.getBaseObject();
|final int $structsOffset = $calculateHeader($numElements) + $numElements * $wordSize;
|UnsafeRow $unsafeRow = new UnsafeRow(2);
|for (int $z = 0; $z < $numElements; $z++) {
| long $offset = $structsOffset + $z * $structSizeAsLong;
| $unsafeArrayData.setLong($z, ($offset << 32) + $structSizeAsLong);
| $unsafeRow.pointTo($baseObject, $baseOffset + $offset, $structSize);
| $setKey;
| $valueAssignmentChecked
|}
|$resultArrayData = $arrayData;
""".stripMargin
}

private def genCodeForAnyElements(
ctx: CodegenContext,
arrayData: String,
keys: String,
values: String,
resultArrayData: String,
numElements: String): String = {
val z = ctx.freshName("z")
val isValuePrimitive = CodeGenerator.isPrimitiveType(childDataType.valueType)
val getValueWithCheck = if (childDataType.valueContainsNull && isValuePrimitive) {
s"$values.isNullAt($z) ? null : (Object)${getValue(values, z)}"
} else {
getValue(values, z)
}

val rowClass = classOf[GenericInternalRow].getName
val genericArrayDataClass = classOf[GenericArrayData].getName
val genericArrayData = ctx.freshName("genericArrayData")
val rowObject = s"new $rowClass(new Object[]{${getKey(keys, z)}, $getValueWithCheck})"
s"""
|$genericArrayDataClass $genericArrayData = ($genericArrayDataClass)$arrayData;
|for (int $z = 0; $z < $numElements; $z++) {
| $genericArrayData.update($z, $rowObject);
|}
|$resultArrayData = $arrayData;
""".stripMargin
}

override def prettyName: String = "map_entries"
}

/**
* Returns the union of all the given maps.
*/
Expand Down
Loading