diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala index bf0f90aab21d2..b1bf5e9a2a12b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala @@ -283,15 +283,15 @@ class UnivocityParser( i += 1 } - var skipValueConversion = false + var skipRow = false var badRecordException: Option[Throwable] = None i = 0 - while (!skipValueConversion && i < parsedSchema.length) { + while (!skipRow && i < parsedSchema.length) { try { val convertedValue = valueConverters(i).apply(getToken(tokens, i)) parsedRow(i) = convertedValue if (csvFilters.skipRow(parsedRow, i)) { - skipValueConversion = true + skipRow = true } else { val requiredIndex = parsedToRequiredIndex(i) if (requiredIndex != -1) { @@ -300,20 +300,20 @@ class UnivocityParser( } } catch { case NonFatal(e) => - badRecordException = Some(e) - skipValueConversion = true + badRecordException = badRecordException.orElse(Some(e)) + requiredSingleRow.setNullAt(i) } i += 1 } - if (skipValueConversion) { + if (skipRow) { + noRows + } else { if (badRecordException.isDefined) { throw BadRecordException( () => getCurrentInput, () => requiredRow.headOption, badRecordException.get) } else { - noRows + requiredRow } - } else { - requiredRow } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index a5bf47337f800..80f5bb4c86377 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -36,7 +36,7 @@ import org.apache.log4j.{AppenderSkeleton, LogManager} import org.apache.log4j.spi.LoggingEvent import org.apache.spark.{SparkException, TestUtils} -import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row} +import org.apache.spark.sql.{AnalysisException, Column, DataFrame, QueryTest, Row} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession @@ -2242,4 +2242,41 @@ class CSVSuite extends QueryTest with SharedSparkSession with TestCsvData { } } } + + test("filters push down - malformed input in PERMISSIVE mode") { + val invalidTs = "2019-123-14 20:35:30" + val invalidRow = s"0,$invalidTs,999" + val validTs = "2019-12-14 20:35:30" + Seq(true, false).foreach { filterPushdown => + withSQLConf(SQLConf.CSV_FILTER_PUSHDOWN_ENABLED.key -> filterPushdown.toString) { + withTempPath { path => + Seq( + "c0,c1,c2", + invalidRow, + s"1,$validTs,999").toDF("data") + .repartition(1) + .write.text(path.getAbsolutePath) + def checkReadback(condition: Column, expected: Seq[Row]): Unit = { + val readback = spark.read + .option("mode", "PERMISSIVE") + .option("columnNameOfCorruptRecord", "c3") + .option("header", true) + .option("timestampFormat", "uuuu-MM-dd HH:mm:ss") + .schema("c0 integer, c1 timestamp, c2 integer, c3 string") + .csv(path.getAbsolutePath) + .where(condition) + .select($"c0", $"c1", $"c3") + checkAnswer(readback, expected) + } + + checkReadback( + condition = $"c2" === 999, + expected = Seq(Row(0, null, invalidRow), Row(1, Timestamp.valueOf(validTs), null))) + checkReadback( + condition = $"c2" === 999 && $"c1" > "1970-01-01 00:00:00", + expected = Seq(Row(1, Timestamp.valueOf(validTs), null))) + } + } + } + } }