Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-31500][SQL] collect_set() of BinaryType returns duplicate elements #28351

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import scala.collection.mutable
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.ArrayData
import org.apache.spark.sql.catalyst.util.GenericArrayData
import org.apache.spark.sql.types._

Expand All @@ -46,13 +47,15 @@ abstract class Collect[T <: Growable[Any] with Iterable[Any]] extends TypedImper
// actual order of input rows.
override lazy val deterministic: Boolean = false

protected def convertToBufferElement(value: Any): Any

override def update(buffer: T, input: InternalRow): T = {
val value = child.eval(input)

// Do not allow null values. We follow the semantics of Hive's collect_list/collect_set here.
// See: org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMkCollectionEvaluator
if (value != null) {
buffer += InternalRow.copyValue(value)
buffer += convertToBufferElement(value)
}
buffer
}
Expand All @@ -61,12 +64,10 @@ abstract class Collect[T <: Growable[Any] with Iterable[Any]] extends TypedImper
buffer ++= other
}

override def eval(buffer: T): Any = {
new GenericArrayData(buffer.toArray)
}
protected val bufferElementType: DataType

private lazy val projection = UnsafeProjection.create(
Array[DataType](ArrayType(elementType = child.dataType, containsNull = false)))
Array[DataType](ArrayType(elementType = bufferElementType, containsNull = false)))
private lazy val row = new UnsafeRow(1)

override def serialize(obj: T): Array[Byte] = {
Expand All @@ -77,7 +78,7 @@ abstract class Collect[T <: Growable[Any] with Iterable[Any]] extends TypedImper
override def deserialize(bytes: Array[Byte]): T = {
val buffer = createAggregationBuffer()
row.pointTo(bytes, bytes.length)
row.getArray(0).foreach(child.dataType, (_, x: Any) => buffer += x)
row.getArray(0).foreach(bufferElementType, (_, x: Any) => buffer += x)
buffer
}
}
Expand Down Expand Up @@ -105,6 +106,10 @@ case class CollectList(

def this(child: Expression) = this(child, 0, 0)

override lazy val bufferElementType = child.dataType

override def convertToBufferElement(value: Any): Any = InternalRow.copyValue(value)

override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
copy(mutableAggBufferOffset = newMutableAggBufferOffset)

Expand All @@ -114,6 +119,10 @@ case class CollectList(
override def createAggregationBuffer(): mutable.ArrayBuffer[Any] = mutable.ArrayBuffer.empty

override def prettyName: String = "collect_list"

override def eval(buffer: mutable.ArrayBuffer[Any]): Any = {
new GenericArrayData(buffer.toArray)
}
}

/**
Expand All @@ -139,6 +148,30 @@ case class CollectSet(

def this(child: Expression) = this(child, 0, 0)

override lazy val bufferElementType = child.dataType match {
case BinaryType => ArrayType(ByteType)
case other => other
}

override def convertToBufferElement(value: Any): Any = child.dataType match {
/*
* collect_set() of BinaryType should not return duplicate elements,
* Java byte arrays use referential equality and identity hash codes
* so we need to use a different catalyst value for arrays
*/
case BinaryType => UnsafeArrayData.fromPrimitiveArray(value.asInstanceOf[Array[Byte]])
case _ => InternalRow.copyValue(value)
}

override def eval(buffer: mutable.HashSet[Any]): Any = {
planga82 marked this conversation as resolved.
Show resolved Hide resolved
val array = child.dataType match {
case BinaryType =>
buffer.iterator.map(_.asInstanceOf[ArrayData].toByteArray).toArray
case _ => buffer.toArray
}
new GenericArrayData(array)
}

override def checkInputDataTypes(): TypeCheckResult = {
if (!child.dataType.existsRecursively(_.isInstanceOf[MapType])) {
TypeCheckResult.TypeCheckSuccess
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,22 @@ class DataFrameAggregateSuite extends QueryTest
)
}

test("SPARK-31500: collect_set() of BinaryType returns duplicate elements") {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: collect_set() of BinaryType should not return duplicate elements?

val bytesTest1 = "test1".getBytes
val bytesTest2 = "test2".getBytes
val df = Seq(bytesTest1, bytesTest1, bytesTest2).toDF("a")
checkAnswer(df.select(size(collect_set($"a"))), Row(2) :: Nil)

val a = "aa".getBytes
val b = "bb".getBytes
val c = "cc".getBytes
val d = "dd".getBytes
val df1 = Seq((a, b), (a, b), (c, d))
.toDF("x", "y")
.select(struct($"x", $"y").as("a"))
checkAnswer(df1.select(size(collect_set($"a"))), Row(2) :: Nil)
}

test("collect_set functions cannot have maps") {
val df = Seq((1, 3, 0), (2, 3, 0), (3, 4, 1))
.toDF("a", "x", "y")
Expand Down