Skip to content

Commit

Permalink
fixed #57 parse sub query for generating select items
Browse files Browse the repository at this point in the history
  • Loading branch information
gaoht committed May 4, 2016
1 parent 2d1ad48 commit d99fe97
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 26 deletions.
Expand Up @@ -17,12 +17,6 @@

package com.dangdang.ddframe.rdb.sharding.parser.visitor;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;

import com.alibaba.druid.sql.ast.SQLExpr;
import com.alibaba.druid.sql.ast.SQLObject;
import com.alibaba.druid.sql.ast.expr.SQLBinaryOpExpr;
Expand All @@ -48,6 +42,13 @@
import lombok.Getter;
import lombok.Setter;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;

/**
* 解析过程的上下文对象.
*
Expand All @@ -56,8 +57,12 @@
@Getter
public final class ParseContext {

private static final String AUTO_GEN_TOKE_KEY_TEMPLATE = "sharding_auto_gen_%d";

private static final String SHARDING_GEN_ALIAS = "sharding_gen_%s";

private final String autoGenTokenKey;

private final SQLParsedResult parsedResult = new SQLParsedResult();

@Setter
Expand All @@ -76,6 +81,24 @@ public final class ParseContext {

private boolean hasAllColumn;

@Setter
private ParseContext parentParseContext;

private List<ParseContext> subParseContext = new LinkedList<>();

private int itemIndex;

public ParseContext(final int parseContextIndex) {
this.autoGenTokenKey = String.format(AUTO_GEN_TOKE_KEY_TEMPLATE, parseContextIndex);
}

/**
* 增加查询投射项数量.
*/
public void increaseItemIndex() {
itemIndex++;
}

