In [0]:
import json
import dlt
import pyspark.sql.functions as fn
from pyspark.sql.types import StringType
from pyspark.sql.functions import lit
from datetime import datetime
#from pyspark.sql.types import *
binary_to_string = fn.udf(lambda x: str(int.from_bytes(x, byteorder='big')), StringType())

def get_config(filegroup, layer):
    """
        loads data quality rules from a table
        :param tag: tag to match
        :return: dictionary of rules that matched the tag
    """
    config_list = [] 
    #df = spark.read.table("config")
    query = f"SELECT * FROM rt_config WHERE target_layer = \"{layer}\" and filegroup_id = {filegroup}"
    df = spark.sql(query)

    input_list = [] 
    for row in df.collect():
        config = {}
        sql_stmt = row['sql_stmt']
        config["kafka_server"] = row['kafka_server']
        config["kafka_key"] = row['kafka_key']
        config["kv_kafka_secret_key"] = row['kv_kafka_secret_key']
        config["out_table"] = row['out_table']
        config["topic"] = row['topic']

        dq_rules = {}
        for i in row["dq_rules"]:
            dq_rules = {i["name"] : i["expectation"]}
            #dq_rules.append(dq_rule)
        
        config["dq_rules"] = dq_rules

        inputs_json = json.loads(row['inputs'])
        input_list = [] 
        for key in inputs_json:
            input_list.append(inputs_json[key])
            if row['out_type'] == "TABLE":
                table_name = inputs_json[key]
            else:
                table_name = inputs_json[key].split('.')[-1]
            sql_stmt = sql_stmt.replace("{" + key + "}", "LIVE." + table_name)     #The table name is split to remove catalog and schema as this a temporary view
        config["input_tables"] = input_list
        config["sql_stmt"] = sql_stmt
        
        config_list.append(config)

    return config_list

In [0]:
def load_table_to_staging(kafka_server, kafka_key, kv_kafka_secret_key, out_table, topic):
    @dlt.table(name=f"STG_{out_table}",comment=f"This is a Staging table STG_{out_table}",)
    def real_time_stage():
        kafka_secret = dbutils.secrets.get(scope="jriscopekv", key=kv_kafka_secret_key)
        sasl = f'kafkashaded.org.apache.kafka.common.security.plain.PlainLoginModule required username="{kafka_key}" password="{kafka_secret}";'

        kafka_options = {
            "kafka.bootstrap.servers": kafka_server,
            "kafka.sasl.mechanism": "PLAIN",
            "kafka.security.protocol": "SASL_SSL",
            "kafka.request.timeout.ms": "60000",
            "kafka.session.timeout.ms": "30000",
            "startingOffsets": "earliest",
            "kafka.sasl.jaas.config": sasl,
            "subscribe": topic
        }
        rtdf =  spark.readStream.format("kafka").options(**kafka_options).load().withColumn('key', fn.col("key").cast(StringType())).withColumn('actualvalue', fn.col("value").cast(StringType())).withColumn("load_date",lit(datetime.now()))

        return rtdf

In [0]:
stage_loading = get_config(3,"STAGE")
for config in stage_loading:
    load_table_to_staging(config["kafka_server"],config["kafka_key"],config["kv_kafka_secret_key"], config["out_table"], config["topic"])

In [0]:
@dlt.table(name=f"catalog",comment=f"This is a dim catalog table",)
def catalog_silver():
    df =  spark.sql("SELECT * FROM STREAM catalog")
    return df

@dlt.table(name=f"SILVER_TXN",comment=f"This is the formatted silver transaction table",)
def transaction_silver():
    formattxn_df =  spark.sql("SELECT v.* from (SELECT from_json(cast(value AS STRING), 'ordertime BIGINT, orderid INT, itemid STRING, billamount INT, country STRING') v FROM STREAM LIVE.STG_Txn)")
    return formattxn_df

@dlt.table(name=f"SILVER_JOIN",comment=f"This is a Joined table",)
def stream_static_join():
    rtdf =  spark.sql("SELECT A.itemid,A.ordertime, A.orderid, B.item_desc, A.billamount,A.country FROM STREAM LIVE.SILVER_TXN as A INNER JOIN STREAM LIVE.catalog as B ON A.itemid = B.itemid;")
    return rtdf


@dlt.table(name=f"GOLD_AGG",comment=f"This is an aggregated table",)
def aggregations():
    rtdf =  spark.sql("SELECT COUNT(*) AS COUNT, country FROM LIVE.SILVER_TXN GROUP BY country;")
    return rtdf