From df548666ec758d96a19435d890bd68a3a147d046 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sat, 21 Jan 2017 23:20:25 +0900 Subject: [PATCH 1/2] Refactoring CSV read path to be consistent with JSON data source --- .../datasources/csv/CSVFileFormat.scala | 23 +- .../datasources/csv/CSVInferSchema.scala | 118 --------- .../datasources/csv/CSVOptions.scala | 20 +- .../execution/datasources/csv/CSVParser.scala | 60 ----- .../datasources/csv/CSVRelation.scala | 98 +------- .../datasources/csv/UnivocityParser.scala | 234 ++++++++++++++++++ ...Suite.scala => UnivocityParserSuite.scala} | 54 ++-- 7 files changed, 291 insertions(+), 316 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala rename sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/{CSVTypeCastSuite.scala => UnivocityParserSuite.scala} (77%) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala index eec2ba8068d5d..9b0275170a574 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.datasources.csv import java.nio.charset.{Charset, StandardCharsets} +import com.univocity.parsers.csv.CsvParser import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.io.{LongWritable, Text} @@ -28,7 +29,7 @@ import org.apache.hadoop.mapreduce._ import org.apache.spark.TaskContext import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Dataset, Encoders, Row, SparkSession} +import org.apache.spark.sql.{Dataset, Encoders, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.CompressionCodecs import org.apache.spark.sql.execution.datasources._ @@ -61,7 +62,7 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { val paths = files.map(_.getPath.toString) val lines: Dataset[String] = readText(sparkSession, csvOptions, paths) val firstLine: String = findFirstLine(csvOptions, lines) - val firstRow = new CsvReader(csvOptions).parseLine(firstLine) + val firstRow = new CsvParser(csvOptions.asParserSettings).parseLine(firstLine) val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis val header = makeSafeHeader(firstRow, csvOptions, caseSensitive) @@ -155,7 +156,6 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = { val csvOptions = new CSVOptions(options) val commentPrefix = csvOptions.comment.toString - val headers = requiredSchema.fields.map(_.name) val broadcastedHadoopConf = sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) @@ -172,21 +172,12 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { CSVRelation.dropHeaderLine(file, lineIterator, csvOptions) - val csvParser = new CsvReader(csvOptions) - val tokenizedIterator = lineIterator.filter { line => + val linesWithoutHeader = lineIterator.filter { line => line.trim.nonEmpty && !line.startsWith(commentPrefix) - }.map { line => - csvParser.parseLine(line) - } - val parser = CSVRelation.csvParser(dataSchema, requiredSchema.fieldNames, csvOptions) - var numMalformedRecords = 0 - tokenizedIterator.flatMap { recordTokens => - val row = parser(recordTokens, numMalformedRecords) - if (row.isEmpty) { - numMalformedRecords += 1 - } - row } + + val parser = new UnivocityParser(dataSchema, requiredSchema, csvOptions) + linesWithoutHeader.flatMap(parser.parse) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala index adc92fe5a31e6..065bf53574366 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala @@ -217,124 +217,6 @@ private[csv] object CSVInferSchema { } private[csv] object CSVTypeCast { - // A `ValueConverter` is responsible for converting the given value to a desired type. - private type ValueConverter = String => Any - - /** - * Create converters which cast each given string datum to each specified type in given schema. - * Currently, we do not support complex types (`ArrayType`, `MapType`, `StructType`). - * - * For string types, this is simply the datum. - * For other types, this is converted into the value according to the type. - * For other nullable types, returns null if it is null or equals to the value specified - * in `nullValue` option. - * - * @param schema schema that contains data types to cast the given value into. - * @param options CSV options. - */ - def makeConverters( - schema: StructType, - options: CSVOptions = CSVOptions()): Array[ValueConverter] = { - schema.map(f => makeConverter(f.name, f.dataType, f.nullable, options)).toArray - } - - /** - * Create a converter which converts the string value to a value according to a desired type. - */ - def makeConverter( - name: String, - dataType: DataType, - nullable: Boolean = true, - options: CSVOptions = CSVOptions()): ValueConverter = dataType match { - case _: ByteType => (d: String) => - nullSafeDatum(d, name, nullable, options)(_.toByte) - - case _: ShortType => (d: String) => - nullSafeDatum(d, name, nullable, options)(_.toShort) - - case _: IntegerType => (d: String) => - nullSafeDatum(d, name, nullable, options)(_.toInt) - - case _: LongType => (d: String) => - nullSafeDatum(d, name, nullable, options)(_.toLong) - - case _: FloatType => (d: String) => - nullSafeDatum(d, name, nullable, options) { - case options.nanValue => Float.NaN - case options.negativeInf => Float.NegativeInfinity - case options.positiveInf => Float.PositiveInfinity - case datum => - Try(datum.toFloat) - .getOrElse(NumberFormat.getInstance(Locale.US).parse(datum).floatValue()) - } - - case _: DoubleType => (d: String) => - nullSafeDatum(d, name, nullable, options) { - case options.nanValue => Double.NaN - case options.negativeInf => Double.NegativeInfinity - case options.positiveInf => Double.PositiveInfinity - case datum => - Try(datum.toDouble) - .getOrElse(NumberFormat.getInstance(Locale.US).parse(datum).doubleValue()) - } - - case _: BooleanType => (d: String) => - nullSafeDatum(d, name, nullable, options)(_.toBoolean) - - case dt: DecimalType => (d: String) => - nullSafeDatum(d, name, nullable, options) { datum => - val value = new BigDecimal(datum.replaceAll(",", "")) - Decimal(value, dt.precision, dt.scale) - } - - case _: TimestampType => (d: String) => - nullSafeDatum(d, name, nullable, options) { datum => - // This one will lose microseconds parts. - // See https://issues.apache.org/jira/browse/SPARK-10681. - Try(options.timestampFormat.parse(datum).getTime * 1000L) - .getOrElse { - // If it fails to parse, then tries the way used in 2.0 and 1.x for backwards - // compatibility. - DateTimeUtils.stringToTime(datum).getTime * 1000L - } - } - - case _: DateType => (d: String) => - nullSafeDatum(d, name, nullable, options) { datum => - // This one will lose microseconds parts. - // See https://issues.apache.org/jira/browse/SPARK-10681.x - Try(DateTimeUtils.millisToDays(options.dateFormat.parse(datum).getTime)) - .getOrElse { - // If it fails to parse, then tries the way used in 2.0 and 1.x for backwards - // compatibility. - DateTimeUtils.millisToDays(DateTimeUtils.stringToTime(datum).getTime) - } - } - - case _: StringType => (d: String) => - nullSafeDatum(d, name, nullable, options)(UTF8String.fromString(_)) - - case udt: UserDefinedType[_] => (datum: String) => - makeConverter(name, udt.sqlType, nullable, options) - - case _ => throw new RuntimeException(s"Unsupported type: ${dataType.typeName}") - } - - private def nullSafeDatum( - datum: String, - name: String, - nullable: Boolean, - options: CSVOptions)(converter: ValueConverter): Any = { - if (datum == options.nullValue || datum == null) { - if (!nullable) { - throw new RuntimeException(s"null value found but field $name is not nullable.") - } - null - } else { - converter.apply(datum) - } - } - /** * Helper method that converts string representation of a character to actual character. * It handles some Java escaped strings and throws exception if given string is longer than one diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index 51213e3f36bef..140ce23958dc0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.datasources.csv import java.nio.charset.StandardCharsets import java.util.Locale -import com.univocity.parsers.csv.CsvWriterSettings +import com.univocity.parsers.csv.{CsvParserSettings, CsvWriterSettings, UnescapedQuoteHandling} import org.apache.commons.lang3.time.FastDateFormat import org.apache.spark.internal.Logging @@ -142,6 +142,24 @@ private[csv] class CSVOptions(@transient private val parameters: CaseInsensitive writerSettings.setQuoteEscapingEnabled(escapeQuotes) writerSettings } + + def asParserSettings: CsvParserSettings = { + val settings = new CsvParserSettings() + val format = settings.getFormat + format.setDelimiter(delimiter) + format.setQuote(quote) + format.setQuoteEscape(escape) + format.setComment(comment) + settings.setIgnoreLeadingWhitespaces(ignoreLeadingWhiteSpaceFlag) + settings.setIgnoreTrailingWhitespaces(ignoreTrailingWhiteSpaceFlag) + settings.setReadInputOnSeparateThread(false) + settings.setInputBufferSize(inputBufferSize) + settings.setMaxColumns(maxColumns) + settings.setNullValue(nullValue) + settings.setMaxCharsPerColumn(maxCharsPerColumn) + settings.setUnescapedQuoteHandling(UnescapedQuoteHandling.STOP_AT_DELIMITER) + settings + } } object CSVOptions { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala deleted file mode 100644 index 4caf72463d607..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala +++ /dev/null @@ -1,60 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.datasources.csv - -import java.io.{CharArrayWriter, OutputStream, StringReader} -import java.nio.charset.StandardCharsets - -import com.univocity.parsers.csv._ - -import org.apache.spark.internal.Logging - -/** - * Read and parse CSV-like input - * - * @param params Parameters object - */ -private[csv] class CsvReader(params: CSVOptions) { - - private val parser: CsvParser = { - val settings = new CsvParserSettings() - val format = settings.getFormat - format.setDelimiter(params.delimiter) - format.setQuote(params.quote) - format.setQuoteEscape(params.escape) - format.setComment(params.comment) - settings.setIgnoreLeadingWhitespaces(params.ignoreLeadingWhiteSpaceFlag) - settings.setIgnoreTrailingWhitespaces(params.ignoreTrailingWhiteSpaceFlag) - settings.setReadInputOnSeparateThread(false) - settings.setInputBufferSize(params.inputBufferSize) - settings.setMaxColumns(params.maxColumns) - settings.setNullValue(params.nullValue) - settings.setMaxCharsPerColumn(params.maxCharsPerColumn) - settings.setUnescapedQuoteHandling(UnescapedQuoteHandling.STOP_AT_DELIMITER) - - new CsvParser(settings) - } - - /** - * parse a line - * - * @param line a String with no newline at the end - * @return array of strings where each string is a field in the CSV record - */ - def parseLine(line: String): Array[String] = parser.parseLine(line) -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala index 9679e4285e536..19058c23abe75 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala @@ -17,19 +17,12 @@ package org.apache.spark.sql.execution.datasources.csv -import scala.util.control.NonFatal - -import org.apache.hadoop.fs.Path -import org.apache.hadoop.mapreduce.TaskAttemptContext +import com.univocity.parsers.csv.CsvParser import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.GenericInternalRow -import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.datasources.{CodecStreams, OutputWriter, OutputWriterFactory, PartitionedFile} -import org.apache.spark.sql.types._ +import org.apache.spark.sql.execution.datasources.PartitionedFile object CSVRelation extends Logging { @@ -40,7 +33,7 @@ object CSVRelation extends Logging { // If header is set, make sure firstLine is materialized before sending to executors. val commentPrefix = params.comment.toString file.rdd.mapPartitions { iter => - val parser = new CsvReader(params) + val parser = new CsvParser(params.asParserSettings) val filteredIter = iter.filter { line => line.trim.nonEmpty && !line.startsWith(commentPrefix) } @@ -56,91 +49,6 @@ object CSVRelation extends Logging { } } - /** - * Returns a function that parses a single CSV record (in the form of an array of strings in which - * each element represents a column) and turns it into either one resulting row or no row (if the - * the record is malformed). - * - * The 2nd argument in the returned function represents the total number of malformed rows - * observed so far. - */ - // This is pretty convoluted and we should probably rewrite the entire CSV parsing soon. - def csvParser( - schema: StructType, - requiredColumns: Array[String], - params: CSVOptions): (Array[String], Int) => Option[InternalRow] = { - val requiredFields = StructType(requiredColumns.map(schema(_))).fields - val safeRequiredFields = if (params.dropMalformed) { - // If `dropMalformed` is enabled, then it needs to parse all the values - // so that we can decide which row is malformed. - requiredFields ++ schema.filterNot(requiredFields.contains(_)) - } else { - requiredFields - } - val safeRequiredIndices = new Array[Int](safeRequiredFields.length) - schema.zipWithIndex.filter { case (field, _) => - safeRequiredFields.contains(field) - }.foreach { case (field, index) => - safeRequiredIndices(safeRequiredFields.indexOf(field)) = index - } - val requiredSize = requiredFields.length - val row = new GenericInternalRow(requiredSize) - val converters = CSVTypeCast.makeConverters(schema, params) - - (tokens: Array[String], numMalformedRows) => { - if (params.dropMalformed && schema.length != tokens.length) { - if (numMalformedRows < params.maxMalformedLogPerPartition) { - logWarning(s"Dropping malformed line: ${tokens.mkString(params.delimiter.toString)}") - } - if (numMalformedRows == params.maxMalformedLogPerPartition - 1) { - logWarning( - s"More than ${params.maxMalformedLogPerPartition} malformed records have been " + - "found on this partition. Malformed records from now on will not be logged.") - } - None - } else if (params.failFast && schema.length != tokens.length) { - throw new RuntimeException(s"Malformed line in FAILFAST mode: " + - s"${tokens.mkString(params.delimiter.toString)}") - } else { - val indexSafeTokens = if (params.permissive && schema.length > tokens.length) { - tokens ++ new Array[String](schema.length - tokens.length) - } else if (params.permissive && schema.length < tokens.length) { - tokens.take(schema.length) - } else { - tokens - } - try { - var index: Int = 0 - var subIndex: Int = 0 - while (subIndex < safeRequiredIndices.length) { - index = safeRequiredIndices(subIndex) - // It anyway needs to try to parse since it decides if this row is malformed - // or not after trying to cast in `DROPMALFORMED` mode even if the casted - // value is not stored in the row. - val value = converters(index).apply(indexSafeTokens(index)) - if (subIndex < requiredSize) { - row(subIndex) = value - } - subIndex += 1 - } - Some(row) - } catch { - case NonFatal(e) if params.dropMalformed => - if (numMalformedRows < params.maxMalformedLogPerPartition) { - logWarning("Parse exception. " + - s"Dropping malformed line: ${tokens.mkString(params.delimiter.toString)}") - } - if (numMalformedRows == params.maxMalformedLogPerPartition - 1) { - logWarning( - s"More than ${params.maxMalformedLogPerPartition} malformed records have been " + - "found on this partition. Malformed records from now on will not be logged.") - } - None - } - } - } - } - // Skips the header line of each file if the `header` option is set to true. def dropHeaderLine( file: PartitionedFile, lines: Iterator[String], csvOptions: CSVOptions): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala new file mode 100644 index 0000000000000..c60208139259a --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala @@ -0,0 +1,234 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.csv + +import java.math.BigDecimal +import java.text.NumberFormat +import java.util.Locale + +import scala.util.Try +import scala.util.control.NonFatal + +import com.univocity.parsers.csv.CsvParser + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + + +private[csv] class UnivocityParser( + schema: StructType, + requiredSchema: StructType, + options: CSVOptions) extends Logging { + def this(schema: StructType, options: CSVOptions) = this(schema, schema, options) + + private val valueConverters = + schema.map(f => makeConverter(f.name, f.dataType, f.nullable, options)).toArray + private val parser = new CsvParser(options.asParserSettings) + + // A `ValueConverter` is responsible for converting the given value to a desired type. + private type ValueConverter = String => Any + + private var numMalformedRecords = 0 + private val row = new GenericInternalRow(requiredSchema.length) + private val indexArr: Array[Int] = { + val fields = if (options.dropMalformed) { + // If `dropMalformed` is enabled, then it needs to parse all the values + // so that we can decide which row is malformed. + requiredSchema ++ schema.filterNot(requiredSchema.contains(_)) + } else { + requiredSchema + } + fields.filter(schema.contains).map(schema.indexOf).toArray + } + + /** + * Create a converter which converts the string value to a value according to a desired type. + * Currently, we do not support complex types (`ArrayType`, `MapType`, `StructType`). + * + * For other nullable types, returns null if it is null or equals to the value specified + * in `nullValue` option. + */ + def makeConverter( + name: String, + dataType: DataType, + nullable: Boolean = true, + options: CSVOptions = CSVOptions()): ValueConverter = dataType match { + case _: ByteType => (d: String) => + nullSafeDatum(d, name, nullable, options)(_.toByte) + + case _: ShortType => (d: String) => + nullSafeDatum(d, name, nullable, options)(_.toShort) + + case _: IntegerType => (d: String) => + nullSafeDatum(d, name, nullable, options)(_.toInt) + + case _: LongType => (d: String) => + nullSafeDatum(d, name, nullable, options)(_.toLong) + + case _: FloatType => (d: String) => + nullSafeDatum(d, name, nullable, options) { + case options.nanValue => Float.NaN + case options.negativeInf => Float.NegativeInfinity + case options.positiveInf => Float.PositiveInfinity + case datum => + Try(datum.toFloat) + .getOrElse(NumberFormat.getInstance(Locale.US).parse(datum).floatValue()) + } + + case _: DoubleType => (d: String) => + nullSafeDatum(d, name, nullable, options) { + case options.nanValue => Double.NaN + case options.negativeInf => Double.NegativeInfinity + case options.positiveInf => Double.PositiveInfinity + case datum => + Try(datum.toDouble) + .getOrElse(NumberFormat.getInstance(Locale.US).parse(datum).doubleValue()) + } + + case _: BooleanType => (d: String) => + nullSafeDatum(d, name, nullable, options)(_.toBoolean) + + case dt: DecimalType => (d: String) => + nullSafeDatum(d, name, nullable, options) { datum => + val value = new BigDecimal(datum.replaceAll(",", "")) + Decimal(value, dt.precision, dt.scale) + } + + case _: TimestampType => (d: String) => + nullSafeDatum(d, name, nullable, options) { datum => + // This one will lose microseconds parts. + // See https://issues.apache.org/jira/browse/SPARK-10681. + Try(options.timestampFormat.parse(datum).getTime * 1000L) + .getOrElse { + // If it fails to parse, then tries the way used in 2.0 and 1.x for backwards + // compatibility. + DateTimeUtils.stringToTime(datum).getTime * 1000L + } + } + + case _: DateType => (d: String) => + nullSafeDatum(d, name, nullable, options) { datum => + // This one will lose microseconds parts. + // See https://issues.apache.org/jira/browse/SPARK-10681.x + Try(DateTimeUtils.millisToDays(options.dateFormat.parse(datum).getTime)) + .getOrElse { + // If it fails to parse, then tries the way used in 2.0 and 1.x for backwards + // compatibility. + DateTimeUtils.millisToDays(DateTimeUtils.stringToTime(datum).getTime) + } + } + + case _: StringType => (d: String) => + nullSafeDatum(d, name, nullable, options)(UTF8String.fromString(_)) + + case udt: UserDefinedType[_] => (datum: String) => + makeConverter(name, udt.sqlType, nullable, options) + + case _ => throw new RuntimeException(s"Unsupported type: ${dataType.typeName}") + } + + private def nullSafeDatum( + datum: String, + name: String, + nullable: Boolean, + options: CSVOptions)(converter: ValueConverter): Any = { + if (datum == options.nullValue || datum == null) { + if (!nullable) { + throw new RuntimeException(s"null value found but field $name is not nullable.") + } + null + } else { + converter.apply(datum) + } + } + + /** + * Parses a single CSV record (in the form of an array of strings in which + * each element represents a column) and turns it into either one resulting row or no row (if the + * the record is malformed). + */ + def parse(input: String): Option[InternalRow] = { + tokenizeWithParseMode(input) { tokens => + var i: Int = 0 + while (i < indexArr.length) { + val pos = indexArr(i) + // It anyway needs to try to parse since it decides if this row is malformed + // or not after trying to cast in `DROPMALFORMED` mode even if the casted + // value is not stored in the row. + val value = valueConverters(pos).apply(tokens(pos)) + if (i < requiredSchema.length) { + row(i) = value + } + i += 1 + } + row + } + } + + /** + * Tokenize the input string into the array of strings with the given parse mode. + */ + private def tokenizeWithParseMode( + input: String)(convert: Array[String] => InternalRow): Option[InternalRow] = { + val tokens = parser.parseLine(input) + if (options.dropMalformed && schema.length != tokens.length) { + if (numMalformedRecords < options.maxMalformedLogPerPartition) { + logWarning(s"Dropping malformed line: ${tokens.mkString(options.delimiter.toString)}") + } + if (numMalformedRecords == options.maxMalformedLogPerPartition - 1) { + logWarning( + s"More than ${options.maxMalformedLogPerPartition} malformed records have been " + + "found on this partition. Malformed records from now on will not be logged.") + } + numMalformedRecords += 1 + None + } else if (options.failFast && schema.length != tokens.length) { + throw new RuntimeException(s"Malformed line in FAILFAST mode: " + + s"${tokens.mkString(options.delimiter.toString)}") + } else { + val checkedTokens = if (options.permissive && schema.length > tokens.length) { + tokens ++ new Array[String](schema.length - tokens.length) + } else if (options.permissive && schema.length < tokens.length) { + tokens.take(schema.length) + } else { + tokens + } + + try { + Some(convert(checkedTokens)) + } catch { + case NonFatal(e) if options.dropMalformed => + if (numMalformedRecords < options.maxMalformedLogPerPartition) { + logWarning("Parse exception. " + + s"Dropping malformed line: ${tokens.mkString(options.delimiter.toString)}") + } + if (numMalformedRecords == options.maxMalformedLogPerPartition - 1) { + logWarning( + s"More than ${options.maxMalformedLogPerPartition} malformed records have been " + + "found on this partition. Malformed records from now on will not be logged.") + } + numMalformedRecords += 1 + None + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParserSuite.scala similarity index 77% rename from sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParserSuite.scala index ffd3d260bcb40..2ca6308852a7e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParserSuite.scala @@ -25,7 +25,9 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -class CSVTypeCastSuite extends SparkFunSuite { +class UnivocityParserSuite extends SparkFunSuite { + private val parser = + new UnivocityParser(StructType(Seq.empty), new CSVOptions(Map.empty[String, String])) private def assertNull(v: Any) = assert(v == null) @@ -36,7 +38,7 @@ class CSVTypeCastSuite extends SparkFunSuite { stringValues.zip(decimalValues).foreach { case (strVal, decimalVal) => val decimalValue = new BigDecimal(decimalVal.toString) - assert(CSVTypeCast.makeConverter("_1", decimalType).apply(strVal) === + assert(parser.makeConverter("_1", decimalType).apply(strVal) === Decimal(decimalValue, decimalType.precision, decimalType.scale)) } } @@ -73,19 +75,19 @@ class CSVTypeCastSuite extends SparkFunSuite { types.foreach { t => // Tests that a custom nullValue. val converter = - CSVTypeCast.makeConverter("_1", t, nullable = true, CSVOptions("nullValue", "-")) + parser.makeConverter("_1", t, nullable = true, CSVOptions("nullValue", "-")) assertNull(converter.apply("-")) assertNull(converter.apply(null)) // Tests that the default nullValue is empty string. - assertNull(CSVTypeCast.makeConverter("_1", t, nullable = true).apply("")) + assertNull(parser.makeConverter("_1", t, nullable = true).apply("")) } // Not nullable field with nullValue option. types.foreach { t => // Casts a null to not nullable field should throw an exception. val converter = - CSVTypeCast.makeConverter("_1", t, nullable = false, CSVOptions("nullValue", "-")) + parser.makeConverter("_1", t, nullable = false, CSVOptions("nullValue", "-")) var message = intercept[RuntimeException] { converter.apply("-") }.getMessage @@ -100,32 +102,32 @@ class CSVTypeCastSuite extends SparkFunSuite { // null. Seq(true, false).foreach { b => val converter = - CSVTypeCast.makeConverter("_1", StringType, nullable = b, CSVOptions("nullValue", "null")) + parser.makeConverter("_1", StringType, nullable = b, CSVOptions("nullValue", "null")) assert(converter.apply("") == UTF8String.fromString("")) } } test("Throws exception for empty string with non null type") { val exception = intercept[RuntimeException]{ - CSVTypeCast.makeConverter("_1", IntegerType, nullable = false, CSVOptions()).apply("") + parser.makeConverter("_1", IntegerType, nullable = false, CSVOptions()).apply("") } assert(exception.getMessage.contains("null value found but field _1 is not nullable.")) } test("Types are cast correctly") { - assert(CSVTypeCast.makeConverter("_1", ByteType).apply("10") == 10) - assert(CSVTypeCast.makeConverter("_1", ShortType).apply("10") == 10) - assert(CSVTypeCast.makeConverter("_1", IntegerType).apply("10") == 10) - assert(CSVTypeCast.makeConverter("_1", LongType).apply("10") == 10) - assert(CSVTypeCast.makeConverter("_1", FloatType).apply("1.00") == 1.0) - assert(CSVTypeCast.makeConverter("_1", DoubleType).apply("1.00") == 1.0) - assert(CSVTypeCast.makeConverter("_1", BooleanType).apply("true") == true) + assert(parser.makeConverter("_1", ByteType).apply("10") == 10) + assert(parser.makeConverter("_1", ShortType).apply("10") == 10) + assert(parser.makeConverter("_1", IntegerType).apply("10") == 10) + assert(parser.makeConverter("_1", LongType).apply("10") == 10) + assert(parser.makeConverter("_1", FloatType).apply("1.00") == 1.0) + assert(parser.makeConverter("_1", DoubleType).apply("1.00") == 1.0) + assert(parser.makeConverter("_1", BooleanType).apply("true") == true) val timestampsOptions = CSVOptions("timestampFormat", "dd/MM/yyyy hh:mm") val customTimestamp = "31/01/2015 00:00" val expectedTime = timestampsOptions.timestampFormat.parse(customTimestamp).getTime val castedTimestamp = - CSVTypeCast.makeConverter("_1", TimestampType, nullable = true, timestampsOptions) + parser.makeConverter("_1", TimestampType, nullable = true, timestampsOptions) .apply(customTimestamp) assert(castedTimestamp == expectedTime * 1000L) @@ -133,14 +135,14 @@ class CSVTypeCastSuite extends SparkFunSuite { val dateOptions = CSVOptions("dateFormat", "dd/MM/yyyy") val expectedDate = dateOptions.dateFormat.parse(customDate).getTime val castedDate = - CSVTypeCast.makeConverter("_1", DateType, nullable = true, dateOptions) + parser.makeConverter("_1", DateType, nullable = true, dateOptions) .apply(customTimestamp) assert(castedDate == DateTimeUtils.millisToDays(expectedDate)) val timestamp = "2015-01-01 00:00:00" - assert(CSVTypeCast.makeConverter("_1", TimestampType).apply(timestamp) == + assert(parser.makeConverter("_1", TimestampType).apply(timestamp) == DateTimeUtils.stringToTime(timestamp).getTime * 1000L) - assert(CSVTypeCast.makeConverter("_1", DateType).apply("2015-01-01") == + assert(parser.makeConverter("_1", DateType).apply("2015-01-01") == DateTimeUtils.millisToDays(DateTimeUtils.stringToTime("2015-01-01").getTime)) } @@ -149,15 +151,15 @@ class CSVTypeCastSuite extends SparkFunSuite { try { Locale.setDefault(new Locale("fr", "FR")) // Would parse as 1.0 in fr-FR - assert(CSVTypeCast.makeConverter("_1", FloatType).apply("1,00") == 100.0) - assert(CSVTypeCast.makeConverter("_1", DoubleType).apply("1,00") == 100.0) + assert(parser.makeConverter("_1", FloatType).apply("1,00") == 100.0) + assert(parser.makeConverter("_1", DoubleType).apply("1,00") == 100.0) } finally { Locale.setDefault(originalLocale) } } test("Float NaN values are parsed correctly") { - val floatVal: Float = CSVTypeCast.makeConverter( + val floatVal: Float = parser.makeConverter( "_1", FloatType, nullable = true, CSVOptions("nanValue", "nn") ).apply("nn").asInstanceOf[Float] @@ -167,7 +169,7 @@ class CSVTypeCastSuite extends SparkFunSuite { } test("Double NaN values are parsed correctly") { - val doubleVal: Double = CSVTypeCast.makeConverter( + val doubleVal: Double = parser.makeConverter( "_1", DoubleType, nullable = true, CSVOptions("nanValue", "-") ).apply("-").asInstanceOf[Double] @@ -175,13 +177,13 @@ class CSVTypeCastSuite extends SparkFunSuite { } test("Float infinite values can be parsed") { - val floatVal1 = CSVTypeCast.makeConverter( + val floatVal1 = parser.makeConverter( "_1", FloatType, nullable = true, CSVOptions("negativeInf", "max") ).apply("max").asInstanceOf[Float] assert(floatVal1 == Float.NegativeInfinity) - val floatVal2 = CSVTypeCast.makeConverter( + val floatVal2 = parser.makeConverter( "_1", FloatType, nullable = true, CSVOptions("positiveInf", "max") ).apply("max").asInstanceOf[Float] @@ -189,13 +191,13 @@ class CSVTypeCastSuite extends SparkFunSuite { } test("Double infinite values can be parsed") { - val doubleVal1 = CSVTypeCast.makeConverter( + val doubleVal1 = parser.makeConverter( "_1", DoubleType, nullable = true, CSVOptions("negativeInf", "max") ).apply("max").asInstanceOf[Double] assert(doubleVal1 == Double.NegativeInfinity) - val doubleVal2 = CSVTypeCast.makeConverter( + val doubleVal2 = parser.makeConverter( "_1", DoubleType, nullable = true, CSVOptions("positiveInf", "max") ).apply("max").asInstanceOf[Double] From b2938ae080ee7c36ef751b0bca57c2bfbdf99b43 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sun, 22 Jan 2017 02:05:03 +0900 Subject: [PATCH 2/2] Add some comments and make it cleaner --- .../datasources/csv/CSVFileFormat.scala | 5 ++-- .../datasources/csv/UnivocityParser.scala | 25 ++++++++++--------- 2 files changed, 16 insertions(+), 14 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala index 9b0275170a574..38970160d5fb3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala @@ -170,14 +170,15 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { } } + // Consumes the header in the iterator. CSVRelation.dropHeaderLine(file, lineIterator, csvOptions) - val linesWithoutHeader = lineIterator.filter { line => + val filteredIter = lineIterator.filter { line => line.trim.nonEmpty && !line.startsWith(commentPrefix) } val parser = new UnivocityParser(dataSchema, requiredSchema, csvOptions) - linesWithoutHeader.flatMap(parser.parse) + filteredIter.flatMap(parser.parse) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala index c60208139259a..8bd1ba6959a28 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala @@ -33,22 +33,27 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String - private[csv] class UnivocityParser( schema: StructType, requiredSchema: StructType, options: CSVOptions) extends Logging { + require(requiredSchema.toSet.subsetOf(schema.toSet), + "requiredSchema should be the subset of schema.") + def this(schema: StructType, options: CSVOptions) = this(schema, schema, options) + // A `ValueConverter` is responsible for converting the given value to a desired type. + private type ValueConverter = String => Any + private val valueConverters = schema.map(f => makeConverter(f.name, f.dataType, f.nullable, options)).toArray - private val parser = new CsvParser(options.asParserSettings) - // A `ValueConverter` is responsible for converting the given value to a desired type. - private type ValueConverter = String => Any + private val parser = new CsvParser(options.asParserSettings) private var numMalformedRecords = 0 + private val row = new GenericInternalRow(requiredSchema.length) + private val indexArr: Array[Int] = { val fields = if (options.dropMalformed) { // If `dropMalformed` is enabled, then it needs to parse all the values @@ -57,7 +62,7 @@ private[csv] class UnivocityParser( } else { requiredSchema } - fields.filter(schema.contains).map(schema.indexOf).toArray + fields.map(schema.indexOf).toArray } /** @@ -167,7 +172,7 @@ private[csv] class UnivocityParser( * the record is malformed). */ def parse(input: String): Option[InternalRow] = { - tokenizeWithParseMode(input) { tokens => + convertWithParseMode(parser.parseLine(input)) { tokens => var i: Int = 0 while (i < indexArr.length) { val pos = indexArr(i) @@ -184,12 +189,8 @@ private[csv] class UnivocityParser( } } - /** - * Tokenize the input string into the array of strings with the given parse mode. - */ - private def tokenizeWithParseMode( - input: String)(convert: Array[String] => InternalRow): Option[InternalRow] = { - val tokens = parser.parseLine(input) + private def convertWithParseMode( + tokens: Array[String])(convert: Array[String] => InternalRow): Option[InternalRow] = { if (options.dropMalformed && schema.length != tokens.length) { if (numMalformedRecords < options.maxMalformedLogPerPartition) { logWarning(s"Dropping malformed line: ${tokens.mkString(options.delimiter.toString)}")