Skip to content

Commit

Permalink
[SPARK-23821][SQL] Collection function: flatten
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

This PR adds a new collection function that transforms an array of arrays into a single array. The PR comprises:
- An expression for flattening array structure
- Flatten function
- A wrapper for PySpark

## How was this patch tested?

New tests added into:
- CollectionExpressionsSuite
- DataFrameFunctionsSuite

## Codegen examples
### Primitive type
```
val df = Seq(
  Seq(Seq(1, 2), Seq(4, 5)),
  Seq(null, Seq(1))
).toDF("i")
df.filter($"i".isNotNull || $"i".isNull).select(flatten($"i")).debugCodegen
```
Result:
```
/* 033 */         boolean inputadapter_isNull = inputadapter_row.isNullAt(0);
/* 034 */         ArrayData inputadapter_value = inputadapter_isNull ?
/* 035 */         null : (inputadapter_row.getArray(0));
/* 036 */
/* 037 */         boolean filter_value = true;
/* 038 */
/* 039 */         if (!(!inputadapter_isNull)) {
/* 040 */           filter_value = inputadapter_isNull;
/* 041 */         }
/* 042 */         if (!filter_value) continue;
/* 043 */
/* 044 */         ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(1);
/* 045 */
/* 046 */         boolean project_isNull = inputadapter_isNull;
/* 047 */         ArrayData project_value = null;
/* 048 */
/* 049 */         if (!inputadapter_isNull) {
/* 050 */           for (int z = 0; !project_isNull && z < inputadapter_value.numElements(); z++) {
/* 051 */             project_isNull |= inputadapter_value.isNullAt(z);
/* 052 */           }
/* 053 */           if (!project_isNull) {
/* 054 */             long project_numElements = 0;
/* 055 */             for (int z = 0; z < inputadapter_value.numElements(); z++) {
/* 056 */               project_numElements += inputadapter_value.getArray(z).numElements();
/* 057 */             }
/* 058 */             if (project_numElements > 2147483632) {
/* 059 */               throw new RuntimeException("Unsuccessful try to flatten an array of arrays with " +
/* 060 */                 project_numElements + " elements due to exceeding the array size limit 2147483632.");
/* 061 */             }
/* 062 */
/* 063 */             long project_size = UnsafeArrayData.calculateSizeOfUnderlyingByteArray(
/* 064 */               project_numElements,
/* 065 */               4);
/* 066 */             if (project_size > 2147483632) {
/* 067 */               throw new RuntimeException("Unsuccessful try to flatten an array of arrays with " +
/* 068 */                 project_size + " bytes of data due to exceeding the limit 2147483632" +
/* 069 */                 " bytes for UnsafeArrayData.");
/* 070 */             }
/* 071 */
/* 072 */             byte[] project_array = new byte[(int)project_size];
/* 073 */             UnsafeArrayData project_tempArrayData = new UnsafeArrayData();
/* 074 */             Platform.putLong(project_array, 16, project_numElements);
/* 075 */             project_tempArrayData.pointTo(project_array, 16, (int)project_size);
/* 076 */             int project_counter = 0;
/* 077 */             for (int k = 0; k < inputadapter_value.numElements(); k++) {
/* 078 */               ArrayData arr = inputadapter_value.getArray(k);
/* 079 */               for (int l = 0; l < arr.numElements(); l++) {
/* 080 */                 if (arr.isNullAt(l)) {
/* 081 */                   project_tempArrayData.setNullAt(project_counter);
/* 082 */                 } else {
/* 083 */                   project_tempArrayData.setInt(
/* 084 */                     project_counter,
/* 085 */                     arr.getInt(l)
/* 086 */                   );
/* 087 */                 }
/* 088 */                 project_counter++;
/* 089 */               }
/* 090 */             }
/* 091 */             project_value = project_tempArrayData;
/* 092 */
/* 093 */           }
/* 094 */
/* 095 */         }
```
### Non-primitive type
```
val df = Seq(
  Seq(Seq("a", "b"), Seq(null, "d")),
  Seq(null, Seq("a"))
).toDF("s")
df.filter($"s".isNotNull || $"s".isNull).select(flatten($"s")).debugCodegen
```
Result:
```
/* 033 */         boolean inputadapter_isNull = inputadapter_row.isNullAt(0);
/* 034 */         ArrayData inputadapter_value = inputadapter_isNull ?
/* 035 */         null : (inputadapter_row.getArray(0));
/* 036 */
/* 037 */         boolean filter_value = true;
/* 038 */
/* 039 */         if (!(!inputadapter_isNull)) {
/* 040 */           filter_value = inputadapter_isNull;
/* 041 */         }
/* 042 */         if (!filter_value) continue;
/* 043 */
/* 044 */         ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(1);
/* 045 */
/* 046 */         boolean project_isNull = inputadapter_isNull;
/* 047 */         ArrayData project_value = null;
/* 048 */
/* 049 */         if (!inputadapter_isNull) {
/* 050 */           for (int z = 0; !project_isNull && z < inputadapter_value.numElements(); z++) {
/* 051 */             project_isNull |= inputadapter_value.isNullAt(z);
/* 052 */           }
/* 053 */           if (!project_isNull) {
/* 054 */             long project_numElements = 0;
/* 055 */             for (int z = 0; z < inputadapter_value.numElements(); z++) {
/* 056 */               project_numElements += inputadapter_value.getArray(z).numElements();
/* 057 */             }
/* 058 */             if (project_numElements > 2147483632) {
/* 059 */               throw new RuntimeException("Unsuccessful try to flatten an array of arrays with " +
/* 060 */                 project_numElements + " elements due to exceeding the array size limit 2147483632.");
/* 061 */             }
/* 062 */
/* 063 */             Object[] project_arrayObject = new Object[(int)project_numElements];
/* 064 */             int project_counter = 0;
/* 065 */             for (int k = 0; k < inputadapter_value.numElements(); k++) {
/* 066 */               ArrayData arr = inputadapter_value.getArray(k);
/* 067 */               for (int l = 0; l < arr.numElements(); l++) {
/* 068 */                 project_arrayObject[project_counter] = arr.getUTF8String(l);
/* 069 */                 project_counter++;
/* 070 */               }
/* 071 */             }
/* 072 */             project_value = new org.apache.spark.sql.catalyst.util.GenericArrayData(project_arrayObject);
/* 073 */
/* 074 */           }
/* 075 */
/* 076 */         }
```

