Skip to content

Commit

Permalink
[HUDI-1879] Support Partition Prune For MergeOnRead Snapshot Table
Browse files Browse the repository at this point in the history
  • Loading branch information
pengzhiwei2018 committed May 26, 2021
1 parent a5789c4 commit c4b07cb
Show file tree
Hide file tree
Showing 5 changed files with 350 additions and 5 deletions.
Expand Up @@ -262,7 +262,14 @@ case class HoodieFileIndex(
// If the partition column size is not equal to the partition fragment size
// and the partition column size is 1, we map the whole partition path
// to the partition column which can benefit from the partition prune.
InternalRow.fromSeq(Seq(UTF8String.fromString(partitionPath)))
val prefix = s"${partitionSchema.fieldNames.head}="
val partitionValue = if (partitionPath.startsWith(prefix)) {
// support hive style partition path
partitionPath.substring(prefix.length)
} else {
partitionPath
}
InternalRow.fromSeq(Seq(UTF8String.fromString(partitionValue)))
} else if (partitionFragments.length != partitionSchema.fields.length &&
partitionSchema.fields.length > 1) {
// If the partition column size is not equal to the partition fragments size
Expand Down
Expand Up @@ -28,8 +28,10 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.sql.avro.SchemaConverters
import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder}
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, Literal}
import org.apache.spark.sql.execution.datasources.{FileStatusCache, InMemoryFileIndex, Spark2ParsePartitionUtil, Spark3ParsePartitionUtil, SparkParsePartitionUtil}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.{And, EqualNullSafe, EqualTo, Filter, GreaterThan, GreaterThanOrEqual, In, IsNotNull, IsNull, LessThan, LessThanOrEqual, Not, Or, StringContains, StringEndsWith, StringStartsWith}
import org.apache.spark.sql.types.{StringType, StructField, StructType}

import scala.collection.JavaConverters._
Expand Down Expand Up @@ -128,4 +130,98 @@ object HoodieSparkUtils {
new Spark3ParsePartitionUtil(conf)
}
}

/**
* Convert Filters to Catalyst Expressions and joined by And. If convert success return an
* Non-Empty Option[Expression],or else return None.
*/
def convertToCatalystExpressions(filters: Array[Filter],
tableSchema: StructType): Option[Expression] = {
val expressions = filters.map(convertToCatalystExpression(_, tableSchema))
if (expressions.forall(p => p.isDefined)) {
if (expressions.isEmpty) {
None
} else if (expressions.length == 1) {
expressions(0)
} else {
Some(expressions.map(_.get).reduce(org.apache.spark.sql.catalyst.expressions.And))
}
} else {
None
}
}

/**
* Convert Filter to Catalyst Expression. If convert success return an Non-Empty
* Option[Expression],or else return None.
*/
def convertToCatalystExpression(filter: Filter, tableSchema: StructType): Option[Expression] = {
Option(
filter match {
case EqualTo(attribute, value) =>
org.apache.spark.sql.catalyst.expressions.EqualTo(toAttribute(attribute, tableSchema), Literal.create(value))
case EqualNullSafe(attribute, value) =>
org.apache.spark.sql.catalyst.expressions.EqualNullSafe(toAttribute(attribute, tableSchema), Literal.create(value))
case GreaterThan(attribute, value) =>
org.apache.spark.sql.catalyst.expressions.GreaterThan(toAttribute(attribute, tableSchema), Literal.create(value))
case GreaterThanOrEqual(attribute, value) =>
org.apache.spark.sql.catalyst.expressions.GreaterThanOrEqual(toAttribute(attribute, tableSchema), Literal.create(value))
case LessThan(attribute, value) =>
org.apache.spark.sql.catalyst.expressions.LessThan(toAttribute(attribute, tableSchema), Literal.create(value))
case LessThanOrEqual(attribute, value) =>
org.apache.spark.sql.catalyst.expressions.LessThanOrEqual(toAttribute(attribute, tableSchema), Literal.create(value))
case In(attribute, values) =>
val attrExp = toAttribute(attribute, tableSchema)
val valuesExp = values.map(v => Literal.create(v))
org.apache.spark.sql.catalyst.expressions.In(attrExp, valuesExp)
case IsNull(attribute) =>
org.apache.spark.sql.catalyst.expressions.IsNull(toAttribute(attribute, tableSchema))
case IsNotNull(attribute) =>
org.apache.spark.sql.catalyst.expressions.IsNotNull(toAttribute(attribute, tableSchema))
case And(left, right) =>
val leftExp = convertToCatalystExpression(left, tableSchema)
val rightExp = convertToCatalystExpression(right, tableSchema)
if (leftExp.isEmpty || rightExp.isEmpty) {
null
} else {
org.apache.spark.sql.catalyst.expressions.And(leftExp.get, rightExp.get)
}
case Or(left, right) =>
val leftExp = convertToCatalystExpression(left, tableSchema)
val rightExp = convertToCatalystExpression(right, tableSchema)
if (leftExp.isEmpty || rightExp.isEmpty) {
null
} else {
org.apache.spark.sql.catalyst.expressions.Or(leftExp.get, rightExp.get)
}
case Not(child) =>
val childExp = convertToCatalystExpression(child, tableSchema)
if (childExp.isEmpty) {
null
} else {
org.apache.spark.sql.catalyst.expressions.Not(childExp.get)
}
case StringStartsWith(attribute, value) =>
val leftExp = toAttribute(attribute, tableSchema)
val rightExp = Literal.create(s"$value%")
org.apache.spark.sql.catalyst.expressions.Like(leftExp, rightExp)
case StringEndsWith(attribute, value) =>
val leftExp = toAttribute(attribute, tableSchema)
val rightExp = Literal.create(s"%$value")
org.apache.spark.sql.catalyst.expressions.Like(leftExp, rightExp)
case StringContains(attribute, value) =>
val leftExp = toAttribute(attribute, tableSchema)
val rightExp = Literal.create(s"%$value%")
org.apache.spark.sql.catalyst.expressions.Like(leftExp, rightExp)
case _=> null
}
)
}

