Skip to content

Commit

Permalink
for #2084, merge SetAssignmentsSegment.columns & values to SetAssignm…
Browse files Browse the repository at this point in the history
…entsSegment.assignments
  • Loading branch information
terrymanu committed Apr 8, 2019
1 parent d54a6a5 commit 5935cf6
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 22 deletions.
Expand Up @@ -24,8 +24,6 @@
import org.apache.shardingsphere.core.parse.antlr.extractor.util.RuleName;
import org.apache.shardingsphere.core.parse.antlr.sql.segment.dml.assignment.AssignmentSegment;
import org.apache.shardingsphere.core.parse.antlr.sql.segment.dml.assignment.SetAssignmentsSegment;
import org.apache.shardingsphere.core.parse.antlr.sql.segment.dml.column.ColumnSegment;
import org.apache.shardingsphere.core.parse.antlr.sql.segment.dml.expr.CommonExpressionSegment;

import java.util.Collection;
import java.util.HashMap;
Expand All @@ -47,17 +45,15 @@ public Optional<SetAssignmentsSegment> extract(final ParserRuleContext ancestorN
if (!setAssignmentsClauseNode.isPresent()) {
return Optional.absent();
}
Collection<ColumnSegment> columnSegments = new LinkedList<>();
Collection<CommonExpressionSegment> valueSegments = new LinkedList<>();
Collection<AssignmentSegment> assignmentSegments = new LinkedList<>();
assignmentExtractor = new AssignmentExtractor(getPlaceholderIndexes(ancestorNode));
for (ParserRuleContext each : ExtractorUtils.getAllDescendantNodes(ancestorNode, RuleName.ASSIGNMENT)) {
Optional<AssignmentSegment> assignmentSegment = assignmentExtractor.extract(each);
if (assignmentSegment.isPresent()) {
columnSegments.add(assignmentSegment.get().getColumn());
valueSegments.add(assignmentSegment.get().getValue());
assignmentSegments.add(assignmentSegment.get());
}
}
return Optional.of(new SetAssignmentsSegment(setAssignmentsClauseNode.get().getStart().getStartIndex(), columnSegments, valueSegments));
return Optional.of(new SetAssignmentsSegment(setAssignmentsClauseNode.get().getStart().getStartIndex(), assignmentSegments));
}

