In [None]:
# spark_scripts/03_analytic_layer_transformations.py
# This script reads from DLT-generated Silver Delta tables, applies
# analytic transformations including CDC/SCD logic, and writes
# the final analytical tables to GCS (S3 equivalent) as Parquet.

import sys

# from pyspark.sql import SparkSession
from pyspark.sql.functions import (
    col,
    lit,
    current_timestamp,
    date_format,
    to_date,
    sha2,
    concat_ws,
    when,
    max,
    min,
)
from pyspark.sql.window import Window
from pyspark.sql.types import *
import pyspark.sql.functions as F


def process_analytic_layer(
    # spark: SparkSession,
    dlt_storage_path: str,  # Base path of DLT pipeline storage (contains /tables)
    analytic_layer_gcs_path_f: str,  # Base path for writing final analytical tables to GCS
):
    """
    Reads silver Delta tables, applies analytic transformations and CDC/SCD logic,
    and writes to the analytic layer in GCS (Parquet).

    Args:
        spark (SparkSession): The active SparkSession.
        dlt_storage_path (str): Base GCS path to the DLT pipeline's storage location.
        analytic_layer_gcs_path (str): Base GCS path to write the final analytical Parquet tables.
    """
    print(f"Starting analytic layer transformations.")
    print(f"Reading from DLT storage: {dlt_storage_path}")
    print(f"Writing to analytic layer GCS: {analytic_layer_gcs_path_f}")

    project_ = "batch_processing"
    schema = "e_com"

    # Define paths to DLT Silver tables
    silver_sales_path = f"{project_}.{schema}.silver_sales"
    silver_products_path = f"{project_}.{schema}.silver_products"
    silver_customers_path = f"{project_}.{schema}.silver_customers"

    # Define paths for final analytic tables in GCS
    fact_sales_output_path = f"{analytic_layer_gcs_path_f}/fact_daily_sales"
    dim_products_output_path = f"{analytic_layer_gcs_path_f}/dim_products"
    dim_customers_output_path = f"{analytic_layer_gcs_path_f}/dim_customers"

    # --- 1. Process Fact Table (fact_daily_sales) ---
    print("\nProcessing fact_daily_sales...")
    try:
        # Read the latest state of silver_sales (which is already incrementally updated by DLT)
        df_silver_sales = spark.read.table("batch_processing.e_com.silver_sales")

        if df_silver_sales.isEmpty():
            print("No new data in silver_sales. Skipping fact_daily_sales processing.")
        else:
            # Simple transformation for fact table: select relevant columns for analysis
            df_fact_sales = df_silver_sales.select(
                col("order_id"),
                col("customer_id"),
                col("order_date"),  # Use this for partitioning
                col("order_status"),
                col("original_total_amount").cast(DecimalType(38, 4)).alias("original_total_amount"),
                col("calculated_order_total").cast(DecimalType(38, 4)).alias("calculated_order_total"),
                col("total_products_in_order"),
                col("total_quantity_in_order"),
                col("silver_processed_timestamp").alias(
                    "audit_load_timestamp"
                ),  # Audit column
            )
            df_fact_sales = df_fact_sales.withColumn(
                "order_date", to_date(col("order_date"), "YYY-MM-DD")
            )

            # display(df_fact_sales)

            # Write to GCS as Parquet, partitioned by order_date
            # Using 'append' mode because facts are typically append-only.
            # Airflow will then pick up new partitions to load to BigQuery.
            df_fact_sales.write.mode("append").partitionBy("order_date").parquet(
                fact_sales_output_path
            )
            print(
                f"fact_daily_sales written to {fact_sales_output_path} (appended, partitioned by order_date)."
            )

    except Exception as e:
        print(f"Error processing fact_daily_sales: {e}")
        # Re-raise to fail the job if critical

    # --- 2. Process Dimension Table (dim_products) - SCD Type 2 ---
    print("\nProcessing dim_products (SCD Type 2)...")
    try:
        df_silver_products = spark.read.table(silver_products_path)

        # Define schema for the target dim_products table (if it doesn't exist)
        # This will ensure consistent schema during merge operations
        # Add SCD columns: effective_start_date, effective_end_date, current_flag, hash_value
        df_silver_products_with_hash = df_silver_products.withColumn(
            "product_hash",
            sha2(
                concat_ws(
                    "||",
                    col("product_name"),
                    col("product_category"),
                    col("product_price"),
                ),
                256,
            ),
        ).select(
            "product_id",
            "product_name",
            "product_category",
            col("product_price").cast(DecimalType(38, 9)).alias("product_price"),
            "product_hash",
            current_timestamp().alias("effective_start_date"),  # New records start now
        )

        # Check if the target dimension table exists in GCS
        # If it's the first run, create it. Otherwise, perform the SCD merge.
        try:
            df_current_dim_products = spark.read.parquet(dim_products_output_path)
            # display(df_current_dim_products)
            print("Existing dim_products found. Performing SCD Type 2 merge.")

            # Identify new records and changed records
            # New records: in silver but not in current_dim (based on product_id)
            # Changed records: in both, but hash value differs

            # Find records that are NOT in the current dimension (new products)
            new_products = df_silver_products_with_hash.alias("new_p").join(
                df_current_dim_products.filter(col("current_flag") == True).alias(
                    "current_p"
                ),  # Only compare with current active records
                col("new_p.product_id") == col("current_p.product_id"),
                "left_anti",  # Get rows from new_p that are not in current_p
            )

            # Find changed records (existing product_id but changed attributes/hash)
            changed_products = (
                df_silver_products_with_hash.alias("new_p")
                .join(
                    df_current_dim_products.filter(col("current_flag") == True).alias(
                        "current_p"
                    ),
                    col("new_p.product_id") == col("current_p.product_id"),
                    "inner",
                )
                .where(col("new_p.product_hash") != col("current_p.product_hash"))
                .select(col("new_p.*"))
            )  # Select all columns from the new product for the new version

            # Get unchanged records (existing product_id and same hash) to carry them forward
            unchanged_products = (
                df_silver_products_with_hash.alias("new_p")
                .join(
                    df_current_dim_products.filter(col("current_flag") == True).alias(
                        "current_p"
                    ),
                    (col("new_p.product_id") == col("current_p.product_id"))
                    & (col("new_p.product_hash") == col("current_p.product_hash")),
                    "inner",
                )
                .select(col("current_p.*"))
            )  # Keep the existing current record as is

            # Mark old versions of changed products as expired
            expired_products = (
                df_current_dim_products.alias("current_p")
                .join(
                    changed_products.alias("changed_p"),
                    (col("current_p.product_id") == col("changed_p.product_id"))
                    & (col("current_p.current_flag") == True),
                    "inner",
                )
                .withColumn(
                    "effective_end_date",
                    col("changed_p.effective_start_date") - F.expr("INTERVAL 1 DAY"),
                )
                .withColumn("current_flag", lit(False))
                .select(
                    "current_p.product_id",
                    "current_p.product_name",
                    "current_p.product_category",
                    "current_p.product_price",
                    "current_p.product_hash",
                    "current_p.effective_start_date",
                    "effective_end_date",
                    "current_flag",
                )
            )  # Select original columns to maintain schema
            # display(expired_products)

            # Combine all pieces for the new dimension table state
            # 1. Old records that are now expired
            # 2. Old records that are unchanged and still current
            # 3. New records (either truly new, or new versions of changed records)

            # Get records that were current but are NOT in the changed_products set (these remain active or unchanged)
            current_active_unchanged = (
                df_current_dim_products.filter(col("current_flag") == True)
                .alias("c_curr")
                .join(
                    changed_products.alias("c_chg"),
                    col("c_curr.product_id") == col("c_chg.product_id"),
                    "left_anti",
                )
                .select(col("c_curr.*"))
            )

            # Combine all previous non-current records (historicals)
            historical_records = df_current_dim_products.filter(
                col("current_flag") == False
            )

            # New versions of changed products and entirely new products
            new_and_updated_versions = (
                df_silver_products_with_hash.withColumn(
                    "effective_end_date", lit(None).cast(DateType())
                )
                .withColumn("current_flag", lit(True).cast(BooleanType()))
                .select(df_current_dim_products.columns)
            )

            df_final_dim_products = (
                historical_records
                .select(new_and_updated_versions.columns)
                .unionByName(expired_products.select(new_and_updated_versions.columns))
                .unionByName(unchanged_products.select(new_and_updated_versions.columns))
                .unionByName(new_and_updated_versions)
            )

            df_final_dim_products.write.mode("overwrite").partitionBy(
                "effective_end_date"
            ).parquet(dim_products_output_path)
            print(f"dim_products (SCD Type 2) updated to {dim_products_output_path}.")

        except Exception as e:
            # If table doesn't exist, this is the first run
            if "Path does not exist" in str(
                e
            ) or "IllegalArgumentException: Path does not exist" in str(e):
                print("dim_products not found. Initial load for SCD Type 2.")
                df_initial_dim_products = (
                    df_silver_products_with_hash.withColumn(
                        "effective_end_date", lit(None).cast(DateType())
                    )
                    .withColumn("current_flag", lit(True).cast(BooleanType()))
                    .select(
                        "product_id",
                        "product_name",
                        "product_category",
                        col("product_price").cast(DecimalType(38, 9)).alias("product_price"),
                        "product_hash",
                        "effective_start_date",
                        "effective_end_date",
                        "current_flag",
                    )
                )
                df_initial_dim_products.write.mode("overwrite").partitionBy("effective_end_date").parquet(
                    dim_products_output_path
                )
                print(
                    f"Initial dim_products (SCD Type 2) loaded to {dim_products_output_path}."
                )
            else:
                raise e  # Re-raise other unexpected errors

    except Exception as e:
        print(f"Error processing dim_products: {e}")
        # Re-raise to fail the job if critical

    # --- 3. Process Dimension Table (dim_customers) - SCD Type 2 ---
    print("\nProcessing dim_customers (SCD Type 2)...")
    try:
        df_silver_customers = spark.read.table(silver_customers_path)

        # Add SCD columns: effective_start_date, effective_end_date, current_flag, hash_value
        df_silver_customers_with_hash = df_silver_customers.withColumn(
            "customer_hash",
            sha2(
                concat_ws(
                    "||",
                    col("customer_first_name"),
                    col("customer_last_name"),
                    col("customer_email"),
                    col(
                        "customer_registration_date"
                    ),  # Include in hash if changes to this should trigger new SCD record
                    col("customer_country"),
                ),
                256,
            ),
        ).select(
            "customer_id",
            col("customer_first_name").alias("first_name"),
            col("customer_last_name").alias("last_name"),
            col("customer_email").alias("email"),
            col("customer_registration_date").alias("registration_date"),
            col("customer_country").alias("country"),
            "customer_hash",
            current_timestamp().alias("effective_start_date"),  # New records start now
        )

        # Check if the target dimension table exists in GCS
        try:
            df_current_dim_customers = spark.read.parquet(dim_customers_output_path)
            # display(df_current_dim_customers)
            print("Existing dim_customers found. Performing SCD Type 2 merge.")

            # Identify new records (not in current dim_customers)
            new_customers = df_silver_customers_with_hash.alias("new_c").join(
                df_current_dim_customers.filter(col("current_flag") == True).alias(
                    "current_c"
                ),
                col("new_c.customer_id") == col("current_c.customer_id"),
                "left_anti",
            )

            # Identify changed records (existing customer_id but changed attributes/hash)
            changed_customers = (
                df_silver_customers_with_hash.alias("new_c")
                .join(
                    df_current_dim_customers.filter(col("current_flag") == True).alias(
                        "current_c"
                    ),
                    col("new_c.customer_id") == col("current_c.customer_id"),
                    "inner",
                )
                .where(col("new_c.customer_hash") != col("current_c.customer_hash"))
                .select(col("new_c.*"))
            )

            # Mark old versions of changed customers as expired
            expired_customers = (
                df_current_dim_customers.alias("current_c")
                .join(
                    changed_customers.alias("changed_c"),
                    (col("current_c.customer_id") == col("changed_c.customer_id"))
                    & (col("current_c.current_flag") == True),
                    "inner",
                )
                .withColumn(
                    "effective_end_date",
                    col("changed_c.effective_start_date") - F.expr("INTERVAL 1 DAY"),
                )
                .withColumn("current_flag", lit(False))
                .select(
                    "current_c.customer_id",
                    "current_c.first_name",
                    "current_c.last_name",
                    "current_c.email",
                    "current_c.registration_date",
                    "current_c.country",
                    "current_c.customer_hash",
                    "current_c.effective_start_date",
                    "effective_end_date",
                    "current_flag",
                )
            )

            # Combine all pieces for the new dimension table state
            current_active_unchanged = (
                df_current_dim_customers.filter(col("current_flag") == True)
                .alias("c_curr")
                .join(
                    changed_customers.alias("c_chg"),
                    col("c_curr.customer_id") == col("c_chg.customer_id"),
                    "left_anti",
                )
                .select(col("c_curr.*"))
            )

            historical_records = df_current_dim_customers.filter(
                col("current_flag") == False
            )

            new_and_updated_versions = (
                df_silver_customers_with_hash.withColumn(
                    "effective_end_date", lit(None).cast(DateType())
                )
                .withColumn("current_flag", lit(True).cast(BooleanType()))
                .select(df_current_dim_customers.columns)
            )

            df_final_dim_customers = (
                historical_records.select(new_and_updated_versions.columns)
                .unionByName(expired_customers.select(new_and_updated_versions.columns))
                .unionByName(current_active_unchanged.select(new_and_updated_versions.columns))
                .unionByName(new_and_updated_versions)
            )

            df_final_dim_customers.write.mode("overwrite").partitionBy(
                "effective_end_date"
            ).parquet(dim_customers_output_path)
            print(f"dim_customers (SCD Type 2) updated to {dim_customers_output_path}.")

        except Exception as e:
            # If table doesn't exist, this is the first run
            if "Path does not exist" in str(
                e
            ) or "IllegalArgumentException: Path does not exist" in str(e):
                print("dim_customers not found. Initial load for SCD Type 2.")
                df_initial_dim_customers = (
                    df_silver_customers_with_hash.withColumn(
                        "effective_end_date", lit(None).cast(DateType())
                    )
                    .withColumn("current_flag", lit(True).cast(BooleanType()))
                    .select(
                        "customer_id",
                        "first_name",
                        "last_name",
                        "email",
                        "registration_date",
                        "country",
                        "customer_hash",
                        "effective_start_date",
                        "effective_end_date",
                        "current_flag",
                    )
                )
                df_initial_dim_customers.write.mode("overwrite").partitionBy(
                    "effective_end_date"
                ).parquet(dim_customers_output_path)
                print(
                    f"Initial dim_customers (SCD Type 2) loaded to {dim_customers_output_path}."
                )
            else:
                raise e  # Re-raise other unexpected errors

    except Exception as e:
        print(f"Error processing dim_customers: {e}")
        # Re-raise to fail the job if critical

    # spark.stop()
    print("Analytic layer transformations complete.")


if __name__ == "__main__":
    # spark = SparkSession.builder \
    #    .appName("EcomAnalyticLayerTransformations") \
    #    .getOrCreate()

    # These parameters would typically be passed from Airflow
    # if len(sys.argv) != 3:
    print(
        "Usage: 03_analytic_layer_transformations.py <dlt_storage_path> <analytic_layer_gcs_path>"
    )
    dlt_storage_path = "gs://batch-processing-de_dlt_storage/"
    analytics_layer_gcs_path_f = "gs://batch-processing-de_final_data/dlt/"
    print(
        f"using default parameters for the job {dlt_storage_path} {analytics_layer_gcs_path_f}"
    )
    # process_analytic_layer(dlt_storage_path, analytics_layer_gcs_path)

    # dlt_storage_path = sys.argv[1]      # e.g., 'gs://your-project-id-processed-data/dlt_storage'
    # analytic_layer_gcs_path = sys.argv[2] # e.g., 'gs://your-project-id-analytic-layer'
    process_analytic_layer(dlt_storage_path, analytics_layer_gcs_path_f)