# 1. Start Spark Session and Import Libraries

In [None]:
# Import system and Spark initialization tools
import findspark
findspark.init()

# spark libraries
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, DoubleType, TimestampType, DateType, BooleanType

# ml libraries
from pyspark.ml import PipelineModel
from pyspark.ml.classification import GBTClassificationModel

# load kafka and mongo packages
# this was our attempt at getting mongo working however we ended up using parquet
kafka_package = "org.apache.spark:spark-sql-kafka-0-10_2.13:3.5.1"
mongo_package = "org.mongodb.spark:mongo-spark-connector_2.13:10.2.1"

# Build and configure the SparkSession:
#  - Name the app
#  - Allocate memory for driver and executors
#  - Include the Kafka connector package
spark = SparkSession.builder \
    .appName("FlightDelayStreamingPrediction") \
    .config("spark.driver.memory", "4g") \
    .config("spark.executor.memory", "4g") \
    .config("spark.jars.packages", f"{kafka_package},{mongo_package}") \
    .getOrCreate()

# 2. Load Models & Pipeline

In [None]:
# 1. Define the filesystem path where the saved preprocessing pipeline is stored
# 2. Define the filesystem path where the trained GBT Classifier model is stored
pipeline_path = "./flight_delay_gbt_pipeline_model"
model_path = "./flight_delay_gbt_model"

# 3. Load the preprocessing PipelineModel (e.g., StringIndexers, VectorAssembler, scalers)
# 4. Load the trained GBTClassificationModel for making predictions
loaded_pipeline_model = PipelineModel.load(pipeline_path)
loaded_rf_model = GBTClassificationModel.load(model_path)

                                                                                

# 3. Create Kafka Topic

In [None]:
# 1. Define the Kafka topic and bootstrap server(s)
kafka_topic = "flight_data_stream"
kafka_bootstrap_servers = "localhost:9092"

# 2. Create a streaming DataFrame by reading from Kafka
kafka_df = spark \
    .readStream \
    .format("kafka") \
    .option("kafka.bootstrap.servers", kafka_bootstrap_servers) \
    .option("subscribe", kafka_topic) \
    .option("startingOffsets", "latest") \
    .load()

# 3. Kafka “value” column comes in as binary; cast it to string for JSON parsing later
kafka_df = kafka_df.selectExpr("CAST(value AS STRING)")

# 4. Print the schema to verify the DataFrame structure
kafka_df.printSchema()

root
 |-- value: string (nullable = true)



# 4. Prepare for Streaming

In [8]:
# 1. Define the expected JSON schema for incoming Kafka messages
json_schema = StructType([
    StructField("FL_DATE", StringType(), True),
    StructField("AIRLINE", StringType(), True),
    StructField("AIRLINE_CODE", StringType(), True),
    StructField("ORIGIN", StringType(), True),
    StructField("DEST", StringType(), True),
    StructField("CRS_DEP_TIME", IntegerType(), True),
    StructField("CRS_ARR_TIME", IntegerType(), True),
    StructField("CRS_ELAPSED_TIME", DoubleType(), True),
    StructField("DISTANCE", DoubleType(), True),
])

# 2. Parse the JSON payload from the Kafka 'value' column into individual columns
parsed_stream_df = kafka_df \
    .select(F.from_json(F.col("value"), json_schema).alias("data")) \
    .select("data.*")

# 3. Feature engineering on the streaming DataFrame
parsed_stream_df = parsed_stream_df.withColumn("FL_DATE", F.to_date(F.col("FL_DATE"), "yyyy-MM-dd"))
parsed_stream_df = parsed_stream_df.withColumn("DEP_HOUR", (F.col("CRS_DEP_TIME") / 100).cast("integer"))
parsed_stream_df = parsed_stream_df.withColumn("DEP_MINUTE", (F.col("CRS_DEP_TIME") % 100).cast("integer"))
parsed_stream_df = parsed_stream_df.withColumn("ARR_HOUR", (F.col("CRS_ARR_TIME") / 100).cast("integer"))
parsed_stream_df = parsed_stream_df.withColumn("ARR_MINUTE", (F.col("CRS_ARR_TIME") % 100).cast("integer"))
parsed_stream_df = parsed_stream_df.withColumn("DEP_DAY_OF_WEEK", F.dayofweek(F.col("FL_DATE")))
parsed_stream_df = parsed_stream_df.withColumn("DEP_MONTH", F.month(F.col("FL_DATE")))
parsed_stream_df = parsed_stream_df.withColumn("DEP_DAY_OF_MONTH", F.dayofmonth(F.col("FL_DATE")))
parsed_stream_df = parsed_stream_df.withColumn("DEP_WEEK_OF_YEAR", F.weekofyear(F.col("FL_DATE")))
parsed_stream_df = parsed_stream_df.withColumn("IS_WEEKEND", F.when(F.col("DEP_DAY_OF_WEEK").isin([1, 7]), 1).otherwise(0))
parsed_stream_df = parsed_stream_df.withColumn("DISTANCE_PER_MINUTE", F.col("DISTANCE") / (F.col("CRS_ELAPSED_TIME") + 1e-6))

