diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 8048f4b9..7bae1b63 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -9,6 +9,7 @@ on: types: [ opened, synchronize, reopened ] branches: - main + - devel jobs: diff --git a/README.md b/README.md index c614930c..17a88b98 100644 --- a/README.md +++ b/README.md @@ -62,9 +62,9 @@ To use in external Spark cluster, submit your application with the following par ### SSL To use TLS secured connections to ArangoDB, set `ssl.enabled` to `true` and either: +- provide base64 encoded certificate as `ssl.cert.value` configuration entry and optionally set `ssl.*`, or - start Spark driver and workers with properly configured JVM default TrustStore, see [link](https://spark.apache.org/docs/latest/security.html#ssl-configuration) -- provide base64 encoded certificate as `ssl.cert.value` configuration entry and optionally set `ssl.*`, or ### Supported deployment topologies @@ -140,7 +140,15 @@ usersDF.filter(col("birthday") === "1982-12-15").show() - `batch.size`: reading batch size, default `1000` - `fill.cache`: whether the query should store the data it reads in the RocksDB block cache (`true`|`false`) - `stream`: whether the query should be executed lazily, default `true` - +- `mode`: allows a mode for dealing with corrupt records during parsing: + - `PERMISSIVE` : when it meets a corrupted record, puts the malformed string into a field configured by + `columnNameOfCorruptRecord`, and sets malformed fields to null. To keep corrupt records, a user can set a string + type field named columnNameOfCorruptRecord in a user-defined schema. If a schema does not have the field, it drops + corrupt records during parsing. When inferring a schema, it implicitly adds a `columnNameOfCorruptRecord` field in + an output schema + - `DROPMALFORMED`: ignores the whole corrupted records + - `FAILFAST`: throws an exception when it meets corrupted records +- `columnNameOfCorruptRecord`: allows renaming the new field having malformed string created by `PERMISSIVE` mode ### Predicate and Projection Pushdown @@ -255,7 +263,8 @@ fail. To makes the job more resilient to temporary errors (i.e. connectivity pro will be retried (with another coordinator) if the configured `overwrite.mode` allows for idempotent requests, namely: - `replace` - `ignore` -- `update` +- `update` with `keep.null=true` + These configurations of `overwrite.mode` would also be compatible with speculative execution of tasks. A failing batch-saving request is retried at most once for every coordinator. After that, if still failing, the write @@ -343,8 +352,12 @@ df.write ## Current limitations -- on batch reading, bad records are not tolerated and will make the job fail -- in read jobs using `stream=true` (default), possible AQL warnings are only logged at the end of each the read task +- In Spark 2.4, on corrupted records in batch reading, partial results are not supported. All fields other than the + field configured by `columnNameOfCorruptRecord` are set to `null` +- in read jobs using `stream=true` (default), possible AQL warnings are only logged at the end of each read task (BTS-671) +- for `content-type=vpack`, implicit deserialization casts don't work well, i.e. reading a document having a field with + a numeric value whereas the related read schema requires a string value for such field +- dates and timestamps fields are interpreted to be in UTC time zone ## Demo diff --git a/arangodb-spark-commons/src/main/scala/org/apache/spark/sql/arangodb/commons/ArangoClient.scala b/arangodb-spark-commons/src/main/scala/org/apache/spark/sql/arangodb/commons/ArangoClient.scala index bd147943..94e99f71 100644 --- a/arangodb-spark-commons/src/main/scala/org/apache/spark/sql/arangodb/commons/ArangoClient.scala +++ b/arangodb-spark-commons/src/main/scala/org/apache/spark/sql/arangodb/commons/ArangoClient.scala @@ -12,10 +12,10 @@ import com.fasterxml.jackson.databind.ObjectMapper import com.fasterxml.jackson.module.scala.DefaultScalaModule import org.apache.spark.internal.Logging import org.apache.spark.sql.arangodb.commons.exceptions.ArangoDBMultiException -import org.apache.spark.sql.arangodb.commons.utils.PushDownCtx +import org.apache.spark.sql.arangodb.commons.filter.PushableFilter +import org.apache.spark.sql.types.StructType -import java.util -import scala.collection.JavaConverters.{asScalaIteratorConverter, mapAsJavaMapConverter} +import scala.collection.JavaConverters.mapAsJavaMapConverter class ArangoClient(options: ArangoOptions) extends Logging { @@ -38,12 +38,12 @@ class ArangoClient(options: ArangoOptions) extends Logging { def shutdown(): Unit = arangoDB.shutdown() - def readCollectionPartition(shardId: String, ctx: PushDownCtx): ArangoCursor[VPackSlice] = { + def readCollectionPartition(shardId: String, filters: Array[PushableFilter], schema: StructType): ArangoCursor[VPackSlice] = { val query = s""" |FOR d IN @@col - |${PushdownUtils.generateFilterClause(ctx.filters)} - |RETURN ${PushdownUtils.generateColumnsFilter(ctx.requiredSchema, "d")}""" + |${PushdownUtils.generateFilterClause(filters)} + |RETURN ${PushdownUtils.generateColumnsFilter(schema, "d")}""" .stripMargin .replaceAll("\n", " ") val params = Map[String, AnyRef]("@col" -> options.readOptions.collection.get) @@ -65,7 +65,7 @@ class ArangoClient(options: ArangoOptions) extends Logging { classOf[VPackSlice]) } - def readCollectionSample(): util.List[String] = { + def readCollectionSample(): Seq[String] = { val query = "FOR d IN @@col LIMIT @size RETURN d" val params = Map( "@col" -> options.readOptions.collection.get, @@ -74,22 +74,28 @@ class ArangoClient(options: ArangoOptions) extends Logging { .asInstanceOf[Map[String, AnyRef]] val opts = aqlOptions() logDebug(s"""Executing AQL query: \n\t$query ${if (params.nonEmpty) s"\n\t with params: $params" else ""}""") + + import scala.collection.JavaConverters.iterableAsScalaIterableConverter arangoDB .db(options.readOptions.db) .query(query, params.asJava, opts, classOf[String]) .asListRemaining() + .asScala + .toSeq } - def readQuerySample(): util.List[String] = { + def readQuerySample(): Seq[String] = { val query = options.readOptions.query.get logDebug(s"Executing AQL query: \n\t$query") - arangoDB + val cursor = arangoDB .db(options.readOptions.db) .query( query, aqlOptions(), classOf[String]) - .asListRemaining() + + import scala.collection.JavaConverters.asScalaIteratorConverter + cursor.asScala.take(options.readOptions.sampleSize).toSeq } def collectionExists(): Boolean = arangoDB @@ -137,6 +143,7 @@ class ArangoClient(options: ArangoOptions) extends Logging { request.setBody(data) val response = arangoDB.execute(request) + import scala.collection.JavaConverters.asScalaIteratorConverter // FIXME // in case there are no errors, response body is an empty object // In cluster 3.8.1 this is not true due to: https://arangodb.atlassian.net/browse/BTS-592 diff --git a/arangodb-spark-commons/src/main/scala/org/apache/spark/sql/arangodb/commons/ArangoOptions.scala b/arangodb-spark-commons/src/main/scala/org/apache/spark/sql/arangodb/commons/ArangoOptions.scala index 7b1c2233..2874d795 100644 --- a/arangodb-spark-commons/src/main/scala/org/apache/spark/sql/arangodb/commons/ArangoOptions.scala +++ b/arangodb-spark-commons/src/main/scala/org/apache/spark/sql/arangodb/commons/ArangoOptions.scala @@ -22,12 +22,13 @@ package org.apache.spark.sql.arangodb.commons import com.arangodb.{ArangoDB, entity} import com.arangodb.model.OverwriteMode +import org.apache.spark.sql.catalyst.util.{ParseMode, PermissiveMode} import java.io.ByteArrayInputStream import java.security.KeyStore import java.security.cert.CertificateFactory import java.util -import java.util.Base64 +import java.util.{Base64, Locale} import javax.net.ssl.{SSLContext, TrustManagerFactory} import scala.collection.JavaConverters.mapAsScalaMapConverter @@ -35,7 +36,9 @@ import scala.collection.JavaConverters.mapAsScalaMapConverter /** * @author Michele Rastelli */ -class ArangoOptions(private val options: Map[String, String]) extends Serializable { +class ArangoOptions(opts: Map[String, String]) extends Serializable { + private val options: Map[String, String] = opts.map(e => (e._1.toLowerCase(Locale.US), e._2)) + lazy val driverOptions: ArangoDriverOptions = new ArangoDriverOptions(options) lazy val readOptions: ArangoReadOptions = new ArangoReadOptions(options) lazy val writeOptions: ArangoWriteOptions = new ArangoWriteOptions(options) @@ -85,6 +88,8 @@ object ArangoOptions { val SAMPLE_SIZE = "sample.size" val FILL_BLOCK_CACHE = "fill.cache" val STREAM = "stream" + val PARSE_MODE = "mode" + val CORRUPT_RECORDS_COLUMN = "columnnameofcorruptrecord" // write options val NUMBER_OF_SHARDS = "table.shards" @@ -175,6 +180,8 @@ class ArangoReadOptions(options: Map[String, String]) extends CommonOptions(opti else throw new IllegalArgumentException("Either collection or query must be defined") val fillBlockCache: Option[Boolean] = options.get(ArangoOptions.FILL_BLOCK_CACHE).map(_.toBoolean) val stream: Boolean = options.getOrElse(ArangoOptions.STREAM, "true").toBoolean + val parseMode: ParseMode = options.get(ArangoOptions.PARSE_MODE).map(ParseMode.fromString).getOrElse(PermissiveMode) + val columnNameOfCorruptRecord: String = options.getOrElse(ArangoOptions.CORRUPT_RECORDS_COLUMN, "") } class ArangoWriteOptions(options: Map[String, String]) extends CommonOptions(options) { diff --git a/arangodb-spark-commons/src/main/scala/org/apache/spark/sql/arangodb/commons/ArangoUtils.scala b/arangodb-spark-commons/src/main/scala/org/apache/spark/sql/arangodb/commons/ArangoUtils.scala index efe65319..b486b9e4 100644 --- a/arangodb-spark-commons/src/main/scala/org/apache/spark/sql/arangodb/commons/ArangoUtils.scala +++ b/arangodb-spark-commons/src/main/scala/org/apache/spark/sql/arangodb/commons/ArangoUtils.scala @@ -1,6 +1,6 @@ package org.apache.spark.sql.arangodb.commons -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{StringType, StructField, StructType} import org.apache.spark.sql.{Encoders, SparkSession} /** @@ -17,10 +17,15 @@ object ArangoUtils { client.shutdown() val spark = SparkSession.getActiveSession.get - spark + val schema = spark .read .json(spark.createDataset(sampleEntries)(Encoders.STRING)) .schema + + if (options.readOptions.columnNameOfCorruptRecord.isEmpty) + schema + else + schema.add(StructField(options.readOptions.columnNameOfCorruptRecord, StringType, nullable = true)) } } diff --git a/arangodb-spark-datasource-2.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/ArangoParserImpl.scala b/arangodb-spark-datasource-2.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/ArangoParserImpl.scala index 6b7016ec..d2fe44d0 100644 --- a/arangodb-spark-datasource-2.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/ArangoParserImpl.scala +++ b/arangodb-spark-datasource-2.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/ArangoParserImpl.scala @@ -42,5 +42,5 @@ class VPackArangoParser(schema: DataType) extends ArangoParserImpl( schema, createOptions(new VPackFactory()), - (bytes: Array[Byte]) => UTF8String.fromString(new VPackParser.Builder().build().toJson(new VPackSlice(bytes))) + (bytes: Array[Byte]) => UTF8String.fromString(new VPackParser.Builder().build().toJson(new VPackSlice(bytes), true)) ) diff --git a/arangodb-spark-datasource-2.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/reader/ArangoCollectionPartitionReader.scala b/arangodb-spark-datasource-2.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/reader/ArangoCollectionPartitionReader.scala index 90ebca34..f8a4f993 100644 --- a/arangodb-spark-datasource-2.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/reader/ArangoCollectionPartitionReader.scala +++ b/arangodb-spark-datasource-2.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/reader/ArangoCollectionPartitionReader.scala @@ -1,15 +1,17 @@ package org.apache.spark.sql.arangodb.datasource.reader import com.arangodb.entity.CursorEntity.Warning -import com.arangodb.velocypack.VPackSlice import org.apache.spark.internal.Logging import org.apache.spark.sql.arangodb.commons.mapping.ArangoParserProvider import org.apache.spark.sql.arangodb.commons.utils.PushDownCtx import org.apache.spark.sql.arangodb.commons.{ArangoClient, ArangoOptions, ContentType} import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.datasources.FailureSafeParser import org.apache.spark.sql.sources.v2.reader.InputPartitionReader +import org.apache.spark.sql.types.StructType import java.nio.charset.StandardCharsets +import scala.annotation.tailrec import scala.collection.JavaConverters.iterableAsScalaIterableConverter @@ -21,19 +23,31 @@ class ArangoCollectionPartitionReader( // override endpoints with partition endpoint private val options = opts.updated(ArangoOptions.ENDPOINTS, inputPartition.endpoint) - private val parser = ArangoParserProvider().of(options.readOptions.contentType, ctx.requiredSchema) + private val actualSchema = StructType(ctx.requiredSchema.filterNot(_.name == options.readOptions.columnNameOfCorruptRecord)) + private val parser = ArangoParserProvider().of(options.readOptions.contentType, actualSchema) + private val safeParser = new FailureSafeParser[Array[Byte]]( + parser.parse(_).toSeq, + options.readOptions.parseMode, + ctx.requiredSchema, + options.readOptions.columnNameOfCorruptRecord) private val client = ArangoClient(options) - private val iterator = client.readCollectionPartition(inputPartition.shardId, ctx) + private val iterator = client.readCollectionPartition(inputPartition.shardId, ctx.filters, actualSchema) - private var current: VPackSlice = _ + var rowIterator: Iterator[InternalRow] = _ // warnings of non stream AQL cursors are all returned along with the first batch if (!options.readOptions.stream) logWarns() - override def next: Boolean = + @tailrec + final override def next: Boolean = if (iterator.hasNext) { - current = iterator.next() - true + val current = iterator.next() + rowIterator = safeParser.parse(options.readOptions.contentType match { + case ContentType.VPack => current.toByteArray + case ContentType.Json => current.toString.getBytes(StandardCharsets.UTF_8) + }) + if (rowIterator.hasNext) true + else next } else { // FIXME: https://arangodb.atlassian.net/browse/BTS-671 // stream AQL cursors' warnings are only returned along with the final batch @@ -41,10 +55,7 @@ class ArangoCollectionPartitionReader( false } - override def get: InternalRow = options.readOptions.contentType match { - case ContentType.VPack => parser.parse(current.toByteArray).head - case ContentType.Json => parser.parse(current.toString.getBytes(StandardCharsets.UTF_8)).head - } + override def get: InternalRow = rowIterator.next() override def close(): Unit = { iterator.close() diff --git a/arangodb-spark-datasource-2.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/reader/ArangoDataSourceReader.scala b/arangodb-spark-datasource-2.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/reader/ArangoDataSourceReader.scala index a8865123..fbd3ab9f 100644 --- a/arangodb-spark-datasource-2.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/reader/ArangoDataSourceReader.scala +++ b/arangodb-spark-datasource-2.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/reader/ArangoDataSourceReader.scala @@ -1,13 +1,14 @@ package org.apache.spark.sql.arangodb.datasource.reader import org.apache.spark.internal.Logging +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.arangodb.commons.filter.{FilterSupport, PushableFilter} import org.apache.spark.sql.arangodb.commons.utils.PushDownCtx import org.apache.spark.sql.arangodb.commons.{ArangoClient, ArangoOptions, ReadMode} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.sources.v2.reader.{DataSourceReader, InputPartition, SupportsPushDownFilters, SupportsPushDownRequiredColumns} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{StringType, StructType} import java.util import scala.collection.JavaConverters.seqAsJavaListConverter @@ -17,6 +18,8 @@ class ArangoDataSourceReader(tableSchema: StructType, options: ArangoOptions) ex with SupportsPushDownRequiredColumns with Logging { + verifyColumnNameOfCorruptRecord(tableSchema, options.readOptions.columnNameOfCorruptRecord) + // fully or partially applied filters private var appliedPushableFilters: Array[PushableFilter] = Array() private var appliedSparkFilters: Array[Filter] = Array() @@ -36,26 +39,30 @@ class ArangoDataSourceReader(tableSchema: StructType, options: ArangoOptions) ex .map(it => new ArangoCollectionPartition(it._1, it._2, new PushDownCtx(readSchema(), appliedPushableFilters), options)) override def pushFilters(filters: Array[Filter]): Array[Filter] = { + // filters related to columnNameOfCorruptRecord are not pushed down + val isCorruptRecordFilter = (f: Filter) => f.references.contains(options.readOptions.columnNameOfCorruptRecord) + val ignoredFilters = filters.filter(isCorruptRecordFilter) val filtersBySupport = filters + .filterNot(isCorruptRecordFilter) .map(f => (f, PushableFilter(f, tableSchema))) .groupBy(_._2.support()) val fullSupp = filtersBySupport.getOrElse(FilterSupport.FULL, Array()) val partialSupp = filtersBySupport.getOrElse(FilterSupport.PARTIAL, Array()) - val noneSupp = filtersBySupport.getOrElse(FilterSupport.NONE, Array()) + val noneSupp = filtersBySupport.getOrElse(FilterSupport.NONE, Array()).map(_._1) ++ ignoredFilters val appliedFilters = fullSupp ++ partialSupp appliedPushableFilters = appliedFilters.map(_._2) appliedSparkFilters = appliedFilters.map(_._1) if (fullSupp.nonEmpty) - logInfo(s"Fully supported filters (applied in AQL):\n\t${fullSupp.map(_._1).mkString("\n\t")}") + logInfo(s"Filters fully applied in AQL:\n\t${fullSupp.map(_._1).mkString("\n\t")}") if (partialSupp.nonEmpty) - logInfo(s"Partially supported filters (applied in AQL and Spark):\n\t${partialSupp.map(_._1).mkString("\n\t")}") + logInfo(s"Filters partially applied in AQL:\n\t${partialSupp.map(_._1).mkString("\n\t")}") if (noneSupp.nonEmpty) - logInfo(s"Not supported filters (applied in Spark):\n\t${noneSupp.map(_._1).mkString("\n\t")}") + logInfo(s"Filters not applied in AQL:\n\t${noneSupp.mkString("\n\t")}") - (partialSupp ++ noneSupp).map(_._1) + partialSupp.map(_._1) ++ noneSupp } override def pushedFilters(): Array[Filter] = appliedSparkFilters @@ -64,4 +71,20 @@ class ArangoDataSourceReader(tableSchema: StructType, options: ArangoOptions) ex this.requiredSchema = requiredSchema } + /** + * A convenient function for schema validation in datasources supporting + * `columnNameOfCorruptRecord` as an option. + */ + private def verifyColumnNameOfCorruptRecord( + schema: StructType, + columnNameOfCorruptRecord: String): Unit = { + schema.getFieldIndex(columnNameOfCorruptRecord).foreach { corruptFieldIndex => + val f = schema(corruptFieldIndex) + if (f.dataType != StringType || !f.nullable) { + throw new AnalysisException( + "The field for corrupt records must be string type and nullable") + } + } + } + } \ No newline at end of file diff --git a/arangodb-spark-datasource-2.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/reader/ArangoQueryReader.scala b/arangodb-spark-datasource-2.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/reader/ArangoQueryReader.scala index 870f1fc5..d403786f 100644 --- a/arangodb-spark-datasource-2.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/reader/ArangoQueryReader.scala +++ b/arangodb-spark-datasource-2.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/reader/ArangoQueryReader.scala @@ -1,34 +1,47 @@ package org.apache.spark.sql.arangodb.datasource.reader import com.arangodb.entity.CursorEntity.Warning -import com.arangodb.velocypack.VPackSlice import org.apache.spark.internal.Logging import org.apache.spark.sql.arangodb.commons.mapping.ArangoParserProvider import org.apache.spark.sql.arangodb.commons.{ArangoClient, ArangoOptions, ContentType} import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.datasources.FailureSafeParser import org.apache.spark.sql.sources.v2.reader.InputPartitionReader import org.apache.spark.sql.types._ import java.nio.charset.StandardCharsets +import scala.annotation.tailrec import scala.collection.JavaConverters.iterableAsScalaIterableConverter class ArangoQueryReader(schema: StructType, options: ArangoOptions) extends InputPartitionReader[InternalRow] with Logging { - private val parser = ArangoParserProvider().of(options.readOptions.contentType, schema) + private val actualSchema = StructType(schema.filterNot(_.name == options.readOptions.columnNameOfCorruptRecord)) + private val parser = ArangoParserProvider().of(options.readOptions.contentType, actualSchema) + private val safeParser = new FailureSafeParser[Array[Byte]]( + parser.parse(_).toSeq, + options.readOptions.parseMode, + schema, + options.readOptions.columnNameOfCorruptRecord) private val client = ArangoClient(options) private val iterator = client.readQuery() - private var current: VPackSlice = _ + var rowIterator: Iterator[InternalRow] = _ // warnings of non stream AQL cursors are all returned along with the first batch if (!options.readOptions.stream) logWarns() - override def next: Boolean = + @tailrec + final override def next: Boolean = if (iterator.hasNext) { - current = iterator.next() - true + val current = iterator.next() + rowIterator = safeParser.parse(options.readOptions.contentType match { + case ContentType.VPack => current.toByteArray + case ContentType.Json => current.toString.getBytes(StandardCharsets.UTF_8) + }) + if (rowIterator.hasNext) true + else next } else { // FIXME: https://arangodb.atlassian.net/browse/BTS-671 // stream AQL cursors' warnings are only returned along with the final batch @@ -36,10 +49,7 @@ class ArangoQueryReader(schema: StructType, options: ArangoOptions) extends Inpu false } - override def get: InternalRow = options.readOptions.contentType match { - case ContentType.VPack => parser.parse(current.toByteArray).head - case ContentType.Json => parser.parse(current.toString.getBytes(StandardCharsets.UTF_8)).head - } + override def get: InternalRow = rowIterator.next() override def close(): Unit = { iterator.close() diff --git a/arangodb-spark-datasource-3.1/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/ArangoParserImpl.scala b/arangodb-spark-datasource-3.1/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/ArangoParserImpl.scala index 0393fb2b..6ed20c1b 100644 --- a/arangodb-spark-datasource-3.1/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/ArangoParserImpl.scala +++ b/arangodb-spark-datasource-3.1/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/ArangoParserImpl.scala @@ -44,5 +44,5 @@ class VPackArangoParser(schema: DataType) extends ArangoParserImpl( schema, createOptions(new VPackFactoryBuilder().build()), - (bytes: Array[Byte]) => UTF8String.fromString(new VPackParser.Builder().build().toJson(new VPackSlice(bytes))) + (bytes: Array[Byte]) => UTF8String.fromString(new VPackParser.Builder().build().toJson(new VPackSlice(bytes), true)) ) diff --git a/arangodb-spark-datasource-3.1/src/main/scala/org/apache/spark/sql/arangodb/datasource/reader/ArangoCollectionPartitionReader.scala b/arangodb-spark-datasource-3.1/src/main/scala/org/apache/spark/sql/arangodb/datasource/reader/ArangoCollectionPartitionReader.scala index a523f00c..122c2317 100644 --- a/arangodb-spark-datasource-3.1/src/main/scala/org/apache/spark/sql/arangodb/datasource/reader/ArangoCollectionPartitionReader.scala +++ b/arangodb-spark-datasource-3.1/src/main/scala/org/apache/spark/sql/arangodb/datasource/reader/ArangoCollectionPartitionReader.scala @@ -1,15 +1,17 @@ package org.apache.spark.sql.arangodb.datasource.reader import com.arangodb.entity.CursorEntity.Warning -import com.arangodb.velocypack.VPackSlice import org.apache.spark.internal.Logging import org.apache.spark.sql.arangodb.commons.mapping.ArangoParserProvider import org.apache.spark.sql.arangodb.commons.utils.PushDownCtx import org.apache.spark.sql.arangodb.commons.{ArangoClient, ArangoOptions, ContentType} import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.FailureSafeParser import org.apache.spark.sql.connector.read.PartitionReader +import org.apache.spark.sql.types.StructType import java.nio.charset.StandardCharsets +import scala.annotation.tailrec import scala.collection.JavaConverters.iterableAsScalaIterableConverter @@ -18,19 +20,31 @@ class ArangoCollectionPartitionReader(inputPartition: ArangoCollectionPartition, // override endpoints with partition endpoint private val options = opts.updated(ArangoOptions.ENDPOINTS, inputPartition.endpoint) - private val parser = ArangoParserProvider().of(options.readOptions.contentType, ctx.requiredSchema) + private val actualSchema = StructType(ctx.requiredSchema.filterNot(_.name == options.readOptions.columnNameOfCorruptRecord)) + private val parser = ArangoParserProvider().of(options.readOptions.contentType, actualSchema) + private val safeParser = new FailureSafeParser[Array[Byte]]( + parser.parse, + options.readOptions.parseMode, + ctx.requiredSchema, + options.readOptions.columnNameOfCorruptRecord) private val client = ArangoClient(options) - private val iterator = client.readCollectionPartition(inputPartition.shardId, ctx) + private val iterator = client.readCollectionPartition(inputPartition.shardId, ctx.filters, actualSchema) - private var current: VPackSlice = _ + var rowIterator: Iterator[InternalRow] = _ // warnings of non stream AQL cursors are all returned along with the first batch if (!options.readOptions.stream) logWarns() - override def next: Boolean = + @tailrec + final override def next: Boolean = if (iterator.hasNext) { - current = iterator.next() - true + val current = iterator.next() + rowIterator = safeParser.parse(options.readOptions.contentType match { + case ContentType.VPack => current.toByteArray + case ContentType.Json => current.toString.getBytes(StandardCharsets.UTF_8) + }) + if (rowIterator.hasNext) true + else next } else { // FIXME: https://arangodb.atlassian.net/browse/BTS-671 // stream AQL cursors' warnings are only returned along with the final batch @@ -38,10 +52,7 @@ class ArangoCollectionPartitionReader(inputPartition: ArangoCollectionPartition, false } - override def get: InternalRow = options.readOptions.contentType match { - case ContentType.VPack => parser.parse(current.toByteArray).head - case ContentType.Json => parser.parse(current.toString.getBytes(StandardCharsets.UTF_8)).head - } + override def get: InternalRow = rowIterator.next() override def close(): Unit = { iterator.close() diff --git a/arangodb-spark-datasource-3.1/src/main/scala/org/apache/spark/sql/arangodb/datasource/reader/ArangoQueryReader.scala b/arangodb-spark-datasource-3.1/src/main/scala/org/apache/spark/sql/arangodb/datasource/reader/ArangoQueryReader.scala index 0365f2b0..a293fd4a 100644 --- a/arangodb-spark-datasource-3.1/src/main/scala/org/apache/spark/sql/arangodb/datasource/reader/ArangoQueryReader.scala +++ b/arangodb-spark-datasource-3.1/src/main/scala/org/apache/spark/sql/arangodb/datasource/reader/ArangoQueryReader.scala @@ -1,33 +1,46 @@ package org.apache.spark.sql.arangodb.datasource.reader import com.arangodb.entity.CursorEntity.Warning -import com.arangodb.velocypack.VPackSlice import org.apache.spark.internal.Logging import org.apache.spark.sql.arangodb.commons.mapping.ArangoParserProvider import org.apache.spark.sql.arangodb.commons.{ArangoClient, ArangoOptions, ContentType} import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.FailureSafeParser import org.apache.spark.sql.connector.read.PartitionReader import org.apache.spark.sql.types._ import java.nio.charset.StandardCharsets +import scala.annotation.tailrec import scala.collection.JavaConverters.iterableAsScalaIterableConverter class ArangoQueryReader(schema: StructType, options: ArangoOptions) extends PartitionReader[InternalRow] with Logging { - private val parser = ArangoParserProvider().of(options.readOptions.contentType, schema) + private val actualSchema = StructType(schema.filterNot(_.name == options.readOptions.columnNameOfCorruptRecord)) + private val parser = ArangoParserProvider().of(options.readOptions.contentType, actualSchema) + private val safeParser = new FailureSafeParser[Array[Byte]]( + parser.parse, + options.readOptions.parseMode, + schema, + options.readOptions.columnNameOfCorruptRecord) private val client = ArangoClient(options) private val iterator = client.readQuery() - private var current: VPackSlice = _ + var rowIterator: Iterator[InternalRow] = _ // warnings of non stream AQL cursors are all returned along with the first batch if (!options.readOptions.stream) logWarns() - override def next: Boolean = + @tailrec + final override def next: Boolean = if (iterator.hasNext) { - current = iterator.next() - true + val current = iterator.next() + rowIterator = safeParser.parse(options.readOptions.contentType match { + case ContentType.VPack => current.toByteArray + case ContentType.Json => current.toString.getBytes(StandardCharsets.UTF_8) + }) + if (rowIterator.hasNext) true + else next } else { // FIXME: https://arangodb.atlassian.net/browse/BTS-671 // stream AQL cursors' warnings are only returned along with the final batch @@ -35,10 +48,7 @@ class ArangoQueryReader(schema: StructType, options: ArangoOptions) extends Part false } - override def get: InternalRow = options.readOptions.contentType match { - case ContentType.VPack => parser.parse(current.toByteArray).head - case ContentType.Json => parser.parse(current.toString.getBytes(StandardCharsets.UTF_8)).head - } + override def get: InternalRow = rowIterator.next() override def close(): Unit = { iterator.close() diff --git a/arangodb-spark-datasource-3.1/src/main/scala/org/apache/spark/sql/arangodb/datasource/reader/ArangoScan.scala b/arangodb-spark-datasource-3.1/src/main/scala/org/apache/spark/sql/arangodb/datasource/reader/ArangoScan.scala index 7ecfd9f4..211c89b5 100644 --- a/arangodb-spark-datasource-3.1/src/main/scala/org/apache/spark/sql/arangodb/datasource/reader/ArangoScan.scala +++ b/arangodb-spark-datasource-3.1/src/main/scala/org/apache/spark/sql/arangodb/datasource/reader/ArangoScan.scala @@ -2,10 +2,12 @@ package org.apache.spark.sql.arangodb.datasource.reader import org.apache.spark.sql.arangodb.commons.{ArangoClient, ArangoOptions, ReadMode} import org.apache.spark.sql.arangodb.commons.utils.PushDownCtx +import org.apache.spark.sql.catalyst.expressions.ExprUtils import org.apache.spark.sql.connector.read.{Batch, InputPartition, PartitionReaderFactory, Scan} import org.apache.spark.sql.types.StructType class ArangoScan(ctx: PushDownCtx, options: ArangoOptions) extends Scan with Batch { + ExprUtils.verifyColumnNameOfCorruptRecord(ctx.requiredSchema, options.readOptions.columnNameOfCorruptRecord) override def readSchema(): StructType = ctx.requiredSchema diff --git a/arangodb-spark-datasource-3.1/src/main/scala/org/apache/spark/sql/arangodb/datasource/reader/ArangoScanBuilder.scala b/arangodb-spark-datasource-3.1/src/main/scala/org/apache/spark/sql/arangodb/datasource/reader/ArangoScanBuilder.scala index 6928dbb3..521e430f 100644 --- a/arangodb-spark-datasource-3.1/src/main/scala/org/apache/spark/sql/arangodb/datasource/reader/ArangoScanBuilder.scala +++ b/arangodb-spark-datasource-3.1/src/main/scala/org/apache/spark/sql/arangodb/datasource/reader/ArangoScanBuilder.scala @@ -22,26 +22,30 @@ class ArangoScanBuilder(options: ArangoOptions, tableSchema: StructType) extends override def build(): Scan = new ArangoScan(new PushDownCtx(requiredSchema, appliedPushableFilters), options) override def pushFilters(filters: Array[Filter]): Array[Filter] = { + // filters related to columnNameOfCorruptRecord are not pushed down + val isCorruptRecordFilter = (f: Filter) => f.references.contains(options.readOptions.columnNameOfCorruptRecord) + val ignoredFilters = filters.filter(isCorruptRecordFilter) val filtersBySupport = filters + .filterNot(isCorruptRecordFilter) .map(f => (f, PushableFilter(f, tableSchema))) .groupBy(_._2.support()) val fullSupp = filtersBySupport.getOrElse(FilterSupport.FULL, Array()) val partialSupp = filtersBySupport.getOrElse(FilterSupport.PARTIAL, Array()) - val noneSupp = filtersBySupport.getOrElse(FilterSupport.NONE, Array()) + val noneSupp = filtersBySupport.getOrElse(FilterSupport.NONE, Array()).map(_._1) ++ ignoredFilters val appliedFilters = fullSupp ++ partialSupp appliedPushableFilters = appliedFilters.map(_._2) appliedSparkFilters = appliedFilters.map(_._1) if (fullSupp.nonEmpty) - logInfo(s"Fully supported filters (applied in AQL):\n\t${fullSupp.map(_._1).mkString("\n\t")}") + logInfo(s"Filters fully applied in AQL:\n\t${fullSupp.map(_._1).mkString("\n\t")}") if (partialSupp.nonEmpty) - logInfo(s"Partially supported filters (applied in AQL and Spark):\n\t${partialSupp.map(_._1).mkString("\n\t")}") + logInfo(s"Filters partially applied in AQL:\n\t${partialSupp.map(_._1).mkString("\n\t")}") if (noneSupp.nonEmpty) - logInfo(s"Not supported filters (applied in Spark):\n\t${noneSupp.map(_._1).mkString("\n\t")}") + logInfo(s"Filters not applied in AQL:\n\t${noneSupp.mkString("\n\t")}") - (partialSupp ++ noneSupp).map(_._1) + partialSupp.map(_._1) ++ noneSupp } override def pushedFilters(): Array[Filter] = appliedSparkFilters diff --git a/integration-tests/src/test/scala/org/apache/spark/sql/arangodb/datasource/BadRecordsTest.scala b/integration-tests/src/test/scala/org/apache/spark/sql/arangodb/datasource/BadRecordsTest.scala new file mode 100644 index 00000000..c1e59cd0 --- /dev/null +++ b/integration-tests/src/test/scala/org/apache/spark/sql/arangodb/datasource/BadRecordsTest.scala @@ -0,0 +1,131 @@ +package org.apache.spark.sql.arangodb.datasource + +import org.apache.spark.{SPARK_VERSION, SparkException} +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.arangodb.commons.ArangoOptions +import org.apache.spark.sql.catalyst.util.{BadRecordException, DropMalformedMode, FailFastMode, ParseMode} +import org.apache.spark.sql.types._ +import org.assertj.core.api.Assertions.{assertThat, catchThrowable} +import org.assertj.core.api.ThrowableAssert.ThrowingCallable +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.ValueSource + +class BadRecordsTest extends BaseSparkTest { + private val collectionName = "deserializationCast" + + @ParameterizedTest + @ValueSource(strings = Array("vpack", "json")) + def stringAsInteger(contentType: String): Unit = testBadRecord( + StructType(Array(StructField("a", IntegerType))), + Seq(Map("a" -> "1")), + Seq("""{"a":"1"}"""), + contentType + ) + + @ParameterizedTest + @ValueSource(strings = Array("vpack", "json")) + def booleanAsInteger(contentType: String): Unit = testBadRecord( + StructType(Array(StructField("a", IntegerType))), + Seq(Map("a" -> true)), + Seq("""{"a":true}"""), + contentType + ) + + @ParameterizedTest + @ValueSource(strings = Array("vpack", "json")) + def stringAsDouble(contentType: String): Unit = testBadRecord( + StructType(Array(StructField("a", DoubleType))), + Seq(Map("a" -> "1")), + Seq("""{"a":"1"}"""), + contentType + ) + + @ParameterizedTest + @ValueSource(strings = Array("vpack", "json")) + def booleanAsDouble(contentType: String): Unit = testBadRecord( + StructType(Array(StructField("a", DoubleType))), + Seq(Map("a" -> true)), + Seq("""{"a":true}"""), + contentType + ) + + @ParameterizedTest + @ValueSource(strings = Array("vpack", "json")) + def stringAsBoolean(contentType: String): Unit = testBadRecord( + StructType(Array(StructField("a", BooleanType))), + Seq(Map("a" -> "true")), + Seq("""{"a":"true"}"""), + contentType + ) + + @ParameterizedTest + @ValueSource(strings = Array("vpack", "json")) + def numberAsBoolean(contentType: String): Unit = testBadRecord( + StructType(Array(StructField("a", BooleanType))), + Seq(Map("a" -> 1)), + Seq("""{"a":1}"""), + contentType + ) + + private def testBadRecord( + schema: StructType, + data: Iterable[Map[String, Any]], + jsonData: Seq[String], + contentType: String + ) = { + // PERMISSIVE + doTestBadRecord(schema, data, jsonData, Map(ArangoOptions.CONTENT_TYPE -> contentType)) + + // PERMISSIVE with columnNameOfCorruptRecord + doTestBadRecord( + schema.add(StructField("corruptRecord", StringType)), + data, + jsonData, + Map( + ArangoOptions.CONTENT_TYPE -> contentType, + ArangoOptions.CORRUPT_RECORDS_COLUMN -> "corruptRecord" + ) + ) + + // DROPMALFORMED + doTestBadRecord(schema, data, jsonData, + Map( + ArangoOptions.CONTENT_TYPE -> contentType, + ArangoOptions.PARSE_MODE -> DropMalformedMode.name + ) + ) + + // FAILFAST + val df = BaseSparkTest.createDF(collectionName, data, schema, Map( + ArangoOptions.CONTENT_TYPE -> contentType, + ArangoOptions.PARSE_MODE -> FailFastMode.name + )) + val thrown = catchThrowable(new ThrowingCallable() { + override def call(): Unit = df.collect() + }) + + assertThat(thrown.getCause).isInstanceOf(classOf[SparkException]) + assertThat(thrown.getCause).hasMessageContaining("Malformed record") + if (!SPARK_VERSION.startsWith("2.4")) { // [SPARK-25886] + assertThat(thrown.getCause).hasCauseInstanceOf(classOf[BadRecordException]) + } + } + + private def doTestBadRecord( + schema: StructType, + data: Iterable[Map[String, Any]], + jsonData: Seq[String], + opts: Map[String, String] = Map.empty + ) = { + import spark.implicits._ + val dfFromJson: DataFrame = spark.read.schema(schema).options(opts).json(jsonData.toDS) + dfFromJson.show() + + val tableDF = BaseSparkTest.createDF(collectionName, data, schema, opts) + assertThat(tableDF.collect()).isEqualTo(dfFromJson.collect()) + + val queryDF = BaseSparkTest.createQueryDF(s"RETURN ${jsonData.head}", schema, opts) + assertThat(queryDF.collect()).isEqualTo(dfFromJson.collect()) + } + +} diff --git a/integration-tests/src/test/scala/org/apache/spark/sql/arangodb/datasource/BaseSparkTest.scala b/integration-tests/src/test/scala/org/apache/spark/sql/arangodb/datasource/BaseSparkTest.scala index c7cd1076..70bc2d92 100644 --- a/integration-tests/src/test/scala/org/apache/spark/sql/arangodb/datasource/BaseSparkTest.scala +++ b/integration-tests/src/test/scala/org/apache/spark/sql/arangodb/datasource/BaseSparkTest.scala @@ -8,6 +8,7 @@ import com.fasterxml.jackson.core.JsonGenerator import com.fasterxml.jackson.databind.module.SimpleModule import com.fasterxml.jackson.databind.{JsonSerializer, ObjectMapper, SerializerProvider} import com.fasterxml.jackson.module.scala.DefaultScalaModule +import org.apache.spark.sql.arangodb.commons.ArangoOptions import org.apache.spark.sql.types._ import org.apache.spark.sql.{DataFrame, SparkSession} import org.junit.jupiter.api.{AfterEach, BeforeAll} @@ -34,6 +35,7 @@ class BaseSparkTest { } def isSingle: Boolean = BaseSparkTest.isSingle + def isCluster: Boolean = !BaseSparkTest.isSingle } @@ -153,7 +155,7 @@ object BaseSparkTest { usersSchema ) - def createDF(name: String, docs: Iterable[Any], schema: StructType): DataFrame = { + def createDF(name: String, docs: Iterable[Any], schema: StructType, additionalOptions: Map[String, String] = Map.empty): DataFrame = { val col = db.collection(name) if (col.exists()) { col.truncate() @@ -164,13 +166,20 @@ object BaseSparkTest { val df = spark.read .format(arangoDatasource) - .options(options + ("table" -> name)) + .options(options ++ additionalOptions + (ArangoOptions.COLLECTION -> name)) .schema(schema) .load() df.createOrReplaceTempView(name) df } + def createQueryDF(query: String, schema: StructType, additionalOptions: Map[String, String] = Map.empty): DataFrame = + spark.read + .format(arangoDatasource) + .options(options ++ additionalOptions + (ArangoOptions.QUERY -> query)) + .schema(schema) + .load() + def dropTable(name: String): Unit = { db.collection(name).drop() } diff --git a/integration-tests/src/test/scala/org/apache/spark/sql/arangodb/datasource/DeserializationCastTest.scala b/integration-tests/src/test/scala/org/apache/spark/sql/arangodb/datasource/DeserializationCastTest.scala new file mode 100644 index 00000000..65a7e2fb --- /dev/null +++ b/integration-tests/src/test/scala/org/apache/spark/sql/arangodb/datasource/DeserializationCastTest.scala @@ -0,0 +1,102 @@ +package org.apache.spark.sql.arangodb.datasource + +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.arangodb.commons.ArangoOptions +import org.apache.spark.sql.types.{BooleanType, DoubleType, IntegerType, StringType, StructField, StructType} +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.Disabled +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.ValueSource + +/** + * FIXME: many vpack tests fail + */ +@Disabled +class DeserializationCastTest extends BaseSparkTest { + private val collectionName = "deserializationCast" + + @ParameterizedTest + @ValueSource(strings = Array("vpack", "json")) + def numberIntToStringCast(contentType: String): Unit = doTestImplicitCast( + StructType(Array(StructField("a", StringType))), + Seq(Map("a" -> 1)), + Seq("""{"a":1}"""), + contentType + ) + + @ParameterizedTest + @ValueSource(strings = Array("vpack", "json")) + def numberDecToStringCast(contentType: String): Unit = doTestImplicitCast( + StructType(Array(StructField("a", StringType))), + Seq(Map("a" -> 1.1)), + Seq("""{"a":1.1}"""), + contentType + ) + + @ParameterizedTest + @ValueSource(strings = Array("vpack", "json")) + def boolToStringCast(contentType: String): Unit = doTestImplicitCast( + StructType(Array(StructField("a", StringType))), + Seq(Map("a" -> true)), + Seq("""{"a":true}"""), + contentType + ) + + @ParameterizedTest + @ValueSource(strings = Array("vpack", "json")) + def objectToStringCast(contentType: String): Unit = doTestImplicitCast( + StructType(Array(StructField("a", StringType))), + Seq(Map("a" -> Map("b" -> "c"))), + Seq("""{"a":{"b":"c"}}"""), + contentType + ) + + @ParameterizedTest + @ValueSource(strings = Array("vpack", "json")) + def arrayToStringCast(contentType: String): Unit = doTestImplicitCast( + StructType(Array(StructField("a", StringType))), + Seq(Map("a" -> Array(1, 2))), + Seq("""{"a":[1,2]}"""), + contentType + ) + + @ParameterizedTest + @ValueSource(strings = Array("vpack", "json")) + def nullToIntegerCast(contentType: String): Unit = doTestImplicitCast( + StructType(Array(StructField("a", IntegerType, nullable = false))), + Seq(Map("a" -> null)), + Seq("""{"a":null}"""), + contentType + ) + + @ParameterizedTest + @ValueSource(strings = Array("vpack", "json")) + def nullToDoubleCast(contentType: String): Unit = doTestImplicitCast( + StructType(Array(StructField("a", DoubleType, nullable = false))), + Seq(Map("a" -> null)), + Seq("""{"a":null}"""), + contentType + ) + + @ParameterizedTest + @ValueSource(strings = Array("vpack", "json")) + def nullAsBoolean(contentType: String): Unit = doTestImplicitCast( + StructType(Array(StructField("a", BooleanType, nullable = false))), + Seq(Map("a" -> null)), + Seq("""{"a":null}"""), + contentType + ) + + private def doTestImplicitCast( + schema: StructType, + data: Iterable[Map[String, Any]], + jsonData: Seq[String], + contentType: String + ) = { + import spark.implicits._ + val dfFromJson: DataFrame = spark.read.schema(schema).json(jsonData.toDS) + dfFromJson.show() + val df = BaseSparkTest.createDF(collectionName, data, schema, Map(ArangoOptions.CONTENT_TYPE -> contentType)) + assertThat(df.collect()).isEqualTo(dfFromJson.collect()) + } +} diff --git a/integration-tests/src/test/scala/org/apache/spark/sql/arangodb/datasource/ReadTest.scala b/integration-tests/src/test/scala/org/apache/spark/sql/arangodb/datasource/ReadTest.scala index 425b087b..40021fe0 100644 --- a/integration-tests/src/test/scala/org/apache/spark/sql/arangodb/datasource/ReadTest.scala +++ b/integration-tests/src/test/scala/org/apache/spark/sql/arangodb/datasource/ReadTest.scala @@ -1,15 +1,14 @@ package org.apache.spark.sql.arangodb.datasource -import org.apache.spark.SparkException +import org.apache.spark.sql.DataFrame import org.apache.spark.sql.arangodb.commons.ArangoOptions -import org.apache.spark.sql.catalyst.util.BadRecordException import org.apache.spark.sql.functions.col -import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} -import org.assertj.core.api.Assertions.{assertThat, catchThrowable} -import org.assertj.core.api.ThrowableAssert.ThrowingCallable +import org.apache.spark.sql.types.{NumericType, StringType, StructField, StructType} +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.Assumptions.assumeTrue import org.junit.jupiter.api.Test import org.junit.jupiter.params.ParameterizedTest -import org.junit.jupiter.params.provider.MethodSource +import org.junit.jupiter.params.provider.{MethodSource, ValueSource} class ReadTest extends BaseSparkTest { @@ -42,31 +41,6 @@ class ReadTest extends BaseSparkTest { assertThat(litalien.birthday).isEqualTo("1944-06-19") } - @ParameterizedTest - @MethodSource(Array("provideProtocolAndContentType")) - def readCollectionWithBadRecords(protocol: String, contentType: String): Unit = { - val thrown = catchThrowable(new ThrowingCallable() { - override def call(): Unit = - spark.read - .format(BaseSparkTest.arangoDatasource) - .options(options + ( - ArangoOptions.COLLECTION -> "users", - ArangoOptions.PROTOCOL -> protocol, - ArangoOptions.CONTENT_TYPE -> contentType - )) - .schema(new StructType( - Array( - StructField("likes", IntegerType) - ) - )) - .load() - .show() - }) - - assertThat(thrown).isInstanceOf(classOf[SparkException]) - assertThat(thrown.getCause).isInstanceOf(classOf[BadRecordException]) - } - @Test def readCollectionSql(): Unit = { val litalien = spark.sql( @@ -111,6 +85,62 @@ class ReadTest extends BaseSparkTest { assertThat(lastNameSchema.nullable).isTrue } + @ParameterizedTest + @ValueSource(strings = Array("vpack", "json")) + def inferCollectionSchemaWithCorruptRecordColumn(contentType: String): Unit = { + assumeTrue(isSingle) + + val additionalOptions = Map( + ArangoOptions.CORRUPT_RECORDS_COLUMN -> "badRecord", + ArangoOptions.SAMPLE_SIZE -> "2", + ArangoOptions.CONTENT_TYPE -> contentType + ) + + doInferCollectionSchemaWithCorruptRecordColumn( + BaseSparkTest.createQueryDF( + """FOR d IN [{"v":1},{"v":2},{"v":"3"}] RETURN d""", + schema = null, + additionalOptions + ) + ) + + doInferCollectionSchemaWithCorruptRecordColumn( + BaseSparkTest.createDF( + "badData", + Seq( + Map("v" -> 1), + Map("v" -> 2), + Map("v" -> "3") + ), + schema = null, + additionalOptions + ) + ) + } + + def doInferCollectionSchemaWithCorruptRecordColumn(df: DataFrame): Unit = { + val vSchema = df.schema("v") + assertThat(vSchema).isInstanceOf(classOf[StructField]) + assertThat(vSchema.name).isEqualTo("v") + assertThat(vSchema.dataType).isInstanceOf(classOf[NumericType]) + assertThat(vSchema.nullable).isTrue + + val badRecordSchema = df.schema("badRecord") + assertThat(badRecordSchema).isInstanceOf(classOf[StructField]) + assertThat(badRecordSchema.name).isEqualTo("badRecord") + assertThat(badRecordSchema.dataType).isInstanceOf(classOf[StringType]) + assertThat(badRecordSchema.nullable).isTrue + + val badRecords = df.filter("badRecord IS NOT NULL").persist() + .select("badRecord") + .collect() + .map(_ (0).asInstanceOf[String]) + + assertThat(badRecords).hasSize(1) + assertThat(badRecords.head).contains(""""v":"3""") + } + + @ParameterizedTest @MethodSource(Array("provideProtocolAndContentType")) def readQuery(protocol: String, contentType: String): Unit = {