-
Notifications
You must be signed in to change notification settings - Fork 28.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-23920][SQL]add array_remove to remove all elements that equal element from array #21069
Changes from 6 commits
f92e18c
f6a629b
1c24720
89b4f48
9281ae2
074ed88
52d2308
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1882,3 +1882,117 @@ case class ArrayRepeat(left: Expression, right: Expression) | |
} | ||
|
||
} | ||
|
||
/** | ||
* Remove all elements that equal to element from the given array | ||
*/ | ||
@ExpressionDescription( | ||
usage = "_FUNC_(array, element) - Remove all elements that equal to element from array.", | ||
examples = """ | ||
Examples: | ||
> SELECT _FUNC_(array(1, 2, 3, null, 3), 3); | ||
[1,2,null] | ||
""", since = "2.4.0") | ||
case class ArrayRemove(left: Expression, right: Expression) | ||
extends BinaryExpression with ImplicitCastInputTypes { | ||
|
||
override def dataType: DataType = left.dataType | ||
|
||
override def inputTypes: Seq[AbstractDataType] = | ||
Seq(ArrayType, left.dataType.asInstanceOf[ArrayType].elementType) | ||
|
||
lazy val elementType: DataType = left.dataType.asInstanceOf[ArrayType].elementType | ||
|
||
@transient private lazy val ordering: Ordering[Any] = | ||
TypeUtils.getInterpretedOrdering(right.dataType) | ||
|
||
override def checkInputDataTypes(): TypeCheckResult = { | ||
if (!left.dataType.isInstanceOf[ArrayType] | ||
|| left.dataType.asInstanceOf[ArrayType].elementType != right.dataType) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe we need to change here as well. |
||
TypeCheckResult.TypeCheckFailure( | ||
"Arguments must be an array followed by a value of same type as the array members") | ||
} else { | ||
TypeUtils.checkForOrderingExpr(right.dataType, s"function $prettyName") | ||
} | ||
} | ||
|
||
override def nullSafeEval(arr: Any, value: Any): Any = { | ||
val newArray = new Array[Any](arr.asInstanceOf[ArrayData].numElements()) | ||
var pos = 0 | ||
arr.asInstanceOf[ArrayData].foreach(right.dataType, (i, v) => | ||
if (v == null || !ordering.equiv(v, value)) { | ||
newArray(pos) = v | ||
pos += 1 | ||
} | ||
) | ||
new GenericArrayData(newArray.slice(0, pos)) | ||
} | ||
|
||
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { | ||
nullSafeCodeGen(ctx, ev, (arr, value) => { | ||
val numsToRemove = ctx.freshName("numsToRemove") | ||
val newArraySize = ctx.freshName("newArraySize") | ||
val i = ctx.freshName("i") | ||
val getValue = CodeGenerator.getValue(arr, elementType, i) | ||
val isEqual = ctx.genEqual(elementType, value, getValue) | ||
s""" | ||
|int $numsToRemove = 0; | ||
|for (int $i = 0; $i < $arr.numElements(); $i ++) { | ||
| if (!$arr.isNullAt($i) && $isEqual) { | ||
| $numsToRemove = $numsToRemove + 1; | ||
| } | ||
|} | ||
|int $newArraySize = $arr.numElements() - $numsToRemove; | ||
|${genCodeForResult(ctx, ev, arr, value, newArraySize)} | ||
""".stripMargin | ||
}) | ||
} | ||
|
||
def genCodeForResult( | ||
ctx: CodegenContext, | ||
ev: ExprCode, | ||
inputArray: String, | ||
value: String, | ||
newArraySize: String): String = { | ||
val values = ctx.freshName("values") | ||
val i = ctx.freshName("i") | ||
val pos = ctx.freshName("pos") | ||
val getValue = CodeGenerator.getValue(inputArray, elementType, i) | ||
val isEqual = ctx.genEqual(elementType, value, getValue) | ||
if (!CodeGenerator.isPrimitiveType(elementType)) { | ||
val arrayClass = classOf[GenericArrayData].getName | ||
s""" | ||
|int $pos = 0; | ||
|Object[] $values = new Object[$newArraySize]; | ||
|for (int $i = 0; $i < $inputArray.numElements(); $i ++) { | ||
| if (!($isEqual)) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't we need to check null? |
||
| $values[$pos] = $getValue; | ||
| $pos = $pos + 1; | ||
| } | ||
|} | ||
|${ev.value} = new $arrayClass($values); | ||
""".stripMargin | ||
} else { | ||
val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) | ||
s""" | ||
|${ctx.createUnsafeArray(values, newArraySize, elementType, s" $prettyName failed.")} | ||
|int $pos = 0; | ||
|for (int $i = 0; $i < $inputArray.numElements(); $i ++) { | ||
| if ($inputArray.isNullAt($i)) { | ||
| $values.setNullAt($pos); | ||
| $pos = $pos + 1; | ||
| } | ||
| else { | ||
| if (!($isEqual)) { | ||
| $values.set$primitiveValueTypeName($pos, $getValue); | ||
| $pos = $pos + 1; | ||
| } | ||
| } | ||
|} | ||
|${ev.value} = $values; | ||
""".stripMargin | ||
} | ||
} | ||
|
||
override def prettyName: String = "array_remove" | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -552,4 +552,60 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper | |
checkEvaluation(ArrayRepeat(strArray, Literal(2)), Seq(Seq("hi", "hola"), Seq("hi", "hola"))) | ||
checkEvaluation(ArrayRepeat(Literal("hi"), Literal(null, IntegerType)), null) | ||
} | ||
|
||
test("Array remove") { | ||
val a0 = Literal.create(Seq(1, 2, 3, 2, 2, 5), ArrayType(IntegerType)) | ||
val a1 = Literal.create(Seq("b", "a", "a", "c", "b"), ArrayType(StringType)) | ||
val a2 = Literal.create(Seq[String](null, "", null, ""), ArrayType(StringType)) | ||
val a3 = Literal.create(Seq.empty[Integer], ArrayType(IntegerType)) | ||
val a4 = Literal.create(null, ArrayType(StringType)) | ||
val a5 = Literal.create(Seq(1, null, 8, 9, null), ArrayType(IntegerType)) | ||
val a6 = Literal.create(Seq(true, false, false, true), ArrayType(BooleanType)) | ||
|
||
checkEvaluation(ArrayRemove(a0, Literal(0)), Seq(1, 2, 3, 2, 2, 5)) | ||
checkEvaluation(ArrayRemove(a0, Literal(1)), Seq(2, 3, 2, 2, 5)) | ||
checkEvaluation(ArrayRemove(a0, Literal(2)), Seq(1, 3, 5)) | ||
checkEvaluation(ArrayRemove(a0, Literal(3)), Seq(1, 2, 2, 2, 5)) | ||
checkEvaluation(ArrayRemove(a0, Literal(5)), Seq(1, 2, 3, 2, 2)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add a case for something like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @ueshin Thank you very much for your comments. I am very sorry for the late reply. I corrected everything except this one. I have |
||
checkEvaluation(ArrayRemove(a0, Literal(null, IntegerType)), null) | ||
|
||
checkEvaluation(ArrayRemove(a1, Literal("")), Seq("b", "a", "a", "c", "b")) | ||
checkEvaluation(ArrayRemove(a1, Literal("a")), Seq("b", "c", "b")) | ||
checkEvaluation(ArrayRemove(a1, Literal("b")), Seq("a", "a", "c")) | ||
checkEvaluation(ArrayRemove(a1, Literal("c")), Seq("b", "a", "a", "b")) | ||
|
||
checkEvaluation(ArrayRemove(a2, Literal("")), Seq(null, null)) | ||
checkEvaluation(ArrayRemove(a2, Literal(null, StringType)), null) | ||
|
||
checkEvaluation(ArrayRemove(a3, Literal(1)), Seq.empty[Integer]) | ||
|
||
checkEvaluation(ArrayRemove(a4, Literal("a")), null) | ||
|
||
checkEvaluation(ArrayRemove(a5, Literal(9)), Seq(1, null, 8, null)) | ||
checkEvaluation(ArrayRemove(a6, Literal(false)), Seq(true, true)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add a case for something like |
||
|
||
// complex data types | ||
val b0 = Literal.create(Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](1, 2), | ||
Array[Byte](1, 2), Array[Byte](5, 6)), ArrayType(BinaryType)) | ||
val b1 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), null), | ||
ArrayType(BinaryType)) | ||
val b2 = Literal.create(Seq[Array[Byte]](null, Array[Byte](1, 2)), | ||
ArrayType(BinaryType)) | ||
val nullBinary = Literal.create(null, BinaryType) | ||
|
||
val dataToRemoved1 = Literal.create(Array[Byte](5, 6), BinaryType) | ||
checkEvaluation(ArrayRemove(b0, dataToRemoved1), | ||
Seq[Array[Byte]](Array[Byte](1, 2), Array[Byte](1, 2))) | ||
checkEvaluation(ArrayRemove(b0, nullBinary), null) | ||
checkEvaluation(ArrayRemove(b1, dataToRemoved1), Seq[Array[Byte]](Array[Byte](2, 1), null)) | ||
checkEvaluation(ArrayRemove(b2, dataToRemoved1), Seq[Array[Byte]](null, Array[Byte](1, 2))) | ||
|
||
val c0 = Literal.create(Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4)), | ||
ArrayType(ArrayType(IntegerType))) | ||
val c1 = Literal.create(Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1)), | ||
ArrayType(ArrayType(IntegerType))) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What if for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @ueshin Thanks for your comments. I added c2 in the test and also fixed the other three issues. Could you please review one more time? Thanks! |
||
val dataToRemoved2 = Literal.create(Seq[Int](1, 2), ArrayType(IntegerType)) | ||
checkEvaluation(ArrayRemove(c0, dataToRemoved2), Seq[Seq[Int]](Seq[Int](3, 4))) | ||
checkEvaluation(ArrayRemove(c1, dataToRemoved2), Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1))) | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will cause
ClassCastException
. See #21401.Also could you add tests similar to tests added in #21401?