Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-23877][SQL]: Use filter predicates to prune partitions in metadata-only queries #20988

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ case class OptimizeMetadataOnlyQuery(catalog: SessionCatalog) extends Rule[Logic
}

plan.transform {
case a @ Aggregate(_, aggExprs, child @ PartitionedRelation(partAttrs, relation)) =>
case a @ Aggregate(_, aggExprs, child @ PartitionedRelation(_, attrs, filters, rel)) =>
// We only apply this optimization when only partitioned attributes are scanned.
if (a.references.subsetOf(partAttrs)) {
if (a.references.subsetOf(attrs)) {
val aggFunctions = aggExprs.flatMap(_.collect {
case agg: AggregateExpression => agg
})
Expand All @@ -67,7 +67,7 @@ case class OptimizeMetadataOnlyQuery(catalog: SessionCatalog) extends Rule[Logic
})
}
if (isAllDistinctAgg) {
a.withNewChildren(Seq(replaceTableScanWithPartitionMetadata(child, relation)))
a.withNewChildren(Seq(replaceTableScanWithPartitionMetadata(child, rel, filters)))
} else {
a
}
Expand Down Expand Up @@ -98,27 +98,49 @@ case class OptimizeMetadataOnlyQuery(catalog: SessionCatalog) extends Rule[Logic
*/
private def replaceTableScanWithPartitionMetadata(
child: LogicalPlan,
relation: LogicalPlan): LogicalPlan = {
relation: LogicalPlan,
partFilters: Seq[Expression]): LogicalPlan = {
// this logic comes from PruneFileSourcePartitions. it ensures that the filter names match the
// relation's schema. PartitionedRelation ensures that the filters only reference partition cols
val relFilters = partFilters.map { e =>
e transform {
case a: AttributeReference =>
a.withName(relation.output.find(_.semanticEquals(a)).get.name)
}
}

child transform {
case plan if plan eq relation =>
relation match {
case l @ LogicalRelation(fsRelation: HadoopFsRelation, _, _, isStreaming) =>
val partAttrs = getPartitionAttrs(fsRelation.partitionSchema.map(_.name), l)
val partitionData = fsRelation.location.listFiles(Nil, Nil)
LocalRelation(partAttrs, partitionData.map(_.values), isStreaming)
val partitionData = fsRelation.location.listFiles(relFilters, Nil)
// partition data may be a stream, which can cause serialization to hit stack level too
// deep exceptions because it is a recursive structure in memory. converting to array
// avoids the problem.
LocalRelation(partAttrs, partitionData.map(_.values).toArray, isStreaming)

case relation: HiveTableRelation =>
val partAttrs = getPartitionAttrs(relation.tableMeta.partitionColumnNames, relation)
val caseInsensitiveProperties =
CaseInsensitiveMap(relation.tableMeta.storage.properties)
val timeZoneId = caseInsensitiveProperties.get(DateTimeUtils.TIMEZONE_OPTION)
.getOrElse(SQLConf.get.sessionLocalTimeZone)
val partitionData = catalog.listPartitions(relation.tableMeta.identifier).map { p =>
val partitions = if (partFilters.nonEmpty) {
catalog.listPartitionsByFilter(relation.tableMeta.identifier, relFilters)
} else {
catalog.listPartitions(relation.tableMeta.identifier)
}

val partitionData = partitions.map { p =>
InternalRow.fromSeq(partAttrs.map { attr =>
Cast(Literal(p.spec(attr.name)), attr.dataType, Option(timeZoneId)).eval()
})
}
LocalRelation(partAttrs, partitionData)
// partition data may be a stream, which can cause serialization to hit stack level too
// deep exceptions because it is a recursive structure in memory. converting to array
// avoids the problem.
LocalRelation(partAttrs, partitionData.toArray)

case _ =>
throw new IllegalStateException(s"unrecognized table scan node: $relation, " +
Expand All @@ -129,35 +151,47 @@ case class OptimizeMetadataOnlyQuery(catalog: SessionCatalog) extends Rule[Logic

/**
* A pattern that finds the partitioned table relation node inside the given plan, and returns a
* pair of the partition attributes and the table relation node.
* pair of the partition attributes, partition filters, and the table relation node.
*
* It keeps traversing down the given plan tree if there is a [[Project]] or [[Filter]] with
* deterministic expressions, and returns result after reaching the partitioned table relation
* node.
*/
object PartitionedRelation {

def unapply(plan: LogicalPlan): Option[(AttributeSet, LogicalPlan)] = plan match {
case l @ LogicalRelation(fsRelation: HadoopFsRelation, _, _, _)
if fsRelation.partitionSchema.nonEmpty =>
val partAttrs = getPartitionAttrs(fsRelation.partitionSchema.map(_.name), l)
Some((AttributeSet(partAttrs), l))

case relation: HiveTableRelation if relation.tableMeta.partitionColumnNames.nonEmpty =>
val partAttrs = getPartitionAttrs(relation.tableMeta.partitionColumnNames, relation)
Some((AttributeSet(partAttrs), relation))

case p @ Project(projectList, child) if projectList.forall(_.deterministic) =>
unapply(child).flatMap { case (partAttrs, relation) =>
if (p.references.subsetOf(partAttrs)) Some((p.outputSet, relation)) else None
}
object PartitionedRelation extends PredicateHelper {

def unapply(
plan: LogicalPlan): Option[(AttributeSet, AttributeSet, Seq[Expression], LogicalPlan)] = {
plan match {
case l @ LogicalRelation(fsRelation: HadoopFsRelation, _, _, _)
if fsRelation.partitionSchema.nonEmpty =>
val partAttrs = AttributeSet(getPartitionAttrs(fsRelation.partitionSchema.map(_.name), l))
Some((partAttrs, partAttrs, Nil, l))

case relation: HiveTableRelation if relation.tableMeta.partitionColumnNames.nonEmpty =>
val partAttrs = AttributeSet(
getPartitionAttrs(relation.tableMeta.partitionColumnNames, relation))
Some((partAttrs, partAttrs, Nil, relation))

case p @ Project(projectList, child) if projectList.forall(_.deterministic) =>
unapply(child).flatMap { case (partAttrs, attrs, filters, relation) =>
if (p.references.subsetOf(attrs)) {
Some((partAttrs, p.outputSet, filters, relation))
} else {
None
}
}

case f @ Filter(condition, child) if condition.deterministic =>
unapply(child).flatMap { case (partAttrs, relation) =>
if (f.references.subsetOf(partAttrs)) Some((partAttrs, relation)) else None
}
case f @ Filter(condition, child) if condition.deterministic =>
unapply(child).flatMap { case (partAttrs, attrs, filters, relation) =>
if (f.references.subsetOf(partAttrs)) {
Some((partAttrs, attrs, splitConjunctivePredicates(condition) ++ filters, relation))
} else {
None
}
}

case _ => None
case _ => None
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.hive

import org.scalatest.BeforeAndAfter

import org.apache.spark.metrics.source.HiveCatalogMetrics
import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.catalyst.expressions.NamedExpression
import org.apache.spark.sql.catalyst.plans.logical.{Distinct, Filter, Project, SubqueryAlias}
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}

class OptimizeHiveMetadataOnlyQuerySuite extends QueryTest with TestHiveSingleton
with BeforeAndAfter with SQLTestUtils {

import spark.implicits._

before {
sql("CREATE TABLE metadata_only (id bigint, data string) PARTITIONED BY (part int)")
(0 to 10).foreach(p => sql(s"ALTER TABLE metadata_only ADD PARTITION (part=$p)"))
}

test("SPARK-23877: validate metadata-only query pushes filters to metastore") {
withTable("metadata_only") {
val startCount = HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount

// verify the number of matching partitions
assert(sql("SELECT DISTINCT part FROM metadata_only WHERE part < 5").collect().length === 5)

// verify that the partition predicate was pushed down to the metastore
assert(HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount - startCount === 5)
}
}

test("SPARK-23877: filter on projected expression") {
withTable("metadata_only") {
val startCount = HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount

// verify the matching partitions
val partitions = spark.internalCreateDataFrame(Distinct(Filter(($"x" < 5).expr,
Project(Seq(($"part" + 1).as("x").expr.asInstanceOf[NamedExpression]),
spark.table("metadata_only").logicalPlan.asInstanceOf[SubqueryAlias].child)))
.queryExecution.toRdd, StructType(Seq(StructField("x", IntegerType))))

checkAnswer(partitions, Seq(1, 2, 3, 4).toDF("x"))

// verify that the partition predicate was not pushed down to the metastore
assert(HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount - startCount == 11)
}
}
}