diff --git a/spark/v3.2/spark/src/main/java/org/apache/iceberg/spark/source/SparkBatch.java b/spark/v3.2/spark/src/main/java/org/apache/iceberg/spark/source/SparkBatch.java index ee4674f27bb3..a9a8f95691b0 100644 --- a/spark/v3.2/spark/src/main/java/org/apache/iceberg/spark/source/SparkBatch.java +++ b/spark/v3.2/spark/src/main/java/org/apache/iceberg/spark/source/SparkBatch.java @@ -34,26 +34,24 @@ import org.apache.iceberg.util.ThreadPools; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.broadcast.Broadcast; +import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.connector.read.Batch; import org.apache.spark.sql.connector.read.InputPartition; import org.apache.spark.sql.connector.read.PartitionReaderFactory; -class SparkBatch implements Batch { +abstract class SparkBatch implements Batch { private final JavaSparkContext sparkContext; private final Table table; private final SparkReadConf readConf; - private final List tasks; private final Schema expectedSchema; private final boolean caseSensitive; private final boolean localityEnabled; - SparkBatch(JavaSparkContext sparkContext, Table table, SparkReadConf readConf, - List tasks, Schema expectedSchema) { - this.sparkContext = sparkContext; + SparkBatch(SparkSession spark, Table table, SparkReadConf readConf, Schema expectedSchema) { + this.sparkContext = JavaSparkContext.fromSparkContext(spark.sparkContext()); this.table = table; this.readConf = readConf; - this.tasks = tasks; this.expectedSchema = expectedSchema; this.caseSensitive = readConf.caseSensitive(); this.localityEnabled = readConf.localityEnabled(); @@ -65,18 +63,24 @@ public InputPartition[] planInputPartitions() { Broadcast tableBroadcast = sparkContext.broadcast(SerializableTable.copyOf(table)); String expectedSchemaString = SchemaParser.toJson(expectedSchema); - InputPartition[] readTasks = new InputPartition[tasks.size()]; + InputPartition[] readTasks = new InputPartition[tasks().size()]; Tasks.range(readTasks.length) .stopOnFailure() .executeWith(localityEnabled ? ThreadPools.getWorkerPool() : null) .run(index -> readTasks[index] = new ReadTask( - tasks.get(index), tableBroadcast, expectedSchemaString, + tasks().get(index), tableBroadcast, expectedSchemaString, caseSensitive, localityEnabled)); return readTasks; } + protected abstract List tasks(); + + protected JavaSparkContext sparkContext() { + return sparkContext; + } + @Override public PartitionReaderFactory createReaderFactory() { return new ReaderFactory(batchSize()); @@ -93,7 +97,7 @@ private int batchSize() { } private boolean parquetOnly() { - return tasks.stream().allMatch(task -> !task.isDataTask() && onlyFileFormat(task, FileFormat.PARQUET)); + return tasks().stream().allMatch(task -> !task.isDataTask() && onlyFileFormat(task, FileFormat.PARQUET)); } private boolean parquetBatchReadsEnabled() { @@ -103,12 +107,12 @@ private boolean parquetBatchReadsEnabled() { } private boolean orcOnly() { - return tasks.stream().allMatch(task -> !task.isDataTask() && onlyFileFormat(task, FileFormat.ORC)); + return tasks().stream().allMatch(task -> !task.isDataTask() && onlyFileFormat(task, FileFormat.ORC)); } private boolean orcBatchReadsEnabled() { return readConf.orcVectorizationEnabled() && // vectorization enabled - tasks.stream().noneMatch(TableScanUtil::hasDeletes); // no delete files + tasks().stream().noneMatch(TableScanUtil::hasDeletes); // no delete files } private boolean onlyFileFormat(CombinedScanTask task, FileFormat fileFormat) { diff --git a/spark/v3.2/spark/src/main/java/org/apache/iceberg/spark/source/SparkScan.java b/spark/v3.2/spark/src/main/java/org/apache/iceberg/spark/source/SparkScan.java index 7d93ad66e1e8..b9292541eaeb 100644 --- a/spark/v3.2/spark/src/main/java/org/apache/iceberg/spark/source/SparkScan.java +++ b/spark/v3.2/spark/src/main/java/org/apache/iceberg/spark/source/SparkScan.java @@ -40,7 +40,6 @@ import org.apache.iceberg.spark.SparkSchemaUtil; import org.apache.iceberg.spark.SparkUtil; import org.apache.iceberg.util.PropertyUtil; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.broadcast.Broadcast; import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.catalyst.InternalRow; @@ -57,10 +56,9 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -abstract class SparkScan implements Scan, SupportsReportStatistics { +abstract class SparkScan extends SparkBatch implements Scan, SupportsReportStatistics { private static final Logger LOG = LoggerFactory.getLogger(SparkScan.class); - private final JavaSparkContext sparkContext; private final Table table; private final SparkReadConf readConf; private final boolean caseSensitive; @@ -69,14 +67,14 @@ abstract class SparkScan implements Scan, SupportsReportStatistics { private final boolean readTimestampWithoutZone; // lazy variables - private StructType readSchema = null; + private StructType readSchema; SparkScan(SparkSession spark, Table table, SparkReadConf readConf, Schema expectedSchema, List filters) { + super(spark, table, readConf, expectedSchema); SparkSchemaUtil.validateMetadataColumnReferences(table.schema(), expectedSchema); - this.sparkContext = JavaSparkContext.fromSparkContext(spark.sparkContext()); this.table = table; this.readConf = readConf; this.caseSensitive = readConf.caseSensitive(); @@ -101,16 +99,14 @@ protected List filterExpressions() { return filterExpressions; } - protected abstract List tasks(); - @Override public Batch toBatch() { - return new SparkBatch(sparkContext, table, readConf, tasks(), expectedSchema); + return this; } @Override public MicroBatchStream toMicroBatchStream(String checkpointLocation) { - return new SparkMicroBatchStream(sparkContext, table, readConf, expectedSchema, checkpointLocation); + return new SparkMicroBatchStream(sparkContext(), table, readConf, expectedSchema, checkpointLocation); } @Override diff --git a/spark/v3.2/spark/src/test/java/org/apache/iceberg/spark/actions/TestExpireSnapshotsAction.java b/spark/v3.2/spark/src/test/java/org/apache/iceberg/spark/actions/TestExpireSnapshotsAction.java index b278bd08a608..20b399e12580 100644 --- a/spark/v3.2/spark/src/test/java/org/apache/iceberg/spark/actions/TestExpireSnapshotsAction.java +++ b/spark/v3.2/spark/src/test/java/org/apache/iceberg/spark/actions/TestExpireSnapshotsAction.java @@ -1147,7 +1147,7 @@ public void testUseLocalIterator() { checkExpirationResults(1L, 0L, 0L, 1L, 2L, results); Assert.assertEquals("Expected total number of jobs with stream-results should match the expected number", - 5L, jobsRunDuringStreamResults); + 4L, jobsRunDuringStreamResults); }); } }