Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[HUDI-2958] Automatically set spark.sql.parquet.writelegacyformat, when using bulkinsert to insert data which contains decimalType #4253

Merged
merged 4 commits into from
Dec 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,16 @@

package org.apache.hudi.util;

import org.apache.spark.sql.types.ArrayType;
import org.apache.spark.sql.types.ByteType$;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.Decimal;
import org.apache.spark.sql.types.DecimalType;
import org.apache.spark.sql.types.DoubleType$;
import org.apache.spark.sql.types.FloatType$;
import org.apache.spark.sql.types.IntegerType$;
import org.apache.spark.sql.types.LongType$;
import org.apache.spark.sql.types.MapType;
import org.apache.spark.sql.types.ShortType$;
import org.apache.spark.sql.types.StringType$;
import org.apache.spark.sql.types.StructField;
Expand Down Expand Up @@ -119,4 +123,26 @@ private static boolean areCompatible(@Nonnull StructType left, @Nonnull StructTy
private static <T> HashSet<T> newHashSet(T... ts) {
return new HashSet<>(Arrays.asList(ts));
}

/**
* Try to find current sparktype whether contains that DecimalType which's scale < Decimal.MAX_LONG_DIGITS().
*
* @param sparkType spark schema.
* @return found result.
*/
public static boolean foundSmallPrecisionDecimalType(DataType sparkType) {
if (sparkType instanceof StructType) {
StructField[] fields = ((StructType) sparkType).fields();
return Arrays.stream(fields).anyMatch(f -> foundSmallPrecisionDecimalType(f.dataType()));
} else if (sparkType instanceof MapType) {
MapType map = (MapType) sparkType;
return foundSmallPrecisionDecimalType(map.keyType()) || foundSmallPrecisionDecimalType(map.valueType());
} else if (sparkType instanceof ArrayType) {
return foundSmallPrecisionDecimalType(((ArrayType) sparkType).elementType());
} else if (sparkType instanceof DecimalType) {
DecimalType decimalType = (DecimalType) sparkType;
return decimalType.precision() < Decimal.MAX_LONG_DIGITS();
}
return false;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,14 @@
import org.apache.hudi.hive.SlashEncodedDayPartitionValueExtractor;
import org.apache.hudi.index.HoodieIndex.IndexType;
import org.apache.hudi.table.BulkInsertPartitioner;
import org.apache.hudi.util.DataTypeUtils;
import org.apache.log4j.LogManager;
import org.apache.log4j.Logger;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.types.StructType;

import java.io.IOException;
import java.util.ArrayList;
Expand Down Expand Up @@ -309,4 +311,15 @@ public static HiveSyncConfig buildHiveSyncConfig(TypedProperties props, String b
DataSourceWriteOptions.HIVE_SUPPORT_TIMESTAMP_TYPE().defaultValue()));
return hiveSyncConfig;
}

