Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import org.apache.spark.sql.types.StructType

import scala.collection.JavaConverters.mapAsJavaMapConverter

@SuppressWarnings(Array("OptionGet"))
class ArangoClient(options: ArangoDBConf) extends Logging {

private def aqlOptions(): AqlQueryOptions = {
Expand Down Expand Up @@ -68,11 +69,10 @@ class ArangoClient(options: ArangoDBConf) extends Logging {

def readCollectionSample(): Seq[String] = {
val query = "FOR d IN @@col LIMIT @size RETURN d"
val params = Map(
val params: Map[String, AnyRef] = Map(
"@col" -> options.readOptions.collection.get,
"size" -> options.readOptions.sampleSize
"size" -> (options.readOptions.sampleSize: java.lang.Integer)
)
.asInstanceOf[Map[String, AnyRef]]
val opts = aqlOptions()
logDebug(s"""Executing AQL query: \n\t$query ${if (params.nonEmpty) s"\n\t with params: $params" else ""}""")

Expand Down Expand Up @@ -147,7 +147,7 @@ class ArangoClient(options: ArangoDBConf) extends Logging {
if (response.getBody.isArray) {
val errors = response.getBody.arrayIterator.asScala
.filter(it => it.get(ArangoResponseField.ERROR).isTrue)
.map(arangoDB.util().deserialize(_, classOf[ErrorEntity]).asInstanceOf[ErrorEntity])
.map(arangoDB.util().deserialize[ErrorEntity](_, classOf[ErrorEntity]))
.toIterable
if (errors.nonEmpty) {
throw new ArangoDBMultiException(errors)
Expand All @@ -157,7 +157,10 @@ class ArangoClient(options: ArangoDBConf) extends Logging {
}


@SuppressWarnings(Array("OptionGet"))
object ArangoClient {
private val INTERNAL_ERROR_CODE = 4
private val SHARDS_API_UNAVAILABLE_CODE = 9

def apply(options: ArangoDBConf): ArangoClient = new ArangoClient(options)

Expand All @@ -175,8 +178,11 @@ object ArangoClient {
case e: ArangoDBException =>
// single server < 3.8 returns Response: 500, Error: 4 - internal error
// single server >= 3.8 returns Response: 501, Error: 9 - shards API is only available in a cluster
if (e.getErrorNum == 9 || e.getErrorNum == 4) Array("")
else throw e
if (INTERNAL_ERROR_CODE.equals(e.getErrorNum) || SHARDS_API_UNAVAILABLE_CODE.equals(e.getErrorNum)) {
Array("")
} else {
throw e
}
}
}

Expand All @@ -185,11 +191,10 @@ object ArangoClient {
val response = client.execute(new Request(ArangoRequestParam.SYSTEM, RequestType.GET, "/_api/cluster/endpoints"))
val field = response.getBody.get("endpoints")
val res = client.util(Serializer.CUSTOM)
.deserialize(field, classOf[Seq[Map[String, String]]])
.asInstanceOf[Seq[Map[String, String]]]
.deserialize[Seq[Map[String, String]]](field, classOf[Seq[Map[String, String]]])
.map(it => it("endpoint").replaceFirst(".*://", ""))
client.shutdown()
res
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,11 @@ object ArangoDBConf {
.createOptional

val BATCH_SIZE = "batchSize"
val DEFAULT_BATCH_SIZE = 10000
val batchSizeConf: ConfigEntry[Int] = ConfigBuilder(BATCH_SIZE)
.doc("batch size")
.intConf
.createWithDefault(10000)
.createWithDefault(DEFAULT_BATCH_SIZE)

val QUERY = "query"
val queryConf: OptionalConfigEntry[String] = ConfigBuilder(QUERY)
Expand All @@ -122,10 +123,11 @@ object ArangoDBConf {
.createOptional

val SAMPLE_SIZE = "sampleSize"
val DEFAULT_SAMPLE_SIZE = 1000
val sampleSizeConf: ConfigEntry[Int] = ConfigBuilder(SAMPLE_SIZE)
.doc("sample size prefetched for schema inference")
.intConf
.createWithDefault(1000)
.createWithDefault(DEFAULT_SAMPLE_SIZE)

val FILL_BLOCK_CACHE = "fillBlockCache"
val fillBlockCacheConf: ConfigEntry[Boolean] = ConfigBuilder(FILL_BLOCK_CACHE)
Expand Down Expand Up @@ -153,10 +155,11 @@ object ArangoDBConf {
.createOptional

val NUMBER_OF_SHARDS = "table.shards"
val DEFAULT_NUMBER_OF_SHARDS = 1
val numberOfShardsConf: ConfigEntry[Int] = ConfigBuilder(NUMBER_OF_SHARDS)
.doc("number of shards of the created collection (in case of SaveMode Append or Overwrite)")
.intConf
.createWithDefault(1)
.createWithDefault(DEFAULT_NUMBER_OF_SHARDS)

val COLLECTION_TYPE = "table.type"
val collectionTypeConf: ConfigEntry[String] = ConfigBuilder(COLLECTION_TYPE)
Expand Down Expand Up @@ -351,7 +354,7 @@ class ArangoDBConf(opts: Map[String, String]) extends Serializable with Logging
*/
def getAllDefinedConfigs: Seq[(String, String, String)] =
confEntries.values.filter(_.isPublic).map { entry =>
val displayValue = Option(getConfString(entry.key, null)).getOrElse(entry.defaultValueString)
val displayValue = settings.get(entry.key).getOrElse(entry.defaultValueString)
(entry.key, displayValue, entry.doc)
}.toSeq

Expand Down Expand Up @@ -443,12 +446,12 @@ class ArangoDBDriverConf(opts: Map[String, String]) extends ArangoDBConf(opts) {
val is = new ByteArrayInputStream(Base64.getDecoder.decode(b64cert))
val cert = CertificateFactory.getInstance(sslCertType).generateCertificate(is)
val ks = KeyStore.getInstance(sslKeystoreType)
ks.load(null)
ks.load(null) // scalastyle:ignore null
ks.setCertificateEntry(sslCertAlias, cert)
val tmf = TrustManagerFactory.getInstance(sslAlgorithm)
tmf.init(ks)
val sc = SSLContext.getInstance(sslProtocol)
sc.init(null, tmf.getTrustManagers, null)
sc.init(null, tmf.getTrustManagers, null) // scalastyle:ignore null
sc
case None => SSLContext.getDefault
}
Expand Down Expand Up @@ -480,9 +483,13 @@ class ArangoDBReadConf(opts: Map[String, String]) extends ArangoDBConf(opts) {
val columnNameOfCorruptRecord: String = getConf(columnNameOfCorruptRecordConf).getOrElse("")

val readMode: ReadMode =
if (query.isDefined) ReadMode.Query
else if (collection.isDefined) ReadMode.Collection
else throw new IllegalArgumentException("Either collection or query must be defined")
if (query.isDefined) {
ReadMode.Query
} else if (collection.isDefined) {
ReadMode.Collection
} else {
throw new IllegalArgumentException("Either collection or query must be defined")
}

}

Expand Down Expand Up @@ -511,4 +518,4 @@ class ArangoDBWriteConf(opts: Map[String, String]) extends ArangoDBConf(opts) {

val keepNull: Boolean = getConf(keepNullConf)

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,11 @@ object ArangoUtils {
.json(spark.createDataset(sampleEntries)(Encoders.STRING))
.schema

if (options.readOptions.columnNameOfCorruptRecord.isEmpty)
if (options.readOptions.columnNameOfCorruptRecord.isEmpty) {
schema
else
} else {
schema.add(StructField(options.readOptions.columnNameOfCorruptRecord, StringType, nullable = true))
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package org.apache.spark.sql.arangodb.commons.filter

import org.apache.spark.sql.arangodb.commons.PushdownUtils.getStructField
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.{DateType, StringType, StructType, TimestampType}
import org.apache.spark.sql.types.{DataType, DateType, StringType, StructType, TimestampType}

sealed trait PushableFilter extends Serializable {
def support(): FilterSupport
Expand All @@ -11,6 +11,7 @@ sealed trait PushableFilter extends Serializable {
}

object PushableFilter {
// scalastyle:off cyclomatic.complexity
def apply(filter: Filter, schema: StructType): PushableFilter = filter match {
// @formatter:off
case f: And => new AndFilter(apply(f.left, schema), apply(f.right, schema))
Expand All @@ -34,6 +35,7 @@ object PushableFilter {
}
// @formatter:on
}
// scalastyle:on cyclomatic.complexity
}

private class OrFilter(parts: PushableFilter*) extends PushableFilter {
Expand All @@ -48,9 +50,13 @@ private class OrFilter(parts: PushableFilter*) extends PushableFilter {
* +---------++---------+---------+------+
*/
override def support(): FilterSupport =
if (parts.exists(_.support == FilterSupport.NONE)) FilterSupport.NONE
else if (parts.forall(_.support == FilterSupport.FULL)) FilterSupport.FULL
else FilterSupport.PARTIAL
if (parts.exists(_.support == FilterSupport.NONE)) {
FilterSupport.NONE
} else if (parts.forall(_.support == FilterSupport.FULL)) {
FilterSupport.FULL
} else {
FilterSupport.PARTIAL
}

override def aql(v: String): String = parts
.map(_.aql(v))
Expand All @@ -69,9 +75,13 @@ private class AndFilter(parts: PushableFilter*) extends PushableFilter {
* +---------++---------+---------+---------+
*/
override def support(): FilterSupport =
if (parts.forall(_.support == FilterSupport.NONE)) FilterSupport.NONE
else if (parts.forall(_.support == FilterSupport.FULL)) FilterSupport.FULL
else FilterSupport.PARTIAL
if (parts.forall(_.support == FilterSupport.NONE)) {
FilterSupport.NONE
} else if (parts.forall(_.support == FilterSupport.FULL)) {
FilterSupport.FULL
} else {
FilterSupport.PARTIAL
}

override def aql(v: String): String = parts
.filter(_.support() != FilterSupport.NONE)
Expand All @@ -91,8 +101,11 @@ private class NotFilter(child: PushableFilter) extends PushableFilter {
* +---------++---------+
*/
override def support(): FilterSupport =
if (child.support() == FilterSupport.FULL) FilterSupport.FULL
else FilterSupport.NONE
if (child.support() == FilterSupport.FULL) {
FilterSupport.FULL
} else {
FilterSupport.NONE
}

override def aql(v: String): String = s"NOT (${child.aql(v)})"
}
Expand All @@ -106,14 +119,14 @@ private class EqualToFilter(attribute: String, value: Any, schema: StructType) e
override def support(): FilterSupport = dataType match {
case _: DateType => FilterSupport.FULL
case _: TimestampType => FilterSupport.PARTIAL // microseconds are ignored in AQL
case t if isTypeAqlCompatible(t) => FilterSupport.FULL
case t: DataType if isTypeAqlCompatible(t) => FilterSupport.FULL
case _ => FilterSupport.NONE
}

override def aql(v: String): String = dataType match {
case t: DateType => s"""DATE_TIMESTAMP(`$v`.$escapedFieldName) == DATE_TIMESTAMP(${getValue(t, value)})"""
case t: TimestampType => s"""DATE_TIMESTAMP(`$v`.$escapedFieldName) == DATE_TIMESTAMP(${getValue(t, value)})"""
case t => s"""`$v`.$escapedFieldName == ${getValue(t, value)}"""
case t: DataType => s"""`$v`.$escapedFieldName == ${getValue(t, value)}"""
}
}

Expand All @@ -126,14 +139,14 @@ private class GreaterThanFilter(attribute: String, value: Any, schema: StructTyp
override def support(): FilterSupport = dataType match {
case _: DateType => FilterSupport.FULL
case _: TimestampType => FilterSupport.PARTIAL // microseconds are ignored in AQL
case t if isTypeAqlCompatible(t) => FilterSupport.FULL
case t: DataType if isTypeAqlCompatible(t) => FilterSupport.FULL
case _ => FilterSupport.NONE
}

override def aql(v: String): String = dataType match {
case t: DateType => s"""DATE_TIMESTAMP(`$v`.$escapedFieldName) > DATE_TIMESTAMP(${getValue(t, value)})"""
case t: TimestampType => s"""DATE_TIMESTAMP(`$v`.$escapedFieldName) >= DATE_TIMESTAMP(${getValue(t, value)})""" // microseconds are ignored in AQL
case t => s"""`$v`.$escapedFieldName > ${getValue(t, value)}"""
case t: DataType => s"""`$v`.$escapedFieldName > ${getValue(t, value)}"""
}
}

Expand All @@ -146,14 +159,14 @@ private class GreaterThanOrEqualFilter(attribute: String, value: Any, schema: St
override def support(): FilterSupport = dataType match {
case _: DateType => FilterSupport.FULL
case _: TimestampType => FilterSupport.PARTIAL // microseconds are ignored in AQL
case t if isTypeAqlCompatible(t) => FilterSupport.FULL
case t: DataType if isTypeAqlCompatible(t) => FilterSupport.FULL
case _ => FilterSupport.NONE
}

override def aql(v: String): String = dataType match {
case t: DateType => s"""DATE_TIMESTAMP(`$v`.$escapedFieldName) >= DATE_TIMESTAMP(${getValue(t, value)})"""
case t: TimestampType => s"""DATE_TIMESTAMP(`$v`.$escapedFieldName) >= DATE_TIMESTAMP(${getValue(t, value)})"""
case t => s"""`$v`.$escapedFieldName >= ${getValue(t, value)}"""
case t: DataType => s"""`$v`.$escapedFieldName >= ${getValue(t, value)}"""
}
}

Expand All @@ -166,14 +179,14 @@ private class LessThanFilter(attribute: String, value: Any, schema: StructType)
override def support(): FilterSupport = dataType match {
case _: DateType => FilterSupport.FULL
case _: TimestampType => FilterSupport.PARTIAL // microseconds are ignored in AQL
case t if isTypeAqlCompatible(t) => FilterSupport.FULL
case t: DataType if isTypeAqlCompatible(t) => FilterSupport.FULL
case _ => FilterSupport.NONE
}

override def aql(v: String): String = dataType match {
case t: DateType => s"""DATE_TIMESTAMP(`$v`.$escapedFieldName) < DATE_TIMESTAMP(${getValue(t, value)})"""
case t: TimestampType => s"""DATE_TIMESTAMP(`$v`.$escapedFieldName) <= DATE_TIMESTAMP(${getValue(t, value)})""" // microseconds are ignored in AQL
case t => s"""`$v`.$escapedFieldName < ${getValue(t, value)}"""
case t: DataType => s"""`$v`.$escapedFieldName < ${getValue(t, value)}"""
}
}

Expand All @@ -186,14 +199,14 @@ private class LessThanOrEqualFilter(attribute: String, value: Any, schema: Struc
override def support(): FilterSupport = dataType match {
case _: DateType => FilterSupport.FULL
case _: TimestampType => FilterSupport.PARTIAL // microseconds are ignored in AQL
case t if isTypeAqlCompatible(t) => FilterSupport.FULL
case t: DataType if isTypeAqlCompatible(t) => FilterSupport.FULL
case _ => FilterSupport.NONE
}

override def aql(v: String): String = dataType match {
case t: DateType => s"""DATE_TIMESTAMP(`$v`.$escapedFieldName) <= DATE_TIMESTAMP(${getValue(t, value)})"""
case t: TimestampType => s"""DATE_TIMESTAMP(`$v`.$escapedFieldName) <= DATE_TIMESTAMP(${getValue(t, value)})"""
case t => s"""`$v`.$escapedFieldName <= ${getValue(t, value)}"""
case t: DataType => s"""`$v`.$escapedFieldName <= ${getValue(t, value)}"""
}
}

Expand Down Expand Up @@ -274,12 +287,16 @@ private class InFilter(attribute: String, values: Array[Any], schema: StructType
override def support(): FilterSupport = dataType match {
case _: DateType => FilterSupport.FULL
case _: TimestampType => FilterSupport.PARTIAL // microseconds are ignored in AQL
case t if isTypeAqlCompatible(t) => FilterSupport.FULL
case t: DataType if isTypeAqlCompatible(t) => FilterSupport.FULL
case _ => FilterSupport.NONE
}

override def aql(v: String): String = dataType match {
case _: TimestampType | DateType => s"""LENGTH([${values.map(getValue(dataType, _)).mkString(",")}][* FILTER DATE_TIMESTAMP(`$v`.$escapedFieldName) == DATE_TIMESTAMP(CURRENT)]) > 0"""
case _ => s"""LENGTH([${values.map(getValue(dataType, _)).mkString(",")}][* FILTER `$v`.$escapedFieldName == CURRENT]) > 0"""
case _: TimestampType | DateType => s"""LENGTH([${
values.map(getValue(dataType, _)).mkString(",")
}][* FILTER DATE_TIMESTAMP(`$v`.$escapedFieldName) == DATE_TIMESTAMP(CURRENT)]) > 0"""
case _ => s"""LENGTH([${
values.map(getValue(dataType, _)).mkString(",")
}][* FILTER `$v`.$escapedFieldName == CURRENT]) > 0"""
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -86,18 +86,18 @@ object CollectionType {
case object DOCUMENT extends CollectionType {
override val name: String = "document"

override def get() = entity.CollectionType.DOCUMENT
override def get(): entity.CollectionType = entity.CollectionType.DOCUMENT
}

case object EDGE extends CollectionType {
override val name: String = "edge"

override def get() = entity.CollectionType.EDGES
override def get(): entity.CollectionType = entity.CollectionType.EDGES
}

def apply(value: String): CollectionType = value match {
case DOCUMENT.name => DOCUMENT
case EDGE.name => EDGE
case _ => throw new IllegalArgumentException(s"${ArangoDBConf.COLLECTION_TYPE}: $value")
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@ class PushDownCtx(
// filters to push down
val filters: Array[PushableFilter]
)
extends Serializable
extends Serializable
Loading