Skip to content
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,14 @@ public interface MigrateTable extends Action<MigrateTable, MigrateTable.Result>
*/
MigrateTable tableProperty(String name, String value);

/**
* @param numReaders the number of concurrent file read operations to use per partition
* @return this for method chaining
**/
default MigrateTable withParallelReads(int numReaders) {
throw new UnsupportedOperationException(this.getClass().getName() + " does not implement withParallelReads");
}

/**
* The action result that contains a summary of the execution.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,14 @@ public interface SnapshotTable extends Action<SnapshotTable, SnapshotTable.Resul
*/
SnapshotTable tableProperty(String key, String value);

/**
* @param numReaders the number of concurrent file read operations to use per partition
* @return this for method chaining
**/
default SnapshotTable withParallelReads(int numReaders) {
throw new UnsupportedOperationException(this.getClass().getName() + " does not implement withParallelReads");
}

/**
* The action result that contains a summary of the execution.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,52 @@ public void addDataPartitionedHive() {
sql("SELECT id, name, dept, subdept FROM %s ORDER BY id", tableName));
}

@Test
public void addDataPartitionedHiveInParallel() {
createPartitionedHiveTable();

String createIceberg =
"CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING iceberg PARTITIONED BY (id)";

sql(createIceberg, tableName);

Object result = scalarSql("CALL %s.system.add_files(" +
"table => '%s', " +
"source_table => '%s', " +
"max_concurrent_read_datafiles => 3)",
catalogName, tableName, sourceTableName);

Assert.assertEquals(8L, result);

assertEquals(
"Iceberg table contains correct data",
sql("SELECT id, name, dept, subdept FROM %s ORDER BY id", sourceTableName),
sql("SELECT id, name, dept, subdept FROM %s ORDER BY id", tableName));
}

@Test
public void addDataUnpartitionedHiveInParallel() {
createUnpartitionedHiveTable();

String createIceberg =
"CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING iceberg";

sql(createIceberg, tableName);

Object result = scalarSql("CALL %s.system.add_files(" +
"table => '%s', " +
"source_table => '%s', " +
"max_concurrent_read_datafiles => 3)",
catalogName, tableName, sourceTableName);

Assert.assertEquals(2L, result);

assertEquals(
"Iceberg table contains correct data",
sql("SELECT * FROM %s ORDER BY id", sourceTableName),
sql("SELECT * FROM %s ORDER BY id", tableName));
}

@Test
public void addPartitionToPartitioned() {
createPartitionedFileTable("parquet");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,34 @@ public void testMigrateWithOptions() throws IOException {
sql("DROP TABLE %s", tableName + "_BACKUP_");
}

@Test
public void testMigrateWithParallelism() throws IOException {
Assume.assumeTrue(catalogName.equals("spark_catalog"));
String location = temp.newFolder().toString();
sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING parquet LOCATION '%s'", tableName, location);
sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName);
sql("INSERT INTO TABLE %s VALUES (2, 'b')", tableName);
sql("INSERT INTO TABLE %s VALUES (3, 'c')", tableName);

Object result =
scalarSql("CALL %s.system.migrate(table => '%s', max_concurrent_read_datafiles => 3)",
catalogName, tableName);

Assert.assertEquals("Should have added three files", 3L, result);

Table createdTable = validationCatalog.loadTable(tableIdent);

String tableLocation = createdTable.location().replace("file:", "");
Assert.assertEquals("Table should have original location", location, tableLocation);

assertEquals(
"Should have expected rows",
ImmutableList.of(row(1L, "a"), row(2L, "b"), row(3L, "c")),
sql("SELECT * FROM %s ORDER BY id", tableName));

sql("DROP TABLE %s", tableName + "_BACKUP_");
}

@Test
public void testMigrateWithInvalidMetricsConfig() throws IOException {
Assume.assumeTrue(catalogName.equals("spark_catalog"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,30 @@ public void testSnapshotWithProperties() throws IOException {
sql("SELECT * FROM %s ORDER BY id", tableName));
}

@Test
public void testSnapshotWithParallelism() throws IOException {
String location = temp.newFolder().toString();
sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING parquet LOCATION '%s'", sourceName, location);
sql("INSERT INTO TABLE %s VALUES (1, 'a')", sourceName);
sql("INSERT INTO TABLE %s VALUES (2, 'b')", sourceName);
sql("INSERT INTO TABLE %s VALUES (3, 'c')", sourceName);
Object result = scalarSql(
"CALL %s.system.snapshot(source_table => '%s', table => '%s', max_concurrent_read_datafiles => 3)",
catalogName, sourceName, tableName);

Assert.assertEquals("Should have added three file", 3L, result);

Table createdTable = validationCatalog.loadTable(tableIdent);

String tableLocation = createdTable.location();
Assert.assertNotEquals("Table should not have the original location", location, tableLocation);

assertEquals(
"Should have expected rows",
ImmutableList.of(row(1L, "a"), row(2L, "b"), row(3L, "c")),
sql("SELECT * FROM %s ORDER BY id", tableName));
}

@Test
public void testSnapshotWithAlternateLocation() throws IOException {
Assume.assumeTrue("No Snapshoting with Alternate locations with Hadoop Catalogs", !catalogName.contains("hadoop"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -380,10 +380,11 @@ private static Iterator<ManifestFile> buildManifest(SerializableConfiguration co
* @param stagingDir a staging directory to store temporary manifest files
* @param partitionFilter only import partitions whose values match those in the map, can be partially defined
* @param checkDuplicateFiles if true, throw exception if import results in a duplicate data file
* @param parallelism Controls max concurrency of file reads per partition
*/
public static void importSparkTable(SparkSession spark, TableIdentifier sourceTableIdent, Table targetTable,
String stagingDir, Map<String, String> partitionFilter,
boolean checkDuplicateFiles) {
boolean checkDuplicateFiles, int parallelism) {
SessionCatalog catalog = spark.sessionState().catalog();

String db = sourceTableIdent.database().nonEmpty() ?
Expand All @@ -399,20 +400,41 @@ public static void importSparkTable(SparkSession spark, TableIdentifier sourceTa
PartitionSpec spec = SparkSchemaUtil.specForTable(spark, sourceTableIdentWithDB.unquotedString());

if (Objects.equal(spec, PartitionSpec.unpartitioned())) {
importUnpartitionedSparkTable(spark, sourceTableIdentWithDB, targetTable, checkDuplicateFiles);
importUnpartitionedSparkTable(spark, sourceTableIdentWithDB, targetTable, checkDuplicateFiles, parallelism);
} else {
List<SparkPartition> sourceTablePartitions = getPartitions(spark, sourceTableIdent,
partitionFilter);
Preconditions.checkArgument(!sourceTablePartitions.isEmpty(),
"Cannot find any partitions in table %s", sourceTableIdent);
importSparkPartitions(spark, sourceTablePartitions, targetTable, spec, stagingDir, checkDuplicateFiles);
importSparkPartitions(spark, sourceTablePartitions, targetTable, spec, stagingDir, checkDuplicateFiles,
parallelism);
}
} catch (AnalysisException e) {
throw SparkExceptionUtil.toUncheckedException(
e, "Unable to get partition spec for table: %s", sourceTableIdentWithDB);
}
}

