In [None]:
## CONFIG
# Select runtime mode and I/O namespaces
# RUN_MODE options: "spark" or "snowpark_connect"
RUN_MODE = "snowpark_connect"

# Input (source) and Output (target) namespaces
SOURCE_DATABASE = "SNOWFLAKE_SAMPLE_DATA"
SOURCE_SCHEMA   = "TPCH_SF100"
TARGET_DATABASE = "DELETETHIS"
TARGET_SCHEMA   = "TEST"

# CDC parameters
CDC_SAMPLE_SIZE = 10_000
RANDOM_SEED = 42

print(f"RUN_MODE = {RUN_MODE}")
print(f"Source  = {SOURCE_DATABASE}.{SOURCE_SCHEMA}")
print(f"Target  = {TARGET_DATABASE}.{TARGET_SCHEMA}")


In [None]:
import sys
print(sys.version)


# TPCH → Delta (Open-Source Spark) — Sample Data Engineering Notebook

Now with **configurable database & schema** for both **source** and **target** table locations.


## 0) Environment & Spark Session (Delta Lake)

In [None]:
if RUN_MODE == "spark":
    from pyspark.sql import SparkSession
    spark = (
        SparkSession.builder
        .appName("tpch-delta-pipeline-sample")
        .config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension")
        .config("spark.sql.catalog.spark_catalog", "org.apache.spark.sql.delta.catalog.DeltaCatalog")
        .getOrCreate()
    )


In [None]:
if RUN_MODE == "snowpark_connect":
    !pip install jdk4py
    !pip install deltalake
    !pip install snowpark-connect==0.28.0
    
    # Restart the KERNEL via the UI in SnowflakeNOtebook SPCS after initial install.

    from snowflake import snowpark_connect
    from deltalake import DeltaTable
    spark = snowpark_connect.server.init_spark_session()

In [None]:


print("Spark version:", spark.version)

# Clears all RDDs, DataFrames, and temporary views persisted in memory or disk.
spark.catalog.clearCache()


## 1) Initialize pipeline timing

In [None]:

from datetime import datetime

pipeline_start_ts = datetime.now()
print("Pipeline started at:", pipeline_start_ts.isoformat(timespec="seconds"))


## 2) Helpers: quoting & fully-qualified names

In [None]:

from pyspark.sql import DataFrame

def q(ident: str) -> str:
    """Backtick-quote an identifier, handling dots/backticks safely."""
    if ident is None or ident == "":
        return ""
    ident = ident.replace("`", "``")
    return f"`{ident}`"

def fqtn(database: str, schema: str, table: str) -> str:
    """Build a fully-qualified table name using database & schema.

    - If database and schema are provided: `database`.`schema`.`table`

    - If only database is provided: `database`.`table`

    - If neither provided: just `table`

    """
    parts = []
    if database:
        parts.append(q(database))
    if schema:
        parts.append(q(schema))
    parts.append(q(table))
    return ".".join(parts)

def ensure_namespace(database: str, schema: str):
    """Create database/schema (namespace) if not exists."""
    if database and schema:
        # Two-level namespace
        spark.sql(f"CREATE NAMESPACE IF NOT EXISTS {q(database)}.{q(schema)}")
    elif database:
        # Single-level (Hive-style)
        spark.sql(f"CREATE DATABASE IF NOT EXISTS {q(database)}")

def df_lower_columns(df: DataFrame) -> DataFrame:
    return df.toDF(*[c.lower() for c in df.columns])


In [None]:
spark.sql(f"DROP TABLE IF EXISTS {fqtn(TARGET_DATABASE, TARGET_SCHEMA, 'dim_customer')}")
spark.sql(f"DROP TABLE IF EXISTS {fqtn(TARGET_DATABASE, TARGET_SCHEMA, 'fact_orders')}")
spark.sql(f"DROP TABLE IF EXISTS {fqtn(TARGET_DATABASE, TARGET_SCHEMA, 'agg_customer_revenue')}")

## 3) Read TPCH Source Delta Tables & Lower-case Columns

In [None]:

from pyspark.sql import functions as F

