Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace java.util.stream with for each in high frequency codes #13845

Merged
merged 5 commits into from Dec 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -138,12 +138,16 @@ private Collection<SelectStatement> getSelectStatements(final SQLStatementContex
Collection<SelectStatement> result = new LinkedList<>();
if (sqlStatementContext instanceof SelectStatementContext) {
result.add(((SelectStatementContext) sqlStatementContext).getSqlStatement());
result.addAll(((SelectStatementContext) sqlStatementContext).getSubqueryContexts().values().stream().map(SelectStatementContext::getSqlStatement).collect(Collectors.toList()));
for (SelectStatementContext each : ((SelectStatementContext) sqlStatementContext).getSubqueryContexts().values()) {
result.add(each.getSqlStatement());
}
}
if (sqlStatementContext instanceof InsertStatementContext && null != ((InsertStatementContext) sqlStatementContext).getInsertSelectContext()) {
SelectStatementContext selectStatementContext = ((InsertStatementContext) sqlStatementContext).getInsertSelectContext().getSelectStatementContext();
result.add(selectStatementContext.getSqlStatement());
result.addAll(selectStatementContext.getSubqueryContexts().values().stream().map(SelectStatementContext::getSqlStatement).collect(Collectors.toList()));
for (SelectStatementContext each : selectStatementContext.getSubqueryContexts().values()) {
result.add(each.getSqlStatement());
}
}
return result;
}
Expand Down
Expand Up @@ -53,7 +53,6 @@
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;
import java.util.stream.Collectors;

/**
* Sharding condition engine for where clause.
Expand Down Expand Up @@ -95,9 +94,16 @@ private Collection<ShardingCondition> createShardingConditions(final SQLStatemen
}

private Map<String, String> getColumnTableNames(final SQLStatementContext<?> sqlStatementContext, final Collection<AndPredicate> andPredicates) {
Collection<ColumnProjection> columns = andPredicates.stream().flatMap(each -> each.getPredicates().stream())
.flatMap(each -> ColumnExtractor.extract(each).stream().map(this::buildColumnProjection)).collect(Collectors.toList());
return sqlStatementContext.getTablesContext().findTableName(columns, schema);
Collection<ColumnProjection> result = new LinkedList<>();
for (AndPredicate each : andPredicates) {
for (ExpressionSegment expressionSegment : each.getPredicates()) {
for (ColumnSegment columnSegment : ColumnExtractor.extract(expressionSegment)) {
ColumnProjection columnProjection = buildColumnProjection(columnSegment);
result.add(columnProjection);
}
}
}
return sqlStatementContext.getTablesContext().findTableName(result, schema);
}

private ColumnProjection buildColumnProjection(final ColumnSegment segment) {
Expand Down
Expand Up @@ -25,6 +25,7 @@
import org.apache.shardingsphere.infra.config.properties.ConfigurationProperties;
import org.apache.shardingsphere.infra.config.properties.ConfigurationPropertyKey;
import org.apache.shardingsphere.infra.metadata.ShardingSphereMetaData;
import org.apache.shardingsphere.sharding.route.engine.condition.ShardingCondition;
import org.apache.shardingsphere.sharding.route.engine.condition.ShardingConditions;
import org.apache.shardingsphere.sharding.route.engine.condition.value.ShardingConditionValue;
import org.apache.shardingsphere.sharding.route.engine.type.broadcast.ShardingDataSourceGroupBroadcastRoutingEngine;
Expand Down Expand Up @@ -203,8 +204,12 @@ private static ShardingRouteEngine getDQLRouteEngineForShardingTable(final Shard
}

private static String getLogicTableName(final ShardingConditions shardingConditions, final Collection<String> tableNames) {
return shardingConditions.getConditions().stream().flatMap(each -> each.getValues().stream())
.map(ShardingConditionValue::getTableName).findFirst().orElseGet(() -> tableNames.iterator().next());
for (ShardingCondition each : shardingConditions.getConditions()) {
for (ShardingConditionValue shardingConditionValue : each.getValues()) {
return shardingConditionValue.getTableName();
}
}
return tableNames.iterator().next();
}

private static boolean isShardingStandardQuery(final Collection<String> tableNames, final ShardingRule shardingRule) {
Expand Down
Expand Up @@ -51,7 +51,10 @@ public RouteContext route(final ShardingRule shardingRule) {
String dataSourceName = getRandomDataSourceName(shardingRule.getDataSourceNames());
RouteMapper dataSourceMapper = new RouteMapper(dataSourceName, dataSourceName);
if (shardingRule.isAllBroadcastTables(logicTables)) {
List<RouteMapper> tableMappers = logicTables.stream().map(each -> new RouteMapper(each, each)).collect(Collectors.toCollection(() -> new ArrayList<>(logicTables.size())));
List<RouteMapper> tableMappers = new ArrayList<>(logicTables.size());
for (String each : logicTables) {
tableMappers.add(new RouteMapper(each, each));
}
result.getRouteUnits().add(new RouteUnit(dataSourceMapper, tableMappers));
} else if (logicTables.isEmpty()) {
result.getRouteUnits().add(new RouteUnit(dataSourceMapper, Collections.emptyList()));
Expand Down
Expand Up @@ -273,7 +273,12 @@ private Optional<BindingTableRule> findBindingTableRule(final Collection<String>
* @return binding table rule
*/
public Optional<BindingTableRule> findBindingTableRule(final String logicTableName) {
return bindingTableRules.stream().filter(each -> each.hasLogicTable(logicTableName)).findFirst();
for (BindingTableRule each : bindingTableRules) {
if (each.hasLogicTable(logicTableName)) {
return Optional.of(each);
}
}
return Optional.empty();
}

