In [1]:
entity = '<entity>'                     # Name of the entity to process
target_schema = '<target schema>'       # Schema of the target table

storage_account = '<storage account>'   # Storage account name (no FQDN)
container = '<container>'               # Container for Synapse Link for Dataverse
servername = '<server>'                 # Azure SQL Database server name
dbname = '<database>'                   # Azure SQL Database name

In [2]:
from pyspark.sql.functions import to_timestamp, col, dense_rank, desc, rank,row_number, coalesce
from pyspark.sql.window import Window
import json
import pyspark.sql.types as types
import pyodbc
import re
import struct

In [3]:
# Functions

# Returns the spark datatype based on the CDM model datatype
def get_attribute_spark_datatype(attribute):
    match attribute['dataType']:
        case 'boolean':
            return types.BooleanType()
        case 'dateTime':
            return types.StringType()
        case 'decimal':
            numeric_trait = [t for t in attribute["cdm:traits"] if t["traitReference"] == "is.dataFormat.numeric.shaped"][0]
            precision = int([a for a in numeric_trait["arguments"] if a["name"]=="precision"][0]["value"])
            scale = int([a for a in numeric_trait["arguments"] if a["name"]=="scale"][0]["value"])
            return types.DecimalType(precision, scale)
        case 'double':
            return types.DoubleType()
        case 'guid':
            return types.StringType()
        case 'int64':
            return types.LongType()
        case 'string':
            return types.StringType()
        case _:
            raise Exception(f"Unsupported CDM data type: {attribute['dataType']}")

# Read the content of a file
def get_file_content(path):
    for file in mssparkutils.fs.ls(model_path):
        pass 

    if not file.isFile:
        raise ValueError(f"The specified path is not a file: {path}")

    return mssparkutils.fs.head(model_path, file.size)


# Get the AAD token to access Azure SQL Database
def get_sql_token():
    return mssparkutils.credentials.getToken('DW')

# Open a new pyodbc SQL Server connection 
def get_pyodbc_sql_conn(server, database):
    sql_token = get_sql_token()
    token_bytes = sql_token.encode("UTF-16-LE")
    token_struct = struct.pack(f'<I{len(token_bytes)}s', len(token_bytes), token_bytes)

    SQL_COPT_SS_ACCESS_TOKEN = 1256
    conn = pyodbc.connect( 'DRIVER={ODBC Driver 18 for SQL Server};'
                        f'SERVER={server};'
                        f'DATABASE={database}', \
                        attrs_before={SQL_COPT_SS_ACCESS_TOKEN: token_struct})
                    
    return conn

re_blob_path = re.compile(r'https://([^/]+).blob.core.windows.net/([^/]+)/(.+)')

# Get the list of partitions to process for the specified entity
def get_entity_partitions_to_process(entity, server, database):
    conn = get_pyodbc_sql_conn(server, database)
    cursor = conn.execute(
        'SELECT DISTINCT [Id], [BasePath], [Timestamp], [Partition]'
        'FROM [DataverseToSql].[BlobsToIngest]'
        'WHERE [EntityName] = ? AND [LoadType] = 0 AND [Complete] = 0',
        entity)

    ret = list(
        (
            jobid,
            f'abfss://{path.group(2)}@{path.group(1)}.dfs.core.windows.net/{path.group(3)}/{timestamp}/{partition}_*{timestamp}.csv',
            f'abfss://{path.group(2)}@{path.group(1)}.dfs.core.windows.net/{path.group(3)}/{timestamp}/{partition}_{timestamp}.parquet',
        )
        for jobid, path, timestamp, partition in ((row[0], re_blob_path.match(row[1]), row[2], row[3]) for row in cursor))

    conn.close()

    return ret

# Get the list of partitions to load for the specified entity
def get_entity_partitions_to_load(entity, server, database):
    conn = get_pyodbc_sql_conn(server, database)
    cursor = conn.execute(
        'SELECT DISTINCT [Id], [BlobName]'
        'FROM [DataverseToSql].[BlobsToIngest]'
        'WHERE [EntityName] = ? AND [LoadType] = 2 AND [Complete] = 0',
        entity)

    ret = list(cursor)

    conn.close()

    return ret

# Mark an injestion job as complete
# For initial ingestion the job corresponds to an entity partition
def mark_job_complete(jobid, server, database):
    conn = get_pyodbc_sql_conn(server, database)
    conn.execute('exec [DataverseToSql].[IngestionJobs_Complete] ?', jobid)
    conn.commit()
    conn.close()

# Insert a parquet file in metadata for loading
def insert_parquet_for_loading(entity, path, server, database):
    conn = get_pyodbc_sql_conn(server, database)
    conn.execute(   'INSERT [DataverseToSql].[BlobsToIngest]('
                        'EntityName,'
                        'BlobName,'
                        'BasePath,'
                        'Timestamp,'
                        'Partition,'
                        'LoadType,'
                        'Complete'
                    ') VALUES(?,?,?,?,?,?,?)', 
                    entity, path, "","", "", 2, 0)
    conn.commit()
    conn.close()


