**Secrets**

The secrets below  like the Cosmos account key are retrieved from a secret scope. If you don't have defined a secret scope for a Cosmos Account you want to use when going through this sample you can find the instructions on how to create one here:
- Here you can [Create a new secret scope](./#secrets/createScope) for the current Databricks workspace
  - See how you can create an [Azure Key Vault backed secret scope](https://docs.microsoft.com/azure/databricks/security/secrets/secret-scopes#--create-an-azure-key-vault-backed-secret-scope) 
  - See how you can create a [Databricks backed secret scope](https://docs.microsoft.com/azure/databricks/security/secrets/secret-scopes#create-a-databricks-backed-secret-scope)
- And here you can find information on how to [add secrets to your Spark configuration](https://docs.microsoft.com/azure/databricks/security/secrets/secrets#read-a-secret)
If you don't want to use secrets at all you can of course also just assign the values in clear-text below - but for obvious reasons we recommend the usage of secrets.

In [0]:
cosmosEndpoint = spark.conf.get("spark.cosmos.accountEndpoint")
cosmosMasterKey = spark.conf.get("spark.cosmos.accountKey")

# It is allowed to use different account for throughput control. So please choose the accountEndpoint and accountKey accordingly.
# throughputControlEndpoint = spark.conf.get("spark.cosmos.throughputControlAccountEndpoint")
# throughputControlMasterKey = spark.conf.get("spark.cosmos.throughputControlAccountKey")

**Preparation - creating the Cosmos DB container to ingest the data into**

Configure the Catalog API to be used for main workload

In [0]:
import uuid
spark.conf.set("spark.sql.catalog.cosmosCatalog", "com.azure.cosmos.spark.CosmosCatalog")
spark.conf.set("spark.sql.catalog.cosmosCatalog.spark.cosmos.accountEndpoint", cosmosEndpoint)
spark.conf.set("spark.sql.catalog.cosmosCatalog.spark.cosmos.accountKey", cosmosMasterKey)
spark.conf.set("spark.sql.catalog.cosmosCatalog.spark.cosmos.views.repositoryPath", "/viewDefinitions" + str(uuid.uuid4()))


Configure the Catalog API to be used for throughput control. This will only be needed if different account is used for throughput control

In [0]:
# import uuid
# spark.conf.set("spark.sql.catalog.throughputControlCatalog", "com.azure.cosmos.spark.CosmosCatalog")
# spark.conf.set("spark.sql.catalog.throughputControlCatalog.spark.cosmos.accountEndpoint", throughputControlEndpoint)
# spark.conf.set("spark.sql.catalog.throughputControlCatalog.spark.cosmos.accountKey", throughputControlMasterKey)
# spark.conf.set("spark.sql.catalog.throughputControlCatalog.spark.cosmos.views.repositoryPath", "/viewDefinitions" + str(uuid.uuid4()))


And execute the command to create the new container with a throughput of up-to 100,000 RU (Autoscale - so 10,000 - 100,000 RU based on scale) and only system properties (like /id) being indexed. We will also create a second container that will be used to store metadata for the global throughput control

In [0]:
%sql
CREATE DATABASE IF NOT EXISTS cosmosCatalog.SampleDatabase;

CREATE TABLE IF NOT EXISTS cosmosCatalog.SampleDatabase.GreenTaxiRecords
USING cosmos.oltp
TBLPROPERTIES(partitionKeyPath = '/id', autoScaleMaxThroughput = '100000', indexingPolicy = 'OnlySystemProperties');

CREATE TABLE IF NOT EXISTS cosmosCatalog.SampleDatabase.GreenTaxiRecordsCFSink
USING cosmos.oltp
TBLPROPERTIES(partitionKeyPath = '/id', autoScaleMaxThroughput = '100000', indexingPolicy = 'OnlySystemProperties');

/* NOTE: It is important to enable TTL (can be off/-1 by default) on the throughput control container */
/* If you are using a different account for throughput control, then please reference following commented examples */
CREATE TABLE IF NOT EXISTS cosmosCatalog.SampleDatabase.ThroughputControl
USING cosmos.oltp
OPTIONS(spark.cosmos.database = 'SampleDatabase')
TBLPROPERTIES(partitionKeyPath = '/groupId', autoScaleMaxThroughput = '4000', indexingPolicy = 'AllProperties', defaultTtlInSeconds = '-1');

-- /* If you are using a different account for throughput control, then please use throughput control catalog account for initializing containers */
-- CREATE DATABASE IF NOT EXISTS throughputControlCatalog.SampleDatabase;

-- CREATE TABLE IF NOT EXISTS throughputControlCatalog.SampleDatabase.ThroughputControl
-- USING cosmos.oltp
-- OPTIONS(spark.cosmos.database = 'SampleDatabase')
-- TBLPROPERTIES(partitionKeyPath = '/groupId', autoScaleMaxThroughput = '4000', indexingPolicy = 'AllProperties', defaultTtlInSeconds = '-1');

-- /* NOTE: Below instructions can be used to modify provisioned throughput 
-- - either via ALTER TABLE (when throughput is provisioned dedicated for container) 
-- or ALTER DATABASE (when throughput is provisioned at the database level across multiple containers) */

-- /* Container-level, auto-scale */
-- ALTER TABLE cosmosCatalog.SampleDatabase.GreenTaxiRecords SET TBLPROPERTIES(autoScaleMaxThroughput = '100000')

-- /* Container-level, manual throughput */
-- ALTER TABLE cosmosCatalog.SampleDatabase.GreenTaxiRecords SET TBLPROPERTIES(manualThroughput = '100000')

-- /* DB-level, auto-scale */
-- ALTER DATABASE cosmosCatalog.SampleDatabase SET DBPROPERTIES(autoScaleMaxThroughput = '100000')

-- /* DB-level, manual throughput */
-- ALTER DATABASE cosmosCatalog.SampleDatabase SET DBPROPERTIES(manualThroughput = '100000')

**Preparation - loading data source "[NYC Taxi & Limousine Commission - green taxi trip records](https://azure.microsoft.com/services/open-datasets/catalog/nyc-taxi-limousine-commission-green-taxi-trip-records/)"**

The green taxi trip records include fields capturing pick-up and drop-off dates/times, pick-up and drop-off locations, trip distances, itemized fares, rate types, payment types, and driver-reported passenger counts. This data set has over 80 million records (>8 GB) of data and is available via a publicly accessible Azure Blob Storage Account located in the East-US Azure region.

In [0]:
import datetime
import time
import uuid
from pyspark.sql.functions import udf
from pyspark.sql.types import StringType, LongType

print("Starting preparation: ", datetime.datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S.%f"))
# Azure storage access info
blob_account_name = "azureopendatastorage"
blob_container_name = "nyctlc"
blob_relative_path = "green"
blob_sas_token = r""
# Allow SPARK to read from Blob remotely
wasbs_path = 'wasbs://%s@%s.blob.core.windows.net/%s' % (blob_container_name, blob_account_name, blob_relative_path)
spark.conf.set(
  'fs.azure.sas.%s.%s.blob.core.windows.net' % (blob_container_name, blob_account_name),
  blob_sas_token)
print('Remote blob path: ' + wasbs_path)
# SPARK read parquet, note that it won't load any data yet by now
# NOTE - if you want to experiment with larger dataset sizes - consider switching to Option B (commenting code 
# for Option A/uncommenting code for option B) the lines below or increase the value passed into the 
# limit function restricting the dataset size below

#------------------------------------------------------------------------------------
# Option A - with limited dataset size
#------------------------------------------------------------------------------------
df_rawInputWithoutLimit = spark.read.parquet(wasbs_path)
partitionCount = df_rawInputWithoutLimit.rdd.getNumPartitions()
df_rawInput = df_rawInputWithoutLimit.limit(1_000_000).repartition(partitionCount)
df_rawInput.persist()

#------------------------------------------------------------------------------------
# Option B - entire dataset
#------------------------------------------------------------------------------------
#df_rawInput = spark.read.parquet(wasbs_path)

# Adding an id column with unique values
uuidUdf= udf(lambda : str(uuid.uuid4()),StringType())
nowUdf= udf(lambda : int(time.time() * 1000),LongType())
df_input_withId = df_rawInput \
  .withColumn("id", uuidUdf()) \
  .withColumn("insertedAt", nowUdf()) \

print('Register the DataFrame as a SQL temporary view: source')
df_input_withId.createOrReplaceTempView('source')
print("Finished preparation: ", datetime.datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S.%f"))

** Sample - ingesting the NYC Green Taxi data into Cosmos DB**

By setting the target throughput threshold to 0.95 (95%) we reduce throttling but still allow the ingestion to consume most of the provisioned throughput. For scenarios where ingestion should only take a smaller subset of the available throughput this threshold can be reduced accordingly.

In [0]:
import uuid
import datetime

print("Starting ingestion: ", datetime.datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S.%f"))

writeCfg = {
  "spark.cosmos.accountEndpoint": cosmosEndpoint,
  "spark.cosmos.accountKey": cosmosMasterKey,
  "spark.cosmos.database": "SampleDatabase",
  "spark.cosmos.container": "GreenTaxiRecords",
  "spark.cosmos.write.strategy": "ItemOverwrite",
  "spark.cosmos.write.bulk.enabled": "true",
  "spark.cosmos.throughputControl.enabled": "true",
#   "spark.cosmos.throughputControl.accountEndpoint": throughputControlEndpoint, # Only need if throughput control is configured with different database account
#   "spark.cosmos.throughputControl.accountKey": throughputControlMasterKey, # Only need if throughput control is configured with different database account
  "spark.cosmos.throughputControl.name": "NYCGreenTaxiDataIngestion",
  "spark.cosmos.throughputControl.targetThroughputThreshold": "0.95",
  "spark.cosmos.throughputControl.globalControl.database": "SampleDatabase",
  "spark.cosmos.throughputControl.globalControl.container": "ThroughputControl",
}

df_NYCGreenTaxi_Input = spark.sql('SELECT * FROM source')

df_NYCGreenTaxi_Input \
  .write \
  .format("cosmos.oltp") \
  .mode("Append") \
  .options(**writeCfg) \
  .save()

print("Finished ingestion: ", datetime.datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S.%f"))

**Getting the reference record count**

In [0]:
count_source = spark.sql('SELECT * FROM source').count()
print("Number of records in source: ", count_source) 

**Sample - validating the record count via query**

In [0]:
from pyspark.sql.types import *
import pyspark.sql.functions as F

print("Starting validation via query: ", datetime.datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S.%f"))
readCfg = {
  "spark.cosmos.accountEndpoint": cosmosEndpoint,
  "spark.cosmos.accountKey": cosmosMasterKey,
  "spark.cosmos.database": "SampleDatabase",
  "spark.cosmos.container": "GreenTaxiRecords",
  "spark.cosmos.read.partitioning.strategy": "Restrictive",#IMPORTANT - any other partitioning strategy will result in indexing not being use to count - so latency and RU would spike up
  "spark.cosmos.read.inferSchema.enabled" : "false",
  "spark.cosmos.read.customQuery" : "SELECT COUNT(0) AS Count FROM c"
}

count_query_schema=StructType(fields=[StructField("Count", LongType(), True)])
query_df = spark.read.format("cosmos.oltp").schema(count_query_schema).options(**readCfg).load()
count_query = query_df.select(F.sum("Count").alias("TotalCount")).first()["TotalCount"]
print("Number of records retrieved via query: ", count_query) 
print("Finished validation via query: ", datetime.datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S.%f"))

assert count_source == count_query

**Sample - validating the record count via change feed**

In [0]:
print("Starting validation via change feed: ", datetime.datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S.%f"))
changeFeedCfg = {
  "spark.cosmos.accountEndpoint": cosmosEndpoint,
  "spark.cosmos.accountKey": cosmosMasterKey,
  "spark.cosmos.database": "SampleDatabase",
  "spark.cosmos.container": "GreenTaxiRecords",
  "spark.cosmos.read.partitioning.strategy": "Default",
  "spark.cosmos.read.inferSchema.enabled" : "false",
  "spark.cosmos.changeFeed.startFrom" : "Beginning",
  "spark.cosmos.changeFeed.mode" : "Incremental"
}
changeFeed_df = spark.read.format("cosmos.oltp.changeFeed").options(**changeFeedCfg).load()
count_changeFeed = changeFeed_df.count()
print("Number of records retrieved via change feed: ", count_changeFeed) 
print("Finished validation via change feed: ", datetime.datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S.%f"))

assert count_source == count_changeFeed

**Sample - bulk deleting documents and validating document count afterwards**

In [0]:
import math

print("Starting to identify to be deleted documents: ", datetime.datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S.%f"))
readCfg = {
  "spark.cosmos.accountEndpoint": cosmosEndpoint,
  "spark.cosmos.accountKey": cosmosMasterKey,
  "spark.cosmos.database": "SampleDatabase",
  "spark.cosmos.container": "GreenTaxiRecords",
  "spark.cosmos.read.partitioning.strategy": "Default",
  "spark.cosmos.read.inferSchema.enabled" : "false",
}

toBeDeleted_df = spark.read.format("cosmos.oltp").options(**readCfg).load().limit(100_000)
print("Number of records to be deleted: ", toBeDeleted_df.count()) 

print("Starting to bulk delete documents: ", datetime.datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S.%f"))
deleteCfg = writeCfg.copy()
deleteCfg["spark.cosmos.write.strategy"] = "ItemDelete"
toBeDeleted_df \
        .write \
        .format("cosmos.oltp") \
        .mode("Append") \
        .options(**deleteCfg) \
        .save()
print("Finished deleting documents: ", datetime.datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S.%f"))

print("Starting count validation via query: ", datetime.datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S.%f"))
count_query_schema=StructType(fields=[StructField("Count", LongType(), True)])
readCfg["spark.cosmos.read.customQuery"] = "SELECT COUNT(0) AS Count FROM c"
query_df = spark.read.format("cosmos.oltp").schema(count_query_schema).options(**readCfg).load()
count_query = query_df.select(F.sum("Count").alias("TotalCount")).first()["TotalCount"]
print("Number of records retrieved via query: ", count_query) 
print("Finished count validation via query: ", datetime.datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S.%f"))

assert max(0, count_source - 100_000) == count_query

**Sample - showing the existing Containers**

In [0]:
%sql
SHOW TABLES FROM cosmosCatalog.SampleDatabase

In [0]:
df_Tables = spark.sql('SHOW TABLES FROM cosmosCatalog.SampleDatabase')
assert df_Tables.count() == 3

**Sample - querying a Cosmos Container via Spark Catalog**

In [0]:
%sql
SELECT * FROM cosmosCatalog.SampleDatabase.GreenTaxiRecords LIMIT 10

**Sample - querying a Cosmos Container with custom settings via Spark Catalog**

Creating the view with custom settings (in this case adding a projection, disabling schema inference and switching to aggressive partitioning strategy)

In [0]:
%sql
CREATE TABLE cosmosCatalog.SampleDatabase.GreenTaxiRecordsView 
  (id STRING, _ts TIMESTAMP, vendorID INT, totalAmount DOUBLE)
USING cosmos.oltp
TBLPROPERTIES(isCosmosView = 'True')
OPTIONS (
  spark.cosmos.database = 'SampleDatabase',
  spark.cosmos.container = 'GreenTaxiRecords',
  spark.cosmos.read.inferSchema.enabled = 'False',
  spark.cosmos.read.inferSchema.includeSystemProperties = 'True',
  spark.cosmos.read.partitioning.strategy = 'Aggressive');

SELECT * FROM cosmosCatalog.SampleDatabase.GreenTaxiRecordsView LIMIT 10

Creating another view with custom settings (in this case enabling schema inference and switching to restrictive partitioning strategy)

In [0]:
%sql
CREATE TABLE cosmosCatalog.SampleDatabase.GreenTaxiRecordsAnotherView 
USING cosmos.oltp
TBLPROPERTIES(isCosmosView = 'True')
OPTIONS (
  spark.cosmos.database = 'SampleDatabase',
  spark.cosmos.container = 'GreenTaxiRecords',
  spark.cosmos.read.inferSchema.enabled = 'True',
  spark.cosmos.read.inferSchema.includeSystemProperties = 'False',
  spark.cosmos.read.partitioning.strategy = 'Restrictive');

SELECT * FROM cosmosCatalog.SampleDatabase.GreenTaxiRecordsAnotherView LIMIT 10

Show all Tables in the Cosmos Catalog to show that both the "real" Containers as well as the views show-up

In [0]:
%sql
SHOW TABLES FROM cosmosCatalog.SampleDatabase

In [0]:
df_Tables = spark.sql('SHOW TABLES FROM cosmosCatalog.SampleDatabase')
assert df_Tables.count() == 5

**Cleanup the views again**

In [0]:
%sql
DROP TABLE IF EXISTS cosmosCatalog.SampleDatabase.GreenTaxiRecordsView;
DROP TABLE IF EXISTS cosmosCatalog.SampleDatabase.GreenTaxiRecordsAnotherView;
SHOW TABLES FROM cosmosCatalog.SampleDatabase

In [0]:
df_Tables = spark.sql('SHOW TABLES FROM cosmosCatalog.SampleDatabase')
assert df_Tables.count() == 3