In [0]:
from pyspark.sql.types import StringType
from pyspark.sql.functions import udf

cosmosEndpoint = "https://REPLACEME.documents.azure.com:443/"
cosmosMasterKey = "REPLACEME"
cosmosDatabaseName = "sampleDB"
cosmosContainerName = "sampleContainer"

# Configure Catalog Api to be used
spark.conf.set("spark.sql.catalog.cosmoscatalog", "com.azure.cosmos.spark.CosmosCatalog")
spark.conf.set("spark.sql.catalog.cosmoscatalog.spark.cosmos.accountEndpoint", cosmosEndpoint)
spark.conf.set("spark.sql.catalog.cosmoscatalog.spark.cosmos.accountKey", cosmosMasterKey)

cfg = {
  "spark.cosmos.accountEndpoint" : cosmosEndpoint,
  "spark.cosmos.accountKey" : cosmosMasterKey,
  "spark.cosmos.database" : cosmosDatabaseName,
  "spark.cosmos.container" : cosmosContainerName,
  "spark.cosmos.read.partitioning.strategy" : "Restrictive"
}

In [0]:
# create an Azure Cosmos DB database using catalog api
spark.sql("CREATE DATABASE IF NOT EXISTS cosmoscatalog.{};".format(cosmosDatabaseName))

# create an Azure Cosmos DB container with hierarchical partitioning using catalog api
spark.sql("CREATE TABLE IF NOT EXISTS cosmoscatalog.{}.{} using cosmos.oltp TBLPROPERTIES(partitionKeyPath = '/tenantId,/userId,/sessionId', manualThroughput = '1100')".format(cosmosDatabaseName, cosmosContainerName))

#ingest some data
spark.createDataFrame((("id1", "tenant 1", "User 1", "session 1"), ("id2", "tenant 1", "User 1", "session 1"), ("id3", "tenant 2", "User 1", "session 1"))) \
  .toDF("id","tenantId","userId","sessionId") \
   .write \
   .format("cosmos.oltp") \
   .options(**cfg) \
   .mode("APPEND") \
   .save()

In [0]:
#query by filtering the first two levels in the hierarchy without feedRangeFilter - this is less efficient as it will go through all physical partitions
query_df = spark.read.format("cosmos.oltp").options(**cfg) \
.option("spark.cosmos.read.customQuery" , "SELECT * from c where c.tenantId = 'tenant 1' and c.userId = 'User 1'").load()
query_df.show()

In [0]:
# prepare feed range to filter on first two levels in the hierarchy
spark.udf.registerJavaFunction("GetFeedRangeForPartitionKey", "com.azure.cosmos.spark.udf.GetFeedRangeForHierarchicalPartitionKeyValues", StringType())
pkDefinition = "{\"paths\":[\"/tenantId\",\"/userId\",\"/sessionId\"],\"kind\":\"MultiHash\"}"
pkValues = "[\"tenant 1\", \"User 1\"]"
feedRangeDf = spark.sql(f"SELECT GetFeedRangeForPartitionKey('{pkDefinition}', '{pkValues}')")
feedRange = feedRangeDf.collect()[0][0]

# query by filtering the first two levels in the hierarchy using feedRangeFilter (will target the physical partition in which all sub-partitions are co-located)
query_df = spark.read.format("cosmos.oltp").options(**cfg).option("spark.cosmos.partitioning.feedRangeFilter",feedRange).load()
query_df.show()