/**
Expand All @@ -293,7 +298,15 @@ public boolean isAllBroadcastTables(final Collection<String> logicTableNames) {
* @return whether logic table is all sharding table or not
*/
public boolean isAllShardingTables(final Collection<String> logicTableNames) {
return !logicTableNames.isEmpty() && logicTableNames.stream().allMatch(this::isShardingTable);
if (logicTableNames.isEmpty()) {
return false;
}
for (String each : logicTableNames) {
if (!isShardingTable(each)) {
return false;
}
}
return true;
}

/**
Expand Down Expand Up @@ -429,7 +442,13 @@ public DataNode getDataNode(final String logicTableName) {
* @return sharding logic table names
*/
public Collection<String> getShardingLogicTableNames(final Collection<String> logicTableNames) {
return logicTableNames.stream().filter(this::isShardingTable).collect(Collectors.toCollection(LinkedList::new));
Collection<String> result = new LinkedList<>();
for (String each : logicTableNames) {
if (isShardingTable(each)) {
result.add(each);
}
}
return result;
}

/**
Expand Down
Expand Up @@ -36,7 +36,6 @@
import java.util.LinkedList;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;

/**
* Pagination context engine.
Expand All @@ -57,7 +56,10 @@ public PaginationContext createPaginationContext(final SelectStatement selectSta
return new LimitPaginationContextEngine().createPaginationContext(limitSegment.get(), parameters);
}
Optional<TopProjectionSegment> topProjectionSegment = findTopProjection(selectStatement);
Collection<ExpressionSegment> expressions = getWhereSegments(selectStatement).stream().map(WhereSegment::getExpr).collect(Collectors.toList());
Collection<ExpressionSegment> expressions = new LinkedList<>();
for (WhereSegment each : getWhereSegments(selectStatement)) {
expressions.add(each.getExpr());
}
if (topProjectionSegment.isPresent()) {
return new TopPaginationContextEngine().createPaginationContext(topProjectionSegment.get(), expressions, parameters);
}
Expand Down
Expand Up @@ -33,7 +33,10 @@
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.TreeMap;
Expand Down Expand Up @@ -66,19 +69,29 @@ public TablesContext(final Collection<? extends TableSegment> tableSegments, fin
if (tableSegments.isEmpty()) {
return;
}
Collection<SimpleTableSegment> simpleTableSegments = tableSegments.stream().filter(each
-> each instanceof SimpleTableSegment).map(each -> (SimpleTableSegment) each).collect(Collectors.toList());
for (SimpleTableSegment each : simpleTableSegments) {
tables.add(each);
tableNames.add(each.getTableName().getIdentifier().getValue());
each.getOwner().ifPresent(optional -> schemaNames.add(optional.getIdentifier().getValue()));
for (TableSegment each : tableSegments) {
if (!(each instanceof SimpleTableSegment)) {
continue;
}
SimpleTableSegment simpleTableSegment = (SimpleTableSegment) each;
tables.add(simpleTableSegment);
tableNames.add(simpleTableSegment.getTableName().getIdentifier().getValue());
simpleTableSegment.getOwner().ifPresent(owner -> schemaNames.add(owner.getIdentifier().getValue()));
}
Collection<SubqueryTableSegment> subqueryTableSegments = tableSegments.stream().filter(each
-> each instanceof SubqueryTableSegment).map(each -> (SubqueryTableSegment) each).collect(Collectors.toList());
for (SubqueryTableSegment each : subqueryTableSegments) {
SelectStatementContext subqueryContext = subqueryContexts.get(each.getSubquery().getStartIndex());
for (TableSegment each : tableSegments) {
if (!(each instanceof SubqueryTableSegment)) {
continue;
}
SubqueryTableSegment subqueryTableSegment = (SubqueryTableSegment) each;
SelectStatementContext subqueryContext = subqueryContexts.get(subqueryTableSegment.getSubquery().getStartIndex());
Collection<SubqueryTableContext> subqueryTableContexts = new SubqueryTableContextEngine().createSubqueryTableContexts(subqueryContext, each.getAlias().orElse(null));
subqueryTables.putAll(subqueryTableContexts.stream().filter(subQuery -> null != subQuery.getAlias()).collect(Collectors.groupingBy(SubqueryTableContext::getAlias)));
Map<String, List<SubqueryTableContext>> result = new HashMap<>();
for (SubqueryTableContext subQuery : subqueryTableContexts) {
if (null != subQuery.getAlias()) {
result.computeIfAbsent(subQuery.getAlias(), unused -> new LinkedList<>()).add(subQuery);
}
}
subqueryTables.putAll(result);
}
}

Expand All @@ -101,11 +114,20 @@ public Collection<String> getTableNames() {
public Map<String, String> findTableName(final Collection<ColumnProjection> columns, final ShardingSphereSchema schema) {
if (1 == tables.size()) {
String tableName = tables.iterator().next().getTableName().getIdentifier().getValue();
return columns.stream().collect(Collectors.toMap(ColumnProjection::getExpression, each -> tableName, (oldValue, currentValue) -> oldValue));
Map<String, String> result = new LinkedHashMap<>(columns.size(), 1);
for (ColumnProjection each : columns) {
result.putIfAbsent(each.getExpression(), tableName);
}
return result;
}
Map<String, String> result = new HashMap<>(columns.size(), 1);
result.putAll(findTableNameFromSQL(getOwnerColumnNames(columns)));
Collection<String> columnNames = columns.stream().filter(each -> null == each.getOwner()).map(ColumnProjection::getName).collect(Collectors.toSet());
Collection<String> columnNames = new LinkedHashSet<>();
for (ColumnProjection each : columns) {
if (null == each.getOwner()) {
columnNames.add(each.getName());
}
}
result.putAll(findTableNameFromMetaData(columnNames, schema));
if (result.size() < columns.size() && !subqueryTables.isEmpty()) {
appendRemainingResult(columns, result);
Expand Down Expand Up @@ -185,6 +207,9 @@ private Optional<String> findTableNameFromSubquery(final String columnName, fina
*/
public Optional<String> getSchemaName() {
Preconditions.checkState(schemaNames.size() <= 1, "Can not support multiple different schema.");
return schemaNames.stream().findFirst();
for (String each : schemaNames) {
return Optional.of(each);
}
return Optional.empty();
}
}
Expand Up @@ -255,11 +255,14 @@ public Optional<WhereSegment> getWhere() {
}

