In [0]:
import json
from delta.tables import DeltaTable
from pyspark.sql import SparkSession
from pyspark.sql.functions import md5, concat_ws, col, current_date, lit


In [0]:
import json
from pyspark.sql import SparkSession
from delta.tables import DeltaTable

class MergePipeline:
    def __init__(self, 
                 payload: str, 
                 run_id: str, 
                 adls_container_name: str = "nipun", 
                 adls_storage_account: str = "delearningstdfssandbox"):
        self.spark = SparkSession.builder.appName("Merge Pipeline").getOrCreate()
        self.catalog_name = '`nipun-catalog`'
        self.payload = json.loads(payload)
        self.run_id = run_id.lower()
        self.prefix_path = f"abfss://{adls_container_name}@{adls_storage_account}.dfs.core.windows.net"

    def read_data(self):
        staging_type = self.payload['StagingType']
        
        if staging_type == 'parquet':
            return self.read_parquet()
        elif staging_type == 'csv':
            return self.read_csv()
        else:
            raise ValueError("Unsupported StagingType")

    def read_parquet(self):
        folder_path_parquet = self.payload['Path'] + self.run_id
        full_path = self.prefix_path + folder_path_parquet
        return self.spark.read.parquet(full_path)

    def read_csv(self):
        folder_path_csv = f"{self.payload['CsvPath']}{self.payload['DatasetName']}.csv"
        return self.spark.read.format("csv").option("header", "true").load(folder_path_csv)

    def create_database(self):
        schema_name = self.payload['SchemaName'] + "_Nipun"
        create_db_query = f"""CREATE SCHEMA IF NOT EXISTS {self.catalog_name}.{schema_name}"""
        self.spark.sql(create_db_query)
        self.spark.sql(f"USE {self.catalog_name}.{schema_name}")

    def create_table_if_not_exist(self, spark_df):
        schema_name = self.payload['SchemaName'] + "_Nipun"
        dataset_name = self.payload['DatasetName'].lower()

        tables_df = self.spark.sql("SHOW TABLES")
        table_exists = tables_df.filter(tables_df.tableName == dataset_name).count() > 0

        if not table_exists:
            spark_df.write.format('delta').saveAsTable(f'{schema_name}.{dataset_name}')

    def apply_update(self, spark_df):
        update_type = self.payload.get('UpdateType', '')
        load_type = self.payload['LoadType']
        schema_name = self.payload['SchemaName'] + "_Nipun"
        dataset_name = self.payload['DatasetName']
        primary_key_fields = self.payload.get('PrimaryKeyFields', '')
        
        delta_table_path = f"{self.prefix_path}/deltademo/{dataset_name.lower()}"
        delta_table = DeltaTable.forName(self.spark, f'{self.catalog_name}.{schema_name}.{dataset_name}')

        invalid_values = [None, '', 'N/A', 'null']

        def null_empty(value):
            return value.strip() in invalid_values

        def not_null_empty(value):
            return value.strip() not in invalid_values

        # SCD2 update
        if update_type == "scd2":
            return self.apply_scd2(spark_df)

        # OP update
        elif update_type == "op":
            return self.apply_op(spark_df)

        # Overwrite: replacing the existing data with new data.
        elif load_type == "F":
            return spark_df.write.format('delta').mode("overwrite").saveAsTable(f'{schema_name}.{dataset_name}')

        # Insert Only: adding new rows to the existing table without affecting the existing rows.
        elif load_type == "I" and null_empty(primary_key_fields):
            return spark_df.write.mode("append").option('path', delta_table_path).saveAsTable(f"{schema_name}.{dataset_name}")

        # Upsert: inserting new rows and updating existing rows based on a condition (usually the primary key)
        elif load_type == "I" and not_null_empty(primary_key_fields):
            def create_update_condition(primary_keys):
                conditions = [f"target.{key.strip()} = source.{key.strip()}" for key in primary_keys.split(',')]
                return ' AND '.join(conditions)

            if ',' in primary_key_fields:
                update_condition = create_update_condition(primary_key_fields)
            else: 
                update_condition = f"target.{primary_key_fields} = source.{primary_key_fields}"

            delta_table.alias("target").merge(
                spark_df.alias("source"),
                update_condition
            ).whenMatchedUpdateAll().whenNotMatchedInsertAll().execute()
        else:
            raise ValueError("Invalid LoadType or UpdateType provided")

    def apply_scd2(self, spark_df):
        schema_name = self.payload['SchemaName'] + "_Nipun"
        scd_handler = SCDHandler(catalog_name=self.catalog_name, database_name=schema_name)
        result_df = scd_handler.scd_2(
            source_df=spark_df,
            dataset_name=self.payload['DatasetName'],
            database_name=schema_name,
            catalog_name=self.catalog_name,
            join_keys=[key.strip() for key in self.payload['PrimaryKeyFields'].split(',')]
        )
        result_df.write.format('delta').mode('overwrite').saveAsTable(f"{self.catalog_name}.{schema_name}.{self.payload['DatasetName']}")
        return result_df

    def apply_op(self, spark_df):
        schema_name = self.payload['SchemaName'] + "_Nipun"
        dataset_name = self.payload['DatasetName']
        delta_table = DeltaTable.forName(self.spark, f'{self.catalog_name}.{schema_name}.{dataset_name}')
        self.spark.sql(f"DROP TABLE IF EXISTS {schema_name}.{dataset_name}")

        delta_table.alias('tgt').merge(
            source=spark_df.alias('src'),
            condition='\nAND '.join([f'tgt.`{_col}` <=> src.`{_col}`' for _col in self.payload['PartitionFields']])
        ).whenMatchedDelete().execute()

        spark_df.write.format('delta').mode('append').saveAsTable(f'{self.catalog_name}.{schema_name}.{dataset_name}')

    def run_pipeline(self):
        spark_df = self.read_data()
        self.create_database()
        self.create_table_if_not_exist(spark_df)
        result_df = self.apply_update(spark_df)
        result_df.show(truncate=False)


