In [0]:
%run ./CommonUtils

In [0]:
%run ./CreateOrReplaceTables

In [0]:
from pyspark.sql.functions import current_timestamp

In [0]:
class SCD1:
    def __init__(self,entity_name,pk,ingestion_layer):
        self.entity_name = entity_name.capitalize()
        self.pk = pk.split(",")
        self.ingestion_layer = ingestion_layer

    def create_src_view(self):
        src_df = spark.read.parquet(f"{get_mnt('landing')}/{self.entity_name}/{self.entity_name}.parquet")
        src_df = src_df.withColumn("processing_dttm",current_timestamp())
        src_df.registerTempTable("source_tbl")
        return src_df,'source_tbl'
    
    def check_or_create_table(self,src_df):
        table_exist_check = spark.sql(f"show tables in {get_catalog_name(self.entity_name,self.ingestion_layer)}").select('tableName').filter(f"tableName = '{self.entity_name.split('_inc')[0]}'").count() == 1
        if table_exist_check:
            print("we are going into crt_execute")
            print(f"this is table_exist_check flag value: {table_exist_check} and its type is {type(table_exist_check)}")
            crt_execute(self.entity_name,self.ingestion_layer,truncate_flag=True,src_df=src_df)

        target_tbl = f"hive_metastore.{get_catalog_name(self.entity_name,self.ingestion_layer)}.{self.entity_name.split('_inc')[0]}"
        # print("======priting in check or create========")
        # display(spark.read.table(target_tbl))
        return target_tbl
    
    def execute_scd_1(self):

        src_df,source_tbl = self.create_src_view()
        # print("=======printing target before merge======")
        # display(spark.read.table(f"hive_metastore.{get_catalog_name(self.entity_name,self.ingestion_layer)}.{self.entity_name.split('_inc')[0]}"))
        target_tbl = self.check_or_create_table(src_df)
        pk_cond = " AND ".join([f"target.{pk_col} = source.{pk_col}" for pk_col in self.pk])
        update_cond = " , ".join([f"target.{u_col} = source.{u_col}" for u_col in src_df.columns])
        # print(update_cond)
        insert_cols= ", ".join(src_df.columns)
        insert_values= ", ".join([f"source.{icol}" for icol in src_df.columns])
        query = f'''
        MERGE INTO {target_tbl} AS target
        USING {source_tbl} AS source
        ON {pk_cond}
        WHEN MATCHED
            THEN 
                UPDATE SET 
                    {update_cond}
        WHEN NOT MATCHED THEN
            INSERT ({insert_cols})
            VALUES ({insert_values});

        '''

        spark.sql(query).display()
        # display(spark.read.table(target_tbl))

# SCD1('Celeste_incremental','TransactionId',ingestion_layer='bronze').execute_scd_1()

In [0]:
class SCD2:
    def __init__(self,entity_name,pk,ingestion_layer):
        self.entity_name = entity_name.capitalize()
        self.pk = pk.split(",")
        self.ingestion_layer = ingestion_layer

    def create_src_view(self):
        catalog_name = "bronze_incremental_schema" if self.entity_name.endswith('_incremental') else "bronze_schema"
        src_df = spark.read.table(f"{catalog_name}.{self.entity_name.split('_inc')[0]}")
        if self.entity_name.endswith('_incremental'):
            src_df = src_df.withColumn("start_date",current_timestamp())
            src_df = src_df.withColumn("end_date",current_timestamp())
        else:
            src_df = src_df.withColumn("processing_dttm",current_timestamp())
        src_df.registerTempTable("source_tbl")
        return src_df,'source_tbl'
    
    def check_or_create_table(self,src_df):
        table_exist_check = spark.sql(f"show tables in {get_catalog_name(self.entity_name,self.ingestion_layer)}").select('tableName').filter(f"tableName = '{self.entity_name.split('_inc')[0]}'").count() == 1
        if table_exist_check!=1:
            crt_execute(self.entity_name,self.ingestion_layer,truncate_flag=True,src_df=src_df)

        target_tbl = f"hive_metastore.{get_catalog_name(self.entity_name,self.ingestion_layer)}.{self.entity_name.split('_inc')[0]}"
        return target_tbl
    
    def execute_scd_2(self):

        src_df,source_tbl = self.create_src_view()
        target_tbl = self.check_or_create_table(src_df)
        pk_cond = " AND ".join([f"target.{pk_col} = source.{pk_col}" for pk_col in self.pk])
        update_cond = " , ".join([f"target.{u_col} = source.{u_col}" for u_col in src_df.columns])
        insert_cols= ", ".join(src_df.columns)
        insert_values= ", ".join([f"source.{icol}" for icol in src_df.columns])
        query = f'''
        MERGE INTO {target_tbl} AS target
        USING {source_tbl} AS source
        ON {pk_cond}
        WHEN MATCHED
            THEN 
                UPDATE SET 
                    {update_cond}
        WHEN NOT MATCHED THEN
            INSERT ({insert_cols})
            VALUES ({insert_values});

        '''

        spark.sql(query).display()

In [0]:
'TransactionId'.split(',')