In [0]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql import Window
from pyspark.sql.types import StructType, StringType, IntegerType, DoubleType, StructField,TimestampType
import redis
import json
import time
import os

In [0]:
base_connection_string = os.getenv("EVENT_HUB_CONNECTION_STR")

entity_paths = ['raw_glucose_readings', 'raw_device_feeds'] # # Initializing Connection String for incoming raw glucose readings
entity_paths_target = ['above_max_glucose_threshold', 'below_min_glucose_threshold' ,'missed_readings', 'device_error'] # Sending Data To This Ones

configs = {}

# Encrypting the connection string with the entity path
for entity_path in entity_paths + entity_paths_target:
    entity_conn_string = f"{base_connection_string};EntityPath={entity_path}"
    encrypted_conn_string = spark._jvm.org.apache.spark.eventhubs.EventHubsUtils.encrypt(entity_conn_string)
    configs[entity_path] = {
        "eventhubs.connectionString": encrypted_conn_string
    }
    if entity_path in entity_paths:
        configs[entity_path]["eventhubs.consumerGroup"] = "spark"

# Accessing the configurations
ehConfGlucose = configs['raw_glucose_readings']
ehConfDevice = configs['raw_device_feeds']
ehConfHighGlucose = configs['above_max_glucose_threshold']
ehConfigLowGlucose=configs['below_min_glucose_threshold']
ehConfMS = configs['missed_readings']
ehConfLostConn=configs['device_error']


In [0]:
class RedisConnection:
    """
    A class representing a Redis connection.

    This class provides a static method to get a singleton instance of a Redis connection.

    Attributes:
        _instance: The singleton instance of the Redis connection.

    Methods:
        get_instance: Returns the singleton instance of the Redis connection.
    """
    
    _instance = None

    @staticmethod
    def get_instance(host, port, password):
        """
        Returns the singleton instance of the Redis connection.

        Args:
            host: The host for the Redis connection.
            port: The port for the Redis connection.
            password: The password for the Redis connection.

        Returns:
            The singleton instance of the Redis connection.
        """
        if RedisConnection._instance is None:
            RedisConnection._instance = redis.Redis(
                host=host,
                port=port,
                password=password,
                ssl=True,
                decode_responses=True
            )
        return RedisConnection._instance

def get_redis_connection():
    """
    Returns a Redis connection instance.
    """
    return RedisConnection.get_instance(
        os.getenv('REDIS_HOST_NAME'),
        int(os.getenv('REDIS_SSL_PORT')),
        os.getenv('REDIS_PRIMARY_ACCESS_KEY')
    )

In [0]:
spark = SparkSession.builder.appName("EventHubGlucoseReadings").getOrCreate()

# Optimizing shuffle partitions since I am performing some joins in this notebook
spark.conf.set("spark.sql.shuffle.partitions", "4")  

# For ignoring empty micro-batches
spark.conf.set("spark.sql.streaming.noDataMicroBatches.enabled", "true")


In [0]:
def fetch_patient_details(user_id):
    try:
        r = get_redis_connection()
        data = r.hgetall(f"user:{user_id}")
        print(f"Fetched data for user_id {user_id}: {data}")
        return json.dumps(data)
    except Exception as e:
        return json.dumps({"error": str(e)})

def fetch_device_details(device_id):
    try:
        r = get_redis_connection()
        data = r.hgetall(f"device:{device_id}")
        print(f"Fetched data for device_id {device_id}: {data}")
        return json.dumps(data)
    except Exception as e:
        return json.dumps({"error": str(e)})

In [0]:
# Schema for raw_glucose_readings
glucose_schema = StructType() \
    .add("user_id", IntegerType()) \
    .add("device_id", IntegerType()) \
    .add("timestamp", StringType()) \
    .add("glucose_reading", DoubleType()) \
    .add("latitude", DoubleType()) \
    .add("longitude", DoubleType())


#Schema Patient Details
patient_schema = StructType([
    StructField("patient_name", StringType()),
    StructField("patient_age", StringType()), 
    StructField("gender", StringType()),
    StructField("min_glucose", StringType()), # Minimum level at which glucose should be
    StructField("max_glucose", StringType()), # Maximum it should be at
    StructField("medical_condition", StringType()),
])



In [0]:
# Schema for raw_device_feeds
device_feed_schema = StructType() \
    .add("device_id", IntegerType()) \
    .add("battery_level", IntegerType()) \
    .add("firmware_name", StringType()) \
    .add("firmware_version", StringType()) \
    .add("connection_status", StringType()) \
    .add("error_code", StringType()) \
    .add("timestamp", StringType())


