In [0]:
# Refactor the code to class and functions for testing
class StreamWordCounts():

    def __init__(self):
        self.DATA_DIR = 'dbfs:/FileStore/test_data/text'
        self.CHECKPOINT_DIR = 'dbfs:/FileStore/checkpoint/word_counts'


    def get_raw_data(self):
        
        from pyspark.sql.functions import explode, split
        text_data = spark.readStream \
                         .format('text') \
                         .option('lineSep', '.') \
                         .load(f'{self.DATA_DIR}')

        return text_data.select(explode(split(text_data.value, ' ')).alias('word'))
    

    def get_quality_data(self, raw_words):

        from pyspark.sql.functions import lower, trim

        return raw_words.select(lower(trim(raw_words.word)).alias('cleaned_words')) \
                        .where('cleaned_words is not null') \
                        .where("cleaned_words rlike '[a-z]'")

    def get_word_counts(self, quality_words):

        return quality_words.groupBy('cleaned_words').count()


    def overwrite_word_counts(self, word_counts):

        return word_counts.writeStream \
                          .format('delta') \
                          .option('checkpointLocation', f'{self.CHECKPOINT_DIR}') \
                          .outputMode('complete') \
                          .toTable('word_counts') # Returns a streaming query


    def execute(self):

        print(f'\tExecuting Word Count...', end='')

        raw_words = self.get_raw_data()
        quality_words = self.get_quality_data(raw_words)
        word_counts = self.get_word_counts(quality_words)
        streaming_query = self.overwrite_word_counts(word_counts)

        print(' Done.')
        return streaming_query
        

In [0]:
class StreamWordCountsTestSuite():

    def __init__(self):
        self.BASE_DIR = 'dbfs:/FileStore'
        self.DATA_DIR = 'dbfs:/FileStore/test_data/text/'


    def clean_up_for_testing(self):

        print('Starting cleaning...', end='')

        spark.sql('DROP TABLE IF EXISTS word_counts')
        dbutils.fs.rm('/user/hive/warehouse/word_counts', recurse=True)
        dbutils.fs.rm(f'{self.BASE_DIR}/checkpoint', recurse=True)
        dbutils.fs.rm(f'{self.BASE_DIR}/test_data/text/', recurse=True)
        dbutils.fs.mkdirs(f'{self.DATA_DIR}')

        print(' Done.')

    
    def get_data(self, file_num):

        print('Getting data...', end='')

        dbutils.fs.mkdirs(f'{self.BASE_DIR}/test_data/text/')
        dbutils.fs.cp(f'{self.BASE_DIR}/data/text/text_data_{file_num}.txt', 
                      f'{self.DATA_DIR}')
        
        print(' Done.')

    
    def assert_result(self, expected_result):
        
        actual_result = spark.sql(
            '''
            SELECT SUM(count)
            FROM word_counts
            WHERE SUBSTR(cleaned_words, 1, 1) == 's'
            '''
        ).collect()[0][0]

        assert expected_result == actual_result, f'Test failed! Expected result is {expected_result}. Got {actual_result} instead.'


    def run_tests(self):

        import time
        sleep_time = 10

        self.clean_up_for_testing()
        word_counter = StreamWordCounts()
        streaming_query = word_counter.execute()

        expected_results = [25, 32, 37]
        for i in range(len(expected_results)):

            print(f'Testing file No.{i + 1}...')

            self.get_data(i + 1)
            print(f'\tWaiting for {sleep_time} seconds...')
            time.sleep(sleep_time)

            self.assert_result(expected_results[i])

            print(f'File No.{i + 1} test completed.\n')

        streaming_query.stop()


In [0]:
stream_word_counts_tester = StreamWordCountsTestSuite()
stream_word_counts_tester.run_tests()

Starting cleaning... Done.
	Executing Word Count... Done.
Testing file No.1...
Getting data... Done.
	Waiting for 30 seconds...
File No.1 test completed.

Testing file No.2...
Getting data... Done.
	Waiting for 30 seconds...
File No.2 test completed.

Testing file No.3...
Getting data... Done.
	Waiting for 30 seconds...
File No.3 test completed.

