Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-30323][SQL] Support filters pushdown in CSV datasource #26973

Closed
wants to merge 50 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
895638f
Return Seq[InternalRow] from convert()
MaxGekk Dec 16, 2019
4bc8d9b
Pass filters to CSV datasource v1
MaxGekk Dec 16, 2019
0124199
Add CSVFilters
MaxGekk Dec 16, 2019
fb8912e
Add filterToExpression
MaxGekk Dec 16, 2019
c2515b6
Initial impl of CSVFilters
MaxGekk Dec 17, 2019
9ced607
Support filters push down in CSV v2
MaxGekk Dec 17, 2019
20dbef0
Add a test to CSVSuite
MaxGekk Dec 17, 2019
becfe1e
Keep only one predicate per field
MaxGekk Dec 17, 2019
77e7d54
Add a benchmark
MaxGekk Dec 18, 2019
415e4ce
SQL config `spark.sql.csv.filterPushdown.enabled`
MaxGekk Dec 18, 2019
3db517f
Use SQL config in CSVBenchmark
MaxGekk Dec 18, 2019
98963bc
Refactoring
MaxGekk Dec 18, 2019
05111a5
Add comments for skipRow
MaxGekk Dec 18, 2019
899cf17
Apply filters only on CSV level
MaxGekk Dec 18, 2019
d08fe58
Add a comment for `predicates`
MaxGekk Dec 18, 2019
b0a34b3
Add a comment for CSVFilters
MaxGekk Dec 18, 2019
5fe5600
Add a comment for `unsupportedFilters`
MaxGekk Dec 18, 2019
a7f3006
Add comments
MaxGekk Dec 19, 2019
c989bee
Add tests to UnivocityParserSuite
MaxGekk Dec 19, 2019
124c45d
Support AlwaysTrue and AlwaysFalse filters
MaxGekk Dec 19, 2019
d7932c2
Add tests for filterToExpression()
MaxGekk Dec 19, 2019
bb0abf4
Add tests for readSchema
MaxGekk Dec 19, 2019
1c707e5
Add tests for skipRow()
MaxGekk Dec 19, 2019
11bcbc6
Benchmarks at the commit 67b644c3d7
MaxGekk Dec 20, 2019
a5088bd
Revert "Benchmarks at the commit 67b644c3d7"
MaxGekk Dec 20, 2019
f0cc83c
Update benchmarks results
MaxGekk Dec 20, 2019
e7b3304
Merge remote-tracking branch 'remotes/origin/master' into csv-filters…
MaxGekk Dec 20, 2019
f24e873
Add equals(), hashCode() and description() to CSVScan
MaxGekk Dec 20, 2019
55ebb60
Tests for column pruning on/off + refactoring
MaxGekk Dec 22, 2019
170944c
Simplifying parsedSchema initialization
MaxGekk Dec 22, 2019
c18ea7b
Merge remote-tracking branch 'origin/master' into csv-filters-pushdown
MaxGekk Dec 24, 2019
17e742f
Merge remote-tracking branch 'remotes/origin/master' into csv-filters…
MaxGekk Jan 9, 2020
f296259
Test the multiLine mode
MaxGekk Jan 9, 2020
c7eac1f
Merge remote-tracking branch 'remotes/origin/master' into csv-filters…
MaxGekk Jan 10, 2020
61eaa36
Follow-up merging
MaxGekk Jan 10, 2020
f0aa0a8
Bug fix
MaxGekk Jan 11, 2020
1cc46b1
Merge remote-tracking branch 'remotes/origin/master' into csv-filters…
MaxGekk Jan 11, 2020
711a703
Apply filters w/o refs at pos 0 only
MaxGekk Jan 13, 2020
a3a99b1
Merge remote-tracking branch 'remotes/origin/master' into csv-filters…
MaxGekk Jan 13, 2020
4a25815
Fix build error
MaxGekk Jan 13, 2020
c03ae06
Put literal filters in front of others
MaxGekk Jan 13, 2020
18389b0
Test more options/modes in the end-to-end test
MaxGekk Jan 13, 2020
e302fa4
Bug fix: malformed input + permissive mode + columnNameOfCorruptRecor…
MaxGekk Jan 13, 2020
96e9554
Remove unnecessary setNullAt
MaxGekk Jan 13, 2020
1be5534
Remove checkFilters()
MaxGekk Jan 14, 2020
9217536
Remove private[sql] for parsedSchema
MaxGekk Jan 14, 2020
15c9648
Simplify code assuming that requireSchema contains all filter refs
MaxGekk Jan 14, 2020
f2c3b3e
Merge remote-tracking branch 'remotes/origin/master' into csv-filters…
MaxGekk Jan 15, 2020
df30439
Use intercept in UnivocityParserSuite
MaxGekk Jan 15, 2020
06be013
Remove nested getSchema() in UnivocityParserSuite
MaxGekk Jan 15, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.csv

