In [1]:
from pyspark.sql import SparkSession, Window
from pyspark.sql.functions import col, monotonically_increasing_id, row_number

JDBC_HOST = "doc_postgres"
JDBC_PORT = 5432
JDBC_DB = "postgres"
JDBC_USER = "postgres"
JDBC_PASSWORD = "mypassword"
JDBC_URL = f"jdbc:postgresql://{JDBC_HOST}:{JDBC_PORT}/{JDBC_DB}"
JDBC_DRIVER = "org.postgresql.Driver"
SRC_TABLE = "mock_data"

def add_sequential_key(df, key_name, order_cols=None, start=1):
    if order_cols is None:
        df = df.withColumn("__tmp_mon", monotonically_increasing_id())
        w = Window.orderBy(col("__tmp_mon"))
        out = df.withColumn(key_name, row_number().over(w) + (start - 1)).drop("__tmp_mon")
    else:
        w = Window.orderBy(*[col(c) for c in order_cols])
        out = df.withColumn(key_name, row_number().over(w) + (start - 1))
    return out

def build_compact(session, df, jdbc_url, props):
    dim_customer = (df.select(
                        col("sale_customer_id").alias("customer_business_id"),
                        col("customer_first_name").alias("first_name"),
                        col("customer_last_name").alias("last_name"),
                        col("customer_age").alias("age"),
                        col("customer_email").alias("email"),
                        col("customer_country").alias("country"),
                        col("customer_postal_code").alias("postal_code"),
                        col("customer_pet_type").alias("pet_type"),
                        col("customer_pet_name").alias("pet_name"),
                        col("customer_pet_breed").alias("pet_breed")
                    )
                    .filter(col("customer_business_id").isNotNull())
                    .dropDuplicates(["customer_business_id"]))
    dim_customer = add_sequential_key(dim_customer, "customer_key", order_cols=["customer_business_id"])
    dim_customer.write.jdbc(url=jdbc_url, table="dim_customer", mode="overwrite", properties=props)

    dim_product = (df.select(
                        col("sale_product_id").alias("product_business_id"),
                        col("product_name"),
                        col("product_category").alias("category"),
                        col("product_price").alias("price"),
                        col("product_weight").alias("weight"),
                        col("product_color").alias("color"),
                        col("product_size").alias("size"),
                        col("product_brand").alias("brand"),
                        col("product_material").alias("material"),
                        col("product_description").alias("description"),
                        col("product_rating").alias("rating"),
                        col("product_reviews").alias("reviews"),
                        col("product_release_date").alias("release_date"),
                        col("product_expiry_date").alias("expiry_date"),
                        col("pet_category")
                    )
                    .filter(col("product_business_id").isNotNull())
                    .dropDuplicates(["product_business_id"]))
    dim_product = add_sequential_key(dim_product, "product_key", order_cols=["product_business_id"])
    dim_product.write.jdbc(url=jdbc_url, table="dim_product", mode="overwrite", properties=props)

    dim_seller = (df.select(
                        col("sale_seller_id").alias("seller_business_id"),
                        col("seller_first_name").alias("first_name"),
                        col("seller_last_name").alias("last_name"),
                        col("seller_email").alias("email"),
                        col("seller_country").alias("country"),
                        col("seller_postal_code").alias("postal_code")
                    )
                    .filter(col("seller_business_id").isNotNull())
                    .dropDuplicates(["seller_business_id"]))
    dim_seller = add_sequential_key(dim_seller, "seller_key", order_cols=["seller_business_id"])
    dim_seller.write.jdbc(url=jdbc_url, table="dim_seller", mode="overwrite", properties=props)

    dim_store = (df.select(
                    col("store_name"),
                    col("store_location").alias("location"),
                    col("store_city").alias("city"),
                    col("store_state").alias("state"),
                    col("store_country").alias("country"),
                    col("store_phone").alias("phone"),
                    col("store_email").alias("email")
                  )
                  .filter(col("store_name").isNotNull() & col("store_location").isNotNull())
                  .dropDuplicates(["store_name", "location"]))
    dim_store = add_sequential_key(dim_store, "store_key", order_cols=["store_name", "location"])
    dim_store.write.jdbc(url=jdbc_url, table="dim_store", mode="overwrite", properties=props)

    dim_supplier = (df.select(
                        col("supplier_name"),
                        col("supplier_contact").alias("contact"),
                        col("supplier_email").alias("email"),
                        col("supplier_phone").alias("phone"),
                        col("supplier_address").alias("address"),
                        col("supplier_city").alias("city"),
                        col("supplier_country").alias("country")
                    )
                    .filter(col("supplier_name").isNotNull())
                    .dropDuplicates(["supplier_name"]))
    dim_supplier = add_sequential_key(dim_supplier, "supplier_key", order_cols=["supplier_name"])
    dim_supplier.write.jdbc(url=jdbc_url, table="dim_supplier", mode="overwrite", properties=props)

    dim_customer = session.read.jdbc(url=jdbc_url, table="dim_customer", properties=props).alias("c")
    dim_product = session.read.jdbc(url=jdbc_url, table="dim_product", properties=props).alias("p")
    dim_seller = session.read.jdbc(url=jdbc_url, table="dim_seller", properties=props).alias("s")
    dim_store = session.read.jdbc(url=jdbc_url, table="dim_store", properties=props).alias("st")
    dim_supplier = session.read.jdbc(url=jdbc_url, table="dim_supplier", properties=props).alias("sup")

    m = df.alias("m")
    joined = m \
        .join(dim_customer, col("m.sale_customer_id") == col("c.customer_business_id"), "left") \
        .join(dim_seller, col("m.sale_seller_id") == col("s.seller_business_id"), "left") \
        .join(dim_product, col("m.sale_product_id") == col("p.product_business_id"), "left") \
        .join(dim_store, (col("m.store_name") == col("st.store_name")) & (col("m.store_location") == col("st.location")), "left") \
        .join(dim_supplier, col("m.supplier_name") == col("sup.supplier_name"), "left")

    fact = joined.select(
        col("c.customer_key").alias("customer_key"),
        col("s.seller_key").alias("seller_key"),
        col("p.product_key").alias("product_key"),
        col("st.store_key").alias("store_key"),
        col("sup.supplier_key").alias("supplier_key"),
        col("m.sale_date").alias("sale_date"),
        col("m.sale_quantity").alias("quantity"),
        col("m.sale_total_price").alias("total_price"),
        col("m.id").alias("original_id")
    )

    fact.write.jdbc(url=jdbc_url, table="fact_sale", mode="overwrite", properties=props)

