Skip to content

Commit

Permalink
Spark 3.5: Support executor cache locality (#9563)
Browse files Browse the repository at this point in the history
  • Loading branch information
aokolnychyi committed Feb 5, 2024
1 parent c4cb0fb commit c745ac3
Show file tree
Hide file tree
Showing 10 changed files with 444 additions and 42 deletions.
11 changes: 11 additions & 0 deletions core/src/test/java/org/apache/iceberg/MockFileScanTask.java
Expand Up @@ -44,6 +44,17 @@ public MockFileScanTask(DataFile file, String schemaString, String specString) {
this.length = file.fileSizeInBytes();
}

public MockFileScanTask(DataFile file, Schema schema, PartitionSpec spec) {
super(file, null, SchemaParser.toJson(schema), PartitionSpecParser.toJson(spec), null);
this.length = file.fileSizeInBytes();
}

public MockFileScanTask(
DataFile file, DeleteFile[] deleteFiles, Schema schema, PartitionSpec spec) {
super(file, deleteFiles, SchemaParser.toJson(schema), PartitionSpecParser.toJson(spec), null);
this.length = file.fileSizeInBytes();
}

public static MockFileScanTask mockTask(long length, int sortOrderId) {
DataFile mockFile = Mockito.mock(DataFile.class);
Mockito.when(mockFile.fileSizeInBytes()).thenReturn(length);
Expand Down
Expand Up @@ -37,6 +37,7 @@
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList;
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap;
import org.apache.iceberg.relocated.com.google.common.collect.Lists;
import org.apache.iceberg.spark.SparkSQLProperties;
import org.apache.iceberg.spark.source.SparkTable;
import org.apache.iceberg.spark.source.TestSparkCatalog;
import org.apache.iceberg.util.SnapshotUtil;
Expand Down Expand Up @@ -85,6 +86,30 @@ public static void clearTestSparkCatalogCache() {
TestSparkCatalog.clearTables();
}

@Test
public void testDeleteWithExecutorCacheLocality() throws NoSuchTableException {
createAndInitPartitionedTable();

append(tableName, new Employee(1, "hr"), new Employee(2, "hr"));
append(tableName, new Employee(3, "hr"), new Employee(4, "hr"));
append(tableName, new Employee(1, "hardware"), new Employee(2, "hardware"));
append(tableName, new Employee(3, "hardware"), new Employee(4, "hardware"));

createBranchIfNeeded();

withSQLConf(
ImmutableMap.of(SparkSQLProperties.EXECUTOR_CACHE_LOCALITY_ENABLED, "true"),
() -> {
sql("DELETE FROM %s WHERE id = 1", commitTarget());
sql("DELETE FROM %s WHERE id = 3", commitTarget());

assertEquals(
"Should have expected rows",
ImmutableList.of(row(2, "hardware"), row(2, "hr"), row(4, "hardware"), row(4, "hr")),
sql("SELECT * FROM %s ORDER BY id ASC, dep ASC", selectTarget()));
});
}

@Test
public void testDeleteFileGranularity() throws NoSuchTableException {
checkDeleteFileGranularity(DeleteGranularity.FILE);
Expand Down
Expand Up @@ -331,4 +331,24 @@ private long driverMaxResultSize() {
SparkConf sparkConf = spark.sparkContext().conf();
return sparkConf.getSizeAsBytes(DRIVER_MAX_RESULT_SIZE, DRIVER_MAX_RESULT_SIZE_DEFAULT);
}

public boolean executorCacheLocalityEnabled() {
return executorCacheEnabled() && executorCacheLocalityEnabledInternal();
}

private boolean executorCacheEnabled() {
return confParser
.booleanConf()
.sessionConf(SparkSQLProperties.EXECUTOR_CACHE_ENABLED)
.defaultValue(SparkSQLProperties.EXECUTOR_CACHE_ENABLED_DEFAULT)
.parse();
}

private boolean executorCacheLocalityEnabledInternal() {
return confParser
.booleanConf()
.sessionConf(SparkSQLProperties.EXECUTOR_CACHE_LOCALITY_ENABLED)
.defaultValue(SparkSQLProperties.EXECUTOR_CACHE_LOCALITY_ENABLED_DEFAULT)
.parse();
}
}
Expand Up @@ -86,4 +86,8 @@ private SparkSQLProperties() {}
public static final String EXECUTOR_CACHE_MAX_TOTAL_SIZE =
"spark.sql.iceberg.executor-cache.max-total-size";
public static final long EXECUTOR_CACHE_MAX_TOTAL_SIZE_DEFAULT = 128 * 1024 * 1024; // 128 MB

public static final String EXECUTOR_CACHE_LOCALITY_ENABLED =
"spark.sql.iceberg.executor-cache.locality.enabled";
public static final boolean EXECUTOR_CACHE_LOCALITY_ENABLED_DEFAULT = false;
}
Expand Up @@ -34,6 +34,8 @@
import org.apache.iceberg.transforms.Transform;
import org.apache.iceberg.transforms.UnknownTransform;
import org.apache.iceberg.util.Pair;
import org.apache.spark.SparkEnv;
import org.apache.spark.scheduler.ExecutorCacheTaskLocation;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.catalyst.expressions.BoundReference;
import org.apache.spark.sql.catalyst.expressions.EqualTo;
Expand All @@ -43,7 +45,12 @@
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.storage.BlockManager;
import org.apache.spark.storage.BlockManagerId;
import org.apache.spark.storage.BlockManagerMaster;
import org.joda.time.DateTime;
import scala.collection.JavaConverters;
import scala.collection.Seq;

public class SparkUtil {
private static final String SPARK_CATALOG_CONF_PREFIX = "spark.sql.catalog";
Expand Down Expand Up @@ -238,4 +245,27 @@ public static String toColumnName(NamedReference ref) {
public static boolean caseSensitive(SparkSession spark) {
return Boolean.parseBoolean(spark.conf().get("spark.sql.caseSensitive"));
}

public static List<String> executorLocations() {
BlockManager driverBlockManager = SparkEnv.get().blockManager();
List<BlockManagerId> executorBlockManagerIds = fetchPeers(driverBlockManager);
return executorBlockManagerIds.stream()
.map(SparkUtil::toExecutorLocation)
.sorted()
.collect(Collectors.toList());
}

private static List<BlockManagerId> fetchPeers(BlockManager blockManager) {
BlockManagerMaster master = blockManager.master();
BlockManagerId id = blockManager.blockManagerId();
return toJavaList(master.getPeers(id));
}

private static <T> List<T> toJavaList(Seq<T> seq) {
return JavaConverters.seqAsJavaListConverter(seq).asJava();
}

private static String toExecutorLocation(BlockManagerId id) {
return ExecutorCacheTaskLocation.apply(id.host(), id.executorId()).toString();
}
}
Expand Up @@ -29,9 +29,8 @@
import org.apache.iceberg.SchemaParser;
import org.apache.iceberg.Table;
import org.apache.iceberg.spark.SparkReadConf;
import org.apache.iceberg.spark.SparkUtil;
import org.apache.iceberg.types.Types;
import org.apache.iceberg.util.Tasks;
import org.apache.iceberg.util.ThreadPools;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.sql.connector.read.Batch;
Expand All @@ -49,6 +48,7 @@ class SparkBatch implements Batch {
private final Schema expectedSchema;
private final boolean caseSensitive;
private final boolean localityEnabled;
private final boolean executorCacheLocalityEnabled;
private final int scanHashCode;

SparkBatch(
Expand All @@ -68,6 +68,7 @@ class SparkBatch implements Batch {
this.expectedSchema = expectedSchema;
this.caseSensitive = readConf.caseSensitive();
this.localityEnabled = readConf.localityEnabled();
this.executorCacheLocalityEnabled = readConf.executorCacheLocalityEnabled();
this.scanHashCode = scanHashCode;
}

Expand All @@ -77,27 +78,39 @@ public InputPartition[] planInputPartitions() {
Broadcast<Table> tableBroadcast =
sparkContext.broadcast(SerializableTableWithSize.copyOf(table));
String expectedSchemaString = SchemaParser.toJson(expectedSchema);
String[][] locations = computePreferredLocations();

InputPartition[] partitions = new InputPartition[taskGroups.size()];

Tasks.range(partitions.length)
.stopOnFailure()
.executeWith(localityEnabled ? ThreadPools.getWorkerPool() : null)
.run(
index ->
partitions[index] =
new SparkInputPartition(
groupingKeyType,
taskGroups.get(index),
tableBroadcast,
branch,
expectedSchemaString,
caseSensitive,
localityEnabled));
for (int index = 0; index < taskGroups.size(); index++) {
partitions[index] =
new SparkInputPartition(
groupingKeyType,
taskGroups.get(index),
tableBroadcast,
branch,
expectedSchemaString,
caseSensitive,
locations != null ? locations[index] : SparkPlanningUtil.NO_LOCATION_PREFERENCE);
}

return partitions;
}

private String[][] computePreferredLocations() {
if (localityEnabled) {
return SparkPlanningUtil.fetchBlockLocations(table.io(), taskGroups);

} else if (executorCacheLocalityEnabled) {
List<String> executorLocations = SparkUtil.executorLocations();
if (!executorLocations.isEmpty()) {
return SparkPlanningUtil.assignExecutors(taskGroups, executorLocations);
}
}

return null;
}

@Override
public PartitionReaderFactory createReaderFactory() {
if (useParquetBatchReads()) {
Expand Down
Expand Up @@ -24,8 +24,6 @@
import org.apache.iceberg.Schema;
import org.apache.iceberg.SchemaParser;
import org.apache.iceberg.Table;
import org.apache.iceberg.hadoop.HadoopInputFile;
import org.apache.iceberg.hadoop.Util;
import org.apache.iceberg.types.Types;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.sql.catalyst.InternalRow;
Expand All @@ -39,9 +37,9 @@ class SparkInputPartition implements InputPartition, HasPartitionKey, Serializab
private final String branch;
private final String expectedSchemaString;
private final boolean caseSensitive;
private final transient String[] preferredLocations;

private transient Schema expectedSchema = null;
private transient String[] preferredLocations = null;

SparkInputPartition(
Types.StructType groupingKeyType,
Expand All @@ -50,19 +48,14 @@ class SparkInputPartition implements InputPartition, HasPartitionKey, Serializab
String branch,
String expectedSchemaString,
boolean caseSensitive,
boolean localityPreferred) {
String[] preferredLocations) {
this.groupingKeyType = groupingKeyType;
this.taskGroup = taskGroup;
this.tableBroadcast = tableBroadcast;
this.branch = branch;
this.expectedSchemaString = expectedSchemaString;
this.caseSensitive = caseSensitive;
if (localityPreferred) {
Table table = tableBroadcast.value();
this.preferredLocations = Util.blockLocations(table.io(), taskGroup);
} else {
this.preferredLocations = HadoopInputFile.NO_LOCATION_PREFERENCE;
}
this.preferredLocations = preferredLocations;
}

@Override
Expand Down
Expand Up @@ -54,8 +54,6 @@
import org.apache.iceberg.util.PropertyUtil;
import org.apache.iceberg.util.SnapshotUtil;
import org.apache.iceberg.util.TableScanUtil;
import org.apache.iceberg.util.Tasks;
import org.apache.iceberg.util.ThreadPools;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.sql.connector.read.InputPartition;
Expand Down Expand Up @@ -154,27 +152,29 @@ public InputPartition[] planInputPartitions(Offset start, Offset end) {
List<CombinedScanTask> combinedScanTasks =
Lists.newArrayList(
TableScanUtil.planTasks(splitTasks, splitSize, splitLookback, splitOpenFileCost));
String[][] locations = computePreferredLocations(combinedScanTasks);

InputPartition[] partitions = new InputPartition[combinedScanTasks.size()];

Tasks.range(partitions.length)
.stopOnFailure()
.executeWith(localityPreferred ? ThreadPools.getWorkerPool() : null)
.run(
index ->
partitions[index] =
new SparkInputPartition(
EMPTY_GROUPING_KEY_TYPE,
combinedScanTasks.get(index),
tableBroadcast,
branch,
expectedSchema,
caseSensitive,
localityPreferred));
for (int index = 0; index < combinedScanTasks.size(); index++) {
partitions[index] =
new SparkInputPartition(
EMPTY_GROUPING_KEY_TYPE,
combinedScanTasks.get(index),
tableBroadcast,
branch,
expectedSchema,
caseSensitive,
locations != null ? locations[index] : SparkPlanningUtil.NO_LOCATION_PREFERENCE);
}

return partitions;
}

private String[][] computePreferredLocations(List<CombinedScanTask> taskGroups) {
return localityPreferred ? SparkPlanningUtil.fetchBlockLocations(table.io(), taskGroups) : null;
}

@Override
public PartitionReaderFactory createReaderFactory() {
return new SparkRowReaderFactory();
Expand Down

0 comments on commit c745ac3

Please sign in to comment.