In [0]:
### Import libraries, required
from multiprocessing.pool import ThreadPool
from pyspark.sql.window import Window
from pyspark.sql import DataFrameWriter
from pyspark.sql.functions import lit, monotonically_increasing_id, row_number, max
from datetime import datetime, timedelta
import pandas as pd
import json

### pool size
pool = ThreadPool(10)

### Set parameters / arguments


dbutils.widgets.text("p_SourceSystem", "", "Source System Name")
p_SourceSystem = dbutils.widgets.get("p_SourceSystem")

dbutils.widgets.text("p_SubSystem", "", "Sub System Name")
p_SubSystem = dbutils.widgets.get("p_SubSystem")

dbutils.widgets.text("p_Phase", "", "PhaseName/StageName")
p_Phase = dbutils.widgets.get("p_Phase")

# assign constant arguments
c_stg_dbname ="raw"
c_src_file_format ="delta"
c_delta_load = "DELTA"
c_full_load = "FULL"
c_ref_load = "REF"
c_snapshot_load = "SNAPSHOT"
c_partial_load = "PARTIAL"
c_max_num_rec = 100000


### environment variables
v_process_start_dtm =datetime.utcnow().isoformat(sep=' ', timespec='seconds')
v_job_loc = dbutils.notebook.entry_point.getDbutils().notebook().getContext().notebookPath().get()
v_user = dbutils.notebook.entry_point.getDbutils().notebook().getContext().tags().apply('user')
v_job_name = v_job_loc.rsplit('/',1)[1]
v_ActivityName='DeltaLake: '+spark.conf.get("spark.databricks.clusterUsageTags.clusterId") 
v_ActivityType='DeltaMerge'

### generate a unique number
dt = datetime.now()
seq_Id = int(dt.strftime("%Y%m%d%H%M%S"))

In [0]:
jdbcConnStrUrl = dbutils.secrets.get(scope = "bpdevtestie1kv001", key = "bp-asql-bpdevtestie1sqldb004-jdbc-secret")
config_query = """(SELECT 
  [SourceName], 
  [System], 
  [SubSystem], 
  [TableName], 
  [LoadType], 
  [SourceDBName], 
  [SourceLocation], 
  [DestinationLocation],
  ADLMergeKey,
  ADLMergeSQL,
  [WatermarkColumn],
  CONVERT(VARCHAR(25), [WatermarkValue], 120) AS  [WatermarkValue],
  [SourceSQL],
  [StgActive]
FROM 
  [config].[IngestTableList]  
   INNER JOIN [audit].[Watermark] 
   ON LOWER([SubSystem])=LOWER([WatermarkSubSystem])
   AND LOWER([TableName])=LOWER([WatermarkTableName])
   AND [Active] = 1
   AND UPPER([SourceName]) = '{SRC_SYS}'
   AND UPPER([SubSystem]) = '{SRC_SUBSYS}') as IngestTableList """.format(SRC_SYS=p_SourceSystem, SRC_SUBSYS = p_SubSystem)

#Read metadata configuration table list from SQL database
config_table_df = spark.read.jdbc(url=jdbcConnStrUrl, table=config_query)

In [0]:

class SqlDB:
  db_conn_string= jdbcConnStrUrl
  
  def __init__(self):
    self.driver_manager = spark._sc._gateway.jvm.java.sql.DriverManager
    self.con = self.driver_manager.getConnection(jdbcConnStrUrl)

  def execute_spoc_without_output(self,spoc_statement):
    # Create callable statement
    exec_statement = self.con.prepareCall(spoc_statement)
        
    # Execute the statement
    result = exec_statement.execute()
        
    # Connections close
    exec_statement.close()
    self.con.close()
    
    return result
  
  def execute_spoc_with_output(self,spoc_statement):
    # Create callable statement
    exec_statement = self.con.prepareCall(spoc_statement)
        
    # Register the output parameters with their index and datatype (indexing starts at 1)
    exec_statement.registerOutParameter(1, spark._sc._gateway.jvm.java.sql.Types.INTEGER)
        
    # Execute the statement
    result = exec_statement.execute()
        
    # Fetch the result with the correct method (get<javaType>)
    result = exec_statement.getInt(1) 
    
    # Connections close
    exec_statement.close()
    self.con.close()
    
    return result
     
  def execute_query(sql_statement):
    #Read metadata configuration table list from SQL database
    config_table_df = spark.read.jdbc(url= jdbcConnStrUrl , table=sql_statement)
    return config_table_df

In [0]:
def file_exists(path):
  if path[:5] == "/dbfs":
    import os
    return os.path.exists(path)
  else:
    try:
      dbutils.fs.ls(path)
      return True
    except Exception as e:
      if 'java.io.FileNotFoundException' in str(e):
        return False
      else:
        raise

In [0]:
def search_database(database):
  if len([(i) for i in spark.catalog.listDatabases() if i.name==str(database)]) != 0:
    return True
  return False