In [0]:
class SCDHandler:
    def __init__(self, catalog_name: str, database_name: str):
        self.spark = SparkSession.builder \
            .appName("SCD2 Example") \
            .config("spark.sql.catalogImplementation", "hive") \
            .enableHiveSupport() \
            .getOrCreate()
        self.spark.sql(f"USE CATALOG {catalog_name}")
        self.spark.sql(f"USE {database_name}")

    def check_columns_presence(self, source_df, target_df, metadata_cols):
        cols_missing = set([cols for cols in target_df.columns if cols not in source_df.columns]) - set(metadata_cols)
        if cols_missing:
            raise Exception(f"Cols missing in source DataFrame: {cols_missing}")

    def apply_hash_and_alias(self, source_df, target_df, metadata_cols):
        tgt_cols = [x for x in target_df.columns if x not in metadata_cols]
        hash_expr = md5(concat_ws("|", *[col(c) for c in tgt_cols]))
        source_df = source_df.withColumn("hash_value", hash_expr).alias("source_df")
        target_df = target_df.withColumn("hash_value", hash_expr).alias("target_df")
        return source_df, target_df

    def scd_2(self, source_df, dataset_name, database_name, catalog_name, join_keys, metadata_cols=None):
        if metadata_cols is None:
            metadata_cols = ['eff_start_date', 'eff_end_date', 'flag']

        target_df = DeltaTable.forName(self.spark, f'{catalog_name}.{database_name}.{dataset_name}').toDF()
        self.check_columns_presence(source_df, target_df, metadata_cols)

        source_df, target_df = self.apply_hash_and_alias(source_df, target_df, metadata_cols)

        join_cond = [source_df[join_key] == target_df[join_key] for join_key in join_keys]
        new_df = source_df.join(target_df, join_cond, 'left_anti')

        base_df = target_df.join(source_df, join_cond, 'left')
        unchanged_filter_expr = " AND ".join([f"source_df.{key} IS NULL" for key in join_keys])
        unchanged_df = base_df.filter(f"({unchanged_filter_expr}) OR (source_df.hash_value = target_df.hash_value)") \
            .select("target_df.*")

        delta_filter_expr = " and ".join([f"source_df.{key} IS NOT NULL" for key in join_keys])
        updated_df = base_df.filter(f"{delta_filter_expr} AND source_df.hash_value != target_df.hash_value")
        updated_new_df = updated_df.select("source_df.*")
        obsolete_df = updated_df.select("target_df.*") \
            .withColumn("eff_end_date", current_date()) \
            .withColumn("flag", lit(0))

        delta_df = new_df.union(updated_new_df) \
            .withColumn("eff_start_date", current_date()) \
            .withColumn("eff_end_date", lit(None)) \
            .withColumn("flag", lit(1))

        # Drop the 'hash_value' column from the result_df
        result_df = unchanged_df.select(target_df.columns). \
            unionByName(delta_df.select(target_df.columns)). \
            unionByName(obsolete_df.select(target_df.columns)). \
            drop("hash_value")

    
        return result_df
