In [1]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import from_json, col, expr, to_timestamp, lit, when, unix_timestamp, to_date, from_unixtime, abs, row_number
from pyspark.sql.types import *
from pyspark.sql import Window
from pyspark.sql.functions import udf
from math import radians, cos, sin, asin, sqrt
import logging
from pyspark.sql.functions import broadcast
from pyspark.storagelevel import StorageLevel

In [2]:
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("GTFSRealtimeVPProcessing")

In [3]:
spark = SparkSession.builder \
    .appName("GTFSRealtimeMonitoring") \
    .master("local[*]") \
    .config("spark.sql.session.timeZone", "America/New_York") \
    .config("spark.executor.memory", "2g") \
    .getOrCreate()


spark.conf.set("spark.sql.shuffle.partitions", "10")  # Adjust based on cluster size
spark.conf.set("spark.sql.streaming.forceDeleteTempCheckpointLocation", "true")

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/08/21 09:56:27 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
25/08/21 09:56:28 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.
25/08/21 09:56:28 WARN Utils: Service 'SparkUI' could not bind on port 4041. Attempting port 4042.
25/08/21 09:56:28 WARN Utils: Service 'SparkUI' could not bind on port 4042. Attempting port 4043.
25/08/21 09:56:28 WARN Utils: Service 'SparkUI' could not bind on port 4043. Attempting port 4044.


In [4]:
vehicle_schema = StructType([
    StructField("header", StructType([
        StructField("gtfsRealtimeVersion", StringType()),
        StructField("timestamp", StringType())
    ])),
    StructField("entity", ArrayType(StructType([
        StructField("id", StringType()),
        StructField("vehicle", StructType([
            StructField("trip", StructType([
                StructField("tripId", StringType()),
                StructField("routeId", StringType()),
                StructField("startDate", StringType())
            ])),
            StructField("position", StructType([
                StructField("latitude", DoubleType()),
                StructField("longitude", DoubleType())
            ])),
            StructField("timestamp", StringType())
        ]))
    ])))
])

In [5]:
clickhouse_url = "jdbc:clickhouse://clickhouse:8123"
clickhouse_properties = {
    "user": "default",
    "password": "123",  # TODO: Replace with secure credential management
    "driver": "com.clickhouse.jdbc.ClickHouseDriver",
    "isolationLevel": "NONE"
}

In [6]:
def read_from_clickhouse(table_name):
    return spark.read \
        .format("jdbc") \
        .option("url", f"{clickhouse_url}/gtfs_batch") \
        .option("dbtable", table_name) \
        .option("user", clickhouse_properties["user"]) \
        .option("password", clickhouse_properties["password"]) \
        .option("driver", clickhouse_properties["driver"]) \
        .load()

In [7]:
from pyspark.sql import functions as F
depot_to_borough = {
    "CS": "Queens", "QV": "Queens", "JA": "Queens", "FP": "Brooklyn", "CSg": "Queens",
    "BP": "Queens", "LG": "Queens", "FR": "Queens", "GA": "Brooklyn", "EN": "Brooklyn",
    "JG": "Brooklyn", "FB": "Brooklyn", "UP": "Brooklyn", "SC": "Brooklyn", "MQ": "Manhattan",
    "QVH": "Manhattan", "MV": "Manhattan", "QU": "Manhattan", "KB": "Manhattan", "GH": "Bronx",
    "WF": "Bronx", "EC": "Bronx", "CA": "Staten Island", "CH": "Staten Island", "MD": "Staten Island",
    "YU": "Staten Island",
    "JK": "Brooklyn",  # For 43206923-JKPC5-JK_C5-Weekday-03
    "CP": "Brooklyn",  # For 43120117-CPPC5-CP_C5-Weekday-03
    "OF": "Manhattan",  # For OF_C5-Weekday-...
    "OH": "Manhattan",  # For OH_C5-Weekday-...
    "YO": "Manhattan"   # For 43139148-YOPC5-YO_C5-Weekday-03
}
mapping_expr = F.create_map([F.lit(x) for x in sum(depot_to_borough.items(), ())])

