In [0]:

%run ./01-config

In [0]:
class Upserter:
    def __init__(self, merge_query, temp_view):
        self.merge_query = merge_query
        self.temp_view = temp_view

    def upsert(self, df_micro_batch, batch_id):
        df_micro_batch.createOrReplaceTempView(self.temp_view)
        df_microbatch.sparkSession.sql(self.merge_query)

In [0]:
from pyspark.sql.window import Window
from pyspark.sql import functions as f

class CDCUpserter:
    def __init__(self, merge_query, temp_view, id_col, sort_col):
        self.merge_query = merge_query
        self.temp_view = temp_view
        self.id_col = id_col
        self.sort_col = sort_col

    def upsert(self, df_micro_batch, batch_id):
        window = Window.partition_by(self.id_col).orderBy(F.col(self.sort_col).desc())

        df_micro_batch.filter(F.col("row_status").isin(["insert", "update"])) \
            .withColumn("rank", F.rank().over(window)) \
            .filter("rank == 1") \
            .drop("rank") \
            .createOrReplaceTempView(self.temp_view)

        df_micro_batch.sparkSession.sql(self.merge_query)

In [0]:
class Silver:
    def __init__(self, env):
        self.Conf = Config()
        self.checkpoint_dir = self.Conf.base_checkpoint_dir + "/checkpoints"
        self.landing_dir = self.Conf.base_data_dir + "/raw"
        self.catalog = env
        self.db_name = self.Conf.db_name
        self.maxFilesPerTrigger = self.Conf.maxFilesPerTrigger
        spark.sql(f"USE {self.catalog}.{self.db_name}")

    
    def upsert_customers(self, once=True, processing_time="10 seconds"):
        query = f"""
            MERGE INTO {self.catalog}.{self.db_name}.silver_customers c
            USING ranked_updates r
            ON c.customer_id=r.customer_id
            WHEN MATCHED AND c.row_time < r.row_time
              THEN UPDATE SET *
            WHEN NOT MATCHED
              THEN INSERT *        
        """

        data_upserter = Upserter(query, "ranked_updates")

        schema = """customer_id STRING, email STRING, first_name STRING, last_name STRING, gender STRING, street STRING, city STRING,
            country_code STRING, row_status STRING, row_time timestamp"""

        df_country_lookup = spark.read.json(self.landing_dir + "/country_lookup")

        df_stream = (spark.readStream
                          .table(f"{self.catalog}.{self.db_name}.bronze")
                          .fiter("topic = 'customers'")
                          .select(F.from_json(F.col("value").cast("string"), schema).alias("c"))
                          .select("c.*")
                          .join(F.broadcast(df_country_lookup), "country_code" == "code")
        )

        stream_writer = (df_stream.writeStream
                                  .foreachBatch(data_upserter.upsert)
                                  .option("checkpointLocation", self.checkpoint_dir + "/silver_customers")                 
                         )
        
        if once:
            return stream_writer.trigger(availableNow=True).strart()
        else:
            return stream_writer.trigger(processingTime=processing_time).start()



