From 92ad5e126a4a9766ae2799e087b72d0cfd268120 Mon Sep 17 00:00:00 2001 From: terrymanu Date: Fri, 31 Aug 2018 00:53:57 +0800 Subject: [PATCH] for #1025, refactor ShardingExecuteEngine, do not partition twice --- .../core/executor/ShardingExecuteEngine.java | 57 +++++++------ .../core/executor/sql/SQLExecuteTemplate.java | 70 ++++++++++++---- .../table/executor/TableMetaDataLoader.java | 17 +++- .../core/routing/SQLRouteResult.java | 6 +- .../shardingsphere/core/rule/TableRule.java | 4 +- .../batch/BatchPreparedStatementExecutor.java | 20 ++--- ...trictlyBatchPreparedStatementExecutor.java | 81 +++++++++++++++++++ ...trictlyBatchPreparedStatementExecutor.java | 56 +++++++++++++ ...tionStrictlyPreparedStatementExecutor.java | 68 ++++++++++++++++ ...moryStrictlyPreparedStatementExecutor.java | 49 +++++++++++ .../prepared/PreparedStatementExecutor.java | 18 ++--- .../ConnectionStrictlyStatementExecutor.java | 66 +++++++++++++++ .../MemoryStrictlyStatementExecutor.java | 49 +++++++++++ .../executor/statement/StatementExecutor.java | 18 ++--- .../statement/ShardingPreparedStatement.java | 62 ++++++++++---- .../core/statement/ShardingStatement.java | 28 ++++--- .../executor/AbstractBaseExecutorTest.java | 3 +- .../BatchPreparedStatementExecutorTest.java | 21 ++--- .../PreparedStatementExecutorTest.java | 59 +++++++------- .../core/executor/StatementExecutorTest.java | 45 ++++++----- .../execution/ExecuteEventListenerTest.java | 5 +- .../ConnectionStrictlyExecuteEngine.java | 28 ++++--- .../stream/MemoryStrictlyExecuteEngine.java | 3 +- 23 files changed, 648 insertions(+), 185 deletions(-) create mode 100644 sharding-jdbc/src/main/java/io/shardingsphere/core/executor/batch/ConnectionStrictlyBatchPreparedStatementExecutor.java create mode 100644 sharding-jdbc/src/main/java/io/shardingsphere/core/executor/batch/MemoryStrictlyBatchPreparedStatementExecutor.java create mode 100644 sharding-jdbc/src/main/java/io/shardingsphere/core/executor/prepared/ConnectionStrictlyPreparedStatementExecutor.java create mode 100644 sharding-jdbc/src/main/java/io/shardingsphere/core/executor/prepared/MemoryStrictlyPreparedStatementExecutor.java create mode 100644 sharding-jdbc/src/main/java/io/shardingsphere/core/executor/statement/ConnectionStrictlyStatementExecutor.java create mode 100644 sharding-jdbc/src/main/java/io/shardingsphere/core/executor/statement/MemoryStrictlyStatementExecutor.java diff --git a/sharding-core/src/main/java/io/shardingsphere/core/executor/ShardingExecuteEngine.java b/sharding-core/src/main/java/io/shardingsphere/core/executor/ShardingExecuteEngine.java index 641efacab9c8d..d8d57746f1887 100644 --- a/sharding-core/src/main/java/io/shardingsphere/core/executor/ShardingExecuteEngine.java +++ b/sharding-core/src/main/java/io/shardingsphere/core/executor/ShardingExecuteEngine.java @@ -27,7 +27,6 @@ import java.util.ArrayList; import java.util.Collection; import java.util.Collections; -import java.util.HashMap; import java.util.Iterator; import java.util.LinkedList; import java.util.List; @@ -126,22 +125,20 @@ private List getResults(final O firstResult, final Collection type of input value * @param type of return value * @return execute result * @throws SQLException throw if execute failure */ - public List groupExecute(final Map> inputs, final int maxThreadCountPerGroup, final ShardingGroupExecuteCallback callback) throws SQLException { - return groupExecute(inputs, maxThreadCountPerGroup, null, callback); + public List groupExecute(final Map>> inputs, final ShardingGroupExecuteCallback callback) throws SQLException { + return groupExecute(inputs, null, callback); } /** * execute all callbacks for group. * * @param inputs input value's map - * @param maxThreadCountPerGroup max thread count for every group * @param callback sharding execute callback * @param firstCallback first sharding execute callback * @param type of input value @@ -149,41 +146,43 @@ public List groupExecute(final Map> inputs, fina * @return execute result * @throws SQLException throw if execute failure */ - public List groupExecute(final Map> inputs, final int maxThreadCountPerGroup, - final ShardingGroupExecuteCallback firstCallback, final ShardingGroupExecuteCallback callback) throws SQLException { + public List groupExecute( + final Map>> inputs, final ShardingGroupExecuteCallback firstCallback, final ShardingGroupExecuteCallback callback) throws SQLException { if (inputs.isEmpty()) { return Collections.emptyList(); } - Map>> executionUnits = new HashMap<>(inputs.size(), 1); - for (Entry> entry : inputs.entrySet()) { - executionUnits.put(entry.getKey(), Lists.partition(new ArrayList<>(entry.getValue()), maxThreadCountPerGroup)); - } - String firstKey = executionUnits.keySet().iterator().next(); - Iterator> firstExecutionUnits = executionUnits.get(firstKey).iterator(); - Collection firstInputs = firstExecutionUnits.next(); - executionUnits.put(firstKey, Lists.newArrayList(firstExecutionUnits)); - Collection>> restResultFutures = asyncGroupExecute(executionUnits, callback); + String firstKey = inputs.keySet().iterator().next(); + Iterator> firstInputGroup = inputs.get(firstKey).iterator(); + Collection firstInputs = firstInputGroup.next(); + inputs.put(firstKey, Lists.newArrayList(firstInputGroup)); + Collection>> restResultFutures = asyncGroupExecute(inputs, callback); return getGroupResults(syncGroupExecute(firstKey, firstInputs, null == firstCallback ? callback : firstCallback), restResultFutures); } private Collection>> asyncGroupExecute(final Map>> inputs, final ShardingGroupExecuteCallback callback) { - Collection>> result = new ArrayList<>(inputs.size()); - for (final Entry>> entry : inputs.entrySet()) { - for (final List each : entry.getValue()) { - result.add(executorService.submit(new Callable>() { - - @Override - public Collection call() throws SQLException { - return callback.execute(entry.getKey(), each); - } - })); - } + Collection>> result = new LinkedList<>(); + for (Entry>> entry : inputs.entrySet()) { + result.addAll(asyncGroupExecute(entry.getKey(), entry.getValue(), callback)); + } + return result; + } + + private Collection>> asyncGroupExecute(final String key, final List> inputs, final ShardingGroupExecuteCallback callback) { + Collection>> result = new LinkedList<>(); + for (final List each : inputs) { + result.add(executorService.submit(new Callable>() { + + @Override + public Collection call() throws SQLException { + return callback.execute(key, each); + } + })); } return result; } - private Collection syncGroupExecute(final String dataSourceName, final Collection inputs, final ShardingGroupExecuteCallback callback) throws SQLException { - return callback.execute(dataSourceName, inputs); + private Collection syncGroupExecute(final String key, final Collection inputs, final ShardingGroupExecuteCallback callback) throws SQLException { + return callback.execute(key, inputs); } private List getGroupResults(final Collection firstResults, final Collection>> restFutures) throws SQLException { diff --git a/sharding-core/src/main/java/io/shardingsphere/core/executor/sql/SQLExecuteTemplate.java b/sharding-core/src/main/java/io/shardingsphere/core/executor/sql/SQLExecuteTemplate.java index 7765f8a3e0d18..1f6836c63853c 100644 --- a/sharding-core/src/main/java/io/shardingsphere/core/executor/sql/SQLExecuteTemplate.java +++ b/sharding-core/src/main/java/io/shardingsphere/core/executor/sql/SQLExecuteTemplate.java @@ -17,7 +17,6 @@ package io.shardingsphere.core.executor.sql; -import io.shardingsphere.core.constant.ConnectionMode; import io.shardingsphere.core.event.ShardingEventBusInstance; import io.shardingsphere.core.executor.ShardingExecuteEngine; import io.shardingsphere.core.executor.sql.event.overall.OverallExecutionEvent; @@ -27,10 +26,11 @@ import java.sql.SQLException; import java.util.Collection; import java.util.Collections; -import java.util.LinkedHashMap; +import java.util.HashMap; import java.util.LinkedList; import java.util.List; import java.util.Map; +import java.util.Map.Entry; /** * SQL execute template. @@ -45,9 +45,47 @@ public final class SQLExecuteTemplate { private final ShardingExecuteEngine executeEngine; - private final ConnectionMode connectionMode; + /** + * Execute. + * + * @param executeUnits execute units + * @param executeCallback execute callback + * @param class type of return value + * @return execute result + * @throws SQLException SQL exception + */ + public List execute(final Collection executeUnits, final SQLExecuteCallback executeCallback) throws SQLException { + return execute(executeUnits, null, executeCallback); + } - private final int maxConnectionsSizePerQuery; + /** + * Execute. + * + * @param executeUnits execute units + * @param firstExecuteCallback first execute callback + * @param executeCallback execute callback + * @param class type of return value + * @return execute result + * @throws SQLException SQL exception + */ + public List execute(final Collection executeUnits, + final SQLExecuteCallback firstExecuteCallback, final SQLExecuteCallback executeCallback) throws SQLException { + OverallExecutionEvent event = new OverallExecutionEvent(executeUnits.size() > 1); + ShardingEventBusInstance.getInstance().post(event); + try { + List result = executeEngine.execute(new LinkedList<>(executeUnits), firstExecuteCallback, executeCallback); + event.setExecuteSuccess(); + return result; + // CHECKSTYLE:OFF + } catch (final Exception ex) { + // CHECKSTYLE:ON + event.setExecuteFailure(ex); + ExecutorExceptionHandler.handleException(ex); + return Collections.emptyList(); + } finally { + ShardingEventBusInstance.getInstance().post(event); + } + } /** * Execute. @@ -58,7 +96,7 @@ public final class SQLExecuteTemplate { * @return execute result * @throws SQLException SQL exception */ - public List execute(final Collection executeUnits, final SQLExecuteCallback executeCallback) throws SQLException { + public List execute(final Map>> executeUnits, final SQLExecuteCallback executeCallback) throws SQLException { return execute(executeUnits, null, executeCallback); } @@ -72,13 +110,12 @@ public List execute(final Collection exec * @return execute result * @throws SQLException SQL exception */ - public List execute( - final Collection executeUnits, final SQLExecuteCallback firstExecuteCallback, final SQLExecuteCallback executeCallback) throws SQLException { + public List execute(final Map>> executeUnits, + final SQLExecuteCallback firstExecuteCallback, final SQLExecuteCallback executeCallback) throws SQLException { OverallExecutionEvent event = new OverallExecutionEvent(executeUnits.size() > 1); ShardingEventBusInstance.getInstance().post(event); try { - List result = ConnectionMode.MEMORY_STRICTLY == connectionMode ? executeEngine.execute(new LinkedList<>(executeUnits), firstExecuteCallback, executeCallback) - : executeEngine.groupExecute(getExecuteUnitGroups(executeUnits), maxConnectionsSizePerQuery, firstExecuteCallback, executeCallback); + List result = executeEngine.groupExecute(transform(executeUnits), firstExecuteCallback, executeCallback); event.setExecuteSuccess(); return result; // CHECKSTYLE:OFF @@ -92,14 +129,15 @@ public List execute( } } - private Map> getExecuteUnitGroups(final Collection executeUnits) { - Map> result = new LinkedHashMap<>(executeUnits.size(), 1); - for (StatementExecuteUnit each : executeUnits) { - String dataSourceName = each.getSqlExecutionUnit().getDataSource(); - if (!result.keySet().contains(dataSourceName)) { - result.put(dataSourceName, new LinkedList()); + private Map>> transform(final Map>> executeUnits) { + Map>> result = new HashMap<>(executeUnits.size()); + for (Entry>> entry : executeUnits.entrySet()) { + if (!result.containsKey(entry.getKey())) { + result.put(entry.getKey(), new LinkedList>()); + } + for (List each : entry.getValue()) { + result.get(entry.getKey()).add(new LinkedList<>(each)); } - result.get(dataSourceName).add(each); } return result; } diff --git a/sharding-core/src/main/java/io/shardingsphere/core/metadata/table/executor/TableMetaDataLoader.java b/sharding-core/src/main/java/io/shardingsphere/core/metadata/table/executor/TableMetaDataLoader.java index 3155339a23b34..57176765d956b 100644 --- a/sharding-core/src/main/java/io/shardingsphere/core/metadata/table/executor/TableMetaDataLoader.java +++ b/sharding-core/src/main/java/io/shardingsphere/core/metadata/table/executor/TableMetaDataLoader.java @@ -17,6 +17,7 @@ package io.shardingsphere.core.metadata.table.executor; +import com.google.common.collect.Lists; import io.shardingsphere.core.exception.ShardingException; import io.shardingsphere.core.executor.ShardingExecuteEngine; import io.shardingsphere.core.executor.ShardingGroupExecuteCallback; @@ -33,10 +34,12 @@ import java.sql.SQLException; import java.util.Collection; import java.util.Collections; +import java.util.HashMap; import java.util.HashSet; import java.util.LinkedList; import java.util.List; import java.util.Map; +import java.util.Map.Entry; /** * Table meta data loader. @@ -68,13 +71,13 @@ public TableMetaData load(final String logicTableName, final ShardingRule shardi return actualTableMetaDataList.iterator().next(); } - private List load(final Map> dataNodeGroups, final ShardingDataSourceNames shardingDataSourceNames) throws SQLException { - return executeEngine.groupExecute(dataNodeGroups, maxConnectionsSizePerQuery, new ShardingGroupExecuteCallback() { + private List load(final Map> dataNodeGroups, final ShardingDataSourceNames shardingDataSourceNames) throws SQLException { + return executeEngine.groupExecute(partitionDataNodeGroups(dataNodeGroups), new ShardingGroupExecuteCallback() { @Override public Collection execute(final String dataSourceName, final Collection actualTableNames) throws SQLException { DataSourceMetaData dataSourceMetaData = shardingDataSourceMetaData.getActualDataSourceMetaData(dataSourceName); - final String catalog = null == dataSourceMetaData ? null : dataSourceMetaData.getSchemeName(); + String catalog = null == dataSourceMetaData ? null : dataSourceMetaData.getSchemeName(); return load(shardingDataSourceNames.getRawMasterDataSourceName(dataSourceName), catalog, actualTableNames); } }); @@ -90,6 +93,14 @@ private Collection load(final String dataSourceName, final String return result; } + private Map>> partitionDataNodeGroups(final Map> dataNodeGroups) { + Map>> result = new HashMap<>(dataNodeGroups.size(), 1); + for (Entry> entry : dataNodeGroups.entrySet()) { + result.put(entry.getKey(), Lists.partition(entry.getValue(), maxConnectionsSizePerQuery)); + } + return result; + } + private boolean isTableExist(final Connection connection, final String catalog, final String actualTableName) throws SQLException { try (ResultSet resultSet = connection.getMetaData().getTables(catalog, null, actualTableName, null)) { return resultSet.next(); diff --git a/sharding-core/src/main/java/io/shardingsphere/core/routing/SQLRouteResult.java b/sharding-core/src/main/java/io/shardingsphere/core/routing/SQLRouteResult.java index 8759a6592bd71..4b893f69372bb 100644 --- a/sharding-core/src/main/java/io/shardingsphere/core/routing/SQLRouteResult.java +++ b/sharding-core/src/main/java/io/shardingsphere/core/routing/SQLRouteResult.java @@ -22,10 +22,10 @@ import lombok.Getter; import lombok.RequiredArgsConstructor; -import java.util.Collection; import java.util.LinkedHashMap; import java.util.LinkedHashSet; import java.util.LinkedList; +import java.util.List; import java.util.Map; import java.util.Set; @@ -55,8 +55,8 @@ public SQLRouteResult(final SQLStatement sqlStatement) { * * @return SQL units grouped by data source name. */ - public Map> getSQLUnitGroups() { - Map> result = new LinkedHashMap<>(executionUnits.size(), 1); + public Map> getSQLUnitGroups() { + Map> result = new LinkedHashMap<>(executionUnits.size(), 1); for (SQLExecutionUnit each : executionUnits) { if (!result.containsKey(each.getDataSource())) { result.put(each.getDataSource(), new LinkedList()); diff --git a/sharding-core/src/main/java/io/shardingsphere/core/rule/TableRule.java b/sharding-core/src/main/java/io/shardingsphere/core/rule/TableRule.java index f2cd56f4b9cac..e4bf88c1ef352 100644 --- a/sharding-core/src/main/java/io/shardingsphere/core/rule/TableRule.java +++ b/sharding-core/src/main/java/io/shardingsphere/core/rule/TableRule.java @@ -99,8 +99,8 @@ private List generateDataNodes(final List actualDataNodes, fin * * @return data node groups, key is data source name, value is tables belong to this data source */ - public Map> getDataNodeGroups() { - Map> result = new LinkedHashMap<>(actualDataNodes.size(), 1); + public Map> getDataNodeGroups() { + Map> result = new LinkedHashMap<>(actualDataNodes.size(), 1); for (DataNode each : actualDataNodes) { String dataSourceName = each.getDataSourceName(); if (!result.containsKey(dataSourceName)) { diff --git a/sharding-jdbc/src/main/java/io/shardingsphere/core/executor/batch/BatchPreparedStatementExecutor.java b/sharding-jdbc/src/main/java/io/shardingsphere/core/executor/batch/BatchPreparedStatementExecutor.java index 2232b9876239f..276b35476a179 100644 --- a/sharding-jdbc/src/main/java/io/shardingsphere/core/executor/batch/BatchPreparedStatementExecutor.java +++ b/sharding-jdbc/src/main/java/io/shardingsphere/core/executor/batch/BatchPreparedStatementExecutor.java @@ -20,7 +20,6 @@ import io.shardingsphere.core.constant.DatabaseType; import io.shardingsphere.core.constant.SQLType; import io.shardingsphere.core.executor.sql.SQLExecuteCallback; -import io.shardingsphere.core.executor.sql.SQLExecuteTemplate; import io.shardingsphere.core.executor.sql.StatementExecuteUnit; import io.shardingsphere.core.executor.sql.threadlocal.ExecutorDataMap; import io.shardingsphere.core.executor.sql.threadlocal.ExecutorExceptionHandler; @@ -30,24 +29,21 @@ import java.util.Collection; import java.util.List; import java.util.Map; +import java.util.Map.Entry; /** - * PreparedStatement Executor for multiple threads to process add batch. + * Prepared statement executor to process add batch. * * @author zhangliang * @author maxiaoguang */ @RequiredArgsConstructor -public final class BatchPreparedStatementExecutor { - - private final SQLExecuteTemplate executeTemplate; +public abstract class BatchPreparedStatementExecutor { private final DatabaseType dbType; private final SQLType sqlType; - private final Collection batchPreparedStatementUnits; - private final int batchCount; /** @@ -66,14 +62,14 @@ protected int[] executeSQL(final StatementExecuteUnit executeUnit) throws SQLExc return executeUnit.getStatement().executeBatch(); } }; - return accumulate(executeTemplate.execute(batchPreparedStatementUnits, callback)); + return accumulate(executeCallback(callback)); } private int[] accumulate(final List results) { int[] result = new int[batchCount]; int count = 0; - for (BatchPreparedStatementUnit each : batchPreparedStatementUnits) { - for (Map.Entry entry : each.getJdbcAndActualAddBatchCallTimesMap().entrySet()) { + for (BatchPreparedStatementUnit each : getBatchPreparedStatementUnitGroups()) { + for (Entry entry : each.getJdbcAndActualAddBatchCallTimesMap().entrySet()) { int value = null == results.get(count) ? 0 : results.get(count)[entry.getValue()]; if (DatabaseType.Oracle == dbType) { result[entry.getKey()] = value; @@ -85,4 +81,8 @@ private int[] accumulate(final List results) { } return result; } + + protected abstract List executeCallback(SQLExecuteCallback executeCallback) throws SQLException; + + protected abstract Collection getBatchPreparedStatementUnitGroups(); } diff --git a/sharding-jdbc/src/main/java/io/shardingsphere/core/executor/batch/ConnectionStrictlyBatchPreparedStatementExecutor.java b/sharding-jdbc/src/main/java/io/shardingsphere/core/executor/batch/ConnectionStrictlyBatchPreparedStatementExecutor.java new file mode 100644 index 0000000000000..293a7b5921552 --- /dev/null +++ b/sharding-jdbc/src/main/java/io/shardingsphere/core/executor/batch/ConnectionStrictlyBatchPreparedStatementExecutor.java @@ -0,0 +1,81 @@ +/* + * Copyright 2016-2018 shardingsphere.io. + *

+ * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *

+ */ + +package io.shardingsphere.core.executor.batch; + +import io.shardingsphere.core.constant.DatabaseType; +import io.shardingsphere.core.constant.SQLType; +import io.shardingsphere.core.executor.sql.SQLExecuteCallback; +import io.shardingsphere.core.executor.sql.SQLExecuteTemplate; +import io.shardingsphere.core.executor.sql.StatementExecuteUnit; + +import java.sql.SQLException; +import java.util.Collection; +import java.util.HashMap; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; + +/** + * Prepared statement executor to process add batch for connection strictly mode. + * + * @author zhangliang + */ +public final class ConnectionStrictlyBatchPreparedStatementExecutor extends BatchPreparedStatementExecutor { + + private final SQLExecuteTemplate executeTemplate; + + private final Map>> batchPreparedStatementUnitGroups; + + public ConnectionStrictlyBatchPreparedStatementExecutor(final DatabaseType dbType, final SQLType sqlType, final int batchCount, + final SQLExecuteTemplate executeTemplate, final Map>> batchPreparedStatementUnitGroups) { + super(dbType, sqlType, batchCount); + this.executeTemplate = executeTemplate; + this.batchPreparedStatementUnitGroups = batchPreparedStatementUnitGroups; + } + + @Override + protected List executeCallback(final SQLExecuteCallback executeCallback) throws SQLException { + return executeTemplate.execute(transformBatchPreparedStatementUnitGroups(), executeCallback); + } + + private Map>> transformBatchPreparedStatementUnitGroups() { + Map>> result = new HashMap<>(batchPreparedStatementUnitGroups.size(), 1); + for (Map.Entry>> entry : batchPreparedStatementUnitGroups.entrySet()) { + List> batchPreparedStatementUnitGroups = entry.getValue(); + for (List each : batchPreparedStatementUnitGroups) { + if (!result.containsKey(entry.getKey())) { + result.put(entry.getKey(), new LinkedList>()); + } + result.get(entry.getKey()).add(new LinkedList<>(each)); + } + } + return result; + } + + @Override + protected Collection getBatchPreparedStatementUnitGroups() { + Collection result = new LinkedList<>(); + for (Entry>> entry : batchPreparedStatementUnitGroups.entrySet()) { + for (List each : entry.getValue()) { + result.addAll(each); + } + } + return result; + } +} diff --git a/sharding-jdbc/src/main/java/io/shardingsphere/core/executor/batch/MemoryStrictlyBatchPreparedStatementExecutor.java b/sharding-jdbc/src/main/java/io/shardingsphere/core/executor/batch/MemoryStrictlyBatchPreparedStatementExecutor.java new file mode 100644 index 0000000000000..114f9890ab199 --- /dev/null +++ b/sharding-jdbc/src/main/java/io/shardingsphere/core/executor/batch/MemoryStrictlyBatchPreparedStatementExecutor.java @@ -0,0 +1,56 @@ +/* + * Copyright 2016-2018 shardingsphere.io. + *

+ * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *

+ */ + +package io.shardingsphere.core.executor.batch; + +import io.shardingsphere.core.constant.DatabaseType; +import io.shardingsphere.core.constant.SQLType; +import io.shardingsphere.core.executor.sql.SQLExecuteCallback; +import io.shardingsphere.core.executor.sql.SQLExecuteTemplate; + +import java.sql.SQLException; +import java.util.Collection; +import java.util.List; + +/** + * Prepared statement executor to process add batch for memory strictly mode. + * + * @author zhangliang + */ +public final class MemoryStrictlyBatchPreparedStatementExecutor extends BatchPreparedStatementExecutor { + + private final SQLExecuteTemplate executeTemplate; + + private final Collection batchPreparedStatementUnits; + + public MemoryStrictlyBatchPreparedStatementExecutor(final DatabaseType dbType, final SQLType sqlType, final int batchCount, + final SQLExecuteTemplate executeTemplate, final Collection batchPreparedStatementUnits) { + super(dbType, sqlType, batchCount); + this.executeTemplate = executeTemplate; + this.batchPreparedStatementUnits = batchPreparedStatementUnits; + } + + @Override + protected List executeCallback(final SQLExecuteCallback executeCallback) throws SQLException { + return executeTemplate.execute(batchPreparedStatementUnits, executeCallback); + } + + @Override + protected Collection getBatchPreparedStatementUnitGroups() { + return batchPreparedStatementUnits; + } +} diff --git a/sharding-jdbc/src/main/java/io/shardingsphere/core/executor/prepared/ConnectionStrictlyPreparedStatementExecutor.java b/sharding-jdbc/src/main/java/io/shardingsphere/core/executor/prepared/ConnectionStrictlyPreparedStatementExecutor.java new file mode 100644 index 0000000000000..a7d15b271d177 --- /dev/null +++ b/sharding-jdbc/src/main/java/io/shardingsphere/core/executor/prepared/ConnectionStrictlyPreparedStatementExecutor.java @@ -0,0 +1,68 @@ +/* + * Copyright 2016-2018 shardingsphere.io. + *

+ * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *

+ */ + +package io.shardingsphere.core.executor.prepared; + +import io.shardingsphere.core.constant.SQLType; +import io.shardingsphere.core.executor.sql.SQLExecuteCallback; +import io.shardingsphere.core.executor.sql.SQLExecuteTemplate; +import io.shardingsphere.core.executor.sql.StatementExecuteUnit; + +import java.sql.SQLException; +import java.util.HashMap; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; + +/** + * Prepared statement executor for connection strictly mode. + * + * @author zhangliang + */ +public final class ConnectionStrictlyPreparedStatementExecutor extends PreparedStatementExecutor { + + private final SQLExecuteTemplate executeTemplate; + + private final Map>> preparedStatementUnitGroups; + + public ConnectionStrictlyPreparedStatementExecutor( + final SQLType sqlType, final SQLExecuteTemplate executeTemplate, final Map>> preparedStatementUnitGroups) { + super(sqlType); + this.executeTemplate = executeTemplate; + this.preparedStatementUnitGroups = preparedStatementUnitGroups; + } + + @Override + protected List executeCallback(final SQLExecuteCallback executeCallback) throws SQLException { + return executeTemplate.execute(transformPreparedStatementUnitGroups(), executeCallback); + } + + private Map>> transformPreparedStatementUnitGroups() { + Map>> result = new HashMap<>(preparedStatementUnitGroups.size(), 1); + for (Entry>> entry : preparedStatementUnitGroups.entrySet()) { + List> preparedStatementUnitGroups = entry.getValue(); + for (List each : preparedStatementUnitGroups) { + if (!result.containsKey(entry.getKey())) { + result.put(entry.getKey(), new LinkedList>()); + } + result.get(entry.getKey()).add(new LinkedList<>(each)); + } + } + return result; + } +} diff --git a/sharding-jdbc/src/main/java/io/shardingsphere/core/executor/prepared/MemoryStrictlyPreparedStatementExecutor.java b/sharding-jdbc/src/main/java/io/shardingsphere/core/executor/prepared/MemoryStrictlyPreparedStatementExecutor.java new file mode 100644 index 0000000000000..451cd1201bdc6 --- /dev/null +++ b/sharding-jdbc/src/main/java/io/shardingsphere/core/executor/prepared/MemoryStrictlyPreparedStatementExecutor.java @@ -0,0 +1,49 @@ +/* + * Copyright 2016-2018 shardingsphere.io. + *

+ * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *

+ */ + +package io.shardingsphere.core.executor.prepared; + +import io.shardingsphere.core.constant.SQLType; +import io.shardingsphere.core.executor.sql.SQLExecuteCallback; +import io.shardingsphere.core.executor.sql.SQLExecuteTemplate; + +import java.sql.SQLException; +import java.util.Collection; +import java.util.List; + +/** + * Prepared statement executor for memory strictly mode. + * + * @author zhangliang + */ +public final class MemoryStrictlyPreparedStatementExecutor extends PreparedStatementExecutor { + + private final SQLExecuteTemplate executeTemplate; + + private final Collection preparedStatementUnits; + + public MemoryStrictlyPreparedStatementExecutor(final SQLType sqlType, final SQLExecuteTemplate executeTemplate, final Collection preparedStatementUnits) { + super(sqlType); + this.executeTemplate = executeTemplate; + this.preparedStatementUnits = preparedStatementUnits; + } + + @Override + protected List executeCallback(final SQLExecuteCallback executeCallback) throws SQLException { + return executeTemplate.execute(preparedStatementUnits, executeCallback); + } +} diff --git a/sharding-jdbc/src/main/java/io/shardingsphere/core/executor/prepared/PreparedStatementExecutor.java b/sharding-jdbc/src/main/java/io/shardingsphere/core/executor/prepared/PreparedStatementExecutor.java index 5114a079c1bf8..fa9bb708d833c 100644 --- a/sharding-jdbc/src/main/java/io/shardingsphere/core/executor/prepared/PreparedStatementExecutor.java +++ b/sharding-jdbc/src/main/java/io/shardingsphere/core/executor/prepared/PreparedStatementExecutor.java @@ -19,7 +19,6 @@ import io.shardingsphere.core.constant.SQLType; import io.shardingsphere.core.executor.sql.SQLExecuteCallback; -import io.shardingsphere.core.executor.sql.SQLExecuteTemplate; import io.shardingsphere.core.executor.sql.StatementExecuteUnit; import io.shardingsphere.core.executor.sql.threadlocal.ExecutorDataMap; import io.shardingsphere.core.executor.sql.threadlocal.ExecutorExceptionHandler; @@ -28,26 +27,21 @@ import java.sql.PreparedStatement; import java.sql.ResultSet; import java.sql.SQLException; -import java.util.Collection; import java.util.List; import java.util.Map; /** - * PreparedStatement Executor for multiple threads. + * Prepared statement executor. * * @author zhangliang * @author caohao * @author maxiaoguang */ @RequiredArgsConstructor -public final class PreparedStatementExecutor { - - private final SQLExecuteTemplate executeTemplate; +public abstract class PreparedStatementExecutor { private final SQLType sqlType; - private final Collection preparedStatementUnits; - /** * Execute query. * @@ -64,7 +58,7 @@ protected ResultSet executeSQL(final StatementExecuteUnit executeUnit) throws SQ return ((PreparedStatement) executeUnit.getStatement()).executeQuery(); } }; - return executeTemplate.execute(preparedStatementUnits, executeCallback); + return executeCallback(executeCallback); } /** @@ -83,7 +77,7 @@ protected Integer executeSQL(final StatementExecuteUnit executeUnit) throws SQLE return ((PreparedStatement) executeUnit.getStatement()).executeUpdate(); } }; - List results = executeTemplate.execute(preparedStatementUnits, executeCallback); + List results = executeCallback(executeCallback); return accumulate(results); } @@ -111,10 +105,12 @@ protected Boolean executeSQL(final StatementExecuteUnit executeUnit) throws SQLE return ((PreparedStatement) executeUnit.getStatement()).execute(); } }; - List result = executeTemplate.execute(preparedStatementUnits, executeCallback); + List result = executeCallback(executeCallback); if (null == result || result.isEmpty() || null == result.get(0)) { return false; } return result.get(0); } + + protected abstract List executeCallback(SQLExecuteCallback executeCallback) throws SQLException; } diff --git a/sharding-jdbc/src/main/java/io/shardingsphere/core/executor/statement/ConnectionStrictlyStatementExecutor.java b/sharding-jdbc/src/main/java/io/shardingsphere/core/executor/statement/ConnectionStrictlyStatementExecutor.java new file mode 100644 index 0000000000000..3d265b864edd9 --- /dev/null +++ b/sharding-jdbc/src/main/java/io/shardingsphere/core/executor/statement/ConnectionStrictlyStatementExecutor.java @@ -0,0 +1,66 @@ +/* + * Copyright 2016-2018 shardingsphere.io. + *

+ * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *

+ */ + +package io.shardingsphere.core.executor.statement; + +import io.shardingsphere.core.constant.SQLType; +import io.shardingsphere.core.executor.sql.SQLExecuteCallback; +import io.shardingsphere.core.executor.sql.SQLExecuteTemplate; +import io.shardingsphere.core.executor.sql.StatementExecuteUnit; + +import java.sql.SQLException; +import java.util.HashMap; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; + +/** + * Statement executor for connection strictly mode. + * + * @author zhangliang + */ +public final class ConnectionStrictlyStatementExecutor extends StatementExecutor { + + private final SQLExecuteTemplate executeTemplate; + + private final Map>> statementUnitGroups; + + public ConnectionStrictlyStatementExecutor(final SQLType sqlType, final SQLExecuteTemplate executeTemplate, final Map>> statementUnitGroups) { + super(sqlType); + this.executeTemplate = executeTemplate; + this.statementUnitGroups = statementUnitGroups; + } + + @Override + protected List executeCallback(final SQLExecuteCallback executeCallback) throws SQLException { + return executeTemplate.execute(transformStatementUnitGroups(), executeCallback); + } + + private Map>> transformStatementUnitGroups() { + Map>> result = new HashMap<>(statementUnitGroups.size(), 1); + for (Map.Entry>> entry : statementUnitGroups.entrySet()) { + List> statementUnitGroups = entry.getValue(); + for (List each : statementUnitGroups) { + if (!result.containsKey(entry.getKey())) { + result.put(entry.getKey(), new LinkedList>()); + } + result.get(entry.getKey()).add(new LinkedList<>(each)); + } + } + return result; + } +} diff --git a/sharding-jdbc/src/main/java/io/shardingsphere/core/executor/statement/MemoryStrictlyStatementExecutor.java b/sharding-jdbc/src/main/java/io/shardingsphere/core/executor/statement/MemoryStrictlyStatementExecutor.java new file mode 100644 index 0000000000000..b53e8f59de405 --- /dev/null +++ b/sharding-jdbc/src/main/java/io/shardingsphere/core/executor/statement/MemoryStrictlyStatementExecutor.java @@ -0,0 +1,49 @@ +/* + * Copyright 2016-2018 shardingsphere.io. + *

+ * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *

+ */ + +package io.shardingsphere.core.executor.statement; + +import io.shardingsphere.core.constant.SQLType; +import io.shardingsphere.core.executor.sql.SQLExecuteCallback; +import io.shardingsphere.core.executor.sql.SQLExecuteTemplate; + +import java.sql.SQLException; +import java.util.Collection; +import java.util.List; + +/** + * Statement executor for memory strictly mode. + * + * @author zhangliang + */ +public final class MemoryStrictlyStatementExecutor extends StatementExecutor { + + private final SQLExecuteTemplate executeTemplate; + + private final Collection statementUnits; + + public MemoryStrictlyStatementExecutor(final SQLType sqlType, final SQLExecuteTemplate executeTemplate, final Collection statementUnits) { + super(sqlType); + this.executeTemplate = executeTemplate; + this.statementUnits = statementUnits; + } + + @Override + protected List executeCallback(final SQLExecuteCallback executeCallback) throws SQLException { + return executeTemplate.execute(statementUnits, executeCallback); + } +} diff --git a/sharding-jdbc/src/main/java/io/shardingsphere/core/executor/statement/StatementExecutor.java b/sharding-jdbc/src/main/java/io/shardingsphere/core/executor/statement/StatementExecutor.java index 8809208865833..891a4d8c959d4 100644 --- a/sharding-jdbc/src/main/java/io/shardingsphere/core/executor/statement/StatementExecutor.java +++ b/sharding-jdbc/src/main/java/io/shardingsphere/core/executor/statement/StatementExecutor.java @@ -19,7 +19,6 @@ import io.shardingsphere.core.constant.SQLType; import io.shardingsphere.core.executor.sql.SQLExecuteCallback; -import io.shardingsphere.core.executor.sql.SQLExecuteTemplate; import io.shardingsphere.core.executor.sql.StatementExecuteUnit; import io.shardingsphere.core.executor.sql.threadlocal.ExecutorDataMap; import io.shardingsphere.core.executor.sql.threadlocal.ExecutorExceptionHandler; @@ -28,12 +27,11 @@ import java.sql.ResultSet; import java.sql.SQLException; import java.sql.Statement; -import java.util.Collection; import java.util.List; import java.util.Map; /** - * Statement Executor for multiple threads. + * Statement executor. * * @author gaohongtao * @author caohao @@ -41,14 +39,10 @@ * @author maxiaoguang */ @RequiredArgsConstructor -public final class StatementExecutor { - - private final SQLExecuteTemplate executeTemplate; +public abstract class StatementExecutor { private final SQLType sqlType; - private final Collection statementUnits; - /** * Execute query. * @@ -65,7 +59,7 @@ protected ResultSet executeSQL(final StatementExecuteUnit executeUnit) throws SQ return executeUnit.getStatement().executeQuery(executeUnit.getSqlExecutionUnit().getSqlUnit().getSql()); } }; - return executeTemplate.execute(statementUnits, executeCallback); + return executeCallback(executeCallback); } /** @@ -145,7 +139,7 @@ protected Integer executeSQL(final StatementExecuteUnit executeUnit) throws SQLE return updater.executeUpdate(executeUnit.getStatement(), executeUnit.getSqlExecutionUnit().getSqlUnit().getSql()); } }; - List results = executeTemplate.execute(statementUnits, executeCallback); + List results = executeCallback(executeCallback); return accumulate(results); } @@ -234,13 +228,15 @@ protected Boolean executeSQL(final StatementExecuteUnit executeUnit) throws SQLE return executor.execute(executeUnit.getStatement(), executeUnit.getSqlExecutionUnit().getSqlUnit().getSql()); } }; - List result = executeTemplate.execute(statementUnits, executeCallback); + List result = executeCallback(executeCallback); if (null == result || result.isEmpty() || null == result.get(0)) { return false; } return result.get(0); } + protected abstract List executeCallback(SQLExecuteCallback executeCallback) throws SQLException; + private interface Updater { int executeUpdate(Statement statement, String sql) throws SQLException; diff --git a/sharding-jdbc/src/main/java/io/shardingsphere/core/jdbc/core/statement/ShardingPreparedStatement.java b/sharding-jdbc/src/main/java/io/shardingsphere/core/jdbc/core/statement/ShardingPreparedStatement.java index 70572b10715d9..b3460c0e1b9e2 100644 --- a/sharding-jdbc/src/main/java/io/shardingsphere/core/jdbc/core/statement/ShardingPreparedStatement.java +++ b/sharding-jdbc/src/main/java/io/shardingsphere/core/jdbc/core/statement/ShardingPreparedStatement.java @@ -24,8 +24,11 @@ import io.shardingsphere.core.constant.ConnectionMode; import io.shardingsphere.core.constant.SQLType; import io.shardingsphere.core.event.ShardingEventBusInstance; -import io.shardingsphere.core.executor.batch.BatchPreparedStatementExecutor; import io.shardingsphere.core.executor.batch.BatchPreparedStatementUnit; +import io.shardingsphere.core.executor.batch.ConnectionStrictlyBatchPreparedStatementExecutor; +import io.shardingsphere.core.executor.batch.MemoryStrictlyBatchPreparedStatementExecutor; +import io.shardingsphere.core.executor.prepared.ConnectionStrictlyPreparedStatementExecutor; +import io.shardingsphere.core.executor.prepared.MemoryStrictlyPreparedStatementExecutor; import io.shardingsphere.core.executor.prepared.PreparedStatementExecutor; import io.shardingsphere.core.executor.prepared.PreparedStatementUnit; import io.shardingsphere.core.executor.sql.SQLExecuteTemplate; @@ -63,8 +66,10 @@ import java.sql.Statement; import java.util.ArrayList; import java.util.Collection; +import java.util.HashMap; import java.util.LinkedList; import java.util.List; +import java.util.Map; import java.util.Map.Entry; import java.util.Objects; @@ -248,11 +253,11 @@ private void sqlRoute() { } private PreparedStatementExecutor getPreparedStatementExecutor() throws SQLException { - ConnectionMode connectionMode = connection.getShardingDataSource().getShardingContext().getConnectionMode(); - int maxConnectionsSizePerQuery = connection.getShardingDataSource().getShardingContext().getMaxConnectionsSizePerQuery(); - SQLExecuteTemplate sqlExecuteTemplate = new SQLExecuteTemplate(connection.getShardingDataSource().getShardingContext().getExecuteEngine(), connectionMode, maxConnectionsSizePerQuery); - Collection executeUnits = ConnectionMode.MEMORY_STRICTLY == connectionMode ? getExecuteUnitsForMemoryStrictly() : getExecuteUnitsForConnectionStrictly(); - return new PreparedStatementExecutor(sqlExecuteTemplate, routeResult.getSqlStatement().getType(), executeUnits); + SQLExecuteTemplate sqlExecuteTemplate = new SQLExecuteTemplate(connection.getShardingDataSource().getShardingContext().getExecuteEngine()); + if (ConnectionMode.MEMORY_STRICTLY == connection.getShardingDataSource().getShardingContext().getConnectionMode()) { + return new MemoryStrictlyPreparedStatementExecutor(routeResult.getSqlStatement().getType(), sqlExecuteTemplate, getExecuteUnitsForMemoryStrictly()); + } + return new ConnectionStrictlyPreparedStatementExecutor(routeResult.getSqlStatement().getType(), sqlExecuteTemplate, getExecuteUnitsForConnectionStrictly()); } private Collection getExecuteUnitsForMemoryStrictly() throws SQLException { @@ -263,15 +268,21 @@ private Collection getExecuteUnitsForMemoryStrictly() thr return result; } - private Collection getExecuteUnitsForConnectionStrictly() throws SQLException { - Collection result = new LinkedList<>(); - for (Entry> entry : routeResult.getSQLUnitGroups().entrySet()) { + private Map>> getExecuteUnitsForConnectionStrictly() throws SQLException { + Map> sqlUnitGroups = routeResult.getSQLUnitGroups(); + Map>> result = new HashMap<>(sqlUnitGroups.size(), 1); + for (Entry> entry : sqlUnitGroups.entrySet()) { String dataSourceName = entry.getKey(); for (List sqlUnitList : Lists.partition(new ArrayList<>(entry.getValue()), connection.getShardingDataSource().getShardingContext().getMaxConnectionsSizePerQuery())) { Connection connection = this.connection.getConnection(dataSourceName); + List preparedStatementUnits = new LinkedList<>(); for (SQLUnit each : sqlUnitList) { - result.add(getPreparedStatementUnit(connection, new SQLExecutionUnit(dataSourceName, each))); + preparedStatementUnits.add(getPreparedStatementUnit(connection, new SQLExecutionUnit(dataSourceName, each))); } + if (!result.containsKey(dataSourceName)) { + result.put(dataSourceName, new LinkedList>()); + } + result.get(dataSourceName).add(preparedStatementUnits); } } return result; @@ -309,15 +320,38 @@ private PreparedStatement createPreparedStatement(final Connection connection, f @Override public int[] executeBatch() throws SQLException { try { - ConnectionMode connectionMode = connection.getShardingDataSource().getShardingContext().getConnectionMode(); - int maxConnectionsSizePerQuery = connection.getShardingDataSource().getShardingContext().getMaxConnectionsSizePerQuery(); - return new BatchPreparedStatementExecutor(new SQLExecuteTemplate(connection.getShardingDataSource().getShardingContext().getExecuteEngine(), connectionMode, maxConnectionsSizePerQuery), - connection.getShardingDataSource().getShardingContext().getDatabaseType(), routeResult.getSqlStatement().getType(), batchStatementUnits, batchCount).executeBatch(); + SQLExecuteTemplate sqlExecuteTemplate = new SQLExecuteTemplate(connection.getShardingDataSource().getShardingContext().getExecuteEngine()); + if (ConnectionMode.MEMORY_STRICTLY == connection.getShardingDataSource().getShardingContext().getConnectionMode()) { + return new MemoryStrictlyBatchPreparedStatementExecutor(connection.getShardingDataSource().getShardingContext().getDatabaseType(), + routeResult.getSqlStatement().getType(), batchCount, sqlExecuteTemplate, batchStatementUnits).executeBatch(); + } + return new ConnectionStrictlyBatchPreparedStatementExecutor(connection.getShardingDataSource().getShardingContext().getDatabaseType(), + routeResult.getSqlStatement().getType(), batchCount, sqlExecuteTemplate, partitionBatchPreparedStatementUnitGroups()).executeBatch(); } finally { clearBatch(); } } + private Map>> partitionBatchPreparedStatementUnitGroups() { + Map>> result = new HashMap<>(batchStatementUnits.size(), 1); + for (Entry> entry : getBatchPreparedStatementUnitGroups().entrySet()) { + result.put(entry.getKey(), Lists.partition(entry.getValue(), connection.getShardingDataSource().getShardingContext().getMaxConnectionsSizePerQuery())); + } + return result; + } + + private Map> getBatchPreparedStatementUnitGroups() { + Map> result = new HashMap<>(batchStatementUnits.size(), 1); + for (BatchPreparedStatementUnit each : batchStatementUnits) { + String dataSourceName = each.getSqlExecutionUnit().getDataSource(); + if (!result.containsKey(dataSourceName)) { + result.put(dataSourceName, new LinkedList()); + } + result.get(dataSourceName).add(each); + } + return result; + } + @Override public ResultSet getGeneratedKeys() throws SQLException { Optional generatedKey = getGeneratedKey(); diff --git a/sharding-jdbc/src/main/java/io/shardingsphere/core/jdbc/core/statement/ShardingStatement.java b/sharding-jdbc/src/main/java/io/shardingsphere/core/jdbc/core/statement/ShardingStatement.java index 55fd500fbc902..3ec7ce4cb2cc4 100644 --- a/sharding-jdbc/src/main/java/io/shardingsphere/core/jdbc/core/statement/ShardingStatement.java +++ b/sharding-jdbc/src/main/java/io/shardingsphere/core/jdbc/core/statement/ShardingStatement.java @@ -25,6 +25,8 @@ import io.shardingsphere.core.executor.sql.SQLExecuteTemplate; import io.shardingsphere.core.executor.sql.result.MemoryQueryResult; import io.shardingsphere.core.executor.sql.result.StreamQueryResult; +import io.shardingsphere.core.executor.statement.ConnectionStrictlyStatementExecutor; +import io.shardingsphere.core.executor.statement.MemoryStrictlyStatementExecutor; import io.shardingsphere.core.executor.statement.StatementExecutor; import io.shardingsphere.core.executor.statement.StatementUnit; import io.shardingsphere.core.jdbc.adapter.AbstractStatementAdapter; @@ -58,8 +60,10 @@ import java.sql.Statement; import java.util.ArrayList; import java.util.Collection; +import java.util.HashMap; import java.util.LinkedList; import java.util.List; +import java.util.Map; import java.util.Map.Entry; /** @@ -240,11 +244,11 @@ public boolean execute(final String sql, final String[] columnNames) throws SQLE } private StatementExecutor getStatementExecutor() throws SQLException { - ConnectionMode connectionMode = connection.getShardingDataSource().getShardingContext().getConnectionMode(); - int maxConnectionsSizePerQuery = connection.getShardingDataSource().getShardingContext().getMaxConnectionsSizePerQuery(); - SQLExecuteTemplate sqlExecuteTemplate = new SQLExecuteTemplate(connection.getShardingDataSource().getShardingContext().getExecuteEngine(), connectionMode, maxConnectionsSizePerQuery); - Collection executeUnits = ConnectionMode.MEMORY_STRICTLY == connectionMode ? getExecuteUnitsForMemoryStrictly() : getExecuteUnitsForConnectionStrictly(); - return new StatementExecutor(sqlExecuteTemplate, routeResult.getSqlStatement().getType(), executeUnits); + SQLExecuteTemplate sqlExecuteTemplate = new SQLExecuteTemplate(connection.getShardingDataSource().getShardingContext().getExecuteEngine()); + if (ConnectionMode.MEMORY_STRICTLY == connection.getShardingDataSource().getShardingContext().getConnectionMode()) { + return new MemoryStrictlyStatementExecutor(routeResult.getSqlStatement().getType(), sqlExecuteTemplate, getExecuteUnitsForMemoryStrictly()); + } + return new ConnectionStrictlyStatementExecutor(routeResult.getSqlStatement().getType(), sqlExecuteTemplate, getExecuteUnitsForConnectionStrictly()); } private Collection getExecuteUnitsForMemoryStrictly() throws SQLException { @@ -255,15 +259,21 @@ private Collection getExecuteUnitsForMemoryStrictly() throws SQLE return result; } - private Collection getExecuteUnitsForConnectionStrictly() throws SQLException { - Collection result = new LinkedList<>(); - for (Entry> entry : routeResult.getSQLUnitGroups().entrySet()) { + private Map>> getExecuteUnitsForConnectionStrictly() throws SQLException { + Map> sqlUnitGroups = routeResult.getSQLUnitGroups(); + Map>> result = new HashMap<>(sqlUnitGroups.size(), 1); + for (Entry> entry : sqlUnitGroups.entrySet()) { String dataSourceName = entry.getKey(); for (List sqlUnitList : Lists.partition(new ArrayList<>(entry.getValue()), connection.getShardingDataSource().getShardingContext().getMaxConnectionsSizePerQuery())) { Connection connection = this.connection.getConnection(dataSourceName); + List statementUnits = new LinkedList<>(); for (SQLUnit each : sqlUnitList) { - result.add(getStatementUnit(connection, new SQLExecutionUnit(dataSourceName, each))); + statementUnits.add(getStatementUnit(connection, new SQLExecutionUnit(dataSourceName, each))); + } + if (!result.containsKey(dataSourceName)) { + result.put(dataSourceName, new LinkedList>()); } + result.get(dataSourceName).add(statementUnits); } } return result; diff --git a/sharding-jdbc/src/test/java/io/shardingsphere/core/executor/AbstractBaseExecutorTest.java b/sharding-jdbc/src/test/java/io/shardingsphere/core/executor/AbstractBaseExecutorTest.java index fbd1088019fa1..ddd26bdc1b9ef 100644 --- a/sharding-jdbc/src/test/java/io/shardingsphere/core/executor/AbstractBaseExecutorTest.java +++ b/sharding-jdbc/src/test/java/io/shardingsphere/core/executor/AbstractBaseExecutorTest.java @@ -17,7 +17,6 @@ package io.shardingsphere.core.executor; -import io.shardingsphere.core.constant.ConnectionMode; import io.shardingsphere.core.event.ShardingEventBusInstance; import io.shardingsphere.core.executor.fixture.EventCaller; import io.shardingsphere.core.executor.fixture.ExecutorTestUtil; @@ -54,7 +53,7 @@ public void setUp() { MockitoAnnotations.initMocks(this); ExecutorExceptionHandler.setExceptionThrown(false); executeEngine = new ShardingExecuteEngine(Runtime.getRuntime().availableProcessors()); - executeTemplate = new SQLExecuteTemplate(executeEngine, ConnectionMode.MEMORY_STRICTLY, 1); + executeTemplate = new SQLExecuteTemplate(executeEngine); overallExecutionEventListener = new TestOverallExecutionEventListener(eventCaller); dqlExecutionEventListener = new TestDQLExecutionEventListener(eventCaller); dmlExecutionEventListener = new TestDMLExecutionEventListener(eventCaller); diff --git a/sharding-jdbc/src/test/java/io/shardingsphere/core/executor/BatchPreparedStatementExecutorTest.java b/sharding-jdbc/src/test/java/io/shardingsphere/core/executor/BatchPreparedStatementExecutorTest.java index c6dfd2c564767..fea2a1c3c07ef 100644 --- a/sharding-jdbc/src/test/java/io/shardingsphere/core/executor/BatchPreparedStatementExecutorTest.java +++ b/sharding-jdbc/src/test/java/io/shardingsphere/core/executor/BatchPreparedStatementExecutorTest.java @@ -22,6 +22,7 @@ import io.shardingsphere.core.event.ShardingEventType; import io.shardingsphere.core.executor.batch.BatchPreparedStatementExecutor; import io.shardingsphere.core.executor.batch.BatchPreparedStatementUnit; +import io.shardingsphere.core.executor.batch.MemoryStrictlyBatchPreparedStatementExecutor; import io.shardingsphere.core.rewrite.SQLBuilder; import io.shardingsphere.core.routing.SQLExecutionUnit; import org.junit.Test; @@ -47,8 +48,8 @@ public final class BatchPreparedStatementExecutorTest extends AbstractBaseExecut @SuppressWarnings("unchecked") @Test public void assertNoPreparedStatement() throws SQLException { - BatchPreparedStatementExecutor actual = new BatchPreparedStatementExecutor(getExecuteTemplate(), DatabaseType.MySQL, SQLType.DML, - Collections.emptyList(), 2); + BatchPreparedStatementExecutor actual = new MemoryStrictlyBatchPreparedStatementExecutor( + DatabaseType.MySQL, SQLType.DML, 2, getExecuteTemplate(), Collections.emptyList()); assertThat(actual.executeBatch(), is(new int[] {0, 0})); } @@ -57,8 +58,8 @@ public void assertExecuteBatchForSinglePreparedStatementSuccess() throws SQLExce PreparedStatement preparedStatement = mock(PreparedStatement.class); when(preparedStatement.executeBatch()).thenReturn(new int[] {10, 20}); when(preparedStatement.getConnection()).thenReturn(mock(Connection.class)); - BatchPreparedStatementExecutor actual = new BatchPreparedStatementExecutor(getExecuteTemplate(), DatabaseType.MySQL, SQLType.DML, - createPreparedStatementUnits(SQL, preparedStatement, "ds_0", 2), 2); + BatchPreparedStatementExecutor actual = new MemoryStrictlyBatchPreparedStatementExecutor( + DatabaseType.MySQL, SQLType.DML, 2, getExecuteTemplate(), createPreparedStatementUnits(SQL, preparedStatement, "ds_0", 2)); assertThat(actual.executeBatch(), is(new int[] {10, 20})); verify(preparedStatement).executeBatch(); verify(getEventCaller(), times(4)).verifyDataSource("ds_0"); @@ -78,8 +79,8 @@ public void assertExecuteBatchForMultiplePreparedStatementsSuccess() throws SQLE when(preparedStatement2.executeBatch()).thenReturn(new int[] {20, 40}); when(preparedStatement1.getConnection()).thenReturn(mock(Connection.class)); when(preparedStatement2.getConnection()).thenReturn(mock(Connection.class)); - BatchPreparedStatementExecutor actual = new BatchPreparedStatementExecutor(getExecuteTemplate(), DatabaseType.MySQL, SQLType.DML, - createPreparedStatementUnits(SQL, preparedStatement1, "ds_0", preparedStatement2, "ds_1", 2), 2); + BatchPreparedStatementExecutor actual = new MemoryStrictlyBatchPreparedStatementExecutor(DatabaseType.MySQL, SQLType.DML, 2, getExecuteTemplate(), + createPreparedStatementUnits(SQL, preparedStatement1, "ds_0", preparedStatement2, "ds_1", 2)); assertThat(actual.executeBatch(), is(new int[] {30, 60})); verify(preparedStatement1).executeBatch(); verify(preparedStatement2).executeBatch(); @@ -99,8 +100,8 @@ public void assertExecuteBatchForSinglePreparedStatementFailure() throws SQLExce SQLException exp = new SQLException(); when(preparedStatement.executeBatch()).thenThrow(exp); when(preparedStatement.getConnection()).thenReturn(mock(Connection.class)); - BatchPreparedStatementExecutor actual = new BatchPreparedStatementExecutor(getExecuteTemplate(), DatabaseType.MySQL, SQLType.DML, - createPreparedStatementUnits(SQL, preparedStatement, "ds_0", 2), 2); + BatchPreparedStatementExecutor actual = new MemoryStrictlyBatchPreparedStatementExecutor(DatabaseType.MySQL, SQLType.DML, 2, getExecuteTemplate(), + createPreparedStatementUnits(SQL, preparedStatement, "ds_0", 2)); assertThat(actual.executeBatch(), is(new int[] {0, 0})); verify(preparedStatement).executeBatch(); verify(getEventCaller(), times(4)).verifyDataSource("ds_0"); @@ -121,8 +122,8 @@ public void assertExecuteBatchForMultiplePreparedStatementsFailure() throws SQLE when(preparedStatement2.executeBatch()).thenThrow(exp); when(preparedStatement1.getConnection()).thenReturn(mock(Connection.class)); when(preparedStatement2.getConnection()).thenReturn(mock(Connection.class)); - BatchPreparedStatementExecutor actual = new BatchPreparedStatementExecutor(getExecuteTemplate(), DatabaseType.MySQL, SQLType.DML, - createPreparedStatementUnits(SQL, preparedStatement1, "ds_0", preparedStatement2, "ds_1", 2), 2); + BatchPreparedStatementExecutor actual = new MemoryStrictlyBatchPreparedStatementExecutor(DatabaseType.MySQL, SQLType.DML, 2, getExecuteTemplate(), + createPreparedStatementUnits(SQL, preparedStatement1, "ds_0", preparedStatement2, "ds_1", 2)); assertThat(actual.executeBatch(), is(new int[] {0, 0})); verify(preparedStatement1).executeBatch(); verify(preparedStatement2).executeBatch(); diff --git a/sharding-jdbc/src/test/java/io/shardingsphere/core/executor/PreparedStatementExecutorTest.java b/sharding-jdbc/src/test/java/io/shardingsphere/core/executor/PreparedStatementExecutorTest.java index 1c2578799f61f..edbf7c8a55610 100644 --- a/sharding-jdbc/src/test/java/io/shardingsphere/core/executor/PreparedStatementExecutorTest.java +++ b/sharding-jdbc/src/test/java/io/shardingsphere/core/executor/PreparedStatementExecutorTest.java @@ -19,6 +19,7 @@ import io.shardingsphere.core.constant.SQLType; import io.shardingsphere.core.event.ShardingEventType; +import io.shardingsphere.core.executor.prepared.MemoryStrictlyPreparedStatementExecutor; import io.shardingsphere.core.executor.prepared.PreparedStatementExecutor; import io.shardingsphere.core.executor.prepared.PreparedStatementUnit; import io.shardingsphere.core.rewrite.SQLBuilder; @@ -54,7 +55,7 @@ public final class PreparedStatementExecutorTest extends AbstractBaseExecutorTes @SuppressWarnings("unchecked") @Test public void assertNoStatement() throws SQLException { - PreparedStatementExecutor actual = new PreparedStatementExecutor(getExecuteTemplate(), SQLType.DQL, Collections.emptyList()); + PreparedStatementExecutor actual = new MemoryStrictlyPreparedStatementExecutor(SQLType.DQL, getExecuteTemplate(), Collections.emptyList()); assertFalse(actual.execute()); assertThat(actual.executeUpdate(), is(0)); assertThat(actual.executeQuery().size(), is(0)); @@ -66,8 +67,8 @@ public void assertExecuteQueryForSinglePreparedStatementSuccess() throws SQLExce ResultSet resultSet = mock(ResultSet.class); when(preparedStatement.executeQuery()).thenReturn(resultSet); when(preparedStatement.getConnection()).thenReturn(mock(Connection.class)); - PreparedStatementExecutor actual = new PreparedStatementExecutor( - getExecuteTemplate(), SQLType.DQL, createPreparedStatementUnits(DQL_SQL, preparedStatement, "ds_0")); + PreparedStatementExecutor actual = new MemoryStrictlyPreparedStatementExecutor( + SQLType.DQL, getExecuteTemplate(), createPreparedStatementUnits(DQL_SQL, preparedStatement, "ds_0")); assertThat(actual.executeQuery(), is(Collections.singletonList(resultSet))); verify(preparedStatement).executeQuery(); verify(getEventCaller(), times(2)).verifyDataSource("ds_0"); @@ -88,8 +89,8 @@ public void assertExecuteQueryForMultiplePreparedStatementsSuccess() throws SQLE when(preparedStatement2.executeQuery()).thenReturn(resultSet2); when(preparedStatement1.getConnection()).thenReturn(mock(Connection.class)); when(preparedStatement2.getConnection()).thenReturn(mock(Connection.class)); - PreparedStatementExecutor actual = new PreparedStatementExecutor( - getExecuteTemplate(), SQLType.DQL, createPreparedStatementUnits(DQL_SQL, preparedStatement1, "ds_0", preparedStatement2, "ds_1")); + PreparedStatementExecutor actual = new MemoryStrictlyPreparedStatementExecutor( + SQLType.DQL, getExecuteTemplate(), createPreparedStatementUnits(DQL_SQL, preparedStatement1, "ds_0", preparedStatement2, "ds_1")); List actualResultSets = actual.executeQuery(); assertThat(actualResultSets, hasItem(resultSet1)); assertThat(actualResultSets, hasItem(resultSet2)); @@ -110,8 +111,8 @@ public void assertExecuteQueryForSinglePreparedStatementFailure() throws SQLExce SQLException exp = new SQLException(); when(preparedStatement.executeQuery()).thenThrow(exp); when(preparedStatement.getConnection()).thenReturn(mock(Connection.class)); - PreparedStatementExecutor actual = new PreparedStatementExecutor( - getExecuteTemplate(), SQLType.DQL, createPreparedStatementUnits(DQL_SQL, preparedStatement, "ds_0")); + PreparedStatementExecutor actual = new MemoryStrictlyPreparedStatementExecutor( + SQLType.DQL, getExecuteTemplate(), createPreparedStatementUnits(DQL_SQL, preparedStatement, "ds_0")); assertThat(actual.executeQuery(), is(Collections.singletonList((ResultSet) null))); verify(preparedStatement).executeQuery(); verify(getEventCaller(), times(2)).verifyDataSource("ds_0"); @@ -131,8 +132,8 @@ public void assertExecuteQueryForMultiplePreparedStatementsFailure() throws SQLE when(preparedStatement2.executeQuery()).thenThrow(exp); when(preparedStatement1.getConnection()).thenReturn(mock(Connection.class)); when(preparedStatement2.getConnection()).thenReturn(mock(Connection.class)); - PreparedStatementExecutor actual = new PreparedStatementExecutor( - getExecuteTemplate(), SQLType.DQL, createPreparedStatementUnits(DQL_SQL, preparedStatement1, "ds_0", preparedStatement2, "ds_1")); + PreparedStatementExecutor actual = new MemoryStrictlyPreparedStatementExecutor( + SQLType.DQL, getExecuteTemplate(), createPreparedStatementUnits(DQL_SQL, preparedStatement1, "ds_0", preparedStatement2, "ds_1")); List actualResultSets = actual.executeQuery(); assertThat(actualResultSets, is(Arrays.asList((ResultSet) null, null))); verify(preparedStatement1).executeQuery(); @@ -151,8 +152,8 @@ public void assertExecuteUpdateForSinglePreparedStatementSuccess() throws SQLExc PreparedStatement preparedStatement = mock(PreparedStatement.class); when(preparedStatement.executeUpdate()).thenReturn(10); when(preparedStatement.getConnection()).thenReturn(mock(Connection.class)); - PreparedStatementExecutor actual = new PreparedStatementExecutor( - getExecuteTemplate(), SQLType.DML, createPreparedStatementUnits(DML_SQL, preparedStatement, "ds_0")); + PreparedStatementExecutor actual = new MemoryStrictlyPreparedStatementExecutor( + SQLType.DML, getExecuteTemplate(), createPreparedStatementUnits(DML_SQL, preparedStatement, "ds_0")); assertThat(actual.executeUpdate(), is(10)); verify(preparedStatement).executeUpdate(); verify(getEventCaller(), times(2)).verifyDataSource("ds_0"); @@ -171,8 +172,8 @@ public void assertExecuteUpdateForMultiplePreparedStatementsSuccess() throws SQL when(preparedStatement2.executeUpdate()).thenReturn(20); when(preparedStatement1.getConnection()).thenReturn(mock(Connection.class)); when(preparedStatement2.getConnection()).thenReturn(mock(Connection.class)); - PreparedStatementExecutor actual = new PreparedStatementExecutor( - getExecuteTemplate(), SQLType.DML, createPreparedStatementUnits(DML_SQL, preparedStatement1, "ds_0", preparedStatement2, "ds_1")); + PreparedStatementExecutor actual = new MemoryStrictlyPreparedStatementExecutor( + SQLType.DML, getExecuteTemplate(), createPreparedStatementUnits(DML_SQL, preparedStatement1, "ds_0", preparedStatement2, "ds_1")); assertThat(actual.executeUpdate(), is(30)); verify(preparedStatement1).executeUpdate(); verify(preparedStatement2).executeUpdate(); @@ -191,8 +192,8 @@ public void assertExecuteUpdateForSinglePreparedStatementFailure() throws SQLExc SQLException exp = new SQLException(); when(preparedStatement.executeUpdate()).thenThrow(exp); when(preparedStatement.getConnection()).thenReturn(mock(Connection.class)); - PreparedStatementExecutor actual = new PreparedStatementExecutor( - getExecuteTemplate(), SQLType.DML, createPreparedStatementUnits(DML_SQL, preparedStatement, "ds_0")); + PreparedStatementExecutor actual = new MemoryStrictlyPreparedStatementExecutor( + SQLType.DML, getExecuteTemplate(), createPreparedStatementUnits(DML_SQL, preparedStatement, "ds_0")); assertThat(actual.executeUpdate(), is(0)); verify(preparedStatement).executeUpdate(); verify(getEventCaller(), times(2)).verifyDataSource("ds_0"); @@ -212,8 +213,8 @@ public void assertExecuteUpdateForMultiplePreparedStatementsFailure() throws SQL when(preparedStatement2.executeUpdate()).thenThrow(exp); when(preparedStatement1.getConnection()).thenReturn(mock(Connection.class)); when(preparedStatement2.getConnection()).thenReturn(mock(Connection.class)); - PreparedStatementExecutor actual = new PreparedStatementExecutor( - getExecuteTemplate(), SQLType.DML, createPreparedStatementUnits(DML_SQL, preparedStatement1, "ds_0", preparedStatement2, "ds_1")); + PreparedStatementExecutor actual = new MemoryStrictlyPreparedStatementExecutor( + SQLType.DML, getExecuteTemplate(), createPreparedStatementUnits(DML_SQL, preparedStatement1, "ds_0", preparedStatement2, "ds_1")); assertThat(actual.executeUpdate(), is(0)); verify(preparedStatement1).executeUpdate(); verify(preparedStatement2).executeUpdate(); @@ -231,8 +232,8 @@ public void assertExecuteForSinglePreparedStatementSuccessWithDML() throws SQLEx PreparedStatement preparedStatement = mock(PreparedStatement.class); when(preparedStatement.execute()).thenReturn(false); when(preparedStatement.getConnection()).thenReturn(mock(Connection.class)); - PreparedStatementExecutor actual = new PreparedStatementExecutor( - getExecuteTemplate(), SQLType.DML, createPreparedStatementUnits(DML_SQL, preparedStatement, "ds_0")); + PreparedStatementExecutor actual = new MemoryStrictlyPreparedStatementExecutor( + SQLType.DML, getExecuteTemplate(), createPreparedStatementUnits(DML_SQL, preparedStatement, "ds_0")); assertFalse(actual.execute()); verify(preparedStatement).execute(); verify(getEventCaller(), times(2)).verifyDataSource("ds_0"); @@ -251,8 +252,8 @@ public void assertExecuteForMultiplePreparedStatementsSuccessWithDML() throws SQ when(preparedStatement2.execute()).thenReturn(false); when(preparedStatement1.getConnection()).thenReturn(mock(Connection.class)); when(preparedStatement2.getConnection()).thenReturn(mock(Connection.class)); - PreparedStatementExecutor actual = new PreparedStatementExecutor( - getExecuteTemplate(), SQLType.DML, createPreparedStatementUnits(DML_SQL, preparedStatement1, "ds_0", preparedStatement2, "ds_1")); + PreparedStatementExecutor actual = new MemoryStrictlyPreparedStatementExecutor( + SQLType.DML, getExecuteTemplate(), createPreparedStatementUnits(DML_SQL, preparedStatement1, "ds_0", preparedStatement2, "ds_1")); assertFalse(actual.execute()); verify(preparedStatement1).execute(); verify(preparedStatement2).execute(); @@ -271,8 +272,8 @@ public void assertExecuteForSinglePreparedStatementFailureWithDML() throws SQLEx SQLException exp = new SQLException(); when(preparedStatement.execute()).thenThrow(exp); when(preparedStatement.getConnection()).thenReturn(mock(Connection.class)); - PreparedStatementExecutor actual = new PreparedStatementExecutor( - getExecuteTemplate(), SQLType.DML, createPreparedStatementUnits(DML_SQL, preparedStatement, "ds_0")); + PreparedStatementExecutor actual = new MemoryStrictlyPreparedStatementExecutor( + SQLType.DML, getExecuteTemplate(), createPreparedStatementUnits(DML_SQL, preparedStatement, "ds_0")); assertFalse(actual.execute()); verify(preparedStatement).execute(); verify(getEventCaller(), times(2)).verifyDataSource("ds_0"); @@ -292,8 +293,8 @@ public void assertExecuteForMultiplePreparedStatementsFailureWithDML() throws SQ when(preparedStatement2.execute()).thenThrow(exp); when(preparedStatement1.getConnection()).thenReturn(mock(Connection.class)); when(preparedStatement2.getConnection()).thenReturn(mock(Connection.class)); - PreparedStatementExecutor actual = new PreparedStatementExecutor( - getExecuteTemplate(), SQLType.DML, createPreparedStatementUnits(DML_SQL, preparedStatement1, "ds_0", preparedStatement2, "ds_1")); + PreparedStatementExecutor actual = new MemoryStrictlyPreparedStatementExecutor( + SQLType.DML, getExecuteTemplate(), createPreparedStatementUnits(DML_SQL, preparedStatement1, "ds_0", preparedStatement2, "ds_1")); assertFalse(actual.execute()); verify(preparedStatement1).execute(); verify(preparedStatement2).execute(); @@ -311,8 +312,8 @@ public void assertExecuteForSinglePreparedStatementWithDQL() throws SQLException PreparedStatement preparedStatement = mock(PreparedStatement.class); when(preparedStatement.execute()).thenReturn(true); when(preparedStatement.getConnection()).thenReturn(mock(Connection.class)); - PreparedStatementExecutor actual = new PreparedStatementExecutor( - getExecuteTemplate(), SQLType.DQL, createPreparedStatementUnits(DQL_SQL, preparedStatement, "ds_0")); + PreparedStatementExecutor actual = new MemoryStrictlyPreparedStatementExecutor( + SQLType.DQL, getExecuteTemplate(), createPreparedStatementUnits(DQL_SQL, preparedStatement, "ds_0")); assertTrue(actual.execute()); verify(preparedStatement).execute(); verify(getEventCaller(), times(2)).verifyDataSource("ds_0"); @@ -331,8 +332,8 @@ public void assertExecuteForMultiplePreparedStatements() throws SQLException { when(preparedStatement2.execute()).thenReturn(true); when(preparedStatement1.getConnection()).thenReturn(mock(Connection.class)); when(preparedStatement2.getConnection()).thenReturn(mock(Connection.class)); - PreparedStatementExecutor actual = new PreparedStatementExecutor( - getExecuteTemplate(), SQLType.DQL, createPreparedStatementUnits(DQL_SQL, preparedStatement1, "ds_0", preparedStatement2, "ds_1")); + PreparedStatementExecutor actual = new MemoryStrictlyPreparedStatementExecutor( + SQLType.DQL, getExecuteTemplate(), createPreparedStatementUnits(DQL_SQL, preparedStatement1, "ds_0", preparedStatement2, "ds_1")); assertTrue(actual.execute()); verify(preparedStatement1).execute(); verify(preparedStatement2).execute(); diff --git a/sharding-jdbc/src/test/java/io/shardingsphere/core/executor/StatementExecutorTest.java b/sharding-jdbc/src/test/java/io/shardingsphere/core/executor/StatementExecutorTest.java index 6a6abeb64fcf3..84116d0a65085 100644 --- a/sharding-jdbc/src/test/java/io/shardingsphere/core/executor/StatementExecutorTest.java +++ b/sharding-jdbc/src/test/java/io/shardingsphere/core/executor/StatementExecutorTest.java @@ -20,6 +20,7 @@ import io.shardingsphere.core.constant.SQLType; import io.shardingsphere.core.event.ShardingEventType; import io.shardingsphere.core.executor.sql.threadlocal.ExecutorExceptionHandler; +import io.shardingsphere.core.executor.statement.MemoryStrictlyStatementExecutor; import io.shardingsphere.core.executor.statement.StatementExecutor; import io.shardingsphere.core.executor.statement.StatementUnit; import io.shardingsphere.core.rewrite.SQLBuilder; @@ -54,7 +55,7 @@ public final class StatementExecutorTest extends AbstractBaseExecutorTest { @Test public void assertNoStatement() throws SQLException { - StatementExecutor actual = new StatementExecutor(getExecuteTemplate(), SQLType.DQL, Collections.emptyList()); + StatementExecutor actual = new MemoryStrictlyStatementExecutor(SQLType.DQL, getExecuteTemplate(), Collections.emptyList()); assertFalse(actual.execute()); assertThat(actual.executeUpdate(), is(0)); assertThat(actual.executeQuery().size(), is(0)); @@ -66,7 +67,7 @@ public void assertExecuteQueryForSingleStatementSuccess() throws SQLException { ResultSet resultSet = mock(ResultSet.class); when(statement.executeQuery(DQL_SQL)).thenReturn(resultSet); when(statement.getConnection()).thenReturn(mock(Connection.class)); - StatementExecutor actual = new StatementExecutor(getExecuteTemplate(), SQLType.DQL, createStatementUnits(DQL_SQL, statement, "ds_0")); + StatementExecutor actual = new MemoryStrictlyStatementExecutor(SQLType.DQL, getExecuteTemplate(), createStatementUnits(DQL_SQL, statement, "ds_0")); assertThat(actual.executeQuery(), is(Collections.singletonList(resultSet))); verify(statement).executeQuery(DQL_SQL); verify(getEventCaller(), times(2)).verifyDataSource("ds_0"); @@ -87,7 +88,7 @@ public void assertExecuteQueryForMultipleStatementsSuccess() throws SQLException when(statement1.getConnection()).thenReturn(mock(Connection.class)); when(statement2.executeQuery(DQL_SQL)).thenReturn(resultSet2); when(statement2.getConnection()).thenReturn(mock(Connection.class)); - StatementExecutor actual = new StatementExecutor(getExecuteTemplate(), SQLType.DQL, createStatementUnits(DQL_SQL, statement1, "ds_0", statement2, "ds_1")); + StatementExecutor actual = new MemoryStrictlyStatementExecutor(SQLType.DQL, getExecuteTemplate(), createStatementUnits(DQL_SQL, statement1, "ds_0", statement2, "ds_1")); List actualResultSets = actual.executeQuery(); assertThat(actualResultSets, hasItem(resultSet1)); assertThat(actualResultSets, hasItem(resultSet2)); @@ -108,7 +109,7 @@ public void assertExecuteQueryForSingleStatementFailure() throws SQLException { SQLException exp = new SQLException(); when(statement.executeQuery(DQL_SQL)).thenThrow(exp); when(statement.getConnection()).thenReturn(mock(Connection.class)); - StatementExecutor actual = new StatementExecutor(getExecuteTemplate(), SQLType.DQL, createStatementUnits(DQL_SQL, statement, "ds_0")); + StatementExecutor actual = new MemoryStrictlyStatementExecutor(SQLType.DQL, getExecuteTemplate(), createStatementUnits(DQL_SQL, statement, "ds_0")); assertThat(actual.executeQuery(), is(Collections.singletonList((ResultSet) null))); verify(statement).executeQuery(DQL_SQL); verify(getEventCaller(), times(2)).verifyDataSource("ds_0"); @@ -128,7 +129,7 @@ public void assertExecuteQueryForMultipleStatementsFailure() throws SQLException when(statement2.executeQuery(DQL_SQL)).thenThrow(exp); when(statement1.getConnection()).thenReturn(mock(Connection.class)); when(statement2.getConnection()).thenReturn(mock(Connection.class)); - StatementExecutor actual = new StatementExecutor(getExecuteTemplate(), SQLType.DQL, createStatementUnits(DQL_SQL, statement1, "ds_0", statement2, "ds_1")); + StatementExecutor actual = new MemoryStrictlyStatementExecutor(SQLType.DQL, getExecuteTemplate(), createStatementUnits(DQL_SQL, statement1, "ds_0", statement2, "ds_1")); List actualResultSets = actual.executeQuery(); assertThat(actualResultSets, is(Arrays.asList((ResultSet) null, null))); verify(statement1).executeQuery(DQL_SQL); @@ -147,7 +148,7 @@ public void assertExecuteUpdateForSingleStatementSuccess() throws SQLException { Statement statement = mock(Statement.class); when(statement.executeUpdate(DML_SQL)).thenReturn(10); when(statement.getConnection()).thenReturn(mock(Connection.class)); - StatementExecutor actual = new StatementExecutor(getExecuteTemplate(), SQLType.DML, createStatementUnits(DML_SQL, statement, "ds_0")); + StatementExecutor actual = new MemoryStrictlyStatementExecutor(SQLType.DML, getExecuteTemplate(), createStatementUnits(DML_SQL, statement, "ds_0")); assertThat(actual.executeUpdate(), is(10)); verify(statement).executeUpdate(DML_SQL); verify(getEventCaller(), times(2)).verifyDataSource("ds_0"); @@ -166,7 +167,7 @@ public void assertExecuteUpdateForMultipleStatementsSuccess() throws SQLExceptio when(statement2.executeUpdate(DML_SQL)).thenReturn(20); when(statement1.getConnection()).thenReturn(mock(Connection.class)); when(statement2.getConnection()).thenReturn(mock(Connection.class)); - StatementExecutor actual = new StatementExecutor(getExecuteTemplate(), SQLType.DML, createStatementUnits(DML_SQL, statement1, "ds_0", statement2, "ds_1")); + StatementExecutor actual = new MemoryStrictlyStatementExecutor(SQLType.DML, getExecuteTemplate(), createStatementUnits(DML_SQL, statement1, "ds_0", statement2, "ds_1")); assertThat(actual.executeUpdate(), is(30)); verify(statement1).executeUpdate(DML_SQL); verify(statement2).executeUpdate(DML_SQL); @@ -185,7 +186,7 @@ public void assertExecuteUpdateForSingleStatementFailure() throws SQLException { SQLException exp = new SQLException(); when(statement.executeUpdate(DML_SQL)).thenThrow(exp); when(statement.getConnection()).thenReturn(mock(Connection.class)); - StatementExecutor actual = new StatementExecutor(getExecuteTemplate(), SQLType.DML, createStatementUnits(DML_SQL, statement, "ds_0")); + StatementExecutor actual = new MemoryStrictlyStatementExecutor(SQLType.DML, getExecuteTemplate(), createStatementUnits(DML_SQL, statement, "ds_0")); assertThat(actual.executeUpdate(), is(0)); verify(statement).executeUpdate(DML_SQL); verify(getEventCaller(), times(2)).verifyDataSource("ds_0"); @@ -205,7 +206,7 @@ public void assertExecuteUpdateForMultipleStatementsFailure() throws SQLExceptio when(statement2.executeUpdate(DML_SQL)).thenThrow(exp); when(statement1.getConnection()).thenReturn(mock(Connection.class)); when(statement2.getConnection()).thenReturn(mock(Connection.class)); - StatementExecutor actual = new StatementExecutor(getExecuteTemplate(), SQLType.DML, createStatementUnits(DML_SQL, statement1, "ds_0", statement2, "ds_1")); + StatementExecutor actual = new MemoryStrictlyStatementExecutor(SQLType.DML, getExecuteTemplate(), createStatementUnits(DML_SQL, statement1, "ds_0", statement2, "ds_1")); assertThat(actual.executeUpdate(), is(0)); verify(statement1).executeUpdate(DML_SQL); verify(statement2).executeUpdate(DML_SQL); @@ -223,7 +224,7 @@ public void assertExecuteUpdateWithAutoGeneratedKeys() throws SQLException { Statement statement = mock(Statement.class); when(statement.executeUpdate(DML_SQL, Statement.NO_GENERATED_KEYS)).thenReturn(10); when(statement.getConnection()).thenReturn(mock(Connection.class)); - StatementExecutor actual = new StatementExecutor(getExecuteTemplate(), SQLType.DML, createStatementUnits(DML_SQL, statement, "ds_0")); + StatementExecutor actual = new MemoryStrictlyStatementExecutor(SQLType.DML, getExecuteTemplate(), createStatementUnits(DML_SQL, statement, "ds_0")); assertThat(actual.executeUpdate(Statement.NO_GENERATED_KEYS), is(10)); verify(statement).executeUpdate(DML_SQL, Statement.NO_GENERATED_KEYS); verify(getEventCaller(), times(2)).verifyDataSource("ds_0"); @@ -239,7 +240,7 @@ public void assertExecuteUpdateWithColumnIndexes() throws SQLException { Statement statement = mock(Statement.class); when(statement.executeUpdate(DML_SQL, new int[] {1})).thenReturn(10); when(statement.getConnection()).thenReturn(mock(Connection.class)); - StatementExecutor actual = new StatementExecutor(getExecuteTemplate(), SQLType.DML, createStatementUnits(DML_SQL, statement, "ds_0")); + StatementExecutor actual = new MemoryStrictlyStatementExecutor(SQLType.DML, getExecuteTemplate(), createStatementUnits(DML_SQL, statement, "ds_0")); assertThat(actual.executeUpdate(new int[] {1}), is(10)); verify(statement).executeUpdate(DML_SQL, new int[] {1}); verify(getEventCaller(), times(2)).verifyDataSource("ds_0"); @@ -255,7 +256,7 @@ public void assertExecuteUpdateWithColumnNames() throws SQLException { Statement statement = mock(Statement.class); when(statement.executeUpdate(DML_SQL, new String[] {"col"})).thenReturn(10); when(statement.getConnection()).thenReturn(mock(Connection.class)); - StatementExecutor actual = new StatementExecutor(getExecuteTemplate(), SQLType.DML, createStatementUnits(DML_SQL, statement, "ds_0")); + StatementExecutor actual = new MemoryStrictlyStatementExecutor(SQLType.DML, getExecuteTemplate(), createStatementUnits(DML_SQL, statement, "ds_0")); assertThat(actual.executeUpdate(new String[] {"col"}), is(10)); verify(statement).executeUpdate(DML_SQL, new String[] {"col"}); verify(getEventCaller(), times(2)).verifyDataSource("ds_0"); @@ -271,7 +272,7 @@ public void assertExecuteForSingleStatementSuccessWithDML() throws SQLException Statement statement = mock(Statement.class); when(statement.execute(DML_SQL)).thenReturn(false); when(statement.getConnection()).thenReturn(mock(Connection.class)); - StatementExecutor actual = new StatementExecutor(getExecuteTemplate(), SQLType.DML, createStatementUnits(DML_SQL, statement, "ds_0")); + StatementExecutor actual = new MemoryStrictlyStatementExecutor(SQLType.DML, getExecuteTemplate(), createStatementUnits(DML_SQL, statement, "ds_0")); assertFalse(actual.execute()); verify(statement).execute(DML_SQL); verify(getEventCaller(), times(2)).verifyDataSource("ds_0"); @@ -290,7 +291,7 @@ public void assertExecuteForMultipleStatementsSuccessWithDML() throws SQLExcepti when(statement2.execute(DML_SQL)).thenReturn(false); when(statement1.getConnection()).thenReturn(mock(Connection.class)); when(statement2.getConnection()).thenReturn(mock(Connection.class)); - StatementExecutor actual = new StatementExecutor(getExecuteTemplate(), SQLType.DML, createStatementUnits(DML_SQL, statement1, "ds_0", statement2, "ds_1")); + StatementExecutor actual = new MemoryStrictlyStatementExecutor(SQLType.DML, getExecuteTemplate(), createStatementUnits(DML_SQL, statement1, "ds_0", statement2, "ds_1")); assertFalse(actual.execute()); verify(statement1).execute(DML_SQL); verify(statement2).execute(DML_SQL); @@ -309,7 +310,7 @@ public void assertExecuteForSingleStatementFailureWithDML() throws SQLException SQLException exp = new SQLException(); when(statement.execute(DML_SQL)).thenThrow(exp); when(statement.getConnection()).thenReturn(mock(Connection.class)); - StatementExecutor actual = new StatementExecutor(getExecuteTemplate(), SQLType.DML, createStatementUnits(DML_SQL, statement, "ds_0")); + StatementExecutor actual = new MemoryStrictlyStatementExecutor(SQLType.DML, getExecuteTemplate(), createStatementUnits(DML_SQL, statement, "ds_0")); assertFalse(actual.execute()); verify(statement).execute(DML_SQL); verify(getEventCaller(), times(2)).verifyDataSource("ds_0"); @@ -329,7 +330,7 @@ public void assertExecuteForMultipleStatementsFailureWithDML() throws SQLExcepti when(statement2.execute(DML_SQL)).thenThrow(exp); when(statement1.getConnection()).thenReturn(mock(Connection.class)); when(statement2.getConnection()).thenReturn(mock(Connection.class)); - StatementExecutor actual = new StatementExecutor(getExecuteTemplate(), SQLType.DML, createStatementUnits(DML_SQL, statement1, "ds_0", statement2, "ds_1")); + StatementExecutor actual = new MemoryStrictlyStatementExecutor(SQLType.DML, getExecuteTemplate(), createStatementUnits(DML_SQL, statement1, "ds_0", statement2, "ds_1")); assertFalse(actual.execute()); verify(statement1).execute(DML_SQL); verify(statement2).execute(DML_SQL); @@ -347,7 +348,7 @@ public void assertExecuteForSingleStatementWithDQL() throws SQLException { Statement statement = mock(Statement.class); when(statement.execute(DQL_SQL)).thenReturn(true); when(statement.getConnection()).thenReturn(mock(Connection.class)); - StatementExecutor actual = new StatementExecutor(getExecuteTemplate(), SQLType.DQL, createStatementUnits(DQL_SQL, statement, "ds_0")); + StatementExecutor actual = new MemoryStrictlyStatementExecutor(SQLType.DQL, getExecuteTemplate(), createStatementUnits(DQL_SQL, statement, "ds_0")); assertTrue(actual.execute()); verify(statement).execute(DQL_SQL); verify(getEventCaller(), times(2)).verifyDataSource("ds_0"); @@ -366,7 +367,7 @@ public void assertExecuteForMultipleStatements() throws SQLException { when(statement2.execute(DQL_SQL)).thenReturn(true); when(statement1.getConnection()).thenReturn(mock(Connection.class)); when(statement2.getConnection()).thenReturn(mock(Connection.class)); - StatementExecutor actual = new StatementExecutor(getExecuteTemplate(), SQLType.DQL, createStatementUnits(DQL_SQL, statement1, "ds_0", statement2, "ds_1")); + StatementExecutor actual = new MemoryStrictlyStatementExecutor(SQLType.DQL, getExecuteTemplate(), createStatementUnits(DQL_SQL, statement1, "ds_0", statement2, "ds_1")); assertTrue(actual.execute()); verify(statement1).execute(DQL_SQL); verify(statement2).execute(DQL_SQL); @@ -384,7 +385,7 @@ public void assertExecuteWithAutoGeneratedKeys() throws SQLException { Statement statement = mock(Statement.class); when(statement.execute(DML_SQL, Statement.NO_GENERATED_KEYS)).thenReturn(false); when(statement.getConnection()).thenReturn(mock(Connection.class)); - StatementExecutor actual = new StatementExecutor(getExecuteTemplate(), SQLType.DML, createStatementUnits(DML_SQL, statement, "ds_0")); + StatementExecutor actual = new MemoryStrictlyStatementExecutor(SQLType.DML, getExecuteTemplate(), createStatementUnits(DML_SQL, statement, "ds_0")); assertFalse(actual.execute(Statement.NO_GENERATED_KEYS)); verify(statement).execute(DML_SQL, Statement.NO_GENERATED_KEYS); verify(getEventCaller(), times(2)).verifyDataSource("ds_0"); @@ -400,7 +401,7 @@ public void assertExecuteWithColumnIndexes() throws SQLException { Statement statement = mock(Statement.class); when(statement.execute(DML_SQL, new int[] {1})).thenReturn(false); when(statement.getConnection()).thenReturn(mock(Connection.class)); - StatementExecutor actual = new StatementExecutor(getExecuteTemplate(), SQLType.DML, createStatementUnits(DML_SQL, statement, "ds_0")); + StatementExecutor actual = new MemoryStrictlyStatementExecutor(SQLType.DML, getExecuteTemplate(), createStatementUnits(DML_SQL, statement, "ds_0")); assertFalse(actual.execute(new int[] {1})); verify(statement).execute(DML_SQL, new int[] {1}); verify(getEventCaller(), times(2)).verifyDataSource("ds_0"); @@ -416,7 +417,7 @@ public void assertExecuteWithColumnNames() throws SQLException { Statement statement = mock(Statement.class); when(statement.execute(DML_SQL, new String[] {"col"})).thenReturn(false); when(statement.getConnection()).thenReturn(mock(Connection.class)); - StatementExecutor actual = new StatementExecutor(getExecuteTemplate(), SQLType.DML, createStatementUnits(DML_SQL, statement, "ds_0")); + StatementExecutor actual = new MemoryStrictlyStatementExecutor(SQLType.DML, getExecuteTemplate(), createStatementUnits(DML_SQL, statement, "ds_0")); assertFalse(actual.execute(new String[] {"col"})); verify(statement).execute(DML_SQL, new String[] {"col"}); verify(getEventCaller(), times(2)).verifyDataSource("ds_0"); @@ -434,7 +435,7 @@ public void assertOverallExceptionFailure() throws SQLException { SQLException exp = new SQLException(); when(statement.execute(DML_SQL)).thenThrow(exp); when(statement.getConnection()).thenReturn(mock(Connection.class)); - StatementExecutor actual = new StatementExecutor(getExecuteTemplate(), SQLType.DML, createStatementUnits(DML_SQL, statement, "ds_0")); + StatementExecutor actual = new MemoryStrictlyStatementExecutor(SQLType.DML, getExecuteTemplate(), createStatementUnits(DML_SQL, statement, "ds_0")); try { assertFalse(actual.execute()); } catch (final SQLException ignore) { diff --git a/sharding-opentracing/src/test/java/io/shardingsphere/opentracing/listener/execution/ExecuteEventListenerTest.java b/sharding-opentracing/src/test/java/io/shardingsphere/opentracing/listener/execution/ExecuteEventListenerTest.java index c14da78afaebd..e9f21f0a2b034 100644 --- a/sharding-opentracing/src/test/java/io/shardingsphere/opentracing/listener/execution/ExecuteEventListenerTest.java +++ b/sharding-opentracing/src/test/java/io/shardingsphere/opentracing/listener/execution/ExecuteEventListenerTest.java @@ -17,15 +17,14 @@ package io.shardingsphere.opentracing.listener.execution; -import io.shardingsphere.core.constant.ConnectionMode; import io.shardingsphere.core.constant.SQLType; import io.shardingsphere.core.executor.ShardingExecuteEngine; +import io.shardingsphere.core.executor.batch.BatchPreparedStatementUnit; import io.shardingsphere.core.executor.sql.SQLExecuteCallback; import io.shardingsphere.core.executor.sql.SQLExecuteTemplate; import io.shardingsphere.core.executor.sql.StatementExecuteUnit; import io.shardingsphere.core.executor.sql.threadlocal.ExecutorDataMap; import io.shardingsphere.core.executor.sql.threadlocal.ExecutorExceptionHandler; -import io.shardingsphere.core.executor.batch.BatchPreparedStatementUnit; import io.shardingsphere.core.executor.statement.StatementUnit; import io.shardingsphere.core.routing.SQLExecutionUnit; import io.shardingsphere.core.routing.SQLUnit; @@ -52,7 +51,7 @@ public final class ExecuteEventListenerTest extends BaseEventListenerTest { private ShardingExecuteEngine executeEngine = new ShardingExecuteEngine(5); - private final SQLExecuteTemplate sqlExecuteTemplate = new SQLExecuteTemplate(executeEngine, ConnectionMode.MEMORY_STRICTLY, 1); + private final SQLExecuteTemplate sqlExecuteTemplate = new SQLExecuteTemplate(executeEngine); @After public void tearDown() { diff --git a/sharding-proxy/src/main/java/io/shardingsphere/proxy/backend/jdbc/execute/memory/ConnectionStrictlyExecuteEngine.java b/sharding-proxy/src/main/java/io/shardingsphere/proxy/backend/jdbc/execute/memory/ConnectionStrictlyExecuteEngine.java index 216b5541d1e5c..9584e2dbb37f7 100644 --- a/sharding-proxy/src/main/java/io/shardingsphere/proxy/backend/jdbc/execute/memory/ConnectionStrictlyExecuteEngine.java +++ b/sharding-proxy/src/main/java/io/shardingsphere/proxy/backend/jdbc/execute/memory/ConnectionStrictlyExecuteEngine.java @@ -18,7 +18,6 @@ package io.shardingsphere.proxy.backend.jdbc.execute.memory; import com.google.common.collect.Lists; -import io.shardingsphere.core.constant.ConnectionMode; import io.shardingsphere.core.constant.SQLType; import io.shardingsphere.core.executor.sql.SQLExecuteCallback; import io.shardingsphere.core.executor.sql.SQLExecuteTemplate; @@ -49,6 +48,7 @@ import java.sql.SQLException; import java.util.ArrayList; import java.util.Collection; +import java.util.HashMap; import java.util.LinkedList; import java.util.List; import java.util.Map; @@ -66,7 +66,7 @@ public final class ConnectionStrictlyExecuteEngine extends JDBCExecuteEngine { public ConnectionStrictlyExecuteEngine(final BackendConnection backendConnection, final JDBCExecutorWrapper jdbcExecutorWrapper) { super(backendConnection, jdbcExecutorWrapper); - sqlExecuteTemplate = new SQLExecuteTemplate(BackendExecutorContext.getInstance().getExecuteEngine(), ConnectionMode.CONNECTION_STRICTLY, 1); + sqlExecuteTemplate = new SQLExecuteTemplate(BackendExecutorContext.getInstance().getExecuteEngine()); } @Override @@ -75,7 +75,7 @@ public ExecuteResponse execute(final SQLRouteResult routeResult) throws SQLExcep SQLType sqlType = routeResult.getSqlStatement().getType(); boolean isExceptionThrown = ExecutorExceptionHandler.isExceptionThrown(); Map dataMap = ExecutorDataMap.getDataMap(); - Collection executeResponseUnits = sqlExecuteTemplate.execute(getStatementExecuteUnits(routeResult, isReturnGeneratedKeys), + Collection executeResponseUnits = sqlExecuteTemplate.execute(partitionStatementExecuteUnits(routeResult, isReturnGeneratedKeys), new FirstConnectionStrictlySQLExecuteCallback(sqlType, isExceptionThrown, dataMap, isReturnGeneratedKeys), new ConnectionStrictlySQLExecuteCallback(sqlType, isExceptionThrown, dataMap, isReturnGeneratedKeys)); ExecuteResponseUnit firstExecuteResponseUnit = executeResponseUnits.iterator().next(); @@ -83,16 +83,26 @@ public ExecuteResponse execute(final SQLRouteResult routeResult) throws SQLExcep ? getExecuteQueryResponse(((ExecuteQueryResponseUnit) firstExecuteResponseUnit).getQueryResponsePackets(), executeResponseUnits) : new ExecuteUpdateResponse(executeResponseUnits); } - private Collection getStatementExecuteUnits(final SQLRouteResult routeResult, final boolean isReturnGeneratedKeys) throws SQLException { - Collection result = new LinkedList<>(); - for (Entry> entry : routeResult.getSQLUnitGroups().entrySet()) { - result.addAll(getStatementExecuteUnits(entry.getKey(), entry.getValue(), isReturnGeneratedKeys)); + private Map>> partitionStatementExecuteUnits(final SQLRouteResult routeResult, final boolean isReturnGeneratedKeys) throws SQLException { + Map> statementExecuteUnits = getStatementExecuteUnits(routeResult, isReturnGeneratedKeys); + Map>> result = new HashMap<>(statementExecuteUnits.size(), 1); + for (Entry> entry : statementExecuteUnits.entrySet()) { + result.put(entry.getKey(), Lists.>newArrayList(Lists.partition(entry.getValue(), RuleRegistry.getInstance().getMaxConnectionsSizePerQuery()))); } return result; } - private Collection getStatementExecuteUnits(final String dataSourceName, final Collection sqlUnits, final boolean isReturnGeneratedKeys) throws SQLException { - Collection result = new LinkedList<>(); + private Map> getStatementExecuteUnits(final SQLRouteResult routeResult, final boolean isReturnGeneratedKeys) throws SQLException { + Map> sqlUnitGroups = routeResult.getSQLUnitGroups(); + Map> result = new HashMap<>(sqlUnitGroups.size(), 1); + for (Entry> entry : sqlUnitGroups.entrySet()) { + result.put(entry.getKey(), getStatementExecuteUnits(entry.getKey(), entry.getValue(), isReturnGeneratedKeys)); + } + return result; + } + + private List getStatementExecuteUnits(final String dataSourceName, final Collection sqlUnits, final boolean isReturnGeneratedKeys) throws SQLException { + List result = new LinkedList<>(); for (List sqlUnitList : Lists.partition(new ArrayList<>(sqlUnits), RuleRegistry.getInstance().getMaxConnectionsSizePerQuery())) { Connection connection = getBackendConnection().getConnection(dataSourceName); for (SQLUnit each : sqlUnitList) { diff --git a/sharding-proxy/src/main/java/io/shardingsphere/proxy/backend/jdbc/execute/stream/MemoryStrictlyExecuteEngine.java b/sharding-proxy/src/main/java/io/shardingsphere/proxy/backend/jdbc/execute/stream/MemoryStrictlyExecuteEngine.java index 11da7c023f77a..3f4f03cfc9329 100644 --- a/sharding-proxy/src/main/java/io/shardingsphere/proxy/backend/jdbc/execute/stream/MemoryStrictlyExecuteEngine.java +++ b/sharding-proxy/src/main/java/io/shardingsphere/proxy/backend/jdbc/execute/stream/MemoryStrictlyExecuteEngine.java @@ -17,7 +17,6 @@ package io.shardingsphere.proxy.backend.jdbc.execute.stream; -import io.shardingsphere.core.constant.ConnectionMode; import io.shardingsphere.core.constant.SQLType; import io.shardingsphere.core.executor.sql.SQLExecuteCallback; import io.shardingsphere.core.executor.sql.SQLExecuteTemplate; @@ -65,7 +64,7 @@ public final class MemoryStrictlyExecuteEngine extends JDBCExecuteEngine { public MemoryStrictlyExecuteEngine(final BackendConnection backendConnection, final JDBCExecutorWrapper jdbcExecutorWrapper) { super(backendConnection, jdbcExecutorWrapper); - sqlExecuteTemplate = new SQLExecuteTemplate(BackendExecutorContext.getInstance().getExecuteEngine(), ConnectionMode.MEMORY_STRICTLY, 1); + sqlExecuteTemplate = new SQLExecuteTemplate(BackendExecutorContext.getInstance().getExecuteEngine()); } @Override