In [8]:
stops = read_from_clickhouse("stops").filter(col("is_current") == True).select("stop_id", "stop_name", "stop_lat", "stop_lon")
stop_times = read_from_clickhouse("stop_times").filter(col("is_current") == True).select("trip_id", "arrival_time", "stop_id")
stop_boroughs = read_from_clickhouse("stop_boroughs")

In [9]:
stop_times_with_stops = stop_times.join(stops, "stop_id", "inner")


stops.persist(StorageLevel.MEMORY_AND_DISK)
stop_times.persist(StorageLevel.MEMORY_AND_DISK)
stop_boroughs.persist(StorageLevel.MEMORY_AND_DISK)
stop_times_with_stops.persist(StorageLevel.MEMORY_AND_DISK)

DataFrame[stop_id: decimal(20,0), trip_id: string, arrival_time: string, stop_name: string, stop_lat: double, stop_lon: double]

In [10]:
def haversine(lon1, lat1, lon2, lat2):
    if None in (lon1, lat1, lon2, lat2):
        return None
    lon1, lat1, lon2, lat2 = map(radians, [lon1, lat1, lon2, lat2])
    dlon = lon2 - lon1
    dlat = lat2 - lat1
    a = sin(dlat / 2)**2 + cos(lat1) * cos(lat2) * sin(dlon / 2)**2
    c = 2 * asin(sqrt(a))
    r = 6371000  # Radius of Earth in meters
    return c * r

In [11]:
haversine_udf = udf(haversine, DoubleType())

In [12]:
vp_raw_df = spark.readStream \
    .format("kafka") \
    .option("kafka.bootstrap.servers", "broker:29092") \
    .option("subscribe", "gtfs-vehicle-positions") \
    .option("startingOffsets", "earliest") \
    .load()

In [13]:
vp_kafka_df = vp_raw_df.selectExpr("CAST(value AS STRING) AS json_str")
vp_df = vp_kafka_df.select(F.from_json(F.col("json_str"), vehicle_schema).alias("data")) \
    .select(
        "data.header.gtfsRealtimeVersion",
        F.col("data.header.timestamp").cast("long").cast("timestamp").alias("header_timestamp"),
        F.expr("explode(data.entity) as entity")
    ) \
    .withColumn("vehicle_timestamp", 
                F.when(F.col("entity.vehicle.timestamp").cast("long").isNotNull(), 
                       F.to_timestamp(F.col("entity.vehicle.timestamp").cast("long"))).otherwise(F.lit(None))) \
    .withWatermark("vehicle_timestamp", "2 minutes")

In [14]:
vp_exploded_df = vp_df.select(
    col("gtfsRealtimeVersion").alias("gtfs_version"),
    col("header_timestamp"),
    col("entity.id").alias("entity_id"),
    col("entity.vehicle.trip.tripId").alias("vp_trip_id"),
    col("entity.vehicle.trip.routeId").alias("route_id"),
    col("entity.vehicle.trip.startDate").alias("vp_start_date"),
    col("entity.vehicle.position.latitude").alias("latitude"),
    col("entity.vehicle.position.longitude").alias("longitude"),
    col("vehicle_timestamp")
).filter(col("vp_trip_id").isNotNull())

