In [1]:
import os
spark_home = os.path.abspath(os.getcwd() + "/../spark-3.5.5-bin-hadoop3")
hadoop_home = os.path.abspath(os.getcwd() + "/../winutils")
print(f"I am using the following SPARK_HOME: {spark_home}")
if os.name == 'nt':
    os.environ["HADOOP_HOME"] = f"{hadoop_home}"
    print(f"Windows detected: set HADOOP_HOME to: {os.environ['HADOOP_HOME']}")
    hadoop_bin = os.path.join(hadoop_home, "bin")
    os.environ["PATH"] = f"{hadoop_bin};{os.environ['PATH']}"
    print(f"  Also added Hadoop bin directory to PATH: {hadoop_bin}")

import findspark
import pyspark
from pyspark.streaming import StreamingContext

findspark.init(spark_home)
sc = pyspark.SparkContext()
spark = pyspark.sql.SparkSession.builder.getOrCreate()


I am using the following SPARK_HOME: C:\Users\arthu\Desktop\spark\spark-3.5.5-bin-hadoop3
Windows detected: set HADOOP_HOME to: C:\Users\arthu\Desktop\spark\winutils
  Also added Hadoop bin directory to PATH: C:\Users\arthu\Desktop\spark\winutils\bin


In [2]:
import threading

# Helper thread to avoid the Spark StreamingContext from blocking Jupyter
        
class StreamingThread(threading.Thread):
    def __init__(self, ssc):
        super().__init__()
        self.ssc = ssc
    def run(self):
        self.ssc.start()
        self.ssc.awaitTermination()
    def stop(self):
        print('----- Stopping... this may take a few seconds -----')
        self.ssc.stop(stopSparkContext=False, stopGraceFully=True)



In [3]:
import random
from pyspark.streaming import StreamingContext
from pyspark.sql import Row
from pyspark.sql.functions import udf, struct, array, col, lit
from pyspark.sql.types import StringType

In [4]:
from pyspark.sql.functions import concat_ws, col
from sentence_transformers import SentenceTransformer
import joblib
import os

# Load models once
globals()['models_loaded'] = False
globals()['embedder'] = None
globals()['clf'] = None

def process(time, rdd):
    if rdd.isEmpty():
        print(f"[{time}] — Empty batch.")
        return

    print(f"\n========= {str(time)} =========")
    try:
        # Step 1: Read JSON from RDD
        df = spark.read.json(rdd)

        # Step 2: Clean missing title/summary
        df = df.fillna({'title': '', 'summary': ''})

        # Step 3: Combine into a single text column
        df = df.withColumn("text", concat_ws(" ", col("title"), col("summary")))

        # Step 4: Get text to predict
        texts = [row["text"] for row in df.select("text").collect()]

        # Step 5: Load models on first run only
        if not globals()['models_loaded']:
            model_path = "C:/Users/arthu/Desktop/spark/notebooks/minilm_model"
            clf_path = "C:/Users/arthu/Desktop/spark/notebooks/logreg_model.pkl"
            globals()['embedder'] = SentenceTransformer(model_path)
            globals()['clf'] = joblib.load(clf_path)
            globals()['models_loaded'] = True
            print("✅ Model and embedder loaded.")

        # Step 6: Predict
        embeddings = globals()['embedder'].encode(texts)
        predictions = globals()['clf'].predict(embeddings)

        # Step 7: Assemble prediction DataFrame
        pred_df = spark.createDataFrame(zip(texts, [str(p) for p in predictions]), ["text", "pred"])

        # Step 8: Include true label (first category)
        df = df.withColumnRenamed("categories", "true_label")

        # Step 9: Merge and show
        result_df = df.join(pred_df, on="text", how="left")
        result_df.select("aid", "text", "true_label", "pred").show(truncate=False)

        # Step 10: Save results to disk
        timestamp = str(time).replace(" ", "_").replace(":", "-")
        output_path = f"./predictions_batch_{timestamp}.json"
        result_df.select("aid", "text", "true_label", "pred").write.json(output_path)
        print(f"📁 Saved predictions to {output_path}")

    except Exception as e:
        print(f"❌ Error processing batch: {e}")


In [5]:
import json

test_json = json.dumps({
    "aid": "test123",
    "title": "Deep learning in medicine",
    "summary": "AI is being used to improve diagnostics and treatment accuracy.",
    "categories": ["cs.AI", "stat.ML"]
})

test_rdd = sc.parallelize([test_json])



In [6]:
process("test_run", test_rdd)



✅ Model and embedder loaded.
+-------+-----------------------------------------------------------------------------------------+----------------+-------+
|aid    |text                                                                                     |true_label      |pred   |
+-------+-----------------------------------------------------------------------------------------+----------------+-------+
|test123|Deep learning in medicine AI is being used to improve diagnostics and treatment accuracy.|[cs.AI, stat.ML]|eess.IV|
+-------+-----------------------------------------------------------------------------------------+----------------+-------+

❌ Error processing batch: [PATH_ALREADY_EXISTS] Path file:/C:/Users/arthu/Desktop/spark/notebooks/predictions_batch_test_run.json already exists. Set mode as "overwrite" to overwrite the existing path.


In [7]:
# Create a StreamingContext that checks for new data every 10 seconds
ssc = StreamingContext(sc, 10)

# Connect to the live data stream from the instructor's server
lines = ssc.socketTextStream("seppe.net", 7778)

# Process each RDD of incoming lines using your model
lines.foreachRDD(process)

# Start the streaming thread (non-blocking)
ssc_t = StreamingThread(ssc)
ssc_t.start()

# Optional: manually stop after e.g. 60 seconds or a few batches
# time.sleep(60)
# ssc_t.stop()





+---------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

In [None]:
ssc_t.stop()


----- Stopping... this may take a few seconds -----
