In [0]:
%run ./01-config

In [0]:
from pyspark.sql import functions as F

def consume_autoloader(*, source_subdir, target_table, once = True, processing_time = "5 seconds"):

    # Define base volume and paths for raw data, checkpoints, and schema evolution
    file_path = f"{raw_base_path}/{source_subdir}"
    bronze_checkpoint = (f"{checkpoint_base_path}/bronze/{target_table}")
    schema_path = (f"{checkpoint_base_path}/schemas/{target_table}")

    # Read raw stream data into a dataframe using Autoloader
    df_stream = (
        spark.readStream
            .format("cloudFiles")
            .option("cloudFiles.format", "json")
            .option("cloudFiles.inferColumnTypes", "true")
            .option("cloudFiles.schemaLocation", schema_path)
            .option("cloudFiles.schemaEvolutionMode", "addNewColumns")
            .option("recursiveFileLookup", "true")
            .option("maxFilesPerTrigger", 1)
            .load(file_path)
            .withColumn("_ingest_time", F.current_timestamp())  # Add ingestion timestamp
            .withColumn("_source_file", F.input_file_name())    # Add source file name
    )

    # Write streaming data to delta table
    stream_writer = (
        df_stream.writeStream
            .format("delta")
            .outputMode("append")
            .option("checkpointLocation", bronze_checkpoint)
            .queryName(target_table)
    )

    # Trigger streaming write either once or continuously
    if once:
        stream_writer.trigger(availableNow=True).toTable(f"{schema_bronze}.{target_table}")
    else:
        stream_writer.trigger(processingTime=processing_time).toTable(f"{schema_bronze}.{target_table}")



def consume_bus_arrivals(once=True, processing_time="5 seconds"):
    # Consume bus arrivals data into bronze table
    consume_autoloader(
        source_subdir="arrivals",
        target_table="bus_arrivals_bz",
        once=once,
        processing_time=processing_time,
    )


def consume_line_status(once=True, processing_time="5 seconds"):
    # Consume line status data into bronze table
    consume_autoloader(
        source_subdir="line_status",
        target_table="line_status_bz",
        once=once,
        processing_time=processing_time,
    )

def consume_stop_points(once=True, processing_time="5 seconds"):
    # Consume stop points data into bronze table
    consume_autoloader(
        source_subdir="stop_points",
        target_table="stop_points_bz",
        once=once,
        processing_time=processing_time,
    )


def consume_london_boroughs(once=True, processing_time="5 seconds"):
    # Consume London boroughs data into bronze table
    consume_autoloader(
        source_subdir="london_boroughs",
        target_table="london_boroughs_bz",
        once=once,
        processing_time=processing_time,
    )



# Orchestrate bronze layer consumption
def consume_bronze(once=True, processing_time="5 seconds"):
    import time
    start = int(time.time())
    print("\nStarting bronze layer consumption...")

    # Consume all bronze tables
    consume_bus_arrivals(once, processing_time)
    consume_line_status(once, processing_time)
    consume_stop_points(once, processing_time)
    consume_london_boroughs(once, processing_time)

    # Wait for all streams to finish if running in 'once' mode
    if once:
        for stream in spark.streams.active:
            stream.awaitTermination()

    print(f"✅ Completed bronze layer consumption in {int(time.time()) - start} seconds")



# Bronze layer validation
def assert_bronze_table(schema_bronze, table_name):
    # Validate that the bronze table is not empty and has no missing _source_file
    print(f"Validating {schema_bronze}.{table_name}...")
    df = spark.table(f"{schema_bronze}.{table_name}")
    stats = df.selectExpr(
            "count(*) as total_rows",
            "sum(CASE WHEN _source_file IS NULL THEN 1 ELSE 0 END) as missing_source"
        ).collect()[0]
    assert stats.total_rows > 0, f"{table_name} is empty"
    assert stats.missing_source == 0, f"{table_name} has missing _source_file"
    print(f"{table_name}: {stats.total_rows:,} records validated")


def validate_bronze():
    import time
    start = int(time.time())
    print("\nValidating bronze layer records...")

    # Validate all bronze tables
    assert_bronze_table(schema_bronze, "bus_arrivals_bz")
    assert_bronze_table(schema_bronze, "line_status_bz")
    assert_bronze_table(schema_bronze, "stop_points_bz")
    assert_bronze_table(schema_bronze, "london_boroughs_bz")

    print(f"✅ Bronze layer validation completed in {int(time.time()) - start} seconds")