Author: mn-mikke <mrkAha12346github>

Closes #20938 from mn-mikke/feature/array-api-flatten-to-master.
  • Loading branch information
mn-mikke authored and ueshin committed Apr 25, 2018
1 parent d6c26d1 commit 5fea17b
Show file tree
Hide file tree
Showing 6 changed files with 376 additions and 0 deletions.
17 changes: 17 additions & 0 deletions python/pyspark/sql/functions.py
Expand Up @@ -2191,6 +2191,23 @@ def reverse(col):
return Column(sc._jvm.functions.reverse(_to_java_column(col)))


@since(2.4)
def flatten(col):
"""
Collection function: creates a single array from an array of arrays.
If a structure of nested arrays is deeper than two levels,
only one level of nesting is removed.
:param col: name of column or expression
>>> df = spark.createDataFrame([([[1, 2, 3], [4, 5], [6]],), ([None, [4, 5]],)], ['data'])
>>> df.select(flatten(df.data).alias('r')).collect()
[Row(r=[1, 2, 3, 4, 5, 6]), Row(r=None)]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.flatten(_to_java_column(col)))


@since(2.3)
def map_keys(col):
"""
Expand Down
Expand Up @@ -413,6 +413,7 @@ object FunctionRegistry {
expression[ArrayMax]("array_max"),
expression[Reverse]("reverse"),
expression[Concat]("concat"),
expression[Flatten]("flatten"),
CreateStruct.registryEntry,

// misc functions
Expand Down
Expand Up @@ -883,3 +883,179 @@ case class Concat(children: Seq[Expression]) extends Expression {

override def sql: String = s"concat(${children.map(_.sql).mkString(", ")})"
}

/**
* Transforms an array of arrays into a single array.
*/
@ExpressionDescription(
usage = "_FUNC_(arrayOfArrays) - Transforms an array of arrays into a single array.",
examples = """
Examples:
> SELECT _FUNC_(array(array(1, 2), array(3, 4));
[1,2,3,4]
""",
since = "2.4.0")
case class Flatten(child: Expression) extends UnaryExpression {

private val MAX_ARRAY_LENGTH = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH

private lazy val childDataType: ArrayType = child.dataType.asInstanceOf[ArrayType]

override def nullable: Boolean = child.nullable || childDataType.containsNull

override def dataType: DataType = childDataType.elementType

lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType

override def checkInputDataTypes(): TypeCheckResult = child.dataType match {
case ArrayType(_: ArrayType, _) =>
TypeCheckResult.TypeCheckSuccess
case _ =>
TypeCheckResult.TypeCheckFailure(
s"The argument should be an array of arrays, " +
s"but '${child.sql}' is of ${child.dataType.simpleString} type."
)
}

override def nullSafeEval(child: Any): Any = {
val elements = child.asInstanceOf[ArrayData].toObjectArray(dataType)

if (elements.contains(null)) {
null
} else {
val arrayData = elements.map(_.asInstanceOf[ArrayData])
val numberOfElements = arrayData.foldLeft(0L)((sum, e) => sum + e.numElements())
if (numberOfElements > MAX_ARRAY_LENGTH) {
throw new RuntimeException("Unsuccessful try to flatten an array of arrays with " +
s"$numberOfElements elements due to exceeding the array size limit $MAX_ARRAY_LENGTH.")
}
val flattenedData = new Array(numberOfElements.toInt)
var position = 0
for (ad <- arrayData) {
val arr = ad.toObjectArray(elementType)
Array.copy(arr, 0, flattenedData, position, arr.length)
position += arr.length
}
new GenericArrayData(flattenedData)
}
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, c => {
val code = if (CodeGenerator.isPrimitiveType(elementType)) {
genCodeForFlattenOfPrimitiveElements(ctx, c, ev.value)
} else {
genCodeForFlattenOfNonPrimitiveElements(ctx, c, ev.value)
}
if (childDataType.containsNull) nullElementsProtection(ev, c, code) else code
})
}

private def nullElementsProtection(
ev: ExprCode,
childVariableName: String,
coreLogic: String): String = {
s"""
|for (int z = 0; !${ev.isNull} && z < $childVariableName.numElements(); z++) {
| ${ev.isNull} |= $childVariableName.isNullAt(z);
|}
|if (!${ev.isNull}) {
| $coreLogic
|}
""".stripMargin
}

private def genCodeForNumberOfElements(
ctx: CodegenContext,
childVariableName: String) : (String, String) = {
val variableName = ctx.freshName("numElements")
val code = s"""
|long $variableName = 0;
|for (int z = 0; z < $childVariableName.numElements(); z++) {
| $variableName += $childVariableName.getArray(z).numElements();
|}
|if ($variableName > $MAX_ARRAY_LENGTH) {
| throw new RuntimeException("Unsuccessful try to flatten an array of arrays with " +
| $variableName + " elements due to exceeding the array size limit $MAX_ARRAY_LENGTH.");
|}
""".stripMargin
(code, variableName)
}

private def genCodeForFlattenOfPrimitiveElements(
ctx: CodegenContext,
childVariableName: String,
arrayDataName: String): String = {
val arrayName = ctx.freshName("array")
val arraySizeName = ctx.freshName("size")
val counter = ctx.freshName("counter")
val tempArrayDataName = ctx.freshName("tempArrayData")

val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx, childVariableName)

val unsafeArraySizeInBytes = s"""
|long $arraySizeName = UnsafeArrayData.calculateSizeOfUnderlyingByteArray(
| $numElemName,
| ${elementType.defaultSize});
|if ($arraySizeName > $MAX_ARRAY_LENGTH) {
| throw new RuntimeException("Unsuccessful try to flatten an array of arrays with " +
| $arraySizeName + " bytes of data due to exceeding the limit $MAX_ARRAY_LENGTH" +
| " bytes for UnsafeArrayData.");
|}
""".stripMargin
val baseOffset = Platform.BYTE_ARRAY_OFFSET

val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType)

