In [0]:
%run ./01-config

In [0]:
from pyspark.sql import functions as F
from pyspark.sql.functions import col, expr, broadcast, when, split, to_json, to_timestamp, get, explode_outer

silver_checkpoint = f"{checkpoint_base_path}/silver"

# Ensure all required silver tables exist in the catalog
def ensure_silver_tables(): 
    spark.sql(f"""
    CREATE TABLE IF NOT EXISTS {schema_silver}.line_status(
        line_id STRING,
        service_type STRING,
        severity_code BIGINT,
        severity_description STRING,
        disruption_category STRING,
        disruption_description STRING,
        disruption_from_date TIMESTAMP,
        disruption_to_date TIMESTAMP,
        is_service_disrupted BOOLEAN,
        event_timestamp TIMESTAMP
    )
    """)

    spark.sql(f"""
    CREATE TABLE IF NOT EXISTS {schema_silver}.bus_arrivals(
        arrival_id STRING,
        operation_type BIGINT,
        vehicle_id STRING,
        naptan_id STRING,
        station_name STRING,
        line_id STRING,
        platform_name STRING,
        direction STRING,
        bearing BIGINT,
        trip_id BIGINT,
        base_version BIGINT,
        destination_naptan_id STRING,
        destination_name STRING,
        event_timestamp TIMESTAMP,
        time_to_station BIGINT,
        current_location STRING,
        towards STRING,
        expected_arrival TIMESTAMP,
        time_to_live TIMESTAMP
    )
    """)

    spark.sql(f"""
    CREATE TABLE IF NOT EXISTS {schema_silver}.london_boroughs(
        borough_code STRING,
        borough_name STRING,
        hectares DOUBLE,
        shape_area DOUBLE,
        shape_length DOUBLE,
        geometry_geojson STRING
    )
    """)

    spark.sql(f"""
    CREATE TABLE IF NOT EXISTS {schema_silver}.stop_points(
        naptan_id STRING,
        indicator STRING,
        ics_code BIGINT,
        stop_type STRING,
        common_name STRING,
        longitude DOUBLE,
        latitude DOUBLE
    )    
    """)

    spark.sql(f"""
    CREATE TABLE IF NOT EXISTS {schema_silver}.bus_arrival_events(
        arrival_event_id BIGINT, 
        line_id STRING,
        vehicle_id STRING,
        naptan_id STRING,
        station_name STRING,
        platform_name STRING,
        direction STRING,
        destination_name STRING,
        time_to_station BIGINT,
        expected_arrival TIMESTAMP,
        time_to_live TIMESTAMP,
        is_service_disrupted BOOLEAN,
        severity_code BIGINT, 
        severity_description STRING, 
        event_timestamp TIMESTAMP
    )     
    """)

    spark.sql(f"""
    CREATE TABLE IF NOT EXISTS {schema_silver}.bus_stops_geo(
        naptan_id STRING,
        stop_name STRING,
        stop_type STRING,
        borough_code STRING,
        borough_name STRING,
        longitude DOUBLE,
        latitude DOUBLE
    )
    """)

    spark.sql(f"""
    CREATE TABLE IF NOT EXISTS {schema_silver}.line_disruption_geo(
        line_id STRING,
        service_type STRING,
        severity_code BIGINT,
        severity_description STRING,
        disruption_category STRING,
        disruption_description STRING,
        disruption_from_date TIMESTAMP,
        disruption_to_date TIMESTAMP,
        is_service_disrupted BOOLEAN,
        borough_code STRING,
        borough_name STRING,
        longitude DOUBLE,
        latitude DOUBLE,
        event_timestamp TIMESTAMP
    )
    """)
    print("\n✅ Ensured all silver tables exist")


# Helper function to perform upsert using a merge query in foreachBatch
def upserter(df_micro_batch, batch_id, merge_query, temp_view):
    df_micro_batch.createOrReplaceTempView(temp_view)
    df_micro_batch._jdf.sparkSession().sql(merge_query)
    print(f"Batch {batch_id} for {temp_view} processed.")


