In [1]:
import findspark
findspark.init()

import pyspark
from delta import *

In [2]:

builder = pyspark.sql.SparkSession.builder.appName('Invoice Reader') \
                                          .config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension") \
                                          .config("spark.sql.catalog.spark_catalog", "org.apache.spark.sql.delta.catalog.DeltaCatalog")

spark = configure_spark_with_delta_pip(builder).getOrCreate()

In [3]:
class Bronze():

    def __init__(self):
        
        self.BASE_DIR = '..'

    
    def get_schema(self):

        return '''
               InvoiceNumber string, CreatedTime bigint, StoreID string, PosID string, CashierID string,
               CustomerType string, CustomerCardNo string, TotalAmount double, NumberOfItems bigint,
               PaymentMethod string, TaxableAmount double, CGST double, SGST double, CESS double, DeliveryType string,
               DeliveryAddress struct<
                    AddressLine string,
                    City string,
                    ContactNumber string,
                    PinCode string,
                    State string
               >,
               InvoiceLineItems array<struct<
                    ItemCode string,
                    ItemDescription string,
                    ItemPrice double,
                    ItemQty bigint,
                    TotalValue double
               >>
               '''


    def read_invoices(self):
        
        from pyspark.sql.functions import input_file_name
        return spark.readStream \
                    .format('json') \
                    .schema(self.get_schema()) \
                    .load(f'{self.BASE_DIR}/test_data/invoices') \
                    .withColumn('FileName', input_file_name())
    

    def process(self):

        print('Starting Bronze data extracting stream...', end='')

        raw_invoice_df = self.read_invoices()
        streaming_query = raw_invoice_df.writeStream \
                                        .queryName('bronze_ingestion') \
                                        .option('checkpointLocation', f'{self.BASE_DIR}/checkpoint/invoices_bz') \
                                        .outputMode('append') \
                                        .toTable('invoices_bz')

        print(' Done.')

        return streaming_query

In [4]:
class Gold():

    def __init__(self):
        
        self.BASE_DIR = '..'


    def read_invoices(self):

        return spark.readStream.table('invoices_bz')


    def get_points_per_customer(self, invoices_df):

        from pyspark.sql.functions import sum, expr

        return invoices_df.groupBy('CustomerCardNo') \
                          .agg(
                              sum('TotalAmount').alias('TotalAmount'),
                              sum(expr('TotalAmount * 0.02')).alias('TotalPoints')
                            )


    def aggregated_upsert(self, raw_invoice_df, batch_id):

        rewards_df = self.get_points_per_customer(raw_invoice_df)
        rewards_df.createOrReplaceTempView('customer_rewards_temp_view')
        merge_statement = '''
            MERGE INTO customer_rewards t
            USING customer_rewards_temp_view s
            ON s.CustomerCardNo = t.CustomerCardNo
            WHEN MATCHED THEN
            UPDATE SET t.TotalAmount = t.TotalAmount + s.TotalAmount,
                       t.TotalPoints = t.TotalPoints + s.TotalPoints
            WHEN NOT MATCHED THEN
            INSERT *
        '''
        
        rewards_df._jdf.sparkSession().sql(merge_statement)

    def save_aggregation(self, aggregated_df):
        
        return aggregated_df.writeStream \
                            .queryName('gold_update') \
                            .format('delta') \
                            .option('checkpointLocation', f'{self.BASE_DIR}/checkpoint/customer_rewards') \
                            .outputMode('update') \
                            .foreachBatch(self.aggregated_upsert) \
                            .start()


    def process(self):

        print('Starting Gold updating stream...', end='')

        raw_invoice_df = self.read_invoices()
        streaming_query = self.save_aggregation(raw_invoice_df)

        print(' Done.\n')
        return streaming_query