s"""
|$numElemCode
|$unsafeArraySizeInBytes
|byte[] $arrayName = new byte[(int)$arraySizeName];
|UnsafeArrayData $tempArrayDataName = new UnsafeArrayData();
|Platform.putLong($arrayName, $baseOffset, $numElemName);
|$tempArrayDataName.pointTo($arrayName, $baseOffset, (int)$arraySizeName);
|int $counter = 0;
|for (int k = 0; k < $childVariableName.numElements(); k++) {
| ArrayData arr = $childVariableName.getArray(k);
| for (int l = 0; l < arr.numElements(); l++) {
| if (arr.isNullAt(l)) {
| $tempArrayDataName.setNullAt($counter);
| } else {
| $tempArrayDataName.set$primitiveValueTypeName(
| $counter,
| ${CodeGenerator.getValue("arr", elementType, "l")}
| );
| }
| $counter++;
| }
|}
|$arrayDataName = $tempArrayDataName;
""".stripMargin
}

private def genCodeForFlattenOfNonPrimitiveElements(
ctx: CodegenContext,
childVariableName: String,
arrayDataName: String): String = {
val genericArrayClass = classOf[GenericArrayData].getName
val arrayName = ctx.freshName("arrayObject")
val counter = ctx.freshName("counter")
val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx, childVariableName)

