Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import org.apache.iceberg.expressions.Evaluator;
import org.apache.iceberg.expressions.Expression;
import org.apache.iceberg.expressions.ExpressionUtil;
import org.apache.iceberg.expressions.ExpressionVisitors;
import org.apache.iceberg.expressions.Expressions;
import org.apache.iceberg.expressions.Projections;
import org.apache.iceberg.metrics.ScanReport;
Expand Down Expand Up @@ -195,6 +196,10 @@ private Expression convertRuntimePredicates(Predicate[] predicates) {
}

protected String runtimeFiltersDesc() {
return Spark3Util.describe(runtimeFilters);
return runtimeFilters.stream()
.flatMap(x -> ExpressionVisitors.visit(x, ExpressionFlattener.INSTANCE).stream())
.map(Spark3Util::describe)
.sorted()
.collect(Collectors.joining(", "));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,11 @@
import org.apache.iceberg.SnapshotSummary;
import org.apache.iceberg.StatisticsFile;
import org.apache.iceberg.Table;
import org.apache.iceberg.expressions.BoundPredicate;
import org.apache.iceberg.expressions.Expression;
import org.apache.iceberg.expressions.ExpressionVisitors;
import org.apache.iceberg.expressions.Expressions;
import org.apache.iceberg.expressions.UnboundPredicate;
import org.apache.iceberg.io.FileIO;
import org.apache.iceberg.metrics.ScanReport;
import org.apache.iceberg.relocated.com.google.common.base.Strings;
Expand Down Expand Up @@ -161,7 +164,11 @@ protected Expression filter() {
}

protected String filtersDesc() {
return Spark3Util.describe(filters);
return filters.stream()
.flatMap(x -> ExpressionVisitors.visit(x, ExpressionFlattener.INSTANCE).stream())
.map(Spark3Util::describe)
.sorted()
.collect(Collectors.joining(", "));
}

protected Types.StructType groupingKeyType() {
Expand Down Expand Up @@ -384,4 +391,52 @@ protected long adjustSplitSize(List<? extends ScanTask> tasks, long splitSize) {
return splitSize;
}
}

static class ExpressionFlattener extends ExpressionVisitors.ExpressionVisitor<List<Expression>> {

protected static final ExpressionFlattener INSTANCE = new ExpressionFlattener();

private ExpressionFlattener() {}

@Override
public List<Expression> alwaysTrue() {
return List.of(Expressions.alwaysTrue());
}

@Override
public List<Expression> alwaysFalse() {
return List.of(Expressions.alwaysFalse());
}

@Override
public List<Expression> not(List<Expression> result) {
return List.of(Expressions.not(mergeExpressions(result)));
}

@Override
public List<Expression> and(List<Expression> leftResult, List<Expression> rightResult) {
List<Expression> flattened = Lists.newArrayList(leftResult);
flattened.addAll(rightResult);
return flattened;
}

@Override
public List<Expression> or(List<Expression> leftResult, List<Expression> rightResult) {
return List.of(Expressions.or(mergeExpressions(leftResult), mergeExpressions(rightResult)));
}

@Override
public <T> List<Expression> predicate(BoundPredicate<T> pred) {
return List.of(pred);
}

@Override
public <T> List<Expression> predicate(UnboundPredicate<T> pred) {
return List.of(pred);
}

private Expression mergeExpressions(List<Expression> toMerge) {
return toMerge.stream().reduce(Expressions.alwaysTrue(), Expressions::and);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,14 @@
import org.apache.spark.sql.connector.expressions.filter.Predicate;
import org.apache.spark.sql.connector.read.Batch;
import org.apache.spark.sql.connector.read.InputPartition;
import org.apache.spark.sql.connector.read.Scan;
import org.apache.spark.sql.connector.read.ScanBuilder;
import org.apache.spark.sql.connector.read.SupportsPushDownV2Filters;
import org.apache.spark.sql.sources.And;
import org.apache.spark.sql.sources.EqualTo;
import org.apache.spark.sql.sources.Filter;
import org.apache.spark.sql.sources.GreaterThan;
import org.apache.spark.sql.sources.In;
import org.apache.spark.sql.sources.LessThan;
import org.apache.spark.sql.sources.Not;
import org.apache.spark.sql.sources.StringStartsWith;
Expand Down Expand Up @@ -269,6 +271,53 @@ public void testUnpartitionedTimestampFilter() {
"ts < cast('2017-12-22 00:00:00+00:00' as timestamp)"));
}

@TestTemplate
public void testSparkBatchScanEquality_Bug16563() {
CaseInsensitiveStringMap options =
new CaseInsensitiveStringMap(ImmutableMap.of("path", unpartitioned.toString()));

// set spark.sql.caseSensitive to false
String caseSensitivityBeforeTest = TestFilteredScan.spark.conf().get("spark.sql.caseSensitive");
TestFilteredScan.spark.conf().set("spark.sql.caseSensitive", "false");
try {
SparkScanBuilder builder1 =
new SparkScanBuilder(spark, TABLES.load(options.get("path")), options);
Filter f1 = GreaterThan.apply("ID", 10);
Filter f2 = LessThan.apply("data", "abc");

pushFilters(builder1, f1, f2);

SparkScanBuilder builder2 =
new SparkScanBuilder(spark, TABLES.load(options.get("path")), options);
pushFilters(builder2, f2, f1);

Scan scan1 = builder1.build();
Scan scan2 = builder2.build();

assertThat(scan1).as("is be equal to ").isEqualTo(scan2);

Filter runtimeFilter1 = In.apply("ID", new Object[] {10, 12, 13});
Filter runtimeFilter2 = In.apply("data", new Object[] {"abc", "def", "cde"});

((SparkRuntimeFilterableScan) scan1)
.filter(
Arrays.stream(new Filter[] {runtimeFilter1, runtimeFilter2})
.map(Filter::toV2)
.toArray(Predicate[]::new));

((SparkRuntimeFilterableScan) scan2)
.filter(
Arrays.stream(new Filter[] {runtimeFilter2, runtimeFilter1})
.map(Filter::toV2)
.toArray(Predicate[]::new));

assertThat(scan1).as("is be equal to ").isEqualTo(scan2);
} finally {
// return global conf to previous state
TestFilteredScan.spark.conf().set("spark.sql.caseSensitive", caseSensitivityBeforeTest);
}
}

@TestTemplate
public void limitPushedDownToSparkScan() {
assumeThat(fileFormat)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import org.apache.iceberg.spark.SparkCatalogConfig;
import org.apache.iceberg.spark.SparkReadOptions;
import org.apache.iceberg.spark.SparkSQLProperties;
import org.apache.iceberg.spark.SparkV2Filters;
import org.apache.iceberg.spark.TestBaseWithCatalog;
import org.apache.iceberg.spark.functions.BucketFunction;
import org.apache.iceberg.spark.functions.DaysFunction;
Expand Down Expand Up @@ -1061,6 +1062,14 @@ public void testCopyOnWriteScanDescription() throws Exception {
() -> {
Predicate predicate1 = new Predicate("=", expressions(fieldRef("id"), intLit(2)));
Predicate predicate2 = new Predicate("<", expressions(fieldRef("id"), intLit(10)));
String filter1Desc = Spark3Util.describe(SparkV2Filters.convert(predicate1));
String filter2Desc = Spark3Util.describe(SparkV2Filters.convert(predicate2));
String expectedFilterDesc;
if (filter1Desc.compareTo(filter2Desc) < 0) {
expectedFilterDesc = filter1Desc + ", " + filter2Desc;
} else {
expectedFilterDesc = filter2Desc + ", " + filter1Desc;
}
pushFilters(builder, predicate1, predicate2);

Scan scan = builder.buildCopyOnWriteScan();
Expand All @@ -1071,7 +1080,7 @@ public void testCopyOnWriteScanDescription() throws Exception {
assertThat(description).contains("schemaId=" + table.schema().schemaId());
assertThat(description).contains("snapshotId=" + table.currentSnapshot().snapshotId());
assertThat(description).contains("branch=null");
assertThat(description).contains("filters=id = 2, id < 10");
assertThat(description).contains("filters=" + expectedFilterDesc);
assertThat(description).contains("groupedBy=data");
});
}
Expand Down
Loading