Skip to content

Commit

Permalink
[SPARK-13442][SQL] Make type inference recognize boolean types
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

https://issues.apache.org/jira/browse/SPARK-13442

This PR adds the support for inferring `BooleanType` for schema.
It supports to infer case-insensitive `true` / `false` as `BooleanType`.

Unittests were added for `CSVInferSchemaSuite` and `CSVSuite` for end-to-end test.

## How was the this patch tested?

This was tested with unittests and with `dev/run_tests` for coding style

Author: hyukjinkwon <gurwls223@gmail.com>

Closes #11315 from HyukjinKwon/SPARK-13442.
  • Loading branch information
HyukjinKwon authored and rxin committed Mar 7, 2016
1 parent e1fb857 commit 8577260
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 0 deletions.
Expand Up @@ -87,6 +87,7 @@ private[csv] object CSVInferSchema {
case LongType => tryParseLong(field)
case DoubleType => tryParseDouble(field)
case TimestampType => tryParseTimestamp(field)
case BooleanType => tryParseBoolean(field)
case StringType => StringType
case other: DataType =>
throw new UnsupportedOperationException(s"Unexpected data type $other")
Expand Down Expand Up @@ -117,6 +118,14 @@ private[csv] object CSVInferSchema {
def tryParseTimestamp(field: String): DataType = {
if ((allCatch opt Timestamp.valueOf(field)).isDefined) {
TimestampType
} else {
tryParseBoolean(field)
}
}

def tryParseBoolean(field: String): DataType = {
if ((allCatch opt field.toBoolean).isDefined) {
BooleanType
} else {
stringType()
}
Expand Down
5 changes: 5 additions & 0 deletions sql/core/src/test/resources/bool.csv
@@ -0,0 +1,5 @@
bool
"True"
"False"

"true"
Expand Up @@ -30,6 +30,8 @@ class InferSchemaSuite extends SparkFunSuite {
assert(CSVInferSchema.inferField(NullType, "3.5") == DoubleType)
assert(CSVInferSchema.inferField(NullType, "test") == StringType)
assert(CSVInferSchema.inferField(NullType, "2015-08-20 15:57:00") == TimestampType)
assert(CSVInferSchema.inferField(NullType, "True") == BooleanType)
assert(CSVInferSchema.inferField(NullType, "FAlSE") == BooleanType)
}

test("String fields types are inferred correctly from other types") {
Expand All @@ -40,6 +42,9 @@ class InferSchemaSuite extends SparkFunSuite {
assert(CSVInferSchema.inferField(DoubleType, "test") == StringType)
assert(CSVInferSchema.inferField(LongType, "2015-08-20 14:57:00") == TimestampType)
assert(CSVInferSchema.inferField(DoubleType, "2015-08-20 15:57:00") == TimestampType)
assert(CSVInferSchema.inferField(LongType, "True") == BooleanType)
assert(CSVInferSchema.inferField(IntegerType, "FALSE") == BooleanType)
assert(CSVInferSchema.inferField(TimestampType, "FALSE") == BooleanType)
}

test("Timestamp field types are inferred correctly from other types") {
Expand All @@ -48,6 +53,11 @@ class InferSchemaSuite extends SparkFunSuite {
assert(CSVInferSchema.inferField(LongType, "2015-08 14:49:00") == StringType)
}

test("Boolean fields types are inferred correctly from other types") {
assert(CSVInferSchema.inferField(LongType, "Fale") == StringType)
assert(CSVInferSchema.inferField(DoubleType, "TRUEe") == StringType)
}

test("Type arrays are merged to highest common type") {
assert(
CSVInferSchema.mergeRowTypes(Array(StringType),
Expand All @@ -67,6 +77,7 @@ class InferSchemaSuite extends SparkFunSuite {
assert(CSVInferSchema.inferField(IntegerType, "\\N", "\\N") == IntegerType)
assert(CSVInferSchema.inferField(DoubleType, "\\N", "\\N") == DoubleType)
assert(CSVInferSchema.inferField(TimestampType, "\\N", "\\N") == TimestampType)
assert(CSVInferSchema.inferField(BooleanType, "\\N", "\\N") == BooleanType)
}

test("Merging Nulltypes should yeild Nulltype.") {
Expand Down
Expand Up @@ -43,6 +43,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
private val emptyFile = "empty.csv"
private val commentsFile = "comments.csv"
private val disableCommentsFile = "disable_comments.csv"
private val boolFile = "bool.csv"
private val simpleSparseFile = "simple_sparse.csv"

private def testFile(fileName: String): String = {
Expand Down Expand Up @@ -118,6 +119,18 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
verifyCars(cars, withHeader = true, checkTypes = true)
}

test("test inferring booleans") {
val result = sqlContext.read
.format("csv")
.option("header", "true")
.option("inferSchema", "true")
.load(testFile(boolFile))

val expectedSchema = StructType(List(
StructField("bool", BooleanType, nullable = true)))
assert(result.schema === expectedSchema)
}

test("test with alternative delimiter and quote") {
val cars = sqlContext.read
.format("csv")
Expand Down

0 comments on commit 8577260

Please sign in to comment.