#Schema Device Details
device_schema = StructType([
    StructField("owner_name", StringType()),
    StructField("device_model", StringType()), 
    StructField("data_transmission_interval", StringType()), ## RATE AT WHICH DEVICE IS EXPECTED TO SEND INFORMATION RELATED TO THE USER GLUCOSE LEVEL
    StructField("expected_transmissions", StringType()),  ## expected_transmissions IS CALCULATED FROM data_transmission_interval for 15 minutes window time
    StructField("manufacturer_name", StringType()), 
])




In [0]:
# Read from the Event Hub
glucose_readings_df = spark \
  .readStream \
  .format("eventhubs") \
  .options(**ehConfGlucose) \
  .load() \
  .select(from_json(col("body").cast("string"), glucose_schema).alias("data")) \
  .select(
      col("data.user_id"),
      col("data.device_id"),
      to_timestamp('data.timestamp', "yyyy-MM-dd'T'HH:mm:ss").alias('timestamp'),
      col("data.glucose_reading"),
      col("data.latitude"),
      col("data.longitude")
  )

fetch_patient_udf = udf(fetch_patient_details, StringType())

enriched_df = glucose_readings_df.withColumn("patient_details_json", fetch_patient_udf(col("user_id")))

# Parse JSON string into separate columns
enriched_df = enriched_df.withColumn("patient_details", from_json(col("patient_details_json"), patient_schema))
enriched_df = enriched_df.select("*", "patient_details.*")
enriched_df = enriched_df.withColumn("max_glucose", col("patient_details.max_glucose").cast(IntegerType()))
enriched_df = enriched_df.withColumn("min_glucose", col("patient_details.min_glucose").cast(IntegerType()))
# enriched_df = enriched_df.withColumn("num_readings_required", col("patient_details.num_readings_required").cast(IntegerType()))

enriched_df = enriched_df.withColumn("reading_over_max_glucose", when(col("glucose_reading") > col("max_glucose"), 1).otherwise(0))
enriched_df = enriched_df.withColumn("reading_over_min_glucose", when(col("glucose_reading") < col("min_glucose"), 1).otherwise(0))


final_df = enriched_df.select(
    col("user_id"),
    col("device_id"),
    col("timestamp"),
    col('glucose_reading'),
    col('latitude'),
    col('longitude'),
    col("patient_details.patient_name").alias("patient_name"),
    col("patient_details.patient_age").alias("patient_age"),
    col("patient_details.gender").alias("gender"),
    col("max_glucose"),  
    col("min_glucose"),
    col("patient_details.medical_condition").alias("medical_condition"),  
    col("reading_over_max_glucose"),
    col("reading_over_min_glucose")
)





In [0]:
device_feeds_df = spark \
  .readStream \
  .format("eventhubs") \
  .options(**ehConfDevice) \
  .load() \
  .select(from_json(col("body").cast("string"), device_feed_schema).alias("data")) \
  .select(
      col("data.device_id"),
      col("data.battery_level"),
      col("data.firmware_name"),
      col("data.firmware_version"),
      col("data.connection_status"),
      col("data.error_code"),
      to_timestamp('data.timestamp', "yyyy-MM-dd'T'HH:mm:ss").alias('timestamp')
  )


fetch_device_udf = udf(fetch_device_details, StringType())

enriched_df = device_feeds_df.withColumn("device_details_json", fetch_device_udf(col("device_id")))

# Parse JSON string into separate columns
enriched_df = enriched_df.withColumn("device_details", from_json(col("device_details_json"), device_schema))
enriched_df = enriched_df.select("*", "device_details.*")
enriched_df = enriched_df.withColumn("data_transmission_interval", col("device_details.data_transmission_interval").cast(IntegerType()))
enriched_df = enriched_df.withColumn("expected_transmissions", col("device_details.expected_transmissions").cast(IntegerType()))




final_device_df = enriched_df.select(
    col("device_id"),
    col("timestamp"),
    col('battery_level'),
    col('firmware_name'),
    col('firmware_version'),
    col('connection_status'),
    col('error_code'),
    col("device_details.owner_name").alias("owner_name"),
    col("device_details.device_model").alias("device_model"),
    col("device_details.data_transmission_interval").alias("data_transmission_interval"),
    col("device_details.expected_transmissions").alias("expected_transmissions"),
    col("device_details.manufacturer_name").alias("manufacturer_name")
) 



In [0]:
## Joining both dataframes for a business logic
final_df = final_df.withColumnRenamed("timestamp", "final_df_timestamp")
combined_df = final_df.join(final_device_df, on=['device_id'])

