In [0]:
class SCD1:
    def __init__(self,entity_name,pk,ingestion_layer):
        self.entity_name = entity_name.capitalize()
        self.pk = pk.split(",")
        self.ingestion_layer = ingestion_layer

    def create_src_view(self):
        src_df = spark.read.parquet(f"{get_mnt('landing')}/{self.entity_name}/{self.entity_name}.parquet")
        src_df = src_df.withColumn("processing_dttm",current_timestamp())
        src_df.registerTempTable("source_tbl")
        return src_df,'source_tbl'
    
    def check_or_create_table(self,src_df):
        table_count = spark.sql(f"show tables in {get_catalog_name(self.entity_name,self.ingestion_layer)}").select('tableName').filter(f"tableName = '{self.entity_name.lower().split('_inc')[0]}'").count()
        print(table_count)
        if table_count ==0:
            print("Table doesnt exist. Creating schema")
            crt_execute(self.entity_name,self.ingestion_layer,truncate_flag=True,src_df=src_df)

        target_tbl = f"hive_metastore.{get_catalog_name(self.entity_name,self.ingestion_layer)}.{self.entity_name.split('_inc')[0]}"
        return target_tbl
    
    def execute_scd_1(self):

        src_df,source_tbl = self.create_src_view()
        target_tbl = self.check_or_create_table(src_df)
        pk_cond = " AND ".join([f"target.{pk_col} = source.{pk_col}" for pk_col in self.pk])
        update_cond = " , ".join([f"target.{u_col} = source.{u_col}" for u_col in src_df.columns])
        insert_cols= ", ".join(src_df.columns)
        insert_values= ", ".join([f"source.{icol}" for icol in src_df.columns])
        query = f'''
        MERGE INTO {target_tbl} AS target
        USING {source_tbl} AS source
        ON {pk_cond}
        WHEN MATCHED
            THEN 
                UPDATE SET 
                    {update_cond}
        WHEN NOT MATCHED THEN
            INSERT ({insert_cols})
            VALUES ({insert_values});

        '''

        spark.sql(query).display()
# SCD1('Arancione_incremental','ArancioneId',ingestion_layer='bronze').execute_scd_1()

