Skip to content

Commit

Permalink
Spark 3.4: Support pushing down system functions by V2 filters (#7886)
Browse files Browse the repository at this point in the history
  • Loading branch information
ConeyLiu committed Aug 1, 2023
1 parent 8203b72 commit 6875577
Show file tree
Hide file tree
Showing 8 changed files with 1,394 additions and 89 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,26 +19,33 @@
package org.apache.iceberg.spark;

import static org.apache.iceberg.expressions.Expressions.and;
import static org.apache.iceberg.expressions.Expressions.bucket;
import static org.apache.iceberg.expressions.Expressions.day;
import static org.apache.iceberg.expressions.Expressions.equal;
import static org.apache.iceberg.expressions.Expressions.greaterThan;
import static org.apache.iceberg.expressions.Expressions.greaterThanOrEqual;
import static org.apache.iceberg.expressions.Expressions.hour;
import static org.apache.iceberg.expressions.Expressions.in;
import static org.apache.iceberg.expressions.Expressions.isNaN;
import static org.apache.iceberg.expressions.Expressions.isNull;
import static org.apache.iceberg.expressions.Expressions.lessThan;
import static org.apache.iceberg.expressions.Expressions.lessThanOrEqual;
import static org.apache.iceberg.expressions.Expressions.month;
import static org.apache.iceberg.expressions.Expressions.not;
import static org.apache.iceberg.expressions.Expressions.notEqual;
import static org.apache.iceberg.expressions.Expressions.notIn;
import static org.apache.iceberg.expressions.Expressions.notNaN;
import static org.apache.iceberg.expressions.Expressions.notNull;
import static org.apache.iceberg.expressions.Expressions.or;
import static org.apache.iceberg.expressions.Expressions.ref;
import static org.apache.iceberg.expressions.Expressions.startsWith;
import static org.apache.iceberg.expressions.Expressions.truncate;
import static org.apache.iceberg.expressions.Expressions.year;

import java.util.Arrays;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.iceberg.expressions.Expression;
import org.apache.iceberg.expressions.Expression.Operation;
Expand All @@ -47,10 +54,12 @@
import org.apache.iceberg.expressions.UnboundTerm;
import org.apache.iceberg.relocated.com.google.common.base.Preconditions;
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap;
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet;
import org.apache.iceberg.util.NaNUtil;
import org.apache.iceberg.util.Pair;
import org.apache.spark.sql.connector.expressions.Literal;
import org.apache.spark.sql.connector.expressions.NamedReference;
import org.apache.spark.sql.connector.expressions.UserDefinedScalarFunc;
import org.apache.spark.sql.connector.expressions.filter.And;
import org.apache.spark.sql.connector.expressions.filter.Not;
import org.apache.spark.sql.connector.expressions.filter.Or;
Expand All @@ -59,6 +68,9 @@

public class SparkV2Filters {

public static final Set<String> SUPPORTED_FUNCTIONS =
ImmutableSet.of("years", "months", "days", "hours", "bucket", "truncate");

private static final String TRUE = "ALWAYS_TRUE";
private static final String FALSE = "ALWAYS_FALSE";
private static final String EQ = "=";
Expand Down Expand Up @@ -98,6 +110,18 @@ public class SparkV2Filters {

private SparkV2Filters() {}

public static Expression convert(Predicate[] predicates) {
Expression expression = Expressions.alwaysTrue();
for (Predicate predicate : predicates) {
Expression converted = convert(predicate);
Preconditions.checkArgument(
converted != null, "Cannot convert Spark predicate to Iceberg expression: %s", predicate);
expression = Expressions.and(expression, converted);
}

return expression;
}

@SuppressWarnings({"checkstyle:CyclomaticComplexity", "checkstyle:MethodLength"})
public static Expression convert(Predicate predicate) {
Operation op = FILTERS.get(predicate.name());
Expand All @@ -110,51 +134,69 @@ public static Expression convert(Predicate predicate) {
return Expressions.alwaysFalse();

case IS_NULL:
return isRef(child(predicate)) ? isNull(SparkUtil.toColumnName(child(predicate))) : null;
if (canConvertToTerm(child(predicate))) {
UnboundTerm<Object> term = toTerm(child(predicate));
return term != null ? isNull(term) : null;
}

return null;

case NOT_NULL:
return isRef(child(predicate)) ? notNull(SparkUtil.toColumnName(child(predicate))) : null;
if (canConvertToTerm(child(predicate))) {
UnboundTerm<Object> term = toTerm(child(predicate));
return term != null ? notNull(term) : null;
}

return null;

case LT:
if (isRef(leftChild(predicate)) && isLiteral(rightChild(predicate))) {
String columnName = SparkUtil.toColumnName(leftChild(predicate));
return lessThan(columnName, convertLiteral(rightChild(predicate)));
} else if (isRef(rightChild(predicate)) && isLiteral(leftChild(predicate))) {
String columnName = SparkUtil.toColumnName(rightChild(predicate));
return greaterThan(columnName, convertLiteral(leftChild(predicate)));
if (canConvertToTerm(leftChild(predicate)) && isLiteral(rightChild(predicate))) {
UnboundTerm<Object> term = toTerm(leftChild(predicate));
return term != null ? lessThan(term, convertLiteral(rightChild(predicate))) : null;
} else if (canConvertToTerm(rightChild(predicate)) && isLiteral(leftChild(predicate))) {
UnboundTerm<Object> term = toTerm(rightChild(predicate));
return term != null ? greaterThan(term, convertLiteral(leftChild(predicate))) : null;
} else {
return null;
}

case LT_EQ:
if (isRef(leftChild(predicate)) && isLiteral(rightChild(predicate))) {
String columnName = SparkUtil.toColumnName(leftChild(predicate));
return lessThanOrEqual(columnName, convertLiteral(rightChild(predicate)));
} else if (isRef(rightChild(predicate)) && isLiteral(leftChild(predicate))) {
String columnName = SparkUtil.toColumnName(rightChild(predicate));
return greaterThanOrEqual(columnName, convertLiteral(leftChild(predicate)));
if (canConvertToTerm(leftChild(predicate)) && isLiteral(rightChild(predicate))) {
UnboundTerm<Object> term = toTerm(leftChild(predicate));
return term != null
? lessThanOrEqual(term, convertLiteral(rightChild(predicate)))
: null;
} else if (canConvertToTerm(rightChild(predicate)) && isLiteral(leftChild(predicate))) {
UnboundTerm<Object> term = toTerm(rightChild(predicate));
return term != null
? greaterThanOrEqual(term, convertLiteral(leftChild(predicate)))
: null;
} else {
return null;
}

case GT:
if (isRef(leftChild(predicate)) && isLiteral(rightChild(predicate))) {
String columnName = SparkUtil.toColumnName(leftChild(predicate));
return greaterThan(columnName, convertLiteral(rightChild(predicate)));
} else if (isRef(rightChild(predicate)) && isLiteral(leftChild(predicate))) {
String columnName = SparkUtil.toColumnName(rightChild(predicate));
return lessThan(columnName, convertLiteral(leftChild(predicate)));
if (canConvertToTerm(leftChild(predicate)) && isLiteral(rightChild(predicate))) {
UnboundTerm<Object> term = toTerm(leftChild(predicate));
return term != null ? greaterThan(term, convertLiteral(rightChild(predicate))) : null;
} else if (canConvertToTerm(rightChild(predicate)) && isLiteral(leftChild(predicate))) {
UnboundTerm<Object> term = toTerm(rightChild(predicate));
return term != null ? lessThan(term, convertLiteral(leftChild(predicate))) : null;
} else {
return null;
}

case GT_EQ:
if (isRef(leftChild(predicate)) && isLiteral(rightChild(predicate))) {
String columnName = SparkUtil.toColumnName(leftChild(predicate));
return greaterThanOrEqual(columnName, convertLiteral(rightChild(predicate)));
} else if (isRef(rightChild(predicate)) && isLiteral(leftChild(predicate))) {
String columnName = SparkUtil.toColumnName(rightChild(predicate));
return lessThanOrEqual(columnName, convertLiteral(leftChild(predicate)));
if (canConvertToTerm(leftChild(predicate)) && isLiteral(rightChild(predicate))) {
UnboundTerm<Object> term = toTerm(leftChild(predicate));
return term != null
? greaterThanOrEqual(term, convertLiteral(rightChild(predicate)))
: null;
} else if (canConvertToTerm(rightChild(predicate)) && isLiteral(leftChild(predicate))) {
UnboundTerm<Object> term = toTerm(rightChild(predicate));
return term != null
? lessThanOrEqual(term, convertLiteral(leftChild(predicate)))
: null;
} else {
return null;
}
Expand Down Expand Up @@ -191,13 +233,17 @@ public static Expression convert(Predicate predicate) {

case IN:
if (isSupportedInPredicate(predicate)) {
return in(
SparkUtil.toColumnName(childAtIndex(predicate, 0)),
Arrays.stream(predicate.children())
.skip(1)
.map(val -> convertLiteral(((Literal<?>) val)))
.filter(Objects::nonNull)
.collect(Collectors.toList()));
UnboundTerm<Object> term = toTerm(childAtIndex(predicate, 0));

return term != null
? in(
term,
Arrays.stream(predicate.children())
.skip(1)
.map(val -> convertLiteral(((Literal<?>) val)))
.filter(Objects::nonNull)
.collect(Collectors.toList()))
: null;
} else {
return null;
}
Expand All @@ -206,18 +252,23 @@ public static Expression convert(Predicate predicate) {
Not notPredicate = (Not) predicate;
Predicate childPredicate = notPredicate.child();
if (childPredicate.name().equals(IN) && isSupportedInPredicate(childPredicate)) {
UnboundTerm<Object> term = toTerm(childAtIndex(childPredicate, 0));
if (term == null) {
return null;
}

// infer an extra notNull predicate for Spark NOT IN filters
// as Iceberg expressions don't follow the 3-value SQL boolean logic
// col NOT IN (1, 2) in Spark is equal to notNull(col) && notIn(col, 1, 2) in Iceberg
Expression notIn =
notIn(
SparkUtil.toColumnName(childAtIndex(childPredicate, 0)),
term,
Arrays.stream(childPredicate.children())
.skip(1)
.map(val -> convertLiteral(((Literal<?>) val)))
.filter(Objects::nonNull)
.collect(Collectors.toList()));
return and(notNull(SparkUtil.toColumnName(childAtIndex(childPredicate, 0))), notIn);
return and(notNull(term), notIn);
} else if (hasNoInFilter(childPredicate)) {
Expression child = convert(childPredicate);
if (child != null) {
Expand Down Expand Up @@ -258,15 +309,13 @@ public static Expression convert(Predicate predicate) {
}

private static Pair<UnboundTerm<Object>, Object> predicateChildren(Predicate predicate) {
if (isRef(leftChild(predicate)) && isLiteral(rightChild(predicate))) {
UnboundTerm<Object> term = ref(SparkUtil.toColumnName(leftChild(predicate)));
Object value = convertLiteral(rightChild(predicate));
return Pair.of(term, value);
if (canConvertToTerm(leftChild(predicate)) && isLiteral(rightChild(predicate))) {
UnboundTerm<Object> term = toTerm(leftChild(predicate));
return term != null ? Pair.of(term, convertLiteral(rightChild(predicate))) : null;

} else if (isRef(rightChild(predicate)) && isLiteral(leftChild(predicate))) {
UnboundTerm<Object> term = ref(SparkUtil.toColumnName(rightChild(predicate)));
Object value = convertLiteral(leftChild(predicate));
return Pair.of(term, value);
} else if (canConvertToTerm(rightChild(predicate)) && isLiteral(leftChild(predicate))) {
UnboundTerm<Object> term = toTerm(rightChild(predicate));
return term != null ? Pair.of(term, convertLiteral(leftChild(predicate))) : null;

} else {
return null;
Expand Down Expand Up @@ -302,10 +351,26 @@ private static <T> T childAtIndex(Predicate predicate, int index) {
return (T) predicate.children()[index];
}

private static boolean canConvertToTerm(
org.apache.spark.sql.connector.expressions.Expression expr) {
return isRef(expr) || isSystemFunc(expr);
}

private static boolean isRef(org.apache.spark.sql.connector.expressions.Expression expr) {
return expr instanceof NamedReference;
}

private static boolean isSystemFunc(org.apache.spark.sql.connector.expressions.Expression expr) {
if (expr instanceof UserDefinedScalarFunc) {
UserDefinedScalarFunc udf = (UserDefinedScalarFunc) expr;
return udf.canonicalName().startsWith("iceberg")
&& SUPPORTED_FUNCTIONS.contains(udf.name())
&& Arrays.stream(udf.children()).allMatch(child -> isLiteral(child) || isRef(child));
}

return false;
}

private static boolean isLiteral(org.apache.spark.sql.connector.expressions.Expression expr) {
return expr instanceof Literal;
}
Expand Down Expand Up @@ -360,10 +425,57 @@ private static boolean hasNoInFilter(Predicate predicate) {
}

private static boolean isSupportedInPredicate(Predicate predicate) {
if (!isRef(childAtIndex(predicate, 0))) {
if (!canConvertToTerm(childAtIndex(predicate, 0))) {
return false;
} else {
return Arrays.stream(predicate.children()).skip(1).allMatch(SparkV2Filters::isLiteral);
}
}

/** Should be called after {@link #canConvertToTerm} passed */
private static <T> UnboundTerm<Object> toTerm(T input) {
if (input instanceof NamedReference) {
return Expressions.ref(SparkUtil.toColumnName((NamedReference) input));
} else if (input instanceof UserDefinedScalarFunc) {
return udfToTerm((UserDefinedScalarFunc) input);
} else {
return null;
}
}

@SuppressWarnings("checkstyle:CyclomaticComplexity")
private static UnboundTerm<Object> udfToTerm(UserDefinedScalarFunc udf) {
org.apache.spark.sql.connector.expressions.Expression[] children = udf.children();
String udfName = udf.name().toLowerCase(Locale.ROOT);
if (children.length == 1) {
org.apache.spark.sql.connector.expressions.Expression child = children[0];
if (isRef(child)) {
String column = SparkUtil.toColumnName((NamedReference) child);
switch (udfName) {
case "years":
return year(column);
case "months":
return month(column);
case "days":
return day(column);
case "hours":
return hour(column);
}
}
} else if (children.length == 2) {
if (isLiteral(children[0]) && isRef(children[1])) {
String column = SparkUtil.toColumnName((NamedReference) children[1]);
switch (udfName) {
case "bucket":
int numBuckets = (Integer) convertLiteral((Literal<?>) children[0]);
return bucket(column, numBuckets);
case "truncate":
int width = (Integer) convertLiteral((Literal<?>) children[0]);
return truncate(column, width);
}
}
}

return null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,20 +45,20 @@
import org.apache.iceberg.relocated.com.google.common.collect.Maps;
import org.apache.iceberg.relocated.com.google.common.collect.Sets;
import org.apache.iceberg.spark.Spark3Util;
import org.apache.iceberg.spark.SparkFilters;
import org.apache.iceberg.spark.SparkReadConf;
import org.apache.iceberg.spark.SparkSchemaUtil;
import org.apache.iceberg.spark.SparkV2Filters;
import org.apache.iceberg.util.SnapshotUtil;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.connector.expressions.NamedReference;
import org.apache.spark.sql.connector.expressions.filter.Predicate;
import org.apache.spark.sql.connector.read.Statistics;
import org.apache.spark.sql.connector.read.SupportsRuntimeFiltering;
import org.apache.spark.sql.sources.Filter;
import org.apache.spark.sql.connector.read.SupportsRuntimeV2Filtering;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

class SparkBatchQueryScan extends SparkPartitioningAwareScan<PartitionScanTask>
implements SupportsRuntimeFiltering {
implements SupportsRuntimeV2Filtering {

private static final Logger LOG = LoggerFactory.getLogger(SparkBatchQueryScan.class);

Expand Down Expand Up @@ -119,8 +119,8 @@ public NamedReference[] filterAttributes() {
}

@Override
public void filter(Filter[] filters) {
Expression runtimeFilterExpr = convertRuntimeFilters(filters);
public void filter(Predicate[] predicates) {
Expression runtimeFilterExpr = convertRuntimeFilters(predicates);

if (runtimeFilterExpr != Expressions.alwaysTrue()) {
Map<Integer, Evaluator> evaluatorsBySpecId = Maps.newHashMap();
Expand Down Expand Up @@ -160,11 +160,11 @@ public void filter(Filter[] filters) {

// at this moment, Spark can only pass IN filters for a single attribute
// if there are multiple filter attributes, Spark will pass two separate IN filters
private Expression convertRuntimeFilters(Filter[] filters) {
private Expression convertRuntimeFilters(Predicate[] predicates) {
Expression runtimeFilterExpr = Expressions.alwaysTrue();

for (Filter filter : filters) {
Expression expr = SparkFilters.convert(filter);
for (Predicate predicate : predicates) {
Expression expr = SparkV2Filters.convert(predicate);
if (expr != null) {
try {
Binder.bind(expectedSchema().asStruct(), expr, caseSensitive());
Expand All @@ -173,7 +173,7 @@ private Expression convertRuntimeFilters(Filter[] filters) {
LOG.warn("Failed to bind {} to expected schema, skipping runtime filter", expr, e);
}
} else {
LOG.warn("Unsupported runtime filter {}", filter);
LOG.warn("Unsupported runtime filter {}", predicate);
}
}

Expand Down

0 comments on commit 6875577

Please sign in to comment.