private Map<ParserRuleContext, Integer> getPlaceholderIndexes(final ParserRuleContext rootNode) {
Expand Down
Expand Up @@ -21,9 +21,9 @@
import org.apache.shardingsphere.core.metadata.table.ShardingTableMetaData;
import org.apache.shardingsphere.core.parse.antlr.constant.QuoteCharacter;
import org.apache.shardingsphere.core.parse.antlr.filler.SQLSegmentFiller;
import org.apache.shardingsphere.core.parse.antlr.sql.segment.dml.assignment.AssignmentSegment;
import org.apache.shardingsphere.core.parse.antlr.sql.segment.dml.assignment.SetAssignmentsSegment;
import org.apache.shardingsphere.core.parse.antlr.sql.segment.dml.column.ColumnSegment;
import org.apache.shardingsphere.core.parse.antlr.sql.segment.dml.expr.CommonExpressionSegment;
import org.apache.shardingsphere.core.parse.antlr.sql.statement.SQLStatement;
import org.apache.shardingsphere.core.parse.antlr.sql.statement.dml.InsertStatement;
import org.apache.shardingsphere.core.parse.lexer.token.DefaultKeyword;
Expand All @@ -49,8 +49,8 @@ public final class EncryptSetAssignmentsFiller implements SQLSegmentFiller<SetAs
public void fill(final SetAssignmentsSegment sqlSegment, final SQLStatement sqlStatement, final EncryptRule encryptRule, final ShardingTableMetaData shardingTableMetaData) {
InsertStatement insertStatement = (InsertStatement) sqlStatement;
String tableName = insertStatement.getTables().getSingleTableName();
for (ColumnSegment each : sqlSegment.getColumns()) {
fillColumn(each, insertStatement, tableName);
for (AssignmentSegment each : sqlSegment.getAssignments()) {
fillColumn(each.getColumn(), insertStatement, tableName);
}
InsertValue insertValue = getInsertValue(sqlSegment, sqlStatement.getLogicSQL());
insertStatement.getInsertValues().getValues().add(insertValue);
Expand All @@ -68,8 +68,8 @@ private void fillColumn(final ColumnSegment sqlSegment, final InsertStatement in
private InsertValue getInsertValue(final SetAssignmentsSegment sqlSegment, final String sql) {
int parametersCount = 0;
List<SQLExpression> columnValues = new LinkedList<>();
for (CommonExpressionSegment each : sqlSegment.getValues()) {
Optional<SQLExpression> sqlExpression = each.convertToSQLExpression(sql);
for (AssignmentSegment each : sqlSegment.getAssignments()) {
Optional<SQLExpression> sqlExpression = each.getValue().convertToSQLExpression(sql);
if (sqlExpression.isPresent()) {
columnValues.add(sqlExpression.get());
if (sqlExpression.get() instanceof SQLPlaceholderExpression) {
Expand Down
Expand Up @@ -22,6 +22,7 @@
import org.apache.shardingsphere.core.metadata.table.ShardingTableMetaData;
import org.apache.shardingsphere.core.parse.antlr.constant.QuoteCharacter;
import org.apache.shardingsphere.core.parse.antlr.filler.SQLSegmentFiller;
import org.apache.shardingsphere.core.parse.antlr.sql.segment.dml.assignment.AssignmentSegment;
import org.apache.shardingsphere.core.parse.antlr.sql.segment.dml.assignment.SetAssignmentsSegment;
import org.apache.shardingsphere.core.parse.antlr.sql.segment.dml.column.ColumnSegment;
import org.apache.shardingsphere.core.parse.antlr.sql.segment.dml.expr.CommonExpressionSegment;
Expand Down Expand Up @@ -55,19 +56,19 @@ public final class SetAssignmentsFiller implements SQLSegmentFiller<SetAssignmen
public void fill(final SetAssignmentsSegment sqlSegment, final SQLStatement sqlStatement, final ShardingRule shardingRule, final ShardingTableMetaData shardingTableMetaData) {
InsertStatement insertStatement = (InsertStatement) sqlStatement;
String tableName = insertStatement.getTables().getSingleTableName();
for (ColumnSegment each : sqlSegment.getColumns()) {
fillColumn(each, insertStatement, tableName);
for (AssignmentSegment each : sqlSegment.getAssignments()) {
fillColumn(each.getColumn(), insertStatement, tableName);
}
int columnCount = getColumnCountExcludeAssistedQueryColumns(insertStatement, shardingRule, shardingTableMetaData);
if (sqlSegment.getValues().size() != columnCount) {
if (sqlSegment.getAssignments().size() != columnCount) {
throw new SQLParsingException("INSERT INTO column size mismatch value size.");
}
AndCondition andCondition = new AndCondition();
Iterator<Column> columns = insertStatement.getColumns().iterator();
int parametersCount = 0;
List<SQLExpression> columnValues = new LinkedList<>();
for (CommonExpressionSegment each : sqlSegment.getValues()) {
SQLExpression columnValue = getColumnValue(insertStatement, shardingRule, andCondition, columns.next(), each);
for (AssignmentSegment each : sqlSegment.getAssignments()) {
SQLExpression columnValue = getColumnValue(insertStatement, shardingRule, andCondition, columns.next(), each.getValue());
columnValues.add(columnValue);
if (columnValue instanceof SQLPlaceholderExpression) {
parametersCount++;
Expand Down
Expand Up @@ -20,8 +20,6 @@
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import org.apache.shardingsphere.core.parse.antlr.sql.segment.SQLSegment;
import org.apache.shardingsphere.core.parse.antlr.sql.segment.dml.column.ColumnSegment;
import org.apache.shardingsphere.core.parse.antlr.sql.segment.dml.expr.CommonExpressionSegment;

import java.util.Collection;

Expand All @@ -36,7 +34,5 @@ public final class SetAssignmentsSegment implements SQLSegment {

private final int startIndex;

private final Collection<ColumnSegment> columns;

private final Collection<CommonExpressionSegment> values;
private final Collection<AssignmentSegment> assignments;
}

0 comments on commit 5935cf6

Please sign in to comment.