Skip to content

Commit

Permalink
[SPARK-47157][SQL] Refactor file listing with ScanFileListing interface
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

In this pull request, we've introduce the `ScanFileListing` trait and its implementation, the `GenericScanFileListing` class, to encapsulate and streamline the handling of file listing results. This new abstraction enhances modularity and facilitates more flexible management of file listings within the system.

### Why are the changes needed?

The introduction of these constructs is crucial for defining a standardized API for file listing operations,  regardless of the underlying representation that's used to represent files and partitions. By improving the modularity of the code we enable future improvements that can prove to be beneficial both for runtime and memory improvements.

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?

This is just a refactoring, not a new behavior, so existing tests would suffice.

### Was this patch authored or co-authored using generative AI tooling?
No

Closes #45224 from costas-db/refactorFileListing.

Lead-authored-by: Costas Zarifis <costas.zarifis@databricks.com>
Co-authored-by: Shoumik Palkar <shoumik.palkar@databricks.com>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
  • Loading branch information
2 people authored and HyukjinKwon committed Feb 26, 2024
1 parent 18b8606 commit e0facc3
Show file tree
Hide file tree
Showing 5 changed files with 223 additions and 61 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -278,26 +278,28 @@ trait FileSourceScanLike extends DataSourceScanExec {

// This field will be accessed during planning (e.g., `outputPartitioning` relies on it), and can
// only use static filters.
@transient lazy val selectedPartitions: Array[PartitionDirectory] = {
@transient lazy val selectedPartitions: ScanFileListing = {
val optimizerMetadataTimeNs = relation.location.metadataOpsTimeNs.getOrElse(0L)
val startTime = System.nanoTime()
// The filters may contain subquery expressions which can't be evaluated during planning.
// Here we filter out subquery expressions and get the static data/partition filters, so that
// they can be used to do pruning at the planning phase.
val staticDataFilters = dataFilters.filterNot(isDynamicPruningFilter)
val staticPartitionFilters = partitionFilters.filterNot(isDynamicPruningFilter)
val ret = relation.location.listFiles(staticPartitionFilters, staticDataFilters)
setFilesNumAndSizeMetric(ret, true)
val partitionDirectories = relation.location.listFiles(
staticPartitionFilters, staticDataFilters)
val fileListing = GenericScanFileListing(partitionDirectories.toArray)
setFilesNumAndSizeMetric(fileListing, static = true)
val timeTakenMs = NANOSECONDS.toMillis(
(System.nanoTime() - startTime) + optimizerMetadataTimeNs)
driverMetrics("metadataTime").set(timeTakenMs)
ret
}.toArray
fileListing
}

// We can only determine the actual partitions at runtime when a dynamic partition filter is
// present. This is because such a filter relies on information that is only available at run
// time (for instance the keys used in the other side of a join).
@transient protected lazy val dynamicallySelectedPartitions: Array[PartitionDirectory] = {
@transient protected lazy val dynamicallySelectedPartitions: ScanFileListing = {
val dynamicDataFilters = dataFilters.filter(isDynamicPruningFilter)
val dynamicPartitionFilters = partitionFilters.filter(isDynamicPruningFilter)

Expand All @@ -311,15 +313,12 @@ trait FileSourceScanLike extends DataSourceScanExec {
val index = partitionColumns.indexWhere(a.name == _.name)
BoundReference(index, partitionColumns(index).dataType, nullable = true)
}, Nil)
var ret = selectedPartitions.filter(p => boundPredicate.eval(p.values))
if (dynamicDataFilters.nonEmpty) {
val filePruningRunner = new FilePruningRunner(dynamicDataFilters)
ret = ret.map(filePruningRunner.prune)
}
setFilesNumAndSizeMetric(ret.toImmutableArraySeq, false)
val timeTakenMs = (System.nanoTime() - startTime) / 1000 / 1000
val returnedFiles = selectedPartitions.filterAndPruneFiles(
boundPredicate, dynamicDataFilters)
setFilesNumAndSizeMetric(returnedFiles, false)
val timeTakenMs = NANOSECONDS.toMillis(System.nanoTime() - startTime)
driverMetrics("pruningTime").set(timeTakenMs)
ret
returnedFiles
} else {
selectedPartitions
}
Expand Down Expand Up @@ -376,11 +375,7 @@ trait FileSourceScanLike extends DataSourceScanExec {
// but those files combined together are not globally sorted. Given that,
// the RDD partition will not be sorted even if the relation has sort columns set
// Current solution is to check if all the buckets have a single file in it

val files = selectedPartitions.flatMap(partition => partition.files)
val bucketToFilesGrouping =
files.map(_.getPath.getName).groupBy(file => BucketingUtils.getBucketId(file))
val singleFilePartitions = bucketToFilesGrouping.forall(p => p._2.length <= 1)
val singleFilePartitions = selectedPartitions.bucketsContainSingleFile

// TODO SPARK-24528 Sort order is currently ignored if buckets are coalesced.
if (singleFilePartitions && optionalNumCoalescedBuckets.isEmpty) {
Expand Down Expand Up @@ -503,11 +498,9 @@ trait FileSourceScanLike extends DataSourceScanExec {
}

/** Helper for computing total number and size of files in selected partitions. */
private def setFilesNumAndSizeMetric(
partitions: Seq[PartitionDirectory],
static: Boolean): Unit = {
val filesNum = partitions.map(_.files.size.toLong).sum
val filesSize = partitions.map(_.files.map(_.getLen).sum).sum
private def setFilesNumAndSizeMetric(partitions: ScanFileListing, static: Boolean): Unit = {
val filesNum = partitions.totalNumberOfFiles
val filesSize = partitions.totalFileSize
if (!static || !partitionFilters.exists(isDynamicPruningFilter)) {
driverMetrics("numFiles").set(filesNum)
driverMetrics("filesSize").set(filesSize)
Expand All @@ -516,7 +509,7 @@ trait FileSourceScanLike extends DataSourceScanExec {
driverMetrics("staticFilesSize").set(filesSize)
}
if (relation.partitionSchema.nonEmpty) {
driverMetrics("numPartitions").set(partitions.length)
driverMetrics("numPartitions").set(partitions.partitionCount)
}
}

Expand All @@ -531,6 +524,61 @@ trait FileSourceScanLike extends DataSourceScanExec {
None
}
} ++ driverMetrics

/**
* A file listing that represents a file list as an array of [[PartitionDirectory]]. This extends
* [[ScanFileListing]] in order to access methods for computing partition sizes and so forth.
* This is the default file listing class used for all table formats.
*/
private case class GenericScanFileListing(partitionDirectories: Array[PartitionDirectory])
extends ScanFileListing {

override def partitionCount: Int = partitionDirectories.length

override def totalFileSize: Long = partitionDirectories.map(_.files.map(_.getLen).sum).sum

override def totalNumberOfFiles: Long = partitionDirectories.map(_.files.length).sum.toLong

override def filterAndPruneFiles(
boundPredicate: BasePredicate,
dynamicFileFilters: Seq[Expression]): ScanFileListing = {
val filteredPartitions = partitionDirectories.filter(p => boundPredicate.eval(p.values))
val prunedPartitions = if (dynamicFileFilters.nonEmpty) {
val filePruningRunner = new FilePruningRunner(dynamicFileFilters)
filteredPartitions.map(filePruningRunner.prune)
} else {
filteredPartitions
}
GenericScanFileListing(prunedPartitions)
}

override def toPartitionArray: Array[PartitionedFile] = {
partitionDirectories.flatMap { p =>
p.files.map { f => PartitionedFileUtil.getPartitionedFile(f, p.values, 0, f.getLen) }
}
}

override def calculateTotalPartitionBytes: Long = {
val openCostInBytes = relation.sparkSession.sessionState.conf.filesOpenCostInBytes
partitionDirectories.flatMap(_.files.map(_.getLen + openCostInBytes)).sum
}

override def filePartitionIterator: Iterator[ListingPartition] = {
partitionDirectories.iterator.map { partitionDirectory =>
ListingPartition(
partitionDirectory.values,
partitionDirectory.files.size,
partitionDirectory.files.iterator)
}
}

override def bucketsContainSingleFile: Boolean = {
val files = partitionDirectories.flatMap(_.files)
val bucketToFilesGrouping =
files.map(_.getPath.getName).groupBy(file => BucketingUtils.getBucketId(file))
bucketToFilesGrouping.forall(p => p._2.length <= 1)
}
}
}

/**
Expand Down Expand Up @@ -664,16 +712,14 @@ case class FileSourceScanExec(
private def createBucketedReadRDD(
bucketSpec: BucketSpec,
readFile: (PartitionedFile) => Iterator[InternalRow],
selectedPartitions: Array[PartitionDirectory]): RDD[InternalRow] = {
selectedPartitions: ScanFileListing): RDD[InternalRow] = {
logInfo(s"Planning with ${bucketSpec.numBuckets} buckets")
val filesGroupedToBuckets =
selectedPartitions.flatMap { p =>
p.files.map(f => PartitionedFileUtil.getPartitionedFile(f, p.values, 0, f.getLen))
}.groupBy { f =>
BucketingUtils
.getBucketId(f.toPath.getName)
.getOrElse(throw QueryExecutionErrors.invalidBucketFile(f.urlEncodedPath))
}
val partitionArray = selectedPartitions.toPartitionArray
val filesGroupedToBuckets = partitionArray.groupBy { f =>
BucketingUtils
.getBucketId(f.toPath.getName)
.getOrElse(throw QueryExecutionErrors.invalidBucketFile(f.urlEncodedPath))
}

val prunedFilesGroupedToBuckets = if (optionalBucketSet.isDefined) {
val bucketSet = optionalBucketSet.get
Expand Down Expand Up @@ -714,10 +760,10 @@ case class FileSourceScanExec(
*/
private def createReadRDD(
readFile: PartitionedFile => Iterator[InternalRow],
selectedPartitions: Array[PartitionDirectory]): RDD[InternalRow] = {
selectedPartitions: ScanFileListing): RDD[InternalRow] = {
val openCostInBytes = relation.sparkSession.sessionState.conf.filesOpenCostInBytes
val maxSplitBytes =
FilePartition.maxSplitBytes(relation.sparkSession, selectedPartitions.toImmutableArraySeq)
FilePartition.maxSplitBytes(relation.sparkSession, selectedPartitions)
logInfo(s"Planning scan with bin packing, max size: $maxSplitBytes bytes, " +
s"open cost is considered as scanning $openCostInBytes bytes.")

Expand All @@ -731,22 +777,23 @@ case class FileSourceScanExec(
_ => true
}

val splitFiles = selectedPartitions.flatMap { partition =>
partition.files.flatMap { file =>
val splitFiles = selectedPartitions.filePartitionIterator.flatMap { partition =>
val ListingPartition(partitionVals, _, fileStatusIterator) = partition
fileStatusIterator.flatMap { file =>
if (shouldProcess(file.getPath)) {
val isSplitable = relation.fileFormat.isSplitable(
relation.sparkSession, relation.options, file.getPath)
PartitionedFileUtil.splitFiles(
file = file,
isSplitable = isSplitable,
maxSplitBytes = maxSplitBytes,
partitionValues = partition.values
partitionValues = partitionVals
)
} else {
Seq.empty
}
}
}.sortBy(_.length)(implicitly[Ordering[Long]].reverse)
}.toArray.sortBy(_.length)(implicitly[Ordering[Long]].reverse)

val partitions = FilePartition
.getFilePartitions(relation.sparkSession, splitFiles.toImmutableArraySeq, maxSplitBytes)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{BasePredicate, Expression}
import org.apache.spark.sql.execution.datasources.{FileStatusWithMetadata, PartitionedFile}

case class ListingPartition(
values: InternalRow,
numFiles: Long,
files: Iterator[FileStatusWithMetadata])

/**
* Trait used to represent the selected partitions and dynamically selected partitions
* during file listing.
*
* The `ScanFileListing` trait defines the core API for interacting with selected partitions,
* establishing a contract for subclasses. It is situated at the root of this package and it is
* designed to provide a widely accessible definition, that is accessible to other packages and
* classes that need a way to represent the selected partitions and dynamically selected partitions.
*/
trait ScanFileListing {

/**
* Returns the number of partitions for the current partition representation.
*/
def partitionCount: Int

/**
* Calculates the total size in bytes of all files across the current file listing representation.
*/
def totalFileSize: Long

/**
* Returns the total number of files across the current file listing representation.
*/
def totalNumberOfFiles: Long

/**
* Filters and prunes files from the current scan file listing representation based on the given
* predicate and dynamic file filters. Initially, it filters partitions based on a static
* predicate. For partitions that pass this filter, it further prunes files using dynamic file
* filters, if any are provided. This method assumes that dynamic file filters are applicable
* only to files within partitions that have already passed the static predicate filter.
*/
def filterAndPruneFiles(
boundPredicate: BasePredicate, dynamicFileFilters: Seq[Expression]): ScanFileListing

/**
* Returns an [[Array[PartitionedFile]] from the current ScanFileListing representation.
*/
def toPartitionArray: Array[PartitionedFile]

/**
* Returns the total partition size in bytes for the current ScanFileListing representation.
*/
def calculateTotalPartitionBytes : Long

/**
* Returns an iterator of over the partitions and their files for the file listing representation.
* This allows us to iterate over the partitions without the additional overhead of materializing
* the whole collection.
*/
def filePartitionIterator: Iterator[ListingPartition]

/**
* Determines if each bucket in the current file listing representation contains at most one file.
* This function returns true if it does, or false otherwise.
*/
def bucketsContainSingleFile: Boolean
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.apache.spark.Partition
import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.connector.read.InputPartition
import org.apache.spark.sql.execution.ScanFileListing
import org.apache.spark.sql.internal.SQLConf

/**
Expand Down Expand Up @@ -106,16 +107,38 @@ object FilePartition extends Logging {
}
}

def maxSplitBytes(
sparkSession: SparkSession,
selectedPartitions: Seq[PartitionDirectory]): Long = {
/**
* Returns the max split bytes, given the total number of bytes taken by the selected
* partitions.
*/
def maxSplitBytes(sparkSession: SparkSession, calculateTotalBytes: => Long): Long = {
val defaultMaxSplitBytes = sparkSession.sessionState.conf.filesMaxPartitionBytes
val openCostInBytes = sparkSession.sessionState.conf.filesOpenCostInBytes
val minPartitionNum = sparkSession.sessionState.conf.filesMinPartitionNum
.getOrElse(sparkSession.leafNodeDefaultParallelism)
val totalBytes = selectedPartitions.flatMap(_.files.map(_.getLen + openCostInBytes)).sum
val totalBytes = calculateTotalBytes
val bytesPerCore = totalBytes / minPartitionNum

Math.min(defaultMaxSplitBytes, Math.max(openCostInBytes, bytesPerCore))
}

/**
* Returns the max split bytes, given the selected partitions represented using the
* [[ScanFileListing]] type.
*/
def maxSplitBytes(sparkSession: SparkSession, selectedPartitions: ScanFileListing): Long = {
val byteNum = selectedPartitions.calculateTotalPartitionBytes
maxSplitBytes(sparkSession, byteNum)
}

/**
* Returns the max split bytes, given the selected partitions represented as a sequence of
* [[PartitionDirectory]]s.
*/
def maxSplitBytes(
sparkSession: SparkSession, selectedPartitions: Seq[PartitionDirectory]): Long = {
val openCostInBytes = sparkSession.sessionState.conf.filesOpenCostInBytes
val byteNum = selectedPartitions.flatMap(_.files.map(_.getLen + openCostInBytes)).sum
maxSplitBytes(sparkSession, byteNum)
}
}

0 comments on commit e0facc3

Please sign in to comment.