// Now by default ParquetWriteSupport will write DecimalType to parquet as int32/int64 when the scale of decimalType < Decimal.MAX_LONG_DIGITS(),
// but AvroParquetReader which used by HoodieParquetReader cannot support read int32/int64 as DecimalType.
// try to find current schema whether contains that DecimalType, and auto set the value of "hoodie.parquet.writeLegacyFormat.enabled"
public static void mayBeOverwriteParquetWriteLegacyFormatProp(Map<String, String> properties, StructType schema) {
if (DataTypeUtils.foundSmallPrecisionDecimalType(schema)
&& !Boolean.parseBoolean(properties.getOrDefault("hoodie.parquet.writeLegacyFormat.enabled", "false"))) {
properties.put("hoodie.parquet.writeLegacyFormat.enabled", "true");
LOG.warn("Small Decimal Type found in current schema, auto set the value of hoodie.parquet.writeLegacyFormat.enabled to true");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,16 @@
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.types.DecimalType$;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.sql.types.StructType$;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.CsvSource;
import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.ArgumentCaptor;
import org.mockito.Captor;
Expand All @@ -52,7 +58,12 @@

import java.math.BigDecimal;
import java.time.LocalDate;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import static org.apache.hudi.DataSourceUtils.mayBeOverwriteParquetWriteLegacyFormatProp;
import static org.apache.hudi.common.model.HoodieFileFormat.PARQUET;
import static org.apache.hudi.hive.ddl.HiveSyncMode.HMS;
import static org.hamcrest.CoreMatchers.containsString;
Expand Down Expand Up @@ -274,4 +285,33 @@ public boolean arePartitionRecordsSorted() {
return false;
}
}

@ParameterizedTest
@CsvSource({"true, false", "true, true", "false, true", "false, false"})
public void testAutoModifyParquetWriteLegacyFormatParameter(boolean smallDecimal, boolean defaultWriteValue) {
// create test StructType
List<StructField> structFields = new ArrayList<>();
if (smallDecimal) {
structFields.add(StructField.apply("d1", DecimalType$.MODULE$.apply(10, 2), false, Metadata.empty()));
} else {
structFields.add(StructField.apply("d1", DecimalType$.MODULE$.apply(38, 10), false, Metadata.empty()));
}
StructType structType = StructType$.MODULE$.apply(structFields);
// create write options
Map<String, String> options = new HashMap<>();
options.put("hoodie.parquet.writeLegacyFormat.enabled", String.valueOf(defaultWriteValue));

// start test
mayBeOverwriteParquetWriteLegacyFormatProp(options, structType);

// check result
boolean res = Boolean.parseBoolean(options.get("hoodie.parquet.writeLegacyFormat.enabled"));
if (smallDecimal) {
// should auto modify "hoodie.parquet.writeLegacyFormat.enabled" = "true".
assertEquals(true, res);
} else {
// should not modify the value of "hoodie.parquet.writeLegacyFormat.enabled".
assertEquals(defaultWriteValue, res);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -723,4 +723,29 @@ class TestCOWDataSource extends HoodieClientTestBase {
val result = spark.sql("select * from tmptable limit 1").collect()(0)
result.schema.contains(new StructField("partition", StringType, true))
}

@Test
def testWriteSmallPrecisionDecimalTable(): Unit = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add one more test, where if not for small precision, if user has set hoodie.parquet.writeLegacyFormat.enabled to false, it should not get overridden.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is difficult to get the value of hoodie.parquet.writeLegacyFormat.enabled directly from spark.
add functions test for autoModifyParquetWriteLegacyFormatParameter to cover all scenes.

val records1 = recordsToStrings(dataGen.generateInserts("001", 5)).toList
val inputDF1 = spark.read.json(spark.sparkContext.parallelize(records1, 2))
.withColumn("shortDecimal", lit(new java.math.BigDecimal(s"2090.0000"))) // create decimalType(8, 4)
inputDF1.write.format("org.apache.hudi")
.options(commonOpts)
.option(DataSourceWriteOptions.OPERATION.key, DataSourceWriteOptions.BULK_INSERT_OPERATION_OPT_VAL)
.mode(SaveMode.Overwrite)
.save(basePath)

// update the value of shortDecimal
val inputDF2 = inputDF1.withColumn("shortDecimal", lit(new java.math.BigDecimal(s"3090.0000")))
inputDF2.write.format("org.apache.hudi")
.options(commonOpts)
.option(DataSourceWriteOptions.OPERATION.key, DataSourceWriteOptions.UPSERT_OPERATION_OPT_VAL)
.mode(SaveMode.Append)
.save(basePath)
val readResult = spark.read.format("hudi").load(basePath)
assert(readResult.count() == 5)
// compare the test result
assertEquals(inputDF2.sort("_row_key").select("shortDecimal").collect().map(_.getDecimal(0).toPlainString).mkString(","),
readResult.sort("_row_key").select("shortDecimal").collect().map(_.getDecimal(0).toPlainString).mkString(","))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,11 @@
import org.apache.spark.sql.sources.v2.writer.DataSourceWriter;
import org.apache.spark.sql.types.StructType;

import java.util.Map;
import java.util.Optional;

import static org.apache.hudi.DataSourceUtils.mayBeOverwriteParquetWriteLegacyFormatProp;

/**
* DataSource V2 implementation for managing internal write logic. Only called internally.
*/
Expand Down Expand Up @@ -64,8 +67,11 @@ public Optional<DataSourceWriter> createWriter(String writeUUID, StructType sche
String tblName = options.get(HoodieWriteConfig.TBL_NAME.key()).get();
boolean populateMetaFields = options.getBoolean(HoodieTableConfig.POPULATE_META_FIELDS.key(),
Boolean.parseBoolean(HoodieTableConfig.POPULATE_META_FIELDS.defaultValue()));
Map<String, String> properties = options.asMap();
// Auto set the value of "hoodie.parquet.writeLegacyFormat.enabled"
mayBeOverwriteParquetWriteLegacyFormatProp(properties, schema);
// 1st arg to createHoodieConfig is not really required to be set. but passing it anyways.
HoodieWriteConfig config = DataSourceUtils.createHoodieConfig(options.get(HoodieWriteConfig.AVRO_SCHEMA_STRING.key()).get(), path, tblName, options.asMap());
HoodieWriteConfig config = DataSourceUtils.createHoodieConfig(options.get(HoodieWriteConfig.AVRO_SCHEMA_STRING.key()).get(), path, tblName, properties);
boolean arePartitionRecordsSorted = HoodieInternalConfig.getBulkInsertIsPartitionRecordsSorted(
options.get(HoodieInternalConfig.BULKINSERT_ARE_PARTITIONER_RECORDS_SORTED).isPresent()
? options.get(HoodieInternalConfig.BULKINSERT_ARE_PARTITIONER_RECORDS_SORTED).get() : null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@

import java.util.Map;

import static org.apache.hudi.DataSourceUtils.mayBeOverwriteParquetWriteLegacyFormatProp;

/**
* DataSource V2 implementation for managing internal write logic. Only called internally.
* This class is only compatible with datasource V2 API in Spark 3.
Expand All @@ -53,6 +55,8 @@ public Table getTable(StructType schema, Transform[] partitioning, Map<String, S
HoodieTableConfig.POPULATE_META_FIELDS.defaultValue()));
boolean arePartitionRecordsSorted = Boolean.parseBoolean(properties.getOrDefault(HoodieInternalConfig.BULKINSERT_ARE_PARTITIONER_RECORDS_SORTED,
Boolean.toString(HoodieInternalConfig.DEFAULT_BULKINSERT_ARE_PARTITIONER_RECORDS_SORTED)));
// Auto set the value of "hoodie.parquet.writeLegacyFormat.enabled"
mayBeOverwriteParquetWriteLegacyFormatProp(properties, schema);
// 1st arg to createHoodieConfig is not really required to be set. but passing it anyways.
HoodieWriteConfig config = DataSourceUtils.createHoodieConfig(properties.get(HoodieWriteConfig.AVRO_SCHEMA_STRING.key()), path, tblName, properties);
return new HoodieDataSourceInternalTable(instantTime, config, schema, getSparkSession(),
Expand Down