diff --git a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestCOWDataSourceStorage.scala b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestCOWDataSourceStorage.scala index af2bc6980474..bf616e2cb314 100644 --- a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestCOWDataSourceStorage.scala +++ b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestCOWDataSourceStorage.scala @@ -26,25 +26,24 @@ import org.apache.hudi.common.table.timeline.{HoodieInstant, HoodieTimeline} import org.apache.hudi.common.testutils.HoodieTestDataGenerator import org.apache.hudi.common.testutils.RawTripTestPayload.recordsToStrings import org.apache.hudi.config.HoodieWriteConfig +import org.apache.hudi.keygen.TimestampBasedAvroKeyGenerator.Config +import org.apache.hudi.keygen.{ComplexKeyGenerator, TimestampBasedKeyGenerator} import org.apache.hudi.testutils.SparkClientFunctionalTestHarness import org.apache.hudi.{DataSourceReadOptions, DataSourceWriteOptions, HoodieDataSourceHelpers} import org.apache.spark.sql._ import org.apache.spark.sql.functions.{col, lit} import org.junit.jupiter.api.Assertions.{assertEquals, assertFalse, assertTrue} -import org.junit.jupiter.api.{Tag, Test} +import org.junit.jupiter.api.Tag import org.junit.jupiter.params.ParameterizedTest -import org.junit.jupiter.params.provider.{Arguments, CsvSource, ValueSource} +import org.junit.jupiter.params.provider.{CsvSource, ValueSource} -import java.util -import java.util.Arrays -import java.util.stream.Stream import scala.collection.JavaConversions._ @Tag("functional") class TestCOWDataSourceStorage extends SparkClientFunctionalTestHarness { - val commonOpts = Map( + var commonOpts = Map( "hoodie.insert.shuffle.parallelism" -> "4", "hoodie.upsert.shuffle.parallelism" -> "4", "hoodie.bulkinsert.shuffle.parallelism" -> "2", @@ -59,14 +58,26 @@ class TestCOWDataSourceStorage extends SparkClientFunctionalTestHarness { val updatedVerificationVal: String = "driver_update" @ParameterizedTest - @ValueSource(booleans = Array(true, false)) - def testCopyOnWriteStorage(isMetadataEnabled: Boolean): Unit = { + @CsvSource(Array("true,org.apache.hudi.keygen.SimpleKeyGenerator", "true,org.apache.hudi.keygen.ComplexKeyGenerator", + "true,org.apache.hudi.keygen.TimestampBasedKeyGenerator", "false,org.apache.hudi.keygen.SimpleKeyGenerator", + "false,org.apache.hudi.keygen.ComplexKeyGenerator", "false,org.apache.hudi.keygen.TimestampBasedKeyGenerator")) + def testCopyOnWriteStorage(isMetadataEnabled: Boolean, keyGenClass: String): Unit = { + commonOpts += DataSourceWriteOptions.KEYGENERATOR_CLASS_NAME.key() -> keyGenClass + if (classOf[ComplexKeyGenerator].getName.equals(keyGenClass)) { + commonOpts += DataSourceWriteOptions.RECORDKEY_FIELD.key() -> "_row_key, pii_col" + } + if (classOf[TimestampBasedKeyGenerator].getName.equals(keyGenClass)) { + commonOpts += DataSourceWriteOptions.RECORDKEY_FIELD.key() -> "_row_key" + commonOpts += DataSourceWriteOptions.PARTITIONPATH_FIELD.key() -> "current_ts" + commonOpts += Config.TIMESTAMP_TYPE_FIELD_PROP -> "EPOCHMILLISECONDS" + commonOpts += Config.TIMESTAMP_OUTPUT_DATE_FORMAT_PROP -> "yyyyMMdd" + } val dataGen = new HoodieTestDataGenerator() val fs = FSUtils.getFs(basePath, spark.sparkContext.hadoopConfiguration) // Insert Operation - val records1 = recordsToStrings(dataGen.generateInserts("000", 100)).toList - val inputDF1 = spark.read.json(spark.sparkContext.parallelize(records1, 2)) - inputDF1.write.format("org.apache.hudi") + val records0 = recordsToStrings(dataGen.generateInserts("000", 100)).toList + val inputDF0 = spark.read.json(spark.sparkContext.parallelize(records0, 2)) + inputDF0.write.format("org.apache.hudi") .options(commonOpts) .option(DataSourceWriteOptions.OPERATION.key, DataSourceWriteOptions.INSERT_OPERATION_OPT_VAL) .option(HoodieMetadataConfig.ENABLE.key, isMetadataEnabled) @@ -82,9 +93,18 @@ class TestCOWDataSourceStorage extends SparkClientFunctionalTestHarness { .load(basePath) assertEquals(100, snapshotDF1.count()) - // Upsert based on the written table with Hudi metadata columns - val verificationRowKey = snapshotDF1.limit(1).select("_row_key").first.getString(0) - val updateDf = snapshotDF1.filter(col("_row_key") === verificationRowKey).withColumn(verificationCol, lit(updatedVerificationVal)) + val records1 = recordsToStrings(dataGen.generateUpdates("001", 100)).toList + val inputDF1 = spark.read.json(spark.sparkContext.parallelize(records1, 2)) + val verificationRowKey = inputDF1.limit(1).select("_row_key").first.getString(0) + var updateDf: DataFrame = null + if (classOf[TimestampBasedKeyGenerator].getName.equals(keyGenClass)) { + // update current_ts to be same as original record so that partition path does not change with timestamp based key gen + val orignalRow = inputDF1.filter(col("_row_key") === verificationRowKey).collectAsList().get(0) + updateDf = snapshotDF1.filter(col("_row_key") === verificationRowKey).withColumn(verificationCol, lit(updatedVerificationVal)) + .withColumn("current_ts", lit(orignalRow.getAs("current_ts"))) + } else { + updateDf = snapshotDF1.filter(col("_row_key") === verificationRowKey).withColumn(verificationCol, lit(updatedVerificationVal)) + } updateDf.write.format("org.apache.hudi") .options(commonOpts) @@ -100,8 +120,26 @@ class TestCOWDataSourceStorage extends SparkClientFunctionalTestHarness { assertEquals(updatedVerificationVal, snapshotDF2.filter(col("_row_key") === verificationRowKey).select(verificationCol).first.getString(0)) // Upsert Operation without Hudi metadata columns - val records2 = recordsToStrings(dataGen.generateUpdates("001", 100)).toList - val inputDF2 = spark.read.json(spark.sparkContext.parallelize(records2 , 2)) + val records2 = recordsToStrings(dataGen.generateUpdates("002", 100)).toList + var inputDF2 = spark.read.json(spark.sparkContext.parallelize(records2, 2)) + + if (classOf[TimestampBasedKeyGenerator].getName.equals(keyGenClass)) { + // incase of Timestamp based key gen, current_ts should not be updated. but dataGen.generateUpdates() would have updated + // the value of current_ts. So, we need to revert it back to original value. + // here is what we are going to do. Copy values to temp columns, join with original df and update the current_ts + // and drop the temp columns. + + val inputDF2WithTempCols = inputDF2.withColumn("current_ts_temp", col("current_ts")) + .withColumn("_row_key_temp", col("_row_key")) + val originalRowCurrentTsDf = inputDF0.select("_row_key", "current_ts") + // join with original df + val joinedDf = inputDF2WithTempCols.drop("_row_key", "current_ts").join(originalRowCurrentTsDf, (inputDF2WithTempCols("_row_key_temp") === originalRowCurrentTsDf("_row_key"))) + // copy values from temp back to original cols and drop temp cols + inputDF2 = joinedDf.withColumn("current_ts_temp", col("current_ts")) + .drop("current_ts", "_row_key_temp").withColumn("current_ts", col("current_ts_temp")) + .drop("current_ts_temp") + } + val uniqueKeyCnt = inputDF2.select("_row_key").distinct().count() inputDF2.write.format("org.apache.hudi") @@ -136,12 +174,12 @@ class TestCOWDataSourceStorage extends SparkClientFunctionalTestHarness { val emptyIncDF = spark.read.format("org.apache.hudi") .option(DataSourceReadOptions.QUERY_TYPE.key, DataSourceReadOptions.QUERY_TYPE_INCREMENTAL_OPT_VAL) .option(DataSourceReadOptions.BEGIN_INSTANTTIME.key, "000") - .option(DataSourceReadOptions.END_INSTANTTIME.key, "001") + .option(DataSourceReadOptions.END_INSTANTTIME.key, "002") .load(basePath) assertEquals(0, emptyIncDF.count()) // Upsert an empty dataFrame - val emptyRecords = recordsToStrings(dataGen.generateUpdates("002", 0)).toList + val emptyRecords = recordsToStrings(dataGen.generateUpdates("003", 0)).toList val emptyDF = spark.read.json(spark.sparkContext.parallelize(emptyRecords, 1)) emptyDF.write.format("org.apache.hudi") .options(commonOpts) @@ -195,7 +233,7 @@ class TestCOWDataSourceStorage extends SparkClientFunctionalTestHarness { .option("hoodie.keep.min.commits", "2") .option("hoodie.keep.max.commits", "3") .option("hoodie.cleaner.commits.retained", "1") - .option("hoodie.metadata.enable","false") + .option("hoodie.metadata.enable", "false") .option(DataSourceWriteOptions.OPERATION.key, DataSourceWriteOptions.BULK_INSERT_OPERATION_OPT_VAL) .mode(SaveMode.Overwrite) .save(basePath) @@ -205,8 +243,7 @@ class TestCOWDataSourceStorage extends SparkClientFunctionalTestHarness { // issue delete partition to partition1 writeRecords(2, dataGenPartition1, writeOperation, basePath) - val expectedRecCount = if (writeOperation.equals(DataSourceWriteOptions.INSERT_OVERWRITE_OPERATION_OPT_VAL)) - { + val expectedRecCount = if (writeOperation.equals(DataSourceWriteOptions.INSERT_OVERWRITE_OPERATION_OPT_VAL)) { 200 - partition1RecordCount } else { 100 - partition1RecordCount @@ -239,15 +276,15 @@ class TestCOWDataSourceStorage extends SparkClientFunctionalTestHarness { .option("hoodie.keep.min.commits", "2") .option("hoodie.keep.max.commits", "3") .option("hoodie.cleaner.commits.retained", "1") - .option("hoodie.metadata.enable","false") + .option("hoodie.metadata.enable", "false") .option(DataSourceWriteOptions.OPERATION.key, writeOperation) .mode(SaveMode.Append) .save(basePath) } - def assertRecordCount(basePath: String, expectedRecordCount: Long) : Unit = { + def assertRecordCount(basePath: String, expectedRecordCount: Long): Unit = { val snapshotDF = spark.read.format("org.apache.hudi") - .load(basePath + "/*/*/*/*") + .load(basePath + "/*/*/*/*") assertEquals(expectedRecordCount, snapshotDF.count()) } }