-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-42746][SQL] Implement LISTAGG function #48748
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-42746][SQL] Implement LISTAGG function #48748
Conversation
# Conflicts: # sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.tokens
# Conflicts: # sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
…qlBaseParser.g4 Co-authored-by: Jiaan Geng <beliefer@163.com>
…to SPARK-42746_listagg_function
}, | ||
"FUNCTION_AND_ORDER_EXPRESSION_MISMATCH" : { | ||
"message" : [ | ||
"The arguments <functionArgs> of the function <functionName> do not match to ordering within group <orderExpr> when use DISTINCT." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how about
Function <funcName> is invoked with DISTINCT. The WITHIN GROUP ordering expressions must be picked from the function inputs, but got <orderingExpr>.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can make it a sub error condition of INVALID_WITHIN_GROUP_EXPRESSION
: MISMATCH_WITH_DISTINCT_INPUT
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea. Now with the common INVALID_WITHIN_GROUP_EXPRESSION
prefix, it looks like Invalid function <funcName> with WITHIN GROUP. The function is invoked with DISTINCT and WITHIN GROUP but expressions <funcArg> and <orderingExpr> do not match. The WITHIN GROUP ordering expression must be picked from the function inputs.
|
||
/** | ||
* Sort buffer according orderExpressions. | ||
* If orderExpressions is empty them returns buffer as is. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
* If orderExpressions is empty them returns buffer as is. | |
* If orderExpressions is empty then returns buffer as is. |
-- !query schema | ||
struct<listagg(c1, NULL):binary> | ||
-- !query output | ||
ޭ�� |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The golden file test framework should print the hex string of binary values. We can improve it in followup PRs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @mitkedb for this
} | ||
} | ||
|
||
private[this] def hexToBytes(s: String): Array[Byte] = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not needed now.
|
||
override def inputTypes: Seq[AbstractDataType] = | ||
TypeCollection( | ||
StringTypeWithCollation(supportsTrimCollation = true), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how is trim collation supported?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As I understand collation only affects comparison so it's important only for DISTINCT
and ORDER BY
DISTINCT
is handled by the aggregation framework and it respects trim collationsORDER BY
is handled by the code withPhysicalDataType.ordering
that respects trim collations too
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added tests for trim collations
…N_GROUP_EXPRESSION.MISMATCH_WITH_DISTINCT_INPUT
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM except for https://github.com/apache/spark/pull/48748/files#r1852128035
return result; | ||
} | ||
|
||
public static byte[] concatWS(byte[] delimiter, byte[]... inputs) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you please add a comment saying what this function is doing?
result, Platform.BYTE_ARRAY_OFFSET + offset, | ||
len); | ||
offset += len; | ||
if(i < inputs.length - 1) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if(i < inputs.length - 1) { | |
if (i < inputs.length - 1) { |
for (int i = 0; i < inputs.length; i++) { | ||
byte[] input = inputs[i]; | ||
int len = input.length; | ||
Platform.copyMemory( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this seems copied from L154 above, please dedup into one place?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't want to accidentally change existing behavior or performance so I thought a little copy-paste was justified in this isolated code. But I probably concern too much)
Removed
dataType match { | ||
case BinaryType => | ||
val inputs = buffer.filter(_ != null).map(_.asInstanceOf[Array[Byte]]) | ||
ByteArray.concatWS(delimiterValue.asInstanceOf[Array[Byte]], inputs.toSeq: _*) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we repeat the .asInstanceOf[Array[Byte]]
two times here, can we use a pattern match to reduce this to one?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added more type strictness
ByteArray.concatWS(delimiterValue.asInstanceOf[Array[Byte]], inputs.toSeq: _*) | ||
case _: StringType => | ||
val inputs = buffer.filter(_ != null).map(_.asInstanceOf[UTF8String]) | ||
UTF8String.concatWs(delimiterValue.asInstanceOf[UTF8String], inputs.toSeq : _*) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These concatenations consume input memory without bound. Do we have some kind of limit to this? If we consume a very large disk-based input table in the aggregation it could crash the executors by running out of memory. We should probably create SQLConfs with max limits for these buffers.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it's a common problem for all collect_*
functions. As I tested, they now fail with OOM if the buffer is too big. And precentile_disc
doc says the same
Lines 403 to 407 in ad49fcf
* Because the number of elements and their partial order cannot be determined in advance. | |
* Therefore we have to store all the elements in memory, and so notice that too many elements can | |
* cause GC paused and eventually OutOfMemory Errors. | |
*/ | |
case class PercentileDisc( |
I think it's a common problem and should be handled in follow-ups.
input, Platform.BYTE_ARRAY_OFFSET, | ||
result, Platform.BYTE_ARRAY_OFFSET + offset, | ||
len); | ||
input, Platform.BYTE_ARRAY_OFFSET, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: 2 spaces indentation
offset += len; | ||
if (delimiter.length > 0 && i < inputs.length - 1) { | ||
Platform.copyMemory( | ||
delimiter, Platform.BYTE_ARRAY_OFFSET, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
val sortOrderExpression = orderExpressions.head | ||
val ascendingOrdering = PhysicalDataType.ordering(sortOrderExpression.dataType) | ||
val ordering = | ||
if (sortOrderExpression.direction == Ascending) ascendingOrdering |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SortOrder
has a nullOrdering
flag, shall we respect it here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm wondering if we should reuse the code in SortExec
to do sorting.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
listagg
filters all null values from result and in this case it's sorted by the same value, so null ordering does nothing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, but the Spark native sorter should be more efficient and support spilling.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can leave it for future optimization
thanks, merging to master! |
### What changes were proposed in this pull request? Added new function `listagg` to pyspark. Follow-up of #48748. ### Why are the changes needed? Allows to use native Python functions to write queries with `listagg`. E.g., `df.select(F.listagg(df.value, ",").alias("r"))`. ### Does this PR introduce _any_ user-facing change? Yes, new functions `listagg` and `listagg_distinct` (with aliases `string_agg` and `string_agg_distinct`) in pyspark. ### How was this patch tested? Unit tests ### Was this patch authored or co-authored using generative AI tooling? Generated-by: GitHub Copilot Closes #49231 from mikhailnik-db/SPARK-50220-listagg-for-pyspark. Authored-by: Mikhail Nikoliukin <mikhail.nikoliukin@databricks.com> Signed-off-by: Ruifeng Zheng <ruifengz@apache.org>
* @group agg_funcs | ||
* @since 4.0.0 | ||
*/ | ||
def listagg(e: Column, delimiter: Column): Column = Column.fn("listagg", e, delimiter) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Declaring the delimiter as String here can improve UX a bit. Since it only allows foldable string literals, we can rely on the compiler instead of runtime errors, WDYT @cloud-fan
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SGTM
What changes were proposed in this pull request?
Implement new aggregation function
listagg([ALL | DISTINCT] expr[, sep]) [WITHIN GROUP (ORDER BY key [ASC | DESC] [,...])]
Why are the changes needed?
Listagg is a popular function implemented by many other vendors. For now, users have to use workarounds like this. PR will close the gap.
Does this PR introduce any user-facing change?
Yes, the new
listagg
function. BigQuery and PostgreSQL have the same function but withstring_agg
name so I added it as an alias.How was this patch tested?
With new unit tests
Was this patch authored or co-authored using generative AI tooling?
Generated-by: GitHub Copilot