In [2]:
from datetime import datetime, timedelta
import pendulum
from airflow import DAG
from airflow.operators.python import PythonOperator
from airflow.providers.apache.spark.operators.spark_submit import SparkSubmitOperator
from airflow_clickhouse_plugin.hooks.clickhouse import ClickHouseHook
from airflow.models import Variable
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, count, when, isnan, isnull, length, current_timestamp, current_date, date_diff, round, sum as spark_sum, lit, regexp_replace, trim, expr, count_distinct
from pyspark.sql.types import *
import logging
import os
import json

logging.basicConfig(level = logging.INFO)
logger = logging.getLogger(__name__)


In [2]:
def create_spark_session(app_name: str, additional_jars: list = None):
    """Create a SparkSession with necessary configurations"""
    import os

    active_session = SparkSession.getActiveSession()
    if active_session:
        active_session.stop()

    # Ensure ClickHouse JDBC JAR is always included
    clickhouse_jar = os.path.abspath("driver/clickhouse-jdbc-0.4.6.jar")
    postgres_jar = os.path.abspath("driver/postgresql-42.7.5.jar")

    default_jars = [clickhouse_jar, postgres_jar]

    if additional_jars:
        jars = ",".join(default_jars + additional_jars)
    else:
        jars = ",".join(default_jars)

    logging.info(f"Adding JARs to Spark session: {jars}")

    builder = SparkSession.builder.appName(app_name) \
        .config("spark.jars", jars) \
        .config("spark.driver.extraClassPath", jars) \
        .config("spark.executor.extraClassPath", jars) \
        .config("spark.driver.userClassPathFirst", "true") \
        .config("spark.executor.userClassPathFirst", "true")

    return builder.getOrCreate()


In [4]:
# DAG default arguments
default_args = {
    'owner': 'airflow',
    'depends_on_past': False,
    'start_date': pendulum.today().add(days = -1),
    'email_on_failure': True,
    'email_on_retry': False,
    'retries': 1,
    'retry_delay': pendulum.duration(minutes = 5)
}

current_date_str = pendulum.today().to_date_string()

table_source = [
    ("public", "fct_transactions"),
    ("public", "dim_item"),
    ("public", "dim_time")
]

expected_schema = {
    "fct_transactions" : {
        "payment_key" : StringType(),
        "customer_key" : StringType(),
        "time_key" : StringType(),
        "item_key" : StringType(),
        "store_key" : StringType(),
        "quantity" : IntegerType(),
        "unit" : StringType(),
        "unit_price" : IntegerType(),
        "total_price" : IntegerType()
    },
    "dim_item" : {
        "item_key" : StringType(),
        "item_name" : StringType(),
        "desc" : StringType(),
        "unit_price" : FloatType(),
        "man_country" : StringType(),
        "supplier" : StringType(),
        "unit" : StringType()
    },
    "dim_time" : {
        "time_key" : StringType(),
        "date" : StringType(),
        "hour" : IntegerType(),
        "day" : IntegerType(),
        "week" : StringType(),
        "month" : IntegerType(),
        "quarter" : StringType(),
        "year" : IntegerType()
    }
}

def stg_dq_checks(df, tablename):
    actual_schema = dict([(field.name, str(field.dataType)) for field in df.schema.fields])
    expected_table_schema = expected_schema[tablename]
    quality_metrics = {}
    schema_validity = int(set(expected_table_schema.keys()) == set(actual_schema.keys())) * 100
    quality_metrics[f"{tablename}_schema_validity"] = schema_validity
    if tablename == "fct_transactions":
        checks = [
            "customer_key",
            "item_key",
            "time_key",
            "quantity",
            "unit_price",
            "total_price"
            ]
        for check in checks:
            quality_metrics[f"{tablename}_null_{check}"] = (
                df.filter(col(check).isNotNull()).count() / df.count() * 100 if df.count() > 0 else 0
            )
        checks_negative = [
            "quantity",
            "unit_price",
            "total_price"]
        for check in checks_negative:
            quality_metrics[f"{tablename}_negative_{check}"] = (
                df.filter(col(check) > 0).count() / df.count() * 100 if df.count() > 0 else 0
            )
    elif tablename == "dim_item":
        checks = [
            "item_key",
            "desc",
            "item_name",
            "unit_price"]
        for check in checks:
            quality_metrics[f"{tablename}_null_{check}"] = (
                df.filter(col(check).isNotNull()).count() / df.count() * 100 if df.count() > 0 else 0
            )
        quality_metrics[f"{tablename}_negative_unit_price"] = (
            df.filter(col("unit_price") > 0).count() / df.count() * 100 if df.count() > 0 else 0
        )
    elif tablename == "dim_time":
        quality_metrics[f"{tablename}_null_time_key"] = (
            df.filter(col("time_key").isNotNull()).count() / df.count() * 100 if df.count() > 0 else 0
        )
    return quality_metrics

