Skip to content

Commit

Permalink
Merge branch 'dev' of ssh://github.com/shardingjdbc/sharding-jdbc int…
Browse files Browse the repository at this point in the history
…o dev
  • Loading branch information
tristaZero committed Aug 31, 2018
2 parents 45f0fa0 + 92ad5e1 commit 1fe6104
Show file tree
Hide file tree
Showing 23 changed files with 648 additions and 185 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
Expand Down Expand Up @@ -126,64 +125,64 @@ private <O> List<O> getResults(final O firstResult, final Collection<ListenableF
* execute all callbacks for group.
*
* @param inputs input value's map
* @param maxThreadCountPerGroup max thread count for every group
* @param callback sharding execute callback
* @param <I> type of input value
* @param <O> type of return value
* @return execute result
* @throws SQLException throw if execute failure
*/
public <I, O> List<O> groupExecute(final Map<String, Collection<I>> inputs, final int maxThreadCountPerGroup, final ShardingGroupExecuteCallback<I, O> callback) throws SQLException {
return groupExecute(inputs, maxThreadCountPerGroup, null, callback);
public <I, O> List<O> groupExecute(final Map<String, List<List<I>>> inputs, final ShardingGroupExecuteCallback<I, O> callback) throws SQLException {
return groupExecute(inputs, null, callback);
}

/**
* execute all callbacks for group.
*
* @param inputs input value's map
* @param maxThreadCountPerGroup max thread count for every group
* @param callback sharding execute callback
* @param firstCallback first sharding execute callback
* @param <I> type of input value
* @param <O> type of return value
* @return execute result
* @throws SQLException throw if execute failure
*/
public <I, O> List<O> groupExecute(final Map<String, Collection<I>> inputs, final int maxThreadCountPerGroup,
final ShardingGroupExecuteCallback<I, O> firstCallback, final ShardingGroupExecuteCallback<I, O> callback) throws SQLException {
public <I, O> List<O> groupExecute(
final Map<String, List<List<I>>> inputs, final ShardingGroupExecuteCallback<I, O> firstCallback, final ShardingGroupExecuteCallback<I, O> callback) throws SQLException {
if (inputs.isEmpty()) {
return Collections.emptyList();
}
Map<String, List<List<I>>> executionUnits = new HashMap<>(inputs.size(), 1);
for (Entry<String, Collection<I>> entry : inputs.entrySet()) {
executionUnits.put(entry.getKey(), Lists.partition(new ArrayList<>(entry.getValue()), maxThreadCountPerGroup));
}
String firstKey = executionUnits.keySet().iterator().next();
Iterator<List<I>> firstExecutionUnits = executionUnits.get(firstKey).iterator();
Collection<I> firstInputs = firstExecutionUnits.next();
executionUnits.put(firstKey, Lists.newArrayList(firstExecutionUnits));
Collection<ListenableFuture<Collection<O>>> restResultFutures = asyncGroupExecute(executionUnits, callback);
String firstKey = inputs.keySet().iterator().next();
Iterator<List<I>> firstInputGroup = inputs.get(firstKey).iterator();
Collection<I> firstInputs = firstInputGroup.next();
inputs.put(firstKey, Lists.newArrayList(firstInputGroup));
Collection<ListenableFuture<Collection<O>>> restResultFutures = asyncGroupExecute(inputs, callback);
return getGroupResults(syncGroupExecute(firstKey, firstInputs, null == firstCallback ? callback : firstCallback), restResultFutures);
}

private <I, O> Collection<ListenableFuture<Collection<O>>> asyncGroupExecute(final Map<String, List<List<I>>> inputs, final ShardingGroupExecuteCallback<I, O> callback) {
Collection<ListenableFuture<Collection<O>>> result = new ArrayList<>(inputs.size());
for (final Entry<String, List<List<I>>> entry : inputs.entrySet()) {
for (final List<I> each : entry.getValue()) {
result.add(executorService.submit(new Callable<Collection<O>>() {

@Override
public Collection<O> call() throws SQLException {
return callback.execute(entry.getKey(), each);
}
}));
}
Collection<ListenableFuture<Collection<O>>> result = new LinkedList<>();
for (Entry<String, List<List<I>>> entry : inputs.entrySet()) {
result.addAll(asyncGroupExecute(entry.getKey(), entry.getValue(), callback));
}
return result;
}

private <I, O> Collection<ListenableFuture<Collection<O>>> asyncGroupExecute(final String key, final List<List<I>> inputs, final ShardingGroupExecuteCallback<I, O> callback) {
Collection<ListenableFuture<Collection<O>>> result = new LinkedList<>();
for (final List<I> each : inputs) {
result.add(executorService.submit(new Callable<Collection<O>>() {

@Override
public Collection<O> call() throws SQLException {
return callback.execute(key, each);
}
}));
}
return result;
}

private <I, O> Collection<O> syncGroupExecute(final String dataSourceName, final Collection<I> inputs, final ShardingGroupExecuteCallback<I, O> callback) throws SQLException {
return callback.execute(dataSourceName, inputs);
private <I, O> Collection<O> syncGroupExecute(final String key, final Collection<I> inputs, final ShardingGroupExecuteCallback<I, O> callback) throws SQLException {
return callback.execute(key, inputs);
}

private <O> List<O> getGroupResults(final Collection<O> firstResults, final Collection<ListenableFuture<Collection<O>>> restFutures) throws SQLException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

package io.shardingsphere.core.executor.sql;

import io.shardingsphere.core.constant.ConnectionMode;
import io.shardingsphere.core.event.ShardingEventBusInstance;
import io.shardingsphere.core.executor.ShardingExecuteEngine;
import io.shardingsphere.core.executor.sql.event.overall.OverallExecutionEvent;
Expand All @@ -27,10 +26,11 @@
import java.sql.SQLException;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;

/**
* SQL execute template.
Expand All @@ -45,9 +45,47 @@ public final class SQLExecuteTemplate {

private final ShardingExecuteEngine executeEngine;

private final ConnectionMode connectionMode;
/**
* Execute.
*
* @param executeUnits execute units
* @param executeCallback execute callback
* @param <T> class type of return value
* @return execute result
* @throws SQLException SQL exception
*/
public <T> List<T> execute(final Collection<? extends StatementExecuteUnit> executeUnits, final SQLExecuteCallback<T> executeCallback) throws SQLException {
return execute(executeUnits, null, executeCallback);
}

private final int maxConnectionsSizePerQuery;
/**
* Execute.
*
* @param executeUnits execute units
* @param firstExecuteCallback first execute callback
* @param executeCallback execute callback
* @param <T> class type of return value
* @return execute result
* @throws SQLException SQL exception
*/
public <T> List<T> execute(final Collection<? extends StatementExecuteUnit> executeUnits,
final SQLExecuteCallback<T> firstExecuteCallback, final SQLExecuteCallback<T> executeCallback) throws SQLException {
OverallExecutionEvent event = new OverallExecutionEvent(executeUnits.size() > 1);
ShardingEventBusInstance.getInstance().post(event);
try {
List<T> result = executeEngine.execute(new LinkedList<>(executeUnits), firstExecuteCallback, executeCallback);
event.setExecuteSuccess();
return result;
// CHECKSTYLE:OFF
} catch (final Exception ex) {
// CHECKSTYLE:ON
event.setExecuteFailure(ex);
ExecutorExceptionHandler.handleException(ex);
return Collections.emptyList();
} finally {
ShardingEventBusInstance.getInstance().post(event);
}
}

/**
* Execute.
Expand All @@ -58,7 +96,7 @@ public final class SQLExecuteTemplate {
* @return execute result
* @throws SQLException SQL exception
*/
public <T> List<T> execute(final Collection<? extends StatementExecuteUnit> executeUnits, final SQLExecuteCallback<T> executeCallback) throws SQLException {
public <T> List<T> execute(final Map<String, List<List<? extends StatementExecuteUnit>>> executeUnits, final SQLExecuteCallback<T> executeCallback) throws SQLException {
return execute(executeUnits, null, executeCallback);
}

Expand All @@ -72,13 +110,12 @@ public <T> List<T> execute(final Collection<? extends StatementExecuteUnit> exec
* @return execute result
* @throws SQLException SQL exception
*/
public <T> List<T> execute(
final Collection<? extends StatementExecuteUnit> executeUnits, final SQLExecuteCallback<T> firstExecuteCallback, final SQLExecuteCallback<T> executeCallback) throws SQLException {
public <T> List<T> execute(final Map<String, List<List<? extends StatementExecuteUnit>>> executeUnits,
final SQLExecuteCallback<T> firstExecuteCallback, final SQLExecuteCallback<T> executeCallback) throws SQLException {
OverallExecutionEvent event = new OverallExecutionEvent(executeUnits.size() > 1);
ShardingEventBusInstance.getInstance().post(event);
try {
List<T> result = ConnectionMode.MEMORY_STRICTLY == connectionMode ? executeEngine.execute(new LinkedList<>(executeUnits), firstExecuteCallback, executeCallback)
: executeEngine.groupExecute(getExecuteUnitGroups(executeUnits), maxConnectionsSizePerQuery, firstExecuteCallback, executeCallback);
List<T> result = executeEngine.groupExecute(transform(executeUnits), firstExecuteCallback, executeCallback);
event.setExecuteSuccess();
return result;
// CHECKSTYLE:OFF
Expand All @@ -92,14 +129,15 @@ public <T> List<T> execute(
}
}

private Map<String, Collection<StatementExecuteUnit>> getExecuteUnitGroups(final Collection<? extends StatementExecuteUnit> executeUnits) {
Map<String, Collection<StatementExecuteUnit>> result = new LinkedHashMap<>(executeUnits.size(), 1);
for (StatementExecuteUnit each : executeUnits) {
String dataSourceName = each.getSqlExecutionUnit().getDataSource();
if (!result.keySet().contains(dataSourceName)) {
result.put(dataSourceName, new LinkedList<StatementExecuteUnit>());
private Map<String, List<List<StatementExecuteUnit>>> transform(final Map<String, List<List<? extends StatementExecuteUnit>>> executeUnits) {
Map<String, List<List<StatementExecuteUnit>>> result = new HashMap<>(executeUnits.size());
for (Entry<String, List<List<? extends StatementExecuteUnit>>> entry : executeUnits.entrySet()) {
if (!result.containsKey(entry.getKey())) {
result.put(entry.getKey(), new LinkedList<List<StatementExecuteUnit>>());
}
for (List<? extends StatementExecuteUnit> each : entry.getValue()) {
result.get(entry.getKey()).add(new LinkedList<>(each));
}
result.get(dataSourceName).add(each);
}
return result;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package io.shardingsphere.core.metadata.table.executor;

import com.google.common.collect.Lists;
import io.shardingsphere.core.exception.ShardingException;
import io.shardingsphere.core.executor.ShardingExecuteEngine;
import io.shardingsphere.core.executor.ShardingGroupExecuteCallback;
Expand All @@ -33,10 +34,12 @@
import java.sql.SQLException;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;

/**
* Table meta data loader.
Expand Down Expand Up @@ -68,13 +71,13 @@ public TableMetaData load(final String logicTableName, final ShardingRule shardi
return actualTableMetaDataList.iterator().next();
}

private List<TableMetaData> load(final Map<String, Collection<String>> dataNodeGroups, final ShardingDataSourceNames shardingDataSourceNames) throws SQLException {
return executeEngine.groupExecute(dataNodeGroups, maxConnectionsSizePerQuery, new ShardingGroupExecuteCallback<String, TableMetaData>() {
private List<TableMetaData> load(final Map<String, List<String>> dataNodeGroups, final ShardingDataSourceNames shardingDataSourceNames) throws SQLException {
return executeEngine.groupExecute(partitionDataNodeGroups(dataNodeGroups), new ShardingGroupExecuteCallback<String, TableMetaData>() {

@Override
public Collection<TableMetaData> execute(final String dataSourceName, final Collection<String> actualTableNames) throws SQLException {
DataSourceMetaData dataSourceMetaData = shardingDataSourceMetaData.getActualDataSourceMetaData(dataSourceName);
final String catalog = null == dataSourceMetaData ? null : dataSourceMetaData.getSchemeName();
String catalog = null == dataSourceMetaData ? null : dataSourceMetaData.getSchemeName();
return load(shardingDataSourceNames.getRawMasterDataSourceName(dataSourceName), catalog, actualTableNames);
}
});
Expand All @@ -90,6 +93,14 @@ private Collection<TableMetaData> load(final String dataSourceName, final String
return result;
}

private Map<String, List<List<String>>> partitionDataNodeGroups(final Map<String, List<String>> dataNodeGroups) {
Map<String, List<List<String>>> result = new HashMap<>(dataNodeGroups.size(), 1);
for (Entry<String, List<String>> entry : dataNodeGroups.entrySet()) {
result.put(entry.getKey(), Lists.partition(entry.getValue(), maxConnectionsSizePerQuery));
}
return result;
}

private boolean isTableExist(final Connection connection, final String catalog, final String actualTableName) throws SQLException {
try (ResultSet resultSet = connection.getMetaData().getTables(catalog, null, actualTableName, null)) {
return resultSet.next();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@
import lombok.Getter;
import lombok.RequiredArgsConstructor;

import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;

Expand Down Expand Up @@ -55,8 +55,8 @@ public SQLRouteResult(final SQLStatement sqlStatement) {
*
* @return SQL units grouped by data source name.
*/
public Map<String, Collection<SQLUnit>> getSQLUnitGroups() {
Map<String, Collection<SQLUnit>> result = new LinkedHashMap<>(executionUnits.size(), 1);
public Map<String, List<SQLUnit>> getSQLUnitGroups() {
Map<String, List<SQLUnit>> result = new LinkedHashMap<>(executionUnits.size(), 1);
for (SQLExecutionUnit each : executionUnits) {
if (!result.containsKey(each.getDataSource())) {
result.put(each.getDataSource(), new LinkedList<SQLUnit>());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ private List<DataNode> generateDataNodes(final List<String> actualDataNodes, fin
*
* @return data node groups, key is data source name, value is tables belong to this data source
*/
public Map<String, Collection<String>> getDataNodeGroups() {
Map<String, Collection<String>> result = new LinkedHashMap<>(actualDataNodes.size(), 1);
public Map<String, List<String>> getDataNodeGroups() {
Map<String, List<String>> result = new LinkedHashMap<>(actualDataNodes.size(), 1);
for (DataNode each : actualDataNodes) {
String dataSourceName = each.getDataSourceName();
if (!result.containsKey(dataSourceName)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import io.shardingsphere.core.constant.DatabaseType;
import io.shardingsphere.core.constant.SQLType;
import io.shardingsphere.core.executor.sql.SQLExecuteCallback;
import io.shardingsphere.core.executor.sql.SQLExecuteTemplate;
import io.shardingsphere.core.executor.sql.StatementExecuteUnit;
import io.shardingsphere.core.executor.sql.threadlocal.ExecutorDataMap;
import io.shardingsphere.core.executor.sql.threadlocal.ExecutorExceptionHandler;
Expand All @@ -30,24 +29,21 @@
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;

/**
* PreparedStatement Executor for multiple threads to process add batch.
* Prepared statement executor to process add batch.
*
* @author zhangliang
* @author maxiaoguang
*/
@RequiredArgsConstructor
public final class BatchPreparedStatementExecutor {

private final SQLExecuteTemplate executeTemplate;
public abstract class BatchPreparedStatementExecutor {

private final DatabaseType dbType;

private final SQLType sqlType;

private final Collection<BatchPreparedStatementUnit> batchPreparedStatementUnits;

private final int batchCount;

/**
Expand All @@ -66,14 +62,14 @@ protected int[] executeSQL(final StatementExecuteUnit executeUnit) throws SQLExc
return executeUnit.getStatement().executeBatch();
}
};
return accumulate(executeTemplate.execute(batchPreparedStatementUnits, callback));
return accumulate(executeCallback(callback));
}

private int[] accumulate(final List<int[]> results) {
int[] result = new int[batchCount];
int count = 0;
for (BatchPreparedStatementUnit each : batchPreparedStatementUnits) {
for (Map.Entry<Integer, Integer> entry : each.getJdbcAndActualAddBatchCallTimesMap().entrySet()) {
for (BatchPreparedStatementUnit each : getBatchPreparedStatementUnitGroups()) {
for (Entry<Integer, Integer> entry : each.getJdbcAndActualAddBatchCallTimesMap().entrySet()) {
int value = null == results.get(count) ? 0 : results.get(count)[entry.getValue()];
if (DatabaseType.Oracle == dbType) {
result[entry.getKey()] = value;
Expand All @@ -85,4 +81,8 @@ private int[] accumulate(final List<int[]> results) {
}
return result;
}

protected abstract <T> List<T> executeCallback(SQLExecuteCallback<T> executeCallback) throws SQLException;

protected abstract Collection<BatchPreparedStatementUnit> getBatchPreparedStatementUnitGroups();
}
Loading

0 comments on commit 1fe6104

Please sign in to comment.