Skip to content

Commit

Permalink
Merge pull request #2824 from tristaZero/dev
Browse files Browse the repository at this point in the history
Fix bug for plainColumn
  • Loading branch information
terrymanu committed Aug 8, 2019
2 parents fa185cb + 3eb8c5d commit 3217df7
Show file tree
Hide file tree
Showing 15 changed files with 131 additions and 55 deletions.
Expand Up @@ -266,6 +266,19 @@ public Map<String, String> getLogicAndCipherColumns(final String logicTable) {
return tables.get(logicTable).getLogicAndCipherColumns();
}

/**
* Get logic and plain columns.
*
* @param logicTable logic table
* @return logic and plain columns
*/
public Map<String, String> getLogicAndPlainColumns(final String logicTable) {
if (!tables.containsKey(logicTable)) {
return Collections.emptyMap();
}
return tables.get(logicTable).getLogicAndPlainColumns();
}

/**
* Get encrypt assisted column values.
*
Expand Down
Expand Up @@ -198,4 +198,22 @@ public String apply(final EncryptColumn input) {
}
});
}

/**
* Get logic and plain columns.
*
* @return logic and plain columns
*/
public Map<String, String> getLogicAndPlainColumns() {
return Maps.transformValues(columns, new Function<EncryptColumn, String>() {

@Override
public String apply(final EncryptColumn input) {
if (input.getPlainColumn().isPresent()) {
return input.getPlainColumn().get();
}
throw new ShardingException("Plain column is null.");
}
});
}
}
Expand Up @@ -23,7 +23,7 @@
import org.apache.shardingsphere.core.rewrite.token.generator.InsertSetQueryAndPlainColumnsTokenGenerator;
import org.apache.shardingsphere.core.rewrite.token.generator.InsertValuesTokenGenerator;
import org.apache.shardingsphere.core.rewrite.token.generator.SQLTokenGenerator;
import org.apache.shardingsphere.core.rewrite.token.generator.SelectCipherItemTokenGenerator;
import org.apache.shardingsphere.core.rewrite.token.generator.SelectEncryptItemTokenGenerator;
import org.apache.shardingsphere.core.rewrite.token.generator.UpdateEncryptColumnTokenGenerator;
import org.apache.shardingsphere.core.rewrite.token.generator.WhereEncryptColumnTokenGenerator;
import org.apache.shardingsphere.core.rule.EncryptRule;
Expand All @@ -41,7 +41,7 @@ public final class EncryptTokenGenerateEngine extends SQLTokenGenerateEngine<Enc
private static final Collection<SQLTokenGenerator> SQL_TOKEN_GENERATORS = new LinkedList<>();

