From 23c1c3e01b64879e5889d6d08c8f824283574574 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 14 Mar 2017 22:05:19 +0800 Subject: [PATCH] unify bad record handling in CSV and JSON --- .../sql/catalyst/json/JacksonParser.scala | 107 +------ .../catalyst/util/BadRecordException.scala | 26 ++ .../spark/sql/catalyst/util/ParseModes.scala | 2 +- .../apache/spark/sql/DataFrameReader.scala | 18 +- .../sql/execution/DataSourceScanExec.scala | 26 +- .../datasources/DataSourceReader.scala | 269 ++++++++++++++++++ .../execution/datasources/FileScanRDD.scala | 21 +- .../datasources/csv/UnivocityParser.scala | 159 +++-------- .../datasources/FileSourceStrategySuite.scala | 4 +- .../execution/datasources/csv/CSVSuite.scala | 2 +- 10 files changed, 389 insertions(+), 245 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/BadRecordException.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceReader.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala index 9b80c0fc87c93..734f390c22bc0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala @@ -42,7 +42,6 @@ class JacksonParser( options: JSONOptions) extends Logging { import JacksonUtils._ - import ParseModes._ import com.fasterxml.jackson.core.JsonToken._ // A `ValueConverter` is responsible for converting a value from `JsonParser` @@ -55,108 +54,6 @@ class JacksonParser( private val factory = new JsonFactory() options.setJacksonOptions(factory) - private val emptyRow: Seq[InternalRow] = Seq(new GenericInternalRow(schema.length)) - - private val corruptFieldIndex = schema.getFieldIndex(options.columnNameOfCorruptRecord) - corruptFieldIndex.foreach { corrFieldIndex => - require(schema(corrFieldIndex).dataType == StringType) - require(schema(corrFieldIndex).nullable) - } - - @transient - private[this] var isWarningPrinted: Boolean = false - - @transient - private def printWarningForMalformedRecord(record: () => UTF8String): Unit = { - def sampleRecord: String = { - if (options.wholeFile) { - "" - } else { - s"Sample record: ${record()}\n" - } - } - - def footer: String = { - s"""Code example to print all malformed records (scala): - |=================================================== - |// The corrupted record exists in column ${options.columnNameOfCorruptRecord}. - |val parsedJson = spark.read.json("/path/to/json/file/test.json") - | - """.stripMargin - } - - if (options.permissive) { - logWarning( - s"""Found at least one malformed record. The JSON reader will replace - |all malformed records with placeholder null in current $PERMISSIVE_MODE parser mode. - |To find out which corrupted records have been replaced with null, please use the - |default inferred schema instead of providing a custom schema. - | - |${sampleRecord ++ footer} - | - """.stripMargin) - } else if (options.dropMalformed) { - logWarning( - s"""Found at least one malformed record. The JSON reader will drop - |all malformed records in current $DROP_MALFORMED_MODE parser mode. To find out which - |corrupted records have been dropped, please switch the parser mode to $PERMISSIVE_MODE - |mode and use the default inferred schema. - | - |${sampleRecord ++ footer} - | - """.stripMargin) - } - } - - @transient - private def printWarningIfWholeFile(): Unit = { - if (options.wholeFile && corruptFieldIndex.isDefined) { - logWarning( - s"""Enabling wholeFile mode and defining columnNameOfCorruptRecord may result - |in very large allocations or OutOfMemoryExceptions being raised. - | - """.stripMargin) - } - } - - /** - * This function deals with the cases it fails to parse. This function will be called - * when exceptions are caught during converting. This functions also deals with `mode` option. - */ - private def failedRecord(record: () => UTF8String): Seq[InternalRow] = { - corruptFieldIndex match { - case _ if options.failFast => - if (options.wholeFile) { - throw new SparkSQLJsonProcessingException("Malformed line in FAILFAST mode") - } else { - throw new SparkSQLJsonProcessingException(s"Malformed line in FAILFAST mode: ${record()}") - } - - case _ if options.dropMalformed => - if (!isWarningPrinted) { - printWarningForMalformedRecord(record) - isWarningPrinted = true - } - Nil - - case None => - if (!isWarningPrinted) { - printWarningForMalformedRecord(record) - isWarningPrinted = true - } - emptyRow - - case Some(corruptIndex) => - if (!isWarningPrinted) { - printWarningIfWholeFile() - isWarningPrinted = true - } - val row = new GenericInternalRow(schema.length) - row.update(corruptIndex, record()) - Seq(row) - } - } - /** * Create a converter which converts the JSON documents held by the `JsonParser` * to a value according to a desired schema. This is a wrapper for the method @@ -472,8 +369,8 @@ class JacksonParser( } } } catch { - case _: JsonProcessingException | _: SparkSQLJsonProcessingException => - failedRecord(() => recordLiteral(record)) + case e @ (_: JsonProcessingException | _: SparkSQLJsonProcessingException) => + throw BadRecordException(() => recordLiteral(record), () => None, e) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/BadRecordException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/BadRecordException.scala new file mode 100644 index 0000000000000..183c2c23931d5 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/BadRecordException.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.unsafe.types.UTF8String + +case class BadRecordException( + record: () => UTF8String, + partialResult: () => Option[InternalRow], + cause: Throwable) extends Exception(cause) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ParseModes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ParseModes.scala index 0e466962b4678..fa6b6187dd117 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ParseModes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ParseModes.scala @@ -36,6 +36,6 @@ object ParseModes { def isPermissiveMode(mode: String): Boolean = if (isValidMode(mode)) { mode.toUpperCase == PERMISSIVE_MODE } else { - true // We default to permissive is the mode string is not valid + true // We default to permissive if the mode string is not valid } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 88fbfb4c92a00..0bc23ef2ec947 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -26,11 +26,12 @@ import org.apache.spark.internal.Logging import org.apache.spark.Partition import org.apache.spark.annotation.InterfaceStability import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions} import org.apache.spark.sql.execution.LogicalRDD import org.apache.spark.sql.execution.command.DDLUtils +import org.apache.spark.sql.execution.datasources.{DataSource, DataSourceReader} import org.apache.spark.sql.execution.datasources.csv._ -import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.datasources.jdbc._ import org.apache.spark.sql.execution.datasources.json.TextInputJsonDataSource import org.apache.spark.sql.types.{StringType, StructType} @@ -384,9 +385,14 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { verifyColumnNameOfCorruptRecord(schema, parsedOptions.columnNameOfCorruptRecord) val createParser = CreateJacksonParser.string _ + val dataSourceReader = DataSourceReader( + schema, + extraOptions.toMap, + sparkSession.sessionState.conf.columnNameOfCorruptRecord) val parsed = jsonDataset.rdd.mapPartitions { iter => val parser = new JacksonParser(schema, parsedOptions) - iter.flatMap(parser.parse(_, createParser, UTF8String.fromString)) + dataSourceReader.read(iter.flatMap(parser.parse(_, createParser, UTF8String.fromString))) + .asInstanceOf[Iterator[InternalRow]] } Dataset.ofRows( @@ -439,10 +445,14 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { val linesWithoutHeader: RDD[String] = maybeFirstLine.map { firstLine => filteredLines.rdd.mapPartitions(CSVUtils.filterHeaderLine(_, firstLine, parsedOptions)) }.getOrElse(filteredLines.rdd) - + val dataSourceReader = DataSourceReader( + schema, + extraOptions.toMap, + sparkSession.sessionState.conf.columnNameOfCorruptRecord) val parsed = linesWithoutHeader.mapPartitions { iter => val parser = new UnivocityParser(schema, parsedOptions) - iter.flatMap(line => parser.parse(line)) + dataSourceReader.read(iter.map(line => parser.parse(line))) + .asInstanceOf[Iterator[InternalRow]] } Dataset.ofRows( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 8ebad676ca310..3bbf5ce2aa98a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -23,18 +23,19 @@ import org.apache.commons.lang3.StringUtils import org.apache.hadoop.fs.{BlockLocation, FileStatus, LocatedFileStatus, Path} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{AnalysisException, SparkSession} +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, UnknownPartitioning} +import org.apache.spark.sql.catalyst.util.ParseModes import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat => ParquetSource} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.{BaseRelation, Filter} -import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils trait DataSourceScanExec extends LeafExecNode with CodegenSupport { @@ -156,6 +157,15 @@ case class FileSourceScanExec( false } + val parseMode = relation.options.getOrElse("mode", ParseModes.PERMISSIVE_MODE) + val corruptFieldIndex: Option[Int] = outputSchema.getFieldIndex(relation.options.getOrElse( + "columnNameOfCorruptRecord", relation.sparkSession.sessionState.conf.columnNameOfCorruptRecord)) + val requiredSchema = if (corruptFieldIndex.isDefined) { + StructType(outputSchema.indices.filter(_ != corruptFieldIndex.get).map(outputSchema)) + } else { + outputSchema + } + @transient private lazy val selectedPartitions = relation.location.listFiles(partitionFilters) override val (outputPartitioning, outputOrdering): (Partitioning, Seq[SortOrder]) = { @@ -254,7 +264,7 @@ case class FileSourceScanExec( sparkSession = relation.sparkSession, dataSchema = relation.dataSchema, partitionSchema = relation.partitionSchema, - requiredSchema = outputSchema, + requiredSchema = requiredSchema, filters = dataFilters, options = relation.options, hadoopConf = relation.sparkSession.sessionState.newHadoopConfWithOptions(relation.options)) @@ -370,7 +380,9 @@ case class FileSourceScanExec( FilePartition(bucketId, bucketed.getOrElse(bucketId, Nil)) } - new FileScanRDD(fsRelation.sparkSession, readFile, filePartitions) + new FileScanRDD(fsRelation.sparkSession, readFile, filePartitions, + StructType(outputSchema ++ relation.partitionSchema), + new DataSourceReader(parseMode, requiredSchema.length, corruptFieldIndex)) } /** @@ -444,7 +456,9 @@ case class FileSourceScanExec( } closePartition() - new FileScanRDD(fsRelation.sparkSession, readFile, partitions) + new FileScanRDD(fsRelation.sparkSession, readFile, partitions, + StructType(outputSchema ++ relation.partitionSchema), + new DataSourceReader(parseMode, requiredSchema.length, corruptFieldIndex)) } private def getBlockLocations(file: FileStatus): Array[BlockLocation] = file match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceReader.scala new file mode 100644 index 0000000000000..58433e3007449 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceReader.scala @@ -0,0 +1,269 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import java.util.NoSuchElementException + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow +import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.types.{DataType, Decimal, StringType, StructType} +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} +import org.apache.spark.util.NextIterator + +class DataSourceReader(mode: String, numFields: Int, corruptFieldIndex: Option[Int]) + extends Serializable { + private val emptyRow = new GenericInternalRow(numFields) + + def read(data: Iterator[Object]): Iterator[Object] = { + new NextIterator[Object] { + private val getResultRow: (Object, () => UTF8String) => Object = { + if (corruptFieldIndex.isDefined) { + val resultRow = new RowWithBadRecord(null, corruptFieldIndex.get, null) + (row, badRecord) => { + resultRow.row = row.asInstanceOf[InternalRow] + resultRow.record = badRecord() + resultRow + } + } else { + (row, badRecord) => row + } + } + + override protected def getNext(): Object = { + try { + getResultRow(data.next(), () => null) + } catch { + case e: BadRecordException if ParseModes.isPermissiveMode(mode) => + getResultRow(e.partialResult().getOrElse(emptyRow), e.record) + case _: BadRecordException if ParseModes.isDropMalformedMode(mode) => + getNext() + case _: NoSuchElementException => + finished = true + null + } + } + + override protected def close(): Unit = {} + } + } +} + +object DataSourceReader { + def apply( + dataSchema: StructType, + options: Map[String, String], + defaultColumnNameOfCorruptRecord: String): DataSourceReader = { + val caseInsensitiveOptions = CaseInsensitiveMap[String](options) + val mode = caseInsensitiveOptions.getOrElse("mode", ParseModes.PERMISSIVE_MODE) + val corruptFieldIndex = dataSchema.getFieldIndex(caseInsensitiveOptions.getOrElse( + "columnNameOfCorruptRecord", defaultColumnNameOfCorruptRecord)) + new DataSourceReader(mode, dataSchema.length - corruptFieldIndex.size, corruptFieldIndex) + } +} + +class RowWithBadRecord(var row: InternalRow, index: Int, var record: UTF8String) + extends InternalRow { + override def numFields: Int = row.numFields + 1 + + override def setNullAt(ordinal: Int): Unit = { + if (ordinal < index) { + row.setNullAt(ordinal) + } else if (ordinal == index) { + record = null + } else { + row.setNullAt(ordinal - 1) + } + } + + override def update(i: Int, value: Any): Unit = { + throw new UnsupportedOperationException("update") + } + + override def copy(): InternalRow = new RowWithBadRecord(row.copy(), index, record) + + override def anyNull: Boolean = row.anyNull || record == null + + override def isNullAt(ordinal: Int): Boolean = { + if (ordinal < index) { + row.isNullAt(ordinal) + } else if (ordinal == index) { + record == null + } else { + row.isNullAt(ordinal - 1) + } + } + + private def fail() = { + throw new IllegalAccessError("This is a string field.") + } + + override def getBoolean(ordinal: Int): Boolean = { + if (ordinal < index) { + row.getBoolean(ordinal) + } else if (ordinal == index) { + fail() + } else { + row.getBoolean(ordinal - 1) + } + } + + override def getByte(ordinal: Int): Byte = { + if (ordinal < index) { + row.getByte(ordinal) + } else if (ordinal == index) { + fail() + } else { + row.getByte(ordinal - 1) + } + } + + override def getShort(ordinal: Int): Short = { + if (ordinal < index) { + row.getShort(ordinal) + } else if (ordinal == index) { + fail() + } else { + row.getShort(ordinal - 1) + } + } + + override def getInt(ordinal: Int): Int = { + if (ordinal < index) { + row.getInt(ordinal) + } else if (ordinal == index) { + fail() + } else { + row.getInt(ordinal - 1) + } + } + + override def getLong(ordinal: Int): Long = { + if (ordinal < index) { + row.getLong(ordinal) + } else if (ordinal == index) { + fail() + } else { + row.getLong(ordinal - 1) + } + } + + override def getFloat(ordinal: Int): Float = { + if (ordinal < index) { + row.getFloat(ordinal) + } else if (ordinal == index) { + fail() + } else { + row.getFloat(ordinal - 1) + } + } + + override def getDouble(ordinal: Int): Double = { + if (ordinal < index) { + row.getDouble(ordinal) + } else if (ordinal == index) { + fail() + } else { + row.getDouble(ordinal - 1) + } + } + + override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = { + if (ordinal < index) { + row.getDecimal(ordinal, precision, scale) + } else if (ordinal == index) { + fail() + } else { + row.getDecimal(ordinal - 1, precision, scale) + } + } + + override def getUTF8String(ordinal: Int): UTF8String = { + if (ordinal < index) { + row.getUTF8String(ordinal) + } else if (ordinal == index) { + record + } else { + row.getUTF8String(ordinal - 1) + } + } + + override def getBinary(ordinal: Int): Array[Byte] = { + if (ordinal < index) { + row.getBinary(ordinal) + } else if (ordinal == index) { + fail() + } else { + row.getBinary(ordinal - 1) + } + } + + override def getInterval(ordinal: Int): CalendarInterval = { + if (ordinal < index) { + row.getInterval(ordinal) + } else if (ordinal == index) { + fail() + } else { + row.getInterval(ordinal - 1) + } + } + + override def getStruct(ordinal: Int, numFields: Int): InternalRow = { + if (ordinal < index) { + row.getStruct(ordinal, numFields) + } else if (ordinal == index) { + fail() + } else { + row.getStruct(ordinal - 1, numFields) + } + } + + override def getArray(ordinal: Int): ArrayData = { + if (ordinal < index) { + row.getArray(ordinal) + } else if (ordinal == index) { + fail() + } else { + row.getArray(ordinal - 1) + } + } + + override def getMap(ordinal: Int): MapData = { + if (ordinal < index) { + row.getMap(ordinal) + } else if (ordinal == index) { + fail() + } else { + row.getMap(ordinal - 1) + } + } + + override def get(ordinal: Int, dataType: DataType): AnyRef = { + if (ordinal < index) { + row.get(ordinal, dataType) + } else if (ordinal == index) { + if (dataType == StringType) { + record + } else { + fail() + } + } else { + row.get(ordinal - 1, dataType) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala index a89d172a911ab..e630a6679e38b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala @@ -26,7 +26,9 @@ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.{InputFileBlockHolder, RDD} import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.UnsafeProjection import org.apache.spark.sql.execution.vectorized.ColumnarBatch +import org.apache.spark.sql.types.StructType import org.apache.spark.util.NextIterator /** @@ -62,7 +64,9 @@ case class FilePartition(index: Int, files: Seq[PartitionedFile]) extends RDDPar class FileScanRDD( @transient private val sparkSession: SparkSession, readFunction: (PartitionedFile) => Iterator[InternalRow], - @transient val filePartitions: Seq[FilePartition]) + @transient val filePartitions: Seq[FilePartition], + fullSchema: StructType, + dataSourceReader: DataSourceReader) extends RDD[InternalRow](sparkSession.sparkContext, Nil) { private val ignoreCorruptFiles = sparkSession.sessionState.conf.ignoreCorruptFiles @@ -135,6 +139,8 @@ class FileScanRDD( } } + private val projection = UnsafeProjection.create(fullSchema) + /** Advances to the next file. Returns true if a new non-empty iterator is available. */ private def nextIterator(): Boolean = { updateBytesReadWithFileSize() @@ -144,8 +150,8 @@ class FileScanRDD( // Sets InputFileBlockHolder for the file block's information InputFileBlockHolder.set(currentFile.filePath, currentFile.start, currentFile.length) - if (ignoreCorruptFiles) { - currentIterator = new NextIterator[Object] { + val recordIterator = if (ignoreCorruptFiles) { + new NextIterator[Object] { // The readFunction may read some bytes before consuming the iterator, e.g., // vectorized Parquet reader. Here we use lazy val to delay the creation of // iterator so that we will throw exception in `getNext`. @@ -173,9 +179,16 @@ class FileScanRDD( override def close(): Unit = {} } } else { - currentIterator = readCurrentFile() + readCurrentFile() } + currentIterator = dataSourceReader.read(recordIterator).map { obj => + if (obj.isInstanceOf[InternalRow]) { + projection(obj.asInstanceOf[InternalRow]) + } else { + obj + } + } hasNext } else { currentFile = null diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala index e42ea3fa391f5..e574dd0fac827 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala @@ -30,7 +30,7 @@ import com.univocity.parsers.csv.CsvParser import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericInternalRow -import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.util.{BadRecordException, DateTimeUtils} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -46,39 +46,23 @@ class UnivocityParser( // A `ValueConverter` is responsible for converting the given value to a desired type. private type ValueConverter = String => Any - private val corruptFieldIndex = schema.getFieldIndex(options.columnNameOfCorruptRecord) - corruptFieldIndex.foreach { corrFieldIndex => - require(schema(corrFieldIndex).dataType == StringType) - require(schema(corrFieldIndex).nullable) - } - - private val dataSchema = StructType(schema.filter(_.name != options.columnNameOfCorruptRecord)) - private val tokenizer = new CsvParser(options.asParserSettings) - private var numMalformedRecords = 0 - private val row = new GenericInternalRow(requiredSchema.length) - // In `PERMISSIVE` parse mode, we should be able to put the raw malformed row into the field - // specified in `columnNameOfCorruptRecord`. The raw input is retrieved by this method. + // Retrieve the raw malformed row. private def getCurrentInput(): String = tokenizer.getContext.currentParsedContent().stripLineEnd // This parser loads an `tokenIndexArr`-th position value in input tokens, - // then put the value in `row(rowIndexArr)`. + // then put the value in row. // // For example, let's say there is CSV data as below: // // a,b,c // 1,2,A // - // Also, let's say `columnNameOfCorruptRecord` is set to "_unparsed", `header` is `true` - // by user and the user selects "c", "b", "_unparsed" and "a" fields. In this case, we need - // to map those values below: - // - // required schema - ["c", "b", "_unparsed", "a"] + // required schema - ["c", "b"] // CSV data schema - ["a", "b", "c"] - // required CSV data schema - ["c", "b", "a"] // // with the input tokens, // @@ -86,45 +70,22 @@ class UnivocityParser( // // Each input token is placed in each output row's position by mapping these. In this case, // - // output row - ["A", 2, null, 1] + // output row - ["A", 2] // // In more details, // - `valueConverters`, input tokens - CSV data schema - // `valueConverters` keeps the positions of input token indices (by its index) to each + // `valueConverters` keeps the positions of input token indices (by its index) to required // value's converter (by its value) in an order of CSV data schema. In this case, - // [string->int, string->int, string->string]. + // [string->string, string->int]. // // - `tokenIndexArr`, input tokens - required CSV data schema - // `tokenIndexArr` keeps the positions of input token indices (by its index) to reordered - // fields given the required CSV data schema (by its value). In this case, [2, 1, 0]. - // - // - `rowIndexArr`, input tokens - required schema - // `rowIndexArr` keeps the positions of input token indices (by its index) to reordered - // field indices given the required schema (by its value). In this case, [0, 1, 3]. + // `tokenIndexArr` keeps the positions of input token indices (by its index) to required + // fields given the required CSV data schema (by its value). In this case, [2, 1]. private val valueConverters: Array[ValueConverter] = - dataSchema.map(f => makeConverter(f.name, f.dataType, f.nullable, options)).toArray - - // Only used to create both `tokenIndexArr` and `rowIndexArr`. This variable means - // the fields that we should try to convert. - private val reorderedFields = if (options.dropMalformed) { - // If `dropMalformed` is enabled, then it needs to parse all the values - // so that we can decide which row is malformed. - requiredSchema ++ schema.filterNot(requiredSchema.contains(_)) - } else { - requiredSchema - } + requiredSchema.map(f => makeConverter(f.name, f.dataType, f.nullable, options)).toArray private val tokenIndexArr: Array[Int] = { - reorderedFields - .filter(_.name != options.columnNameOfCorruptRecord) - .map(f => dataSchema.indexOf(f)).toArray - } - - private val rowIndexArr: Array[Int] = if (corruptFieldIndex.isDefined) { - val corrFieldIndex = corruptFieldIndex.get - reorderedFields.indices.filter(_ != corrFieldIndex).toArray - } else { - reorderedFields.indices.toArray + requiredSchema.map(f => schema.indexOf(f)).toArray } /** @@ -205,7 +166,7 @@ class UnivocityParser( } case _: StringType => (d: String) => - nullSafeDatum(d, name, nullable, options)(UTF8String.fromString(_)) + nullSafeDatum(d, name, nullable, options)(UTF8String.fromString) case udt: UserDefinedType[_] => (datum: String) => makeConverter(name, udt.sqlType, nullable, options) @@ -233,82 +194,36 @@ class UnivocityParser( * Parses a single CSV string and turns it into either one resulting row or no row (if the * the record is malformed). */ - def parse(input: String): Option[InternalRow] = convert(tokenizer.parseLine(input)) + def parse(input: String): InternalRow = convert(tokenizer.parseLine(input)) - private def convert(tokens: Array[String]): Option[InternalRow] = { - convertWithParseMode(tokens) { tokens => - var i: Int = 0 - while (i < tokenIndexArr.length) { - // It anyway needs to try to parse since it decides if this row is malformed - // or not after trying to cast in `DROPMALFORMED` mode even if the casted - // value is not stored in the row. - val from = tokenIndexArr(i) - val to = rowIndexArr(i) - val value = valueConverters(from).apply(tokens(from)) - if (i < requiredSchema.length) { - row(to) = value - } - i += 1 - } - row - } - } - - private def convertWithParseMode( - tokens: Array[String])(convert: Array[String] => InternalRow): Option[InternalRow] = { - if (options.dropMalformed && dataSchema.length != tokens.length) { - if (numMalformedRecords < options.maxMalformedLogPerPartition) { - logWarning(s"Dropping malformed line: ${tokens.mkString(options.delimiter.toString)}") - } - if (numMalformedRecords == options.maxMalformedLogPerPartition - 1) { - logWarning( - s"More than ${options.maxMalformedLogPerPartition} malformed records have been " + - "found on this partition. Malformed records from now on will not be logged.") - } - numMalformedRecords += 1 - None - } else if (options.failFast && dataSchema.length != tokens.length) { - throw new RuntimeException(s"Malformed line in FAILFAST mode: " + - s"${tokens.mkString(options.delimiter.toString)}") - } else { + private def convert(tokens: Array[String]): InternalRow = { + if (tokens.length != schema.length) { // If a length of parsed tokens is not equal to expected one, it makes the length the same // with the expected. If the length is shorter, it adds extra tokens in the tail. // If longer, it drops extra tokens. - // - // TODO: Revisit this; if a length of tokens does not match an expected length in the schema, - // we probably need to treat it as a malformed record. - // See an URL below for related discussions: - // https://github.com/apache/spark/pull/16928#discussion_r102657214 - val checkedTokens = if (options.permissive && dataSchema.length != tokens.length) { - if (dataSchema.length > tokens.length) { - tokens ++ new Array[String](dataSchema.length - tokens.length) - } else { - tokens.take(dataSchema.length) - } + val checkedTokens = if (schema.length > tokens.length) { + tokens ++ new Array[String](schema.length - tokens.length) } else { - tokens + tokens.take(schema.length) } - - try { - Some(convert(checkedTokens)) - } catch { - case NonFatal(e) if options.permissive => - val row = new GenericInternalRow(requiredSchema.length) - corruptFieldIndex.foreach(row(_) = UTF8String.fromString(getCurrentInput())) - Some(row) - case NonFatal(e) if options.dropMalformed => - if (numMalformedRecords < options.maxMalformedLogPerPartition) { - logWarning("Parse exception. " + - s"Dropping malformed line: ${tokens.mkString(options.delimiter.toString)}") - } - if (numMalformedRecords == options.maxMalformedLogPerPartition - 1) { - logWarning( - s"More than ${options.maxMalformedLogPerPartition} malformed records have been " + - "found on this partition. Malformed records from now on will not be logged.") - } - numMalformedRecords += 1 - None + def getPartialResult(): Option[InternalRow] = { + try { + Some(convert(checkedTokens)) + } catch { + case NonFatal(e) => None + } } + throw BadRecordException( + () => UTF8String.fromString(getCurrentInput()), + getPartialResult, + new RuntimeException("Malformed CSV line")) + } else { + var i: Int = 0 + while (i < requiredSchema.length) { + row(i) = valueConverters(i).apply(tokens(tokenIndexArr(i))) + i += 1 + } + row } } } @@ -335,7 +250,7 @@ private[csv] object UnivocityParser { val tokenizer = parser.tokenizer convertStream(inputStream, shouldDropHeader, tokenizer) { tokens => parser.convert(tokens) - }.flatten + } } private def convertStream[T]( @@ -381,6 +296,6 @@ private[csv] object UnivocityParser { val filteredLines: Iterator[String] = CSVUtils.filterCommentAndEmpty(linesWithoutHeader, options) - filteredLines.flatMap(line => parser.parse(line)) + filteredLines.map(line => parser.parse(line)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala index f36162858bf7a..519d14ec3229b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala @@ -285,8 +285,8 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi val fakeRDD = new FileScanRDD( spark, (file: PartitionedFile) => Iterator.empty, - Seq(partition) - ) + Seq(partition), + null, null) assertResult(Set("host0", "host1", "host2")) { fakeRDD.preferredLocations(partition).toSet diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 95dfdf5b298e6..db1a79cecc528 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -293,7 +293,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { .load(testFile(carsFile)).collect() } - assert(exception.getMessage.contains("Malformed line in FAILFAST mode: 2015,Chevy,Volt")) + assert(exception.getMessage.contains("Malformed CSV line")) } }