private Collection<TableSegment> getAllTableSegments() {
Collection<TableSegment> result = new LinkedList<>();
TableExtractor tableExtractor = new TableExtractor();
tableExtractor.extractTablesFromSelect(getSqlStatement());
result.addAll(tableExtractor.getRewriteTables());
result.addAll(tableExtractor.getTableContext().stream().filter(each -> each instanceof SubqueryTableSegment).collect(Collectors.toList()));
Collection<TableSegment> result = new LinkedList<>(tableExtractor.getRewriteTables());
for (TableSegment each : tableExtractor.getTableContext()) {
if (each instanceof SubqueryTableSegment) {
result.add(each);
}
}
return result;
}
}
Expand Up @@ -29,6 +29,7 @@
import org.apache.shardingsphere.infra.route.context.RouteMapper;
import org.apache.shardingsphere.infra.route.context.RouteUnit;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashSet;
Expand Down Expand Up @@ -74,7 +75,11 @@ private static List<RouteMapper> getRouteTableRouteMappers(final Collection<Rout
if (null == tableMappers) {
return Collections.emptyList();
}
return tableMappers.stream().map(routeMapper -> new RouteMapper(routeMapper.getLogicName(), routeMapper.getActualName())).collect(Collectors.toList());
List<RouteMapper> result = new ArrayList<>(tableMappers.size());
for (RouteMapper each : tableMappers) {
result.add(new RouteMapper(each.getLogicName(), each.getActualName()));
}
return result;
}

