Skip to content

Commit

Permalink
Cosmos DB Spark Connector - query optimizations for queries targeting…
Browse files Browse the repository at this point in the history
… logical partitions (#25889)

* Extending CosmosCatalog in Spark

Allowing the following new options when creating new containers:
- analyticalStorageTttlInSeconds - this is required to make the Catalog integration useful in Synapse
- partitionKeyVersion - we don't allow creating containers with Hash V2 (whichis needed to support long partition keys and requirement to later support hierarchical Pk)

Also allowing to retrieve more container properties mapped into TBLPROPERTIES when using DESCRIBE TABLE EXTENDED command

* Spark query perf improvement when only targeting single/few logical partitions

* Query optimizations without tests yet

* Adding unit test coverage

* Adding suffix to FeedRangeCacheInterval
  • Loading branch information
FabianMeiswinkel committed Dec 8, 2021
1 parent 933087a commit 84779c9
Show file tree
Hide file tree
Showing 14 changed files with 314 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
package com.azure.cosmos

import com.azure.cosmos.implementation.SparkBridgeImplementationInternal
import com.azure.cosmos.implementation.SparkBridgeImplementationInternal.rangeToNormalizedRange
import com.azure.cosmos.implementation.feedranges.FeedRangeEpkImpl
import com.azure.cosmos.implementation.routing.Range
import com.azure.cosmos.models.FeedRange
import com.azure.cosmos.spark.NormalizedRange

// scalastyle:off underscore.import
Expand All @@ -30,4 +32,20 @@ private[cosmos] object SparkBridgeInternal {
private[this] def toCosmosRange(range: NormalizedRange): Range[String] = {
new Range[String](range.min, range.max, true, false)
}

private[cosmos] def getCacheKeyForContainer(container: CosmosAsyncContainer): String = {
val database = container.getDatabase
s"${database.getClient.getServiceEndpoint}|${database.getId}|${container.getId}"
}

private[cosmos] def getNormalizedEffectiveRange
(
container: CosmosAsyncContainer,
feedRange: FeedRange
) : NormalizedRange = {

SparkBridgeImplementationInternal
.rangeToNormalizedRange(
container.getNormalizedEffectiveRange(feedRange).block)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,16 @@

package com.azure.cosmos.implementation

import com.azure.cosmos.CosmosClientBuilder
import com.azure.cosmos.{CosmosAsyncContainer, CosmosClientBuilder}
import com.azure.cosmos.implementation.ImplementationBridgeHelpers.CosmosClientBuilderHelper
import com.azure.cosmos.implementation.changefeed.implementation.{ChangeFeedState, ChangeFeedStateV1}
import com.azure.cosmos.implementation.feedranges.{FeedRangeContinuation, FeedRangeEpkImpl, FeedRangeInternal}
import com.azure.cosmos.implementation.query.CompositeContinuationToken
import com.azure.cosmos.implementation.routing.Range
import com.azure.cosmos.models.FeedRange
import com.azure.cosmos.models.{FeedRange, PartitionKey, PartitionKeyDefinition, SparkModelBridgeInternal}
import com.azure.cosmos.spark.NormalizedRange

// scalastyle:off underscore.import
import com.azure.cosmos.implementation.feedranges._
import scala.collection.JavaConverters._
// scalastyle:on underscore.import

Expand Down Expand Up @@ -94,7 +94,7 @@ private[cosmos] object SparkBridgeImplementationInternal {
.toArray
}

private[this] def rangeToNormalizedRange(rangeInput: Range[String]) = {
private[cosmos] def rangeToNormalizedRange(rangeInput: Range[String]) = {
val range = FeedRangeInternal.normalizeRange(rangeInput)
assert(range != null, "Argument 'range' must not be null.")
assert(range.isMinInclusive, "Argument 'range' must be minInclusive")
Expand Down Expand Up @@ -145,4 +145,18 @@ private[cosmos] object SparkBridgeImplementationInternal {
val epk = feedRange.asInstanceOf[FeedRangeEpkImpl]
rangeToNormalizedRange(epk.getRange)
}

private[cosmos] def partitionKeyValueToNormalizedRange
(
partitionKeyValue: Object,
partitionKeyDefinitionJson: String
): NormalizedRange = {

val feedRange = FeedRange
.forLogicalPartition(new PartitionKey(partitionKeyValue))
.asInstanceOf[FeedRangePartitionKeyImpl]

val pkDefinition = SparkModelBridgeInternal.createPartitionKeyDefinitionFromJson(partitionKeyDefinitionJson)
rangeToNormalizedRange(feedRange.getEffectiveRange(pkDefinition))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,8 @@ private[cosmos] object SparkModelBridgeInternal {
def createIndexingPolicyFromJson(json: String): IndexingPolicy = {
new IndexingPolicy(json)
}

def createPartitionKeyDefinitionFromJson(json: String): PartitionKeyDefinition = {
new PartitionKeyDefinition(json)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
package com.azure.cosmos.spark

import com.azure.cosmos.{CosmosAsyncContainer, SparkBridgeInternal}
import com.azure.cosmos.models.FeedRange
import reactor.core.scala.publisher.SMono
import reactor.core.scala.publisher.SMono.PimpJMono

import java.time.Instant
import java.time.temporal.ChronoUnit
import scala.collection.concurrent.TrieMap
// scalastyle:off underscore.import
import scala.collection.JavaConverters._
// scalastyle:on underscore.import

private[spark] object ContainerFeedRangesCache {

private val cache = new TrieMap[String, CachedFeedRanges]

def getFeedRanges
(
container: CosmosAsyncContainer
): SMono[List[FeedRange]] = {

val key = SparkBridgeInternal.getCacheKeyForContainer(container)

cache.get(key) match {
case Some(cached) =>
if (cached
.retrievedAt
.compareTo(Instant.now.minus(CosmosConstants.feedRangesCacheIntervalInMinutes, ChronoUnit.MINUTES)) >= 0) {

SMono.just(cached.feedRanges)
} else {
refreshFeedRanges(key, container)
}
case None => refreshFeedRanges(key, container)
}
}

private[this] def refreshFeedRanges(key: String, container: CosmosAsyncContainer): SMono[List[FeedRange]] = {
container
.getFeedRanges
.map[List[FeedRange]](javaList => {
val scalaList = javaList.asScala.toList
cache.put(key, CachedFeedRanges(scalaList, Instant.now))
scalaList
})
.asScala
}

private case class CachedFeedRanges(feedRanges: List[FeedRange], retrievedAt: Instant)
}
Original file line number Diff line number Diff line change
Expand Up @@ -559,32 +559,27 @@ class CosmosCatalog
CosmosClientConfiguration(config, readConfig.forceEventualConsistency),
None,
s"CosmosCatalog(name $catalogName).tryGetContainerMetadata($databaseName, $containerName)"))
.to(cosmosClientCacheItem =>
.to(cosmosClientCacheItem => {

val container = cosmosClientCacheItem
.client
.getDatabase(databaseName)
.getContainer(containerName)

(
cosmosClientCacheItem
.client
.getDatabase(databaseName)
.getContainer(containerName)
container
.read()
.block()
.getProperties,

cosmosClientCacheItem
.client
.getDatabase(databaseName)
.getContainer(containerName)
.getFeedRanges
.block()
.asScala
.toList,
ContainerFeedRangesCache
.getFeedRanges(container)
.block(),

try {
Some(
(
cosmosClientCacheItem
.client
.getDatabase(databaseName)
.getContainer(containerName)
container
.readThroughput()
.block()
.getProperties,
Expand All @@ -599,9 +594,8 @@ class CosmosCatalog
try {
Some(
(
cosmosClientCacheItem
.client
.getDatabase(databaseName)
container
.getDatabase
.readThroughput()
.block()
.getProperties,
Expand All @@ -619,7 +613,7 @@ class CosmosCatalog
}
}
)
))
}))
} catch {
case e: CosmosException if isNotFound(e) =>
None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import com.azure.cosmos.implementation.routing.LocationHelper
import com.azure.cosmos.models.{CosmosChangeFeedRequestOptions, CosmosParameterizedQuery, FeedRange}
import com.azure.cosmos.spark.ChangeFeedModes.ChangeFeedMode
import com.azure.cosmos.spark.ChangeFeedStartFromModes.{ChangeFeedStartFromMode, PointInTime}
import com.azure.cosmos.spark.CosmosPredicates.requireNotNullOrEmpty
import com.azure.cosmos.spark.ItemWriteStrategy.{ItemWriteStrategy, values}
import com.azure.cosmos.spark.PartitioningStrategies.PartitioningStrategy
import com.azure.cosmos.spark.SchemaConversionModes.SchemaConversionMode
Expand All @@ -22,6 +23,7 @@ import java.time.format.DateTimeFormatter
import java.time.{Duration, Instant}
import java.util.{Locale, ServiceLoader}
import scala.collection.immutable.{HashSet, Map}
import scala.collection.mutable

// scalastyle:off underscore.import
import scala.collection.JavaConverters._
Expand Down Expand Up @@ -52,6 +54,7 @@ private object CosmosConfigNames {
val ReadInferSchemaQuery = "spark.cosmos.read.inferSchema.query"
val ReadPartitioningStrategy = "spark.cosmos.read.partitioning.strategy"
val ReadPartitioningTargetedCount = "spark.cosmos.partitioning.targetedCount"
val ReadPartitioningFeedRangeFilter = "spark.cosmos.partitioning.feedRangeFilter"
val ViewsRepositoryPath = "spark.cosmos.views.repositoryPath"
val DiagnosticsMode = "spark.cosmos.diagnostics"
val WriteBulkEnabled = "spark.cosmos.write.bulk.enabled"
Expand Down Expand Up @@ -98,6 +101,7 @@ private object CosmosConfigNames {
ReadInferSchemaQuery,
ReadPartitioningStrategy,
ReadPartitioningTargetedCount,
ReadPartitioningFeedRangeFilter,
ViewsRepositoryPath,
DiagnosticsMode,
WriteBulkEnabled,
Expand Down Expand Up @@ -137,7 +141,7 @@ private object CosmosConfig {
(
databaseName: Option[String],
containerName: Option[String],
sparkConf: SparkConf,
sparkConf: Option[SparkConf],
// spark application configteams
userProvidedOptions: Map[String, String] // user provided config
) : Map[String, String] = {
Expand All @@ -162,8 +166,13 @@ private object CosmosConfig {
effectiveUserConfig += (CosmosContainerConfig.CONTAINER_NAME_KEY -> containerName.get)
}

val conf = sparkConf.clone()
val returnValue = conf.setAll(effectiveUserConfig.toMap).getAll.toMap
val returnValue = sparkConf match {
case Some(sparkConfig) => {
val conf = sparkConf.get.clone()
conf.setAll(effectiveUserConfig.toMap).getAll.toMap
}
case None => effectiveUserConfig.toMap
}

returnValue.foreach((configProperty) => CosmosConfigNames.validateConfigName(configProperty._1))

Expand All @@ -185,7 +194,23 @@ private object CosmosConfig {
getEffectiveConfig(
databaseName,
containerName,
session.sparkContext.getConf, // spark application config
Some(session.sparkContext.getConf), // spark application config
userProvidedOptions) // user provided config
}

def getEffectiveConfigIgnoringSessionConfig
(
databaseName: Option[String],
containerName: Option[String],
userProvidedOptions: Map[String, String] = Map().empty
) : Map[String, String] = {

// TODO: moderakh we should investigate how spark sql config should be merged:
// TODO: session.conf.getAll, // spark sql runtime config
getEffectiveConfig(
databaseName,
containerName,
None,
userProvidedOptions) // user provided config
}
}
Expand Down Expand Up @@ -644,7 +669,8 @@ private object PartitioningStrategies extends Enumeration {
private case class CosmosPartitioningConfig
(
partitioningStrategy: PartitioningStrategy,
targetedPartitionCount: Option[Int]
targetedPartitionCount: Option[Int],
feedRangeFiler: Option[Array[NormalizedRange]]
)

private object CosmosPartitioningConfig {
Expand All @@ -666,11 +692,32 @@ private object CosmosPartitioningConfig {
parseFromStringFunction = strategyNotYetParsed => CosmosConfigEntry.parseEnumeration(strategyNotYetParsed, PartitioningStrategies),
helpMessage = "The partitioning strategy used (Default, Custom, Restrictive or Aggressive)")

private val partitioningFeedRangeFilter = CosmosConfigEntry[Array[NormalizedRange]](
key = CosmosConfigNames.ReadPartitioningFeedRangeFilter,
defaultValue = None,
mandatory = false,
parseFromStringFunction = filter => {
requireNotNullOrEmpty(filter, CosmosConfigNames.ReadPartitioningFeedRangeFilter)

val epkRanges = mutable.Buffer[NormalizedRange]()
val fragments = filter.split(",")
for (fragment <- fragments) {
val minAndMax = fragment.trim.split("-")
epkRanges += (NormalizedRange(minAndMax(0), minAndMax(1)))
}

epkRanges.toArray
},
helpMessage = "The feed ranges this query should be scoped to")

def parseCosmosPartitioningConfig(cfg: Map[String, String]): CosmosPartitioningConfig = {
val partitioningStrategyParsed = CosmosConfigEntry
.parse(cfg, partitioningStrategy)
.getOrElse(DefaultPartitioningStrategy)

val partitioningFeedRangeFilterParsed = CosmosConfigEntry
.parse(cfg, partitioningFeedRangeFilter)

val targetedPartitionCountParsed = if (partitioningStrategyParsed == PartitioningStrategies.Custom) {
CosmosConfigEntry.parse(cfg, targetedPartitionCount)
} else {
Expand All @@ -679,7 +726,8 @@ private object CosmosPartitioningConfig {

CosmosPartitioningConfig(
partitioningStrategyParsed,
targetedPartitionCountParsed
targetedPartitionCountParsed,
partitioningFeedRangeFilterParsed
)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ private object CosmosConstants {
val maxRetryIntervalForTransientFailuresInMs = 5000
val maxRetryCountForTransientFailures = 100
val defaultDirectRequestTimeoutInSeconds = 10L
val feedRangesCacheIntervalInMinutes = 1

object Names {
val ItemsDataSourceShortName = "cosmos.oltp"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,10 @@ private object CosmosPartitionPlanner extends BasicLoggingTrait {

assertOnSparkDriver()
val lastContinuationTokens: ConcurrentMap[FeedRange, String] = new ConcurrentHashMap[FeedRange, String]()
container
.getFeedRanges
.asScala
.flatMapMany(feedRanges => SFlux.fromIterable(feedRanges.asScala))

ContainerFeedRangesCache
.getFeedRanges(container)
.flatMapMany(feedRanges => SFlux.fromIterable(feedRanges))
.flatMap(feedRange => {
val requestOptions = changeFeedConfig.toRequestOptions(feedRange)
requestOptions.setMaxItemCount(1)
Expand Down Expand Up @@ -462,11 +462,9 @@ private object CosmosPartitionPlanner extends BasicLoggingTrait {
val container = ThroughputControlHelper.getContainer(userConfig, cosmosContainerConfig, clientCacheItem.client)
SparkUtils.safeOpenConnectionInitCaches(container, (msg, e) => logWarning(msg, e))

container
.getFeedRanges
.asScala
ContainerFeedRangesCache
.getFeedRanges(container)
.map(feedRanges => feedRanges
.asScala
.map(feedRange => SparkBridgeImplementationInternal.toNormalizedRange(feedRange))
.toArray)
})
Expand Down

0 comments on commit 84779c9

Please sign in to comment.