diff --git a/kernel/single/distsql/handler/src/main/java/org/apache/shardingsphere/single/distsql/handler/update/SetDefaultSingleTableStorageUnitExecutor.java b/kernel/single/distsql/handler/src/main/java/org/apache/shardingsphere/single/distsql/handler/update/SetDefaultSingleTableStorageUnitExecutor.java index 8c3d4e0ec43f9..e06493a4624ce 100644 --- a/kernel/single/distsql/handler/src/main/java/org/apache/shardingsphere/single/distsql/handler/update/SetDefaultSingleTableStorageUnitExecutor.java +++ b/kernel/single/distsql/handler/src/main/java/org/apache/shardingsphere/single/distsql/handler/update/SetDefaultSingleTableStorageUnitExecutor.java @@ -23,12 +23,15 @@ import org.apache.shardingsphere.distsql.handler.engine.update.rdl.rule.spi.database.DatabaseRuleCreateExecutor; import org.apache.shardingsphere.infra.exception.core.ShardingSpherePreconditions; import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase; +import org.apache.shardingsphere.infra.rule.attribute.datasource.DataSourceMapperRuleAttribute; import org.apache.shardingsphere.single.api.config.SingleRuleConfiguration; import org.apache.shardingsphere.single.distsql.statement.rdl.SetDefaultSingleTableStorageUnitStatement; import org.apache.shardingsphere.single.rule.SingleRule; import java.util.Collection; import java.util.Collections; +import java.util.HashSet; +import java.util.stream.Collectors; /** * Set default single table storage unit executor. @@ -47,12 +50,18 @@ public void checkBeforeUpdate(final SetDefaultSingleTableStorageUnitStatement sq private void checkStorageUnitExist(final SetDefaultSingleTableStorageUnitStatement sqlStatement) { if (!Strings.isNullOrEmpty(sqlStatement.getDefaultStorageUnit())) { - Collection storageUnitNames = database.getResourceMetaData().getStorageUnits().keySet(); - ShardingSpherePreconditions.checkContains(storageUnitNames, sqlStatement.getDefaultStorageUnit(), + Collection dataSourceNames = new HashSet<>(database.getResourceMetaData().getStorageUnits().keySet()); + dataSourceNames.addAll(getLogicDataSourceNames()); + ShardingSpherePreconditions.checkContains(dataSourceNames, sqlStatement.getDefaultStorageUnit(), () -> new MissingRequiredStorageUnitsException(database.getName(), Collections.singleton(sqlStatement.getDefaultStorageUnit()))); } } + private Collection getLogicDataSourceNames() { + return database.getRuleMetaData().getAttributes(DataSourceMapperRuleAttribute.class).stream() + .flatMap(each -> each.getDataSourceMapper().keySet().stream()).collect(Collectors.toSet()); + } + @Override public SingleRuleConfiguration buildToBeCreatedRuleConfiguration(final SetDefaultSingleTableStorageUnitStatement sqlStatement) { SingleRuleConfiguration result = new SingleRuleConfiguration(); diff --git a/kernel/single/distsql/handler/src/test/java/org/apache/shardingsphere/single/distsql/handler/update/SetDefaultSingleTableStorageUnitExecutorTest.java b/kernel/single/distsql/handler/src/test/java/org/apache/shardingsphere/single/distsql/handler/update/SetDefaultSingleTableStorageUnitExecutorTest.java index 74f4bbdd7410b..4bfcc7c4cbdee 100644 --- a/kernel/single/distsql/handler/src/test/java/org/apache/shardingsphere/single/distsql/handler/update/SetDefaultSingleTableStorageUnitExecutorTest.java +++ b/kernel/single/distsql/handler/src/test/java/org/apache/shardingsphere/single/distsql/handler/update/SetDefaultSingleTableStorageUnitExecutorTest.java @@ -19,6 +19,7 @@ import org.apache.shardingsphere.infra.exception.kernel.metadata.resource.storageunit.MissingRequiredStorageUnitsException; import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase; +import org.apache.shardingsphere.infra.rule.attribute.datasource.DataSourceMapperRuleAttribute; import org.apache.shardingsphere.single.api.config.SingleRuleConfiguration; import org.apache.shardingsphere.single.distsql.statement.rdl.SetDefaultSingleTableStorageUnitStatement; import org.apache.shardingsphere.single.rule.SingleRule; @@ -26,11 +27,15 @@ import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.junit.jupiter.MockitoExtension; +import java.util.Collections; + import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.RETURNS_DEEP_STUBS; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -41,11 +46,23 @@ class SetDefaultSingleTableStorageUnitExecutorTest { private final SetDefaultSingleTableStorageUnitExecutor executor = new SetDefaultSingleTableStorageUnitExecutor(); @Test - void assertCheckWithInvalidResource() { - executor.setDatabase(mock(ShardingSphereDatabase.class, RETURNS_DEEP_STUBS)); + void assertCheckWithInvalidDataSource() { + ShardingSphereDatabase database = mock(ShardingSphereDatabase.class, RETURNS_DEEP_STUBS); + when(database.getRuleMetaData().getAttributes(any())).thenReturn(Collections.emptyList()); + executor.setDatabase(database); assertThrows(MissingRequiredStorageUnitsException.class, () -> executor.checkBeforeUpdate(new SetDefaultSingleTableStorageUnitStatement("bar_ds"))); } + @Test + void assertCheckWithLogicDataSource() { + ShardingSphereDatabase database = mock(ShardingSphereDatabase.class, RETURNS_DEEP_STUBS); + DataSourceMapperRuleAttribute ruleAttribute = mock(DataSourceMapperRuleAttribute.class, RETURNS_DEEP_STUBS); + when(ruleAttribute.getDataSourceMapper().keySet()).thenReturn(Collections.singleton("logic_ds")); + when(database.getRuleMetaData().getAttributes(any())).thenReturn(Collections.singleton(ruleAttribute)); + executor.setDatabase(database); + assertDoesNotThrow(() -> executor.checkBeforeUpdate(new SetDefaultSingleTableStorageUnitStatement("logic_ds"))); + } + @Test void assertBuild() { SingleRule rule = mock(SingleRule.class);