import scala.util.Try

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources
import org.apache.spark.sql.types.{BooleanType, StructType}

/**
* An instance of the class compiles filters to predicates and allows to
* apply the predicates to an internal row with partially initialized values
* converted from parsed CSV fields.
*
* @param filters The filters pushed down to CSV datasource.
* @param requiredSchema The schema with only fields requested by the upper layer.
*/
class CSVFilters(filters: Seq[sources.Filter], requiredSchema: StructType) {
/**
* Converted filters to predicates and grouped by maximum field index
* in the read schema. For example, if an filter refers to 2 attributes
* attrA with field index 5 and attrB with field index 10 in the read schema:
* 0 === $"attrA" or $"attrB" < 100
* the filter is compiled to a predicate, and placed to the `predicates`
* array at the position 10. In this way, if there is a row with initialized
* fields from the 0 to 10 index, the predicate can be applied to the row
* to check that the row should be skipped or not.
* Multiple predicates with the same maximum reference index are combined
* by the `And` expression.
*/
private val predicates: Array[BasePredicate] = {
val len = requiredSchema.fields.length
val groupedPredicates = Array.fill[BasePredicate](len)(null)
if (SQLConf.get.csvFilterPushDown) {
val groupedFilters = Array.fill(len)(Seq.empty[sources.Filter])
for (filter <- filters) {
val refs = filter.references
val index = if (refs.isEmpty) {
// For example, AlwaysTrue and AlwaysFalse doesn't have any references
// Filters w/o refs always return the same result. Taking into account
// that predicates are combined via And, we can apply such filters only
// once at the position 0.
0
} else {
// readSchema must contain attributes of all filters.
// Accordingly, fieldIndex() returns a valid index always.
refs.map(requiredSchema.fieldIndex).max
}
groupedFilters(index) :+= filter
}
if (len > 0 && !groupedFilters(0).isEmpty) {
// We assume that filters w/o refs like AlwaysTrue and AlwaysFalse
// can be evaluated faster that others. We put them in front of others.
val (literals, others) = groupedFilters(0).partition(_.references.isEmpty)
groupedFilters(0) = literals ++ others
}
for (i <- 0 until len) {
if (!groupedFilters(i).isEmpty) {
val reducedExpr = groupedFilters(i)
.flatMap(CSVFilters.filterToExpression(_, toRef))
.reduce(And)
groupedPredicates(i) = Predicate.create(reducedExpr)
}
}
}
groupedPredicates
}

/**
* Applies all filters that refer to row fields at the positions from 0 to index.
* @param row The internal row to check.
* @param index Maximum field index. The function assumes that all fields
* from 0 to index position are set.
* @return false iff row fields at the position from 0 to index pass filters
* or there are no applicable filters
* otherwise false if at least one of the filters returns false.
*/
def skipRow(row: InternalRow, index: Int): Boolean = {
val predicate = predicates(index)
predicate != null && !predicate.eval(row)
}

// Finds a filter attribute in the read schema and converts it to a `BoundReference`
private def toRef(attr: String): Option[BoundReference] = {
requiredSchema.getFieldIndex(attr).map { index =>
val field = requiredSchema(index)
BoundReference(requiredSchema.fieldIndex(attr), field.dataType, field.nullable)
}
}
}

