In [None]:
import os
spark_home = os.path.abspath(os.getcwd() + "/../Assignment_3/spark/spark-3.5.5-bin-hadoop3")
hadoop_home = os.path.abspath(os.getcwd() + "/../Assignment_3/spark/winutils")
print(f"I am using the following SPARK_HOME: {spark_home}")
if os.name == 'nt':
    os.environ["HADOOP_HOME"] = f"{hadoop_home}"
    print(f"Windows detected: set HADOOP_HOME to: {os.environ['HADOOP_HOME']}")
    hadoop_bin = os.path.join(hadoop_home, "bin")
    os.environ["PATH"] = f"{hadoop_bin};{os.environ['PATH']}"
    print(f"  Also added Hadoop bin directory to PATH: {hadoop_bin}")

import findspark
import pyspark
from pyspark.streaming import StreamingContext

findspark.init(spark_home)
sc = pyspark.SparkContext()
spark = pyspark.sql.SparkSession.builder.getOrCreate()

In [None]:
import threading

# Helper thread to avoid the Spark StreamingContext from blocking Jupyter
        
class StreamingThread(threading.Thread):
    def __init__(self, ssc):
        super().__init__()
        self.ssc = ssc
    def run(self):
        self.ssc.start()
        self.ssc.awaitTermination()
    def stop(self):
        print('----- Stopping... this may take a few seconds -----')
        self.ssc.stop(stopSparkContext=False, stopGraceFully=True)

In [None]:
import random
import pandas as pd
from pyspark.streaming import StreamingContext
from pyspark.sql import Row
from pyspark.sql.functions import udf, struct, array, col, lit, collect_set
from pyspark.sql.types import StringType

In [None]:
import torch 
from transformers import pipeline

In [None]:
from pyspark.sql.types import StructType, StructField, StringType, FloatType

In [None]:
df_categories = spark.read.parquet("data/arxiv_categories.parquet.gzip")
row = df_categories.agg(collect_set("group").alias("groups")).collect()[0]

categories = row['groups']

print(categories)

In [None]:

candidate_labels = categories

# schema = StructType([
#     StructField("summary", StringType(), True),
#     StructField("predicted_label", StringType(), True),
#     StructField("confidence_score", FloatType(), True)
# ])

def classify_partition(time, rdd):
    # Only load the model once per Python worker
    if not globals().get("models_loaded", False):
        globals()["classifier"] = pipeline("zero-shot-classification", model="valhalla/distilbart-mnli-12-3", device=-1)
        globals()["models_loaded"] = True

    classifier = globals()["classifier"]

    # classify_udf = udf(classifier, schema)

    if rdd.isEmpty():
        return
    
    print("========= %s =========" % str(time))
    
    # Convert to data frame
    df = spark.read.json(rdd)
    # df_sel = df.select('summary')
    # df_sel.show()

    dfp = df.toPandas()

    results = []

    for idx, row in dfp.iterrows():
        text = row["summary"]
        res = classifier(text, candidate_labels)
        label = res["labels"][0]
        score = float(res["scores"][0])

        row_data = row.to_dict()
        row_data["predicted_label"] = label
        row_data["confidence_score"] = score
        results.append(row_data)

    df_result = pd.DataFrame(results)
    display(df_result)

    
    # for text in dfp["summary"]:
    #     res = classifier(text, candidate_labels)
    #     label = res["labels"][0]
    #     score = float(res["scores"][0])
    #     results.append((text, label, score))

    # df_result = pd.DataFrame(results, columns=["summary", "predicted_label", "confidence_score"])
    # display(df_result)

In [None]:
ssc = StreamingContext(sc, 10)

In [None]:
lines = ssc.socketTextStream("seppe.net", 7778)
lines.foreachRDD(classify_partition)

In [None]:
ssc_t = StreamingThread(ssc)
ssc_t.start()

In [None]:
ssc_t = StreamingThread(ssc)
ssc_t.start()

In [None]:
ssc_t.stop()