# 4. Select the features you’ll use in your model and drop any rows missing them
feature_columns = [
    "AIRLINE_CODE", "ORIGIN", "DEST", "CRS_ELAPSED_TIME", "DISTANCE",
    "DEP_HOUR", "DEP_MINUTE", "ARR_HOUR", "ARR_MINUTE",
    "DEP_DAY_OF_WEEK", "DEP_MONTH", "DEP_DAY_OF_MONTH", "DEP_WEEK_OF_YEAR",
    "IS_WEEKEND", "DISTANCE_PER_MINUTE"
]
parsed_stream_df = parsed_stream_df.dropna(subset=feature_columns + ["FL_DATE"])

# 5. Print schema to verify all fields and types
parsed_stream_df.printSchema()

# 6. Apply preprocessing pipeline (e.g., StringIndexers, VectorAssembler, scaling)
processed_stream_df = loaded_pipeline_model.transform(parsed_stream_df)

# 7. Generate predictions using the pre‑trained Random Forest model
predictions_df = loaded_rf_model.transform(processed_stream_df)

# 8. Select only the columns you want to output downstream
output_df = predictions_df.select(
    "FL_DATE",
    "AIRLINE_CODE",
    "ORIGIN",
    "DEST",
    "CRS_DEP_TIME",
    "CRS_ARR_TIME",
    "DISTANCE",
    "prediction",
    "probability"
)

# 9. Map the numeric prediction to a human‑readable label
output_df = output_df.withColumn(
    "Prediction_Label",
     F.when(F.col("prediction") == 1, "Severe Delay Predicted")
     .otherwise("No Severe Delay Predicted")
)

root
 |-- FL_DATE: date (nullable = true)
 |-- AIRLINE: string (nullable = true)
 |-- AIRLINE_CODE: string (nullable = true)
 |-- ORIGIN: string (nullable = true)
 |-- DEST: string (nullable = true)
 |-- CRS_DEP_TIME: integer (nullable = true)
 |-- CRS_ARR_TIME: integer (nullable = true)
 |-- CRS_ELAPSED_TIME: double (nullable = true)
 |-- DISTANCE: double (nullable = true)
 |-- DEP_HOUR: integer (nullable = true)
 |-- DEP_MINUTE: integer (nullable = true)
 |-- ARR_HOUR: integer (nullable = true)
 |-- ARR_MINUTE: integer (nullable = true)
 |-- DEP_DAY_OF_WEEK: integer (nullable = true)
 |-- DEP_MONTH: integer (nullable = true)
 |-- DEP_DAY_OF_MONTH: integer (nullable = true)
 |-- DEP_WEEK_OF_YEAR: integer (nullable = true)
 |-- IS_WEEKEND: integer (nullable = false)
 |-- DISTANCE_PER_MINUTE: double (nullable = true)



In [None]:
# after you run this, it will continue running while streaming the data to the website
# make sure kafka producer and flask app are running in two terminals

# 1. Define a UDF to pull out the probability of the “severe delay” class (index 1 of the vector)
prob_udf = F.udf(lambda prob: float(prob[1]), DoubleType())

# 2. Add a new column with the extracted probability
output_df = output_df.withColumn("Probability_Severe_Delay", prob_udf(F.col("probability")))

# 3. Print the schema to verify that Probability_Severe_Delay was added correctly
output_df.printSchema()

# 4. Define the MongoDB connection string and database/collection names
mongo_uri = "mongodb://127.0.0.1/flightdb.flight_predictions"

# 5. Build and start the streaming query:
mongo_query = output_df \
    .select("FL_DATE", "AIRLINE_CODE", "ORIGIN", "DEST", "CRS_DEP_TIME", "Prediction_Label", "Probability_Severe_Delay") \
    .writeStream \
    .format("mongodb") \
    .option("spark.mongodb.connection.uri", mongo_uri) \
    .option("checkpointLocation", "./mongodb_checkpoint") \
    .outputMode("append") \
    .start()


# Console output
console_query = output_df \
    .select("FL_DATE", "AIRLINE_CODE", "ORIGIN", "DEST", "CRS_DEP_TIME", "Prediction_Label", "Probability_Severe_Delay") \
    .writeStream \
    .outputMode("append") \
    .format("console") \
    .option("truncate", "false") \
    .start()

# File output
file_output_path = "./streaming_predictions_output"
file_query = output_df \
    .select("FL_DATE", "AIRLINE_CODE", "ORIGIN", "DEST", "CRS_DEP_TIME", "Prediction_Label", "Probability_Severe_Delay") \
    .writeStream \
    .outputMode("append") \
    .format("parquet") \
    .option("path", file_output_path) \
    .option("checkpointLocation", "./streaming_predictions_checkpoint") \
    .start()

# Wait for all queries
try:
    while True:
        if any([q.isActive for q in [console_query, file_query, mongo_query]]):
            pass
        else:
            break
except KeyboardInterrupt:
    for q in [console_query, file_query, mongo_query]:
        q.stop()
    print("Streaming query interrupted by user.")