In [0]:
def search_object(database, table):
	if len([(i) for i in spark.catalog.listTables(database) if i.name==str(table)]) != 0:
		return True
	return False

In [0]:
def get_row_processed_cnt(path , activityID):
  import databricks.koalas as ks
  from delta.tables import DeltaTable
  
  # set deltalake spark path
  deltaTable = DeltaTable.forPath(spark, path)
  # Gather audit details
  df = deltaTable.history(1)
  merge_record_cnt = df.select(df["operationMetrics"]["numTargetRowsInserted"],
                                       df["operationMetrics"]["numTargetRowsUpdated"], 
                                       df["operationMetrics"]["numTargetRowsDeleted"],
                                       df["operationMetrics"]["numSourceRows"],
                                       df["operationMetrics"]["numOutputRows"])
  num_rec_processed = merge_record_cnt.to_koalas().to_numpy()
  numTargetRowsInserted = int(num_rec_processed[0][0]) if num_rec_processed[0][0] is not None else 0
  numTargetRowsUpdated = int(num_rec_processed[0][1]) if num_rec_processed[0][1] is not None else 0
  numTargetRowsDeleted = int(num_rec_processed[0][2]) if num_rec_processed[0][2] is not None else 0
  numSourceRows = int(num_rec_processed[0][3]) if num_rec_processed[0][3] is not None else 0
  numOutputRows = int(num_rec_processed[0][4]) if num_rec_processed[0][4] is not None else 0
    
  # invoke the activity end date stored procedure
  activity_end_spoc_statement = f"""
    EXEC [audit].[uspActivityEnd] 
    {activityID}, {numSourceRows}, {numTargetRowsInserted}, 
    {numTargetRowsUpdated},{numTargetRowsDeleted}, {numOutputRows}"""
    
  sql_server_end_activity = SqlDB()
  audit_activity_end = sql_server_end_activity.execute_spoc_without_output(activity_end_spoc_statement)
    
  return audit_activity_end


def upd_activity_erorlog(activityID, errorMsg, batchId, addlInfo):
  # invoke the activity end date stored procedure
  activity_failed_spoc_statement = f"""
    EXEC [audit].[uspErrorLog]
    {activityID}, '{errorMsg}', {batchId}, '{addlInfo}'"""
  
  sql_server_errorlog = SqlDB()
  audit_activity_log = sql_server_errorlog.execute_spoc_without_output(activity_failed_spoc_statement)
  
  return audit_activity_log


def get_batchId(SourceSystem, SubSystem):
  # Write your sql statement as a string. Add a '?' for every output parameter.
  batch_spoc_statement = f""" EXEC [audit].[uspBatchStart] '{SourceSystem}', '{SubSystem}',? """
  
  sql_server_batch_start = SqlDB()
  audit_start_batchId = sql_server_batch_start.execute_spoc_with_output(batch_spoc_statement)
  
  return audit_start_batchId

def get_activityId(PipelineID, PipelineName, BatchID, Phase,
                   ActivityName, ActivityType, SourceObjectName, UserID):
  # invoke the activity end date stored procedure
  activity_start_spoc_statement = f"""
    EXEC [audit].[uspActivityStart] 
    '{PipelineID}', '{PipelineName}', {BatchID}, '{Phase}','{ActivityName}', '{ActivityType}', '{SourceObjectName}','{UserID}',?"""
  
  sql_server_activity = SqlDB()
  audit_start_activityId = sql_server_activity.execute_spoc_with_output(activity_start_spoc_statement)
  
  return audit_start_activityId

def end_batchId(BatchId):
  # Write your sql statement as a string. Add a '?' for every output parameter.
  batch_spoc_statement = f""" EXEC [audit].[uspBatchEnd] {BatchId} """
  
  sql_server_batch_end = SqlDB()
  audit_end_batchId = sql_server_batch_end.execute_spoc_without_output(batch_spoc_statement)
  
  return audit_end_batchId

def end_waterMark(WatermarkSubSystem, WatermarkTableName, WatermarkColumn):
  # Write your sql statement as a string. Add a '?' for every output parameter.
  watermark_spoc_statement = f""" EXEC [audit].[uspWatermarkEnd] '{WatermarkSubSystem}', '{WatermarkTableName}','{WatermarkColumn}' """
  
  sql_server_end_wm = SqlDB()
  audit_end_wm = sql_server_end_wm.execute_spoc_without_output(watermark_spoc_statement)
  
  return audit_end_wm

In [0]:
from pyspark.sql.functions  import date_format
from delta.tables import *

# get BatchId for auditing.
audit_batchId = get_batchId(p_SourceSystem, p_SubSystem)

