In [None]:
import os
spark_home = os.path.abspath(os.getcwd() + "/../Assignment_3/spark/spark-3.5.5-bin-hadoop3")
hadoop_home = os.path.abspath(os.getcwd() + "/../Assignment_3/spark/winutils")
print(f"I am using the following SPARK_HOME: {spark_home}")
if os.name == 'nt':
    os.environ["HADOOP_HOME"] = f"{hadoop_home}"
    print(f"Windows detected: set HADOOP_HOME to: {os.environ['HADOOP_HOME']}")
    hadoop_bin = os.path.join(hadoop_home, "bin")
    os.environ["PATH"] = f"{hadoop_bin};{os.environ['PATH']}"
    print(f"  Also added Hadoop bin directory to PATH: {hadoop_bin}")

import findspark
import pyspark
from pyspark.streaming import StreamingContext

findspark.init(spark_home)
sc = pyspark.SparkContext()
spark = pyspark.sql.SparkSession.builder.getOrCreate()

In [None]:
import random
import numpy as np
import pandas as pd
from pyspark.sql import Row
from pyspark.sql.functions import udf, struct, array, col, lit, collect_set
import torch 
from transformers import AutoTokenizer, AutoModelForSequenceClassification

In [None]:
import threading

# Helper thread to avoid the Spark StreamingContext from blocking Jupyter
        
class StreamingThread(threading.Thread):
    def __init__(self, ssc):
        super().__init__()
        self.ssc = ssc
    def run(self):
        self.ssc.start()
        self.ssc.awaitTermination()
    def stop(self):
        print('----- Stopping... this may take a few seconds -----')
        self.ssc.stop(stopSparkContext=False, stopGraceFully=True)

In [None]:
import json

with open("models/finetuned_scibert_scivocab_uncased_weighted_17cats/id2label.json", "r") as f:
    id2label = {int(k): v for k, v in json.load(f).items()}

id2label

In [None]:
candidate_labels = [id2label[i] for i in range(len(id2label))]
# candidate_labels

In [None]:
df_level1_categories = pd.read_parquet("data/df_level1_categories.gzip")
# display(df_level1_categories)

In [None]:
# get finetuned model

model_path = "models/finetuned_scibert_scivocab_uncased_weighted_17cats"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForSequenceClassification.from_pretrained(model_path)

In [None]:
candidate_labels = candidate_labels
max_len = 256

def classify_partition(time, rdd):
    # Only load the model once per Python worker
    if not globals().get("models_loaded", False):
        globals()["classifier"] = model
        globals()["models_loaded"] = True

    classifier = globals()["classifier"]

    if rdd.isEmpty():
        return
    
    print("========= %s =========" % str(time))
    
    # Convert to data frame
    df = spark.read.json(rdd)
    dfp = df.toPandas()
    dfp['level1_category'] = dfp['main_category'].str.split('.').str[0] # get true category to compare with prediction afterwards
    dfp = pd.merge(dfp, df_level1_categories, on = 'level1_category').drop(['level1_category', 'group'], axis = 1) # get true category full label (not abbreviated)

    results = []

    for idx, row in dfp.iterrows():
        text = row["title"] + ". " + row["summary"]
        
        inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=max_len)

        with torch.no_grad():
            outputs = classifier(**inputs)
        
        logits = outputs.logits
        predicted_class = torch.argmax(logits, dim=1).item()

        predicted_label = id2label[predicted_class]
        results.append(predicted_label)

    df_result = pd.DataFrame(results, columns = ['pred'])

    dfp = pd.concat([dfp, df_result], axis = 1)
    dfp["correct"] = np.where(dfp.subgroup == dfp.pred, '+', '-')

    display(dfp)


In [None]:
ssc = StreamingContext(sc, 10)

In [None]:
lines = ssc.socketTextStream("seppe.net", 7778)
lines.foreachRDD(classify_partition)

In [None]:
ssc_t = StreamingThread(ssc)
ssc_t.start()

In [None]:
ssc_t.stop()