private def toAttribute(columnName: String, tableSchema: StructType): AttributeReference = {
val field = tableSchema.find(p => p.name == columnName)
assert(field.isDefined, s"Cannot find column: $columnName, Table Columns are: " +
s"${tableSchema.fieldNames.mkString(",")}")
AttributeReference(columnName, field.get.dataType, field.get.nullable)()
}
}
Expand Up @@ -67,7 +67,6 @@ class MergeOnReadSnapshotRelation(val sqlContext: SQLContext,
DataSourceReadOptions.REALTIME_MERGE_OPT_KEY,
DataSourceReadOptions.DEFAULT_REALTIME_MERGE_OPT_VAL)
private val maxCompactionMemoryInBytes = getMaxCompactionMemoryInBytes(jobConf)
private val fileIndex = buildFileIndex()
private val preCombineField = {
val preCombineFieldFromTableConfig = metaClient.getTableConfig.getPreCombineField
if (preCombineFieldFromTableConfig != null) {
Expand All @@ -94,6 +93,8 @@ class MergeOnReadSnapshotRelation(val sqlContext: SQLContext,
})
val requiredAvroSchema = AvroConversionUtils
.convertStructTypeToAvroSchema(requiredStructSchema, tableAvroSchema.getName, tableAvroSchema.getNamespace)

val fileIndex = buildFileIndex(filters)
val hoodieTableState = HoodieMergeOnReadTableState(
tableStructSchema,
requiredStructSchema,
Expand Down Expand Up @@ -131,15 +132,28 @@ class MergeOnReadSnapshotRelation(val sqlContext: SQLContext,
rdd.asInstanceOf[RDD[Row]]
}

def buildFileIndex(): List[HoodieMergeOnReadFileSplit] = {
def buildFileIndex(filters: Array[Filter]): List[HoodieMergeOnReadFileSplit] = {

val fileStatuses = if (globPaths.isDefined) {
// Load files from the global paths if it has defined to be compatible with the original mode
val inMemoryFileIndex = HoodieSparkUtils.createInMemoryFileIndex(sqlContext.sparkSession, globPaths.get)
inMemoryFileIndex.allFiles()
} else { // Load files by the HoodieFileIndex.
val hoodieFileIndex = HoodieFileIndex(sqlContext.sparkSession, metaClient,
Some(tableStructSchema), optParams, FileStatusCache.getOrCreate(sqlContext.sparkSession))
hoodieFileIndex.allFiles

// Get partition filter and convert to catalyst expression
val partitionColumns = hoodieFileIndex.partitionSchema.fieldNames.toSet
val partitionFilters = filters.filter(f => f.references.forall(p => partitionColumns.contains(p)))
val partitionFilterExpression =
HoodieSparkUtils.convertToCatalystExpressions(partitionFilters, tableStructSchema)

// if convert success to catalyst expression, use the partition prune
if (partitionFilterExpression.isDefined) {
hoodieFileIndex.listFiles(Seq(partitionFilterExpression.get), Seq.empty).flatMap(_.files)
} else {
hoodieFileIndex.allFiles
}
}

if (fileStatuses.isEmpty) { // If this an empty table, return an empty split list.
Expand Down
@@ -0,0 +1,165 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.hudi

import org.apache.hudi.HoodieSparkUtils.convertToCatalystExpressions
import org.apache.hudi.HoodieSparkUtils.convertToCatalystExpression
import org.apache.spark.sql.sources.{And, EqualNullSafe, EqualTo, Filter, GreaterThan, GreaterThanOrEqual, In, IsNotNull, IsNull, LessThan, LessThanOrEqual, Not, Or, StringContains, StringEndsWith, StringStartsWith}
import org.apache.spark.sql.types.{DoubleType, IntegerType, LongType, StringType, StructField, StructType}
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Test

import scala.collection.mutable.ArrayBuffer

class TestConvertFilterToCatalystExpression {

private lazy val tableSchema = {
val fields = new ArrayBuffer[StructField]()
fields.append(StructField("id", LongType, nullable = false))
fields.append(StructField("name", StringType, nullable = true))
fields.append(StructField("price", DoubleType, nullable = true))
fields.append(StructField("ts", IntegerType, nullable = false))
StructType(fields)
}

@Test
def testBaseConvert(): Unit = {
checkConvertFilter(eq("id", 1), "(`id` = 1)")
checkConvertFilter(eqs("name", "a1"), "(`name` <=> 'a1')")
checkConvertFilter(lt("price", 10), "(`price` < 10)")
checkConvertFilter(lte("ts", 1), "(`ts` <= 1)")
checkConvertFilter(gt("price", 10), "(`price` > 10)")
checkConvertFilter(gte("price", 10), "(`price` >= 10)")
checkConvertFilter(in("id", 1, 2 , 3), "(`id` IN (1, 2, 3))")
checkConvertFilter(isNull("id"), "(`id` IS NULL)")
checkConvertFilter(isNotNull("name"), "(`name` IS NOT NULL)")
checkConvertFilter(and(lt("ts", 10), gt("ts", 1)),
"((`ts` < 10) AND (`ts` > 1))")
checkConvertFilter(or(lte("ts", 10), gte("ts", 1)),
"((`ts` <= 10) OR (`ts` >= 1))")
checkConvertFilter(not(and(lt("ts", 10), gt("ts", 1))),
"(NOT ((`ts` < 10) AND (`ts` > 1)))")
checkConvertFilter(startWith("name", "ab"), "`name` LIKE 'ab%'")
checkConvertFilter(endWith("name", "cd"), "`name` LIKE '%cd'")
checkConvertFilter(contains("name", "e"), "`name` LIKE '%e%'")
}

@Test
def testConvertFilters(): Unit = {
checkConvertFilters(Array.empty[Filter], null)
checkConvertFilters(Array(eq("id", 1)), "(`id` = 1)")
checkConvertFilters(Array(lt("ts", 10), gt("ts", 1)),
"((`ts` < 10) AND (`ts` > 1))")
}

@Test
def testUnSupportConvert(): Unit = {
checkConvertFilters(Array(unsupport()), null)
checkConvertFilters(Array(and(unsupport(), eq("id", 1))), null)
checkConvertFilters(Array(or(unsupport(), eq("id", 1))), null)
checkConvertFilters(Array(and(eq("id", 1), not(unsupport()))), null)
}

private def checkConvertFilter(filter: Filter, expectExpression: String): Unit = {
val exp = convertToCatalystExpression(filter, tableSchema)
if (expectExpression == null) {
assertEquals(exp.isEmpty, true)
} else {
assertEquals(exp.isDefined, true)
assertEquals(expectExpression, exp.get.sql)
}
}

private def checkConvertFilters(filters: Array[Filter], expectExpression: String): Unit = {
val exp = convertToCatalystExpressions(filters, tableSchema)
if (expectExpression == null) {
assertEquals(exp.isEmpty, true)
} else {
assertEquals(exp.isDefined, true)
assertEquals(expectExpression, exp.get.sql)
}
}

private def eq(attribute: String, value: Any): Filter = {
EqualTo(attribute, value)
}

private def eqs(attribute: String, value: Any): Filter = {
EqualNullSafe(attribute, value)
}

private def gt(attribute: String, value: Any): Filter = {
GreaterThan(attribute, value)
}

private def gte(attribute: String, value: Any): Filter = {
GreaterThanOrEqual(attribute, value)
}

private def lt(attribute: String, value: Any): Filter = {
LessThan(attribute, value)
}

private def lte(attribute: String, value: Any): Filter = {
LessThanOrEqual(attribute, value)
}

private def in(attribute: String, values: Any*): Filter = {
In(attribute, values.toArray)
}

private def isNull(attribute: String): Filter = {
IsNull(attribute)
}

private def isNotNull(attribute: String): Filter = {
IsNotNull(attribute)
}

private def and(left: Filter, right: Filter): Filter = {
And(left, right)
}

private def or(left: Filter, right: Filter): Filter = {
Or(left, right)
}

private def not(child: Filter): Filter = {
Not(child)
}

private def startWith(attribute: String, value: String): Filter = {
StringStartsWith(attribute, value)
}

private def endWith(attribute: String, value: String): Filter = {
StringEndsWith(attribute, value)
}

private def contains(attribute: String, value: String): Filter = {
StringContains(attribute, value)
}

private def unsupport(): Filter = {
UnSupportFilter("")
}

case class UnSupportFilter(value: Any) extends Filter {
override def references: Array[String] = Array.empty
}
}

0 comments on commit c4b07cb

Please sign in to comment.