# Stream upsert for line_status table
def upsert_line_status(once = True, processing_time="15 seconds", startingVersion=0):
    merge_query = f"""
    MERGE INTO {schema_silver}.line_status AS target
    USING line_status_delta AS source
    ON target.line_id = source.line_id 
    AND target.event_timestamp = source.event_timestamp
    WHEN NOT MATCHED THEN INSERT *
    """

    df_silver = (
        spark.readStream
            .option("startingVersion", startingVersion)
            .option("ignoreDeletes", True)
            .table(f"{schema_bronze}.line_status_bz")
            .select(
                col("id").alias("line_id"),
                get(col("serviceTypes"),0)["name"].alias("service_type"),
                get(col("lineStatuses"), 0)["statusSeverity"].alias("severity_code"),
                get(col("lineStatuses"), 0)["statusSeverityDescription"].alias("severity_description"),
                get(col("lineStatuses"), 0)["disruption"]["category"].alias("disruption_category"),
                get(col("lineStatuses"), 0)["disruption"]["description"].alias("disruption_description"),
                to_timestamp(get(get(col("lineStatuses"), 0)["validityPeriods"], 0)["fromDate"]).alias("disruption_from_date"),
                to_timestamp(get(get(col("lineStatuses"), 0)["validityPeriods"], 0)["toDate"]).alias("disruption_to_date"),
                get(get(col("lineStatuses"), 0)["validityPeriods"], 0)["isNow"].alias("is_service_disrupted"),
                col("created").cast("timestamp").alias("event_timestamp")
            )
            .withWatermark("event_timestamp", "30 seconds")
            .dropDuplicates(["line_id", "event_timestamp"])
    )

    stream_writer = (df_silver.writeStream
            .foreachBatch(lambda df, id: upserter(df, id, merge_query, "line_status_delta"))
            .outputMode("update")
            .option("checkpointLocation", f"{silver_checkpoint}/line_status") 
            .queryName("line_status_upsert_stream")   

    )

    if once:
        stream_writer.trigger(availableNow=True).start()
    else:
        stream_writer.trigger(processingTime=processing_time).start()



# Stream upsert for bus_arrivals table
def upsert_bus_arrivals(once = True, processing_time="15 seconds", startingVersion=0):
    merge_query = f"""
    MERGE INTO {schema_silver}.bus_arrivals AS target 
    USING bus_arrivals_delta AS source
    ON target.arrival_id = source.arrival_id
    AND target.event_timestamp = source.event_timestamp
    WHEN NOT MATCHED THEN INSERT *
    """

    df_silver = (
        spark.readStream
            .option("startingVersion", startingVersion)
            .option("ignoreDeletes", True)
            .table(f"{schema_bronze}.bus_arrivals_bz")
            .select(
                col("id").alias("arrival_id"),
                col("operationType").alias("operation_type"),
                col("vehicleId").alias("vehicle_id"),
                col("naptanId").alias("naptan_id"),
                col("stationName").alias("station_name"),
                col("lineId").alias("line_id"),
                col("platformName").alias("platform_name"),"direction","bearing",
                col("tripId").alias("trip_id"),
                col("baseVersion").alias("base_version"),
                col("destinationNaptanId").alias("destination_naptan_id"),
                col("destinationName").alias("destination_name"),
                col("timestamp").cast("timestamp").alias("event_timestamp"),
                col("timeToStation").alias("time_to_station"),
                col("currentLocation").alias("current_location"),"towards",
                col("expectedArrival").cast("timestamp").alias("expected_arrival"),
                col("timeToLive").cast("timestamp").alias("time_to_live")
            )
            .withWatermark("event_timestamp", "30 seconds")
            .dropDuplicates(["arrival_id", "event_timestamp"])
    )

    stream_writer = (df_silver.writeStream
                    .foreachBatch(lambda df, id: upserter(df, id, merge_query, "bus_arrivals_delta"))
                    .outputMode("update")
                    .option("checkpointLocation", f"{silver_checkpoint}/bus_arrivals") 
                    .queryName("bus_arrival_upsert_stream")
    )

    if once:
        stream_writer.trigger(availableNow=True).start()
    else:
        stream_writer.trigger(processingTime=processing_time).start()


