Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
a909cb1
reapply changes
GideonPotok Sep 11, 2024
ce865a9
SPARK COLLATIONS MAP
GideonPotok Sep 11, 2024
d697403
formatting
GideonPotok Sep 13, 2024
535b16b
Revert "formatting"
GideonPotok Sep 13, 2024
066ebd4
formatting
GideonPotok Sep 13, 2024
f2d0503
formatting
GideonPotok Sep 13, 2024
4f0cfbe
formatting
GideonPotok Sep 13, 2024
432de23
move reason
GideonPotok Sep 17, 2024
a8d626b
four spaces for classes
GideonPotok Sep 20, 2024
d621b8a
fix indentation of method params
GideonPotok Sep 20, 2024
af97fe8
fix indentation of method params
GideonPotok Sep 20, 2024
bf91fe9
fix indentation of method params
GideonPotok Sep 20, 2024
0b7364f
fix indentation of method params
GideonPotok Sep 20, 2024
ca564d3
fix indentation of method params
GideonPotok Sep 20, 2024
96c742f
Update common/utils/src/main/resources/error/error-conditions.json
GideonPotok Sep 20, 2024
d5552cd
Update sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpre…
GideonPotok Sep 20, 2024
ce8986f
Update sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpre…
GideonPotok Sep 20, 2024
2632b91
Update sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpre…
GideonPotok Sep 20, 2024
72483ac
Update sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpre…
GideonPotok Sep 20, 2024
4695462
fix call to throw SparkUnsupportedOperationException
GideonPotok Sep 20, 2024
b285a6f
Apply suggestions from code review
GideonPotok Sep 20, 2024
e330698
fix
GideonPotok Sep 24, 2024
f4074be
hello
GideonPotok Sep 26, 2024
adae8f3
passing tests
GideonPotok Sep 28, 2024
37efd0c
passing tests
GideonPotok Sep 28, 2024
afd123b
Added COMPLEX_EXPRESSION_UNSUPPORTED_INPUT.BAD_INPUTS. Tests pass.
GideonPotok Sep 29, 2024
f4c39b1
reformat error-conditions.json for test 'Error conditions are correct…
GideonPotok Sep 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -625,6 +625,11 @@
"Cannot process input data types for the expression: <expression>."
],
"subClass" : {
"BAD_INPUTS" : {
"message" : [
"The input data types to <functionName> must be valid, but found the input types <dataType>."
]
},
"MISMATCHED_TYPES" : {
"message" : [
"All input types must be the same except nullable, containsNull, valueContainsNull flags, but found the input types <inputTypes>."
Expand Down Expand Up @@ -1005,6 +1010,11 @@
"The input of <functionName> can't be <dataType> type data."
]
},
"UNSUPPORTED_MODE_DATA_TYPE" : {
"message" : [
"The <mode> does not support the <child> data type, because there is a \"MAP\" type with keys and/or values that have collated sub-fields."
]
},
"UNSUPPORTED_UDF_INPUT_TYPE" : {
"message" : [
"UDFs do not support '<dataType>' as an input data type."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,17 @@

package org.apache.spark.sql.catalyst.expressions.aggregate

import org.apache.spark.SparkIllegalArgumentException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, TypeCheckResult, UnresolvedWithinGroup}
import org.apache.spark.sql.catalyst.expressions.{Ascending, Descending, Expression, ExpressionDescription, ImplicitCastInputTypes, SortOrder}
import org.apache.spark.sql.catalyst.expressions.Cast.toSQLExpr
import org.apache.spark.sql.catalyst.trees.UnaryLike
import org.apache.spark.sql.catalyst.types.PhysicalDataType
import org.apache.spark.sql.catalyst.util.{CollationFactory, GenericArrayData, UnsafeRowUtils}
import org.apache.spark.sql.catalyst.util.{ArrayData, CollationFactory, GenericArrayData, UnsafeRowUtils}
import org.apache.spark.sql.errors.DataTypeErrors.{toSQLId, toSQLType}
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, ArrayType, BooleanType, DataType, StringType}
import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, ArrayType, BooleanType, DataType, MapType, StringType, StructField, StructType}
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.collection.OpenHashMap

Expand All @@ -50,17 +53,20 @@ case class Mode(
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)

