diff --git a/be/src/pipeline/exec/cache_source_operator.cpp b/be/src/pipeline/exec/cache_source_operator.cpp index cc986b22ad7650..3adb97d969f05d 100644 --- a/be/src/pipeline/exec/cache_source_operator.cpp +++ b/be/src/pipeline/exec/cache_source_operator.cpp @@ -38,9 +38,6 @@ Status CacheSourceLocalState::init(RuntimeState* state, LocalStateInfo& info) { ->data_queue.set_source_dependency(_shared_state->source_deps.front()); const auto& scan_ranges = info.scan_ranges; bool hit_cache = false; - if (scan_ranges.size() > 1) { - return Status::InternalError("CacheSourceOperator only support one scan range, plan error"); - } const auto& cache_param = _parent->cast()._cache_param; // 1. init the slot orders @@ -60,8 +57,20 @@ Status CacheSourceLocalState::init(RuntimeState* state, LocalStateInfo& info) { // 2. build cache key by digest_tablet_id RETURN_IF_ERROR(QueryCache::build_cache_key(scan_ranges, cache_param, &_cache_key, &_version)); - custom_profile()->add_info_string( - "CacheTabletId", std::to_string(scan_ranges[0].scan_range.palo_scan_range.tablet_id)); + std::vector cache_tablet_ids; + cache_tablet_ids.reserve(scan_ranges.size()); + for (const auto& scan_range : scan_ranges) { + cache_tablet_ids.push_back(scan_range.scan_range.palo_scan_range.tablet_id); + } + std::sort(cache_tablet_ids.begin(), cache_tablet_ids.end()); + std::string tablet_ids_str; + for (size_t i = 0; i < cache_tablet_ids.size(); ++i) { + tablet_ids_str += std::to_string(cache_tablet_ids[i]); + if (i < cache_tablet_ids.size() - 1) { + tablet_ids_str += ","; + } + } + custom_profile()->add_info_string("CacheTabletId", tablet_ids_str); // 3. lookup the cache and find proper slot order hit_cache = _global_cache->lookup(_cache_key, _version, &_query_cache_handle); diff --git a/be/src/pipeline/query_cache/query_cache.h b/be/src/pipeline/query_cache/query_cache.h index 4ac06bd511670b..b47d6306283215 100644 --- a/be/src/pipeline/query_cache/query_cache.h +++ b/be/src/pipeline/query_cache/query_cache.h @@ -109,27 +109,76 @@ class QueryCache : public LRUCachePolicy { static Status build_cache_key(const std::vector& scan_ranges, const TQueryCacheParam& cache_param, std::string* cache_key, int64_t* version) { - if (scan_ranges.size() > 1) { - return Status::InternalError( - "CacheSourceOperator only support one scan range, plan error"); + if (scan_ranges.empty()) { + return Status::InternalError("scan_ranges is empty, plan error"); } - auto& scan_range = scan_ranges[0]; - DCHECK(scan_range.scan_range.__isset.palo_scan_range); - auto tablet_id = scan_range.scan_range.palo_scan_range.tablet_id; - - std::from_chars(scan_range.scan_range.palo_scan_range.version.data(), - scan_range.scan_range.palo_scan_range.version.data() + - scan_range.scan_range.palo_scan_range.version.size(), - *version); - - auto find_tablet = cache_param.tablet_to_range.find(tablet_id); - if (find_tablet == cache_param.tablet_to_range.end()) { - return Status::InternalError("Not find tablet in partition_to_tablets, plan error"); + + std::string digest; + try { + digest = cache_param.digest; + } catch (const std::exception&) { + return Status::InternalError("digest is invalid, plan error"); + } + if (digest.empty()) { + return Status::InternalError("digest is empty, plan error"); + } + + if (cache_param.tablet_to_range.empty()) { + return Status::InternalError("tablet_to_range is empty, plan error"); + } + + std::vector tablet_ids; + tablet_ids.reserve(scan_ranges.size()); + for (const auto& scan_range : scan_ranges) { + auto tablet_id = scan_range.scan_range.palo_scan_range.tablet_id; + tablet_ids.push_back(tablet_id); + } + std::sort(tablet_ids.begin(), tablet_ids.end()); + + int64_t first_version = -1; + std::string first_tablet_range; + for (size_t i = 0; i < tablet_ids.size(); ++i) { + auto tablet_id = tablet_ids[i]; + + auto find_tablet = cache_param.tablet_to_range.find(tablet_id); + if (find_tablet == cache_param.tablet_to_range.end()) { + return Status::InternalError("Not find tablet in partition_to_tablets, plan error"); + } + + auto scan_range_iter = + std::find_if(scan_ranges.begin(), scan_ranges.end(), + [&tablet_id](const TScanRangeParams& range) { + return range.scan_range.palo_scan_range.tablet_id == tablet_id; + }); + int64_t current_version = -1; + std::from_chars(scan_range_iter->scan_range.palo_scan_range.version.data(), + scan_range_iter->scan_range.palo_scan_range.version.data() + + scan_range_iter->scan_range.palo_scan_range.version.size(), + current_version); + + if (i == 0) { + first_version = current_version; + first_tablet_range = find_tablet->second; + } else { + if (current_version != first_version) { + return Status::InternalError( + "All tablets in one instance must have the same version, plan error"); + } + if (find_tablet->second != first_tablet_range) { + return Status::InternalError( + "All tablets in one instance must have the same tablet_to_range, plan " + "error"); + } + } } - *cache_key = cache_param.digest + - std::string(reinterpret_cast(&tablet_id), sizeof(tablet_id)) + - find_tablet->second; + *version = first_version; + + *cache_key = digest; + for (auto tablet_id : tablet_ids) { + *cache_key += std::string(reinterpret_cast(&tablet_id), sizeof(tablet_id)); + } + *cache_key += first_tablet_range; return Status::OK(); } diff --git a/be/test/pipeline/exec/query_cache_test.cpp b/be/test/pipeline/exec/query_cache_test.cpp index 0f0d37f2e4de42..c8ee7082a6e7f8 100644 --- a/be/test/pipeline/exec/query_cache_test.cpp +++ b/be/test/pipeline/exec/query_cache_test.cpp @@ -40,9 +40,24 @@ TEST_F(QueryCacheTest, create_global_cache) { TEST_F(QueryCacheTest, build_cache_key) { { std::vector scan_ranges; - scan_ranges.push_back({}); - scan_ranges.push_back({}); + TScanRangeParams scan_range1; + TPaloScanRange palp_scan_range1; + palp_scan_range1.__set_tablet_id(1); + palp_scan_range1.__set_version("100"); + scan_range1.scan_range.__set_palo_scan_range(palp_scan_range1); + scan_ranges.emplace_back(scan_range1); + + TScanRangeParams scan_range2; + TPaloScanRange palp_scan_range2; + palp_scan_range2.__set_tablet_id(2); + palp_scan_range2.__set_version("100"); + scan_range2.scan_range.__set_palo_scan_range(palp_scan_range2); + scan_ranges.emplace_back(scan_range2); + TQueryCacheParam cache_param; + cache_param.__set_digest("test_digest"); + cache_param.tablet_to_range.insert({1, "range_abc"}); + cache_param.tablet_to_range.insert({2, "range_xyz"}); std::string cache_key; int64_t version = 0; auto st = QueryCache::build_cache_key(scan_ranges, cache_param, &cache_key, &version); @@ -59,6 +74,7 @@ TEST_F(QueryCacheTest, build_cache_key) { scan_range.scan_range.__set_palo_scan_range(palp_scan_range); scan_ranges.push_back(scan_range); TQueryCacheParam cache_param; + cache_param.__set_digest("test_digest"); std::string cache_key; int64_t version = 0; auto st = QueryCache::build_cache_key(scan_ranges, cache_param, &cache_key, &version); @@ -87,6 +103,156 @@ TEST_F(QueryCacheTest, build_cache_key) { } } +TEST_F(QueryCacheTest, build_cache_key_multiple_tablets) { + { + std::vector scan_ranges; + TScanRangeParams scan_range1; + TPaloScanRange palp_scan_range1; + palp_scan_range1.__set_tablet_id(3); + palp_scan_range1.__set_version("100"); + scan_range1.scan_range.__set_palo_scan_range(palp_scan_range1); + scan_ranges.push_back(scan_range1); + + TScanRangeParams scan_range2; + TPaloScanRange palp_scan_range2; + palp_scan_range2.__set_tablet_id(1); + palp_scan_range2.__set_version("100"); + scan_range2.scan_range.__set_palo_scan_range(palp_scan_range2); + scan_ranges.push_back(scan_range2); + + TScanRangeParams scan_range3; + TPaloScanRange palp_scan_range3; + palp_scan_range3.__set_tablet_id(2); + palp_scan_range3.__set_version("100"); + scan_range3.scan_range.__set_palo_scan_range(palp_scan_range3); + scan_ranges.push_back(scan_range3); + + TQueryCacheParam cache_param; + cache_param.__set_digest("test_digest"); + cache_param.tablet_to_range.insert({1, "range_abc"}); + cache_param.tablet_to_range.insert({2, "range_abc"}); + cache_param.tablet_to_range.insert({3, "range_abc"}); + + std::string cache_key; + int64_t version = 0; + auto st = QueryCache::build_cache_key(scan_ranges, cache_param, &cache_key, &version); + + EXPECT_TRUE(st.ok()); + EXPECT_EQ(version, 100); + + int64_t expected_tablet1 = 1; + int64_t expected_tablet2 = 2; + int64_t expected_tablet3 = 3; + std::string expected_key = + "test_digest" + + std::string(reinterpret_cast(&expected_tablet1), sizeof(expected_tablet1)) + + std::string(reinterpret_cast(&expected_tablet2), sizeof(expected_tablet2)) + + std::string(reinterpret_cast(&expected_tablet3), sizeof(expected_tablet3)) + + "range_abc"; + + EXPECT_EQ(cache_key, expected_key); + } + + { + std::vector scan_ranges; + TScanRangeParams scan_range1; + TPaloScanRange palp_scan_range1; + palp_scan_range1.__set_tablet_id(1); + palp_scan_range1.__set_version("100"); + scan_range1.scan_range.__set_palo_scan_range(palp_scan_range1); + scan_ranges.push_back(scan_range1); + + TScanRangeParams scan_range2; + TPaloScanRange palp_scan_range2; + palp_scan_range2.__set_tablet_id(2); + palp_scan_range2.__set_version("200"); + scan_range2.scan_range.__set_palo_scan_range(palp_scan_range2); + scan_ranges.push_back(scan_range2); + + TQueryCacheParam cache_param; + cache_param.__set_digest("test_digest"); + cache_param.tablet_to_range.insert({1, "range_abc"}); + cache_param.tablet_to_range.insert({2, "range_abc"}); + + std::string cache_key; + int64_t version = 0; + auto st = QueryCache::build_cache_key(scan_ranges, cache_param, &cache_key, &version); + + EXPECT_FALSE(st.ok()); + EXPECT_TRUE(st.msg().find("same version") != std::string::npos); + } + + { + std::vector scan_ranges; + TScanRangeParams scan_range1; + TPaloScanRange palp_scan_range1; + palp_scan_range1.__set_tablet_id(1); + palp_scan_range1.__set_version("100"); + scan_range1.scan_range.__set_palo_scan_range(palp_scan_range1); + scan_ranges.push_back(scan_range1); + + TScanRangeParams scan_range2; + TPaloScanRange palp_scan_range2; + palp_scan_range2.__set_tablet_id(2); + palp_scan_range2.__set_version("100"); + scan_range2.scan_range.__set_palo_scan_range(palp_scan_range2); + scan_ranges.push_back(scan_range2); + + TQueryCacheParam cache_param; + cache_param.__set_digest("test_digest"); + cache_param.tablet_to_range.insert({1, "range_abc"}); + cache_param.tablet_to_range.insert({2, "range_xyz"}); + + std::string cache_key; + int64_t version = 0; + auto st = QueryCache::build_cache_key(scan_ranges, cache_param, &cache_key, &version); + + EXPECT_FALSE(st.ok()); + EXPECT_TRUE(st.msg().find("same tablet_to_range") != std::string::npos); + } + + { + std::vector scan_ranges; + TScanRangeParams scan_range1; + TPaloScanRange palp_scan_range1; + palp_scan_range1.__set_tablet_id(1); + palp_scan_range1.__set_version("100"); + scan_range1.scan_range.__set_palo_scan_range(palp_scan_range1); + scan_ranges.push_back(scan_range1); + + TScanRangeParams scan_range2; + TPaloScanRange palp_scan_range2; + palp_scan_range2.__set_tablet_id(2); + palp_scan_range2.__set_version("100"); + scan_range2.scan_range.__set_palo_scan_range(palp_scan_range2); + scan_ranges.push_back(scan_range2); + + TQueryCacheParam cache_param; + cache_param.__set_digest("test_digest"); + cache_param.tablet_to_range.insert({1, "range_abc"}); + cache_param.tablet_to_range.insert({3, "range_abc"}); + + std::string cache_key; + int64_t version = 0; + auto st = QueryCache::build_cache_key(scan_ranges, cache_param, &cache_key, &version); + + EXPECT_FALSE(st.ok()); + EXPECT_TRUE(st.msg().find("Not find tablet") != std::string::npos); + } + + { + std::vector scan_ranges; + TQueryCacheParam cache_param; + cache_param.__set_digest("test_digest"); + std::string cache_key; + int64_t version = 0; + auto st = QueryCache::build_cache_key(scan_ranges, cache_param, &cache_key, &version); + + EXPECT_FALSE(st.ok()); + EXPECT_TRUE(st.msg().find("empty") != std::string::npos); + } +} + TEST_F(QueryCacheTest, insert_and_lookup) { std::unique_ptr query_cache {QueryCache::create_global_cache(1024 * 1024 * 1024)}; std::string cache_key = "be ut"; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/distribute/worker/job/UnassignedScanSingleOlapTableJob.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/distribute/worker/job/UnassignedScanSingleOlapTableJob.java index 649e2fa9bb28ea..5516bdf5ddd34c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/distribute/worker/job/UnassignedScanSingleOlapTableJob.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/distribute/worker/job/UnassignedScanSingleOlapTableJob.java @@ -17,6 +17,9 @@ package org.apache.doris.nereids.trees.plans.distribute.worker.job; +import org.apache.doris.catalog.MaterializedIndex; +import org.apache.doris.catalog.Partition; +import org.apache.doris.catalog.Tablet; import org.apache.doris.nereids.StatementContext; import org.apache.doris.nereids.trees.plans.distribute.DistributeContext; import org.apache.doris.nereids.trees.plans.distribute.worker.DistributedPlanWorker; @@ -25,16 +28,27 @@ import org.apache.doris.planner.ExchangeNode; import org.apache.doris.planner.OlapScanNode; import org.apache.doris.planner.PlanFragment; +import org.apache.doris.qe.ConnectContext; +import org.apache.doris.thrift.TScanRangeParams; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.ListMultimap; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.LinkedHashSet; import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Set; /** UnassignedScanSingleOlapTableJob */ public class UnassignedScanSingleOlapTableJob extends AbstractUnassignedScanJob { + private static final Logger LOG = LogManager.getLogger(UnassignedScanSingleOlapTableJob.class); + private OlapScanNode olapScanNode; private final ScanWorkerSelector scanWorkerSelector; @@ -81,9 +95,135 @@ protected List insideMachineParallelization( // instance 5: olapScanNode1: ScanRanges([tablet_10007]) // ], // } + if (usePartitionParallelismForQueryCache(workerToScanRanges, distributeContext)) { + try { + // Best effort optimization for query cache: keep tablets in same partition + // on the same instance to reduce BE concurrency pressure. + List partitionInstances = insideMachineParallelizationByPartition(workerToScanRanges); + if (partitionInstances != null) { + return partitionInstances; + } + } catch (Exception e) { + LOG.warn("Failed to assign query cache instances by partition, fallback to default planning", + e); + } + } + return super.insideMachineParallelization(workerToScanRanges, inputJobs, distributeContext); } + private List insideMachineParallelizationByPartition( + Map workerToScanRanges) { + List selectedPartitionIds = new ArrayList<>(olapScanNode.getSelectedPartitionIds()); + Map tabletToPartitionId = buildTabletToPartitionId(selectedPartitionIds); + if (tabletToPartitionId.size() != olapScanNode.getScanTabletIds().size()) { + return null; + } + + ConnectContext context = statementContext.getConnectContext(); + List instances = new ArrayList<>(); + for (Map.Entry entry : workerToScanRanges.entrySet()) { + DistributedPlanWorker worker = entry.getKey(); + ScanSource scanSource = entry.getValue().scanSource; + if (!(scanSource instanceof DefaultScanSource)) { + return null; + } + + DefaultScanSource defaultScanSource = (DefaultScanSource) scanSource; + ScanRanges scanRanges = defaultScanSource.scanNodeToScanRanges.get(olapScanNode); + if (scanRanges == null) { + return null; + } + if (scanRanges.params.isEmpty()) { + continue; + } + + Map partitionToScanRanges = splitScanRangesByPartition(scanRanges, tabletToPartitionId); + if (partitionToScanRanges == null) { + return null; + } + + // One partition on one BE maps to one instance. Different BEs may miss some partitions. + for (Long partitionId : selectedPartitionIds) { + ScanRanges partitionScanRanges = partitionToScanRanges.remove(partitionId); + if (partitionScanRanges == null || partitionScanRanges.params.isEmpty()) { + continue; + } + instances.add(assignWorkerAndDataSources( + instances.size(), context.nextInstanceId(), worker, + new DefaultScanSource(ImmutableMap.of(olapScanNode, partitionScanRanges)))); + } + + if (!partitionToScanRanges.isEmpty()) { + return null; + } + } + return instances; + } + + private boolean usePartitionParallelismForQueryCache( + Map workerToScanRanges, + DistributeContext distributeContext) { + if (fragment.queryCacheParam == null || workerToScanRanges.isEmpty()) { + return false; + } + + ConnectContext context = statementContext.getConnectContext(); + if (context == null || useLocalShuffleToAddParallel(distributeContext)) { + return false; + } + + long totalTabletNum = olapScanNode.getScanTabletIds().size(); + int parallelPipelineTaskNum = Math.max( + context.getSessionVariable().getParallelExecInstanceNum(), 1); + long threshold = (long) parallelPipelineTaskNum * workerToScanRanges.size(); + return totalTabletNum > threshold; + } + + private Map buildTabletToPartitionId(List selectedPartitionIds) { + long selectedIndexId = olapScanNode.getSelectedIndexId(); + if (selectedIndexId == -1) { + selectedIndexId = olapScanNode.getOlapTable().getBaseIndexId(); + } + + Set scanTabletIds = new LinkedHashSet<>(olapScanNode.getScanTabletIds()); + Map tabletToPartitionId = new LinkedHashMap<>(scanTabletIds.size()); + for (Long partitionId : selectedPartitionIds) { + Partition partition = olapScanNode.getOlapTable().getPartition(partitionId); + if (partition == null) { + continue; + } + MaterializedIndex index = partition.getIndex(selectedIndexId); + if (index == null) { + continue; + } + for (Tablet tablet : index.getTablets()) { + long tabletId = tablet.getId(); + if (scanTabletIds.contains(tabletId)) { + tabletToPartitionId.put(tabletId, partitionId); + } + } + } + return tabletToPartitionId; + } + + private Map splitScanRangesByPartition( + ScanRanges scanRanges, Map tabletToPartitionId) { + Map partitionToScanRanges = new LinkedHashMap<>(); + for (int i = 0; i < scanRanges.params.size(); i++) { + TScanRangeParams scanRangeParams = scanRanges.params.get(i); + long tabletId = scanRangeParams.getScanRange().getPaloScanRange().getTabletId(); + Long partitionId = tabletToPartitionId.get(tabletId); + if (partitionId == null) { + return null; + } + partitionToScanRanges + .computeIfAbsent(partitionId, id -> new ScanRanges()) + .addScanRange(scanRangeParams, scanRanges.bytes.get(i)); + } + return partitionToScanRanges; + } + @Override protected List fillUpAssignedJobs(List assignedJobs, DistributedPlanWorkerManager workerManager, ListMultimap inputJobs) { diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/distribute/worker/job/UnassignedScanSingleOlapTableJobTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/distribute/worker/job/UnassignedScanSingleOlapTableJobTest.java new file mode 100644 index 00000000000000..e4cf2e7d8c0278 --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/distribute/worker/job/UnassignedScanSingleOlapTableJobTest.java @@ -0,0 +1,376 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.trees.plans.distribute.worker.job; + +import org.apache.doris.catalog.LocalTablet; +import org.apache.doris.catalog.MaterializedIndex; +import org.apache.doris.catalog.OlapTable; +import org.apache.doris.catalog.Partition; +import org.apache.doris.catalog.Tablet; +import org.apache.doris.nereids.StatementContext; +import org.apache.doris.nereids.trees.plans.distribute.DistributeContext; +import org.apache.doris.nereids.trees.plans.distribute.worker.DistributedPlanWorker; +import org.apache.doris.nereids.trees.plans.distribute.worker.DistributedPlanWorkerManager; +import org.apache.doris.nereids.trees.plans.distribute.worker.ScanWorkerSelector; +import org.apache.doris.planner.DataPartition; +import org.apache.doris.planner.OlapScanNode; +import org.apache.doris.planner.PlanFragment; +import org.apache.doris.planner.PlanFragmentId; +import org.apache.doris.qe.ConnectContext; +import org.apache.doris.qe.OriginStatement; +import org.apache.doris.thrift.TPaloScanRange; +import org.apache.doris.thrift.TQueryCacheParam; +import org.apache.doris.thrift.TScanRange; +import org.apache.doris.thrift.TScanRangeParams; +import org.apache.doris.thrift.TUniqueId; + +import com.google.common.collect.ArrayListMultimap; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +public class UnassignedScanSingleOlapTableJobTest { + @Test + public void testQueryCacheAssignByPartition() { + ConnectContext connectContext = new ConnectContext(); + connectContext.setThreadLocalInfo(); + connectContext.setQueryId(new TUniqueId(1, 1)); + connectContext.getSessionVariable().parallelPipelineTaskNum = 1; + StatementContext statementContext = new StatementContext( + connectContext, new OriginStatement("select * from t", 0)); + connectContext.setStatementContext(statementContext); + + long partitionOne = 100L; + long partitionTwo = 200L; + long selectedIndexId = 10L; + Map tabletToPartition = ImmutableMap.of( + 1L, partitionOne, + 2L, partitionOne, + 3L, partitionOne, + 4L, partitionTwo, + 5L, partitionTwo, + 6L, partitionTwo + ); + + OlapScanNode olapScanNode = Mockito.mock(OlapScanNode.class); + OlapTable olapTable = Mockito.mock(OlapTable.class); + Mockito.when(olapScanNode.getSelectedPartitionIds()) + .thenReturn(Arrays.asList(partitionOne, partitionTwo)); + Mockito.when(olapScanNode.getSelectedIndexId()).thenReturn(selectedIndexId); + Mockito.when(olapScanNode.getOlapTable()).thenReturn(olapTable); + Mockito.when(olapScanNode.getScanTabletIds()) + .thenReturn(new ArrayList<>(tabletToPartition.keySet())); + + Partition firstPartition = Mockito.mock(Partition.class); + MaterializedIndex firstIndex = Mockito.mock(MaterializedIndex.class); + Mockito.when(olapTable.getPartition(partitionOne)).thenReturn(firstPartition); + Mockito.when(firstPartition.getIndex(selectedIndexId)).thenReturn(firstIndex); + Mockito.when(firstIndex.getTablets()).thenReturn(ImmutableList.of( + tablet(1L), tablet(2L), tablet(3L) + )); + + Partition secondPartition = Mockito.mock(Partition.class); + MaterializedIndex secondIndex = Mockito.mock(MaterializedIndex.class); + Mockito.when(olapTable.getPartition(partitionTwo)).thenReturn(secondPartition); + Mockito.when(secondPartition.getIndex(selectedIndexId)).thenReturn(secondIndex); + Mockito.when(secondIndex.getTablets()).thenReturn(ImmutableList.of( + tablet(4L), tablet(5L), tablet(6L) + )); + + PlanFragment fragment = new PlanFragment(new PlanFragmentId(0), null, DataPartition.RANDOM); + fragment.queryCacheParam = new TQueryCacheParam(); + + DistributedPlanWorker worker1 = new TestWorker(1L, "be1"); + DistributedPlanWorker worker2 = new TestWorker(2L, "be2"); + Map workerToScanSources + = new LinkedHashMap<>(); + // Same partition tablets on one BE should be grouped into one instance. + workerToScanSources.put(worker1, new UninstancedScanSource(new DefaultScanSource( + ImmutableMap.of(olapScanNode, scanRanges(1L, 2L, 4L))))); + workerToScanSources.put(worker2, new UninstancedScanSource(new DefaultScanSource( + ImmutableMap.of(olapScanNode, scanRanges(3L, 5L, 6L))))); + + ScanWorkerSelector scanWorkerSelector = Mockito.mock(ScanWorkerSelector.class); + Mockito.when(scanWorkerSelector.selectReplicaAndWorkerWithoutBucket( + Mockito.eq(olapScanNode), Mockito.eq(connectContext) + )).thenReturn(workerToScanSources); + + UnassignedScanSingleOlapTableJob unassignedJob = new UnassignedScanSingleOlapTableJob( + statementContext, + fragment, + olapScanNode, + ArrayListMultimap.create(), + scanWorkerSelector + ); + DistributeContext distributeContext = new DistributeContext( + Mockito.mock(DistributedPlanWorkerManager.class), + true + ); + + List assignedJobs = unassignedJob.computeAssignedJobs( + distributeContext, ArrayListMultimap.create()); + + Assertions.assertEquals(4, assignedJobs.size()); + + Map>> workerToInstanceTablets = new HashMap<>(); + for (AssignedJob assignedJob : assignedJobs) { + DefaultScanSource defaultScanSource = (DefaultScanSource) assignedJob.getScanSource(); + ScanRanges ranges = defaultScanSource.scanNodeToScanRanges.get(olapScanNode); + Set tabletIds = ranges.params.stream() + .map(param -> param.getScanRange().getPaloScanRange().getTabletId()) + .collect(Collectors.toCollection(HashSet::new)); + Set partitionIds = tabletIds.stream() + .map(tabletToPartition::get) + .collect(Collectors.toSet()); + + // Every instance must only contain tablets from one partition. + Assertions.assertEquals(1, partitionIds.size()); + + workerToInstanceTablets.computeIfAbsent( + assignedJob.getAssignedWorker().id(), k -> new HashSet<>() + ).add(tabletIds); + } + + Map>> expected = new HashMap<>(); + expected.put(1L, new HashSet<>(Arrays.asList( + new HashSet<>(Arrays.asList(1L, 2L)), + new HashSet<>(Arrays.asList(4L)) + ))); + expected.put(2L, new HashSet<>(Arrays.asList( + new HashSet<>(Arrays.asList(3L)), + new HashSet<>(Arrays.asList(5L, 6L)) + ))); + + // Different partitions are split into different instances on each BE. + Assertions.assertEquals(expected, workerToInstanceTablets); + } + + @Test + public void testQueryCacheFallbackToDefaultWhenPartitionMappingIncomplete() { + ConnectContext connectContext = new ConnectContext(); + connectContext.setThreadLocalInfo(); + connectContext.setQueryId(new TUniqueId(2, 2)); + connectContext.getSessionVariable().parallelPipelineTaskNum = 1; + StatementContext statementContext = new StatementContext( + connectContext, new OriginStatement("select * from t", 0)); + connectContext.setStatementContext(statementContext); + + long partitionOne = 100L; + long selectedIndexId = 10L; + + OlapScanNode olapScanNode = Mockito.mock(OlapScanNode.class); + OlapTable olapTable = Mockito.mock(OlapTable.class); + // Intentionally miss partitionTwo to trigger fallback. + Mockito.when(olapScanNode.getSelectedPartitionIds()) + .thenReturn(ImmutableList.of(partitionOne)); + Mockito.when(olapScanNode.getSelectedIndexId()).thenReturn(selectedIndexId); + Mockito.when(olapScanNode.getOlapTable()).thenReturn(olapTable); + Mockito.when(olapScanNode.getScanTabletIds()) + .thenReturn(new ArrayList<>(ImmutableList.of(1L, 2L, 3L, 4L, 5L, 6L))); + + Partition firstPartition = Mockito.mock(Partition.class); + MaterializedIndex firstIndex = Mockito.mock(MaterializedIndex.class); + Mockito.when(olapTable.getPartition(partitionOne)).thenReturn(firstPartition); + Mockito.when(firstPartition.getIndex(selectedIndexId)).thenReturn(firstIndex); + Mockito.when(firstIndex.getTablets()) + .thenReturn(ImmutableList.of(tablet(1L), tablet(2L), tablet(3L))); + + PlanFragment fragment = new PlanFragment(new PlanFragmentId(0), null, DataPartition.RANDOM); + fragment.queryCacheParam = new TQueryCacheParam(); + + DistributedPlanWorker worker1 = new TestWorker(1L, "be1"); + DistributedPlanWorker worker2 = new TestWorker(2L, "be2"); + Map workerToScanSources + = new LinkedHashMap<>(); + workerToScanSources.put(worker1, new UninstancedScanSource(new DefaultScanSource( + ImmutableMap.of(olapScanNode, scanRanges(1L, 2L, 4L))))); + workerToScanSources.put(worker2, new UninstancedScanSource(new DefaultScanSource( + ImmutableMap.of(olapScanNode, scanRanges(3L, 5L, 6L))))); + + ScanWorkerSelector scanWorkerSelector = Mockito.mock(ScanWorkerSelector.class); + Mockito.when(scanWorkerSelector.selectReplicaAndWorkerWithoutBucket( + Mockito.eq(olapScanNode), Mockito.eq(connectContext) + )).thenReturn(workerToScanSources); + + UnassignedScanSingleOlapTableJob unassignedJob = new UnassignedScanSingleOlapTableJob( + statementContext, + fragment, + olapScanNode, + ArrayListMultimap.create(), + scanWorkerSelector + ); + + List assignedJobs = unassignedJob.computeAssignedJobs( + new DistributeContext(Mockito.mock(DistributedPlanWorkerManager.class), true), + ArrayListMultimap.create()); + + // query cache default planning uses one instance per tablet. + Assertions.assertEquals(6, assignedJobs.size()); + } + + @Test + public void testNonQueryCacheUseDefaultPlanning() { + ConnectContext connectContext = new ConnectContext(); + connectContext.setThreadLocalInfo(); + connectContext.setQueryId(new TUniqueId(3, 3)); + connectContext.getSessionVariable().parallelPipelineTaskNum = 1; + StatementContext statementContext = new StatementContext( + connectContext, new OriginStatement("select * from t", 0)); + connectContext.setStatementContext(statementContext); + + long partitionOne = 100L; + long partitionTwo = 200L; + long selectedIndexId = 10L; + + OlapScanNode olapScanNode = Mockito.mock(OlapScanNode.class); + OlapTable olapTable = Mockito.mock(OlapTable.class); + Mockito.when(olapScanNode.getSelectedPartitionIds()) + .thenReturn(Arrays.asList(partitionOne, partitionTwo)); + Mockito.when(olapScanNode.getSelectedIndexId()).thenReturn(selectedIndexId); + Mockito.when(olapScanNode.getOlapTable()).thenReturn(olapTable); + Mockito.when(olapScanNode.getScanTabletIds()) + .thenReturn(new ArrayList<>(ImmutableList.of(1L, 2L, 3L, 4L, 5L, 6L))); + + Partition firstPartition = Mockito.mock(Partition.class); + MaterializedIndex firstIndex = Mockito.mock(MaterializedIndex.class); + Mockito.when(olapTable.getPartition(partitionOne)).thenReturn(firstPartition); + Mockito.when(firstPartition.getIndex(selectedIndexId)).thenReturn(firstIndex); + Mockito.when(firstIndex.getTablets()) + .thenReturn(ImmutableList.of(tablet(1L), tablet(2L), tablet(3L))); + + Partition secondPartition = Mockito.mock(Partition.class); + MaterializedIndex secondIndex = Mockito.mock(MaterializedIndex.class); + Mockito.when(olapTable.getPartition(partitionTwo)).thenReturn(secondPartition); + Mockito.when(secondPartition.getIndex(selectedIndexId)).thenReturn(secondIndex); + Mockito.when(secondIndex.getTablets()) + .thenReturn(ImmutableList.of(tablet(4L), tablet(5L), tablet(6L))); + + PlanFragment fragment = new PlanFragment(new PlanFragmentId(0), null, DataPartition.RANDOM); + // No query cache param, must use default planning. + fragment.setParallelExecNum(10); + + DistributedPlanWorker worker1 = new TestWorker(1L, "be1"); + DistributedPlanWorker worker2 = new TestWorker(2L, "be2"); + Map workerToScanSources + = new LinkedHashMap<>(); + workerToScanSources.put(worker1, new UninstancedScanSource(new DefaultScanSource( + ImmutableMap.of(olapScanNode, scanRanges(1L, 2L, 4L))))); + workerToScanSources.put(worker2, new UninstancedScanSource(new DefaultScanSource( + ImmutableMap.of(olapScanNode, scanRanges(3L, 5L, 6L))))); + + ScanWorkerSelector scanWorkerSelector = Mockito.mock(ScanWorkerSelector.class); + Mockito.when(scanWorkerSelector.selectReplicaAndWorkerWithoutBucket( + Mockito.eq(olapScanNode), Mockito.eq(connectContext) + )).thenReturn(workerToScanSources); + + UnassignedScanSingleOlapTableJob unassignedJob = new UnassignedScanSingleOlapTableJob( + statementContext, + fragment, + olapScanNode, + ArrayListMultimap.create(), + scanWorkerSelector + ); + + List assignedJobs = unassignedJob.computeAssignedJobs( + new DistributeContext(Mockito.mock(DistributedPlanWorkerManager.class), true), + ArrayListMultimap.create()); + + // default planning splits by tablet count when parallelExecNum is large enough. + Assertions.assertEquals(6, assignedJobs.size()); + } + + private static Tablet tablet(long tabletId) { + return new LocalTablet(tabletId); + } + + private static ScanRanges scanRanges(long... tabletIds) { + ScanRanges scanRanges = new ScanRanges(); + for (long tabletId : tabletIds) { + TPaloScanRange paloScanRange = new TPaloScanRange(); + paloScanRange.setTabletId(tabletId); + TScanRange scanRange = new TScanRange(); + scanRange.setPaloScanRange(paloScanRange); + TScanRangeParams scanRangeParams = new TScanRangeParams(); + scanRangeParams.setScanRange(scanRange); + scanRanges.addScanRange(scanRangeParams, 1L); + } + return scanRanges; + } + + private static class TestWorker implements DistributedPlanWorker { + private final long id; + private final String address; + + private TestWorker(long id, String address) { + this.id = id; + this.address = address; + } + + @Override + public long getCatalogId() { + return 0; + } + + @Override + public long id() { + return id; + } + + @Override + public String address() { + return address; + } + + @Override + public String host() { + return address; + } + + @Override + public int port() { + return 0; + } + + @Override + public String brpcAddress() { + return address; + } + + @Override + public int brpcPort() { + return 0; + } + + @Override + public boolean available() { + return true; + } + } +}