Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ on:
types: [ opened, synchronize, reopened ]
branches:
- main
- devel

jobs:

Expand Down
23 changes: 18 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,9 @@ To use in external Spark cluster, submit your application with the following par
### SSL

To use TLS secured connections to ArangoDB, set `ssl.enabled` to `true` and either:
- provide base64 encoded certificate as `ssl.cert.value` configuration entry and optionally set `ssl.*`, or
- start Spark driver and workers with properly configured JVM default TrustStore, see
[link](https://spark.apache.org/docs/latest/security.html#ssl-configuration)
- provide base64 encoded certificate as `ssl.cert.value` configuration entry and optionally set `ssl.*`, or

### Supported deployment topologies

Expand Down Expand Up @@ -140,7 +140,15 @@ usersDF.filter(col("birthday") === "1982-12-15").show()
- `batch.size`: reading batch size, default `1000`
- `fill.cache`: whether the query should store the data it reads in the RocksDB block cache (`true`|`false`)
- `stream`: whether the query should be executed lazily, default `true`

- `mode`: allows a mode for dealing with corrupt records during parsing:
- `PERMISSIVE` : when it meets a corrupted record, puts the malformed string into a field configured by
`columnNameOfCorruptRecord`, and sets malformed fields to null. To keep corrupt records, a user can set a string
type field named columnNameOfCorruptRecord in a user-defined schema. If a schema does not have the field, it drops
corrupt records during parsing. When inferring a schema, it implicitly adds a `columnNameOfCorruptRecord` field in
an output schema
- `DROPMALFORMED`: ignores the whole corrupted records
- `FAILFAST`: throws an exception when it meets corrupted records
- `columnNameOfCorruptRecord`: allows renaming the new field having malformed string created by `PERMISSIVE` mode

### Predicate and Projection Pushdown

Expand Down Expand Up @@ -255,7 +263,8 @@ fail. To makes the job more resilient to temporary errors (i.e. connectivity pro
will be retried (with another coordinator) if the configured `overwrite.mode` allows for idempotent requests, namely:
- `replace`
- `ignore`
- `update`
- `update` with `keep.null=true`

These configurations of `overwrite.mode` would also be compatible with speculative execution of tasks.

A failing batch-saving request is retried at most once for every coordinator. After that, if still failing, the write
Expand Down Expand Up @@ -343,8 +352,12 @@ df.write

## Current limitations

- on batch reading, bad records are not tolerated and will make the job fail
- in read jobs using `stream=true` (default), possible AQL warnings are only logged at the end of each the read task
- In Spark 2.4, on corrupted records in batch reading, partial results are not supported. All fields other than the
field configured by `columnNameOfCorruptRecord` are set to `null`
- in read jobs using `stream=true` (default), possible AQL warnings are only logged at the end of each read task (BTS-671)
- for `content-type=vpack`, implicit deserialization casts don't work well, i.e. reading a document having a field with
a numeric value whereas the related read schema requires a string value for such field
- dates and timestamps fields are interpreted to be in UTC time zone

## Demo

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ import com.fasterxml.jackson.databind.ObjectMapper
import com.fasterxml.jackson.module.scala.DefaultScalaModule
import org.apache.spark.internal.Logging
import org.apache.spark.sql.arangodb.commons.exceptions.ArangoDBMultiException
import org.apache.spark.sql.arangodb.commons.utils.PushDownCtx
import org.apache.spark.sql.arangodb.commons.filter.PushableFilter
import org.apache.spark.sql.types.StructType

import java.util
import scala.collection.JavaConverters.{asScalaIteratorConverter, mapAsJavaMapConverter}
import scala.collection.JavaConverters.mapAsJavaMapConverter

class ArangoClient(options: ArangoOptions) extends Logging {

Expand All @@ -38,12 +38,12 @@ class ArangoClient(options: ArangoOptions) extends Logging {

def shutdown(): Unit = arangoDB.shutdown()

def readCollectionPartition(shardId: String, ctx: PushDownCtx): ArangoCursor[VPackSlice] = {
def readCollectionPartition(shardId: String, filters: Array[PushableFilter], schema: StructType): ArangoCursor[VPackSlice] = {
val query =
s"""
|FOR d IN @@col
|${PushdownUtils.generateFilterClause(ctx.filters)}
|RETURN ${PushdownUtils.generateColumnsFilter(ctx.requiredSchema, "d")}"""
|${PushdownUtils.generateFilterClause(filters)}
|RETURN ${PushdownUtils.generateColumnsFilter(schema, "d")}"""
.stripMargin
.replaceAll("\n", " ")
val params = Map[String, AnyRef]("@col" -> options.readOptions.collection.get)
Expand All @@ -65,7 +65,7 @@ class ArangoClient(options: ArangoOptions) extends Logging {
classOf[VPackSlice])
}

def readCollectionSample(): util.List[String] = {
def readCollectionSample(): Seq[String] = {
val query = "FOR d IN @@col LIMIT @size RETURN d"
val params = Map(
"@col" -> options.readOptions.collection.get,
Expand All @@ -74,22 +74,28 @@ class ArangoClient(options: ArangoOptions) extends Logging {
.asInstanceOf[Map[String, AnyRef]]
val opts = aqlOptions()
logDebug(s"""Executing AQL query: \n\t$query ${if (params.nonEmpty) s"\n\t with params: $params" else ""}""")

import scala.collection.JavaConverters.iterableAsScalaIterableConverter
arangoDB
.db(options.readOptions.db)
.query(query, params.asJava, opts, classOf[String])
.asListRemaining()
.asScala
.toSeq
}

def readQuerySample(): util.List[String] = {
def readQuerySample(): Seq[String] = {
val query = options.readOptions.query.get
logDebug(s"Executing AQL query: \n\t$query")
arangoDB
val cursor = arangoDB
.db(options.readOptions.db)
.query(
query,
aqlOptions(),
classOf[String])
.asListRemaining()

import scala.collection.JavaConverters.asScalaIteratorConverter
cursor.asScala.take(options.readOptions.sampleSize).toSeq
}

def collectionExists(): Boolean = arangoDB
Expand Down Expand Up @@ -137,6 +143,7 @@ class ArangoClient(options: ArangoOptions) extends Logging {
request.setBody(data)
val response = arangoDB.execute(request)

import scala.collection.JavaConverters.asScalaIteratorConverter
// FIXME
// in case there are no errors, response body is an empty object
// In cluster 3.8.1 this is not true due to: https://arangodb.atlassian.net/browse/BTS-592
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,23 @@ package org.apache.spark.sql.arangodb.commons

import com.arangodb.{ArangoDB, entity}
import com.arangodb.model.OverwriteMode
import org.apache.spark.sql.catalyst.util.{ParseMode, PermissiveMode}

import java.io.ByteArrayInputStream
import java.security.KeyStore
import java.security.cert.CertificateFactory
import java.util
import java.util.Base64
import java.util.{Base64, Locale}
import javax.net.ssl.{SSLContext, TrustManagerFactory}
import scala.collection.JavaConverters.mapAsScalaMapConverter


/**
* @author Michele Rastelli
*/
class ArangoOptions(private val options: Map[String, String]) extends Serializable {
class ArangoOptions(opts: Map[String, String]) extends Serializable {
private val options: Map[String, String] = opts.map(e => (e._1.toLowerCase(Locale.US), e._2))

lazy val driverOptions: ArangoDriverOptions = new ArangoDriverOptions(options)
lazy val readOptions: ArangoReadOptions = new ArangoReadOptions(options)
lazy val writeOptions: ArangoWriteOptions = new ArangoWriteOptions(options)
Expand Down Expand Up @@ -85,6 +88,8 @@ object ArangoOptions {
val SAMPLE_SIZE = "sample.size"
val FILL_BLOCK_CACHE = "fill.cache"
val STREAM = "stream"
val PARSE_MODE = "mode"
val CORRUPT_RECORDS_COLUMN = "columnnameofcorruptrecord"

// write options
val NUMBER_OF_SHARDS = "table.shards"
Expand Down Expand Up @@ -175,6 +180,8 @@ class ArangoReadOptions(options: Map[String, String]) extends CommonOptions(opti
else throw new IllegalArgumentException("Either collection or query must be defined")
val fillBlockCache: Option[Boolean] = options.get(ArangoOptions.FILL_BLOCK_CACHE).map(_.toBoolean)
val stream: Boolean = options.getOrElse(ArangoOptions.STREAM, "true").toBoolean
val parseMode: ParseMode = options.get(ArangoOptions.PARSE_MODE).map(ParseMode.fromString).getOrElse(PermissiveMode)
val columnNameOfCorruptRecord: String = options.getOrElse(ArangoOptions.CORRUPT_RECORDS_COLUMN, "")
}

class ArangoWriteOptions(options: Map[String, String]) extends CommonOptions(options) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package org.apache.spark.sql.arangodb.commons

import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.{StringType, StructField, StructType}
import org.apache.spark.sql.{Encoders, SparkSession}

/**
Expand All @@ -17,10 +17,15 @@ object ArangoUtils {
client.shutdown()

val spark = SparkSession.getActiveSession.get
spark
val schema = spark
.read
.json(spark.createDataset(sampleEntries)(Encoders.STRING))
.schema

if (options.readOptions.columnNameOfCorruptRecord.isEmpty)
schema
else
schema.add(StructField(options.readOptions.columnNameOfCorruptRecord, StringType, nullable = true))
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -42,5 +42,5 @@ class VPackArangoParser(schema: DataType)
extends ArangoParserImpl(
schema,
createOptions(new VPackFactory()),
(bytes: Array[Byte]) => UTF8String.fromString(new VPackParser.Builder().build().toJson(new VPackSlice(bytes)))
(bytes: Array[Byte]) => UTF8String.fromString(new VPackParser.Builder().build().toJson(new VPackSlice(bytes), true))
)
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
package org.apache.spark.sql.arangodb.datasource.reader

import com.arangodb.entity.CursorEntity.Warning
import com.arangodb.velocypack.VPackSlice
import org.apache.spark.internal.Logging
import org.apache.spark.sql.arangodb.commons.mapping.ArangoParserProvider
import org.apache.spark.sql.arangodb.commons.utils.PushDownCtx
import org.apache.spark.sql.arangodb.commons.{ArangoClient, ArangoOptions, ContentType}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.datasources.FailureSafeParser
import org.apache.spark.sql.sources.v2.reader.InputPartitionReader
import org.apache.spark.sql.types.StructType

import java.nio.charset.StandardCharsets
import scala.annotation.tailrec
import scala.collection.JavaConverters.iterableAsScalaIterableConverter


Expand All @@ -21,30 +23,39 @@ class ArangoCollectionPartitionReader(

// override endpoints with partition endpoint
private val options = opts.updated(ArangoOptions.ENDPOINTS, inputPartition.endpoint)
private val parser = ArangoParserProvider().of(options.readOptions.contentType, ctx.requiredSchema)
private val actualSchema = StructType(ctx.requiredSchema.filterNot(_.name == options.readOptions.columnNameOfCorruptRecord))
private val parser = ArangoParserProvider().of(options.readOptions.contentType, actualSchema)
private val safeParser = new FailureSafeParser[Array[Byte]](
parser.parse(_).toSeq,
options.readOptions.parseMode,
ctx.requiredSchema,
options.readOptions.columnNameOfCorruptRecord)
private val client = ArangoClient(options)
private val iterator = client.readCollectionPartition(inputPartition.shardId, ctx)
private val iterator = client.readCollectionPartition(inputPartition.shardId, ctx.filters, actualSchema)

private var current: VPackSlice = _
var rowIterator: Iterator[InternalRow] = _

// warnings of non stream AQL cursors are all returned along with the first batch
if (!options.readOptions.stream) logWarns()

override def next: Boolean =
@tailrec
final override def next: Boolean =
if (iterator.hasNext) {
current = iterator.next()
true
val current = iterator.next()
rowIterator = safeParser.parse(options.readOptions.contentType match {
case ContentType.VPack => current.toByteArray
case ContentType.Json => current.toString.getBytes(StandardCharsets.UTF_8)
})
if (rowIterator.hasNext) true
else next
} else {
// FIXME: https://arangodb.atlassian.net/browse/BTS-671
// stream AQL cursors' warnings are only returned along with the final batch
if (options.readOptions.stream) logWarns()
false
}

override def get: InternalRow = options.readOptions.contentType match {
case ContentType.VPack => parser.parse(current.toByteArray).head
case ContentType.Json => parser.parse(current.toString.getBytes(StandardCharsets.UTF_8)).head
}
override def get: InternalRow = rowIterator.next()

override def close(): Unit = {
iterator.close()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
package org.apache.spark.sql.arangodb.datasource.reader

import org.apache.spark.internal.Logging
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.arangodb.commons.filter.{FilterSupport, PushableFilter}
import org.apache.spark.sql.arangodb.commons.utils.PushDownCtx
import org.apache.spark.sql.arangodb.commons.{ArangoClient, ArangoOptions, ReadMode}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.sources.Filter
import org.apache.spark.sql.sources.v2.reader.{DataSourceReader, InputPartition, SupportsPushDownFilters, SupportsPushDownRequiredColumns}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.{StringType, StructType}

import java.util
import scala.collection.JavaConverters.seqAsJavaListConverter
Expand All @@ -17,6 +18,8 @@ class ArangoDataSourceReader(tableSchema: StructType, options: ArangoOptions) ex
with SupportsPushDownRequiredColumns
with Logging {

verifyColumnNameOfCorruptRecord(tableSchema, options.readOptions.columnNameOfCorruptRecord)

// fully or partially applied filters
private var appliedPushableFilters: Array[PushableFilter] = Array()
private var appliedSparkFilters: Array[Filter] = Array()
Expand All @@ -36,26 +39,30 @@ class ArangoDataSourceReader(tableSchema: StructType, options: ArangoOptions) ex
.map(it => new ArangoCollectionPartition(it._1, it._2, new PushDownCtx(readSchema(), appliedPushableFilters), options))

override def pushFilters(filters: Array[Filter]): Array[Filter] = {
// filters related to columnNameOfCorruptRecord are not pushed down
val isCorruptRecordFilter = (f: Filter) => f.references.contains(options.readOptions.columnNameOfCorruptRecord)
val ignoredFilters = filters.filter(isCorruptRecordFilter)
val filtersBySupport = filters
.filterNot(isCorruptRecordFilter)
.map(f => (f, PushableFilter(f, tableSchema)))
.groupBy(_._2.support())

val fullSupp = filtersBySupport.getOrElse(FilterSupport.FULL, Array())
val partialSupp = filtersBySupport.getOrElse(FilterSupport.PARTIAL, Array())
val noneSupp = filtersBySupport.getOrElse(FilterSupport.NONE, Array())
val noneSupp = filtersBySupport.getOrElse(FilterSupport.NONE, Array()).map(_._1) ++ ignoredFilters

val appliedFilters = fullSupp ++ partialSupp
appliedPushableFilters = appliedFilters.map(_._2)
appliedSparkFilters = appliedFilters.map(_._1)

if (fullSupp.nonEmpty)
logInfo(s"Fully supported filters (applied in AQL):\n\t${fullSupp.map(_._1).mkString("\n\t")}")
logInfo(s"Filters fully applied in AQL:\n\t${fullSupp.map(_._1).mkString("\n\t")}")
if (partialSupp.nonEmpty)
logInfo(s"Partially supported filters (applied in AQL and Spark):\n\t${partialSupp.map(_._1).mkString("\n\t")}")
logInfo(s"Filters partially applied in AQL:\n\t${partialSupp.map(_._1).mkString("\n\t")}")
if (noneSupp.nonEmpty)
logInfo(s"Not supported filters (applied in Spark):\n\t${noneSupp.map(_._1).mkString("\n\t")}")
logInfo(s"Filters not applied in AQL:\n\t${noneSupp.mkString("\n\t")}")

(partialSupp ++ noneSupp).map(_._1)
partialSupp.map(_._1) ++ noneSupp
}

override def pushedFilters(): Array[Filter] = appliedSparkFilters
Expand All @@ -64,4 +71,20 @@ class ArangoDataSourceReader(tableSchema: StructType, options: ArangoOptions) ex
this.requiredSchema = requiredSchema
}

/**
* A convenient function for schema validation in datasources supporting
* `columnNameOfCorruptRecord` as an option.
*/
private def verifyColumnNameOfCorruptRecord(
schema: StructType,
columnNameOfCorruptRecord: String): Unit = {
schema.getFieldIndex(columnNameOfCorruptRecord).foreach { corruptFieldIndex =>
val f = schema(corruptFieldIndex)
if (f.dataType != StringType || !f.nullable) {
throw new AnalysisException(
"The field for corrupt records must be string type and nullable")
}
}
}

}
Loading