In [0]:
%run ./CommonUtils

In [0]:
%run ./CreateOrReplaceTables

In [0]:
from pyspark.sql.functions import current_timestamp

In [0]:
class SCD1:
    def __init__(self,entity_name,pk,ingestion_layer):
        self.entity_name = entity_name.capitalize()
        self.pk = pk.split(",")
        self.ingestion_layer = ingestion_layer

    def create_src_view(self):
        src_df = spark.read.parquet(f"{get_mnt('landing')}/{self.entity_name}/{self.entity_name}.parquet")
        src_df = src_df.withColumn("processing_dttm",current_timestamp())
        src_df.registerTempTable("source_tbl")
        return src_df,'source_tbl'
    
    def check_or_create_table(self,src_df):
        table_count = spark.sql(f"show tables in {get_catalog_name(self.entity_name,self.ingestion_layer)}").select('tableName').filter(f"tableName = '{self.entity_name.lower().split('_inc')[0]}'").count()
        print(table_count)
        if table_count ==0:
            print("Table doesnt exist. Creating schema")
            crt_execute(self.entity_name,self.ingestion_layer,truncate_flag=True,src_df=src_df)

        target_tbl = f"hive_metastore.{get_catalog_name(self.entity_name,self.ingestion_layer)}.{self.entity_name.split('_inc')[0]}"
        return target_tbl
    
    def execute_scd_1(self):

        src_df,source_tbl = self.create_src_view()
        target_tbl = self.check_or_create_table(src_df)
        pk_cond = " AND ".join([f"target.{pk_col} = source.{pk_col}" for pk_col in self.pk])
        update_cond = " , ".join([f"target.{u_col} = source.{u_col}" for u_col in src_df.columns])
        insert_cols= ", ".join(src_df.columns)
        insert_values= ", ".join([f"source.{icol}" for icol in src_df.columns])
        query = f'''
        MERGE INTO {target_tbl} AS target
        USING {source_tbl} AS source
        ON {pk_cond}
        WHEN MATCHED
            THEN 
                UPDATE SET 
                    {update_cond}
        WHEN NOT MATCHED THEN
            INSERT ({insert_cols})
            VALUES ({insert_values});

        '''

        spark.sql(query).display()
# SCD1('Arancione_incremental','ArancioneId',ingestion_layer='bronze').execute_scd_1()

In [0]:
class SCD2:
    def __init__(self,entity_name,pk,ingestion_layer):
        self.entity_name = entity_name.capitalize()
        self.pk = pk.split(",")
        self.ingestion_layer = ingestion_layer

    def table_check_count(self,ingestion_layer):
        table_count = spark.sql(f"show tables in {get_catalog_name(self.entity_name,ingestion_layer)}").select('tableName').filter(f"tableName = '{self.entity_name.lower().split('_inc')[0]}'").count()
        return table_count

    def create_src_view(self):
        src_df = spark.read.table(f'hive_metastore.bronze_incremental_schema.{self.entity_name.lower().split('_inc')[0]}'),withColumnRenamed('lastUpdateDate','sourceLastUpdateDate')
        if self.table_check_count('silver') != 0:
            max_ts = src_df.selectExpr('max(processing_dttm)').collect()[0][0]
            src_df = src_df.filter(f'processing_dttm = "{max_ts}"')
        src_df = src_df.withColumn("start_date",lit(None).cast(TimestampType()))
        src_df = src_df.withColumn("end_date",lit(None).cast(TimestampType()))
        src_df = src_df.withColumn("processing_dttm",current_timestamp())
        src_df.registerTempTable("source_tbl")
        return src_df,'source_tbl'
    
    def check_or_create_table(self,src_df):
        if self.table_check_count('silver') == 0:
            print("Table doesnt exist. Creating schema")
            crt_execute(self.entity_name,self.ingestion_layer,truncate_flag=True,src_df=src_df)

        target_tbl = f"hive_metastore.{get_catalog_name(self.entity_name,self.ingestion_layer)}.{self.entity_name.split('_inc')[0]}"
        return target_tbl
    
    def execute_scd_2(self):

        src_df,source_tbl = self.create_src_view()
        # print("=======printing target before merge======")
        # display(spark.read.table(f"hive_metastore.{get_catalog_name(self.entity_name,self.ingestion_layer)}.{self.entity_name.split('_inc')[0]}"))
        target_tbl = self.check_or_create_table(src_df)
        pk_cond = " AND ".join([f"target.{pk_col} = source.{pk_col}" for pk_col in self.pk])
        update_cond = " , ".join([f"target.{u_col} = source.{u_col}" for u_col in src_df.columns])
        scd2_cond = " OR ".join([f"target.{u_col} != source.{u_col}" for u_col in src_df.columns if u_col not in ['processing_dttm']+self.pk])
        # print(update_cond)
        insert_cols= ", ".join(src_df.columns)
        insert_values= ", ".join([f"source.{icol}" for icol in src_df.columns])
        query = f'''
        MERGE INTO {target_tbl} AS target
        USING {source_tbl} AS source
        ON {pk_cond}
        WHEN MATCHED AND ({scd2_cond})
            THEN 
                UPDATE SET 
                    {update_cond}
        WHEN NOT MATCHED THEN
            INSERT ({insert_cols})
            VALUES ({insert_values});

        '''

        spark.sql(query).display()
        # display(spark.read.table(target_tbl))

# SCD2('Celeste_incremental','TransactionId',ingestion_layer='silver').execute_scd_2()