In [1]:
dbutils.widgets.text("partitionKey", "0", "Partition Key")

In [2]:
val partitionKey = dbutils.widgets.get("partitionKey").toInt
val prevPartitionKey = partitionKey

val scope = "key-vault-secrets"

val storageAccount = "dmstore2";
val storageKey = dbutils.secrets.get(scope, "dmstore2-2");

val server = dbutils.secrets.get(scope, "srv001").concat(".database.windows.net");
val database = dbutils.secrets.get(scope, "db001");
val user = dbutils.secrets.get(scope, "dbuser001");
val password = dbutils.secrets.get(scope, "dbpwd001");
val table = "dbo.LINEITEM_LOADTEST"

val url = s"jdbc:sqlserver://$server;databaseName=$database;"

In [3]:
spark.conf.set(s"fs.azure.account.key.$storageAccount.blob.core.windows.net", storageKey);

In [4]:
val li = spark
  .read
  .parquet(s"wasbs://tpch@$storageAccount.blob.core.windows.net/10GB/parquet/lineitem")
  .filter($"L_PARTITION_KEY" === partitionKey)

In [5]:
import org.apache.spark.sql.types._

val schema = StructType(
    StructField("L_ORDERKEY", IntegerType, false) ::
    StructField("L_PARTKEY", IntegerType, false) ::
    StructField("L_SUPPKEY", IntegerType, false) ::  
    StructField("L_LINENUMBER", IntegerType, false) ::
    StructField("L_QUANTITY", DecimalType(15,2), false) ::
    StructField("L_EXTENDEDPRICE", DecimalType(15,2), false) ::
    StructField("L_DISCOUNT", DecimalType(15,2), false) ::
    StructField("L_TAX", DecimalType(15,2), false) ::
    StructField("L_RETURNFLAG", StringType, false) ::
    StructField("L_LINESTATUS", StringType, false) ::
    StructField("L_SHIPDATE", DateType, false) ::
    StructField("L_COMMITDATE", DateType, false) ::
    StructField("L_RECEIPTDATE", DateType, false) ::
    StructField("L_SHIPINSTRUCT", StringType, false) ::  
    StructField("L_SHIPMODE", StringType, false) ::  
    StructField("L_COMMENT", StringType, false) ::  
    StructField("L_PARTITION_KEY", IntegerType, false) ::  
    Nil)
    
val li2 = spark.createDataFrame(li.rdd, schema)

In [6]:
val sqlPartitionValueInfo = 
s"""
SELECT
	*
FROM
(
	SELECT
		prv.[boundary_id] AS partitionId,
		CAST(prv.[value] AS INT) AS [value],
		CAST(LAG(prv.[value]) OVER (ORDER BY prv.[boundary_id]) AS INT) AS [prevValue],
		CAST(LEAD(prv.[value]) OVER (ORDER BY prv.[boundary_id]) AS INT) AS [nextValue]
	FROM
		sys.[indexes] i
	INNER JOIN
		sys.[data_spaces] dp ON i.[data_space_id] = dp.[data_space_id]
	INNER JOIN
		sys.[partition_schemes] ps ON dp.[data_space_id] = ps.[data_space_id]
	INNER JOIN
		sys.[partition_range_values] prv ON [prv].[function_id] = [ps].[function_id]
	WHERE
		i.[object_id] = OBJECT_ID('${table}')
	AND
		i.[index_id] IN (0,1)
) AS [pi]
WHERE
	[value] = ${partitionKey}
"""

In [7]:
val connectionProperties = new java.util.Properties()
connectionProperties.put("user", user)
connectionProperties.put("password", password)
connectionProperties.setProperty("Driver", "com.microsoft.sqlserver.jdbc.SQLServerDriver")
val conn = java.sql.DriverManager.getConnection(url, connectionProperties)
val st = conn.createStatement()

In [8]:
case class PartitionInfo(partitionId: Int, value: Int, prevValue: Option[Int], nextValue: Option[Int]);
val piDF = spark.read.jdbc(url, s"($sqlPartitionValueInfo) AS t", connectionProperties)
val pi= piDF.as[PartitionInfo].collect()(0)

In [9]:
st.execute(s"DROP TABLE IF EXISTS ${table}_STG_${partitionKey}")
st.execute(s"SELECT TOP (0) * INTO ${table}_STG_${partitionKey} FROM ${table}")

In [10]:
st.execute(s"CREATE CLUSTERED INDEX IXC ON ${table}_STG_${partitionKey} ([L_COMMITDATE], [L_PARTITION_KEY])")
st.execute(s"CREATE UNIQUE NONCLUSTERED INDEX IX1 ON ${table}_STG_${partitionKey} ([L_ORDERKEY], [L_LINENUMBER], [L_PARTITION_KEY])")
st.execute(s"CREATE NONCLUSTERED INDEX IX2 ON ${table}_STG_${partitionKey} ([L_PARTKEY], [L_PARTITION_KEY])")

In [11]:
li2.write 
  .format("com.microsoft.sqlserver.jdbc.spark") 
  .mode("overwrite")   
  .option("truncate", "true") 
  .option("url", url) 
  .option("dbtable", s"${table}_STG_${partitionKey}") 
  .option("user", user) 
  .option("password", password) 
  .option("reliabilityLevel", "BEST_EFFORT") 
  .option("tableLock", "false") 
  .option("batchsize", "100000")   
  .save()

In [12]:
if (pi.prevValue == None) {
  st.execute(s"ALTER TABLE ${table}_STG_${partitionKey} ADD CONSTRAINT ck_partition_${partitionKey} CHECK (L_PARTITION_KEY <= ${pi.value})")
} else {
  st.execute(s"ALTER TABLE ${table}_STG_${partitionKey} ADD CONSTRAINT ck_partition_${partitionKey} CHECK (L_PARTITION_KEY > ${pi.prevValue.get} AND L_PARTITION_KEY <= ${pi.value})")
}

In [13]:
st.execute(s"TRUNCATE TABLE ${table} WITH (PARTITIONS (${pi.partitionId}))")
st.execute(s"ALTER TABLE ${table}_STG_${partitionKey} SWITCH TO ${table} PARTITION ${pi.partitionId}")
st.execute(s"DROP TABLE ${table}_STG_${partitionKey}")

In [14]:
dbutils.notebook.exit(partitionKey.toString)

199810