diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala index f0df18da8eed6..56677d7d97af2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala @@ -102,13 +102,11 @@ class CSVInferSchema(val options: CSVOptions) extends Serializable { if (field == null || field.isEmpty || field == options.nullValue) { typeSoFar } else { - typeSoFar match { + val typeElemInfer = typeSoFar match { case NullType => tryParseInteger(field) case IntegerType => tryParseInteger(field) case LongType => tryParseLong(field) - case _: DecimalType => - // DecimalTypes have different precisions and scales, so we try to find the common type. - compatibleType(typeSoFar, tryParseDecimal(field)).getOrElse(StringType) + case _: DecimalType => tryParseDecimal(field) case DoubleType => tryParseDouble(field) case TimestampType => tryParseTimestamp(field) case BooleanType => tryParseBoolean(field) @@ -116,6 +114,7 @@ class CSVInferSchema(val options: CSVOptions) extends Serializable { case other: DataType => throw new UnsupportedOperationException(s"Unexpected data type $other") } + compatibleType(typeSoFar, typeElemInfer).getOrElse(StringType) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchemaSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchemaSuite.scala index b014eb92fae50..d268f8c2e7210 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchemaSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchemaSuite.scala @@ -56,11 +56,11 @@ class CSVInferSchemaSuite extends SparkFunSuite with SQLHelper { assert(inferSchema.inferField(IntegerType, "1.0") == DoubleType) assert(inferSchema.inferField(DoubleType, null) == DoubleType) assert(inferSchema.inferField(DoubleType, "test") == StringType) - assert(inferSchema.inferField(LongType, "2015-08-20 14:57:00") == TimestampType) - assert(inferSchema.inferField(DoubleType, "2015-08-20 15:57:00") == TimestampType) - assert(inferSchema.inferField(LongType, "True") == BooleanType) - assert(inferSchema.inferField(IntegerType, "FALSE") == BooleanType) - assert(inferSchema.inferField(TimestampType, "FALSE") == BooleanType) + assert(inferSchema.inferField(LongType, "2015-08-20 14:57:00") == StringType) + assert(inferSchema.inferField(DoubleType, "2015-08-20 15:57:00") == StringType) + assert(inferSchema.inferField(LongType, "True") == StringType) + assert(inferSchema.inferField(IntegerType, "FALSE") == StringType) + assert(inferSchema.inferField(TimestampType, "FALSE") == StringType) val textValueOne = Long.MaxValue.toString + "0" val decimalValueOne = new java.math.BigDecimal(textValueOne) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 366cf11871fa0..fcb7bdc25f08f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -2341,6 +2341,18 @@ abstract class CSVSuite extends QueryTest with SharedSparkSession with TestCsvDa checkAnswer(csv, Row(null)) } } + + test("SPARK-32025: infer the schema from mixed-type values") { + withTempPath { path => + Seq("col_mixed_types", "2012", "1997", "True").toDS.write.text(path.getCanonicalPath) + val df = spark.read.format("csv") + .option("header", "true") + .option("inferSchema", "true") + .load(path.getCanonicalPath) + + assert(df.schema.last == StructField("col_mixed_types", StringType, true)) + } + } } class CSVv1Suite extends CSVSuite {