def run():
    spark = SparkSession.builder \
        .appName("compact_snowflake_spark") \
        .config("spark.jars.packages", "org.postgresql:postgresql:42.6.0") \
        .getOrCreate()

    props = {"user": JDBC_USER, "password": JDBC_PASSWORD, "driver": JDBC_DRIVER}
    src = spark.read.jdbc(url=JDBC_URL, table=SRC_TABLE, properties=props)
    build_compact(spark, src, JDBC_URL, props)
    spark.stop()

if __name__ == "__main__":
    run()


In [2]:
!pip install clickhouse_connect

Collecting clickhouse_connect
  Downloading clickhouse_connect-0.10.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl.metadata (4.2 kB)
Downloading clickhouse_connect-0.10.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m0m
[?25hInstalling collected packages: clickhouse_connect
Successfully installed clickhouse_connect-0.10.0


In [3]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import (
    col, sum as _sum, avg as _avg, month, year, corr,
    to_date, coalesce, when, lit, from_unixtime, length, lag
)
from pyspark.sql.types import StructType
from clickhouse_connect import get_client
from datetime import date, datetime
from decimal import Decimal
from pyspark.sql.window import Window

POSTGRES_JDBC_URL = "jdbc:postgresql://doc_postgres:5432/postgres"
PG_PROPS = {"user": "postgres", "password": "mypassword", "driver": "org.postgresql.Driver"}

CLICKHOUSE_HOST = "clickhouse-server"
CLICKHOUSE_PORT = 8123
CH_USER = "default"
CH_PASSWORD = ""

spark = SparkSession.builder \
    .appName("Postgres->ClickHouse ETL (date-fix, nullable monthly)") \
    .config("spark.jars.packages", "org.postgresql:postgresql:42.6.0") \
    .config("spark.sql.legacy.timeParserPolicy", "LEGACY") \
    .getOrCreate()
spark.conf.set("spark.sql.legacy.timeParserPolicy", "LEGACY")

ch = get_client(host=CLICKHOUSE_HOST, port=CLICKHOUSE_PORT, username=CH_USER, password=CH_PASSWORD)

def read_from_postgres(table):
    df = spark.read.jdbc(url=POSTGRES_JDBC_URL, table=table, properties=PG_PROPS)
    return df.cache()

def recreate_ch_table(name, ddl_body):
    ch.command(f"DROP TABLE IF EXISTS {name}")
    sql = f"CREATE TABLE {name} ({ddl_body}) ENGINE = MergeTree() ORDER BY tuple()"
    ch.command(sql)

def _normalize_value(v):
    if v is None:
        return None
    if isinstance(v, (date, datetime)):
        return v.isoformat()
    if isinstance(v, Decimal):
        return float(v)
    try:
        if hasattr(v, "item"):
            return v.item()
    except Exception:
        pass
    return v

def insert_spark_df_to_ch(df, table_name, batch_size=1000):
    cols = df.columns
    if not cols:
        return
    it = df.toLocalIterator()
    batch = []
    total = 0
    for row in it:
        tup = tuple(_normalize_value(row[i]) for i in range(len(cols)))
        batch.append(tup)
        if len(batch) >= batch_size:
            ch.insert(table_name, batch, column_names=cols)
            total += len(batch)
            batch = []
    if batch:
        ch.insert(table_name, batch, column_names=cols)
        total += len(batch)

fact_sale = read_from_postgres("fact_sale")
dim_product = read_from_postgres("dim_product")
dim_customer = read_from_postgres("dim_customer")
dim_store = read_from_postgres("dim_store")
dim_supplier = read_from_postgres("dim_supplier")

possible_date_cols = ["sale_date", "sale_ts", "sale_datetime", "created_at", "date", "timestamp"]
formats = [
    "M/d/yyyy", "MM/dd/yyyy", "yyyy-MM-dd HH:mm:ss",
    "yyyy-MM-dd", "dd.MM.yyyy", "yyyy/MM/dd", "MMM d, yyyy"
]

def build_parsed_date_expr(col_name):
    parsed_expr = None
    s = col(col_name).cast("string")
    for fmt in formats:
        parsed = to_date(s, fmt)
        parsed_expr = parsed if parsed_expr is None else coalesce(parsed_expr, parsed)
    parsed_epoch_secs = when((s.rlike('^[0-9]+$')) & (length(s) <= 10),
                             to_date(from_unixtime(s.cast("long"))))
    parsed_expr = coalesce(parsed_expr, parsed_epoch_secs)
    parsed_epoch_millis = when((s.rlike('^[0-9]+$')) & (length(s) > 10),
                               to_date(from_unixtime((s.cast("double")/1000).cast("long"))))
    parsed_expr = coalesce(parsed_expr, parsed_epoch_millis)
    parsed_expr = coalesce(parsed_expr, to_date(s))
    return parsed_expr

parsed_expr = None
for c in possible_date_cols:
    if c in fact_sale.columns:
        expr = build_parsed_date_expr(c)
        parsed_expr = expr if parsed_expr is None else coalesce(parsed_expr, expr)

if parsed_expr is not None:
    fact_sale = fact_sale.withColumn("sale_date_parsed", parsed_expr)
    fact_sale = fact_sale.withColumn("sale_date", col("sale_date_parsed")).drop("sale_date_parsed")

if fact_sale.columns and dim_product.columns:
    top_products = fact_sale.join(dim_product, "product_key", "inner") \
        .groupBy("product_name") \
        .agg(_sum(col("quantity")).alias("total_quantity_sold")) \
        .orderBy(col("total_quantity_sold").desc()).limit(10)
    recreate_ch_table("top_products_mart", "product_name String, total_quantity_sold Float64")
    insert_spark_df_to_ch(top_products.select("product_name", "total_quantity_sold"), "top_products_mart")

    category_revenue = fact_sale.join(dim_product, "product_key", "inner") \
        .groupBy("category") \
        .agg(_sum(col("total_price")).alias("total_revenue"))
    recreate_ch_table("category_revenue_mart", "category String, total_revenue Float64")
    insert_spark_df_to_ch(category_revenue.select("category", "total_revenue"), "category_revenue_mart")

    product_ratings = fact_sale.join(dim_product, "product_key", "inner") \
        .groupBy("product_name") \
        .agg(_avg(col("rating")).alias("avg_rating"), _sum(col("reviews")).alias("total_reviews"))
    recreate_ch_table("product_ratings_mart", "product_name String, avg_rating Float64, total_reviews Int64")
    insert_spark_df_to_ch(product_ratings.select("product_name", "avg_rating", "total_reviews"), "product_ratings_mart")

if fact_sale.columns and dim_customer.columns:
    top_customers = fact_sale.join(dim_customer, "customer_key", "inner") \
        .groupBy("first_name", "last_name") \
        .agg(_sum(col("total_price")).alias("total_spent")) \
        .orderBy(col("total_spent").desc()).limit(10)
    recreate_ch_table("top_customers_mart", "first_name String, last_name String, total_spent Float64")
    insert_spark_df_to_ch(top_customers.select("first_name", "last_name", "total_spent"), "top_customers_mart")

    customers_by_country = fact_sale.join(dim_customer, "customer_key", "inner") \
        .groupBy("country") \
        .agg(_sum(col("total_price")).alias("country_revenue"))
    recreate_ch_table("customers_by_country_mart", "country String, country_revenue Float64")
    insert_spark_df_to_ch(customers_by_country.select("country", "country_revenue"), "customers_by_country_mart")

    customer_avg_check = fact_sale.join(dim_customer, "customer_key", "inner") \
        .groupBy("first_name", "last_name") \
        .agg(_avg(col("total_price")).alias("avg_check"))
    recreate_ch_table("customer_avg_check_mart", "first_name String, last_name String, avg_check Float64")
    insert_spark_df_to_ch(customer_avg_check.select("first_name", "last_name", "avg_check"), "customer_avg_check_mart")

if "sale_date" in fact_sale.columns and fact_sale.filter(col("sale_date").isNotNull()).count() > 0:
    fact_with_date = fact_sale.filter(col("sale_date").isNotNull())

    monthly = fact_with_date \
        .withColumn("sale_year", year(col("sale_date")).cast("int")) \
        .withColumn("sale_month", month(col("sale_date")).cast("int")) \
        .groupBy("sale_year", "sale_month") \
        .agg(_sum(col("total_price")).alias("monthly_revenue")) \
        .orderBy("sale_year", "sale_month")

    w = Window.orderBy("sale_year", "sale_month")
    monthly = monthly.withColumn("prev_month_revenue", lag("monthly_revenue").over(w))
    monthly = monthly.withColumn("mom_change_pct",
                                 when(col("prev_month_revenue").isNotNull() & (col("prev_month_revenue") != 0),
                                      (col("monthly_revenue") - col("prev_month_revenue")) / col("prev_month_revenue"))
                                 .otherwise(lit(None))
                                 )

    recreate_ch_table(
        "monthly_trends_mart",
        "sale_year Int32, sale_month Int32, monthly_revenue Float64, prev_month_revenue Nullable(Float64), mom_change_pct Nullable(Float64)"
    )
    insert_spark_df_to_ch(monthly.select("sale_year", "sale_month", "monthly_revenue", "prev_month_revenue", "mom_change_pct"),
                          "monthly_trends_mart")

    yearly = monthly.groupBy("sale_year").agg(_sum("monthly_revenue").alias("yearly_revenue")).orderBy("sale_year")
    recreate_ch_table("yearly_trends_mart", "sale_year Int32, yearly_revenue Float64")
    insert_spark_df_to_ch(yearly.select("sale_year", "yearly_revenue"), "yearly_trends_mart")

    avg_order_monthly = fact_with_date \
        .withColumn("sale_year", year(col("sale_date")).cast("int")) \
        .withColumn("sale_month", month(col("sale_date")).cast("int")) \
        .groupBy("sale_year", "sale_month") \
        .agg(_avg(col("total_price")).alias("avg_order_value"))
    recreate_ch_table("avg_order_monthly_mart", "sale_year Int32, sale_month Int32, avg_order_value Nullable(Float64)")
    insert_spark_df_to_ch(avg_order_monthly.select("sale_year", "sale_month", "avg_order_value"), "avg_order_monthly_mart")

if fact_sale.columns and dim_store.columns:
    top_stores = fact_sale.join(dim_store, "store_key", "inner") \
        .groupBy("store_name") \
        .agg(_sum(col("total_price")).alias("total_revenue")) \
        .orderBy(col("total_revenue").desc()).limit(5)
    recreate_ch_table("top_stores_mart", "store_name String, total_revenue Float64")
    insert_spark_df_to_ch(top_stores.select("store_name", "total_revenue"), "top_stores_mart")

    sales_by_city = fact_sale.join(dim_store, "store_key", "inner") \
        .groupBy("city", "country") \
        .agg(_sum(col("total_price")).alias("city_revenue"))
    recreate_ch_table("sales_by_city_mart", "city String, country String, city_revenue Float64")
    insert_spark_df_to_ch(sales_by_city.select("city", "country", "city_revenue"), "sales_by_city_mart")

    store_avg_check = fact_sale.join(dim_store, "store_key", "inner") \
        .groupBy("store_name") \
        .agg(_avg(col("total_price")).alias("avg_receipt"))
    recreate_ch_table("store_avg_check_mart", "store_name String, avg_receipt Float64")
    insert_spark_df_to_ch(store_avg_check.select("store_name", "avg_receipt"), "store_avg_check_mart")

if fact_sale.columns and dim_supplier.columns:
    top_suppliers = fact_sale.join(dim_supplier, "supplier_key", "inner") \
        .groupBy("supplier_name") \
        .agg(_sum(col("total_price")).alias("total_revenue")) \
        .orderBy(col("total_revenue").desc()).limit(5)
    recreate_ch_table("top_suppliers_mart", "supplier_name String, total_revenue Float64")
    insert_spark_df_to_ch(top_suppliers.select("supplier_name", "total_revenue"), "top_suppliers_mart")

    supplier_avg_price = fact_sale.join(dim_supplier, "supplier_key", "inner") \
        .join(dim_product, "product_key", "inner") \
        .groupBy("supplier_name") \
        .agg(_avg(col("price")).alias("avg_price"))
    recreate_ch_table("supplier_avg_price_mart", "supplier_name String, avg_price Float64")
    insert_spark_df_to_ch(supplier_avg_price.select("supplier_name", "avg_price"), "supplier_avg_price_mart")

    supplier_by_country = fact_sale.join(dim_supplier, "supplier_key", "inner") \
        .groupBy("country") \
        .agg(_sum(col("total_price")).alias("country_revenue"))
    recreate_ch_table("supplier_by_country_mart", "country String, country_revenue Float64")
    insert_spark_df_to_ch(supplier_by_country.select("country", "country_revenue"), "supplier_by_country_mart")

if fact_sale.columns and dim_product.columns:
    best_products = fact_sale.join(dim_product, "product_key", "inner") \
        .groupBy("product_name") \
        .agg(_avg(col("rating")).alias("avg_rating")) \
        .orderBy(col("avg_rating").desc()).limit(5)
    recreate_ch_table("best_products_mart", "product_name String, avg_rating Float64")
    insert_spark_df_to_ch(best_products.select("product_name", "avg_rating"), "best_products_mart")

    worst_products = fact_sale.join(dim_product, "product_key", "inner") \
        .groupBy("product_name") \
        .agg(_avg(col("rating")).alias("avg_rating")) \
        .orderBy(col("avg_rating").asc()).limit(5)
    recreate_ch_table("worst_products_mart", "product_name String, avg_rating Float64")
    insert_spark_df_to_ch(worst_products.select("product_name", "avg_rating"), "worst_products_mart")

    top_reviewed = fact_sale.join(dim_product, "product_key", "inner") \
        .groupBy("product_name") \
        .agg(_sum(col("reviews")).alias("total_reviews")) \
        .orderBy(col("total_reviews").desc()).limit(5)
    recreate_ch_table("top_reviewed_mart", "product_name String, total_reviews Int64")
    insert_spark_df_to_ch(top_reviewed.select("product_name", "total_reviews"), "top_reviewed_mart")

    product_sales_correlation = fact_sale.join(dim_product, "product_key", "inner") \
        .groupBy("product_name") \
        .agg(
            _avg(col("rating")).alias("avg_rating"),
            _sum(col("quantity")).alias("total_quantity_sold"),
            _sum(col("total_price")).alias("total_revenue")
        )
    corr_row = product_sales_correlation.select(
        corr("avg_rating", "total_quantity_sold").alias("corr_q"),
        corr("avg_rating", "total_revenue").alias("corr_r")
    ).collect()
    corr_q = corr_row[0]["corr_q"] if corr_row and corr_row[0]["corr_q"] is not None else 0.0
    corr_r = corr_row[0]["corr_r"] if corr_row and corr_row[0]["corr_r"] is not None else 0.0

    recreate_ch_table("product_rating_correlation_mart", "metric String, correlation_value Float64")
    corr_df = spark.createDataFrame([("rating_vs_quantity", float(corr_q)), ("rating_vs_revenue", float(corr_r))])
    insert_spark_df_to_ch(corr_df.selectExpr("_1 as metric", "_2 as correlation_value"), "product_rating_correlation_mart")

spark.stop()
