In [0]:
%run "../Configuration/config file"

In [0]:
class UpsertCls():

    def __init__(self, merge_query, view_name):
        self.query = merge_query
        self.view_name = view_name

    def upsert(self, df, batch_id):
        df.createOrReplaceTempView(self.view_name)
        df._jdf.sparkSession().sql(self.query)

In [0]:
from pyspark.sql.functions import to_date, min, avg, max, col

class Gold():

    def __init__(self):
        obj_conf = ConfigModule()
        self.checkpoint_dir = f"{obj_conf.base_checkpoint_path}/checkpoint"
        self.catalog = obj_conf.environment
        self.schema = obj_conf.db_name
    
    def get_gd_complete_bpm_summary(self, version=0, once=True, processingtime='15 seconds'):
        print(f"Steaming started for {self.catalog}.{self.schema}.gd_workout_bpm_summary table...", end='')
        merge_query = f""" MERGE INTO {self.catalog}.{self.schema}.gd_workout_bpm_summary AS target
                            USING complete_bpm_summary_view AS source ON 
                            source.user_id=target.user_id AND source.date=target.date 
                            AND source.workout_id = target.workout_id AND source.session_id = target.session_id
                            WHEN MATCHED THEN UPDATE SET *
                            WHEN NOT MATCHED THEN INSERT *;
                            """
        heartrate_df = spark.read.table(f"{self.catalog}.{self.schema}.sl_heart_rate")
        user_bin_df = spark.read.table(f"{self.catalog}.{self.schema}.sl_user_bins")
        combine_df = heartrate_df.join(user_bin_df, on="device_id")\
                                .withColumn("date", to_date("time")).select("*")
        
        read_df = spark.readStream.option("startingVersion", version)\
                                    .option("ignoreDeletes", True)\
                                    .table(f"{self.catalog}.{self.schema}.sl_complete_workout")
        join_agg_df = read_df.join(combine_df, ((read_df.user_id==combine_df.user_id) & 
                                            (read_df.start_time<=combine_df.time) & 
                                            (read_df.end_time>=combine_df.time)), how="left")\
                        .groupBy(read_df.user_id, "session_id", "workout_id", "date")\
                        .agg(min("heart_rate").alias("min_bpm"),
                             avg("heart_rate").alias("avg_bpm"),
                             max("heart_rate").alias("max_bpm"))
        select_df = join_agg_df.join(user_bin_df, on="user_id")\
                                .select(*join_agg_df.columns, "age", col("gender").alias("sex"), "city", "state")
        cls_obj = UpsertCls(merge_query, 'complete_bpm_summary_view')
        final_df = select_df.writeStream.queryName("gd_complete_bpm_summary_stream")\
                                    .format("delta")\
                                    .option("checkpointLocation", f"{self.checkpoint_dir}/cp_gd_complete_bpm_summary")\
                                    .outputMode('update')\
                                    .foreachBatch(cls_obj.upsert)
        if once:
            final_df.trigger(availableNow=once).start()
        else:
            final_df.trigger(processingTime=processingtime).start()
        print("Started.")
    
    def launcher(self, once=True, processingtime='15 seconds'):
        self.get_gd_complete_bpm_summary(once=once, processingtime=processingtime)

#obj = Gold()
#obj.launcher()

In [0]:
class GoldTestSuite():

    def __init__(self):
        obj_conf = ConfigModule()
        self.catalog = obj_conf.environment
        self.schema = obj_conf.db_name

    def assert_fn(self, table_name, filter, expected_count):
        print(f'Testing Gold layer - {self.catalog}.{self.schema}.{table_name} table...', end='')
        actual_count = spark.sql(f"select count(*) from {self.catalog}.{self.schema}.{table_name} where {filter}").collect()[0][0]
        assert actual_count==expected_count, f"Test case failed, actual count is {actual_count}"
        print('Test Passed.')
    
    def assert_data_fn(self, table_name, test_data_path):
        print(f"Data compare for {self.catalog}.{self.schema}.{table_name} table...", end='')
        test_df = spark.read.format('parquet').load(test_data_path)
        table_df = spark.read.format("delta").table(f"{self.catalog}.{self.schema}.{table_name}")
        assert test_df.exceptAll(table_df).isEmpty() and table_df.exceptAll(test_df).isEmpty(), "Data mismatch, testcase failed."
        print("Passed.")
    
    def testcases(self):
        self.assert_fn('gd_workout_bpm_summary', 'true', 1)
        self.assert_fn('gd_gym_summary', 'true', 8)
        self.assert_data_fn('gd_workout_bpm_summary', '/FileStore/tables/test_workout_bpm_summary')
        self.assert_data_fn('gd_gym_summary', '/FileStore/tables/test_gd_gym_summary')

#obj = GoldTestSuite()
#obj.testcases()