def extract(table_source : dict):
    """
    Extract data from PostgreSQL and perform initial quality checks
    """
    spark = create_spark_session('PostgreSQL-Extract')
    dataframes = {}
    for schemaname, tablename in table_source:
        df = spark.read \
            .format("jdbc") \
            .option("url", Variable.get("POSTGRES_JDBC_URL")) \
            .option("dbtable", f"{schemaname}.{tablename}") \
            .option("user", "spark") \
            .option("password", Variable.get("POSTGRES_PASSWORD")) \
            .option("driver", "org.postgresql.Driver") \
            .load()
        dataframes[tablename] = df

    all_quality_metrics = {}
    for tablename, df in dataframes.items():
        quality_results = stg_dq_checks(df, tablename)
        all_quality_metrics[tablename] = quality_results

    with open(f'dq_metrics/{current_date_str}_quality_metrics.json', 'w') as f:
        json.dump(all_quality_metrics, f)
        f.close()

    failed_checks = {}
    for table, metrics in all_quality_metrics.items():
        failed_metrics = {metric: val for metric, val in metrics.items() if val < 90}

        if failed_metrics:
            failed_checks[table] = failed_metrics

    if failed_checks:
        logger.error("Data Quality Check Failed for the following tables:")
        for table, metrics in failed_checks.items():
            for metric, val in metrics.items():
                logger.error(f"{table}.{metric} = {val}% (Expected ≥ 90%)")
        raise Exception(f"Data Quality Check Failed: {json.dumps(failed_checks, indent=4)}")
    else:
        logger.info("All Data Quality Checks Passed!")
        for schemaname, tablename in table_source:
            staging_path_write = f"staging/{current_date_str}.{schemaname}.{tablename}.parquet"
            dataframes[tablename].write.parquet(staging_path_write, mode="overwrite")
    
    return all_quality_metrics

In [5]:
extract(table_source = table_source)

