Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/pr_build_linux.yml
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,7 @@ jobs:
- name: "sql"
value: |
org.apache.spark.sql.CometToPrettyStringSuite
org.apache.spark.sql.CometCollationSuite
fail-fast: false
name: ${{ matrix.profile.name }}/${{ matrix.profile.scan_impl }} [${{ matrix.suite.name }}]
runs-on: ${{ github.repository_owner == 'apache' && format('runs-on={0},family=m8a+m7a+c8a,cpu=16,image=ubuntu24-full-x64,extras=s3-cache,disk=large,tag=datafusion-comet', github.run_id) || 'ubuntu-latest' }}
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/pr_build_macos.yml
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ jobs:
- name: "sql"
value: |
org.apache.spark.sql.CometToPrettyStringSuite
org.apache.spark.sql.CometCollationSuite

fail-fast: false
name: ${{ matrix.os }}/${{ matrix.profile.name }} [${{ matrix.suite.name }}]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,16 @@

package org.apache.comet.shims

import org.apache.spark.sql.internal.types.StringTypeWithCollation
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.types.{DataType, StringType}

trait CometTypeShim {
def isStringCollationType(dt: DataType): Boolean = dt.isInstanceOf[StringTypeWithCollation]
// A `StringType` carries collation metadata in Spark 4.0. Only non-default (non-UTF8_BINARY)
// collations have semantics Comet's byte-level hashing/sorting/equality cannot honor. The
// default `StringType` object is `StringType(UTF8_BINARY_COLLATION_ID)`, so comparing
// `collationId` against that instance's id picks out non-default collations without needing
// `private[sql]` helpers on `StringType`.
def isStringCollationType(dt: DataType): Boolean = dt match {
case st: StringType => st.collationId != StringType.collationId
case _ => false
}
}
54 changes: 0 additions & 54 deletions dev/diffs/4.0.1.diff
Original file line number Diff line number Diff line change
Expand Up @@ -150,26 +150,6 @@ index 4410fe50912..43bcce2a038 100644
case _ => Map[String, String]()
}
val childrenInfo = children.flatMap {
diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/listagg-collations.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/listagg-collations.sql.out
index 7aca17dcb25..8afeb3b4a2f 100644
--- a/sql/core/src/test/resources/sql-tests/analyzer-results/listagg-collations.sql.out
+++ b/sql/core/src/test/resources/sql-tests/analyzer-results/listagg-collations.sql.out
@@ -64,15 +64,6 @@ WithCTE
+- CTERelationRef xxxx, true, [c1#x], false, false


--- !query
-SELECT lower(listagg(DISTINCT c1 COLLATE utf8_lcase) WITHIN GROUP (ORDER BY c1 COLLATE utf8_lcase)) FROM (VALUES ('a'), ('B'), ('b'), ('A')) AS t(c1)
--- !query analysis
-Aggregate [lower(listagg(distinct collate(c1#x, utf8_lcase), null, collate(c1#x, utf8_lcase) ASC NULLS FIRST, 0, 0)) AS lower(listagg(DISTINCT collate(c1, utf8_lcase), NULL) WITHIN GROUP (ORDER BY collate(c1, utf8_lcase) ASC NULLS FIRST))#x]
-+- SubqueryAlias t
- +- Project [col1#x AS c1#x]
- +- LocalRelation [col1#x]
-
-
-- !query
WITH t(c1) AS (SELECT replace(listagg(DISTINCT col1 COLLATE unicode_rtrim) COLLATE utf8_binary, ' ', '') FROM (VALUES ('xbc '), ('xbc '), ('a'), ('xbc'))) SELECT len(c1), regexp_count(c1, 'a'), regexp_count(c1, 'xbc') FROM t
-- !query analysis
diff --git a/sql/core/src/test/resources/sql-tests/inputs/collations.sql b/sql/core/src/test/resources/sql-tests/inputs/collations.sql
index 17815ed5dde..baad440b1ce 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/collations.sql
Expand Down Expand Up @@ -230,21 +210,6 @@ index 698ca009b4f..57d774a3617 100644

-- Test tables
CREATE table explain_temp1 (key int, val int) USING PARQUET;
diff --git a/sql/core/src/test/resources/sql-tests/inputs/listagg-collations.sql b/sql/core/src/test/resources/sql-tests/inputs/listagg-collations.sql
index aa3d02dc2fb..c4f878d9908 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/listagg-collations.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/listagg-collations.sql
@@ -5,7 +5,9 @@ WITH t(c1) AS (SELECT listagg(col1) WITHIN GROUP (ORDER BY col1) FROM (VALUES ('
-- Test cases with utf8_lcase. Lower expression added for determinism
SELECT lower(listagg(c1) WITHIN GROUP (ORDER BY c1 COLLATE utf8_lcase)) FROM (VALUES ('a'), ('A'), ('b'), ('B')) AS t(c1);
WITH t(c1) AS (SELECT lower(listagg(DISTINCT col1 COLLATE utf8_lcase)) FROM (VALUES ('a'), ('A'), ('b'), ('B'))) SELECT len(c1), regexp_count(c1, 'a'), regexp_count(c1, 'b') FROM t;
-SELECT lower(listagg(DISTINCT c1 COLLATE utf8_lcase) WITHIN GROUP (ORDER BY c1 COLLATE utf8_lcase)) FROM (VALUES ('a'), ('B'), ('b'), ('A')) AS t(c1);
+-- TODO https://github.com/apache/datafusion-comet/issues/1947
+-- TODO fix Comet for this query
+-- SELECT lower(listagg(DISTINCT c1 COLLATE utf8_lcase) WITHIN GROUP (ORDER BY c1 COLLATE utf8_lcase)) FROM (VALUES ('a'), ('B'), ('b'), ('A')) AS t(c1);
-- Test cases with unicode_rtrim.
WITH t(c1) AS (SELECT replace(listagg(DISTINCT col1 COLLATE unicode_rtrim) COLLATE utf8_binary, ' ', '') FROM (VALUES ('xbc '), ('xbc '), ('a'), ('xbc'))) SELECT len(c1), regexp_count(c1, 'a'), regexp_count(c1, 'xbc') FROM t;
WITH t(c1) AS (SELECT listagg(col1) WITHIN GROUP (ORDER BY col1 COLLATE unicode_rtrim) FROM (VALUES ('abc '), ('abc\n'), ('abc'), ('x'))) SELECT replace(replace(c1, ' ', ''), '\n', '$') FROM t;
diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/aggregates_part3.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/aggregates_part3.sql
index 41fd4de2a09..162d5a817b6 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/aggregates_part3.sql
Expand Down Expand Up @@ -367,25 +332,6 @@ index 21a3ce1e122..f4762ab98f0 100644
SET spark.sql.ansi.enabled = false;

-- In COMPENSATION views get invalidated if the type can't cast
diff --git a/sql/core/src/test/resources/sql-tests/results/listagg-collations.sql.out b/sql/core/src/test/resources/sql-tests/results/listagg-collations.sql.out
index 1f8c5822e7d..b7de4e28813 100644
--- a/sql/core/src/test/resources/sql-tests/results/listagg-collations.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/listagg-collations.sql.out
@@ -40,14 +40,6 @@ struct<len(c1):int,regexp_count(c1, a):int,regexp_count(c1, b):int>
2 1 1


--- !query
-SELECT lower(listagg(DISTINCT c1 COLLATE utf8_lcase) WITHIN GROUP (ORDER BY c1 COLLATE utf8_lcase)) FROM (VALUES ('a'), ('B'), ('b'), ('A')) AS t(c1)
--- !query schema
-struct<lower(listagg(DISTINCT collate(c1, utf8_lcase), NULL) WITHIN GROUP (ORDER BY collate(c1, utf8_lcase) ASC NULLS FIRST)):string collate UTF8_LCASE>
--- !query output
-ab
-
-
-- !query
WITH t(c1) AS (SELECT replace(listagg(DISTINCT col1 COLLATE unicode_rtrim) COLLATE utf8_binary, ' ', '') FROM (VALUES ('xbc '), ('xbc '), ('a'), ('xbc'))) SELECT len(c1), regexp_count(c1, 'a'), regexp_count(c1, 'xbc') FROM t
-- !query schema
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
index 0f42502f1d9..e9ff802141f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,12 @@ import org.apache.comet.serde.ExprOuterClass.{AggExpr, Expr, ScalarFunc}
import org.apache.comet.serde.Types.{DataType => ProtoDataType}
import org.apache.comet.serde.Types.DataType._
import org.apache.comet.serde.literals.CometLiteral
import org.apache.comet.shims.CometExprShim
import org.apache.comet.shims.{CometExprShim, CometTypeShim}

/**
* An utility object for query plan and expression serialization.
*/
object QueryPlanSerde extends Logging with CometExprShim {
object QueryPlanSerde extends Logging with CometExprShim with CometTypeShim {

private val arrayExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map(
classOf[ArrayAppend] -> CometArrayAppend,
Expand Down Expand Up @@ -800,6 +800,8 @@ object QueryPlanSerde extends Logging with CometExprShim {
// scalastyle:on
def supportedScalarSortElementType(dt: DataType): Boolean = {
dt match {
// Collated strings require collation-aware ordering; Comet only compares raw bytes.
case st: StringType if isStringCollationType(st) => false
case _: ByteType | _: ShortType | _: IntegerType | _: LongType | _: FloatType |
_: DoubleType | _: DecimalType | _: DateType | _: TimestampType | _: TimestampNTZType |
_: BooleanType | _: BinaryType | _: StringType =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ import org.apache.comet.CometConf.{COMET_EXEC_SHUFFLE_ENABLED, COMET_SHUFFLE_MOD
import org.apache.comet.CometSparkSessionExtensions.{hasExplainInfo, isCometShuffleManagerEnabled, withInfos}
import org.apache.comet.serde.{Compatible, OperatorOuterClass, QueryPlanSerde, SupportLevel, Unsupported}
import org.apache.comet.serde.operator.CometSink
import org.apache.comet.shims.ShimCometShuffleExchangeExec
import org.apache.comet.shims.{CometTypeShim, ShimCometShuffleExchangeExec}

/**
* Performs a shuffle that will result in the desired partitioning.
Expand Down Expand Up @@ -219,6 +219,7 @@ case class CometShuffleExchangeExec(
object CometShuffleExchangeExec
extends CometSink[ShuffleExchangeExec]
with ShimCometShuffleExchangeExec
with CometTypeShim
with SQLConfHelper {

override def getSupportLevel(op: ShuffleExchangeExec): SupportLevel = {
Expand Down Expand Up @@ -316,6 +317,9 @@ object CometShuffleExchangeExec
* hashing complex types, see hash_funcs/utils.rs
*/
def supportedHashPartitioningDataType(dt: DataType): Boolean = dt match {
// Collated strings require collation-aware hashing; Comet only hashes raw bytes,
// which would misroute rows that compare equal under the collation.
case st: StringType if isStringCollationType(st) => false
case _: BooleanType | _: ByteType | _: ShortType | _: IntegerType | _: LongType |
_: FloatType | _: DoubleType | _: StringType | _: BinaryType | _: TimestampType |
_: TimestampNTZType | _: DateType =>
Expand All @@ -338,6 +342,8 @@ object CometShuffleExchangeExec
* complex types.
*/
def supportedRangePartitioningDataType(dt: DataType): Boolean = dt match {
// Collated strings require collation-aware ordering; Comet only compares raw bytes.
case st: StringType if isStringCollationType(st) => false
case _: BooleanType | _: ByteType | _: ShortType | _: IntegerType | _: LongType |
_: FloatType | _: DoubleType | _: StringType | _: BinaryType | _: TimestampType |
_: TimestampNTZType | _: DecimalType | _: DateType =>
Expand Down Expand Up @@ -498,6 +504,11 @@ object CometShuffleExchangeExec
reasons += s"unsupported hash partitioning expression: $expr"
}
}
for (dt <- expressions.map(_.dataType).distinct) {
if (isStringCollationType(dt)) {
reasons += s"unsupported hash partitioning data type for columnar shuffle: $dt"
}
}
case SinglePartition =>
// we already checked that the input types are supported
case RoundRobinPartitioning(_) =>
Expand All @@ -508,6 +519,11 @@ object CometShuffleExchangeExec
reasons += s"unsupported range partitioning sort order: $o"
}
}
for (dt <- orderings.map(_.dataType).distinct) {
if (isStringCollationType(dt)) {
reasons += s"unsupported range partitioning data type for columnar shuffle: $dt"
}
}
case _ =>
reasons +=
s"unsupported Spark partitioning for columnar shuffle: ${partitioning.getClass.getName}"
Expand Down
10 changes: 9 additions & 1 deletion spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ import org.apache.comet.CometSparkSessionExtensions.{isCometShuffleEnabled, with
import org.apache.comet.parquet.CometParquetUtils
import org.apache.comet.serde.{CometOperatorSerde, Compatible, Incompatible, OperatorOuterClass, SupportLevel, Unsupported}
import org.apache.comet.serde.OperatorOuterClass.{AggregateMode => CometAggregateMode, Operator}
import org.apache.comet.serde.QueryPlanSerde.{aggExprToProto, exprToProto, supportedSortType}
import org.apache.comet.serde.QueryPlanSerde.{aggExprToProto, exprToProto, isStringCollationType, supportedSortType}
import org.apache.comet.serde.operator.CometSink

/**
Expand Down Expand Up @@ -1386,6 +1386,14 @@ trait CometBaseAggregate {
return None
}

if (groupingExpressions.exists(expr => isStringCollationType(expr.dataType))) {
// Collation-aware grouping requires collation-aware hashing/equality; Comet only
// compares raw bytes, which would put rows that compare equal under the collation
// into different groups.
withInfo(aggregate, "Grouping on non-default collated strings is not supported")
return None
}

val groupingExprsWithInput =
groupingExpressions.map(expr => expr.name -> exprToProto(expr, child.output))

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.spark.sql

class CometCollationSuite extends CometTestBase {
Comment thread
parthchandra marked this conversation as resolved.

// Queries that group, sort, or shuffle on a non-default collated string must fall back to
// Spark because Comet's shuffle/sort/aggregate compare raw bytes rather than collation-aware
// keys. The shuffle-exchange rule is the primary line of defense (see #1947), so these tests
// pin down the fallback reason it emits.
private val hashShuffleCollationReason =
"unsupported hash partitioning data type for columnar shuffle"
private val rangeShuffleCollationReason =
"unsupported range partitioning data type for columnar shuffle"

test("listagg DISTINCT with utf8_lcase collation (issue #1947)") {
checkSparkAnswerAndFallbackReason(
"SELECT lower(listagg(DISTINCT c1 COLLATE utf8_lcase) " +
"WITHIN GROUP (ORDER BY c1 COLLATE utf8_lcase)) " +
"FROM (VALUES ('a'), ('B'), ('b'), ('A')) AS t(c1)",
hashShuffleCollationReason)
}

test("DISTINCT on utf8_lcase collated string groups case-insensitively") {
checkSparkAnswerAndFallbackReason(
"SELECT DISTINCT c1 COLLATE utf8_lcase AS c " +
"FROM (VALUES ('a'), ('A'), ('b'), ('B')) AS t(c1) ORDER BY c",
hashShuffleCollationReason)
}

test("GROUP BY utf8_lcase collated string groups case-insensitively") {
checkSparkAnswerAndFallbackReason(
"SELECT lower(c1 COLLATE utf8_lcase) AS k, count(*) " +
"FROM (VALUES ('a'), ('A'), ('b'), ('B')) AS t(c1) " +
"GROUP BY c1 COLLATE utf8_lcase ORDER BY k",
hashShuffleCollationReason)
}

test("ORDER BY utf8_lcase collated string sorts case-insensitively") {
checkSparkAnswerAndFallbackReason(
"SELECT c1 COLLATE utf8_lcase AS c " +
"FROM (VALUES ('A'), ('b'), ('a'), ('B')) AS t(c1) ORDER BY c",
rangeShuffleCollationReason)
}

test("default UTF8_BINARY string still runs through Comet") {
// Sanity check that the collation fallback does not over-block the default string type.
withParquetTable(Seq(("a", 1), ("b", 2), ("a", 3)), "tbl") {
checkSparkAnswerAndOperator("SELECT DISTINCT _1 FROM tbl ORDER BY _1")
}
}
}
Loading