Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,16 @@ import scala.reflect.ClassTag
* NOTE: when using numeric type as the value type, the user of this class should be careful to
* distinguish between the 0/0.0/0L and non-exist value
*/
private[spark] class CollationAwareHashMap[K: ClassTag, V: ClassTag, X](
initialCapacity: Int,
hashering: AnyRef => Long,
equalsFunction: (AnyRef, AnyRef) => Boolean
)
extends OpenHashMap[K, V](initialCapacity) {
override def getOpenHashSet: OpenHashSet[K] =
new CollationAwareOpenHashSet[K, X](initialCapacity, 0.7, hashering, equalsFunction)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's 0.7 here? since I see it in multiple lines, we should consider separating it out into one place, like a constant

}

private[spark]
class OpenHashMap[K : ClassTag, @specialized(Long, Int, Double) V: ClassTag](
initialCapacity: Int)
Expand All @@ -37,7 +47,8 @@ class OpenHashMap[K : ClassTag, @specialized(Long, Int, Double) V: ClassTag](

def this() = this(64)

protected var _keySet = new OpenHashSet[K](initialCapacity)
def getOpenHashSet: OpenHashSet[K] = new OpenHashSet[K](initialCapacity, 0.7)
protected var _keySet = getOpenHashSet

// Init in constructor (instead of in declaration) to work around a Scala compiler specialization
// bug that would generate two arrays (one for Object and one for specialized T).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import scala.reflect._
import com.google.common.hash.Hashing

import org.apache.spark.annotation.Private
import org.apache.spark.util.collection.OpenHashSet.Hasher

/**
* A simple, fast hash set optimized for non-null insertion-only use case, where keys are never
Expand All @@ -40,10 +41,12 @@ import org.apache.spark.annotation.Private
* It uses quadratic probing with a power-of-2 hash table size, which is guaranteed
* to explore all spaces for each key (see http://en.wikipedia.org/wiki/Quadratic_probing).
*/

@Private
class OpenHashSet[@specialized(Long, Int, Double, Float) T: ClassTag](
initialCapacity: Int,
loadFactor: Double)
loadFactor: Double
)
extends Serializable {

require(initialCapacity <= OpenHashSet.MAX_CAPACITY,
Expand All @@ -67,7 +70,11 @@ class OpenHashSet[@specialized(Long, Int, Double, Float) T: ClassTag](
case ClassTag.Int => new IntHasher().asInstanceOf[Hasher[T]]
case ClassTag.Double => new DoubleHasher().asInstanceOf[Hasher[T]]
case ClassTag.Float => new FloatHasher().asInstanceOf[Hasher[T]]
case _ => new Hasher[T]
case _ => nonClassTagHasher()
}

protected def nonClassTagHasher(): Hasher[T] = {
new Hasher[T]
}

protected var _capacity = nextPowerOf2(initialCapacity)
Expand Down Expand Up @@ -118,8 +125,19 @@ class OpenHashSet[@specialized(Long, Int, Double, Float) T: ClassTag](
* See: https://issues.apache.org/jira/browse/SPARK-45599
*/
@annotation.nowarn("cat=other-non-cooperative-equals")
private def keyExistsAtPos(k: T, pos: Int) =
_data(pos) equals k
protected def keyExistsAtPos(k: T, pos: Int) = {
classTag[T] match {
case ClassTag.Long | ClassTag.Int | ClassTag.Double | ClassTag.Float =>
_data(pos) equals k
case _ => nonClassTagKeyExistsAtPos(k, _data(pos))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wait, since this is specialized(Long, Int, Double, Float), what else could be here in case _?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@uros-db the specialized annotation is a performance optimization for versions of the generic class with those primitive types. But the type can still be anything else. The annotation results in the compiler generating the versions of generic classes for the specific types being specialized. The generic class can still be any other type, but boxing and unboxing will occur.

Source: https://www.scala-lang.org/api/current/scala/specialized.html
https://www.waitingforcode.com/scala-types/type-specialization-scala/read#:~:text=Scala%20uses%20the%20%40specialized%20class,classes%20for%20the%20specific%20types.
https://www.baeldung.com/scala/specialized-annotation


}
}

@annotation.nowarn("cat=other-non-cooperative-equals")
protected def nonClassTagKeyExistsAtPos(k: T, dataAtPos: T): Boolean = {
dataAtPos equals k
}
Comment on lines +137 to +140
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

related to the comment above, could you explain what this does?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@uros-db please see comment above and let me know if we are on the same page


/**
* Add an element to the set. This one differs from add in that it doesn't trigger rehashing.
Expand Down Expand Up @@ -291,9 +309,6 @@ object OpenHashSet {
* A set of specialized hash function implementation to avoid boxing hash code computation
* in the specialized implementation of OpenHashSet.
*/
sealed class Hasher[@specialized(Long, Int, Double, Float) T] extends Serializable {
def hash(o: T): Int = o.hashCode()
}

class LongHasher extends Hasher[Long] {
override def hash(o: Long): Int = (o ^ (o >>> 32)).toInt
Expand All @@ -314,9 +329,50 @@ object OpenHashSet {
override def hash(o: Float): Int = java.lang.Float.floatToIntBits(o)
}

class Hasher[@specialized(Long, Int, Double, Float) T] extends Serializable {
def hash(o: T): Int = o.hashCode()
}

private def grow1(newSize: Int): Unit = {}
private def move1(oldPos: Int, newPos: Int): Unit = { }

private val grow = grow1 _
private val move = move1 _
}


@Private
class CollationAwareOpenHashSet[T: ClassTag, T2](
initialCapacity: Int,
loadFactor: Double, hashFunc: AnyRef => Long,
equalsFunction: (AnyRef, AnyRef) => Boolean)
extends OpenHashSet[T](initialCapacity, loadFactor) {

override def nonClassTagKeyExistsAtPos(k: T, dataAtPos: T): Boolean = {
equalsFunction(k.asInstanceOf[AnyRef], dataAtPos.asInstanceOf[AnyRef])
}

override def nonClassTagHasher(): OpenHashSet.Hasher[T] = {
val f: Object => Int = o => hashFunc(o)
.toInt
new CustomHasher(f.asInstanceOf[Any => Int]).asInstanceOf[Hasher[T]]
}

class CustomHasher(f: Any => Int) extends Hasher[Any] {
override def hash(o: Any): Int = {
f(o)
}
}
/*
,
var specialPassedInHasher: Option[Object => Int] = Some(o => {
val i = CollationFactory.fetchCollation(1)
.hashFunction.applyAsLong(o.asInstanceOf[UTF8String])
.toInt
// scalastyle:off println
println(s"Hashing: $o -> $i")
// scalastyle:on println
i
})
*/
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,22 @@
package org.apache.spark.sql.catalyst.expressions.aggregate

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, UnresolvedWithinGroup}
import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, TypeCheckResult, UnresolvedWithinGroup}
import org.apache.spark.sql.catalyst.expressions.{Ascending, Descending, Expression, ExpressionDescription, ImplicitCastInputTypes, SortOrder}
import org.apache.spark.sql.catalyst.trees.UnaryLike
import org.apache.spark.sql.catalyst.types.PhysicalDataType
import org.apache.spark.sql.catalyst.util.GenericArrayData
import org.apache.spark.sql.catalyst.util.{GenericArrayData, UnsafeRowUtils}
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, ArrayType, BooleanType, DataType}
import org.apache.spark.util.collection.OpenHashMap
import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, ArrayType, BooleanType, DataType, StringType}
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.collection.{CollationAwareHashMap, OpenHashMap}

