diff --git a/.github/workflows/maven-deploy.yml b/.github/workflows/maven-deploy.yml
index 76c6a040..4afab33b 100644
--- a/.github/workflows/maven-deploy.yml
+++ b/.github/workflows/maven-deploy.yml
@@ -12,6 +12,26 @@ jobs:
strategy:
fail-fast: false
+ matrix:
+ include:
+ - scala-version: 2.11
+ spark-version: 2.4
+ - scala-version: 2.12
+ spark-version: 2.4
+ - scala-version: 2.12
+ spark-version: 3.1
+ - scala-version: 2.12
+ spark-version: 3.2
+ - scala-version: 2.13
+ spark-version: 3.2
+ - scala-version: 2.12
+ spark-version: 3.3
+ - scala-version: 2.13
+ spark-version: 3.3
+ - scala-version: 2.12
+ spark-version: 3.4
+ - scala-version: 2.13
+ spark-version: 3.4
steps:
- uses: actions/checkout@v2
diff --git a/.github/workflows/maven-release.yml b/.github/workflows/maven-release.yml
index 60ce229d..094fea9a 100644
--- a/.github/workflows/maven-release.yml
+++ b/.github/workflows/maven-release.yml
@@ -24,6 +24,14 @@ jobs:
spark-version: 3.2
- scala-version: 2.13
spark-version: 3.2
+ - scala-version: 2.12
+ spark-version: 3.3
+ - scala-version: 2.13
+ spark-version: 3.3
+ - scala-version: 2.12
+ spark-version: 3.4
+ - scala-version: 2.13
+ spark-version: 3.4
steps:
- uses: actions/checkout@v2
diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml
index 4462c105..e0c3d926 100644
--- a/.github/workflows/test.yml
+++ b/.github/workflows/test.yml
@@ -38,6 +38,8 @@ jobs:
- 2.4
- 3.1
- 3.2
+ - 3.3
+ - 3.4
topology:
- single
- cluster
@@ -53,12 +55,18 @@ jobs:
spark-version: 3.1
- scala-version: 2.11
spark-version: 3.2
+ - scala-version: 2.11
+ spark-version: 3.3
+ - scala-version: 2.11
+ spark-version: 3.4
- scala-version: 2.11
java-version: 11
- scala-version: 2.13
spark-version: 2.4
- scala-version: 2.13
spark-version: 3.1
+ - docker-img: docker.io/arangodb/arangodb:3.9.10
+ java-version: 8
- docker-img: docker.io/arangodb/arangodb:3.10.6
java-version: 8
- docker-img: docker.io/arangodb/arangodb:3.11.0
@@ -96,6 +104,8 @@ jobs:
- 2.4
- 3.1
- 3.2
+ - 3.3
+ - 3.4
topology:
- cluster
java-version:
@@ -107,6 +117,10 @@ jobs:
spark-version: 3.1
- scala-version: 2.11
spark-version: 3.2
+ - scala-version: 2.11
+ spark-version: 3.3
+ - scala-version: 2.11
+ spark-version: 3.4
- scala-version: 2.13
spark-version: 2.4
- scala-version: 2.13
@@ -140,10 +154,15 @@ jobs:
matrix:
python-version: [3.9]
scala-version: [2.12]
- spark-version: [3.1, 3.2]
+ spark-version: [3.1, 3.2, 3.3, 3.4]
topology: [single, cluster]
java-version: [8, 11]
docker-img: ["docker.io/arangodb/arangodb:3.11.0"]
+ exclude:
+ - topology: cluster
+ java-version: 8
+ - topology: single
+ java-version: 11
steps:
- uses: actions/checkout@v2
@@ -191,6 +210,8 @@ jobs:
- 2.4
- 3.1
- 3.2
+ - 3.3
+ - 3.4
topology:
- single
java-version:
@@ -203,6 +224,10 @@ jobs:
spark-version: 3.1
- scala-version: 2.11
spark-version: 3.2
+ - scala-version: 2.11
+ spark-version: 3.3
+ - scala-version: 2.11
+ spark-version: 3.4
- scala-version: 2.13
spark-version: 2.4
- scala-version: 2.13
@@ -301,6 +326,12 @@ jobs:
- spark-version: 3.3
scala-version: 2.13
spark-full-version: 3.3.2
+ - spark-version: 3.4
+ scala-version: 2.12
+ spark-full-version: 3.4.0
+ - spark-version: 3.4
+ scala-version: 2.13
+ spark-full-version: 3.4.0
steps:
- uses: actions/checkout@v2
@@ -331,7 +362,7 @@ jobs:
scala-version:
- 2.12
spark-version:
- - 3.2
+ - 3.4
topology:
- single
java-version:
diff --git a/arangodb-spark-datasource-3.4/pom.xml b/arangodb-spark-datasource-3.4/pom.xml
new file mode 100644
index 00000000..4cbc5c6a
--- /dev/null
+++ b/arangodb-spark-datasource-3.4/pom.xml
@@ -0,0 +1,79 @@
+
+
+
+ arangodb-spark-datasource
+ com.arangodb
+ 1.4.3
+
+ 4.0.0
+
+ arangodb-spark-datasource-3.4_${scala.compat.version}
+
+ arangodb-spark-datasource-3.4
+ ArangoDB Datasource for Apache Spark 3.4
+ https://github.com/arangodb/arangodb-spark-datasource
+
+
+
+ Michele Rastelli
+ https://github.com/rashtao
+
+
+
+
+ https://github.com/arangodb/arangodb-spark-datasource
+
+
+
+ false
+ ../integration-tests/target/site/jacoco-aggregate/jacoco.xml
+ src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/json/*
+ src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/json/*
+ false
+
+
+
+
+ com.arangodb
+ arangodb-spark-commons-${spark.compat.version}_${scala.compat.version}
+ ${project.version}
+
+
+ org.apache.httpcomponents
+ httpclient
+ 4.5.13
+
+
+
+
+
+
+ maven-assembly-plugin
+
+
+ jar-with-dependencies
+
+
+
+
+ package
+
+ single
+
+
+
+
+
+ org.sonatype.plugins
+ nexus-staging-maven-plugin
+ true
+
+ false
+
+
+
+
+
+
\ No newline at end of file
diff --git a/arangodb-spark-datasource-3.4/src/main/resources/META-INF/services/org.apache.spark.sql.arangodb.commons.mapping.ArangoGeneratorProvider b/arangodb-spark-datasource-3.4/src/main/resources/META-INF/services/org.apache.spark.sql.arangodb.commons.mapping.ArangoGeneratorProvider
new file mode 100644
index 00000000..477374e3
--- /dev/null
+++ b/arangodb-spark-datasource-3.4/src/main/resources/META-INF/services/org.apache.spark.sql.arangodb.commons.mapping.ArangoGeneratorProvider
@@ -0,0 +1 @@
+org.apache.spark.sql.arangodb.datasource.mapping.ArangoGeneratorProviderImpl
\ No newline at end of file
diff --git a/arangodb-spark-datasource-3.4/src/main/resources/META-INF/services/org.apache.spark.sql.arangodb.commons.mapping.ArangoParserProvider b/arangodb-spark-datasource-3.4/src/main/resources/META-INF/services/org.apache.spark.sql.arangodb.commons.mapping.ArangoParserProvider
new file mode 100644
index 00000000..3e6a6b92
--- /dev/null
+++ b/arangodb-spark-datasource-3.4/src/main/resources/META-INF/services/org.apache.spark.sql.arangodb.commons.mapping.ArangoParserProvider
@@ -0,0 +1 @@
+org.apache.spark.sql.arangodb.datasource.mapping.ArangoParserProviderImpl
diff --git a/arangodb-spark-datasource-3.4/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/arangodb-spark-datasource-3.4/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister
new file mode 100644
index 00000000..5a634481
--- /dev/null
+++ b/arangodb-spark-datasource-3.4/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister
@@ -0,0 +1 @@
+com.arangodb.spark.DefaultSource
diff --git a/arangodb-spark-datasource-3.4/src/main/scala/com/arangodb/spark/DefaultSource.scala b/arangodb-spark-datasource-3.4/src/main/scala/com/arangodb/spark/DefaultSource.scala
new file mode 100644
index 00000000..38c0925c
--- /dev/null
+++ b/arangodb-spark-datasource-3.4/src/main/scala/com/arangodb/spark/DefaultSource.scala
@@ -0,0 +1,40 @@
+package com.arangodb.spark
+
+import org.apache.spark.sql.arangodb.commons.{ArangoClient, ArangoDBConf}
+import org.apache.spark.sql.arangodb.datasource.ArangoTable
+import org.apache.spark.sql.connector.catalog.{Table, TableProvider}
+import org.apache.spark.sql.connector.expressions.Transform
+import org.apache.spark.sql.sources.DataSourceRegister
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
+
+import java.util
+
+class DefaultSource extends TableProvider with DataSourceRegister {
+
+ private def extractOptions(options: util.Map[String, String]): ArangoDBConf = {
+ val opts: ArangoDBConf = ArangoDBConf(options)
+ if (opts.driverOptions.acquireHostList) {
+ val hosts = ArangoClient.acquireHostList(opts)
+ opts.updated(ArangoDBConf.ENDPOINTS, hosts.mkString(","))
+ } else {
+ opts
+ }
+ }
+
+ override def inferSchema(options: CaseInsensitiveStringMap): StructType = getTable(options).schema()
+
+ private def getTable(options: CaseInsensitiveStringMap): Table =
+ getTable(None, options.asCaseSensitiveMap()) // scalastyle:ignore null
+
+ override def getTable(schema: StructType, partitioning: Array[Transform], properties: util.Map[String, String]): Table =
+ getTable(Option(schema), properties)
+
+ override def supportsExternalMetadata(): Boolean = true
+
+ override def shortName(): String = "arangodb"
+
+ private def getTable(schema: Option[StructType], properties: util.Map[String, String]) =
+ new ArangoTable(schema, extractOptions(properties))
+
+}
diff --git a/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/ArangoTable.scala b/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/ArangoTable.scala
new file mode 100644
index 00000000..e8f4d5a8
--- /dev/null
+++ b/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/ArangoTable.scala
@@ -0,0 +1,37 @@
+package org.apache.spark.sql.arangodb.datasource
+
+import org.apache.spark.sql.arangodb.commons.{ArangoDBConf, ArangoUtils}
+import org.apache.spark.sql.arangodb.datasource.reader.ArangoScanBuilder
+import org.apache.spark.sql.arangodb.datasource.writer.ArangoWriterBuilder
+import org.apache.spark.sql.connector.catalog.{SupportsRead, SupportsWrite, Table, TableCapability}
+import org.apache.spark.sql.connector.read.ScanBuilder
+import org.apache.spark.sql.connector.write.{LogicalWriteInfo, WriteBuilder}
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
+
+import java.util
+import scala.collection.JavaConverters.setAsJavaSetConverter
+
+class ArangoTable(private var schemaOpt: Option[StructType], options: ArangoDBConf) extends Table with SupportsRead with SupportsWrite {
+ private lazy val tableSchema = schemaOpt.getOrElse(ArangoUtils.inferSchema(options))
+
+ override def name(): String = this.getClass.toString
+
+ override def schema(): StructType = tableSchema
+
+ override def capabilities(): util.Set[TableCapability] = Set(
+ TableCapability.BATCH_READ,
+ TableCapability.BATCH_WRITE,
+ // TableCapability.STREAMING_WRITE,
+ TableCapability.ACCEPT_ANY_SCHEMA,
+ TableCapability.TRUNCATE
+ // TableCapability.OVERWRITE_BY_FILTER,
+ // TableCapability.OVERWRITE_DYNAMIC,
+ ).asJava
+
+ override def newScanBuilder(scanOptions: CaseInsensitiveStringMap): ScanBuilder =
+ new ArangoScanBuilder(options.updated(ArangoDBConf(scanOptions)), schema())
+
+ override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder =
+ new ArangoWriterBuilder(info.schema(), options.updated(ArangoDBConf(info.options())))
+}
diff --git a/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/ArangoGeneratorImpl.scala b/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/ArangoGeneratorImpl.scala
new file mode 100644
index 00000000..b4c07882
--- /dev/null
+++ b/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/ArangoGeneratorImpl.scala
@@ -0,0 +1,46 @@
+package org.apache.spark.sql.arangodb.datasource.mapping
+
+import com.arangodb.jackson.dataformat.velocypack.VPackFactoryBuilder
+import com.fasterxml.jackson.core.JsonFactoryBuilder
+import org.apache.spark.sql.arangodb.commons.{ArangoDBConf, ContentType}
+import org.apache.spark.sql.arangodb.commons.mapping.{ArangoGenerator, ArangoGeneratorProvider}
+import org.apache.spark.sql.arangodb.datasource.mapping.json.{JSONOptions, JacksonGenerator}
+import org.apache.spark.sql.types.{DataType, StructType}
+
+import java.io.OutputStream
+
+abstract sealed class ArangoGeneratorImpl(
+ schema: DataType,
+ writer: OutputStream,
+ options: JSONOptions)
+ extends JacksonGenerator(
+ schema,
+ options.buildJsonFactory().createGenerator(writer),
+ options) with ArangoGenerator
+
+class ArangoGeneratorProviderImpl extends ArangoGeneratorProvider {
+ override def of(
+ contentType: ContentType,
+ schema: StructType,
+ outputStream: OutputStream,
+ conf: ArangoDBConf
+ ): ArangoGeneratorImpl = contentType match {
+ case ContentType.JSON => new JsonArangoGenerator(schema, outputStream, conf)
+ case ContentType.VPACK => new VPackArangoGenerator(schema, outputStream, conf)
+ case _ => throw new IllegalArgumentException
+ }
+}
+
+class JsonArangoGenerator(schema: StructType, outputStream: OutputStream, conf: ArangoDBConf)
+ extends ArangoGeneratorImpl(
+ schema,
+ outputStream,
+ createOptions(new JsonFactoryBuilder().build(), conf)
+ )
+
+class VPackArangoGenerator(schema: StructType, outputStream: OutputStream, conf: ArangoDBConf)
+ extends ArangoGeneratorImpl(
+ schema,
+ outputStream,
+ createOptions(new VPackFactoryBuilder().build(), conf)
+ )
diff --git a/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/ArangoParserImpl.scala b/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/ArangoParserImpl.scala
new file mode 100644
index 00000000..dad564ce
--- /dev/null
+++ b/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/ArangoParserImpl.scala
@@ -0,0 +1,47 @@
+package org.apache.spark.sql.arangodb.datasource.mapping
+
+import com.arangodb.jackson.dataformat.velocypack.VPackFactoryBuilder
+import com.fasterxml.jackson.core.json.JsonReadFeature
+import com.fasterxml.jackson.core.{JsonFactory, JsonFactoryBuilder}
+import org.apache.spark.sql.arangodb.commons.{ArangoDBConf, ContentType}
+import org.apache.spark.sql.arangodb.commons.mapping.{ArangoParser, ArangoParserProvider, MappingUtils}
+import org.apache.spark.sql.arangodb.datasource.mapping.json.{JSONOptions, JacksonParser}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.types.DataType
+import org.apache.spark.unsafe.types.UTF8String
+
+abstract sealed class ArangoParserImpl(
+ schema: DataType,
+ options: JSONOptions,
+ recordLiteral: Array[Byte] => UTF8String)
+ extends JacksonParser(schema, options) with ArangoParser {
+ override def parse(data: Array[Byte]): Iterable[InternalRow] = super.parse(
+ data,
+ (jsonFactory: JsonFactory, record: Array[Byte]) => jsonFactory.createParser(record),
+ recordLiteral
+ )
+}
+
+class ArangoParserProviderImpl extends ArangoParserProvider {
+ override def of(contentType: ContentType, schema: DataType, conf: ArangoDBConf): ArangoParserImpl = contentType match {
+ case ContentType.JSON => new JsonArangoParser(schema, conf)
+ case ContentType.VPACK => new VPackArangoParser(schema, conf)
+ case _ => throw new IllegalArgumentException
+ }
+}
+
+class JsonArangoParser(schema: DataType, conf: ArangoDBConf)
+ extends ArangoParserImpl(
+ schema,
+ createOptions(new JsonFactoryBuilder()
+ .configure(JsonReadFeature.ALLOW_UNESCAPED_CONTROL_CHARS, true)
+ .build(), conf),
+ (bytes: Array[Byte]) => UTF8String.fromBytes(bytes)
+ )
+
+class VPackArangoParser(schema: DataType, conf: ArangoDBConf)
+ extends ArangoParserImpl(
+ schema,
+ createOptions(new VPackFactoryBuilder().build(), conf),
+ (bytes: Array[Byte]) => UTF8String.fromString(MappingUtils.vpackToJson(bytes))
+ )
diff --git a/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/json/CreateJacksonParser.scala b/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/json/CreateJacksonParser.scala
new file mode 100644
index 00000000..0fa095f1
--- /dev/null
+++ b/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/json/CreateJacksonParser.scala
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// scalastyle:off
+
+package org.apache.spark.sql.arangodb.datasource.mapping.json
+
+import com.fasterxml.jackson.core.{JsonFactory, JsonParser}
+import org.apache.hadoop.io.Text
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.unsafe.types.UTF8String
+import sun.nio.cs.StreamDecoder
+
+import java.io.{ByteArrayInputStream, InputStream, InputStreamReader}
+import java.nio.channels.Channels
+import java.nio.charset.{Charset, StandardCharsets}
+
+private[sql] object CreateJacksonParser extends Serializable {
+ def string(jsonFactory: JsonFactory, record: String): JsonParser = {
+ jsonFactory.createParser(record)
+ }
+
+ def utf8String(jsonFactory: JsonFactory, record: UTF8String): JsonParser = {
+ val bb = record.getByteBuffer
+ assert(bb.hasArray)
+
+ val bain = new ByteArrayInputStream(
+ bb.array(), bb.arrayOffset() + bb.position(), bb.remaining())
+
+ jsonFactory.createParser(new InputStreamReader(bain, StandardCharsets.UTF_8))
+ }
+
+ def text(jsonFactory: JsonFactory, record: Text): JsonParser = {
+ jsonFactory.createParser(record.getBytes, 0, record.getLength)
+ }
+
+ // Jackson parsers can be ranked according to their performance:
+ // 1. Array based with actual encoding UTF-8 in the array. This is the fastest parser
+ // but it doesn't allow to set encoding explicitly. Actual encoding is detected automatically
+ // by checking leading bytes of the array.
+ // 2. InputStream based with actual encoding UTF-8 in the stream. Encoding is detected
+ // automatically by analyzing first bytes of the input stream.
+ // 3. Reader based parser. This is the slowest parser used here but it allows to create
+ // a reader with specific encoding.
+ // The method creates a reader for an array with given encoding and sets size of internal
+ // decoding buffer according to size of input array.
+ private def getStreamDecoder(enc: String, in: Array[Byte], length: Int): StreamDecoder = {
+ val bais = new ByteArrayInputStream(in, 0, length)
+ val byteChannel = Channels.newChannel(bais)
+ val decodingBufferSize = Math.min(length, 8192)
+ val decoder = Charset.forName(enc).newDecoder()
+
+ StreamDecoder.forDecoder(byteChannel, decoder, decodingBufferSize)
+ }
+
+ def text(enc: String, jsonFactory: JsonFactory, record: Text): JsonParser = {
+ val sd = getStreamDecoder(enc, record.getBytes, record.getLength)
+ jsonFactory.createParser(sd)
+ }
+
+ def inputStream(jsonFactory: JsonFactory, is: InputStream): JsonParser = {
+ jsonFactory.createParser(is)
+ }
+
+ def inputStream(enc: String, jsonFactory: JsonFactory, is: InputStream): JsonParser = {
+ jsonFactory.createParser(new InputStreamReader(is, enc))
+ }
+
+ def internalRow(jsonFactory: JsonFactory, row: InternalRow): JsonParser = {
+ val ba = row.getBinary(0)
+
+ jsonFactory.createParser(ba, 0, ba.length)
+ }
+
+ def internalRow(enc: String, jsonFactory: JsonFactory, row: InternalRow): JsonParser = {
+ val binary = row.getBinary(0)
+ val sd = getStreamDecoder(enc, binary, binary.length)
+
+ jsonFactory.createParser(sd)
+ }
+}
diff --git a/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/json/JSONOptions.scala b/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/json/JSONOptions.scala
new file mode 100644
index 00000000..a9034cbe
--- /dev/null
+++ b/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/json/JSONOptions.scala
@@ -0,0 +1,268 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// scalastyle:off
+
+package org.apache.spark.sql.arangodb.datasource.mapping.json
+
+import com.fasterxml.jackson.core.json.JsonReadFeature
+import com.fasterxml.jackson.core.{JsonFactory, JsonFactoryBuilder}
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.catalyst.util._
+import org.apache.spark.sql.catalyst.{DataSourceOptions, FileSourceOptions}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy
+
+import java.nio.charset.{Charset, StandardCharsets}
+import java.time.ZoneId
+import java.util.Locale
+
+/**
+ * Options for parsing JSON data into Spark SQL rows.
+ *
+ * Most of these map directly to Jackson's internal options, specified in [[JsonReadFeature]].
+ */
+private[sql] class JSONOptions(
+ @transient val parameters: CaseInsensitiveMap[String],
+ defaultTimeZoneId: String,
+ defaultColumnNameOfCorruptRecord: String)
+ extends FileSourceOptions(parameters) with Logging {
+
+ import JSONOptions._
+
+ def this(
+ parameters: Map[String, String],
+ defaultTimeZoneId: String,
+ defaultColumnNameOfCorruptRecord: String = "") = {
+ this(
+ CaseInsensitiveMap(parameters),
+ defaultTimeZoneId,
+ defaultColumnNameOfCorruptRecord)
+ }
+
+ val samplingRatio =
+ parameters.get(SAMPLING_RATIO).map(_.toDouble).getOrElse(1.0)
+ val primitivesAsString =
+ parameters.get(PRIMITIVES_AS_STRING).map(_.toBoolean).getOrElse(false)
+ val prefersDecimal =
+ parameters.get(PREFERS_DECIMAL).map(_.toBoolean).getOrElse(false)
+ val allowComments =
+ parameters.get(ALLOW_COMMENTS).map(_.toBoolean).getOrElse(false)
+ val allowUnquotedFieldNames =
+ parameters.get(ALLOW_UNQUOTED_FIELD_NAMES).map(_.toBoolean).getOrElse(false)
+ val allowSingleQuotes =
+ parameters.get(ALLOW_SINGLE_QUOTES).map(_.toBoolean).getOrElse(true)
+ val allowNumericLeadingZeros =
+ parameters.get(ALLOW_NUMERIC_LEADING_ZEROS).map(_.toBoolean).getOrElse(false)
+ val allowNonNumericNumbers =
+ parameters.get(ALLOW_NON_NUMERIC_NUMBERS).map(_.toBoolean).getOrElse(true)
+ val allowBackslashEscapingAnyCharacter =
+ parameters.get(ALLOW_BACKSLASH_ESCAPING_ANY_CHARACTER).map(_.toBoolean).getOrElse(false)
+ private val allowUnquotedControlChars =
+ parameters.get(ALLOW_UNQUOTED_CONTROL_CHARS).map(_.toBoolean).getOrElse(false)
+ val compressionCodec = parameters.get(COMPRESSION).map(CompressionCodecs.getCodecClassName)
+ val parseMode: ParseMode =
+ parameters.get(MODE).map(ParseMode.fromString).getOrElse(PermissiveMode)
+ val columnNameOfCorruptRecord =
+ parameters.getOrElse(COLUMN_NAME_OF_CORRUPTED_RECORD, defaultColumnNameOfCorruptRecord)
+
+ // Whether to ignore column of all null values or empty array/struct during schema inference
+ val dropFieldIfAllNull = parameters.get(DROP_FIELD_IF_ALL_NULL).map(_.toBoolean).getOrElse(false)
+
+ // Whether to ignore null fields during json generating
+ val ignoreNullFields = parameters.get(IGNORE_NULL_FIELDS).map(_.toBoolean)
+ .getOrElse(SQLConf.get.jsonGeneratorIgnoreNullFields)
+
+ // If this is true, when writing NULL values to columns of JSON tables with explicit DEFAULT
+ // values, never skip writing the NULL values to storage, overriding 'ignoreNullFields' above.
+ // This can be useful to enforce that inserted NULL values are present in storage to differentiate
+ // from missing data.
+ val writeNullIfWithDefaultValue = SQLConf.get.jsonWriteNullIfWithDefaultValue
+
+ // A language tag in IETF BCP 47 format
+ val locale: Locale = parameters.get(LOCALE).map(Locale.forLanguageTag).getOrElse(Locale.US)
+
+ val zoneId: ZoneId = DateTimeUtils.getZoneId(
+ parameters.getOrElse(DateTimeUtils.TIMEZONE_OPTION, defaultTimeZoneId))
+
+ val dateFormatInRead: Option[String] = parameters.get(DATE_FORMAT)
+ val dateFormatInWrite: String = parameters.getOrElse(DATE_FORMAT, DateFormatter.defaultPattern)
+
+ val timestampFormatInRead: Option[String] =
+ if (SQLConf.get.legacyTimeParserPolicy == LegacyBehaviorPolicy.LEGACY) {
+ Some(parameters.getOrElse(TIMESTAMP_FORMAT,
+ s"${DateFormatter.defaultPattern}'T'HH:mm:ss.SSSXXX"))
+ } else {
+ parameters.get(TIMESTAMP_FORMAT)
+ }
+ val timestampFormatInWrite: String = parameters.getOrElse(TIMESTAMP_FORMAT,
+ if (SQLConf.get.legacyTimeParserPolicy == LegacyBehaviorPolicy.LEGACY) {
+ s"${DateFormatter.defaultPattern}'T'HH:mm:ss.SSSXXX"
+ } else {
+ s"${DateFormatter.defaultPattern}'T'HH:mm:ss[.SSS][XXX]"
+ })
+
+ val timestampNTZFormatInRead: Option[String] = parameters.get(TIMESTAMP_NTZ_FORMAT)
+ val timestampNTZFormatInWrite: String =
+ parameters.getOrElse(TIMESTAMP_NTZ_FORMAT, s"${DateFormatter.defaultPattern}'T'HH:mm:ss[.SSS]")
+
+ // SPARK-39731: Enables the backward compatible parsing behavior.
+ // Generally, this config should be set to false to avoid producing potentially incorrect results
+ // which is the current default (see JacksonParser).
+ //
+ // If enabled and the date cannot be parsed, we will fall back to `DateTimeUtils.stringToDate`.
+ // If enabled and the timestamp cannot be parsed, `DateTimeUtils.stringToTimestamp` will be used.
+ // Otherwise, depending on the parser policy and a custom pattern, an exception may be thrown and
+ // the value will be parsed as null.
+ val enableDateTimeParsingFallback: Option[Boolean] =
+ parameters.get(ENABLE_DATETIME_PARSING_FALLBACK).map(_.toBoolean)
+
+ val multiLine = parameters.get(MULTI_LINE).map(_.toBoolean).getOrElse(false)
+
+ /**
+ * A string between two consecutive JSON records.
+ */
+ val lineSeparator: Option[String] = parameters.get(LINE_SEP).map { sep =>
+ require(sep.nonEmpty, "'lineSep' cannot be an empty string.")
+ sep
+ }
+
+ protected def checkedEncoding(enc: String): String = enc
+
+ /**
+ * Standard encoding (charset) name. For example UTF-8, UTF-16LE and UTF-32BE.
+ * If the encoding is not specified (None) in read, it will be detected automatically
+ * when the multiLine option is set to `true`. If encoding is not specified in write,
+ * UTF-8 is used by default.
+ */
+ val encoding: Option[String] = parameters.get(ENCODING)
+ .orElse(parameters.get(CHARSET)).map(checkedEncoding)
+
+ val lineSeparatorInRead: Option[Array[Byte]] = lineSeparator.map { lineSep =>
+ lineSep.getBytes(encoding.getOrElse(StandardCharsets.UTF_8.name()))
+ }
+ val lineSeparatorInWrite: String = lineSeparator.getOrElse("\n")
+
+ /**
+ * Generating JSON strings in pretty representation if the parameter is enabled.
+ */
+ val pretty: Boolean = parameters.get(PRETTY).map(_.toBoolean).getOrElse(false)
+
+ /**
+ * Enables inferring of TimestampType and TimestampNTZType from strings matched to the
+ * corresponding timestamp pattern defined by the timestampFormat and timestampNTZFormat options
+ * respectively.
+ */
+ val inferTimestamp: Boolean = parameters.get(INFER_TIMESTAMP).map(_.toBoolean).getOrElse(false)
+
+ /**
+ * Generating \u0000 style codepoints for non-ASCII characters if the parameter is enabled.
+ */
+ val writeNonAsciiCharacterAsCodePoint: Boolean =
+ parameters.get(WRITE_NON_ASCII_CHARACTER_AS_CODEPOINT).map(_.toBoolean).getOrElse(false)
+
+ /** Build a Jackson [[JsonFactory]] using JSON options. */
+ def buildJsonFactory(): JsonFactory = {
+ new JsonFactoryBuilder()
+ .configure(JsonReadFeature.ALLOW_JAVA_COMMENTS, allowComments)
+ .configure(JsonReadFeature.ALLOW_UNQUOTED_FIELD_NAMES, allowUnquotedFieldNames)
+ .configure(JsonReadFeature.ALLOW_SINGLE_QUOTES, allowSingleQuotes)
+ .configure(JsonReadFeature.ALLOW_LEADING_ZEROS_FOR_NUMBERS, allowNumericLeadingZeros)
+ .configure(JsonReadFeature.ALLOW_NON_NUMERIC_NUMBERS, allowNonNumericNumbers)
+ .configure(
+ JsonReadFeature.ALLOW_BACKSLASH_ESCAPING_ANY_CHARACTER,
+ allowBackslashEscapingAnyCharacter)
+ .configure(JsonReadFeature.ALLOW_UNESCAPED_CONTROL_CHARS, allowUnquotedControlChars)
+ .build()
+ }
+}
+
+private[sql] class JSONOptionsInRead(
+ @transient override val parameters: CaseInsensitiveMap[String],
+ defaultTimeZoneId: String,
+ defaultColumnNameOfCorruptRecord: String)
+ extends JSONOptions(parameters, defaultTimeZoneId, defaultColumnNameOfCorruptRecord) {
+
+ def this(
+ parameters: Map[String, String],
+ defaultTimeZoneId: String,
+ defaultColumnNameOfCorruptRecord: String = "") = {
+ this(
+ CaseInsensitiveMap(parameters),
+ defaultTimeZoneId,
+ defaultColumnNameOfCorruptRecord)
+ }
+
+ protected override def checkedEncoding(enc: String): String = {
+ val isDenied = JSONOptionsInRead.denyList.contains(Charset.forName(enc))
+ require(multiLine || !isDenied,
+ s"""The $enc encoding must not be included in the denyList when multiLine is disabled:
+ |denylist: ${JSONOptionsInRead.denyList.mkString(", ")}""".stripMargin)
+
+ val isLineSepRequired =
+ multiLine || Charset.forName(enc) == StandardCharsets.UTF_8 || lineSeparator.nonEmpty
+ require(isLineSepRequired, s"The lineSep option must be specified for the $enc encoding")
+
+ enc
+ }
+}
+
+private[sql] object JSONOptionsInRead {
+ // The following encodings are not supported in per-line mode (multiline is false)
+ // because they cause some problems in reading files with BOM which is supposed to
+ // present in the files with such encodings. After splitting input files by lines,
+ // only the first lines will have the BOM which leads to impossibility for reading
+ // the rest lines. Besides of that, the lineSep option must have the BOM in such
+ // encodings which can never present between lines.
+ val denyList = Seq(
+ Charset.forName("UTF-16"),
+ Charset.forName("UTF-32")
+ )
+}
+
+object JSONOptions extends DataSourceOptions {
+ val SAMPLING_RATIO = newOption("samplingRatio")
+ val PRIMITIVES_AS_STRING = newOption("primitivesAsString")
+ val PREFERS_DECIMAL = newOption("prefersDecimal")
+ val ALLOW_COMMENTS = newOption("allowComments")
+ val ALLOW_UNQUOTED_FIELD_NAMES = newOption("allowUnquotedFieldNames")
+ val ALLOW_SINGLE_QUOTES = newOption("allowSingleQuotes")
+ val ALLOW_NUMERIC_LEADING_ZEROS = newOption("allowNumericLeadingZeros")
+ val ALLOW_NON_NUMERIC_NUMBERS = newOption("allowNonNumericNumbers")
+ val ALLOW_BACKSLASH_ESCAPING_ANY_CHARACTER = newOption("allowBackslashEscapingAnyCharacter")
+ val ALLOW_UNQUOTED_CONTROL_CHARS = newOption("allowUnquotedControlChars")
+ val COMPRESSION = newOption("compression")
+ val MODE = newOption("mode")
+ val DROP_FIELD_IF_ALL_NULL = newOption("dropFieldIfAllNull")
+ val IGNORE_NULL_FIELDS = newOption("ignoreNullFields")
+ val LOCALE = newOption("locale")
+ val DATE_FORMAT = newOption("dateFormat")
+ val TIMESTAMP_FORMAT = newOption("timestampFormat")
+ val TIMESTAMP_NTZ_FORMAT = newOption("timestampNTZFormat")
+ val ENABLE_DATETIME_PARSING_FALLBACK = newOption("enableDateTimeParsingFallback")
+ val MULTI_LINE = newOption("multiLine")
+ val LINE_SEP = newOption("lineSep")
+ val PRETTY = newOption("pretty")
+ val INFER_TIMESTAMP = newOption("inferTimestamp")
+ val COLUMN_NAME_OF_CORRUPTED_RECORD = newOption("columnNameOfCorruptRecord")
+ val TIME_ZONE = newOption("timeZone")
+ val WRITE_NON_ASCII_CHARACTER_AS_CODEPOINT = newOption("writeNonAsciiCharacterAsCodePoint")
+ // Options with alternative
+ val ENCODING = "encoding"
+ val CHARSET = "charset"
+ newOption(ENCODING, CHARSET)
+}
diff --git a/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/json/JacksonGenerator.scala b/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/json/JacksonGenerator.scala
new file mode 100644
index 00000000..c28f788b
--- /dev/null
+++ b/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/json/JacksonGenerator.scala
@@ -0,0 +1,327 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// scalastyle:off
+
+package org.apache.spark.sql.arangodb.datasource.mapping.json
+
+import com.fasterxml.jackson.core._
+import com.fasterxml.jackson.core.util.DefaultPrettyPrinter
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.SpecializedGetters
+import org.apache.spark.sql.catalyst.util.LegacyDateFormats.FAST_DATE_FORMAT
+import org.apache.spark.sql.catalyst.util._
+import org.apache.spark.sql.errors.QueryExecutionErrors
+import org.apache.spark.sql.types._
+
+import java.io.Writer
+
+/**
+ * `JackGenerator` can only be initialized with a `StructType`, a `MapType` or an `ArrayType`.
+ * Once it is initialized with `StructType`, it can be used to write out a struct or an array of
+ * struct. Once it is initialized with `MapType`, it can be used to write out a map or an array
+ * of map. An exception will be thrown if trying to write out a struct if it is initialized with
+ * a `MapType`, and vice verse.
+ */
+private[sql] class JacksonGenerator(
+ dataType: DataType,
+ generator: JsonGenerator,
+ options: JSONOptions) {
+
+ def this(dataType: DataType,
+ writer: Writer,
+ options: JSONOptions) {
+ this(
+ dataType,
+ options.buildJsonFactory().createGenerator(writer).setRootValueSeparator(null),
+ options)
+ }
+
+ // A `ValueWriter` is responsible for writing a field of an `InternalRow` to appropriate
+ // JSON data. Here we are using `SpecializedGetters` rather than `InternalRow` so that
+ // we can directly access data in `ArrayData` without the help of `SpecificMutableRow`.
+ private type ValueWriter = (SpecializedGetters, Int) => Unit
+
+ // `JackGenerator` can only be initialized with a `StructType`, a `MapType` or a `ArrayType`.
+ require(dataType.isInstanceOf[StructType] || dataType.isInstanceOf[MapType]
+ || dataType.isInstanceOf[ArrayType],
+ s"JacksonGenerator only supports to be initialized with a ${StructType.simpleString}, " +
+ s"${MapType.simpleString} or ${ArrayType.simpleString} but got ${dataType.catalogString}")
+
+ // `ValueWriter`s for all fields of the schema
+ private lazy val rootFieldWriters: Array[ValueWriter] = dataType match {
+ case st: StructType => st.map(_.dataType).map(makeWriter).toArray
+ case _ => throw QueryExecutionErrors.initialTypeNotTargetDataTypeError(
+ dataType, StructType.simpleString)
+ }
+
+ // `ValueWriter` for array data storing rows of the schema.
+ private lazy val arrElementWriter: ValueWriter = dataType match {
+ case at: ArrayType => makeWriter(at.elementType)
+ case _: StructType | _: MapType => makeWriter(dataType)
+ case _ => throw QueryExecutionErrors.initialTypeNotTargetDataTypesError(dataType)
+ }
+
+ private lazy val mapElementWriter: ValueWriter = dataType match {
+ case mt: MapType => makeWriter(mt.valueType)
+ case _ => throw QueryExecutionErrors.initialTypeNotTargetDataTypeError(
+ dataType, MapType.simpleString)
+ }
+
+ private val gen = {
+ if (options.pretty) generator.setPrettyPrinter(new DefaultPrettyPrinter("")) else generator
+ }
+
+ private val lineSeparator: String = options.lineSeparatorInWrite
+
+ private val timestampFormatter = TimestampFormatter(
+ options.timestampFormatInWrite,
+ options.zoneId,
+ options.locale,
+ legacyFormat = FAST_DATE_FORMAT,
+ isParsing = false)
+ private val timestampNTZFormatter = TimestampFormatter(
+ options.timestampNTZFormatInWrite,
+ options.zoneId,
+ legacyFormat = FAST_DATE_FORMAT,
+ isParsing = false,
+ forTimestampNTZ = true)
+ private val dateFormatter = DateFormatter(
+ options.dateFormatInWrite,
+ options.locale,
+ legacyFormat = FAST_DATE_FORMAT,
+ isParsing = false)
+
+ private def makeWriter(dataType: DataType): ValueWriter = dataType match {
+ case NullType =>
+ (row: SpecializedGetters, ordinal: Int) =>
+ gen.writeNull()
+
+ case BooleanType =>
+ (row: SpecializedGetters, ordinal: Int) =>
+ gen.writeBoolean(row.getBoolean(ordinal))
+
+ case ByteType =>
+ (row: SpecializedGetters, ordinal: Int) =>
+ gen.writeNumber(row.getByte(ordinal))
+
+ case ShortType =>
+ (row: SpecializedGetters, ordinal: Int) =>
+ gen.writeNumber(row.getShort(ordinal))
+
+ case IntegerType =>
+ (row: SpecializedGetters, ordinal: Int) =>
+ gen.writeNumber(row.getInt(ordinal))
+
+ case LongType =>
+ (row: SpecializedGetters, ordinal: Int) =>
+ gen.writeNumber(row.getLong(ordinal))
+
+ case FloatType =>
+ (row: SpecializedGetters, ordinal: Int) =>
+ gen.writeNumber(row.getFloat(ordinal))
+
+ case DoubleType =>
+ (row: SpecializedGetters, ordinal: Int) =>
+ gen.writeNumber(row.getDouble(ordinal))
+
+ case StringType =>
+ (row: SpecializedGetters, ordinal: Int) =>
+ gen.writeString(row.getUTF8String(ordinal).toString)
+
+ case TimestampType =>
+ (row: SpecializedGetters, ordinal: Int) =>
+ val timestampString = timestampFormatter.format(row.getLong(ordinal))
+ gen.writeString(timestampString)
+
+ case TimestampNTZType =>
+ (row: SpecializedGetters, ordinal: Int) =>
+ val timestampString =
+ timestampNTZFormatter.format(DateTimeUtils.microsToLocalDateTime(row.getLong(ordinal)))
+ gen.writeString(timestampString)
+
+ case DateType =>
+ (row: SpecializedGetters, ordinal: Int) =>
+ val dateString = dateFormatter.format(row.getInt(ordinal))
+ gen.writeString(dateString)
+
+ case CalendarIntervalType =>
+ (row: SpecializedGetters, ordinal: Int) =>
+ gen.writeString(row.getInterval(ordinal).toString)
+
+ case YearMonthIntervalType(start, end) =>
+ (row: SpecializedGetters, ordinal: Int) =>
+ val ymString = IntervalUtils.toYearMonthIntervalString(
+ row.getInt(ordinal),
+ IntervalStringStyles.ANSI_STYLE,
+ start,
+ end)
+ gen.writeString(ymString)
+
+ case DayTimeIntervalType(start, end) =>
+ (row: SpecializedGetters, ordinal: Int) =>
+ val dtString = IntervalUtils.toDayTimeIntervalString(
+ row.getLong(ordinal),
+ IntervalStringStyles.ANSI_STYLE,
+ start,
+ end)
+ gen.writeString(dtString)
+
+ case BinaryType =>
+ (row: SpecializedGetters, ordinal: Int) =>
+ gen.writeBinary(row.getBinary(ordinal))
+
+ case dt: DecimalType =>
+ (row: SpecializedGetters, ordinal: Int) =>
+ gen.writeNumber(row.getDecimal(ordinal, dt.precision, dt.scale).toJavaBigDecimal)
+
+ case st: StructType =>
+ val fieldWriters = st.map(_.dataType).map(makeWriter)
+ (row: SpecializedGetters, ordinal: Int) =>
+ writeObject(writeFields(row.getStruct(ordinal, st.length), st, fieldWriters))
+
+ case at: ArrayType =>
+ val elementWriter = makeWriter(at.elementType)
+ (row: SpecializedGetters, ordinal: Int) =>
+ writeArray(writeArrayData(row.getArray(ordinal), elementWriter))
+
+ case mt: MapType =>
+ val valueWriter = makeWriter(mt.valueType)
+ (row: SpecializedGetters, ordinal: Int) =>
+ writeObject(writeMapData(row.getMap(ordinal), mt, valueWriter))
+
+ // For UDT values, they should be in the SQL type's corresponding value type.
+ // We should not see values in the user-defined class at here.
+ // For example, VectorUDT's SQL type is an array of double. So, we should expect that v is
+ // an ArrayData at here, instead of a Vector.
+ case t: UserDefinedType[_] =>
+ makeWriter(t.sqlType)
+
+ case _ =>
+ (row: SpecializedGetters, ordinal: Int) =>
+ val v = row.get(ordinal, dataType)
+ throw QueryExecutionErrors.failToConvertValueToJsonError(v, v.getClass, dataType)
+ }
+
+ private def writeObject(f: => Unit): Unit = {
+ gen.writeStartObject()
+ f
+ gen.writeEndObject()
+ }
+
+ private def writeFields(
+ row: InternalRow, schema: StructType, fieldWriters: Seq[ValueWriter]): Unit = {
+ var i = 0
+ while (i < row.numFields) {
+ val field = schema(i)
+ if (!row.isNullAt(i)) {
+ gen.writeFieldName(field.name)
+ fieldWriters(i).apply(row, i)
+ } else if ((!options.ignoreNullFields ||
+ (options.writeNullIfWithDefaultValue && field.getExistenceDefaultValue().isDefined)) && field.name != "_key") {
+ gen.writeFieldName(field.name)
+ gen.writeNull()
+ }
+ i += 1
+ }
+ }
+
+ private def writeArray(f: => Unit): Unit = {
+ gen.writeStartArray()
+ f
+ gen.writeEndArray()
+ }
+
+ private def writeArrayData(
+ array: ArrayData, fieldWriter: ValueWriter): Unit = {
+ var i = 0
+ while (i < array.numElements()) {
+ if (!array.isNullAt(i)) {
+ fieldWriter.apply(array, i)
+ } else {
+ gen.writeNull()
+ }
+ i += 1
+ }
+ }
+
+ private def writeMapData(
+ map: MapData, mapType: MapType, fieldWriter: ValueWriter): Unit = {
+ val keyArray = map.keyArray()
+ val valueArray = map.valueArray()
+ var i = 0
+ while (i < map.numElements()) {
+ gen.writeFieldName(keyArray.get(i, mapType.keyType).toString)
+ if (!valueArray.isNullAt(i)) {
+ fieldWriter.apply(valueArray, i)
+ } else {
+ gen.writeNull()
+ }
+ i += 1
+ }
+ }
+
+ def close(): Unit = gen.close()
+
+ def flush(): Unit = gen.flush()
+
+ def writeStartArray(): Unit = {
+ gen.writeStartArray()
+ }
+
+ def writeEndArray(): Unit = {
+ gen.writeEndArray()
+ }
+
+ /**
+ * Transforms a single `InternalRow` to JSON object using Jackson.
+ * This api calling will be validated through accessing `rootFieldWriters`.
+ *
+ * @param row The row to convert
+ */
+ def write(row: InternalRow): Unit = {
+ writeObject(writeFields(
+ fieldWriters = rootFieldWriters,
+ row = row,
+ schema = dataType.asInstanceOf[StructType]))
+ }
+
+ /**
+ * Transforms multiple `InternalRow`s or `MapData`s to JSON array using Jackson
+ *
+ * @param array The array of rows or maps to convert
+ */
+ def write(array: ArrayData): Unit = writeArray(writeArrayData(array, arrElementWriter))
+
+ /**
+ * Transforms a single `MapData` to JSON object using Jackson
+ * This api calling will will be validated through accessing `mapElementWriter`.
+ *
+ * @param map a map to convert
+ */
+ def write(map: MapData): Unit = {
+ writeObject(writeMapData(
+ fieldWriter = mapElementWriter,
+ map = map,
+ mapType = dataType.asInstanceOf[MapType]))
+ }
+
+ def writeLineEnding(): Unit = {
+ // Note that JSON uses writer with UTF-8 charset. This string will be written out as UTF-8.
+ gen.writeRaw(lineSeparator)
+ }
+}
diff --git a/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/json/JacksonParser.scala b/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/json/JacksonParser.scala
new file mode 100644
index 00000000..cdafc938
--- /dev/null
+++ b/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/json/JacksonParser.scala
@@ -0,0 +1,589 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// scalastyle:off
+
+package org.apache.spark.sql.arangodb.datasource.mapping.json
+
+import com.fasterxml.jackson.core._
+import org.apache.spark.SparkUpgradeException
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.util.LegacyDateFormats.FAST_DATE_FORMAT
+import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns._
+import org.apache.spark.sql.catalyst.util._
+import org.apache.spark.sql.catalyst.{InternalRow, NoopFilters, StructFilters}
+import org.apache.spark.sql.errors.QueryExecutionErrors
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.sources.Filter
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
+import org.apache.spark.util.Utils
+
+import java.io.{ByteArrayOutputStream, CharConversionException}
+import java.nio.charset.MalformedInputException
+import scala.collection.mutable.ArrayBuffer
+import scala.util.control.NonFatal
+
+/**
+ * Constructs a parser for a given schema that translates a json string to an [[InternalRow]].
+ */
+class JacksonParser(
+ schema: DataType,
+ val options: JSONOptions,
+ allowArrayAsStructs: Boolean = false,
+ filters: Seq[Filter] = Seq.empty) extends Logging {
+
+ import JacksonUtils._
+ import com.fasterxml.jackson.core.JsonToken._
+
+ // A `ValueConverter` is responsible for converting a value from `JsonParser`
+ // to a value in a field for `InternalRow`.
+ private type ValueConverter = JsonParser => AnyRef
+
+ // `ValueConverter`s for the root schema for all fields in the schema
+ private val rootConverter = makeRootConverter(schema)
+
+ private val factory = options.buildJsonFactory()
+
+ private lazy val timestampFormatter = TimestampFormatter(
+ options.timestampFormatInRead,
+ options.zoneId,
+ options.locale,
+ legacyFormat = FAST_DATE_FORMAT,
+ isParsing = true)
+ private lazy val timestampNTZFormatter = TimestampFormatter(
+ options.timestampNTZFormatInRead,
+ options.zoneId,
+ legacyFormat = FAST_DATE_FORMAT,
+ isParsing = true,
+ forTimestampNTZ = true)
+ private lazy val dateFormatter = DateFormatter(
+ options.dateFormatInRead,
+ options.locale,
+ legacyFormat = FAST_DATE_FORMAT,
+ isParsing = true)
+
+ // Flags to signal if we need to fall back to the backward compatible behavior of parsing
+ // dates and timestamps.
+ // For more information, see comments for "enableDateTimeParsingFallback" option in JSONOptions.
+ private val enableParsingFallbackForTimestampType =
+ options.enableDateTimeParsingFallback
+ .orElse(SQLConf.get.jsonEnableDateTimeParsingFallback)
+ .getOrElse {
+ SQLConf.get.legacyTimeParserPolicy == SQLConf.LegacyBehaviorPolicy.LEGACY ||
+ options.timestampFormatInRead.isEmpty
+ }
+ private val enableParsingFallbackForDateType =
+ options.enableDateTimeParsingFallback
+ .orElse(SQLConf.get.jsonEnableDateTimeParsingFallback)
+ .getOrElse {
+ SQLConf.get.legacyTimeParserPolicy == SQLConf.LegacyBehaviorPolicy.LEGACY ||
+ options.dateFormatInRead.isEmpty
+ }
+
+ private val enablePartialResults = SQLConf.get.jsonEnablePartialResults
+
+ /**
+ * Create a converter which converts the JSON documents held by the `JsonParser`
+ * to a value according to a desired schema. This is a wrapper for the method
+ * `makeConverter()` to handle a row wrapped with an array.
+ */
+ private def makeRootConverter(dt: DataType): JsonParser => Iterable[InternalRow] = {
+ dt match {
+ case st: StructType => makeStructRootConverter(st)
+ case mt: MapType => makeMapRootConverter(mt)
+ case at: ArrayType => makeArrayRootConverter(at)
+ }
+ }
+
+ private def makeStructRootConverter(st: StructType): JsonParser => Iterable[InternalRow] = {
+ val elementConverter = makeConverter(st)
+ val fieldConverters = st.map(_.dataType).map(makeConverter).toArray
+ val jsonFilters = if (SQLConf.get.jsonFilterPushDown) {
+ new JsonFilters(filters, st)
+ } else {
+ new NoopFilters
+ }
+ (parser: JsonParser) => parseJsonToken[Iterable[InternalRow]](parser, st) {
+ case START_OBJECT => convertObject(parser, st, fieldConverters, jsonFilters, isRoot = true)
+ // SPARK-3308: support reading top level JSON arrays and take every element
+ // in such an array as a row
+ //
+ // For example, we support, the JSON data as below:
+ //
+ // [{"a":"str_a_1"}]
+ // [{"a":"str_a_2"}, {"b":"str_b_3"}]
+ //
+ // resulting in:
+ //
+ // List([str_a_1,null])
+ // List([str_a_2,null], [null,str_b_3])
+ //
+ case START_ARRAY if allowArrayAsStructs =>
+ val array = convertArray(parser, elementConverter, isRoot = true)
+ // Here, as we support reading top level JSON arrays and take every element
+ // in such an array as a row, this case is possible.
+ if (array.numElements() == 0) {
+ Array.empty[InternalRow]
+ } else {
+ array.toArray[InternalRow](schema)
+ }
+ case START_ARRAY =>
+ throw QueryExecutionErrors.cannotParseJsonArraysAsStructsError()
+ }
+ }
+
+ private def makeMapRootConverter(mt: MapType): JsonParser => Iterable[InternalRow] = {
+ val fieldConverter = makeConverter(mt.valueType)
+ (parser: JsonParser) => parseJsonToken[Iterable[InternalRow]](parser, mt) {
+ case START_OBJECT => Some(InternalRow(convertMap(parser, fieldConverter)))
+ }
+ }
+
+ private def makeArrayRootConverter(at: ArrayType): JsonParser => Iterable[InternalRow] = {
+ val elemConverter = makeConverter(at.elementType)
+ (parser: JsonParser) => parseJsonToken[Iterable[InternalRow]](parser, at) {
+ case START_ARRAY => Some(InternalRow(convertArray(parser, elemConverter)))
+ case START_OBJECT if at.elementType.isInstanceOf[StructType] =>
+ // This handles the case when an input JSON object is a structure but
+ // the specified schema is an array of structures. In that case, the input JSON is
+ // considered as an array of only one element of struct type.
+ // This behavior was introduced by changes for SPARK-19595.
+ //
+ // For example, if the specified schema is ArrayType(new StructType().add("i", IntegerType))
+ // and JSON input as below:
+ //
+ // [{"i": 1}, {"i": 2}]
+ // [{"i": 3}]
+ // {"i": 4}
+ //
+ // The last row is considered as an array with one element, and result of conversion:
+ //
+ // Seq(Row(1), Row(2))
+ // Seq(Row(3))
+ // Seq(Row(4))
+ //
+ val st = at.elementType.asInstanceOf[StructType]
+ val fieldConverters = st.map(_.dataType).map(makeConverter).toArray
+ Some(InternalRow(new GenericArrayData(convertObject(parser, st, fieldConverters).toArray)))
+ }
+ }
+
+ private val decimalParser = ExprUtils.getDecimalParser(options.locale)
+
+ /**
+ * Create a converter which converts the JSON documents held by the `JsonParser`
+ * to a value according to a desired schema.
+ */
+ def makeConverter(dataType: DataType): ValueConverter = dataType match {
+ case BooleanType =>
+ (parser: JsonParser) => parseJsonToken[java.lang.Boolean](parser, dataType) {
+ case VALUE_TRUE => true
+ case VALUE_FALSE => false
+ }
+
+ case ByteType =>
+ (parser: JsonParser) => parseJsonToken[java.lang.Byte](parser, dataType) {
+ case VALUE_NUMBER_INT => parser.getByteValue
+ }
+
+ case ShortType =>
+ (parser: JsonParser) => parseJsonToken[java.lang.Short](parser, dataType) {
+ case VALUE_NUMBER_INT => parser.getShortValue
+ }
+
+ case IntegerType =>
+ (parser: JsonParser) => parseJsonToken[java.lang.Integer](parser, dataType) {
+ case VALUE_NUMBER_INT => parser.getIntValue
+ }
+
+ case LongType =>
+ (parser: JsonParser) => parseJsonToken[java.lang.Long](parser, dataType) {
+ case VALUE_NUMBER_INT => parser.getLongValue
+ }
+
+ case FloatType =>
+ (parser: JsonParser) => parseJsonToken[java.lang.Float](parser, dataType) {
+ case VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT =>
+ parser.getFloatValue
+
+ case VALUE_STRING if parser.getTextLength >= 1 =>
+ // Special case handling for NaN and Infinity.
+ parser.getText match {
+ case "NaN" if options.allowNonNumericNumbers =>
+ Float.NaN
+ case "+INF" | "+Infinity" | "Infinity" if options.allowNonNumericNumbers =>
+ Float.PositiveInfinity
+ case "-INF" | "-Infinity" if options.allowNonNumericNumbers =>
+ Float.NegativeInfinity
+ case _ => throw QueryExecutionErrors.cannotParseStringAsDataTypeError(
+ parser, VALUE_STRING, FloatType)
+ }
+ }
+
+ case DoubleType =>
+ (parser: JsonParser) => parseJsonToken[java.lang.Double](parser, dataType) {
+ case VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT =>
+ parser.getDoubleValue
+
+ case VALUE_STRING if parser.getTextLength >= 1 =>
+ // Special case handling for NaN and Infinity.
+ parser.getText match {
+ case "NaN" if options.allowNonNumericNumbers =>
+ Double.NaN
+ case "+INF" | "+Infinity" | "Infinity" if options.allowNonNumericNumbers =>
+ Double.PositiveInfinity
+ case "-INF" | "-Infinity" if options.allowNonNumericNumbers =>
+ Double.NegativeInfinity
+ case _ => throw QueryExecutionErrors.cannotParseStringAsDataTypeError(
+ parser, VALUE_STRING, DoubleType)
+ }
+ }
+
+ case StringType =>
+ (parser: JsonParser) => parseJsonToken[UTF8String](parser, dataType) {
+ case VALUE_STRING =>
+ UTF8String.fromString(parser.getText)
+
+ case _ =>
+ // Note that it always tries to convert the data as string without the case of failure.
+ val writer = new ByteArrayOutputStream()
+ Utils.tryWithResource(factory.createGenerator(writer, JsonEncoding.UTF8)) {
+ generator => generator.copyCurrentStructure(parser)
+ }
+ UTF8String.fromBytes(writer.toByteArray)
+ }
+
+ case TimestampType =>
+ (parser: JsonParser) => parseJsonToken[java.lang.Long](parser, dataType) {
+ case VALUE_STRING if parser.getTextLength >= 1 =>
+ try {
+ timestampFormatter.parse(parser.getText)
+ } catch {
+ case NonFatal(e) =>
+ // If fails to parse, then tries the way used in 2.0 and 1.x for backwards
+ // compatibility if enabled.
+ if (!enableParsingFallbackForTimestampType) {
+ throw e
+ }
+ val str = DateTimeUtils.cleanLegacyTimestampStr(UTF8String.fromString(parser.getText))
+ DateTimeUtils.stringToTimestamp(str, options.zoneId).getOrElse(throw e)
+ }
+
+ case VALUE_NUMBER_INT =>
+ parser.getLongValue * 1000L
+ }
+
+ case TimestampNTZType =>
+ (parser: JsonParser) => parseJsonToken[java.lang.Long](parser, dataType) {
+ case VALUE_STRING if parser.getTextLength >= 1 =>
+ timestampNTZFormatter.parseWithoutTimeZone(parser.getText, false)
+ }
+
+ case DateType =>
+ (parser: JsonParser) => parseJsonToken[java.lang.Integer](parser, dataType) {
+ case VALUE_STRING if parser.getTextLength >= 1 =>
+ try {
+ dateFormatter.parse(parser.getText)
+ } catch {
+ case NonFatal(e) =>
+ // If fails to parse, then tries the way used in 2.0 and 1.x for backwards
+ // compatibility if enabled.
+ if (!enableParsingFallbackForDateType) {
+ throw e
+ }
+ val str = DateTimeUtils.cleanLegacyTimestampStr(UTF8String.fromString(parser.getText))
+ DateTimeUtils.stringToDate(str).getOrElse {
+ // In Spark 1.5.0, we store the data as number of days since epoch in string.
+ // So, we just convert it to Int.
+ try {
+ RebaseDateTime.rebaseJulianToGregorianDays(parser.getText.toInt)
+ } catch {
+ case _: NumberFormatException => throw e
+ }
+ }.asInstanceOf[Integer]
+ }
+ }
+
+ case BinaryType =>
+ (parser: JsonParser) => parseJsonToken[Array[Byte]](parser, dataType) {
+ case VALUE_STRING => parser.getBinaryValue
+ }
+
+ case dt: DecimalType =>
+ (parser: JsonParser) => parseJsonToken[Decimal](parser, dataType) {
+ case (VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT) =>
+ Decimal(parser.getDecimalValue, dt.precision, dt.scale)
+ case VALUE_STRING if parser.getTextLength >= 1 =>
+ val bigDecimal = decimalParser(parser.getText)
+ Decimal(bigDecimal, dt.precision, dt.scale)
+ }
+
+ case CalendarIntervalType => (parser: JsonParser) =>
+ parseJsonToken[CalendarInterval](parser, dataType) {
+ case VALUE_STRING =>
+ IntervalUtils.safeStringToInterval(UTF8String.fromString(parser.getText))
+ }
+
+ case ym: YearMonthIntervalType => (parser: JsonParser) =>
+ parseJsonToken[Integer](parser, dataType) {
+ case VALUE_STRING =>
+ val expr = Cast(Literal(parser.getText), ym)
+ Integer.valueOf(expr.eval(EmptyRow).asInstanceOf[Int])
+ }
+
+ case dt: DayTimeIntervalType => (parser: JsonParser) =>
+ parseJsonToken[java.lang.Long](parser, dataType) {
+ case VALUE_STRING =>
+ val expr = Cast(Literal(parser.getText), dt)
+ java.lang.Long.valueOf(expr.eval(EmptyRow).asInstanceOf[Long])
+ }
+
+ case st: StructType =>
+ val fieldConverters = st.map(_.dataType).map(makeConverter).toArray
+ (parser: JsonParser) => parseJsonToken[InternalRow](parser, dataType) {
+ case START_OBJECT => convertObject(parser, st, fieldConverters).get
+ }
+
+ case at: ArrayType =>
+ val elementConverter = makeConverter(at.elementType)
+ (parser: JsonParser) => parseJsonToken[ArrayData](parser, dataType) {
+ case START_ARRAY => convertArray(parser, elementConverter)
+ }
+
+ case mt: MapType =>
+ val valueConverter = makeConverter(mt.valueType)
+ (parser: JsonParser) => parseJsonToken[MapData](parser, dataType) {
+ case START_OBJECT => convertMap(parser, valueConverter)
+ }
+
+ case udt: UserDefinedType[_] =>
+ makeConverter(udt.sqlType)
+
+ case _: NullType =>
+ (parser: JsonParser) => parseJsonToken[java.lang.Long](parser, dataType) {
+ case _ => null
+ }
+
+ // We don't actually hit this exception though, we keep it for understandability
+ case _ => throw QueryExecutionErrors.unsupportedTypeError(dataType)
+ }
+
+ /**
+ * This method skips `FIELD_NAME`s at the beginning, and handles nulls ahead before trying
+ * to parse the JSON token using given function `f`. If the `f` failed to parse and convert the
+ * token, call `failedConversion` to handle the token.
+ */
+ @scala.annotation.tailrec
+ private def parseJsonToken[R >: Null](
+ parser: JsonParser,
+ dataType: DataType)(f: PartialFunction[JsonToken, R]): R = {
+ parser.getCurrentToken match {
+ case FIELD_NAME =>
+ // There are useless FIELD_NAMEs between START_OBJECT and END_OBJECT tokens
+ parser.nextToken()
+ parseJsonToken[R](parser, dataType)(f)
+
+ case null | VALUE_NULL => null
+
+ case other => f.applyOrElse(other, failedConversion(parser, dataType))
+ }
+ }
+
+ private val allowEmptyString = SQLConf.get.getConf(SQLConf.LEGACY_ALLOW_EMPTY_STRING_IN_JSON)
+
+ /**
+ * This function throws an exception for failed conversion. For empty string on data types
+ * except for string and binary types, this also throws an exception.
+ */
+ private def failedConversion[R >: Null](
+ parser: JsonParser,
+ dataType: DataType): PartialFunction[JsonToken, R] = {
+
+ // SPARK-25040: Disallows empty strings for data types except for string and binary types.
+ // But treats empty strings as null for certain types if the legacy config is enabled.
+ case VALUE_STRING if parser.getTextLength < 1 && allowEmptyString =>
+ dataType match {
+ case FloatType | DoubleType | TimestampType | DateType =>
+ throw QueryExecutionErrors.emptyJsonFieldValueError(dataType)
+ case _ => null
+ }
+
+ case VALUE_STRING if parser.getTextLength < 1 =>
+ throw QueryExecutionErrors.emptyJsonFieldValueError(dataType)
+
+ case token =>
+ // We cannot parse this token based on the given data type. So, we throw a
+ // RuntimeException and this exception will be caught by `parse` method.
+ throw QueryExecutionErrors.cannotParseJSONFieldError(parser, token, dataType)
+ }
+
+ /**
+ * Parse an object from the token stream into a new Row representing the schema.
+ * Fields in the json that are not defined in the requested schema will be dropped.
+ */
+ private def convertObject(
+ parser: JsonParser,
+ schema: StructType,
+ fieldConverters: Array[ValueConverter],
+ structFilters: StructFilters = new NoopFilters(),
+ isRoot: Boolean = false): Option[InternalRow] = {
+ val row = new GenericInternalRow(schema.length)
+ var badRecordException: Option[Throwable] = None
+ var skipRow = false
+
+ structFilters.reset()
+ resetExistenceDefaultsBitmask(schema)
+ while (!skipRow && nextUntil(parser, JsonToken.END_OBJECT)) {
+ schema.getFieldIndex(parser.getCurrentName) match {
+ case Some(index) =>
+ try {
+ row.update(index, fieldConverters(index).apply(parser))
+ skipRow = structFilters.skipRow(row, index)
+ schema.existenceDefaultsBitmask(index) = false
+ } catch {
+ case e: SparkUpgradeException => throw e
+ case NonFatal(e) if isRoot || enablePartialResults =>
+ badRecordException = badRecordException.orElse(Some(e))
+ parser.skipChildren()
+ }
+ case None =>
+ parser.skipChildren()
+ }
+ }
+ if (skipRow) {
+ None
+ } else if (badRecordException.isEmpty) {
+ applyExistenceDefaultValuesToRow(schema, row)
+ Some(row)
+ } else {
+ throw PartialResultException(row, badRecordException.get)
+ }
+ }
+
+ /**
+ * Parse an object as a Map, preserving all fields.
+ */
+ private def convertMap(
+ parser: JsonParser,
+ fieldConverter: ValueConverter): MapData = {
+ val keys = ArrayBuffer.empty[UTF8String]
+ val values = ArrayBuffer.empty[Any]
+ var badRecordException: Option[Throwable] = None
+
+ while (nextUntil(parser, JsonToken.END_OBJECT)) {
+ keys += UTF8String.fromString(parser.getCurrentName)
+ try {
+ values += fieldConverter.apply(parser)
+ } catch {
+ case PartialResultException(row, cause) if enablePartialResults =>
+ badRecordException = badRecordException.orElse(Some(cause))
+ values += row
+ case NonFatal(e) if enablePartialResults =>
+ badRecordException = badRecordException.orElse(Some(e))
+ parser.skipChildren()
+ }
+ }
+
+ // The JSON map will never have null or duplicated map keys, it's safe to create a
+ // ArrayBasedMapData directly here.
+ val mapData = ArrayBasedMapData(keys.toArray, values.toArray)
+
+ if (badRecordException.isEmpty) {
+ mapData
+ } else {
+ throw PartialResultException(InternalRow(mapData), badRecordException.get)
+ }
+ }
+
+ /**
+ * Parse an object as a Array.
+ */
+ private def convertArray(
+ parser: JsonParser,
+ fieldConverter: ValueConverter,
+ isRoot: Boolean = false): ArrayData = {
+ val values = ArrayBuffer.empty[Any]
+ var badRecordException: Option[Throwable] = None
+
+ while (nextUntil(parser, JsonToken.END_ARRAY)) {
+ try {
+ val v = fieldConverter.apply(parser)
+ if (isRoot && v == null) throw QueryExecutionErrors.rootConverterReturnNullError()
+ values += v
+ } catch {
+ case PartialResultException(row, cause) if enablePartialResults =>
+ badRecordException = badRecordException.orElse(Some(cause))
+ values += row
+ }
+ }
+
+ val arrayData = new GenericArrayData(values.toArray)
+
+ if (badRecordException.isEmpty) {
+ arrayData
+ } else {
+ throw PartialResultException(InternalRow(arrayData), badRecordException.get)
+ }
+ }
+
+ /**
+ * Parse the JSON input to the set of [[InternalRow]]s.
+ *
+ * @param recordLiteral an optional function that will be used to generate
+ * the corrupt record text instead of record.toString
+ */
+ def parse[T](
+ record: T,
+ createParser: (JsonFactory, T) => JsonParser,
+ recordLiteral: T => UTF8String): Iterable[InternalRow] = {
+ try {
+ Utils.tryWithResource(createParser(factory, record)) { parser =>
+ // a null first token is equivalent to testing for input.trim.isEmpty
+ // but it works on any token stream and not just strings
+ parser.nextToken() match {
+ case null => None
+ case _ => rootConverter.apply(parser) match {
+ case null => throw QueryExecutionErrors.rootConverterReturnNullError()
+ case rows => rows.toSeq
+ }
+ }
+ }
+ } catch {
+ case e: SparkUpgradeException => throw e
+ case e @ (_: RuntimeException | _: JsonProcessingException | _: MalformedInputException) =>
+ // JSON parser currently doesn't support partial results for corrupted records.
+ // For such records, all fields other than the field configured by
+ // `columnNameOfCorruptRecord` are set to `null`.
+ throw BadRecordException(() => recordLiteral(record), () => None, e)
+ case e: CharConversionException if options.encoding.isEmpty =>
+ val msg =
+ """JSON parser cannot handle a character in its input.
+ |Specifying encoding as an input option explicitly might help to resolve the issue.
+ |""".stripMargin + e.getMessage
+ val wrappedCharException = new CharConversionException(msg)
+ wrappedCharException.initCause(e)
+ throw BadRecordException(() => recordLiteral(record), () => None, wrappedCharException)
+ case PartialResultException(row, cause) =>
+ throw BadRecordException(
+ record = () => recordLiteral(record),
+ partialResult = () => Some(row),
+ cause)
+ }
+ }
+}
diff --git a/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/json/JacksonUtils.scala b/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/json/JacksonUtils.scala
new file mode 100644
index 00000000..122800b0
--- /dev/null
+++ b/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/json/JacksonUtils.scala
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// scalastyle:off
+
+package org.apache.spark.sql.arangodb.datasource.mapping.json
+
+import com.fasterxml.jackson.core.{JsonParser, JsonToken}
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch, TypeCheckSuccess}
+import org.apache.spark.sql.errors.QueryErrorsBase
+import org.apache.spark.sql.types._
+
+object JacksonUtils extends QueryErrorsBase {
+ /**
+ * Advance the parser until a null or a specific token is found
+ */
+ def nextUntil(parser: JsonParser, stopOn: JsonToken): Boolean = {
+ parser.nextToken() match {
+ case null => false
+ case x => x != stopOn
+ }
+ }
+
+ def verifyType(name: String, dataType: DataType): TypeCheckResult = {
+ dataType match {
+ case NullType | _: AtomicType | CalendarIntervalType => TypeCheckSuccess
+
+ case st: StructType =>
+ st.foldLeft(TypeCheckSuccess: TypeCheckResult) { case (currResult, field) =>
+ if (currResult.isFailure) currResult else verifyType(field.name, field.dataType)
+ }
+
+ case at: ArrayType => verifyType(name, at.elementType)
+
+ // For MapType, its keys are treated as a string (i.e. calling `toString`) basically when
+ // generating JSON, so we only care if the values are valid for JSON.
+ case mt: MapType => verifyType(name, mt.valueType)
+
+ case udt: UserDefinedType[_] => verifyType(name, udt.sqlType)
+
+ case _ =>
+ DataTypeMismatch(
+ errorSubClass = "CANNOT_CONVERT_TO_JSON",
+ messageParameters = Map(
+ "name" -> toSQLId(name),
+ "type" -> toSQLType(dataType)))
+ }
+ }
+}
diff --git a/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/json/JsonFilters.scala b/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/json/JsonFilters.scala
new file mode 100644
index 00000000..1665787a
--- /dev/null
+++ b/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/json/JsonFilters.scala
@@ -0,0 +1,158 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// scalastyle:off
+
+package org.apache.spark.sql.arangodb.datasource.mapping.json
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.{InternalRow, StructFilters}
+import org.apache.spark.sql.sources
+import org.apache.spark.sql.types.StructType
+
+/**
+ * The class provides API for applying pushed down source filters to rows with
+ * a struct schema parsed from JSON records. The class should be used in this way:
+ * 1. Before processing of the next row, `JacksonParser` (parser for short) resets the internal
+ * state of `JsonFilters` by calling the `reset()` method.
+ * 2. The parser reads JSON fields one-by-one in streaming fashion. It converts an incoming
+ * field value to the desired type from the schema. After that, it sets the value to an instance
+ * of `InternalRow` at the position according to the schema. Order of parsed JSON fields can
+ * be different from the order in the schema.
+ * 3. Per every JSON field of the top-level JSON object, the parser calls `skipRow` by passing
+ * an `InternalRow` in which some of fields can be already set, and the position of the JSON
+ * field according to the schema.
+ * 3.1 `skipRow` finds a group of predicates that refers to this JSON field.
+ * 3.2 Per each predicate from the group, `skipRow` decrements its reference counter.
+ * 3.2.1 If predicate reference counter becomes 0, it means that all predicate attributes have
+ * been already set in the internal row, and the predicate can be applied to it. `skipRow`
+ * invokes the predicate for the row.
+ * 3.3 `skipRow` applies predicates until one of them returns `false`. In that case, the method
+ * returns `true` to the parser.
+ * 3.4 If all predicates with zero reference counter return `true`, the final result of
+ * the method is `false` which tells the parser to not skip the row.
+ * 4. If the parser gets `true` from `JsonFilters.skipRow`, it must not call the method anymore
+ * for this internal row, and should go the step 1.
+ *
+ * Besides of `StructFilters` assumptions, `JsonFilters` assumes that:
+ * - `skipRow()` can be called for any valid index of the struct fields,
+ * and in any order.
+ * - After `skipRow()` returns `true`, the internal state of `JsonFilters` can be inconsistent,
+ * so, `skipRow()` must not be called for the current row anymore without `reset()`.
+ *
+ * @param pushedFilters The pushed down source filters. The filters should refer to
+ * the fields of the provided schema.
+ * @param schema The required schema of records from datasource files.
+ */
+class JsonFilters(pushedFilters: Seq[sources.Filter], schema: StructType)
+ extends StructFilters(pushedFilters, schema) {
+
+ /**
+ * Stateful JSON predicate that keeps track of its dependent references in the
+ * current row via `refCount`.
+ *
+ * @param predicate The predicate compiled from pushed down source filters.
+ * @param totalRefs The total amount of all filters references which the predicate
+ * compiled from.
+ */
+ case class JsonPredicate(predicate: BasePredicate, totalRefs: Int) {
+ // The current number of predicate references in the row that have been not set yet.
+ // When `refCount` reaches zero, the predicate has all dependencies are set, and can
+ // be applied to the row.
+ var refCount: Int = totalRefs
+
+ def reset(): Unit = {
+ refCount = totalRefs
+ }
+ }
+
+ // Predicates compiled from the pushed down filters. The predicates are grouped by their
+ // attributes. The i-th group contains predicates that refer to the i-th field of the given
+ // schema. A predicates can be placed to many groups if it has many attributes. For example:
+ // schema: i INTEGER, s STRING
+ // filters: IsNotNull("i"), AlwaysTrue, Or(EqualTo("i", 0), StringStartsWith("s", "abc"))
+ // predicates:
+ // 0: Array(IsNotNull("i"), AlwaysTrue, Or(EqualTo("i", 0), StringStartsWith("s", "abc")))
+ // 1: Array(AlwaysTrue, Or(EqualTo("i", 0), StringStartsWith("s", "abc")))
+ private val predicates: Array[Array[JsonPredicate]] = {
+ val groupedPredicates = Array.fill(schema.length)(Array.empty[JsonPredicate])
+ val groupedByRefSet: Map[Set[String], JsonPredicate] = filters
+ // Group filters that have the same set of references. For example:
+ // IsNotNull("i") -> Set("i"), AlwaysTrue -> Set(),
+ // Or(EqualTo("i", 0), StringStartsWith("s", "abc")) -> Set("i", "s")
+ // By grouping filters we could avoid tracking their state of references in the
+ // current row separately.
+ .groupBy(_.references.toSet)
+ // Combine all filters from the same group by `And` because all filters should
+ // return `true` to do not skip a row. The result is compiled to a predicate.
+ .map { case (refSet, refsFilters) =>
+ (refSet, JsonPredicate(toPredicate(refsFilters), refSet.size))
+ }
+ // Apply predicates w/o references like `AlwaysTrue` and `AlwaysFalse` to all fields.
+ // We cannot set such predicates to a particular position because skipRow() can
+ // be invoked for any index due to unpredictable order of JSON fields in JSON records.
+ val withLiterals: Map[Set[String], JsonPredicate] = groupedByRefSet.map {
+ case (refSet, pred) if refSet.isEmpty =>
+ (schema.fields.map(_.name).toSet, pred.copy(totalRefs = 1))
+ case others => others
+ }
+ // Build a map where key is only one field and value is seq of predicates refer to the field
+ // "i" -> Seq(AlwaysTrue, IsNotNull("i"), Or(EqualTo("i", 0), StringStartsWith("s", "abc")))
+ // "s" -> Seq(AlwaysTrue, Or(EqualTo("i", 0), StringStartsWith("s", "abc")))
+ val groupedByFields: Map[String, Seq[(String, JsonPredicate)]] = withLiterals.toSeq
+ .flatMap { case (refSet, pred) => refSet.map((_, pred)) }
+ .groupBy(_._1)
+ // Build the final array by converting keys of `groupedByFields` to their
+ // indexes in the provided schema.
+ groupedByFields.foreach { case (fieldName, fieldPredicates) =>
+ val fieldIndex = schema.fieldIndex(fieldName)
+ groupedPredicates(fieldIndex) = fieldPredicates.map(_._2).toArray
+ }
+ groupedPredicates
+ }
+
+ /**
+ * Applies predicates (compiled filters) associated with the row field value
+ * at the position `index` only if other predicates dependencies are already
+ * set in the given row.
+ *
+ * Note: If the function returns `true`, `refCount` of some predicates can be not decremented.
+ *
+ * @param row The row with fully or partially set values.
+ * @param index The index of already set value.
+ * @return `true` if at least one of applicable predicates (all dependent row values are set)
+ * return `false`. It returns `false` if all predicates return `true`.
+ */
+ def skipRow(row: InternalRow, index: Int): Boolean = {
+ assert(0 <= index && index < schema.fields.length,
+ s"The index $index is out of the valid range [0, ${schema.fields.length}). " +
+ s"It must point out to a field of the schema: ${schema.catalogString}.")
+ var skip = false
+ for (pred <- predicates(index) if !skip) {
+ pred.refCount -= 1
+ assert(pred.refCount >= 0,
+ s"Predicate reference counter cannot be negative but got ${pred.refCount}.")
+ skip = pred.refCount == 0 && !pred.predicate.eval(row)
+ }
+ skip
+ }
+
+ /**
+ * Reset states of all predicates by re-initializing reference counters.
+ */
+ override def reset(): Unit = predicates.foreach(_.foreach(_.reset))
+}
diff --git a/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/json/JsonInferSchema.scala b/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/json/JsonInferSchema.scala
new file mode 100644
index 00000000..d2c70f71
--- /dev/null
+++ b/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/json/JsonInferSchema.scala
@@ -0,0 +1,413 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// scalastyle:off
+
+package org.apache.spark.sql.arangodb.datasource.mapping.json
+
+import com.fasterxml.jackson.core._
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.analysis.TypeCoercion
+import org.apache.spark.sql.catalyst.expressions.ExprUtils
+import org.apache.spark.sql.catalyst.json.JacksonUtils.nextUntil
+import org.apache.spark.sql.catalyst.util.LegacyDateFormats.FAST_DATE_FORMAT
+import org.apache.spark.sql.catalyst.util._
+import org.apache.spark.sql.errors.QueryExecutionErrors
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types._
+import org.apache.spark.util.Utils
+
+import java.io.CharConversionException
+import java.nio.charset.MalformedInputException
+import java.util.Comparator
+import scala.util.control.Exception.allCatch
+
+private[sql] class JsonInferSchema(options: JSONOptions) extends Serializable {
+
+ private val decimalParser = ExprUtils.getDecimalParser(options.locale)
+
+ private val timestampFormatter = TimestampFormatter(
+ options.timestampFormatInRead,
+ options.zoneId,
+ options.locale,
+ legacyFormat = FAST_DATE_FORMAT,
+ isParsing = true)
+ private val timestampNTZFormatter = TimestampFormatter(
+ options.timestampNTZFormatInRead,
+ options.zoneId,
+ legacyFormat = FAST_DATE_FORMAT,
+ isParsing = true,
+ forTimestampNTZ = true)
+
+ private def handleJsonErrorsByParseMode(parseMode: ParseMode,
+ columnNameOfCorruptRecord: String, e: Throwable): Option[StructType] = {
+ parseMode match {
+ case PermissiveMode =>
+ Some(StructType(Array(StructField(columnNameOfCorruptRecord, StringType))))
+ case DropMalformedMode =>
+ None
+ case FailFastMode =>
+ throw QueryExecutionErrors.malformedRecordsDetectedInSchemaInferenceError(e)
+ }
+ }
+
+ /**
+ * Infer the type of a collection of json records in three stages:
+ * 1. Infer the type of each record
+ * 2. Merge types by choosing the lowest type necessary to cover equal keys
+ * 3. Replace any remaining null fields with string, the top type
+ */
+ def infer[T](
+ json: RDD[T],
+ createParser: (JsonFactory, T) => JsonParser): StructType = {
+ val parseMode = options.parseMode
+ val columnNameOfCorruptRecord = options.columnNameOfCorruptRecord
+
+ // In each RDD partition, perform schema inference on each row and merge afterwards.
+ val typeMerger = JsonInferSchema.compatibleRootType(columnNameOfCorruptRecord, parseMode)
+ val mergedTypesFromPartitions = json.mapPartitions { iter =>
+ val factory = options.buildJsonFactory()
+ iter.flatMap { row =>
+ try {
+ Utils.tryWithResource(createParser(factory, row)) { parser =>
+ parser.nextToken()
+ Some(inferField(parser))
+ }
+ } catch {
+ case e @ (_: RuntimeException | _: JsonProcessingException |
+ _: MalformedInputException) =>
+ handleJsonErrorsByParseMode(parseMode, columnNameOfCorruptRecord, e)
+ case e: CharConversionException if options.encoding.isEmpty =>
+ val msg =
+ """JSON parser cannot handle a character in its input.
+ |Specifying encoding as an input option explicitly might help to resolve the issue.
+ |""".stripMargin + e.getMessage
+ val wrappedCharException = new CharConversionException(msg)
+ wrappedCharException.initCause(e)
+ handleJsonErrorsByParseMode(parseMode, columnNameOfCorruptRecord, wrappedCharException)
+ }
+ }.reduceOption(typeMerger).iterator
+ }
+
+ // Here we manually submit a fold-like Spark job, so that we can set the SQLConf when running
+ // the fold functions in the scheduler event loop thread.
+ val existingConf = SQLConf.get
+ var rootType: DataType = StructType(Nil)
+ val foldPartition = (iter: Iterator[DataType]) => iter.fold(StructType(Nil))(typeMerger)
+ val mergeResult = (index: Int, taskResult: DataType) => {
+ rootType = SQLConf.withExistingConf(existingConf) {
+ typeMerger(rootType, taskResult)
+ }
+ }
+ json.sparkContext.runJob(mergedTypesFromPartitions, foldPartition, mergeResult)
+
+ canonicalizeType(rootType, options)
+ .find(_.isInstanceOf[StructType])
+ // canonicalizeType erases all empty structs, including the only one we want to keep
+ .getOrElse(StructType(Nil)).asInstanceOf[StructType]
+ }
+
+ /**
+ * Infer the type of a json document from the parser's token stream
+ */
+ def inferField(parser: JsonParser): DataType = {
+ import com.fasterxml.jackson.core.JsonToken._
+ parser.getCurrentToken match {
+ case null | VALUE_NULL => NullType
+
+ case FIELD_NAME =>
+ parser.nextToken()
+ inferField(parser)
+
+ case VALUE_STRING if parser.getTextLength < 1 =>
+ // Zero length strings and nulls have special handling to deal
+ // with JSON generators that do not distinguish between the two.
+ // To accurately infer types for empty strings that are really
+ // meant to represent nulls we assume that the two are isomorphic
+ // but will defer treating null fields as strings until all the
+ // record fields' types have been combined.
+ NullType
+
+ case VALUE_STRING =>
+ val field = parser.getText
+ lazy val decimalTry = allCatch opt {
+ val bigDecimal = decimalParser(field)
+ DecimalType(bigDecimal.precision, bigDecimal.scale)
+ }
+ if (options.prefersDecimal && decimalTry.isDefined) {
+ decimalTry.get
+ } else if (options.inferTimestamp &&
+ timestampNTZFormatter.parseWithoutTimeZoneOptional(field, false).isDefined) {
+ SQLConf.get.timestampType
+ } else if (options.inferTimestamp &&
+ timestampFormatter.parseOptional(field).isDefined) {
+ TimestampType
+ } else {
+ StringType
+ }
+
+ case START_OBJECT =>
+ val builder = Array.newBuilder[StructField]
+ while (nextUntil(parser, END_OBJECT)) {
+ builder += StructField(
+ parser.getCurrentName,
+ inferField(parser),
+ nullable = true)
+ }
+ val fields: Array[StructField] = builder.result()
+ // Note: other code relies on this sorting for correctness, so don't remove it!
+ java.util.Arrays.sort(fields, JsonInferSchema.structFieldComparator)
+ StructType(fields)
+
+ case START_ARRAY =>
+ // If this JSON array is empty, we use NullType as a placeholder.
+ // If this array is not empty in other JSON objects, we can resolve
+ // the type as we pass through all JSON objects.
+ var elementType: DataType = NullType
+ while (nextUntil(parser, END_ARRAY)) {
+ elementType = JsonInferSchema.compatibleType(
+ elementType, inferField(parser))
+ }
+
+ ArrayType(elementType)
+
+ case (VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT) if options.primitivesAsString => StringType
+
+ case (VALUE_TRUE | VALUE_FALSE) if options.primitivesAsString => StringType
+
+ case VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT =>
+ import JsonParser.NumberType._
+ parser.getNumberType match {
+ // For Integer values, use LongType by default.
+ case INT | LONG => LongType
+ // Since we do not have a data type backed by BigInteger,
+ // when we see a Java BigInteger, we use DecimalType.
+ case BIG_INTEGER | BIG_DECIMAL =>
+ val v = parser.getDecimalValue
+ if (Math.max(v.precision(), v.scale()) <= DecimalType.MAX_PRECISION) {
+ DecimalType(Math.max(v.precision(), v.scale()), v.scale())
+ } else {
+ DoubleType
+ }
+ case FLOAT | DOUBLE if options.prefersDecimal =>
+ val v = parser.getDecimalValue
+ if (Math.max(v.precision(), v.scale()) <= DecimalType.MAX_PRECISION) {
+ DecimalType(Math.max(v.precision(), v.scale()), v.scale())
+ } else {
+ DoubleType
+ }
+ case FLOAT | DOUBLE =>
+ DoubleType
+ }
+
+ case VALUE_TRUE | VALUE_FALSE => BooleanType
+
+ case _ =>
+ throw QueryExecutionErrors.malformedJSONError()
+ }
+ }
+
+ /**
+ * Recursively canonicalizes inferred types, e.g., removes StructTypes with no fields,
+ * drops NullTypes or converts them to StringType based on provided options.
+ */
+ private[json] def canonicalizeType(
+ tpe: DataType, options: JSONOptions): Option[DataType] = tpe match {
+ case at: ArrayType =>
+ canonicalizeType(at.elementType, options)
+ .map(t => at.copy(elementType = t))
+
+ case StructType(fields) =>
+ val canonicalFields = fields.filter(_.name.nonEmpty).flatMap { f =>
+ canonicalizeType(f.dataType, options)
+ .map(t => f.copy(dataType = t))
+ }
+ // SPARK-8093: empty structs should be deleted
+ if (canonicalFields.isEmpty) {
+ None
+ } else {
+ Some(StructType(canonicalFields))
+ }
+
+ case NullType =>
+ if (options.dropFieldIfAllNull) {
+ None
+ } else {
+ Some(StringType)
+ }
+
+ case other => Some(other)
+ }
+}
+
+object JsonInferSchema {
+ val structFieldComparator = new Comparator[StructField] {
+ override def compare(o1: StructField, o2: StructField): Int = {
+ o1.name.compareTo(o2.name)
+ }
+ }
+
+ def isSorted(arr: Array[StructField]): Boolean = {
+ var i: Int = 0
+ while (i < arr.length - 1) {
+ if (structFieldComparator.compare(arr(i), arr(i + 1)) > 0) {
+ return false
+ }
+ i += 1
+ }
+ true
+ }
+
+ def withCorruptField(
+ struct: StructType,
+ other: DataType,
+ columnNameOfCorruptRecords: String,
+ parseMode: ParseMode): StructType = parseMode match {
+ case PermissiveMode =>
+ // If we see any other data type at the root level, we get records that cannot be
+ // parsed. So, we use the struct as the data type and add the corrupt field to the schema.
+ if (!struct.fieldNames.contains(columnNameOfCorruptRecords)) {
+ // If this given struct does not have a column used for corrupt records,
+ // add this field.
+ val newFields: Array[StructField] =
+ StructField(columnNameOfCorruptRecords, StringType, nullable = true) +: struct.fields
+ // Note: other code relies on this sorting for correctness, so don't remove it!
+ java.util.Arrays.sort(newFields, structFieldComparator)
+ StructType(newFields)
+ } else {
+ // Otherwise, just return this struct.
+ struct
+ }
+
+ case DropMalformedMode =>
+ // If corrupt record handling is disabled we retain the valid schema and discard the other.
+ struct
+
+ case FailFastMode =>
+ // If `other` is not struct type, consider it as malformed one and throws an exception.
+ throw QueryExecutionErrors.malformedRecordsDetectedInSchemaInferenceError(other)
+ }
+
+ /**
+ * Remove top-level ArrayType wrappers and merge the remaining schemas
+ */
+ def compatibleRootType(
+ columnNameOfCorruptRecords: String,
+ parseMode: ParseMode): (DataType, DataType) => DataType = {
+ // Since we support array of json objects at the top level,
+ // we need to check the element type and find the root level data type.
+ case (ArrayType(ty1, _), ty2) =>
+ compatibleRootType(columnNameOfCorruptRecords, parseMode)(ty1, ty2)
+ case (ty1, ArrayType(ty2, _)) =>
+ compatibleRootType(columnNameOfCorruptRecords, parseMode)(ty1, ty2)
+ // Discard null/empty documents
+ case (struct: StructType, NullType) => struct
+ case (NullType, struct: StructType) => struct
+ case (struct: StructType, o) if !o.isInstanceOf[StructType] =>
+ withCorruptField(struct, o, columnNameOfCorruptRecords, parseMode)
+ case (o, struct: StructType) if !o.isInstanceOf[StructType] =>
+ withCorruptField(struct, o, columnNameOfCorruptRecords, parseMode)
+ // If we get anything else, we call compatibleType.
+ // Usually, when we reach here, ty1 and ty2 are two StructTypes.
+ case (ty1, ty2) => compatibleType(ty1, ty2)
+ }
+
+ private[this] val emptyStructFieldArray = Array.empty[StructField]
+
+ /**
+ * Returns the most general data type for two given data types.
+ */
+ def compatibleType(t1: DataType, t2: DataType): DataType = {
+ TypeCoercion.findTightestCommonType(t1, t2).getOrElse {
+ // t1 or t2 is a StructType, ArrayType, or an unexpected type.
+ (t1, t2) match {
+ // Double support larger range than fixed decimal, DecimalType.Maximum should be enough
+ // in most case, also have better precision.
+ case (DoubleType, _: DecimalType) | (_: DecimalType, DoubleType) =>
+ DoubleType
+
+ case (t1: DecimalType, t2: DecimalType) =>
+ val scale = math.max(t1.scale, t2.scale)
+ val range = math.max(t1.precision - t1.scale, t2.precision - t2.scale)
+ if (range + scale > 38) {
+ // DecimalType can't support precision > 38
+ DoubleType
+ } else {
+ DecimalType(range + scale, scale)
+ }
+
+ case (StructType(fields1), StructType(fields2)) =>
+ // Both fields1 and fields2 should be sorted by name, since inferField performs sorting.
+ // Therefore, we can take advantage of the fact that we're merging sorted lists and skip
+ // building a hash map or performing additional sorting.
+ assert(isSorted(fields1),
+ s"${StructType.simpleString}'s fields were not sorted: ${fields1.toSeq}")
+ assert(isSorted(fields2),
+ s"${StructType.simpleString}'s fields were not sorted: ${fields2.toSeq}")
+
+ val newFields = new java.util.ArrayList[StructField]()
+
+ var f1Idx = 0
+ var f2Idx = 0
+
+ while (f1Idx < fields1.length && f2Idx < fields2.length) {
+ val f1Name = fields1(f1Idx).name
+ val f2Name = fields2(f2Idx).name
+ val comp = f1Name.compareTo(f2Name)
+ if (comp == 0) {
+ val dataType = compatibleType(fields1(f1Idx).dataType, fields2(f2Idx).dataType)
+ newFields.add(StructField(f1Name, dataType, nullable = true))
+ f1Idx += 1
+ f2Idx += 1
+ } else if (comp < 0) { // f1Name < f2Name
+ newFields.add(fields1(f1Idx))
+ f1Idx += 1
+ } else { // f1Name > f2Name
+ newFields.add(fields2(f2Idx))
+ f2Idx += 1
+ }
+ }
+ while (f1Idx < fields1.length) {
+ newFields.add(fields1(f1Idx))
+ f1Idx += 1
+ }
+ while (f2Idx < fields2.length) {
+ newFields.add(fields2(f2Idx))
+ f2Idx += 1
+ }
+ StructType(newFields.toArray(emptyStructFieldArray))
+
+ case (ArrayType(elementType1, containsNull1), ArrayType(elementType2, containsNull2)) =>
+ ArrayType(compatibleType(elementType1, elementType2), containsNull1 || containsNull2)
+
+ // The case that given `DecimalType` is capable of given `IntegralType` is handled in
+ // `findTightestCommonType`. Both cases below will be executed only when the given
+ // `DecimalType` is not capable of the given `IntegralType`.
+ case (t1: IntegralType, t2: DecimalType) =>
+ compatibleType(DecimalType.forType(t1), t2)
+ case (t1: DecimalType, t2: IntegralType) =>
+ compatibleType(t1, DecimalType.forType(t2))
+
+ case (TimestampNTZType, TimestampType) | (TimestampType, TimestampNTZType) =>
+ TimestampType
+
+ // strings and every string is a Json object.
+ case (_, _) => StringType
+ }
+ }
+ }
+}
diff --git a/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/package.scala b/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/package.scala
new file mode 100644
index 00000000..90270213
--- /dev/null
+++ b/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/package.scala
@@ -0,0 +1,14 @@
+package org.apache.spark.sql.arangodb.datasource
+
+import com.fasterxml.jackson.core.JsonFactory
+import org.apache.spark.sql.arangodb.commons.ArangoDBConf
+import org.apache.spark.sql.arangodb.datasource.mapping.json.JSONOptions
+
+package object mapping {
+ private[mapping] def createOptions(jsonFactory: JsonFactory, conf: ArangoDBConf) =
+ new JSONOptions(Map.empty[String, String], "UTC") {
+ override def buildJsonFactory(): JsonFactory = jsonFactory
+
+ override val ignoreNullFields: Boolean = conf.mappingOptions.ignoreNullFields
+ }
+}
diff --git a/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/reader/ArangoCollectionPartition.scala b/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/reader/ArangoCollectionPartition.scala
new file mode 100644
index 00000000..ce354e0e
--- /dev/null
+++ b/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/reader/ArangoCollectionPartition.scala
@@ -0,0 +1,15 @@
+package org.apache.spark.sql.arangodb.datasource.reader
+
+import org.apache.spark.sql.connector.read.InputPartition
+
+/**
+ * Partition corresponding to an Arango collection shard
+ * @param shardId collection shard id
+ * @param endpoint db endpoint to use to query the partition
+ */
+class ArangoCollectionPartition(val shardId: String, val endpoint: String) extends InputPartition
+
+/**
+ * Custom user queries will not be partitioned (eg. AQL traversals)
+ */
+object SingletonPartition extends InputPartition
diff --git a/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/reader/ArangoCollectionPartitionReader.scala b/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/reader/ArangoCollectionPartitionReader.scala
new file mode 100644
index 00000000..d24ec187
--- /dev/null
+++ b/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/reader/ArangoCollectionPartitionReader.scala
@@ -0,0 +1,65 @@
+package org.apache.spark.sql.arangodb.datasource.reader
+
+import com.arangodb.entity.CursorWarning
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.arangodb.commons.mapping.ArangoParserProvider
+import org.apache.spark.sql.arangodb.commons.utils.PushDownCtx
+import org.apache.spark.sql.arangodb.commons.{ArangoClient, ArangoDBConf}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.util.FailureSafeParser
+import org.apache.spark.sql.connector.read.PartitionReader
+import org.apache.spark.sql.types.StructType
+
+import scala.annotation.tailrec
+import scala.collection.JavaConverters.iterableAsScalaIterableConverter
+
+
+class ArangoCollectionPartitionReader(inputPartition: ArangoCollectionPartition, ctx: PushDownCtx, opts: ArangoDBConf)
+ extends PartitionReader[InternalRow] with Logging {
+
+ // override endpoints with partition endpoint
+ private val options = opts.updated(ArangoDBConf.ENDPOINTS, inputPartition.endpoint)
+ private val actualSchema = StructType(ctx.requiredSchema.filterNot(_.name == options.readOptions.columnNameOfCorruptRecord))
+ private val parser = ArangoParserProvider().of(options.driverOptions.contentType, actualSchema, options)
+ private val safeParser = new FailureSafeParser[Array[Byte]](
+ parser.parse,
+ options.readOptions.parseMode,
+ ctx.requiredSchema,
+ options.readOptions.columnNameOfCorruptRecord)
+ private val client = ArangoClient(options)
+ private val iterator = client.readCollectionPartition(inputPartition.shardId, ctx.filters, actualSchema)
+
+ var rowIterator: Iterator[InternalRow] = _
+
+ // warnings of non stream AQL cursors are all returned along with the first batch
+ if (!options.readOptions.stream) logWarns()
+
+ @tailrec
+ final override def next: Boolean =
+ if (iterator.hasNext) {
+ val current = iterator.next()
+ rowIterator = safeParser.parse(current.get)
+ if (rowIterator.hasNext) {
+ true
+ } else {
+ next
+ }
+ } else {
+ // FIXME: https://arangodb.atlassian.net/browse/BTS-671
+ // stream AQL cursors' warnings are only returned along with the final batch
+ if (options.readOptions.stream) logWarns()
+ false
+ }
+
+ override def get: InternalRow = rowIterator.next()
+
+ override def close(): Unit = {
+ iterator.close()
+ client.shutdown()
+ }
+
+ private def logWarns(): Unit = Option(iterator.getWarnings).foreach(_.asScala.foreach((w: CursorWarning) =>
+ logWarning(s"Got AQL warning: [${w.getCode}] ${w.getMessage}")
+ ))
+
+}
diff --git a/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/reader/ArangoPartitionReaderFactory.scala b/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/reader/ArangoPartitionReaderFactory.scala
new file mode 100644
index 00000000..feacc04b
--- /dev/null
+++ b/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/reader/ArangoPartitionReaderFactory.scala
@@ -0,0 +1,13 @@
+package org.apache.spark.sql.arangodb.datasource.reader
+
+import org.apache.spark.sql.arangodb.commons.ArangoDBConf
+import org.apache.spark.sql.arangodb.commons.utils.PushDownCtx
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory}
+
+class ArangoPartitionReaderFactory(ctx: PushDownCtx, options: ArangoDBConf) extends PartitionReaderFactory {
+ override def createReader(partition: InputPartition): PartitionReader[InternalRow] = partition match {
+ case p: ArangoCollectionPartition => new ArangoCollectionPartitionReader(p, ctx, options)
+ case SingletonPartition => new ArangoQueryReader(ctx.requiredSchema, options)
+ }
+}
diff --git a/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/reader/ArangoQueryReader.scala b/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/reader/ArangoQueryReader.scala
new file mode 100644
index 00000000..8d975f59
--- /dev/null
+++ b/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/reader/ArangoQueryReader.scala
@@ -0,0 +1,63 @@
+package org.apache.spark.sql.arangodb.datasource.reader
+
+import com.arangodb.entity.CursorWarning
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.arangodb.commons.mapping.ArangoParserProvider
+import org.apache.spark.sql.arangodb.commons.{ArangoClient, ArangoDBConf}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.util.FailureSafeParser
+import org.apache.spark.sql.connector.read.PartitionReader
+import org.apache.spark.sql.types._
+
+import scala.annotation.tailrec
+import scala.collection.JavaConverters.iterableAsScalaIterableConverter
+
+
+class ArangoQueryReader(schema: StructType, options: ArangoDBConf) extends PartitionReader[InternalRow] with Logging {
+
+ private val actualSchema = StructType(schema.filterNot(_.name == options.readOptions.columnNameOfCorruptRecord))
+ private val parser = ArangoParserProvider().of(options.driverOptions.contentType, actualSchema, options)
+ private val safeParser = new FailureSafeParser[Array[Byte]](
+ parser.parse,
+ options.readOptions.parseMode,
+ schema,
+ options.readOptions.columnNameOfCorruptRecord)
+ private val client = ArangoClient(options)
+ private val iterator = client.readQuery()
+
+ var rowIterator: Iterator[InternalRow] = _
+
+ // warnings of non stream AQL cursors are all returned along with the first batch
+ if (!options.readOptions.stream) logWarns()
+
+ @tailrec
+ final override def next: Boolean =
+ if (iterator.hasNext) {
+ val current = iterator.next()
+ rowIterator = safeParser.parse(current.get)
+ if (rowIterator.hasNext) {
+ true
+ } else {
+ next
+ }
+ } else {
+ // FIXME: https://arangodb.atlassian.net/browse/BTS-671
+ // stream AQL cursors' warnings are only returned along with the final batch
+ if (options.readOptions.stream) logWarns()
+ false
+ }
+
+ override def get: InternalRow = rowIterator.next()
+
+ override def close(): Unit = {
+ iterator.close()
+ client.shutdown()
+ }
+
+ private def logWarns(): Unit = Option(iterator.getWarnings).foreach(_.asScala.foreach((w: CursorWarning) =>
+ logWarning(s"Got AQL warning: [${w.getCode}] ${w.getMessage}")
+ ))
+
+}
+
+
diff --git a/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/reader/ArangoScan.scala b/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/reader/ArangoScan.scala
new file mode 100644
index 00000000..3feedac5
--- /dev/null
+++ b/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/reader/ArangoScan.scala
@@ -0,0 +1,28 @@
+package org.apache.spark.sql.arangodb.datasource.reader
+
+import org.apache.spark.sql.arangodb.commons.{ArangoClient, ArangoDBConf, ReadMode}
+import org.apache.spark.sql.arangodb.commons.utils.PushDownCtx
+import org.apache.spark.sql.catalyst.expressions.ExprUtils
+import org.apache.spark.sql.connector.read.{Batch, InputPartition, PartitionReaderFactory, Scan}
+import org.apache.spark.sql.types.StructType
+
+class ArangoScan(ctx: PushDownCtx, options: ArangoDBConf) extends Scan with Batch {
+ ExprUtils.verifyColumnNameOfCorruptRecord(ctx.requiredSchema, options.readOptions.columnNameOfCorruptRecord)
+
+ override def readSchema(): StructType = ctx.requiredSchema
+
+ override def toBatch: Batch = this
+
+ override def planInputPartitions(): Array[InputPartition] = options.readOptions.readMode match {
+ case ReadMode.Query => Array(SingletonPartition)
+ case ReadMode.Collection => planCollectionPartitions()
+ }
+
+ override def createReaderFactory(): PartitionReaderFactory = new ArangoPartitionReaderFactory(ctx, options)
+
+ private def planCollectionPartitions(): Array[InputPartition] =
+ ArangoClient.getCollectionShardIds(options)
+ .zip(Stream.continually(options.driverOptions.endpoints).flatten)
+ .map(it => new ArangoCollectionPartition(it._1, it._2))
+
+}
diff --git a/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/reader/ArangoScanBuilder.scala b/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/reader/ArangoScanBuilder.scala
new file mode 100644
index 00000000..5b439438
--- /dev/null
+++ b/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/reader/ArangoScanBuilder.scala
@@ -0,0 +1,66 @@
+package org.apache.spark.sql.arangodb.datasource.reader
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.arangodb.commons.{ArangoDBConf, ReadMode}
+import org.apache.spark.sql.arangodb.commons.filter.{FilterSupport, PushableFilter}
+import org.apache.spark.sql.arangodb.commons.utils.PushDownCtx
+import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownFilters, SupportsPushDownRequiredColumns}
+import org.apache.spark.sql.sources.Filter
+import org.apache.spark.sql.types.StructType
+
+class ArangoScanBuilder(options: ArangoDBConf, tableSchema: StructType) extends ScanBuilder
+ with SupportsPushDownFilters
+ with SupportsPushDownRequiredColumns
+ with Logging {
+
+ private var readSchema: StructType = _
+
+ // fully or partially applied filters
+ private var appliedPushableFilters: Array[PushableFilter] = Array()
+ private var appliedSparkFilters: Array[Filter] = Array()
+
+ override def build(): Scan = new ArangoScan(new PushDownCtx(readSchema, appliedPushableFilters), options)
+
+ override def pushFilters(filters: Array[Filter]): Array[Filter] = {
+ options.readOptions.readMode match {
+ case ReadMode.Collection => pushFiltersReadModeCollection(filters)
+ case ReadMode.Query => filters
+ }
+ }
+
+ private def pushFiltersReadModeCollection(filters: Array[Filter]): Array[Filter] = {
+ // filters related to columnNameOfCorruptRecord are not pushed down
+ val isCorruptRecordFilter = (f: Filter) => f.references.contains(options.readOptions.columnNameOfCorruptRecord)
+ val ignoredFilters = filters.filter(isCorruptRecordFilter)
+ val filtersBySupport = filters
+ .filterNot(isCorruptRecordFilter)
+ .map(f => (f, PushableFilter(f, tableSchema)))
+ .groupBy(_._2.support())
+
+ val fullSupp = filtersBySupport.getOrElse(FilterSupport.FULL, Array())
+ val partialSupp = filtersBySupport.getOrElse(FilterSupport.PARTIAL, Array())
+ val noneSupp = filtersBySupport.getOrElse(FilterSupport.NONE, Array()).map(_._1) ++ ignoredFilters
+
+ val appliedFilters = fullSupp ++ partialSupp
+ appliedPushableFilters = appliedFilters.map(_._2)
+ appliedSparkFilters = appliedFilters.map(_._1)
+
+ if (fullSupp.nonEmpty) {
+ logInfo(s"Filters fully applied in AQL:\n\t${fullSupp.map(_._1).mkString("\n\t")}")
+ }
+ if (partialSupp.nonEmpty) {
+ logInfo(s"Filters partially applied in AQL:\n\t${partialSupp.map(_._1).mkString("\n\t")}")
+ }
+ if (noneSupp.nonEmpty) {
+ logInfo(s"Filters not applied in AQL:\n\t${noneSupp.mkString("\n\t")}")
+ }
+
+ partialSupp.map(_._1) ++ noneSupp
+ }
+
+ override def pushedFilters(): Array[Filter] = appliedSparkFilters
+
+ override def pruneColumns(requiredSchema: StructType): Unit = {
+ this.readSchema = requiredSchema
+ }
+}
diff --git a/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/writer/ArangoBatchWriter.scala b/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/writer/ArangoBatchWriter.scala
new file mode 100644
index 00000000..e7680f0e
--- /dev/null
+++ b/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/writer/ArangoBatchWriter.scala
@@ -0,0 +1,30 @@
+package org.apache.spark.sql.arangodb.datasource.writer
+
+import org.apache.spark.sql.SaveMode
+import org.apache.spark.sql.arangodb.commons.{ArangoClient, ArangoDBConf}
+import org.apache.spark.sql.arangodb.commons.exceptions.DataWriteAbortException
+import org.apache.spark.sql.connector.write.{BatchWrite, DataWriterFactory, PhysicalWriteInfo, WriterCommitMessage}
+import org.apache.spark.sql.types.StructType
+
+class ArangoBatchWriter(schema: StructType, options: ArangoDBConf, mode: SaveMode) extends BatchWrite {
+
+ override def createBatchWriterFactory(info: PhysicalWriteInfo): DataWriterFactory =
+ new ArangoDataWriterFactory(schema, options)
+
+ override def commit(messages: Array[WriterCommitMessage]): Unit = {
+ // nothing to do here
+ }
+
+ override def abort(messages: Array[WriterCommitMessage]): Unit = {
+ val client = ArangoClient(options)
+ mode match {
+ case SaveMode.Append => throw new DataWriteAbortException(
+ "Cannot abort with SaveMode.Append: the underlying data source may require manual cleanup.")
+ case SaveMode.Overwrite => client.truncate()
+ case SaveMode.ErrorIfExists => ???
+ case SaveMode.Ignore => ???
+ }
+ client.shutdown()
+ }
+
+}
diff --git a/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/writer/ArangoDataWriter.scala b/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/writer/ArangoDataWriter.scala
new file mode 100644
index 00000000..b6577254
--- /dev/null
+++ b/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/writer/ArangoDataWriter.scala
@@ -0,0 +1,136 @@
+package org.apache.spark.sql.arangodb.datasource.writer
+
+import com.arangodb.{ArangoDBException, ArangoDBMultipleException}
+import com.arangodb.model.OverwriteMode
+import com.arangodb.util.RawBytes
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.arangodb.commons.exceptions.{ArangoDBDataWriterException, DataWriteAbortException}
+import org.apache.spark.sql.arangodb.commons.mapping.{ArangoGenerator, ArangoGeneratorProvider}
+import org.apache.spark.sql.arangodb.commons.{ArangoClient, ArangoDBConf}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.connector.write.{DataWriter, WriterCommitMessage}
+import org.apache.spark.sql.types.StructType
+
+import java.io.ByteArrayOutputStream
+import java.net.{ConnectException, UnknownHostException}
+import scala.annotation.tailrec
+import scala.collection.JavaConverters.iterableAsScalaIterableConverter
+import scala.util.Random
+
+class ArangoDataWriter(schema: StructType, options: ArangoDBConf, partitionId: Int)
+ extends DataWriter[InternalRow] with Logging {
+
+ private var failures = 0
+ private var exceptions: List[Exception] = List()
+ private var requestCount = 0L
+ private var endpointIdx = partitionId
+ private val endpoints = Stream.continually(options.driverOptions.endpoints).flatten
+ private val rnd = new Random()
+ private var client: ArangoClient = createClient()
+ private var batchCount: Int = _
+ private var outStream: ByteArrayOutputStream = _
+ private var vpackGenerator: ArangoGenerator = _
+
+ initBatch()
+
+ override def write(record: InternalRow): Unit = {
+ vpackGenerator.write(record)
+ vpackGenerator.flush()
+ batchCount += 1
+ if (batchCount == options.writeOptions.batchSize || outStream.size() > options.writeOptions.byteBatchSize) {
+ flushBatch()
+ initBatch()
+ }
+ }
+
+ override def commit(): WriterCommitMessage = {
+ flushBatch()
+ null // scalastyle:ignore null
+ }
+
+ /**
+ * Data cleanup will happen in [[ArangoBatchWriter.abort()]]
+ */
+ override def abort(): Unit = if (!canRetry) {
+ client.shutdown()
+ throw new DataWriteAbortException(
+ "Task cannot be retried. To make batch writes idempotent, so that they can be retried, consider using " +
+ "'keep.null=true' (default) and 'overwrite.mode=(ignore|replace|update)'.")
+ }
+
+ override def close(): Unit = {
+ client.shutdown()
+ }
+
+ private def createClient() = ArangoClient(options.updated(ArangoDBConf.ENDPOINTS, endpoints(endpointIdx)))
+
+ private def canRetry: Boolean = ArangoDataWriter.canRetry(schema, options)
+
+ private def initBatch(): Unit = {
+ batchCount = 0
+ outStream = new ByteArrayOutputStream()
+ vpackGenerator = ArangoGeneratorProvider().of(options.driverOptions.contentType, schema, outStream, options)
+ vpackGenerator.writeStartArray()
+ }
+
+ private def flushBatch(): Unit = {
+ vpackGenerator.writeEndArray()
+ vpackGenerator.close()
+ vpackGenerator.flush()
+ logDebug(s"flushBatch(), bufferSize: ${outStream.size()}")
+ saveDocuments(RawBytes.of(outStream.toByteArray))
+ }
+
+ @tailrec private def saveDocuments(payload: RawBytes): Unit = {
+ try {
+ requestCount += 1
+ logDebug(s"Sending request #$requestCount for partition $partitionId")
+ client.saveDocuments(payload)
+ logDebug(s"Received response #$requestCount for partition $partitionId")
+ failures = 0
+ exceptions = List()
+ } catch {
+ case e: Exception =>
+ client.shutdown()
+ failures += 1
+ exceptions = e :: exceptions
+ endpointIdx += 1
+ if ((canRetry || isConnectionException(e)) && failures < options.writeOptions.maxAttempts) {
+ val delay = computeDelay()
+ logWarning(s"Got exception while saving documents, retrying in $delay ms:", e)
+ Thread.sleep(delay)
+ client = createClient()
+ saveDocuments(payload)
+ } else {
+ throw new ArangoDBDataWriterException(exceptions.reverse.toArray)
+ }
+ }
+ }
+
+ private def computeDelay(): Int = {
+ val min = options.writeOptions.minRetryDelay
+ val max = options.writeOptions.maxRetryDelay
+ val diff = max - min
+ val delta = if (diff <= 0) 0 else rnd.nextInt(diff)
+ min + delta
+ }
+
+ private def isConnectionException(e: Throwable): Boolean = e match {
+ case ae: ArangoDBException => isConnectionException(ae.getCause)
+ case me: ArangoDBMultipleException => me.getExceptions.asScala.forall(isConnectionException)
+ case _: ConnectException => true
+ case _: UnknownHostException => true
+ case _ => false
+ }
+
+}
+
+object ArangoDataWriter {
+ def canRetry(schema: StructType, options: ArangoDBConf): Boolean =
+ schema.exists(p => p.name == "_key" && !p.nullable) && (options.writeOptions.overwriteMode match {
+ case OverwriteMode.ignore => true
+ case OverwriteMode.replace => true
+ case OverwriteMode.update => options.writeOptions.keepNull
+ case OverwriteMode.conflict => false
+ })
+}
diff --git a/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/writer/ArangoDataWriterFactory.scala b/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/writer/ArangoDataWriterFactory.scala
new file mode 100644
index 00000000..d4513fd7
--- /dev/null
+++ b/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/writer/ArangoDataWriterFactory.scala
@@ -0,0 +1,12 @@
+package org.apache.spark.sql.arangodb.datasource.writer
+
+import org.apache.spark.sql.arangodb.commons.ArangoDBConf
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.connector.write.{DataWriter, DataWriterFactory}
+import org.apache.spark.sql.types.StructType
+
+class ArangoDataWriterFactory(schema: StructType, options: ArangoDBConf) extends DataWriterFactory {
+ override def createWriter(partitionId: Int, taskId: Long): DataWriter[InternalRow] = {
+ new ArangoDataWriter(schema, options, partitionId)
+ }
+}
diff --git a/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/writer/ArangoWriterBuilder.scala b/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/writer/ArangoWriterBuilder.scala
new file mode 100644
index 00000000..4336e88f
--- /dev/null
+++ b/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/writer/ArangoWriterBuilder.scala
@@ -0,0 +1,93 @@
+package org.apache.spark.sql.arangodb.datasource.writer
+
+import com.arangodb.entity.CollectionType
+import com.arangodb.model.OverwriteMode
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.arangodb.commons.{ArangoClient, ArangoDBConf, ContentType}
+import org.apache.spark.sql.connector.write.{BatchWrite, SupportsTruncate, WriteBuilder}
+import org.apache.spark.sql.types.{DecimalType, StringType, StructType}
+import org.apache.spark.sql.{AnalysisException, SaveMode}
+
+class ArangoWriterBuilder(schema: StructType, options: ArangoDBConf)
+ extends WriteBuilder with SupportsTruncate with Logging {
+
+ private var mode: SaveMode = SaveMode.Append
+ validateConfig()
+
+ override def buildForBatch(): BatchWrite = {
+ val client = ArangoClient(options)
+ if (!client.collectionExists()) {
+ client.createCollection()
+ }
+ client.shutdown()
+
+ val updatedOptions = options.updated(ArangoDBConf.OVERWRITE_MODE, mode match {
+ case SaveMode.Append => options.writeOptions.overwriteMode.getValue
+ case _ => OverwriteMode.ignore.getValue
+ })
+
+ logSummary(updatedOptions)
+ new ArangoBatchWriter(schema, updatedOptions, mode)
+ }
+
+ override def truncate(): WriteBuilder = {
+ mode = SaveMode.Overwrite
+ if (options.writeOptions.confirmTruncate) {
+ val client = ArangoClient(options)
+ if (client.collectionExists()) {
+ client.truncate()
+ } else {
+ client.createCollection()
+ }
+ client.shutdown()
+ this
+ } else {
+ throw new AnalysisException(
+ "You are attempting to use overwrite mode which will truncate this collection prior to inserting data. If " +
+ "you just want to change data already in the collection set save mode 'append' and " +
+ s"'overwrite.mode=(replace|update)'. To actually truncate set '${ArangoDBConf.CONFIRM_TRUNCATE}=true'.")
+ }
+ }
+
+ private def validateConfig(): Unit = {
+ if (options.driverOptions.contentType == ContentType.JSON && hasDecimalTypeFields) {
+ throw new UnsupportedOperationException("Cannot write DecimalType when using contentType=json")
+ }
+
+ if (options.writeOptions.collectionType == CollectionType.EDGES &&
+ !schema.exists(p => p.name == "_from" && p.dataType == StringType && !p.nullable)
+ ) {
+ throw new IllegalArgumentException("Writing edge collection requires non nullable string field named _from.")
+ }
+
+ if (options.writeOptions.collectionType == CollectionType.EDGES &&
+ !schema.exists(p => p.name == "_to" && p.dataType == StringType && !p.nullable)
+ ) {
+ throw new IllegalArgumentException("Writing edge collection requires non nullable string field named _to.")
+ }
+ }
+
+ private def hasDecimalTypeFields: Boolean =
+ schema.existsRecursively {
+ case _: DecimalType => true
+ case _ => false
+ }
+
+ private def logSummary(updatedOptions: ArangoDBConf): Unit = {
+ val canRetry = ArangoDataWriter.canRetry(schema, updatedOptions)
+
+ logInfo(s"Using save mode: $mode")
+ logInfo(s"Using write configuration: ${updatedOptions.writeOptions}")
+ logInfo(s"Using mapping configuration: ${updatedOptions.mappingOptions}")
+ logInfo(s"Can retry: $canRetry")
+
+ if (!canRetry) {
+ logWarning(
+ """The provided configuration does not allow idempotent requests: write failures will not be retried and lead
+ |to task failure. Speculative task executions could fail or write incorrect data."""
+ .stripMargin.replaceAll("\n", "")
+ )
+ }
+ }
+
+}
diff --git a/bin/clean.sh b/bin/clean.sh
index be538923..9a8398b0 100755
--- a/bin/clean.sh
+++ b/bin/clean.sh
@@ -7,3 +7,5 @@ mvn clean -Pspark-3.2 -Pscala-2.12
mvn clean -Pspark-3.2 -Pscala-2.13
mvn clean -Pspark-3.3 -Pscala-2.12
mvn clean -Pspark-3.3 -Pscala-2.13
+mvn clean -Pspark-3.4 -Pscala-2.12
+mvn clean -Pspark-3.4 -Pscala-2.13
diff --git a/bin/test.sh b/bin/test.sh
index 498635c6..fe1f78be 100755
--- a/bin/test.sh
+++ b/bin/test.sh
@@ -23,3 +23,9 @@ mvn test -Pspark-3.3 -Pscala-2.12
mvn clean -Pspark-3.3 -Pscala-2.13
mvn test -Pspark-3.3 -Pscala-2.13
+
+mvn clean -Pspark-3.4 -Pscala-2.12
+mvn test -Pspark-3.4 -Pscala-2.12
+
+mvn clean -Pspark-3.4 -Pscala-2.13
+mvn test -Pspark-3.4 -Pscala-2.13
diff --git a/demo/README.md b/demo/README.md
index 9aced3c6..b82a5ee8 100644
--- a/demo/README.md
+++ b/demo/README.md
@@ -79,7 +79,7 @@ docker run -it --rm \
-v $(pwd):/demo \
-v $(pwd)/docker/.ivy2:/opt/bitnami/spark/.ivy2 \
--network arangodb \
- docker.io/bitnami/spark:3.2.1 \
+ docker.io/bitnami/spark:3.2.4 \
./bin/spark-submit --master spark://spark-master:7077 \
--packages="com.arangodb:arangodb-spark-datasource-3.2_2.12:$ARANGO_SPARK_VERSION" \
--class Demo /demo/target/demo-$ARANGO_SPARK_VERSION.jar
diff --git a/demo/docker/start_spark_3.2.sh b/demo/docker/start_spark_3.2.sh
index c64ecad7..1dba1f70 100755
--- a/demo/docker/start_spark_3.2.sh
+++ b/demo/docker/start_spark_3.2.sh
@@ -9,7 +9,7 @@ docker run -d --network arangodb --ip 172.28.10.1 --name spark-master -h spark-m
-e SPARK_LOCAL_STORAGE_ENCRYPTION_ENABLED=no \
-e SPARK_SSL_ENABLED=no \
-v $(pwd)/docker/import:/import \
- docker.io/bitnami/spark:3.2.1
+ docker.io/bitnami/spark:3.2.4
docker run -d --network arangodb --ip 172.28.10.11 --name spark-worker-1 -h spark-worker-1 \
-e SPARK_MODE=worker \
@@ -21,7 +21,7 @@ docker run -d --network arangodb --ip 172.28.10.11 --name spark-worker-1 -h spar
-e SPARK_LOCAL_STORAGE_ENCRYPTION_ENABLED=no \
-e SPARK_SSL_ENABLED=no \
-v $(pwd)/docker/import:/import \
- docker.io/bitnami/spark:3.2.1
+ docker.io/bitnami/spark:3.2.4
docker run -d --network arangodb --ip 172.28.10.12 --name spark-worker-2 -h spark-worker-2 \
-e SPARK_MODE=worker \
@@ -33,7 +33,7 @@ docker run -d --network arangodb --ip 172.28.10.12 --name spark-worker-2 -h spar
-e SPARK_LOCAL_STORAGE_ENCRYPTION_ENABLED=no \
-e SPARK_SSL_ENABLED=no \
-v $(pwd)/docker/import:/import \
- docker.io/bitnami/spark:3.2.1
+ docker.io/bitnami/spark:3.2.4
docker run -d --network arangodb --ip 172.28.10.13 --name spark-worker-3 -h spark-worker-3 \
-e SPARK_MODE=worker \
@@ -45,4 +45,4 @@ docker run -d --network arangodb --ip 172.28.10.13 --name spark-worker-3 -h spar
-e SPARK_LOCAL_STORAGE_ENCRYPTION_ENABLED=no \
-e SPARK_SSL_ENABLED=no \
-v $(pwd)/docker/import:/import \
- docker.io/bitnami/spark:3.2.1
+ docker.io/bitnami/spark:3.2.4
diff --git a/demo/pom.xml b/demo/pom.xml
index ceb0609c..e9a43656 100644
--- a/demo/pom.xml
+++ b/demo/pom.xml
@@ -55,21 +55,28 @@
spark-3.2
+
+ true
+
- 3.2.1
+ 3.2.4
3.2
spark-3.3
-
- true
-
3.3.2
3.3
+
+ spark-3.4
+
+ 3.4.0
+ 3.4
+
+
diff --git a/docker/start_spark_2.4.sh b/docker/start_spark_2.4.sh
deleted file mode 100755
index 09999280..00000000
--- a/docker/start_spark_2.4.sh
+++ /dev/null
@@ -1,7 +0,0 @@
-#!/bin/bash
-
-docker network create arangodb --subnet 172.28.0.0/16
-docker run --network arangodb --ip 172.28.10.1 --name spark-master -h spark-master -e ENABLE_INIT_DAEMON=false -d bde2020/spark-master:2.4.5-hadoop2.7
-docker run --network arangodb --ip 172.28.10.11 --name spark-worker-1 -h spark-worker-1 -e SPARK_WORKER_CORES=1 -e ENABLE_INIT_DAEMON=false -d bde2020/spark-worker:2.4.5-hadoop2.7
-docker run --network arangodb --ip 172.28.10.12 --name spark-worker-2 -h spark-worker-2 -e SPARK_WORKER_CORES=1 -e ENABLE_INIT_DAEMON=false -d bde2020/spark-worker:2.4.5-hadoop2.7
-docker run --network arangodb --ip 172.28.10.13 --name spark-worker-3 -h spark-worker-3 -e SPARK_WORKER_CORES=1 -e ENABLE_INIT_DAEMON=false -d bde2020/spark-worker:2.4.5-hadoop2.7
diff --git a/docker/start_spark_3.1.sh b/docker/start_spark_3.1.sh
deleted file mode 100755
index e3c3bb06..00000000
--- a/docker/start_spark_3.1.sh
+++ /dev/null
@@ -1,7 +0,0 @@
-#!/bin/bash
-
-docker network create arangodb --subnet 172.28.0.0/16
-docker run --network arangodb --ip 172.28.10.1 --name spark-master -h spark-master -e ENABLE_INIT_DAEMON=false -d bde2020/spark-master:3.1.1-hadoop3.2
-docker run --network arangodb --ip 172.28.10.11 --name spark-worker-1 -h spark-worker-1 -e SPARK_WORKER_CORES=1 -e ENABLE_INIT_DAEMON=false -d bde2020/spark-worker:3.1.1-hadoop3.2
-docker run --network arangodb --ip 172.28.10.12 --name spark-worker-2 -h spark-worker-2 -e SPARK_WORKER_CORES=1 -e ENABLE_INIT_DAEMON=false -d bde2020/spark-worker:3.1.1-hadoop3.2
-docker run --network arangodb --ip 172.28.10.13 --name spark-worker-3 -h spark-worker-3 -e SPARK_WORKER_CORES=1 -e ENABLE_INIT_DAEMON=false -d bde2020/spark-worker:3.1.1-hadoop3.2
diff --git a/docker/stop.sh b/docker/stop.sh
deleted file mode 100755
index c3f9aa22..00000000
--- a/docker/stop.sh
+++ /dev/null
@@ -1,10 +0,0 @@
-#!/bin/bash
-
-docker exec adb /app/arangodb stop
-sleep 1
-docker rm -f \
- adb \
- spark-master \
- spark-worker-1 \
- spark-worker-2 \
- spark-worker-3
diff --git a/integration-tests/src/test/scala/org/apache/spark/sql/arangodb/datasource/DeserializationCastTest.scala b/integration-tests/src/test/scala/org/apache/spark/sql/arangodb/datasource/DeserializationCastTest.scala
index 6fe0e22f..a56533eb 100644
--- a/integration-tests/src/test/scala/org/apache/spark/sql/arangodb/datasource/DeserializationCastTest.scala
+++ b/integration-tests/src/test/scala/org/apache/spark/sql/arangodb/datasource/DeserializationCastTest.scala
@@ -62,6 +62,7 @@ class DeserializationCastTest extends BaseSparkTest {
def nullToIntegerCast(contentType: String): Unit = {
// FIXME: DE-599
assumeTrue(!SPARK_VERSION.startsWith("3.3"))
+ assumeTrue(!SPARK_VERSION.startsWith("3.4"))
doTestImplicitCast(
StructType(Array(StructField("a", IntegerType, nullable = false))),
@@ -76,6 +77,7 @@ class DeserializationCastTest extends BaseSparkTest {
def nullToDoubleCast(contentType: String): Unit = {
// FIXME: DE-599
assumeTrue(!SPARK_VERSION.startsWith("3.3"))
+ assumeTrue(!SPARK_VERSION.startsWith("3.4"))
doTestImplicitCast(
StructType(Array(StructField("a", DoubleType, nullable = false))),
@@ -90,6 +92,7 @@ class DeserializationCastTest extends BaseSparkTest {
def nullAsBoolean(contentType: String): Unit = {
// FIXME: DE-599
assumeTrue(!SPARK_VERSION.startsWith("3.3"))
+ assumeTrue(!SPARK_VERSION.startsWith("3.4"))
doTestImplicitCast(
StructType(Array(StructField("a", BooleanType, nullable = false))),
diff --git a/integration-tests/src/test/scala/org/apache/spark/sql/arangodb/datasource/write/AbortTest.scala b/integration-tests/src/test/scala/org/apache/spark/sql/arangodb/datasource/write/AbortTest.scala
index 242b4651..6386ef5b 100644
--- a/integration-tests/src/test/scala/org/apache/spark/sql/arangodb/datasource/write/AbortTest.scala
+++ b/integration-tests/src/test/scala/org/apache/spark/sql/arangodb/datasource/write/AbortTest.scala
@@ -79,7 +79,8 @@ class AbortTest extends BaseSparkTest {
ArangoDBConf.PROTOCOL -> protocol,
ArangoDBConf.CONTENT_TYPE -> contentType,
ArangoDBConf.OVERWRITE_MODE -> OverwriteMode.replace.getValue,
- ArangoDBConf.COLLECTION_TYPE -> CollectionType.EDGE.name
+ ArangoDBConf.COLLECTION_TYPE -> CollectionType.EDGE.name,
+ ArangoDBConf.BATCH_SIZE -> "9"
))
.save()
})
@@ -113,7 +114,8 @@ class AbortTest extends BaseSparkTest {
ArangoDBConf.PROTOCOL -> protocol,
ArangoDBConf.CONTENT_TYPE -> contentType,
ArangoDBConf.OVERWRITE_MODE -> OverwriteMode.replace.getValue,
- ArangoDBConf.COLLECTION_TYPE -> CollectionType.EDGE.name
+ ArangoDBConf.COLLECTION_TYPE -> CollectionType.EDGE.name,
+ ArangoDBConf.BATCH_SIZE -> "9"
))
.save()
})
@@ -150,7 +152,8 @@ class AbortTest extends BaseSparkTest {
ArangoDBConf.COLLECTION -> collectionName,
ArangoDBConf.PROTOCOL -> protocol,
ArangoDBConf.CONTENT_TYPE -> contentType,
- ArangoDBConf.OVERWRITE_MODE -> OverwriteMode.replace.getValue
+ ArangoDBConf.OVERWRITE_MODE -> OverwriteMode.replace.getValue,
+ ArangoDBConf.BATCH_SIZE -> "9"
))
.save()
})
@@ -181,14 +184,22 @@ class AbortTest extends BaseSparkTest {
ArangoDBConf.PROTOCOL -> protocol,
ArangoDBConf.CONTENT_TYPE -> contentType,
ArangoDBConf.CONFIRM_TRUNCATE -> "true",
- ArangoDBConf.OVERWRITE_MODE -> OverwriteMode.replace.getValue
+ ArangoDBConf.OVERWRITE_MODE -> OverwriteMode.replace.getValue,
+ ArangoDBConf.BATCH_SIZE -> "1"
))
.save()
})
assertThat(thrown).isInstanceOf(classOf[SparkException])
- assertThat(thrown.getCause.getCause).isInstanceOf(classOf[ArangoDBDataWriterException])
- val rootEx = thrown.getCause.getCause.getCause
+
+ val cause = if(SPARK_VERSION.startsWith("3.4")) {
+ thrown.getCause
+ } else {
+ thrown.getCause.getCause
+ }
+
+ assertThat(cause).isInstanceOf(classOf[ArangoDBDataWriterException])
+ val rootEx = cause.getCause
assertThat(rootEx).isInstanceOf(classOf[ArangoDBMultiException])
val errors = rootEx.asInstanceOf[ArangoDBMultiException].errors
assertThat(errors.length).isEqualTo(1)
diff --git a/integration-tests/src/test/scala/org/apache/spark/sql/arangodb/datasource/write/OverwriteModeTest.scala b/integration-tests/src/test/scala/org/apache/spark/sql/arangodb/datasource/write/OverwriteModeTest.scala
index d06bd78e..3cd8c3f4 100644
--- a/integration-tests/src/test/scala/org/apache/spark/sql/arangodb/datasource/write/OverwriteModeTest.scala
+++ b/integration-tests/src/test/scala/org/apache/spark/sql/arangodb/datasource/write/OverwriteModeTest.scala
@@ -58,7 +58,8 @@ class OverwriteModeTest extends BaseSparkTest {
ArangoDBConf.COLLECTION -> collectionName,
ArangoDBConf.PROTOCOL -> protocol,
ArangoDBConf.CONTENT_TYPE -> contentType,
- ArangoDBConf.OVERWRITE_MODE -> OverwriteMode.conflict.getValue
+ ArangoDBConf.OVERWRITE_MODE -> OverwriteMode.conflict.getValue,
+ ArangoDBConf.BATCH_SIZE -> "3"
))
.save()
})
diff --git a/pom.xml b/pom.xml
index 89ca2ba2..bc24f921 100644
--- a/pom.xml
+++ b/pom.xml
@@ -117,6 +117,15 @@
4.1.0
+
+ spark-3.4
+
+ 3.4.0
+ 3.4
+
+ 4.1.0
+
+
no-deploy