Skip to content

Commit

Permalink
[FLINK-16024][connector/jdbc] Support filter pushdown
Browse files Browse the repository at this point in the history
This closes #20140
  • Loading branch information
Qing Lim authored and libenchao committed Nov 2, 2022
1 parent 03c0f15 commit 4967001
Show file tree
Hide file tree
Showing 8 changed files with 880 additions and 11 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.connector.jdbc.split;

import org.apache.flink.annotation.Internal;
import org.apache.flink.util.Preconditions;

import java.io.Serializable;

/** Combine 2 {@link JdbcParameterValuesProvider} into 1. */
@Internal
public class CompositeJdbcParameterValuesProvider implements JdbcParameterValuesProvider {
JdbcParameterValuesProvider a;
JdbcParameterValuesProvider b;

public CompositeJdbcParameterValuesProvider(
JdbcParameterValuesProvider a, JdbcParameterValuesProvider b) {
Preconditions.checkArgument(
a.getParameterValues().length == b.getParameterValues().length,
"Both JdbcParameterValuesProvider should have the same length.");
this.a = a;
this.b = b;
}

@Override
public Serializable[][] getParameterValues() {
int batchNum = this.a.getParameterValues().length;
Serializable[][] parameters = new Serializable[batchNum][];
for (int i = 0; i < batchNum; i++) {
Serializable[] aSlice = a.getParameterValues()[i];
Serializable[] bSlice = b.getParameterValues()[i];
int totalLen = aSlice.length + bSlice.length;

Serializable[] batchParams = new Serializable[totalLen];

System.arraycopy(aSlice, 0, batchParams, 0, aSlice.length);
System.arraycopy(bSlice, 0, batchParams, aSlice.length, bSlice.length);
parameters[i] = batchParams;
}
return parameters;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,33 +22,51 @@
import org.apache.flink.connector.jdbc.dialect.JdbcDialect;
import org.apache.flink.connector.jdbc.internal.options.JdbcConnectorOptions;
import org.apache.flink.connector.jdbc.internal.options.JdbcReadOptions;
import org.apache.flink.connector.jdbc.split.CompositeJdbcParameterValuesProvider;
import org.apache.flink.connector.jdbc.split.JdbcGenericParameterValuesProvider;
import org.apache.flink.connector.jdbc.split.JdbcNumericBetweenParametersProvider;
import org.apache.flink.connector.jdbc.split.JdbcParameterValuesProvider;
import org.apache.flink.table.connector.ChangelogMode;
import org.apache.flink.table.connector.Projection;
import org.apache.flink.table.connector.source.DynamicTableSource;
import org.apache.flink.table.connector.source.InputFormatProvider;
import org.apache.flink.table.connector.source.LookupTableSource;
import org.apache.flink.table.connector.source.ScanTableSource;
import org.apache.flink.table.connector.source.abilities.SupportsFilterPushDown;
import org.apache.flink.table.connector.source.abilities.SupportsLimitPushDown;
import org.apache.flink.table.connector.source.abilities.SupportsProjectionPushDown;
import org.apache.flink.table.connector.source.lookup.LookupFunctionProvider;
import org.apache.flink.table.connector.source.lookup.PartialCachingLookupProvider;
import org.apache.flink.table.connector.source.lookup.cache.LookupCache;
import org.apache.flink.table.expressions.CallExpression;
import org.apache.flink.table.expressions.ResolvedExpression;
import org.apache.flink.table.types.DataType;
import org.apache.flink.table.types.logical.RowType;
import org.apache.flink.util.Preconditions;

import org.apache.commons.lang3.ArrayUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.annotation.Nullable;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;

/** A {@link DynamicTableSource} for JDBC. */
@Internal
public class JdbcDynamicTableSource
implements ScanTableSource,
LookupTableSource,
SupportsProjectionPushDown,
SupportsLimitPushDown {
SupportsLimitPushDown,
SupportsFilterPushDown {
private static final Logger LOG = LoggerFactory.getLogger(JdbcDynamicTableSource.class);

private final JdbcConnectorOptions options;
private final JdbcReadOptions readOptions;
Expand All @@ -57,6 +75,8 @@ public class JdbcDynamicTableSource
private DataType physicalRowDataType;
private final String dialectName;
private long limit = -1;
private List<String> resolvedPredicates = new ArrayList<>();
private Serializable[] pushdownParams = new Serializable[0];

public JdbcDynamicTableSource(
JdbcConnectorOptions options,
Expand Down Expand Up @@ -117,21 +137,46 @@ public ScanRuntimeProvider getScanRuntimeProvider(ScanContext runtimeProviderCon
options.getTableName(),
DataType.getFieldNames(physicalRowDataType).toArray(new String[0]),
new String[0]);
final List<String> predicates = new ArrayList<String>();

if (readOptions.getPartitionColumnName().isPresent()) {
long lowerBound = readOptions.getPartitionLowerBound().get();
long upperBound = readOptions.getPartitionUpperBound().get();
int numPartitions = readOptions.getNumPartitions().get();

Serializable[][] allPushdownParams = replicatePushdownParamsForN(numPartitions);
JdbcParameterValuesProvider allParams =
new CompositeJdbcParameterValuesProvider(
new JdbcNumericBetweenParametersProvider(lowerBound, upperBound)
.ofBatchNum(numPartitions),
new JdbcGenericParameterValuesProvider(allPushdownParams));

builder.setParametersProvider(allParams);

predicates.add(
dialect.quoteIdentifier(readOptions.getPartitionColumnName().get())
+ " BETWEEN ? AND ?");
} else {
builder.setParametersProvider(
new JdbcNumericBetweenParametersProvider(lowerBound, upperBound)
.ofBatchNum(numPartitions));
query +=
" WHERE "
+ dialect.quoteIdentifier(readOptions.getPartitionColumnName().get())
+ " BETWEEN ? AND ?";
new JdbcGenericParameterValuesProvider(replicatePushdownParamsForN(1)));
}

predicates.addAll(this.resolvedPredicates);

if (predicates.size() > 0) {
String joinedConditions =
predicates.stream()
.map(pred -> String.format("(%s)", pred))
.collect(Collectors.joining(" AND "));
query += " WHERE " + joinedConditions;
}

if (limit >= 0) {
query = String.format("%s %s", query, dialect.getLimitClause(limit));
}

LOG.debug("Query generated for JDBC scan: " + query);

builder.setQuery(query);
final RowType rowType = (RowType) physicalRowDataType.getLogicalType();
builder.setRowConverter(dialect.getRowConverter(rowType));
Expand Down Expand Up @@ -159,8 +204,12 @@ public void applyProjection(int[][] projectedFields, DataType producedDataType)

@Override
public DynamicTableSource copy() {
return new JdbcDynamicTableSource(
options, readOptions, lookupMaxRetryTimes, cache, physicalRowDataType);
JdbcDynamicTableSource newSource =
new JdbcDynamicTableSource(
options, readOptions, lookupMaxRetryTimes, cache, physicalRowDataType);
newSource.resolvedPredicates = new ArrayList<>(this.resolvedPredicates);
newSource.pushdownParams = Arrays.copyOf(this.pushdownParams, this.pushdownParams.length);
return newSource;
}

@Override
Expand All @@ -183,7 +232,9 @@ public boolean equals(Object o) {
&& Objects.equals(cache, that.cache)
&& Objects.equals(physicalRowDataType, that.physicalRowDataType)
&& Objects.equals(dialectName, that.dialectName)
&& Objects.equals(limit, that.limit);
&& Objects.equals(limit, that.limit)
&& Objects.equals(resolvedPredicates, that.resolvedPredicates)
&& Arrays.deepEquals(pushdownParams, that.pushdownParams);
}

@Override
Expand All @@ -195,11 +246,51 @@ public int hashCode() {
cache,
physicalRowDataType,
dialectName,
limit);
limit,
resolvedPredicates,
pushdownParams);
}

@Override
public void applyLimit(long limit) {
this.limit = limit;
}

@Override
public Result applyFilters(List<ResolvedExpression> filters) {
List<ResolvedExpression> acceptedFilters = new ArrayList<>();
List<ResolvedExpression> remainingFilters = new ArrayList<>();

for (ResolvedExpression filter : filters) {
Optional<ParameterizedPredicate> simplePredicate = parseFilterToPredicate(filter);
if (simplePredicate.isPresent()) {
acceptedFilters.add(filter);
ParameterizedPredicate pred = simplePredicate.get();
this.pushdownParams = ArrayUtils.addAll(this.pushdownParams, pred.getParameters());
this.resolvedPredicates.add(pred.getPredicate());
} else {
remainingFilters.add(filter);
}
}

return Result.of(acceptedFilters, remainingFilters);
}

private Optional<ParameterizedPredicate> parseFilterToPredicate(ResolvedExpression filter) {
if (filter instanceof CallExpression) {
CallExpression callExp = (CallExpression) filter;
return callExp.accept(
new JdbcFilterPushdownPreparedStatementVisitor(
this.options.getDialect()::quoteIdentifier));
}
return Optional.empty();
}

private Serializable[][] replicatePushdownParamsForN(int n) {
Serializable[][] allPushdownParams = new Serializable[n][pushdownParams.length];
for (int i = 0; i < n; i++) {
allPushdownParams[i] = this.pushdownParams;
}
return allPushdownParams;
}
}

0 comments on commit 4967001

Please sign in to comment.