/**
* 设置当前正在访问的表.
*
Expand Down Expand Up @@ -329,5 +352,4 @@ public void registerSelectItem(final String selectItem) {
}
selectItems.add(rawItemExpr);
}

}
Expand Up @@ -17,9 +17,6 @@

package com.dangdang.ddframe.rdb.sharding.parser.visitor.basic.mysql;

import java.util.Arrays;
import java.util.Collections;

import com.alibaba.druid.sql.ast.SQLHint;
import com.alibaba.druid.sql.ast.expr.SQLBetweenExpr;
import com.alibaba.druid.sql.ast.expr.SQLBinaryOpExpr;
Expand All @@ -38,18 +35,24 @@
import com.dangdang.ddframe.rdb.sharding.parser.visitor.SQLVisitor;
import com.dangdang.ddframe.rdb.sharding.util.SQLUtil;

import java.util.Arrays;
import java.util.Collections;

/**
* MySQL解析基础访问器.
*
* @author zhangliang
*/
public abstract class AbstractMySQLVisitor extends MySqlOutputVisitor implements SQLVisitor {

private final ParseContext parseContext = new ParseContext();
private ParseContext parseContext;

private int parseContextIndex;

protected AbstractMySQLVisitor() {
super(new SQLBuilder());
setPrettyFormat(false);
parseContext = new ParseContext(parseContextIndex);
}

@Override
Expand All @@ -62,6 +65,25 @@ public final ParseContext getParseContext() {
return parseContext;
}

protected final void stepInQuery() {
if (0 == parseContextIndex) {
parseContextIndex++;
return;
}
ParseContext parseContext = new ParseContext(parseContextIndex++);
parseContext.setShardingColumns(this.parseContext.getShardingColumns());
parseContext.setParentParseContext(this.parseContext);
this.parseContext.getSubParseContext().add(parseContext);
this.parseContext = parseContext;
}

protected final void stepOutQuery() {
if (null == parseContext.getParentParseContext()) {
return;
}
parseContext = parseContext.getParentParseContext();
}

@Override
public final SQLBuilder getSQLBuilder() {
return (SQLBuilder) appender;
Expand All @@ -86,7 +108,7 @@ public final boolean visit(final SQLVariantRefExpr x) {

@Override
public final boolean visit(final SQLExprTableSource x) {
return visit(x, parseContext.addTable(x));
return visit(x, getParseContext().addTable(x));
}

private boolean visit(final SQLExprTableSource x, final Table table) {
Expand Down Expand Up @@ -128,7 +150,7 @@ public final boolean visit(final SQLPropertyExpr x) {
return super.visit(x);
}
String tableOrAliasName = ((SQLIdentifierExpr) x.getOwner()).getLowerName();
if (parseContext.isBinaryOperateWithAlias(x, tableOrAliasName)) {
if (getParseContext().isBinaryOperateWithAlias(x, tableOrAliasName)) {
return super.visit(x);
}
printToken(tableOrAliasName);
Expand Down
Expand Up @@ -17,8 +17,6 @@

package com.dangdang.ddframe.rdb.sharding.parser.visitor.basic.mysql;

import java.util.List;

import com.alibaba.druid.sql.ast.SQLExpr;
import com.alibaba.druid.sql.ast.SQLOrderBy;
import com.alibaba.druid.sql.ast.expr.SQLAggregateExpr;
Expand All @@ -31,7 +29,6 @@
import com.alibaba.druid.sql.ast.statement.SQLExprTableSource;
import com.alibaba.druid.sql.ast.statement.SQLSelectItem;
import com.alibaba.druid.sql.ast.statement.SQLSelectOrderByItem;
import com.alibaba.druid.sql.ast.statement.SQLSelectStatement;
import com.alibaba.druid.sql.dialect.mysql.ast.expr.MySqlSelectGroupByExpr;
import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlSelectQueryBlock;
import com.alibaba.druid.sql.dialect.mysql.visitor.MySqlOutputVisitor;
Expand All @@ -44,27 +41,25 @@
import com.google.common.base.Optional;
import com.google.common.base.Strings;

import java.util.List;

/**
* MySQL的SELECT语句访问器.
*
* @author gaohongtao, zhangliang
*/
public class MySQLSelectVisitor extends AbstractMySQLVisitor {

private static final String AUTO_GEN_TOKE_KEY = "sharding_auto_gen";

// TODO 封装到方法内部
private int itemIndex;

@Override
protected void printSelectList(final List<SQLSelectItem> selectList) {
super.printSelectList(selectList);
// TODO 提炼成print,或者是否不应该由token的方式替换?
getSQLBuilder().appendToken(AUTO_GEN_TOKE_KEY, false);
getSQLBuilder().appendToken(getParseContext().getAutoGenTokenKey(), false);
}

@Override
public boolean visit(final MySqlSelectQueryBlock x) {
stepInQuery();
if (x.getFrom() instanceof SQLExprTableSource) {
SQLExprTableSource tableExpr = (SQLExprTableSource) x.getFrom();
getParseContext().setCurrentTable(tableExpr.getExpr().toString(), Optional.fromNullable(tableExpr.getAlias()));
Expand All @@ -80,7 +75,7 @@ public boolean visit(final MySqlSelectQueryBlock x) {
*/
// TODO SELECT * 导致index不准,不支持SELECT *,且生产环境不建议使用SELECT *
public boolean visit(final SQLSelectItem x) {
itemIndex++;
getParseContext().increaseItemIndex();
if (Strings.isNullOrEmpty(x.getAlias())) {
SQLExpr expr = x.getExpr();
if (expr instanceof SQLIdentifierExpr) {
Expand Down Expand Up @@ -111,7 +106,7 @@ public boolean visit(final SQLAggregateExpr x) {
x.accept(new MySqlOutputVisitor(expression));
// TODO index获取不准,考虑使用别名替换
AggregationColumn column = new AggregationColumn(expression.toString(), aggregationType, Optional.fromNullable(((SQLSelectItem) x.getParent()).getAlias()),
null == x.getOption() ? Optional.<String>absent() : Optional.of(x.getOption().toString()), itemIndex);
null == x.getOption() ? Optional.<String>absent() : Optional.of(x.getOption().toString()), getParseContext().getItemIndex());
getParseContext().getParsedResult().getMergeContext().getAggregationColumns().add(column);
if (AggregationType.AVG.equals(aggregationType)) {
getParseContext().addDerivedColumnsForAvgColumn(column);
Expand Down Expand Up @@ -190,7 +185,7 @@ public boolean visit(final MySqlSelectQueryBlock.Limit x) {
}

@Override
public void endVisit(final SQLSelectStatement x) {
public void endVisit(final MySqlSelectQueryBlock x) {
StringBuilder derivedSelectItems = new StringBuilder();
for (AggregationColumn aggregationColumn : getParseContext().getParsedResult().getMergeContext().getAggregationColumns()) {
for (AggregationColumn derivedColumn : aggregationColumn.getDerivedColumns()) {
Expand All @@ -206,8 +201,9 @@ public void endVisit(final SQLSelectStatement x) {
}
}
if (0 != derivedSelectItems.length()) {
getSQLBuilder().buildSQL(AUTO_GEN_TOKE_KEY, derivedSelectItems.toString());
getSQLBuilder().buildSQL(getParseContext().getAutoGenTokenKey(), derivedSelectItems.toString());
}
super.endVisit(x);
stepOutQuery();
}
}
@@ -0,0 +1,42 @@
<?xml version="1.0" encoding="UTF-8"?>
<asserts>
<assert id="assertSelectWithOrderByForAliasAndSubQuery" sql="SELECT price FROM (SELECT o.user_id,o.price FROM order o WHERE o.order_id = 1 ORDER BY o.order_id) order by user_id" expected-sql="SELECT price[Token(, user_id AS sharding_gen_1)] FROM (SELECT [Token(o)].user_id, [Token(o)].price[Token(, order_id AS sharding_gen_1)] FROM [Token(order)] o WHERE o.order_id = 1 ORDER BY o.order_id ) ORDER BY user_id">
<tables>
</tables>
<condition-contexts>
<condition-context>
</condition-context>
</condition-contexts>
<order-by-columns>
<order-by-column name="user_id" alias="sharding_gen_1" order-by-type="ASC" />
</order-by-columns>
</assert>
<assert id="assertSelectWithGroupByAndSubQuery" sql="SELECT AVG(i.SUM_PRICE) avg FROM (SELECT o.order_id,SUM(o.price) AS SUM_PRICE FROM order o WHERE o.order_id = 1 GROUP BY o.order_id) i" expected-sql="SELECT AVG(i.SUM_PRICE) AS avg[Token(, COUNT(i.SUM_PRICE) AS sharding_gen_1, SUM(i.SUM_PRICE) AS sharding_gen_2)] FROM (SELECT [Token(o)].order_id, SUM(o.price) AS SUM_PRICE[Token(, o.order_id AS sharding_gen_1)] FROM [Token(order)] o WHERE o.order_id = 1 GROUP BY o.order_id ) i">
<tables>
</tables>
<condition-contexts>
<condition-context>
</condition-context>
</condition-contexts>
<aggregation-columns>
<aggregation-column expression="AVG(i.SUM_PRICE)" aggregation-type="AVG" alias="avg" index="1">
<derived-column expression="COUNT(i.SUM_PRICE)" aggregation-type="COUNT" alias="sharding_gen_1"/>
<derived-column expression="SUM(i.SUM_PRICE)" aggregation-type="SUM" alias="sharding_gen_2" />
</aggregation-column>
<aggregation-column expression="COUNT(i.SUM_PRICE)" aggregation-type="COUNT" alias="sharding_gen_1" />
<aggregation-column expression="SUM(i.SUM_PRICE)" aggregation-type="SUM" alias="sharding_gen_2" />
</aggregation-columns>
</assert>
<assert id="assertSelectWithWhereSubQuery" sql="SELECT * FROM order o WHERE o.order_id = 2 and exists (select 1 from t_user u where u.user_id = o.user_id and u.user_id = 1)" expected-sql="SELECT * FROM [Token(order)] o WHERE o.order_id = 2 AND EXISTS (SELECT 1 FROM [Token(t_user)] u WHERE u.user_id = [Token(o)].user_id AND u.user_id = 1)">
<tables>
<table name="order" alias="o" />
</tables>
<condition-contexts>
<condition-context>
<condition column-name="order_id" table-name="order" operator="EQUAL">
<value value="2" type="java.lang.Integer" />
</condition>
</condition-context>
</condition-contexts>
</assert>
</asserts>

0 comments on commit d99fe97

Please sign in to comment.