/**
* Import files from an existing Spark table to an Iceberg table.
*
* The import uses the Spark session to get table metadata. It assumes no
* operation is going on the original and target table and thus is not
* thread-safe.
*
* @param spark a Spark session
* @param sourceTableIdent an identifier of the source Spark table
* @param targetTable an Iceberg table where to import the data
* @param stagingDir a staging directory to store temporary manifest files
* @param partitionFilter only import partitions whose values match those in the map, can be partially defined
* @param checkDuplicateFiles if true, throw exception if import results in a duplicate data file
*/
public static void importSparkTable(SparkSession spark, TableIdentifier sourceTableIdent, Table targetTable,
String stagingDir, Map<String, String> partitionFilter,
boolean checkDuplicateFiles) {
importSparkTable(spark, sourceTableIdent, targetTable, stagingDir, partitionFilter, checkDuplicateFiles, 1);
}

/**
* Import files from an existing Spark table to an Iceberg table.
*
Expand Down Expand Up @@ -448,7 +470,7 @@ public static void importSparkTable(SparkSession spark, TableIdentifier sourceTa
}

private static void importUnpartitionedSparkTable(SparkSession spark, TableIdentifier sourceTableIdent,
Table targetTable, boolean checkDuplicateFiles) {
Table targetTable, boolean checkDuplicateFiles, int parallelism) {
try {
CatalogTable sourceTable = spark.sessionState().catalog().getTableMetadata(sourceTableIdent);
Option<String> format =
Expand All @@ -463,7 +485,8 @@ private static void importUnpartitionedSparkTable(SparkSession spark, TableIdent
NameMapping nameMapping = nameMappingString != null ? NameMappingParser.fromJson(nameMappingString) : null;

List<DataFile> files = TableMigrationUtil.listPartition(
partition, Util.uriToString(sourceTable.location()), format.get(), spec, conf, metricsConfig, nameMapping);
partition, Util.uriToString(sourceTable.location()), format.get(), spec, conf, metricsConfig, nameMapping,
parallelism);

if (checkDuplicateFiles) {
Dataset<Row> importedFiles = spark.createDataset(
Expand Down Expand Up @@ -497,9 +520,11 @@ private static void importUnpartitionedSparkTable(SparkSession spark, TableIdent
* @param spec a partition spec
* @param stagingDir a staging directory to store temporary manifest files
* @param checkDuplicateFiles if true, throw exception if import results in a duplicate data file
* @param listPartitionParallelism Max number of concurrent files to read per partition while indexing table
*/
public static void importSparkPartitions(SparkSession spark, List<SparkPartition> partitions, Table targetTable,
PartitionSpec spec, String stagingDir, boolean checkDuplicateFiles) {
PartitionSpec spec, String stagingDir, boolean checkDuplicateFiles,
int listPartitionParallelism) {
Configuration conf = spark.sessionState().newHadoopConf();
SerializableConfiguration serializableConf = new SerializableConfiguration(conf);
int parallelism = Math.min(partitions.size(), spark.sessionState().conf().parallelPartitionDiscoveryParallelism());
Expand All @@ -516,8 +541,11 @@ public static void importSparkPartitions(SparkSession spark, List<SparkPartition
Encoders.javaSerialization(SparkPartition.class));

Dataset<DataFile> filesToImport = partitionDS
.flatMap((FlatMapFunction<SparkPartition, DataFile>) sparkPartition ->
listPartition(sparkPartition, spec, serializableConf, metricsConfig, nameMapping).iterator(),
.flatMap(
(FlatMapFunction<SparkPartition, DataFile>) sparkPartition ->
TableMigrationUtil.listPartition(sparkPartition.values, sparkPartition.uri,
sparkPartition.format, spec, serializableConf.get(), metricsConfig, nameMapping,
listPartitionParallelism).iterator(),
Encoders.javaSerialization(DataFile.class));

if (checkDuplicateFiles) {
Expand Down Expand Up @@ -564,6 +592,21 @@ public static void importSparkPartitions(SparkSession spark, List<SparkPartition
}
}

/**
* Import files from given partitions to an Iceberg table.
*
* @param spark a Spark session
* @param partitions partitions to import
* @param targetTable an Iceberg table where to import the data
* @param spec a partition spec
* @param stagingDir a staging directory to store temporary manifest files
* @param checkDuplicateFiles if true, throw exception if import results in a duplicate data file
*/
public static void importSparkPartitions(SparkSession spark, List<SparkPartition> partitions, Table targetTable,
PartitionSpec spec, String stagingDir, boolean checkDuplicateFiles) {
importSparkPartitions(spark, partitions, targetTable, spec, stagingDir, checkDuplicateFiles, 1);
}

/**
* Import files from given partitions to an Iceberg table.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

package org.apache.iceberg.spark.actions;

import java.util.Collections;
import java.util.Map;
import org.apache.iceberg.Snapshot;
import org.apache.iceberg.SnapshotSummary;
Expand Down Expand Up @@ -60,6 +61,8 @@ public class BaseMigrateTableSparkAction
private final StagingTableCatalog destCatalog;
private final Identifier destTableIdent;
private final Identifier backupIdent;
// Max number of concurrent files to read per partition while indexing table
private int readDatafileParallelism = 1;

public BaseMigrateTableSparkAction(SparkSession spark, CatalogPlugin sourceCatalog, Identifier sourceTableIdent) {
super(spark, sourceCatalog, sourceTableIdent);
Expand All @@ -69,6 +72,7 @@ public BaseMigrateTableSparkAction(SparkSession spark, CatalogPlugin sourceCatal
this.backupIdent = Identifier.of(sourceTableIdent.namespace(), backupName);
}


@Override
protected MigrateTable self() {
return this;
Expand Down Expand Up @@ -96,6 +100,12 @@ public MigrateTable tableProperty(String property, String value) {
return this;
}

@Override
public MigrateTable withParallelReads(int numReaders) {
this.readDatafileParallelism = numReaders;
return this;
}

@Override
public MigrateTable.Result execute() {
String desc = String.format("Migrating table %s", destTableIdent().toString());
Expand Down Expand Up @@ -125,7 +135,8 @@ private MigrateTable.Result doExecute() {
TableIdentifier v1BackupIdent = new TableIdentifier(backupIdent.name(), backupNamespace);
String stagingLocation = getMetadataLocation(icebergTable);
LOG.info("Generating Iceberg metadata for {} in {}", destTableIdent(), stagingLocation);
SparkTableUtil.importSparkTable(spark(), v1BackupIdent, icebergTable, stagingLocation);
SparkTableUtil.importSparkTable(spark(), v1BackupIdent, icebergTable, stagingLocation, Collections.emptyMap(),
false, readDatafileParallelism);

LOG.info("Committing staged changes to {}", destTableIdent());
stagedTable.commitStagedChanges();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

package org.apache.iceberg.spark.actions;

import java.util.Collections;
import java.util.Map;
import org.apache.iceberg.Snapshot;
import org.apache.iceberg.SnapshotSummary;
Expand Down Expand Up @@ -57,6 +58,8 @@ public class BaseSnapshotTableSparkAction
private StagingTableCatalog destCatalog;
private Identifier destTableIdent;
private String destTableLocation = null;
// Max number of concurrent files to read per partition while indexing table
private int readDatafileParallelism = 1;

BaseSnapshotTableSparkAction(SparkSession spark, CatalogPlugin sourceCatalog, Identifier sourceTableIdent) {
super(spark, sourceCatalog, sourceTableIdent);
Expand Down Expand Up @@ -107,6 +110,12 @@ public SnapshotTable tableProperty(String property, String value) {
return this;
}

@Override
public SnapshotTable withParallelReads(int numReaders) {
this.readDatafileParallelism = numReaders;
return this;
}

@Override
public SnapshotTable.Result execute() {
String desc = String.format("Snapshotting table %s as %s", sourceTableIdent(), destTableIdent);
Expand All @@ -133,7 +142,8 @@ private SnapshotTable.Result doExecute() {
TableIdentifier v1TableIdent = v1SourceTable().identifier();
String stagingLocation = getMetadataLocation(icebergTable);
LOG.info("Generating Iceberg metadata for {} in {}", destTableIdent(), stagingLocation);
SparkTableUtil.importSparkTable(spark(), v1TableIdent, icebergTable, stagingLocation);
SparkTableUtil.importSparkTable(spark(), v1TableIdent, icebergTable, stagingLocation,
Collections.emptyMap(), false, readDatafileParallelism);

LOG.info("Committing staged changes to {}", destTableIdent());
stagedTable.commitStagedChanges();
Expand Down
Loading