From 4fb313bd4f69d384fdacd1e2457b7be7772527ce Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Sat, 13 Nov 2021 23:04:29 -0800 Subject: [PATCH] address comments --- .../execution/datasources/AggregatePushDownUtils.scala | 5 +++++ .../spark/sql/execution/datasources/orc/OrcUtils.scala | 4 +++- .../sql/execution/datasources/parquet/ParquetUtils.scala | 4 +++- .../datasources/v2/orc/OrcPartitionReaderFactory.scala | 8 ++------ .../v2/parquet/ParquetPartitionReaderFactory.scala | 9 ++------- 5 files changed, 15 insertions(+), 15 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/AggregatePushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/AggregatePushDownUtils.scala index 2fafdddcad6cc..f00217e8eb0f4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/AggregatePushDownUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/AggregatePushDownUtils.scala @@ -168,6 +168,11 @@ object AggregatePushDownUtils { aggregation: Aggregation, partitionValues: InternalRow): InternalRow = { val groupByColNames = aggregation.groupByColumns.map(_.fieldNames.head) + assert(groupByColNames.length == partitionSchema.length && + groupByColNames.length == partitionValues.numFields, "The number of group by columns " + + s"${groupByColNames.length} should be the same as partition schema length " + + s"${partitionSchema.length} and the number of fields ${partitionValues.numFields} " + + s"in partitionValues") var reorderedPartColValues = Array.empty[Any] if (!partitionSchema.names.sameElements(groupByColNames)) { groupByColNames.foreach { col => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala index fd9c4e1e7fc0e..fb599d0b21c54 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala @@ -499,7 +499,9 @@ object OrcUtils extends Logging { (0 until schemaWithoutGroupBy.length).toArray) val resultRow = orcValuesDeserializer.deserializeFromValues(aggORCValues) if (aggregation.groupByColumns.nonEmpty) { - new JoinedRow(partitionValues, resultRow) + val reOrderedPartitionValues = AggregatePushDownUtils.reOrderPartitionCol( + partitionSchema, aggregation, partitionValues) + new JoinedRow(reOrderedPartitionValues, resultRow) } else { resultRow } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala index 53e6514f8bb82..4dff5d89f3ad2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala @@ -203,7 +203,9 @@ object ParquetUtils { } if (aggregation.groupByColumns.nonEmpty) { - new JoinedRow(partitionValues, converter.currentRecord) + val reorderedPartitionValues = AggregatePushDownUtils.reOrderPartitionCol( + partitionSchema, aggregation, partitionValues) + new JoinedRow(reorderedPartitionValues, converter.currentRecord) } else { converter.currentRecord } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala index 49aa24319e5b5..1363a487d7fa1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala @@ -190,11 +190,9 @@ case class OrcPartitionReaderFactory( private var hasNext = true private lazy val row: InternalRow = { Utils.tryWithResource(createORCReader(filePath, conf)) { reader => - val partitionValues = AggregatePushDownUtils.reOrderPartitionCol( - partitionSchema, aggregation.get, file.partitionValues) OrcUtils.createAggInternalRowFromFooter( reader, filePath.toString, dataSchema, partitionSchema, aggregation.get, - readDataSchema, partitionValues) + readDataSchema, file.partitionValues) } } @@ -220,11 +218,9 @@ case class OrcPartitionReaderFactory( private var hasNext = true private lazy val batch: ColumnarBatch = { Utils.tryWithResource(createORCReader(filePath, conf)) { reader => - val partitionValues = AggregatePushDownUtils.reOrderPartitionCol( - partitionSchema, aggregation.get, file.partitionValues) val row = OrcUtils.createAggInternalRowFromFooter( reader, filePath.toString, dataSchema, partitionSchema, aggregation.get, - readDataSchema, partitionValues) + readDataSchema, file.partitionValues) AggregatePushDownUtils.convertAggregatesRowToBatch(row, readDataSchema, offHeap = false) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala index 4c0c71ad9579f..56d6cc7f57b92 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala @@ -135,12 +135,9 @@ case class ParquetPartitionReaderFactory( private lazy val row: InternalRow = { val footer = getFooter(file) - val partitionValues = AggregatePushDownUtils.reOrderPartitionCol( - partitionSchema, aggregation.get, file.partitionValues) - if (footer != null && footer.getBlocks.size > 0) { ParquetUtils.createAggInternalRowFromFooter(footer, file.filePath, dataSchema, - partitionSchema, aggregation.get, readDataSchema, partitionValues, + partitionSchema, aggregation.get, readDataSchema, file.partitionValues, getDatetimeRebaseMode(footer.getFileMetaData)) } else { null @@ -182,10 +179,8 @@ case class ParquetPartitionReaderFactory( private val batch: ColumnarBatch = { val footer = getFooter(file) if (footer != null && footer.getBlocks.size > 0) { - val partitionValues = AggregatePushDownUtils.reOrderPartitionCol( - partitionSchema, aggregation.get, file.partitionValues) val row = ParquetUtils.createAggInternalRowFromFooter(footer, file.filePath, - dataSchema, partitionSchema, aggregation.get, readDataSchema, partitionValues, + dataSchema, partitionSchema, aggregation.get, readDataSchema, file.partitionValues, getDatetimeRebaseMode(footer.getFileMetaData)) AggregatePushDownUtils.convertAggregatesRowToBatch( row, readDataSchema, enableOffHeapColumnVector && Option(TaskContext.get()).isDefined)