From 0fecbd9780af64ef783455c421eaadc2aa7e2074 Mon Sep 17 00:00:00 2001 From: Jia Fan Date: Sat, 5 Aug 2023 11:57:23 +0800 Subject: [PATCH 01/58] [SPARK-42746][SQL] Add the LIST_AGG() aggregate function --- .../main/resources/error/error-classes.json | 5 + docs/sql-error-conditions.md | 6 + docs/sql-ref-ansi-compliance.md | 1 + python/pyspark/sql/tests/test_functions.py | 5 +- .../spark/sql/catalyst/parser/SqlBaseLexer.g4 | 1 + .../sql/catalyst/parser/SqlBaseLexer.tokens | 1356 +++++++++-------- .../sql/catalyst/parser/SqlBaseParser.g4 | 4 + .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/aggregate/ListAgg.scala | 105 ++ .../sql/catalyst/parser/AstBuilder.scala | 34 +- .../sql/errors/QueryCompilationErrors.scala | 12 + .../org/apache/spark/sql/functions.scala | 54 + .../sql-functions/sql-expression-schema.md | 9 +- .../sql-tests/results/ansi/keywords.sql.out | 1 + .../sql-tests/results/keywords.sql.out | 1 + .../spark/sql/DataFrameAggregateSuite.scala | 39 + .../spark/sql/DataFrameFunctionsSuite.scala | 3 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 51 + .../ThriftServerWithSparkContextSuite.scala | 2 +- 19 files changed, 1039 insertions(+), 651 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ListAgg.scala diff --git a/common/utils/src/main/resources/error/error-classes.json b/common/utils/src/main/resources/error/error-classes.json index 0ea1eed35e463..631f84477d6cc 100644 --- a/common/utils/src/main/resources/error/error-classes.json +++ b/common/utils/src/main/resources/error/error-classes.json @@ -902,6 +902,11 @@ ], "sqlState" : "42809" }, + "FUNCTION_AND_ORDER_EXPRESSION_MISMATCH_ERROR" : { + "message" : [ + "The function arguments should match the order by expression ." + ] + }, "GENERATED_COLUMN_WITH_DEFAULT_VALUE" : { "message" : [ "A column cannot have both a default value and a generation expression but column has default value: () and generation expression: ()." diff --git a/docs/sql-error-conditions.md b/docs/sql-error-conditions.md index b59bb1789488e..ec297f3dd3329 100644 --- a/docs/sql-error-conditions.md +++ b/docs/sql-error-conditions.md @@ -562,6 +562,12 @@ No such struct field `` in ``. The operation `` is not allowed on the ``: ``. +### FUNCTION_AND_ORDER_EXPRESSION_MISMATCH_ERROR + +SQLSTATE: none assigned + +The function `` arguments `` should match the order by expression ``. + ### GENERATED_COLUMN_WITH_DEFAULT_VALUE SQLSTATE: none assigned diff --git a/docs/sql-ref-ansi-compliance.md b/docs/sql-ref-ansi-compliance.md index 84af522ad2185..084ecefc44062 100644 --- a/docs/sql-ref-ansi-compliance.md +++ b/docs/sql-ref-ansi-compliance.md @@ -505,6 +505,7 @@ Below is a list of all the keywords in Spark SQL. |LIMIT|non-reserved|non-reserved|non-reserved| |LINES|non-reserved|non-reserved|non-reserved| |LIST|non-reserved|non-reserved|non-reserved| +|LISTAGG|non-reserved|non-reserved|non-reserved| |LOAD|non-reserved|non-reserved|non-reserved| |LOCAL|non-reserved|non-reserved|reserved| |LOCATION|non-reserved|non-reserved|non-reserved| diff --git a/python/pyspark/sql/tests/test_functions.py b/python/pyspark/sql/tests/test_functions.py index 344dddb4a6410..70769b2c9cd3e 100644 --- a/python/pyspark/sql/tests/test_functions.py +++ b/python/pyspark/sql/tests/test_functions.py @@ -78,7 +78,10 @@ def test_function_parity(self): missing_in_py = jvm_fn_set.difference(py_fn_set) # Functions that we expect to be missing in python until they are added to pyspark - expected_missing_in_py = set() + expected_missing_in_py = { + "listagg_distinct", # TODO + "listagg" # TODO + } self.assertEqual( expected_missing_in_py, missing_in_py, "Missing functions in pyspark not as expected" diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 index 74e8ee1ecf9fe..6ccc8474e36b9 100644 --- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 @@ -244,6 +244,7 @@ ILIKE: 'ILIKE'; LIMIT: 'LIMIT'; LINES: 'LINES'; LIST: 'LIST'; +LISTAGG: 'LISTAGG'; LOAD: 'LOAD'; LOCAL: 'LOCAL'; LOCATION: 'LOCATION'; diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.tokens b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.tokens index 459749d8ffe6f..e29b1f7d64760 100644 --- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.tokens +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.tokens @@ -9,339 +9,374 @@ ADD=8 AFTER=9 ALL=10 ALTER=11 -ANALYZE=12 -AND=13 -ANTI=14 -ANY=15 -ANY_VALUE=16 -ARCHIVE=17 -ARRAY=18 -AS=19 -ASC=20 -AT=21 -AUTHORIZATION=22 -BETWEEN=23 -BOTH=24 -BUCKET=25 -BUCKETS=26 -BY=27 -CACHE=28 -CASCADE=29 -CASE=30 -CAST=31 -CATALOG=32 -CATALOGS=33 -CHANGE=34 -CHECK=35 -CLEAR=36 -CLUSTER=37 -CLUSTERED=38 -CODEGEN=39 -COLLATE=40 -COLLECTION=41 -COLUMN=42 -COLUMNS=43 -COMMENT=44 -COMMIT=45 -COMPACT=46 -COMPACTIONS=47 -COMPUTE=48 -CONCATENATE=49 -CONSTRAINT=50 -COST=51 -CREATE=52 -CROSS=53 -CUBE=54 -CURRENT=55 -CURRENT_DATE=56 -CURRENT_TIME=57 -CURRENT_TIMESTAMP=58 -CURRENT_USER=59 -DAY=60 -DAYS=61 -DAYOFYEAR=62 -DATA=63 -DATABASE=64 -DATABASES=65 -DATEADD=66 -DATEDIFF=67 -DBPROPERTIES=68 -DEFAULT=69 -DEFINED=70 -DELETE=71 -DELIMITED=72 -DESC=73 -DESCRIBE=74 -DFS=75 -DIRECTORIES=76 -DIRECTORY=77 -DISTINCT=78 -DISTRIBUTE=79 -DIV=80 -DROP=81 -ELSE=82 -END=83 -ESCAPE=84 -ESCAPED=85 -EXCEPT=86 -EXCHANGE=87 -EXCLUDE=88 -EXISTS=89 -EXPLAIN=90 -EXPORT=91 -EXTENDED=92 -EXTERNAL=93 -EXTRACT=94 -FALSE=95 -FETCH=96 -FIELDS=97 -FILTER=98 -FILEFORMAT=99 -FIRST=100 -FOLLOWING=101 -FOR=102 -FOREIGN=103 -FORMAT=104 -FORMATTED=105 -FROM=106 -FULL=107 -FUNCTION=108 -FUNCTIONS=109 -GLOBAL=110 -GRANT=111 -GROUP=112 -GROUPING=113 -HAVING=114 -HOUR=115 -HOURS=116 -IF=117 -IGNORE=118 -IMPORT=119 -IN=120 -INCLUDE=121 -INDEX=122 -INDEXES=123 -INNER=124 -INPATH=125 -INPUTFORMAT=126 -INSERT=127 -INTERSECT=128 -INTERVAL=129 -INTO=130 -IS=131 -ITEMS=132 -JOIN=133 -KEYS=134 -LAST=135 -LATERAL=136 -LAZY=137 -LEADING=138 -LEFT=139 -LIKE=140 -ILIKE=141 -LIMIT=142 -LINES=143 -LIST=144 -LOAD=145 -LOCAL=146 -LOCATION=147 -LOCK=148 -LOCKS=149 -LOGICAL=150 -MACRO=151 -MAP=152 -MATCHED=153 -MERGE=154 -MICROSECOND=155 -MICROSECONDS=156 -MILLISECOND=157 -MILLISECONDS=158 -MINUTE=159 -MINUTES=160 -MONTH=161 -MONTHS=162 -MSCK=163 -NAMESPACE=164 -NAMESPACES=165 -NANOSECOND=166 -NANOSECONDS=167 -NATURAL=168 -NO=169 -NOT=170 -NULL=171 -NULLS=172 -OF=173 -OFFSET=174 -ON=175 -ONLY=176 -OPTION=177 -OPTIONS=178 -OR=179 -ORDER=180 -OUT=181 -OUTER=182 -OUTPUTFORMAT=183 -OVER=184 -OVERLAPS=185 -OVERLAY=186 -OVERWRITE=187 -PARTITION=188 -PARTITIONED=189 -PARTITIONS=190 -PERCENTILE_CONT=191 -PERCENTILE_DISC=192 -PERCENTLIT=193 -PIVOT=194 -PLACING=195 -POSITION=196 -PRECEDING=197 -PRIMARY=198 -PRINCIPALS=199 -PROPERTIES=200 -PURGE=201 -QUARTER=202 -QUERY=203 -RANGE=204 -RECORDREADER=205 -RECORDWRITER=206 -RECOVER=207 -REDUCE=208 -REFERENCES=209 -REFRESH=210 -RENAME=211 -REPAIR=212 -REPEATABLE=213 -REPLACE=214 -RESET=215 -RESPECT=216 -RESTRICT=217 -REVOKE=218 -RIGHT=219 -RLIKE=220 -ROLE=221 -ROLES=222 -ROLLBACK=223 -ROLLUP=224 -ROW=225 -ROWS=226 -SECOND=227 -SECONDS=228 -SCHEMA=229 -SCHEMAS=230 -SELECT=231 -SEMI=232 -SEPARATED=233 -SERDE=234 -SERDEPROPERTIES=235 -SESSION_USER=236 -SET=237 -SETMINUS=238 -SETS=239 -SHOW=240 -SKEWED=241 -SOME=242 -SORT=243 -SORTED=244 -SOURCE=245 -START=246 -STATISTICS=247 -STORED=248 -STRATIFY=249 -STRUCT=250 -SUBSTR=251 -SUBSTRING=252 -SYNC=253 -SYSTEM_TIME=254 -SYSTEM_VERSION=255 -TABLE=256 -TABLES=257 -TABLESAMPLE=258 -TARGET=259 -TBLPROPERTIES=260 -TEMPORARY=261 -TERMINATED=262 -THEN=263 -TIME=264 -TIMESTAMP=265 -TIMESTAMPADD=266 -TIMESTAMPDIFF=267 -TO=268 -TOUCH=269 -TRAILING=270 -TRANSACTION=271 -TRANSACTIONS=272 -TRANSFORM=273 -TRIM=274 -TRUE=275 -TRUNCATE=276 -TRY_CAST=277 -TYPE=278 -UNARCHIVE=279 -UNBOUNDED=280 -UNCACHE=281 -UNION=282 -UNIQUE=283 -UNKNOWN=284 -UNLOCK=285 -UNPIVOT=286 -UNSET=287 -UPDATE=288 -USE=289 -USER=290 -USING=291 -VALUES=292 -VERSION=293 -VIEW=294 -VIEWS=295 -WEEK=296 -WEEKS=297 -WHEN=298 -WHERE=299 -WINDOW=300 -WITH=301 -WITHIN=302 -YEAR=303 -YEARS=304 -ZONE=305 -EQ=306 -NSEQ=307 -NEQ=308 -NEQJ=309 -LT=310 -LTE=311 -GT=312 -GTE=313 -PLUS=314 -MINUS=315 -ASTERISK=316 -SLASH=317 -PERCENT=318 -TILDE=319 -AMPERSAND=320 -PIPE=321 -CONCAT_PIPE=322 -HAT=323 -COLON=324 -ARROW=325 -HENT_START=326 -HENT_END=327 -STRING=328 -DOUBLEQUOTED_STRING=329 -BIGINT_LITERAL=330 -SMALLINT_LITERAL=331 -TINYINT_LITERAL=332 -INTEGER_VALUE=333 -EXPONENT_VALUE=334 -DECIMAL_VALUE=335 -FLOAT_LITERAL=336 -DOUBLE_LITERAL=337 -BIGDECIMAL_LITERAL=338 -IDENTIFIER=339 -BACKQUOTED_IDENTIFIER=340 -SIMPLE_COMMENT=341 -BRACKETED_COMMENT=342 -WS=343 -UNRECOGNIZED=344 +ALWAYS=12 +ANALYZE=13 +AND=14 +ANTI=15 +ANY=16 +ANY_VALUE=17 +ARCHIVE=18 +ARRAY=19 +AS=20 +ASC=21 +AT=22 +AUTHORIZATION=23 +BETWEEN=24 +BIGINT=25 +BINARY=26 +BOOLEAN=27 +BOTH=28 +BUCKET=29 +BUCKETS=30 +BY=31 +BYTE=32 +CACHE=33 +CASCADE=34 +CASE=35 +CAST=36 +CATALOG=37 +CATALOGS=38 +CHANGE=39 +CHAR=40 +CHARACTER=41 +CHECK=42 +CLEAR=43 +CLUSTER=44 +CLUSTERED=45 +CODEGEN=46 +COLLATE=47 +COLLECTION=48 +COLUMN=49 +COLUMNS=50 +COMMENT=51 +COMMIT=52 +COMPACT=53 +COMPACTIONS=54 +COMPUTE=55 +CONCATENATE=56 +CONSTRAINT=57 +COST=58 +CREATE=59 +CROSS=60 +CUBE=61 +CURRENT=62 +CURRENT_DATE=63 +CURRENT_TIME=64 +CURRENT_TIMESTAMP=65 +CURRENT_USER=66 +DAY=67 +DAYS=68 +DAYOFYEAR=69 +DATA=70 +DATE=71 +DATABASE=72 +DATABASES=73 +DATEADD=74 +DATE_ADD=75 +DATEDIFF=76 +DATE_DIFF=77 +DBPROPERTIES=78 +DEC=79 +DECIMAL=80 +DEFAULT=81 +DEFINED=82 +DELETE=83 +DELIMITED=84 +DESC=85 +DESCRIBE=86 +DFS=87 +DIRECTORIES=88 +DIRECTORY=89 +DISTINCT=90 +DISTRIBUTE=91 +DIV=92 +DOUBLE=93 +DROP=94 +ELSE=95 +END=96 +ESCAPE=97 +ESCAPED=98 +EXCEPT=99 +EXCHANGE=100 +EXCLUDE=101 +EXISTS=102 +EXPLAIN=103 +EXPORT=104 +EXTENDED=105 +EXTERNAL=106 +EXTRACT=107 +FALSE=108 +FETCH=109 +FIELDS=110 +FILTER=111 +FILEFORMAT=112 +FIRST=113 +FLOAT=114 +FOLLOWING=115 +FOR=116 +FOREIGN=117 +FORMAT=118 +FORMATTED=119 +FROM=120 +FULL=121 +FUNCTION=122 +FUNCTIONS=123 +GENERATED=124 +GLOBAL=125 +GRANT=126 +GROUP=127 +GROUPING=128 +HAVING=129 +BINARY_HEX=130 +HOUR=131 +HOURS=132 +IDENTIFIER_KW=133 +IF=134 +IGNORE=135 +IMPORT=136 +IN=137 +INCLUDE=138 +INDEX=139 +INDEXES=140 +INNER=141 +INPATH=142 +INPUTFORMAT=143 +INSERT=144 +INTERSECT=145 +INTERVAL=146 +INT=147 +INTEGER=148 +INTO=149 +IS=150 +ITEMS=151 +JOIN=152 +KEYS=153 +LAST=154 +LATERAL=155 +LAZY=156 +LEADING=157 +LEFT=158 +LIKE=159 +ILIKE=160 +LIMIT=161 +LINES=162 +LIST=163 +LISTAGG=164 +LOAD=165 +LOCAL=166 +LOCATION=167 +LOCK=168 +LOCKS=169 +LOGICAL=170 +LONG=171 +MACRO=172 +MAP=173 +MATCHED=174 +MERGE=175 +MICROSECOND=176 +MICROSECONDS=177 +MILLISECOND=178 +MILLISECONDS=179 +MINUTE=180 +MINUTES=181 +MONTH=182 +MONTHS=183 +MSCK=184 +NAME=185 +NAMESPACE=186 +NAMESPACES=187 +NANOSECOND=188 +NANOSECONDS=189 +NATURAL=190 +NO=191 +NOT=192 +NULL=193 +NULLS=194 +NUMERIC=195 +OF=196 +OFFSET=197 +ON=198 +ONLY=199 +OPTION=200 +OPTIONS=201 +OR=202 +ORDER=203 +OUT=204 +OUTER=205 +OUTPUTFORMAT=206 +OVER=207 +OVERLAPS=208 +OVERLAY=209 +OVERWRITE=210 +PARTITION=211 +PARTITIONED=212 +PARTITIONS=213 +PERCENTILE_CONT=214 +PERCENTILE_DISC=215 +PERCENTLIT=216 +PIVOT=217 +PLACING=218 +POSITION=219 +PRECEDING=220 +PRIMARY=221 +PRINCIPALS=222 +PROPERTIES=223 +PURGE=224 +QUARTER=225 +QUERY=226 +RANGE=227 +REAL=228 +RECORDREADER=229 +RECORDWRITER=230 +RECOVER=231 +REDUCE=232 +REFERENCES=233 +REFRESH=234 +RENAME=235 +REPAIR=236 +REPEATABLE=237 +REPLACE=238 +RESET=239 +RESPECT=240 +RESTRICT=241 +REVOKE=242 +RIGHT=243 +RLIKE=244 +ROLE=245 +ROLES=246 +ROLLBACK=247 +ROLLUP=248 +ROW=249 +ROWS=250 +SECOND=251 +SECONDS=252 +SCHEMA=253 +SCHEMAS=254 +SELECT=255 +SEMI=256 +SEPARATED=257 +SERDE=258 +SERDEPROPERTIES=259 +SESSION_USER=260 +SET=261 +SETMINUS=262 +SETS=263 +SHORT=264 +SHOW=265 +SINGLE=266 +SKEWED=267 +SMALLINT=268 +SOME=269 +SORT=270 +SORTED=271 +SOURCE=272 +START=273 +STATISTICS=274 +STORED=275 +STRATIFY=276 +STRING=277 +STRUCT=278 +SUBSTR=279 +SUBSTRING=280 +SYNC=281 +SYSTEM_TIME=282 +SYSTEM_VERSION=283 +TABLE=284 +TABLES=285 +TABLESAMPLE=286 +TARGET=287 +TBLPROPERTIES=288 +TEMPORARY=289 +TERMINATED=290 +THEN=291 +TIME=292 +TIMESTAMP=293 +TIMESTAMP_LTZ=294 +TIMESTAMP_NTZ=295 +TIMESTAMPADD=296 +TIMESTAMPDIFF=297 +TINYINT=298 +TO=299 +TOUCH=300 +TRAILING=301 +TRANSACTION=302 +TRANSACTIONS=303 +TRANSFORM=304 +TRIM=305 +TRUE=306 +TRUNCATE=307 +TRY_CAST=308 +TYPE=309 +UNARCHIVE=310 +UNBOUNDED=311 +UNCACHE=312 +UNION=313 +UNIQUE=314 +UNKNOWN=315 +UNLOCK=316 +UNPIVOT=317 +UNSET=318 +UPDATE=319 +USE=320 +USER=321 +USING=322 +VALUES=323 +VARCHAR=324 +VERSION=325 +VIEW=326 +VIEWS=327 +VOID=328 +WEEK=329 +WEEKS=330 +WHEN=331 +WHERE=332 +WINDOW=333 +WITH=334 +WITHIN=335 +YEAR=336 +YEARS=337 +ZONE=338 +EQ=339 +NSEQ=340 +NEQ=341 +NEQJ=342 +LT=343 +LTE=344 +GT=345 +GTE=346 +PLUS=347 +MINUS=348 +ASTERISK=349 +SLASH=350 +PERCENT=351 +TILDE=352 +AMPERSAND=353 +PIPE=354 +CONCAT_PIPE=355 +HAT=356 +COLON=357 +ARROW=358 +FAT_ARROW=359 +HENT_START=360 +HENT_END=361 +QUESTION=362 +STRING_LITERAL=363 +DOUBLEQUOTED_STRING=364 +BIGINT_LITERAL=365 +SMALLINT_LITERAL=366 +TINYINT_LITERAL=367 +INTEGER_VALUE=368 +EXPONENT_VALUE=369 +DECIMAL_VALUE=370 +FLOAT_LITERAL=371 +DOUBLE_LITERAL=372 +BIGDECIMAL_LITERAL=373 +IDENTIFIER=374 +BACKQUOTED_IDENTIFIER=375 +SIMPLE_COMMENT=376 +BRACKETED_COMMENT=377 +WS=378 +UNRECOGNIZED=379 ';'=1 '('=2 ')'=3 @@ -353,313 +388,348 @@ UNRECOGNIZED=344 'AFTER'=9 'ALL'=10 'ALTER'=11 -'ANALYZE'=12 -'AND'=13 -'ANTI'=14 -'ANY'=15 -'ANY_VALUE'=16 -'ARCHIVE'=17 -'ARRAY'=18 -'AS'=19 -'ASC'=20 -'AT'=21 -'AUTHORIZATION'=22 -'BETWEEN'=23 -'BOTH'=24 -'BUCKET'=25 -'BUCKETS'=26 -'BY'=27 -'CACHE'=28 -'CASCADE'=29 -'CASE'=30 -'CAST'=31 -'CATALOG'=32 -'CATALOGS'=33 -'CHANGE'=34 -'CHECK'=35 -'CLEAR'=36 -'CLUSTER'=37 -'CLUSTERED'=38 -'CODEGEN'=39 -'COLLATE'=40 -'COLLECTION'=41 -'COLUMN'=42 -'COLUMNS'=43 -'COMMENT'=44 -'COMMIT'=45 -'COMPACT'=46 -'COMPACTIONS'=47 -'COMPUTE'=48 -'CONCATENATE'=49 -'CONSTRAINT'=50 -'COST'=51 -'CREATE'=52 -'CROSS'=53 -'CUBE'=54 -'CURRENT'=55 -'CURRENT_DATE'=56 -'CURRENT_TIME'=57 -'CURRENT_TIMESTAMP'=58 -'CURRENT_USER'=59 -'DAY'=60 -'DAYS'=61 -'DAYOFYEAR'=62 -'DATA'=63 -'DATABASE'=64 -'DATABASES'=65 -'DATEADD'=66 -'DATEDIFF'=67 -'DBPROPERTIES'=68 -'DEFAULT'=69 -'DEFINED'=70 -'DELETE'=71 -'DELIMITED'=72 -'DESC'=73 -'DESCRIBE'=74 -'DFS'=75 -'DIRECTORIES'=76 -'DIRECTORY'=77 -'DISTINCT'=78 -'DISTRIBUTE'=79 -'DIV'=80 -'DROP'=81 -'ELSE'=82 -'END'=83 -'ESCAPE'=84 -'ESCAPED'=85 -'EXCEPT'=86 -'EXCHANGE'=87 -'EXCLUDE'=88 -'EXISTS'=89 -'EXPLAIN'=90 -'EXPORT'=91 -'EXTENDED'=92 -'EXTERNAL'=93 -'EXTRACT'=94 -'FALSE'=95 -'FETCH'=96 -'FIELDS'=97 -'FILTER'=98 -'FILEFORMAT'=99 -'FIRST'=100 -'FOLLOWING'=101 -'FOR'=102 -'FOREIGN'=103 -'FORMAT'=104 -'FORMATTED'=105 -'FROM'=106 -'FULL'=107 -'FUNCTION'=108 -'FUNCTIONS'=109 -'GLOBAL'=110 -'GRANT'=111 -'GROUP'=112 -'GROUPING'=113 -'HAVING'=114 -'HOUR'=115 -'HOURS'=116 -'IF'=117 -'IGNORE'=118 -'IMPORT'=119 -'IN'=120 -'INCLUDE'=121 -'INDEX'=122 -'INDEXES'=123 -'INNER'=124 -'INPATH'=125 -'INPUTFORMAT'=126 -'INSERT'=127 -'INTERSECT'=128 -'INTERVAL'=129 -'INTO'=130 -'IS'=131 -'ITEMS'=132 -'JOIN'=133 -'KEYS'=134 -'LAST'=135 -'LATERAL'=136 -'LAZY'=137 -'LEADING'=138 -'LEFT'=139 -'LIKE'=140 -'ILIKE'=141 -'LIMIT'=142 -'LINES'=143 -'LIST'=144 -'LOAD'=145 -'LOCAL'=146 -'LOCATION'=147 -'LOCK'=148 -'LOCKS'=149 -'LOGICAL'=150 -'MACRO'=151 -'MAP'=152 -'MATCHED'=153 -'MERGE'=154 -'MICROSECOND'=155 -'MICROSECONDS'=156 -'MILLISECOND'=157 -'MILLISECONDS'=158 -'MINUTE'=159 -'MINUTES'=160 -'MONTH'=161 -'MONTHS'=162 -'MSCK'=163 -'NAMESPACE'=164 -'NAMESPACES'=165 -'NANOSECOND'=166 -'NANOSECONDS'=167 -'NATURAL'=168 -'NO'=169 -'NULL'=171 -'NULLS'=172 -'OF'=173 -'OFFSET'=174 -'ON'=175 -'ONLY'=176 -'OPTION'=177 -'OPTIONS'=178 -'OR'=179 -'ORDER'=180 -'OUT'=181 -'OUTER'=182 -'OUTPUTFORMAT'=183 -'OVER'=184 -'OVERLAPS'=185 -'OVERLAY'=186 -'OVERWRITE'=187 -'PARTITION'=188 -'PARTITIONED'=189 -'PARTITIONS'=190 -'PERCENTILE_CONT'=191 -'PERCENTILE_DISC'=192 -'PERCENT'=193 -'PIVOT'=194 -'PLACING'=195 -'POSITION'=196 -'PRECEDING'=197 -'PRIMARY'=198 -'PRINCIPALS'=199 -'PROPERTIES'=200 -'PURGE'=201 -'QUARTER'=202 -'QUERY'=203 -'RANGE'=204 -'RECORDREADER'=205 -'RECORDWRITER'=206 -'RECOVER'=207 -'REDUCE'=208 -'REFERENCES'=209 -'REFRESH'=210 -'RENAME'=211 -'REPAIR'=212 -'REPEATABLE'=213 -'REPLACE'=214 -'RESET'=215 -'RESPECT'=216 -'RESTRICT'=217 -'REVOKE'=218 -'RIGHT'=219 -'ROLE'=221 -'ROLES'=222 -'ROLLBACK'=223 -'ROLLUP'=224 -'ROW'=225 -'ROWS'=226 -'SECOND'=227 -'SECONDS'=228 -'SCHEMA'=229 -'SCHEMAS'=230 -'SELECT'=231 -'SEMI'=232 -'SEPARATED'=233 -'SERDE'=234 -'SERDEPROPERTIES'=235 -'SESSION_USER'=236 -'SET'=237 -'MINUS'=238 -'SETS'=239 -'SHOW'=240 -'SKEWED'=241 -'SOME'=242 -'SORT'=243 -'SORTED'=244 -'SOURCE'=245 -'START'=246 -'STATISTICS'=247 -'STORED'=248 -'STRATIFY'=249 -'STRUCT'=250 -'SUBSTR'=251 -'SUBSTRING'=252 -'SYNC'=253 -'SYSTEM_TIME'=254 -'SYSTEM_VERSION'=255 -'TABLE'=256 -'TABLES'=257 -'TABLESAMPLE'=258 -'TARGET'=259 -'TBLPROPERTIES'=260 -'TERMINATED'=262 -'THEN'=263 -'TIME'=264 -'TIMESTAMP'=265 -'TIMESTAMPADD'=266 -'TIMESTAMPDIFF'=267 -'TO'=268 -'TOUCH'=269 -'TRAILING'=270 -'TRANSACTION'=271 -'TRANSACTIONS'=272 -'TRANSFORM'=273 -'TRIM'=274 -'TRUE'=275 -'TRUNCATE'=276 -'TRY_CAST'=277 -'TYPE'=278 -'UNARCHIVE'=279 -'UNBOUNDED'=280 -'UNCACHE'=281 -'UNION'=282 -'UNIQUE'=283 -'UNKNOWN'=284 -'UNLOCK'=285 -'UNPIVOT'=286 -'UNSET'=287 -'UPDATE'=288 -'USE'=289 -'USER'=290 -'USING'=291 -'VALUES'=292 -'VERSION'=293 -'VIEW'=294 -'VIEWS'=295 -'WEEK'=296 -'WEEKS'=297 -'WHEN'=298 -'WHERE'=299 -'WINDOW'=300 -'WITH'=301 -'WITHIN'=302 -'YEAR'=303 -'YEARS'=304 -'ZONE'=305 -'<=>'=307 -'<>'=308 -'!='=309 -'<'=310 -'>'=312 -'+'=314 -'-'=315 -'*'=316 -'/'=317 -'%'=318 -'~'=319 -'&'=320 -'|'=321 -'||'=322 -'^'=323 -':'=324 -'->'=325 -'/*+'=326 -'*/'=327 +'ALWAYS'=12 +'ANALYZE'=13 +'AND'=14 +'ANTI'=15 +'ANY'=16 +'ANY_VALUE'=17 +'ARCHIVE'=18 +'ARRAY'=19 +'AS'=20 +'ASC'=21 +'AT'=22 +'AUTHORIZATION'=23 +'BETWEEN'=24 +'BIGINT'=25 +'BINARY'=26 +'BOOLEAN'=27 +'BOTH'=28 +'BUCKET'=29 +'BUCKETS'=30 +'BY'=31 +'BYTE'=32 +'CACHE'=33 +'CASCADE'=34 +'CASE'=35 +'CAST'=36 +'CATALOG'=37 +'CATALOGS'=38 +'CHANGE'=39 +'CHAR'=40 +'CHARACTER'=41 +'CHECK'=42 +'CLEAR'=43 +'CLUSTER'=44 +'CLUSTERED'=45 +'CODEGEN'=46 +'COLLATE'=47 +'COLLECTION'=48 +'COLUMN'=49 +'COLUMNS'=50 +'COMMENT'=51 +'COMMIT'=52 +'COMPACT'=53 +'COMPACTIONS'=54 +'COMPUTE'=55 +'CONCATENATE'=56 +'CONSTRAINT'=57 +'COST'=58 +'CREATE'=59 +'CROSS'=60 +'CUBE'=61 +'CURRENT'=62 +'CURRENT_DATE'=63 +'CURRENT_TIME'=64 +'CURRENT_TIMESTAMP'=65 +'CURRENT_USER'=66 +'DAY'=67 +'DAYS'=68 +'DAYOFYEAR'=69 +'DATA'=70 +'DATE'=71 +'DATABASE'=72 +'DATABASES'=73 +'DATEADD'=74 +'DATE_ADD'=75 +'DATEDIFF'=76 +'DATE_DIFF'=77 +'DBPROPERTIES'=78 +'DEC'=79 +'DECIMAL'=80 +'DEFAULT'=81 +'DEFINED'=82 +'DELETE'=83 +'DELIMITED'=84 +'DESC'=85 +'DESCRIBE'=86 +'DFS'=87 +'DIRECTORIES'=88 +'DIRECTORY'=89 +'DISTINCT'=90 +'DISTRIBUTE'=91 +'DIV'=92 +'DOUBLE'=93 +'DROP'=94 +'ELSE'=95 +'END'=96 +'ESCAPE'=97 +'ESCAPED'=98 +'EXCEPT'=99 +'EXCHANGE'=100 +'EXCLUDE'=101 +'EXISTS'=102 +'EXPLAIN'=103 +'EXPORT'=104 +'EXTENDED'=105 +'EXTERNAL'=106 +'EXTRACT'=107 +'FALSE'=108 +'FETCH'=109 +'FIELDS'=110 +'FILTER'=111 +'FILEFORMAT'=112 +'FIRST'=113 +'FLOAT'=114 +'FOLLOWING'=115 +'FOR'=116 +'FOREIGN'=117 +'FORMAT'=118 +'FORMATTED'=119 +'FROM'=120 +'FULL'=121 +'FUNCTION'=122 +'FUNCTIONS'=123 +'GENERATED'=124 +'GLOBAL'=125 +'GRANT'=126 +'GROUP'=127 +'GROUPING'=128 +'HAVING'=129 +'X'=130 +'HOUR'=131 +'HOURS'=132 +'IDENTIFIER'=133 +'IF'=134 +'IGNORE'=135 +'IMPORT'=136 +'IN'=137 +'INCLUDE'=138 +'INDEX'=139 +'INDEXES'=140 +'INNER'=141 +'INPATH'=142 +'INPUTFORMAT'=143 +'INSERT'=144 +'INTERSECT'=145 +'INTERVAL'=146 +'INT'=147 +'INTEGER'=148 +'INTO'=149 +'IS'=150 +'ITEMS'=151 +'JOIN'=152 +'KEYS'=153 +'LAST'=154 +'LATERAL'=155 +'LAZY'=156 +'LEADING'=157 +'LEFT'=158 +'LIKE'=159 +'ILIKE'=160 +'LIMIT'=161 +'LINES'=162 +'LIST'=163 +'LISTAGG'=164 +'LOAD'=165 +'LOCAL'=166 +'LOCATION'=167 +'LOCK'=168 +'LOCKS'=169 +'LOGICAL'=170 +'LONG'=171 +'MACRO'=172 +'MAP'=173 +'MATCHED'=174 +'MERGE'=175 +'MICROSECOND'=176 +'MICROSECONDS'=177 +'MILLISECOND'=178 +'MILLISECONDS'=179 +'MINUTE'=180 +'MINUTES'=181 +'MONTH'=182 +'MONTHS'=183 +'MSCK'=184 +'NAME'=185 +'NAMESPACE'=186 +'NAMESPACES'=187 +'NANOSECOND'=188 +'NANOSECONDS'=189 +'NATURAL'=190 +'NO'=191 +'NULL'=193 +'NULLS'=194 +'NUMERIC'=195 +'OF'=196 +'OFFSET'=197 +'ON'=198 +'ONLY'=199 +'OPTION'=200 +'OPTIONS'=201 +'OR'=202 +'ORDER'=203 +'OUT'=204 +'OUTER'=205 +'OUTPUTFORMAT'=206 +'OVER'=207 +'OVERLAPS'=208 +'OVERLAY'=209 +'OVERWRITE'=210 +'PARTITION'=211 +'PARTITIONED'=212 +'PARTITIONS'=213 +'PERCENTILE_CONT'=214 +'PERCENTILE_DISC'=215 +'PERCENT'=216 +'PIVOT'=217 +'PLACING'=218 +'POSITION'=219 +'PRECEDING'=220 +'PRIMARY'=221 +'PRINCIPALS'=222 +'PROPERTIES'=223 +'PURGE'=224 +'QUARTER'=225 +'QUERY'=226 +'RANGE'=227 +'REAL'=228 +'RECORDREADER'=229 +'RECORDWRITER'=230 +'RECOVER'=231 +'REDUCE'=232 +'REFERENCES'=233 +'REFRESH'=234 +'RENAME'=235 +'REPAIR'=236 +'REPEATABLE'=237 +'REPLACE'=238 +'RESET'=239 +'RESPECT'=240 +'RESTRICT'=241 +'REVOKE'=242 +'RIGHT'=243 +'ROLE'=245 +'ROLES'=246 +'ROLLBACK'=247 +'ROLLUP'=248 +'ROW'=249 +'ROWS'=250 +'SECOND'=251 +'SECONDS'=252 +'SCHEMA'=253 +'SCHEMAS'=254 +'SELECT'=255 +'SEMI'=256 +'SEPARATED'=257 +'SERDE'=258 +'SERDEPROPERTIES'=259 +'SESSION_USER'=260 +'SET'=261 +'MINUS'=262 +'SETS'=263 +'SHORT'=264 +'SHOW'=265 +'SINGLE'=266 +'SKEWED'=267 +'SMALLINT'=268 +'SOME'=269 +'SORT'=270 +'SORTED'=271 +'SOURCE'=272 +'START'=273 +'STATISTICS'=274 +'STORED'=275 +'STRATIFY'=276 +'STRING'=277 +'STRUCT'=278 +'SUBSTR'=279 +'SUBSTRING'=280 +'SYNC'=281 +'SYSTEM_TIME'=282 +'SYSTEM_VERSION'=283 +'TABLE'=284 +'TABLES'=285 +'TABLESAMPLE'=286 +'TARGET'=287 +'TBLPROPERTIES'=288 +'TERMINATED'=290 +'THEN'=291 +'TIME'=292 +'TIMESTAMP'=293 +'TIMESTAMP_LTZ'=294 +'TIMESTAMP_NTZ'=295 +'TIMESTAMPADD'=296 +'TIMESTAMPDIFF'=297 +'TINYINT'=298 +'TO'=299 +'TOUCH'=300 +'TRAILING'=301 +'TRANSACTION'=302 +'TRANSACTIONS'=303 +'TRANSFORM'=304 +'TRIM'=305 +'TRUE'=306 +'TRUNCATE'=307 +'TRY_CAST'=308 +'TYPE'=309 +'UNARCHIVE'=310 +'UNBOUNDED'=311 +'UNCACHE'=312 +'UNION'=313 +'UNIQUE'=314 +'UNKNOWN'=315 +'UNLOCK'=316 +'UNPIVOT'=317 +'UNSET'=318 +'UPDATE'=319 +'USE'=320 +'USER'=321 +'USING'=322 +'VALUES'=323 +'VARCHAR'=324 +'VERSION'=325 +'VIEW'=326 +'VIEWS'=327 +'VOID'=328 +'WEEK'=329 +'WEEKS'=330 +'WHEN'=331 +'WHERE'=332 +'WINDOW'=333 +'WITH'=334 +'WITHIN'=335 +'YEAR'=336 +'YEARS'=337 +'ZONE'=338 +'<=>'=340 +'<>'=341 +'!='=342 +'<'=343 +'>'=345 +'+'=347 +'-'=348 +'*'=349 +'/'=350 +'%'=351 +'~'=352 +'&'=353 +'|'=354 +'||'=355 +'^'=356 +':'=357 +'->'=358 +'=>'=359 +'/*+'=360 +'*/'=361 +'?'=362 diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 index 1ea0f6e583d2c..a2462ade22217 100644 --- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 @@ -982,6 +982,8 @@ primaryExpression | name=(PERCENTILE_CONT | PERCENTILE_DISC) LEFT_PAREN percentage=valueExpression RIGHT_PAREN WITHIN GROUP LEFT_PAREN ORDER BY sortItem RIGHT_PAREN (FILTER LEFT_PAREN WHERE where=booleanExpression RIGHT_PAREN)? ( OVER windowSpec)? #percentile + | LISTAGG LEFT_PAREN setQuantifier? aggEpxr=expression (COMMA delimiter=stringLit)? RIGHT_PAREN + WITHIN GROUP LEFT_PAREN ORDER BY sortItem RIGHT_PAREN ( OVER windowSpec)? #listAgg ; literalType @@ -1393,6 +1395,7 @@ ansiNonReserved | LIMIT | LINES | LIST + | LISTAGG | LOAD | LOCAL | LOCATION @@ -1716,6 +1719,7 @@ nonReserved | LIMIT | LINES | LIST + | LISTAGG | LOAD | LOCAL | LOCATION diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 558579cdb80ac..6eeba1a94c9c8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -477,6 +477,7 @@ object FunctionRegistry { expression[Percentile]("percentile"), expression[Median]("median"), expression[Skewness]("skewness"), + expression[ListAgg]("listagg"), expression[ApproximatePercentile]("percentile_approx"), expression[ApproximatePercentile]("approx_percentile", true), expression[HistogramNumeric]("histogram_numeric"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ListAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ListAgg.scala new file mode 100644 index 0000000000000..f7819b2dd206e --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ListAgg.scala @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.trees.UnaryLike +import org.apache.spark.sql.catalyst.types.PhysicalDataType +import org.apache.spark.sql.types.{DataType, StringType} +import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.collection.OpenHashMap + +@ExpressionDescription( + usage = "_FUNC_(expr) - Returns the concatenated input values," + + " separated by the delimiter string.", + examples = """ + Examples: + > SELECT _FUNC_(col) FROM VALUES ('a'), ('b'), ('c') AS tab(col); + a,b,c + > SELECT _FUNC_(col) FROM VALUES (NULL), ('a'), ('b') AS tab(col); + a,b + > SELECT _FUNC_(col, '|') FROM VALUES ('a'), ('b') AS tab(col); + a|b + > SELECT _FUNC_(col) FROM VALUES (NULL), (NULL) AS tab(col); + NULL + """, + group = "agg_funcs", + since = "4.0.0") +case class ListAgg( + child: Expression, + delimiter: Expression = Literal.create(",", StringType), + reverse: Boolean = false, + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) extends TypedAggregateWithHashMapAsBuffer + with UnaryLike[Expression] { + + def this(child: Expression) = this(child, Literal.create(",", StringType), false, 0, 0) + def this(child: Expression, delimiter: Expression) = this(child, delimiter, false, 0, 0) + + override def update( + buffer: OpenHashMap[AnyRef, Long], + input: InternalRow): OpenHashMap[AnyRef, Long] = { + val value = child.eval(input) + if (value != null) { + val key = InternalRow.copyValue(value) + buffer.changeValue(key.asInstanceOf[AnyRef], 1L, _ + 1L) + } + buffer + } + + override def merge( + buffer: OpenHashMap[AnyRef, Long], + input: OpenHashMap[AnyRef, Long]): OpenHashMap[AnyRef, Long] = { + input.foreach { case (key, count) => + buffer.changeValue(key, count, _ + count) + } + buffer + } + + override def eval(buffer: OpenHashMap[AnyRef, Long]): Any = { + if (buffer.nonEmpty) { + val ordering = PhysicalDataType.ordering(child.dataType) + val sortedCounts = if (reverse) { + buffer.toSeq.sortBy(_._1)(ordering.asInstanceOf[Ordering[AnyRef]].reverse) + } else { + buffer.toSeq.sortBy(_._1)(ordering.asInstanceOf[Ordering[AnyRef]]) + } + UTF8String.fromString(sortedCounts.map(kc => { + List.fill(kc._2.toInt)(kc._1.toString).mkString(delimiter.eval() + .asInstanceOf[UTF8String].toString) + }).mkString(delimiter.eval().asInstanceOf[UTF8String].toString)) + } else { + null + } + } + + override def withNewMutableAggBufferOffset( + newMutableAggBufferOffset: Int) : ImperativeAggregate = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) + + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = + copy(inputAggBufferOffset = newInputAggBufferOffset) + + override def nullable: Boolean = true + + override def dataType: DataType = StringType + + override protected def withNewChildInternal(newChild: Expression): Expression = + copy(child = newChild) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 0635e6a1b44fc..70152abaa3004 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.catalyst.{FunctionIdentifier, SQLConfHelper, TableId import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.{AnyValue, First, Last, PercentileCont, PercentileDisc} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AnyValue, First, Last, ListAgg, PercentileCont, PercentileDisc} import org.apache.spark.sql.catalyst.parser.SqlBaseParser._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -2203,6 +2203,38 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { } } + /** + * Create a ListAgg expression. + */ + override def visitListAgg(ctx: ListAggContext): AnyRef = { + val column = expression(ctx.aggEpxr) + val sortOrder = visitSortItem(ctx.sortItem) + if (!column.semanticEquals(sortOrder.child)) { + throw QueryCompilationErrors.functionAndOrderExpressionMismatchError("list_agg", column, + sortOrder.child) + } + val listAgg = if (ctx.delimiter != null) { + sortOrder.direction match { + case Ascending => ListAgg(sortOrder.child, Literal(ctx.delimiter.getText)) + case Descending => ListAgg(sortOrder.child, Literal(ctx.delimiter.getText), true) + } + } else { + sortOrder.direction match { + case Ascending => ListAgg(sortOrder.child) + case Descending => ListAgg(sortOrder.child, Literal(","), true) + } + } + val isDistinct = Option(ctx.setQuantifier()).exists(_.DISTINCT != null) + val aggregateExpression = listAgg.toAggregateExpression(isDistinct) + ctx.windowSpec match { + case spec: WindowRefContext => + UnresolvedWindowExpression(aggregateExpression, visitWindowRef(spec)) + case spec: WindowDefContext => + WindowExpression(aggregateExpression, visitWindowDef(spec)) + case _ => aggregateExpression + } + } + /** * Create a Substring/Substr expression. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index 1e4f779e565af..d287f0ee72589 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -876,6 +876,18 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat messageParameters = Map("database" -> quoted)) } + def functionAndOrderExpressionMismatchError( + functionName: String, + functionExpr: Expression, + orderExpr: Expression): Throwable = { + new AnalysisException( + errorClass = "FUNCTION_AND_ORDER_EXPRESSION_MISMATCH_ERROR", + messageParameters = Map( + "functionName" -> functionName, + "functionExpr" -> toSQLExpr(functionExpr), + "orderExpr" -> toSQLExpr(orderExpr))) + } + def wrongCommandForObjectTypeError( operation: String, requiredType: String, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index e7e8b945d9186..3a8dface54f82 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1148,6 +1148,60 @@ object functions { */ def sum_distinct(e: Column): Column = withAggregateFunction(Sum(e.expr), isDistinct = true) + /** + * Aggregate function: returns the concatenated input values. + * + * @group agg_funcs + * @since 4.0.0 + */ + def listagg(e: Column): Column = withAggregateFunction { + ListAgg(e.expr) + } + + /** + * Aggregate function: returns the concatenated input values. + * + * @group agg_funcs + * @since 4.0.0 + */ + def listagg(columnName: String): Column = listagg(Column(columnName)) + + /** + * Aggregate function: returns the concatenated input values, separated by the delimiter string. + * + * @group agg_funcs + * @since 4.0.0 + */ + def listagg(e: Column, delimiter: String): Column = withAggregateFunction { + ListAgg(e.expr, Literal.create(delimiter, StringType)) + } + + /** + * Aggregate function: returns the concatenated input values, separated by the delimiter string. + * + * @group agg_funcs + * @since 4.0.0 + */ + def listagg(columnName: String, delimiter: String): Column = + listagg(Column(columnName), delimiter) + + /** + * Aggregate function: returns the concatenated input values, separated by the delimiter string. + * + * @group agg_funcs + * @since 4.0.0 + */ + def listagg_distinct(e: Column): Column = withAggregateFunction(ListAgg(e.expr), + isDistinct = true) + + /** + * Aggregate function: returns the concatenated input values, separated by the delimiter string. + * + * @group agg_funcs + * @since 4.0.0 + */ + def listagg_distinct(columnName: String): Column = listagg_distinct(Column(columnName)) + /** * Aggregate function: alias for `var_samp`. * diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md index 71fde8c7268cc..49386b26d0c54 100644 --- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md +++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md @@ -131,8 +131,8 @@ | org.apache.spark.sql.catalyst.expressions.EqualTo | == | SELECT 2 == 2 | struct<(2 = 2):boolean> | | org.apache.spark.sql.catalyst.expressions.EulerNumber | e | SELECT e() | struct | | org.apache.spark.sql.catalyst.expressions.Exp | exp | SELECT exp(0) | struct | -| org.apache.spark.sql.catalyst.expressions.Explode | explode | SELECT explode(array(10, 20)) | struct | -| org.apache.spark.sql.catalyst.expressions.Explode | explode_outer | SELECT explode_outer(array(10, 20)) | struct | +| org.apache.spark.sql.catalyst.expressions.ExplodeExpressionBuilder | explode | SELECT explode(array(10, 20)) | struct | +| org.apache.spark.sql.catalyst.expressions.ExplodeExpressionBuilder | explode_outer | SELECT explode_outer(array(10, 20)) | struct | | org.apache.spark.sql.catalyst.expressions.Expm1 | expm1 | SELECT expm1(0) | struct | | org.apache.spark.sql.catalyst.expressions.Extract | extract | SELECT extract(YEAR FROM TIMESTAMP '2019-08-12 01:00:00.123456') | struct | | org.apache.spark.sql.catalyst.expressions.Factorial | factorial | SELECT factorial(5) | struct | @@ -212,7 +212,7 @@ | org.apache.spark.sql.catalyst.expressions.MapKeys | map_keys | SELECT map_keys(map(1, 'a', 2, 'b')) | struct> | | org.apache.spark.sql.catalyst.expressions.MapValues | map_values | SELECT map_values(map(1, 'a', 2, 'b')) | struct> | | org.apache.spark.sql.catalyst.expressions.MapZipWith | map_zip_with | SELECT map_zip_with(map(1, 'a', 2, 'b'), map(1, 'x', 2, 'y'), (k, v1, v2) -> concat(v1, v2)) | struct> | -| org.apache.spark.sql.catalyst.expressions.Mask | mask | SELECT mask('abcd-EFGH-8765-4321') | struct | +| org.apache.spark.sql.catalyst.expressions.MaskExpressionBuilder | mask | SELECT mask('abcd-EFGH-8765-4321') | struct | | org.apache.spark.sql.catalyst.expressions.Md5 | md5 | SELECT md5('Spark') | struct | | org.apache.spark.sql.catalyst.expressions.MicrosToTimestamp | timestamp_micros | SELECT timestamp_micros(1230219000123123) | struct | | org.apache.spark.sql.catalyst.expressions.MillisToTimestamp | timestamp_millis | SELECT timestamp_millis(1230219000123) | struct | @@ -385,7 +385,7 @@ | org.apache.spark.sql.catalyst.expressions.aggregate.Corr | corr | SELECT corr(c1, c2) FROM VALUES (3, 2), (3, 3), (6, 4) as tab(c1, c2) | struct | | org.apache.spark.sql.catalyst.expressions.aggregate.Count | count | SELECT count(*) FROM VALUES (NULL), (5), (5), (20) AS tab(col) | struct | | org.apache.spark.sql.catalyst.expressions.aggregate.CountIf | count_if | SELECT count_if(col % 2 = 0) FROM VALUES (NULL), (0), (1), (2), (3) AS tab(col) | struct | -| org.apache.spark.sql.catalyst.expressions.aggregate.CountMinSketchAgg | count_min_sketch | SELECT hex(count_min_sketch(col, 0.5d, 0.5d, 1)) FROM VALUES (1), (2), (1) AS tab(col) | struct | +| org.apache.spark.sql.catalyst.expressions.aggregate.CountMinSketchAggExpressionBuilder | count_min_sketch | SELECT hex(count_min_sketch(col, 0.5d, 0.5d, 1)) FROM VALUES (1), (2), (1) AS tab(col) | struct | | org.apache.spark.sql.catalyst.expressions.aggregate.CovPopulation | covar_pop | SELECT covar_pop(c1, c2) FROM VALUES (1,1), (2,2), (3,3) AS tab(c1, c2) | struct | | org.apache.spark.sql.catalyst.expressions.aggregate.CovSample | covar_samp | SELECT covar_samp(c1, c2) FROM VALUES (1,1), (2,2), (3,3) AS tab(c1, c2) | struct | | org.apache.spark.sql.catalyst.expressions.aggregate.First | first | SELECT first(col) FROM VALUES (10), (5), (20) AS tab(col) | struct | @@ -397,6 +397,7 @@ | org.apache.spark.sql.catalyst.expressions.aggregate.Kurtosis | kurtosis | SELECT kurtosis(col) FROM VALUES (-10), (-20), (100), (1000) AS tab(col) | struct | | org.apache.spark.sql.catalyst.expressions.aggregate.Last | last | SELECT last(col) FROM VALUES (10), (5), (20) AS tab(col) | struct | | org.apache.spark.sql.catalyst.expressions.aggregate.Last | last_value | SELECT last_value(col) FROM VALUES (10), (5), (20) AS tab(col) | struct | +| org.apache.spark.sql.catalyst.expressions.aggregate.ListAgg | listagg | SELECT listagg(col) FROM VALUES ('a'), ('b'), ('c') AS tab(col) | struct | | org.apache.spark.sql.catalyst.expressions.aggregate.Max | max | SELECT max(col) FROM VALUES (10), (50), (20) AS tab(col) | struct | | org.apache.spark.sql.catalyst.expressions.aggregate.MaxBy | max_by | SELECT max_by(x, y) FROM VALUES ('a', 10), ('b', 50), ('c', 20) AS tab(x, y) | struct | | org.apache.spark.sql.catalyst.expressions.aggregate.Median | median | SELECT median(col) FROM VALUES (0), (10) AS tab(col) | struct | diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/keywords.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/keywords.sql.out index fe7bec0219162..3baa4d8e1b6b8 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/keywords.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/keywords.sql.out @@ -159,6 +159,7 @@ LIKE false LIMIT false LINES false LIST false +LISTAGG false LOAD false LOCAL false LOCATION false diff --git a/sql/core/src/test/resources/sql-tests/results/keywords.sql.out b/sql/core/src/test/resources/sql-tests/results/keywords.sql.out index a4fd9c82cf095..cabe7608a7180 100644 --- a/sql/core/src/test/resources/sql-tests/results/keywords.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/keywords.sql.out @@ -159,6 +159,7 @@ LIKE false LIMIT false LINES false LIST false +LISTAGG false LOAD false LOCAL false LOCATION false diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index d78771a8f19bc..2878a6727e3c9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -602,6 +602,45 @@ class DataFrameAggregateSuite extends QueryTest ) } + test("listagg function") { + // normal case + val df = Seq(("a", "b"), ("b", "c"), ("c", "d")).toDF("a", "b") + checkAnswer( + df.selectExpr("listagg(a)", "listagg(b)"), + Seq(Row("a,b,c", "b,c,d")) + ) + checkAnswer( + df.select(listagg($"a"), listagg($"b")), + Seq(Row("a,b,c", "b,c,d")) + ) + + // distinct case + val df2 = Seq(("a", "b"), ("a", "b"), ("b", "d")).toDF("a", "b") + checkAnswer( + df2.select(listagg_distinct($"a"), listagg_distinct($"b")), + Seq(Row("a,b", "b,d")) + ) + + // null case + val df3 = Seq(("a", "b", null), ("a", "b", null), (null, null, null)).toDF("a", "b", "c") + checkAnswer( + df3.select(listagg_distinct($"a"), listagg($"a"), listagg_distinct($"b"), listagg($"b"), + listagg($"c")), + Seq(Row("a", "a,a", "b", "b,b", null)) + ) + + // custom delimiter + val df4 = Seq(("a", "b"), ("b", "c"), ("c", "d")).toDF("a", "b") + checkAnswer( + df4.selectExpr("listagg(a, '|')", "listagg(b, '|')"), + Seq(Row("a|b|c", "b|c|d")) + ) + checkAnswer( + df4.select(listagg($"a", "|"), listagg($"b", "|")), + Seq(Row("a|b|c", "b|c|d")) + ) + } + test("SPARK-31500: collect_set() of BinaryType returns duplicate elements") { val bytesTest1 = "test1".getBytes val bytesTest2 = "test2".getBytes diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index c7dcb575ff050..04f2ddd71d8cd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -82,7 +82,8 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "bucket", "days", "hours", "months", "years", // Datasource v2 partition transformations "product", // Discussed in https://github.com/apache/spark/pull/30745 "unwrap_udt", - "collect_top_k" + "collect_top_k", + "listagg_distinct" ) // We only consider functions matching this pattern, this excludes symbolic and other diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index cfeccbdf648c2..133d638d23ab1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -158,6 +158,57 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark } } + test("SPARK-42746: listagg function") { + withTempView("df", "df2") { + Seq(("a", "b"), ("a", "c"), ("b", "c"), ("b", "d"), (null, null)).toDF("a", "b") + .createOrReplaceTempView("df") + checkAnswer( + sql("select listagg(b) from df group by a"), + Row("b,c") :: Row("c,d") :: Row(null) :: Nil) + + checkAnswer( + sql("select listagg(b, '|') from df group by a"), + Row("b|c") :: Row("c|d") :: Row(null) :: Nil) + + checkAnswer( + sql("SELECT LISTAGG(a) FROM df"), + Row("a,a,b,b") :: Nil) + + checkAnswer( + sql("SELECT LISTAGG(DISTINCT a) FROM df"), + Row("a,b") :: Nil) + + checkAnswer( + sql("SELECT LISTAGG(a) WITHIN GROUP (ORDER BY a) FROM df"), + Row("a,a,b,b") :: Nil) + + checkAnswer( + sql("SELECT LISTAGG(a) WITHIN GROUP (ORDER BY a DESC) FROM df"), + Row("b,b,a,a") :: Nil) + + checkAnswer( + sql("SELECT LISTAGG(a) WITHIN GROUP (ORDER BY a DESC) " + + "OVER (PARTITION BY b) FROM df"), + Row("a") :: Row("b,a") :: Row("b,a") :: Row("b") :: Row(null) :: Nil) + + checkError( + exception = intercept[AnalysisException] { + sql("SELECT LISTAGG(a) WITHIN GROUP (ORDER BY b) FROM df") + }, + errorClass = "FUNCTION_AND_ORDER_EXPRESSION_MISMATCH_ERROR", + parameters = Map( + "functionName" -> "list_agg", + "functionExpr" -> "\"a\"", + "orderExpr" -> "\"b\"")) + + Seq((1, true), (2, false), (3, false)).toDF("a", "b").createOrReplaceTempView("df2") + + checkAnswer( + sql("SELECT LISTAGG(a), LISTAGG(b) FROM df2"), + Row("1,2,3", "false,false,true") :: Nil) + } + } + test("support table.star") { checkAnswer( sql( diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala index 45d1f70956a41..32901a308463e 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala @@ -213,7 +213,7 @@ trait ThriftServerWithSparkContextSuite extends SharedThriftServer { val sessionHandle = client.openSession(user, "") val infoValue = client.getInfo(sessionHandle, GetInfoType.CLI_ODBC_KEYWORDS) // scalastyle:off line.size.limit - assert(infoValue.getStringValue == "ADD,AFTER,ALL,ALTER,ALWAYS,ANALYZE,AND,ANTI,ANY,ANY_VALUE,ARCHIVE,ARRAY,AS,ASC,AT,AUTHORIZATION,BETWEEN,BIGINT,BINARY,BOOLEAN,BOTH,BUCKET,BUCKETS,BY,BYTE,CACHE,CASCADE,CASE,CAST,CATALOG,CATALOGS,CHANGE,CHAR,CHARACTER,CHECK,CLEAR,CLUSTER,CLUSTERED,CODEGEN,COLLATE,COLLECTION,COLUMN,COLUMNS,COMMENT,COMMIT,COMPACT,COMPACTIONS,COMPUTE,CONCATENATE,CONSTRAINT,COST,CREATE,CROSS,CUBE,CURRENT,CURRENT_DATE,CURRENT_TIME,CURRENT_TIMESTAMP,CURRENT_USER,DATA,DATABASE,DATABASES,DATE,DATEADD,DATEDIFF,DATE_ADD,DATE_DIFF,DAY,DAYOFYEAR,DAYS,DBPROPERTIES,DEC,DECIMAL,DEFAULT,DEFINED,DELETE,DELIMITED,DESC,DESCRIBE,DFS,DIRECTORIES,DIRECTORY,DISTINCT,DISTRIBUTE,DIV,DOUBLE,DROP,ELSE,END,ESCAPE,ESCAPED,EXCEPT,EXCHANGE,EXCLUDE,EXISTS,EXPLAIN,EXPORT,EXTENDED,EXTERNAL,EXTRACT,FALSE,FETCH,FIELDS,FILEFORMAT,FILTER,FIRST,FLOAT,FOLLOWING,FOR,FOREIGN,FORMAT,FORMATTED,FROM,FULL,FUNCTION,FUNCTIONS,GENERATED,GLOBAL,GRANT,GROUP,GROUPING,HAVING,HOUR,HOURS,IDENTIFIER,IF,IGNORE,ILIKE,IMPORT,IN,INCLUDE,INDEX,INDEXES,INNER,INPATH,INPUTFORMAT,INSERT,INT,INTEGER,INTERSECT,INTERVAL,INTO,IS,ITEMS,JOIN,KEYS,LAST,LATERAL,LAZY,LEADING,LEFT,LIKE,LIMIT,LINES,LIST,LOAD,LOCAL,LOCATION,LOCK,LOCKS,LOGICAL,LONG,MACRO,MAP,MATCHED,MERGE,MICROSECOND,MICROSECONDS,MILLISECOND,MILLISECONDS,MINUS,MINUTE,MINUTES,MONTH,MONTHS,MSCK,NAME,NAMESPACE,NAMESPACES,NANOSECOND,NANOSECONDS,NATURAL,NO,NULL,NULLS,NUMERIC,OF,OFFSET,ON,ONLY,OPTION,OPTIONS,OR,ORDER,OUT,OUTER,OUTPUTFORMAT,OVER,OVERLAPS,OVERLAY,OVERWRITE,PARTITION,PARTITIONED,PARTITIONS,PERCENT,PERCENTILE_CONT,PERCENTILE_DISC,PIVOT,PLACING,POSITION,PRECEDING,PRIMARY,PRINCIPALS,PROPERTIES,PURGE,QUARTER,QUERY,RANGE,REAL,RECORDREADER,RECORDWRITER,RECOVER,REDUCE,REFERENCES,REFRESH,RENAME,REPAIR,REPEATABLE,REPLACE,RESET,RESPECT,RESTRICT,REVOKE,RIGHT,ROLE,ROLES,ROLLBACK,ROLLUP,ROW,ROWS,SCHEMA,SCHEMAS,SECOND,SECONDS,SELECT,SEMI,SEPARATED,SERDE,SERDEPROPERTIES,SESSION_USER,SET,SETS,SHORT,SHOW,SINGLE,SKEWED,SMALLINT,SOME,SORT,SORTED,SOURCE,START,STATISTICS,STORED,STRATIFY,STRING,STRUCT,SUBSTR,SUBSTRING,SYNC,SYSTEM_TIME,SYSTEM_VERSION,TABLE,TABLES,TABLESAMPLE,TARGET,TBLPROPERTIES,TERMINATED,THEN,TIME,TIMESTAMP,TIMESTAMPADD,TIMESTAMPDIFF,TIMESTAMP_LTZ,TIMESTAMP_NTZ,TINYINT,TO,TOUCH,TRAILING,TRANSACTION,TRANSACTIONS,TRANSFORM,TRIM,TRUE,TRUNCATE,TRY_CAST,TYPE,UNARCHIVE,UNBOUNDED,UNCACHE,UNION,UNIQUE,UNKNOWN,UNLOCK,UNPIVOT,UNSET,UPDATE,USE,USER,USING,VALUES,VARCHAR,VERSION,VIEW,VIEWS,VOID,WEEK,WEEKS,WHEN,WHERE,WINDOW,WITH,WITHIN,X,YEAR,YEARS,ZONE") + assert(infoValue.getStringValue == "ADD,AFTER,ALL,ALTER,ALWAYS,ANALYZE,AND,ANTI,ANY,ANY_VALUE,ARCHIVE,ARRAY,AS,ASC,AT,AUTHORIZATION,BETWEEN,BIGINT,BINARY,BOOLEAN,BOTH,BUCKET,BUCKETS,BY,BYTE,CACHE,CASCADE,CASE,CAST,CATALOG,CATALOGS,CHANGE,CHAR,CHARACTER,CHECK,CLEAR,CLUSTER,CLUSTERED,CODEGEN,COLLATE,COLLECTION,COLUMN,COLUMNS,COMMENT,COMMIT,COMPACT,COMPACTIONS,COMPUTE,CONCATENATE,CONSTRAINT,COST,CREATE,CROSS,CUBE,CURRENT,CURRENT_DATE,CURRENT_TIME,CURRENT_TIMESTAMP,CURRENT_USER,DATA,DATABASE,DATABASES,DATE,DATEADD,DATEDIFF,DATE_ADD,DATE_DIFF,DAY,DAYOFYEAR,DAYS,DBPROPERTIES,DEC,DECIMAL,DEFAULT,DEFINED,DELETE,DELIMITED,DESC,DESCRIBE,DFS,DIRECTORIES,DIRECTORY,DISTINCT,DISTRIBUTE,DIV,DOUBLE,DROP,ELSE,END,ESCAPE,ESCAPED,EXCEPT,EXCHANGE,EXCLUDE,EXISTS,EXPLAIN,EXPORT,EXTENDED,EXTERNAL,EXTRACT,FALSE,FETCH,FIELDS,FILEFORMAT,FILTER,FIRST,FLOAT,FOLLOWING,FOR,FOREIGN,FORMAT,FORMATTED,FROM,FULL,FUNCTION,FUNCTIONS,GENERATED,GLOBAL,GRANT,GROUP,GROUPING,HAVING,HOUR,HOURS,IDENTIFIER,IF,IGNORE,ILIKE,IMPORT,IN,INCLUDE,INDEX,INDEXES,INNER,INPATH,INPUTFORMAT,INSERT,INT,INTEGER,INTERSECT,INTERVAL,INTO,IS,ITEMS,JOIN,KEYS,LAST,LATERAL,LAZY,LEADING,LEFT,LIKE,LIMIT,LINES,LIST,LISTAGG,LOAD,LOCAL,LOCATION,LOCK,LOCKS,LOGICAL,LONG,MACRO,MAP,MATCHED,MERGE,MICROSECOND,MICROSECONDS,MILLISECOND,MILLISECONDS,MINUS,MINUTE,MINUTES,MONTH,MONTHS,MSCK,NAME,NAMESPACE,NAMESPACES,NANOSECOND,NANOSECONDS,NATURAL,NO,NULL,NULLS,NUMERIC,OF,OFFSET,ON,ONLY,OPTION,OPTIONS,OR,ORDER,OUT,OUTER,OUTPUTFORMAT,OVER,OVERLAPS,OVERLAY,OVERWRITE,PARTITION,PARTITIONED,PARTITIONS,PERCENT,PERCENTILE_CONT,PERCENTILE_DISC,PIVOT,PLACING,POSITION,PRECEDING,PRIMARY,PRINCIPALS,PROPERTIES,PURGE,QUARTER,QUERY,RANGE,REAL,RECORDREADER,RECORDWRITER,RECOVER,REDUCE,REFERENCES,REFRESH,RENAME,REPAIR,REPEATABLE,REPLACE,RESET,RESPECT,RESTRICT,REVOKE,RIGHT,ROLE,ROLES,ROLLBACK,ROLLUP,ROW,ROWS,SCHEMA,SCHEMAS,SECOND,SECONDS,SELECT,SEMI,SEPARATED,SERDE,SERDEPROPERTIES,SESSION_USER,SET,SETS,SHORT,SHOW,SINGLE,SKEWED,SMALLINT,SOME,SORT,SORTED,SOURCE,START,STATISTICS,STORED,STRATIFY,STRING,STRUCT,SUBSTR,SUBSTRING,SYNC,SYSTEM_TIME,SYSTEM_VERSION,TABLE,TABLES,TABLESAMPLE,TARGET,TBLPROPERTIES,TERMINATED,THEN,TIME,TIMESTAMP,TIMESTAMPADD,TIMESTAMPDIFF,TIMESTAMP_LTZ,TIMESTAMP_NTZ,TINYINT,TO,TOUCH,TRAILING,TRANSACTION,TRANSACTIONS,TRANSFORM,TRIM,TRUE,TRUNCATE,TRY_CAST,TYPE,UNARCHIVE,UNBOUNDED,UNCACHE,UNION,UNIQUE,UNKNOWN,UNLOCK,UNPIVOT,UNSET,UPDATE,USE,USER,USING,VALUES,VARCHAR,VERSION,VIEW,VIEWS,VOID,WEEK,WEEKS,WHEN,WHERE,WINDOW,WITH,WITHIN,X,YEAR,YEARS,ZONE") // scalastyle:on line.size.limit } } From 99cf932d5dfdc3868807870881a8d9cd8819b2c0 Mon Sep 17 00:00:00 2001 From: Jia Fan Date: Wed, 9 Aug 2023 12:39:46 +0800 Subject: [PATCH 02/58] update --- .../org/apache/spark/sql/functions.scala | 24 ++++--------------- .../spark/sql/DataFrameAggregateSuite.scala | 4 ---- 2 files changed, 4 insertions(+), 24 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 3a8dface54f82..d5aed3a11bebc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1158,33 +1158,16 @@ object functions { ListAgg(e.expr) } - /** - * Aggregate function: returns the concatenated input values. - * - * @group agg_funcs - * @since 4.0.0 - */ - def listagg(columnName: String): Column = listagg(Column(columnName)) - /** * Aggregate function: returns the concatenated input values, separated by the delimiter string. * * @group agg_funcs * @since 4.0.0 */ - def listagg(e: Column, delimiter: String): Column = withAggregateFunction { - ListAgg(e.expr, Literal.create(delimiter, StringType)) + def listagg(e: Column, delimiter: Column): Column = withAggregateFunction { + ListAgg(e.expr, delimiter.expr) } - /** - * Aggregate function: returns the concatenated input values, separated by the delimiter string. - * - * @group agg_funcs - * @since 4.0.0 - */ - def listagg(columnName: String, delimiter: String): Column = - listagg(Column(columnName), delimiter) - /** * Aggregate function: returns the concatenated input values, separated by the delimiter string. * @@ -1200,7 +1183,8 @@ object functions { * @group agg_funcs * @since 4.0.0 */ - def listagg_distinct(columnName: String): Column = listagg_distinct(Column(columnName)) + def listagg_distinct(e: Column, delimiter: Column): Column = withAggregateFunction( + ListAgg(e.expr, delimiter.expr), isDistinct = true) /** * Aggregate function: alias for `var_samp`. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 2878a6727e3c9..853f92dca74da 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -635,10 +635,6 @@ class DataFrameAggregateSuite extends QueryTest df4.selectExpr("listagg(a, '|')", "listagg(b, '|')"), Seq(Row("a|b|c", "b|c|d")) ) - checkAnswer( - df4.select(listagg($"a", "|"), listagg($"b", "|")), - Seq(Row("a|b|c", "b|c|d")) - ) } test("SPARK-31500: collect_set() of BinaryType returns duplicate elements") { From db513cf5c4523381f7278bc6c1fb276872d014c3 Mon Sep 17 00:00:00 2001 From: Jia Fan Date: Wed, 9 Aug 2023 13:56:59 +0800 Subject: [PATCH 03/58] update --- .../connect/client/CheckConnectJvmClientCompatibility.scala | 3 +++ 1 file changed, 3 insertions(+) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala index 2bf9c41fb2cbd..c5c6bcbfc5daf 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala @@ -195,6 +195,9 @@ object CheckConnectJvmClientCompatibility { ProblemFilters.exclude[Problem]("org.apache.spark.sql.functions.callUDF"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.functions.unwrap_udt"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.functions.udaf"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.functions.listagg"), + ProblemFilters.exclude[DirectMissingMethodProblem]( + "org.apache.spark.sql.functions.listagg_distinct"), // KeyValueGroupedDataset ProblemFilters.exclude[Problem]( From 68ed7392f0e2c54dcd74f53cff58f37dc73c0f1d Mon Sep 17 00:00:00 2001 From: Jia Fan Date: Wed, 9 Aug 2023 15:20:25 +0800 Subject: [PATCH 04/58] format --- .../connect/client/CheckConnectJvmClientCompatibility.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala index c5c6bcbfc5daf..631bad0aaf5bc 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala @@ -195,7 +195,8 @@ object CheckConnectJvmClientCompatibility { ProblemFilters.exclude[Problem]("org.apache.spark.sql.functions.callUDF"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.functions.unwrap_udt"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.functions.udaf"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.functions.listagg"), + ProblemFilters.exclude[DirectMissingMethodProblem]( + "org.apache.spark.sql.functions.listagg"), ProblemFilters.exclude[DirectMissingMethodProblem]( "org.apache.spark.sql.functions.listagg_distinct"), From d8460c887fec985fc1471e4a04dffa7abd24c744 Mon Sep 17 00:00:00 2001 From: Jia Fan Date: Wed, 9 Aug 2023 19:20:34 +0800 Subject: [PATCH 05/58] format --- python/pyspark/sql/tests/test_functions.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/python/pyspark/sql/tests/test_functions.py b/python/pyspark/sql/tests/test_functions.py index 70769b2c9cd3e..38bad0fc765b8 100644 --- a/python/pyspark/sql/tests/test_functions.py +++ b/python/pyspark/sql/tests/test_functions.py @@ -78,10 +78,7 @@ def test_function_parity(self): missing_in_py = jvm_fn_set.difference(py_fn_set) # Functions that we expect to be missing in python until they are added to pyspark - expected_missing_in_py = { - "listagg_distinct", # TODO - "listagg" # TODO - } + expected_missing_in_py = {"listagg_distinct", "listagg"} # TODO self.assertEqual( expected_missing_in_py, missing_in_py, "Missing functions in pyspark not as expected" From 864f65801fccd0c3f3daddda4ef9a777b3bb923b Mon Sep 17 00:00:00 2001 From: Jia Fan Date: Thu, 10 Aug 2023 09:39:59 +0800 Subject: [PATCH 06/58] fix review --- common/utils/src/main/resources/error/error-classes.json | 2 +- docs/sql-error-conditions.md | 2 +- .../org/apache/spark/sql/catalyst/parser/AstBuilder.scala | 2 +- .../org/apache/spark/sql/errors/QueryCompilationErrors.scala | 4 ++-- .../src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala | 4 ++-- 5 files changed, 7 insertions(+), 7 deletions(-) diff --git a/common/utils/src/main/resources/error/error-classes.json b/common/utils/src/main/resources/error/error-classes.json index b415aa4eb7b5e..5f75ca821b442 100644 --- a/common/utils/src/main/resources/error/error-classes.json +++ b/common/utils/src/main/resources/error/error-classes.json @@ -921,7 +921,7 @@ ], "sqlState" : "42809" }, - "FUNCTION_AND_ORDER_EXPRESSION_MISMATCH_ERROR" : { + "FUNCTION_AND_ORDER_EXPRESSION_MISMATCH" : { "message" : [ "The function arguments should match the order by expression ." ] diff --git a/docs/sql-error-conditions.md b/docs/sql-error-conditions.md index 094c6d7300f16..e7bff2f0457ac 100644 --- a/docs/sql-error-conditions.md +++ b/docs/sql-error-conditions.md @@ -581,7 +581,7 @@ No such struct field `` in ``. The operation `` is not allowed on the ``: ``. -### FUNCTION_AND_ORDER_EXPRESSION_MISMATCH_ERROR +### FUNCTION_AND_ORDER_EXPRESSION_MISMATCH SQLSTATE: none assigned diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 5a8f341e2365d..b896d4ef690b6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -2210,7 +2210,7 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { val column = expression(ctx.aggEpxr) val sortOrder = visitSortItem(ctx.sortItem) if (!column.semanticEquals(sortOrder.child)) { - throw QueryCompilationErrors.functionAndOrderExpressionMismatchError("list_agg", column, + throw QueryCompilationErrors.functionAndOrderExpressionMismatchError("LISTAGG", column, sortOrder.child) } val listAgg = if (ctx.delimiter != null) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index 452cee90e7afa..6a803f7673131 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -889,9 +889,9 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat functionExpr: Expression, orderExpr: Expression): Throwable = { new AnalysisException( - errorClass = "FUNCTION_AND_ORDER_EXPRESSION_MISMATCH_ERROR", + errorClass = "FUNCTION_AND_ORDER_EXPRESSION_MISMATCH", messageParameters = Map( - "functionName" -> functionName, + "functionName" -> toSQLStmt(functionName), "functionExpr" -> toSQLExpr(functionExpr), "orderExpr" -> toSQLExpr(orderExpr))) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 133d638d23ab1..e4f1a43fc0d88 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -195,9 +195,9 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark exception = intercept[AnalysisException] { sql("SELECT LISTAGG(a) WITHIN GROUP (ORDER BY b) FROM df") }, - errorClass = "FUNCTION_AND_ORDER_EXPRESSION_MISMATCH_ERROR", + errorClass = "FUNCTION_AND_ORDER_EXPRESSION_MISMATCH", parameters = Map( - "functionName" -> "list_agg", + "functionName" -> "LISTAGG", "functionExpr" -> "\"a\"", "orderExpr" -> "\"b\"")) From d23654c43e23bfb6007db76593b61ad4d77cb50d Mon Sep 17 00:00:00 2001 From: Jia Fan Date: Tue, 22 Aug 2023 22:11:52 +0800 Subject: [PATCH 07/58] format --- python/pyspark/sql/tests/test_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/tests/test_functions.py b/python/pyspark/sql/tests/test_functions.py index 862aa171752af..5521a28e50561 100644 --- a/python/pyspark/sql/tests/test_functions.py +++ b/python/pyspark/sql/tests/test_functions.py @@ -88,7 +88,7 @@ def test_function_parity(self): "schema_of_xml", # TODO: listagg functions will soon be added and removed from this list "listagg_distinct", - "listagg" + "listagg", } self.assertEqual( From f048dc9289a5340caefa50dbed16527785acbee0 Mon Sep 17 00:00:00 2001 From: Jia Fan Date: Mon, 9 Oct 2023 10:34:31 +0800 Subject: [PATCH 08/58] update --- .../spark/sql/catalyst/expressions/aggregate/ListAgg.scala | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ListAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ListAgg.scala index f7819b2dd206e..f24b4ad36baa7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ListAgg.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ListAgg.scala @@ -34,6 +34,10 @@ import org.apache.spark.util.collection.OpenHashMap a,b,c > SELECT _FUNC_(col) FROM VALUES (NULL), ('a'), ('b') AS tab(col); a,b + > SELECT _FUNC_(col) FROM VALUES ('a'), ('a') AS tab(col); + a,a + > SELECT _FUNC_(DISTINCT col) FROM VALUES ('a'), ('a'), ('b') AS tab(col); + a,b > SELECT _FUNC_(col, '|') FROM VALUES ('a'), ('b') AS tab(col); a|b > SELECT _FUNC_(col) FROM VALUES (NULL), (NULL) AS tab(col); From 77825a75537d8557d6168697dfbec3cdbae82516 Mon Sep 17 00:00:00 2001 From: Jia Fan Date: Mon, 9 Oct 2023 16:53:54 +0800 Subject: [PATCH 09/58] update --- .../src/main/resources/error/error-classes.json | 2 +- docs/sql-error-conditions.md | 2 +- .../spark/sql/catalyst/parser/AstBuilder.scala | 12 ++++++------ .../scala/org/apache/spark/sql/SQLQuerySuite.scala | 6 +++++- 4 files changed, 13 insertions(+), 9 deletions(-) diff --git a/common/utils/src/main/resources/error/error-classes.json b/common/utils/src/main/resources/error/error-classes.json index b182e411a79c5..a0f77fdcfe02f 100644 --- a/common/utils/src/main/resources/error/error-classes.json +++ b/common/utils/src/main/resources/error/error-classes.json @@ -979,7 +979,7 @@ }, "FUNCTION_AND_ORDER_EXPRESSION_MISMATCH" : { "message" : [ - "The function arguments should match the order by expression ." + "The function arguments should match the order by expression when use DISTINCT." ] }, "GENERATED_COLUMN_WITH_DEFAULT_VALUE" : { diff --git a/docs/sql-error-conditions.md b/docs/sql-error-conditions.md index c0c33b4013453..eaa1ae90d8e45 100644 --- a/docs/sql-error-conditions.md +++ b/docs/sql-error-conditions.md @@ -615,7 +615,7 @@ The operation `` is not allowed on the ``: `` SQLSTATE: none assigned -The function `` arguments `` should match the order by expression ``. +The function `` arguments `` should match the order by expression `` when use DISTINCT. ### GENERATED_COLUMN_WITH_DEFAULT_VALUE diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index bf61e2f1d0f40..b2cfe681e3863 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -2209,22 +2209,22 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { override def visitListAgg(ctx: ListAggContext): AnyRef = { val column = expression(ctx.aggEpxr) val sortOrder = visitSortItem(ctx.sortItem) - if (!column.semanticEquals(sortOrder.child)) { + val isDistinct = Option(ctx.setQuantifier()).exists(_.DISTINCT != null) + if (!column.semanticEquals(sortOrder.child) && isDistinct) { throw QueryCompilationErrors.functionAndOrderExpressionMismatchError("LISTAGG", column, sortOrder.child) } val listAgg = if (ctx.delimiter != null) { sortOrder.direction match { - case Ascending => ListAgg(sortOrder.child, Literal(ctx.delimiter.getText)) - case Descending => ListAgg(sortOrder.child, Literal(ctx.delimiter.getText), true) + case Ascending => ListAgg(column, Literal(ctx.delimiter.getText)) + case Descending => ListAgg(column, Literal(ctx.delimiter.getText), true) } } else { sortOrder.direction match { - case Ascending => ListAgg(sortOrder.child) - case Descending => ListAgg(sortOrder.child, Literal(","), true) + case Ascending => ListAgg(column) + case Descending => ListAgg(column, Literal(","), true) } } - val isDistinct = Option(ctx.setQuantifier()).exists(_.DISTINCT != null) val aggregateExpression = listAgg.toAggregateExpression(isDistinct) ctx.windowSpec match { case spec: WindowRefContext => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 6a6c04dace2e9..04e4b9055a02a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -191,9 +191,13 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark "OVER (PARTITION BY b) FROM df"), Row("a") :: Row("b,a") :: Row("b,a") :: Row("b") :: Row(null) :: Nil) + checkAnswer( + sql("SELECT LISTAGG(a) WITHIN GROUP (ORDER BY b) FROM df"), + Row("b,c,c,d") :: Nil) + checkError( exception = intercept[AnalysisException] { - sql("SELECT LISTAGG(a) WITHIN GROUP (ORDER BY b) FROM df") + sql("SELECT LISTAGG(DISTINCT a) WITHIN GROUP (ORDER BY b) FROM df") }, errorClass = "FUNCTION_AND_ORDER_EXPRESSION_MISMATCH", parameters = Map( From ce093b5a5da16215fe996c26fc041ab200b69bbc Mon Sep 17 00:00:00 2001 From: Jia Fan Date: Wed, 11 Oct 2023 21:29:21 +0800 Subject: [PATCH 10/58] update --- .../sql/catalyst/parser/SqlBaseParser.g4 | 2 +- .../expressions/aggregate/ListAgg.scala | 109 ------------------ .../expressions/aggregate/collect.scala | 99 ++++++++++++++++ .../sql/catalyst/parser/AstBuilder.scala | 10 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 10 +- 5 files changed, 113 insertions(+), 117 deletions(-) delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ListAgg.scala diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 index 8ddc849bfdcfe..d0625dbc6930e 100644 --- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 @@ -989,7 +989,7 @@ primaryExpression WITHIN GROUP LEFT_PAREN ORDER BY sortItem RIGHT_PAREN (FILTER LEFT_PAREN WHERE where=booleanExpression RIGHT_PAREN)? ( OVER windowSpec)? #percentile | LISTAGG LEFT_PAREN setQuantifier? aggEpxr=expression (COMMA delimiter=stringLit)? RIGHT_PAREN - WITHIN GROUP LEFT_PAREN ORDER BY sortItem RIGHT_PAREN ( OVER windowSpec)? #listAgg + WITHIN GROUP LEFT_PAREN ORDER BY sortItem RIGHT_PAREN (OVER windowSpec)? #listAgg ; literalType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ListAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ListAgg.scala deleted file mode 100644 index f24b4ad36baa7..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ListAgg.scala +++ /dev/null @@ -1,109 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions.aggregate - -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.trees.UnaryLike -import org.apache.spark.sql.catalyst.types.PhysicalDataType -import org.apache.spark.sql.types.{DataType, StringType} -import org.apache.spark.unsafe.types.UTF8String -import org.apache.spark.util.collection.OpenHashMap - -@ExpressionDescription( - usage = "_FUNC_(expr) - Returns the concatenated input values," + - " separated by the delimiter string.", - examples = """ - Examples: - > SELECT _FUNC_(col) FROM VALUES ('a'), ('b'), ('c') AS tab(col); - a,b,c - > SELECT _FUNC_(col) FROM VALUES (NULL), ('a'), ('b') AS tab(col); - a,b - > SELECT _FUNC_(col) FROM VALUES ('a'), ('a') AS tab(col); - a,a - > SELECT _FUNC_(DISTINCT col) FROM VALUES ('a'), ('a'), ('b') AS tab(col); - a,b - > SELECT _FUNC_(col, '|') FROM VALUES ('a'), ('b') AS tab(col); - a|b - > SELECT _FUNC_(col) FROM VALUES (NULL), (NULL) AS tab(col); - NULL - """, - group = "agg_funcs", - since = "4.0.0") -case class ListAgg( - child: Expression, - delimiter: Expression = Literal.create(",", StringType), - reverse: Boolean = false, - mutableAggBufferOffset: Int = 0, - inputAggBufferOffset: Int = 0) extends TypedAggregateWithHashMapAsBuffer - with UnaryLike[Expression] { - - def this(child: Expression) = this(child, Literal.create(",", StringType), false, 0, 0) - def this(child: Expression, delimiter: Expression) = this(child, delimiter, false, 0, 0) - - override def update( - buffer: OpenHashMap[AnyRef, Long], - input: InternalRow): OpenHashMap[AnyRef, Long] = { - val value = child.eval(input) - if (value != null) { - val key = InternalRow.copyValue(value) - buffer.changeValue(key.asInstanceOf[AnyRef], 1L, _ + 1L) - } - buffer - } - - override def merge( - buffer: OpenHashMap[AnyRef, Long], - input: OpenHashMap[AnyRef, Long]): OpenHashMap[AnyRef, Long] = { - input.foreach { case (key, count) => - buffer.changeValue(key, count, _ + count) - } - buffer - } - - override def eval(buffer: OpenHashMap[AnyRef, Long]): Any = { - if (buffer.nonEmpty) { - val ordering = PhysicalDataType.ordering(child.dataType) - val sortedCounts = if (reverse) { - buffer.toSeq.sortBy(_._1)(ordering.asInstanceOf[Ordering[AnyRef]].reverse) - } else { - buffer.toSeq.sortBy(_._1)(ordering.asInstanceOf[Ordering[AnyRef]]) - } - UTF8String.fromString(sortedCounts.map(kc => { - List.fill(kc._2.toInt)(kc._1.toString).mkString(delimiter.eval() - .asInstanceOf[UTF8String].toString) - }).mkString(delimiter.eval().asInstanceOf[UTF8String].toString)) - } else { - null - } - } - - override def withNewMutableAggBufferOffset( - newMutableAggBufferOffset: Int) : ImperativeAggregate = - copy(mutableAggBufferOffset = newMutableAggBufferOffset) - - override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = - copy(inputAggBufferOffset = newInputAggBufferOffset) - - override def nullable: Boolean = true - - override def dataType: DataType = StringType - - override protected def withNewChildInternal(newChild: Expression): Expression = - copy(child = newChild) -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala index 7bbc930ceab59..48c3a55e821fe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala @@ -19,15 +19,18 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import scala.collection.generic.Growable import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.trees.UnaryLike +import org.apache.spark.sql.catalyst.types.PhysicalDataType import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, TypeUtils} import org.apache.spark.sql.errors.QueryErrorsBase import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.BoundedPriorityQueue /** @@ -245,3 +248,99 @@ case class CollectTopK( override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): CollectTopK = copy(inputAggBufferOffset = newInputAggBufferOffset) } + +@ExpressionDescription( + usage = "_FUNC_(expr) - Returns the concatenated input values," + + " separated by the delimiter string.", + examples = """ + Examples: + > SELECT _FUNC_(col) FROM VALUES ('a'), ('b'), ('c') AS tab(col); + a,b,c + > SELECT _FUNC_(col) FROM VALUES (NULL), ('a'), ('b') AS tab(col); + a,b + > SELECT _FUNC_(col) FROM VALUES ('a'), ('a') AS tab(col); + a,a + > SELECT _FUNC_(DISTINCT col) FROM VALUES ('a'), ('a'), ('b') AS tab(col); + a,b + > SELECT _FUNC_(col, '|') FROM VALUES ('a'), ('b') AS tab(col); + a|b + > SELECT _FUNC_(col) FROM VALUES (NULL), (NULL) AS tab(col); + NULL + """, + group = "agg_funcs", + since = "4.0.0") +case class ListAgg( + child: Expression, + delimiter: Expression = Literal.create(",", StringType), + orderExpression: Option[Expression] = Option.empty, + reverse: Boolean = false, + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) extends Collect[mutable.ArrayBuffer[Any]] { + + def this(child: Expression) = + this(child, Literal.create(",", StringType), Option.empty, false, 0, 0) + def this(child: Expression, delimiter: Expression) = + this(child, delimiter, Option.empty, false, 0, 0) + + override protected def convertToBufferElement(value: Any): Any = InternalRow.copyValue(value) + + override protected lazy val bufferElementType: DataType = { + if (orderExpression.isDefined) { + ArrayType(StructType(Seq( + StructField("value", child.dataType), + StructField("sortOrder", orderExpression.get.dataType))), containsNull = false) + } else { + child.dataType + } + } + + override def eval(buffer: mutable.ArrayBuffer[Any]): Any = { + if (buffer.nonEmpty) { + val sortedCounts = if (orderExpression.isDefined) { + val ordering = PhysicalDataType.ordering(orderExpression.get.dataType) + if (reverse) { + buffer.asInstanceOf[mutable.ArrayBuffer[(Any, Any)]].toSeq.sortBy(_._2)(ordering + .asInstanceOf[Ordering[Any]].reverse).map(_._1) + } else { + buffer.asInstanceOf[mutable.ArrayBuffer[(Any, Any)]].toSeq.sortBy(_._2)(ordering + .asInstanceOf[Ordering[Any]]).map(_._1) + } + } else { + buffer.toSeq + } + UTF8String.fromString(sortedCounts.map(_.toString) + .mkString(delimiter.eval().asInstanceOf[UTF8String].toString)) + } else { + UTF8String.fromString("") + } + } + + override def update(buffer: ArrayBuffer[Any], input: InternalRow): ArrayBuffer[Any] = { + val value = child.eval(input) + if (value != null) { + if (orderExpression.isDefined) { + buffer += ((convertToBufferElement(value), + convertToBufferElement(orderExpression.get.eval(input)))) + } else { + buffer += convertToBufferElement(value) + } + } + buffer + } + + override def createAggregationBuffer(): mutable.ArrayBuffer[Any] = mutable.ArrayBuffer.empty + + override def withNewMutableAggBufferOffset( + newMutableAggBufferOffset: Int) : ImperativeAggregate = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) + + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = + copy(inputAggBufferOffset = newInputAggBufferOffset) + + override def nullable: Boolean = true + + override def dataType: DataType = StringType + + override protected def withNewChildInternal(newChild: Expression): Expression = + copy(child = newChild) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index b2cfe681e3863..0d5cd7e6ba0e5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -2216,13 +2216,15 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { } val listAgg = if (ctx.delimiter != null) { sortOrder.direction match { - case Ascending => ListAgg(column, Literal(ctx.delimiter.getText)) - case Descending => ListAgg(column, Literal(ctx.delimiter.getText), true) + case Ascending => ListAgg(column, Literal(ctx.delimiter.getText), Some(sortOrder.child), + false) + case Descending => ListAgg(column, Literal(ctx.delimiter.getText), Some(sortOrder.child), + true) } } else { sortOrder.direction match { - case Ascending => ListAgg(column) - case Descending => ListAgg(column, Literal(","), true) + case Ascending => ListAgg(column, Literal(","), Some(sortOrder.child), false) + case Descending => ListAgg(column, Literal(","), Some(sortOrder.child), true) } } val aggregateExpression = listAgg.toAggregateExpression(isDistinct) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 04e4b9055a02a..973893d7a4f37 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -164,11 +164,15 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark .createOrReplaceTempView("df") checkAnswer( sql("select listagg(b) from df group by a"), - Row("b,c") :: Row("c,d") :: Row(null) :: Nil) + Row("") :: Row("b,c") :: Row("c,d") :: Nil) + + checkAnswer( + sql("select listagg(b) from df where 1 != 1"), + Row("") :: Nil) checkAnswer( sql("select listagg(b, '|') from df group by a"), - Row("b|c") :: Row("c|d") :: Row(null) :: Nil) + Row("b|c") :: Row("c|d") :: Row("") :: Nil) checkAnswer( sql("SELECT LISTAGG(a) FROM df"), @@ -189,7 +193,7 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark checkAnswer( sql("SELECT LISTAGG(a) WITHIN GROUP (ORDER BY a DESC) " + "OVER (PARTITION BY b) FROM df"), - Row("a") :: Row("b,a") :: Row("b,a") :: Row("b") :: Row(null) :: Nil) + Row("a") :: Row("b,a") :: Row("b,a") :: Row("b") :: Row("") :: Nil) checkAnswer( sql("SELECT LISTAGG(a) WITHIN GROUP (ORDER BY b) FROM df"), From 19fdafc0b84ca9fad6aaf774b1b928c816e2bd04 Mon Sep 17 00:00:00 2001 From: Jia Fan Date: Mon, 23 Oct 2023 11:34:41 +0800 Subject: [PATCH 11/58] update --- .../expressions/aggregate/collect.scala | 71 +++++++++---------- .../sql/catalyst/parser/AstBuilder.scala | 8 +-- .../org/apache/spark/sql/functions.scala | 10 +-- .../org/apache/spark/sql/SQLQuerySuite.scala | 2 +- 4 files changed, 45 insertions(+), 46 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala index fa93b3cec6664..bea3d3036e8ba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.trees.UnaryLike +import org.apache.spark.sql.catalyst.trees.{BinaryLike, UnaryLike} import org.apache.spark.sql.catalyst.types.PhysicalDataType import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, TypeUtils} import org.apache.spark.sql.errors.QueryErrorsBase @@ -39,8 +39,7 @@ import org.apache.spark.util.BoundedPriorityQueue * We have to store all the collected elements in memory, and so notice that too many elements * can cause GC paused and eventually OutOfMemory Errors. */ -abstract class Collect[T <: Growable[Any] with Iterable[Any]] extends TypedImperativeAggregate[T] - with UnaryLike[Expression] { +abstract class Collect[T <: Growable[Any] with Iterable[Any]] extends TypedImperativeAggregate[T] { val child: Expression @@ -105,7 +104,8 @@ abstract class Collect[T <: Growable[Any] with Iterable[Any]] extends TypedImper case class CollectList( child: Expression, mutableAggBufferOffset: Int = 0, - inputAggBufferOffset: Int = 0) extends Collect[mutable.ArrayBuffer[Any]] { + inputAggBufferOffset: Int = 0) extends Collect[mutable.ArrayBuffer[Any]] + with UnaryLike[Expression] { def this(child: Expression) = this(child, 0, 0) @@ -151,7 +151,7 @@ case class CollectSet( child: Expression, mutableAggBufferOffset: Int = 0, inputAggBufferOffset: Int = 0) - extends Collect[mutable.HashSet[Any]] with QueryErrorsBase { + extends Collect[mutable.HashSet[Any]] with QueryErrorsBase with UnaryLike[Expression] { def this(child: Expression) = this(child, 0, 0) @@ -216,7 +216,8 @@ case class CollectTopK( num: Int, reverse: Boolean = false, mutableAggBufferOffset: Int = 0, - inputAggBufferOffset: Int = 0) extends Collect[BoundedPriorityQueue[Any]] { + inputAggBufferOffset: Int = 0) extends Collect[BoundedPriorityQueue[Any]] + with UnaryLike[Expression] { assert(num > 0) def this(child: Expression, num: Int) = this(child, num, false, 0, 0) @@ -272,43 +273,38 @@ case class CollectTopK( case class ListAgg( child: Expression, delimiter: Expression = Literal.create(",", StringType), - orderExpression: Option[Expression] = Option.empty, + orderExpression: Expression, reverse: Boolean = false, mutableAggBufferOffset: Int = 0, - inputAggBufferOffset: Int = 0) extends Collect[mutable.ArrayBuffer[Any]] { + inputAggBufferOffset: Int = 0) extends Collect[mutable.ArrayBuffer[Any]] + with BinaryLike[Expression] { def this(child: Expression) = - this(child, Literal.create(",", StringType), Option.empty, false, 0, 0) + this(child, Literal.create(",", StringType), child, false, 0, 0) def this(child: Expression, delimiter: Expression) = - this(child, delimiter, Option.empty, false, 0, 0) + this(child, delimiter, child, false, 0, 0) override protected def convertToBufferElement(value: Any): Any = InternalRow.copyValue(value) override protected lazy val bufferElementType: DataType = { - if (orderExpression.isDefined) { - ArrayType(StructType(Seq( - StructField("value", child.dataType), - StructField("sortOrder", orderExpression.get.dataType))), containsNull = false) - } else { - child.dataType - } + StructType(Seq( + StructField("value", child.dataType), + StructField("sortOrder", orderExpression.dataType))) } override def eval(buffer: mutable.ArrayBuffer[Any]): Any = { if (buffer.nonEmpty) { - val sortedCounts = if (orderExpression.isDefined) { - val ordering = PhysicalDataType.ordering(orderExpression.get.dataType) - if (reverse) { - buffer.asInstanceOf[mutable.ArrayBuffer[(Any, Any)]].toSeq.sortBy(_._2)(ordering - .asInstanceOf[Ordering[Any]].reverse).map(_._1) - } else { - buffer.asInstanceOf[mutable.ArrayBuffer[(Any, Any)]].toSeq.sortBy(_._2)(ordering - .asInstanceOf[Ordering[Any]]).map(_._1) - } + val ordering = PhysicalDataType.ordering(orderExpression.dataType) + val sorted = if (reverse) { + buffer.asInstanceOf[mutable.ArrayBuffer[InternalRow]].toSeq.sortBy(_.get(1, + orderExpression.dataType))(ordering.asInstanceOf[Ordering[AnyRef]].reverse).map(_.get(0, + child.dataType)) } else { - buffer.toSeq + buffer.asInstanceOf[mutable.ArrayBuffer[InternalRow]].toSeq.sortBy(_.get(1, + orderExpression.dataType))(ordering.asInstanceOf[Ordering[AnyRef]]).map(_.get(0, + child.dataType)) } - UTF8String.fromString(sortedCounts.map(_.toString) + UTF8String.fromString(sorted.map(_.toString) .mkString(delimiter.eval().asInstanceOf[UTF8String].toString)) } else { UTF8String.fromString("") @@ -318,12 +314,8 @@ case class ListAgg( override def update(buffer: ArrayBuffer[Any], input: InternalRow): ArrayBuffer[Any] = { val value = child.eval(input) if (value != null) { - if (orderExpression.isDefined) { - buffer += ((convertToBufferElement(value), - convertToBufferElement(orderExpression.get.eval(input)))) - } else { - buffer += convertToBufferElement(value) - } + buffer += InternalRow.apply(convertToBufferElement(value), + convertToBufferElement(orderExpression.eval(input))) } buffer } @@ -341,6 +333,13 @@ case class ListAgg( override def dataType: DataType = StringType - override protected def withNewChildInternal(newChild: Expression): Expression = - copy(child = newChild) + override def left: Expression = child + + override def right: Expression = orderExpression + + override protected def withNewChildrenInternal( + newLeft: Expression, + newRight: Expression): Expression = { + copy(child = newLeft, orderExpression = newRight) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 9ba6644894d88..2891250e1e1b2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -2227,15 +2227,15 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { } val listAgg = if (ctx.delimiter != null) { sortOrder.direction match { - case Ascending => ListAgg(column, Literal(ctx.delimiter.getText), Some(sortOrder.child), + case Ascending => ListAgg(column, Literal(ctx.delimiter.getText), sortOrder.child, false) - case Descending => ListAgg(column, Literal(ctx.delimiter.getText), Some(sortOrder.child), + case Descending => ListAgg(column, Literal(ctx.delimiter.getText), sortOrder.child, true) } } else { sortOrder.direction match { - case Ascending => ListAgg(column, Literal(","), Some(sortOrder.child), false) - case Descending => ListAgg(column, Literal(","), Some(sortOrder.child), true) + case Ascending => ListAgg(column, Literal(","), sortOrder.child, false) + case Descending => ListAgg(column, Literal(","), sortOrder.child, true) } } val aggregateExpression = listAgg.toAggregateExpression(isDistinct) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 00ede09965f02..75cfb0c2bafa2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1142,7 +1142,7 @@ object functions { * @since 4.0.0 */ def listagg(e: Column): Column = withAggregateFunction { - ListAgg(e.expr) + ListAgg(e.expr, Literal(","), e.expr) } /** @@ -1152,7 +1152,7 @@ object functions { * @since 4.0.0 */ def listagg(e: Column, delimiter: Column): Column = withAggregateFunction { - ListAgg(e.expr, delimiter.expr) + ListAgg(e.expr, delimiter.expr, e.expr) } /** @@ -1161,8 +1161,8 @@ object functions { * @group agg_funcs * @since 4.0.0 */ - def listagg_distinct(e: Column): Column = withAggregateFunction(ListAgg(e.expr), - isDistinct = true) + def listagg_distinct(e: Column): Column = withAggregateFunction(ListAgg(e.expr, Literal(","), + e.expr), isDistinct = true) /** * Aggregate function: returns the concatenated input values, separated by the delimiter string. @@ -1171,7 +1171,7 @@ object functions { * @since 4.0.0 */ def listagg_distinct(e: Column, delimiter: Column): Column = withAggregateFunction( - ListAgg(e.expr, delimiter.expr), isDistinct = true) + ListAgg(e.expr, delimiter.expr, e.expr), isDistinct = true) /** * Aggregate function: alias for `var_samp`. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 3438b1fcd56dc..56e0a10214fc3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -197,7 +197,7 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark checkAnswer( sql("SELECT LISTAGG(a) WITHIN GROUP (ORDER BY b) FROM df"), - Row("b,c,c,d") :: Nil) + Row("a,a,b,b") :: Nil) checkError( exception = intercept[AnalysisException] { From 9ae3872f3f764e2b4bc46ad3b51254da1b518bc7 Mon Sep 17 00:00:00 2001 From: Jia Fan Date: Mon, 23 Oct 2023 11:36:42 +0800 Subject: [PATCH 12/58] update --- .../spark/sql/catalyst/expressions/aggregate/collect.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala index bea3d3036e8ba..3c7d82a7da2cc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala @@ -329,7 +329,7 @@ case class ListAgg( override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = copy(inputAggBufferOffset = newInputAggBufferOffset) - override def nullable: Boolean = true + override def nullable: Boolean = false override def dataType: DataType = StringType From 8e3aae398b691d08348e70082c8f48d037fbac94 Mon Sep 17 00:00:00 2001 From: Jia Fan Date: Mon, 23 Oct 2023 14:01:41 +0800 Subject: [PATCH 13/58] update --- common/utils/src/main/resources/error/error-classes.json | 3 ++- docs/sql-error-conditions.md | 4 +--- .../spark/sql/catalyst/expressions/aggregate/collect.scala | 3 ++- .../src/test/resources/sql-functions/sql-expression-schema.md | 2 +- .../scala/org/apache/spark/sql/DataFrameAggregateSuite.scala | 2 +- 5 files changed, 7 insertions(+), 7 deletions(-) diff --git a/common/utils/src/main/resources/error/error-classes.json b/common/utils/src/main/resources/error/error-classes.json index ae693d594ca8d..0266729dcf116 100644 --- a/common/utils/src/main/resources/error/error-classes.json +++ b/common/utils/src/main/resources/error/error-classes.json @@ -1060,7 +1060,8 @@ "FUNCTION_AND_ORDER_EXPRESSION_MISMATCH" : { "message" : [ "The function arguments should match the order by expression when use DISTINCT." - ] + ], + "sqlState": "42822" }, "GENERATED_COLUMN_WITH_DEFAULT_VALUE" : { "message" : [ diff --git a/docs/sql-error-conditions.md b/docs/sql-error-conditions.md index 131a4d634ca9f..416ee49c7bfed 100644 --- a/docs/sql-error-conditions.md +++ b/docs/sql-error-conditions.md @@ -648,7 +648,7 @@ The operation `` is not allowed on the ``: `` ### FUNCTION_AND_ORDER_EXPRESSION_MISMATCH -SQLSTATE: none assigned +[SQLSTATE: 42822](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) The function `` arguments `` should match the order by expression `` when use DISTINCT. @@ -2316,5 +2316,3 @@ The operation `` requires a ``. But `` is a The `` requires `` parameters but the actual number is ``. For more details see [WRONG_NUM_ARGS](sql-error-conditions-wrong-num-args-error-class.html) - - diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala index 3c7d82a7da2cc..8e99c714e0e72 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala @@ -266,7 +266,7 @@ case class CollectTopK( > SELECT _FUNC_(col, '|') FROM VALUES ('a'), ('b') AS tab(col); a|b > SELECT _FUNC_(col) FROM VALUES (NULL), (NULL) AS tab(col); - NULL + "" """, group = "agg_funcs", since = "4.0.0") @@ -285,6 +285,7 @@ case class ListAgg( this(child, delimiter, child, false, 0, 0) override protected def convertToBufferElement(value: Any): Any = InternalRow.copyValue(value) + override def defaultResult: Option[Literal] = Option(Literal.create("", StringType)) override protected lazy val bufferElementType: DataType = { StructType(Seq( diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md index 9c77b0a1b8e7f..9986d4c801fb0 100644 --- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md +++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md @@ -401,7 +401,7 @@ | org.apache.spark.sql.catalyst.expressions.aggregate.Kurtosis | kurtosis | SELECT kurtosis(col) FROM VALUES (-10), (-20), (100), (1000) AS tab(col) | struct | | org.apache.spark.sql.catalyst.expressions.aggregate.Last | last | SELECT last(col) FROM VALUES (10), (5), (20) AS tab(col) | struct | | org.apache.spark.sql.catalyst.expressions.aggregate.Last | last_value | SELECT last_value(col) FROM VALUES (10), (5), (20) AS tab(col) | struct | -| org.apache.spark.sql.catalyst.expressions.aggregate.ListAgg | listagg | SELECT listagg(col) FROM VALUES ('a'), ('b'), ('c') AS tab(col) | struct | +| org.apache.spark.sql.catalyst.expressions.aggregate.ListAgg | listagg | SELECT listagg(col) FROM VALUES ('a'), ('b'), ('c') AS tab(col) | struct | | org.apache.spark.sql.catalyst.expressions.aggregate.Max | max | SELECT max(col) FROM VALUES (10), (50), (20) AS tab(col) | struct | | org.apache.spark.sql.catalyst.expressions.aggregate.MaxBy | max_by | SELECT max_by(x, y) FROM VALUES ('a', 10), ('b', 50), ('c', 20) AS tab(x, y) | struct | | org.apache.spark.sql.catalyst.expressions.aggregate.Median | median | SELECT median(col) FROM VALUES (0), (10) AS tab(col) | struct | diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 7aa4fb052e062..4c3f8d5fe9174 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -627,7 +627,7 @@ class DataFrameAggregateSuite extends QueryTest checkAnswer( df3.select(listagg_distinct($"a"), listagg($"a"), listagg_distinct($"b"), listagg($"b"), listagg($"c")), - Seq(Row("a", "a,a", "b", "b,b", null)) + Seq(Row("a", "a,a", "b", "b,b", "")) ) // custom delimiter From c0f14962bbdbf3c7ce69c60730c67dc5f746f575 Mon Sep 17 00:00:00 2001 From: Jia Fan Date: Mon, 23 Oct 2023 17:25:59 +0800 Subject: [PATCH 14/58] update --- common/utils/src/main/resources/error/error-classes.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/utils/src/main/resources/error/error-classes.json b/common/utils/src/main/resources/error/error-classes.json index 0266729dcf116..79c06d4a485cf 100644 --- a/common/utils/src/main/resources/error/error-classes.json +++ b/common/utils/src/main/resources/error/error-classes.json @@ -1061,7 +1061,7 @@ "message" : [ "The function arguments should match the order by expression when use DISTINCT." ], - "sqlState": "42822" + "sqlState" : "42822" }, "GENERATED_COLUMN_WITH_DEFAULT_VALUE" : { "message" : [ From 2fcd3732bd2c4662f100ad5a86f8ccb63bbb68a4 Mon Sep 17 00:00:00 2001 From: Jia Fan Date: Mon, 23 Oct 2023 20:35:11 +0800 Subject: [PATCH 15/58] update --- .../spark/sql/catalyst/expressions/aggregate/collect.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala index 8e99c714e0e72..9a560354eca13 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala @@ -266,7 +266,7 @@ case class CollectTopK( > SELECT _FUNC_(col, '|') FROM VALUES ('a'), ('b') AS tab(col); a|b > SELECT _FUNC_(col) FROM VALUES (NULL), (NULL) AS tab(col); - "" + """, group = "agg_funcs", since = "4.0.0") From bd088a4e7125a06945f5799c6bc256ec025c1e51 Mon Sep 17 00:00:00 2001 From: Jia Fan Date: Tue, 24 Oct 2023 17:02:43 +0800 Subject: [PATCH 16/58] Update sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 Co-authored-by: Jiaan Geng --- .../org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 index 5662025a94880..3f4b0f6a324b3 100644 --- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 @@ -989,7 +989,7 @@ primaryExpression WITHIN GROUP LEFT_PAREN ORDER BY sortItem RIGHT_PAREN (FILTER LEFT_PAREN WHERE where=booleanExpression RIGHT_PAREN)? ( OVER windowSpec)? #percentile | LISTAGG LEFT_PAREN setQuantifier? aggEpxr=expression (COMMA delimiter=stringLit)? RIGHT_PAREN - WITHIN GROUP LEFT_PAREN ORDER BY sortItem RIGHT_PAREN (OVER windowSpec)? #listAgg + WITHIN GROUP LEFT_PAREN ORDER BY sortItem RIGHT_PAREN (OVER windowSpec)? #listAgg ; literalType From dd0bfafaba170419fbcd97358315ac2e30e41fea Mon Sep 17 00:00:00 2001 From: Jia Fan Date: Tue, 24 Oct 2023 17:55:47 +0800 Subject: [PATCH 17/58] update --- .../expressions/aggregate/collect.scala | 43 +++++++++++++------ 1 file changed, 31 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala index 9a560354eca13..0679ded7f13e3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala @@ -284,26 +284,40 @@ case class ListAgg( def this(child: Expression, delimiter: Expression) = this(child, delimiter, child, false, 0, 0) + private lazy val sameExpression = orderExpression.semanticEquals(child) + override protected def convertToBufferElement(value: Any): Any = InternalRow.copyValue(value) override def defaultResult: Option[Literal] = Option(Literal.create("", StringType)) override protected lazy val bufferElementType: DataType = { - StructType(Seq( - StructField("value", child.dataType), - StructField("sortOrder", orderExpression.dataType))) + if (sameExpression) { + child.dataType + } else { + StructType(Seq( + StructField("value", child.dataType), + StructField("sortOrder", orderExpression.dataType))) + } } override def eval(buffer: mutable.ArrayBuffer[Any]): Any = { if (buffer.nonEmpty) { val ordering = PhysicalDataType.ordering(orderExpression.dataType) - val sorted = if (reverse) { - buffer.asInstanceOf[mutable.ArrayBuffer[InternalRow]].toSeq.sortBy(_.get(1, - orderExpression.dataType))(ordering.asInstanceOf[Ordering[AnyRef]].reverse).map(_.get(0, - child.dataType)) + val sorted = if (sameExpression) { + if (reverse) { + buffer.toSeq.sorted(ordering.reverse) + } else { + buffer.toSeq.sorted(ordering) + } } else { - buffer.asInstanceOf[mutable.ArrayBuffer[InternalRow]].toSeq.sortBy(_.get(1, - orderExpression.dataType))(ordering.asInstanceOf[Ordering[AnyRef]]).map(_.get(0, - child.dataType)) + if (reverse) { + buffer.asInstanceOf[mutable.ArrayBuffer[InternalRow]].toSeq.sortBy(_.get(1, + orderExpression.dataType))(ordering.asInstanceOf[Ordering[AnyRef]].reverse).map(_.get(0, + child.dataType)) + } else { + buffer.asInstanceOf[mutable.ArrayBuffer[InternalRow]].toSeq.sortBy(_.get(1, + orderExpression.dataType))(ordering.asInstanceOf[Ordering[AnyRef]]).map(_.get(0, + child.dataType)) + } } UTF8String.fromString(sorted.map(_.toString) .mkString(delimiter.eval().asInstanceOf[UTF8String].toString)) @@ -315,8 +329,13 @@ case class ListAgg( override def update(buffer: ArrayBuffer[Any], input: InternalRow): ArrayBuffer[Any] = { val value = child.eval(input) if (value != null) { - buffer += InternalRow.apply(convertToBufferElement(value), - convertToBufferElement(orderExpression.eval(input))) + val v = if (sameExpression) { + convertToBufferElement(value) + } else { + InternalRow.apply(convertToBufferElement(value), + convertToBufferElement(orderExpression.eval(input))) + } + buffer += v } buffer } From 6f54ab0b6c26969a403f4ead235b6498941c8f6d Mon Sep 17 00:00:00 2001 From: Jia Fan Date: Tue, 24 Oct 2023 18:19:36 +0800 Subject: [PATCH 18/58] update --- .../expressions/aggregate/collect.scala | 20 +++++++++---------- .../sql/catalyst/parser/AstBuilder.scala | 16 +++------------ 2 files changed, 13 insertions(+), 23 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala index 0679ded7f13e3..68bac13d790af 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala @@ -284,6 +284,16 @@ case class ListAgg( def this(child: Expression, delimiter: Expression) = this(child, delimiter, child, false, 0, 0) + override def createAggregationBuffer(): mutable.ArrayBuffer[Any] = mutable.ArrayBuffer.empty + + override def nullable: Boolean = false + + override def dataType: DataType = StringType + + override def left: Expression = child + + override def right: Expression = orderExpression + private lazy val sameExpression = orderExpression.semanticEquals(child) override protected def convertToBufferElement(value: Any): Any = InternalRow.copyValue(value) @@ -340,8 +350,6 @@ case class ListAgg( buffer } - override def createAggregationBuffer(): mutable.ArrayBuffer[Any] = mutable.ArrayBuffer.empty - override def withNewMutableAggBufferOffset( newMutableAggBufferOffset: Int) : ImperativeAggregate = copy(mutableAggBufferOffset = newMutableAggBufferOffset) @@ -349,14 +357,6 @@ case class ListAgg( override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = copy(inputAggBufferOffset = newInputAggBufferOffset) - override def nullable: Boolean = false - - override def dataType: DataType = StringType - - override def left: Expression = child - - override def right: Expression = orderExpression - override protected def withNewChildrenInternal( newLeft: Expression, newRight: Expression): Expression = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 2891250e1e1b2..3ef3acfde0b57 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -2225,19 +2225,9 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { throw QueryCompilationErrors.functionAndOrderExpressionMismatchError("LISTAGG", column, sortOrder.child) } - val listAgg = if (ctx.delimiter != null) { - sortOrder.direction match { - case Ascending => ListAgg(column, Literal(ctx.delimiter.getText), sortOrder.child, - false) - case Descending => ListAgg(column, Literal(ctx.delimiter.getText), sortOrder.child, - true) - } - } else { - sortOrder.direction match { - case Ascending => ListAgg(column, Literal(","), sortOrder.child, false) - case Descending => ListAgg(column, Literal(","), sortOrder.child, true) - } - } + val delimiter = if (ctx.delimiter != null) Literal(ctx.delimiter.getText) else Literal(",") + val reverse = sortOrder.direction == Descending + val listAgg = ListAgg(column, delimiter, sortOrder.child, reverse) val aggregateExpression = listAgg.toAggregateExpression(isDistinct) ctx.windowSpec match { case spec: WindowRefContext => From b0dc0175ed099b69fd516ad7f73ac88fc1a8acae Mon Sep 17 00:00:00 2001 From: Jia Fan Date: Tue, 24 Oct 2023 18:59:43 +0800 Subject: [PATCH 19/58] update --- .../expressions/aggregate/collect.scala | 20 +++++++++---------- .../sql/errors/QueryCompilationErrors.scala | 2 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 2 +- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala index 68bac13d790af..441906e19e957 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala @@ -272,7 +272,7 @@ case class CollectTopK( since = "4.0.0") case class ListAgg( child: Expression, - delimiter: Expression = Literal.create(",", StringType), + delimiter: Expression, orderExpression: Expression, reverse: Boolean = false, mutableAggBufferOffset: Int = 0, @@ -284,8 +284,6 @@ case class ListAgg( def this(child: Expression, delimiter: Expression) = this(child, delimiter, child, false, 0, 0) - override def createAggregationBuffer(): mutable.ArrayBuffer[Any] = mutable.ArrayBuffer.empty - override def nullable: Boolean = false override def dataType: DataType = StringType @@ -294,6 +292,15 @@ case class ListAgg( override def right: Expression = orderExpression + override def createAggregationBuffer(): mutable.ArrayBuffer[Any] = mutable.ArrayBuffer.empty + + override def withNewMutableAggBufferOffset( + newMutableAggBufferOffset: Int): ImperativeAggregate = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) + + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = + copy(inputAggBufferOffset = newInputAggBufferOffset) + private lazy val sameExpression = orderExpression.semanticEquals(child) override protected def convertToBufferElement(value: Any): Any = InternalRow.copyValue(value) @@ -350,13 +357,6 @@ case class ListAgg( buffer } - override def withNewMutableAggBufferOffset( - newMutableAggBufferOffset: Int) : ImperativeAggregate = - copy(mutableAggBufferOffset = newMutableAggBufferOffset) - - override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = - copy(inputAggBufferOffset = newInputAggBufferOffset) - override protected def withNewChildrenInternal( newLeft: Expression, newRight: Expression): Expression = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index 1e0f52c12099a..cd05f8d829961 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -927,7 +927,7 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat new AnalysisException( errorClass = "FUNCTION_AND_ORDER_EXPRESSION_MISMATCH", messageParameters = Map( - "functionName" -> toSQLStmt(functionName), + "functionName" -> toSQLId(functionName), "functionExpr" -> toSQLExpr(functionExpr), "orderExpr" -> toSQLExpr(orderExpr))) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 269dbe8382275..2a20478d9e177 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -205,7 +205,7 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark }, errorClass = "FUNCTION_AND_ORDER_EXPRESSION_MISMATCH", parameters = Map( - "functionName" -> "LISTAGG", + "functionName" -> "`LISTAGG`", "functionExpr" -> "\"a\"", "orderExpr" -> "\"b\"")) From 87999f4c98799c5365efcb0cf0e83fd8f0aa3dbc Mon Sep 17 00:00:00 2001 From: Jia Fan Date: Tue, 24 Oct 2023 19:59:35 +0800 Subject: [PATCH 20/58] update --- .../src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 2a20478d9e177..730add64d7882 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -199,6 +199,10 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark sql("SELECT LISTAGG(a) WITHIN GROUP (ORDER BY b) FROM df"), Row("a,a,b,b") :: Nil) + checkAnswer( + sql("SELECT LISTAGG(a) WITHIN GROUP (ORDER BY b DESC) FROM df"), + Row("b,a,b,a") :: Nil) + checkError( exception = intercept[AnalysisException] { sql("SELECT LISTAGG(DISTINCT a) WITHIN GROUP (ORDER BY b) FROM df") From a0f0e5da554e5c4f048f351b8e5655015c1bdee0 Mon Sep 17 00:00:00 2001 From: Jia Fan Date: Tue, 24 Oct 2023 20:20:32 +0800 Subject: [PATCH 21/58] update --- .../expressions/aggregate/collect.scala | 22 +++++++++---------- 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala index 441906e19e957..a93b89f2b8104 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala @@ -319,23 +319,21 @@ case class ListAgg( override def eval(buffer: mutable.ArrayBuffer[Any]): Any = { if (buffer.nonEmpty) { val ordering = PhysicalDataType.ordering(orderExpression.dataType) - val sorted = if (sameExpression) { - if (reverse) { - buffer.toSeq.sorted(ordering.reverse) - } else { - buffer.toSeq.sorted(ordering) - } - } else { - if (reverse) { - buffer.asInstanceOf[mutable.ArrayBuffer[InternalRow]].toSeq.sortBy(_.get(1, + lazy val sortFunc = (sameExpression, reverse) match { + case (true, true) => (buffer: mutable.ArrayBuffer[Any]) => + buffer.sorted(ordering.reverse) + case (true, false) => (buffer: mutable.ArrayBuffer[Any]) => + buffer.sorted(ordering) + case (false, true) => (buffer: mutable.ArrayBuffer[Any]) => + buffer.asInstanceOf[mutable.ArrayBuffer[InternalRow]].sortBy(_.get(1, orderExpression.dataType))(ordering.asInstanceOf[Ordering[AnyRef]].reverse).map(_.get(0, child.dataType)) - } else { - buffer.asInstanceOf[mutable.ArrayBuffer[InternalRow]].toSeq.sortBy(_.get(1, + case (false, false) => (buffer: mutable.ArrayBuffer[Any]) => + buffer.asInstanceOf[mutable.ArrayBuffer[InternalRow]].sortBy(_.get(1, orderExpression.dataType))(ordering.asInstanceOf[Ordering[AnyRef]]).map(_.get(0, child.dataType)) - } } + val sorted = sortFunc(buffer) UTF8String.fromString(sorted.map(_.toString) .mkString(delimiter.eval().asInstanceOf[UTF8String].toString)) } else { From 8ba2466be04fe692e74523e231d040232e8f55c8 Mon Sep 17 00:00:00 2001 From: Jia Fan Date: Wed, 25 Oct 2023 10:09:05 +0800 Subject: [PATCH 22/58] update --- .../catalyst/analysis/FunctionRegistry.scala | 2 +- .../expressions/aggregate/collect.scala | 39 ++++++++++--------- .../sql/catalyst/parser/AstBuilder.scala | 2 +- .../org/apache/spark/sql/functions.scala | 10 ++--- 4 files changed, 28 insertions(+), 25 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index b47f52d665f14..9e3c8160e2923 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -478,7 +478,6 @@ object FunctionRegistry { expression[Percentile]("percentile"), expression[Median]("median"), expression[Skewness]("skewness"), - expression[ListAgg]("listagg"), expression[ApproximatePercentile]("percentile_approx"), expression[ApproximatePercentile]("approx_percentile", true), expression[HistogramNumeric]("histogram_numeric"), @@ -493,6 +492,7 @@ object FunctionRegistry { expression[CollectList]("collect_list"), expression[CollectList]("array_agg", true, Some("3.3.0")), expression[CollectSet]("collect_set"), + expression[ListAgg]("listagg"), expressionBuilder("count_min_sketch", CountMinSketchAggExpressionBuilder), expression[BoolAnd]("every", true), expression[BoolAnd]("bool_and"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala index a93b89f2b8104..ec46d48d51842 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala @@ -272,17 +272,17 @@ case class CollectTopK( since = "4.0.0") case class ListAgg( child: Expression, - delimiter: Expression, orderExpression: Expression, + delimiter: Expression = Literal.create(",", StringType), reverse: Boolean = false, mutableAggBufferOffset: Int = 0, inputAggBufferOffset: Int = 0) extends Collect[mutable.ArrayBuffer[Any]] with BinaryLike[Expression] { def this(child: Expression) = - this(child, Literal.create(",", StringType), child, false, 0, 0) + this(child, child, Literal.create(",", StringType), false, 0, 0) def this(child: Expression, delimiter: Expression) = - this(child, delimiter, child, false, 0, 0) + this(child, child, delimiter, false, 0, 0) override def nullable: Boolean = false @@ -316,23 +316,26 @@ case class ListAgg( } } + private lazy val sortFunc = { + val ordering = PhysicalDataType.ordering(orderExpression.dataType) + (sameExpression, reverse) match { + case (true, true) => (buffer: mutable.ArrayBuffer[Any]) => + buffer.sorted(ordering.reverse) + case (true, false) => (buffer: mutable.ArrayBuffer[Any]) => + buffer.sorted(ordering) + case (false, true) => (buffer: mutable.ArrayBuffer[Any]) => + buffer.asInstanceOf[mutable.ArrayBuffer[InternalRow]].sortBy(_.get(1, + orderExpression.dataType))(ordering.asInstanceOf[Ordering[AnyRef]].reverse).map(_.get(0, + child.dataType)) + case (false, false) => (buffer: mutable.ArrayBuffer[Any]) => + buffer.asInstanceOf[mutable.ArrayBuffer[InternalRow]].sortBy(_.get(1, + orderExpression.dataType))(ordering.asInstanceOf[Ordering[AnyRef]]).map(_.get(0, + child.dataType)) + } + } + override def eval(buffer: mutable.ArrayBuffer[Any]): Any = { if (buffer.nonEmpty) { - val ordering = PhysicalDataType.ordering(orderExpression.dataType) - lazy val sortFunc = (sameExpression, reverse) match { - case (true, true) => (buffer: mutable.ArrayBuffer[Any]) => - buffer.sorted(ordering.reverse) - case (true, false) => (buffer: mutable.ArrayBuffer[Any]) => - buffer.sorted(ordering) - case (false, true) => (buffer: mutable.ArrayBuffer[Any]) => - buffer.asInstanceOf[mutable.ArrayBuffer[InternalRow]].sortBy(_.get(1, - orderExpression.dataType))(ordering.asInstanceOf[Ordering[AnyRef]].reverse).map(_.get(0, - child.dataType)) - case (false, false) => (buffer: mutable.ArrayBuffer[Any]) => - buffer.asInstanceOf[mutable.ArrayBuffer[InternalRow]].sortBy(_.get(1, - orderExpression.dataType))(ordering.asInstanceOf[Ordering[AnyRef]]).map(_.get(0, - child.dataType)) - } val sorted = sortFunc(buffer) UTF8String.fromString(sorted.map(_.toString) .mkString(delimiter.eval().asInstanceOf[UTF8String].toString)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 3ef3acfde0b57..b8a2deb06087c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -2227,7 +2227,7 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { } val delimiter = if (ctx.delimiter != null) Literal(ctx.delimiter.getText) else Literal(",") val reverse = sortOrder.direction == Descending - val listAgg = ListAgg(column, delimiter, sortOrder.child, reverse) + val listAgg = ListAgg(column, sortOrder.child, delimiter, reverse) val aggregateExpression = listAgg.toAggregateExpression(isDistinct) ctx.windowSpec match { case spec: WindowRefContext => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 75cfb0c2bafa2..409d022f63199 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1142,7 +1142,7 @@ object functions { * @since 4.0.0 */ def listagg(e: Column): Column = withAggregateFunction { - ListAgg(e.expr, Literal(","), e.expr) + ListAgg(e.expr, e.expr) } /** @@ -1152,7 +1152,7 @@ object functions { * @since 4.0.0 */ def listagg(e: Column, delimiter: Column): Column = withAggregateFunction { - ListAgg(e.expr, delimiter.expr, e.expr) + ListAgg(e.expr, e.expr, delimiter.expr) } /** @@ -1161,8 +1161,8 @@ object functions { * @group agg_funcs * @since 4.0.0 */ - def listagg_distinct(e: Column): Column = withAggregateFunction(ListAgg(e.expr, Literal(","), - e.expr), isDistinct = true) + def listagg_distinct(e: Column): Column = withAggregateFunction(ListAgg(e.expr, e.expr), + isDistinct = true) /** * Aggregate function: returns the concatenated input values, separated by the delimiter string. @@ -1171,7 +1171,7 @@ object functions { * @since 4.0.0 */ def listagg_distinct(e: Column, delimiter: Column): Column = withAggregateFunction( - ListAgg(e.expr, delimiter.expr, e.expr), isDistinct = true) + ListAgg(e.expr, e.expr, delimiter.expr), isDistinct = true) /** * Aggregate function: alias for `var_samp`. From 885e812940bc41cb537d67c0ed6a621d3ab8a151 Mon Sep 17 00:00:00 2001 From: Jia Fan Date: Mon, 27 Nov 2023 11:41:26 +0800 Subject: [PATCH 23/58] update --- .../org/apache/spark/sql/catalyst/parser/AstBuilder.scala | 3 ++- .../src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala | 4 ++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index a14fe0dbe25f1..5e1e80233111c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -2246,7 +2246,8 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { throw QueryCompilationErrors.functionAndOrderExpressionMismatchError("LISTAGG", column, sortOrder.child) } - val delimiter = if (ctx.delimiter != null) Literal(ctx.delimiter.getText) else Literal(",") + val delimiter = if (ctx.delimiter != null) Literal(string(visitStringLit(ctx.delimiter))) + else Literal(",") val reverse = sortOrder.direction == Descending val listAgg = ListAgg(column, sortOrder.child, delimiter, reverse) val aggregateExpression = listAgg.toAggregateExpression(isDistinct) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 3b3659f0ce06f..31bb404d6eff3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -203,6 +203,10 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark sql("SELECT LISTAGG(a) WITHIN GROUP (ORDER BY b DESC) FROM df"), Row("b,a,b,a") :: Nil) + checkAnswer( + sql("SELECT LISTAGG(a, '|') WITHIN GROUP (ORDER BY b DESC) FROM df"), + Row("b|a|b|a") :: Nil) + checkError( exception = intercept[AnalysisException] { sql("SELECT LISTAGG(DISTINCT a) WITHIN GROUP (ORDER BY b) FROM df") From 9e70cc57cd49f23ecc900922ef78960ddb501967 Mon Sep 17 00:00:00 2001 From: Mikhail Nikoliukin Date: Thu, 31 Oct 2024 11:02:59 +0100 Subject: [PATCH 24/58] [SPARK-42746] upgrade the old branch after merge --- .../CheckConnectJvmClientCompatibility.scala | 4 - docs/sql-ref-ansi-compliance.md | 1 - .../spark/sql/catalyst/parser/SqlBaseLexer.g4 | 1 - .../sql/catalyst/parser/SqlBaseParser.g4 | 2 - .../org/apache/spark/sql/functions.scala | 15 ++-- .../sql/catalyst/analysis/Analyzer.scala | 6 +- .../InverseDistributionFunction.scala | 26 ++++++ .../catalyst/expressions/aggregate/Mode.scala | 2 +- .../SupportsOrderingWithinGroup.scala | 24 ++--- .../expressions/aggregate/collect.scala | 88 +++++++++---------- .../expressions/aggregate/percentiles.scala | 4 +- .../sql/catalyst/parser/AstBuilder.scala | 2 +- .../sql-tests/results/ansi/keywords.sql.out | 1 - .../sql-tests/results/keywords.sql.out | 1 - .../org/apache/spark/sql/SQLQuerySuite.scala | 22 ++--- .../ThriftServerWithSparkContextSuite.scala | 2 +- 16 files changed, 99 insertions(+), 102 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/InverseDistributionFunction.scala diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala index 96d3d172bd779..d176cb3d0a444 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala @@ -205,10 +205,6 @@ object CheckConnectJvmClientCompatibility { ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.toJavaRDD"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.javaRDD"), - ProblemFilters.exclude[DirectMissingMethodProblem]( - "org.apache.spark.sql.functions.listagg"), - ProblemFilters.exclude[DirectMissingMethodProblem]( - "org.apache.spark.sql.functions.listagg_distinct"), // KeyValueGroupedDataset ProblemFilters.exclude[Problem]( "org.apache.spark.sql.KeyValueGroupedDataset.queryExecution"), diff --git a/docs/sql-ref-ansi-compliance.md b/docs/sql-ref-ansi-compliance.md index e964e3069b52e..268f5b970d30d 100644 --- a/docs/sql-ref-ansi-compliance.md +++ b/docs/sql-ref-ansi-compliance.md @@ -580,7 +580,6 @@ Below is a list of all the keywords in Spark SQL. |LIMIT|non-reserved|non-reserved|non-reserved| |LINES|non-reserved|non-reserved|non-reserved| |LIST|non-reserved|non-reserved|non-reserved| -|LISTAGG|non-reserved|non-reserved|non-reserved| |LOAD|non-reserved|non-reserved|non-reserved| |LOCAL|non-reserved|non-reserved|reserved| |LOCATION|non-reserved|non-reserved|non-reserved| diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 index 916817b39f61c..085e723d02bc0 100644 --- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 @@ -295,7 +295,6 @@ ILIKE: 'ILIKE'; LIMIT: 'LIMIT'; LINES: 'LINES'; LIST: 'LIST'; -LISTAGG: 'LISTAGG'; LOAD: 'LOAD'; LOCAL: 'LOCAL'; LOCATION: 'LOCATION'; diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 index f6ca0823d9d79..4900c971966cc 100644 --- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 @@ -1661,7 +1661,6 @@ ansiNonReserved | LIMIT | LINES | LIST - | LISTAGG | LOAD | LOCAL | LOCATION @@ -2022,7 +2021,6 @@ nonReserved | LIMIT | LINES | LIST - | LISTAGG | LOAD | LOCAL | LOCATION diff --git a/sql/api/src/main/scala/org/apache/spark/sql/functions.scala b/sql/api/src/main/scala/org/apache/spark/sql/functions.scala index a6c81641d268c..502258fb5aed1 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/functions.scala @@ -1153,9 +1153,7 @@ object functions { * @group agg_funcs * @since 4.0.0 */ - def listagg(e: Column): Column = withAggregateFunction { - ListAgg(e.expr, e.expr) - } + def listagg(e: Column): Column = Column.fn("listagg", e) /** * Aggregate function: returns the concatenated input values, separated by the delimiter string. @@ -1163,9 +1161,7 @@ object functions { * @group agg_funcs * @since 4.0.0 */ - def listagg(e: Column, delimiter: Column): Column = withAggregateFunction { - ListAgg(e.expr, e.expr, delimiter.expr) - } + def listagg(e: Column, delimiter: Column): Column = Column.fn("listagg", e, delimiter) /** * Aggregate function: returns the concatenated input values, separated by the delimiter string. @@ -1173,8 +1169,7 @@ object functions { * @group agg_funcs * @since 4.0.0 */ - def listagg_distinct(e: Column): Column = withAggregateFunction(ListAgg(e.expr, e.expr), - isDistinct = true) + def listagg_distinct(e: Column): Column = Column.fn("listagg", isDistinct = true, e) /** * Aggregate function: returns the concatenated input values, separated by the delimiter string. @@ -1182,8 +1177,8 @@ object functions { * @group agg_funcs * @since 4.0.0 */ - def listagg_distinct(e: Column, delimiter: Column): Column = withAggregateFunction( - ListAgg(e.expr, e.expr, delimiter.expr), isDistinct = true) + def listagg_distinct(e: Column, delimiter: Column): Column = + Column.fn("listagg", isDistinct = true, e, delimiter) /** * Aggregate function: alias for `var_samp`. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 6b64f493f4052..317f8a2fce079 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2210,14 +2210,14 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor numArgs: Int, u: UnresolvedFunction): Expression = { func match { - case owg: SupportsOrderingWithinGroup if u.isDistinct => + case owg: InverseDistributionFunction if u.isDistinct => throw QueryCompilationErrors.distinctInverseDistributionFunctionUnsupportedError( owg.prettyName) - case owg: SupportsOrderingWithinGroup + case owg: InverseDistributionFunction if !owg.orderingFilled && u.orderingWithinGroup.isEmpty => throw QueryCompilationErrors.inverseDistributionFunctionMissingWithinGroupError( owg.prettyName) - case owg: SupportsOrderingWithinGroup + case owg: InverseDistributionFunction if owg.orderingFilled && u.orderingWithinGroup.nonEmpty => throw QueryCompilationErrors.wrongNumOrderingsForInverseDistributionFunctionError( owg.prettyName, 0, u.orderingWithinGroup.length) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/InverseDistributionFunction.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/InverseDistributionFunction.scala new file mode 100644 index 0000000000000..7e4c028f89a10 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/InverseDistributionFunction.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +/** + * The trait used to set the [[SortOrder]] after inverse distribution functions parsed. + * Order clause is mandatory for all extenders. + */ +trait InverseDistributionFunction + extends SupportsOrderingWithinGroup { self: AggregateFunction => +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala index 97add0b8e45bc..7af4d5668d719 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala @@ -37,7 +37,7 @@ case class Mode( inputAggBufferOffset: Int = 0, reverseOpt: Option[Boolean] = None) extends TypedAggregateWithHashMapAsBuffer with ImplicitCastInputTypes - with SupportsOrderingWithinGroup with UnaryLike[Expression] { + with InverseDistributionFunction with UnaryLike[Expression] { def this(child: Expression) = this(child, 0, 0) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/SupportsOrderingWithinGroup.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/SupportsOrderingWithinGroup.scala index 9c0502a2c1fcf..dfc28455d5b7f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/SupportsOrderingWithinGroup.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/SupportsOrderingWithinGroup.scala @@ -1,28 +1,14 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.expressions.SortOrder /** - * The trait used to set the [[SortOrder]] after inverse distribution functions parsed. + * The trait used to set the [[SortOrder]] for supporting functions. + * By default ordering is optional. */ -trait SupportsOrderingWithinGroup { self: AggregateFunction => - def orderingFilled: Boolean = false +trait SupportsOrderingWithinGroup { def withOrderingWithinGroup(orderingWithinGroup: Seq[SortOrder]): AggregateFunction + /** Indicator that ordering was set */ + def orderingFilled: Boolean = false } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala index 4293cdd8d6e7e..a0ec970f54162 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala @@ -18,14 +18,13 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import scala.collection.mutable -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.Growable +import scala.collection.mutable.{ArrayBuffer, Growable} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.trees.{BinaryLike, UnaryLike} +import org.apache.spark.sql.catalyst.trees.UnaryLike import org.apache.spark.sql.catalyst.types.PhysicalDataType import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, TypeUtils, UnsafeRowUtils} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase} @@ -283,31 +282,23 @@ private[aggregate] object CollectTopK { > SELECT _FUNC_(col, '|') FROM VALUES ('a'), ('b') AS tab(col); a|b > SELECT _FUNC_(col) FROM VALUES (NULL), (NULL) AS tab(col); - """, group = "agg_funcs", since = "4.0.0") case class ListAgg( child: Expression, - orderExpression: Expression, - delimiter: Expression = Literal.create(",", StringType), - reverse: Boolean = false, + delimiter: Expression = Literal.create(",", StringType),// TODO replace with null (empty string) + orderExpressions: Seq[SortOrder] = Nil, mutableAggBufferOffset: Int = 0, inputAggBufferOffset: Int = 0) extends Collect[mutable.ArrayBuffer[Any]] - with BinaryLike[Expression] { + with SupportsOrderingWithinGroup { def this(child: Expression) = - this(child, child, Literal.create(",", StringType), false, 0, 0) + this(child, Literal.create(",", StringType), Nil, 0, 0) def this(child: Expression, delimiter: Expression) = - this(child, child, delimiter, false, 0, 0) - - override def nullable: Boolean = false - - override def dataType: DataType = StringType - - override def left: Expression = child + this(child, delimiter, Nil, 0, 0) - override def right: Expression = orderExpression + override def dataType: DataType = StringType // TODO add binary override def createAggregationBuffer(): mutable.ArrayBuffer[Any] = mutable.ArrayBuffer.empty @@ -318,66 +309,75 @@ case class ListAgg( override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = copy(inputAggBufferOffset = newInputAggBufferOffset) - private lazy val sameExpression = orderExpression.semanticEquals(child) + private lazy val dontNeedSaveOrderValue = orderExpressions.isEmpty || + (orderExpressions.size == 1 && orderExpressions.head.semanticEquals(child)) override protected def convertToBufferElement(value: Any): Any = InternalRow.copyValue(value) + // TODO make null override def defaultResult: Option[Literal] = Option(Literal.create("", StringType)) override protected lazy val bufferElementType: DataType = { - if (sameExpression) { + if (dontNeedSaveOrderValue) { child.dataType } else { StructType(Seq( StructField("value", child.dataType), - StructField("sortOrder", orderExpression.dataType))) + StructField("sortOrder", orderExpressions.head.dataType))) } } - private lazy val sortFunc = { - val ordering = PhysicalDataType.ordering(orderExpression.dataType) - (sameExpression, reverse) match { - case (true, true) => (buffer: mutable.ArrayBuffer[Any]) => - buffer.sorted(ordering.reverse) - case (true, false) => (buffer: mutable.ArrayBuffer[Any]) => - buffer.sorted(ordering) - case (false, true) => (buffer: mutable.ArrayBuffer[Any]) => - buffer.asInstanceOf[mutable.ArrayBuffer[InternalRow]].sortBy(_.get(1, - orderExpression.dataType))(ordering.asInstanceOf[Ordering[AnyRef]].reverse).map(_.get(0, - child.dataType)) - case (false, false) => (buffer: mutable.ArrayBuffer[Any]) => - buffer.asInstanceOf[mutable.ArrayBuffer[InternalRow]].sortBy(_.get(1, - orderExpression.dataType))(ordering.asInstanceOf[Ordering[AnyRef]]).map(_.get(0, - child.dataType)) + private[this] def sortBuffer(buffer: mutable.ArrayBuffer[Any]): mutable.ArrayBuffer[Any] = { + if (!orderingFilled) { + return buffer + } + val ascendingOrdering = PhysicalDataType.ordering(orderExpressions.head.dataType) + val ordering = if (orderExpressions.head.direction == Ascending) ascendingOrdering + else ascendingOrdering.reverse + + if (dontNeedSaveOrderValue) { + buffer.sorted(ordering) + } else { + buffer.asInstanceOf[mutable.ArrayBuffer[InternalRow]] + .sortBy(_.get(1, orderExpressions.head.dataType))(ordering.asInstanceOf[Ordering[AnyRef]]) + .map(_.get(0, child.dataType)) } } override def eval(buffer: mutable.ArrayBuffer[Any]): Any = { if (buffer.nonEmpty) { - val sorted = sortFunc(buffer) + val sorted = sortBuffer(buffer) UTF8String.fromString(sorted.map(_.toString) .mkString(delimiter.eval().asInstanceOf[UTF8String].toString)) } else { - UTF8String.fromString("") + UTF8String.fromString("") // TODO null } } override def update(buffer: ArrayBuffer[Any], input: InternalRow): ArrayBuffer[Any] = { val value = child.eval(input) if (value != null) { - val v = if (sameExpression) { + val v = if (dontNeedSaveOrderValue) { convertToBufferElement(value) } else { InternalRow.apply(convertToBufferElement(value), - convertToBufferElement(orderExpression.eval(input))) + convertToBufferElement(orderExpressions.head.child.eval(input))) } buffer += v } buffer } - override protected def withNewChildrenInternal( - newLeft: Expression, - newRight: Expression): Expression = { - copy(child = newLeft, orderExpression = newRight) - } + override def orderingFilled: Boolean = orderExpressions.nonEmpty + override def withOrderingWithinGroup(orderingWithinGroup: Seq[SortOrder]): AggregateFunction = + copy(orderExpressions = orderingWithinGroup) + + override def children: Seq[Expression] = child +: delimiter +: orderExpressions.map(_.child) + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = + copy( + child = newChildren.head, + delimiter = newChildren(1), + orderExpressions = newChildren.drop(2).zip(orderExpressions) + .map { case (newExpr, oldSortOrder) => oldSortOrder.copy(child = newExpr) } + ) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/percentiles.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/percentiles.scala index 89a6984b80852..25f37385ef97d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/percentiles.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/percentiles.scala @@ -360,7 +360,7 @@ case class PercentileCont(left: Expression, right: Expression, reverse: Boolean extends AggregateFunction with RuntimeReplaceableAggregate with ImplicitCastInputTypes - with SupportsOrderingWithinGroup + with InverseDistributionFunction with BinaryLike[Expression] { private lazy val percentile = new Percentile(left, right, reverse) override lazy val replacement: Expression = percentile @@ -407,7 +407,7 @@ case class PercentileDisc( mutableAggBufferOffset: Int = 0, inputAggBufferOffset: Int = 0, legacyCalculation: Boolean = SQLConf.get.getConf(SQLConf.LEGACY_PERCENTILE_DISC_CALCULATION)) - extends PercentileBase with SupportsOrderingWithinGroup with BinaryLike[Expression] { + extends PercentileBase with InverseDistributionFunction with BinaryLike[Expression] { val frequencyExpression: Expression = Literal(1L) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index f52ba2e9cf84c..caeb78d20e6a8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FUNC_ALIAS import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, ClusterBySpec} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.{AnyValue, First, Last, ListAgg} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AnyValue, First, Last} import org.apache.spark.sql.catalyst.parser.SqlBaseParser._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/keywords.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/keywords.sql.out index 94c5f398231b8..b2331ec4ab804 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/keywords.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/keywords.sql.out @@ -181,7 +181,6 @@ LIKE false LIMIT false LINES false LIST false -LISTAGG false LOAD false LOCAL false LOCATION false diff --git a/sql/core/src/test/resources/sql-tests/results/keywords.sql.out b/sql/core/src/test/resources/sql-tests/results/keywords.sql.out index 06c1c3f8c6a5a..a885525028623 100644 --- a/sql/core/src/test/resources/sql-tests/results/keywords.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/keywords.sql.out @@ -181,7 +181,6 @@ LIKE false LIMIT false LINES false LIST false -LISTAGG false LOAD false LOCAL false LOCATION false diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 04e8752dfe7bb..0c40842647be4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -210,21 +210,21 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark sql("SELECT LISTAGG(a, '|') WITHIN GROUP (ORDER BY b DESC) FROM df"), Row("b|a|b|a") :: Nil) - checkError( - exception = intercept[AnalysisException] { - sql("SELECT LISTAGG(DISTINCT a) WITHIN GROUP (ORDER BY b) FROM df") - }, - errorClass = "FUNCTION_AND_ORDER_EXPRESSION_MISMATCH", - parameters = Map( - "functionName" -> "`LISTAGG`", - "functionExpr" -> "\"a\"", - "orderExpr" -> "\"b\"")) - +// checkError( +// exception = intercept[AnalysisException] { +// sql("SELECT LISTAGG(DISTINCT a) WITHIN GROUP (ORDER BY b) FROM df") +// }, +// condition = "FUNCTION_AND_ORDER_EXPRESSION_MISMATCH", +// parameters = Map( +// "functionName" -> "`LISTAGG`", +// "functionExpr" -> "\"a\"", +// "orderExpr" -> "\"b\"")) +// Seq((1, true), (2, false), (3, false)).toDF("a", "b").createOrReplaceTempView("df2") checkAnswer( sql("SELECT LISTAGG(a), LISTAGG(b) FROM df2"), - Row("1,2,3", "false,false,true") :: Nil) + Row("1,2,3", "true,false,false") :: Nil) } } diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala index 43ce5d70cdc38..71d81b06463f1 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala @@ -214,7 +214,7 @@ trait ThriftServerWithSparkContextSuite extends SharedThriftServer { val sessionHandle = client.openSession(user, "") val infoValue = client.getInfo(sessionHandle, GetInfoType.CLI_ODBC_KEYWORDS) // scalastyle:off line.size.limit - assert(infoValue.getStringValue == "ADD,AFTER,AGGREGATE,ALL,ALTER,ALWAYS,ANALYZE,AND,ANTI,ANY,ANY_VALUE,ARCHIVE,ARRAY,AS,ASC,AT,AUTHORIZATION,BEGIN,BETWEEN,BIGINT,BINARY,BINDING,BOOLEAN,BOTH,BUCKET,BUCKETS,BY,BYTE,CACHE,CALL,CALLED,CASCADE,CASE,CAST,CATALOG,CATALOGS,CHANGE,CHAR,CHARACTER,CHECK,CLEAR,CLUSTER,CLUSTERED,CODEGEN,COLLATE,COLLATION,COLLECTION,COLUMN,COLUMNS,COMMENT,COMMIT,COMPACT,COMPACTIONS,COMPENSATION,COMPUTE,CONCATENATE,CONSTRAINT,CONTAINS,COST,CREATE,CROSS,CUBE,CURRENT,CURRENT_DATE,CURRENT_TIME,CURRENT_TIMESTAMP,CURRENT_USER,DATA,DATABASE,DATABASES,DATE,DATEADD,DATEDIFF,DATE_ADD,DATE_DIFF,DAY,DAYOFYEAR,DAYS,DBPROPERTIES,DEC,DECIMAL,DECLARE,DEFAULT,DEFINED,DEFINER,DELETE,DELIMITED,DESC,DESCRIBE,DETERMINISTIC,DFS,DIRECTORIES,DIRECTORY,DISTINCT,DISTRIBUTE,DIV,DO,DOUBLE,DROP,ELSE,END,ESCAPE,ESCAPED,EVOLUTION,EXCEPT,EXCHANGE,EXCLUDE,EXECUTE,EXISTS,EXPLAIN,EXPORT,EXTENDED,EXTERNAL,EXTRACT,FALSE,FETCH,FIELDS,FILEFORMAT,FILTER,FIRST,FLOAT,FOLLOWING,FOR,FOREIGN,FORMAT,FORMATTED,FROM,FULL,FUNCTION,FUNCTIONS,GENERATED,GLOBAL,GRANT,GROUP,GROUPING,HAVING,HOUR,HOURS,IDENTIFIER,IDENTITY,IF,IGNORE,ILIKE,IMMEDIATE,IMPORT,IN,INCLUDE,INCREMENT,INDEX,INDEXES,INNER,INPATH,INPUT,INPUTFORMAT,INSERT,INT,INTEGER,INTERSECT,INTERVAL,INTO,INVOKER,IS,ITEMS,ITERATE,JOIN,KEYS,LANGUAGE,LAST,LATERAL,LAZY,LEADING,LEAVE,LEFT,LIKE,LIMIT,LINES,LIST,LISTAGG,LOAD,LOCAL,LOCATION,LOCK,LOCKS,LOGICAL,LONG,LOOP,MACRO,MAP,MATCHED,MERGE,MICROSECOND,MICROSECONDS,MILLISECOND,MILLISECONDS,MINUS,MINUTE,MINUTES,MODIFIES,MONTH,MONTHS,MSCK,NAME,NAMESPACE,NAMESPACES,NANOSECOND,NANOSECONDS,NATURAL,NO,NONE,NOT,NULL,NULLS,NUMERIC,OF,OFFSET,ON,ONLY,OPTION,OPTIONS,OR,ORDER,OUT,OUTER,OUTPUTFORMAT,OVER,OVERLAPS,OVERLAY,OVERWRITE,PARTITION,PARTITIONED,PARTITIONS,PERCENT,PIVOT,PLACING,POSITION,PRECEDING,PRIMARY,PRINCIPALS,PROPERTIES,PURGE,QUARTER,QUERY,RANGE,READS,REAL,RECORDREADER,RECORDWRITER,RECOVER,REDUCE,REFERENCES,REFRESH,RENAME,REPAIR,REPEAT,REPEATABLE,REPLACE,RESET,RESPECT,RESTRICT,RETURN,RETURNS,REVOKE,RIGHT,ROLE,ROLES,ROLLBACK,ROLLUP,ROW,ROWS,SCHEMA,SCHEMAS,SECOND,SECONDS,SECURITY,SELECT,SEMI,SEPARATED,SERDE,SERDEPROPERTIES,SESSION_USER,SET,SETS,SHORT,SHOW,SINGLE,SKEWED,SMALLINT,SOME,SORT,SORTED,SOURCE,SPECIFIC,SQL,START,STATISTICS,STORED,STRATIFY,STRING,STRUCT,SUBSTR,SUBSTRING,SYNC,SYSTEM_TIME,SYSTEM_VERSION,TABLE,TABLES,TABLESAMPLE,TARGET,TBLPROPERTIES,TERMINATED,THEN,TIME,TIMEDIFF,TIMESTAMP,TIMESTAMPADD,TIMESTAMPDIFF,TIMESTAMP_LTZ,TIMESTAMP_NTZ,TINYINT,TO,TOUCH,TRAILING,TRANSACTION,TRANSACTIONS,TRANSFORM,TRIM,TRUE,TRUNCATE,TRY_CAST,TYPE,UNARCHIVE,UNBOUNDED,UNCACHE,UNION,UNIQUE,UNKNOWN,UNLOCK,UNPIVOT,UNSET,UNTIL,UPDATE,USE,USER,USING,VALUES,VAR,VARCHAR,VARIABLE,VARIANT,VERSION,VIEW,VIEWS,VOID,WEEK,WEEKS,WHEN,WHERE,WHILE,WINDOW,WITH,WITHIN,X,YEAR,YEARS,ZONE") + assert(infoValue.getStringValue == "ADD,AFTER,AGGREGATE,ALL,ALTER,ALWAYS,ANALYZE,AND,ANTI,ANY,ANY_VALUE,ARCHIVE,ARRAY,AS,ASC,AT,AUTHORIZATION,BEGIN,BETWEEN,BIGINT,BINARY,BINDING,BOOLEAN,BOTH,BUCKET,BUCKETS,BY,BYTE,CACHE,CALL,CALLED,CASCADE,CASE,CAST,CATALOG,CATALOGS,CHANGE,CHAR,CHARACTER,CHECK,CLEAR,CLUSTER,CLUSTERED,CODEGEN,COLLATE,COLLATION,COLLECTION,COLUMN,COLUMNS,COMMENT,COMMIT,COMPACT,COMPACTIONS,COMPENSATION,COMPUTE,CONCATENATE,CONSTRAINT,CONTAINS,COST,CREATE,CROSS,CUBE,CURRENT,CURRENT_DATE,CURRENT_TIME,CURRENT_TIMESTAMP,CURRENT_USER,DATA,DATABASE,DATABASES,DATE,DATEADD,DATEDIFF,DATE_ADD,DATE_DIFF,DAY,DAYOFYEAR,DAYS,DBPROPERTIES,DEC,DECIMAL,DECLARE,DEFAULT,DEFINED,DEFINER,DELETE,DELIMITED,DESC,DESCRIBE,DETERMINISTIC,DFS,DIRECTORIES,DIRECTORY,DISTINCT,DISTRIBUTE,DIV,DO,DOUBLE,DROP,ELSE,END,ESCAPE,ESCAPED,EVOLUTION,EXCEPT,EXCHANGE,EXCLUDE,EXECUTE,EXISTS,EXPLAIN,EXPORT,EXTENDED,EXTERNAL,EXTRACT,FALSE,FETCH,FIELDS,FILEFORMAT,FILTER,FIRST,FLOAT,FOLLOWING,FOR,FOREIGN,FORMAT,FORMATTED,FROM,FULL,FUNCTION,FUNCTIONS,GENERATED,GLOBAL,GRANT,GROUP,GROUPING,HAVING,HOUR,HOURS,IDENTIFIER,IDENTITY,IF,IGNORE,ILIKE,IMMEDIATE,IMPORT,IN,INCLUDE,INCREMENT,INDEX,INDEXES,INNER,INPATH,INPUT,INPUTFORMAT,INSERT,INT,INTEGER,INTERSECT,INTERVAL,INTO,INVOKER,IS,ITEMS,ITERATE,JOIN,KEYS,LANGUAGE,LAST,LATERAL,LAZY,LEADING,LEAVE,LEFT,LIKE,LIMIT,LINES,LIST,LOAD,LOCAL,LOCATION,LOCK,LOCKS,LOGICAL,LONG,LOOP,MACRO,MAP,MATCHED,MERGE,MICROSECOND,MICROSECONDS,MILLISECOND,MILLISECONDS,MINUS,MINUTE,MINUTES,MODIFIES,MONTH,MONTHS,MSCK,NAME,NAMESPACE,NAMESPACES,NANOSECOND,NANOSECONDS,NATURAL,NO,NONE,NOT,NULL,NULLS,NUMERIC,OF,OFFSET,ON,ONLY,OPTION,OPTIONS,OR,ORDER,OUT,OUTER,OUTPUTFORMAT,OVER,OVERLAPS,OVERLAY,OVERWRITE,PARTITION,PARTITIONED,PARTITIONS,PERCENT,PIVOT,PLACING,POSITION,PRECEDING,PRIMARY,PRINCIPALS,PROPERTIES,PURGE,QUARTER,QUERY,RANGE,READS,REAL,RECORDREADER,RECORDWRITER,RECOVER,REDUCE,REFERENCES,REFRESH,RENAME,REPAIR,REPEAT,REPEATABLE,REPLACE,RESET,RESPECT,RESTRICT,RETURN,RETURNS,REVOKE,RIGHT,ROLE,ROLES,ROLLBACK,ROLLUP,ROW,ROWS,SCHEMA,SCHEMAS,SECOND,SECONDS,SECURITY,SELECT,SEMI,SEPARATED,SERDE,SERDEPROPERTIES,SESSION_USER,SET,SETS,SHORT,SHOW,SINGLE,SKEWED,SMALLINT,SOME,SORT,SORTED,SOURCE,SPECIFIC,SQL,START,STATISTICS,STORED,STRATIFY,STRING,STRUCT,SUBSTR,SUBSTRING,SYNC,SYSTEM_TIME,SYSTEM_VERSION,TABLE,TABLES,TABLESAMPLE,TARGET,TBLPROPERTIES,TERMINATED,THEN,TIME,TIMEDIFF,TIMESTAMP,TIMESTAMPADD,TIMESTAMPDIFF,TIMESTAMP_LTZ,TIMESTAMP_NTZ,TINYINT,TO,TOUCH,TRAILING,TRANSACTION,TRANSACTIONS,TRANSFORM,TRIM,TRUE,TRUNCATE,TRY_CAST,TYPE,UNARCHIVE,UNBOUNDED,UNCACHE,UNION,UNIQUE,UNKNOWN,UNLOCK,UNPIVOT,UNSET,UNTIL,UPDATE,USE,USER,USING,VALUES,VAR,VARCHAR,VARIABLE,VARIANT,VERSION,VIEW,VIEWS,VOID,WEEK,WEEKS,WHEN,WHERE,WHILE,WINDOW,WITH,WITHIN,X,YEAR,YEARS,ZONE") // scalastyle:on line.size.limit } } From 338205607060b8ae75a1e31aed2c09f3a92846fa Mon Sep 17 00:00:00 2001 From: Mikhail Nikoliukin Date: Thu, 31 Oct 2024 17:36:09 +0100 Subject: [PATCH 25/58] [SPARK-42746] add binary type support, type validation and set default return to null --- .../apache/spark/unsafe/types/ByteArray.java | 33 +++++++ .../spark/unsafe/array/ByteArraySuite.java | 54 ++++++++++++ .../expressions/aggregate/collect.scala | 84 ++++++++++++++---- .../spark/sql/DataFrameAggregateSuite.scala | 2 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 88 ++++++++++++++++--- 5 files changed, 228 insertions(+), 33 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java index aae47aa963201..e010e2dadf605 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java @@ -160,6 +160,39 @@ public static byte[] concat(byte[]... inputs) { return result; } + public static byte[] concatWS(byte[] delimiter, byte[]... inputs) { + // Compute the total length of the result + long totalLength = 0; + for (byte[] input : inputs) { + if (input != null) { + totalLength += input.length + delimiter.length; + } else { + return null; + } + } + if (totalLength > 0) totalLength -= delimiter.length; + // Allocate a new byte array, and copy the inputs one by one into it + final byte[] result = new byte[Ints.checkedCast(totalLength)]; + int offset = 0; + for (int i = 0; i < inputs.length; i++) { + byte[] input = inputs[i]; + int len = input.length; + Platform.copyMemory( + input, Platform.BYTE_ARRAY_OFFSET, + result, Platform.BYTE_ARRAY_OFFSET + offset, + len); + offset += len; + if(i < inputs.length - 1) { + Platform.copyMemory( + delimiter, Platform.BYTE_ARRAY_OFFSET, + result, Platform.BYTE_ARRAY_OFFSET + offset, + delimiter.length); + offset += delimiter.length; + } + } + return result; + } + // Helper method for implementing `lpad` and `rpad`. // If the padding pattern's length is 0, return the first `len` bytes of the input byte // sequence if it is longer than `len` bytes, or a copy of the byte sequence, otherwise. diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/array/ByteArraySuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/array/ByteArraySuite.java index aff619175ff7b..e86d52fc90e02 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/array/ByteArraySuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/array/ByteArraySuite.java @@ -67,4 +67,58 @@ public void testCompareBinary() { byte[] y4 = new byte[]{(byte) 100, (byte) 200}; Assertions.assertEquals(0, ByteArray.compareBinary(x4, y4)); } + + @Test + public void testConcat() { + byte[] x1 = new byte[]{(byte) 1, (byte) 2, (byte) 3}; + byte[] y1 = new byte[]{(byte) 4, (byte) 5, (byte) 6}; + byte[] result1 = ByteArray.concat(x1, y1); + byte[] expected1 = new byte[]{(byte) 1, (byte) 2, (byte) 3, (byte) 4, (byte) 5, (byte) 6}; + Assertions.assertArrayEquals(expected1, result1); + + byte[] x2 = new byte[]{(byte) 1, (byte) 2, (byte) 3}; + byte[] y2 = new byte[0]; + byte[] result2 = ByteArray.concat(x2, y2); + byte[] expected2 = new byte[]{(byte) 1, (byte) 2, (byte) 3}; + Assertions.assertArrayEquals(expected2, result2); + + byte[] x3 = new byte[0]; + byte[] y3 = new byte[]{(byte) 4, (byte) 5, (byte) 6}; + byte[] result3 = ByteArray.concat(x3, y3); + byte[] expected3 = new byte[]{(byte) 4, (byte) 5, (byte) 6}; + Assertions.assertArrayEquals(expected3, result3); + + byte[] x4 = new byte[]{(byte) 1, (byte) 2, (byte) 3}; + byte[] y4 = null; + byte[] result4 = ByteArray.concat(x4, y4); + Assertions.assertArrayEquals(null, result4); + } + + @Test + public void testConcatWS() { + byte[] separator = new byte[]{(byte) 42}; // Separator byte array + + byte[] x1 = new byte[]{(byte) 1, (byte) 2, (byte) 3}; + byte[] y1 = new byte[]{(byte) 4, (byte) 5, (byte) 6}; + byte[] result1 = ByteArray.concatWS(separator, x1, y1); + byte[] expected1 = new byte[]{(byte) 1, (byte) 2, (byte) 3, (byte) 42, (byte) 4, (byte) 5, (byte) 6}; + Assertions.assertArrayEquals(expected1, result1); + + byte[] x2 = new byte[]{(byte) 1, (byte) 2, (byte) 3}; + byte[] y2 = new byte[0]; + byte[] result2 = ByteArray.concatWS(separator, x2, y2); + byte[] expected2 = new byte[]{(byte) 1, (byte) 2, (byte) 3, (byte) 42}; + Assertions.assertArrayEquals(expected2, result2); + + byte[] x3 = new byte[0]; + byte[] y3 = new byte[]{(byte) 4, (byte) 5, (byte) 6}; + byte[] result3 = ByteArray.concatWS(separator, x3, y3); + byte[] expected3 = new byte[]{(byte) 42, (byte) 4, (byte) 5, (byte) 6}; + Assertions.assertArrayEquals(expected3, result3); + + byte[] x4 = new byte[]{(byte) 1, (byte) 2, (byte) 3}; + byte[] y4 = null; + byte[] result4 = ByteArray.concatWS(separator, x4, y4); + Assertions.assertArrayEquals(null, result4); + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala index a0ec970f54162..def09142067f2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala @@ -27,9 +27,12 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.trees.UnaryLike import org.apache.spark.sql.catalyst.types.PhysicalDataType import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, TypeUtils, UnsafeRowUtils} +import org.apache.spark.sql.catalyst.util.TypeUtils.toSQLExpr import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase} +import org.apache.spark.sql.errors.DataTypeErrors.{toSQLId, toSQLType} +import org.apache.spark.sql.internal.types.StringTypeWithCollation import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.unsafe.types.{ByteArray, UTF8String} import org.apache.spark.util.BoundedPriorityQueue /** @@ -267,38 +270,42 @@ private[aggregate] object CollectTopK { } @ExpressionDescription( - usage = "_FUNC_(expr) - Returns the concatenated input values," + + usage = "_FUNC_(expr) - Returns the concatenated input non-null values," + " separated by the delimiter string.", examples = """ Examples: > SELECT _FUNC_(col) FROM VALUES ('a'), ('b'), ('c') AS tab(col); - a,b,c - > SELECT _FUNC_(col) FROM VALUES (NULL), ('a'), ('b') AS tab(col); - a,b + abc + > SELECT _FUNC_(col) FROM VALUES ('a'), (NULL), ('b') AS tab(col); + ab > SELECT _FUNC_(col) FROM VALUES ('a'), ('a') AS tab(col); - a,a + aa > SELECT _FUNC_(DISTINCT col) FROM VALUES ('a'), ('a'), ('b') AS tab(col); - a,b - > SELECT _FUNC_(col, '|') FROM VALUES ('a'), ('b') AS tab(col); - a|b + ab + > SELECT _FUNC_(col, ', ') FROM VALUES ('a'), ('b'), ('c') AS tab(col); + a, b, c > SELECT _FUNC_(col) FROM VALUES (NULL), (NULL) AS tab(col); + NULL """, group = "agg_funcs", - since = "4.0.0") + since = "4.0.0") // TODO change case class ListAgg( child: Expression, delimiter: Expression = Literal.create(",", StringType),// TODO replace with null (empty string) orderExpressions: Seq[SortOrder] = Nil, mutableAggBufferOffset: Int = 0, inputAggBufferOffset: Int = 0) extends Collect[mutable.ArrayBuffer[Any]] - with SupportsOrderingWithinGroup { + with SupportsOrderingWithinGroup + with ImplicitCastInputTypes { def this(child: Expression) = this(child, Literal.create(",", StringType), Nil, 0, 0) def this(child: Expression, delimiter: Expression) = this(child, delimiter, Nil, 0, 0) - override def dataType: DataType = StringType // TODO add binary + override def dataType: DataType = child.dataType + + override def nullable: Boolean = true override def createAggregationBuffer(): mutable.ArrayBuffer[Any] = mutable.ArrayBuffer.empty @@ -313,8 +320,8 @@ case class ListAgg( (orderExpressions.size == 1 && orderExpressions.head.semanticEquals(child)) override protected def convertToBufferElement(value: Any): Any = InternalRow.copyValue(value) - // TODO make null - override def defaultResult: Option[Literal] = Option(Literal.create("", StringType)) + + override def defaultResult: Option[Literal] = Option(Literal.create(null, dataType)) override protected lazy val bufferElementType: DataType = { if (dontNeedSaveOrderValue) { @@ -326,6 +333,36 @@ case class ListAgg( } } + override def inputTypes: Seq[AbstractDataType] = + TypeCollection( + StringTypeWithCollation(supportsTrimCollation = true), + BinaryType + ) +: + TypeCollection( + StringTypeWithCollation(supportsTrimCollation = true), + BinaryType + ) +: + orderExpressions.map(_ => AnyDataType) + + + override def checkInputDataTypes(): TypeCheckResult = { + val matchInputTypes = super.checkInputDataTypes() + if (matchInputTypes.isFailure) { + return matchInputTypes + } + if (!delimiter.foldable) { + return DataTypeMismatch( + errorSubClass = "NON_FOLDABLE_INPUT", + messageParameters = Map( + "inputName" -> toSQLId("delimiter"), + "inputType" -> toSQLType(delimiter.dataType), + "inputExpr" -> toSQLExpr(delimiter) + ) + ) + } + TypeUtils.checkForSameTypeInputExpr(child.dataType :: delimiter.dataType :: Nil, prettyName) + } + private[this] def sortBuffer(buffer: mutable.ArrayBuffer[Any]): mutable.ArrayBuffer[Any] = { if (!orderingFilled) { return buffer @@ -343,13 +380,24 @@ case class ListAgg( } } + private[this] def concatWSInternal(buffer: mutable.ArrayBuffer[Any]): Any = { + val delimiterValue = delimiter.eval() + dataType match { + case BinaryType => + val inputs = buffer.map(_.asInstanceOf[Array[Byte]]) + ByteArray.concatWS(delimiterValue.asInstanceOf[Array[Byte]], inputs.toSeq: _*) + case _: StringType => + val inputs = buffer.map(_.asInstanceOf[UTF8String]) + UTF8String.fromString(inputs.mkString(delimiterValue.toString)) + } + } + override def eval(buffer: mutable.ArrayBuffer[Any]): Any = { if (buffer.nonEmpty) { - val sorted = sortBuffer(buffer) - UTF8String.fromString(sorted.map(_.toString) - .mkString(delimiter.eval().asInstanceOf[UTF8String].toString)) + val sortedBufferWithoutNulls = sortBuffer(buffer).filter(_ != null) + concatWSInternal(sortedBufferWithoutNulls) } else { - UTF8String.fromString("") // TODO null + null } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 5b39577acefc8..d52c4e3a8fa8c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -644,7 +644,7 @@ class DataFrameAggregateSuite extends QueryTest checkAnswer( df3.select(listagg_distinct($"a"), listagg($"a"), listagg_distinct($"b"), listagg($"b"), listagg($"c")), - Seq(Row("a", "a,a", "b", "b,b", "")) + Seq(Row("a", "a,a", "b", "b,b", null)) ) // custom delimiter diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 0c40842647be4..ad7f086c7cd17 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -161,55 +161,115 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark } } + private[this] def hexToBytes(s: String): Array[Byte] = { + val byteArray = BigInt(s, 16).toByteArray + if (byteArray.length > 1 && byteArray(0) == 0) { + // remove sign byte if exists + byteArray.tail + } else { + byteArray + } + } + test("SPARK-42746: listagg function") { withTempView("df", "df2") { Seq(("a", "b"), ("a", "c"), ("b", "c"), ("b", "d"), (null, null)).toDF("a", "b") .createOrReplaceTempView("df") checkAnswer( sql("select listagg(b) from df group by a"), - Row("") :: Row("b,c") :: Row("c,d") :: Nil) + Row(null) :: Row("b,c") :: Row("c,d") :: Nil) checkAnswer( sql("select listagg(b) from df where 1 != 1"), - Row("") :: Nil) + Row(null) :: Nil) checkAnswer( sql("select listagg(b, '|') from df group by a"), - Row("b|c") :: Row("c|d") :: Row("") :: Nil) + Row("b|c") :: Row("c|d") :: Row(null) :: Nil) checkAnswer( - sql("SELECT LISTAGG(a) FROM df"), + sql("select listagg(a) from df"), Row("a,a,b,b") :: Nil) checkAnswer( - sql("SELECT LISTAGG(DISTINCT a) FROM df"), + sql("select listagg(distinct a) from df"), Row("a,b") :: Nil) checkAnswer( - sql("SELECT LISTAGG(a) WITHIN GROUP (ORDER BY a) FROM df"), + sql("select listagg(a) within group (order by a) from df"), Row("a,a,b,b") :: Nil) checkAnswer( - sql("SELECT LISTAGG(a) WITHIN GROUP (ORDER BY a DESC) FROM df"), + sql("select listagg(a) within group (order by a desc) from df"), Row("b,b,a,a") :: Nil) checkAnswer( - sql("SELECT LISTAGG(a) WITHIN GROUP (ORDER BY a DESC) " + - "OVER (PARTITION BY b) FROM df"), - Row("a") :: Row("b,a") :: Row("b,a") :: Row("b") :: Row("") :: Nil) + sql("""select listagg(a) within group (order by a desc) over (partition by b) from df"""), + Row("a") :: Row("b,a") :: Row("b,a") :: Row("b") :: Row(null) :: Nil) checkAnswer( - sql("SELECT LISTAGG(a) WITHIN GROUP (ORDER BY b) FROM df"), + sql("select listagg(a) within group (order by b) from df"), Row("a,a,b,b") :: Nil) checkAnswer( - sql("SELECT LISTAGG(a) WITHIN GROUP (ORDER BY b DESC) FROM df"), + sql("select listagg(a) within group (order by b desc) from df"), Row("b,a,b,a") :: Nil) checkAnswer( - sql("SELECT LISTAGG(a, '|') WITHIN GROUP (ORDER BY b DESC) FROM df"), + sql("select listagg(a, '|') within group (order by b desc) from df"), Row("b|a|b|a") :: Nil) + checkAnswer( + sql("select listagg(c1, X'42')from values (X'DEAD'), (X'BEEF') as t(c1)"), + Row(hexToBytes("DEAD42BEEF")) :: Nil) + + checkError( + exception = intercept[AnalysisException] { + sql("select listagg(c1) from values (array('a', 'b')) as t(c1)") + }, + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + parameters = Map( + "sqlExpr" -> "\"listagg(c1, ,)\"", + "paramIndex" -> "first", + "requiredType" -> "(\"STRING\" or \"BINARY\")", + "inputSql" -> "\"c1\"", + "inputType" -> "\"ARRAY\""), + context = ExpectedContext( + fragment = "listagg(c1)", + start = 7, + stop = 17 + )) + + checkError( + exception = intercept[AnalysisException] { + sql("select listagg(c1, ', ')from values (X'DEAD'), (X'BEEF') as t(c1)") + }, + condition = "DATATYPE_MISMATCH.DATA_DIFF_TYPES", + parameters = Map( + "sqlExpr" -> "\"listagg(c1, , )\"", + "functionName" -> "`listagg`", + "dataType" -> "(\"BINARY\" or \"STRING\")"), + context = ExpectedContext( + fragment = "listagg(c1, ', ')", + start = 7, + stop = 23 + )) + + checkError( + exception = intercept[AnalysisException] { + sql("select listagg(b, a) from df group by a") + }, + condition = "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT", + parameters = Map( + "sqlExpr" -> "\"listagg(b, a)\"", + "inputName" -> "`delimiter`", + "inputType" -> "\"STRING\"", + "inputExpr" -> "\"a\""), + context = ExpectedContext( + fragment = "listagg(b, a)", + start = 7, + stop = 19 + )) // checkError( // exception = intercept[AnalysisException] { // sql("SELECT LISTAGG(DISTINCT a) WITHIN GROUP (ORDER BY b) FROM df") @@ -223,7 +283,7 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark Seq((1, true), (2, false), (3, false)).toDF("a", "b").createOrReplaceTempView("df2") checkAnswer( - sql("SELECT LISTAGG(a), LISTAGG(b) FROM df2"), + sql("select listagg(a), listagg(b) from df2"), Row("1,2,3", "true,false,false") :: Nil) } } From 0f64921edb2647f0291338d3fc4c1a34ae49e4b0 Mon Sep 17 00:00:00 2001 From: Mikhail Nikoliukin Date: Fri, 1 Nov 2024 13:20:05 +0100 Subject: [PATCH 26/58] [SPARK-42746] add more validation errors --- .../resources/error/error-conditions.json | 6 ++ .../sql/catalyst/analysis/Analyzer.scala | 4 + .../sql/catalyst/analysis/CheckAnalysis.scala | 10 ++- .../expressions/aggregate/collect.scala | 22 ++++- .../sql/errors/QueryCompilationErrors.scala | 6 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 80 +++++++++++++++---- 6 files changed, 105 insertions(+), 23 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 90348441e736b..ac7359d0ae338 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -1563,6 +1563,12 @@ ], "sqlState" : "42710" }, + "FUNCTION_AND_ORDER_EXPRESSION_MISMATCH" : { + "message" : [ + "The function arguments should match the order by expression when use DISTINCT." + ], + "sqlState" : "42822" + }, "GENERATED_COLUMN_WITH_DEFAULT_VALUE" : { "message" : [ "A column cannot have both a default value and a generation expression but column has default value: () and generation expression: ()." diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 317f8a2fce079..2f0e9aa2907ec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2225,6 +2225,10 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor if !f.isInstanceOf[SupportsOrderingWithinGroup] && u.orderingWithinGroup.nonEmpty => throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( func.prettyName, "WITHIN GROUP (ORDER BY ...)") + case listAgg: ListAgg + if u.isDistinct && !listAgg.isOrderCompatible(u.orderingWithinGroup) => + throw QueryCompilationErrors.functionAndOrderExpressionMismatchError( + listAgg.prettyName, listAgg.child, u.orderingWithinGroup) // AggregateWindowFunctions are AggregateFunctions that can only be evaluated within // the context of a Window clause. They do not need to be wrapped in an // AggregateExpression. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index efb63ea181a80..1453294a339c4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.ExtendedAnalysisException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.SubExprUtils._ -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, Median, PercentileCont, PercentileDisc} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, ListAgg, Median, PercentileCont, PercentileDisc} import org.apache.spark.sql.catalyst.optimizer.{BooleanSimplification, DecorrelateInnerQuery, InlineCTE} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -427,6 +427,14 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB // Only allow window functions with an aggregate expression or an offset window // function or a Pandas window UDF. w.windowFunction match { + case agg @ AggregateExpression(fun: ListAgg, _, _, _, _) + // listagg(...) WITHIN GROUP (ORDER BY ...) OVER (ORDER BY ...) is unsupported + if fun.orderingFilled && (w.windowSpec.orderSpec.nonEmpty || + w.windowSpec.frameSpecification != + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing)) => + agg.failAnalysis( + errorClass = "INVALID_WINDOW_SPEC_FOR_AGGREGATION_FUNC", + messageParameters = Map("aggFunc" -> toSQLExpr(agg.aggregateFunction))) case agg @ AggregateExpression( _: PercentileCont | _: PercentileDisc | _: Median, _, _, _, _) if w.windowSpec.orderSpec.nonEmpty || w.windowSpec.frameSpecification != diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala index def09142067f2..b1a570d92ff5f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala @@ -316,8 +316,8 @@ case class ListAgg( override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = copy(inputAggBufferOffset = newInputAggBufferOffset) - private lazy val dontNeedSaveOrderValue = orderExpressions.isEmpty || - (orderExpressions.size == 1 && orderExpressions.head.semanticEquals(child)) + /** Indicates that the result of [[child]] is enough for evaluation */ + private lazy val dontNeedSaveOrderValue = isOrderCompatible(orderExpressions) override protected def convertToBufferElement(value: Any): Any = InternalRow.copyValue(value) @@ -329,7 +329,7 @@ case class ListAgg( } else { StructType(Seq( StructField("value", child.dataType), - StructField("sortOrder", orderExpressions.head.dataType))) + StructField("sortOrderValue", orderExpressions.head.dataType))) } } @@ -428,4 +428,20 @@ case class ListAgg( orderExpressions = newChildren.drop(2).zip(orderExpressions) .map { case (newExpr, oldSortOrder) => oldSortOrder.copy(child = newExpr) } ) + + /** + * Utility func to check if given order is defined and different from [[child]]. + * + * @see [[QueryCompilationErrors.functionAndOrderExpressionMismatchError]] + * @see [[dontNeedSaveOrderValue]] + */ + def isOrderCompatible(someOrder: Seq[SortOrder]): Boolean = { + if (someOrder.isEmpty) { + return true + } + if (someOrder.size == 1 && someOrder.head.child.semanticEquals(child)) { + return true + } + false + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index fa3d18b8567c8..88581cade51ca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.{ExtendedAnalysisException, FunctionIdentif import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, FunctionAlreadyExistsException, NamespaceAlreadyExistsException, NoSuchFunctionException, NoSuchNamespaceException, NoSuchPartitionException, NoSuchTableException, Star, TableAlreadyExistsException, UnresolvedRegex} import org.apache.spark.sql.catalyst.catalog.{CatalogTable, InvalidUDFClassException} import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec -import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, AttributeSet, CreateMap, CreateStruct, Expression, GroupingID, NamedExpression, SpecifiedWindowFrame, WindowFrame, WindowFunction, WindowSpecDefinition} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, AttributeSet, CreateMap, CreateStruct, Expression, GroupingID, NamedExpression, SortOrder, SpecifiedWindowFrame, WindowFrame, WindowFunction, WindowSpecDefinition} import org.apache.spark.sql.catalyst.expressions.aggregate.AnyValue import org.apache.spark.sql.catalyst.plans.JoinType import org.apache.spark.sql.catalyst.plans.logical.{Assignment, InputParameter, Join, LogicalPlan, SerdeInfo, Window} @@ -1052,13 +1052,13 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat def functionAndOrderExpressionMismatchError( functionName: String, functionExpr: Expression, - orderExpr: Expression): Throwable = { + orderExpr: Seq[SortOrder]): Throwable = { new AnalysisException( errorClass = "FUNCTION_AND_ORDER_EXPRESSION_MISMATCH", messageParameters = Map( "functionName" -> toSQLId(functionName), "functionExpr" -> toSQLExpr(functionExpr), - "orderExpr" -> toSQLExpr(orderExpr))) + "orderExpr" -> orderExpr.map(order => toSQLExpr(order.child)).mkString(","))) } def wrongCommandForObjectTypeError( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index ad7f086c7cd17..5645bc9696a3a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -237,8 +237,7 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark context = ExpectedContext( fragment = "listagg(c1)", start = 7, - stop = 17 - )) + stop = 17)) checkError( exception = intercept[AnalysisException] { @@ -252,8 +251,7 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark context = ExpectedContext( fragment = "listagg(c1, ', ')", start = 7, - stop = 23 - )) + stop = 23)) checkError( exception = intercept[AnalysisException] { @@ -268,18 +266,68 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark context = ExpectedContext( fragment = "listagg(b, a)", start = 7, - stop = 19 - )) -// checkError( -// exception = intercept[AnalysisException] { -// sql("SELECT LISTAGG(DISTINCT a) WITHIN GROUP (ORDER BY b) FROM df") -// }, -// condition = "FUNCTION_AND_ORDER_EXPRESSION_MISMATCH", -// parameters = Map( -// "functionName" -> "`LISTAGG`", -// "functionExpr" -> "\"a\"", -// "orderExpr" -> "\"b\"")) -// + stop = 19)) + + checkAnswer( + sql("select listagg(a) over (order by a) from df"), + Row(null) :: Row("a,a") :: Row("a,a") :: Row("a,a,b,b") :: Row("a,a,b,b") :: Nil) + + checkError( + exception = intercept[AnalysisException] { + sql("select listagg(a) within group (order by a) over (order by a) from df") + }, + condition = "INVALID_WINDOW_SPEC_FOR_AGGREGATION_FUNC", + parameters = Map("aggFunc" -> "\"listagg(a, ,, a)\""), + context = ExpectedContext( + fragment = "listagg(a) within group (order by a) over (order by a)", + start = 7, + stop = 60)) + + checkError( + exception = intercept[AnalysisException] { + sql("select listagg(distinct a) over (order by a) from df") + }, + condition = "DISTINCT_WINDOW_FUNCTION_UNSUPPORTED", + parameters = Map("windowExpr" -> + ("\"listagg(DISTINCT a, ,) " + + "OVER (ORDER BY a ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)\"")), + context = ExpectedContext( + fragment = "listagg(distinct a) over (order by a)", + start = 7, + stop = 43)) + + checkAnswer( + sql("select listagg(distinct a) within group (order by a DESC) from df"), + Row("b,a") :: Nil) + + checkError( + exception = intercept[AnalysisException] { + sql("select listagg(distinct a) within group (order by b) from df") + }, + condition = "FUNCTION_AND_ORDER_EXPRESSION_MISMATCH", + parameters = Map( + "functionName" -> "`listagg`", + "functionExpr" -> "\"a\"", + "orderExpr" -> "\"b\""), + context = ExpectedContext( + fragment = "listagg(distinct a) within group (order by b)", + start = 7, + stop = 51)) + + checkError( + exception = intercept[AnalysisException] { + sql("select listagg(distinct a) within group (order by a, b) from df") + }, + condition = "FUNCTION_AND_ORDER_EXPRESSION_MISMATCH", + parameters = Map( + "functionName" -> "`listagg`", + "functionExpr" -> "\"a\"", + "orderExpr" -> "\"a\",\"b\""), + context = ExpectedContext( + fragment = "listagg(distinct a) within group (order by a, b)", + start = 7, + stop = 54)) + Seq((1, true), (2, false), (3, false)).toDF("a", "b").createOrReplaceTempView("df2") checkAnswer( From d69ad1f2c88e61e1ef962d0e501770f34d5ca750 Mon Sep 17 00:00:00 2001 From: Mikhail Nikoliukin Date: Fri, 1 Nov 2024 14:48:20 +0100 Subject: [PATCH 27/58] [SPARK-42746] set default delimiter to null --- .../expressions/aggregate/collect.scala | 39 +++++++++++---- .../spark/sql/DataFrameAggregateSuite.scala | 8 ++-- .../org/apache/spark/sql/SQLQuerySuite.scala | 48 ++++++++++++------- 3 files changed, 65 insertions(+), 30 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala index b1a570d92ff5f..a6b026798e698 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala @@ -22,7 +22,7 @@ import scala.collection.mutable.{ArrayBuffer, Growable} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch, TypeCheckSuccess} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.trees.UnaryLike import org.apache.spark.sql.catalyst.types.PhysicalDataType @@ -291,7 +291,7 @@ private[aggregate] object CollectTopK { since = "4.0.0") // TODO change case class ListAgg( child: Expression, - delimiter: Expression = Literal.create(",", StringType),// TODO replace with null (empty string) + delimiter: Expression = Literal(null), orderExpressions: Seq[SortOrder] = Nil, mutableAggBufferOffset: Int = 0, inputAggBufferOffset: Int = 0) extends Collect[mutable.ArrayBuffer[Any]] @@ -299,7 +299,7 @@ case class ListAgg( with ImplicitCastInputTypes { def this(child: Expression) = - this(child, Literal.create(",", StringType), Nil, 0, 0) + this(child, Literal(null), Nil, 0, 0) def this(child: Expression, delimiter: Expression) = this(child, delimiter, Nil, 0, 0) @@ -340,7 +340,8 @@ case class ListAgg( ) +: TypeCollection( StringTypeWithCollation(supportsTrimCollation = true), - BinaryType + BinaryType, + NullType ) +: orderExpressions.map(_ => AnyDataType) @@ -360,7 +361,12 @@ case class ListAgg( ) ) } - TypeUtils.checkForSameTypeInputExpr(child.dataType :: delimiter.dataType :: Nil, prettyName) + if (delimiter.dataType == NullType) { + // null is the default empty delimiter so type is not important + TypeCheckSuccess + } else { + TypeUtils.checkForSameTypeInputExpr(child.dataType :: delimiter.dataType :: Nil, prettyName) + } } private[this] def sortBuffer(buffer: mutable.ArrayBuffer[Any]): mutable.ArrayBuffer[Any] = { @@ -380,22 +386,35 @@ case class ListAgg( } } - private[this] def concatWSInternal(buffer: mutable.ArrayBuffer[Any]): Any = { + private[this] def getDelimiterValue: Any = { val delimiterValue = delimiter.eval() + if (delimiterValue == null) { + // default delimiter value + dataType match { + case StringType => UTF8String.fromString("") + case BinaryType => ByteArray.EMPTY_BYTE + } + } else { + delimiterValue + } + } + + private[this] def concatSkippingNulls(buffer: mutable.ArrayBuffer[Any]): Any = { + val delimiterValue = getDelimiterValue dataType match { case BinaryType => - val inputs = buffer.map(_.asInstanceOf[Array[Byte]]) + val inputs = buffer.filter(_ != null).map(_.asInstanceOf[Array[Byte]]) ByteArray.concatWS(delimiterValue.asInstanceOf[Array[Byte]], inputs.toSeq: _*) case _: StringType => - val inputs = buffer.map(_.asInstanceOf[UTF8String]) + val inputs = buffer.filter(_ != null).map(_.asInstanceOf[UTF8String]) UTF8String.fromString(inputs.mkString(delimiterValue.toString)) } } override def eval(buffer: mutable.ArrayBuffer[Any]): Any = { if (buffer.nonEmpty) { - val sortedBufferWithoutNulls = sortBuffer(buffer).filter(_ != null) - concatWSInternal(sortedBufferWithoutNulls) + val sortedBufferWithoutNulls = sortBuffer(buffer) + concatSkippingNulls(sortedBufferWithoutNulls) } else { null } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index d52c4e3a8fa8c..d7c6c19303fb3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -625,18 +625,18 @@ class DataFrameAggregateSuite extends QueryTest val df = Seq(("a", "b"), ("b", "c"), ("c", "d")).toDF("a", "b") checkAnswer( df.selectExpr("listagg(a)", "listagg(b)"), - Seq(Row("a,b,c", "b,c,d")) + Seq(Row("abc", "bcd")) ) checkAnswer( df.select(listagg($"a"), listagg($"b")), - Seq(Row("a,b,c", "b,c,d")) + Seq(Row("abc", "bcd")) ) // distinct case val df2 = Seq(("a", "b"), ("a", "b"), ("b", "d")).toDF("a", "b") checkAnswer( df2.select(listagg_distinct($"a"), listagg_distinct($"b")), - Seq(Row("a,b", "b,d")) + Seq(Row("ab", "bd")) ) // null case @@ -644,7 +644,7 @@ class DataFrameAggregateSuite extends QueryTest checkAnswer( df3.select(listagg_distinct($"a"), listagg($"a"), listagg_distinct($"b"), listagg($"b"), listagg($"c")), - Seq(Row("a", "a,a", "b", "b,b", null)) + Seq(Row("a", "aa", "b", "bb", null)) ) // custom delimiter diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 5645bc9696a3a..bc2f5b8725e6f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -171,13 +171,17 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark } } - test("SPARK-42746: listagg function") { + test("listagg function") { withTempView("df", "df2") { Seq(("a", "b"), ("a", "c"), ("b", "c"), ("b", "d"), (null, null)).toDF("a", "b") .createOrReplaceTempView("df") checkAnswer( sql("select listagg(b) from df group by a"), - Row(null) :: Row("b,c") :: Row("c,d") :: Nil) + Row(null) :: Row("bc") :: Row("cd") :: Nil) + + checkAnswer( + sql("select listagg(b, null) from df group by a"), + Row(null) :: Row("bc") :: Row("cd") :: Nil) checkAnswer( sql("select listagg(b) from df where 1 != 1"), @@ -187,38 +191,50 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark sql("select listagg(b, '|') from df group by a"), Row("b|c") :: Row("c|d") :: Row(null) :: Nil) + checkAnswer( + spark.sql("select listagg(b, :param || ' ') from df group by a", Map("param" -> ",")), + Row("b, c") :: Row("c, d") :: Row(null) :: Nil) + checkAnswer( sql("select listagg(a) from df"), - Row("a,a,b,b") :: Nil) + Row("aabb") :: Nil) checkAnswer( sql("select listagg(distinct a) from df"), - Row("a,b") :: Nil) + Row("ab") :: Nil) checkAnswer( sql("select listagg(a) within group (order by a) from df"), - Row("a,a,b,b") :: Nil) + Row("aabb") :: Nil) checkAnswer( sql("select listagg(a) within group (order by a desc) from df"), - Row("b,b,a,a") :: Nil) + Row("bbaa") :: Nil) checkAnswer( sql("""select listagg(a) within group (order by a desc) over (partition by b) from df"""), - Row("a") :: Row("b,a") :: Row("b,a") :: Row("b") :: Row(null) :: Nil) + Row("a") :: Row("ba") :: Row("ba") :: Row("b") :: Row(null) :: Nil) checkAnswer( sql("select listagg(a) within group (order by b) from df"), - Row("a,a,b,b") :: Nil) + Row("aabb") :: Nil) checkAnswer( sql("select listagg(a) within group (order by b desc) from df"), - Row("b,a,b,a") :: Nil) + Row("baba") :: Nil) checkAnswer( sql("select listagg(a, '|') within group (order by b desc) from df"), Row("b|a|b|a") :: Nil) + checkAnswer( + sql("select listagg(c1)from values (X'DEAD'), (X'BEEF') as t(c1)"), + Row(hexToBytes("DEADBEEF")) :: Nil) + + checkAnswer( + sql("select listagg(c1, null)from values (X'DEAD'), (X'BEEF') as t(c1)"), + Row(hexToBytes("DEADBEEF")) :: Nil) + checkAnswer( sql("select listagg(c1, X'42')from values (X'DEAD'), (X'BEEF') as t(c1)"), Row(hexToBytes("DEAD42BEEF")) :: Nil) @@ -229,7 +245,7 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark }, condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", parameters = Map( - "sqlExpr" -> "\"listagg(c1, ,)\"", + "sqlExpr" -> "\"listagg(c1, NULL)\"", "paramIndex" -> "first", "requiredType" -> "(\"STRING\" or \"BINARY\")", "inputSql" -> "\"c1\"", @@ -270,14 +286,14 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark checkAnswer( sql("select listagg(a) over (order by a) from df"), - Row(null) :: Row("a,a") :: Row("a,a") :: Row("a,a,b,b") :: Row("a,a,b,b") :: Nil) + Row(null) :: Row("aa") :: Row("aa") :: Row("aabb") :: Row("aabb") :: Nil) checkError( exception = intercept[AnalysisException] { sql("select listagg(a) within group (order by a) over (order by a) from df") }, condition = "INVALID_WINDOW_SPEC_FOR_AGGREGATION_FUNC", - parameters = Map("aggFunc" -> "\"listagg(a, ,, a)\""), + parameters = Map("aggFunc" -> "\"listagg(a, NULL, a)\""), context = ExpectedContext( fragment = "listagg(a) within group (order by a) over (order by a)", start = 7, @@ -289,7 +305,7 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark }, condition = "DISTINCT_WINDOW_FUNCTION_UNSUPPORTED", parameters = Map("windowExpr" -> - ("\"listagg(DISTINCT a, ,) " + + ("\"listagg(DISTINCT a, NULL) " + "OVER (ORDER BY a ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)\"")), context = ExpectedContext( fragment = "listagg(distinct a) over (order by a)", @@ -298,7 +314,7 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark checkAnswer( sql("select listagg(distinct a) within group (order by a DESC) from df"), - Row("b,a") :: Nil) + Row("ba") :: Nil) checkError( exception = intercept[AnalysisException] { @@ -331,8 +347,8 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark Seq((1, true), (2, false), (3, false)).toDF("a", "b").createOrReplaceTempView("df2") checkAnswer( - sql("select listagg(a), listagg(b) from df2"), - Row("1,2,3", "true,false,false") :: Nil) + sql("select listagg(a), listagg(b, ',') from df2"), + Row("123", "true,false,false") :: Nil) } } From 6f74b67515ed3829d26c2ebe71cbbc5b5919d683 Mon Sep 17 00:00:00 2001 From: Mikhail Nikoliukin Date: Fri, 1 Nov 2024 17:02:20 +0100 Subject: [PATCH 28/58] [SPARK-42746] add multi expression ordering support --- .../expressions/aggregate/collect.scala | 38 ++++++++++++++----- .../org/apache/spark/sql/SQLQuerySuite.scala | 8 ++++ 2 files changed, 36 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala index a6b026798e698..e29ba3ae6e068 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala @@ -323,13 +323,33 @@ case class ListAgg( override def defaultResult: Option[Literal] = Option(Literal.create(null, dataType)) + private[this] def orderValuesField: Seq[StructField] = { + orderExpressions.zipWithIndex.map { + case (order, i) => StructField(s"sortOrderValue[$i]", order.dataType) + } + } + + private[this] def evalOrderValues(internalRow: InternalRow): Seq[Any] = { + orderExpressions.map(order => convertToBufferElement(order.child.eval(internalRow))) + } + + private[this] def bufferOrdering: Ordering[InternalRow] = { + val bufferSortOrder = orderExpressions.zipWithIndex.map { + case (originalOrder, i) => + originalOrder.copy( + child = BoundReference(i + 1, originalOrder.dataType, nullable = true) + ) + } + new InterpretedOrdering(bufferSortOrder) + } + override protected lazy val bufferElementType: DataType = { if (dontNeedSaveOrderValue) { child.dataType } else { - StructType(Seq( - StructField("value", child.dataType), - StructField("sortOrderValue", orderExpressions.head.dataType))) + StructType( + StructField("value", child.dataType) + +: orderValuesField) } } @@ -373,15 +393,14 @@ case class ListAgg( if (!orderingFilled) { return buffer } - val ascendingOrdering = PhysicalDataType.ordering(orderExpressions.head.dataType) - val ordering = if (orderExpressions.head.direction == Ascending) ascendingOrdering - else ascendingOrdering.reverse - if (dontNeedSaveOrderValue) { + val ascendingOrdering = PhysicalDataType.ordering(orderExpressions.head.dataType) + val ordering = if (orderExpressions.head.direction == Ascending) ascendingOrdering + else ascendingOrdering.reverse buffer.sorted(ordering) } else { buffer.asInstanceOf[mutable.ArrayBuffer[InternalRow]] - .sortBy(_.get(1, orderExpressions.head.dataType))(ordering.asInstanceOf[Ordering[AnyRef]]) + .sorted(bufferOrdering) .map(_.get(0, child.dataType)) } } @@ -426,8 +445,7 @@ case class ListAgg( val v = if (dontNeedSaveOrderValue) { convertToBufferElement(value) } else { - InternalRow.apply(convertToBufferElement(value), - convertToBufferElement(orderExpressions.head.child.eval(input))) + InternalRow.fromSeq(convertToBufferElement(value) +: evalOrderValues(input)) } buffer += v } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index bc2f5b8725e6f..44453ca0bbc5b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -227,6 +227,14 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark sql("select listagg(a, '|') within group (order by b desc) from df"), Row("b|a|b|a") :: Nil) + checkAnswer( + sql("select listagg(a) within group (order by b desc, a asc) from df"), + Row("baba") :: Nil) + + checkAnswer( + sql("select listagg(a) within group (order by b desc, a desc) from df"), + Row("bbaa") :: Nil) + checkAnswer( sql("select listagg(c1)from values (X'DEAD'), (X'BEEF') as t(c1)"), Row(hexToBytes("DEADBEEF")) :: Nil) From 5050630c764441a2321e46e50ff4df8a063b21b6 Mon Sep 17 00:00:00 2001 From: Mikhail Nikoliukin Date: Mon, 4 Nov 2024 15:26:26 +0100 Subject: [PATCH 29/58] [SPARK-42746] add scala functions --- .../apache/spark/sql/FunctionTestSuite.scala | 2 + .../org/apache/spark/sql/functions.scala | 46 +++++++++++++++++-- .../expressions/aggregate/collect.scala | 30 +++++++++--- .../org/apache/spark/sql/SQLQuerySuite.scala | 2 +- 4 files changed, 68 insertions(+), 12 deletions(-) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/FunctionTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/FunctionTestSuite.scala index 40b66bcb8358d..d40e0cbe55bca 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/FunctionTestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/FunctionTestSuite.scala @@ -65,6 +65,8 @@ class FunctionTestSuite extends ConnectFunSuite { testEquals("avg/mean", avg("a"), avg(a), mean(a), mean("a")) testEquals("collect_list", collect_list("a"), collect_list(a)) testEquals("collect_set", collect_set("a"), collect_set(a)) + testEquals("listagg", listagg("a"), listagg(a)) + testEquals("listagg_distinct", listagg_distinct("a"), listagg_distinct(a)) testEquals("corr", corr("a", "b"), corr(a, b)) testEquals( "count_distinct", diff --git a/sql/api/src/main/scala/org/apache/spark/sql/functions.scala b/sql/api/src/main/scala/org/apache/spark/sql/functions.scala index 502258fb5aed1..fc0dae905ce4a 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/functions.scala @@ -1148,7 +1148,7 @@ object functions { def sum_distinct(e: Column): Column = Column.fn("sum", isDistinct = true, e) /** - * Aggregate function: returns the concatenated input values. + * Aggregate function: returns the concatenation of non-null input values. * * @group agg_funcs * @since 4.0.0 @@ -1156,7 +1156,16 @@ object functions { def listagg(e: Column): Column = Column.fn("listagg", e) /** - * Aggregate function: returns the concatenated input values, separated by the delimiter string. + * Aggregate function: returns the concatenation of non-null input values. + * + * @group agg_funcs + * @since 4.0.0 + */ + def listagg(columnName: String): Column = listagg(Column(columnName)) + + /** + * Aggregate function: returns the concatenation of non-null input values, + * separated by the delimiter string. * * @group agg_funcs * @since 4.0.0 @@ -1164,7 +1173,17 @@ object functions { def listagg(e: Column, delimiter: Column): Column = Column.fn("listagg", e, delimiter) /** - * Aggregate function: returns the concatenated input values, separated by the delimiter string. + * Aggregate function: returns the concatenation of non-null input values, + * separated by the delimiter string. + * + * @group agg_funcs + * @since 4.0.0 + */ + def listagg(columnName: String, delimiter: String): Column = + listagg(Column(columnName), lit(delimiter)) + + /** + * Aggregate function: returns the concatenation of distinct non-null input values. * * @group agg_funcs * @since 4.0.0 @@ -1172,7 +1191,16 @@ object functions { def listagg_distinct(e: Column): Column = Column.fn("listagg", isDistinct = true, e) /** - * Aggregate function: returns the concatenated input values, separated by the delimiter string. + * Aggregate function: returns the concatenation of distinct non-null input values. + * + * @group agg_funcs + * @since 4.0.0 + */ + def listagg_distinct(columnName: String): Column = listagg_distinct(Column(columnName)) + + /** + * Aggregate function: returns the concatenation of distinct non-null input values, + * separated by the delimiter string. * * @group agg_funcs * @since 4.0.0 @@ -1180,6 +1208,16 @@ object functions { def listagg_distinct(e: Column, delimiter: Column): Column = Column.fn("listagg", isDistinct = true, e, delimiter) + /** + * Aggregate function: returns the concatenation of distinct non-null input values, + * separated by the delimiter string. + * + * @group agg_funcs + * @since 4.0.0 + */ + def listagg_distinct(columnName: String, delimiter: String): Column = + listagg_distinct(Column(columnName), lit(delimiter)) + /** * Aggregate function: alias for `var_samp`. * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala index e29ba3ae6e068..fd2bf93c525fd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala @@ -287,8 +287,12 @@ private[aggregate] object CollectTopK { > SELECT _FUNC_(col) FROM VALUES (NULL), (NULL) AS tab(col); NULL """, + note = """ + If order is not specified the function is non-deterministic because + the order of the rows may be non-deterministic after a shuffle. + """, group = "agg_funcs", - since = "4.0.0") // TODO change + since = "4.0.0") case class ListAgg( child: Expression, delimiter: Expression = Literal(null), @@ -317,12 +321,22 @@ case class ListAgg( copy(inputAggBufferOffset = newInputAggBufferOffset) /** Indicates that the result of [[child]] is enough for evaluation */ - private lazy val dontNeedSaveOrderValue = isOrderCompatible(orderExpressions) + private lazy val noNeedSaveOrderValue: Boolean = isOrderCompatible(orderExpressions) override protected def convertToBufferElement(value: Any): Any = InternalRow.copyValue(value) override def defaultResult: Option[Literal] = Option(Literal.create(null, dataType)) + override def sql(isDistinct: Boolean): String = { + val distinct = if (isDistinct) "DISTINCT " else "" + val withinGroup = if (orderingFilled) { + s" WITHIN GROUP (ORDER BY ${orderExpressions.map(_.sql).mkString(", ")})" + } else { + "" + } + s"$prettyName($distinct${child.sql}, ${delimiter.sql})$withinGroup" + } + private[this] def orderValuesField: Seq[StructField] = { orderExpressions.zipWithIndex.map { case (order, i) => StructField(s"sortOrderValue[$i]", order.dataType) @@ -337,14 +351,15 @@ case class ListAgg( val bufferSortOrder = orderExpressions.zipWithIndex.map { case (originalOrder, i) => originalOrder.copy( - child = BoundReference(i + 1, originalOrder.dataType, nullable = true) + // first value is the evaluated child so add +1 for order's values + child = BoundReference(i + 1, originalOrder.dataType, originalOrder.child.nullable) ) } new InterpretedOrdering(bufferSortOrder) } override protected lazy val bufferElementType: DataType = { - if (dontNeedSaveOrderValue) { + if (noNeedSaveOrderValue) { child.dataType } else { StructType( @@ -393,7 +408,7 @@ case class ListAgg( if (!orderingFilled) { return buffer } - if (dontNeedSaveOrderValue) { + if (noNeedSaveOrderValue) { val ascendingOrdering = PhysicalDataType.ordering(orderExpressions.head.dataType) val ordering = if (orderExpressions.head.direction == Ascending) ascendingOrdering else ascendingOrdering.reverse @@ -401,6 +416,7 @@ case class ListAgg( } else { buffer.asInstanceOf[mutable.ArrayBuffer[InternalRow]] .sorted(bufferOrdering) + // drop order values after sort .map(_.get(0, child.dataType)) } } @@ -442,7 +458,7 @@ case class ListAgg( override def update(buffer: ArrayBuffer[Any], input: InternalRow): ArrayBuffer[Any] = { val value = child.eval(input) if (value != null) { - val v = if (dontNeedSaveOrderValue) { + val v = if (noNeedSaveOrderValue) { convertToBufferElement(value) } else { InternalRow.fromSeq(convertToBufferElement(value) +: evalOrderValues(input)) @@ -470,7 +486,7 @@ case class ListAgg( * Utility func to check if given order is defined and different from [[child]]. * * @see [[QueryCompilationErrors.functionAndOrderExpressionMismatchError]] - * @see [[dontNeedSaveOrderValue]] + * @see [[noNeedSaveOrderValue]] */ def isOrderCompatible(someOrder: Seq[SortOrder]): Boolean = { if (someOrder.isEmpty) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 44453ca0bbc5b..df20f481b7751 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -164,7 +164,7 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark private[this] def hexToBytes(s: String): Array[Byte] = { val byteArray = BigInt(s, 16).toByteArray if (byteArray.length > 1 && byteArray(0) == 0) { - // remove sign byte if exists + // remove sign byte for positive numbers if exists byteArray.tail } else { byteArray From b75855dda59ac29836e61a37b25520f057004fab Mon Sep 17 00:00:00 2001 From: Mikhail Nikoliukin Date: Mon, 4 Nov 2024 15:41:19 +0100 Subject: [PATCH 30/58] [SPARK-42746] add string_agg alias --- .../org/apache/spark/sql/functions.scala | 80 +++++++++++++++++++ .../catalyst/analysis/FunctionRegistry.scala | 1 + .../org/apache/spark/sql/SQLQuerySuite.scala | 15 ++++ 3 files changed, 96 insertions(+) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/functions.scala b/sql/api/src/main/scala/org/apache/spark/sql/functions.scala index fc0dae905ce4a..7b990bac0a79a 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/functions.scala @@ -1218,6 +1218,86 @@ object functions { def listagg_distinct(columnName: String, delimiter: String): Column = listagg_distinct(Column(columnName), lit(delimiter)) + /** + * Aggregate function: returns the concatenation of non-null input values. + * Alias for `listagg`. + * + * @group agg_funcs + * @since 4.0.0 + */ + def string_agg(e: Column): Column = Column.fn("string_agg", e) + + /** + * Aggregate function: returns the concatenation of non-null input values. + * Alias for `listagg`. + * + * @group agg_funcs + * @since 4.0.0 + */ + def string_agg(columnName: String): Column = string_agg(Column(columnName)) + + /** + * Aggregate function: returns the concatenation of non-null input values, + * separated by the delimiter string. + * Alias for `listagg`. + * + * @group agg_funcs + * @since 4.0.0 + */ + def string_agg(e: Column, delimiter: Column): Column = Column.fn("string_agg", e, delimiter) + + /** + * Aggregate function: returns the concatenation of non-null input values, + * separated by the delimiter string. + * Alias for `listagg`. + * + * @group agg_funcs + * @since 4.0.0 + */ + def string_agg(columnName: String, delimiter: String): Column = + string_agg(Column(columnName), lit(delimiter)) + + /** + * Aggregate function: returns the concatenation of distinct non-null input values. + * Alias for `listagg`. + * + * @group agg_funcs + * @since 4.0.0 + */ + def string_agg_distinct(e: Column): Column = Column.fn("string_agg", isDistinct = true, e) + + /** + * Aggregate function: returns the concatenation of distinct non-null input values. + * Alias for `listagg`. + * + * @group agg_funcs + * @since 4.0.0 + */ + def string_agg_distinct(columnName: String): Column = string_agg_distinct(Column(columnName)) + + /** + * Aggregate function: returns the concatenation of distinct non-null input values, + * separated by the delimiter string. + * Alias for `listagg`. + * + * @group agg_funcs + * @since 4.0.0 + */ + def string_agg_distinct(e: Column, delimiter: Column): Column = + Column.fn("string_agg", isDistinct = true, e, delimiter) + + /** + * Aggregate function: returns the concatenation of distinct non-null input values, + * separated by the delimiter string. + * Alias for `listagg`. + * + * @group agg_funcs + * @since 4.0.0 + */ + def string_agg_distinct(columnName: String, delimiter: String): Column = + string_agg_distinct(Column(columnName), lit(delimiter)) + + /** * Aggregate function: alias for `var_samp`. * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index cd5890da485fb..7a63be2c825b2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -506,6 +506,7 @@ object FunctionRegistry { expression[CollectList]("array_agg", true, Some("3.3.0")), expression[CollectSet]("collect_set"), expression[ListAgg]("listagg"), + expression[ListAgg]("string_agg", setAlias = true), expressionBuilder("count_min_sketch", CountMinSketchAggExpressionBuilder), expression[BoolAnd]("every", true), expression[BoolAnd]("bool_and"), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index df20f481b7751..7b5a68b8caac8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -179,6 +179,10 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark sql("select listagg(b) from df group by a"), Row(null) :: Row("bc") :: Row("cd") :: Nil) + checkAnswer( + sql("select string_agg(b) from df group by a"), + Row(null) :: Row("bc") :: Row("cd") :: Nil) + checkAnswer( sql("select listagg(b, null) from df group by a"), Row(null) :: Row("bc") :: Row("cd") :: Nil) @@ -307,6 +311,17 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark start = 7, stop = 60)) + checkError( + exception = intercept[AnalysisException] { + sql("select string_agg(a) within group (order by a) over (order by a) from df") + }, + condition = "INVALID_WINDOW_SPEC_FOR_AGGREGATION_FUNC", + parameters = Map("aggFunc" -> "\"listagg(a, NULL, a)\""), + context = ExpectedContext( + fragment = "string_agg(a) within group (order by a) over (order by a)", + start = 7, + stop = 63)) + checkError( exception = intercept[AnalysisException] { sql("select listagg(distinct a) over (order by a) from df") From 14ee65edf0885a7f2be2fe79c8e321c82b033e73 Mon Sep 17 00:00:00 2001 From: Mikhail Nikoliukin Date: Mon, 4 Nov 2024 16:29:46 +0100 Subject: [PATCH 31/58] [SPARK-42746] return licence to SupportsOrderingWithinGroup --- .../aggregate/SupportsOrderingWithinGroup.scala | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/SupportsOrderingWithinGroup.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/SupportsOrderingWithinGroup.scala index dfc28455d5b7f..f9b6aebf909bd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/SupportsOrderingWithinGroup.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/SupportsOrderingWithinGroup.scala @@ -1,3 +1,19 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.apache.spark.sql.catalyst.expressions.aggregate From e638d291638907c0123b71bc2b08074a870969d0 Mon Sep 17 00:00:00 2001 From: Mikhail Nikoliukin Date: Mon, 4 Nov 2024 18:41:18 +0100 Subject: [PATCH 32/58] [SPARK-42746] add collation tests --- .../expressions/aggregate/collect.scala | 4 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 47 +++++++++++++++++++ 2 files changed, 49 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala index e938527ae527b..9ef00176e44e7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala @@ -429,8 +429,8 @@ case class ListAgg( if (delimiterValue == null) { // default delimiter value dataType match { - case StringType => UTF8String.fromString("") - case BinaryType => ByteArray.EMPTY_BYTE + case _: StringType => UTF8String.fromString("") + case _: BinaryType => ByteArray.EMPTY_BYTE } } else { delimiterValue diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 7b5a68b8caac8..f00edfa03448d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -375,6 +375,53 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark } } + test("listagg collation test") { + checkAnswer( + sql("select listagg(c1) within group (order by c1 collate utf8_binary)" + + " from values ('a'), ('A'), ('b'), ('B') as t(c1)"), + Row("ABab") :: Nil) + + checkAnswer( + sql("select listagg(c1) within group (order by c1 collate utf8_lcase)" + + " from values ('a'), ('A'), ('b'), ('B') as t(c1)"), + Row("aAbB") :: Nil) + + checkAnswer( + sql("select listagg(DISTINCT c1 collate utf8_binary)" + + " from values ('a'), ('A'), ('b'), ('B') as t(c1)"), + Row("aAbB") :: Nil) + + checkAnswer( + sql("select listagg(DISTINCT c1 collate utf8_lcase)" + + " from values ('a'), ('A'), ('b'), ('B') as t(c1)"), + Row("ab") :: Nil) + + checkAnswer( + sql("select listagg(DISTINCT c1 collate utf8_lcase)" + + " within group (order by c1 collate utf8_lcase)" + + " from values ('a'), ('B'), ('b'), ('A') as t(c1)"), + Row("aB") :: Nil) + + checkError( + exception = intercept[AnalysisException] { + sql( + """select listagg(DISTINCT c1 collate utf8_lcase) + | within group (order by c1 collate utf8_binary) + | from values ('a'), ('b'), ('A'), ('B') as t(c1)""".stripMargin) + }, + condition = "FUNCTION_AND_ORDER_EXPRESSION_MISMATCH", + parameters = Map( + "functionName" -> "`listagg`", + "functionExpr" -> "\"collate(c1, utf8_lcase)\"", + "orderExpr" -> "\"collate(c1, utf8_binary)\""), + context = ExpectedContext( + fragment = + """listagg(DISTINCT c1 collate utf8_lcase) + | within group (order by c1 collate utf8_binary)""".stripMargin, + start = 7, + stop = 93)) + } + test("support table.star") { checkAnswer( sql( From dfcc112f6214daba06fe2daf57b11bffe10fe4c3 Mon Sep 17 00:00:00 2001 From: Mikhail Nikoliukin Date: Mon, 4 Nov 2024 18:52:33 +0100 Subject: [PATCH 33/58] [SPARK-42746] add listagg to excludedDataFrameFunctions --- .../java/org/apache/spark/unsafe/array/ByteArraySuite.java | 2 +- .../org/apache/spark/sql/DataFrameFunctionsSuite.scala | 7 ++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/array/ByteArraySuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/array/ByteArraySuite.java index e86d52fc90e02..0fb83088b49f9 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/array/ByteArraySuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/array/ByteArraySuite.java @@ -96,7 +96,7 @@ public void testConcat() { @Test public void testConcatWS() { - byte[] separator = new byte[]{(byte) 42}; // Separator byte array + byte[] separator = new byte[]{(byte) 42}; byte[] x1 = new byte[]{(byte) 1, (byte) 2, (byte) 3}; byte[] y1 = new byte[]{(byte) 4, (byte) 5, (byte) 6}; diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index d4c67551343fb..4494057b1eefe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -73,7 +73,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "sum_distinct", // equivalent to sum(distinct foo) "typedLit", "typedlit", // Scala only "udaf", "udf", // create function statement in sql - "call_function" // moot in SQL as you just call the function directly + "call_function", // moot in SQL as you just call the function directly + "listagg_distinct", // equivalent to listagg(distinct foo) + "string_agg_distinct" // equivalent to string_agg(distinct foo) ) val excludedSqlFunctions = Set.empty[String] @@ -83,8 +85,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "product", // Discussed in https://github.com/apache/spark/pull/30745 "unwrap_udt", "timestamp_add", - "timestamp_diff", - "listagg_distinct" + "timestamp_diff" ) // We only consider functions matching this pattern, this excludes symbolic and other From 3cbe9e9edac01b7957aba8efce21767ac793c8e8 Mon Sep 17 00:00:00 2001 From: Mikhail Nikoliukin Date: Mon, 4 Nov 2024 19:01:13 +0100 Subject: [PATCH 34/58] [SPARK-42746] add string_agg to expected_missing_in_py --- python/pyspark/sql/tests/test_functions.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/pyspark/sql/tests/test_functions.py b/python/pyspark/sql/tests/test_functions.py index bae26cfbf989f..4f6b57502cbd5 100644 --- a/python/pyspark/sql/tests/test_functions.py +++ b/python/pyspark/sql/tests/test_functions.py @@ -87,6 +87,8 @@ def test_function_parity(self): # TODO: listagg functions will soon be added and removed from this list "listagg_distinct", "listagg", + "string_agg", + "string_agg_distinct", } self.assertEqual( From 7105c7cecfcda120c8891535dcc87debd3035450 Mon Sep 17 00:00:00 2001 From: Mikhail Nikoliukin Date: Mon, 4 Nov 2024 19:37:24 +0100 Subject: [PATCH 35/58] [SPARK-42746] add follow-up ticket --- python/pyspark/sql/tests/test_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/tests/test_functions.py b/python/pyspark/sql/tests/test_functions.py index 4f6b57502cbd5..6a68afb284781 100644 --- a/python/pyspark/sql/tests/test_functions.py +++ b/python/pyspark/sql/tests/test_functions.py @@ -84,7 +84,7 @@ def test_function_parity(self): # Functions that we expect to be missing in python until they are added to pyspark expected_missing_in_py = { - # TODO: listagg functions will soon be added and removed from this list + # TODO(SPARK-50220): listagg functions will soon be added and removed from this list "listagg_distinct", "listagg", "string_agg", From 27cbd038f27be43f3795bd8623230489c843a5a0 Mon Sep 17 00:00:00 2001 From: Mikhail Nikoliukin Date: Tue, 5 Nov 2024 10:06:21 +0100 Subject: [PATCH 36/58] [SPARK-42746] fix formating --- .../org/apache/spark/sql/functions.scala | 51 ++++++++----------- .../sql-functions/sql-expression-schema.md | 3 +- 2 files changed, 24 insertions(+), 30 deletions(-) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/functions.scala b/sql/api/src/main/scala/org/apache/spark/sql/functions.scala index 7b990bac0a79a..4f883d9fb515b 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/functions.scala @@ -1164,8 +1164,8 @@ object functions { def listagg(columnName: String): Column = listagg(Column(columnName)) /** - * Aggregate function: returns the concatenation of non-null input values, - * separated by the delimiter string. + * Aggregate function: returns the concatenation of non-null input values, separated by the + * delimiter string. * * @group agg_funcs * @since 4.0.0 @@ -1173,8 +1173,8 @@ object functions { def listagg(e: Column, delimiter: Column): Column = Column.fn("listagg", e, delimiter) /** - * Aggregate function: returns the concatenation of non-null input values, - * separated by the delimiter string. + * Aggregate function: returns the concatenation of non-null input values, separated by the + * delimiter string. * * @group agg_funcs * @since 4.0.0 @@ -1199,8 +1199,8 @@ object functions { def listagg_distinct(columnName: String): Column = listagg_distinct(Column(columnName)) /** - * Aggregate function: returns the concatenation of distinct non-null input values, - * separated by the delimiter string. + * Aggregate function: returns the concatenation of distinct non-null input values, separated by + * the delimiter string. * * @group agg_funcs * @since 4.0.0 @@ -1209,8 +1209,8 @@ object functions { Column.fn("listagg", isDistinct = true, e, delimiter) /** - * Aggregate function: returns the concatenation of distinct non-null input values, - * separated by the delimiter string. + * Aggregate function: returns the concatenation of distinct non-null input values, separated by + * the delimiter string. * * @group agg_funcs * @since 4.0.0 @@ -1219,8 +1219,7 @@ object functions { listagg_distinct(Column(columnName), lit(delimiter)) /** - * Aggregate function: returns the concatenation of non-null input values. - * Alias for `listagg`. + * Aggregate function: returns the concatenation of non-null input values. Alias for `listagg`. * * @group agg_funcs * @since 4.0.0 @@ -1228,8 +1227,7 @@ object functions { def string_agg(e: Column): Column = Column.fn("string_agg", e) /** - * Aggregate function: returns the concatenation of non-null input values. - * Alias for `listagg`. + * Aggregate function: returns the concatenation of non-null input values. Alias for `listagg`. * * @group agg_funcs * @since 4.0.0 @@ -1237,9 +1235,8 @@ object functions { def string_agg(columnName: String): Column = string_agg(Column(columnName)) /** - * Aggregate function: returns the concatenation of non-null input values, - * separated by the delimiter string. - * Alias for `listagg`. + * Aggregate function: returns the concatenation of non-null input values, separated by the + * delimiter string. Alias for `listagg`. * * @group agg_funcs * @since 4.0.0 @@ -1247,9 +1244,8 @@ object functions { def string_agg(e: Column, delimiter: Column): Column = Column.fn("string_agg", e, delimiter) /** - * Aggregate function: returns the concatenation of non-null input values, - * separated by the delimiter string. - * Alias for `listagg`. + * Aggregate function: returns the concatenation of non-null input values, separated by the + * delimiter string. Alias for `listagg`. * * @group agg_funcs * @since 4.0.0 @@ -1258,8 +1254,8 @@ object functions { string_agg(Column(columnName), lit(delimiter)) /** - * Aggregate function: returns the concatenation of distinct non-null input values. - * Alias for `listagg`. + * Aggregate function: returns the concatenation of distinct non-null input values. Alias for + * `listagg`. * * @group agg_funcs * @since 4.0.0 @@ -1267,8 +1263,8 @@ object functions { def string_agg_distinct(e: Column): Column = Column.fn("string_agg", isDistinct = true, e) /** - * Aggregate function: returns the concatenation of distinct non-null input values. - * Alias for `listagg`. + * Aggregate function: returns the concatenation of distinct non-null input values. Alias for + * `listagg`. * * @group agg_funcs * @since 4.0.0 @@ -1276,9 +1272,8 @@ object functions { def string_agg_distinct(columnName: String): Column = string_agg_distinct(Column(columnName)) /** - * Aggregate function: returns the concatenation of distinct non-null input values, - * separated by the delimiter string. - * Alias for `listagg`. + * Aggregate function: returns the concatenation of distinct non-null input values, separated by + * the delimiter string. Alias for `listagg`. * * @group agg_funcs * @since 4.0.0 @@ -1287,9 +1282,8 @@ object functions { Column.fn("string_agg", isDistinct = true, e, delimiter) /** - * Aggregate function: returns the concatenation of distinct non-null input values, - * separated by the delimiter string. - * Alias for `listagg`. + * Aggregate function: returns the concatenation of distinct non-null input values, separated by + * the delimiter string. Alias for `listagg`. * * @group agg_funcs * @since 4.0.0 @@ -1297,7 +1291,6 @@ object functions { def string_agg_distinct(columnName: String, delimiter: String): Column = string_agg_distinct(Column(columnName), lit(delimiter)) - /** * Aggregate function: alias for `var_samp`. * diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md index 383ad790a8dd0..5e732aaff9e70 100644 --- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md +++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md @@ -424,7 +424,8 @@ | org.apache.spark.sql.catalyst.expressions.aggregate.Kurtosis | kurtosis | SELECT kurtosis(col) FROM VALUES (-10), (-20), (100), (1000) AS tab(col) | struct | | org.apache.spark.sql.catalyst.expressions.aggregate.Last | last | SELECT last(col) FROM VALUES (10), (5), (20) AS tab(col) | struct | | org.apache.spark.sql.catalyst.expressions.aggregate.Last | last_value | SELECT last_value(col) FROM VALUES (10), (5), (20) AS tab(col) | struct | -| org.apache.spark.sql.catalyst.expressions.aggregate.ListAgg | listagg | SELECT listagg(col) FROM VALUES ('a'), ('b'), ('c') AS tab(col) | struct | +| org.apache.spark.sql.catalyst.expressions.aggregate.ListAgg | listagg | SELECT listagg(col) FROM VALUES ('a'), ('b'), ('c') AS tab(col) | struct | +| org.apache.spark.sql.catalyst.expressions.aggregate.ListAgg | string_agg | SELECT string_agg(col) FROM VALUES ('a'), ('b'), ('c') AS tab(col) | struct | | org.apache.spark.sql.catalyst.expressions.aggregate.Max | max | SELECT max(col) FROM VALUES (10), (50), (20) AS tab(col) | struct | | org.apache.spark.sql.catalyst.expressions.aggregate.MaxBy | max_by | SELECT max_by(x, y) FROM VALUES ('a', 10), ('b', 50), ('c', 20) AS tab(x, y) | struct | | org.apache.spark.sql.catalyst.expressions.aggregate.Median | median | SELECT median(col) FROM VALUES (0), (10) AS tab(col) | struct | From d514787921193912a321ad15a341bdbce3238ffe Mon Sep 17 00:00:00 2001 From: Mikhail Nikoliukin Date: Tue, 5 Nov 2024 10:08:31 +0100 Subject: [PATCH 37/58] [SPARK-42746] remove functions with columnName --- .../org/apache/spark/sql/functions.scala | 73 ------------------- 1 file changed, 73 deletions(-) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/functions.scala b/sql/api/src/main/scala/org/apache/spark/sql/functions.scala index 4f883d9fb515b..47d53f7b65635 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/functions.scala @@ -1155,14 +1155,6 @@ object functions { */ def listagg(e: Column): Column = Column.fn("listagg", e) - /** - * Aggregate function: returns the concatenation of non-null input values. - * - * @group agg_funcs - * @since 4.0.0 - */ - def listagg(columnName: String): Column = listagg(Column(columnName)) - /** * Aggregate function: returns the concatenation of non-null input values, separated by the * delimiter string. @@ -1172,16 +1164,6 @@ object functions { */ def listagg(e: Column, delimiter: Column): Column = Column.fn("listagg", e, delimiter) - /** - * Aggregate function: returns the concatenation of non-null input values, separated by the - * delimiter string. - * - * @group agg_funcs - * @since 4.0.0 - */ - def listagg(columnName: String, delimiter: String): Column = - listagg(Column(columnName), lit(delimiter)) - /** * Aggregate function: returns the concatenation of distinct non-null input values. * @@ -1190,14 +1172,6 @@ object functions { */ def listagg_distinct(e: Column): Column = Column.fn("listagg", isDistinct = true, e) - /** - * Aggregate function: returns the concatenation of distinct non-null input values. - * - * @group agg_funcs - * @since 4.0.0 - */ - def listagg_distinct(columnName: String): Column = listagg_distinct(Column(columnName)) - /** * Aggregate function: returns the concatenation of distinct non-null input values, separated by * the delimiter string. @@ -1208,16 +1182,6 @@ object functions { def listagg_distinct(e: Column, delimiter: Column): Column = Column.fn("listagg", isDistinct = true, e, delimiter) - /** - * Aggregate function: returns the concatenation of distinct non-null input values, separated by - * the delimiter string. - * - * @group agg_funcs - * @since 4.0.0 - */ - def listagg_distinct(columnName: String, delimiter: String): Column = - listagg_distinct(Column(columnName), lit(delimiter)) - /** * Aggregate function: returns the concatenation of non-null input values. Alias for `listagg`. * @@ -1226,14 +1190,6 @@ object functions { */ def string_agg(e: Column): Column = Column.fn("string_agg", e) - /** - * Aggregate function: returns the concatenation of non-null input values. Alias for `listagg`. - * - * @group agg_funcs - * @since 4.0.0 - */ - def string_agg(columnName: String): Column = string_agg(Column(columnName)) - /** * Aggregate function: returns the concatenation of non-null input values, separated by the * delimiter string. Alias for `listagg`. @@ -1243,16 +1199,6 @@ object functions { */ def string_agg(e: Column, delimiter: Column): Column = Column.fn("string_agg", e, delimiter) - /** - * Aggregate function: returns the concatenation of non-null input values, separated by the - * delimiter string. Alias for `listagg`. - * - * @group agg_funcs - * @since 4.0.0 - */ - def string_agg(columnName: String, delimiter: String): Column = - string_agg(Column(columnName), lit(delimiter)) - /** * Aggregate function: returns the concatenation of distinct non-null input values. Alias for * `listagg`. @@ -1262,15 +1208,6 @@ object functions { */ def string_agg_distinct(e: Column): Column = Column.fn("string_agg", isDistinct = true, e) - /** - * Aggregate function: returns the concatenation of distinct non-null input values. Alias for - * `listagg`. - * - * @group agg_funcs - * @since 4.0.0 - */ - def string_agg_distinct(columnName: String): Column = string_agg_distinct(Column(columnName)) - /** * Aggregate function: returns the concatenation of distinct non-null input values, separated by * the delimiter string. Alias for `listagg`. @@ -1281,16 +1218,6 @@ object functions { def string_agg_distinct(e: Column, delimiter: Column): Column = Column.fn("string_agg", isDistinct = true, e, delimiter) - /** - * Aggregate function: returns the concatenation of distinct non-null input values, separated by - * the delimiter string. Alias for `listagg`. - * - * @group agg_funcs - * @since 4.0.0 - */ - def string_agg_distinct(columnName: String, delimiter: String): Column = - string_agg_distinct(Column(columnName), lit(delimiter)) - /** * Aggregate function: alias for `var_samp`. * From 5fd9a302dc54d5b7eab1bc92c0fe561b21074331 Mon Sep 17 00:00:00 2001 From: Mikhail Nikoliukin Date: Tue, 5 Nov 2024 10:28:43 +0100 Subject: [PATCH 38/58] [SPARK-42746] reformat file --- .../org/apache/spark/sql/functions.scala | 8 +- .../expressions/aggregate/collect.scala | 162 +++++++++--------- 2 files changed, 88 insertions(+), 82 deletions(-) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/functions.scala b/sql/api/src/main/scala/org/apache/spark/sql/functions.scala index 47d53f7b65635..465fef43904a2 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/functions.scala @@ -1157,7 +1157,7 @@ object functions { /** * Aggregate function: returns the concatenation of non-null input values, separated by the - * delimiter string. + * delimiter. * * @group agg_funcs * @since 4.0.0 @@ -1174,7 +1174,7 @@ object functions { /** * Aggregate function: returns the concatenation of distinct non-null input values, separated by - * the delimiter string. + * the delimiter. * * @group agg_funcs * @since 4.0.0 @@ -1192,7 +1192,7 @@ object functions { /** * Aggregate function: returns the concatenation of non-null input values, separated by the - * delimiter string. Alias for `listagg`. + * delimiter. Alias for `listagg`. * * @group agg_funcs * @since 4.0.0 @@ -1210,7 +1210,7 @@ object functions { /** * Aggregate function: returns the concatenation of distinct non-null input values, separated by - * the delimiter string. Alias for `listagg`. + * the delimiter. Alias for `listagg`. * * @group agg_funcs * @since 4.0.0 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala index 9ef00176e44e7..44983159ec202 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala @@ -273,8 +273,8 @@ private[aggregate] object CollectTopK { } @ExpressionDescription( - usage = "_FUNC_(expr) - Returns the concatenated input non-null values," + - " separated by the delimiter string.", + usage = "_FUNC_(expr) - Returns the concatenation of non-null input values," + + " separated by the delimiter.", examples = """ Examples: > SELECT _FUNC_(col) FROM VALUES ('a'), ('b'), ('c') AS tab(col); @@ -295,39 +295,47 @@ private[aggregate] object CollectTopK { the order of the rows may be non-deterministic after a shuffle. """, group = "agg_funcs", - since = "4.0.0") + since = "4.0.0" +) case class ListAgg( child: Expression, delimiter: Expression = Literal(null), orderExpressions: Seq[SortOrder] = Nil, mutableAggBufferOffset: Int = 0, - inputAggBufferOffset: Int = 0) extends Collect[mutable.ArrayBuffer[Any]] + inputAggBufferOffset: Int = 0) + extends Collect[mutable.ArrayBuffer[Any]] with SupportsOrderingWithinGroup with ImplicitCastInputTypes { + override protected lazy val bufferElementType: DataType = { + if (noNeedSaveOrderValue) { + child.dataType + } else { + StructType( + StructField("value", child.dataType) + +: orderValuesField + ) + } + } + /** Indicates that the result of [[child]] is enough for evaluation */ + private lazy val noNeedSaveOrderValue: Boolean = isOrderCompatible(orderExpressions) + def this(child: Expression) = this(child, Literal(null), Nil, 0, 0) + def this(child: Expression, delimiter: Expression) = this(child, delimiter, Nil, 0, 0) - override def dataType: DataType = child.dataType - override def nullable: Boolean = true override def createAggregationBuffer(): mutable.ArrayBuffer[Any] = mutable.ArrayBuffer.empty - override def withNewMutableAggBufferOffset( - newMutableAggBufferOffset: Int): ImperativeAggregate = + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = copy(mutableAggBufferOffset = newMutableAggBufferOffset) override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = copy(inputAggBufferOffset = newInputAggBufferOffset) - /** Indicates that the result of [[child]] is enough for evaluation */ - private lazy val noNeedSaveOrderValue: Boolean = isOrderCompatible(orderExpressions) - - override protected def convertToBufferElement(value: Any): Any = InternalRow.copyValue(value) - override def defaultResult: Option[Literal] = Option(Literal.create(null, dataType)) override def sql(isDistinct: Boolean): String = { @@ -340,37 +348,6 @@ case class ListAgg( s"$prettyName($distinct${child.sql}, ${delimiter.sql})$withinGroup" } - private[this] def orderValuesField: Seq[StructField] = { - orderExpressions.zipWithIndex.map { - case (order, i) => StructField(s"sortOrderValue[$i]", order.dataType) - } - } - - private[this] def evalOrderValues(internalRow: InternalRow): Seq[Any] = { - orderExpressions.map(order => convertToBufferElement(order.child.eval(internalRow))) - } - - private[this] def bufferOrdering: Ordering[InternalRow] = { - val bufferSortOrder = orderExpressions.zipWithIndex.map { - case (originalOrder, i) => - originalOrder.copy( - // first value is the evaluated child so add +1 for order's values - child = BoundReference(i + 1, originalOrder.dataType, originalOrder.child.nullable) - ) - } - new InterpretedOrdering(bufferSortOrder) - } - - override protected lazy val bufferElementType: DataType = { - if (noNeedSaveOrderValue) { - child.dataType - } else { - StructType( - StructField("value", child.dataType) - +: orderValuesField) - } - } - override def inputTypes: Seq[AbstractDataType] = TypeCollection( StringTypeWithCollation(supportsTrimCollation = true), @@ -381,8 +358,7 @@ case class ListAgg( BinaryType, NullType ) +: - orderExpressions.map(_ => AnyDataType) - + orderExpressions.map(_ => AnyDataType) override def checkInputDataTypes(): TypeCheckResult = { val matchInputTypes = super.checkInputDataTypes() @@ -407,54 +383,71 @@ case class ListAgg( } } + override def eval(buffer: mutable.ArrayBuffer[Any]): Any = { + if (buffer.nonEmpty) { + val sortedBufferWithoutNulls = sortBuffer(buffer) + concatSkippingNulls(sortedBufferWithoutNulls) + } else { + null + } + } + private[this] def sortBuffer(buffer: mutable.ArrayBuffer[Any]): mutable.ArrayBuffer[Any] = { if (!orderingFilled) { return buffer } if (noNeedSaveOrderValue) { val ascendingOrdering = PhysicalDataType.ordering(orderExpressions.head.dataType) - val ordering = if (orderExpressions.head.direction == Ascending) ascendingOrdering + val ordering = + if (orderExpressions.head.direction == Ascending) ascendingOrdering else ascendingOrdering.reverse buffer.sorted(ordering) } else { - buffer.asInstanceOf[mutable.ArrayBuffer[InternalRow]] + buffer + .asInstanceOf[mutable.ArrayBuffer[InternalRow]] .sorted(bufferOrdering) // drop order values after sort .map(_.get(0, child.dataType)) } } - private[this] def getDelimiterValue: Any = { - val delimiterValue = delimiter.eval() - if (delimiterValue == null) { - // default delimiter value - dataType match { - case _: StringType => UTF8String.fromString("") - case _: BinaryType => ByteArray.EMPTY_BYTE - } - } else { - delimiterValue + private[this] def bufferOrdering: Ordering[InternalRow] = { + val bufferSortOrder = orderExpressions.zipWithIndex.map { + case (originalOrder, i) => + originalOrder.copy( + // first value is the evaluated child so add +1 for order's values + child = BoundReference(i + 1, originalOrder.dataType, originalOrder.child.nullable) + ) } + new InterpretedOrdering(bufferSortOrder) } + override def orderingFilled: Boolean = orderExpressions.nonEmpty + private[this] def concatSkippingNulls(buffer: mutable.ArrayBuffer[Any]): Any = { val delimiterValue = getDelimiterValue dataType match { - case BinaryType => - val inputs = buffer.filter(_ != null).map(_.asInstanceOf[Array[Byte]]) - ByteArray.concatWS(delimiterValue.asInstanceOf[Array[Byte]], inputs.toSeq: _*) - case _: StringType => - val inputs = buffer.filter(_ != null).map(_.asInstanceOf[UTF8String]) - UTF8String.fromString(inputs.mkString(delimiterValue.toString)) + case BinaryType => + val inputs = buffer.filter(_ != null).map(_.asInstanceOf[Array[Byte]]) + ByteArray.concatWS(delimiterValue.asInstanceOf[Array[Byte]], inputs.toSeq: _*) + case _: StringType => + val inputs = buffer.filter(_ != null).map(_.asInstanceOf[UTF8String]) + UTF8String.fromString(inputs.mkString(delimiterValue.toString)) } } - override def eval(buffer: mutable.ArrayBuffer[Any]): Any = { - if (buffer.nonEmpty) { - val sortedBufferWithoutNulls = sortBuffer(buffer) - concatSkippingNulls(sortedBufferWithoutNulls) + override def dataType: DataType = child.dataType + + private[this] def getDelimiterValue: Any = { + val delimiterValue = delimiter.eval() + if (delimiterValue == null) { + // default delimiter value + dataType match { + case _: StringType => UTF8String.fromString("") + case _: BinaryType => ByteArray.EMPTY_BYTE + } } else { - null + delimiterValue } } @@ -471,20 +464,17 @@ case class ListAgg( buffer } - override def orderingFilled: Boolean = orderExpressions.nonEmpty + private[this] def evalOrderValues(internalRow: InternalRow): Seq[Any] = { + orderExpressions.map(order => convertToBufferElement(order.child.eval(internalRow))) + } + + override protected def convertToBufferElement(value: Any): Any = InternalRow.copyValue(value) + override def withOrderingWithinGroup(orderingWithinGroup: Seq[SortOrder]): AggregateFunction = copy(orderExpressions = orderingWithinGroup) override def children: Seq[Expression] = child +: delimiter +: orderExpressions.map(_.child) - override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = - copy( - child = newChildren.head, - delimiter = newChildren(1), - orderExpressions = newChildren.drop(2).zip(orderExpressions) - .map { case (newExpr, oldSortOrder) => oldSortOrder.copy(child = newExpr) } - ) - /** * Utility func to check if given order is defined and different from [[child]]. * @@ -500,4 +490,20 @@ case class ListAgg( } false } + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = + copy( + child = newChildren.head, + delimiter = newChildren(1), + orderExpressions = newChildren + .drop(2) + .zip(orderExpressions) + .map { case (newExpr, oldSortOrder) => oldSortOrder.copy(child = newExpr) } + ) + + private[this] def orderValuesField: Seq[StructField] = { + orderExpressions.zipWithIndex.map { + case (order, i) => StructField(s"sortOrderValue[$i]", order.dataType) + } + } } From 516567a0fe3e53ca90238735aa480a5d2d2226d2 Mon Sep 17 00:00:00 2001 From: Mikhail Nikoliukin Date: Tue, 5 Nov 2024 10:51:57 +0100 Subject: [PATCH 39/58] [SPARK-42746] listagg with columnName from tests --- .../src/test/scala/org/apache/spark/sql/FunctionTestSuite.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/FunctionTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/FunctionTestSuite.scala index d40e0cbe55bca..40b66bcb8358d 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/FunctionTestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/FunctionTestSuite.scala @@ -65,8 +65,6 @@ class FunctionTestSuite extends ConnectFunSuite { testEquals("avg/mean", avg("a"), avg(a), mean(a), mean("a")) testEquals("collect_list", collect_list("a"), collect_list(a)) testEquals("collect_set", collect_set("a"), collect_set(a)) - testEquals("listagg", listagg("a"), listagg(a)) - testEquals("listagg_distinct", listagg_distinct("a"), listagg_distinct(a)) testEquals("corr", corr("a", "b"), corr(a, b)) testEquals( "count_distinct", From 27f445d9723cb6221741007fa5d3aa5130e42860 Mon Sep 17 00:00:00 2001 From: Mikhail Nikoliukin Date: Tue, 5 Nov 2024 11:32:10 +0100 Subject: [PATCH 40/58] [SPARK-42746] fix java style --- .../java/org/apache/spark/unsafe/array/ByteArraySuite.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/array/ByteArraySuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/array/ByteArraySuite.java index 0fb83088b49f9..5e221b4e359d4 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/array/ByteArraySuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/array/ByteArraySuite.java @@ -101,7 +101,8 @@ public void testConcatWS() { byte[] x1 = new byte[]{(byte) 1, (byte) 2, (byte) 3}; byte[] y1 = new byte[]{(byte) 4, (byte) 5, (byte) 6}; byte[] result1 = ByteArray.concatWS(separator, x1, y1); - byte[] expected1 = new byte[]{(byte) 1, (byte) 2, (byte) 3, (byte) 42, (byte) 4, (byte) 5, (byte) 6}; + byte[] expected1 = new byte[]{(byte) 1, (byte) 2, (byte) 3, (byte) 42, + (byte) 4, (byte) 5, (byte) 6}; Assertions.assertArrayEquals(expected1, result1); byte[] x2 = new byte[]{(byte) 1, (byte) 2, (byte) 3}; From fc722df737f516985e4e6ee949a7ab949beb5312 Mon Sep 17 00:00:00 2001 From: Mikhail Nikoliukin Date: Tue, 5 Nov 2024 17:55:57 +0100 Subject: [PATCH 41/58] [SPARK-42746] improve doc and errors --- .../resources/error/error-conditions.json | 2 +- .../expressions/aggregate/collect.scala | 24 +++++++++++++++---- .../sql/errors/QueryCompilationErrors.scala | 6 ++--- .../org/apache/spark/sql/SQLQuerySuite.scala | 8 +++---- 4 files changed, 28 insertions(+), 12 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 3ff52af2f598e..f7003fd949fa9 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -1565,7 +1565,7 @@ }, "FUNCTION_AND_ORDER_EXPRESSION_MISMATCH" : { "message" : [ - "The function arguments should match the order by expression when use DISTINCT." + "The arguments of the function do not match to ordering within group when use DISTINCT." ], "sqlState" : "42822" }, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala index 44983159ec202..29807f3ce18e6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala @@ -272,13 +272,27 @@ private[aggregate] object CollectTopK { } } +// scalastyle:off line.size.limit @ExpressionDescription( - usage = "_FUNC_(expr) - Returns the concatenation of non-null input values," + - " separated by the delimiter.", + usage = """ + _FUNC_(expr[, delimiter])[ WITHIN GROUP (ORDER BY key [ASC | DESC] [,...])] - Returns + the concatenation of non-null input values, separated by the delimiter ordered by key. + If all values are null, null is returned. + """, + arguments = """ + Arguments: + * expr - a string or binary expression to be concatenated. + * delimiter - an optional string or binary foldable expression used to separate the input values. + If null, the concatenation will be performed without a delimiter. Default is null. + * key - an optional expression for ordering the input values. Multiple keys can be specified. + If none are specified, the order of the rows in the result is non-deterministic. + """, examples = """ Examples: > SELECT _FUNC_(col) FROM VALUES ('a'), ('b'), ('c') AS tab(col); abc + > SELECT _FUNC_(col) WITHIN GROUP (ORDER BY col DESC) FROM VALUES ('a'), ('b'), ('c') AS tab(col); + cba > SELECT _FUNC_(col) FROM VALUES ('a'), (NULL), ('b') AS tab(col); ab > SELECT _FUNC_(col) FROM VALUES ('a'), ('a') AS tab(col); @@ -288,15 +302,17 @@ private[aggregate] object CollectTopK { > SELECT _FUNC_(col, ', ') FROM VALUES ('a'), ('b'), ('c') AS tab(col); a, b, c > SELECT _FUNC_(col) FROM VALUES (NULL), (NULL) AS tab(col); - NULL + NULL """, note = """ - If order is not specified the function is non-deterministic because + * If the order is not specified, the function is non-deterministic because the order of the rows may be non-deterministic after a shuffle. + * If DISTINCT is specified, then expr and key must be the same expression. """, group = "agg_funcs", since = "4.0.0" ) +// scalastyle:on line.size.limit case class ListAgg( child: Expression, delimiter: Expression = Literal(null), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index 88581cade51ca..4a370bf824c37 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -1051,14 +1051,14 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat def functionAndOrderExpressionMismatchError( functionName: String, - functionExpr: Expression, + functionArgs: Expression, orderExpr: Seq[SortOrder]): Throwable = { new AnalysisException( errorClass = "FUNCTION_AND_ORDER_EXPRESSION_MISMATCH", messageParameters = Map( "functionName" -> toSQLId(functionName), - "functionExpr" -> toSQLExpr(functionExpr), - "orderExpr" -> orderExpr.map(order => toSQLExpr(order.child)).mkString(","))) + "functionArgs" -> toSQLExpr(functionArgs), + "orderExpr" -> orderExpr.map(order => toSQLExpr(order.child)).mkString(", "))) } def wrongCommandForObjectTypeError( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index f00edfa03448d..a5f03297a56a3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -346,7 +346,7 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark condition = "FUNCTION_AND_ORDER_EXPRESSION_MISMATCH", parameters = Map( "functionName" -> "`listagg`", - "functionExpr" -> "\"a\"", + "functionArgs" -> "\"a\"", "orderExpr" -> "\"b\""), context = ExpectedContext( fragment = "listagg(distinct a) within group (order by b)", @@ -360,8 +360,8 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark condition = "FUNCTION_AND_ORDER_EXPRESSION_MISMATCH", parameters = Map( "functionName" -> "`listagg`", - "functionExpr" -> "\"a\"", - "orderExpr" -> "\"a\",\"b\""), + "functionArgs" -> "\"a\"", + "orderExpr" -> "\"a\", \"b\""), context = ExpectedContext( fragment = "listagg(distinct a) within group (order by a, b)", start = 7, @@ -412,7 +412,7 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark condition = "FUNCTION_AND_ORDER_EXPRESSION_MISMATCH", parameters = Map( "functionName" -> "`listagg`", - "functionExpr" -> "\"collate(c1, utf8_lcase)\"", + "functionArgs" -> "\"collate(c1, utf8_lcase)\"", "orderExpr" -> "\"collate(c1, utf8_binary)\""), context = ExpectedContext( fragment = From e3b1a26c801f3383510c603fd9283c8f55609060 Mon Sep 17 00:00:00 2001 From: Mikhail Nikoliukin Date: Thu, 14 Nov 2024 14:17:15 +0100 Subject: [PATCH 42/58] [SPARK-42746] add golden files for listagg --- .../analyzer-results/listagg.sql.out | 506 ++++++++++++++++++ .../resources/sql-tests/inputs/listagg.sql | 48 ++ .../sql-tests/results/listagg.sql.out | 436 +++++++++++++++ 3 files changed, 990 insertions(+) create mode 100644 sql/core/src/test/resources/sql-tests/analyzer-results/listagg.sql.out create mode 100644 sql/core/src/test/resources/sql-tests/inputs/listagg.sql create mode 100644 sql/core/src/test/resources/sql-tests/results/listagg.sql.out diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/listagg.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/listagg.sql.out new file mode 100644 index 0000000000000..d199947c30a1e --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/listagg.sql.out @@ -0,0 +1,506 @@ +-- Automatically generated by SQLQueryTestSuite +-- !query +CREATE TEMP VIEW df AS +SELECT * FROM (VALUES ('a', 'b'), ('a', 'c'), ('b', 'c'), ('b', 'd'), (NULL, NULL)) AS t(a, b) +-- !query analysis +CreateViewCommand `df`, SELECT * FROM (VALUES ('a', 'b'), ('a', 'c'), ('b', 'c'), ('b', 'd'), (NULL, NULL)) AS t(a, b), false, false, LocalTempView, UNSUPPORTED, true + +- Project [a#x, b#x] + +- SubqueryAlias t + +- Project [col1#x AS a#x, col2#x AS b#x] + +- LocalRelation [col1#x, col2#x] + + +-- !query +CREATE TEMP VIEW df2 AS +SELECT * FROM (VALUES (1, true), (2, false), (3, false)) AS t(a, b) +-- !query analysis +CreateViewCommand `df2`, SELECT * FROM (VALUES (1, true), (2, false), (3, false)) AS t(a, b), false, false, LocalTempView, UNSUPPORTED, true + +- Project [a#x, b#x] + +- SubqueryAlias t + +- Project [col1#x AS a#x, col2#x AS b#x] + +- LocalRelation [col1#x, col2#x] + + +-- !query +SELECT listagg(b) FROM df GROUP BY a +-- !query analysis +Aggregate [a#x], [listagg(b#x, null, 0, 0) AS listagg(b, NULL)#x] ++- SubqueryAlias df + +- View (`df`, [a#x, b#x]) + +- Project [cast(a#x as string) AS a#x, cast(b#x as string) AS b#x] + +- Project [a#x, b#x] + +- SubqueryAlias t + +- Project [col1#x AS a#x, col2#x AS b#x] + +- LocalRelation [col1#x, col2#x] + + +-- !query +SELECT string_agg(b) FROM df GROUP BY a +-- !query analysis +Aggregate [a#x], [string_agg(b#x, null, 0, 0) AS string_agg(b, NULL)#x] ++- SubqueryAlias df + +- View (`df`, [a#x, b#x]) + +- Project [cast(a#x as string) AS a#x, cast(b#x as string) AS b#x] + +- Project [a#x, b#x] + +- SubqueryAlias t + +- Project [col1#x AS a#x, col2#x AS b#x] + +- LocalRelation [col1#x, col2#x] + + +-- !query +SELECT listagg(b, NULL) FROM df GROUP BY a +-- !query analysis +Aggregate [a#x], [listagg(b#x, null, 0, 0) AS listagg(b, NULL)#x] ++- SubqueryAlias df + +- View (`df`, [a#x, b#x]) + +- Project [cast(a#x as string) AS a#x, cast(b#x as string) AS b#x] + +- Project [a#x, b#x] + +- SubqueryAlias t + +- Project [col1#x AS a#x, col2#x AS b#x] + +- LocalRelation [col1#x, col2#x] + + +-- !query +SELECT listagg(b) FROM df WHERE 1 != 1 +-- !query analysis +Aggregate [listagg(b#x, null, 0, 0) AS listagg(b, NULL)#x] ++- Filter NOT (1 = 1) + +- SubqueryAlias df + +- View (`df`, [a#x, b#x]) + +- Project [cast(a#x as string) AS a#x, cast(b#x as string) AS b#x] + +- Project [a#x, b#x] + +- SubqueryAlias t + +- Project [col1#x AS a#x, col2#x AS b#x] + +- LocalRelation [col1#x, col2#x] + + +-- !query +SELECT listagg(b, '|') FROM df GROUP BY a +-- !query analysis +Aggregate [a#x], [listagg(b#x, |, 0, 0) AS listagg(b, |)#x] ++- SubqueryAlias df + +- View (`df`, [a#x, b#x]) + +- Project [cast(a#x as string) AS a#x, cast(b#x as string) AS b#x] + +- Project [a#x, b#x] + +- SubqueryAlias t + +- Project [col1#x AS a#x, col2#x AS b#x] + +- LocalRelation [col1#x, col2#x] + + +-- !query +SELECT listagg(a) FROM df +-- !query analysis +Aggregate [listagg(a#x, null, 0, 0) AS listagg(a, NULL)#x] ++- SubqueryAlias df + +- View (`df`, [a#x, b#x]) + +- Project [cast(a#x as string) AS a#x, cast(b#x as string) AS b#x] + +- Project [a#x, b#x] + +- SubqueryAlias t + +- Project [col1#x AS a#x, col2#x AS b#x] + +- LocalRelation [col1#x, col2#x] + + +-- !query +SELECT listagg(DISTINCT a) FROM df +-- !query analysis +Aggregate [listagg(distinct a#x, null, 0, 0) AS listagg(DISTINCT a, NULL)#x] ++- SubqueryAlias df + +- View (`df`, [a#x, b#x]) + +- Project [cast(a#x as string) AS a#x, cast(b#x as string) AS b#x] + +- Project [a#x, b#x] + +- SubqueryAlias t + +- Project [col1#x AS a#x, col2#x AS b#x] + +- LocalRelation [col1#x, col2#x] + + +-- !query +SELECT listagg(a) WITHIN GROUP (ORDER BY a) FROM df +-- !query analysis +Aggregate [listagg(a#x, null, a#x ASC NULLS FIRST, 0, 0) AS listagg(a, NULL) WITHIN GROUP (ORDER BY a ASC NULLS FIRST)#x] ++- SubqueryAlias df + +- View (`df`, [a#x, b#x]) + +- Project [cast(a#x as string) AS a#x, cast(b#x as string) AS b#x] + +- Project [a#x, b#x] + +- SubqueryAlias t + +- Project [col1#x AS a#x, col2#x AS b#x] + +- LocalRelation [col1#x, col2#x] + + +-- !query +SELECT listagg(a) WITHIN GROUP (ORDER BY a DESC) FROM df +-- !query analysis +Aggregate [listagg(a#x, null, a#x DESC NULLS LAST, 0, 0) AS listagg(a, NULL) WITHIN GROUP (ORDER BY a DESC NULLS LAST)#x] ++- SubqueryAlias df + +- View (`df`, [a#x, b#x]) + +- Project [cast(a#x as string) AS a#x, cast(b#x as string) AS b#x] + +- Project [a#x, b#x] + +- SubqueryAlias t + +- Project [col1#x AS a#x, col2#x AS b#x] + +- LocalRelation [col1#x, col2#x] + + +-- !query +SELECT listagg(a) WITHIN GROUP (ORDER BY a DESC) OVER (PARTITION BY b) FROM df +-- !query analysis +Project [listagg(a, NULL) WITHIN GROUP (ORDER BY a DESC NULLS LAST) OVER (PARTITION BY b ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)#x] ++- Project [a#x, b#x, listagg(a, NULL) WITHIN GROUP (ORDER BY a DESC NULLS LAST) OVER (PARTITION BY b ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)#x, listagg(a, NULL) WITHIN GROUP (ORDER BY a DESC NULLS LAST) OVER (PARTITION BY b ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)#x] + +- Window [listagg(a#x, null, a#x DESC NULLS LAST, 0, 0) windowspecdefinition(b#x, specifiedwindowframe(RowFrame, unboundedpreceding$(), unboundedfollowing$())) AS listagg(a, NULL) WITHIN GROUP (ORDER BY a DESC NULLS LAST) OVER (PARTITION BY b ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)#x], [b#x] + +- Project [a#x, b#x] + +- SubqueryAlias df + +- View (`df`, [a#x, b#x]) + +- Project [cast(a#x as string) AS a#x, cast(b#x as string) AS b#x] + +- Project [a#x, b#x] + +- SubqueryAlias t + +- Project [col1#x AS a#x, col2#x AS b#x] + +- LocalRelation [col1#x, col2#x] + + +-- !query +SELECT listagg(a) WITHIN GROUP (ORDER BY b) FROM df +-- !query analysis +Aggregate [listagg(a#x, null, b#x ASC NULLS FIRST, 0, 0) AS listagg(a, NULL) WITHIN GROUP (ORDER BY b ASC NULLS FIRST)#x] ++- SubqueryAlias df + +- View (`df`, [a#x, b#x]) + +- Project [cast(a#x as string) AS a#x, cast(b#x as string) AS b#x] + +- Project [a#x, b#x] + +- SubqueryAlias t + +- Project [col1#x AS a#x, col2#x AS b#x] + +- LocalRelation [col1#x, col2#x] + + +-- !query +SELECT listagg(a) WITHIN GROUP (ORDER BY b DESC) FROM df +-- !query analysis +Aggregate [listagg(a#x, null, b#x DESC NULLS LAST, 0, 0) AS listagg(a, NULL) WITHIN GROUP (ORDER BY b DESC NULLS LAST)#x] ++- SubqueryAlias df + +- View (`df`, [a#x, b#x]) + +- Project [cast(a#x as string) AS a#x, cast(b#x as string) AS b#x] + +- Project [a#x, b#x] + +- SubqueryAlias t + +- Project [col1#x AS a#x, col2#x AS b#x] + +- LocalRelation [col1#x, col2#x] + + +-- !query +SELECT listagg(a, '|') WITHIN GROUP (ORDER BY b DESC) FROM df +-- !query analysis +Aggregate [listagg(a#x, |, b#x DESC NULLS LAST, 0, 0) AS listagg(a, |) WITHIN GROUP (ORDER BY b DESC NULLS LAST)#x] ++- SubqueryAlias df + +- View (`df`, [a#x, b#x]) + +- Project [cast(a#x as string) AS a#x, cast(b#x as string) AS b#x] + +- Project [a#x, b#x] + +- SubqueryAlias t + +- Project [col1#x AS a#x, col2#x AS b#x] + +- LocalRelation [col1#x, col2#x] + + +-- !query +SELECT listagg(a) WITHIN GROUP (ORDER BY b DESC, a ASC) FROM df +-- !query analysis +Aggregate [listagg(a#x, null, b#x DESC NULLS LAST, a#x ASC NULLS FIRST, 0, 0) AS listagg(a, NULL) WITHIN GROUP (ORDER BY b DESC NULLS LAST, a ASC NULLS FIRST)#x] ++- SubqueryAlias df + +- View (`df`, [a#x, b#x]) + +- Project [cast(a#x as string) AS a#x, cast(b#x as string) AS b#x] + +- Project [a#x, b#x] + +- SubqueryAlias t + +- Project [col1#x AS a#x, col2#x AS b#x] + +- LocalRelation [col1#x, col2#x] + + +-- !query +SELECT listagg(a) WITHIN GROUP (ORDER BY b DESC, a DESC) FROM df +-- !query analysis +Aggregate [listagg(a#x, null, b#x DESC NULLS LAST, a#x DESC NULLS LAST, 0, 0) AS listagg(a, NULL) WITHIN GROUP (ORDER BY b DESC NULLS LAST, a DESC NULLS LAST)#x] ++- SubqueryAlias df + +- View (`df`, [a#x, b#x]) + +- Project [cast(a#x as string) AS a#x, cast(b#x as string) AS b#x] + +- Project [a#x, b#x] + +- SubqueryAlias t + +- Project [col1#x AS a#x, col2#x AS b#x] + +- LocalRelation [col1#x, col2#x] + + +-- !query +SELECT listagg(c1) FROM (VALUES (X'DEAD'), (X'BEEF')) AS t(c1) +-- !query analysis +Aggregate [listagg(c1#x, null, 0, 0) AS listagg(c1, NULL)#x] ++- SubqueryAlias t + +- Project [col1#x AS c1#x] + +- LocalRelation [col1#x] + + +-- !query +SELECT listagg(c1, NULL) FROM (VALUES (X'DEAD'), (X'BEEF')) AS t(c1) +-- !query analysis +Aggregate [listagg(c1#x, null, 0, 0) AS listagg(c1, NULL)#x] ++- SubqueryAlias t + +- Project [col1#x AS c1#x] + +- LocalRelation [col1#x] + + +-- !query +SELECT listagg(c1, X'42') FROM (VALUES (X'DEAD'), (X'BEEF')) AS t(c1) +-- !query analysis +Aggregate [listagg(c1#x, 0x42, 0, 0) AS listagg(c1, X'42')#x] ++- SubqueryAlias t + +- Project [col1#x AS c1#x] + +- LocalRelation [col1#x] + + +-- !query +SELECT listagg(a), listagg(b, ',') FROM df2 +-- !query analysis +Aggregate [listagg(cast(a#x as string), null, 0, 0) AS listagg(a, NULL)#x, listagg(cast(b#x as string), ,, 0, 0) AS listagg(b, ,)#x] ++- SubqueryAlias df2 + +- View (`df2`, [a#x, b#x]) + +- Project [cast(a#x as int) AS a#x, cast(b#x as boolean) AS b#x] + +- Project [a#x, b#x] + +- SubqueryAlias t + +- Project [col1#x AS a#x, col2#x AS b#x] + +- LocalRelation [col1#x, col2#x] + + +-- !query +SELECT listagg(c1) FROM (VALUES (ARRAY['a', 'b'])) AS t(c1) +-- !query analysis +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "PARSE_SYNTAX_ERROR", + "sqlState" : "42601", + "messageParameters" : { + "error" : "','", + "hint" : "" + } +} + + +-- !query +SELECT listagg(c1, ', ') FROM (VALUES (X'DEAD'), (X'BEEF')) AS t(c1) +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.DATA_DIFF_TYPES", + "sqlState" : "42K09", + "messageParameters" : { + "dataType" : "(\"BINARY\" or \"STRING\")", + "functionName" : "`listagg`", + "sqlExpr" : "\"listagg(c1, , )\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 24, + "fragment" : "listagg(c1, ', ')" + } ] +} + + +-- !query +SELECT listagg(b, a) FROM df GROUP BY a +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT", + "sqlState" : "42K09", + "messageParameters" : { + "inputExpr" : "\"a\"", + "inputName" : "`delimiter`", + "inputType" : "\"STRING\"", + "sqlExpr" : "\"listagg(b, a)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 20, + "fragment" : "listagg(b, a)" + } ] +} + + +-- !query +SELECT listagg(a) OVER (ORDER BY a) FROM df +-- !query analysis +Project [listagg(a, NULL) OVER (ORDER BY a ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)#x] ++- Project [a#x, listagg(a, NULL) OVER (ORDER BY a ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)#x, listagg(a, NULL) OVER (ORDER BY a ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)#x] + +- Window [listagg(a#x, null, 0, 0) windowspecdefinition(a#x ASC NULLS FIRST, specifiedwindowframe(RangeFrame, unboundedpreceding$(), currentrow$())) AS listagg(a, NULL) OVER (ORDER BY a ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)#x], [a#x ASC NULLS FIRST] + +- Project [a#x] + +- SubqueryAlias df + +- View (`df`, [a#x, b#x]) + +- Project [cast(a#x as string) AS a#x, cast(b#x as string) AS b#x] + +- Project [a#x, b#x] + +- SubqueryAlias t + +- Project [col1#x AS a#x, col2#x AS b#x] + +- LocalRelation [col1#x, col2#x] + + +-- !query +SELECT listagg(a) WITHIN GROUP (ORDER BY a) OVER (ORDER BY a) FROM df +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "INVALID_WINDOW_SPEC_FOR_AGGREGATION_FUNC", + "sqlState" : "42601", + "messageParameters" : { + "aggFunc" : "\"listagg(a, NULL, a)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 61, + "fragment" : "listagg(a) WITHIN GROUP (ORDER BY a) OVER (ORDER BY a)" + } ] +} + + +-- !query +SELECT string_agg(a) WITHIN GROUP (ORDER BY a) OVER (ORDER BY a) FROM df +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "INVALID_WINDOW_SPEC_FOR_AGGREGATION_FUNC", + "sqlState" : "42601", + "messageParameters" : { + "aggFunc" : "\"listagg(a, NULL, a)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 64, + "fragment" : "string_agg(a) WITHIN GROUP (ORDER BY a) OVER (ORDER BY a)" + } ] +} + + +-- !query +SELECT listagg(DISTINCT a) OVER (ORDER BY a) FROM df +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DISTINCT_WINDOW_FUNCTION_UNSUPPORTED", + "sqlState" : "0A000", + "messageParameters" : { + "windowExpr" : "\"listagg(DISTINCT a, NULL) OVER (ORDER BY a ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 44, + "fragment" : "listagg(DISTINCT a) OVER (ORDER BY a)" + } ] +} + + +-- !query +SELECT listagg(DISTINCT a) WITHIN GROUP (ORDER BY b) FROM df +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "FUNCTION_AND_ORDER_EXPRESSION_MISMATCH", + "sqlState" : "42822", + "messageParameters" : { + "functionArgs" : "\"a\"", + "functionName" : "`listagg`", + "orderExpr" : "\"b\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 52, + "fragment" : "listagg(DISTINCT a) WITHIN GROUP (ORDER BY b)" + } ] +} + + +-- !query +SELECT listagg(DISTINCT a) WITHIN GROUP (ORDER BY a, b) FROM df +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "FUNCTION_AND_ORDER_EXPRESSION_MISMATCH", + "sqlState" : "42822", + "messageParameters" : { + "functionArgs" : "\"a\"", + "functionName" : "`listagg`", + "orderExpr" : "\"a\", \"b\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 55, + "fragment" : "listagg(DISTINCT a) WITHIN GROUP (ORDER BY a, b)" + } ] +} + + +-- !query +SELECT listagg(c1) WITHIN GROUP (ORDER BY c1 COLLATE utf8_binary) FROM (VALUES ('a'), ('A'), ('b'), ('B')) AS t(c1) +-- !query analysis +Aggregate [listagg(c1#x, null, collate(c1#x, utf8_binary) ASC NULLS FIRST, 0, 0) AS listagg(c1, NULL) WITHIN GROUP (ORDER BY collate(c1, utf8_binary) ASC NULLS FIRST)#x] ++- SubqueryAlias t + +- Project [col1#x AS c1#x] + +- LocalRelation [col1#x] + + +-- !query +SELECT listagg(c1) WITHIN GROUP (ORDER BY c1 COLLATE utf8_lcase) FROM (VALUES ('a'), ('A'), ('b'), ('B')) AS t(c1) +-- !query analysis +Aggregate [listagg(c1#x, null, collate(c1#x, utf8_lcase) ASC NULLS FIRST, 0, 0) AS listagg(c1, NULL) WITHIN GROUP (ORDER BY collate(c1, utf8_lcase) ASC NULLS FIRST)#x] ++- SubqueryAlias t + +- Project [col1#x AS c1#x] + +- LocalRelation [col1#x] + + +-- !query +SELECT listagg(DISTINCT c1 COLLATE utf8_binary) FROM (VALUES ('a'), ('A'), ('b'), ('B')) AS t(c1) +-- !query analysis +Aggregate [listagg(distinct collate(c1#x, utf8_binary), null, 0, 0) AS listagg(DISTINCT collate(c1, utf8_binary), NULL)#x] ++- SubqueryAlias t + +- Project [col1#x AS c1#x] + +- LocalRelation [col1#x] + + +-- !query +SELECT listagg(DISTINCT c1 COLLATE utf8_lcase) FROM (VALUES ('a'), ('A'), ('b'), ('B')) AS t(c1) +-- !query analysis +Aggregate [listagg(distinct collate(c1#x, utf8_lcase), null, 0, 0) AS listagg(DISTINCT collate(c1, utf8_lcase), NULL)#x] ++- SubqueryAlias t + +- Project [col1#x AS c1#x] + +- LocalRelation [col1#x] + + +-- !query +SELECT listagg(DISTINCT c1 COLLATE utf8_lcase) WITHIN GROUP (ORDER BY c1 COLLATE utf8_lcase) FROM (VALUES ('a'), ('B'), ('b'), ('A')) AS t(c1) +-- !query analysis +Aggregate [listagg(distinct collate(c1#x, utf8_lcase), null, collate(c1#x, utf8_lcase) ASC NULLS FIRST, 0, 0) AS listagg(DISTINCT collate(c1, utf8_lcase), NULL) WITHIN GROUP (ORDER BY collate(c1, utf8_lcase) ASC NULLS FIRST)#x] ++- SubqueryAlias t + +- Project [col1#x AS c1#x] + +- LocalRelation [col1#x] + + +-- !query +SELECT listagg(DISTINCT c1 COLLATE utf8_lcase) WITHIN GROUP (ORDER BY c1 COLLATE utf8_binary) FROM (VALUES ('a'), ('b'), ('A'), ('B')) AS t(c1) +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "FUNCTION_AND_ORDER_EXPRESSION_MISMATCH", + "sqlState" : "42822", + "messageParameters" : { + "functionArgs" : "\"collate(c1, utf8_lcase)\"", + "functionName" : "`listagg`", + "orderExpr" : "\"collate(c1, utf8_binary)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 93, + "fragment" : "listagg(DISTINCT c1 COLLATE utf8_lcase) WITHIN GROUP (ORDER BY c1 COLLATE utf8_binary)" + } ] +} diff --git a/sql/core/src/test/resources/sql-tests/inputs/listagg.sql b/sql/core/src/test/resources/sql-tests/inputs/listagg.sql new file mode 100644 index 0000000000000..05f02961425fc --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/listagg.sql @@ -0,0 +1,48 @@ +-- Create temporary views +CREATE TEMP VIEW df AS +SELECT * FROM (VALUES ('a', 'b'), ('a', 'c'), ('b', 'c'), ('b', 'd'), (NULL, NULL)) AS t(a, b); + +CREATE TEMP VIEW df2 AS +SELECT * FROM (VALUES (1, true), (2, false), (3, false)) AS t(a, b); + +-- Test cases for listagg function +SELECT listagg(b) FROM df GROUP BY a; +SELECT string_agg(b) FROM df GROUP BY a; +SELECT listagg(b, NULL) FROM df GROUP BY a; +SELECT listagg(b) FROM df WHERE 1 != 1; +SELECT listagg(b, '|') FROM df GROUP BY a; +SELECT listagg(a) FROM df; +SELECT listagg(DISTINCT a) FROM df; +SELECT listagg(a) WITHIN GROUP (ORDER BY a) FROM df; +SELECT listagg(a) WITHIN GROUP (ORDER BY a DESC) FROM df; +SELECT listagg(a) WITHIN GROUP (ORDER BY a DESC) OVER (PARTITION BY b) FROM df; +SELECT listagg(a) WITHIN GROUP (ORDER BY b) FROM df; +SELECT listagg(a) WITHIN GROUP (ORDER BY b DESC) FROM df; +SELECT listagg(a, '|') WITHIN GROUP (ORDER BY b DESC) FROM df; +SELECT listagg(a) WITHIN GROUP (ORDER BY b DESC, a ASC) FROM df; +SELECT listagg(a) WITHIN GROUP (ORDER BY b DESC, a DESC) FROM df; +SELECT listagg(c1) FROM (VALUES (X'DEAD'), (X'BEEF')) AS t(c1); +SELECT listagg(c1, NULL) FROM (VALUES (X'DEAD'), (X'BEEF')) AS t(c1); +SELECT listagg(c1, X'42') FROM (VALUES (X'DEAD'), (X'BEEF')) AS t(c1); +SELECT listagg(a), listagg(b, ',') FROM df2; + +-- Error cases +SELECT listagg(c1) FROM (VALUES (ARRAY['a', 'b'])) AS t(c1); +SELECT listagg(c1, ', ') FROM (VALUES (X'DEAD'), (X'BEEF')) AS t(c1); +SELECT listagg(b, a) FROM df GROUP BY a; +SELECT listagg(a) OVER (ORDER BY a) FROM df; +SELECT listagg(a) WITHIN GROUP (ORDER BY a) OVER (ORDER BY a) FROM df; +SELECT string_agg(a) WITHIN GROUP (ORDER BY a) OVER (ORDER BY a) FROM df; +SELECT listagg(DISTINCT a) OVER (ORDER BY a) FROM df; +SELECT listagg(DISTINCT a) WITHIN GROUP (ORDER BY b) FROM df; +SELECT listagg(DISTINCT a) WITHIN GROUP (ORDER BY a, b) FROM df; + +-- Test cases with collations +SELECT listagg(c1) WITHIN GROUP (ORDER BY c1 COLLATE utf8_binary) FROM (VALUES ('a'), ('A'), ('b'), ('B')) AS t(c1); +SELECT listagg(c1) WITHIN GROUP (ORDER BY c1 COLLATE utf8_lcase) FROM (VALUES ('a'), ('A'), ('b'), ('B')) AS t(c1); +SELECT listagg(DISTINCT c1 COLLATE utf8_binary) FROM (VALUES ('a'), ('A'), ('b'), ('B')) AS t(c1); +SELECT listagg(DISTINCT c1 COLLATE utf8_lcase) FROM (VALUES ('a'), ('A'), ('b'), ('B')) AS t(c1); +SELECT listagg(DISTINCT c1 COLLATE utf8_lcase) WITHIN GROUP (ORDER BY c1 COLLATE utf8_lcase) FROM (VALUES ('a'), ('B'), ('b'), ('A')) AS t(c1); + +-- Error case with collations +SELECT listagg(DISTINCT c1 COLLATE utf8_lcase) WITHIN GROUP (ORDER BY c1 COLLATE utf8_binary) FROM (VALUES ('a'), ('b'), ('A'), ('B')) AS t(c1); \ No newline at end of file diff --git a/sql/core/src/test/resources/sql-tests/results/listagg.sql.out b/sql/core/src/test/resources/sql-tests/results/listagg.sql.out new file mode 100644 index 0000000000000..39c7d1ce7a6c8 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/listagg.sql.out @@ -0,0 +1,436 @@ +-- Automatically generated by SQLQueryTestSuite +-- !query +CREATE TEMP VIEW df AS +SELECT * FROM (VALUES ('a', 'b'), ('a', 'c'), ('b', 'c'), ('b', 'd'), (NULL, NULL)) AS t(a, b) +-- !query schema +struct<> +-- !query output + + + +-- !query +CREATE TEMP VIEW df2 AS +SELECT * FROM (VALUES (1, true), (2, false), (3, false)) AS t(a, b) +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT listagg(b) FROM df GROUP BY a +-- !query schema +struct +-- !query output +NULL +bc +cd + + +-- !query +SELECT string_agg(b) FROM df GROUP BY a +-- !query schema +struct +-- !query output +NULL +bc +cd + + +-- !query +SELECT listagg(b, NULL) FROM df GROUP BY a +-- !query schema +struct +-- !query output +NULL +bc +cd + + +-- !query +SELECT listagg(b) FROM df WHERE 1 != 1 +-- !query schema +struct +-- !query output +NULL + + +-- !query +SELECT listagg(b, '|') FROM df GROUP BY a +-- !query schema +struct +-- !query output +NULL +b|c +c|d + + +-- !query +SELECT listagg(a) FROM df +-- !query schema +struct +-- !query output +aabb + + +-- !query +SELECT listagg(DISTINCT a) FROM df +-- !query schema +struct +-- !query output +ab + + +-- !query +SELECT listagg(a) WITHIN GROUP (ORDER BY a) FROM df +-- !query schema +struct +-- !query output +aabb + + +-- !query +SELECT listagg(a) WITHIN GROUP (ORDER BY a DESC) FROM df +-- !query schema +struct +-- !query output +bbaa + + +-- !query +SELECT listagg(a) WITHIN GROUP (ORDER BY a DESC) OVER (PARTITION BY b) FROM df +-- !query schema +struct +-- !query output +NULL +a +b +ba +ba + + +-- !query +SELECT listagg(a) WITHIN GROUP (ORDER BY b) FROM df +-- !query schema +struct +-- !query output +aabb + + +-- !query +SELECT listagg(a) WITHIN GROUP (ORDER BY b DESC) FROM df +-- !query schema +struct +-- !query output +baba + + +-- !query +SELECT listagg(a, '|') WITHIN GROUP (ORDER BY b DESC) FROM df +-- !query schema +struct +-- !query output +b|a|b|a + + +-- !query +SELECT listagg(a) WITHIN GROUP (ORDER BY b DESC, a ASC) FROM df +-- !query schema +struct +-- !query output +baba + + +-- !query +SELECT listagg(a) WITHIN GROUP (ORDER BY b DESC, a DESC) FROM df +-- !query schema +struct +-- !query output +bbaa + + +-- !query +SELECT listagg(c1) FROM (VALUES (X'DEAD'), (X'BEEF')) AS t(c1) +-- !query schema +struct +-- !query output +ޭ�� + + +-- !query +SELECT listagg(c1, NULL) FROM (VALUES (X'DEAD'), (X'BEEF')) AS t(c1) +-- !query schema +struct +-- !query output +ޭ�� + + +-- !query +SELECT listagg(c1, X'42') FROM (VALUES (X'DEAD'), (X'BEEF')) AS t(c1) +-- !query schema +struct +-- !query output +ޭB�� + + +-- !query +SELECT listagg(a), listagg(b, ',') FROM df2 +-- !query schema +struct +-- !query output +123 true,false,false + + +-- !query +SELECT listagg(c1) FROM (VALUES (ARRAY['a', 'b'])) AS t(c1) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "PARSE_SYNTAX_ERROR", + "sqlState" : "42601", + "messageParameters" : { + "error" : "','", + "hint" : "" + } +} + + +-- !query +SELECT listagg(c1, ', ') FROM (VALUES (X'DEAD'), (X'BEEF')) AS t(c1) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.DATA_DIFF_TYPES", + "sqlState" : "42K09", + "messageParameters" : { + "dataType" : "(\"BINARY\" or \"STRING\")", + "functionName" : "`listagg`", + "sqlExpr" : "\"listagg(c1, , )\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 24, + "fragment" : "listagg(c1, ', ')" + } ] +} + + +-- !query +SELECT listagg(b, a) FROM df GROUP BY a +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT", + "sqlState" : "42K09", + "messageParameters" : { + "inputExpr" : "\"a\"", + "inputName" : "`delimiter`", + "inputType" : "\"STRING\"", + "sqlExpr" : "\"listagg(b, a)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 20, + "fragment" : "listagg(b, a)" + } ] +} + + +-- !query +SELECT listagg(a) OVER (ORDER BY a) FROM df +-- !query schema +struct +-- !query output +NULL +aa +aa +aabb +aabb + + +-- !query +SELECT listagg(a) WITHIN GROUP (ORDER BY a) OVER (ORDER BY a) FROM df +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "INVALID_WINDOW_SPEC_FOR_AGGREGATION_FUNC", + "sqlState" : "42601", + "messageParameters" : { + "aggFunc" : "\"listagg(a, NULL, a)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 61, + "fragment" : "listagg(a) WITHIN GROUP (ORDER BY a) OVER (ORDER BY a)" + } ] +} + + +-- !query +SELECT string_agg(a) WITHIN GROUP (ORDER BY a) OVER (ORDER BY a) FROM df +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "INVALID_WINDOW_SPEC_FOR_AGGREGATION_FUNC", + "sqlState" : "42601", + "messageParameters" : { + "aggFunc" : "\"listagg(a, NULL, a)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 64, + "fragment" : "string_agg(a) WITHIN GROUP (ORDER BY a) OVER (ORDER BY a)" + } ] +} + + +-- !query +SELECT listagg(DISTINCT a) OVER (ORDER BY a) FROM df +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DISTINCT_WINDOW_FUNCTION_UNSUPPORTED", + "sqlState" : "0A000", + "messageParameters" : { + "windowExpr" : "\"listagg(DISTINCT a, NULL) OVER (ORDER BY a ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 44, + "fragment" : "listagg(DISTINCT a) OVER (ORDER BY a)" + } ] +} + + +-- !query +SELECT listagg(DISTINCT a) WITHIN GROUP (ORDER BY b) FROM df +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "FUNCTION_AND_ORDER_EXPRESSION_MISMATCH", + "sqlState" : "42822", + "messageParameters" : { + "functionArgs" : "\"a\"", + "functionName" : "`listagg`", + "orderExpr" : "\"b\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 52, + "fragment" : "listagg(DISTINCT a) WITHIN GROUP (ORDER BY b)" + } ] +} + + +-- !query +SELECT listagg(DISTINCT a) WITHIN GROUP (ORDER BY a, b) FROM df +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "FUNCTION_AND_ORDER_EXPRESSION_MISMATCH", + "sqlState" : "42822", + "messageParameters" : { + "functionArgs" : "\"a\"", + "functionName" : "`listagg`", + "orderExpr" : "\"a\", \"b\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 55, + "fragment" : "listagg(DISTINCT a) WITHIN GROUP (ORDER BY a, b)" + } ] +} + + +-- !query +SELECT listagg(c1) WITHIN GROUP (ORDER BY c1 COLLATE utf8_binary) FROM (VALUES ('a'), ('A'), ('b'), ('B')) AS t(c1) +-- !query schema +struct +-- !query output +ABab + + +-- !query +SELECT listagg(c1) WITHIN GROUP (ORDER BY c1 COLLATE utf8_lcase) FROM (VALUES ('a'), ('A'), ('b'), ('B')) AS t(c1) +-- !query schema +struct +-- !query output +aAbB + + +-- !query +SELECT listagg(DISTINCT c1 COLLATE utf8_binary) FROM (VALUES ('a'), ('A'), ('b'), ('B')) AS t(c1) +-- !query schema +struct +-- !query output +aAbB + + +-- !query +SELECT listagg(DISTINCT c1 COLLATE utf8_lcase) FROM (VALUES ('a'), ('A'), ('b'), ('B')) AS t(c1) +-- !query schema +struct +-- !query output +ab + + +-- !query +SELECT listagg(DISTINCT c1 COLLATE utf8_lcase) WITHIN GROUP (ORDER BY c1 COLLATE utf8_lcase) FROM (VALUES ('a'), ('B'), ('b'), ('A')) AS t(c1) +-- !query schema +struct +-- !query output +aB + + +-- !query +SELECT listagg(DISTINCT c1 COLLATE utf8_lcase) WITHIN GROUP (ORDER BY c1 COLLATE utf8_binary) FROM (VALUES ('a'), ('b'), ('A'), ('B')) AS t(c1) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "FUNCTION_AND_ORDER_EXPRESSION_MISMATCH", + "sqlState" : "42822", + "messageParameters" : { + "functionArgs" : "\"collate(c1, utf8_lcase)\"", + "functionName" : "`listagg`", + "orderExpr" : "\"collate(c1, utf8_binary)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 93, + "fragment" : "listagg(DISTINCT c1 COLLATE utf8_lcase) WITHIN GROUP (ORDER BY c1 COLLATE utf8_binary)" + } ] +} From ad49fcf6abd1fcf646b069692b98daac731fdd5b Mon Sep 17 00:00:00 2001 From: Mikhail Nikoliukin Date: Thu, 14 Nov 2024 17:27:15 +0100 Subject: [PATCH 43/58] [SPARK-42746] remove InverseDistributionFunction --- .../resources/error/error-conditions.json | 46 +++++++++---------- .../sql/catalyst/analysis/Analyzer.scala | 21 ++++----- .../InverseDistributionFunction.scala | 26 ----------- .../catalyst/expressions/aggregate/Mode.scala | 6 ++- .../SupportsOrderingWithinGroup.scala | 23 ++++++++-- .../expressions/aggregate/collect.scala | 2 + .../expressions/aggregate/percentiles.scala | 16 +++++-- .../sql/errors/QueryCompilationErrors.scala | 22 +++++---- .../sql-tests/analyzer-results/mode.sql.out | 8 ++-- .../analyzer-results/percentiles.sql.out | 10 ++-- .../resources/sql-tests/results/mode.sql.out | 8 ++-- .../sql-tests/results/percentiles.sql.out | 10 ++-- 12 files changed, 102 insertions(+), 96 deletions(-) delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/InverseDistributionFunction.scala diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index f7003fd949fa9..ce76b2fbcebf4 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -2613,29 +2613,6 @@ }, "sqlState" : "22006" }, - "INVALID_INVERSE_DISTRIBUTION_FUNCTION" : { - "message" : [ - "Invalid inverse distribution function ." - ], - "subClass" : { - "DISTINCT_UNSUPPORTED" : { - "message" : [ - "Cannot use DISTINCT with WITHIN GROUP." - ] - }, - "WITHIN_GROUP_MISSING" : { - "message" : [ - "WITHIN GROUP is required for inverse distribution function." - ] - }, - "WRONG_NUM_ORDERINGS" : { - "message" : [ - "Requires orderings in WITHIN GROUP but got ." - ] - } - }, - "sqlState" : "42K0K" - }, "INVALID_JAVA_IDENTIFIER_AS_FIELD_NAME" : { "message" : [ " is not a valid identifier of Java and cannot be used as field name", @@ -3338,6 +3315,29 @@ ], "sqlState" : "42601" }, + "INVALID_WITHIN_GROUP_EXPRESSION" : { + "message" : [ + "Invalid function with WITHIN GROUP." + ], + "subClass" : { + "DISTINCT_UNSUPPORTED" : { + "message" : [ + "The function does not support DISTINCT with WITHIN GROUP." + ] + }, + "WITHIN_GROUP_MISSING" : { + "message" : [ + "WITHIN GROUP is required for the function." + ] + }, + "WRONG_NUM_ORDERINGS" : { + "message" : [ + "The function requires orderings in WITHIN GROUP but got ." + ] + } + }, + "sqlState" : "42K0K" + }, "INVALID_WRITER_COMMIT_MESSAGE" : { "message" : [ "The data source writer has generated an invalid number of commit messages. Expected exactly one writer commit message from each task, but received ." diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index f534eba2b1e7f..957541545b3cd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2216,16 +2216,15 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor numArgs: Int, u: UnresolvedFunction): Expression = { func match { - case owg: InverseDistributionFunction if u.isDistinct => - throw QueryCompilationErrors.distinctInverseDistributionFunctionUnsupportedError( - owg.prettyName) - case owg: InverseDistributionFunction - if !owg.orderingFilled && u.orderingWithinGroup.isEmpty => - throw QueryCompilationErrors.inverseDistributionFunctionMissingWithinGroupError( - owg.prettyName) - case owg: InverseDistributionFunction - if owg.orderingFilled && u.orderingWithinGroup.nonEmpty => - throw QueryCompilationErrors.wrongNumOrderingsForInverseDistributionFunctionError( + case owg: SupportsOrderingWithinGroup if !owg.isDistinctSupported && u.isDistinct => + throw QueryCompilationErrors.distinctWithOrderingFunctionUnsupportedError( + owg.prettyName) + case owg: SupportsOrderingWithinGroup + if owg.isOrderingMandatory && !owg.orderingFilled && u.orderingWithinGroup.isEmpty => + throw QueryCompilationErrors.functionMissingWithinGroupError(owg.prettyName) + case owg: Mode if owg.orderingFilled && u.orderingWithinGroup.nonEmpty => + // mode(expr1) within group (order by expr2) is not supported + throw QueryCompilationErrors.wrongNumOrderingsForFunctionError( owg.prettyName, 0, u.orderingWithinGroup.length) case f if !f.isInstanceOf[SupportsOrderingWithinGroup] && u.orderingWithinGroup.nonEmpty => @@ -2277,7 +2276,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor case agg: AggregateFunction => // Note: PythonUDAF does not support these advanced clauses. if (agg.isInstanceOf[PythonUDAF]) checkUnsupportedAggregateClause(agg, u) - // After parse, the inverse distribution functions not set the ordering within group yet. + // After parse, the functions not set the ordering within group yet. val newAgg = agg match { case owg: SupportsOrderingWithinGroup if !owg.orderingFilled && u.orderingWithinGroup.nonEmpty => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/InverseDistributionFunction.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/InverseDistributionFunction.scala deleted file mode 100644 index 7e4c028f89a10..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/InverseDistributionFunction.scala +++ /dev/null @@ -1,26 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions.aggregate - -/** - * The trait used to set the [[SortOrder]] after inverse distribution functions parsed. - * Order clause is mandatory for all extenders. - */ -trait InverseDistributionFunction - extends SupportsOrderingWithinGroup { self: AggregateFunction => -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala index 7af4d5668d719..f3eeaa96b3d46 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala @@ -37,7 +37,7 @@ case class Mode( inputAggBufferOffset: Int = 0, reverseOpt: Option[Boolean] = None) extends TypedAggregateWithHashMapAsBuffer with ImplicitCastInputTypes - with InverseDistributionFunction with UnaryLike[Expression] { + with SupportsOrderingWithinGroup with UnaryLike[Expression] { def this(child: Expression) = this(child, 0, 0) @@ -183,6 +183,8 @@ case class Mode( } override def orderingFilled: Boolean = child != UnresolvedWithinGroup + override def isOrderingMandatory: Boolean = true + override def isDistinctSupported: Boolean = false assert(orderingFilled || (!orderingFilled && reverseOpt.isEmpty)) @@ -190,7 +192,7 @@ case class Mode( child match { case UnresolvedWithinGroup => if (orderingWithinGroup.length != 1) { - throw QueryCompilationErrors.wrongNumOrderingsForInverseDistributionFunctionError( + throw QueryCompilationErrors.wrongNumOrderingsForFunctionError( nodeName, 1, orderingWithinGroup.length) } orderingWithinGroup.head match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/SupportsOrderingWithinGroup.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/SupportsOrderingWithinGroup.scala index f9b6aebf909bd..a4bcb6185fffa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/SupportsOrderingWithinGroup.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/SupportsOrderingWithinGroup.scala @@ -21,10 +21,27 @@ import org.apache.spark.sql.catalyst.expressions.SortOrder /** * The trait used to set the [[SortOrder]] for supporting functions. - * By default ordering is optional. */ trait SupportsOrderingWithinGroup { def withOrderingWithinGroup(orderingWithinGroup: Seq[SortOrder]): AggregateFunction - /** Indicator that ordering was set */ - def orderingFilled: Boolean = false + + /** Indicator that ordering was set. */ + def orderingFilled: Boolean + + /** + * Tells Analyzer that WITHIN GROUP (ORDER BY ...) is mandatory for function. + * + * @see [[QueryCompilationErrors.functionMissingWithinGroupError]] + * @see [[org.apache.spark.sql.catalyst.analysis.Analyzer]] + */ + def isOrderingMandatory: Boolean + + /** + * Tells Analyzer that DISTINCT is supported. + * The DISTINCT can conflict with order so some functions can ban it. + * + * @see [[QueryCompilationErrors.functionMissingWithinGroupError]] + * @see [[org.apache.spark.sql.catalyst.analysis.Analyzer]] + */ + def isDistinctSupported: Boolean } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala index 29807f3ce18e6..098fcb361e6b7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala @@ -323,6 +323,8 @@ case class ListAgg( with SupportsOrderingWithinGroup with ImplicitCastInputTypes { + override def isOrderingMandatory: Boolean = false + override def isDistinctSupported: Boolean = true override protected lazy val bufferElementType: DataType = { if (noNeedSaveOrderValue) { child.dataType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/percentiles.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/percentiles.scala index 25f37385ef97d..6dfa1b499df23 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/percentiles.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/percentiles.scala @@ -360,7 +360,7 @@ case class PercentileCont(left: Expression, right: Expression, reverse: Boolean extends AggregateFunction with RuntimeReplaceableAggregate with ImplicitCastInputTypes - with InverseDistributionFunction + with SupportsOrderingWithinGroup with BinaryLike[Expression] { private lazy val percentile = new Percentile(left, right, reverse) override lazy val replacement: Expression = percentile @@ -378,7 +378,7 @@ case class PercentileCont(left: Expression, right: Expression, reverse: Boolean override def withOrderingWithinGroup(orderingWithinGroup: Seq[SortOrder]): AggregateFunction = { if (orderingWithinGroup.length != 1) { - throw QueryCompilationErrors.wrongNumOrderingsForInverseDistributionFunctionError( + throw QueryCompilationErrors.wrongNumOrderingsForFunctionError( nodeName, 1, orderingWithinGroup.length) } orderingWithinGroup.head match { @@ -390,6 +390,10 @@ case class PercentileCont(left: Expression, right: Expression, reverse: Boolean override protected def withNewChildrenInternal( newLeft: Expression, newRight: Expression): PercentileCont = this.copy(left = newLeft, right = newRight) + + override def orderingFilled: Boolean = left != UnresolvedWithinGroup + override def isOrderingMandatory: Boolean = true + override def isDistinctSupported: Boolean = false } /** @@ -407,7 +411,7 @@ case class PercentileDisc( mutableAggBufferOffset: Int = 0, inputAggBufferOffset: Int = 0, legacyCalculation: Boolean = SQLConf.get.getConf(SQLConf.LEGACY_PERCENTILE_DISC_CALCULATION)) - extends PercentileBase with InverseDistributionFunction with BinaryLike[Expression] { + extends PercentileBase with SupportsOrderingWithinGroup with BinaryLike[Expression] { val frequencyExpression: Expression = Literal(1L) @@ -432,7 +436,7 @@ case class PercentileDisc( override def withOrderingWithinGroup(orderingWithinGroup: Seq[SortOrder]): AggregateFunction = { if (orderingWithinGroup.length != 1) { - throw QueryCompilationErrors.wrongNumOrderingsForInverseDistributionFunctionError( + throw QueryCompilationErrors.wrongNumOrderingsForFunctionError( nodeName, 1, orderingWithinGroup.length) } orderingWithinGroup.head match { @@ -467,6 +471,10 @@ case class PercentileDisc( toDoubleValue(higherKey) } } + + override def orderingFilled: Boolean = left != UnresolvedWithinGroup + override def isOrderingMandatory: Boolean = true + override def isDistinctSupported: Boolean = false } // scalastyle:off line.size.limit diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index 4a370bf824c37..0ce8ce221dc94 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -725,28 +725,32 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat "windowExpr" -> toSQLExpr(windowExpr))) } - def distinctInverseDistributionFunctionUnsupportedError(funcName: String): Throwable = { + def distinctWithOrderingFunctionUnsupportedError(funcName: String): Throwable = { new AnalysisException( - errorClass = "INVALID_INVERSE_DISTRIBUTION_FUNCTION.DISTINCT_UNSUPPORTED", - messageParameters = Map("funcName" -> toSQLId(funcName))) + errorClass = "INVALID_WITHIN_GROUP_EXPRESSION.DISTINCT_UNSUPPORTED", + messageParameters = Map("funcName" -> toSQLId(funcName)) + ) } - def inverseDistributionFunctionMissingWithinGroupError(funcName: String): Throwable = { + def functionMissingWithinGroupError(funcName: String): Throwable = { new AnalysisException( - errorClass = "INVALID_INVERSE_DISTRIBUTION_FUNCTION.WITHIN_GROUP_MISSING", - messageParameters = Map("funcName" -> toSQLId(funcName))) + errorClass = "INVALID_WITHIN_GROUP_EXPRESSION.WITHIN_GROUP_MISSING", + messageParameters = Map("funcName" -> toSQLId(funcName)) + ) } - def wrongNumOrderingsForInverseDistributionFunctionError( + def wrongNumOrderingsForFunctionError( funcName: String, validOrderingsNumber: Int, actualOrderingsNumber: Int): Throwable = { new AnalysisException( - errorClass = "INVALID_INVERSE_DISTRIBUTION_FUNCTION.WRONG_NUM_ORDERINGS", + errorClass = "INVALID_WITHIN_GROUP_EXPRESSION.WRONG_NUM_ORDERINGS", messageParameters = Map( "funcName" -> toSQLId(funcName), "expectedNum" -> validOrderingsNumber.toString, - "actualNum" -> actualOrderingsNumber.toString)) + "actualNum" -> actualOrderingsNumber.toString + ) + ) } def aliasNumberNotMatchColumnNumberError( diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/mode.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/mode.sql.out index d6ecbc72a7178..8028c344140f5 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/mode.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/mode.sql.out @@ -74,7 +74,7 @@ SELECT department, mode(DISTINCT salary) FROM basic_pays GROUP BY department ORD -- !query analysis org.apache.spark.sql.AnalysisException { - "errorClass" : "INVALID_INVERSE_DISTRIBUTION_FUNCTION.DISTINCT_UNSUPPORTED", + "errorClass" : "INVALID_WITHIN_GROUP_EXPRESSION.DISTINCT_UNSUPPORTED", "sqlState" : "42K0K", "messageParameters" : { "funcName" : "`mode`" @@ -379,7 +379,7 @@ FROM basic_pays -- !query analysis org.apache.spark.sql.AnalysisException { - "errorClass" : "INVALID_INVERSE_DISTRIBUTION_FUNCTION.DISTINCT_UNSUPPORTED", + "errorClass" : "INVALID_WITHIN_GROUP_EXPRESSION.DISTINCT_UNSUPPORTED", "sqlState" : "42K0K", "messageParameters" : { "funcName" : "`mode`" @@ -401,7 +401,7 @@ FROM basic_pays -- !query analysis org.apache.spark.sql.AnalysisException { - "errorClass" : "INVALID_INVERSE_DISTRIBUTION_FUNCTION.WITHIN_GROUP_MISSING", + "errorClass" : "INVALID_WITHIN_GROUP_EXPRESSION.WITHIN_GROUP_MISSING", "sqlState" : "42K0K", "messageParameters" : { "funcName" : "`mode`" @@ -423,7 +423,7 @@ FROM basic_pays -- !query analysis org.apache.spark.sql.AnalysisException { - "errorClass" : "INVALID_INVERSE_DISTRIBUTION_FUNCTION.WRONG_NUM_ORDERINGS", + "errorClass" : "INVALID_WITHIN_GROUP_EXPRESSION.WRONG_NUM_ORDERINGS", "sqlState" : "42K0K", "messageParameters" : { "actualNum" : "1", diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/percentiles.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/percentiles.sql.out index 4a31cff8c7d0f..eb8102afa47ef 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/percentiles.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/percentiles.sql.out @@ -248,7 +248,7 @@ FROM aggr -- !query analysis org.apache.spark.sql.AnalysisException { - "errorClass" : "INVALID_INVERSE_DISTRIBUTION_FUNCTION.DISTINCT_UNSUPPORTED", + "errorClass" : "INVALID_WITHIN_GROUP_EXPRESSION.DISTINCT_UNSUPPORTED", "sqlState" : "42K0K", "messageParameters" : { "funcName" : "`percentile_cont`" @@ -270,7 +270,7 @@ FROM aggr -- !query analysis org.apache.spark.sql.AnalysisException { - "errorClass" : "INVALID_INVERSE_DISTRIBUTION_FUNCTION.DISTINCT_UNSUPPORTED", + "errorClass" : "INVALID_WITHIN_GROUP_EXPRESSION.DISTINCT_UNSUPPORTED", "sqlState" : "42K0K", "messageParameters" : { "funcName" : "`percentile_cont`" @@ -342,7 +342,7 @@ FROM aggr -- !query analysis org.apache.spark.sql.AnalysisException { - "errorClass" : "INVALID_INVERSE_DISTRIBUTION_FUNCTION.WITHIN_GROUP_MISSING", + "errorClass" : "INVALID_WITHIN_GROUP_EXPRESSION.WITHIN_GROUP_MISSING", "sqlState" : "42K0K", "messageParameters" : { "funcName" : "`percentile_cont`" @@ -364,7 +364,7 @@ FROM aggr -- !query analysis org.apache.spark.sql.AnalysisException { - "errorClass" : "INVALID_INVERSE_DISTRIBUTION_FUNCTION.WITHIN_GROUP_MISSING", + "errorClass" : "INVALID_WITHIN_GROUP_EXPRESSION.WITHIN_GROUP_MISSING", "sqlState" : "42K0K", "messageParameters" : { "funcName" : "`percentile_cont`" @@ -386,7 +386,7 @@ FROM aggr -- !query analysis org.apache.spark.sql.AnalysisException { - "errorClass" : "INVALID_INVERSE_DISTRIBUTION_FUNCTION.WRONG_NUM_ORDERINGS", + "errorClass" : "INVALID_WITHIN_GROUP_EXPRESSION.WRONG_NUM_ORDERINGS", "sqlState" : "42K0K", "messageParameters" : { "actualNum" : "2", diff --git a/sql/core/src/test/resources/sql-tests/results/mode.sql.out b/sql/core/src/test/resources/sql-tests/results/mode.sql.out index ad7d59eeb1634..70f253066d4f9 100644 --- a/sql/core/src/test/resources/sql-tests/results/mode.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/mode.sql.out @@ -51,7 +51,7 @@ struct<> -- !query output org.apache.spark.sql.AnalysisException { - "errorClass" : "INVALID_INVERSE_DISTRIBUTION_FUNCTION.DISTINCT_UNSUPPORTED", + "errorClass" : "INVALID_WITHIN_GROUP_EXPRESSION.DISTINCT_UNSUPPORTED", "sqlState" : "42K0K", "messageParameters" : { "funcName" : "`mode`" @@ -373,7 +373,7 @@ struct<> -- !query output org.apache.spark.sql.AnalysisException { - "errorClass" : "INVALID_INVERSE_DISTRIBUTION_FUNCTION.DISTINCT_UNSUPPORTED", + "errorClass" : "INVALID_WITHIN_GROUP_EXPRESSION.DISTINCT_UNSUPPORTED", "sqlState" : "42K0K", "messageParameters" : { "funcName" : "`mode`" @@ -397,7 +397,7 @@ struct<> -- !query output org.apache.spark.sql.AnalysisException { - "errorClass" : "INVALID_INVERSE_DISTRIBUTION_FUNCTION.WITHIN_GROUP_MISSING", + "errorClass" : "INVALID_WITHIN_GROUP_EXPRESSION.WITHIN_GROUP_MISSING", "sqlState" : "42K0K", "messageParameters" : { "funcName" : "`mode`" @@ -421,7 +421,7 @@ struct<> -- !query output org.apache.spark.sql.AnalysisException { - "errorClass" : "INVALID_INVERSE_DISTRIBUTION_FUNCTION.WRONG_NUM_ORDERINGS", + "errorClass" : "INVALID_WITHIN_GROUP_EXPRESSION.WRONG_NUM_ORDERINGS", "sqlState" : "42K0K", "messageParameters" : { "actualNum" : "1", diff --git a/sql/core/src/test/resources/sql-tests/results/percentiles.sql.out b/sql/core/src/test/resources/sql-tests/results/percentiles.sql.out index cd95eee186e12..55aaa8ee7378e 100644 --- a/sql/core/src/test/resources/sql-tests/results/percentiles.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/percentiles.sql.out @@ -222,7 +222,7 @@ struct<> -- !query output org.apache.spark.sql.AnalysisException { - "errorClass" : "INVALID_INVERSE_DISTRIBUTION_FUNCTION.DISTINCT_UNSUPPORTED", + "errorClass" : "INVALID_WITHIN_GROUP_EXPRESSION.DISTINCT_UNSUPPORTED", "sqlState" : "42K0K", "messageParameters" : { "funcName" : "`percentile_cont`" @@ -246,7 +246,7 @@ struct<> -- !query output org.apache.spark.sql.AnalysisException { - "errorClass" : "INVALID_INVERSE_DISTRIBUTION_FUNCTION.DISTINCT_UNSUPPORTED", + "errorClass" : "INVALID_WITHIN_GROUP_EXPRESSION.DISTINCT_UNSUPPORTED", "sqlState" : "42K0K", "messageParameters" : { "funcName" : "`percentile_cont`" @@ -324,7 +324,7 @@ struct<> -- !query output org.apache.spark.sql.AnalysisException { - "errorClass" : "INVALID_INVERSE_DISTRIBUTION_FUNCTION.WITHIN_GROUP_MISSING", + "errorClass" : "INVALID_WITHIN_GROUP_EXPRESSION.WITHIN_GROUP_MISSING", "sqlState" : "42K0K", "messageParameters" : { "funcName" : "`percentile_cont`" @@ -348,7 +348,7 @@ struct<> -- !query output org.apache.spark.sql.AnalysisException { - "errorClass" : "INVALID_INVERSE_DISTRIBUTION_FUNCTION.WITHIN_GROUP_MISSING", + "errorClass" : "INVALID_WITHIN_GROUP_EXPRESSION.WITHIN_GROUP_MISSING", "sqlState" : "42K0K", "messageParameters" : { "funcName" : "`percentile_cont`" @@ -372,7 +372,7 @@ struct<> -- !query output org.apache.spark.sql.AnalysisException { - "errorClass" : "INVALID_INVERSE_DISTRIBUTION_FUNCTION.WRONG_NUM_ORDERINGS", + "errorClass" : "INVALID_WITHIN_GROUP_EXPRESSION.WRONG_NUM_ORDERINGS", "sqlState" : "42K0K", "messageParameters" : { "actualNum" : "2", From 0aca46cee938b2eb3ccdc0f2ce16a71bde79658e Mon Sep 17 00:00:00 2001 From: Mikhail Nikoliukin Date: Fri, 15 Nov 2024 13:06:43 +0100 Subject: [PATCH 44/58] [SPARK-42746] fix golden file and small refactoring --- .../SupportsOrderingWithinGroup.scala | 4 +--- .../expressions/aggregate/collect.scala | 13 +++++----- .../analyzer-results/listagg.sql.out | 24 +++++++++++++------ .../resources/sql-tests/inputs/listagg.sql | 2 +- .../sql-tests/results/listagg.sql.out | 24 +++++++++++++------ 5 files changed, 43 insertions(+), 24 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/SupportsOrderingWithinGroup.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/SupportsOrderingWithinGroup.scala index a4bcb6185fffa..453251ac61cde 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/SupportsOrderingWithinGroup.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/SupportsOrderingWithinGroup.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.expressions.SortOrder /** * The trait used to set the [[SortOrder]] for supporting functions. */ -trait SupportsOrderingWithinGroup { +trait SupportsOrderingWithinGroup { self: AggregateFunction => def withOrderingWithinGroup(orderingWithinGroup: Seq[SortOrder]): AggregateFunction /** Indicator that ordering was set. */ @@ -32,7 +32,6 @@ trait SupportsOrderingWithinGroup { * Tells Analyzer that WITHIN GROUP (ORDER BY ...) is mandatory for function. * * @see [[QueryCompilationErrors.functionMissingWithinGroupError]] - * @see [[org.apache.spark.sql.catalyst.analysis.Analyzer]] */ def isOrderingMandatory: Boolean @@ -41,7 +40,6 @@ trait SupportsOrderingWithinGroup { * The DISTINCT can conflict with order so some functions can ban it. * * @see [[QueryCompilationErrors.functionMissingWithinGroupError]] - * @see [[org.apache.spark.sql.catalyst.analysis.Analyzer]] */ def isDistinctSupported: Boolean } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala index 098fcb361e6b7..8d7396f334ecb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala @@ -323,8 +323,15 @@ case class ListAgg( with SupportsOrderingWithinGroup with ImplicitCastInputTypes { + override def orderingFilled: Boolean = orderExpressions.nonEmpty + override def isOrderingMandatory: Boolean = false + override def isDistinctSupported: Boolean = true + + override def withOrderingWithinGroup(orderingWithinGroup: Seq[SortOrder]): AggregateFunction = + copy(orderExpressions = orderingWithinGroup) + override protected lazy val bufferElementType: DataType = { if (noNeedSaveOrderValue) { child.dataType @@ -439,9 +446,6 @@ case class ListAgg( } new InterpretedOrdering(bufferSortOrder) } - - override def orderingFilled: Boolean = orderExpressions.nonEmpty - private[this] def concatSkippingNulls(buffer: mutable.ArrayBuffer[Any]): Any = { val delimiterValue = getDelimiterValue dataType match { @@ -488,9 +492,6 @@ case class ListAgg( override protected def convertToBufferElement(value: Any): Any = InternalRow.copyValue(value) - override def withOrderingWithinGroup(orderingWithinGroup: Seq[SortOrder]): AggregateFunction = - copy(orderExpressions = orderingWithinGroup) - override def children: Seq[Expression] = child +: delimiter +: orderExpressions.map(_.child) /** diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/listagg.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/listagg.sql.out index d199947c30a1e..199d1329e0271 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/listagg.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/listagg.sql.out @@ -261,16 +261,26 @@ Aggregate [listagg(cast(a#x as string), null, 0, 0) AS listagg(a, NULL)#x, lista -- !query -SELECT listagg(c1) FROM (VALUES (ARRAY['a', 'b'])) AS t(c1) +SELECT listagg(c1) FROM (VALUES (ARRAY('a', 'b'))) AS t(c1) -- !query analysis -org.apache.spark.sql.catalyst.parser.ParseException +org.apache.spark.sql.catalyst.ExtendedAnalysisException { - "errorClass" : "PARSE_SYNTAX_ERROR", - "sqlState" : "42601", + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", "messageParameters" : { - "error" : "','", - "hint" : "" - } + "inputSql" : "\"c1\"", + "inputType" : "\"ARRAY\"", + "paramIndex" : "first", + "requiredType" : "(\"STRING\" or \"BINARY\")", + "sqlExpr" : "\"listagg(c1, NULL)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 18, + "fragment" : "listagg(c1)" + } ] } diff --git a/sql/core/src/test/resources/sql-tests/inputs/listagg.sql b/sql/core/src/test/resources/sql-tests/inputs/listagg.sql index 05f02961425fc..0cf49aae6a139 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/listagg.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/listagg.sql @@ -27,7 +27,7 @@ SELECT listagg(c1, X'42') FROM (VALUES (X'DEAD'), (X'BEEF')) AS t(c1); SELECT listagg(a), listagg(b, ',') FROM df2; -- Error cases -SELECT listagg(c1) FROM (VALUES (ARRAY['a', 'b'])) AS t(c1); +SELECT listagg(c1) FROM (VALUES (ARRAY('a', 'b'))) AS t(c1); SELECT listagg(c1, ', ') FROM (VALUES (X'DEAD'), (X'BEEF')) AS t(c1); SELECT listagg(b, a) FROM df GROUP BY a; SELECT listagg(a) OVER (ORDER BY a) FROM df; diff --git a/sql/core/src/test/resources/sql-tests/results/listagg.sql.out b/sql/core/src/test/resources/sql-tests/results/listagg.sql.out index 39c7d1ce7a6c8..6b2f880f92955 100644 --- a/sql/core/src/test/resources/sql-tests/results/listagg.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/listagg.sql.out @@ -182,18 +182,28 @@ struct -- !query -SELECT listagg(c1) FROM (VALUES (ARRAY['a', 'b'])) AS t(c1) +SELECT listagg(c1) FROM (VALUES (ARRAY('a', 'b'))) AS t(c1) -- !query schema struct<> -- !query output -org.apache.spark.sql.catalyst.parser.ParseException +org.apache.spark.sql.catalyst.ExtendedAnalysisException { - "errorClass" : "PARSE_SYNTAX_ERROR", - "sqlState" : "42601", + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", "messageParameters" : { - "error" : "','", - "hint" : "" - } + "inputSql" : "\"c1\"", + "inputType" : "\"ARRAY\"", + "paramIndex" : "first", + "requiredType" : "(\"STRING\" or \"BINARY\")", + "sqlExpr" : "\"listagg(c1, NULL)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 18, + "fragment" : "listagg(c1)" + } ] } From ca5b13aef27e75c4eee1308ec1363d414f021323 Mon Sep 17 00:00:00 2001 From: Mikhail Nikoliukin Date: Tue, 19 Nov 2024 14:34:18 +0100 Subject: [PATCH 45/58] [SPARK-42746] fix ThriftServerQueryTestSuite --- .../listagg-collations.sql.out | 66 ++++++++++++++++++ .../analyzer-results/listagg.sql.out | 67 ------------------- .../sql-tests/inputs/listagg-collations.sql | 9 +++ .../resources/sql-tests/inputs/listagg.sql | 12 +--- .../results/listagg-collations.sql.out | 63 +++++++++++++++++ .../sql-tests/results/listagg.sql.out | 64 ------------------ .../ThriftServerQueryTestSuite.scala | 1 + 7 files changed, 140 insertions(+), 142 deletions(-) create mode 100644 sql/core/src/test/resources/sql-tests/analyzer-results/listagg-collations.sql.out create mode 100644 sql/core/src/test/resources/sql-tests/inputs/listagg-collations.sql create mode 100644 sql/core/src/test/resources/sql-tests/results/listagg-collations.sql.out diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/listagg-collations.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/listagg-collations.sql.out new file mode 100644 index 0000000000000..b29f722f0fb9b --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/listagg-collations.sql.out @@ -0,0 +1,66 @@ +-- Automatically generated by SQLQueryTestSuite +-- !query +SELECT listagg(c1) WITHIN GROUP (ORDER BY c1 COLLATE utf8_binary) FROM (VALUES ('a'), ('A'), ('b'), ('B')) AS t(c1) +-- !query analysis +Aggregate [listagg(c1#x, null, collate(c1#x, utf8_binary) ASC NULLS FIRST, 0, 0) AS listagg(c1, NULL) WITHIN GROUP (ORDER BY collate(c1, utf8_binary) ASC NULLS FIRST)#x] ++- SubqueryAlias t + +- Project [col1#x AS c1#x] + +- LocalRelation [col1#x] + + +-- !query +SELECT listagg(c1) WITHIN GROUP (ORDER BY c1 COLLATE utf8_lcase) FROM (VALUES ('a'), ('A'), ('b'), ('B')) AS t(c1) +-- !query analysis +Aggregate [listagg(c1#x, null, collate(c1#x, utf8_lcase) ASC NULLS FIRST, 0, 0) AS listagg(c1, NULL) WITHIN GROUP (ORDER BY collate(c1, utf8_lcase) ASC NULLS FIRST)#x] ++- SubqueryAlias t + +- Project [col1#x AS c1#x] + +- LocalRelation [col1#x] + + +-- !query +SELECT listagg(DISTINCT c1 COLLATE utf8_binary) FROM (VALUES ('a'), ('A'), ('b'), ('B')) AS t(c1) +-- !query analysis +Aggregate [listagg(distinct collate(c1#x, utf8_binary), null, 0, 0) AS listagg(DISTINCT collate(c1, utf8_binary), NULL)#x] ++- SubqueryAlias t + +- Project [col1#x AS c1#x] + +- LocalRelation [col1#x] + + +-- !query +SELECT listagg(DISTINCT c1 COLLATE utf8_lcase) FROM (VALUES ('a'), ('A'), ('b'), ('B')) AS t(c1) +-- !query analysis +Aggregate [listagg(distinct collate(c1#x, utf8_lcase), null, 0, 0) AS listagg(DISTINCT collate(c1, utf8_lcase), NULL)#x] ++- SubqueryAlias t + +- Project [col1#x AS c1#x] + +- LocalRelation [col1#x] + + +-- !query +SELECT listagg(DISTINCT c1 COLLATE utf8_lcase) WITHIN GROUP (ORDER BY c1 COLLATE utf8_lcase) FROM (VALUES ('a'), ('B'), ('b'), ('A')) AS t(c1) +-- !query analysis +Aggregate [listagg(distinct collate(c1#x, utf8_lcase), null, collate(c1#x, utf8_lcase) ASC NULLS FIRST, 0, 0) AS listagg(DISTINCT collate(c1, utf8_lcase), NULL) WITHIN GROUP (ORDER BY collate(c1, utf8_lcase) ASC NULLS FIRST)#x] ++- SubqueryAlias t + +- Project [col1#x AS c1#x] + +- LocalRelation [col1#x] + + +-- !query +SELECT listagg(DISTINCT c1 COLLATE utf8_lcase) WITHIN GROUP (ORDER BY c1 COLLATE utf8_binary) FROM (VALUES ('a'), ('b'), ('A'), ('B')) AS t(c1) +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "FUNCTION_AND_ORDER_EXPRESSION_MISMATCH", + "sqlState" : "42822", + "messageParameters" : { + "functionArgs" : "\"collate(c1, utf8_lcase)\"", + "functionName" : "`listagg`", + "orderExpr" : "\"collate(c1, utf8_binary)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 93, + "fragment" : "listagg(DISTINCT c1 COLLATE utf8_lcase) WITHIN GROUP (ORDER BY c1 COLLATE utf8_binary)" + } ] +} diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/listagg.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/listagg.sql.out index 199d1329e0271..a6f1821f5d3cb 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/listagg.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/listagg.sql.out @@ -447,70 +447,3 @@ org.apache.spark.sql.AnalysisException "fragment" : "listagg(DISTINCT a) WITHIN GROUP (ORDER BY a, b)" } ] } - - --- !query -SELECT listagg(c1) WITHIN GROUP (ORDER BY c1 COLLATE utf8_binary) FROM (VALUES ('a'), ('A'), ('b'), ('B')) AS t(c1) --- !query analysis -Aggregate [listagg(c1#x, null, collate(c1#x, utf8_binary) ASC NULLS FIRST, 0, 0) AS listagg(c1, NULL) WITHIN GROUP (ORDER BY collate(c1, utf8_binary) ASC NULLS FIRST)#x] -+- SubqueryAlias t - +- Project [col1#x AS c1#x] - +- LocalRelation [col1#x] - - --- !query -SELECT listagg(c1) WITHIN GROUP (ORDER BY c1 COLLATE utf8_lcase) FROM (VALUES ('a'), ('A'), ('b'), ('B')) AS t(c1) --- !query analysis -Aggregate [listagg(c1#x, null, collate(c1#x, utf8_lcase) ASC NULLS FIRST, 0, 0) AS listagg(c1, NULL) WITHIN GROUP (ORDER BY collate(c1, utf8_lcase) ASC NULLS FIRST)#x] -+- SubqueryAlias t - +- Project [col1#x AS c1#x] - +- LocalRelation [col1#x] - - --- !query -SELECT listagg(DISTINCT c1 COLLATE utf8_binary) FROM (VALUES ('a'), ('A'), ('b'), ('B')) AS t(c1) --- !query analysis -Aggregate [listagg(distinct collate(c1#x, utf8_binary), null, 0, 0) AS listagg(DISTINCT collate(c1, utf8_binary), NULL)#x] -+- SubqueryAlias t - +- Project [col1#x AS c1#x] - +- LocalRelation [col1#x] - - --- !query -SELECT listagg(DISTINCT c1 COLLATE utf8_lcase) FROM (VALUES ('a'), ('A'), ('b'), ('B')) AS t(c1) --- !query analysis -Aggregate [listagg(distinct collate(c1#x, utf8_lcase), null, 0, 0) AS listagg(DISTINCT collate(c1, utf8_lcase), NULL)#x] -+- SubqueryAlias t - +- Project [col1#x AS c1#x] - +- LocalRelation [col1#x] - - --- !query -SELECT listagg(DISTINCT c1 COLLATE utf8_lcase) WITHIN GROUP (ORDER BY c1 COLLATE utf8_lcase) FROM (VALUES ('a'), ('B'), ('b'), ('A')) AS t(c1) --- !query analysis -Aggregate [listagg(distinct collate(c1#x, utf8_lcase), null, collate(c1#x, utf8_lcase) ASC NULLS FIRST, 0, 0) AS listagg(DISTINCT collate(c1, utf8_lcase), NULL) WITHIN GROUP (ORDER BY collate(c1, utf8_lcase) ASC NULLS FIRST)#x] -+- SubqueryAlias t - +- Project [col1#x AS c1#x] - +- LocalRelation [col1#x] - - --- !query -SELECT listagg(DISTINCT c1 COLLATE utf8_lcase) WITHIN GROUP (ORDER BY c1 COLLATE utf8_binary) FROM (VALUES ('a'), ('b'), ('A'), ('B')) AS t(c1) --- !query analysis -org.apache.spark.sql.AnalysisException -{ - "errorClass" : "FUNCTION_AND_ORDER_EXPRESSION_MISMATCH", - "sqlState" : "42822", - "messageParameters" : { - "functionArgs" : "\"collate(c1, utf8_lcase)\"", - "functionName" : "`listagg`", - "orderExpr" : "\"collate(c1, utf8_binary)\"" - }, - "queryContext" : [ { - "objectType" : "", - "objectName" : "", - "startIndex" : 8, - "stopIndex" : 93, - "fragment" : "listagg(DISTINCT c1 COLLATE utf8_lcase) WITHIN GROUP (ORDER BY c1 COLLATE utf8_binary)" - } ] -} diff --git a/sql/core/src/test/resources/sql-tests/inputs/listagg-collations.sql b/sql/core/src/test/resources/sql-tests/inputs/listagg-collations.sql new file mode 100644 index 0000000000000..8e608d940a0f9 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/listagg-collations.sql @@ -0,0 +1,9 @@ +-- Test cases with collations +SELECT listagg(c1) WITHIN GROUP (ORDER BY c1 COLLATE utf8_binary) FROM (VALUES ('a'), ('A'), ('b'), ('B')) AS t(c1); +SELECT listagg(c1) WITHIN GROUP (ORDER BY c1 COLLATE utf8_lcase) FROM (VALUES ('a'), ('A'), ('b'), ('B')) AS t(c1); +SELECT listagg(DISTINCT c1 COLLATE utf8_binary) FROM (VALUES ('a'), ('A'), ('b'), ('B')) AS t(c1); +SELECT listagg(DISTINCT c1 COLLATE utf8_lcase) FROM (VALUES ('a'), ('A'), ('b'), ('B')) AS t(c1); +SELECT listagg(DISTINCT c1 COLLATE utf8_lcase) WITHIN GROUP (ORDER BY c1 COLLATE utf8_lcase) FROM (VALUES ('a'), ('B'), ('b'), ('A')) AS t(c1); + +-- Error case with collations +SELECT listagg(DISTINCT c1 COLLATE utf8_lcase) WITHIN GROUP (ORDER BY c1 COLLATE utf8_binary) FROM (VALUES ('a'), ('b'), ('A'), ('B')) AS t(c1); \ No newline at end of file diff --git a/sql/core/src/test/resources/sql-tests/inputs/listagg.sql b/sql/core/src/test/resources/sql-tests/inputs/listagg.sql index 0cf49aae6a139..15c8cfa823e9b 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/listagg.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/listagg.sql @@ -35,14 +35,4 @@ SELECT listagg(a) WITHIN GROUP (ORDER BY a) OVER (ORDER BY a) FROM df; SELECT string_agg(a) WITHIN GROUP (ORDER BY a) OVER (ORDER BY a) FROM df; SELECT listagg(DISTINCT a) OVER (ORDER BY a) FROM df; SELECT listagg(DISTINCT a) WITHIN GROUP (ORDER BY b) FROM df; -SELECT listagg(DISTINCT a) WITHIN GROUP (ORDER BY a, b) FROM df; - --- Test cases with collations -SELECT listagg(c1) WITHIN GROUP (ORDER BY c1 COLLATE utf8_binary) FROM (VALUES ('a'), ('A'), ('b'), ('B')) AS t(c1); -SELECT listagg(c1) WITHIN GROUP (ORDER BY c1 COLLATE utf8_lcase) FROM (VALUES ('a'), ('A'), ('b'), ('B')) AS t(c1); -SELECT listagg(DISTINCT c1 COLLATE utf8_binary) FROM (VALUES ('a'), ('A'), ('b'), ('B')) AS t(c1); -SELECT listagg(DISTINCT c1 COLLATE utf8_lcase) FROM (VALUES ('a'), ('A'), ('b'), ('B')) AS t(c1); -SELECT listagg(DISTINCT c1 COLLATE utf8_lcase) WITHIN GROUP (ORDER BY c1 COLLATE utf8_lcase) FROM (VALUES ('a'), ('B'), ('b'), ('A')) AS t(c1); - --- Error case with collations -SELECT listagg(DISTINCT c1 COLLATE utf8_lcase) WITHIN GROUP (ORDER BY c1 COLLATE utf8_binary) FROM (VALUES ('a'), ('b'), ('A'), ('B')) AS t(c1); \ No newline at end of file +SELECT listagg(DISTINCT a) WITHIN GROUP (ORDER BY a, b) FROM df; \ No newline at end of file diff --git a/sql/core/src/test/resources/sql-tests/results/listagg-collations.sql.out b/sql/core/src/test/resources/sql-tests/results/listagg-collations.sql.out new file mode 100644 index 0000000000000..2f1640def3adb --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/listagg-collations.sql.out @@ -0,0 +1,63 @@ +-- Automatically generated by SQLQueryTestSuite +-- !query +SELECT listagg(c1) WITHIN GROUP (ORDER BY c1 COLLATE utf8_binary) FROM (VALUES ('a'), ('A'), ('b'), ('B')) AS t(c1) +-- !query schema +struct +-- !query output +ABab + + +-- !query +SELECT listagg(c1) WITHIN GROUP (ORDER BY c1 COLLATE utf8_lcase) FROM (VALUES ('a'), ('A'), ('b'), ('B')) AS t(c1) +-- !query schema +struct +-- !query output +aAbB + + +-- !query +SELECT listagg(DISTINCT c1 COLLATE utf8_binary) FROM (VALUES ('a'), ('A'), ('b'), ('B')) AS t(c1) +-- !query schema +struct +-- !query output +aAbB + + +-- !query +SELECT listagg(DISTINCT c1 COLLATE utf8_lcase) FROM (VALUES ('a'), ('A'), ('b'), ('B')) AS t(c1) +-- !query schema +struct +-- !query output +ab + + +-- !query +SELECT listagg(DISTINCT c1 COLLATE utf8_lcase) WITHIN GROUP (ORDER BY c1 COLLATE utf8_lcase) FROM (VALUES ('a'), ('B'), ('b'), ('A')) AS t(c1) +-- !query schema +struct +-- !query output +aB + + +-- !query +SELECT listagg(DISTINCT c1 COLLATE utf8_lcase) WITHIN GROUP (ORDER BY c1 COLLATE utf8_binary) FROM (VALUES ('a'), ('b'), ('A'), ('B')) AS t(c1) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "FUNCTION_AND_ORDER_EXPRESSION_MISMATCH", + "sqlState" : "42822", + "messageParameters" : { + "functionArgs" : "\"collate(c1, utf8_lcase)\"", + "functionName" : "`listagg`", + "orderExpr" : "\"collate(c1, utf8_binary)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 93, + "fragment" : "listagg(DISTINCT c1 COLLATE utf8_lcase) WITHIN GROUP (ORDER BY c1 COLLATE utf8_binary)" + } ] +} diff --git a/sql/core/src/test/resources/sql-tests/results/listagg.sql.out b/sql/core/src/test/resources/sql-tests/results/listagg.sql.out index 6b2f880f92955..14804080a0871 100644 --- a/sql/core/src/test/resources/sql-tests/results/listagg.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/listagg.sql.out @@ -380,67 +380,3 @@ org.apache.spark.sql.AnalysisException "fragment" : "listagg(DISTINCT a) WITHIN GROUP (ORDER BY a, b)" } ] } - - --- !query -SELECT listagg(c1) WITHIN GROUP (ORDER BY c1 COLLATE utf8_binary) FROM (VALUES ('a'), ('A'), ('b'), ('B')) AS t(c1) --- !query schema -struct --- !query output -ABab - - --- !query -SELECT listagg(c1) WITHIN GROUP (ORDER BY c1 COLLATE utf8_lcase) FROM (VALUES ('a'), ('A'), ('b'), ('B')) AS t(c1) --- !query schema -struct --- !query output -aAbB - - --- !query -SELECT listagg(DISTINCT c1 COLLATE utf8_binary) FROM (VALUES ('a'), ('A'), ('b'), ('B')) AS t(c1) --- !query schema -struct --- !query output -aAbB - - --- !query -SELECT listagg(DISTINCT c1 COLLATE utf8_lcase) FROM (VALUES ('a'), ('A'), ('b'), ('B')) AS t(c1) --- !query schema -struct --- !query output -ab - - --- !query -SELECT listagg(DISTINCT c1 COLLATE utf8_lcase) WITHIN GROUP (ORDER BY c1 COLLATE utf8_lcase) FROM (VALUES ('a'), ('B'), ('b'), ('A')) AS t(c1) --- !query schema -struct --- !query output -aB - - --- !query -SELECT listagg(DISTINCT c1 COLLATE utf8_lcase) WITHIN GROUP (ORDER BY c1 COLLATE utf8_binary) FROM (VALUES ('a'), ('b'), ('A'), ('B')) AS t(c1) --- !query schema -struct<> --- !query output -org.apache.spark.sql.AnalysisException -{ - "errorClass" : "FUNCTION_AND_ORDER_EXPRESSION_MISMATCH", - "sqlState" : "42822", - "messageParameters" : { - "functionArgs" : "\"collate(c1, utf8_lcase)\"", - "functionName" : "`listagg`", - "orderExpr" : "\"collate(c1, utf8_binary)\"" - }, - "queryContext" : [ { - "objectType" : "", - "objectName" : "", - "startIndex" : 8, - "stopIndex" : 93, - "fragment" : "listagg(DISTINCT c1 COLLATE utf8_lcase) WITHIN GROUP (ORDER BY c1 COLLATE utf8_binary)" - } ] -} diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerQueryTestSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerQueryTestSuite.scala index 782f549182ec2..6be4ee8f164b6 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerQueryTestSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerQueryTestSuite.scala @@ -104,6 +104,7 @@ class ThriftServerQueryTestSuite extends SQLQueryTestSuite with SharedThriftServ "timestampNTZ/datetime-special-ansi.sql", // SPARK-47264 "collations.sql", + "listagg-collations.sql", "pipe-operators.sql", // VARIANT type "variant/named-function-arguments.sql" From 056ec618dec2c0f934731f22d2eae61f11c3ce8d Mon Sep 17 00:00:00 2001 From: Mikhail Nikoliukin Date: Tue, 19 Nov 2024 15:23:24 +0100 Subject: [PATCH 46/58] [SPARK-42746] fix after merge --- .../analysis/FunctionResolution.scala | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionResolution.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionResolution.scala index 5a27a72190325..38d00b015e074 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionResolution.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionResolution.scala @@ -128,18 +128,14 @@ class FunctionResolution( numArgs: Int, u: UnresolvedFunction): Expression = { func match { - case owg: SupportsOrderingWithinGroup if u.isDistinct => - throw QueryCompilationErrors.distinctInverseDistributionFunctionUnsupportedError( - owg.prettyName - ) - case owg: SupportsOrderingWithinGroup - if !owg.orderingFilled && u.orderingWithinGroup.isEmpty => - throw QueryCompilationErrors.inverseDistributionFunctionMissingWithinGroupError( - owg.prettyName - ) + case owg: SupportsOrderingWithinGroup if !owg.isDistinctSupported && u.isDistinct => + throw QueryCompilationErrors.distinctWithOrderingFunctionUnsupportedError(owg.prettyName) case owg: SupportsOrderingWithinGroup - if owg.orderingFilled && u.orderingWithinGroup.nonEmpty => - throw QueryCompilationErrors.wrongNumOrderingsForInverseDistributionFunctionError( + if owg.isOrderingMandatory && !owg.orderingFilled && u.orderingWithinGroup.isEmpty => + throw QueryCompilationErrors.functionMissingWithinGroupError(owg.prettyName) + case owg: Mode if owg.orderingFilled && u.orderingWithinGroup.nonEmpty => + // mode(expr1) within group (order by expr2) is not supported + throw QueryCompilationErrors.wrongNumOrderingsForFunctionError( owg.prettyName, 0, u.orderingWithinGroup.length @@ -149,6 +145,10 @@ class FunctionResolution( func.prettyName, "WITHIN GROUP (ORDER BY ...)" ) + case listAgg: ListAgg + if u.isDistinct && !listAgg.isOrderCompatible(u.orderingWithinGroup) => + throw QueryCompilationErrors.functionAndOrderExpressionMismatchError( + listAgg.prettyName, listAgg.child, u.orderingWithinGroup) // AggregateWindowFunctions are AggregateFunctions that can only be evaluated within // the context of a Window clause. They do not need to be wrapped in an // AggregateExpression. @@ -198,7 +198,7 @@ class FunctionResolution( case agg: AggregateFunction => // Note: PythonUDAF does not support these advanced clauses. if (agg.isInstanceOf[PythonUDAF]) checkUnsupportedAggregateClause(agg, u) - // After parse, the inverse distribution functions not set the ordering within group yet. + // After parse, the functions not set the ordering within group yet. val newAgg = agg match { case owg: SupportsOrderingWithinGroup if !owg.orderingFilled && u.orderingWithinGroup.nonEmpty => From cb5ad3e0285c2b95f5242b7b1a775acaf395cfb6 Mon Sep 17 00:00:00 2001 From: Mikhail Nikoliukin Date: Thu, 21 Nov 2024 12:13:55 +0100 Subject: [PATCH 47/58] [SPARK-42746] add comments to sortBuffer --- .../expressions/aggregate/collect.scala | 36 ++++++++++++++----- 1 file changed, 27 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala index 8d7396f334ecb..b3133a7905f99 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala @@ -388,10 +388,9 @@ case class ListAgg( override def checkInputDataTypes(): TypeCheckResult = { val matchInputTypes = super.checkInputDataTypes() if (matchInputTypes.isFailure) { - return matchInputTypes - } - if (!delimiter.foldable) { - return DataTypeMismatch( + matchInputTypes + } else if (!delimiter.foldable) { + DataTypeMismatch( errorSubClass = "NON_FOLDABLE_INPUT", messageParameters = Map( "inputName" -> toSQLId("delimiter"), @@ -399,8 +398,7 @@ case class ListAgg( "inputExpr" -> toSQLExpr(delimiter) ) ) - } - if (delimiter.dataType == NullType) { + } else if (delimiter.dataType == NullType) { // null is the default empty delimiter so type is not important TypeCheckSuccess } else { @@ -417,25 +415,44 @@ case class ListAgg( } } + /** + * Sort buffer according orderExpressions. + * If orderExpressions is empty them returns buffer as is. + * The format of buffer is determined by [[noNeedSaveOrderValue]] + * @return sorted buffer containing only child's values + */ private[this] def sortBuffer(buffer: mutable.ArrayBuffer[Any]): mutable.ArrayBuffer[Any] = { if (!orderingFilled) { + // without order return as is. return buffer } if (noNeedSaveOrderValue) { - val ascendingOrdering = PhysicalDataType.ordering(orderExpressions.head.dataType) + // Here the buffer has structure [childValue0, childValue1, ...] + // and we want to sort it by childValues + val sortOrderExpression = orderExpressions.head + val ascendingOrdering = PhysicalDataType.ordering(sortOrderExpression.dataType) val ordering = - if (orderExpressions.head.direction == Ascending) ascendingOrdering + if (sortOrderExpression.direction == Ascending) ascendingOrdering else ascendingOrdering.reverse buffer.sorted(ordering) } else { + // Here the buffer has structure + // [[childValue, orderValue0, orderValue1, ...], + // [childValue, orderValue0, orderValue1, ...], + // ...] + // and we want to sort it by tuples (orderValue0, orderValue1, ...) buffer .asInstanceOf[mutable.ArrayBuffer[InternalRow]] .sorted(bufferOrdering) - // drop order values after sort + // drop orderValues after sort .map(_.get(0, child.dataType)) } } + /** + * @return ordering by (orderValue0, orderValue1, ...) + * for InternalRow with format [childValue, orderValue0, orderValue1, ...] + */ private[this] def bufferOrdering: Ordering[InternalRow] = { val bufferSortOrder = orderExpressions.zipWithIndex.map { case (originalOrder, i) => @@ -446,6 +463,7 @@ case class ListAgg( } new InterpretedOrdering(bufferSortOrder) } + private[this] def concatSkippingNulls(buffer: mutable.ArrayBuffer[Any]): Any = { val delimiterValue = getDelimiterValue dataType match { From 07dfd82c14293d11b7fc8680982f501013088323 Mon Sep 17 00:00:00 2001 From: Mikhail Nikoliukin Date: Thu, 21 Nov 2024 12:15:27 +0100 Subject: [PATCH 48/58] [SPARK-42746] return SupportsOrderingWithinGroup check --- .../spark/sql/catalyst/analysis/FunctionResolution.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionResolution.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionResolution.scala index 38d00b015e074..ea3a946cc1cbe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionResolution.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionResolution.scala @@ -133,8 +133,9 @@ class FunctionResolution( case owg: SupportsOrderingWithinGroup if owg.isOrderingMandatory && !owg.orderingFilled && u.orderingWithinGroup.isEmpty => throw QueryCompilationErrors.functionMissingWithinGroupError(owg.prettyName) - case owg: Mode if owg.orderingFilled && u.orderingWithinGroup.nonEmpty => - // mode(expr1) within group (order by expr2) is not supported + case owg: SupportsOrderingWithinGroup + if owg.orderingFilled && u.orderingWithinGroup.nonEmpty => + // e.g mode(expr1) within group (order by expr2) is not supported throw QueryCompilationErrors.wrongNumOrderingsForFunctionError( owg.prettyName, 0, From be68e20aab749d278511869234568b4f15db4b8f Mon Sep 17 00:00:00 2001 From: Mikhail Nikoliukin Date: Thu, 21 Nov 2024 12:15:57 +0100 Subject: [PATCH 49/58] [SPARK-42746] remove test duplicates --- .../org/apache/spark/sql/SQLQuerySuite.scala | 251 ------------------ 1 file changed, 251 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index a5f03297a56a3..29eb5e0fed0b2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -171,257 +171,6 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark } } - test("listagg function") { - withTempView("df", "df2") { - Seq(("a", "b"), ("a", "c"), ("b", "c"), ("b", "d"), (null, null)).toDF("a", "b") - .createOrReplaceTempView("df") - checkAnswer( - sql("select listagg(b) from df group by a"), - Row(null) :: Row("bc") :: Row("cd") :: Nil) - - checkAnswer( - sql("select string_agg(b) from df group by a"), - Row(null) :: Row("bc") :: Row("cd") :: Nil) - - checkAnswer( - sql("select listagg(b, null) from df group by a"), - Row(null) :: Row("bc") :: Row("cd") :: Nil) - - checkAnswer( - sql("select listagg(b) from df where 1 != 1"), - Row(null) :: Nil) - - checkAnswer( - sql("select listagg(b, '|') from df group by a"), - Row("b|c") :: Row("c|d") :: Row(null) :: Nil) - - checkAnswer( - spark.sql("select listagg(b, :param || ' ') from df group by a", Map("param" -> ",")), - Row("b, c") :: Row("c, d") :: Row(null) :: Nil) - - checkAnswer( - sql("select listagg(a) from df"), - Row("aabb") :: Nil) - - checkAnswer( - sql("select listagg(distinct a) from df"), - Row("ab") :: Nil) - - checkAnswer( - sql("select listagg(a) within group (order by a) from df"), - Row("aabb") :: Nil) - - checkAnswer( - sql("select listagg(a) within group (order by a desc) from df"), - Row("bbaa") :: Nil) - - checkAnswer( - sql("""select listagg(a) within group (order by a desc) over (partition by b) from df"""), - Row("a") :: Row("ba") :: Row("ba") :: Row("b") :: Row(null) :: Nil) - - checkAnswer( - sql("select listagg(a) within group (order by b) from df"), - Row("aabb") :: Nil) - - checkAnswer( - sql("select listagg(a) within group (order by b desc) from df"), - Row("baba") :: Nil) - - checkAnswer( - sql("select listagg(a, '|') within group (order by b desc) from df"), - Row("b|a|b|a") :: Nil) - - checkAnswer( - sql("select listagg(a) within group (order by b desc, a asc) from df"), - Row("baba") :: Nil) - - checkAnswer( - sql("select listagg(a) within group (order by b desc, a desc) from df"), - Row("bbaa") :: Nil) - - checkAnswer( - sql("select listagg(c1)from values (X'DEAD'), (X'BEEF') as t(c1)"), - Row(hexToBytes("DEADBEEF")) :: Nil) - - checkAnswer( - sql("select listagg(c1, null)from values (X'DEAD'), (X'BEEF') as t(c1)"), - Row(hexToBytes("DEADBEEF")) :: Nil) - - checkAnswer( - sql("select listagg(c1, X'42')from values (X'DEAD'), (X'BEEF') as t(c1)"), - Row(hexToBytes("DEAD42BEEF")) :: Nil) - - checkError( - exception = intercept[AnalysisException] { - sql("select listagg(c1) from values (array('a', 'b')) as t(c1)") - }, - condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", - parameters = Map( - "sqlExpr" -> "\"listagg(c1, NULL)\"", - "paramIndex" -> "first", - "requiredType" -> "(\"STRING\" or \"BINARY\")", - "inputSql" -> "\"c1\"", - "inputType" -> "\"ARRAY\""), - context = ExpectedContext( - fragment = "listagg(c1)", - start = 7, - stop = 17)) - - checkError( - exception = intercept[AnalysisException] { - sql("select listagg(c1, ', ')from values (X'DEAD'), (X'BEEF') as t(c1)") - }, - condition = "DATATYPE_MISMATCH.DATA_DIFF_TYPES", - parameters = Map( - "sqlExpr" -> "\"listagg(c1, , )\"", - "functionName" -> "`listagg`", - "dataType" -> "(\"BINARY\" or \"STRING\")"), - context = ExpectedContext( - fragment = "listagg(c1, ', ')", - start = 7, - stop = 23)) - - checkError( - exception = intercept[AnalysisException] { - sql("select listagg(b, a) from df group by a") - }, - condition = "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT", - parameters = Map( - "sqlExpr" -> "\"listagg(b, a)\"", - "inputName" -> "`delimiter`", - "inputType" -> "\"STRING\"", - "inputExpr" -> "\"a\""), - context = ExpectedContext( - fragment = "listagg(b, a)", - start = 7, - stop = 19)) - - checkAnswer( - sql("select listagg(a) over (order by a) from df"), - Row(null) :: Row("aa") :: Row("aa") :: Row("aabb") :: Row("aabb") :: Nil) - - checkError( - exception = intercept[AnalysisException] { - sql("select listagg(a) within group (order by a) over (order by a) from df") - }, - condition = "INVALID_WINDOW_SPEC_FOR_AGGREGATION_FUNC", - parameters = Map("aggFunc" -> "\"listagg(a, NULL, a)\""), - context = ExpectedContext( - fragment = "listagg(a) within group (order by a) over (order by a)", - start = 7, - stop = 60)) - - checkError( - exception = intercept[AnalysisException] { - sql("select string_agg(a) within group (order by a) over (order by a) from df") - }, - condition = "INVALID_WINDOW_SPEC_FOR_AGGREGATION_FUNC", - parameters = Map("aggFunc" -> "\"listagg(a, NULL, a)\""), - context = ExpectedContext( - fragment = "string_agg(a) within group (order by a) over (order by a)", - start = 7, - stop = 63)) - - checkError( - exception = intercept[AnalysisException] { - sql("select listagg(distinct a) over (order by a) from df") - }, - condition = "DISTINCT_WINDOW_FUNCTION_UNSUPPORTED", - parameters = Map("windowExpr" -> - ("\"listagg(DISTINCT a, NULL) " + - "OVER (ORDER BY a ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)\"")), - context = ExpectedContext( - fragment = "listagg(distinct a) over (order by a)", - start = 7, - stop = 43)) - - checkAnswer( - sql("select listagg(distinct a) within group (order by a DESC) from df"), - Row("ba") :: Nil) - - checkError( - exception = intercept[AnalysisException] { - sql("select listagg(distinct a) within group (order by b) from df") - }, - condition = "FUNCTION_AND_ORDER_EXPRESSION_MISMATCH", - parameters = Map( - "functionName" -> "`listagg`", - "functionArgs" -> "\"a\"", - "orderExpr" -> "\"b\""), - context = ExpectedContext( - fragment = "listagg(distinct a) within group (order by b)", - start = 7, - stop = 51)) - - checkError( - exception = intercept[AnalysisException] { - sql("select listagg(distinct a) within group (order by a, b) from df") - }, - condition = "FUNCTION_AND_ORDER_EXPRESSION_MISMATCH", - parameters = Map( - "functionName" -> "`listagg`", - "functionArgs" -> "\"a\"", - "orderExpr" -> "\"a\", \"b\""), - context = ExpectedContext( - fragment = "listagg(distinct a) within group (order by a, b)", - start = 7, - stop = 54)) - - Seq((1, true), (2, false), (3, false)).toDF("a", "b").createOrReplaceTempView("df2") - - checkAnswer( - sql("select listagg(a), listagg(b, ',') from df2"), - Row("123", "true,false,false") :: Nil) - } - } - - test("listagg collation test") { - checkAnswer( - sql("select listagg(c1) within group (order by c1 collate utf8_binary)" + - " from values ('a'), ('A'), ('b'), ('B') as t(c1)"), - Row("ABab") :: Nil) - - checkAnswer( - sql("select listagg(c1) within group (order by c1 collate utf8_lcase)" + - " from values ('a'), ('A'), ('b'), ('B') as t(c1)"), - Row("aAbB") :: Nil) - - checkAnswer( - sql("select listagg(DISTINCT c1 collate utf8_binary)" + - " from values ('a'), ('A'), ('b'), ('B') as t(c1)"), - Row("aAbB") :: Nil) - - checkAnswer( - sql("select listagg(DISTINCT c1 collate utf8_lcase)" + - " from values ('a'), ('A'), ('b'), ('B') as t(c1)"), - Row("ab") :: Nil) - - checkAnswer( - sql("select listagg(DISTINCT c1 collate utf8_lcase)" + - " within group (order by c1 collate utf8_lcase)" + - " from values ('a'), ('B'), ('b'), ('A') as t(c1)"), - Row("aB") :: Nil) - - checkError( - exception = intercept[AnalysisException] { - sql( - """select listagg(DISTINCT c1 collate utf8_lcase) - | within group (order by c1 collate utf8_binary) - | from values ('a'), ('b'), ('A'), ('B') as t(c1)""".stripMargin) - }, - condition = "FUNCTION_AND_ORDER_EXPRESSION_MISMATCH", - parameters = Map( - "functionName" -> "`listagg`", - "functionArgs" -> "\"collate(c1, utf8_lcase)\"", - "orderExpr" -> "\"collate(c1, utf8_binary)\""), - context = ExpectedContext( - fragment = - """listagg(DISTINCT c1 collate utf8_lcase) - | within group (order by c1 collate utf8_binary)""".stripMargin, - start = 7, - stop = 93)) - } - test("support table.star") { checkAnswer( sql( From 6a9c1fec5bf826e2f77ab9f6d54cfef71802a44e Mon Sep 17 00:00:00 2001 From: Mikhail Nikoliukin Date: Thu, 21 Nov 2024 12:49:27 +0100 Subject: [PATCH 50/58] [SPARK-42746] move functionAndOrderExpressionMismatchError to CheckAnalysis --- .../sql/catalyst/analysis/CheckAnalysis.scala | 5 +++++ .../analysis/FunctionResolution.scala | 4 ---- .../expressions/aggregate/collect.scala | 16 +++++++------- .../listagg-collations.sql.out | 11 ++-------- .../analyzer-results/listagg.sql.out | 22 ++++--------------- .../results/listagg-collations.sql.out | 11 ++-------- .../sql-tests/results/listagg.sql.out | 22 ++++--------------- 7 files changed, 25 insertions(+), 66 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 6efc6e9f68cee..4506e5bdaf2dd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -423,6 +423,11 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB "funcName" -> toSQLExpr(wf), "windowExpr" -> toSQLExpr(w))) + case agg @ AggregateExpression(listAgg: ListAgg, _, _, _, _) + if agg.isDistinct && listAgg.needSaveOrderValue => + throw QueryCompilationErrors.functionAndOrderExpressionMismatchError( + listAgg.prettyName, listAgg.child, listAgg.orderExpressions) + case w: WindowExpression => // Only allow window functions with an aggregate expression or an offset window // function or a Pandas window UDF. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionResolution.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionResolution.scala index ea3a946cc1cbe..800126e0030e8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionResolution.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionResolution.scala @@ -146,10 +146,6 @@ class FunctionResolution( func.prettyName, "WITHIN GROUP (ORDER BY ...)" ) - case listAgg: ListAgg - if u.isDistinct && !listAgg.isOrderCompatible(u.orderingWithinGroup) => - throw QueryCompilationErrors.functionAndOrderExpressionMismatchError( - listAgg.prettyName, listAgg.child, u.orderingWithinGroup) // AggregateWindowFunctions are AggregateFunctions that can only be evaluated within // the context of a Window clause. They do not need to be wrapped in an // AggregateExpression. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala index b3133a7905f99..1ace6d8ac1431 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala @@ -333,7 +333,7 @@ case class ListAgg( copy(orderExpressions = orderingWithinGroup) override protected lazy val bufferElementType: DataType = { - if (noNeedSaveOrderValue) { + if (!needSaveOrderValue) { child.dataType } else { StructType( @@ -342,8 +342,8 @@ case class ListAgg( ) } } - /** Indicates that the result of [[child]] is enough for evaluation */ - private lazy val noNeedSaveOrderValue: Boolean = isOrderCompatible(orderExpressions) + /** Indicates that the result of [[child]] is not enough for evaluation */ + lazy val needSaveOrderValue: Boolean = !isOrderCompatible(orderExpressions) def this(child: Expression) = this(child, Literal(null), Nil, 0, 0) @@ -418,7 +418,7 @@ case class ListAgg( /** * Sort buffer according orderExpressions. * If orderExpressions is empty them returns buffer as is. - * The format of buffer is determined by [[noNeedSaveOrderValue]] + * The format of buffer is determined by [[needSaveOrderValue]] * @return sorted buffer containing only child's values */ private[this] def sortBuffer(buffer: mutable.ArrayBuffer[Any]): mutable.ArrayBuffer[Any] = { @@ -426,7 +426,7 @@ case class ListAgg( // without order return as is. return buffer } - if (noNeedSaveOrderValue) { + if (!needSaveOrderValue) { // Here the buffer has structure [childValue0, childValue1, ...] // and we want to sort it by childValues val sortOrderExpression = orderExpressions.head @@ -494,7 +494,7 @@ case class ListAgg( override def update(buffer: ArrayBuffer[Any], input: InternalRow): ArrayBuffer[Any] = { val value = child.eval(input) if (value != null) { - val v = if (noNeedSaveOrderValue) { + val v = if (!needSaveOrderValue) { convertToBufferElement(value) } else { InternalRow.fromSeq(convertToBufferElement(value) +: evalOrderValues(input)) @@ -516,9 +516,9 @@ case class ListAgg( * Utility func to check if given order is defined and different from [[child]]. * * @see [[QueryCompilationErrors.functionAndOrderExpressionMismatchError]] - * @see [[noNeedSaveOrderValue]] + * @see [[needSaveOrderValue]] */ - def isOrderCompatible(someOrder: Seq[SortOrder]): Boolean = { + private[this] def isOrderCompatible(someOrder: Seq[SortOrder]): Boolean = { if (someOrder.isEmpty) { return true } diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/listagg-collations.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/listagg-collations.sql.out index b29f722f0fb9b..60e885c4597ed 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/listagg-collations.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/listagg-collations.sql.out @@ -47,7 +47,7 @@ Aggregate [listagg(distinct collate(c1#x, utf8_lcase), null, collate(c1#x, utf8_ -- !query SELECT listagg(DISTINCT c1 COLLATE utf8_lcase) WITHIN GROUP (ORDER BY c1 COLLATE utf8_binary) FROM (VALUES ('a'), ('b'), ('A'), ('B')) AS t(c1) -- !query analysis -org.apache.spark.sql.AnalysisException +org.apache.spark.sql.catalyst.ExtendedAnalysisException { "errorClass" : "FUNCTION_AND_ORDER_EXPRESSION_MISMATCH", "sqlState" : "42822", @@ -55,12 +55,5 @@ org.apache.spark.sql.AnalysisException "functionArgs" : "\"collate(c1, utf8_lcase)\"", "functionName" : "`listagg`", "orderExpr" : "\"collate(c1, utf8_binary)\"" - }, - "queryContext" : [ { - "objectType" : "", - "objectName" : "", - "startIndex" : 8, - "stopIndex" : 93, - "fragment" : "listagg(DISTINCT c1 COLLATE utf8_lcase) WITHIN GROUP (ORDER BY c1 COLLATE utf8_binary)" - } ] + } } diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/listagg.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/listagg.sql.out index a6f1821f5d3cb..9c4f18cb5b50f 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/listagg.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/listagg.sql.out @@ -408,7 +408,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException -- !query SELECT listagg(DISTINCT a) WITHIN GROUP (ORDER BY b) FROM df -- !query analysis -org.apache.spark.sql.AnalysisException +org.apache.spark.sql.catalyst.ExtendedAnalysisException { "errorClass" : "FUNCTION_AND_ORDER_EXPRESSION_MISMATCH", "sqlState" : "42822", @@ -416,21 +416,14 @@ org.apache.spark.sql.AnalysisException "functionArgs" : "\"a\"", "functionName" : "`listagg`", "orderExpr" : "\"b\"" - }, - "queryContext" : [ { - "objectType" : "", - "objectName" : "", - "startIndex" : 8, - "stopIndex" : 52, - "fragment" : "listagg(DISTINCT a) WITHIN GROUP (ORDER BY b)" - } ] + } } -- !query SELECT listagg(DISTINCT a) WITHIN GROUP (ORDER BY a, b) FROM df -- !query analysis -org.apache.spark.sql.AnalysisException +org.apache.spark.sql.catalyst.ExtendedAnalysisException { "errorClass" : "FUNCTION_AND_ORDER_EXPRESSION_MISMATCH", "sqlState" : "42822", @@ -438,12 +431,5 @@ org.apache.spark.sql.AnalysisException "functionArgs" : "\"a\"", "functionName" : "`listagg`", "orderExpr" : "\"a\", \"b\"" - }, - "queryContext" : [ { - "objectType" : "", - "objectName" : "", - "startIndex" : 8, - "stopIndex" : 55, - "fragment" : "listagg(DISTINCT a) WITHIN GROUP (ORDER BY a, b)" - } ] + } } diff --git a/sql/core/src/test/resources/sql-tests/results/listagg-collations.sql.out b/sql/core/src/test/resources/sql-tests/results/listagg-collations.sql.out index 2f1640def3adb..136e2040987f7 100644 --- a/sql/core/src/test/resources/sql-tests/results/listagg-collations.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/listagg-collations.sql.out @@ -44,7 +44,7 @@ SELECT listagg(DISTINCT c1 COLLATE utf8_lcase) WITHIN GROUP (ORDER BY c1 COLLATE -- !query schema struct<> -- !query output -org.apache.spark.sql.AnalysisException +org.apache.spark.sql.catalyst.ExtendedAnalysisException { "errorClass" : "FUNCTION_AND_ORDER_EXPRESSION_MISMATCH", "sqlState" : "42822", @@ -52,12 +52,5 @@ org.apache.spark.sql.AnalysisException "functionArgs" : "\"collate(c1, utf8_lcase)\"", "functionName" : "`listagg`", "orderExpr" : "\"collate(c1, utf8_binary)\"" - }, - "queryContext" : [ { - "objectType" : "", - "objectName" : "", - "startIndex" : 8, - "stopIndex" : 93, - "fragment" : "listagg(DISTINCT c1 COLLATE utf8_lcase) WITHIN GROUP (ORDER BY c1 COLLATE utf8_binary)" - } ] + } } diff --git a/sql/core/src/test/resources/sql-tests/results/listagg.sql.out b/sql/core/src/test/resources/sql-tests/results/listagg.sql.out index 14804080a0871..033997350e877 100644 --- a/sql/core/src/test/resources/sql-tests/results/listagg.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/listagg.sql.out @@ -339,7 +339,7 @@ SELECT listagg(DISTINCT a) WITHIN GROUP (ORDER BY b) FROM df -- !query schema struct<> -- !query output -org.apache.spark.sql.AnalysisException +org.apache.spark.sql.catalyst.ExtendedAnalysisException { "errorClass" : "FUNCTION_AND_ORDER_EXPRESSION_MISMATCH", "sqlState" : "42822", @@ -347,14 +347,7 @@ org.apache.spark.sql.AnalysisException "functionArgs" : "\"a\"", "functionName" : "`listagg`", "orderExpr" : "\"b\"" - }, - "queryContext" : [ { - "objectType" : "", - "objectName" : "", - "startIndex" : 8, - "stopIndex" : 52, - "fragment" : "listagg(DISTINCT a) WITHIN GROUP (ORDER BY b)" - } ] + } } @@ -363,7 +356,7 @@ SELECT listagg(DISTINCT a) WITHIN GROUP (ORDER BY a, b) FROM df -- !query schema struct<> -- !query output -org.apache.spark.sql.AnalysisException +org.apache.spark.sql.catalyst.ExtendedAnalysisException { "errorClass" : "FUNCTION_AND_ORDER_EXPRESSION_MISMATCH", "sqlState" : "42822", @@ -371,12 +364,5 @@ org.apache.spark.sql.AnalysisException "functionArgs" : "\"a\"", "functionName" : "`listagg`", "orderExpr" : "\"a\", \"b\"" - }, - "queryContext" : [ { - "objectType" : "", - "objectName" : "", - "startIndex" : 8, - "stopIndex" : 55, - "fragment" : "listagg(DISTINCT a) WITHIN GROUP (ORDER BY a, b)" - } ] + } } From 0efedf3a213150629ad0336e91f97f48e435755f Mon Sep 17 00:00:00 2001 From: Mikhail Nikoliukin Date: Fri, 22 Nov 2024 16:45:02 +0100 Subject: [PATCH 51/58] [SPARK-42746] FUNCTION_AND_ORDER_EXPRESSION_MISMATCH -> INVALID_WITHIN_GROUP_EXPRESSION.MISMATCH_WITH_DISTINCT_INPUT --- .../main/resources/error/error-conditions.json | 11 +++++------ .../sql/catalyst/analysis/CheckAnalysis.scala | 2 +- .../expressions/aggregate/collect.scala | 4 ++-- .../sql/errors/QueryCompilationErrors.scala | 8 +++----- .../listagg-collations.sql.out | 9 ++++----- .../sql-tests/analyzer-results/listagg.sql.out | 18 ++++++++---------- .../results/listagg-collations.sql.out | 9 ++++----- .../sql-tests/results/listagg.sql.out | 18 ++++++++---------- .../org/apache/spark/sql/SQLQuerySuite.scala | 10 ---------- 9 files changed, 35 insertions(+), 54 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 3e4643c27f141..7c8b36e806030 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -1565,12 +1565,6 @@ ], "sqlState" : "42710" }, - "FUNCTION_AND_ORDER_EXPRESSION_MISMATCH" : { - "message" : [ - "The arguments of the function do not match to ordering within group when use DISTINCT." - ], - "sqlState" : "42822" - }, "GENERATED_COLUMN_WITH_DEFAULT_VALUE" : { "message" : [ "A column cannot have both a default value and a generation expression but column has default value: () and generation expression: ()." @@ -3357,6 +3351,11 @@ "The function does not support DISTINCT with WITHIN GROUP." ] }, + "MISMATCH_WITH_DISTINCT_INPUT": { + "message": [ + "Function is invoked with DISTINCT. The WITHIN GROUP ordering expressions must be picked from the function inputs, but got ." + ] + }, "WITHIN_GROUP_MISSING" : { "message" : [ "WITHIN GROUP is required for the function." diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 4506e5bdaf2dd..ce9c12a2d866f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -426,7 +426,7 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB case agg @ AggregateExpression(listAgg: ListAgg, _, _, _, _) if agg.isDistinct && listAgg.needSaveOrderValue => throw QueryCompilationErrors.functionAndOrderExpressionMismatchError( - listAgg.prettyName, listAgg.child, listAgg.orderExpressions) + listAgg.prettyName, listAgg.orderExpressions) case w: WindowExpression => // Only allow window functions with an aggregate expression or an offset window diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala index 1ace6d8ac1431..f339879ef5981 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala @@ -417,7 +417,7 @@ case class ListAgg( /** * Sort buffer according orderExpressions. - * If orderExpressions is empty them returns buffer as is. + * If orderExpressions is empty then returns buffer as is. * The format of buffer is determined by [[needSaveOrderValue]] * @return sorted buffer containing only child's values */ @@ -472,7 +472,7 @@ case class ListAgg( ByteArray.concatWS(delimiterValue.asInstanceOf[Array[Byte]], inputs.toSeq: _*) case _: StringType => val inputs = buffer.filter(_ != null).map(_.asInstanceOf[UTF8String]) - UTF8String.fromString(inputs.mkString(delimiterValue.toString)) + UTF8String.concatWs(delimiterValue.asInstanceOf[UTF8String], inputs.toSeq : _*) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index 8adad5a91be69..8ea3c52bda609 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -1055,14 +1055,12 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat def functionAndOrderExpressionMismatchError( functionName: String, - functionArgs: Expression, orderExpr: Seq[SortOrder]): Throwable = { new AnalysisException( - errorClass = "FUNCTION_AND_ORDER_EXPRESSION_MISMATCH", + errorClass = "INVALID_WITHIN_GROUP_EXPRESSION.MISMATCH_WITH_DISTINCT_INPUT", messageParameters = Map( - "functionName" -> toSQLId(functionName), - "functionArgs" -> toSQLExpr(functionArgs), - "orderExpr" -> orderExpr.map(order => toSQLExpr(order.child)).mkString(", "))) + "funcName" -> toSQLId(functionName), + "orderingExpr" -> orderExpr.map(order => toSQLExpr(order.child)).mkString(", "))) } def wrongCommandForObjectTypeError( diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/listagg-collations.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/listagg-collations.sql.out index 60e885c4597ed..2fdee7a8d8172 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/listagg-collations.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/listagg-collations.sql.out @@ -49,11 +49,10 @@ SELECT listagg(DISTINCT c1 COLLATE utf8_lcase) WITHIN GROUP (ORDER BY c1 COLLATE -- !query analysis org.apache.spark.sql.catalyst.ExtendedAnalysisException { - "errorClass" : "FUNCTION_AND_ORDER_EXPRESSION_MISMATCH", - "sqlState" : "42822", + "errorClass" : "INVALID_WITHIN_GROUP_EXPRESSION.MISMATCH_WITH_DISTINCT_INPUT", + "sqlState" : "42K0K", "messageParameters" : { - "functionArgs" : "\"collate(c1, utf8_lcase)\"", - "functionName" : "`listagg`", - "orderExpr" : "\"collate(c1, utf8_binary)\"" + "funcName" : "`listagg`", + "orderingExpr" : "\"collate(c1, utf8_binary)\"" } } diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/listagg.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/listagg.sql.out index 9c4f18cb5b50f..07269b3cd9055 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/listagg.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/listagg.sql.out @@ -410,12 +410,11 @@ SELECT listagg(DISTINCT a) WITHIN GROUP (ORDER BY b) FROM df -- !query analysis org.apache.spark.sql.catalyst.ExtendedAnalysisException { - "errorClass" : "FUNCTION_AND_ORDER_EXPRESSION_MISMATCH", - "sqlState" : "42822", + "errorClass" : "INVALID_WITHIN_GROUP_EXPRESSION.MISMATCH_WITH_DISTINCT_INPUT", + "sqlState" : "42K0K", "messageParameters" : { - "functionArgs" : "\"a\"", - "functionName" : "`listagg`", - "orderExpr" : "\"b\"" + "funcName" : "`listagg`", + "orderingExpr" : "\"b\"" } } @@ -425,11 +424,10 @@ SELECT listagg(DISTINCT a) WITHIN GROUP (ORDER BY a, b) FROM df -- !query analysis org.apache.spark.sql.catalyst.ExtendedAnalysisException { - "errorClass" : "FUNCTION_AND_ORDER_EXPRESSION_MISMATCH", - "sqlState" : "42822", + "errorClass" : "INVALID_WITHIN_GROUP_EXPRESSION.MISMATCH_WITH_DISTINCT_INPUT", + "sqlState" : "42K0K", "messageParameters" : { - "functionArgs" : "\"a\"", - "functionName" : "`listagg`", - "orderExpr" : "\"a\", \"b\"" + "funcName" : "`listagg`", + "orderingExpr" : "\"a\", \"b\"" } } diff --git a/sql/core/src/test/resources/sql-tests/results/listagg-collations.sql.out b/sql/core/src/test/resources/sql-tests/results/listagg-collations.sql.out index 136e2040987f7..23a9b76a0841f 100644 --- a/sql/core/src/test/resources/sql-tests/results/listagg-collations.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/listagg-collations.sql.out @@ -46,11 +46,10 @@ struct<> -- !query output org.apache.spark.sql.catalyst.ExtendedAnalysisException { - "errorClass" : "FUNCTION_AND_ORDER_EXPRESSION_MISMATCH", - "sqlState" : "42822", + "errorClass" : "INVALID_WITHIN_GROUP_EXPRESSION.MISMATCH_WITH_DISTINCT_INPUT", + "sqlState" : "42K0K", "messageParameters" : { - "functionArgs" : "\"collate(c1, utf8_lcase)\"", - "functionName" : "`listagg`", - "orderExpr" : "\"collate(c1, utf8_binary)\"" + "funcName" : "`listagg`", + "orderingExpr" : "\"collate(c1, utf8_binary)\"" } } diff --git a/sql/core/src/test/resources/sql-tests/results/listagg.sql.out b/sql/core/src/test/resources/sql-tests/results/listagg.sql.out index 033997350e877..b6e141f9a003f 100644 --- a/sql/core/src/test/resources/sql-tests/results/listagg.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/listagg.sql.out @@ -341,12 +341,11 @@ struct<> -- !query output org.apache.spark.sql.catalyst.ExtendedAnalysisException { - "errorClass" : "FUNCTION_AND_ORDER_EXPRESSION_MISMATCH", - "sqlState" : "42822", + "errorClass" : "INVALID_WITHIN_GROUP_EXPRESSION.MISMATCH_WITH_DISTINCT_INPUT", + "sqlState" : "42K0K", "messageParameters" : { - "functionArgs" : "\"a\"", - "functionName" : "`listagg`", - "orderExpr" : "\"b\"" + "funcName" : "`listagg`", + "orderingExpr" : "\"b\"" } } @@ -358,11 +357,10 @@ struct<> -- !query output org.apache.spark.sql.catalyst.ExtendedAnalysisException { - "errorClass" : "FUNCTION_AND_ORDER_EXPRESSION_MISMATCH", - "sqlState" : "42822", + "errorClass" : "INVALID_WITHIN_GROUP_EXPRESSION.MISMATCH_WITH_DISTINCT_INPUT", + "sqlState" : "42K0K", "messageParameters" : { - "functionArgs" : "\"a\"", - "functionName" : "`listagg`", - "orderExpr" : "\"a\", \"b\"" + "funcName" : "`listagg`", + "orderingExpr" : "\"a\", \"b\"" } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 29eb5e0fed0b2..1b8f596a999b7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -161,16 +161,6 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark } } - private[this] def hexToBytes(s: String): Array[Byte] = { - val byteArray = BigInt(s, 16).toByteArray - if (byteArray.length > 1 && byteArray(0) == 0) { - // remove sign byte for positive numbers if exists - byteArray.tail - } else { - byteArray - } - } - test("support table.star") { checkAnswer( sql( From 811c36c76dd5a84df758ebf1d795c86193e0ca63 Mon Sep 17 00:00:00 2001 From: Mikhail Nikoliukin Date: Fri, 22 Nov 2024 17:14:41 +0100 Subject: [PATCH 52/58] [SPARK-42746] add trim collation tests --- .../listagg-collations.sql.out | 27 +++++++++++++++++++ .../sql-tests/inputs/listagg-collations.sql | 3 +++ .../results/listagg-collations.sql.out | 26 ++++++++++++++++++ 3 files changed, 56 insertions(+) diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/listagg-collations.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/listagg-collations.sql.out index 2fdee7a8d8172..d23da89d35351 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/listagg-collations.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/listagg-collations.sql.out @@ -44,6 +44,33 @@ Aggregate [listagg(distinct collate(c1#x, utf8_lcase), null, collate(c1#x, utf8_ +- LocalRelation [col1#x] +-- !query +SELECT listagg(DISTINCT c1 COLLATE unicode_rtrim) FROM (VALUES ('abc '), ('abc '), ('x'), ('abc')) AS t(c1) +-- !query analysis +Aggregate [listagg(distinct collate(c1#x, unicode_rtrim), null, 0, 0) AS listagg(DISTINCT collate(c1, unicode_rtrim), NULL)#x] ++- SubqueryAlias t + +- Project [col1#x AS c1#x] + +- LocalRelation [col1#x] + + +-- !query +SELECT listagg(c1) WITHIN GROUP (ORDER BY c1) FROM (VALUES ('abc '), ('abc '), ('abc\n'), ('abc'), ('x')) AS t(c1) +-- !query analysis +Aggregate [listagg(c1#x, null, c1#x ASC NULLS FIRST, 0, 0) AS listagg(c1, NULL) WITHIN GROUP (ORDER BY c1 ASC NULLS FIRST)#x] ++- SubqueryAlias t + +- Project [col1#x AS c1#x] + +- LocalRelation [col1#x] + + +-- !query +SELECT listagg(c1) WITHIN GROUP (ORDER BY c1 COLLATE unicode_rtrim) FROM (VALUES ('abc '), ('abc '), ('abc\n'), ('abc'), ('x')) AS t(c1) +-- !query analysis +Aggregate [listagg(c1#x, null, collate(c1#x, unicode_rtrim) ASC NULLS FIRST, 0, 0) AS listagg(c1, NULL) WITHIN GROUP (ORDER BY collate(c1, unicode_rtrim) ASC NULLS FIRST)#x] ++- SubqueryAlias t + +- Project [col1#x AS c1#x] + +- LocalRelation [col1#x] + + -- !query SELECT listagg(DISTINCT c1 COLLATE utf8_lcase) WITHIN GROUP (ORDER BY c1 COLLATE utf8_binary) FROM (VALUES ('a'), ('b'), ('A'), ('B')) AS t(c1) -- !query analysis diff --git a/sql/core/src/test/resources/sql-tests/inputs/listagg-collations.sql b/sql/core/src/test/resources/sql-tests/inputs/listagg-collations.sql index 8e608d940a0f9..35f86183c37b3 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/listagg-collations.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/listagg-collations.sql @@ -4,6 +4,9 @@ SELECT listagg(c1) WITHIN GROUP (ORDER BY c1 COLLATE utf8_lcase) FROM (VALUES (' SELECT listagg(DISTINCT c1 COLLATE utf8_binary) FROM (VALUES ('a'), ('A'), ('b'), ('B')) AS t(c1); SELECT listagg(DISTINCT c1 COLLATE utf8_lcase) FROM (VALUES ('a'), ('A'), ('b'), ('B')) AS t(c1); SELECT listagg(DISTINCT c1 COLLATE utf8_lcase) WITHIN GROUP (ORDER BY c1 COLLATE utf8_lcase) FROM (VALUES ('a'), ('B'), ('b'), ('A')) AS t(c1); +SELECT listagg(DISTINCT c1 COLLATE unicode_rtrim) FROM (VALUES ('abc '), ('abc '), ('x'), ('abc')) AS t(c1); +SELECT listagg(c1) WITHIN GROUP (ORDER BY c1) FROM (VALUES ('abc '), ('abc '), ('abc\n'), ('abc'), ('x')) AS t(c1); +SELECT listagg(c1) WITHIN GROUP (ORDER BY c1 COLLATE unicode_rtrim) FROM (VALUES ('abc '), ('abc '), ('abc\n'), ('abc'), ('x')) AS t(c1); -- Error case with collations SELECT listagg(DISTINCT c1 COLLATE utf8_lcase) WITHIN GROUP (ORDER BY c1 COLLATE utf8_binary) FROM (VALUES ('a'), ('b'), ('A'), ('B')) AS t(c1); \ No newline at end of file diff --git a/sql/core/src/test/resources/sql-tests/results/listagg-collations.sql.out b/sql/core/src/test/resources/sql-tests/results/listagg-collations.sql.out index 23a9b76a0841f..86d38d14f2cf8 100644 --- a/sql/core/src/test/resources/sql-tests/results/listagg-collations.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/listagg-collations.sql.out @@ -39,6 +39,32 @@ struct +-- !query output +abc x + + +-- !query +SELECT listagg(c1) WITHIN GROUP (ORDER BY c1) FROM (VALUES ('abc '), ('abc '), ('abc\n'), ('abc'), ('x')) AS t(c1) +-- !query schema +struct +-- !query output +abcabc +abc abc x + + +-- !query +SELECT listagg(c1) WITHIN GROUP (ORDER BY c1 COLLATE unicode_rtrim) FROM (VALUES ('abc '), ('abc '), ('abc\n'), ('abc'), ('x')) AS t(c1) +-- !query schema +struct +-- !query output +abc abc abcabc +x + + -- !query SELECT listagg(DISTINCT c1 COLLATE utf8_lcase) WITHIN GROUP (ORDER BY c1 COLLATE utf8_binary) FROM (VALUES ('a'), ('b'), ('A'), ('B')) AS t(c1) -- !query schema From 9c5bd3df7faa2b3925bcf60638314d13836085f9 Mon Sep 17 00:00:00 2001 From: Mikhail Nikoliukin Date: Fri, 22 Nov 2024 17:37:29 +0100 Subject: [PATCH 53/58] [SPARK-42746] adjust error message --- common/utils/src/main/resources/error/error-conditions.json | 2 +- .../org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala | 2 +- .../org/apache/spark/sql/errors/QueryCompilationErrors.scala | 2 ++ .../sql-tests/analyzer-results/listagg-collations.sql.out | 1 + .../test/resources/sql-tests/analyzer-results/listagg.sql.out | 2 ++ .../test/resources/sql-tests/results/listagg-collations.sql.out | 1 + sql/core/src/test/resources/sql-tests/results/listagg.sql.out | 2 ++ 7 files changed, 10 insertions(+), 2 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 7c8b36e806030..8dbdd9311cb3d 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -3353,7 +3353,7 @@ }, "MISMATCH_WITH_DISTINCT_INPUT": { "message": [ - "Function is invoked with DISTINCT. The WITHIN GROUP ordering expressions must be picked from the function inputs, but got ." + "The function is invoked with DISTINCT and WITHIN GROUP but expressions and do not match. The WITHIN GROUP ordering expression must be picked from the function inputs." ] }, "WITHIN_GROUP_MISSING" : { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index ce9c12a2d866f..4506e5bdaf2dd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -426,7 +426,7 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB case agg @ AggregateExpression(listAgg: ListAgg, _, _, _, _) if agg.isDistinct && listAgg.needSaveOrderValue => throw QueryCompilationErrors.functionAndOrderExpressionMismatchError( - listAgg.prettyName, listAgg.orderExpressions) + listAgg.prettyName, listAgg.child, listAgg.orderExpressions) case w: WindowExpression => // Only allow window functions with an aggregate expression or an offset window diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index 8ea3c52bda609..4c970d066d31e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -1055,11 +1055,13 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat def functionAndOrderExpressionMismatchError( functionName: String, + functionArg: Expression, orderExpr: Seq[SortOrder]): Throwable = { new AnalysisException( errorClass = "INVALID_WITHIN_GROUP_EXPRESSION.MISMATCH_WITH_DISTINCT_INPUT", messageParameters = Map( "funcName" -> toSQLId(functionName), + "funcArg" -> toSQLExpr(functionArg), "orderingExpr" -> orderExpr.map(order => toSQLExpr(order.child)).mkString(", "))) } diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/listagg-collations.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/listagg-collations.sql.out index d23da89d35351..ca471858a5416 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/listagg-collations.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/listagg-collations.sql.out @@ -79,6 +79,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException "errorClass" : "INVALID_WITHIN_GROUP_EXPRESSION.MISMATCH_WITH_DISTINCT_INPUT", "sqlState" : "42K0K", "messageParameters" : { + "funcArg" : "\"collate(c1, utf8_lcase)\"", "funcName" : "`listagg`", "orderingExpr" : "\"collate(c1, utf8_binary)\"" } diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/listagg.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/listagg.sql.out index 07269b3cd9055..84893fb9fbab1 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/listagg.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/listagg.sql.out @@ -413,6 +413,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException "errorClass" : "INVALID_WITHIN_GROUP_EXPRESSION.MISMATCH_WITH_DISTINCT_INPUT", "sqlState" : "42K0K", "messageParameters" : { + "funcArg" : "\"a\"", "funcName" : "`listagg`", "orderingExpr" : "\"b\"" } @@ -427,6 +428,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException "errorClass" : "INVALID_WITHIN_GROUP_EXPRESSION.MISMATCH_WITH_DISTINCT_INPUT", "sqlState" : "42K0K", "messageParameters" : { + "funcArg" : "\"a\"", "funcName" : "`listagg`", "orderingExpr" : "\"a\", \"b\"" } diff --git a/sql/core/src/test/resources/sql-tests/results/listagg-collations.sql.out b/sql/core/src/test/resources/sql-tests/results/listagg-collations.sql.out index 86d38d14f2cf8..cf3bac04f09ca 100644 --- a/sql/core/src/test/resources/sql-tests/results/listagg-collations.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/listagg-collations.sql.out @@ -75,6 +75,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException "errorClass" : "INVALID_WITHIN_GROUP_EXPRESSION.MISMATCH_WITH_DISTINCT_INPUT", "sqlState" : "42K0K", "messageParameters" : { + "funcArg" : "\"collate(c1, utf8_lcase)\"", "funcName" : "`listagg`", "orderingExpr" : "\"collate(c1, utf8_binary)\"" } diff --git a/sql/core/src/test/resources/sql-tests/results/listagg.sql.out b/sql/core/src/test/resources/sql-tests/results/listagg.sql.out index b6e141f9a003f..15c5fc0e1cea5 100644 --- a/sql/core/src/test/resources/sql-tests/results/listagg.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/listagg.sql.out @@ -344,6 +344,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException "errorClass" : "INVALID_WITHIN_GROUP_EXPRESSION.MISMATCH_WITH_DISTINCT_INPUT", "sqlState" : "42K0K", "messageParameters" : { + "funcArg" : "\"a\"", "funcName" : "`listagg`", "orderingExpr" : "\"b\"" } @@ -360,6 +361,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException "errorClass" : "INVALID_WITHIN_GROUP_EXPRESSION.MISMATCH_WITH_DISTINCT_INPUT", "sqlState" : "42K0K", "messageParameters" : { + "funcArg" : "\"a\"", "funcName" : "`listagg`", "orderingExpr" : "\"a\", \"b\"" } From e6d9c7091914f45a6f4b63597b51c3a5f101c85f Mon Sep 17 00:00:00 2001 From: Mikhail Nikoliukin Date: Mon, 25 Nov 2024 16:58:02 +0100 Subject: [PATCH 54/58] [SPARK-42746] make SortOrder a child of listagg --- .../org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 3 +++ .../spark/sql/catalyst/expressions/aggregate/collect.scala | 5 ++--- .../org/apache/spark/sql/execution/SparkStrategies.scala | 7 ++++++- .../resources/sql-tests/analyzer-results/listagg.sql.out | 4 ++-- .../src/test/resources/sql-tests/results/listagg.sql.out | 4 ++-- 5 files changed, 15 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index bed7bea61597f..81c7c75f47c61 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2782,6 +2782,9 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor ne case e: Expression if e.foldable => e // No need to create an attribute reference if it will be evaluated as a Literal. + case e: SortOrder => + // For SortOder just recursively extract the from child expression. + e.copy(child = extractExpr(e.child)) case e: NamedArgumentExpression => // For NamedArgumentExpression, we extract the value and replace it with // an AttributeReference (with an internal column name, e.g. "_w0"). diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala index f339879ef5981..5dbab2eea88d2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala @@ -510,7 +510,7 @@ case class ListAgg( override protected def convertToBufferElement(value: Any): Any = InternalRow.copyValue(value) - override def children: Seq[Expression] = child +: delimiter +: orderExpressions.map(_.child) + override def children: Seq[Expression] = child +: delimiter +: orderExpressions /** * Utility func to check if given order is defined and different from [[child]]. @@ -534,8 +534,7 @@ case class ListAgg( delimiter = newChildren(1), orderExpressions = newChildren .drop(2) - .zip(orderExpressions) - .map { case (newExpr, oldSortOrder) => oldSortOrder.copy(child = newExpr) } + .map(_.asInstanceOf[SortOrder]) ) private[this] def orderValuesField: Seq[StructField] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 22082aca81a22..e4b89b4f4de88 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -607,7 +607,12 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // [COUNT(DISTINCT bar), COUNT(DISTINCT foo)] is disallowed because those two distinct // aggregates have different column expressions. val distinctExpressions = - functionsWithDistinct.head.aggregateFunction.children.filterNot(_.foldable) + functionsWithDistinct.head.aggregateFunction.children + .filterNot(_.foldable) + .map { + case s: SortOrder => s.child + case e => e + } val normalizedNamedDistinctExpressions = distinctExpressions.map { e => // Ideally this should be done in `NormalizeFloatingNumbers`, but we do it here // because `distinctExpressions` is not extracted during logical phase. diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/listagg.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/listagg.sql.out index 84893fb9fbab1..71eb3f8ca76b3 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/listagg.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/listagg.sql.out @@ -353,7 +353,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException "errorClass" : "INVALID_WINDOW_SPEC_FOR_AGGREGATION_FUNC", "sqlState" : "42601", "messageParameters" : { - "aggFunc" : "\"listagg(a, NULL, a)\"" + "aggFunc" : "\"listagg(a, NULL, a ASC NULLS FIRST)\"" }, "queryContext" : [ { "objectType" : "", @@ -373,7 +373,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException "errorClass" : "INVALID_WINDOW_SPEC_FOR_AGGREGATION_FUNC", "sqlState" : "42601", "messageParameters" : { - "aggFunc" : "\"listagg(a, NULL, a)\"" + "aggFunc" : "\"listagg(a, NULL, a ASC NULLS FIRST)\"" }, "queryContext" : [ { "objectType" : "", diff --git a/sql/core/src/test/resources/sql-tests/results/listagg.sql.out b/sql/core/src/test/resources/sql-tests/results/listagg.sql.out index 15c5fc0e1cea5..ef580704992ce 100644 --- a/sql/core/src/test/resources/sql-tests/results/listagg.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/listagg.sql.out @@ -278,7 +278,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException "errorClass" : "INVALID_WINDOW_SPEC_FOR_AGGREGATION_FUNC", "sqlState" : "42601", "messageParameters" : { - "aggFunc" : "\"listagg(a, NULL, a)\"" + "aggFunc" : "\"listagg(a, NULL, a ASC NULLS FIRST)\"" }, "queryContext" : [ { "objectType" : "", @@ -300,7 +300,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException "errorClass" : "INVALID_WINDOW_SPEC_FOR_AGGREGATION_FUNC", "sqlState" : "42601", "messageParameters" : { - "aggFunc" : "\"listagg(a, NULL, a)\"" + "aggFunc" : "\"listagg(a, NULL, a ASC NULLS FIRST)\"" }, "queryContext" : [ { "objectType" : "", From 0bbd8af127c6e3921ba8db30562529983e7cf9ec Mon Sep 17 00:00:00 2001 From: Mikhail Nikoliukin Date: Mon, 25 Nov 2024 18:25:14 +0100 Subject: [PATCH 55/58] [SPARK-42746] fix error-conditions --- common/utils/src/main/resources/error/error-conditions.json | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 8dbdd9311cb3d..9d0f04ed2e17e 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -3351,8 +3351,8 @@ "The function does not support DISTINCT with WITHIN GROUP." ] }, - "MISMATCH_WITH_DISTINCT_INPUT": { - "message": [ + "MISMATCH_WITH_DISTINCT_INPUT" : { + "message" : [ "The function is invoked with DISTINCT and WITHIN GROUP but expressions and do not match. The WITHIN GROUP ordering expression must be picked from the function inputs." ] }, From d96ac1e8a436b67d82f259b91a4556a8ee785ec5 Mon Sep 17 00:00:00 2001 From: Mikhail Nikoliukin Date: Tue, 26 Nov 2024 12:35:18 +0100 Subject: [PATCH 56/58] [SPARK-42746] deduplicate concat logic --- .../apache/spark/unsafe/types/ByteArray.java | 41 ++++++++----------- 1 file changed, 18 insertions(+), 23 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java index e010e2dadf605..39ae8abe12225 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java @@ -135,32 +135,27 @@ public static byte[] subStringSQL(byte[] bytes, int pos, int len) { return Arrays.copyOfRange(bytes, start, end); } + /** + * Concatenate multiple byte arrays into one. + * If one of the inputs is null then null will be returned. + * @param inputs byte arrays to concatenate + * @return the concatenated byte array or null if one of the arguments is null + */ public static byte[] concat(byte[]... inputs) { - // Compute the total length of the result - long totalLength = 0; - for (byte[] input : inputs) { - if (input != null) { - totalLength += input.length; - } else { - return null; - } - } - - // Allocate a new byte array, and copy the inputs one by one into it - final byte[] result = new byte[Ints.checkedCast(totalLength)]; - int offset = 0; - for (byte[] input : inputs) { - int len = input.length; - Platform.copyMemory( - input, Platform.BYTE_ARRAY_OFFSET, - result, Platform.BYTE_ARRAY_OFFSET + offset, - len); - offset += len; - } - return result; + return concatWS(EMPTY_BYTE, inputs); } + /** + * Concatenate multiple byte arrays with a given delimiter. + * If the delimiter or one of the inputs is null then null will be returned. + * @param delimiter byte array to be placed between each input + * @param inputs byte arrays to concatenate + * @return the concatenated byte array or null if one of the arguments is null + */ public static byte[] concatWS(byte[] delimiter, byte[]... inputs) { + if (delimiter == null) { + return null; + } // Compute the total length of the result long totalLength = 0; for (byte[] input : inputs) { @@ -182,7 +177,7 @@ public static byte[] concatWS(byte[] delimiter, byte[]... inputs) { result, Platform.BYTE_ARRAY_OFFSET + offset, len); offset += len; - if(i < inputs.length - 1) { + if (delimiter.length > 0 && i < inputs.length - 1) { Platform.copyMemory( delimiter, Platform.BYTE_ARRAY_OFFSET, result, Platform.BYTE_ARRAY_OFFSET + offset, From aee0ac5018141efe070b06ba57783cdbb4ecedeb Mon Sep 17 00:00:00 2001 From: Mikhail Nikoliukin Date: Tue, 26 Nov 2024 12:50:15 +0100 Subject: [PATCH 57/58] [SPARK-42746] add type safety in getDelimiterValue --- .../expressions/aggregate/collect.scala | 40 +++++++++++-------- 1 file changed, 23 insertions(+), 17 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala index 5dbab2eea88d2..7789c23b50a48 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import scala.collection.mutable import scala.collection.mutable.{ArrayBuffer, Growable} +import scala.util.{Left, Right} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult @@ -465,32 +466,37 @@ case class ListAgg( } private[this] def concatSkippingNulls(buffer: mutable.ArrayBuffer[Any]): Any = { - val delimiterValue = getDelimiterValue - dataType match { - case BinaryType => + getDelimiterValue match { + case Right(delimiterValue: Array[Byte]) => val inputs = buffer.filter(_ != null).map(_.asInstanceOf[Array[Byte]]) - ByteArray.concatWS(delimiterValue.asInstanceOf[Array[Byte]], inputs.toSeq: _*) - case _: StringType => + ByteArray.concatWS(delimiterValue, inputs.toSeq: _*) + case Left(delimiterValue: UTF8String) => val inputs = buffer.filter(_ != null).map(_.asInstanceOf[UTF8String]) - UTF8String.concatWs(delimiterValue.asInstanceOf[UTF8String], inputs.toSeq : _*) + UTF8String.concatWs(delimiterValue, inputs.toSeq: _*) } } - override def dataType: DataType = child.dataType - - private[this] def getDelimiterValue: Any = { + /** + * @return delimiter value or default empty value if delimiter is null. Type respects [[dataType]] + */ + private[this] def getDelimiterValue: Either[UTF8String, Array[Byte]] = { val delimiterValue = delimiter.eval() - if (delimiterValue == null) { - // default delimiter value - dataType match { - case _: StringType => UTF8String.fromString("") - case _: BinaryType => ByteArray.EMPTY_BYTE - } - } else { - delimiterValue + dataType match { + case _: StringType => + Left( + if (delimiterValue == null) UTF8String.fromString("") + else delimiterValue.asInstanceOf[UTF8String] + ) + case _: BinaryType => + Right( + if (delimiterValue == null) ByteArray.EMPTY_BYTE + else delimiterValue.asInstanceOf[Array[Byte]] + ) } } + override def dataType: DataType = child.dataType + override def update(buffer: ArrayBuffer[Any], input: InternalRow): ArrayBuffer[Any] = { val value = child.eval(input) if (value != null) { From 91b759f12bac568d38ed5b94244dd65eeb1aef41 Mon Sep 17 00:00:00 2001 From: Mikhail Nikoliukin Date: Thu, 28 Nov 2024 12:19:18 +0100 Subject: [PATCH 58/58] [SPARK-42746] fix java indent --- .../org/apache/spark/unsafe/types/ByteArray.java | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java index 39ae8abe12225..f12408fb49313 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java @@ -138,6 +138,7 @@ public static byte[] subStringSQL(byte[] bytes, int pos, int len) { /** * Concatenate multiple byte arrays into one. * If one of the inputs is null then null will be returned. + * * @param inputs byte arrays to concatenate * @return the concatenated byte array or null if one of the arguments is null */ @@ -148,8 +149,9 @@ public static byte[] concat(byte[]... inputs) { /** * Concatenate multiple byte arrays with a given delimiter. * If the delimiter or one of the inputs is null then null will be returned. + * * @param delimiter byte array to be placed between each input - * @param inputs byte arrays to concatenate + * @param inputs byte arrays to concatenate * @return the concatenated byte array or null if one of the arguments is null */ public static byte[] concatWS(byte[] delimiter, byte[]... inputs) { @@ -173,15 +175,15 @@ public static byte[] concatWS(byte[] delimiter, byte[]... inputs) { byte[] input = inputs[i]; int len = input.length; Platform.copyMemory( - input, Platform.BYTE_ARRAY_OFFSET, - result, Platform.BYTE_ARRAY_OFFSET + offset, - len); + input, Platform.BYTE_ARRAY_OFFSET, + result, Platform.BYTE_ARRAY_OFFSET + offset, + len); offset += len; if (delimiter.length > 0 && i < inputs.length - 1) { Platform.copyMemory( - delimiter, Platform.BYTE_ARRAY_OFFSET, - result, Platform.BYTE_ARRAY_OFFSET + offset, - delimiter.length); + delimiter, Platform.BYTE_ARRAY_OFFSET, + result, Platform.BYTE_ARRAY_OFFSET + offset, + delimiter.length); offset += delimiter.length; } }