case class Mode(
child: Expression,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0,
reverseOpt: Option[Boolean] = None)
extends TypedAggregateWithHashMapAsBuffer with ImplicitCastInputTypes
extends CollationAwareTypedAggregateWithHashMapAsBuffer with ImplicitCastInputTypes
with SupportsOrderingWithinGroup with UnaryLike[Expression] {

def this(child: Expression) = this(child, 0, 0)
Expand All @@ -41,6 +42,16 @@ case class Mode(
this(child, 0, 0, Some(reverse))
}

override def checkInputDataTypes(): TypeCheckResult = {
if (UnsafeRowUtils.isBinaryStable(child.dataType) || child.dataType.isInstanceOf[StringType]) {
super.checkInputDataTypes()
} else {
TypeCheckResult.TypeCheckFailure(
"The input to the function 'mode' was a type of binary-unstable type that is " +
s"not currently supported by ${prettyName}.")
}
}

// Returns null for empty inputs
override def nullable: Boolean = true

Expand All @@ -51,8 +62,8 @@ case class Mode(
override def prettyName: String = "mode"

override def update(
buffer: OpenHashMap[AnyRef, Long],
input: InternalRow): OpenHashMap[AnyRef, Long] = {
buffer: CollationAwareHashMap[AnyRef, Long, UTF8String],
input: InternalRow): CollationAwareHashMap[AnyRef, Long, UTF8String] = {
val key = child.eval(input)

if (key != null) {
Expand All @@ -62,19 +73,19 @@ case class Mode(
}

override def merge(
buffer: OpenHashMap[AnyRef, Long],
other: OpenHashMap[AnyRef, Long]): OpenHashMap[AnyRef, Long] = {
buffer: CollationAwareHashMap[AnyRef, Long, UTF8String],
other: CollationAwareHashMap[AnyRef, Long, UTF8String]):
CollationAwareHashMap[AnyRef, Long, UTF8String] = {
other.foreach { case (key, count) =>
buffer.changeValue(key, count, _ + count)
}
buffer
}

override def eval(buffer: OpenHashMap[AnyRef, Long]): Any = {
override def eval(buffer: CollationAwareHashMap[AnyRef, Long, UTF8String]): Any = {
if (buffer.isEmpty) {
return null
}

reverseOpt.map { reverse =>
val defaultKeyOrdering = if (reverse) {
PhysicalDataType.ordering(child.dataType).asInstanceOf[Ordering[AnyRef]].reverse
Expand Down Expand Up @@ -201,7 +212,8 @@ case class PandasMode(
child: Expression,
ignoreNA: Boolean = true,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0) extends TypedAggregateWithHashMapAsBuffer
inputAggBufferOffset: Int = 0)
extends TypedAggregateWithHashMapAsBuffer
Comment on lines +215 to +216
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like an unnecessary change, but I suppose it's a prototyping leftover so I won't comment these more for now

with ImplicitCastInputTypes with UnaryLike[Expression] {

def this(child: Expression) = this(child, true, 0, 0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,13 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode}
import org.apache.spark.sql.catalyst.trees.TreePattern.{AGGREGATE_EXPRESSION, TreePattern}
import org.apache.spark.sql.catalyst.trees.UnaryLike
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.catalyst.util.{CollationFactory, UnsafeRowUtils}
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.types._
import org.apache.spark.util.collection.OpenHashMap
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.collection.{CollationAwareHashMap, OpenHashMap}

/** The mode of an [[AggregateFunction]]. */
sealed trait AggregateMode
Expand Down Expand Up @@ -643,18 +646,14 @@ abstract class TypedImperativeAggregate[T] extends ImperativeAggregate {
* A special [[TypedImperativeAggregate]] that uses `OpenHashMap[AnyRef, Long]` as internal
* aggregation buffer.
*/
abstract class TypedAggregateWithHashMapAsBuffer
extends TypedImperativeAggregate[OpenHashMap[AnyRef, Long]] {
override def createAggregationBuffer(): OpenHashMap[AnyRef, Long] = {
// Initialize new counts map instance here.
new OpenHashMap[AnyRef, Long]()
}
abstract class TypedAggregateWithHashMapAsBufferBase[HM <: OpenHashMap[AnyRef, Long]]
extends TypedImperativeAggregate[HM] {

protected def child: Expression

private lazy val projection = UnsafeProjection.create(Array[DataType](child.dataType, LongType))

override def serialize(obj: OpenHashMap[AnyRef, Long]): Array[Byte] = {
override def serialize(obj: HM): Array[Byte] = {
val buffer = new Array[Byte](4 << 10) // 4K
val bos = new ByteArrayOutputStream()
val out = new DataOutputStream(bos)
Expand All @@ -676,11 +675,16 @@ abstract class TypedAggregateWithHashMapAsBuffer
}
}

override def deserialize(bytes: Array[Byte]): OpenHashMap[AnyRef, Long] = {
override def deserialize(bytes: Array[Byte]): HM = {
val bis = new ByteArrayInputStream(bytes)
val ins = new DataInputStream(bis)
try {
val counts = new OpenHashMap[AnyRef, Long]
val counts = createAggregationBuffer()
/* (64, child.dataType match {
case StringType if child.dataType.asInstanceOf[StringType].isUTF8BinaryLcaseCollation => 1
case StringType => 0
case _ => -1
}) */
// Read unsafeRow size and content in bytes.
var sizeOfNextRow = ins.readInt()
while (sizeOfNextRow >= 0) {
Expand All @@ -702,3 +706,55 @@ abstract class TypedAggregateWithHashMapAsBuffer
}
}
}

object CollationAwareFunctionRegistry {
def bytesToHashFunction(dataType: DataType): AnyRef => Long = {
val hashFunction = dataType match {
case s: StringType => a: AnyRef => CollationFactory.fetchCollation(s.collationId)
.hashFunction.applyAsLong(a.asInstanceOf[UTF8String])
case nb: StructType if !UnsafeRowUtils.isBinaryStable(nb) => a: AnyRef =>
a.asInstanceOf[InternalRow].toSeq(nb).zip(nb.fields.toSeq).foldLeft(0L)((acc, b)
=> acc ^ bytesToHashFunction(b._2.dataType)(b._1.asInstanceOf[AnyRef]))
case _ => a: AnyRef => a.hashCode().toLong
}
hashFunction
}
def bytesToEqualFunction(dataType: DataType): (AnyRef, AnyRef) => Boolean = {
val equalFunction = dataType match {
case s: StringType => (a: AnyRef, b: AnyRef) =>
a.asInstanceOf[UTF8String].semanticEquals(b.asInstanceOf[UTF8String], s.collationId)
case nb: StructType if !UnsafeRowUtils.isBinaryStable(nb) => (a: AnyRef, b: AnyRef) =>
a.asInstanceOf[InternalRow].toSeq(nb)
.zip(b.asInstanceOf[InternalRow].toSeq(nb)).zipWithIndex
.forall { case ((a, b), i) =>
bytesToEqualFunction(
nb.fields(i).dataType
)(a.asInstanceOf[AnyRef], b.asInstanceOf[AnyRef])
}
case _ => (a: AnyRef, b: AnyRef) =>
a.equals(b)
}
equalFunction

}
}
Comment on lines +710 to +740
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these look like general collations-related Util methods, perhaps we should place this elsewhere (for example, into CollationFactory)


abstract class CollationAwareTypedAggregateWithHashMapAsBuffer
extends TypedAggregateWithHashMapAsBufferBase[CollationAwareHashMap[AnyRef, Long, UTF8String]] {
self: UnaryLike[Expression] =>
override def createAggregationBuffer(): CollationAwareHashMap[AnyRef, Long, UTF8String] = {
// Initialize new counts map instance here.
new CollationAwareHashMap[AnyRef, Long, UTF8String](64,
CollationAwareFunctionRegistry.bytesToHashFunction(self.dataType),
CollationAwareFunctionRegistry.bytesToEqualFunction(self.dataType)
)
}
}

abstract class TypedAggregateWithHashMapAsBuffer
extends TypedAggregateWithHashMapAsBufferBase[OpenHashMap[AnyRef, Long]]{
override def createAggregationBuffer(): OpenHashMap[AnyRef, Long] = {
// Initialize new counts map instance here.
new OpenHashMap[AnyRef, Long]()
}
}
Loading