diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala index 7f1ed28046b1d..edead9b21b21c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala @@ -87,6 +87,7 @@ private[csv] object CSVInferSchema { case LongType => tryParseLong(field) case DoubleType => tryParseDouble(field) case TimestampType => tryParseTimestamp(field) + case BooleanType => tryParseBoolean(field) case StringType => StringType case other: DataType => throw new UnsupportedOperationException(s"Unexpected data type $other") @@ -117,6 +118,14 @@ private[csv] object CSVInferSchema { def tryParseTimestamp(field: String): DataType = { if ((allCatch opt Timestamp.valueOf(field)).isDefined) { TimestampType + } else { + tryParseBoolean(field) + } + } + + def tryParseBoolean(field: String): DataType = { + if ((allCatch opt field.toBoolean).isDefined) { + BooleanType } else { stringType() } diff --git a/sql/core/src/test/resources/bool.csv b/sql/core/src/test/resources/bool.csv new file mode 100644 index 0000000000000..94b2d49506e0d --- /dev/null +++ b/sql/core/src/test/resources/bool.csv @@ -0,0 +1,5 @@ +bool +"True" +"False" + +"true" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala index 412f1b89beee7..7af3f94aefea2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala @@ -30,6 +30,8 @@ class InferSchemaSuite extends SparkFunSuite { assert(CSVInferSchema.inferField(NullType, "3.5") == DoubleType) assert(CSVInferSchema.inferField(NullType, "test") == StringType) assert(CSVInferSchema.inferField(NullType, "2015-08-20 15:57:00") == TimestampType) + assert(CSVInferSchema.inferField(NullType, "True") == BooleanType) + assert(CSVInferSchema.inferField(NullType, "FAlSE") == BooleanType) } test("String fields types are inferred correctly from other types") { @@ -40,6 +42,9 @@ class InferSchemaSuite extends SparkFunSuite { assert(CSVInferSchema.inferField(DoubleType, "test") == StringType) assert(CSVInferSchema.inferField(LongType, "2015-08-20 14:57:00") == TimestampType) assert(CSVInferSchema.inferField(DoubleType, "2015-08-20 15:57:00") == TimestampType) + assert(CSVInferSchema.inferField(LongType, "True") == BooleanType) + assert(CSVInferSchema.inferField(IntegerType, "FALSE") == BooleanType) + assert(CSVInferSchema.inferField(TimestampType, "FALSE") == BooleanType) } test("Timestamp field types are inferred correctly from other types") { @@ -48,6 +53,11 @@ class InferSchemaSuite extends SparkFunSuite { assert(CSVInferSchema.inferField(LongType, "2015-08 14:49:00") == StringType) } + test("Boolean fields types are inferred correctly from other types") { + assert(CSVInferSchema.inferField(LongType, "Fale") == StringType) + assert(CSVInferSchema.inferField(DoubleType, "TRUEe") == StringType) + } + test("Type arrays are merged to highest common type") { assert( CSVInferSchema.mergeRowTypes(Array(StringType), @@ -67,6 +77,7 @@ class InferSchemaSuite extends SparkFunSuite { assert(CSVInferSchema.inferField(IntegerType, "\\N", "\\N") == IntegerType) assert(CSVInferSchema.inferField(DoubleType, "\\N", "\\N") == DoubleType) assert(CSVInferSchema.inferField(TimestampType, "\\N", "\\N") == TimestampType) + assert(CSVInferSchema.inferField(BooleanType, "\\N", "\\N") == BooleanType) } test("Merging Nulltypes should yeild Nulltype.") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 9cd3a9ab952b4..53027bb698bf8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -43,6 +43,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { private val emptyFile = "empty.csv" private val commentsFile = "comments.csv" private val disableCommentsFile = "disable_comments.csv" + private val boolFile = "bool.csv" private val simpleSparseFile = "simple_sparse.csv" private def testFile(fileName: String): String = { @@ -118,6 +119,18 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { verifyCars(cars, withHeader = true, checkTypes = true) } + test("test inferring booleans") { + val result = sqlContext.read + .format("csv") + .option("header", "true") + .option("inferSchema", "true") + .load(testFile(boolFile)) + + val expectedSchema = StructType(List( + StructField("bool", BooleanType, nullable = true))) + assert(result.schema === expectedSchema) + } + test("test with alternative delimiter and quote") { val cars = sqlContext.read .format("csv")