# Process a partition of an entity
def process_partition(entity, source_path, target_path, schema, timestamp_cols, target_columns, server, database):
    df_source = spark.read \
        .schema(schema) \
        .option("mode", "PERMISSIVE") \
        .option("multiline", True) \
        .option("header", False) \
        .option("escape", '"') \
        .csv(source_path) \
        .repartition(1)

    # Identify the columns that are common to the source file and the target table.
    common_columns = [col for col in target_columns if col in df_source.columns]

    # Convert datetime columns
    for colname, formats in timestamp_cols:
        df_source = df_source.withColumn(colname, coalesce(*[to_timestamp(col(colname), fmt) for fmt in formats]))

    # Deduplicate the records and write to the target table
    # Consider only the latest record for each Id
    # The latest record is identified based on the fields SinkModifiedOn and versionnumber
    order_by_columns = [desc(c) for c in ["SinkModifiedOn", "versionnumber"]]
    rownum_colname = "dv2sqlrownum"

    # Deleted record are skipped
    df = df_source \
        .withColumn(rownum_colname, row_number().over(Window.partitionBy("Id").orderBy(order_by_columns))) \
        .where(col(rownum_colname) == 1) \
        .where((col('IsDelete').isNull()) | (col('IsDelete') == False)) \
        .select(common_columns) \
        .write \
        .mode("overwrite") \
        .parquet(target_path)

    # Insert parquet file for later loading
    insert_parquet_for_loading(entity, target_path, server, database)

def load_partition(source_path, jdbc_url, target_table):
    spark \
        .read \
        .parquet(source_path) \
        .write \
        .mode("append") \
        .format("com.microsoft.sqlserver.jdbc.spark") \
        .option("url", jdbc_url) \
        .option("dbtable", target_table) \
        .option("accessToken", get_sql_token()) \
        .option("driver", "com.microsoft.sqlserver.jdbc.SQLServerDriver") \
        .option("schemaCheckEnabled", False) \
        .option("tableLock", True) \
        .save()

def truncate_table(table, servername, database):
    conn = get_pyodbc_sql_conn(servername, database)
    conn.execute(f'TRUNCATE TABLE {table}')
    conn.commit()
    conn.close()


In [4]:
target_table = f"[{target_schema}].[{entity}]"
jdbc_url = "jdbc:sqlserver://" + servername + ";" + "databaseName=" + dbname + ";"

# Read entity metadata
model_path = f'abfss://{container}@{storage_account}.dfs.core.windows.net/Microsoft.Athena.TrickleFeedService/{entity}-model.json'
model_json = json.loads(get_file_content(model_path))

# Read target table schema
# Note: this step does not load data from the target table; it retrieves metadata only
target_df = spark.read \
    .format("jdbc") \
    .option("driver", "com.microsoft.sqlserver.jdbc.SQLServerDriver") \
    .option("url", jdbc_url) \
    .option("accessToken", get_sql_token()) \
    .option("dbtable", target_table) \
    .load()

target_field_map = dict((field.name, field) for field in target_df.schema)

# Build the CSV schema and identify timestamp columns
schema = types.StructType()
timestamp_cols = []
timestamp_formats = [ "M/d/y h:m:s a", "y-M-d'T'H:m:s.SSSSSSSXXXXX", "y-M-d'T'H:m:sX"]

for attribute in model_json['entities'][0]['attributes']:
    # Align CSV data types to the target table to avoid conversion errors during bulk load
    # This could happen with optionset fields when the int datatype override is enabled
    # Does not apply to dateTime columns
    if attribute['name'] in target_field_map and attribute['dataType'] != 'dateTime':
        schema.add(
            field = attribute['name'],
            data_type = target_field_map[attribute['name']].dataType,
            nullable = target_field_map[attribute['name']].nullable
        )
    else:
        schema.add(
            field = attribute['name'],
            data_type = get_attribute_spark_datatype(attribute),
            nullable = True
        )

    if attribute['dataType'] == 'dateTime':
        timestamp_cols.append((attribute['name'], timestamp_formats))

# Iterate over the table one partition at a time
# Partitions are ingested serially because table lock is used during load
for jobid, source_path, parquet_path in get_entity_partitions_to_process(entity, servername, dbname):
    process_partition(entity, source_path, parquet_path, schema, timestamp_cols, target_df.columns, servername, dbname)
    mark_job_complete(jobid, servername, dbname)

truncate_table(target_table, servername, dbname)

for jobid, parquet_path in get_entity_partitions_to_load(entity, servername, dbname):
    load_partition(parquet_path, jdbc_url, target_table)
    mark_job_complete(jobid, servername, dbname)
