diff --git a/core/src/test/java/org/apache/iceberg/MockFileScanTask.java b/core/src/test/java/org/apache/iceberg/MockFileScanTask.java index 58275ad3f0c2..565433c82cb1 100644 --- a/core/src/test/java/org/apache/iceberg/MockFileScanTask.java +++ b/core/src/test/java/org/apache/iceberg/MockFileScanTask.java @@ -44,6 +44,17 @@ public MockFileScanTask(DataFile file, String schemaString, String specString) { this.length = file.fileSizeInBytes(); } + public MockFileScanTask(DataFile file, Schema schema, PartitionSpec spec) { + super(file, null, SchemaParser.toJson(schema), PartitionSpecParser.toJson(spec), null); + this.length = file.fileSizeInBytes(); + } + + public MockFileScanTask( + DataFile file, DeleteFile[] deleteFiles, Schema schema, PartitionSpec spec) { + super(file, deleteFiles, SchemaParser.toJson(schema), PartitionSpecParser.toJson(spec), null); + this.length = file.fileSizeInBytes(); + } + public static MockFileScanTask mockTask(long length, int sortOrderId) { DataFile mockFile = Mockito.mock(DataFile.class); Mockito.when(mockFile.fileSizeInBytes()).thenReturn(length); diff --git a/spark/v3.5/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMergeOnReadDelete.java b/spark/v3.5/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMergeOnReadDelete.java index 01f24c4dfe04..91600d4df08d 100644 --- a/spark/v3.5/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMergeOnReadDelete.java +++ b/spark/v3.5/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMergeOnReadDelete.java @@ -37,6 +37,7 @@ import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.SparkSQLProperties; import org.apache.iceberg.spark.source.SparkTable; import org.apache.iceberg.spark.source.TestSparkCatalog; import org.apache.iceberg.util.SnapshotUtil; @@ -85,6 +86,30 @@ public static void clearTestSparkCatalogCache() { TestSparkCatalog.clearTables(); } + @Test + public void testDeleteWithExecutorCacheLocality() throws NoSuchTableException { + createAndInitPartitionedTable(); + + append(tableName, new Employee(1, "hr"), new Employee(2, "hr")); + append(tableName, new Employee(3, "hr"), new Employee(4, "hr")); + append(tableName, new Employee(1, "hardware"), new Employee(2, "hardware")); + append(tableName, new Employee(3, "hardware"), new Employee(4, "hardware")); + + createBranchIfNeeded(); + + withSQLConf( + ImmutableMap.of(SparkSQLProperties.EXECUTOR_CACHE_LOCALITY_ENABLED, "true"), + () -> { + sql("DELETE FROM %s WHERE id = 1", commitTarget()); + sql("DELETE FROM %s WHERE id = 3", commitTarget()); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(2, "hardware"), row(2, "hr"), row(4, "hardware"), row(4, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC, dep ASC", selectTarget())); + }); + } + @Test public void testDeleteFileGranularity() throws NoSuchTableException { checkDeleteFileGranularity(DeleteGranularity.FILE); diff --git a/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/SparkReadConf.java b/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/SparkReadConf.java index 984e2bce1efc..2990d981d009 100644 --- a/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/SparkReadConf.java +++ b/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/SparkReadConf.java @@ -331,4 +331,24 @@ private long driverMaxResultSize() { SparkConf sparkConf = spark.sparkContext().conf(); return sparkConf.getSizeAsBytes(DRIVER_MAX_RESULT_SIZE, DRIVER_MAX_RESULT_SIZE_DEFAULT); } + + public boolean executorCacheLocalityEnabled() { + return executorCacheEnabled() && executorCacheLocalityEnabledInternal(); + } + + private boolean executorCacheEnabled() { + return confParser + .booleanConf() + .sessionConf(SparkSQLProperties.EXECUTOR_CACHE_ENABLED) + .defaultValue(SparkSQLProperties.EXECUTOR_CACHE_ENABLED_DEFAULT) + .parse(); + } + + private boolean executorCacheLocalityEnabledInternal() { + return confParser + .booleanConf() + .sessionConf(SparkSQLProperties.EXECUTOR_CACHE_LOCALITY_ENABLED) + .defaultValue(SparkSQLProperties.EXECUTOR_CACHE_LOCALITY_ENABLED_DEFAULT) + .parse(); + } } diff --git a/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java b/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java index 4a665202317b..ea8f6fe0718b 100644 --- a/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java +++ b/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java @@ -86,4 +86,8 @@ private SparkSQLProperties() {} public static final String EXECUTOR_CACHE_MAX_TOTAL_SIZE = "spark.sql.iceberg.executor-cache.max-total-size"; public static final long EXECUTOR_CACHE_MAX_TOTAL_SIZE_DEFAULT = 128 * 1024 * 1024; // 128 MB + + public static final String EXECUTOR_CACHE_LOCALITY_ENABLED = + "spark.sql.iceberg.executor-cache.locality.enabled"; + public static final boolean EXECUTOR_CACHE_LOCALITY_ENABLED_DEFAULT = false; } diff --git a/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/SparkUtil.java b/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/SparkUtil.java index 2357ca0441fc..de06cceb2677 100644 --- a/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/SparkUtil.java +++ b/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/SparkUtil.java @@ -34,6 +34,8 @@ import org.apache.iceberg.transforms.Transform; import org.apache.iceberg.transforms.UnknownTransform; import org.apache.iceberg.util.Pair; +import org.apache.spark.SparkEnv; +import org.apache.spark.scheduler.ExecutorCacheTaskLocation; import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.catalyst.expressions.BoundReference; import org.apache.spark.sql.catalyst.expressions.EqualTo; @@ -43,7 +45,12 @@ import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.StructType; +import org.apache.spark.storage.BlockManager; +import org.apache.spark.storage.BlockManagerId; +import org.apache.spark.storage.BlockManagerMaster; import org.joda.time.DateTime; +import scala.collection.JavaConverters; +import scala.collection.Seq; public class SparkUtil { private static final String SPARK_CATALOG_CONF_PREFIX = "spark.sql.catalog"; @@ -238,4 +245,27 @@ public static String toColumnName(NamedReference ref) { public static boolean caseSensitive(SparkSession spark) { return Boolean.parseBoolean(spark.conf().get("spark.sql.caseSensitive")); } + + public static List executorLocations() { + BlockManager driverBlockManager = SparkEnv.get().blockManager(); + List executorBlockManagerIds = fetchPeers(driverBlockManager); + return executorBlockManagerIds.stream() + .map(SparkUtil::toExecutorLocation) + .sorted() + .collect(Collectors.toList()); + } + + private static List fetchPeers(BlockManager blockManager) { + BlockManagerMaster master = blockManager.master(); + BlockManagerId id = blockManager.blockManagerId(); + return toJavaList(master.getPeers(id)); + } + + private static List toJavaList(Seq seq) { + return JavaConverters.seqAsJavaListConverter(seq).asJava(); + } + + private static String toExecutorLocation(BlockManagerId id) { + return ExecutorCacheTaskLocation.apply(id.host(), id.executorId()).toString(); + } } diff --git a/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/source/SparkBatch.java b/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/source/SparkBatch.java index 4ed37a9f3d24..fd6783f3e1f7 100644 --- a/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/source/SparkBatch.java +++ b/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/source/SparkBatch.java @@ -29,9 +29,8 @@ import org.apache.iceberg.SchemaParser; import org.apache.iceberg.Table; import org.apache.iceberg.spark.SparkReadConf; +import org.apache.iceberg.spark.SparkUtil; import org.apache.iceberg.types.Types; -import org.apache.iceberg.util.Tasks; -import org.apache.iceberg.util.ThreadPools; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.broadcast.Broadcast; import org.apache.spark.sql.connector.read.Batch; @@ -49,6 +48,7 @@ class SparkBatch implements Batch { private final Schema expectedSchema; private final boolean caseSensitive; private final boolean localityEnabled; + private final boolean executorCacheLocalityEnabled; private final int scanHashCode; SparkBatch( @@ -68,6 +68,7 @@ class SparkBatch implements Batch { this.expectedSchema = expectedSchema; this.caseSensitive = readConf.caseSensitive(); this.localityEnabled = readConf.localityEnabled(); + this.executorCacheLocalityEnabled = readConf.executorCacheLocalityEnabled(); this.scanHashCode = scanHashCode; } @@ -77,27 +78,39 @@ public InputPartition[] planInputPartitions() { Broadcast tableBroadcast = sparkContext.broadcast(SerializableTableWithSize.copyOf(table)); String expectedSchemaString = SchemaParser.toJson(expectedSchema); + String[][] locations = computePreferredLocations(); InputPartition[] partitions = new InputPartition[taskGroups.size()]; - Tasks.range(partitions.length) - .stopOnFailure() - .executeWith(localityEnabled ? ThreadPools.getWorkerPool() : null) - .run( - index -> - partitions[index] = - new SparkInputPartition( - groupingKeyType, - taskGroups.get(index), - tableBroadcast, - branch, - expectedSchemaString, - caseSensitive, - localityEnabled)); + for (int index = 0; index < taskGroups.size(); index++) { + partitions[index] = + new SparkInputPartition( + groupingKeyType, + taskGroups.get(index), + tableBroadcast, + branch, + expectedSchemaString, + caseSensitive, + locations != null ? locations[index] : SparkPlanningUtil.NO_LOCATION_PREFERENCE); + } return partitions; } + private String[][] computePreferredLocations() { + if (localityEnabled) { + return SparkPlanningUtil.fetchBlockLocations(table.io(), taskGroups); + + } else if (executorCacheLocalityEnabled) { + List executorLocations = SparkUtil.executorLocations(); + if (!executorLocations.isEmpty()) { + return SparkPlanningUtil.assignExecutors(taskGroups, executorLocations); + } + } + + return null; + } + @Override public PartitionReaderFactory createReaderFactory() { if (useParquetBatchReads()) { diff --git a/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/source/SparkInputPartition.java b/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/source/SparkInputPartition.java index 0394b691e152..7826322be7de 100644 --- a/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/source/SparkInputPartition.java +++ b/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/source/SparkInputPartition.java @@ -24,8 +24,6 @@ import org.apache.iceberg.Schema; import org.apache.iceberg.SchemaParser; import org.apache.iceberg.Table; -import org.apache.iceberg.hadoop.HadoopInputFile; -import org.apache.iceberg.hadoop.Util; import org.apache.iceberg.types.Types; import org.apache.spark.broadcast.Broadcast; import org.apache.spark.sql.catalyst.InternalRow; @@ -39,9 +37,9 @@ class SparkInputPartition implements InputPartition, HasPartitionKey, Serializab private final String branch; private final String expectedSchemaString; private final boolean caseSensitive; + private final transient String[] preferredLocations; private transient Schema expectedSchema = null; - private transient String[] preferredLocations = null; SparkInputPartition( Types.StructType groupingKeyType, @@ -50,19 +48,14 @@ class SparkInputPartition implements InputPartition, HasPartitionKey, Serializab String branch, String expectedSchemaString, boolean caseSensitive, - boolean localityPreferred) { + String[] preferredLocations) { this.groupingKeyType = groupingKeyType; this.taskGroup = taskGroup; this.tableBroadcast = tableBroadcast; this.branch = branch; this.expectedSchemaString = expectedSchemaString; this.caseSensitive = caseSensitive; - if (localityPreferred) { - Table table = tableBroadcast.value(); - this.preferredLocations = Util.blockLocations(table.io(), taskGroup); - } else { - this.preferredLocations = HadoopInputFile.NO_LOCATION_PREFERENCE; - } + this.preferredLocations = preferredLocations; } @Override diff --git a/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/source/SparkMicroBatchStream.java b/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/source/SparkMicroBatchStream.java index 3ffd9904bbf3..320d2e14adc9 100644 --- a/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/source/SparkMicroBatchStream.java +++ b/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/source/SparkMicroBatchStream.java @@ -54,8 +54,6 @@ import org.apache.iceberg.util.PropertyUtil; import org.apache.iceberg.util.SnapshotUtil; import org.apache.iceberg.util.TableScanUtil; -import org.apache.iceberg.util.Tasks; -import org.apache.iceberg.util.ThreadPools; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.broadcast.Broadcast; import org.apache.spark.sql.connector.read.InputPartition; @@ -154,27 +152,29 @@ public InputPartition[] planInputPartitions(Offset start, Offset end) { List combinedScanTasks = Lists.newArrayList( TableScanUtil.planTasks(splitTasks, splitSize, splitLookback, splitOpenFileCost)); + String[][] locations = computePreferredLocations(combinedScanTasks); InputPartition[] partitions = new InputPartition[combinedScanTasks.size()]; - Tasks.range(partitions.length) - .stopOnFailure() - .executeWith(localityPreferred ? ThreadPools.getWorkerPool() : null) - .run( - index -> - partitions[index] = - new SparkInputPartition( - EMPTY_GROUPING_KEY_TYPE, - combinedScanTasks.get(index), - tableBroadcast, - branch, - expectedSchema, - caseSensitive, - localityPreferred)); + for (int index = 0; index < combinedScanTasks.size(); index++) { + partitions[index] = + new SparkInputPartition( + EMPTY_GROUPING_KEY_TYPE, + combinedScanTasks.get(index), + tableBroadcast, + branch, + expectedSchema, + caseSensitive, + locations != null ? locations[index] : SparkPlanningUtil.NO_LOCATION_PREFERENCE); + } return partitions; } + private String[][] computePreferredLocations(List taskGroups) { + return localityPreferred ? SparkPlanningUtil.fetchBlockLocations(table.io(), taskGroups) : null; + } + @Override public PartitionReaderFactory createReaderFactory() { return new SparkRowReaderFactory(); diff --git a/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/source/SparkPlanningUtil.java b/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/source/SparkPlanningUtil.java new file mode 100644 index 000000000000..9cdec2c8f463 --- /dev/null +++ b/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/source/SparkPlanningUtil.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.util.List; +import java.util.Map; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.ScanTask; +import org.apache.iceberg.ScanTaskGroup; +import org.apache.iceberg.StructLike; +import org.apache.iceberg.hadoop.Util; +import org.apache.iceberg.io.FileIO; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.types.JavaHash; +import org.apache.iceberg.util.Tasks; +import org.apache.iceberg.util.ThreadPools; + +class SparkPlanningUtil { + + public static final String[] NO_LOCATION_PREFERENCE = new String[0]; + + private SparkPlanningUtil() {} + + public static String[][] fetchBlockLocations( + FileIO io, List> taskGroups) { + String[][] locations = new String[taskGroups.size()][]; + + Tasks.range(taskGroups.size()) + .stopOnFailure() + .executeWith(ThreadPools.getWorkerPool()) + .run(index -> locations[index] = Util.blockLocations(io, taskGroups.get(index))); + + return locations; + } + + public static String[][] assignExecutors( + List> taskGroups, List executorLocations) { + Map> partitionHashes = Maps.newHashMap(); + String[][] locations = new String[taskGroups.size()][]; + + for (int index = 0; index < taskGroups.size(); index++) { + locations[index] = assign(taskGroups.get(index), executorLocations, partitionHashes); + } + + return locations; + } + + private static String[] assign( + ScanTaskGroup taskGroup, + List executorLocations, + Map> partitionHashes) { + List locations = Lists.newArrayList(); + + for (ScanTask task : taskGroup.tasks()) { + if (task.isFileScanTask()) { + FileScanTask fileTask = task.asFileScanTask(); + PartitionSpec spec = fileTask.spec(); + if (spec.isPartitioned() && !fileTask.deletes().isEmpty()) { + JavaHash partitionHash = + partitionHashes.computeIfAbsent(spec.specId(), key -> partitionHash(spec)); + int partitionHashCode = partitionHash.hash(fileTask.partition()); + int index = Math.floorMod(partitionHashCode, executorLocations.size()); + String executorLocation = executorLocations.get(index); + locations.add(executorLocation); + } + } + } + + return locations.toArray(NO_LOCATION_PREFERENCE); + } + + private static JavaHash partitionHash(PartitionSpec spec) { + return JavaHash.forType(spec.partitionType()); + } +} diff --git a/spark/v3.5/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkPlanningUtil.java b/spark/v3.5/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkPlanningUtil.java new file mode 100644 index 000000000000..65c6790e5b49 --- /dev/null +++ b/spark/v3.5/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkPlanningUtil.java @@ -0,0 +1,213 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.apache.iceberg.types.Types.NestedField.required; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.when; + +import java.util.List; +import org.apache.iceberg.BaseScanTaskGroup; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.DataTask; +import org.apache.iceberg.DeleteFile; +import org.apache.iceberg.MockFileScanTask; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.ScanTask; +import org.apache.iceberg.ScanTaskGroup; +import org.apache.iceberg.Schema; +import org.apache.iceberg.StructLike; +import org.apache.iceberg.TestHelpers.Row; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.spark.TestBaseWithCatalog; +import org.apache.iceberg.types.Types; +import org.junit.jupiter.api.TestTemplate; +import org.mockito.Mockito; + +public class TestSparkPlanningUtil extends TestBaseWithCatalog { + + private static final Schema SCHEMA = + new Schema( + required(1, "id", Types.IntegerType.get()), + required(2, "data", Types.StringType.get()), + required(3, "category", Types.StringType.get())); + private static final PartitionSpec SPEC_1 = + PartitionSpec.builderFor(SCHEMA).withSpecId(1).bucket("id", 16).identity("data").build(); + private static final PartitionSpec SPEC_2 = + PartitionSpec.builderFor(SCHEMA).withSpecId(2).identity("data").build(); + private static final List EXECUTOR_LOCATIONS = + ImmutableList.of("host1_exec1", "host1_exec2", "host1_exec3", "host2_exec1", "host2_exec2"); + + @TestTemplate + public void testFileScanTaskWithoutDeletes() { + List tasks = + ImmutableList.of( + new MockFileScanTask(mockDataFile(Row.of(1, "a")), SCHEMA, SPEC_1), + new MockFileScanTask(mockDataFile(Row.of(2, "b")), SCHEMA, SPEC_1), + new MockFileScanTask(mockDataFile(Row.of(3, "c")), SCHEMA, SPEC_1)); + ScanTaskGroup taskGroup = new BaseScanTaskGroup<>(tasks); + List> taskGroups = ImmutableList.of(taskGroup); + + String[][] locations = SparkPlanningUtil.assignExecutors(taskGroups, EXECUTOR_LOCATIONS); + + // should not assign executors if there are no deletes + assertThat(locations.length).isEqualTo(1); + assertThat(locations[0]).isEmpty(); + } + + @TestTemplate + public void testFileScanTaskWithDeletes() { + StructLike partition1 = Row.of("k2", null); + StructLike partition2 = Row.of("k1"); + List tasks = + ImmutableList.of( + new MockFileScanTask( + mockDataFile(partition1), mockDeleteFiles(1, partition1), SCHEMA, SPEC_1), + new MockFileScanTask( + mockDataFile(partition2), mockDeleteFiles(3, partition2), SCHEMA, SPEC_2), + new MockFileScanTask( + mockDataFile(partition1), mockDeleteFiles(2, partition1), SCHEMA, SPEC_1)); + ScanTaskGroup taskGroup = new BaseScanTaskGroup<>(tasks); + List> taskGroups = ImmutableList.of(taskGroup); + + String[][] locations = SparkPlanningUtil.assignExecutors(taskGroups, EXECUTOR_LOCATIONS); + + // should assign executors and handle different size of partitions + assertThat(locations.length).isEqualTo(1); + assertThat(locations[0].length).isGreaterThanOrEqualTo(1); + } + + @TestTemplate + public void testFileScanTaskWithUnpartitionedDeletes() { + List tasks1 = + ImmutableList.of( + new MockFileScanTask( + mockDataFile(Row.of()), + mockDeleteFiles(2, Row.of()), + SCHEMA, + PartitionSpec.unpartitioned()), + new MockFileScanTask( + mockDataFile(Row.of()), + mockDeleteFiles(2, Row.of()), + SCHEMA, + PartitionSpec.unpartitioned()), + new MockFileScanTask( + mockDataFile(Row.of()), + mockDeleteFiles(2, Row.of()), + SCHEMA, + PartitionSpec.unpartitioned())); + ScanTaskGroup taskGroup1 = new BaseScanTaskGroup<>(tasks1); + List tasks2 = + ImmutableList.of( + new MockFileScanTask( + mockDataFile(null), + mockDeleteFiles(2, null), + SCHEMA, + PartitionSpec.unpartitioned()), + new MockFileScanTask( + mockDataFile(null), + mockDeleteFiles(2, null), + SCHEMA, + PartitionSpec.unpartitioned()), + new MockFileScanTask( + mockDataFile(null), + mockDeleteFiles(2, null), + SCHEMA, + PartitionSpec.unpartitioned())); + ScanTaskGroup taskGroup2 = new BaseScanTaskGroup<>(tasks2); + List> taskGroups = ImmutableList.of(taskGroup1, taskGroup2); + + String[][] locations = SparkPlanningUtil.assignExecutors(taskGroups, EXECUTOR_LOCATIONS); + + // should not assign executors if the table is unpartitioned + assertThat(locations.length).isEqualTo(2); + assertThat(locations[0]).isEmpty(); + assertThat(locations[1]).isEmpty(); + } + + @TestTemplate + public void testDataTasks() { + List tasks = + ImmutableList.of( + new MockDataTask(mockDataFile(Row.of(1, "a"))), + new MockDataTask(mockDataFile(Row.of(2, "b"))), + new MockDataTask(mockDataFile(Row.of(3, "c")))); + ScanTaskGroup taskGroup = new BaseScanTaskGroup<>(tasks); + List> taskGroups = ImmutableList.of(taskGroup); + + String[][] locations = SparkPlanningUtil.assignExecutors(taskGroups, EXECUTOR_LOCATIONS); + + // should not assign executors for data tasks + assertThat(locations.length).isEqualTo(1); + assertThat(locations[0]).isEmpty(); + } + + @TestTemplate + public void testUnknownTasks() { + List tasks = ImmutableList.of(new UnknownScanTask(), new UnknownScanTask()); + ScanTaskGroup taskGroup = new BaseScanTaskGroup<>(tasks); + List> taskGroups = ImmutableList.of(taskGroup); + + String[][] locations = SparkPlanningUtil.assignExecutors(taskGroups, EXECUTOR_LOCATIONS); + + // should not assign executors for unknown tasks + assertThat(locations.length).isEqualTo(1); + assertThat(locations[0]).isEmpty(); + } + + private static DataFile mockDataFile(StructLike partition) { + DataFile file = Mockito.mock(DataFile.class); + when(file.partition()).thenReturn(partition); + return file; + } + + private static DeleteFile[] mockDeleteFiles(int count, StructLike partition) { + DeleteFile[] files = new DeleteFile[count]; + for (int index = 0; index < count; index++) { + files[index] = mockDeleteFile(partition); + } + return files; + } + + private static DeleteFile mockDeleteFile(StructLike partition) { + DeleteFile file = Mockito.mock(DeleteFile.class); + when(file.partition()).thenReturn(partition); + return file; + } + + private static class MockDataTask extends MockFileScanTask implements DataTask { + + MockDataTask(DataFile file) { + super(file); + } + + @Override + public PartitionSpec spec() { + return PartitionSpec.unpartitioned(); + } + + @Override + public CloseableIterable rows() { + throw new UnsupportedOperationException(); + } + } + + private static class UnknownScanTask implements ScanTask {} +}