Skip to content

Commit

Permalink
add where segment parsing
Browse files Browse the repository at this point in the history
  • Loading branch information
tristaZero committed Jan 27, 2020
1 parent 11dfef5 commit af6fbbd
Show file tree
Hide file tree
Showing 7 changed files with 96 additions and 11 deletions.
Expand Up @@ -98,7 +98,7 @@ public void judgeForInsert() {
@Test
public void judgeForWhereSegment() {
SelectStatement selectStatement = new SelectStatement();
WhereSegment whereSegment = new WhereSegment(0, 0, 0);
WhereSegment whereSegment = new WhereSegment(0, 0);
AndPredicate andPredicate = new AndPredicate();
andPredicate.getPredicates().addAll(Collections.singletonList(new PredicateSegment(0, 0,
new ColumnSegment(0, 0, "shadow"),
Expand Down
Expand Up @@ -94,7 +94,8 @@ private UpdateStatement createUpdateStatementAndParameters(final Object sharding
Collection<AssignmentSegment> assignments = Collections.singletonList(new AssignmentSegment(0, 0, new ColumnSegment(0, 0, "id"), new LiteralExpressionSegment(0, 0, shardingColumnParameter)));
SetAssignmentSegment setAssignmentSegment = new SetAssignmentSegment(0, 0, assignments);
result.setSetAssignment(setAssignmentSegment);
WhereSegment where = new WhereSegment(0, 0, 1);
WhereSegment where = new WhereSegment(0, 0);
where.setParametersCount(1);
where.setParameterMarkerStartIndex(0);
AndPredicate andPre = new AndPredicate();
andPre.getPredicates().add(new PredicateSegment(0, 1, new ColumnSegment(0, 0, "id"), new PredicateCompareRightValue("=", new ParameterMarkerExpressionSegment(0, 0, 0))));
Expand Down
Expand Up @@ -44,7 +44,8 @@ public Optional<WhereSegment> extract(final ParserRuleContext ancestorNode, fina
if (!whereNode.isPresent()) {
return Optional.absent();
}
WhereSegment result = new WhereSegment(whereNode.get().getStart().getStartIndex(), whereNode.get().getStop().getStopIndex(), parameterMarkerIndexes.size());
WhereSegment result = new WhereSegment(whereNode.get().getStart().getStartIndex(), whereNode.get().getStop().getStopIndex());
result.setParametersCount(parameterMarkerIndexes.size());
Optional<OrPredicateSegment> orPredicateSegment = predicateExtractor.extract(whereNode.get(), parameterMarkerIndexes);
if (orPredicateSegment.isPresent()) {
result.getAndPredicates().addAll(orPredicateSegment.get().getAndPredicates());
Expand Down
Expand Up @@ -18,6 +18,7 @@
package org.apache.shardingsphere.sql.parser.sql.segment.dml.predicate;

import lombok.Getter;
import org.apache.shardingsphere.sql.parser.sql.segment.SQLSegment;

import java.util.Collection;
import java.util.LinkedList;
Expand All @@ -26,9 +27,14 @@
* And predicate.
*
* @author duhongjun
* @author panjuan
*/
@Getter
public final class AndPredicate {
public final class AndPredicate implements SQLSegment {

private final int startIndex = 0;

private final int stopIndex = 0;

private final Collection<PredicateSegment> predicates = new LinkedList<>();
}
Expand Up @@ -29,6 +29,7 @@
* Where segment.
*
* @author duhongjun
* @author panjuan
*/
@RequiredArgsConstructor
@Getter
Expand All @@ -39,7 +40,7 @@ public final class WhereSegment implements SQLSegment {

private final int stopIndex;

private final int parametersCount;
private int parametersCount;

private final Collection<AndPredicate> andPredicates = new LinkedList<>();

Expand Down
Expand Up @@ -17,6 +17,8 @@

package org.apache.shardingsphere.sql.parser;

import com.google.common.base.Optional;
import com.google.common.base.Preconditions;
import org.antlr.v4.runtime.tree.TerminalNode;
import org.apache.shardingsphere.sql.parser.api.SQLVisitor;
import org.apache.shardingsphere.sql.parser.autogen.MySQLStatementBaseVisitor;
Expand Down Expand Up @@ -62,8 +64,10 @@
import org.apache.shardingsphere.sql.parser.autogen.MySQLStatementParser.UnreservedWord_Context;
import org.apache.shardingsphere.sql.parser.autogen.MySQLStatementParser.UseContext;
import org.apache.shardingsphere.sql.parser.autogen.MySQLStatementParser.WeightStringFunctionContext;
import org.apache.shardingsphere.sql.parser.autogen.MySQLStatementParser.WhereClauseContext;
import org.apache.shardingsphere.sql.parser.autogen.MySQLStatementParser.WindowFunctionContext;
import org.apache.shardingsphere.sql.parser.core.constant.AggregationType;
import org.apache.shardingsphere.sql.parser.core.constant.LogicalOperator;
import org.apache.shardingsphere.sql.parser.sql.ASTNode;
import org.apache.shardingsphere.sql.parser.sql.segment.dal.FromSchemaSegment;
import org.apache.shardingsphere.sql.parser.sql.segment.dal.ShowLikeSegment;
Expand All @@ -80,6 +84,11 @@
import org.apache.shardingsphere.sql.parser.sql.segment.dml.item.AggregationDistinctProjectionSegment;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.item.AggregationProjectionSegment;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.item.ExpressionProjectionSegment;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.predicate.AndPredicate;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.predicate.OrPredicateSegment;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.predicate.PredicateSegment;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.predicate.WhereSegment;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.predicate.value.PredicateRightValue;
import org.apache.shardingsphere.sql.parser.sql.segment.generic.SchemaSegment;
import org.apache.shardingsphere.sql.parser.sql.segment.generic.TableSegment;
import org.apache.shardingsphere.sql.parser.sql.statement.dal.dialect.mysql.ShowTableStatusStatement;
Expand Down Expand Up @@ -147,7 +156,8 @@ public ASTNode visitShowLike(final ShowLikeContext ctx) {

@Override
public ASTNode visitInsert(final InsertContext ctx) {
// TODO :Since there is no segment for insertValuesClause, InsertStatement is created by sub rule.
// TODO :FIXME, no parsing for on duplicate phrase
// TODO :FIXME, since there is no segment for insertValuesClause, InsertStatement is created by sub rule.
InsertStatement result = null != ctx.insertValuesClause() ? (InsertStatement) visit(ctx.insertValuesClause()) : (InsertStatement) visit(ctx.setAssignmentsClause());
TableSegment table = (TableSegment) visit(ctx.tableName());
result.setTable(table);
Expand Down Expand Up @@ -203,8 +213,7 @@ public ASTNode visitAssignment(final AssignmentContext ctx) {
public ASTNode visitAssignmentValue(final AssignmentValueContext ctx) {
ExprContext expr = ctx.expr();
if (null != expr) {
ASTNode value = visit(expr);
return createExpressionSegment(value, expr);
return visit(expr);
}
return new CommonExpressionSegment(ctx.start.getStartIndex(), ctx.stop.getStopIndex(), ctx.getText());
}
Expand All @@ -214,6 +223,15 @@ public ASTNode visitBlobValue(final BlobValueContext ctx) {
return new LiteralValue(ctx.STRING_().getText());
}

@Override
public ASTNode visitWhereClause(final WhereClauseContext ctx) {
WhereSegment result = new WhereSegment(ctx.start.getStartIndex(), ctx.stop.getStopIndex());
result.setParameterMarkerStartIndex(currentParameterIndex);
result.getAndPredicates().addAll(((OrPredicateSegment) visit(ctx.expr())).getAndPredicates());
result.setParametersCount(currentParameterIndex);
return result;
}

// TCLStatement.g4
@Override
public ASTNode visitSetAutoCommit(final SetAutoCommitContext ctx) {
Expand Down Expand Up @@ -268,17 +286,25 @@ public ASTNode visitColumnName(final ColumnNameContext ctx) {
@Override
public ASTNode visitExpr(final ExprContext ctx) {
BooleanPrimaryContext bool = ctx.booleanPrimary();
ASTNode result;
if (null != bool) {
return visit(bool);
result = visit(bool);
} else if (null != ctx.logicalOperator()) {
result = mergePredicateSegment(visit(ctx.expr(0)), visit(ctx.expr(1)), ctx.logicalOperator().getText());
} else {
result = new LiteralValue(ctx.getText());
}
return new LiteralValue(ctx.getText());
return createExpressionSegment(result, ctx);
}

@Override
public ASTNode visitBooleanPrimary(final BooleanPrimaryContext ctx) {
if (null != ctx.subquery()) {
return new SubquerySegment(ctx.start.getStartIndex(), ctx.stop.getStopIndex(), ctx.subquery().getText());
}
if (null != ctx.comparisonOperator()) {
return new PredicateSegment(ctx.getStart().getStartIndex(), ctx.getStop().getStopIndex(), (ColumnSegment) visit(ctx.booleanPrimary()), (PredicateRightValue) ctx.predicate());
}
if (null != ctx.predicate()) {
return visit(ctx.predicate());
}
Expand Down Expand Up @@ -323,6 +349,9 @@ public ASTNode visitSimpleExpr(final SimpleExprContext ctx) {
if (null != ctx.functionCall()) {
return visit(ctx.functionCall());
}
if (null != ctx.columnName()) {
return visit(ctx.columnName());
}
return new CommonExpressionSegment(ctx.start.getStartIndex(), ctx.stop.getStopIndex(), ctx.getText());
}

Expand Down Expand Up @@ -549,6 +578,51 @@ private String getDistinctExpression(final AggregationFunctionContext ctx) {
return result.toString();
}

private OrPredicateSegment mergePredicateSegment(final ASTNode left, final ASTNode right, final String operator) {
Optional<LogicalOperator> logicalOperator = LogicalOperator.valueFrom(operator);
Preconditions.checkState(logicalOperator.isPresent());
if (LogicalOperator.OR == logicalOperator.get()) {
return mergeOrPredicateSegment(left, right);
}
return mergeAndPredicateSegment(left, right);
}

private OrPredicateSegment mergeOrPredicateSegment(final ASTNode left, final ASTNode right) {
OrPredicateSegment result = new OrPredicateSegment();
result.getAndPredicates().addAll(getAndPredicates(left));
result.getAndPredicates().addAll(getAndPredicates(right));
return result;
}

private OrPredicateSegment mergeAndPredicateSegment(final ASTNode left, final ASTNode right) {
OrPredicateSegment result = new OrPredicateSegment();
for (AndPredicate eachLeft : getAndPredicates(left)) {
for (AndPredicate eachRight : getAndPredicates(right)) {
result.getAndPredicates().add(createAndPredicate(eachLeft, eachRight));
}
}
return result;
}

private AndPredicate createAndPredicate(final AndPredicate left, final AndPredicate right) {
AndPredicate result = new AndPredicate();
result.getPredicates().addAll(left.getPredicates());
result.getPredicates().addAll(right.getPredicates());
return result;
}

private Collection<AndPredicate> getAndPredicates(final ASTNode astNode) {
if (astNode instanceof OrPredicateSegment) {
return ((OrPredicateSegment) astNode).getAndPredicates();
}
if (astNode instanceof AndPredicate) {
return Collections.singleton((AndPredicate) astNode);
}
AndPredicate andPredicate = new AndPredicate();
andPredicate.getPredicates().add((PredicateSegment) astNode);
return Collections.singleton(andPredicate);
}

// TODO :FIXME, sql case id: insert_with_str_to_date
private void calculateParameterCount(final Collection<ExprContext> exprContexts) {
for (ExprContext each : exprContexts) {
Expand Down
Expand Up @@ -65,7 +65,9 @@ public void assertCreatePaginationContextWhenLimitSegmentTopSegmentAbsentAndWher
SelectStatement selectStatement = mock(SelectStatement.class);
when(selectStatement.findSQLSegment(LimitSegment.class)).thenReturn(Optional.<LimitSegment>absent());
when(selectStatement.findSQLSegment(TopProjectionSegment.class)).thenReturn(Optional.<TopProjectionSegment>absent());
when(selectStatement.findSQLSegment(WhereSegment.class)).thenReturn(Optional.of(new WhereSegment(0, 10, 10)));
WhereSegment whereSegment = new WhereSegment(0, 10);
whereSegment.setParametersCount(10);
when(selectStatement.findSQLSegment(WhereSegment.class)).thenReturn(Optional.of(whereSegment));
ProjectionsContext projectionsContext = mock(ProjectionsContext.class);
when(projectionsContext.findAlias(anyString())).thenReturn(Optional.<String>absent());
PaginationContext paginationContext = new PaginationContextEngine().createPaginationContext(selectStatement, projectionsContext, Collections.emptyList());
Expand Down

0 comments on commit af6fbbd

Please sign in to comment.