In [0]:
from pyspark.sql import DataFrame, Window, functions as F
from datetime import timedelta,datetime

In [0]:
data = [
        (1, 'James','Driver',15,datetime.today()-timedelta(days = 12)),
        (1, 'James','Teacher',18,datetime.today() - timedelta(days = 10)),
        (1 , 'James','Engineer',23,datetime.today()-timedelta(days = 8)),
        (4,'Abhishek','architect',28,datetime.today()-timedelta(days = 7)),
        (5,'Jeefron','CEO',67,datetime.today()-timedelta(days=6))
      ]


In [0]:
df = spark.createDataFrame(data = data ,schema = ['id','name','occupation','age','date_column'])

In [0]:
from pyspark.sql import functions as F
from pyspark.sql import Window
from pyspark.sql import DataFrame

class SCD2:
    def __init__(self, primary_keys: list[str], order_column: str, delete_condition: str):
        self.primary_keys = primary_keys
        self.order_column = order_column
        self.delete_condition = delete_condition

    def apply_scd2(self, df: DataFrame) -> DataFrame:
        """
        Apply SCD2 transformations to the source DataFrame.
        """
        window_spec = Window.partitionBy(*self.primary_keys).orderBy(self.order_column)

        # Step 1: Add DateTimeValidFrom and DateTimeValidTo to a struct
        df = df.withColumn(
            "__versioned",
            F.struct(
                F.col(self.order_column).alias("DateTimeValidFrom"),
                F.lead(self.order_column).over(window_spec).alias("DateTimeValidTo")
            )
        )

        # Step 2: Adjust DateTimeValidTo with a 1-second lag
        df = df.withColumn(
            "__versioned",
            F.col("__versioned").withField(
                "DateTimeValidTo",
                F.when(F.col("__versioned.DateTimeValidTo").isNotNull(),
                       F.col("__versioned.DateTimeValidTo") - F.expr("INTERVAL 1 SECOND")
                      ).otherwise(F.lit(None))
            )
        )

        # Step 3: Add DeleteFlag
        df = df.withColumn(
            "__versioned",
            F.col("__versioned").withField(
                "DeleteFlag",
                F.when(F.expr(self.delete_condition), F.lit(True)).otherwise(F.lit(False))
            )
        )

        # Step 4: Update DateTimeValidTo for deleted records
        df = df.withColumn(
            "__versioned",
            F.col("__versioned").withField(
                "DateTimeValidTo",
                F.when(
                    (F.col("__versioned.DeleteFlag") == True) &
                    F.col("__versioned.DateTimeValidTo").isNull(),
                    F.current_timestamp()
                ).otherwise(F.col("__versioned.DateTimeValidTo"))
            )
        )

        return df

    def catch(self, target_table: str) -> bool:
        """
        Check if the target table exists in the catalog.
        """
        try:
            spark.read.table(target_table)
            return True
        except Exception:
            return False

    def execute(self, df: DataFrame, target_table: str) -> DataFrame:
        """
        Execute the SCD2 process, merging the source DataFrame with the target table.
        """
        # Apply SCD2 transformations to the source DataFrame
        source_df = self.apply_scd2(df)

        # Check if target table exists
        if self.catch(target_table):
            # Load the target table
            target_df = spark.read.table(target_table)
        else:
            # Create an empty DataFrame with the source DataFrame schema if table doesn't exist
            target_df = spark.createDataFrame(spark.sparkContext.emptyRDD(), source_df.schema)

        # Union source and target DataFrames
        unioned_df = target_df.unionByName(source_df, allowMissingColumns=True)

        # Reapply SCD2 transformations after the union
        final_df = self.apply_scd2(unioned_df)

        return final_df


In [0]:
# Create a DataFrame with the sample data
df = spark.createDataFrame(data, schema=['id', 'name', 'occupation', 'age', 'date_column'])

# Initialize the SCD2 object
scd2_object = SCD2(["id"], "date_column", "id = 5")

# Specify the target table name
target_table_name = "target_table"

# Execute the SCD2 process
df_final = scd2_object.execute(df, target_table_name)

# Save the final DataFrame to the target table
df_final.write.mode('overwrite').saveAsTable(target_table_name)

# Display the saved target table
spark.read.table(target_table_name).display()


id,name,occupation,age,date_column,__versioned
1,Alice,Engineer,30,2023-11-19 10:00:00,"List(2023-11-19 10:00:00, 2024-11-08 13:56:24.79351, false)"
1,James,Driver,15,2024-11-08 13:56:25.79351,"List(2024-11-08 13:56:25.79351, 2024-11-10 13:56:24.793532, false)"
1,James,Teacher,18,2024-11-10 13:56:25.793532,"List(2024-11-10 13:56:25.793532, 2024-11-12 13:56:24.793564, false)"
1,James,Engineer,23,2024-11-12 13:56:25.793564,"List(2024-11-12 13:56:25.793564, null, false)"
2,Bob,Doctor,35,2023-11-19 11:00:00,"List(2023-11-19 11:00:00, null, false)"
3,Charlie,Artist,25,2023-11-19 12:00:00,"List(2023-11-19 12:00:00, null, false)"
4,Abhishek,architect,28,2024-11-13 13:56:25.793571,"List(2024-11-13 13:56:25.793571, null, false)"
5,Eve,Scientist,40,2023-11-19 13:00:00,"List(2023-11-19 13:00:00, 2024-11-14 13:56:24.793575, true)"
5,Jeefron,CEO,67,2024-11-14 13:56:25.793575,"List(2024-11-14 13:56:25.793575, 2024-11-20 13:56:40.086, true)"
