In [0]:
import unittest
import time
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, StringType
from pyspark.sql.streaming import StreamingQueryListener

class TestStreamListener(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        # Create a local Spark Session for testing
        cls.spark = SparkSession.builder.appName("TestingListener").getOrCreate()

    def test_listener_captures_progress(self):
        # 1. Define a simple schema and a mock listener
        schema = StructType([StructField("data", StringType(), True)])
        
        class TestListener(StreamingQueryListener):
            def __init__(self):
                self.batch_count = 0
                self.started = False
            def onQueryStarted(self, event): self.started = True
            def onQueryProgress(self, event): self.batch_count += 1
            def onQueryTerminated(self, event): pass

        test_listener = TestListener()
        self.spark.streams.addListener(test_listener)

        # 2. Use a Memory Stream to simulate incoming data
        input_df = self.spark.readStream.format("rate").option("rowsPerSecond", 5).load()
        
        query = (input_df.writeStream
                 .format("memory")
                 .queryName("test_query")
                 .start())

        # 3. Wait a few seconds for micro-batches to process
        time.sleep(5)
        query.stop()

        # 4. Assertions
        self.assertTrue(test_listener.started, "Listener should have detected query start")
        self.assertGreater(test_listener.batch_count, 0, "Listener should have recorded at least one batch")

        # Cleanup
        self.spark.streams.removeListener(test_listener)

# Run the tests programmatically to avoid SystemExit
suite = unittest.TestLoader().loadTestsFromTestCase(TestStreamListener)
unittest.TextTestRunner(verbosity=2).run(suite)
