Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-11645][SQL] Remove OpenHashSet for the old aggregate. #9621

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,6 @@ import org.apache.spark.unsafe.Platform
import org.apache.spark.unsafe.types._


// These classes are here to avoid issues with serialization and integration with quasiquotes.
class IntegerHashSet extends org.apache.spark.util.collection.OpenHashSet[Int]
class LongHashSet extends org.apache.spark.util.collection.OpenHashSet[Long]

/**
* Java source for evaluating an [[Expression]] given a [[InternalRow]] of input.
*
Expand Down Expand Up @@ -205,8 +201,6 @@ class CodeGenContext {
case _: StructType => "InternalRow"
case _: ArrayType => "ArrayData"
case _: MapType => "MapData"
case dt: OpenHashSetUDT if dt.elementType == IntegerType => classOf[IntegerHashSet].getName
case dt: OpenHashSetUDT if dt.elementType == LongType => classOf[LongHashSet].getName
case udt: UserDefinedType[_] => javaType(udt.sqlType)
case ObjectType(cls) if cls.isArray => s"${javaType(ObjectType(cls.getComponentType))}[]"
case ObjectType(cls) => cls.getName
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
case t: StructType => t.toSeq.forall(field => canSupport(field.dataType))
case t: ArrayType if canSupport(t.elementType) => true
case MapType(kt, vt, _) if canSupport(kt) && canSupport(vt) => true
case dt: OpenHashSetUDT => false // it's not a standard UDT
case udt: UserDefinedType[_] => canSupport(udt.sqlType)
case _ => false
}
Expand Down Expand Up @@ -309,13 +308,13 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
in.map(BindReferences.bindReference(_, inputSchema))

def generate(
expressions: Seq[Expression],
subexpressionEliminationEnabled: Boolean): UnsafeProjection = {
expressions: Seq[Expression],
subexpressionEliminationEnabled: Boolean): UnsafeProjection = {
create(canonicalize(expressions), subexpressionEliminationEnabled)
}

protected def create(expressions: Seq[Expression]): UnsafeProjection = {
create(expressions, false)
create(expressions, subexpressionEliminationEnabled = false)
}

private def create(
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,16 @@ import java.util.{HashMap => JavaHashMap}

import scala.reflect.ClassTag

import com.clearspring.analytics.stream.cardinality.HyperLogLog
import com.esotericsoftware.kryo.io.{Input, Output}
import com.esotericsoftware.kryo.{Kryo, Serializer}
import com.twitter.chill.ResourcePool

import org.apache.spark.serializer.{KryoSerializer, SerializerInstance}
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{IntegerHashSet, LongHashSet}
import org.apache.spark.sql.types.Decimal
import org.apache.spark.util.MutablePair
import org.apache.spark.util.collection.OpenHashSet
import org.apache.spark.{SparkConf, SparkEnv}


private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(conf) {
override def newKryo(): Kryo = {
val kryo = super.newKryo()
Expand All @@ -43,16 +40,9 @@ private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(co
kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericRow])
kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericInternalRow])
kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericMutableRow])
kryo.register(classOf[com.clearspring.analytics.stream.cardinality.HyperLogLog],
new HyperLogLogSerializer)
kryo.register(classOf[java.math.BigDecimal], new JavaBigDecimalSerializer)
kryo.register(classOf[BigDecimal], new ScalaBigDecimalSerializer)

// Specific hashsets must come first TODO: Move to core.
kryo.register(classOf[IntegerHashSet], new IntegerHashSetSerializer)
kryo.register(classOf[LongHashSet], new LongHashSetSerializer)
kryo.register(classOf[org.apache.spark.util.collection.OpenHashSet[_]],
new OpenHashSetSerializer)
kryo.register(classOf[Decimal])
kryo.register(classOf[JavaHashMap[_, _]])

Expand All @@ -62,7 +52,7 @@ private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(co
}

private[execution] class KryoResourcePool(size: Int)
extends ResourcePool[SerializerInstance](size) {
extends ResourcePool[SerializerInstance](size) {

val ser: SparkSqlSerializer = {
val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf())
Expand Down Expand Up @@ -116,92 +106,3 @@ private[sql] class ScalaBigDecimalSerializer extends Serializer[BigDecimal] {
new java.math.BigDecimal(input.readString())
}
}

private[sql] class HyperLogLogSerializer extends Serializer[HyperLogLog] {
def write(kryo: Kryo, output: Output, hyperLogLog: HyperLogLog) {
val bytes = hyperLogLog.getBytes()
output.writeInt(bytes.length)
output.writeBytes(bytes)
}

def read(kryo: Kryo, input: Input, tpe: Class[HyperLogLog]): HyperLogLog = {
val length = input.readInt()
val bytes = input.readBytes(length)
HyperLogLog.Builder.build(bytes)
}
}

private[sql] class OpenHashSetSerializer extends Serializer[OpenHashSet[_]] {
def write(kryo: Kryo, output: Output, hs: OpenHashSet[_]) {
val rowSerializer = kryo.getDefaultSerializer(classOf[Array[Any]]).asInstanceOf[Serializer[Any]]
output.writeInt(hs.size)
val iterator = hs.iterator
while(iterator.hasNext) {
val row = iterator.next()
rowSerializer.write(kryo, output, row.asInstanceOf[GenericInternalRow].values)
}
}

def read(kryo: Kryo, input: Input, tpe: Class[OpenHashSet[_]]): OpenHashSet[_] = {
val rowSerializer = kryo.getDefaultSerializer(classOf[Array[Any]]).asInstanceOf[Serializer[Any]]
val numItems = input.readInt()
val set = new OpenHashSet[Any](numItems + 1)
var i = 0
while (i < numItems) {
val row =
new GenericInternalRow(rowSerializer.read(
kryo,
input,
classOf[Array[Any]].asInstanceOf[Class[Any]]).asInstanceOf[Array[Any]])
set.add(row)
i += 1
}
set
}
}

private[sql] class IntegerHashSetSerializer extends Serializer[IntegerHashSet] {
def write(kryo: Kryo, output: Output, hs: IntegerHashSet) {
output.writeInt(hs.size)
val iterator = hs.iterator
while(iterator.hasNext) {
val value: Int = iterator.next()
output.writeInt(value)
}
}

def read(kryo: Kryo, input: Input, tpe: Class[IntegerHashSet]): IntegerHashSet = {
val numItems = input.readInt()
val set = new IntegerHashSet
var i = 0
while (i < numItems) {
val value = input.readInt()
set.add(value)
i += 1
}
set
}
}

private[sql] class LongHashSetSerializer extends Serializer[LongHashSet] {
def write(kryo: Kryo, output: Output, hs: LongHashSet) {
output.writeInt(hs.size)
val iterator = hs.iterator
while(iterator.hasNext) {
val value = iterator.next()
output.writeLong(value)
}
}

def read(kryo: Kryo, input: Input, tpe: Class[LongHashSet]): LongHashSet = {
val numItems = input.readInt()
val set = new LongHashSet
var i = 0
while (i < numItems) {
val value = input.readLong()
set.add(value)
i += 1
}
set
}
}
Loading