Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bugfix: update join condition placeholder param error #5052

1 change: 1 addition & 0 deletions changes/en-us/develop.md
Expand Up @@ -34,6 +34,7 @@ Add changes here for all PR submitted to the develop branch.
- [[#5033](https://github.com/seata/seata/pull/5023)] fix mysql InsertOnDuplicateUpdate insert value type recognize error
- [[#5038](https://github.com/seata/seata/pull/5038)] remove @EnableConfigurationProperties({SagaAsyncThreadPoolProperties.class})
- [[#5050](https://github.com/seata/seata/pull/5050)] fix global session is not change to Committed in saga mode
- [[#5052](https://github.com/seata/seata/pull/5052)] fix update join condition placeholder param error
- [[#5031](https://github.com/seata/seata/pull/5031)] fix mysql InsertOnDuplicateUpdate should not use null index value as image sql query condition
- [[#5075](https://github.com/seata/seata/pull/5075)] fix InsertOnDuplicateUpdateExecutor could not intercept the sql which has no primary and unique key

Expand Down
1 change: 1 addition & 0 deletions changes/zh-cn/develop.md
Expand Up @@ -36,6 +36,7 @@
- [[#5033](https://github.com/seata/seata/pull/5023)] 修复InsertOnDuplicateUpdate中插入值解析为String类型导致的类型识别错误
- [[#5038](https://github.com/seata/seata/pull/5038)] 修复SagaAsyncThreadPoolProperties冲突问题
- [[#5050](https://github.com/seata/seata/pull/5050)] 修复Saga模式下全局状态未正确更改成Committed
- [[#5052](https://github.com/seata/seata/pull/5052)] 修复update join条件中占位符参数问题
- [[#5031](https://github.com/seata/seata/pull/5031)] 修复InsertOnDuplicateUpdate中不应该使用null值索引作为查询条件
- [[#5075](https://github.com/seata/seata/pull/5075)] 修复InsertOnDuplicateUpdate无法拦截无主键和唯一索引的SQL

Expand Down
Expand Up @@ -412,7 +412,7 @@ protected String buildLockKey(TableRecords rowsIncludingPK) {
sb.append(":");
int rowSequence = 0;
List<Map<String, Field>> pksRows = rowsIncludingPK.pkRows();
List<String> primaryKeysOnlyName = getTableMeta().getPrimaryKeyOnlyName();
List<String> primaryKeysOnlyName = rowsIncludingPK.getTableMeta().getPrimaryKeyOnlyName();
for (Map<String, Field> rowMap : pksRows) {
int pkSplitIndex = 0;
for (String pkName : primaryKeysOnlyName) {
Expand Down
Expand Up @@ -29,8 +29,10 @@
import io.seata.common.util.CollectionUtils;
import io.seata.core.protocol.Version;
import io.seata.rm.datasource.ConnectionProxy;
import io.seata.rm.datasource.sql.struct.Field;
import io.seata.rm.datasource.sql.struct.TableMetaCacheFactory;
import io.seata.rm.datasource.undo.SQLUndoLog;
import io.seata.sqlparser.ParametersHolder;
import io.seata.sqlparser.SQLType;
import io.seata.common.exception.ShouldNeverHappenException;
import io.seata.common.util.IOUtil;
Expand All @@ -44,6 +46,7 @@
import io.seata.sqlparser.SQLRecognizer;
import io.seata.sqlparser.SQLUpdateRecognizer;
import io.seata.sqlparser.util.ColumnUtils;
import io.seata.sqlparser.JoinRecognizer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -79,25 +82,24 @@ protected TableRecords beforeImage() throws SQLException {
// update join sql,like update t1 inner join t2 on t1.id = t2.id set t1.name = ?; tableItems = {"update t1 inner join t2","t1","t2"}
String[] tableItems = tableNames.split(recognizer.MULTI_TABLE_NAME_SEPERATOR);
String joinTable = tableItems[0];
int itemTableIndex = 1;
final int itemTableIndex = 1;
String suffixCommonCondition = buildBeforeImageSQLCommonConditionSuffix(paramAppenderList);
for (int i = itemTableIndex; i < tableItems.length; i++) {
List<String> itemTableUpdateColumns = getItemUpdateColumns(this.getTableMeta(tableItems[i]), recognizer.getUpdateColumns());
if (CollectionUtils.isEmpty(itemTableUpdateColumns)) {
continue;
}
String selectSQL = buildBeforeImageSQL(joinTable, tableItems[i], itemTableUpdateColumns, paramAppenderList);
String selectSQL = buildBeforeImageSQL(joinTable, tableItems[i], suffixCommonCondition, itemTableUpdateColumns);
TableRecords tableRecords = buildTableRecords(getTableMeta(tableItems[i]), selectSQL, paramAppenderList);
beforeImagesMap.put(tableItems[i], tableRecords);
}
return null;
}

private String buildBeforeImageSQL(String joinTable, String itemTable, List<String> itemTableUpdateColumns,
ArrayList<List<Object>> paramAppenderList) {
private String buildBeforeImageSQLCommonConditionSuffix(ArrayList<List<Object>> paramAppenderList) {
SQLUpdateRecognizer recognizer = (SQLUpdateRecognizer) sqlRecognizer;
TableMeta itemTableMeta = getTableMeta(itemTable);
StringBuilder prefix = new StringBuilder("SELECT ");
StringBuilder suffix = new StringBuilder(" FROM ").append(joinTable);
StringBuilder suffix = new StringBuilder();
buildJoinCondition(recognizer,paramAppenderList);
String whereCondition = buildWhereCondition(recognizer, paramAppenderList);
String orderByCondition = buildOrderCondition(recognizer, paramAppenderList);
String limitCondition = buildLimitCondition(recognizer, paramAppenderList);
Expand All @@ -110,6 +112,21 @@ private String buildBeforeImageSQL(String joinTable, String itemTable, List<Stri
if (StringUtils.isNotBlank(limitCondition)) {
suffix.append(" ").append(limitCondition);
}
return suffix.toString();
}

private void buildJoinCondition(SQLUpdateRecognizer recognizer, ArrayList<List<Object>> paramAppenderList) {
if (statementProxy instanceof ParametersHolder) {
((JoinRecognizer)recognizer).getJoinCondition((ParametersHolder) statementProxy,paramAppenderList);
}
}

private String buildBeforeImageSQL(String joinTable, String itemTable,String suffixCondition, List<String> itemTableUpdateColumns) {
SQLUpdateRecognizer recognizer = (SQLUpdateRecognizer) sqlRecognizer;
TableMeta itemTableMeta = getTableMeta(itemTable);
StringBuilder prefix = new StringBuilder("SELECT ");
StringBuilder suffix = new StringBuilder(" FROM ").append(joinTable);
suffix.append(suffixCondition);
//maybe duplicate row for select join sql.remove duplicate row by 'group by' condition
suffix.append(GROUP_BY);
List<String> pkColumnNames = getColumnNamesWithTablePrefixList(itemTable, recognizer.getTableAlias(itemTable), itemTableMeta.getPrimaryKeyOnlyName());
Expand All @@ -129,16 +146,18 @@ protected TableRecords afterImage(TableRecords beforeImage) throws SQLException
String tableNames = recognizer.getTableName();
String[] tableItems = tableNames.split(recognizer.MULTI_TABLE_NAME_SEPERATOR);
String joinTable = tableItems[0];
int itemTableIndex = 1;
final int itemTableIndex = 1;
ArrayList<List<Object>> joinConditionParams = new ArrayList<>();
buildJoinCondition(recognizer,joinConditionParams);
for (int i = itemTableIndex; i < tableItems.length; i++) {
TableRecords tableBeforeImage = beforeImagesMap.get(tableItems[i]);
if (tableBeforeImage == null) {
if (tableBeforeImage == null || CollectionUtils.isEmpty(tableBeforeImage.getRows())) {
continue;
}
String selectSQL = buildAfterImageSQL(joinTable, tableItems[i], tableBeforeImage);
ResultSet rs = null;
try (PreparedStatement pst = statementProxy.getConnection().prepareStatement(selectSQL)) {
SqlGenerateUtils.setParamForPk(tableBeforeImage.pkRows(), getTableMeta(tableItems[i]).getPrimaryKeyOnlyName(), pst);
setAfterImageSQLPlaceHolderParams(joinConditionParams,tableBeforeImage.pkRows(), getTableMeta(tableItems[i]).getPrimaryKeyOnlyName(), pst);
rs = pst.executeQuery();
TableRecords afterImage = TableRecords.buildRecords(getTableMeta(tableItems[i]), rs);
afterImagesMap.put(tableItems[i], afterImage);
Expand All @@ -149,6 +168,29 @@ protected TableRecords afterImage(TableRecords beforeImage) throws SQLException
return null;
}

private void setAfterImageSQLPlaceHolderParams(ArrayList<List<Object>> joinConditionParams,
List<Map<String, Field>> pkRowsList, List<String> pkColumnNameList,
PreparedStatement pst) throws SQLException {
int paramIndex = 1;
if (CollectionUtils.isNotEmpty(joinConditionParams)) {
for (int i = 0, ts = joinConditionParams.size(); i < ts; i++) {
List<Object> paramAppender = joinConditionParams.get(i);
for (int j = 0, ds = paramAppender.size(); j < ds; j++) {
pst.setObject(paramIndex, paramAppender.get(j));
paramIndex++;
}
}
}
for (int i = 0; i < pkRowsList.size(); i++) {
Map<String, Field> rowData = pkRowsList.get(i);
for (String columnName : pkColumnNameList) {
Field pkField = rowData.get(columnName);
pst.setObject(paramIndex, pkField.getValue(), pkField.getType());
paramIndex++;
}
}
}

private String buildAfterImageSQL(String joinTable, String itemTable,
TableRecords beforeImage) throws SQLException {
SQLUpdateRecognizer recognizer = (SQLUpdateRecognizer) sqlRecognizer;
Expand Down
@@ -0,0 +1,34 @@
/*
* Copyright 1999-2019 Seata.io Group.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.seata.sqlparser;

import java.util.ArrayList;
import java.util.List;

/**
* The interface Where recognizer.
*
* @author renliangyu857
*/
public interface JoinRecognizer {
/**
* Gets join condition.
* @param parametersHolder the parameters holder
* @param paramAppenderList the param appender list
* @return the join condition
*/
String getJoinCondition(ParametersHolder parametersHolder, ArrayList<List<Object>> paramAppenderList);
}
Expand Up @@ -133,6 +133,17 @@ protected String getOrderByCondition(SQLOrderBy sqlOrderBy, final ParametersHold
return sb.toString();
}

protected String getJoinCondition(SQLExpr joinCondition,final ParametersHolder parametersHolder,
final ArrayList<List<Object>> paramAppenderList) {
if (Objects.isNull(joinCondition)) {
return StringUtils.EMPTY;
}

StringBuilder sb = new StringBuilder();
executeVisit(joinCondition, createOutputVisitor(parametersHolder, paramAppenderList, sb));
return sb.toString();
}

public String getDbType() {
return JdbcConstants.MYSQL;
}
Expand Down
Expand Up @@ -34,6 +34,7 @@
import com.alibaba.druid.sql.ast.statement.SQLUpdateSetItem;
import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlUpdateStatement;
import com.alibaba.druid.sql.dialect.mysql.visitor.MySqlOutputVisitor;
import io.seata.sqlparser.JoinRecognizer;
import io.seata.sqlparser.util.ColumnUtils;
import io.seata.sqlparser.ParametersHolder;
import io.seata.sqlparser.SQLType;
Expand All @@ -46,7 +47,7 @@
*
* @author sharajava
*/
public class MySQLUpdateRecognizer extends BaseMySQLRecognizer implements SQLUpdateRecognizer {
public class MySQLUpdateRecognizer extends BaseMySQLRecognizer implements SQLUpdateRecognizer, JoinRecognizer {

private final MySqlUpdateStatement ast;

Expand Down Expand Up @@ -191,6 +192,15 @@ public String getOrderByCondition(ParametersHolder parametersHolder, ArrayList<L
return super.getOrderByCondition(sqlOrderBy, parametersHolder, paramAppenderList);
}

@Override
public String getJoinCondition(ParametersHolder parametersHolder, ArrayList<List<Object>> paramAppenderList) {
if (!(ast.getTableSource() instanceof SQLJoinTableSource)) {
return "";
}
SQLExpr joinCondition = ((SQLJoinTableSource) ast.getTableSource()).getCondition();
return super.getJoinCondition(joinCondition, parametersHolder, paramAppenderList);
}

@Override
protected SQLStatement getAst() {
return ast;
Expand Down