Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
* relational expression in Arrow.
*/
class ArrowFilter extends Filter implements ArrowRel {
private final List<String> match;
private final List<List<ConditionToken>> match;

ArrowFilter(RelOptCluster cluster, RelTraitSet traitSet, RelNode input, RexNode condition) {
super(cluster, traitSet, input, condition);
Expand Down
15 changes: 12 additions & 3 deletions arrow/src/main/java/org/apache/calcite/adapter/arrow/ArrowRel.java
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,24 @@ public interface ArrowRel extends RelNode {
* {@link ArrowRel} nodes into a SQL query. */
class Implementor {
@Nullable List<Integer> selectFields;
final List<String> whereClause = new ArrayList<>();
final List<List<ConditionToken>> whereClause = new ArrayList<>();
@Nullable RelOptTable table;
@Nullable ArrowTable arrowTable;

/** Adds new predicates.
*
* @param predicates Predicates
* <p>The structure is two levels of nesting:
* <ul>
* <li>Outer list: conjunction (AND) of clauses
* <li>Inner list: disjunction (OR) of conditions within a clause
* </ul>
*
* <p>Each {@link ConditionToken} represents a single unary or binary
* predicate condition.
*
* @param predicates Predicates in CNF form
*/
void addFilters(List<String> predicates) {
void addFilters(List<List<ConditionToken>> predicates) {
whereClause.addAll(predicates);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
import org.apache.calcite.rel.logical.LogicalFilter;
import org.apache.calcite.rel.logical.LogicalProject;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.sql.validate.SqlValidatorUtil;

import com.google.common.collect.ImmutableList;
Expand Down Expand Up @@ -97,9 +99,13 @@ protected ArrowFilterRule(Config config) {
RelNode convert(Filter filter) {
final RelTraitSet traitSet =
filter.getTraitSet().replace(ArrowRel.CONVENTION);
// Expand SEARCH (e.g. IN, BETWEEN) before pushing to Arrow,
// since Gandiva does not support SEARCH natively.
final RexNode condition =
RexUtil.expandSearch(filter.getCluster().getRexBuilder(), null, filter.getCondition());
return new ArrowFilter(filter.getCluster(), traitSet,
convert(filter.getInput(), ArrowRel.CONVENTION),
filter.getCondition());
condition);
}

/** Rule configuration. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ public class ArrowTable extends AbstractTable
* {@link org.apache.calcite.adapter.arrow.ArrowMethod#ARROW_QUERY}. */
@SuppressWarnings("unused")
public Enumerable<Object> query(DataContext root, ImmutableIntList fields,
List<String> conditions) {
List<List<List<String>>> conditions) {
requireNonNull(fields, "fields");
final Projector projector;
final Filter filter;
Expand All @@ -119,30 +119,26 @@ public Enumerable<Object> query(DataContext root, ImmutableIntList fields,
} else {
projector = null;

final List<TreeNode> conditionNodes = new ArrayList<>(conditions.size());
for (String condition : conditions) {
String[] data = condition.split(" ");
List<TreeNode> treeNodes = new ArrayList<>(2);
treeNodes.add(
TreeBuilder.makeField(schema.getFields()
.get(schema.getFields().indexOf(schema.findField(data[0])))));

// if the split condition has more than two parts it's a binary operator
// with an additional literal node
if (data.length > 2) {
treeNodes.add(makeLiteralNode(data[2], data[3]));
final List<TreeNode> conjuncts = new ArrayList<>(conditions.size());
for (List<List<String>> orGroup : conditions) {
final List<TreeNode> disjuncts = new ArrayList<>(orGroup.size());
for (List<String> conditionParts : orGroup) {
disjuncts.add(
parseSingleCondition(
ConditionToken.fromTokenList(conditionParts)));
}
if (disjuncts.size() == 1) {
conjuncts.add(disjuncts.get(0));
} else {
conjuncts.add(TreeBuilder.makeOr(disjuncts));
}

String operator = data[1];
conditionNodes.add(
TreeBuilder.makeFunction(operator, treeNodes, new ArrowType.Bool()));
}
final Condition filterCondition;
if (conditionNodes.size() == 1) {
filterCondition = TreeBuilder.makeCondition(conditionNodes.get(0));
if (conjuncts.size() == 1) {
filterCondition = TreeBuilder.makeCondition(conjuncts.get(0));
} else {
TreeNode treeNode = TreeBuilder.makeAnd(conditionNodes);
filterCondition = TreeBuilder.makeCondition(treeNode);
filterCondition =
TreeBuilder.makeCondition(TreeBuilder.makeAnd(conjuncts));
}

try {
Expand Down Expand Up @@ -184,6 +180,26 @@ private static RelDataType deduceRowType(Schema schema,
return builder.build();
}

/** Parses a single {@link ConditionToken} into a Gandiva {@link TreeNode}. */
private TreeNode parseSingleCondition(ConditionToken token) {
final List<TreeNode> treeNodes = new ArrayList<>(2);
treeNodes.add(
TreeBuilder.makeField(schema.getFields()
.get(
schema.getFields().indexOf(
schema.findField(token.fieldName)))));

if (token.isBinary()) {
treeNodes.add(
makeLiteralNode(
requireNonNull(token.value, "value"),
requireNonNull(token.valueType, "valueType")));
}

return TreeBuilder.makeFunction(
token.operator, treeNodes, new ArrowType.Bool());
}

private static TreeNode makeLiteralNode(String literal, String type) {
if (type.startsWith("decimal")) {
String[] typeParts =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@

import com.google.common.primitives.Ints;

import java.util.ArrayList;
import java.util.List;

import static java.util.Objects.requireNonNull;
Expand Down Expand Up @@ -84,6 +85,23 @@ protected ArrowToEnumerableConverter(RelOptCluster cluster,
: Expressions.call(
BuiltInMethod.IMMUTABLE_INT_LIST_IDENTITY.method,
Expressions.constant(fieldCount)),
Expressions.constant(arrowImplementor.whereClause))));
Expressions.constant(
toTokenLists(arrowImplementor.whereClause)))));
}

/** Converts structured {@link ConditionToken} conditions to nested string
* lists for serialization through {@link Expressions#constant}. */
private static List<List<List<String>>> toTokenLists(
List<List<ConditionToken>> conditions) {
final List<List<List<String>>> result =
new ArrayList<>(conditions.size());
for (List<ConditionToken> orGroup : conditions) {
final List<List<String>> group = new ArrayList<>(orGroup.size());
for (ConditionToken token : orGroup) {
group.add(token.toTokenList());
}
result.add(group);
}
return result;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
import static java.util.Objects.requireNonNull;

/**
* Translates a {@link RexNode} expression to a Gandiva string.
* Translates a {@link RexNode} expression to Gandiva predicate tokens.
*/
class ArrowTranslator {
final RexBuilder rexBuilder;
Expand All @@ -61,13 +61,30 @@ public static ArrowTranslator create(RexBuilder rexBuilder,
return new ArrowTranslator(rexBuilder, rowType);
}

List<String> translateMatch(RexNode condition) {
List<RexNode> disjunctions = RelOptUtil.disjunctions(condition);
if (disjunctions.size() == 1) {
return translateAnd(disjunctions.get(0));
} else {
throw new UnsupportedOperationException("Unsupported disjunctive condition " + condition);
/** The maximum number of nodes allowed during CNF conversion.
*
* <p>If exceeded, {@link RexUtil#toCnf(RexBuilder, int, RexNode)} returns
* the original expression unchanged, which may cause the subsequent
* translation to Gandiva predicates to fail with an
* {@link UnsupportedOperationException}. That exception is caught by
* {@link ArrowRules.ArrowFilterRule#onMatch}, which silently skips the
* Arrow convention and falls back to an Enumerable plan. */
private static final int MAX_CNF_NODE_COUNT = 256;

List<List<ConditionToken>> translateMatch(RexNode condition) {
// Convert to CNF; SEARCH nodes are already expanded
// by ArrowFilterRule before reaching here.
final RexNode cnf = RexUtil.toCnf(rexBuilder, MAX_CNF_NODE_COUNT, condition);

final List<List<ConditionToken>> result = new ArrayList<>();
for (RexNode conjunct : RelOptUtil.conjunctions(cnf)) {
final List<ConditionToken> orGroup = new ArrayList<>();
for (RexNode disjunct : RelOptUtil.disjunctions(conjunct)) {
orGroup.add(translateMatch2(disjunct));
}
result.add(orGroup);
}
return result;
}

/**
Expand All @@ -93,34 +110,14 @@ private static Object literalValue(RexLiteral literal) {
}
}

/**
* Translate a conjunctive predicate to a SQL string.
*
* @param condition A conjunctive predicate
*
* @return SQL string for the predicate
*/
private List<String> translateAnd(RexNode condition) {
List<String> predicates = new ArrayList<>();
for (RexNode node : RelOptUtil.conjunctions(condition)) {
if (node.getKind() == SqlKind.SEARCH) {
final RexNode node2 = RexUtil.expandSearch(rexBuilder, null, node);
predicates.addAll(translateMatch(node2));
} else {
predicates.add(translateMatch2(node));
}
}
return predicates;
}

/**
* Translates a binary or unary relation.
*
* @param node A RexNode that always evaluates to a boolean expression.
* Currently, this method is only called from translateAnd.
* @return The translated SQL string for the relation.
* @return The translated condition token for the relation.
*/
private String translateMatch2(RexNode node) {
private ConditionToken translateMatch2(RexNode node) {
switch (node.getKind()) {
case EQUALS:
return translateBinary("equal", "=", (RexCall) node);
Expand All @@ -144,7 +141,7 @@ private String translateMatch2(RexNode node) {
return translateUnary("isnotfalse", (RexCall) node);
case INPUT_REF:
final RexInputRef inputRef = (RexInputRef) node;
return fieldNames.get(inputRef.getIndex()) + " istrue";
return ConditionToken.unary(fieldNames.get(inputRef.getIndex()), "istrue");
case NOT:
return translateUnary("isfalse", (RexCall) node);
default:
Expand All @@ -156,10 +153,10 @@ private String translateMatch2(RexNode node) {
* Translates a call to a binary operator, reversing arguments if
* necessary.
*/
private String translateBinary(String op, String rop, RexCall call) {
private ConditionToken translateBinary(String op, String rop, RexCall call) {
final RexNode left = call.operands.get(0);
final RexNode right = call.operands.get(1);
@Nullable String expression = translateBinary2(op, left, right);
@Nullable ConditionToken expression = translateBinary2(op, left, right);
if (expression != null) {
return expression;
}
Expand All @@ -171,7 +168,8 @@ private String translateBinary(String op, String rop, RexCall call) {
}

/** Translates a call to a binary operator. Returns null on failure. */
private @Nullable String translateBinary2(String op, RexNode left, RexNode right) {
private @Nullable ConditionToken translateBinary2(String op, RexNode left,
RexNode right) {
if (right.getKind() != SqlKind.LITERAL) {
return null;
}
Expand All @@ -189,26 +187,29 @@ private String translateBinary(String op, String rop, RexCall call) {
}
}

/** Combines a field name, operator, and literal to produce a predicate string. */
private String translateOp2(String op, String name, RexLiteral right) {
/** Combines a field name, operator, and literal to produce a binary
* condition token. */
private ConditionToken translateOp2(String op, String name,
RexLiteral right) {
Object value = literalValue(right);
String valueString = value.toString();
String valueType = getLiteralType(right.getType());

if (value instanceof String) {
final RelDataTypeField field = requireNonNull(rowType.getField(name, true, false), "field");
final RelDataTypeField field =
requireNonNull(rowType.getField(name, true, false), "field");
SqlTypeName typeName = field.getType().getSqlTypeName();
if (typeName != SqlTypeName.CHAR) {
valueString = "'" + valueString + "'";
}
}
return name + " " + op + " " + valueString + " " + valueType;
return ConditionToken.binary(name, op, valueString, valueType);
}

/** Translates a call to a unary operator. */
private String translateUnary(String op, RexCall call) {
private ConditionToken translateUnary(String op, RexCall call) {
final RexNode opNode = call.operands.get(0);
@Nullable String expression = translateUnary2(op, opNode);
@Nullable ConditionToken expression = translateUnary2(op, opNode);

if (expression != null) {
return expression;
Expand All @@ -218,21 +219,16 @@ private String translateUnary(String op, RexCall call) {
}

/** Translates a call to a unary operator. Returns null on failure. */
private @Nullable String translateUnary2(String op, RexNode opNode) {
private @Nullable ConditionToken translateUnary2(String op, RexNode opNode) {
if (opNode.getKind() == SqlKind.INPUT_REF) {
final RexInputRef inputRef = (RexInputRef) opNode;
final String name = fieldNames.get(inputRef.getIndex());
return translateUnaryOp(op, name);
return ConditionToken.unary(name, op);
}

return null;
}

/** Combines a field name and a unary operator to produce a predicate string. */
private static String translateUnaryOp(String op, String name) {
return name + " " + op;
}

private static String getLiteralType(RelDataType type) {
if (type.getSqlTypeName() == SqlTypeName.DECIMAL) {
return "decimal" + "(" + type.getPrecision() + "," + type.getScale() + ")";
Expand Down
Loading
Loading