In [None]:
#Author: Tan Xin Hui
import time
from pyspark.sql import SparkSession
from pyspark.sql.functions import when
from pyspark.sql.types import StructType, StructField, DoubleType, IntegerType, StringType

# Define the schema 
schema = StructType([
    StructField("Carat", DoubleType(), True),
    StructField("Clarity", StringType(), True),
    StructField("Color", StringType(), True),
    StructField("Fluorescence", StringType(), True),
    StructField("Polish", StringType(), True),
    StructField("Price", DoubleType(), True),
    StructField("Shape", StringType(), True),
    StructField("Symmetry", StringType(), True),
    StructField("Shape_encoded", DoubleType(), True),
    StructField("Clarity_encoded", DoubleType(), True),
    StructField("Color_encoded", DoubleType(), True),
    StructField("Polish_encoded", DoubleType(), True),
    StructField("Symmetry_encoded", DoubleType(), True),
    StructField("Fluorescence_encoded", DoubleType(), True),
    StructField("Price per Carat", DoubleType(), True),
    StructField("price_label", StringType(), True),
])

def main_stream():
    spark = SparkSession.builder \
        .appName("AnnotatedDataPipeline") \
        .config("spark.hadoop.fs.defaultFS", "hdfs://localhost:9000") \
        .getOrCreate()

    annotated_data_path = "/user/student/annotated_data"
    annotated_df = spark.readStream \
        .format("csv") \
        .option("header", "true") \
        .schema(schema) \
        .load(annotated_data_path)

    converted_df = annotated_df.withColumn(
        "price_category",
        when(annotated_df["price_label"] == 2, "high")
        .when(annotated_df["price_label"] == 1, "medium")
        .when(annotated_df["price_label"] == 0, "low")
        .otherwise("unknown")
    )

    query = converted_df.writeStream \
        .outputMode("append") \
        .format("csv") \
        .option("path", "hdfs://localhost:9000/user/student/streamed_data/") \
        .option("checkpointLocation", "hdfs://localhost:9000/user/student/checkpoints") \
        .trigger(processingTime='1 minute') \
        .start()

    try:
        while query.isActive:
            progress = query.lastProgress
            print("Latest progress update:")
            print(progress)
            print("********************************")
            time.sleep(10)

    except KeyboardInterrupt:
        print("Streaming job was manually stopped.")
    
    except Exception as e:
        print(f"Streaming job terminated with an error: {e}")

    finally:
        query.stop()
        print("Streaming job has finished.")

if __name__ == "__main__":
    main_stream()