[[34m2025-03-01T16:18:47.129+0700[0m] {[34m1582760252.py:[0m20} INFO[0m - Adding JARs to Spark session: /Users/sawitpro/Documents/snippets/porto/etl_project/driver/clickhouse-jdbc-0.4.6.jar,/Users/sawitpro/Documents/snippets/porto/etl_project/driver/postgresql-42.7.5.jar[0m


                                                                                

[[34m2025-03-01T16:18:54.514+0700[0m] {[34m2371406242.py:[0m146} INFO[0m - All Data Quality Checks Passed![0m


                                                                                

{'fct_transactions': {'fct_transactions_schema_validity': 100,
  'fct_transactions_null_customer_key': 100.0,
  'fct_transactions_null_item_key': 100.0,
  'fct_transactions_null_time_key': 100.0,
  'fct_transactions_null_quantity': 100.0,
  'fct_transactions_null_unit_price': 100.0,
  'fct_transactions_null_total_price': 100.0,
  'fct_transactions_negative_quantity': 100.0,
  'fct_transactions_negative_unit_price': 100.0,
  'fct_transactions_negative_total_price': 100.0},
 'dim_item': {'dim_item_schema_validity': 100,
  'dim_item_null_item_key': 100.0,
  'dim_item_null_desc': 100.0,
  'dim_item_null_item_name': 100.0,
  'dim_item_null_unit_price': 100.0,
  'dim_item_negative_unit_price': 100.0},
 'dim_time': {'dim_time_schema_validity': 100,
  'dim_time_null_time_key': 100.0}}

In [6]:
def transform(table_source : dict):
    """
    Transform data using PySpark
    """
    spark = create_spark_session("Spark-Transform")
    dataframes = {}
    for schemaname, tablename in table_source:
        staging_path_write = f"staging/{current_date_str}.{schemaname}.{tablename}.parquet"
        df = spark.read \
            .format("parquet") \
            .load(staging_path_write)
        dataframes[tablename] = df

    fct_transactions = dataframes["fct_transactions"].alias("ft")
    dim_item = dataframes["dim_item"].alias("di")
    dim_time = dataframes["dim_time"].alias("dt")

    cte_df = (
        fct_transactions
            .join(dim_item, col("ft.item_key") == col("di.item_key"), "left")
            .join(dim_time, col("ft.time_key") == col("dt.time_key"), "left")
            .selectExpr(
                "MAKE_DATE(year, month, day) AS transaction_date",
                "quantity",
                "total_price",
                "customer_key",
                "REGEXP_REPLACE(REGEXP_REPLACE(TRIM(desc), '^[a-z]. ', ''), ' - ', ' ') AS item_category"
            )
    )

    transformed_df = (
        cte_df
        .groupBy("transaction_date", "item_category")
        .agg(
            spark_sum("total_price").alias("total_transaction_value"),
            spark_sum("quantity").alias("total_goods_sold"),
            count_distinct("customer_key").alias("count_transacting_customer")
        )
    )

    transformed_path = f"transformed/{current_date_str}.daily_transaction_summary.parquet"
    transformed_df.write.parquet(transformed_path, mode="overwrite")
    logger.info(f"Transformed data succesfully saved to: {transformed_path}")

In [7]:
transform(table_source = table_source)

[[34m2025-03-01T16:18:57.720+0700[0m] {[34m1582760252.py:[0m20} INFO[0m - Adding JARs to Spark session: /Users/sawitpro/Documents/snippets/porto/etl_project/driver/clickhouse-jdbc-0.4.6.jar,/Users/sawitpro/Documents/snippets/porto/etl_project/driver/postgresql-42.7.5.jar[0m


[Stage 7:>                                                          (0 + 8) / 8]

[[34m2025-03-01T16:19:01.773+0700[0m] {[34m3267859770.py:[0m43} INFO[0m - Transformed data succesfully saved to: transformed/2025-03-01.daily_transaction_summary.parquet[0m


25/03/01 16:19:01 WARN MemoryManager: Total allocation exceeds 95.00% (1,020,054,720 bytes) of heap memory
Scaling row group sizes to 95.00% for 8 writers
                                                                                

In [8]:
# Airflow Variables for parameterization
AIRFLOW_PATH = Variable.get("LOCAL_AIRFLOW_PATH")
POSTGRES_JDBC_URL = Variable.get("POSTGRES_JDBC_URL")
POSTGRES_PASSWORD = Variable.get("POSTGRES_PASSWORD")
CLIKCHOUSE_CONN_ID = "clickhouse"
POSTGRES_SCHEMA = "public"
CLICKHOUSE_SCHEMA = "default"
TABLENAME = "person_detail"

In [29]:
def load(sink_tablename : str):
    """
    Load transformed data from Parquet into ClickHouse database using JDBC.
    """
    import clickhouse_connect
    spark = create_spark_session("Load-to-ClickHouse")
    ch_client = clickhouse_connect.get_client(
        host = "127.0.0.1",
        port = 8123,
        username = "spark",
        password = Variable.get("CLICKHOUSE_PASSWORD"))

    stg_table = f"`intermediate`.{sink_tablename}_stg"
    ch_stg_ddl = f"""
    CREATE TABLE IF NOT EXISTS {stg_table} (
        transaction_date Date,
        item_category String,
        total_transaction_value Float64,
        total_goods_sold Int32,
        count_transacting_customer Int32,
        updated_at UInt32 DEFAULT toUnixTimestamp(now())
    )
    ENGINE = MergeTree()
    ORDER BY (transaction_date, item_category)
    """
    ch_client.query(ch_stg_ddl)

    prod_table = f"`default`.{sink_tablename}"
    ch_prod_ddl = f"""
    CREATE TABLE IF NOT EXISTS {prod_table} (
        transaction_date Date,
        item_category String,
        total_transaction_value Float64,
        total_goods_sold Int32,
        count_transacting_customer Int32,
        updated_at UInt32 DEFAULT toUnixTimestamp(now())
    )
    ENGINE = ReplacingMergeTree(updated_at)
    ORDER BY (transaction_date, item_category);
    """
    ch_client.query(ch_prod_ddl)

    transformed_path = f"transformed/{current_date_str}.{sink_tablename}.parquet"
    transformed_df = spark.read.parquet(transformed_path)

    transformed_df.write \
        .format("jdbc") \
        .option("url", "jdbc:clickhouse://127.0.0.1:8123/default") \
        .option("dbtable", stg_table) \
        .option("user", "spark") \
        .option("password", "spark") \
        .option("driver", "com.clickhouse.jdbc.ClickHouseDriver") \
        .option("batchsize", "50000") \
        .mode("append") \
        .save()
    
    ch_prod_dml = f"""
    INSERT INTO {sink_tablename}
    SELECT * FROM {stg_table}
    """
    ch_client.query(ch_prod_dml)
    ch_prod_merge = f"""
    OPTIMIZE TABLE {prod_table} FINAL
    """
    ch_client.query(ch_prod_merge)
    logger.info(f"Data successfully loaded into ClickHouse table: {sink_tablename}")

load("daily_transaction_summary")

[[34m2025-03-01T17:41:59.054+0700[0m] {[34m1582760252.py:[0m20} INFO[0m - Adding JARs to Spark session: /Users/sawitpro/Documents/snippets/porto/etl_project/driver/clickhouse-jdbc-0.4.6.jar,/Users/sawitpro/Documents/snippets/porto/etl_project/driver/postgresql-42.7.5.jar[0m
[[34m2025-03-01T17:41:59.766+0700[0m] {[34m3900943316.py:[0m66} INFO[0m - Data successfully loaded into ClickHouse table: daily_transaction_summary[0m
