Permalink
Browse files

Check header on csv parsing from dataset of strings

  • Loading branch information...
MaxGekk committed Apr 13, 2018
1 parent 78d9f66 commit b43a7c7ec50e03aaf4990e9bbb6989cdb2c076ef
@@ -21,6 +21,8 @@ import java.util.{Locale, Properties}
import scala.collection.JavaConverters._
import com.univocity.parsers.csv.CsvParser
import org.apache.spark.Partition
import org.apache.spark.annotation.InterfaceStability
import org.apache.spark.api.java.JavaRDD
@@ -486,6 +488,11 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
StructType(schema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord))
val linesWithoutHeader: RDD[String] = maybeFirstLine.map { firstLine =>
if (parsedOptions.enforceSchema == false) {
CSVDataSource.checkHeader(firstLine, new CsvParser(parsedOptions.asParserSettings),
actualSchema, csvDataset.getClass.getCanonicalName, checkHeaderFlag = true,
sparkSession.sessionState.conf.caseSensitiveAnalysis)
}
filteredLines.rdd.mapPartitions(CSVUtils.filterHeaderLine(_, firstLine, parsedOptions))
}.getOrElse(filteredLines.rdd)
@@ -143,7 +143,7 @@ object CSVDataSource {
s"""|CSV file header does not contain the expected fields.
| Header: ${columnNames.mkString(", ")}
| Schema: ${fieldNames.mkString(", ")}
|Expected: $nameInSchema but found: $nameInHeader
|Expected: ${columnNames(i)} but found: ${fieldNames(i)}
|CSV file: $fileName""".stripMargin
)
}
@@ -162,6 +162,14 @@ object CSVDataSource {
}
}
}
def checkHeader(header: String, parser: CsvParser, schema: StructType, fileName: String,
checkHeaderFlag: Boolean, caseSensitive: Boolean): Unit = {
if (checkHeaderFlag) {
checkHeaderColumnNames(schema, parser.parseLine(header), fileName, checkHeaderFlag,
caseSensitive)
}
}
}
object TextInputCSVDataSource extends CSVDataSource {
@@ -185,7 +193,7 @@ object TextInputCSVDataSource extends CSVDataSource {
// Note: if there are only comments in the first block, the header would probably
// be not extracted.
CSVUtils.extractHeader(lines, parser.options).foreach { header =>
checkHeader(header, parser.tokenizer, dataSchema, file.filePath,
CSVDataSource.checkHeader(header, parser.tokenizer, dataSchema, file.filePath,
checkHeaderFlag = !parser.options.enforceSchema, caseSensitive)
}
}
@@ -248,14 +256,6 @@ object TextInputCSVDataSource extends CSVDataSource {
sparkSession.createDataset(rdd)(Encoders.STRING)
}
}
def checkHeader(header: String, parser: CsvParser, schema: StructType, fileName: String,
checkHeaderFlag: Boolean, caseSensitive: Boolean): Unit = {
if (checkHeaderFlag) {
checkHeaderColumnNames(schema, parser.parseLine(header), fileName, checkHeaderFlag,
caseSensitive)
}
}
}
object MultiLineCSVDataSource extends CSVDataSource {
@@ -1374,7 +1374,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
}
}
test("Ignore column name case if spark.sql.caseSensitive is false") {
test("SPARK-23786: Ignore column name case if spark.sql.caseSensitive is false") {
withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") {
withTempPath { path =>
import collection.JavaConverters._
@@ -1390,4 +1390,14 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
}
}
}
test("SPARK-23786: check header on parsing of dataset of strings") {
val ds = Seq("columnA,columnB", "1.0,1000.0").toDS()
val ischema = new StructType().add("columnB", DoubleType).add("columnA", DoubleType)
val exception = intercept[IllegalArgumentException] {
spark.read.schema(ischema).option("header", true).option("enforceSchema", false).csv(ds)
}
assert(exception.getMessage.contains("CSV file header does not contain the expected fields"))
}
}

0 comments on commit b43a7c7

Please sign in to comment.