override def checkInputDataTypes(): TypeCheckResult = {
if (UnsafeRowUtils.isBinaryStable(child.dataType) || child.dataType.isInstanceOf[StringType]) {
// TODO: SPARK-49358: Mode expression for map type with collated fields
if (UnsafeRowUtils.isBinaryStable(child.dataType) ||
!child.dataType.existsRecursively(f => f.isInstanceOf[MapType] &&
!UnsafeRowUtils.isBinaryStable(f))) {
/*
* The Mode class uses collation awareness logic to handle string data.
* Complex types with collated fields are not yet supported.
* All complex types except MapType with collated fields are supported.
*/
// TODO: SPARK-48700: Mode expression for complex types (all collations)
super.checkInputDataTypes()
} else {
TypeCheckResult.TypeCheckFailure("The input to the function 'mode' was" +
" a type of binary-unstable type that is " +
s"not currently supported by ${prettyName}.")
TypeCheckResult.DataTypeMismatch("UNSUPPORTED_MODE_DATA_TYPE",
messageParameters =
Map("child" -> toSQLType(child.dataType),
"mode" -> toSQLId(prettyName)))
}
}

Expand All @@ -86,6 +92,54 @@ case class Mode(
buffer
}

private def getCollationAwareBuffer(
childDataType: DataType,
buffer: OpenHashMap[AnyRef, Long]): Iterable[(AnyRef, Long)] = {
def groupAndReduceBuffer(groupingFunction: AnyRef => _): Iterable[(AnyRef, Long)] = {
buffer.groupMapReduce(t =>
groupingFunction(t._1))(x => x)((x, y) => (x._1, x._2 + y._2)).values
}
def determineBufferingFunction(
childDataType: DataType): Option[AnyRef => _] = {
childDataType match {
case _ if UnsafeRowUtils.isBinaryStable(child.dataType) => None
case _ => Some(collationAwareTransform(_, childDataType))
}
}
determineBufferingFunction(childDataType).map(groupAndReduceBuffer).getOrElse(buffer)
}

protected[sql] def collationAwareTransform(data: AnyRef, dataType: DataType): AnyRef = {
dataType match {
case _ if UnsafeRowUtils.isBinaryStable(dataType) => data
case st: StructType =>
processStructTypeWithBuffer(data.asInstanceOf[InternalRow].toSeq(st).zip(st.fields))
case at: ArrayType => processArrayTypeWithBuffer(at, data.asInstanceOf[ArrayData])
case st: StringType =>
CollationFactory.getCollationKey(data.asInstanceOf[UTF8String], st.collationId)
case _ =>
throw new SparkIllegalArgumentException(
errorClass = "COMPLEX_EXPRESSION_UNSUPPORTED_INPUT.BAD_INPUTS",
messageParameters = Map(
"expression" -> toSQLExpr(this),
"functionName" -> toSQLType(prettyName),
"dataType" -> toSQLType(child.dataType))
)
}
}

private def processStructTypeWithBuffer(
tuples: Seq[(Any, StructField)]): Seq[Any] = {
tuples.map(t => collationAwareTransform(t._1.asInstanceOf[AnyRef], t._2.dataType))
}

private def processArrayTypeWithBuffer(
a: ArrayType,
data: ArrayData): Seq[Any] = {
(0 until data.numElements()).map(i =>
collationAwareTransform(data.get(i, a.elementType), a.elementType))
}

override def eval(buffer: OpenHashMap[AnyRef, Long]): Any = {
if (buffer.isEmpty) {
return null
Expand All @@ -102,17 +156,12 @@ case class Mode(
* to a single value (the sum of the counts), and finally reduces the groups to a single map.
*
* The new map is then used in the rest of the Mode evaluation logic.
*
* It is expected to work for all simple and complex types with
* collated fields, except for MapType (temporarily).
*/
val collationAwareBuffer = child.dataType match {
case c: StringType if
!CollationFactory.fetchCollation(c.collationId).supportsBinaryEquality =>
val collationId = c.collationId
val modeMap = buffer.toSeq.groupMapReduce {
case (k, _) => CollationFactory.getCollationKey(k.asInstanceOf[UTF8String], collationId)
}(x => x)((x, y) => (x._1, x._2 + y._2)).values
modeMap
case _ => buffer
}
val collationAwareBuffer = getCollationAwareBuffer(child.dataType, buffer)

reverseOpt.map { reverse =>
val defaultKeyOrdering = if (reverse) {
PhysicalDataType.ordering(child.dataType).asInstanceOf[Ordering[AnyRef]].reverse
Expand Down
Loading