In [21]:
import findspark
findspark.init()

from pyspark.sql import SparkSession

spark = SparkSession.builder.appName('Invoice Reader').getOrCreate()

In [22]:
class InvoiceStreamBatch():

    def __init__(self):
        
        self.BASE_DIR = '..'

    
    def get_schema(self):

        return '''
               InvoiceNumber string, CreatedTime bigint, StoreID string, PosID string, CashierID string,
               CustomerType string, CustomerCardNo string, TotalAmount double, NumberOfItems bigint,
               PaymentMethod string, TaxableAmount double, CGST double, SGST double, CESS double, DeliveryType string,
               DeliveryAddress struct<
                    AddressLine string,
                    City string,
                    ContactNumber string,
                    PinCode string,
                    State string
               >,
               InvoiceLineItems array<struct<
                    ItemCode string,
                    ItemDescription string,
                    ItemPrice double,
                    ItemQty bigint,
                    TotalValue double
               >>
               '''


    def read_invoices(self):

        return spark.readStream \
                    .format('json') \
                    .schema(self.get_schema()) \
                    .load(f'{self.BASE_DIR}/test_data/invoices')
    

    def explode_invoices(self, invoice_df):

        return invoice_df.selectExpr(
            'InvoiceNumber', 'CreatedTime', 'StoreID', 'PosID',
            'CustomerType', 'PaymentMethod', 'DeliveryType',
            'DeliveryAddress.City', 'DeliveryAddress.State', 'DeliveryAddress.PinCode',
            'explode(InvoiceLineItems) as LineItem'
        )
    

    def flatten_invoices(self, exploded_df):

        from pyspark.sql.functions import expr

        flattened_df = exploded_df.withColumn('ItemCode', expr('LineItem.ItemCode')) \
                                  .withColumn('ItemDescription', expr('LineItem.ItemDescription')) \
                                  .withColumn('ItemPrice', expr('LineItem.ItemPrice')) \
                                  .withColumn('ItemQty', expr('LineItem.ItemQty')) \
                                  .withColumn('TotalValue', expr('LineItem.TotalValue')) \
                                  .drop('LineItem')
        
        return flattened_df
    

    def append_invoices(self, flattened_df, trigger='batch'):
        
        streaming_query = flattened_df.writeStream \
                                      .format("memory") \
                                      .outputMode('append') \
                                      .option('maxFilePerTrigger', 1)
        
        if trigger == 'batch':
            return streaming_query.trigger(availableNow = True) \
                                  .queryName('invoice_line_items') \
                                  .start(f'{self.BASE_DIR}/stream_tester')

        else:
            return streaming_query.trigger(processingTime = trigger) \
                                  .queryName('invoice_line_items') \
                                  .option('checkpointLocation', f'{self.BASE_DIR}/checkpoint/invoices') \
                                  .start(f'{self.BASE_DIR}/stream_tester')
        

    def process(self, trigger='batch'):

        print('Starting invoice processing stream...', end='')

        raw_invoice_df = self.read_invoices()
        exploded_df = self.explode_invoices(raw_invoice_df)
        flattened_df = self.flatten_invoices(exploded_df)
        streaming_query = self.append_invoices(flattened_df, trigger=trigger)

        print(' Done.')

        return streaming_query

In [23]:
class StreamBatchInvoiceTestSuite():

    def __init__(self):
        
        self.BASE_DIR = '..'


    def clean_up_for_testing(self):
        
        import shutil
        import os

        print('Starting cleaning...', end='')

        spark.sql('DROP TABLE IF EXISTS word_counts')
        shutil.rmtree(f'{self.BASE_DIR}/checkpoint/invoices')
        os.makedirs(f'{self.BASE_DIR}/checkpoint/invoices')

        shutil.rmtree(f'{self.BASE_DIR}/test_data/invoices')
        os.makedirs(f'{self.BASE_DIR}/test_data/invoices')

        print(' Done.')

    
    def get_data(self, file_num):

        import shutil

        print('\tGetting data...', end='')

        shutil.copyfile(src=f'{self.BASE_DIR}/data/invoices/invoices-{file_num}.json', 
                        dst=f'{self.BASE_DIR}/test_data/invoices/invoices-{file_num}.json')
        
        print(' Done.')

    
    def assert_result(self, expected_result):
        
        print('\tStarting validation...', end='')

        actual_result = spark.sql(
            '''
            SELECT COUNT(*)
            FROM invoice_line_items
            '''
        ).collect()[0][0]

        assert expected_result == actual_result, f'Test failed! Expected result is {expected_result}. Got {actual_result} instead.'
        
        print(' Done.')


    def wait_for_microbatch(self, sleep_time=15):

        import time

        print(f'\tWaiting for {sleep_time} seconds...', end='')
        time.sleep(sleep_time)

        print(' Done.')


    def run_stream_tests(self, trigger='20 seconds'):

        sleep_time = 10
        self.clean_up_for_testing()
        invoice_stream = InvoiceStreamBatch()
        streaming_query = invoice_stream.process(trigger=trigger)

        expected_results = [1253, 2510, 3994]
        for i in range(len(expected_results)):

            print(f'Testing file No.{i + 1}...')

            self.get_data(i + 1)
            self.wait_for_microbatch(sleep_time=sleep_time) # Only works if sleep_time >= 5

            self.assert_result(expected_results[i])

            print(f'File No.{i + 1} test passed.\n')

        streaming_query.stop()

    
    def run_batch_tests(self):

        sleep_time = 10
        self.clean_up_for_testing()
        invoice_stream = InvoiceStreamBatch()

        print(f'Testing file 1 and 2...')
        self.get_data(1)
        self.get_data(2)
        invoice_stream.process(trigger='batch')
        self.wait_for_microbatch(sleep_time=sleep_time)
        self.assert_result(2510)
        print(f'File 1, 2 test passed.\n')

        print(f'Testing file 3...')
        self.get_data(3)
        invoice_stream.process(trigger='batch')
        self.wait_for_microbatch(sleep_time=sleep_time)
        self.assert_result(3994)
        print(f'File 3 test passed.\n')           


In [14]:
invoice_stream_tester = StreamBatchInvoiceTestSuite()
invoice_stream_tester.run_stream_tests('3 seconds')

Starting cleaning... Done.
Starting invoice processing stream... Done.
Testing file No.1...
	Getting data... Done.
	Waiting for 10 seconds... Done.
	Starting validation... Done.
File No.1 test passed.

Testing file No.2...
	Getting data... Done.
	Waiting for 10 seconds... Done.
	Starting validation... Done.
File No.2 test passed.

Testing file No.3...
	Getting data... Done.
	Waiting for 10 seconds... Done.
	Starting validation... Done.
File No.3 test passed.



In [24]:
invoice_stream_tester.run_batch_tests()

Starting cleaning... Done.
Testing file 1 and 2...
	Getting data... Done.
	Getting data... Done.
Starting invoice processing stream... Done.
	Waiting for 10 seconds... Done.
	Starting validation... Done.
File 1, 2 test passed.

Testing file 3...
	Getting data... Done.
Starting invoice processing stream...

AnalysisException: This query does not support recovering from checkpoint location. Delete ../checkpoint/invoices/offsets to start over.

In [20]:
spark.stop()