In [15]:
def process_batch_to_dashboard(batch_df, batch_id):
    try:
        logger.info(f"Processing batch {batch_id} for dashboard")
        
        # Join streaming batch with static stop_times_with_stops
        joined_df = batch_df.join(stop_times_with_stops, batch_df["vp_trip_id"] == stop_times_with_stops["trip_id"], "inner")        

        
        # Compute scheduled_seconds from arrival_time (HH:MM:SS to seconds)
        joined_df = joined_df.withColumn("scheduled_seconds", 
                                         expr("cast(split(arrival_time, ':')[0] as int) * 3600 + cast(split(arrival_time, ':')[1] as int) * 60 + cast(split(arrival_time, ':')[2] as int)"))
        
        # Compute scheduled_arrival timestamp
        joined_df = joined_df.withColumn("scheduled_arrival", 
                                         to_timestamp(from_unixtime(unix_timestamp(to_date(col("vp_start_date"), "yyyyMMdd")) + col("scheduled_seconds"))))
        
        # Filter to stops where scheduled_arrival <= vehicle_timestamp (stops the vehicle should have arrived at)
        filtered_df = joined_df.filter((col("scheduled_arrival") <= col("vehicle_timestamp")) | col("trip_id").isNull())
        
        # Window to select the latest scheduled stop (max scheduled_arrival <= vehicle_timestamp)
        window_spec = Window.partitionBy("entity_id", "vp_trip_id", "vehicle_timestamp").orderBy(col("scheduled_arrival").desc())
        selected_df = filtered_df.withColumn("row_num", row_number().over(window_spec)).filter(col("row_num") == 1).drop("row_num")
        
        # Compute distance to the selected stop (if any)
        selected_df = selected_df.withColumn("distance", haversine_udf(col("longitude"), col("latitude"), col("stop_lon"), col("stop_lat")))
        
        # Determine status and delay
        selected_df = selected_df.withColumn("status", when(col("distance") <= 100, "arrived").otherwise("on_way"))
        selected_df = selected_df.withColumn("delay_seconds", when(col("status") == "arrived", unix_timestamp(col("vehicle_timestamp")) - unix_timestamp(col("scheduled_arrival"))).otherwise(lit(None)))
        
        # Join with stop_boroughs for borough (cast stop_id to String)
        selected_df = selected_df.withColumn("depot_code", F.regexp_extract(F.col("vp_trip_id"), r"^(?:[0-9]+-)?([A-Z]{2,3})(?:PC5-[A-Z]{2,3})?_.*", 1)) \
                                 .withColumn("borough", F.coalesce(mapping_expr[F.col("depot_code")], F.lit("Unknown"))) \
                                 .drop("depot_code")
        
        selected_df = selected_df.select(
            F.col("gtfs_version"),
            F.col("entity_id"),
            F.col("vp_trip_id"),
            F.col("route_id"),
            F.col("vp_start_date"),
            F.col("latitude"),
            F.col("longitude"),
            F.col("vehicle_timestamp"),
            F.col("header_timestamp"),
            F.col("stop_id"),
            F.col("stop_name"),
            F.col("stop_lat"),
            F.col("stop_lon"),
            F.col("scheduled_arrival"),
            F.col("distance"),
            F.col("status"),
            F.col("delay_seconds"),
            F.col("borough")
        )

        
        
        
        # Write to dashboard table
        selected_df.write \
            .format("jdbc") \
            .option("url", f"{clickhouse_url}/gtfs_dashboard") \
            .option("dbtable", "trip_monitoring2") \
            .option("user", clickhouse_properties["user"]) \
            .option("password", clickhouse_properties["password"]) \
            .option("driver", clickhouse_properties["driver"]) \
            .option("jdbcCompliant", "false") \
            .option("batchsize", 10000) \
            .mode("append") \
            .save()
        
        logger.info(f"Batch {batch_id} processed and written to dashboard")
    except Exception as e:
        logger.error(f"Error processing batch {batch_id}: {e}")

In [None]:
vp_query = vp_exploded_df.writeStream \
    .outputMode("append") \
    .foreachBatch(process_batch_to_dashboard) \
    .trigger(processingTime="30 seconds") \
    .option("checkpointLocation", "check_points/trips_monitoring_checks") \
    .start() 

vp_query.awaitTermination()

INFO:py4j.java_gateway:Callback Server Starting
INFO:py4j.java_gateway:Socket listening on ('127.0.0.1', 38177)
25/08/21 09:56:35 WARN ResolveWriteToStream: spark.sql.adaptive.enabled is not supported in streaming DataFrames/Datasets and will be disabled.
25/08/21 09:56:36 WARN AdminClientConfig: These configurations '[key.deserializer, value.deserializer, enable.auto.commit, max.poll.records, auto.offset.reset]' were supplied but are not used yet.
INFO:py4j.clientserver:Python Server ready to receive messages
INFO:py4j.clientserver:Received command c on object id p0
INFO:GTFSRealtimeVPProcessing:Processing batch 0 for dashboard
25/08/21 09:56:39 WARN package: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.
INFO:GTFSRealtimeVPProcessing:Batch 0 processed and written to dashboard        
25/08/21 09:57:08 WARN ProcessingTimeExecutor: Current batch is falling behind. The trigger interval i