Skip to content

Commit

Permalink
Merge Unsafe code into the regular GeneratedAggregate, guarded by a
Browse files Browse the repository at this point in the history
configuration flag; integrate planner support and re-enable all tests.
  • Loading branch information
JoshRosen committed Apr 22, 2015
1 parent d85eeff commit 1f4b716
Show file tree
Hide file tree
Showing 9 changed files with 168 additions and 350 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import java.util.Iterator;

import org.apache.spark.sql.Row;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.unsafe.PlatformDependent;
import org.apache.spark.unsafe.map.BytesToBytesMap;
Expand Down Expand Up @@ -65,6 +66,32 @@ public final class UnsafeFixedWidthAggregationMap {
*/
private long[] groupingKeyConversionScratchSpace = new long[1024 / 8];

/**
* @return true if UnsafeFixedWidthAggregationMap supports grouping keys with the given schema,
* false otherwise.
*/
public static boolean supportsGroupKeySchema(StructType schema) {
for (StructField field: schema.fields()) {
if (!UnsafeRow.readableFieldTypes.contains(field.dataType())) {
return false;
}
}
return true;
}

/**
* @return true if UnsafeFixedWidthAggregationMap supports aggregation buffers with the given
* schema, false otherwise.
*/
public static boolean supportsAggregationBufferSchema(StructType schema) {
for (StructField field: schema.fields()) {
if (!UnsafeRow.settableFieldTypes.contains(field.dataType())) {
return false;
}
}
return true;
}

/**
* Create a new UnsafeFixedWidthAggregationMap.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.apache.spark.sql.types.DataType;
import static org.apache.spark.sql.types.DataTypes.*;

import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.sql.types.UTF8String;
import org.apache.spark.unsafe.PlatformDependent;
Expand All @@ -34,7 +35,10 @@
import javax.annotation.Nullable;
import java.math.BigDecimal;
import java.sql.Date;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Set;


// TODO: pick a better name for this class, since this is potentially confusing.
Expand Down Expand Up @@ -71,6 +75,34 @@ public static int calculateBitSetWidthInBytes(int numFields) {
return ((numFields / 64) + ((numFields % 64 == 0 ? 0 : 1))) * 8;
}

/**
* Field types that can be updated in place in UnsafeRows (e.g. we support set() for these types)
*/
public static final Set<DataType> settableFieldTypes;

/**
* Fields types can be read(but not set (e.g. set() will throw UnsupportedOperationException).
*/
public static final Set<DataType> readableFieldTypes;

static {
settableFieldTypes = new HashSet<DataType>(Arrays.asList(new DataType[] {
IntegerType,
LongType,
DoubleType,
BooleanType,
ShortType,
ByteType,
FloatType
}));

// We support get() on a superset of the types for which we support set():
readableFieldTypes = new HashSet<DataType>(Arrays.asList(new DataType[] {
StringType
}));
readableFieldTypes.addAll(settableFieldTypes);
}

public UnsafeRow() { }

public void set(Object baseObject, long baseOffset, int numFields, StructType schema) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,22 @@ import org.apache.spark.sql.types._

class UnsafeFixedWidthAggregationMapSuite extends FunSuite with Matchers {

import UnsafeFixedWidthAggregationMap._

private val groupKeySchema = StructType(StructField("product", StringType) :: Nil)
private val aggBufferSchema = StructType(StructField("salePrice", IntegerType) :: Nil)
private def emptyAggregationBuffer: Row = new GenericRow(Array[Any](0))

test("supported schemas") {
assert(!supportsAggregationBufferSchema(StructType(StructField("x", StringType) :: Nil)))
assert(supportsGroupKeySchema(StructType(StructField("x", StringType) :: Nil)))

assert(
!supportsAggregationBufferSchema(StructType(StructField("x", ArrayType(IntegerType)) :: Nil)))
assert(
!supportsGroupKeySchema(StructType(StructField("x", ArrayType(IntegerType)) :: Nil)))
}

test("empty map") {
val map = new UnsafeFixedWidthAggregationMap(
emptyAggregationBuffer,
Expand Down
3 changes: 3 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ private[spark] object SQLConf {
val DEFAULT_SIZE_IN_BYTES = "spark.sql.defaultSizeInBytes"
val SHUFFLE_PARTITIONS = "spark.sql.shuffle.partitions"
val CODEGEN_ENABLED = "spark.sql.codegen"
val UNSAFE_ENABLED = "spark.sql.unsafe"
val DIALECT = "spark.sql.dialect"

val PARQUET_BINARY_AS_STRING = "spark.sql.parquet.binaryAsString"
Expand Down Expand Up @@ -149,6 +150,8 @@ private[sql] class SQLConf extends Serializable {
*/
private[spark] def codegenEnabled: Boolean = getConf(CODEGEN_ENABLED, "false").toBoolean

private[spark] def unsafeEnabled: Boolean = getConf(UNSAFE_ENABLED, "false").toBoolean

private[spark] def useSqlSerializer2: Boolean = getConf(USE_SQL_SERIALIZER2, "true").toBoolean

/**
Expand Down
2 changes: 2 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1011,6 +1011,8 @@ class SQLContext(@transient val sparkContext: SparkContext)

def codegenEnabled: Boolean = self.conf.codegenEnabled

def unsafeEnabled: Boolean = self.conf.unsafeEnabled

def numPartitions: Int = self.conf.numShufflePartitions

def strategies: Seq[Strategy] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.trees._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.memory.MemoryAllocator

case class AggregateEvaluation(
schema: Seq[Attribute],
Expand All @@ -41,13 +42,15 @@ case class AggregateEvaluation(
* @param groupingExpressions expressions that are evaluated to determine grouping.
* @param aggregateExpressions expressions that are computed for each group.
* @param child the input data source.
* @param unsafeEnabled whether to allow Unsafe-based aggregation buffers to be used.
*/
@DeveloperApi
case class GeneratedAggregate(
partial: Boolean,
groupingExpressions: Seq[Expression],
aggregateExpressions: Seq[NamedExpression],
child: SparkPlan)
child: SparkPlan,
unsafeEnabled: Boolean)
extends UnaryNode {

override def requiredChildDistribution: Seq[Distribution] =
Expand Down Expand Up @@ -225,6 +228,21 @@ case class GeneratedAggregate(
case e: Expression if groupMap.contains(e) => groupMap(e)
})

val aggregationBufferSchema: StructType = StructType.fromAttributes(computationSchema)

val groupKeySchema: StructType = {
val fields = groupingExpressions.zipWithIndex.map { case (expr, idx) =>
// This is a dummy field name
StructField(idx.toString, expr.dataType, expr.nullable)
}
StructType(fields)
}

val schemaSupportsUnsafe: Boolean = {
UnsafeFixedWidthAggregationMap.supportsAggregationBufferSchema(aggregationBufferSchema) &&
UnsafeFixedWidthAggregationMap.supportsGroupKeySchema(groupKeySchema)
}

child.execute().mapPartitions { iter =>
// Builds a new custom class for holding the results of aggregation for a group.
val initialValues = computeFunctions.flatMap(_.initialValues)
Expand Down Expand Up @@ -265,7 +283,48 @@ case class GeneratedAggregate(

val resultProjection = resultProjectionBuilder()
Iterator(resultProjection(buffer))
} else if (unsafeEnabled && schemaSupportsUnsafe) {
log.info("Using Unsafe-based aggregator")
val aggregationMap = new UnsafeFixedWidthAggregationMap(
newAggregationBuffer(EmptyRow),
aggregationBufferSchema,
groupKeySchema,
MemoryAllocator.UNSAFE,
1024
)

while (iter.hasNext) {
val currentRow: Row = iter.next()
val groupKey: Row = groupProjection(currentRow)
val aggregationBuffer = aggregationMap.getAggregationBuffer(groupKey)
updateProjection.target(aggregationBuffer)(joinedRow(aggregationBuffer, currentRow))
}

new Iterator[Row] {
private[this] val mapIterator = aggregationMap.iterator()
private[this] val resultProjection = resultProjectionBuilder()

def hasNext: Boolean = mapIterator.hasNext

def next(): Row = {
val entry = mapIterator.next()
val result = resultProjection(joinedRow(entry.key, entry.value))
if (hasNext) {
result
} else {
// This is the last element in the iterator, so let's free the buffer. Before we do,
// though, we need to make a defensive copy of the result so that we don't return an
// object that might contain dangling pointers to the freed memory
val resultCopy = result.copy()
aggregationMap.free()
resultCopy
}
}
}
} else {
if (unsafeEnabled) {
log.info("Not using Unsafe-based aggregator because it is not supported for this schema")
}
val buffers = new java.util.HashMap[Row, MutableRow]()

var currentRow: Row = null
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,18 +131,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
if canBeCodeGened(
allAggregates(partialComputation) ++
allAggregates(rewrittenAggregateExpressions)) &&
codegenEnabled => {
if (self.sqlContext.getConf("spark.sql.unsafe.enabled", "false") == "true") {
execution.UnsafeGeneratedAggregate(
partial = false,
namedGroupingAttributes,
rewrittenAggregateExpressions,
execution.UnsafeGeneratedAggregate(
partial = true,
groupingExpressions,
partialComputation,
planLater(child))) :: Nil
} else {
codegenEnabled =>
execution.GeneratedAggregate(
partial = false,
namedGroupingAttributes,
Expand All @@ -151,9 +140,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
partial = true,
groupingExpressions,
partialComputation,
planLater(child))) :: Nil
}
}
planLater(child),
unsafeEnabled),
unsafeEnabled) :: Nil

// Cases where some aggregate can not be codegened
case PartialAggregation(
Expand Down
Loading

0 comments on commit 1f4b716

Please sign in to comment.