required_tables = [
    "customer", "orders", "lineitem",
    "part", "supplier", "partsupp",
    "nation", "region"
]

# Ensure target namespace exists
ensure_namespace(TARGET_DATABASE, TARGET_SCHEMA)

dfs = {}
for t in required_tables:
    full = fqtn(SOURCE_DATABASE, SOURCE_SCHEMA, t)
    print("Reading:", full)
    df = spark.table(full)
    df = df_lower_columns(df)
    dfs[t] = df

# Optional peek (comment out count for large datasets)
for name, df in list(dfs.items())[:3]:
    print(name, "→", df.columns[:5])


## 4) Transformations: Complex Joins + Window Functions

In [None]:
from pyspark.sql import functions as F
from pyspark.sql import Window

customer = dfs["customer"]
orders = dfs["orders"]
lineitem = dfs["lineitem"]
nation = dfs["nation"]
region = dfs["region"]

cust_geo = (
    customer.alias("c")
    .join(nation.alias("n"), F.col("c.c_nationkey") == F.col("n.n_nationkey"), "left")
    .join(region.alias("r"), F.col("n.n_regionkey") == F.col("r.r_regionkey"), "left")
    .select(
        F.col("c.c_custkey").alias("custkey"),
        F.col("c.c_name").alias("name"),
        F.col("c.c_acctbal").alias("acctbal"),
        F.col("c.c_mktsegment").alias("mktsegment"),
        F.col("n.n_name").alias("nation"),
        F.col("r.r_name").alias("region")
    )
)

order_revenue = (
    lineitem.groupBy("l_orderkey")
    .agg(F.sum(F.col("l_extendedprice") * (1 - F.col("l_discount"))).alias("order_revenue"))
)

orders_enriched = (
    orders.alias("o")
    .join(order_revenue.alias("rev"), F.col("o.o_orderkey") == F.col("rev.l_orderkey"), "left")
    .select(
        F.col("o.o_orderkey").alias("orderkey"),
        F.col("o.o_custkey").alias("custkey"),
        F.col("o.o_orderstatus").alias("orderstatus"),
        F.col("o.o_totalprice").alias("totalprice"),
        F.col("o.o_orderdate").alias("orderdate"),
        F.col("o.o_orderpriority").alias("orderpriority"),
        F.col("o.o_clerk").alias("clerk"),
        F.col("o.o_shippriority").alias("shippriority"),
        F.col("rev.order_revenue").alias("order_revenue")
    )
)

w_latest = Window.partitionBy("custkey").orderBy(F.col("orderdate").desc())
w_revenue_rank = Window.partitionBy("custkey").orderBy(F.col("order_revenue").desc_nulls_last())

orders_w = (
    orders_enriched
    .withColumn("rn_latest", F.row_number().over(w_latest))
    .withColumn("revenue_rank", F.dense_rank().over(w_revenue_rank))
)

w_rolling = (
    Window.partitionBy("custkey")
    .orderBy(F.col("orderdate").cast("timestamp").cast("long"))
    .rangeBetween(-30 * 24 * 3600, 0)
)

orders_w = orders_w.withColumn(
    "rolling_30d_revenue",
    F.sum("order_revenue").over(w_rolling)
)

#orders_w.limit(5).show(truncate=False)

## 5) Create Curated Target Delta Tables

In [None]:

from pyspark.sql import functions as F

DIM_CUSTOMER_TBL = fqtn(TARGET_DATABASE, TARGET_SCHEMA, "dim_customer")
FACT_ORDERS_TBL  = fqtn(TARGET_DATABASE, TARGET_SCHEMA, "fact_orders")
AGG_REVENUE_TBL  = fqtn(TARGET_DATABASE, TARGET_SCHEMA, "agg_customer_revenue")

dim_customer_df = (
    cust_geo.dropDuplicates(["custkey"])
)

fact_orders_df = orders_w.select(
    "orderkey", "custkey", "orderstatus", "totalprice", "orderdate",
    "orderpriority", "clerk", "shippriority", "order_revenue",
    "rn_latest", "revenue_rank", "rolling_30d_revenue"
)

