Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-13442][SQL] Make type inference recognize boolean types #11315

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -87,6 +87,7 @@ private[csv] object CSVInferSchema {
case LongType => tryParseLong(field)
case DoubleType => tryParseDouble(field)
case TimestampType => tryParseTimestamp(field)
case BooleanType => tryParseBoolean(field)
case StringType => StringType
case other: DataType =>
throw new UnsupportedOperationException(s"Unexpected data type $other")
Expand Down Expand Up @@ -117,6 +118,14 @@ private[csv] object CSVInferSchema {
def tryParseTimestamp(field: String): DataType = {
if ((allCatch opt Timestamp.valueOf(field)).isDefined) {
TimestampType
} else {
tryParseBoolean(field)
}
}

def tryParseBoolean(field: String): DataType = {
if ((allCatch opt field.toBoolean).isDefined) {
BooleanType
} else {
stringType()
}
Expand Down
5 changes: 5 additions & 0 deletions sql/core/src/test/resources/bool.csv
@@ -0,0 +1,5 @@
bool
"True"
"False"

"true"
Expand Up @@ -30,6 +30,8 @@ class InferSchemaSuite extends SparkFunSuite {
assert(CSVInferSchema.inferField(NullType, "3.5") == DoubleType)
assert(CSVInferSchema.inferField(NullType, "test") == StringType)
assert(CSVInferSchema.inferField(NullType, "2015-08-20 15:57:00") == TimestampType)
assert(CSVInferSchema.inferField(NullType, "True") == BooleanType)
assert(CSVInferSchema.inferField(NullType, "FAlSE") == BooleanType)
}

test("String fields types are inferred correctly from other types") {
Expand All @@ -40,6 +42,9 @@ class InferSchemaSuite extends SparkFunSuite {
assert(CSVInferSchema.inferField(DoubleType, "test") == StringType)
assert(CSVInferSchema.inferField(LongType, "2015-08-20 14:57:00") == TimestampType)
assert(CSVInferSchema.inferField(DoubleType, "2015-08-20 15:57:00") == TimestampType)
assert(CSVInferSchema.inferField(LongType, "True") == BooleanType)
assert(CSVInferSchema.inferField(IntegerType, "FALSE") == BooleanType)
assert(CSVInferSchema.inferField(TimestampType, "FALSE") == BooleanType)
}

test("Timestamp field types are inferred correctly from other types") {
Expand All @@ -48,6 +53,11 @@ class InferSchemaSuite extends SparkFunSuite {
assert(CSVInferSchema.inferField(LongType, "2015-08 14:49:00") == StringType)
}

test("Boolean fields types are inferred correctly from other types") {
assert(CSVInferSchema.inferField(LongType, "Fale") == StringType)
assert(CSVInferSchema.inferField(DoubleType, "TRUEe") == StringType)
}

test("Type arrays are merged to highest common type") {
assert(
CSVInferSchema.mergeRowTypes(Array(StringType),
Expand All @@ -67,6 +77,7 @@ class InferSchemaSuite extends SparkFunSuite {
assert(CSVInferSchema.inferField(IntegerType, "\\N", "\\N") == IntegerType)
assert(CSVInferSchema.inferField(DoubleType, "\\N", "\\N") == DoubleType)
assert(CSVInferSchema.inferField(TimestampType, "\\N", "\\N") == TimestampType)
assert(CSVInferSchema.inferField(BooleanType, "\\N", "\\N") == BooleanType)
}

test("Merging Nulltypes should yeild Nulltype.") {
Expand Down
Expand Up @@ -37,6 +37,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
private val emptyFile = "empty.csv"
private val commentsFile = "comments.csv"
private val disableCommentsFile = "disable_comments.csv"
private val boolFile = "bool.csv"
private val simpleSparseFile = "simple_sparse.csv"

private def testFile(fileName: String): String = {
Expand Down Expand Up @@ -112,6 +113,18 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
verifyCars(cars, withHeader = true, checkTypes = true)
}

test("test inferring booleans") {
val result = sqlContext.read
.format("csv")
.option("header", "true")
.option("inferSchema", "true")
.load(testFile(boolFile))

val expectedSchema = StructType(List(
StructField("bool", BooleanType, nullable = true)))
assert(result.schema === expectedSchema)
}

test("test with alternative delimiter and quote") {
val cars = sqlContext.read
.format("csv")
Expand Down