Skip to content

Commit

Permalink
Extract aggregation map into its own class.
Browse files Browse the repository at this point in the history
This makes the code much easier to understand and
will allow me to implement unsafe versions of both
GeneratedAggregate and the regular Aggregate operator.
  • Loading branch information
JoshRosen committed Apr 22, 2015
1 parent d2bb986 commit b3eaccd
Show file tree
Hide file tree
Showing 3 changed files with 229 additions and 100 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.expressions;

import java.util.Arrays;
import java.util.Iterator;

import org.apache.spark.sql.Row;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.unsafe.PlatformDependent;
import org.apache.spark.unsafe.map.BytesToBytesMap;
import org.apache.spark.unsafe.memory.MemoryAllocator;
import org.apache.spark.unsafe.memory.MemoryLocation;

/**
* Unsafe-based HashMap for performing aggregations in which the aggregated values are
* fixed-width. This is NOT threadsafe.
*/
public final class UnsafeFixedWidthAggregationMap {

/**
* An empty aggregation buffer, encoded in UnsafeRow format. When inserting a new key into the
* map, we copy this buffer and use it as the value.
*/
private final long[] emptyAggregationBuffer;

private final StructType aggregationBufferSchema;

private final StructType groupingKeySchema;

/**
* Encodes grouping keys as UnsafeRows.
*/
private final UnsafeRowConverter groupingKeyToUnsafeRowConverter;

/**
* A hashmap which maps from opaque bytearray keys to bytearray values.
*/
private final BytesToBytesMap map;

/**
* Re-used pointer to the current aggregation buffer
*/
private final UnsafeRow currentAggregationBuffer = new UnsafeRow();

/**
* Scratch space that is used when encoding grouping keys into UnsafeRow format.
*
* By default, this is a 1MB array, but it will grow as necessary in case larger keys are
* encountered.
*/
private long[] groupingKeyConversionScratchSpace = new long[1024 / 8];

/**
* Create a new UnsafeFixedWidthAggregationMap.
*
* @param emptyAggregationBuffer the default value for new keys (a "zero" of the agg. function)
* @param aggregationBufferSchema the schema of the aggregation buffer, used for row conversion.
* @param groupingKeySchema the schema of the grouping key, used for row conversion.
* @param allocator the memory allocator used to allocate our Unsafe memory structures.
* @param initialCapacity the initial capacity of the map (a sizing hint to avoid re-hashing).
*/
public UnsafeFixedWidthAggregationMap(
Row emptyAggregationBuffer,
StructType aggregationBufferSchema,
StructType groupingKeySchema,
MemoryAllocator allocator,
long initialCapacity) {
this.emptyAggregationBuffer =
convertToUnsafeRow(emptyAggregationBuffer, aggregationBufferSchema);
this.aggregationBufferSchema = aggregationBufferSchema;
this.groupingKeyToUnsafeRowConverter = new UnsafeRowConverter(groupingKeySchema);
this.groupingKeySchema = groupingKeySchema;
this.map = new BytesToBytesMap(allocator, initialCapacity);
}

/**
* Convert a Java object row into an UnsafeRow, allocating it into a new long array.
*/
private static long[] convertToUnsafeRow(Row javaRow, StructType schema) {
final UnsafeRowConverter converter = new UnsafeRowConverter(schema);
final long[] unsafeRow = new long[converter.getSizeRequirement(javaRow)];
final long writtenLength =
converter.writeRow(javaRow, unsafeRow, PlatformDependent.LONG_ARRAY_OFFSET);
assert (writtenLength == unsafeRow.length): "Size requirement calculation was wrong!";
return unsafeRow;
}

/**
* Return the aggregation buffer for the current group. For efficiency, all calls to this method
* return the same object.
*/
public UnsafeRow getAggregationBuffer(Row groupingKey) {
// Zero out the buffer that's used to hold the current row. This is necessary in order
// to ensure that rows hash properly, since garbage data from the previous row could
// otherwise end up as padding in this row.
Arrays.fill(groupingKeyConversionScratchSpace, 0);
final int groupingKeySize = groupingKeyToUnsafeRowConverter.getSizeRequirement(groupingKey);
if (groupingKeySize > groupingKeyConversionScratchSpace.length) {
groupingKeyConversionScratchSpace = new long[groupingKeySize];
}
final long actualGroupingKeySize = groupingKeyToUnsafeRowConverter.writeRow(
groupingKey,
groupingKeyConversionScratchSpace,
PlatformDependent.LONG_ARRAY_OFFSET);
assert (groupingKeySize == actualGroupingKeySize) : "Size requirement calculation was wrong!";

// Probe our map using the serialized key
final BytesToBytesMap.Location loc = map.lookup(
groupingKeyConversionScratchSpace,
PlatformDependent.LONG_ARRAY_OFFSET,
groupingKeySize);
if (!loc.isDefined()) {
// This is the first time that we've seen this grouping key, so we'll insert a copy of the
// empty aggregation buffer into the map:
loc.storeKeyAndValue(
groupingKeyConversionScratchSpace,
PlatformDependent.LONG_ARRAY_OFFSET,
groupingKeySize,
emptyAggregationBuffer,
PlatformDependent.LONG_ARRAY_OFFSET,
emptyAggregationBuffer.length
);
}

// Reset the pointer to point to the value that we just stored or looked up:
final MemoryLocation address = loc.getValueAddress();
currentAggregationBuffer.set(
address.getBaseObject(),
address.getBaseOffset(),
aggregationBufferSchema.length(),
aggregationBufferSchema
);
return currentAggregationBuffer;
}

public static class MapEntry {
public final UnsafeRow key = new UnsafeRow();
public final UnsafeRow value = new UnsafeRow();
}

/**
* Returns an iterator over the keys and values in this map.
*
* For efficiency, each call returns the same object.
*/
public Iterator<MapEntry> iterator() {
return new Iterator<MapEntry>() {

private final MapEntry entry = new MapEntry();
private final Iterator<BytesToBytesMap.Location> mapLocationIterator = map.iterator();

@Override
public boolean hasNext() {
return mapLocationIterator.hasNext();
}

@Override
public MapEntry next() {
final BytesToBytesMap.Location loc = mapLocationIterator.next();
final MemoryLocation keyAddress = loc.getKeyAddress();
final MemoryLocation valueAddress = loc.getValueAddress();
entry.key.set(
keyAddress.getBaseObject(),
keyAddress.getBaseOffset(),
groupingKeySchema.length(),
groupingKeySchema
);
entry.value.set(
valueAddress.getBaseObject(),
valueAddress.getBaseOffset(),
aggregationBufferSchema.length(),
aggregationBufferSchema
);
return entry;
}

@Override
public void remove() {
throw new UnsupportedOperationException();
}
};
}

/**
* Free the unsafe memory associated with this map.
*/
public void free() {
map.free();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,10 @@ private case object DoubleUnsafeColumnWriter extends DoubleUnsafeColumnWriter

class UnsafeRowConverter(fieldTypes: Array[DataType]) {

def this(schema: StructType) {
this(schema.fields.map(_.dataType))
}

private[this] val unsafeRow = new UnsafeRow()

private[this] val writers: Array[UnsafeColumnWriter[Any]] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,12 @@

package org.apache.spark.sql.execution

import java.util

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.trees._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.PlatformDependent
import org.apache.spark.unsafe.map.BytesToBytesMap
import org.apache.spark.unsafe.memory.MemoryAllocator

// TODO: finish cleaning up documentation instead of just copying it
Expand Down Expand Up @@ -258,128 +254,50 @@ case class UnsafeGeneratedAggregate(
val resultProjection = resultProjectionBuilder()
Iterator(resultProjection(buffer))
} else {
// TODO: if we knew how many groups to expect, we could size this hashmap appropriately
val buffers = new BytesToBytesMap(MemoryAllocator.UNSAFE, 128)

// Set up the mutable "pointers" that we'll re-use when pointing to key and value rows
val currentBuffer: UnsafeRow = new UnsafeRow()

// We're going to need to allocate a lot of empty aggregation buffers, so let's do it
// once and keep a copy of the serialized buffer and copy it into the hash map when we see
// new keys:
val emptyAggregationBuffer: Array[Long] = {
val javaBuffer: MutableRow = newAggregationBuffer(EmptyRow).asInstanceOf[MutableRow]
val fieldTypes = StructType.fromAttributes(computationSchema).map(_.dataType).toArray
val converter = new UnsafeRowConverter(fieldTypes)
val buffer = new Array[Long](converter.getSizeRequirement(javaBuffer))
converter.writeRow(javaBuffer, buffer, PlatformDependent.LONG_ARRAY_OFFSET)
buffer
}

val keyToUnsafeRowConverter: UnsafeRowConverter = {
new UnsafeRowConverter(groupingExpressions.map(_.dataType).toArray)
}

val aggregationBufferSchema = StructType.fromAttributes(computationSchema)
val keySchema: StructType = {

val groupKeySchema: StructType = {
val fields = groupingExpressions.zipWithIndex.map { case (expr, idx) =>
StructField(idx.toString, expr.dataType, expr.nullable)
}
StructType(fields)
}

// Allocate some scratch space for holding the keys that we use to index into the hash map.
// 16 MB ought to be enough for anyone (TODO)
val unsafeRowBuffer: Array[Long] = new Array[Long](1024 * 16 / 8)
val aggregationMap = new UnsafeFixedWidthAggregationMap(
newAggregationBuffer(EmptyRow),
aggregationBufferSchema,
groupKeySchema,
MemoryAllocator.UNSAFE,
1024
)

while (iter.hasNext) {
// Zero out the buffer that's used to hold the current row. This is necessary in order
// to ensure that rows hash properly, since garbage data from the previous row could
// otherwise end up as padding in this row.
util.Arrays.fill(unsafeRowBuffer, 0)
// Grab the next row from our input iterator and compute its group projection.
// In the long run, it might be nice to use Unsafe rows for this as well, but for now
// we'll just rely on the existing code paths to compute the projection.
val currentJavaRow = iter.next()
val currentGroup: Row = groupProjection(currentJavaRow)
// Convert the current group into an UnsafeRow so that we can use it as a key for our
// aggregation hash map
val groupProjectionSize = keyToUnsafeRowConverter.getSizeRequirement(currentGroup)
if (groupProjectionSize > unsafeRowBuffer.length) {
throw new IllegalStateException("Group projection does not fit into buffer")
}
val keyLengthInBytes: Int = keyToUnsafeRowConverter.writeRow(
currentGroup, unsafeRowBuffer, PlatformDependent.LONG_ARRAY_OFFSET).toInt // TODO

val loc: BytesToBytesMap#Location =
buffers.lookup(unsafeRowBuffer, PlatformDependent.LONG_ARRAY_OFFSET, keyLengthInBytes)
if (!loc.isDefined) {
// This is the first time that we've seen this key, so we'll copy the empty aggregation
// buffer row that we created earlier. TODO: this doesn't work very well for aggregates
// where the size of the aggregate buffer is different for different rows (even if the
// size of buffers don't grow once created, as is the case for things like grabbing the
// first row's value for a string-valued column (or the shortest string)).

loc.storeKeyAndValue(
unsafeRowBuffer,
PlatformDependent.LONG_ARRAY_OFFSET,
keyLengthInBytes,
emptyAggregationBuffer,
PlatformDependent.LONG_ARRAY_OFFSET,
emptyAggregationBuffer.length
)
}
// Reset our pointer to point to the buffer stored in the hash map
val address = loc.getValueAddress
currentBuffer.set(
address.getBaseObject,
address.getBaseOffset,
aggregationBufferSchema.length,
aggregationBufferSchema
)
// Target the projection at the current aggregation buffer and then project the updated
// values.
updateProjection.target(currentBuffer)(joinedRow(currentBuffer, currentJavaRow))
val currentRow: Row = iter.next()
val groupKey: Row = groupProjection(currentRow)
val aggregationBuffer = aggregationMap.getAggregationBuffer(groupKey)
updateProjection.target(aggregationBuffer)(joinedRow(aggregationBuffer, currentRow))
}

new Iterator[Row] {
private[this] val resultIterator = buffers.iterator()
private[this] val mapIterator = aggregationMap.iterator()
private[this] val resultProjection = resultProjectionBuilder()
private[this] val key: UnsafeRow = new UnsafeRow()
private[this] val value: UnsafeRow = new UnsafeRow()

def hasNext: Boolean = resultIterator.hasNext
def hasNext: Boolean = mapIterator.hasNext

def next(): Row = {
val currentGroup: BytesToBytesMap#Location = resultIterator.next()
val keyAddress = currentGroup.getKeyAddress
key.set(
keyAddress.getBaseObject,
keyAddress.getBaseOffset,
groupingExpressions.length,
keySchema)
val valueAddress = currentGroup.getValueAddress
value.set(
valueAddress.getBaseObject,
valueAddress.getBaseOffset,
aggregationBufferSchema.length,
aggregationBufferSchema)
val result = resultProjection(joinedRow(key, value))
val entry = mapIterator.next()
val result = resultProjection(joinedRow(entry.key, entry.value))
if (hasNext) {
result
} else {
// This is the last element in the iterator, so let's free the buffer. Before we do,
// though, we need to make a defensive copy of the result so that we don't return an
// object that might contain dangling pointers to the freed memory
val resultCopy = result.copy()
buffers.free()
aggregationMap.free()
resultCopy
}
}

override def finalize(): Unit = {
buffers.free()
}
}
}
}
Expand Down

0 comments on commit b3eaccd

Please sign in to comment.