In [1]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import from_json, col, explode
from pyspark.sql.types import *

In [2]:
# Spark session
spark = SparkSession.builder \
    .appName("KafkaConsumerGTFSA") \
    .master("local[*]") \
    .getOrCreate()

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/08/21 09:53:12 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
25/08/21 09:53:13 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.


In [3]:
# ClickHouse connection details
CLICKHOUSE_HOST = "clickhouse"
CLICKHOUSE_PORT = 8123
CLICKHOUSE_USER = "default"
CLICKHOUSE_PASS = "123"
CLICKHOUSE_DB = "gtfs_streaming"
CLICKHOUSE_TABLE = "gtfs_alerts"

In [4]:
# Define schema for GTFS Realtime JSON
schema = StructType([
    StructField("header", StructType([
        StructField("gtfsRealtimeVersion", StringType(), True),
        StructField("timestamp", StringType(), True)
    ]), True),
    StructField("entity", ArrayType(StructType([
        StructField("id", StringType(), True),
        StructField("alert", StructType([
            StructField("activePeriod", ArrayType(StructType([
                StructField("start", StringType(), True),
                StructField("end", StringType(), True)
            ]), True), True),
            StructField("informedEntity", ArrayType(StructType([
                StructField("agencyId", StringType(), True),
                StructField("routeId", StringType(), True)
            ]), True), True),
            StructField("headerText", StructType([
                StructField("translation", ArrayType(StructType([
                    StructField("text", StringType(), True),
                    StructField("language", StringType(), True)
                ]), True), True)
            ]), True),
            StructField("descriptionText", StructType([
                StructField("translation", ArrayType(StructType([
                    StructField("text", StringType(), True),
                    StructField("language", StringType(), True)
                ]), True), True)
            ]), True)
        ]), True)
    ]), True), True)
])

In [5]:
# Read from Kafka
raw_df = spark.readStream.format("kafka") \
    .option("kafka.bootstrap.servers", "broker:29092") \
    .option("subscribe", "gtfs-alerts") \
    .option("startingOffsets", "earliest") \
    .load()

In [6]:
kafka_df = raw_df.selectExpr("CAST(value AS STRING) AS json_str", "topic")
alert_df = kafka_df.select(from_json(col("json_str"), schema).alias("data")).select("data.*")

In [7]:
# Flatten nested JSON
df_entity = alert_df.select("header", explode("entity").alias("entity"))

In [8]:
df_active_period = df_entity.select(
    col("header.gtfsRealtimeVersion").alias("gtfsRealtimeVersion"),
    col("header.timestamp").alias("timestamp"),
    col("entity.id").alias("id"),
    explode("entity.alert.activePeriod").alias("activePeriod"),
    col("entity.alert.informedEntity").alias("informedEntity"),
    col("entity.alert.headerText").alias("headerText"),
    col("entity.alert.descriptionText").alias("descriptionText")
).filter(col("activePeriod").isNotNull())

In [9]:
df_informed_entity = df_active_period.select(
    "gtfsRealtimeVersion",
    "timestamp",
    "id",
    col("activePeriod.start").alias("activePeriod_start"),
    col("activePeriod.end").alias("activePeriod_end"),
    explode("informedEntity").alias("informedEntity"),
    "headerText",
    "descriptionText"
).filter(col("informedEntity").isNotNull())

In [10]:
df_header_text = df_informed_entity.select(
    "gtfsRealtimeVersion",
    "timestamp",
    "id",
    "activePeriod_start",
    "activePeriod_end",
    col("informedEntity.agencyId").alias("agencyId"),
    col("informedEntity.routeId").alias("routeId"),
    explode("headerText.translation").alias("headerTranslation"),
    "descriptionText"
).filter(col("headerTranslation").isNotNull())

In [11]:
df_final = df_header_text.select(
    "gtfsRealtimeVersion",
    "timestamp",
    "id",
    "activePeriod_start",
    "activePeriod_end",
    "agencyId",
    "routeId",
    col("headerTranslation.text").alias("header_text"),
    col("headerTranslation.language").alias("header_language"),
    explode("descriptionText.translation").alias("descriptionTranslation")
).filter(col("descriptionTranslation").isNotNull())

In [12]:
df_final = df_final.select(
    "gtfsRealtimeVersion",
    "timestamp",
    "id",
    "activePeriod_start",
    "activePeriod_end",
    "agencyId",
    "routeId",
    "header_text",
    "header_language",
    col("descriptionTranslation.text").alias("description_text"),
    col("descriptionTranslation.language").alias("description_language")
)

In [13]:
# JDBC details for Spark
clickhouse_url = f"jdbc:clickhouse://{CLICKHOUSE_HOST}:{CLICKHOUSE_PORT}/{CLICKHOUSE_DB}"
clickhouse_properties = {
    "user": CLICKHOUSE_USER,
    "password": CLICKHOUSE_PASS,
    "driver": "com.clickhouse.jdbc.ClickHouseDriver"
}

In [14]:
# Function to write each micro-batch to ClickHouse
def write_to_clickhouse(batch_df, batch_id):
    batch_df.write \
        .mode("append") \
        .jdbc(clickhouse_url, CLICKHOUSE_TABLE, properties=clickhouse_properties)

In [16]:
# Start streaming to ClickHouse
alert_query = df_final.writeStream \
    .outputMode("append") \
    .foreachBatch(write_to_clickhouse) \
    .trigger(processingTime="30 seconds") \
    .option("checkpointLocation", "check_points/alerts_checks") \
    .start()

25/08/21 09:53:34 WARN ResolveWriteToStream: spark.sql.adaptive.enabled is not supported in streaming DataFrames/Datasets and will be disabled.
25/08/21 09:53:35 WARN AdminClientConfig: These configurations '[key.deserializer, value.deserializer, enable.auto.commit, max.poll.records, auto.offset.reset]' were supplied but are not used yet.


In [None]:
# # Create ClickHouse table if it does not exist
# client = clickhouse_connect.get_client(
#     host=CLICKHOUSE_HOST,
#     port=CLICKHOUSE_PORT,
#     username=CLICKHOUSE_USER,
#     password=CLICKHOUSE_PASS
# )

# client.command(f"""
# CREATE DATABASE IF NOT EXISTS {CLICKHOUSE_DB}
# """)

# client.command(f"""
# CREATE TABLE IF NOT EXISTS {CLICKHOUSE_DB}.{CLICKHOUSE_TABLE} (
#     gtfsRealtimeVersion String,
#     timestamp String,
#     id String,
#     activePeriod_start String,
#     activePeriod_end String,
#     agencyId String,
#     routeId String,
#     header_text String,
#     header_language String,
#     description_text String,
#     description_language String
# ) ENGINE = MergeTree()
# ORDER BY id
# """)