static {
SQL_TOKEN_GENERATORS.add(new SelectCipherItemTokenGenerator());
SQL_TOKEN_GENERATORS.add(new SelectEncryptItemTokenGenerator());
SQL_TOKEN_GENERATORS.add(new UpdateEncryptColumnTokenGenerator());
SQL_TOKEN_GENERATORS.add(new WhereEncryptColumnTokenGenerator());
SQL_TOKEN_GENERATORS.add(new InsertCipherNameTokenGenerator());
Expand Down
Expand Up @@ -24,7 +24,7 @@
import org.apache.shardingsphere.core.parse.sql.segment.dml.item.SelectItemsSegment;
import org.apache.shardingsphere.core.parse.sql.statement.dml.SelectStatement;
import org.apache.shardingsphere.core.rewrite.builder.ParameterBuilder;
import org.apache.shardingsphere.core.rewrite.token.pojo.SelectCipherItemToken;
import org.apache.shardingsphere.core.rewrite.token.pojo.SelectEncryptItemToken;
import org.apache.shardingsphere.core.rule.EncryptRule;

import java.util.Collection;
Expand All @@ -36,19 +36,21 @@
*
* @author panjuan
*/
public final class SelectCipherItemTokenGenerator implements CollectionSQLTokenGenerator<EncryptRule> {
public final class SelectEncryptItemTokenGenerator implements CollectionSQLTokenGenerator<EncryptRule> {

private EncryptRule encryptRule;

private OptimizedStatement optimizedStatement;

private boolean isQueryWithCipherColumn;

@Override
public Collection<SelectCipherItemToken> generateSQLTokens(final OptimizedStatement optimizedStatement,
final ParameterBuilder parameterBuilder, final EncryptRule rule, final boolean isQueryWithCipherColumn) {
public Collection<SelectEncryptItemToken> generateSQLTokens(final OptimizedStatement optimizedStatement,
final ParameterBuilder parameterBuilder, final EncryptRule rule, final boolean isQueryWithCipherColumn) {
if (!isNeedToGenerateSQLToken(optimizedStatement)) {
return Collections.emptyList();
}
initParameters(rule, optimizedStatement);
initParameters(rule, optimizedStatement, isQueryWithCipherColumn);
return createSelectCipherItemTokens();
}

Expand All @@ -64,13 +66,14 @@ private boolean isSelectStatementWithTable(final OptimizedStatement optimizedSta
return optimizedStatement.getSQLStatement() instanceof SelectStatement && !optimizedStatement.getTables().isEmpty();
}

private void initParameters(final EncryptRule rule, final OptimizedStatement optimizedStatement) {
private void initParameters(final EncryptRule rule, final OptimizedStatement optimizedStatement, final boolean isQueryWithCipherColumn) {
encryptRule = rule;
this.optimizedStatement = optimizedStatement;
this.isQueryWithCipherColumn = isQueryWithCipherColumn;
}

private Collection<SelectCipherItemToken> createSelectCipherItemTokens() {
Collection<SelectCipherItemToken> result = new LinkedList<>();
private Collection<SelectEncryptItemToken> createSelectCipherItemTokens() {
Collection<SelectEncryptItemToken> result = new LinkedList<>();
SelectItemsSegment selectItemsSegment = optimizedStatement.getSQLStatement().findSQLSegment(SelectItemsSegment.class).get();
String tableName = optimizedStatement.getTables().getSingleTableName();
Collection<String> logicColumns = encryptRule.getLogicColumns(tableName);
Expand All @@ -86,8 +89,12 @@ private boolean isLogicColumn(final SelectItemSegment each, final Collection<Str
return each instanceof ColumnSelectItemSegment && logicColumns.contains(((ColumnSelectItemSegment) each).getName());
}

private SelectCipherItemToken createSelectCipherItemToken(final SelectItemSegment each, final String tableName) {
return new SelectCipherItemToken(each.getStartIndex(),
each.getStopIndex(), encryptRule.getCipherColumn(tableName, ((ColumnSelectItemSegment) each).getName()));
private SelectEncryptItemToken createSelectCipherItemToken(final SelectItemSegment each, final String tableName) {
String columnName = ((ColumnSelectItemSegment) each).getName();
Optional<String> plainColumn = encryptRule.getPlainColumn(tableName, columnName);
if (!isQueryWithCipherColumn && plainColumn.isPresent()) {
return new SelectEncryptItemToken(each.getStartIndex(), each.getStopIndex(), plainColumn.get());
}
return new SelectEncryptItemToken(each.getStartIndex(), each.getStopIndex(), encryptRule.getCipherColumn(tableName, columnName));
}
}
Expand Up @@ -25,13 +25,13 @@
* @author panjuan
*/
@Getter
public final class SelectCipherItemToken extends SQLToken implements Substitutable {
public final class SelectEncryptItemToken extends SQLToken implements Substitutable {

private final int stopIndex;

private final String selectItemName;

public SelectCipherItemToken(final int startIndex, final int stopIndex, final String selectItemName) {
public SelectEncryptItemToken(final int startIndex, final int stopIndex, final String selectItemName) {
super(startIndex);
this.stopIndex = stopIndex;
this.selectItemName = selectItemName;
Expand Down
Expand Up @@ -901,7 +901,7 @@ public void assertRewriteSelectInWithShardingEncryptorWithCipher() {
@Test
public void assertRewriteSelectInWithShardingEncryptorWithPlain() {
SQLRewriteEngine rewriteEngine = createSQLRewriteEngine(createSQLRouteResultForSelectInWithShardingEncryptor(), "SELECT id FROM table_z WHERE id in (3,5)", Collections.emptyList(), false);
assertThat(rewriteEngine.generateSQL(null, logicTableAndActualTables).getSql(), is("SELECT cipher FROM table_z WHERE plain IN ('3', '5')"));
assertThat(rewriteEngine.generateSQL(null, logicTableAndActualTables).getSql(), is("SELECT plain FROM table_z WHERE plain IN ('3', '5')"));
}

private SQLRouteResult createSQLRouteResultForSelectInWithShardingEncryptor() {
Expand Down Expand Up @@ -1030,7 +1030,7 @@ public void assertRewriteSelectEqualWithShardingEncryptorWithCipher() {
public void assertRewriteSelectEqualWithShardingEncryptorWithPlain() {
SQLRewriteEngine rewriteEngine = createSQLRewriteEngine(
createSQLRouteResultForSelectEqualWithShardingEncryptor(), "SELECT id FROM table_z WHERE id=? AND name=?", Arrays.<Object>asList(1, "x"), false);
assertThat(rewriteEngine.generateSQL().getSql(), is("SELECT cipher FROM table_z WHERE plain = ? AND name=?"));
assertThat(rewriteEngine.generateSQL().getSql(), is("SELECT plain FROM table_z WHERE plain = ? AND name=?"));
assertThat(getParameterBuilder(rewriteEngine).getParameters().get(0), is((Object) 1));
}

Expand Down Expand Up @@ -1062,7 +1062,7 @@ public void assertRewriteSelectInWithShardingEncryptorWithParameterWithCipher()
public void assertRewriteSelectInWithShardingEncryptorWithParameterWithPlain() {
SQLRewriteEngine rewriteEngine = createSQLRewriteEngine(
createSQLRouteResultForSelectInWithShardingEncryptorWithParameter(), "SELECT id FROM table_z WHERE id in (?, ?) or id = 3", Arrays.<Object>asList(1, 2), false);
assertThat(rewriteEngine.generateSQL(null, logicTableAndActualTables).getSql(), is("SELECT cipher FROM table_z WHERE plain IN (?, ?) or plain = '3'"));
assertThat(rewriteEngine.generateSQL(null, logicTableAndActualTables).getSql(), is("SELECT plain FROM table_z WHERE plain IN (?, ?) or plain = '3'"));
assertThat(getParameterBuilder(rewriteEngine).getParameters().get(0), is((Object) 1));
assertThat(getParameterBuilder(rewriteEngine).getParameters().get(1), is((Object) 2));
}
Expand Down
Expand Up @@ -20,6 +20,7 @@
import com.google.common.base.Preconditions;
import lombok.AccessLevel;
import lombok.Getter;
import org.apache.shardingsphere.core.constant.properties.ShardingPropertiesConstant;
import org.apache.shardingsphere.core.route.SQLRouteResult;
import org.apache.shardingsphere.core.rule.ShardingRule;
import org.apache.shardingsphere.shardingjdbc.jdbc.adapter.executor.ForceExecuteCallback;
Expand All @@ -42,6 +43,7 @@
* Adapter for {@code ResultSet}.
*
* @author zhangliang
* @author panjuan
*/
public abstract class AbstractResultSetAdapter extends AbstractUnsupportedOperationResultSet {

Expand Down Expand Up @@ -70,23 +72,41 @@ public AbstractResultSetAdapter(final List<ResultSet> resultSets, final Statemen

@Override
public final ResultSetMetaData getMetaData() throws SQLException {
return new ShardingResultSetMetaData(resultSets.get(0).getMetaData(), getShardingRule(), sqlRouteResult.getOptimizedStatement());
return new ShardingResultSetMetaData(resultSets.get(0).getMetaData(), getShardingRule(), sqlRouteResult.getOptimizedStatement(), logicAndActualColumns);
}

private Map<String, String> createLogicAndActualColumns() {
return isQueryWithCipherColumn() ? createLogicAndCipherColumns() : createLogicAndPlainColumns();
}

private Map<String, String> createLogicAndCipherColumns() {
Map<String, String> result = new LinkedHashMap<>();
for (String each : sqlRouteResult.getOptimizedStatement().getTables().getTableNames()) {
result.putAll(getShardingRule().getEncryptRule().getLogicAndCipherColumns(each));
}
return result;
}

private Map<String, String> createLogicAndPlainColumns() {
Map<String, String> result = new LinkedHashMap<>();
for (String each : sqlRouteResult.getOptimizedStatement().getTables().getTableNames()) {
result.putAll(getShardingRule().getEncryptRule().getLogicAndPlainColumns(each));
}
return result;
}

private ShardingRule getShardingRule() {
return statement instanceof ShardingPreparedStatement
? ((ShardingPreparedStatement) statement).getConnection().getRuntimeContext().getRule()
: ((ShardingStatement) statement).getConnection().getRuntimeContext().getRule();
}

private boolean isQueryWithCipherColumn() {
return statement instanceof ShardingPreparedStatement
? ((ShardingPreparedStatement) statement).getConnection().getRuntimeContext().getProps().<Boolean>getValue(ShardingPropertiesConstant.QUERY_WITH_CIPHER_COLUMN)
: ((ShardingStatement) statement).getConnection().getRuntimeContext().getProps().<Boolean>getValue(ShardingPropertiesConstant.QUERY_WITH_CIPHER_COLUMN);
}

@Override
public final int findColumn(final String columnLabel) throws SQLException {
return resultSets.get(0).findColumn(getActualColumnLabel(columnLabel));
Expand Down
Expand Up @@ -17,11 +17,13 @@

package org.apache.shardingsphere.shardingjdbc.jdbc.core.resultset;

import org.apache.shardingsphere.core.constant.properties.ShardingPropertiesConstant;
import org.apache.shardingsphere.core.execute.sql.execute.result.QueryResult;
import org.apache.shardingsphere.core.execute.sql.execute.result.StreamQueryResult;
import org.apache.shardingsphere.core.merge.dql.iterator.IteratorStreamMergedResult;
import org.apache.shardingsphere.core.optimize.api.statement.OptimizedStatement;
import org.apache.shardingsphere.core.rule.EncryptRule;
import org.apache.shardingsphere.shardingjdbc.jdbc.core.context.EncryptRuntimeContext;
import org.apache.shardingsphere.shardingjdbc.jdbc.unsupported.AbstractUnsupportedOperationResultSet;

import java.io.InputStream;
Expand Down Expand Up @@ -63,24 +65,36 @@ public final class EncryptResultSet extends AbstractUnsupportedOperationResultSe

private final Map<String, String> logicAndActualColumns;

public EncryptResultSet(final EncryptRule encryptRule, final OptimizedStatement optimizedStatement, final Statement encryptStatement, final ResultSet resultSet) {
this.encryptRule = encryptRule;
public EncryptResultSet(final EncryptRuntimeContext encryptRuntimeContext, final OptimizedStatement optimizedStatement, final Statement encryptStatement, final ResultSet resultSet) {
this.encryptRule = encryptRuntimeContext.getRule();
this.optimizedStatement = optimizedStatement;
this.encryptStatement = encryptStatement;
originalResultSet = resultSet;
QueryResult queryResult = new StreamQueryResult(resultSet, encryptRule);
this.resultSet = new IteratorStreamMergedResult(Collections.singletonList(queryResult));
logicAndActualColumns = createLogicAndActualColumns();
logicAndActualColumns = createLogicAndActualColumns(encryptRuntimeContext.getProps().<Boolean>getValue(ShardingPropertiesConstant.QUERY_WITH_CIPHER_COLUMN));
}

private Map<String, String> createLogicAndActualColumns() {
private Map<String, String> createLogicAndActualColumns(final boolean isQueryWithCipherColumn) {
return isQueryWithCipherColumn ? createLogicAndCipherColumns() : createLogicAndPlainColumns();
}

private Map<String, String> createLogicAndCipherColumns() {
Map<String, String> result = new LinkedHashMap<>();
for (String each : optimizedStatement.getTables().getTableNames()) {
result.putAll(encryptRule.getLogicAndCipherColumns(each));
}
return result;
}

private Map<String, String> createLogicAndPlainColumns() {
Map<String, String> result = new LinkedHashMap<>();
for (String each : optimizedStatement.getTables().getTableNames()) {
result.putAll(encryptRule.getLogicAndPlainColumns(each));
}
return result;
}

@Override
public boolean next() throws SQLException {
return resultSet.next();
Expand Down Expand Up @@ -357,7 +371,7 @@ public Object getObject(final String columnLabel) throws SQLException {

@Override
public ResultSetMetaData getMetaData() throws SQLException {
return new EncryptResultSetMetaData(originalResultSet.getMetaData(), encryptRule, optimizedStatement);
return new EncryptResultSetMetaData(originalResultSet.getMetaData(), encryptRule, optimizedStatement, logicAndActualColumns);
}

@Override
Expand Down
Expand Up @@ -25,7 +25,6 @@
import java.sql.ResultSetMetaData;
import java.sql.SQLException;
import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.Map;
import java.util.Map.Entry;
Expand All @@ -45,19 +44,12 @@ public final class EncryptResultSetMetaData extends WrapperAdapter implements Re

private final Map<String, String> logicAndActualColumns;

public EncryptResultSetMetaData(final ResultSetMetaData resultSetMetaData, final EncryptRule encryptRule, final OptimizedStatement optimizedStatement) {
public EncryptResultSetMetaData(final ResultSetMetaData resultSetMetaData,
final EncryptRule encryptRule, final OptimizedStatement optimizedStatement, final Map<String, String> logicAndActualColumns) {
this.resultSetMetaData = resultSetMetaData;
this.encryptRule = encryptRule;
this.optimizedStatement = optimizedStatement;
logicAndActualColumns = createLogicAndActualColumns();
}

private Map<String, String> createLogicAndActualColumns() {
Map<String, String> result = new LinkedHashMap<>();
for (String each : optimizedStatement.getTables().getTableNames()) {
result.putAll(encryptRule.getLogicAndCipherColumns(each));
}
return result;
this.logicAndActualColumns = logicAndActualColumns;
}

@Override
Expand Down
Expand Up @@ -26,7 +26,6 @@
import java.sql.ResultSetMetaData;
import java.sql.SQLException;
import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.Map;
import java.util.Map.Entry;
Expand All @@ -47,19 +46,12 @@ public final class ShardingResultSetMetaData extends WrapperAdapter implements R

private final Map<String, String> logicAndActualColumns;

public ShardingResultSetMetaData(final ResultSetMetaData resultSetMetaData, final ShardingRule shardingRule, final OptimizedStatement optimizedStatement) {
public ShardingResultSetMetaData(final ResultSetMetaData resultSetMetaData,
final ShardingRule shardingRule, final OptimizedStatement optimizedStatement, final Map<String, String> logicAndActualColumns) {
this.resultSetMetaData = resultSetMetaData;
this.shardingRule = shardingRule;
this.optimizedStatement = optimizedStatement;
logicAndActualColumns = createLogicAndActualColumns();
}

private Map<String, String> createLogicAndActualColumns() {
Map<String, String> result = new LinkedHashMap<>();
for (String each : optimizedStatement.getTables().getTableNames()) {
result.putAll(shardingRule.getEncryptRule().getLogicAndCipherColumns(each));
}
return result;
this.logicAndActualColumns = logicAndActualColumns;
}

@Override
Expand Down

0 comments on commit 3217df7

Please sign in to comment.