agg_revenue_df = (
    fact_orders_df.groupBy("custkey")
    .agg(
        F.countDistinct("orderkey").alias("order_cnt"),
        F.sum("order_revenue").alias("lifetime_revenue"),
        F.max("orderdate").alias("last_order_date")
    )
)

for t in [DIM_CUSTOMER_TBL, FACT_ORDERS_TBL, AGG_REVENUE_TBL]:
    spark.sql(f"DROP TABLE IF EXISTS {t}")

dim_customer_df.write.mode("overwrite").format("delta").saveAsTable(DIM_CUSTOMER_TBL)
fact_orders_df.write.mode("overwrite").format("delta").saveAsTable(FACT_ORDERS_TBL)
agg_revenue_df.write.mode("overwrite").format("delta").saveAsTable(AGG_REVENUE_TBL)

print("Created/overwritten:")
for t in [DIM_CUSTOMER_TBL, FACT_ORDERS_TBL, AGG_REVENUE_TBL]:
    print(" -", t)


## 6) Simulate Incremental CDC (10K rows) and Apply to FACT

In [None]:

from pyspark.sql import functions as F
from pyspark.sql.window import Window

sim_src = (
    orders_enriched
    .withColumn("rnd", F.rand(seed=RANDOM_SEED))
    .withColumn("rn", F.row_number().over(Window.orderBy("rnd")))
    .filter(F.col("rn") <= CDC_SAMPLE_SIZE)
    .drop("rnd","rn")
)

sim_cdc = (
    sim_src
    .withColumn("p", F.rand(seed=RANDOM_SEED))
    .withColumn(
        "op",
        F.when(F.col("p") < 0.30, F.lit("I"))
         .when(F.col("p") < 0.60, F.lit("D"))
         .otherwise(F.lit("U"))
    )
    .drop("p")
)

INSERT_KEY_OFFSET = 10_000_000
sim_cdc = sim_cdc.withColumn(
    "merge_orderkey",
    F.when(F.col("op") == "I", F.col("orderkey") + F.lit(INSERT_KEY_OFFSET)).otherwise(F.col("orderkey"))
)

sim_cdc = sim_cdc.withColumn(
    "orderpriority",
    F.when(F.col("op") == "U", F.concat_ws("-", F.col("orderpriority"), F.lit("upd"))).otherwise(F.col("orderpriority"))
)

sim_cdc = sim_cdc.select(
    F.col("merge_orderkey").alias("orderkey"),
    "custkey", "orderstatus", "totalprice", "orderdate",
    "orderpriority", "clerk", "shippriority", "order_revenue", "op"
).cache()

print("CDC counts:")
sim_cdc.groupBy("op").count().show()


### 6a) Apply CDC via Delta MERGE (fallback to DataFrame)

In [None]:
from pyspark.sql import functions as F
from pyspark.sql.utils import AnalysisException


# Assuming TARGET_DATABASE, TARGET_SCHEMA, and sim_cdc are defined
# NOTE: spark and fqtn must be defined/available in your execution environment.
# TARGET_DATABASE, TARGET_SCHEMA, and sim_cdc (a DataFrame) must also be available.
FACT_ORDERS_TBL = fqtn(TARGET_DATABASE, TARGET_SCHEMA, "fact_orders")

# Create the Delta table if it doesn't exist
spark.sql(f"""
    CREATE TABLE IF NOT EXISTS {FACT_ORDERS_TBL} (
        orderkey BIGINT,
        custkey BIGINT,
        orderstatus STRING,
        totalprice DOUBLE,
        orderdate DATE,
        orderpriority STRING,
        clerk STRING,
        shippriority INT,
        order_revenue DOUBLE,
        rn_latest INT,
        revenue_rank INT,
        rolling_30d_revenue DOUBLE
    ) USING DELTA
""")

# Ensure proper column types
spark.sql(f"ALTER TABLE {FACT_ORDERS_TBL} ALTER COLUMN orderstatus TYPE STRING")
spark.sql(f"ALTER TABLE {FACT_ORDERS_TBL} ALTER COLUMN orderpriority TYPE STRING")
spark.sql(f"ALTER TABLE {FACT_ORDERS_TBL} ALTER COLUMN clerk TYPE STRING")