s"""
|$numElemCode
|Object[] $arrayName = new Object[(int)$numElemName];
|int $counter = 0;
|for (int k = 0; k < $childVariableName.numElements(); k++) {
| ArrayData arr = $childVariableName.getArray(k);
| for (int l = 0; l < arr.numElements(); l++) {
| $arrayName[$counter] = ${CodeGenerator.getValue("arr", elementType, "l")};
| $counter++;
| }
|}
|$arrayDataName = new $genericArrayClass($arrayName);
""".stripMargin
}

override def prettyName: String = "flatten"
}
Expand Up @@ -280,4 +280,99 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper

checkEvaluation(Concat(Seq(aa0, aa1)), Seq(Seq("a", "b"), Seq("c"), Seq("d"), Seq("e", "f")))
}

test("Flatten") {
// Primitive-type test cases
val intArrayType = ArrayType(ArrayType(IntegerType))

// Main test cases (primitive type)
val aim1 = Literal.create(Seq(Seq(1, 2, 3), Seq(4, 5), Seq(6)), intArrayType)
val aim2 = Literal.create(Seq(Seq(1, 2, 3)), intArrayType)

checkEvaluation(Flatten(aim1), Seq(1, 2, 3, 4, 5, 6))
checkEvaluation(Flatten(aim2), Seq(1, 2, 3))

// Test cases with an empty array (primitive type)
val aie1 = Literal.create(Seq(Seq.empty, Seq(1, 2), Seq(3, 4)), intArrayType)
val aie2 = Literal.create(Seq(Seq(1, 2), Seq.empty, Seq(3, 4)), intArrayType)
val aie3 = Literal.create(Seq(Seq(1, 2), Seq(3, 4), Seq.empty), intArrayType)
val aie4 = Literal.create(Seq(Seq.empty, Seq.empty, Seq.empty), intArrayType)
val aie5 = Literal.create(Seq(Seq.empty), intArrayType)
val aie6 = Literal.create(Seq.empty, intArrayType)

checkEvaluation(Flatten(aie1), Seq(1, 2, 3, 4))
checkEvaluation(Flatten(aie2), Seq(1, 2, 3, 4))
checkEvaluation(Flatten(aie3), Seq(1, 2, 3, 4))
checkEvaluation(Flatten(aie4), Seq.empty)
checkEvaluation(Flatten(aie5), Seq.empty)
checkEvaluation(Flatten(aie6), Seq.empty)

// Test cases with null elements (primitive type)
val ain1 = Literal.create(Seq(Seq(null, null, null), Seq(4, null)), intArrayType)
val ain2 = Literal.create(Seq(Seq(null, 2, null), Seq(null, null)), intArrayType)
val ain3 = Literal.create(Seq(Seq(null, null), Seq(null, null)), intArrayType)

checkEvaluation(Flatten(ain1), Seq(null, null, null, 4, null))
checkEvaluation(Flatten(ain2), Seq(null, 2, null, null, null))
checkEvaluation(Flatten(ain3), Seq(null, null, null, null))

// Test cases with a null array (primitive type)
val aia1 = Literal.create(Seq(null, Seq(1, 2)), intArrayType)
val aia2 = Literal.create(Seq(Seq(1, 2), null), intArrayType)
val aia3 = Literal.create(Seq(null), intArrayType)
val aia4 = Literal.create(null, intArrayType)

checkEvaluation(Flatten(aia1), null)
checkEvaluation(Flatten(aia2), null)
checkEvaluation(Flatten(aia3), null)
checkEvaluation(Flatten(aia4), null)

// Non-primitive-type test cases
val strArrayType = ArrayType(ArrayType(StringType))
val arrArrayType = ArrayType(ArrayType(ArrayType(StringType)))

// Main test cases (non-primitive type)
val asm1 = Literal.create(Seq(Seq("a"), Seq("b", "c"), Seq("d", "e", "f")), strArrayType)
val asm2 = Literal.create(Seq(Seq("a", "b")), strArrayType)
val asm3 = Literal.create(Seq(Seq(Seq("a", "b"), Seq("c")), Seq(Seq("d", "e"))), arrArrayType)

checkEvaluation(Flatten(asm1), Seq("a", "b", "c", "d", "e", "f"))
checkEvaluation(Flatten(asm2), Seq("a", "b"))
checkEvaluation(Flatten(asm3), Seq(Seq("a", "b"), Seq("c"), Seq("d", "e")))

// Test cases with an empty array (non-primitive type)
val ase1 = Literal.create(Seq(Seq.empty, Seq("a", "b"), Seq("c", "d")), strArrayType)
val ase2 = Literal.create(Seq(Seq("a", "b"), Seq.empty, Seq("c", "d")), strArrayType)
val ase3 = Literal.create(Seq(Seq("a", "b"), Seq("c", "d"), Seq.empty), strArrayType)
val ase4 = Literal.create(Seq(Seq.empty, Seq.empty, Seq.empty), strArrayType)
val ase5 = Literal.create(Seq(Seq.empty), strArrayType)
val ase6 = Literal.create(Seq.empty, strArrayType)

checkEvaluation(Flatten(ase1), Seq("a", "b", "c", "d"))
checkEvaluation(Flatten(ase2), Seq("a", "b", "c", "d"))
checkEvaluation(Flatten(ase3), Seq("a", "b", "c", "d"))
checkEvaluation(Flatten(ase4), Seq.empty)
checkEvaluation(Flatten(ase5), Seq.empty)
checkEvaluation(Flatten(ase6), Seq.empty)

// Test cases with null elements (non-primitive type)
val asn1 = Literal.create(Seq(Seq(null, null, "c"), Seq(null, null)), strArrayType)
val asn2 = Literal.create(Seq(Seq(null, null, null), Seq("d", null)), strArrayType)
val asn3 = Literal.create(Seq(Seq(null, null), Seq(null, null)), strArrayType)

checkEvaluation(Flatten(asn1), Seq(null, null, "c", null, null))
checkEvaluation(Flatten(asn2), Seq(null, null, null, "d", null))
checkEvaluation(Flatten(asn3), Seq(null, null, null, null))

// Test cases with a null array (non-primitive type)
val asa1 = Literal.create(Seq(null, Seq("a", "b")), strArrayType)
val asa2 = Literal.create(Seq(Seq("a", "b"), null), strArrayType)
val asa3 = Literal.create(Seq(null), strArrayType)
val asa4 = Literal.create(null, strArrayType)

checkEvaluation(Flatten(asa1), null)
checkEvaluation(Flatten(asa2), null)
checkEvaluation(Flatten(asa3), null)
checkEvaluation(Flatten(asa4), null)
}
}
8 changes: 8 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Expand Up @@ -3340,6 +3340,14 @@ object functions {
*/
def reverse(e: Column): Column = withExpr { Reverse(e.expr) }

/**
* Creates a single array from an array of arrays. If a structure of nested arrays is deeper than
* two levels, only one level of nesting is removed.
* @group collection_funcs
* @since 2.4.0
*/
def flatten(e: Column): Column = withExpr { Flatten(e.expr) }

/**
* Returns an unordered array containing the keys of the map.
* @group collection_funcs
Expand Down

0 comments on commit 5fea17b

Please sign in to comment.