In [0]:
# ==============================
# WIDGET
# ==============================
dbutils.widgets.text(
    name="config_path",
    defaultValue="/Workspace/Users/ud3041@gmail.com/end-to-end-ETL-pipeline/medallion/silver/config_yfinance.json",
    label="Config File Path"
)


# =========================
# IMPORTS
# =========================
import json
from pyspark.sql import functions as F
from pyspark.sql.types import DateType, BooleanType, StringType
from delta.tables import DeltaTable

from utils.logger import get_logger
from utils.sparksession import create_spark_session

# =========================
# INIT LOGGER & SPARK
# =========================
logger = get_logger("b2s_yfinance_scd2_safe")
spark = create_spark_session("B2S | YFinance SCD2 SAFE")

# =========================
# LOAD CONFIG
# =========================
config_path = dbutils.widgets.get("config_path")

with open(config_path, "r") as f:
    config = json.load(f)

CATALOG = config["catalog"]
BRONZE_SCHEMA = config["bronze_schema"]
SILVER_SCHEMA = config["silver_schema"]
TABLES = config["tables"]

# =========================
# UTILS
# =========================
def table_exists(table_name: str) -> bool:
    try:
        spark.table(table_name)
        return True
    except Exception:
        return False


def ensure_scd_columns(table_name: str):
    """
    Add missing SCD columns safely
    """
    existing_cols = {c.name for c in spark.table(table_name).schema}

    alter_stmts = []

    if "row_hash" not in existing_cols:
        alter_stmts.append("ADD COLUMN row_hash STRING")

    if "effective_from" not in existing_cols:
        alter_stmts.append("ADD COLUMN effective_from DATE")

    if "effective_to" not in existing_cols:
        alter_stmts.append("ADD COLUMN effective_to DATE")

    if "is_current" not in existing_cols:
        alter_stmts.append("ADD COLUMN is_current BOOLEAN")

    if alter_stmts:
        stmt = f"ALTER TABLE {table_name} " + " ".join(alter_stmts)
        spark.sql(stmt)
        logger.info(f"Added missing SCD columns to {table_name}")


# =========================
# PROCESS TABLES
# =========================
for table in TABLES:

    table_name = table["name"]
    business_keys = table["business_key"]
    tracked_columns = table["tracked_columns"]
    hash_column = table["hash_column"]

    logger.info(f"Processing table: {table_name}")

    bronze_df = spark.table(f"`{CATALOG}`.{BRONZE_SCHEMA}.{table_name}")

    # -----------------------------------------
    # HASH
    # -----------------------------------------
    bronze_df = bronze_df.withColumn(
        hash_column,
        F.sha2(
            F.concat_ws(
                "||",
                *[F.col(c).cast("string") for c in tracked_columns]
            ),
            256
        )
    )

    target_table = f"`{CATALOG}`.{SILVER_SCHEMA}.{table_name}"

    # -----------------------------------------
    # FIRST RUN → CREATE SILVER TABLE
    # -----------------------------------------
    if not table_exists(target_table):
        logger.info(f"Creating Silver table: {target_table}")

        (
            bronze_df
            .withColumn("effective_from", F.current_date())
            .withColumn("effective_to", F.lit(None).cast(DateType()))
            .withColumn("is_current", F.lit(True))
            .write
            .format("delta")
            .mode("overwrite")
            .saveAsTable(target_table)
        )

        continue  # move to next table

    # -----------------------------------------
    # ENSURE SCD COLUMNS EXIST
    # -----------------------------------------
    ensure_scd_columns(target_table)

    delta_table = DeltaTable.forName(spark, target_table)

    # -----------------------------------------
    # BUILD MERGE CONDITION (COMPOSITE KEYS)
    # -----------------------------------------
    merge_condition = " AND ".join(
        [f"t.{k} = s.{k}" for k in business_keys]
    ) + " AND t.is_current = true"

    # -----------------------------------------
    # 1️⃣ EXPIRE CHANGED ROWS
    # -----------------------------------------
    delta_table.alias("t") \
        .merge(
            bronze_df.alias("s"),
            merge_condition
        ) \
        .whenMatchedUpdate(
            condition=f"t.{hash_column} <> s.{hash_column}",
            set={
                "effective_to": "current_date()",
                "is_current": "false"
            }
        ) \
        .execute()

    # -----------------------------------------
    # 2️⃣ INSERT NEW + CHANGED ROWS
    # -----------------------------------------
    silver_df = spark.table(target_table)

    join_expr = [
        bronze_df[k] == silver_df[k] for k in business_keys
    ] + [silver_df.is_current == True]

    new_rows_df = (
        bronze_df.alias("s")
        .join(
            silver_df.alias("t"),
            join_expr,
            "left_anti"
        )
    )

    (
        new_rows_df
        .withColumn("effective_from", F.current_date())
        .withColumn("effective_to", F.lit(None).cast(DateType()))
        .withColumn("is_current", F.lit(True))
        .write
        .format("delta")
        .mode("append")
        .saveAsTable(target_table)
    )

    logger.info(f"SCD2 completed for {table_name}")

logger.info("B2S YFinance SCD2 pipeline completed successfully")