object CSVFilters {
private def checkFilterRefs(filter: sources.Filter, schema: StructType): Boolean = {
val fieldNames = schema.fields.map(_.name).toSet
filter.references.forall(fieldNames.contains(_))
}

/**
* Returns the filters currently supported by CSV datasource.
* @param filters The filters pushed down to CSV datasource.
* @param schema data schema of CSV files.
* @return a sub-set of `filters` that can be handled by CSV datasource.
*/
def pushedFilters(filters: Array[sources.Filter], schema: StructType): Array[sources.Filter] = {
filters.filter(checkFilterRefs(_, schema))
}

private def zip[A, B](a: Option[A], b: Option[B]): Option[(A, B)] = {
a.zip(b).headOption
}

private def toLiteral(value: Any): Option[Literal] = {
Try(Literal(value)).toOption
}

/**
* Converts a filter to an expression and binds it to row positions.
*
* @param filter The filter to convert.
* @param toRef The function converts a filter attribute to a bound reference.
* @return some expression with resolved attributes or None if the conversion
* of the given filter to an expression is impossible.
*/
def filterToExpression(
filter: sources.Filter,
toRef: String => Option[BoundReference]): Option[Expression] = {
def zipAttributeAndValue(name: String, value: Any): Option[(BoundReference, Literal)] = {
zip(toRef(name), toLiteral(value))
}
def translate(filter: sources.Filter): Option[Expression] = filter match {
case sources.And(left, right) =>
zip(translate(left), translate(right)).map(And.tupled)
case sources.Or(left, right) =>
zip(translate(left), translate(right)).map(Or.tupled)
case sources.Not(child) =>
translate(child).map(Not)
case sources.EqualTo(attribute, value) =>
zipAttributeAndValue(attribute, value).map(EqualTo.tupled)
case sources.EqualNullSafe(attribute, value) =>
zipAttributeAndValue(attribute, value).map(EqualNullSafe.tupled)
case sources.IsNull(attribute) =>
toRef(attribute).map(IsNull)
case sources.IsNotNull(attribute) =>
toRef(attribute).map(IsNotNull)
case sources.In(attribute, values) =>
val literals = values.toSeq.flatMap(toLiteral)
if (literals.length == values.length) {
toRef(attribute).map(In(_, literals))
} else {
None
}
case sources.GreaterThan(attribute, value) =>
zipAttributeAndValue(attribute, value).map(GreaterThan.tupled)
case sources.GreaterThanOrEqual(attribute, value) =>
zipAttributeAndValue(attribute, value).map(GreaterThanOrEqual.tupled)
case sources.LessThan(attribute, value) =>
zipAttributeAndValue(attribute, value).map(LessThan.tupled)
case sources.LessThanOrEqual(attribute, value) =>
zipAttributeAndValue(attribute, value).map(LessThanOrEqual.tupled)
case sources.StringContains(attribute, value) =>
zipAttributeAndValue(attribute, value).map(Contains.tupled)
case sources.StringStartsWith(attribute, value) =>
zipAttributeAndValue(attribute, value).map(StartsWith.tupled)
case sources.StringEndsWith(attribute, value) =>
zipAttributeAndValue(attribute, value).map(EndsWith.tupled)
case sources.AlwaysTrue() =>
Some(Literal(true, BooleanType))
case sources.AlwaysFalse() =>
Some(Literal(false, BooleanType))
}
translate(filter)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{ExprUtils, GenericInternalRow}
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.sources.Filter
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

Expand All @@ -39,15 +40,20 @@ import org.apache.spark.unsafe.types.UTF8String
* @param requiredSchema The schema of the data that should be output for each row. This should be a
* subset of the columns in dataSchema.
* @param options Configuration options for a CSV parser.
* @param filters The pushdown filters that should be applied to converted values.
*/
class UnivocityParser(
dataSchema: StructType,
requiredSchema: StructType,
val options: CSVOptions) extends Logging {
val options: CSVOptions,
filters: Seq[Filter]) extends Logging {
require(requiredSchema.toSet.subsetOf(dataSchema.toSet),
s"requiredSchema (${requiredSchema.catalogString}) should be the subset of " +
s"dataSchema (${dataSchema.catalogString}).")

def this(dataSchema: StructType, requiredSchema: StructType, options: CSVOptions) = {
this(dataSchema, requiredSchema, options, Seq.empty)
}
def this(schema: StructType, options: CSVOptions) = this(schema, schema, options)

// A `ValueConverter` is responsible for converting the given value to a desired type.
Expand All @@ -72,7 +78,11 @@ class UnivocityParser(
new CsvParser(parserSetting)
}

private val row = new GenericInternalRow(requiredSchema.length)
// Pre-allocated Seq to avoid the overhead of the seq builder.
private val requiredRow = Seq(new GenericInternalRow(requiredSchema.length))
HyukjinKwon marked this conversation as resolved.
Show resolved Hide resolved
// Pre-allocated empty sequence returned when the parsed row cannot pass filters.
// We preallocate it avoid unnecessary invokes of the seq builder.
private val noRows = Seq.empty[InternalRow]

private val timestampFormatter = TimestampFormatter(
options.timestampFormat,
Expand All @@ -83,6 +93,8 @@ class UnivocityParser(
options.zoneId,
options.locale)

private val csvFilters = new CSVFilters(filters, requiredSchema)

// Retrieve the raw record string.
private def getCurrentInput: UTF8String = {
UTF8String.fromString(tokenizer.getContext.currentParsedContent().stripLineEnd)
Expand Down Expand Up @@ -194,7 +206,7 @@ class UnivocityParser(
private val doParse = if (options.columnPruning && requiredSchema.isEmpty) {
// If `columnPruning` enabled and partition attributes scanned only,
// `schema` gets empty.
(_: String) => InternalRow.empty
(_: String) => Seq(InternalRow.empty)
} else {
// parse if the columnPruning is disabled or requiredSchema is nonEmpty
(input: String) => convert(tokenizer.parseLine(input))
Expand All @@ -204,15 +216,15 @@ class UnivocityParser(
* Parses a single CSV string and turns it into either one resulting row or no row (if the
* the record is malformed).
*/
def parse(input: String): InternalRow = doParse(input)
def parse(input: String): Seq[InternalRow] = doParse(input)

private val getToken = if (options.columnPruning) {
(tokens: Array[String], index: Int) => tokens(index)
} else {
(tokens: Array[String], index: Int) => tokens(tokenIndexArr(index))
}

private def convert(tokens: Array[String]): InternalRow = {
private def convert(tokens: Array[String]): Seq[InternalRow] = {
MaxGekk marked this conversation as resolved.
Show resolved Hide resolved
if (tokens == null) {
throw BadRecordException(
() => getCurrentInput,
Expand All @@ -229,7 +241,7 @@ class UnivocityParser(
}
def getPartialResult(): Option[InternalRow] = {
try {
Some(convert(checkedTokens))
convert(checkedTokens).headOption
} catch {
case _: BadRecordException => None
}
Expand All @@ -242,24 +254,40 @@ class UnivocityParser(
new RuntimeException("Malformed CSV record"))
} else {
// When the length of the returned tokens is identical to the length of the parsed schema,
// we just need to convert the tokens that correspond to the required columns.
var badRecordException: Option[Throwable] = None
// we just need to:
// 1. Convert the tokens that correspond to the required schema.
// 2. Apply the pushdown filters to `requiredRow`.
var i = 0
val row = requiredRow.head
var skipRow = false
var badRecordException: Option[Throwable] = None
while (i < requiredSchema.length) {
try {
row(i) = valueConverters(i).apply(getToken(tokens, i))
if (!skipRow) {
row(i) = valueConverters(i).apply(getToken(tokens, i))
if (csvFilters.skipRow(row, i)) {
skipRow = true
}
}
if (skipRow) {
row.setNullAt(i)
}
} catch {
case NonFatal(e) =>
badRecordException = badRecordException.orElse(Some(e))
row.setNullAt(i)
}
i += 1
}

if (badRecordException.isEmpty) {
row
if (skipRow) {
noRows
} else {
throw BadRecordException(() => getCurrentInput, () => Some(row), badRecordException.get)
if (badRecordException.isDefined) {
cloud-fan marked this conversation as resolved.
Show resolved Hide resolved
throw BadRecordException(
() => getCurrentInput, () => requiredRow.headOption, badRecordException.get)
} else {
requiredRow
}
}
}
}
Expand Down Expand Up @@ -291,7 +319,7 @@ private[sql] object UnivocityParser {
schema: StructType): Iterator[InternalRow] = {
val tokenizer = parser.tokenizer
val safeParser = new FailureSafeParser[Array[String]](
input => Seq(parser.convert(input)),
input => parser.convert(input),
parser.options.parseMode,
schema,
parser.options.columnNameOfCorruptRecord)
Expand Down Expand Up @@ -344,7 +372,7 @@ private[sql] object UnivocityParser {
val filteredLines: Iterator[String] = CSVExprUtils.filterCommentAndEmpty(lines, options)

val safeParser = new FailureSafeParser[String](
input => Seq(parser.parse(input)),
input => parser.parse(input),
parser.options.parseMode,
schema,
parser.options.columnNameOfCorruptRecord)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ case class CsvToStructs(
StructType(nullableSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord))
val rawParser = new UnivocityParser(actualSchema, actualSchema, parsedOptions)
new FailureSafeParser[String](
input => Seq(rawParser.parse(input)),
input => rawParser.parse(input),
mode,
nullableSchema,
parsedOptions.columnNameOfCorruptRecord)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2153,6 +2153,11 @@ object SQLConf {
.booleanConf
.createWithDefault(false)

val CSV_FILTER_PUSHDOWN_ENABLED = buildConf("spark.sql.csv.filterPushdown.enabled")
.doc("When true, enable filter pushdown to CSV datasource.")
.booleanConf
.createWithDefault(true)

/**
* Holds information about keys that have been deprecated.
*
Expand Down Expand Up @@ -2722,6 +2727,8 @@ class SQLConf extends Serializable with Logging {

def ignoreDataLocality: Boolean = getConf(SQLConf.IGNORE_DATA_LOCALITY)

def csvFilterPushDown: Boolean = getConf(CSV_FILTER_PUSHDOWN_ENABLED)

/** ********************** SQLConf functionality methods ************ */

/** Set Spark SQL configuration properties. */
Expand Down
Loading