In [0]:
class SCD2:
    def __init__(self, entity_name, pk, ingestion_layer):
        self.entity_name = entity_name.capitalize()
        self.pk = pk.split(",")
        self.ingestion_layer = ingestion_layer

    def table_check_count(self):
        table_count = (
            spark.sql(
                f"show tables in {get_catalog_name(self.entity_name,self.ingestion_layer)}"
            )
            .select("tableName")
            .filter(f"tableName = '{self.entity_name.lower().split('_inc')[0]}'")
            .count()
        )
        return table_count
    
    def check_or_create_table(self,src_df):
        table_count = spark.sql(f"show tables in {get_catalog_name(self.entity_name,self.ingestion_layer)}").select('tableName').filter(f"tableName = '{self.entity_name.lower().split('_inc')[0]}'").count()
        print(table_count)
        if table_count ==0:
            print("Table doesnt exist. Creating schema")
            src_df = src_df.withColumn("start_date",lit(None).cast(TimestampType()))
            src_df = src_df.withColumn("end_date",lit(None).cast(TimestampType()))
            src_df = src_df.withColumn("processing_dttm",lit(None).cast(TimestampType()))
            crt_execute(self.entity_name,self.ingestion_layer,truncate_flag=True,src_df=src_df)

        target_tbl = f"hive_metastore.{get_catalog_name(self.entity_name,self.ingestion_layer)}.{self.entity_name.split('_inc')[0]}"
        return target_tbl

    def check_columns_presence(self, source_df, target_df, metadata_cols):
        """
        Check if all columns from the target DataFrame are present in the source DataFrame.

        Args:
            source_df (pyspark.sql.DataFrame): Source DataFrame.
            target_df (pyspark.sql.DataFrame): Target DataFrame.

        Raises:
            Exception: If columns are missing in the source DataFrame.

        Returns:
            None
        """
        cols_missing = set([cols for cols in target_df.columns if cols not in source_df.columns]) - set(metadata_cols)
        if cols_missing:
            raise Exception(f"Cols missing in source DataFrame: {cols_missing}")

    def apply_hash_and_alias(self, source_df, target_df, metadata_cols) -> ([DataFrame, DataFrame]):
        """
        Apply hash calculation and alias to source and target DataFrames.

        Args:
            source_df (pyspark.sql.DataFrame): Source DataFrame.
            target_df (pyspark.sql.DataFrame): Target DataFrame.
            metadata_cols (list): List of metadata columns to exclude from hash calculation.

        Returns:
            tuple: Tuple containing aliased source DataFrame and aliased target DataFrame.
        """
        # Extract columns from target DataFrame excluding metadata columns
        tgt_cols = [x for x in target_df.columns if x not in metadata_cols]

        # Calculate hash expression
        hash_expr = md5(concat_ws("|", *[col(c) for c in tgt_cols]))

        # Apply hash calculation and alias to source and target DataFrames
        source_df = source_df.withColumn("hash_value", hash_expr).alias("source_df")
        target_df = target_df.withColumn("hash_value", hash_expr).alias("target_df")

        return source_df, target_df
    
    def create_src_df(self):
        src_df = spark.read.table(
            f"hive_metastore.bronze_incremental_schema.{self.entity_name.lower().split('_inc')[0]}")
        if 'lastUpdateDate' in src_df.columns:
            src_df=src_df.withColumnRenamed("lastUpdateDate", "sourceLastUpdateDate")
        if self.table_check_count() != 0:
            max_ts = src_df.selectExpr("max(processing_dttm)").collect()[0][0]
            src_df = src_df.filter(f'processing_dttm = "{max_ts}"')
        src_df = src_df.drop('processing_dttm')
        return src_df


    def scd_2(self, source_df, tgt_df, join_keys, metadata_cols=None) -> DataFrame:
        if metadata_cols is None:
            metadata_cols = ['start_date', 'end_date','processing_dttm']
        tgt_cols = [x for x in tgt_df.columns]
        self.check_columns_presence(source_df, tgt_df, metadata_cols)
        # Apply hash calculation and alias

        tgt_untouched_df = tgt_df.filter('end_date is not null')
        target_df = tgt_df.filter('end_date is null')



        source_df, target_df = self.apply_hash_and_alias(source_df, target_df, metadata_cols)

        # Identify new records
        join_cond = [source_df[join_key] == target_df[join_key] for join_key in join_keys]
        new_df = source_df.join(target_df, join_cond, 'left_anti')

        base_df = target_df.join(source_df, join_cond, 'left')


        # Filter unchanged records or same records
        unchanged_filter_expr = " AND ".join([f"source_df.{key} IS NULL" for key in join_keys])
        unchanged_df = base_df.filter(f"({unchanged_filter_expr}) OR "
                                      f"(source_df.hash_value = target_df.hash_value)") \
            .select("target_df.*")

        # identify updated records
        delta_filter_expr = " and ".join([f"source_df.{key} IS NOT NULL" for key in join_keys])
        updated_df = base_df.filter(f"{delta_filter_expr} AND "
                                    f"source_df.hash_value != target_df.hash_value")


        # pick updated records from source_df for new entry
        updated_new_df = updated_df.select("source_df.*")

        # pick updated records from target_df for obsolete entry
        obsolete_df = updated_df.select("target_df.*") \
            .withColumn("end_date", current_timestamp())

        # union : new & updated records and add scd2 meta-deta
        delta_df = new_df.union(updated_new_df) \
            .withColumn("start_date", current_timestamp()) \
            .withColumn("end_date", lit(None)) \
            .withColumn("processing_dttm", current_timestamp())

        # union all datasets : delta_df + obsolete_df + unchanged_df
        result_df = unchanged_df.select(tgt_cols). \
            unionByName(delta_df.select(tgt_cols)). \
            unionByName(obsolete_df.select(tgt_cols)). \
            unionByName(tgt_untouched_df.select(tgt_cols))
  
        return result_df
    
    def execute_scd_2(self):
        source_df = self.create_src_df()
        silver_table_name = self.check_or_create_table(source_df)
        # silver_table_name = f"hive_metastore.{get_catalog_name(self.entity_name,self.ingestion_layer)}.{self.entity_name.split('_inc')[0]}"

        target_df = spark.read.table(silver_table_name)
        scd2_df= self.scd_2(source_df, target_df, join_keys=self.pk, metadata_cols=None)

        source_tbl='scd2_table'
        scd2_df.registerTempTable(source_tbl)

        pk = self.pk+['start_date']
        pk_cond = " AND ".join([f"target.{pk_col} = source.{pk_col}" for pk_col in pk])

        insert_cols= ", ".join(scd2_df.columns)
        insert_values= ", ".join([f"source.{icol}" for icol in scd2_df.columns])

        query = f'''
        MERGE INTO {silver_table_name} AS target
        USING {source_tbl} AS source
        ON {pk_cond}
        WHEN MATCHED AND (source.end_date is not NULL and target.end_date is NULL)
            THEN 
                UPDATE SET 
                    target.end_date = CURRENT_TIMESTAMP
        WHEN NOT MATCHED THEN
            INSERT ({insert_cols})
            VALUES ({insert_values});

        '''

        spark.sql(query).display()

