Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-31500][SQL] collect_set() of BinaryType returns duplicate elements #28351

Closed
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,15 @@ abstract class Collect[T <: Growable[Any] with Iterable[Any]] extends TypedImper
// actual order of input rows.
override lazy val deterministic: Boolean = false

protected def convertToBufferElement(value: Any): Any

override def update(buffer: T, input: InternalRow): T = {
val value = child.eval(input)

// Do not allow null values. We follow the semantics of Hive's collect_list/collect_set here.
// See: org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMkCollectionEvaluator
if (value != null) {
buffer += InternalRow.copyValue(value)
buffer += convertToBufferElement(value)
}
buffer
}
Expand All @@ -61,12 +63,10 @@ abstract class Collect[T <: Growable[Any] with Iterable[Any]] extends TypedImper
buffer ++= other
}

override def eval(buffer: T): Any = {
new GenericArrayData(buffer.toArray)
}
protected val bufferElementType: DataType

private lazy val projection = UnsafeProjection.create(
Array[DataType](ArrayType(elementType = child.dataType, containsNull = false)))
Array[DataType](ArrayType(elementType = bufferElementType, containsNull = false)))
private lazy val row = new UnsafeRow(1)

override def serialize(obj: T): Array[Byte] = {
Expand All @@ -77,7 +77,7 @@ abstract class Collect[T <: Growable[Any] with Iterable[Any]] extends TypedImper
override def deserialize(bytes: Array[Byte]): T = {
val buffer = createAggregationBuffer()
row.pointTo(bytes, bytes.length)
row.getArray(0).foreach(child.dataType, (_, x: Any) => buffer += x)
row.getArray(0).foreach(bufferElementType, (_, x: Any) => buffer += x)
buffer
}
}
Expand Down Expand Up @@ -105,6 +105,10 @@ case class CollectList(

def this(child: Expression) = this(child, 0, 0)

override lazy val bufferElementType = child.dataType

override def convertToBufferElement(value: Any): Any = InternalRow.copyValue(value)

override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
copy(mutableAggBufferOffset = newMutableAggBufferOffset)

Expand All @@ -114,6 +118,10 @@ case class CollectList(
override def createAggregationBuffer(): mutable.ArrayBuffer[Any] = mutable.ArrayBuffer.empty

override def prettyName: String = "collect_list"

override def eval(buffer: mutable.ArrayBuffer[Any]): Any = {
new GenericArrayData(buffer.toArray)
}
}

/**
Expand All @@ -139,6 +147,35 @@ case class CollectSet(

def this(child: Expression) = this(child, 0, 0)

/*
* SPARK-31500
* Array[Byte](BinaryType) Scala equality don't works as expected
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

  /* 
   * SPARK-31500: Array[Byte](BinaryType) Scala equality don't works as expected

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not a Scala issue. Java byte arrays use referential equality and identity hash codes. This has tripped up many many people before.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I know the main reason, I'm going tom explain it better

* so HashSet return duplicates, we have to change types to drop
* this duplicates and make collect_set work as expected for this
* data type
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you make this comment clearer for others and move it into the line 163-164?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, it's better to move to this line. I've tried to clarify the message

*/
override lazy val bufferElementType = child.dataType match {
case BinaryType => ArrayType(BinaryType)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ArrayType(ByteType)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, it works anyway but I think it's better in this way

case other => other
}

override def convertToBufferElement(value: Any): Any = {
val v = InternalRow.copyValue(value)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit you only need to copy for the default case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it's true. Cleaner in that way

child.dataType match {
case BinaryType => UnsafeArrayData.fromPrimitiveArray(v.asInstanceOf[Array[Byte]])
case _ => v
}
}

override def eval(buffer: mutable.HashSet[Any]): Any = {
planga82 marked this conversation as resolved.
Show resolved Hide resolved
val array = child.dataType match {
case BinaryType =>
buffer.iterator.map(_.asInstanceOf[UnsafeArrayData].toByteArray).toArray
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a bit safer to cast to ArrayData here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

case _ => buffer.toArray
}
new GenericArrayData(array)
}

override def checkInputDataTypes(): TypeCheckResult = {
if (!child.dataType.existsRecursively(_.isInstanceOf[MapType])) {
TypeCheckResult.TypeCheckSuccess
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql

import scala.collection.mutable
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we still need this import?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need it, I forget it, thanks

import scala.util.Random

import org.scalatest.Matchers.the
Expand Down Expand Up @@ -530,6 +531,26 @@ class DataFrameAggregateSuite extends QueryTest
)
}

test("SPARK-31500: collect_set() of BinaryType returns duplicate elements") {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: collect_set() of BinaryType should not return duplicate elements?

val bytesTest1 = "test1".getBytes
val bytesTest2 = "test2".getBytes
val df = Seq(bytesTest1, bytesTest1, bytesTest2).toDF("a")
val ret = df.select(collect_set($"a")).collect()
.map(r => r.getAs[Seq[_]](0)).head
assert(ret.length == 2)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: checkAnswer(df.select(size(collect_set($"a"))), Row(2) :: Nil)?


val a = "aa".getBytes
val b = "bb".getBytes
val c = "cc".getBytes
val d = "dd".getBytes
val df1 = Seq((a, b), (a, b), (c, d))
.toDF("x", "y")
.select(struct($"x", $"y").as("a"))
val ret1 = df1.select(collect_set($"a")).collect()
.map(r => r.getAs[Seq[_]](0)).head
assert(ret1.length == 2)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: checkAnswer(df1.select(size(collect_set($"a"))), Row(2) :: Nil)?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shorter way to compare, good idea!

}

test("collect_set functions cannot have maps") {
val df = Seq((1, 3, 0), (2, 3, 0), (3, 4, 1))
.toDF("a", "x", "y")
Expand Down