In [None]:
# import findspark
# findspark.init('/opt/spark')

# Load config variables
from config import BRONZE_PATH, SILVER_PATH, GOLD_FEATURES_PATH, GOLD_PREDICTIONS_PATH, BRONZE_CHECKPOINT, SILVER_CHECKPOINT, GOLD_PROCESSING_CHECKPOINT, KAFKA_BOOTSTRAP_SERVERS, KAFKA_TOPIC, MODEL_PATH

from pyspark.sql import SparkSession, DataFrame
from pyspark.sql.types import StructType , StructField, DoubleType, StringType, ArrayType, LongType
from pyspark.sql.functions import col, from_json, avg, window, stddev, lead
from pyspark.sql.window import Window

In [None]:
spark = SparkSession.builder.appName("mini-projet").master("local") \
        .config(
            "spark.jars.packages",
            "org.apache.spark:spark-sql-kafka-0-10_2.12:3.5.2,"
            "io.delta:delta-spark_2.12:3.3.0,"
            "io.delta:delta-storage:3.3.0"
        ) \
        .config("spark.sql.extensions","io.delta.sql.DeltaSparkSessionExtension") \
        .config("spark.sql.catalog.spark_catalog", "org.apache.spark.sql.delta.catalog.DeltaCatalog") \
        .getOrCreate()

In [None]:
bronze = BRONZE_PATH
silver = SILVER_PATH
gold_features_path = GOLD_FEATURES_PATH      
gold_predictions_path = GOLD_PREDICTIONS_PATH

In [None]:
schema = StructType([
    StructField("s", StringType(),True),
    StructField("p", DoubleType(),True),
    StructField("v", DoubleType(),True),
    StructField("t", LongType()),
])

df_raw = spark.readStream.format("kafka") \
        .option('kafka.bootstrap.servers', KAFKA_BOOTSTRAP_SERVERS) \
        .option('subscribe', KAFKA_TOPIC) \
        .option('startingOffsets','latest') \
        .load()
df_raw

In [None]:
# Save stream data
df_raw.writeStream.format("delta").option("checkpointLocation", BRONZE_CHECKPOINT).outputMode("append").start(bronze)

In [None]:
# read bronze
df_bronze = spark.readStream.format("delta").load(bronze)

data_parsed = df_bronze \
    .selectExpr('CAST(value as String) as value_str') \
    .select(from_json(col("value_str"),schema).alias('data')) \
    .select(
        col('data.p').alias("price"),
        col('data.s').alias("symbol"),
        col('data.v').alias("volume"),
        col('data.t').alias("timestamp"),
    )
data_parsed = data_parsed.withColumn('event_time', (col('timestamp') / 1000).cast("timestamp"))

data_parsed

In [None]:
# write silver
data_parsed.writeStream \
        .format("delta") \
        .outputMode("append") \
        .option("checkpointLocation", SILVER_CHECKPOINT)\
        .start(silver)

In [None]:
#read silver
df_silver = spark.readStream.format("delta").load(silver)

df_features = df_silver \
    .withWatermark("event_time","10 seconds") \
    .groupBy(
        window(col("event_time"),"20 seconds"),
        col("symbol")
    ).agg(
        avg("price").alias("avg_price"),
        avg("volume").alias("avg_volume"),
        stddev("price").alias("volatility")
    )
df_features

In [None]:
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.classification import RandomForestClassificationModel

df_features_clean = df_features.na.fill({
    "avg_price": 0,
    "avg_volume": 0,
    "volatility": 0
})

assembler = VectorAssembler(
    inputCols=["avg_price", "avg_volume", "volatility"],
    outputCol="features"
)

model_path = MODEL_PATH

model = None
model_loaded = False

try:
    # Attempt to load the model
    model = RandomForestClassificationModel.load(model_path)
    model_loaded = True
    print("SUCCESS: Model loaded!")
except Exception as e:
    # Continue without a model
    print(f"WARNING: No model found ({e}).")
    model_loaded = False
    model = None

In [None]:
def process_batch(df, epoch_id):
    df.cache()
    
    count = df.count() 
    if count > 0:
        print(f"Batch {epoch_id}: {count} rows to process.")
        
        # SAVE HISTORY
        df.drop("features").write \
            .format("delta") \
            .mode("append") \
            .save(gold_features_path)
        
        # RUN PREDICTION
        if model_loaded:
            # Vectorize
            df_vec = assembler.transform(df)
            # Predict
            df_pred = model.transform(df_vec)
            # Save results
            df_pred.select("window", "symbol", "prediction", "probability", "avg_price") \
                .write \
                .format("delta") \
                .mode("append") \
                .option("path", gold_predictions_path) \
                .saveAsTable("default.live_predictions")
                
    df.unpersist()

In [None]:
# launching the final stream with foreachBatch
query = df_features_clean.writeStream \
    .foreachBatch(process_batch) \
    .outputMode("update") \
    .option("checkpointLocation", GOLD_PROCESSING_CHECKPOINT) \
    .start()

query.awaitTermination()

In [None]:
# Quick verification
print("Contents of the live_predictions table:")
spark.sql("select * from default.live_predictions;").show(truncate=False)