# Stream upsert for stop_points table
def upsert_stop_points(once = True, processing_time="15 seconds", startingVersion=0):
    merge_query = f"""
    MERGE INTO {schema_silver}.stop_points AS target
    USING stop_points_delta AS source
    ON target.naptan_id = source.naptan_id
    WHEN NOT MATCHED THEN INSERT *
    """

    df_silver = (
        spark.readStream
            .option("startingVersion", startingVersion)
            .option("ignoreDelete", True)
            .table(f"{schema_bronze}.stop_points_bz")
            .select(explode_outer(col("stopPoints")).alias("stop"))
            .select(
                col("stop.naptanId").alias("naptan_id"),
                split(col("stop.indicator"), ",")[0].alias("indicator"),
                col("stop.icsCode").cast("bigint").alias("ics_code"),
                split(col("stop.stopType"), ",")[0].alias("stop_type"),
                col("stop.hubNaptanCode").alias("hub_naptan_code"),
                split(col("stop.commonName"), ",")[0].alias("common_name"),
                col("stop.lon").alias("longitude"),
                col("stop.lat").alias("latitude")
            )
    )

    stream_writer = (df_silver.writeStream
            .foreachBatch(lambda df, id: upserter(df, id, merge_query, "stop_points_delta"))
            .outputMode("update")
            .option("checkpointLocation", f"{silver_checkpoint}/stop_points") 
            .queryName("stop_points_upsert_stream")   

    )

    if once:
        stream_writer.trigger(availableNow=True).start()
    else:
        stream_writer.trigger(processingTime=processing_time).start()



# Batch write for london_boroughs table
def write_london_boroughs():
    df_silver = (
        spark.read
            .table(f"{schema_bronze}.london_boroughs_bz")
            .select(
                col("properties.CODE").alias("borough_code"),
                col("properties.BOROUGH").alias("borough_name"),
                col("properties.HECTARES").alias("hectares"),
                col("properties.Shape__Area").alias("shape_area"),
                col("properties.Shape__Length").alias("shape_length"),
                to_json(
                    expr("named_struct('type','Polygon','coordinates', geometry)")
                ).alias("geometry_geojson")
            )
            .dropna(subset=["borough_code"])
            .dropDuplicates(["borough_code"])
    )

    
    (
    df_silver.write
            .format("delta")
            .mode("overwrite")
            .option("overwriteSchema", "true")
            .saveAsTable(f"{schema_silver}.london_boroughs")
    )



# Batch upsert for bus_arrival_events derived table
def upsert_bus_arrival_events_batch():
    spark.sql(f"""
        MERGE INTO {schema_silver}.bus_arrival_events AS target
        USING (
            SELECT
                ba.arrival_id AS arrival_event_id,
                ba.line_id,
                ba.vehicle_id,
                ba.naptan_id,
                ba.station_name,
                ba.platform_name,
                ba.direction,
                ba.destination_name,
                ba.time_to_station,
                ba.expected_arrival,
                ba.time_to_live,
                COALESCE(ls.is_service_disrupted, FALSE) AS is_service_disrupted,
                ls.severity_code,
                ls.severity_description,
                ba.event_timestamp
            FROM {schema_silver}.bus_arrivals ba
            LEFT JOIN {schema_silver}.line_status ls
              ON ba.line_id = ls.line_id
             AND ba.event_timestamp BETWEEN
                 ls.event_timestamp - INTERVAL 30 SECONDS
             AND ls.event_timestamp + INTERVAL 30 SECONDS
        ) AS source
        ON target.arrival_event_id = source.arrival_event_id
       AND target.event_timestamp   = source.event_timestamp
        WHEN NOT MATCHED THEN INSERT *
    """)