data_itr = config_table_df.rdd.toLocalIterator()
# looping thorough each row of the dataframe
for row in data_itr:
  if row["SourceName"].upper() == p_SourceSystem:
    v_SourceObjectName = row["SourceDBName"]+"."+row["TableName"]
    audit_activityId = get_activityId(v_process_start_dtm,v_job_name,  audit_batchId, p_Phase,v_ActivityName, v_ActivityType, v_SourceObjectName, v_user)
    if file_exists(row["SourceLocation"]):
      if row["WatermarkValue"] is not None and row["ADLMergeKey"] is not None and row["LoadType"].strip().upper() != c_full_load :
        try:
          print('********  Incremental load has begun for '+ row["SourceDBName"]+"."+row["TableName"] + ' ********')
          #Read data from Code orange datalake trusted layer
          df = spark.read.format(c_src_file_format).load(row["SourceLocation"])
          audit_col = [col.strip() for col in df.columns if col == row["WatermarkColumn"] ]
          
          # using list comprehension
          load_state_col = ' '.join(map(str, audit_col))
          filter_value =  '1900-12-31 12:59:59' if row["WatermarkValue"] is None else row["WatermarkValue"]
          increment_df = df.where(df[load_state_col] > filter_value)
                    
          # prepare merge sql
          # generated joining conditons for merge statement
          merge_condition_lst =  row["ADLMergeKey"].split(",")
          empty_init_str = ''
          for colname in merge_condition_lst:
            empty_init_str = empty_init_str + 't.{}'.format(colname)+ ' = s.{}'.format(colname) + '  AND '
          
          # generated joining conditons for merge statement
          conditional_str = empty_init_str[:-5]                  
          deltaTable = DeltaTable.forPath(spark, row["DestinationLocation"])
          
          # Schema Evolution with a Merge Operation
          deltaTable.alias("t").merge(
           increment_df.alias("s"),
           conditional_str
           ).whenMatchedUpdateAll(
           ).whenNotMatchedInsertAll(
           ).execute()
          
          if row["StgActive"] == '1':
            src_query = row["SourceSQL"].rstrip('\r').rstrip('\n')
            stg_tbl_nm = c_stg_dbname + '.' + row["TableName"]
            query_df = spark.sql(src_query)
            query_df = DataFrameWriter(query_df)
            query_df.jdbc(url=jdbcConnStrUrl, table=stg_tbl_nm, mode="overwrite")
            
          # write audit details
          spoc_end_watermark=end_waterMark(row["SubSystem"],row["TableName"], row["WatermarkColumn"])
          activity_end_activity = get_row_processed_cnt(row["DestinationLocation"], audit_activityId)
                    
        except Exception as e:
          error_msg="ERROR: error in Incremental loading for {subsystem} - {table} table"+ ":  " + str(e).format(table=row["TableName"], subsystem=row["SubSystem"])
          # write audit details
          activity_fail_spoc_statement = upd_activity_erorlog(activityID=audit_activityId, errorMsg=error_msg, batchId=audit_batchId, addlInfo=None)
          audit_activity_end = sql_server.execute_spoc_without_output(activity_fail_spoc_statement)
      else:
        try:
          print('********  FULL load has begun for '+ row["SourceDBName"]+"."+row["TableName"] + ' ********')
          #Read data from Code orange datalake trusted layer
          df = spark.read.format(c_src_file_format).load(row["SourceLocation"])
          
          #Create table with path using DataFrame's schema and write data to it
          target_table_name = row["SourceDBName"]+"."+row["TableName"]
          df.write.format(c_src_file_format).mode("overwrite").option("path", row["DestinationLocation"]).saveAsTable(target_table_name)
          
          if row["StgActive"] == '1':
            src_query =  row["SourceSQL"].rstrip('\r').rstrip('\n')
            stg_tbl_nm = c_stg_dbname + '.' + row["TableName"]
            query_df = spark.sql(src_query)
            query_df = DataFrameWriter(query_df)
            query_df.jdbc(url=jdbcConnStrUrl, table=stg_tbl_nm, mode="overwrite")
          
          # write audit details
          spoc_end_watermark=end_waterMark(row["SubSystem"],row["TableName"], row["WatermarkColumn"])
          activity_end_activity = get_row_processed_cnt(row["DestinationLocation"], audit_activityId)
          
        except Exception as e:
          error_msg="ERROR: error in FULL loading for {subsystem} - {table} table ".format(table=row["TableName"], subsystem=row["SubSystem"]) + ":" + str(e)
          # write audit details
          activity_log = upd_activity_erorlog(activityID=audit_activityId, errorMsg=error_msg, batchId=audit_batchId, addlInfo=None)
    else:
      error_msg="ERROR: Source File Location-{fileloc} doesn't exist.".format(fileloc=row["SourceLocation"])
      # write audit details
      activity_log = upd_activity_erorlog(activityID=audit_activityId, errorMsg=error_msg, batchId=audit_batchId, addlInfo=None)
  else:
    error_msg="ERROR: Except source {expct_source} ,others like-{source} isn't supported.".format(expct_source=p_SourceSystem, source=row["SourceName"])
    # write audit details
    activity_log = upd_activity_erorlog(activityID=audit_activityId, errorMsg=error_msg, batchId=audit_batchId, addlInfo=None)

close_batchId = end_batchId(audit_batchId)