diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/PartitionCompensator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/PartitionCompensator.java index 4c26ab47cecdee..fe2e88cd5046c3 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/PartitionCompensator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/PartitionCompensator.java @@ -198,7 +198,8 @@ private static Pair>, Pair Assertions.assertEquals(expectedUnion, v)); } + @SuppressWarnings("unchecked") + @Test + public void testCalcInvalidPartitionsDoesNotCompensateBasePartitionsUnusedByQuery() + throws Exception { + DatabaseIf baseDb = mockDatabase("cat", 1L, "db", 2L); + MTMVRelatedTableIf relatedTable = mockRelatedTableIf( + "base_t", 10L, ImmutableList.of("cat", "db", "base_t"), baseDb); + BaseColInfo colInfo = new BaseColInfo("dt", new BaseTableInfo(relatedTable)); + + DatabaseIf mvDb = mockDatabase("internal", 3L, "mv_db", 4L); + MTMV mtmv = Mockito.mock(MTMV.class); + Mockito.when(mtmv.getName()).thenReturn("mv1"); + Mockito.when(mtmv.getId()).thenReturn(100L); + Mockito.when(mtmv.getDatabase()).thenReturn(mvDb); + Mockito.when(mtmv.selectNonEmptyPartitionIds(ArgumentMatchers.any())) + .thenReturn(ImmutableList.of(1L)); + + long mvP20260301Id = 101L; + long mvP20260401Id = 102L; + long mvP20260402Id = 103L; + long mvP20260403Id = 104L; + long mvP20260428Id = 105L; + Partition mvP20260301 = mockPartition(mvP20260301Id, "mv_p20260301"); + Partition mvP20260401 = mockPartition(mvP20260401Id, "mv_p20260401"); + Partition mvP20260402 = mockPartition(mvP20260402Id, "mv_p20260402"); + Partition mvP20260403 = mockPartition(mvP20260403Id, "mv_p20260403"); + Partition mvP20260428 = mockPartition(mvP20260428Id, "mv_p20260428"); + Mockito.when(mtmv.getPartition(mvP20260301Id)).thenReturn(mvP20260301); + Mockito.when(mtmv.getPartition(mvP20260401Id)).thenReturn(mvP20260401); + Mockito.when(mtmv.getPartition(mvP20260402Id)).thenReturn(mvP20260402); + Mockito.when(mtmv.getPartition(mvP20260403Id)).thenReturn(mvP20260403); + Mockito.when(mtmv.getPartition(mvP20260428Id)).thenReturn(mvP20260428); + + PartitionInfo mvPartitionInfo = Mockito.mock(PartitionInfo.class); + Mockito.when(mtmv.getPartitionInfo()).thenReturn(mvPartitionInfo); + Mockito.when(mvPartitionInfo.getType()).thenReturn(PartitionType.RANGE); + MTMVPartitionInfo mvPctInfo = Mockito.mock(MTMVPartitionInfo.class); + Mockito.when(mtmv.getMvPartitionInfo()).thenReturn(mvPctInfo); + Mockito.when(mvPctInfo.getPctTables()).thenReturn(ImmutableSet.of(relatedTable)); + Mockito.when(mvPctInfo.getPctInfos()).thenReturn(ImmutableList.of(colInfo)); + + Map> relatedPartitionMapping = new HashMap<>(); + relatedPartitionMapping.put("mv_p20260301", ImmutableSet.of("p20260301")); + relatedPartitionMapping.put("mv_p20260401", ImmutableSet.of("p20260401")); + relatedPartitionMapping.put("mv_p20260402", ImmutableSet.of("p20260402")); + relatedPartitionMapping.put("mv_p20260403", ImmutableSet.of("p20260403")); + relatedPartitionMapping.put("mv_p20260428", ImmutableSet.of("p20260428")); + Map>> partitionMappings = new HashMap<>(); + partitionMappings.put(relatedTable, relatedPartitionMapping); + + AsyncMaterializationContext matCtx = Mockito.mock(AsyncMaterializationContext.class); + Mockito.when(matCtx.getMtmv()).thenReturn(mtmv); + Mockito.when(matCtx.calculatePartitionMappings()).thenReturn(partitionMappings); + + Map> canRewriteMap = new HashMap<>(); + canRewriteMap.put(new BaseTableInfo(mtmv), + ImmutableList.of(mvP20260401, mvP20260402)); + StatementContext stmtCtx = Mockito.mock(StatementContext.class); + Mockito.when(stmtCtx.getMvCanRewritePartitionsMap()).thenReturn(canRewriteMap); + CascadesContext cascadesCtx = Mockito.mock(CascadesContext.class); + Mockito.when(cascadesCtx.getStatementContext()).thenReturn(stmtCtx); + + LogicalOlapScan selectedMvScan = Mockito.mock(LogicalOlapScan.class); + Mockito.when(selectedMvScan.getTable()).thenReturn(mtmv); + Mockito.when(selectedMvScan.getSelectedPartitionIds()) + .thenReturn(ImmutableList.of(mvP20260301Id, mvP20260401Id, mvP20260402Id, + mvP20260403Id, mvP20260428Id)); + Plan rewrittenPlan = Mockito.mock(Plan.class); + Mockito.when(rewrittenPlan.collectToList(ArgumentMatchers.any())) + .thenReturn(ImmutableList.of(selectedMvScan)); + + Map, Set> queryUsedPartitions = new HashMap<>(); + queryUsedPartitions.put(relatedTable.getFullQualifiers(), + ImmutableSet.of("p20260401", "p20260402", "p20260403")); + + Pair>, Map>> result = + PartitionCompensator.calcInvalidPartitions(queryUsedPartitions, rewrittenPlan, matCtx, cascadesCtx); + + Assertions.assertNotNull(result); + Assertions.assertEquals(ImmutableSet.of("mv_p20260301", "mv_p20260403", "mv_p20260428"), + result.key().get(new BaseTableInfo(mtmv))); + Assertions.assertEquals(ImmutableSet.of("p20260403"), result.value().get(colInfo)); + } + @SuppressWarnings("unchecked") private static MTMVRelatedTableIf mockRelatedTableIf( String tableName, long tableId, List qualifiers, DatabaseIf db) { @@ -496,6 +581,25 @@ private static MTMVRelatedTableIf mockRelatedTableIf( return table; } + private static DatabaseIf mockDatabase(String catalogName, long catalogId, String dbName, long dbId) { + CatalogIf catalog = Mockito.mock(CatalogIf.class); + Mockito.when(catalog.getId()).thenReturn(catalogId); + Mockito.when(catalog.getName()).thenReturn(catalogName); + + DatabaseIf db = Mockito.mock(DatabaseIf.class); + Mockito.when(db.getId()).thenReturn(dbId); + Mockito.when(db.getFullName()).thenReturn(dbName); + Mockito.when(db.getCatalog()).thenReturn(catalog); + return db; + } + + private static Partition mockPartition(long partitionId, String partitionName) { + Partition partition = Mockito.mock(Partition.class); + Mockito.when(partition.getId()).thenReturn(partitionId); + Mockito.when(partition.getName()).thenReturn(partitionName); + return partition; + } + private static BaseTableInfo newBaseTableInfo() { CatalogIf catalog = Mockito.mock(CatalogIf.class); Mockito.when(catalog.getId()).thenReturn(1L);