def apply_cdc_operations_merge(source_df, target_table_name):
    """
    Applies Change Data Capture (CDC) operations (I, U, D) to a target table
    using a single, atomic MERGE INTO operation.

    Args:
        source_df (DataFrame): The source DataFrame containing CDC records with 'op' column.
        target_table_name (str): The full name of the target Delta table.
    """
    
    # 1. Prepare the source DataFrame with proper schema alignment
    #    This creates the source dataset for MERGE, ensuring it has all 12 target columns.
    cdc_data = (
        source_df
            .withColumn("rn_latest", F.lit(0).cast("int"))
            .withColumn("revenue_rank", F.lit(0).cast("int"))
            .withColumn("rolling_30d_revenue", F.lit(0.0).cast("double"))
            .withColumn("orderstatus", F.substring(F.col("orderstatus"), 1, 10))
            .withColumn("orderpriority", F.substring(F.col("orderpriority"), 1, 15))
            .withColumn("clerk", F.substring(F.col("clerk"), 1, 15))
    )

    # 2. 💥 FIX: Dynamically build the column lists, EXCLUDING the 'op' column.
    #    This ensures the source and target column counts are equal for INSERT/UPDATE actions.
    target_data_columns = [col for col in cdc_data.columns if col != 'op']

    # Generate the column assignments for UPDATE and INSERT
    update_set_clause = ", ".join([f"target.{col} = source.{col}" for col in target_data_columns])
    
    insert_cols_clause = f"({', '.join(target_data_columns)})"
    insert_vals_clause = f"({', '.join([f'source.{col}' for col in target_data_columns])})"

    print(f"Applying CDC operations using MERGE INTO...")
    
    # Create or replace a temporary view for the CDC data
    cdc_data.createOrReplaceTempView("cdc_source")

    # 3. Use the explicit column lists in the MERGE statement
    merge_sql = f"""
    MERGE INTO {target_table_name} AS target
    USING cdc_source AS source
    ON target.orderkey = source.orderkey
    WHEN MATCHED AND source.op = 'D' THEN
        DELETE
    WHEN MATCHED AND source.op = 'U' THEN
        UPDATE SET {update_set_clause}
    WHEN NOT MATCHED AND source.op = 'I' THEN
        INSERT {insert_cols_clause} VALUES {insert_vals_clause}
    """

    try:
        spark.sql(merge_sql)
        print("CDC operations completed successfully using MERGE!")
    except Exception as e:
        print(f"MERGE failed: {e}")
        # In case of failure, you might want to log the error or handle it gracefully
        # depending on your production environment.

# Apply the CDC operations
apply_cdc_operations_merge(sim_cdc, FACT_ORDERS_TBL)

### 6b) Post-CDC Checks

In [None]:

row_count = spark.table(fqtn(TARGET_DATABASE, TARGET_SCHEMA, "fact_orders")).count()
print("Rows in fact_orders after CDC:", row_count)


## 7) Refresh Aggregates After CDC

In [None]:

fact_orders_post = spark.table(fqtn(TARGET_DATABASE, TARGET_SCHEMA, "fact_orders"))

agg_revenue_post = (
    fact_orders_post.groupBy("custkey")
    .agg(
        F.countDistinct("orderkey").alias("order_cnt"),
        F.sum("order_revenue").alias("lifetime_revenue"),
        F.max("orderdate").alias("last_order_date")
    )
)

agg_revenue_tbl = fqtn(TARGET_DATABASE, TARGET_SCHEMA, "agg_customer_revenue")
agg_revenue_post.write.mode("overwrite").format("delta").saveAsTable(agg_revenue_tbl)
print("Refreshed:", agg_revenue_tbl)


## 8) Runtime (minutes & seconds)

In [None]:

from datetime import datetime

pipeline_end_ts = datetime.now()
elapsed = pipeline_end_ts - pipeline_start_ts

mins = int(elapsed.total_seconds() // 60)
secs = int(elapsed.total_seconds() % 60)

print(f"Pipeline finished at: {pipeline_end_ts.isoformat(timespec='seconds')}")
print(f"Total runtime: {mins} min {secs} sec")
