-
Notifications
You must be signed in to change notification settings - Fork 28.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-31500][SQL] collect_set() of BinaryType returns duplicate elements #28351
Changes from 5 commits
d1c7525
91382bf
586c4b7
2cd1a22
a5b1dd3
7ea0059
881ee9b
a9c3576
4782aea
67f55a6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -46,13 +46,15 @@ abstract class Collect[T <: Growable[Any] with Iterable[Any]] extends TypedImper | |
// actual order of input rows. | ||
override lazy val deterministic: Boolean = false | ||
|
||
protected def convertToBufferElement(value: Any): Any | ||
|
||
override def update(buffer: T, input: InternalRow): T = { | ||
val value = child.eval(input) | ||
|
||
// Do not allow null values. We follow the semantics of Hive's collect_list/collect_set here. | ||
// See: org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMkCollectionEvaluator | ||
if (value != null) { | ||
buffer += InternalRow.copyValue(value) | ||
buffer += convertToBufferElement(value) | ||
} | ||
buffer | ||
} | ||
|
@@ -61,12 +63,10 @@ abstract class Collect[T <: Growable[Any] with Iterable[Any]] extends TypedImper | |
buffer ++= other | ||
} | ||
|
||
override def eval(buffer: T): Any = { | ||
new GenericArrayData(buffer.toArray) | ||
} | ||
protected val bufferElementType: DataType | ||
|
||
private lazy val projection = UnsafeProjection.create( | ||
Array[DataType](ArrayType(elementType = child.dataType, containsNull = false))) | ||
Array[DataType](ArrayType(elementType = bufferElementType, containsNull = false))) | ||
private lazy val row = new UnsafeRow(1) | ||
|
||
override def serialize(obj: T): Array[Byte] = { | ||
|
@@ -77,7 +77,7 @@ abstract class Collect[T <: Growable[Any] with Iterable[Any]] extends TypedImper | |
override def deserialize(bytes: Array[Byte]): T = { | ||
val buffer = createAggregationBuffer() | ||
row.pointTo(bytes, bytes.length) | ||
row.getArray(0).foreach(child.dataType, (_, x: Any) => buffer += x) | ||
row.getArray(0).foreach(bufferElementType, (_, x: Any) => buffer += x) | ||
buffer | ||
} | ||
} | ||
|
@@ -105,6 +105,10 @@ case class CollectList( | |
|
||
def this(child: Expression) = this(child, 0, 0) | ||
|
||
override lazy val bufferElementType = child.dataType | ||
|
||
override def convertToBufferElement(value: Any): Any = InternalRow.copyValue(value) | ||
|
||
override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = | ||
copy(mutableAggBufferOffset = newMutableAggBufferOffset) | ||
|
||
|
@@ -114,6 +118,10 @@ case class CollectList( | |
override def createAggregationBuffer(): mutable.ArrayBuffer[Any] = mutable.ArrayBuffer.empty | ||
|
||
override def prettyName: String = "collect_list" | ||
|
||
override def eval(buffer: mutable.ArrayBuffer[Any]): Any = { | ||
new GenericArrayData(buffer.toArray) | ||
} | ||
} | ||
|
||
/** | ||
|
@@ -139,6 +147,35 @@ case class CollectSet( | |
|
||
def this(child: Expression) = this(child, 0, 0) | ||
|
||
/* | ||
* SPARK-31500 | ||
* Array[Byte](BinaryType) Scala equality don't works as expected | ||
* so HashSet return duplicates, we have to change types to drop | ||
* this duplicates and make collect_set work as expected for this | ||
* data type | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you make this comment clearer for others and move it into the line 163-164? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok, it's better to move to this line. I've tried to clarify the message |
||
*/ | ||
override lazy val bufferElementType = child.dataType match { | ||
case BinaryType => ArrayType(BinaryType) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ArrayType(ByteType)? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good catch, it works anyway but I think it's better in this way |
||
case other => other | ||
} | ||
|
||
override def convertToBufferElement(value: Any): Any = { | ||
val v = InternalRow.copyValue(value) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit you only need to copy for the default case. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes it's true. Cleaner in that way |
||
child.dataType match { | ||
case BinaryType => UnsafeArrayData.fromPrimitiveArray(v.asInstanceOf[Array[Byte]]) | ||
case _ => v | ||
} | ||
} | ||
|
||
override def eval(buffer: mutable.HashSet[Any]): Any = { | ||
planga82 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
val array = child.dataType match { | ||
case BinaryType => | ||
buffer.iterator.map(_.asInstanceOf[UnsafeArrayData].toByteArray).toArray | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is a bit safer to cast to ArrayData here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK |
||
case _ => buffer.toArray | ||
} | ||
new GenericArrayData(array) | ||
} | ||
|
||
override def checkInputDataTypes(): TypeCheckResult = { | ||
if (!child.dataType.existsRecursively(_.isInstanceOf[MapType])) { | ||
TypeCheckResult.TypeCheckSuccess | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,6 +17,7 @@ | |
|
||
package org.apache.spark.sql | ||
|
||
import scala.collection.mutable | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we still need this import? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We don't need it, I forget it, thanks |
||
import scala.util.Random | ||
|
||
import org.scalatest.Matchers.the | ||
|
@@ -530,6 +531,26 @@ class DataFrameAggregateSuite extends QueryTest | |
) | ||
} | ||
|
||
test("SPARK-31500: collect_set() of BinaryType returns duplicate elements") { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: |
||
val bytesTest1 = "test1".getBytes | ||
val bytesTest2 = "test2".getBytes | ||
val df = Seq(bytesTest1, bytesTest1, bytesTest2).toDF("a") | ||
val ret = df.select(collect_set($"a")).collect() | ||
.map(r => r.getAs[Seq[_]](0)).head | ||
assert(ret.length == 2) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: |
||
|
||
val a = "aa".getBytes | ||
val b = "bb".getBytes | ||
val c = "cc".getBytes | ||
val d = "dd".getBytes | ||
val df1 = Seq((a, b), (a, b), (c, d)) | ||
.toDF("x", "y") | ||
.select(struct($"x", $"y").as("a")) | ||
val ret1 = df1.select(collect_set($"a")).collect() | ||
.map(r => r.getAs[Seq[_]](0)).head | ||
assert(ret1.length == 2) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. shorter way to compare, good idea! |
||
} | ||
|
||
test("collect_set functions cannot have maps") { | ||
val df = Seq((1, 3, 0), (2, 3, 0), (3, 4, 1)) | ||
.toDF("a", "x", "y") | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is not a Scala issue. Java byte arrays use referential equality and identity hash codes. This has tripped up many many people before.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I know the main reason, I'm going tom explain it better