In [2]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col , current_timestamp , lit
from pyspark.sql.types import StringType

In [None]:
spark = SparkSession.builder.appName("ETL Customer").config("spark.jars" , r"JARFILELOCATION").getOrCreate()

jdbc_url = "jdbc:postgresql://localhost:5432/RETAIL_DB"
conn_props = {"user" : "UZR" , "password" : "PASSOWD" , "driver" : "org.postgresql.Driver"}

#extract distinct customers from raw source
raw_cutomers = spark.read.jdbc(url = jdbc_url , table = 'wh_source.online_retail_raw' , column = 'customerid' , lowerBound = 1 , upperBound = 10 , numPartitions=10 , properties = conn_props).select(col("customerid") , col("country")).filter(col("customerid").isNotNull()).distinct()



In [None]:
#add etl timestamp
customer_stage = raw_cutomers.withColumnRenamed("customerid" , "customer_id").withColumn("etl_insert_ts", current_timestamp()).withColumn("etl_update_ts", current_timestamp())


In [None]:
#load existing stage table
try:
    stage_customers   =  spark.read.jdbc(jdbc_url, "wh_stage.stg_customer" , properties = conn_props)
except Exception:
    stage_customers = None

if stage_customers:
    new_customers  = customer_stage.join(stage_customers.select("customer_id") , "customer_id" , "left_anti")
    if new_customers.count() > 0:
        new_customers.write.jdbc(jdbc_url , "wh_stage.stg_customer" , mode = "append" , properties = conn_props)
else:
    customer_stage.write.jdbc(jdbc_url , "wh_stage.stg_customer" , mode = "overwrite" , properties = conn_props)



In [None]:
#load existing dimesion
try:
    dim_customers = spark.read.jdbc(jdbc_url , "wh_core.dim_customer" , properties = conn_props)
except Exception:
    dim_customers = None


if dim_customers:
    new_dim_customers = customer_stage.join(dim_customers , "customer_id" , "left_anti").withColumn("start_date"  , current_timestamp()).withColumn("end_date" , lit(None).cast("timestamp")).withColumn("current_flag"  ,lit(True))

    if new_dim_customers.count() > 0:
        new_dim_customers.write.jdbc(jdbc_url , "wh_core.dim_customer" , mode = "append" , properties = conn_props)
else:
    customer_stage.withColumn("start_date" , current_timestamp()).withColumn("end_date" , lit(None).cast("timestamp")).withColumn("current_flag", lit(True)).write.jdbc(jdbc_url , "wh_core.dim_customer" ,  mode = "overwrite" , properties = conn_props)


spark.stop()