private static List<RouteMapper> getGenericTableRouteMappers(final SQLStatementContext<?> sqlStatementContext) {
Expand Down
Expand Up @@ -33,6 +33,7 @@
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;

/**
* SQL rewrite entry.
Expand Down Expand Up @@ -80,6 +81,8 @@ private SQLRewriteContext createSQLRewriteContext(final String sql, final List<O

@SuppressWarnings({"unchecked", "rawtypes"})
private void decorate(final Map<ShardingSphereRule, SQLRewriteContextDecorator> decorators, final SQLRewriteContext sqlRewriteContext, final RouteContext routeContext) {
decorators.forEach((key, value) -> value.decorate(key, props, sqlRewriteContext, routeContext));
for (Entry<ShardingSphereRule, SQLRewriteContextDecorator> entry : decorators.entrySet()) {
entry.getValue().decorate(entry.getKey(), props, sqlRewriteContext, routeContext);
}
}
}
Expand Up @@ -35,9 +35,9 @@
import org.apache.shardingsphere.sql.parser.sql.common.statement.SQLStatement;

import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.stream.Collectors;

/**
* Driver JDBC executor.
Expand Down Expand Up @@ -145,6 +145,11 @@ private <T> List<T> doExecute(final ExecutionGroupContext<JDBCExecutionUnit> exe
}

private void refreshMetaData(final SQLStatement sqlStatement, final Collection<RouteUnit> routeUnits) throws SQLException {
metadataRefreshEngine.refresh(sqlStatement, routeUnits.stream().map(each -> each.getDataSourceMapper().getLogicName()).collect(Collectors.toList()));
List<String> result = new ArrayList<>(routeUnits.size());
for (RouteUnit each : routeUnits) {
String logicName = each.getDataSourceMapper().getLogicName();
result.add(logicName);
}
metadataRefreshEngine.refresh(sqlStatement, result);
}
}
Expand Up @@ -35,6 +35,7 @@

import java.sql.SQLException;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedList;
Expand Down Expand Up @@ -92,15 +93,24 @@ public void addBatchForExecutionUnits(final Collection<ExecutionUnit> executionU
}

private Collection<BatchExecutionUnit> createBatchExecutionUnits(final Collection<ExecutionUnit> executionUnits) {
return executionUnits.stream().map(BatchExecutionUnit::new).collect(Collectors.toList());
List<BatchExecutionUnit> result = new ArrayList<>(executionUnits.size());
for (ExecutionUnit executionUnit : executionUnits) {
BatchExecutionUnit batchExecutionUnit = new BatchExecutionUnit(executionUnit);
result.add(batchExecutionUnit);
}
return result;
}

private void handleOldBatchExecutionUnits(final Collection<BatchExecutionUnit> newExecutionUnits) {
newExecutionUnits.forEach(this::reviseBatchExecutionUnits);
}

private void reviseBatchExecutionUnits(final BatchExecutionUnit batchExecutionUnit) {
batchExecutionUnits.stream().filter(each -> each.equals(batchExecutionUnit)).forEach(each -> reviseBatchExecutionUnit(each, batchExecutionUnit));
for (BatchExecutionUnit each : batchExecutionUnits) {
if (each.equals(batchExecutionUnit)) {
reviseBatchExecutionUnit(each, batchExecutionUnit);
}
}
}

private void reviseBatchExecutionUnit(final BatchExecutionUnit oldBatchExecutionUnit, final BatchExecutionUnit newBatchExecutionUnit) {
Expand All @@ -110,7 +120,9 @@ private void reviseBatchExecutionUnit(final BatchExecutionUnit oldBatchExecution

private void handleNewBatchExecutionUnits(final Collection<BatchExecutionUnit> newExecutionUnits) {
newExecutionUnits.removeAll(batchExecutionUnits);
newExecutionUnits.forEach(each -> each.mapAddBatchCount(batchCount));
for (BatchExecutionUnit each : newExecutionUnits) {
each.mapAddBatchCount(batchCount);
}
batchExecutionUnits.addAll(newExecutionUnits);
}

Expand Down Expand Up @@ -184,8 +196,11 @@ private boolean isSameDataSourceAndSQL(final BatchExecutionUnit batchExecutionUn
*/
public List<Statement> getStatements() {
List<Statement> result = new LinkedList<>();
for (ExecutionGroup<JDBCExecutionUnit> each : executionGroupContext.getInputGroups()) {
result.addAll(each.getInputs().stream().map(JDBCExecutionUnit::getStorageResource).collect(Collectors.toList()));
for (ExecutionGroup<JDBCExecutionUnit> eachGroup : executionGroupContext.getInputGroups()) {
for (JDBCExecutionUnit eachUnit : eachGroup.getInputs()) {
Statement storageResource = eachUnit.getStorageResource();
result.add(storageResource);
}
}
return result;
}
Expand Down