From fc4c3a8aa5b345526298379124530b6c2793d9e5 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 18 Apr 2015 22:26:51 -0700 Subject: [PATCH] Sketch how the converters will be used in UnsafeGeneratedAggregate --- .../execution/UnsafeGeneratedAggregate.scala | 50 ++++++++++++------- 1 file changed, 32 insertions(+), 18 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeGeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeGeneratedAggregate.scala index 95668ae5c69e5..485e35c849f9a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeGeneratedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeGeneratedAggregate.scala @@ -35,10 +35,10 @@ import org.apache.spark.unsafe.memory.MemoryAllocator */ @DeveloperApi case class UnsafeGeneratedAggregate( - partial: Boolean, - groupingExpressions: Seq[Expression], - aggregateExpressions: Seq[NamedExpression], - child: SparkPlan) + partial: Boolean, + groupingExpressions: Seq[Expression], + aggregateExpressions: Seq[NamedExpression], + child: SparkPlan) extends UnaryNode { override def requiredChildDistribution: Seq[Distribution] = @@ -267,17 +267,25 @@ case class UnsafeGeneratedAggregate( // We're going to need to allocate a lot of empty aggregation buffers, so let's do it // once and keep a copy of the serialized buffer and copy it into the hash map when we see // new keys: - val javaAggregationBuffer: MutableRow = - newAggregationBuffer(EmptyRow).asInstanceOf[MutableRow] - val numberOfFieldsInAggregationBuffer: Int = javaAggregationBuffer.schema.fields.length - val aggregationBufferSchema: StructType = javaAggregationBuffer.schema - // TODO perform that conversion to an UnsafeRow - // Allocate some scratch space for holding the keys that we use to index into the hash map. - val unsafeRowBuffer: Array[Long] = new Array[Long](1024) + val (emptyAggregationBuffer: Array[Long], numberOfColumnsInAggBuffer: Int) = { + val javaBuffer: MutableRow = newAggregationBuffer(EmptyRow).asInstanceOf[MutableRow] + val converter = new UnsafeRowConverter(javaBuffer.schema.fields.map(_.dataType)) + val buffer = new Array[Long](converter.getSizeRequirement(javaBuffer)) + converter.writeRow(javaBuffer, buffer, PlatformDependent.LONG_ARRAY_OFFSET) + (buffer, javaBuffer.schema.fields.length) + } // TODO: there's got got to be an actual way of obtaining this up front. var groupProjectionSchema: StructType = null + val keyToUnsafeRowConverter: UnsafeRowConverter = { + new UnsafeRowConverter(groupProjectionSchema.fields.map(_.dataType)) + } + + // Allocate some scratch space for holding the keys that we use to index into the hash map. + // 16 MB ought to be enough for anyone (TODO) + val unsafeRowBuffer: Array[Long] = new Array[Long](1024 * 16 / 8) + while (iter.hasNext) { // Zero out the buffer that's used to hold the current row. This is necessary in order // to ensure that rows hash properly, since garbage data from the previous row could @@ -291,7 +299,13 @@ case class UnsafeGeneratedAggregate( val currentGroup: Row = groupProjection(currentJavaRow) // Convert the current group into an UnsafeRow so that we can use it as a key for our // aggregation hash map - // --- TODO --- + val groupProjectionSize = keyToUnsafeRowConverter.getSizeRequirement(currentGroup) + if (groupProjectionSize > unsafeRowBuffer.length) { + throw new IllegalStateException("Group projection does not fit into buffer") + } + keyToUnsafeRowConverter.writeRow( + currentGroup, unsafeRowBuffer, PlatformDependent.LONG_ARRAY_OFFSET) + val keyLengthInBytes: Int = 0 val loc: BytesToBytesMap#Location = buffers.lookup(unsafeRowBuffer, PlatformDependent.LONG_ARRAY_OFFSET, keyLengthInBytes) @@ -308,9 +322,9 @@ case class UnsafeGeneratedAggregate( unsafeRowBuffer, PlatformDependent.LONG_ARRAY_OFFSET, keyLengthInBytes, - null, // empty agg buffer + emptyAggregationBuffer, PlatformDependent.LONG_ARRAY_OFFSET, - 0 // length of the aggregation buffer + emptyAggregationBuffer.length ) } // Reset our pointer to point to the buffer stored in the hash map @@ -318,8 +332,8 @@ case class UnsafeGeneratedAggregate( currentBuffer.set( address.getBaseObject, address.getBaseOffset, - numberOfFieldsInAggregationBuffer, - javaAggregationBuffer.schema + numberOfColumnsInAggBuffer, + null ) // Target the projection at the current aggregation buffer and then project the updated // values. @@ -346,8 +360,8 @@ case class UnsafeGeneratedAggregate( value.set( valueAddress.getBaseObject, valueAddress.getBaseOffset, - aggregationBufferSchema.fields.length, - aggregationBufferSchema + numberOfColumnsInAggBuffer, + null ) // TODO: once the iterator has been fully consumed, we need to free the map so that // its off-heap memory is reclaimed. This may mean that we'll have to perform an extra