In [5]:
class AggregationTestSuite():

    def __init__(self):
        
        self.BASE_DIR = '..'


    def clean_up_for_testing(self):

        import shutil
        import os

        print('Starting cleaning...', end='')

        spark.sql('DROP TABLE IF EXISTS invoice_bz')
        spark.sql('DROP TABLE IF EXISTS customer_rewards')

        shutil.rmtree(f'{self.BASE_DIR}/notebooks/spark-warehouse/invoices_bz')
        os.makedirs(f'{self.BASE_DIR}/notebooks/spark-warehouse/invoices_bz')
        
        shutil.rmtree(f'{self.BASE_DIR}/notebooks/spark-warehouse/customer_rewards')
        os.makedirs(f'{self.BASE_DIR}/notebooks/spark-warehouse/customer_rewards')

        shutil.rmtree(f'{self.BASE_DIR}/checkpoint/invoices_bz')
        os.makedirs(f'{self.BASE_DIR}/checkpoint/invoices_bz')

        shutil.rmtree(f'{self.BASE_DIR}/checkpoint/customer_rewards')
        os.makedirs(f'{self.BASE_DIR}/checkpoint/customer_rewards')

        shutil.rmtree(f'{self.BASE_DIR}/test_data/invoices')
        os.makedirs(f'{self.BASE_DIR}/test_data/invoices')

        spark.sql('CREATE TABLE customer_rewards(CustomerCardNo STRING, TotalAmount DOUBLE, TotalPoints DOUBLE) USING delta')

        print(' Done.')

    
    def get_data(self, file_num):

        import shutil

        print('\tGetting data...', end='')

        shutil.copyfile(src=f'{self.BASE_DIR}/data/invoices/invoices-{file_num}.json', 
                        dst=f'{self.BASE_DIR}/test_data/invoices/invoices-{file_num}.json')
        
        print(' Done.')

    
    def assert_bronze(self, expected_result):
        
        print('\tStarting Bronze validation...', end='')

        actual_result = spark.sql(
            '''
            SELECT COUNT(*)
            FROM invoices_bz
            '''
        ).collect()[0][0]

        assert expected_result == actual_result, f'Test failed! Expected result is {expected_result}. Got {actual_result} instead.'
        
        print(' Done.')


    def assert_gold(self, expected_result):
        
        print('\tStarting Gold validation...', end='')

        actual_result = spark.sql(
            '''
            SELECT TotalAmount
            FROM customer_rewards
            WHERE CustomerCardNo = '2262471989'
            '''
        ).collect()[0][0]

        assert expected_result == actual_result, f'Test failed! Expected result is {expected_result}. Got {actual_result} instead.'        
        print(' Done.')


    def wait_for_microbatch(self, sleep_time=15):

        import time

        print(f'\tWaiting for {sleep_time} seconds...', end='')
        time.sleep(sleep_time)

        print(' Done.')


    def run_stream_tests(self):

        # Sleep time between extract and transform operation
        sleep_time = 20  # Only works if sleep_time >= 20
        self.clean_up_for_testing()

        bronze_extractor = Bronze()
        bronze_streaming_query = bronze_extractor.process()

        gold_update = Gold()
        gold_streaming_query = gold_update.process()

        expected_bronze_results = [501, 501 + 500, 501 + 500 + 590]
        expected_gold_results = [36859, 36859 + 20740, 36859 + 20740 + 31959]

        for i in range(len(expected_bronze_results)):

            print(f'Testing file No.{i + 1}...')

            self.get_data(i + 1)
            self.wait_for_microbatch(sleep_time=sleep_time)

            self.assert_bronze(expected_bronze_results[i])
            self.assert_gold(expected_gold_results[i])

            print(f'File No.{i + 1} test passed.\n')

        bronze_streaming_query.stop()
        gold_streaming_query.stop()


In [6]:
aggregated_stream_tester = AggregationTestSuite()
aggregated_stream_tester.run_stream_tests()

Starting cleaning... Done.
Starting Bronze data extracting stream... Done.
Starting Gold updating stream... Done.

Testing file No.1...
	Getting data... Done.
	Waiting for 20 seconds... Done.
	Starting Bronze validation... Done.
	Starting Gold validation... Done.
File No.1 test passed.

Testing file No.2...
	Getting data... Done.
	Waiting for 20 seconds...

In [None]:
spark.stop()