Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cosmos DB Spark Connector - query optimizations for queries targeting logical partitions #25889

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
package com.azure.cosmos

import com.azure.cosmos.implementation.SparkBridgeImplementationInternal
import com.azure.cosmos.implementation.SparkBridgeImplementationInternal.rangeToNormalizedRange
import com.azure.cosmos.implementation.feedranges.FeedRangeEpkImpl
import com.azure.cosmos.implementation.routing.Range
import com.azure.cosmos.models.FeedRange
import com.azure.cosmos.spark.NormalizedRange

// scalastyle:off underscore.import
Expand All @@ -30,4 +32,20 @@ private[cosmos] object SparkBridgeInternal {
private[this] def toCosmosRange(range: NormalizedRange): Range[String] = {
new Range[String](range.min, range.max, true, false)
}

private[cosmos] def getCacheKeyForContainer(container: CosmosAsyncContainer): String = {
val database = container.getDatabase
s"${database.getClient.getServiceEndpoint}|${database.getId}|${container.getId}"
}

private[cosmos] def getNormalizedEffectiveRange
(
container: CosmosAsyncContainer,
feedRange: FeedRange
) : NormalizedRange = {

SparkBridgeImplementationInternal
.rangeToNormalizedRange(
container.getNormalizedEffectiveRange(feedRange).block)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,16 @@

package com.azure.cosmos.implementation

import com.azure.cosmos.CosmosClientBuilder
import com.azure.cosmos.{CosmosAsyncContainer, CosmosClientBuilder}
import com.azure.cosmos.implementation.ImplementationBridgeHelpers.CosmosClientBuilderHelper
import com.azure.cosmos.implementation.changefeed.implementation.{ChangeFeedState, ChangeFeedStateV1}
import com.azure.cosmos.implementation.feedranges.{FeedRangeContinuation, FeedRangeEpkImpl, FeedRangeInternal}
import com.azure.cosmos.implementation.query.CompositeContinuationToken
import com.azure.cosmos.implementation.routing.Range
import com.azure.cosmos.models.FeedRange
import com.azure.cosmos.models.{FeedRange, PartitionKey, PartitionKeyDefinition, SparkModelBridgeInternal}
import com.azure.cosmos.spark.NormalizedRange

// scalastyle:off underscore.import
import com.azure.cosmos.implementation.feedranges._
import scala.collection.JavaConverters._
// scalastyle:on underscore.import

Expand Down Expand Up @@ -94,7 +94,7 @@ private[cosmos] object SparkBridgeImplementationInternal {
.toArray
}

private[this] def rangeToNormalizedRange(rangeInput: Range[String]) = {
private[cosmos] def rangeToNormalizedRange(rangeInput: Range[String]) = {
val range = FeedRangeInternal.normalizeRange(rangeInput)
assert(range != null, "Argument 'range' must not be null.")
assert(range.isMinInclusive, "Argument 'range' must be minInclusive")
Expand Down Expand Up @@ -145,4 +145,18 @@ private[cosmos] object SparkBridgeImplementationInternal {
val epk = feedRange.asInstanceOf[FeedRangeEpkImpl]
rangeToNormalizedRange(epk.getRange)
}

private[cosmos] def partitionKeyValueToNormalizedRange
(
partitionKeyValue: Object,
partitionKeyDefinitionJson: String
): NormalizedRange = {

val feedRange = FeedRange
.forLogicalPartition(new PartitionKey(partitionKeyValue))
.asInstanceOf[FeedRangePartitionKeyImpl]

val pkDefinition = SparkModelBridgeInternal.createPartitionKeyDefinitionFromJson(partitionKeyDefinitionJson)
rangeToNormalizedRange(feedRange.getEffectiveRange(pkDefinition))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,8 @@ private[cosmos] object SparkModelBridgeInternal {
def createIndexingPolicyFromJson(json: String): IndexingPolicy = {
new IndexingPolicy(json)
}

def createPartitionKeyDefinitionFromJson(json: String): PartitionKeyDefinition = {
new PartitionKeyDefinition(json)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
package com.azure.cosmos.spark

import com.azure.cosmos.{CosmosAsyncContainer, SparkBridgeInternal}
import com.azure.cosmos.models.FeedRange
import reactor.core.scala.publisher.SMono
import reactor.core.scala.publisher.SMono.PimpJMono

import java.time.Instant
import java.time.temporal.ChronoUnit
import scala.collection.concurrent.TrieMap
// scalastyle:off underscore.import
import scala.collection.JavaConverters._
// scalastyle:on underscore.import

private[spark] object ContainerFeedRangesCache {

private val cache = new TrieMap[String, CachedFeedRanges]

def getFeedRanges
(
container: CosmosAsyncContainer
): SMono[List[FeedRange]] = {

val key = SparkBridgeInternal.getCacheKeyForContainer(container)

cache.get(key) match {
case Some(cached) =>
if (cached
.retrievedAt
.compareTo(Instant.now.minus(CosmosConstants.feedRangesCacheIntervalInMinutes, ChronoUnit.MINUTES)) >= 0) {

SMono.just(cached.feedRanges)
} else {
refreshFeedRanges(key, container)
}
case None => refreshFeedRanges(key, container)
}
}

private[this] def refreshFeedRanges(key: String, container: CosmosAsyncContainer): SMono[List[FeedRange]] = {
FabianMeiswinkel marked this conversation as resolved.
Show resolved Hide resolved
container
.getFeedRanges
.map[List[FeedRange]](javaList => {
val scalaList = javaList.asScala.toList
cache.put(key, CachedFeedRanges(scalaList, Instant.now))
scalaList
})
.asScala
}

private case class CachedFeedRanges(feedRanges: List[FeedRange], retrievedAt: Instant)
}
Original file line number Diff line number Diff line change
Expand Up @@ -559,32 +559,27 @@ class CosmosCatalog
CosmosClientConfiguration(config, readConfig.forceEventualConsistency),
None,
s"CosmosCatalog(name $catalogName).tryGetContainerMetadata($databaseName, $containerName)"))
.to(cosmosClientCacheItem =>
.to(cosmosClientCacheItem => {

val container = cosmosClientCacheItem
.client
.getDatabase(databaseName)
.getContainer(containerName)

(
cosmosClientCacheItem
.client
.getDatabase(databaseName)
.getContainer(containerName)
container
.read()
.block()
.getProperties,

cosmosClientCacheItem
.client
.getDatabase(databaseName)
.getContainer(containerName)
.getFeedRanges
.block()
.asScala
.toList,
ContainerFeedRangesCache
.getFeedRanges(container)
.block(),

try {
Some(
(
cosmosClientCacheItem
.client
.getDatabase(databaseName)
.getContainer(containerName)
container
.readThroughput()
.block()
.getProperties,
Expand All @@ -599,9 +594,8 @@ class CosmosCatalog
try {
Some(
(
cosmosClientCacheItem
.client
.getDatabase(databaseName)
container
.getDatabase
.readThroughput()
.block()
.getProperties,
Expand All @@ -619,7 +613,7 @@ class CosmosCatalog
}
}
)
))
}))
} catch {
case e: CosmosException if isNotFound(e) =>
None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import com.azure.cosmos.implementation.routing.LocationHelper
import com.azure.cosmos.models.{CosmosChangeFeedRequestOptions, CosmosParameterizedQuery, FeedRange}
import com.azure.cosmos.spark.ChangeFeedModes.ChangeFeedMode
import com.azure.cosmos.spark.ChangeFeedStartFromModes.{ChangeFeedStartFromMode, PointInTime}
import com.azure.cosmos.spark.CosmosPredicates.requireNotNullOrEmpty
import com.azure.cosmos.spark.ItemWriteStrategy.{ItemWriteStrategy, values}
import com.azure.cosmos.spark.PartitioningStrategies.PartitioningStrategy
import com.azure.cosmos.spark.SchemaConversionModes.SchemaConversionMode
Expand All @@ -22,6 +23,7 @@ import java.time.format.DateTimeFormatter
import java.time.{Duration, Instant}
import java.util.{Locale, ServiceLoader}
import scala.collection.immutable.{HashSet, Map}
import scala.collection.mutable

// scalastyle:off underscore.import
import scala.collection.JavaConverters._
Expand Down Expand Up @@ -52,6 +54,7 @@ private object CosmosConfigNames {
val ReadInferSchemaQuery = "spark.cosmos.read.inferSchema.query"
val ReadPartitioningStrategy = "spark.cosmos.read.partitioning.strategy"
val ReadPartitioningTargetedCount = "spark.cosmos.partitioning.targetedCount"
val ReadPartitioningFeedRangeFilter = "spark.cosmos.partitioning.feedRangeFilter"
val ViewsRepositoryPath = "spark.cosmos.views.repositoryPath"
val DiagnosticsMode = "spark.cosmos.diagnostics"
val WriteBulkEnabled = "spark.cosmos.write.bulk.enabled"
Expand Down Expand Up @@ -98,6 +101,7 @@ private object CosmosConfigNames {
ReadInferSchemaQuery,
ReadPartitioningStrategy,
ReadPartitioningTargetedCount,
ReadPartitioningFeedRangeFilter,
ViewsRepositoryPath,
DiagnosticsMode,
WriteBulkEnabled,
Expand Down Expand Up @@ -137,7 +141,7 @@ private object CosmosConfig {
(
databaseName: Option[String],
containerName: Option[String],
sparkConf: SparkConf,
sparkConf: Option[SparkConf],
// spark application configteams
userProvidedOptions: Map[String, String] // user provided config
) : Map[String, String] = {
Expand All @@ -162,8 +166,13 @@ private object CosmosConfig {
effectiveUserConfig += (CosmosContainerConfig.CONTAINER_NAME_KEY -> containerName.get)
}

val conf = sparkConf.clone()
val returnValue = conf.setAll(effectiveUserConfig.toMap).getAll.toMap
val returnValue = sparkConf match {
case Some(sparkConfig) => {
val conf = sparkConf.get.clone()
conf.setAll(effectiveUserConfig.toMap).getAll.toMap
}
case None => effectiveUserConfig.toMap
}

returnValue.foreach((configProperty) => CosmosConfigNames.validateConfigName(configProperty._1))

Expand All @@ -185,7 +194,23 @@ private object CosmosConfig {
getEffectiveConfig(
databaseName,
containerName,
session.sparkContext.getConf, // spark application config
Some(session.sparkContext.getConf), // spark application config
userProvidedOptions) // user provided config
}

def getEffectiveConfigIgnoringSessionConfig
(
databaseName: Option[String],
containerName: Option[String],
userProvidedOptions: Map[String, String] = Map().empty
) : Map[String, String] = {

// TODO: moderakh we should investigate how spark sql config should be merged:
// TODO: session.conf.getAll, // spark sql runtime config
getEffectiveConfig(
databaseName,
containerName,
None,
userProvidedOptions) // user provided config
}
}
Expand Down Expand Up @@ -644,7 +669,8 @@ private object PartitioningStrategies extends Enumeration {
private case class CosmosPartitioningConfig
(
partitioningStrategy: PartitioningStrategy,
targetedPartitionCount: Option[Int]
targetedPartitionCount: Option[Int],
feedRangeFiler: Option[Array[NormalizedRange]]
)

private object CosmosPartitioningConfig {
Expand All @@ -666,11 +692,32 @@ private object CosmosPartitioningConfig {
parseFromStringFunction = strategyNotYetParsed => CosmosConfigEntry.parseEnumeration(strategyNotYetParsed, PartitioningStrategies),
helpMessage = "The partitioning strategy used (Default, Custom, Restrictive or Aggressive)")

private val partitioningFeedRangeFilter = CosmosConfigEntry[Array[NormalizedRange]](
key = CosmosConfigNames.ReadPartitioningFeedRangeFilter,
defaultValue = None,
mandatory = false,
parseFromStringFunction = filter => {
requireNotNullOrEmpty(filter, CosmosConfigNames.ReadPartitioningFeedRangeFilter)

val epkRanges = mutable.Buffer[NormalizedRange]()
val fragments = filter.split(",")
for (fragment <- fragments) {
val minAndMax = fragment.trim.split("-")
epkRanges += (NormalizedRange(minAndMax(0), minAndMax(1)))
}

epkRanges.toArray
},
helpMessage = "The feed ranges this query should be scoped to")

def parseCosmosPartitioningConfig(cfg: Map[String, String]): CosmosPartitioningConfig = {
val partitioningStrategyParsed = CosmosConfigEntry
.parse(cfg, partitioningStrategy)
.getOrElse(DefaultPartitioningStrategy)

val partitioningFeedRangeFilterParsed = CosmosConfigEntry
.parse(cfg, partitioningFeedRangeFilter)

val targetedPartitionCountParsed = if (partitioningStrategyParsed == PartitioningStrategies.Custom) {
CosmosConfigEntry.parse(cfg, targetedPartitionCount)
} else {
Expand All @@ -679,7 +726,8 @@ private object CosmosPartitioningConfig {

CosmosPartitioningConfig(
partitioningStrategyParsed,
targetedPartitionCountParsed
targetedPartitionCountParsed,
partitioningFeedRangeFilterParsed
)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ private object CosmosConstants {
val maxRetryIntervalForTransientFailuresInMs = 5000
val maxRetryCountForTransientFailures = 100
val defaultDirectRequestTimeoutInSeconds = 10L
val feedRangesCacheIntervalInMinutes = 1

object Names {
val ItemsDataSourceShortName = "cosmos.oltp"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,10 @@ private object CosmosPartitionPlanner extends BasicLoggingTrait {

assertOnSparkDriver()
val lastContinuationTokens: ConcurrentMap[FeedRange, String] = new ConcurrentHashMap[FeedRange, String]()
container
.getFeedRanges
.asScala
.flatMapMany(feedRanges => SFlux.fromIterable(feedRanges.asScala))

ContainerFeedRangesCache
.getFeedRanges(container)
FabianMeiswinkel marked this conversation as resolved.
Show resolved Hide resolved
.flatMapMany(feedRanges => SFlux.fromIterable(feedRanges))
.flatMap(feedRange => {
val requestOptions = changeFeedConfig.toRequestOptions(feedRange)
requestOptions.setMaxItemCount(1)
Expand Down Expand Up @@ -462,11 +462,9 @@ private object CosmosPartitionPlanner extends BasicLoggingTrait {
val container = ThroughputControlHelper.getContainer(userConfig, cosmosContainerConfig, clientCacheItem.client)
SparkUtils.safeOpenConnectionInitCaches(container, (msg, e) => logWarning(msg, e))

container
.getFeedRanges
.asScala
ContainerFeedRangesCache
.getFeedRanges(container)
.map(feedRanges => feedRanges
.asScala
.map(feedRange => SparkBridgeImplementationInternal.toNormalizedRange(feedRange))
.toArray)
})
Expand Down