# Batch upsert for bus_stops_geo derived table (spatial enrichment)
def upsert_bus_stops_geo_batch():
    merge_query = f"""
    MERGE INTO {schema_silver}.bus_stops_geo AS target
    USING bus_stops_geo_delta AS source
    ON target.naptan_id = source.naptan_id
    WHEN NOT MATCHED THEN INSERT *
    """

    # Read stop_points and create geometry
    df_stop_point = (
        spark.read
            .table(f"{schema_silver}.stop_points")
            .select(
                "naptan_id",
                col("common_name").alias("stop_name"),"stop_type","longitude","latitude",
                expr("ST_SetSRID(ST_Point(longitude, latitude), 4326)").alias("stop_geom")
            )
    )

    # Read london_boroughs and create geometry
    df_london_borough = (
        spark.read
            .table(f"{schema_silver}.london_boroughs")
            .select(
                "borough_code",col("borough_name"),
                expr(
                    "ST_GeomFromGeoJSON("
                    "to_json(named_struct('type','Polygon','coordinates',"
                    "from_json(geometry_geojson,"
                    "'struct<coordinates:struct<coordinates:array<array<array<double>>>>>')"
                    ".coordinates.coordinates)))"
                ).alias("borough_geom")
            )
    )

    # Spatial join to assign borough to stop
    df_silver = (
        df_stop_point
            .join(broadcast(df_london_borough),expr("ST_Contains(borough_geom, stop_geom)"),"inner")
            .select("naptan_id","stop_name","stop_type","borough_code","borough_name","longitude","latitude")
    )

    df_silver.createOrReplaceTempView("bus_stops_geo_delta")
    spark.sql(merge_query)


# Batch upsert for line_disruption_geo table (spatial enrichment)
def upsert_line_disruption_geo_batch():
    spark.sql(f"""
        MERGE INTO {schema_silver}.line_disruption_geo AS target
        USING (
            SELECT
                ls.line_id,
                ls.service_type,
                ls.severity_code,
                ls.severity_description,
                ls.disruption_category,
                ls.disruption_description,
                ls.disruption_from_date,
                ls.disruption_to_date,
                ls.is_service_disrupted,
                bs.borough_code,
                bs.borough_name,
                bs.longitude,
                bs.latitude,
                ls.event_timestamp
            FROM {schema_silver}.line_status ls
            JOIN {schema_silver}.bus_stops_geo bs
        ) AS source
        ON  target.line_id = source.line_id
        AND target.disruption_description = source.disruption_description
        AND target.borough_name = source.borough_name
        WHEN NOT MATCHED THEN INSERT *
    """)


# Await all active streaming queries if once=True
def await_queries(once):
    if once:
        for q in spark.streams.active:
            q.awaitTermination()


# Orchestrate all silver streaming and batch upserts
def upsert_silver(once=True, processing_time="5 seconds"):
    import time
    start = int(time.time())

    ensure_silver_tables()

    print("Running Silver streaming layer...")
    upsert_line_status(once, processing_time)
    upsert_bus_arrivals(once, processing_time)
    upsert_stop_points(once, processing_time)

    await_queries(once)
    print(f"Completed silver streaming layer in {int(time.time()) - start} seconds")

    write_london_boroughs()

    print("Processing derived table...")
    upsert_bus_arrival_events_batch()
    upsert_bus_stops_geo_batch()
    upsert_line_disruption_geo_batch()
    print(f"✅ Completed silver batch enrichment layer in {int(time.time()) - start} seconds")

    

# Assert that a table has at least min_count records
def assert_count(schema_silver, table_name, min_count=1):
    print(f"Validating record counts in {table_name}...", end="")
    actual_count = spark.read.table(f"{schema_silver}.{table_name}").where(filter_expr).count()
    assert actual_count >= min_count, (f"{table_name} has {actual} records, expected >= {min_count}")
    print("Success")

# Validate all silver tables for minimum record counts
def validate_silver():
    import time
    start = int(time.time())
    schema=schema_silver
    print("\nStarting Silver Layer Validation...")

    assert_count(schema_silver, "line_status")
    assert_count(schema_silver, "bus_arrivals")
    assert_count(schema_silver, "london_boroughs")
    assert_count(schema_silver, "stop_points")

    assert_count(schema_silver, "bus_arrival_events")
    assert_count(schema_silver, "bus_stops_geo")

    assert_count(schema_silver, "line_disruption_geo")

    print(f"✅ Silver Layer Validation completed in {int(time.time()) - start} seconds")