combined_feeds_df = combined_df.select(
    col("device_id"),
    col('device_model'),
    col('user_id'),
    col('patient_name'),
    col('glucose_reading'),
    col('data_transmission_interval'),
    col('expected_transmissions'),  
    col("final_df_timestamp")
    )


In [0]:
"""
EVERY DEVICE HAS A RATE AT WHICH IT EXPECTED TO SEND UPDATES ON GLUCOSE LEVEL. FOR EXAMPLE, a data_transmission_interval OF 2 MEANS THAT EVERY 2 MINUTES,
AN UPDATE ABOUT THE USER'S GLUCOSE LEVEL IS TRANSMITTED. THE COLUMN expected_transmissions WAS CALCULATED FOR EVERY DEVICE FOR A 15 MINUTES WINDOW BEFORE BEIGN SENT TO REDIS. 

BELOW, THE CODE BLOCK CREATES A 15 MINUTES WINDOW WHERE INCOMING GLUCOSE READINGS ARE COUNTED AND AT THE END WILL BE COMPARED TO expected_transmissions.
THIS CAN HELP INVESTIGATE AS TO IF THE DEVICES ARE WELL CALIBRATED AND SEND UPDATES ABOUT THE USERS AT THE EXPECTED RATE. OR, MORE INQUIRIES ABOUT THE QUALITY OF THE DEVICE, DISCONNECTIONS OR IF SOME USERS MIGHT BE NEGLECTING WEARING THE DEVICE. 
"""


result_df = combined_feeds_df \
    .withWatermark("final_df_timestamp", "1 minute") \
    .groupBy(window(col('final_df_timestamp'), "15 minutes"), col('user_id'),col('patient_name'),col('device_id'),col("device_model"),col('expected_transmissions'),) \
    .agg(count("*").alias("count_transmissions")) \
    .filter(col("count_transmissions") < col("expected_transmissions"))


result_df = result_df.select(
    col("window").start.alias("window_start"),
    col("window").end.alias("window_end"),
    col('user_id'),
    col("patient_name"),
    col('device_id'),
    col("device_model"),
    col('expected_transmissions'),
    col('count_transmissions')
)

missed_readings_df = result_df.select(
    col("user_id").cast("string").alias("partitionKey"),
    to_json(struct([result_df[x] for x in result_df.columns])).alias("body")
)


In [0]:
# Filtering for glucose_reading above max_treshold
high_glucose_df = final_df.select(
    col("user_id"),
    col("device_id"),
    col('patient_name'),
    col("final_df_timestamp"),
    col('glucose_reading'),
    col("max_glucose"))\
    .filter("reading_over_max_glucose==1")


high_glucose_json_df = high_glucose_df.select(
    col("user_id").cast("string").alias("partitionKey"),
    to_json(struct(*[col(x) for x in high_glucose_df.columns])).alias("body")
)



In [0]:
# Filtering for glucose_reading below min_treshold
low_glucose_df = final_df.select(
    col("user_id"),
    col("device_id"),
    col('patient_name'),
    col("final_df_timestamp"),
    col('glucose_reading'),
    col("min_glucose"))\
    .filter("reading_over_min_glucose==1")


low_glucose_json_df = low_glucose_df.select(
    col("user_id").cast("string").alias("partitionKey"),
    to_json(struct(*[col(x) for x in low_glucose_df.columns])).alias("body")
)


In [0]:
# spark.conf.set("spark.sql.streaming.statefulOperator.checkCorrectness.enabled", True)

In [0]:
# Filtering for Devices were connection was lost

lost_connection_df= final_device_df.select(
    col("device_id"),
    col("timestamp"),
    col('battery_level'),
    col('connection_status'),
    col('error_code'),
    col("owner_name"),
    col("manufacturer_name")) \
        .filter("connection_status = 'Disconnected'")

lost_connection_json_df = lost_connection_df.select(
    col("device_id").cast("string").alias("partitionKey"),
    to_json(struct(*[col(x) for x in lost_connection_df.columns])).alias("body")
)



In [0]:
# List of checkpoint locations where the checkpoints will be stored during the streaming process
# For Fault Tolerance and to ensure that the data is not lost during the streaming process
checkpoint_high_glucose = "/FileStore/tables/streaming-data/high-glucose-checkpoint-location"
checkpoint_low_glucose = "/FileStore/tables/streaming-data/low-glucose-checkpoint-location"
checkpoint_lost_connection = "/FileStore/tables/streaming-data/lost-connection-checkpoint-location"
checkpoint_transmission_quality = "/FileStore/tables/streaming-data/transmission-quality-checkpoint-location"