In [0]:
# class SCD2:
#     def __init__(self, entity_name, pk, ingestion_layer):
#         self.entity_name = entity_name.capitalize()
#         self.pk = pk.split(",")
#         self.ingestion_layer = ingestion_layer

#     def table_check_count(self):
#         table_count = (
#             spark.sql(
#                 f"show tables in {get_catalog_name(self.entity_name,self.ingestion_layer)}"
#             )
#             .select("tableName")
#             .filter(f"tableName = '{self.entity_name.lower().split('_inc')[0]}'")
#             .count()
#         )
#         return table_count

#     def create_src_view(self):
#         src_df = spark.read.table(
#             f"hive_metastore.bronze_incremental_schema.{self.entity_name.lower().split('_inc')[0]}"
#         ).withColumnRenamed("lastUpdateDate", "sourceLastUpdateDate")
#         if self.table_check_count() != 0:
#             max_ts = src_df.selectExpr("max(processing_dttm)").collect()[0][0]
#             src_df = src_df.filter(f'processing_dttm = "{max_ts}"')
#         src_df = src_df.drop('processing_dttm')
#         return src_df

#     def check_or_create_table(self, src_df):
#         if self.table_check_count() == 0:
#             print("Table doesnt exist. Creating schema")
#             crt_execute(
#                 self.entity_name,
#                 self.ingestion_layer,
#                 truncate_flag=True,
#                 src_df=src_df,
#             )

#         target_tbl = f"hive_metastore.{get_catalog_name(self.entity_name,self.ingestion_layer)}.{self.entity_name.split('_inc')[0]}"
#         return target_tbl

#     def execute_scd_2(self):
#         src_df, source_tbl = self.create_src_view()
#         # print("=======printing target before merge======")
#         # display(spark.read.table(f"hive_metastore.{get_catalog_name(self.entity_name,self.ingestion_layer)}.{self.entity_name.split('_inc')[0]}"))
#         target_tbl = self.check_or_create_table(src_df)
#         pk_cond = " AND ".join(
#             [f"target.{pk_col} = source.{pk_col}" for pk_col in self.pk]
#         )
#         update_cond = " , ".join(
#             [f"target.{u_col} = source.{u_col}" for u_col in src_df.columns]
#         )
#         scd2_cond = " OR ".join(
#             [
#                 f"target.{u_col} != source.{u_col}"
#                 for u_col in src_df.columns
#                 if u_col not in ["processing_dttm","start_date",'end_date'] + self.pk
#             ]
#         )
#         # print(update_cond)
#         insert_cols = ", ".join(
#             [i for i in src_df.columns if i not in ("start_date", "end_date")]
#         )
#         insert_values = ", ".join(
#             [
#                 f"{i}"
#                 for i in src_df.columns
#                 if i not in ("start_date", "end_date")
#             ]
#         )

#         query = f"""
#         MERGE INTO {target_tbl} AS target
#         USING {source_tbl} AS source
#         ON {pk_cond}
#         WHEN MATCHED AND (target.end_date is NULL) AND ({scd2_cond}) THEN
#             UPDATE SET 
#         target.end_date = CURRENT_TIMESTAMP
    
#         WHEN NOT MATCHED THEN
#             INSERT ({insert_cols}, target.start_date, target.end_date)
#                 VALUES ({insert_values}, CURRENT_TIMESTAMP, NULL)
#         """

#         spark.sql(query).display()

#         tgt_df = spark.read.table(f"hive_metastore.silver_incremental_schema.{self.entity_name.split('_inc')[0]}").filter('end_date is null')

#         tgt_df.registerTempTable('target_table')
#         src_df.registerTempTable('source_table')

#         ingest_df = spark.sql(
#             f'''
#             select source.* from source_table source inner join target_table target on {pk_cond} and
#             ({scd2_cond})
#             '''
#         )
#         ingest_df = ingest_df.withColumn('start_date',current_timestamp())
#         ingest_df = ingest_df.union(tgt_df).distinct()
#         ingest_df.write.mode("append").saveAsTable(f"hive_metastore.silver_incremental_schema.{self.entity_name.split('_inc')[0]}")

# # SCD2("Arancione_incremental", "ArancioneID", ingestion_layer="silver").execute_scd_2()