Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.parquet.hadoop.metadata.FileMetaData
import org.apache.spark.sql.HoodieSchemaUtils
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
import org.apache.spark.sql.catalyst.expressions.{ArrayTransform, Attribute, Cast, CreateNamedStruct, CreateStruct, Expression, GetStructField, LambdaFunction, Literal, MapEntries, MapFromEntries, NamedLambdaVariable, UnsafeProjection}
import org.apache.spark.sql.types.{ArrayType, DataType, DoubleType, FloatType, MapType, StringType, StructField, StructType}
import org.apache.spark.sql.types.{ArrayType, DataType, DateType, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, StringType, StructField, StructType}

object HoodieParquetFileFormatHelper {

Expand Down Expand Up @@ -104,18 +104,24 @@ object HoodieParquetFileFormatHelper {
requiredSchema: StructType,
partitionSchema: StructType,
schemaUtils: HoodieSchemaUtils): UnsafeProjection = {
val floatToDoubleCache = scala.collection.mutable.HashMap.empty[(DataType, DataType), Boolean]
val addedCastCache = scala.collection.mutable.HashMap.empty[(DataType, DataType), Boolean]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any test coverage on these type promotion cases?


def hasFloatToDoubleConversion(src: DataType, dst: DataType): Boolean = {
floatToDoubleCache.getOrElseUpdate((src, dst), {
def hasUnsupportedConversion(src: DataType, dst: DataType): Boolean = {
addedCastCache.getOrElseUpdate((src, dst), {
(src, dst) match {
case (FloatType, DoubleType) => true
case (IntegerType, DecimalType()) => true
case (LongType, DecimalType()) => true
case (FloatType, DecimalType()) => true
case (DoubleType, DecimalType()) => true
case (StringType, DecimalType()) => true
case (StringType, DateType) => true
case (StructType(srcFields), StructType(dstFields)) =>
srcFields.zip(dstFields).exists { case (sf, df) => hasFloatToDoubleConversion(sf.dataType, df.dataType) }
srcFields.zip(dstFields).exists { case (sf, df) => hasUnsupportedConversion(sf.dataType, df.dataType) }
case (ArrayType(sElem, _), ArrayType(dElem, _)) =>
hasFloatToDoubleConversion(sElem, dElem)
hasUnsupportedConversion(sElem, dElem)
case (MapType(sKey, sVal, _), MapType(dKey, dVal, _)) =>
hasFloatToDoubleConversion(sKey, dKey) || hasFloatToDoubleConversion(sVal, dVal)
hasUnsupportedConversion(sKey, dKey) || hasUnsupportedConversion(sVal, dVal)
case _ => false
}
})
Expand All @@ -127,7 +133,14 @@ object HoodieParquetFileFormatHelper {
case (FloatType, DoubleType) =>
val toStr = Cast(expr, StringType, if (needTimeZone) timeZoneId else None)
Cast(toStr, dstType, if (needTimeZone) timeZoneId else None)
case (s: StructType, d: StructType) if hasFloatToDoubleConversion(s, d) =>
case (IntegerType | LongType | FloatType | DoubleType, dec: DecimalType) =>
val toStr = Cast(expr, StringType, if (needTimeZone) timeZoneId else None)
Cast(toStr, dec, if (needTimeZone) timeZoneId else None)
case (StringType, dec: DecimalType) =>
Cast(expr, dec, if (needTimeZone) timeZoneId else None)
case (StringType, DateType) =>
Cast(expr, DateType, if (needTimeZone) timeZoneId else None)
case (s: StructType, d: StructType) if hasUnsupportedConversion(s, d) =>
val structFields = s.fields.zip(d.fields).zipWithIndex.map {
case ((srcField, dstField), i) =>
val child = GetStructField(expr, i, Some(dstField.name))
Expand All @@ -136,13 +149,13 @@ object HoodieParquetFileFormatHelper {
CreateNamedStruct(d.fields.zip(structFields).flatMap {
case (f, c) => Seq(Literal(f.name), c)
})
case (ArrayType(sElementType, containsNull), ArrayType(dElementType, _)) if hasFloatToDoubleConversion(sElementType, dElementType) =>
case (ArrayType(sElementType, containsNull), ArrayType(dElementType, _)) if hasUnsupportedConversion(sElementType, dElementType) =>
val lambdaVar = NamedLambdaVariable("element", sElementType, containsNull)
val body = recursivelyCastExpressions(lambdaVar, sElementType, dElementType)
val func = LambdaFunction(body, Seq(lambdaVar))
ArrayTransform(expr, func)
case (MapType(sKeyType, sValType, vnull), MapType(dKeyType, dValType, _))
if hasFloatToDoubleConversion(sKeyType, dKeyType) || hasFloatToDoubleConversion(sValType, dValType) =>
if hasUnsupportedConversion(sKeyType, dKeyType) || hasUnsupportedConversion(sValType, dValType) =>
val kv = NamedLambdaVariable("kv", new StructType()
.add("key", sKeyType, nullable = false)
.add("value", sValType, nullable = vnull), nullable = false)
Expand Down
Loading