-
Notifications
You must be signed in to change notification settings - Fork 29.1k
[WIP] [PROOF OF CONCEPT] [SPARK] [SQL] Collation Mode #46917
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
0bab248
f054589
5d171d6
a52a5e4
1c11d32
e6bc82c
94869d7
b114de5
38e31db
13d5bae
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,6 +22,7 @@ import scala.reflect._ | |
| import com.google.common.hash.Hashing | ||
|
|
||
| import org.apache.spark.annotation.Private | ||
| import org.apache.spark.util.collection.OpenHashSet.Hasher | ||
|
|
||
| /** | ||
| * A simple, fast hash set optimized for non-null insertion-only use case, where keys are never | ||
|
|
@@ -40,10 +41,12 @@ import org.apache.spark.annotation.Private | |
| * It uses quadratic probing with a power-of-2 hash table size, which is guaranteed | ||
| * to explore all spaces for each key (see http://en.wikipedia.org/wiki/Quadratic_probing). | ||
| */ | ||
|
|
||
| @Private | ||
| class OpenHashSet[@specialized(Long, Int, Double, Float) T: ClassTag]( | ||
| initialCapacity: Int, | ||
| loadFactor: Double) | ||
| loadFactor: Double | ||
| ) | ||
| extends Serializable { | ||
|
|
||
| require(initialCapacity <= OpenHashSet.MAX_CAPACITY, | ||
|
|
@@ -67,7 +70,11 @@ class OpenHashSet[@specialized(Long, Int, Double, Float) T: ClassTag]( | |
| case ClassTag.Int => new IntHasher().asInstanceOf[Hasher[T]] | ||
| case ClassTag.Double => new DoubleHasher().asInstanceOf[Hasher[T]] | ||
| case ClassTag.Float => new FloatHasher().asInstanceOf[Hasher[T]] | ||
| case _ => new Hasher[T] | ||
| case _ => nonClassTagHasher() | ||
| } | ||
|
|
||
| protected def nonClassTagHasher(): Hasher[T] = { | ||
| new Hasher[T] | ||
| } | ||
|
|
||
| protected var _capacity = nextPowerOf2(initialCapacity) | ||
|
|
@@ -118,8 +125,19 @@ class OpenHashSet[@specialized(Long, Int, Double, Float) T: ClassTag]( | |
| * See: https://issues.apache.org/jira/browse/SPARK-45599 | ||
| */ | ||
| @annotation.nowarn("cat=other-non-cooperative-equals") | ||
| private def keyExistsAtPos(k: T, pos: Int) = | ||
| _data(pos) equals k | ||
| protected def keyExistsAtPos(k: T, pos: Int) = { | ||
| classTag[T] match { | ||
| case ClassTag.Long | ClassTag.Int | ClassTag.Double | ClassTag.Float => | ||
| _data(pos) equals k | ||
| case _ => nonClassTagKeyExistsAtPos(k, _data(pos)) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. wait, since this is
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @uros-db the specialized annotation is a performance optimization for versions of the generic class with those primitive types. But the type can still be anything else. The annotation results in the compiler generating the versions of generic classes for the specific types being specialized. The generic class can still be any other type, but boxing and unboxing will occur. Source: https://www.scala-lang.org/api/current/scala/specialized.html |
||
|
|
||
| } | ||
| } | ||
|
|
||
| @annotation.nowarn("cat=other-non-cooperative-equals") | ||
| protected def nonClassTagKeyExistsAtPos(k: T, dataAtPos: T): Boolean = { | ||
| dataAtPos equals k | ||
| } | ||
|
Comment on lines
+137
to
+140
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. related to the comment above, could you explain what this does?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @uros-db please see comment above and let me know if we are on the same page |
||
|
|
||
| /** | ||
| * Add an element to the set. This one differs from add in that it doesn't trigger rehashing. | ||
|
|
@@ -291,9 +309,6 @@ object OpenHashSet { | |
| * A set of specialized hash function implementation to avoid boxing hash code computation | ||
| * in the specialized implementation of OpenHashSet. | ||
| */ | ||
| sealed class Hasher[@specialized(Long, Int, Double, Float) T] extends Serializable { | ||
| def hash(o: T): Int = o.hashCode() | ||
| } | ||
|
|
||
| class LongHasher extends Hasher[Long] { | ||
| override def hash(o: Long): Int = (o ^ (o >>> 32)).toInt | ||
|
|
@@ -314,9 +329,50 @@ object OpenHashSet { | |
| override def hash(o: Float): Int = java.lang.Float.floatToIntBits(o) | ||
| } | ||
|
|
||
| class Hasher[@specialized(Long, Int, Double, Float) T] extends Serializable { | ||
| def hash(o: T): Int = o.hashCode() | ||
| } | ||
|
|
||
| private def grow1(newSize: Int): Unit = {} | ||
| private def move1(oldPos: Int, newPos: Int): Unit = { } | ||
|
|
||
| private val grow = grow1 _ | ||
| private val move = move1 _ | ||
| } | ||
|
|
||
|
|
||
| @Private | ||
| class CollationAwareOpenHashSet[T: ClassTag, T2]( | ||
| initialCapacity: Int, | ||
| loadFactor: Double, hashFunc: AnyRef => Long, | ||
| equalsFunction: (AnyRef, AnyRef) => Boolean) | ||
| extends OpenHashSet[T](initialCapacity, loadFactor) { | ||
|
|
||
| override def nonClassTagKeyExistsAtPos(k: T, dataAtPos: T): Boolean = { | ||
| equalsFunction(k.asInstanceOf[AnyRef], dataAtPos.asInstanceOf[AnyRef]) | ||
| } | ||
|
|
||
| override def nonClassTagHasher(): OpenHashSet.Hasher[T] = { | ||
| val f: Object => Int = o => hashFunc(o) | ||
| .toInt | ||
| new CustomHasher(f.asInstanceOf[Any => Int]).asInstanceOf[Hasher[T]] | ||
| } | ||
|
|
||
| class CustomHasher(f: Any => Int) extends Hasher[Any] { | ||
| override def hash(o: Any): Int = { | ||
| f(o) | ||
| } | ||
| } | ||
| /* | ||
| , | ||
| var specialPassedInHasher: Option[Object => Int] = Some(o => { | ||
| val i = CollationFactory.fetchCollation(1) | ||
| .hashFunction.applyAsLong(o.asInstanceOf[UTF8String]) | ||
| .toInt | ||
| // scalastyle:off println | ||
| println(s"Hashing: $o -> $i") | ||
| // scalastyle:on println | ||
| i | ||
| }) | ||
| */ | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,21 +18,22 @@ | |
| package org.apache.spark.sql.catalyst.expressions.aggregate | ||
|
|
||
| import org.apache.spark.sql.catalyst.InternalRow | ||
| import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, UnresolvedWithinGroup} | ||
| import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, TypeCheckResult, UnresolvedWithinGroup} | ||
| import org.apache.spark.sql.catalyst.expressions.{Ascending, Descending, Expression, ExpressionDescription, ImplicitCastInputTypes, SortOrder} | ||
| import org.apache.spark.sql.catalyst.trees.UnaryLike | ||
| import org.apache.spark.sql.catalyst.types.PhysicalDataType | ||
| import org.apache.spark.sql.catalyst.util.GenericArrayData | ||
| import org.apache.spark.sql.catalyst.util.{GenericArrayData, UnsafeRowUtils} | ||
| import org.apache.spark.sql.errors.QueryCompilationErrors | ||
| import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, ArrayType, BooleanType, DataType} | ||
| import org.apache.spark.util.collection.OpenHashMap | ||
| import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, ArrayType, BooleanType, DataType, StringType} | ||
| import org.apache.spark.unsafe.types.UTF8String | ||
| import org.apache.spark.util.collection.{CollationAwareHashMap, OpenHashMap} | ||
|
|
||
| case class Mode( | ||
| child: Expression, | ||
| mutableAggBufferOffset: Int = 0, | ||
| inputAggBufferOffset: Int = 0, | ||
| reverseOpt: Option[Boolean] = None) | ||
| extends TypedAggregateWithHashMapAsBuffer with ImplicitCastInputTypes | ||
| extends CollationAwareTypedAggregateWithHashMapAsBuffer with ImplicitCastInputTypes | ||
| with SupportsOrderingWithinGroup with UnaryLike[Expression] { | ||
|
|
||
| def this(child: Expression) = this(child, 0, 0) | ||
|
|
@@ -41,6 +42,16 @@ case class Mode( | |
| this(child, 0, 0, Some(reverse)) | ||
| } | ||
|
|
||
| override def checkInputDataTypes(): TypeCheckResult = { | ||
| if (UnsafeRowUtils.isBinaryStable(child.dataType) || child.dataType.isInstanceOf[StringType]) { | ||
| super.checkInputDataTypes() | ||
| } else { | ||
| TypeCheckResult.TypeCheckFailure( | ||
| "The input to the function 'mode' was a type of binary-unstable type that is " + | ||
| s"not currently supported by ${prettyName}.") | ||
| } | ||
| } | ||
|
|
||
| // Returns null for empty inputs | ||
| override def nullable: Boolean = true | ||
|
|
||
|
|
@@ -51,8 +62,8 @@ case class Mode( | |
| override def prettyName: String = "mode" | ||
|
|
||
| override def update( | ||
| buffer: OpenHashMap[AnyRef, Long], | ||
| input: InternalRow): OpenHashMap[AnyRef, Long] = { | ||
| buffer: CollationAwareHashMap[AnyRef, Long, UTF8String], | ||
| input: InternalRow): CollationAwareHashMap[AnyRef, Long, UTF8String] = { | ||
| val key = child.eval(input) | ||
|
|
||
| if (key != null) { | ||
|
|
@@ -62,19 +73,19 @@ case class Mode( | |
| } | ||
|
|
||
| override def merge( | ||
| buffer: OpenHashMap[AnyRef, Long], | ||
| other: OpenHashMap[AnyRef, Long]): OpenHashMap[AnyRef, Long] = { | ||
| buffer: CollationAwareHashMap[AnyRef, Long, UTF8String], | ||
| other: CollationAwareHashMap[AnyRef, Long, UTF8String]): | ||
| CollationAwareHashMap[AnyRef, Long, UTF8String] = { | ||
| other.foreach { case (key, count) => | ||
| buffer.changeValue(key, count, _ + count) | ||
| } | ||
| buffer | ||
| } | ||
|
|
||
| override def eval(buffer: OpenHashMap[AnyRef, Long]): Any = { | ||
| override def eval(buffer: CollationAwareHashMap[AnyRef, Long, UTF8String]): Any = { | ||
| if (buffer.isEmpty) { | ||
| return null | ||
| } | ||
|
|
||
| reverseOpt.map { reverse => | ||
| val defaultKeyOrdering = if (reverse) { | ||
| PhysicalDataType.ordering(child.dataType).asInstanceOf[Ordering[AnyRef]].reverse | ||
|
|
@@ -201,7 +212,8 @@ case class PandasMode( | |
| child: Expression, | ||
| ignoreNA: Boolean = true, | ||
| mutableAggBufferOffset: Int = 0, | ||
| inputAggBufferOffset: Int = 0) extends TypedAggregateWithHashMapAsBuffer | ||
| inputAggBufferOffset: Int = 0) | ||
| extends TypedAggregateWithHashMapAsBuffer | ||
|
Comment on lines
+215
to
+216
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. looks like an unnecessary change, but I suppose it's a prototyping leftover so I won't comment these more for now |
||
| with ImplicitCastInputTypes with UnaryLike[Expression] { | ||
|
|
||
| def this(child: Expression) = this(child, true, 0, 0) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -24,10 +24,13 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute | |
| import org.apache.spark.sql.catalyst.expressions._ | ||
| import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode} | ||
| import org.apache.spark.sql.catalyst.trees.TreePattern.{AGGREGATE_EXPRESSION, TreePattern} | ||
| import org.apache.spark.sql.catalyst.trees.UnaryLike | ||
| import org.apache.spark.sql.catalyst.types.DataTypeUtils | ||
| import org.apache.spark.sql.catalyst.util.{CollationFactory, UnsafeRowUtils} | ||
| import org.apache.spark.sql.errors.QueryExecutionErrors | ||
| import org.apache.spark.sql.types._ | ||
| import org.apache.spark.util.collection.OpenHashMap | ||
| import org.apache.spark.unsafe.types.UTF8String | ||
| import org.apache.spark.util.collection.{CollationAwareHashMap, OpenHashMap} | ||
|
|
||
| /** The mode of an [[AggregateFunction]]. */ | ||
| sealed trait AggregateMode | ||
|
|
@@ -643,18 +646,14 @@ abstract class TypedImperativeAggregate[T] extends ImperativeAggregate { | |
| * A special [[TypedImperativeAggregate]] that uses `OpenHashMap[AnyRef, Long]` as internal | ||
| * aggregation buffer. | ||
| */ | ||
| abstract class TypedAggregateWithHashMapAsBuffer | ||
| extends TypedImperativeAggregate[OpenHashMap[AnyRef, Long]] { | ||
| override def createAggregationBuffer(): OpenHashMap[AnyRef, Long] = { | ||
| // Initialize new counts map instance here. | ||
| new OpenHashMap[AnyRef, Long]() | ||
| } | ||
| abstract class TypedAggregateWithHashMapAsBufferBase[HM <: OpenHashMap[AnyRef, Long]] | ||
| extends TypedImperativeAggregate[HM] { | ||
|
|
||
| protected def child: Expression | ||
|
|
||
| private lazy val projection = UnsafeProjection.create(Array[DataType](child.dataType, LongType)) | ||
|
|
||
| override def serialize(obj: OpenHashMap[AnyRef, Long]): Array[Byte] = { | ||
| override def serialize(obj: HM): Array[Byte] = { | ||
| val buffer = new Array[Byte](4 << 10) // 4K | ||
| val bos = new ByteArrayOutputStream() | ||
| val out = new DataOutputStream(bos) | ||
|
|
@@ -676,11 +675,16 @@ abstract class TypedAggregateWithHashMapAsBuffer | |
| } | ||
| } | ||
|
|
||
| override def deserialize(bytes: Array[Byte]): OpenHashMap[AnyRef, Long] = { | ||
| override def deserialize(bytes: Array[Byte]): HM = { | ||
| val bis = new ByteArrayInputStream(bytes) | ||
| val ins = new DataInputStream(bis) | ||
| try { | ||
| val counts = new OpenHashMap[AnyRef, Long] | ||
| val counts = createAggregationBuffer() | ||
| /* (64, child.dataType match { | ||
| case StringType if child.dataType.asInstanceOf[StringType].isUTF8BinaryLcaseCollation => 1 | ||
| case StringType => 0 | ||
| case _ => -1 | ||
| }) */ | ||
| // Read unsafeRow size and content in bytes. | ||
| var sizeOfNextRow = ins.readInt() | ||
| while (sizeOfNextRow >= 0) { | ||
|
|
@@ -702,3 +706,55 @@ abstract class TypedAggregateWithHashMapAsBuffer | |
| } | ||
| } | ||
| } | ||
|
|
||
| object CollationAwareFunctionRegistry { | ||
| def bytesToHashFunction(dataType: DataType): AnyRef => Long = { | ||
| val hashFunction = dataType match { | ||
| case s: StringType => a: AnyRef => CollationFactory.fetchCollation(s.collationId) | ||
| .hashFunction.applyAsLong(a.asInstanceOf[UTF8String]) | ||
| case nb: StructType if !UnsafeRowUtils.isBinaryStable(nb) => a: AnyRef => | ||
| a.asInstanceOf[InternalRow].toSeq(nb).zip(nb.fields.toSeq).foldLeft(0L)((acc, b) | ||
| => acc ^ bytesToHashFunction(b._2.dataType)(b._1.asInstanceOf[AnyRef])) | ||
| case _ => a: AnyRef => a.hashCode().toLong | ||
| } | ||
| hashFunction | ||
| } | ||
| def bytesToEqualFunction(dataType: DataType): (AnyRef, AnyRef) => Boolean = { | ||
| val equalFunction = dataType match { | ||
| case s: StringType => (a: AnyRef, b: AnyRef) => | ||
| a.asInstanceOf[UTF8String].semanticEquals(b.asInstanceOf[UTF8String], s.collationId) | ||
| case nb: StructType if !UnsafeRowUtils.isBinaryStable(nb) => (a: AnyRef, b: AnyRef) => | ||
| a.asInstanceOf[InternalRow].toSeq(nb) | ||
| .zip(b.asInstanceOf[InternalRow].toSeq(nb)).zipWithIndex | ||
| .forall { case ((a, b), i) => | ||
| bytesToEqualFunction( | ||
| nb.fields(i).dataType | ||
| )(a.asInstanceOf[AnyRef], b.asInstanceOf[AnyRef]) | ||
| } | ||
| case _ => (a: AnyRef, b: AnyRef) => | ||
| a.equals(b) | ||
| } | ||
| equalFunction | ||
|
|
||
| } | ||
| } | ||
|
Comment on lines
+710
to
+740
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. these look like general collations-related Util methods, perhaps we should place this elsewhere (for example, into CollationFactory) |
||
|
|
||
| abstract class CollationAwareTypedAggregateWithHashMapAsBuffer | ||
| extends TypedAggregateWithHashMapAsBufferBase[CollationAwareHashMap[AnyRef, Long, UTF8String]] { | ||
| self: UnaryLike[Expression] => | ||
| override def createAggregationBuffer(): CollationAwareHashMap[AnyRef, Long, UTF8String] = { | ||
| // Initialize new counts map instance here. | ||
| new CollationAwareHashMap[AnyRef, Long, UTF8String](64, | ||
| CollationAwareFunctionRegistry.bytesToHashFunction(self.dataType), | ||
| CollationAwareFunctionRegistry.bytesToEqualFunction(self.dataType) | ||
| ) | ||
| } | ||
| } | ||
|
|
||
| abstract class TypedAggregateWithHashMapAsBuffer | ||
| extends TypedAggregateWithHashMapAsBufferBase[OpenHashMap[AnyRef, Long]]{ | ||
| override def createAggregationBuffer(): OpenHashMap[AnyRef, Long] = { | ||
| // Initialize new counts map instance here. | ||
| new OpenHashMap[AnyRef, Long]() | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's 0.7 here? since I see it in multiple lines, we should consider separating it out into one place, like a constant