checkpoint_list = [checkpoint_high_glucose, checkpoint_low_glucose, checkpoint_lost_connection, checkpoint_transmission_quality]

def create_checkpoint_location(checkpoint_list):
    for checkpoint in checkpoint_list:
        try:
            # We check if the directory exists
            dbutils.fs.ls(checkpoint)
            print(f"Checkpoint location already exists at {checkpoint}")
        except Exception as e:
            # The java.io error is returned when the directory does not exits
            if "java.io.FileNotFoundException" in str(e):
                # Directory doesn't exist, so create it
                dbutils.fs.mkdirs(checkpoint)
                print(f"Checkpoint location created at {checkpoint}")
            else:
                print(f"Error processing {checkpoint}: {str(e)}")

# In case of wanting to restart.
def delete_checkpoint_locations(checkpoint_list):
    for checkpoint in checkpoint_list:
        try:
            
            dbutils.fs.ls(checkpoint)
            
            dbutils.fs.rm(checkpoint, recurse=True)
            print(f"Checkpoint location deleted at {checkpoint}")
        except Exception as e:
            # The java.io error is returned when the directory does not exist
            if "java.io.FileNotFoundException" in str(e):
                print(f"Checkpoint location does not exist at {checkpoint}")
            else:
                print(f"Error processing {checkpoint}: {str(e)}")

Checkpoint location deleted at /FileStore/tables/streaming-data/high-glucose-checkpoint-location
Checkpoint location deleted at /FileStore/tables/streaming-data/low-glucose-checkpoint-location
Checkpoint location deleted at /FileStore/tables/streaming-data/lost-connection-checkpoint-location
Checkpoint location deleted at /FileStore/tables/streaming-data/transmission-quality-checkpoint-location


In [0]:
# Creating Checkpoint Locations
create_checkpoint_location(checkpoint_list)

Checkpoint location created at /FileStore/tables/streaming-data/high-glucose-checkpoint-location
Checkpoint location created at /FileStore/tables/streaming-data/low-glucose-checkpoint-location
Checkpoint location created at /FileStore/tables/streaming-data/lost-connection-checkpoint-location
Checkpoint location created at /FileStore/tables/streaming-data/transmission-quality-checkpoint-location


In [0]:
# List contents of the /FileStore/tables/streaming-data directory
display(dbutils.fs.ls("/FileStore/tables/streaming-data/"))


path,name,size,modificationTime
dbfs:/FileStore/tables/streaming-data/high-glucose-checkpoint-location/,high-glucose-checkpoint-location/,0,1720369904000
dbfs:/FileStore/tables/streaming-data/lost-connection-checkpoint-location/,lost-connection-checkpoint-location/,0,1720369904000
dbfs:/FileStore/tables/streaming-data/low-glucose-checkpoint-location/,low-glucose-checkpoint-location/,0,1720369904000
dbfs:/FileStore/tables/streaming-data/transmission-quality-checkpoint-location/,transmission-quality-checkpoint-location/,0,1720369905000


In [0]:
# Starting the different streaming Queries

query = final_df.writeStream \
    .outputMode("append") \
    .format("memory") \
    .queryName("final_table") \
    .start()

query_high_glucose = high_glucose_json_df.writeStream \
    .format("eventhubs") \
    .options(**ehConfHighGlucose) \
    .option("checkpointLocation", checkpoint_high_glucose) \
    .trigger(processingTime='5 seconds') \
    .start()


query_low_glucose = low_glucose_json_df.writeStream \
    .format("eventhubs") \
    .options(**ehConfigLowGlucose) \
    .option("checkpointLocation", checkpoint_low_glucose) \
    .trigger(processingTime='5 seconds') \
    .start()

query_lost_connection = lost_connection_json_df.writeStream \
    .format("eventhubs") \
    .options(**ehConfLostConn) \
    .option("checkpointLocation", checkpoint_lost_connection) \
    .trigger(processingTime='5 seconds') \
    .start()

query_transmission_quality = missed_readings_df \
    .writeStream \
    .outputMode("append") \
    .format("eventhubs") \
    .options(**ehConfMS) \
    .option("checkpointLocation", checkpoint_transmission_quality) \
    .trigger(processingTime='5 seconds') \
    .start()
    

In [0]:
# import time

# for _ in range(50):  # Number of iterations
#     print(query.status)
#     spark.sql("SELECT * FROM final_table ORDER BY final_df_timestamp DESC LIMIT 20").display()
#     time.sleep(5)

In [0]:
# Uncomment below to stop the queries

# query.stop()
# query_high_glucose.stop()
# query_lost_connection.stop()
# query_low_glucose.stop()
# query_transmission_quality.stop() 