Skip to content

Commit

Permalink
[HUDI-2958] Automatically set spark.sql.parquet.writelegacyformat, wh…
Browse files Browse the repository at this point in the history
…en using bulkinsert to insert data which contains decimalType (#4253)
  • Loading branch information
xiarixiaoyao committed Dec 17, 2021
1 parent e4cfb42 commit 9246b16
Show file tree
Hide file tree
Showing 6 changed files with 115 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,16 @@

package org.apache.hudi.util;

import org.apache.spark.sql.types.ArrayType;
import org.apache.spark.sql.types.ByteType$;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.Decimal;
import org.apache.spark.sql.types.DecimalType;
import org.apache.spark.sql.types.DoubleType$;
import org.apache.spark.sql.types.FloatType$;
import org.apache.spark.sql.types.IntegerType$;
import org.apache.spark.sql.types.LongType$;
import org.apache.spark.sql.types.MapType;
import org.apache.spark.sql.types.ShortType$;
import org.apache.spark.sql.types.StringType$;
import org.apache.spark.sql.types.StructField;
Expand Down Expand Up @@ -119,4 +123,26 @@ private static boolean areCompatible(@Nonnull StructType left, @Nonnull StructTy
private static <T> HashSet<T> newHashSet(T... ts) {
return new HashSet<>(Arrays.asList(ts));
}

/**
* Try to find current sparktype whether contains that DecimalType which's scale < Decimal.MAX_LONG_DIGITS().
*
* @param sparkType spark schema.
* @return found result.
*/
public static boolean foundSmallPrecisionDecimalType(DataType sparkType) {
if (sparkType instanceof StructType) {
StructField[] fields = ((StructType) sparkType).fields();
return Arrays.stream(fields).anyMatch(f -> foundSmallPrecisionDecimalType(f.dataType()));
} else if (sparkType instanceof MapType) {
MapType map = (MapType) sparkType;
return foundSmallPrecisionDecimalType(map.keyType()) || foundSmallPrecisionDecimalType(map.valueType());
} else if (sparkType instanceof ArrayType) {
return foundSmallPrecisionDecimalType(((ArrayType) sparkType).elementType());
} else if (sparkType instanceof DecimalType) {
DecimalType decimalType = (DecimalType) sparkType;
return decimalType.precision() < Decimal.MAX_LONG_DIGITS();
}
return false;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,14 @@
import org.apache.hudi.hive.SlashEncodedDayPartitionValueExtractor;
import org.apache.hudi.index.HoodieIndex.IndexType;
import org.apache.hudi.table.BulkInsertPartitioner;
import org.apache.hudi.util.DataTypeUtils;
import org.apache.log4j.LogManager;
import org.apache.log4j.Logger;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.types.StructType;

import java.io.IOException;
import java.util.ArrayList;
Expand Down Expand Up @@ -309,4 +311,15 @@ public static HiveSyncConfig buildHiveSyncConfig(TypedProperties props, String b
DataSourceWriteOptions.HIVE_SUPPORT_TIMESTAMP_TYPE().defaultValue()));
return hiveSyncConfig;
}

// Now by default ParquetWriteSupport will write DecimalType to parquet as int32/int64 when the scale of decimalType < Decimal.MAX_LONG_DIGITS(),
// but AvroParquetReader which used by HoodieParquetReader cannot support read int32/int64 as DecimalType.
// try to find current schema whether contains that DecimalType, and auto set the value of "hoodie.parquet.writeLegacyFormat.enabled"
public static void mayBeOverwriteParquetWriteLegacyFormatProp(Map<String, String> properties, StructType schema) {
if (DataTypeUtils.foundSmallPrecisionDecimalType(schema)
&& !Boolean.parseBoolean(properties.getOrDefault("hoodie.parquet.writeLegacyFormat.enabled", "false"))) {
properties.put("hoodie.parquet.writeLegacyFormat.enabled", "true");
LOG.warn("Small Decimal Type found in current schema, auto set the value of hoodie.parquet.writeLegacyFormat.enabled to true");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,16 @@
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.types.DecimalType$;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.sql.types.StructType$;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.CsvSource;
import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.ArgumentCaptor;
import org.mockito.Captor;
Expand All @@ -52,7 +58,12 @@

import java.math.BigDecimal;
import java.time.LocalDate;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import static org.apache.hudi.DataSourceUtils.mayBeOverwriteParquetWriteLegacyFormatProp;
import static org.apache.hudi.common.model.HoodieFileFormat.PARQUET;
import static org.apache.hudi.hive.ddl.HiveSyncMode.HMS;
import static org.hamcrest.CoreMatchers.containsString;
Expand Down Expand Up @@ -274,4 +285,33 @@ public boolean arePartitionRecordsSorted() {
return false;
}
}

@ParameterizedTest
@CsvSource({"true, false", "true, true", "false, true", "false, false"})
public void testAutoModifyParquetWriteLegacyFormatParameter(boolean smallDecimal, boolean defaultWriteValue) {
// create test StructType
List<StructField> structFields = new ArrayList<>();
if (smallDecimal) {
structFields.add(StructField.apply("d1", DecimalType$.MODULE$.apply(10, 2), false, Metadata.empty()));
} else {
structFields.add(StructField.apply("d1", DecimalType$.MODULE$.apply(38, 10), false, Metadata.empty()));
}
StructType structType = StructType$.MODULE$.apply(structFields);
// create write options
Map<String, String> options = new HashMap<>();
options.put("hoodie.parquet.writeLegacyFormat.enabled", String.valueOf(defaultWriteValue));

// start test
mayBeOverwriteParquetWriteLegacyFormatProp(options, structType);

// check result
boolean res = Boolean.parseBoolean(options.get("hoodie.parquet.writeLegacyFormat.enabled"));
if (smallDecimal) {
// should auto modify "hoodie.parquet.writeLegacyFormat.enabled" = "true".
assertEquals(true, res);
} else {
// should not modify the value of "hoodie.parquet.writeLegacyFormat.enabled".
assertEquals(defaultWriteValue, res);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -723,4 +723,29 @@ class TestCOWDataSource extends HoodieClientTestBase {
val result = spark.sql("select * from tmptable limit 1").collect()(0)
result.schema.contains(new StructField("partition", StringType, true))
}

@Test
def testWriteSmallPrecisionDecimalTable(): Unit = {
val records1 = recordsToStrings(dataGen.generateInserts("001", 5)).toList
val inputDF1 = spark.read.json(spark.sparkContext.parallelize(records1, 2))
.withColumn("shortDecimal", lit(new java.math.BigDecimal(s"2090.0000"))) // create decimalType(8, 4)
inputDF1.write.format("org.apache.hudi")
.options(commonOpts)
.option(DataSourceWriteOptions.OPERATION.key, DataSourceWriteOptions.BULK_INSERT_OPERATION_OPT_VAL)
.mode(SaveMode.Overwrite)
.save(basePath)

// update the value of shortDecimal
val inputDF2 = inputDF1.withColumn("shortDecimal", lit(new java.math.BigDecimal(s"3090.0000")))
inputDF2.write.format("org.apache.hudi")
.options(commonOpts)
.option(DataSourceWriteOptions.OPERATION.key, DataSourceWriteOptions.UPSERT_OPERATION_OPT_VAL)
.mode(SaveMode.Append)
.save(basePath)
val readResult = spark.read.format("hudi").load(basePath)
assert(readResult.count() == 5)
// compare the test result
assertEquals(inputDF2.sort("_row_key").select("shortDecimal").collect().map(_.getDecimal(0).toPlainString).mkString(","),
readResult.sort("_row_key").select("shortDecimal").collect().map(_.getDecimal(0).toPlainString).mkString(","))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,11 @@
import org.apache.spark.sql.sources.v2.writer.DataSourceWriter;
import org.apache.spark.sql.types.StructType;

import java.util.Map;
import java.util.Optional;

import static org.apache.hudi.DataSourceUtils.mayBeOverwriteParquetWriteLegacyFormatProp;

/**
* DataSource V2 implementation for managing internal write logic. Only called internally.
*/
Expand Down Expand Up @@ -64,8 +67,11 @@ public Optional<DataSourceWriter> createWriter(String writeUUID, StructType sche
String tblName = options.get(HoodieWriteConfig.TBL_NAME.key()).get();
boolean populateMetaFields = options.getBoolean(HoodieTableConfig.POPULATE_META_FIELDS.key(),
Boolean.parseBoolean(HoodieTableConfig.POPULATE_META_FIELDS.defaultValue()));
Map<String, String> properties = options.asMap();
// Auto set the value of "hoodie.parquet.writeLegacyFormat.enabled"
mayBeOverwriteParquetWriteLegacyFormatProp(properties, schema);
// 1st arg to createHoodieConfig is not really required to be set. but passing it anyways.
HoodieWriteConfig config = DataSourceUtils.createHoodieConfig(options.get(HoodieWriteConfig.AVRO_SCHEMA_STRING.key()).get(), path, tblName, options.asMap());
HoodieWriteConfig config = DataSourceUtils.createHoodieConfig(options.get(HoodieWriteConfig.AVRO_SCHEMA_STRING.key()).get(), path, tblName, properties);
boolean arePartitionRecordsSorted = HoodieInternalConfig.getBulkInsertIsPartitionRecordsSorted(
options.get(HoodieInternalConfig.BULKINSERT_ARE_PARTITIONER_RECORDS_SORTED).isPresent()
? options.get(HoodieInternalConfig.BULKINSERT_ARE_PARTITIONER_RECORDS_SORTED).get() : null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@

import java.util.Map;

import static org.apache.hudi.DataSourceUtils.mayBeOverwriteParquetWriteLegacyFormatProp;

/**
* DataSource V2 implementation for managing internal write logic. Only called internally.
* This class is only compatible with datasource V2 API in Spark 3.
Expand All @@ -53,6 +55,8 @@ public Table getTable(StructType schema, Transform[] partitioning, Map<String, S
HoodieTableConfig.POPULATE_META_FIELDS.defaultValue()));
boolean arePartitionRecordsSorted = Boolean.parseBoolean(properties.getOrDefault(HoodieInternalConfig.BULKINSERT_ARE_PARTITIONER_RECORDS_SORTED,
Boolean.toString(HoodieInternalConfig.DEFAULT_BULKINSERT_ARE_PARTITIONER_RECORDS_SORTED)));
// Auto set the value of "hoodie.parquet.writeLegacyFormat.enabled"
mayBeOverwriteParquetWriteLegacyFormatProp(properties, schema);
// 1st arg to createHoodieConfig is not really required to be set. but passing it anyways.
HoodieWriteConfig config = DataSourceUtils.createHoodieConfig(properties.get(HoodieWriteConfig.AVRO_SCHEMA_STRING.key()), path, tblName, properties);
return new HoodieDataSourceInternalTable(instantTime, config, schema, getSparkSession(),
Expand Down

0 comments on commit 9246b16

Please sign in to comment.