Skip to content

Commit

Permalink
Add Oracle SQL - Update statement (#11692)
Browse files Browse the repository at this point in the history
* add oracle update definition

* add multi-column in an assignment

* make AssignmentSegment into abstract, define new ColumnAssignmentSegment and refactor AssignmentSegment related code.

* utilize ColumnAssignmentSegment to create an instance type of AssignmentSegment

* set ColumnAssignmentSegment's columns field to final and uninitialize it

* remove commented code
  • Loading branch information
ThanoshanMV committed Aug 22, 2021
1 parent 4005fc6 commit 3c51144
Show file tree
Hide file tree
Showing 31 changed files with 545 additions and 134 deletions.
Expand Up @@ -60,7 +60,7 @@ protected boolean isNeedRewriteForEncrypt(final SQLStatementContext sqlStatement
public void rewrite(final ParameterBuilder parameterBuilder, final SQLStatementContext sqlStatementContext, final List<Object> parameters) {
String tableName = ((TableAvailable) sqlStatementContext).getAllTables().iterator().next().getTableName().getIdentifier().getValue();
for (AssignmentSegment each : getSetAssignmentSegment(sqlStatementContext.getSqlStatement()).getAssignments()) {
if (each.getValue() instanceof ParameterMarkerExpressionSegment && getEncryptRule().findEncryptor(tableName, each.getColumn().getIdentifier().getValue()).isPresent()) {
if (each.getValue() instanceof ParameterMarkerExpressionSegment && getEncryptRule().findEncryptor(tableName, each.getColumns().get(0).getIdentifier().getValue()).isPresent()) {
StandardParameterBuilder standardParameterBuilder = parameterBuilder instanceof StandardParameterBuilder
? (StandardParameterBuilder) parameterBuilder : ((GroupedParameterBuilder) parameterBuilder).getParameterBuilders().get(0);
encryptParameters(standardParameterBuilder, tableName, each, parameters);
Expand All @@ -78,7 +78,7 @@ private SetAssignmentSegment getSetAssignmentSegment(final SQLStatement sqlState
}

private void encryptParameters(final StandardParameterBuilder parameterBuilder, final String tableName, final AssignmentSegment assignmentSegment, final List<Object> parameters) {
String columnName = assignmentSegment.getColumn().getIdentifier().getValue();
String columnName = assignmentSegment.getColumns().get(0).getIdentifier().getValue();
int parameterMarkerIndex = ((ParameterMarkerExpressionSegment) assignmentSegment.getValue()).getParameterMarkerIndex();
Object originalValue = parameters.get(parameterMarkerIndex);
Object cipherValue = getEncryptRule().getEncryptValues(tableName, columnName, Collections.singletonList(originalValue)).iterator().next();
Expand Down
Expand Up @@ -57,7 +57,7 @@ public Collection<EncryptAssignmentToken> generateSQLTokens(final SQLStatementCo
Collection<EncryptAssignmentToken> result = new LinkedList<>();
String tableName = ((TableAvailable) sqlStatementContext).getAllTables().iterator().next().getTableName().getIdentifier().getValue();
for (AssignmentSegment each : getSetAssignmentSegment(sqlStatementContext.getSqlStatement()).getAssignments()) {
if (getEncryptRule().findEncryptor(tableName, each.getColumn().getIdentifier().getValue()).isPresent()) {
if (getEncryptRule().findEncryptor(tableName, each.getColumns().get(0).getIdentifier().getValue()).isPresent()) {
generateSQLToken(tableName, each).ifPresent(result::add);
}
}
Expand All @@ -84,8 +84,8 @@ private Optional<EncryptAssignmentToken> generateSQLToken(final String tableName
}

private EncryptAssignmentToken generateParameterSQLToken(final String tableName, final AssignmentSegment assignmentSegment) {
EncryptParameterAssignmentToken result = new EncryptParameterAssignmentToken(assignmentSegment.getColumn().getStartIndex(), assignmentSegment.getStopIndex());
String columnName = assignmentSegment.getColumn().getIdentifier().getValue();
EncryptParameterAssignmentToken result = new EncryptParameterAssignmentToken(assignmentSegment.getColumns().get(0).getStartIndex(), assignmentSegment.getStopIndex());
String columnName = assignmentSegment.getColumns().get(0).getIdentifier().getValue();
addCipherColumn(tableName, columnName, result);
addAssistedQueryColumn(tableName, columnName, result);
addPlainColumn(tableName, columnName, result);
Expand All @@ -105,7 +105,7 @@ private void addPlainColumn(final String tableName, final String columnName, fin
}

private EncryptAssignmentToken generateLiteralSQLToken(final String tableName, final AssignmentSegment assignmentSegment) {
EncryptLiteralAssignmentToken result = new EncryptLiteralAssignmentToken(assignmentSegment.getColumn().getStartIndex(), assignmentSegment.getStopIndex());
EncryptLiteralAssignmentToken result = new EncryptLiteralAssignmentToken(assignmentSegment.getColumns().get(0).getStartIndex(), assignmentSegment.getStopIndex());
addCipherAssignment(tableName, assignmentSegment, result);
addAssistedQueryAssignment(tableName, assignmentSegment, result);
addPlainAssignment(tableName, assignmentSegment, result);
Expand All @@ -114,22 +114,22 @@ private EncryptAssignmentToken generateLiteralSQLToken(final String tableName, f

private void addCipherAssignment(final String tableName, final AssignmentSegment assignmentSegment, final EncryptLiteralAssignmentToken token) {
Object originalValue = ((LiteralExpressionSegment) assignmentSegment.getValue()).getLiterals();
Object cipherValue = getEncryptRule().getEncryptValues(tableName, assignmentSegment.getColumn().getIdentifier().getValue(), Collections.singletonList(originalValue)).iterator().next();
token.addAssignment(getEncryptRule().getCipherColumn(tableName, assignmentSegment.getColumn().getIdentifier().getValue()), cipherValue);
Object cipherValue = getEncryptRule().getEncryptValues(tableName, assignmentSegment.getColumns().get(0).getIdentifier().getValue(), Collections.singletonList(originalValue)).iterator().next();
token.addAssignment(getEncryptRule().getCipherColumn(tableName, assignmentSegment.getColumns().get(0).getIdentifier().getValue()), cipherValue);
}

private void addAssistedQueryAssignment(final String tableName, final AssignmentSegment assignmentSegment, final EncryptLiteralAssignmentToken token) {
Object originalValue = ((LiteralExpressionSegment) assignmentSegment.getValue()).getLiterals();
Optional<String> assistedQueryColumn = getEncryptRule().findAssistedQueryColumn(tableName, assignmentSegment.getColumn().getIdentifier().getValue());
Optional<String> assistedQueryColumn = getEncryptRule().findAssistedQueryColumn(tableName, assignmentSegment.getColumns().get(0).getIdentifier().getValue());
assistedQueryColumn.ifPresent(s -> {
Object assistedQueryValue = getEncryptRule().getEncryptAssistedQueryValues(
tableName, assignmentSegment.getColumn().getIdentifier().getValue(), Collections.singletonList(originalValue)).iterator().next();
tableName, assignmentSegment.getColumns().get(0).getIdentifier().getValue(), Collections.singletonList(originalValue)).iterator().next();
token.addAssignment(s, assistedQueryValue);
});
}

private void addPlainAssignment(final String tableName, final AssignmentSegment assignmentSegment, final EncryptLiteralAssignmentToken token) {
Object originalValue = ((LiteralExpressionSegment) assignmentSegment.getValue()).getLiterals();
getEncryptRule().findPlainColumn(tableName, assignmentSegment.getColumn().getIdentifier().getValue()).ifPresent(plainColumn -> token.addAssignment(plainColumn, originalValue));
getEncryptRule().findPlainColumn(tableName, assignmentSegment.getColumns().get(0).getIdentifier().getValue()).ifPresent(plainColumn -> token.addAssignment(plainColumn, originalValue));
}
}
Expand Up @@ -60,7 +60,7 @@ public Collection<EncryptAssignmentToken> generateSQLTokens(final InsertStatemen
return result;
}
for (AssignmentSegment each : onDuplicateKeyColumnsSegments) {
if (getEncryptRule().findEncryptor(tableName, each.getColumn().getIdentifier().getValue()).isPresent()) {
if (getEncryptRule().findEncryptor(tableName, each.getColumns().get(0).getIdentifier().getValue()).isPresent()) {
generateSQLToken(tableName, each).ifPresent(result::add);
}
}
Expand All @@ -78,16 +78,16 @@ private Optional<EncryptAssignmentToken> generateSQLToken(final String tableName
}

private EncryptAssignmentToken generateParameterSQLToken(final String tableName, final AssignmentSegment assignmentSegment) {
EncryptParameterAssignmentToken result = new EncryptParameterAssignmentToken(assignmentSegment.getColumn().getStartIndex(), assignmentSegment.getStopIndex());
String columnName = assignmentSegment.getColumn().getIdentifier().getValue();
EncryptParameterAssignmentToken result = new EncryptParameterAssignmentToken(assignmentSegment.getColumns().get(0).getStartIndex(), assignmentSegment.getStopIndex());
String columnName = assignmentSegment.getColumns().get(0).getIdentifier().getValue();
addCipherColumn(tableName, columnName, result);
addAssistedQueryColumn(tableName, columnName, result);
addPlainColumn(tableName, columnName, result);
return result;
}

private EncryptAssignmentToken generateLiteralSQLToken(final String tableName, final AssignmentSegment assignmentSegment) {
EncryptLiteralAssignmentToken result = new EncryptLiteralAssignmentToken(assignmentSegment.getColumn().getStartIndex(), assignmentSegment.getStopIndex());
EncryptLiteralAssignmentToken result = new EncryptLiteralAssignmentToken(assignmentSegment.getColumns().get(0).getStartIndex(), assignmentSegment.getStopIndex());
addCipherAssignment(tableName, assignmentSegment, result);
addAssistedQueryAssignment(tableName, assignmentSegment, result);
addPlainAssignment(tableName, assignmentSegment, result);
Expand All @@ -108,21 +108,22 @@ private void addPlainColumn(final String tableName, final String columnName, fin

private void addCipherAssignment(final String tableName, final AssignmentSegment assignmentSegment, final EncryptLiteralAssignmentToken token) {
Object originalValue = ((LiteralExpressionSegment) assignmentSegment.getValue()).getLiterals();
Object cipherValue = getEncryptRule().getEncryptValues(tableName, assignmentSegment.getColumn().getIdentifier().getValue(), Collections.singletonList(originalValue)).iterator().next();
token.addAssignment(getEncryptRule().getCipherColumn(tableName, assignmentSegment.getColumn().getIdentifier().getValue()), cipherValue);
Object cipherValue = getEncryptRule().getEncryptValues(tableName, assignmentSegment.getColumns().get(0).getIdentifier().getValue(), Collections.singletonList(originalValue)).iterator().next();
token.addAssignment(getEncryptRule().getCipherColumn(tableName, assignmentSegment.getColumns().get(0).getIdentifier().getValue()), cipherValue);
}

private void addAssistedQueryAssignment(final String tableName, final AssignmentSegment assignmentSegment, final EncryptLiteralAssignmentToken token) {
getEncryptRule().findAssistedQueryColumn(tableName, assignmentSegment.getColumn().getIdentifier().getValue()).ifPresent(assistedQueryColumn -> {
getEncryptRule().findAssistedQueryColumn(tableName, assignmentSegment.getColumns().get(0).getIdentifier().getValue()).ifPresent(assistedQueryColumn -> {
Object originalValue = ((LiteralExpressionSegment) assignmentSegment.getValue()).getLiterals();
Object assistedQueryValue = getEncryptRule().getEncryptAssistedQueryValues(tableName, assignmentSegment.getColumn().getIdentifier().getValue(), Collections.singletonList(originalValue))
Object assistedQueryValue = getEncryptRule()
.getEncryptAssistedQueryValues(tableName, assignmentSegment.getColumns().get(0).getIdentifier().getValue(), Collections.singletonList(originalValue))
.iterator().next();
token.addAssignment(assistedQueryColumn, assistedQueryValue);
});
}

private void addPlainAssignment(final String tableName, final AssignmentSegment assignmentSegment, final EncryptLiteralAssignmentToken token) {
Object originalValue = ((LiteralExpressionSegment) assignmentSegment.getValue()).getLiterals();
getEncryptRule().findPlainColumn(tableName, assignmentSegment.getColumn().getIdentifier().getValue()).ifPresent(plainColumn -> token.addAssignment(plainColumn, originalValue));
getEncryptRule().findPlainColumn(tableName, assignmentSegment.getColumns().get(0).getIdentifier().getValue()).ifPresent(plainColumn -> token.addAssignment(plainColumn, originalValue));
}
}
Expand Up @@ -38,7 +38,7 @@ protected boolean isNeedRewriteForShadow(final SQLStatementContext sqlStatementC
}

private boolean isContainShadowColumn(final UpdateStatement updateStatement) {
return updateStatement.getSetAssignment().getAssignments().stream().anyMatch(each -> each.getColumn().getIdentifier().getValue().equals(getShadowColumn()));
return updateStatement.getSetAssignment().getAssignments().stream().anyMatch(each -> each.getColumns().get(0).getIdentifier().getValue().equals(getShadowColumn()));
}

@Override
Expand All @@ -55,7 +55,7 @@ private void doShadowRewrite(final ParameterBuilder parameterBuilder, final Upda
private int getShadowColumnIndex(final UpdateStatement sqlStatement) {
int count = 0;
for (AssignmentSegment each : sqlStatement.getSetAssignment().getAssignments()) {
if (each.getColumn().getIdentifier().getValue().equals(getShadowColumn())) {
if (each.getColumns().get(0).getIdentifier().getValue().equals(getShadowColumn())) {
return count;
}
count++;
Expand Down
Expand Up @@ -42,7 +42,7 @@ protected boolean isGenerateSQLTokenForShadow(final SQLStatementContext sqlState
}

private boolean isContainShadowColumn(final Collection<AssignmentSegment> assignments) {
return assignments.stream().anyMatch(each -> each.getColumn().getIdentifier().getValue().equals(getShadowColumn()));
return assignments.stream().anyMatch(each -> each.getColumns().get(0).getIdentifier().getValue().equals(getShadowColumn()));
}

@Override
Expand All @@ -52,7 +52,7 @@ public Collection<? extends SQLToken> generateSQLTokens(final UpdateStatementCon

private Collection<RemoveToken> generateRemoveTokenForShadow(final Collection<AssignmentSegment> assignments) {
List<AssignmentSegment> assignmentSegments = (LinkedList<AssignmentSegment>) assignments;
return IntStream.range(0, assignmentSegments.size()).filter(i -> getShadowColumn().equals(assignmentSegments.get(i).getColumn().getIdentifier().getValue()))
return IntStream.range(0, assignmentSegments.size()).filter(i -> getShadowColumn().equals(assignmentSegments.get(i).getColumns().get(0).getIdentifier().getValue()))
.mapToObj(i -> createRemoveToken(assignmentSegments, i)).collect(Collectors.toCollection(LinkedList::new));
}

Expand Down
Expand Up @@ -21,6 +21,7 @@
import org.apache.shardingsphere.infra.rewrite.parameter.builder.impl.GroupedParameterBuilder;
import org.apache.shardingsphere.shadow.rule.ShadowRule;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.assignment.AssignmentSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.assignment.ColumnAssignmentSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.assignment.SetAssignmentSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.ColumnSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ExpressionSegment;
Expand All @@ -31,6 +32,8 @@
import org.junit.Test;

import java.util.Collections;
import java.util.LinkedList;
import java.util.List;

import static org.junit.Assert.assertTrue;
import static org.mockito.Mockito.mock;
Expand All @@ -56,7 +59,10 @@ private void mockUpdateStatementContext(final String shadowColumn) {
}

private SetAssignmentSegment createSetAssignmentSegment(final String shadowColumn) {
return new SetAssignmentSegment(0, 20, Collections.singletonList(new AssignmentSegment(0, 15, new ColumnSegment(0, 15, new IdentifierValue(shadowColumn)), mock(ExpressionSegment.class))));
List<ColumnSegment> columns = new LinkedList<>();
columns.add(new ColumnSegment(0, 15, new IdentifierValue(shadowColumn)));
AssignmentSegment assignment = new ColumnAssignmentSegment(0, 15, columns, mock(ExpressionSegment.class));
return new SetAssignmentSegment(0, 20, Collections.singletonList(assignment));
}

private void initShadowUpdateValueParameterRewriter(final String shadowColumn) {
Expand Down
Expand Up @@ -20,6 +20,7 @@
import org.apache.shardingsphere.infra.binder.statement.dml.UpdateStatementContext;
import org.apache.shardingsphere.shadow.rule.ShadowRule;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.assignment.AssignmentSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.assignment.ColumnAssignmentSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.assignment.SetAssignmentSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.ColumnSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ExpressionSegment;
Expand All @@ -31,6 +32,7 @@

import java.util.Collection;
import java.util.LinkedList;
import java.util.List;

import static org.hamcrest.CoreMatchers.is;
import static org.junit.Assert.assertThat;
Expand Down Expand Up @@ -65,7 +67,10 @@ private SetAssignmentSegment createSetAssignmentSegment(final String shadowColum
}

private AssignmentSegment createAssignmentSegment(final int startIndex, final int stopIndex, final IdentifierValue identifierValue) {
return new AssignmentSegment(startIndex, stopIndex, new ColumnSegment(startIndex, stopIndex, identifierValue), mock(ExpressionSegment.class));
List<ColumnSegment> columns = new LinkedList<>();
columns.add(new ColumnSegment(startIndex, stopIndex, identifierValue));
AssignmentSegment result = new ColumnAssignmentSegment(startIndex, stopIndex, columns, mock(ExpressionSegment.class));
return result;
}

private void initShadowUpdateColumnTokenGenerator(final String shadowColumn) {
Expand Down
Expand Up @@ -72,7 +72,7 @@ public void preValidate(final ShardingRule shardingRule, final SQLStatementConte

private boolean isUpdateShardingKey(final ShardingRule shardingRule, final OnDuplicateKeyColumnsSegment onDuplicateKeyColumnsSegment, final String tableName) {
for (AssignmentSegment each : onDuplicateKeyColumnsSegment.getColumns()) {
if (shardingRule.isShardingColumn(each.getColumn().getIdentifier().getValue(), tableName)) {
if (shardingRule.isShardingColumn(each.getColumns().get(0).getIdentifier().getValue(), tableName)) {
return true;
}
}
Expand Down
Expand Up @@ -50,7 +50,7 @@ public void preValidate(final ShardingRule shardingRule, final SQLStatementConte
UpdateStatement sqlStatement = sqlStatementContext.getSqlStatement();
String tableName = sqlStatementContext.getTablesContext().getTables().iterator().next().getTableName().getIdentifier().getValue();
for (AssignmentSegment each : sqlStatement.getSetAssignment().getAssignments()) {
String shardingColumn = each.getColumn().getIdentifier().getValue();
String shardingColumn = each.getColumns().get(0).getIdentifier().getValue();
if (shardingRule.isShardingColumn(shardingColumn, tableName)) {
Optional<Object> shardingColumnSetAssignmentValue = getShardingColumnSetAssignmentValue(each, parameters);
Optional<Object> shardingValue = Optional.empty();
Expand Down

0 comments on commit 3c51144

Please sign in to comment.