Skip to content

Commit

Permalink
Sketch how the converters will be used in UnsafeGeneratedAggregate
Browse files Browse the repository at this point in the history
  • Loading branch information
JoshRosen committed Apr 22, 2015
1 parent 53ba9b7 commit fc4c3a8
Showing 1 changed file with 32 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ import org.apache.spark.unsafe.memory.MemoryAllocator
*/
@DeveloperApi
case class UnsafeGeneratedAggregate(
partial: Boolean,
groupingExpressions: Seq[Expression],
aggregateExpressions: Seq[NamedExpression],
child: SparkPlan)
partial: Boolean,
groupingExpressions: Seq[Expression],
aggregateExpressions: Seq[NamedExpression],
child: SparkPlan)
extends UnaryNode {

override def requiredChildDistribution: Seq[Distribution] =
Expand Down Expand Up @@ -267,17 +267,25 @@ case class UnsafeGeneratedAggregate(
// We're going to need to allocate a lot of empty aggregation buffers, so let's do it
// once and keep a copy of the serialized buffer and copy it into the hash map when we see
// new keys:
val javaAggregationBuffer: MutableRow =
newAggregationBuffer(EmptyRow).asInstanceOf[MutableRow]
val numberOfFieldsInAggregationBuffer: Int = javaAggregationBuffer.schema.fields.length
val aggregationBufferSchema: StructType = javaAggregationBuffer.schema
// TODO perform that conversion to an UnsafeRow
// Allocate some scratch space for holding the keys that we use to index into the hash map.
val unsafeRowBuffer: Array[Long] = new Array[Long](1024)
val (emptyAggregationBuffer: Array[Long], numberOfColumnsInAggBuffer: Int) = {
val javaBuffer: MutableRow = newAggregationBuffer(EmptyRow).asInstanceOf[MutableRow]
val converter = new UnsafeRowConverter(javaBuffer.schema.fields.map(_.dataType))
val buffer = new Array[Long](converter.getSizeRequirement(javaBuffer))
converter.writeRow(javaBuffer, buffer, PlatformDependent.LONG_ARRAY_OFFSET)
(buffer, javaBuffer.schema.fields.length)
}

// TODO: there's got got to be an actual way of obtaining this up front.
var groupProjectionSchema: StructType = null

val keyToUnsafeRowConverter: UnsafeRowConverter = {
new UnsafeRowConverter(groupProjectionSchema.fields.map(_.dataType))
}

// Allocate some scratch space for holding the keys that we use to index into the hash map.
// 16 MB ought to be enough for anyone (TODO)
val unsafeRowBuffer: Array[Long] = new Array[Long](1024 * 16 / 8)

while (iter.hasNext) {
// Zero out the buffer that's used to hold the current row. This is necessary in order
// to ensure that rows hash properly, since garbage data from the previous row could
Expand All @@ -291,7 +299,13 @@ case class UnsafeGeneratedAggregate(
val currentGroup: Row = groupProjection(currentJavaRow)
// Convert the current group into an UnsafeRow so that we can use it as a key for our
// aggregation hash map
// --- TODO ---
val groupProjectionSize = keyToUnsafeRowConverter.getSizeRequirement(currentGroup)
if (groupProjectionSize > unsafeRowBuffer.length) {
throw new IllegalStateException("Group projection does not fit into buffer")
}
keyToUnsafeRowConverter.writeRow(
currentGroup, unsafeRowBuffer, PlatformDependent.LONG_ARRAY_OFFSET)

val keyLengthInBytes: Int = 0
val loc: BytesToBytesMap#Location =
buffers.lookup(unsafeRowBuffer, PlatformDependent.LONG_ARRAY_OFFSET, keyLengthInBytes)
Expand All @@ -308,18 +322,18 @@ case class UnsafeGeneratedAggregate(
unsafeRowBuffer,
PlatformDependent.LONG_ARRAY_OFFSET,
keyLengthInBytes,
null, // empty agg buffer
emptyAggregationBuffer,
PlatformDependent.LONG_ARRAY_OFFSET,
0 // length of the aggregation buffer
emptyAggregationBuffer.length
)
}
// Reset our pointer to point to the buffer stored in the hash map
val address = loc.getValueAddress
currentBuffer.set(
address.getBaseObject,
address.getBaseOffset,
numberOfFieldsInAggregationBuffer,
javaAggregationBuffer.schema
numberOfColumnsInAggBuffer,
null
)
// Target the projection at the current aggregation buffer and then project the updated
// values.
Expand All @@ -346,8 +360,8 @@ case class UnsafeGeneratedAggregate(
value.set(
valueAddress.getBaseObject,
valueAddress.getBaseOffset,
aggregationBufferSchema.fields.length,
aggregationBufferSchema
numberOfColumnsInAggBuffer,
null
)
// TODO: once the iterator has been fully consumed, we need to free the map so that
// its off-heap memory is reclaimed. This may mean that we'll have to perform an extra
Expand Down

0 comments on commit fc4c3a8

Please sign in to comment.