In [1]:
from pyspark.sql import SparkSession
import os

spark = (
    SparkSession.builder
    .appName("S3AvroAnalytics")
    .config("spark.jars", "/drivers/postgresql-42.5.0.jar")
    .getOrCreate()
)

In [2]:
from pyspark.sql import functions as F, Window

In [3]:
# Configure S3
hadoop_conf = spark._jsc.hadoopConfiguration()
hadoop_conf.set("fs.s3a.access.key", os.getenv("AWS_ACCESS_KEY_ID"))
hadoop_conf.set("fs.s3a.secret.key", os.getenv("AWS_SECRET_ACCESS_KEY"))
hadoop_conf.set("fs.s3a.endpoint", "s3.amazonaws.com")
hadoop_conf.set("fs.s3a.impl", "org.apache.hadoop.fs.s3a.S3AFileSystem")

In [4]:
# Read Avro files from S3
input_path = "s3a://mlops-dbz-sink/topics/dbserver1.stroke_predictions.predictions/"
df = spark.read.format("avro").load(input_path)

In [5]:
df.printSchema()
df.show(1, truncate=False)

root
 |-- before: struct (nullable = true)
 |    |-- id: integer (nullable = true)
 |    |-- timestamp: long (nullable = true)
 |    |-- gender: string (nullable = true)
 |    |-- age: double (nullable = true)
 |    |-- hypertension: integer (nullable = true)
 |    |-- heart_disease: integer (nullable = true)
 |    |-- avg_glucose_level: double (nullable = true)
 |    |-- bmi: double (nullable = true)
 |    |-- smoking_status: string (nullable = true)
 |    |-- name: string (nullable = true)
 |    |-- country: string (nullable = true)
 |    |-- province: string (nullable = true)
 |    |-- probability: double (nullable = true)
 |    |-- risk_category: string (nullable = true)
 |    |-- contributing_factors: string (nullable = true)
 |    |-- prediction_data: string (nullable = true)
 |-- after: struct (nullable = true)
 |    |-- id: integer (nullable = true)
 |    |-- timestamp: long (nullable = true)
 |    |-- gender: string (nullable = true)
 |    |-- age: double (nullable = true)
 | 

In [6]:
df.columns

['before', 'after', 'source', 'op', 'ts_ms', 'transaction', 'partition']

In [7]:
# 1️ Analytic: Count events by operation type
op_counts = df.groupBy("op").count()
op_counts.show()

+---+-----+
| op|count|
+---+-----+
|  r| 4608|
+---+-----+



In [8]:
# 2️ Analytic: Get final state of each prediction (latest by ts_ms)
# Flatten the 'after' struct for easier use
after_df = df.filter(df.after.isNotNull()).select(
    df.after.id.alias("id"),
    df.after.age.alias("age"),
    df.after.bmi.alias("bmi"),
    df.after.gender.alias("gender"),
    df.after.country.alias("country"),
    df.after.province.alias("province"),
    df.after.hypertension.alias("hypertension"),
    df.after.heart_disease.alias("heart_disease"),
    df.after.smoking_status.alias("smoking_status"),
    df.after.avg_glucose_level.alias("avg_glucose_level"),
    df.after.probability.alias("probability"),
    df.after.risk_category.alias("risk_category"),
    df.op,
    df.ts_ms
)

In [9]:
# Window function to get latest ts_ms per id
window = Window.partitionBy("id").orderBy(F.col("ts_ms").desc())

In [10]:
latest_stroke_predictions = (
    after_df.withColumn("rn", F.row_number().over(window))
    .filter(F.col("rn") == 1)
    .drop("rn")
)

In [11]:
latest_stroke_predictions.show(5, truncate=False)

+---+----+------------------+------+--------------+------------+------------+-------------+---------------+------------------+-------------------+-------------+---+-------------+
|id |age |bmi               |gender|country       |province    |hypertension|heart_disease|smoking_status |avg_glucose_level |probability        |risk_category|op |ts_ms        |
+---+----+------------------+------+--------------+------------+------------+-------------+---------------+------------------+-------------------+-------------+---+-------------+
|1  |45.5|26.799999237060547|Male  |United States |California  |0           |0            |never smoked   |95.19999694824219 |0.27587899565696716|Low          |r  |1758131396734|
|2  |67.0|32.099998474121094|Female|Canada        |Ontario     |1           |1            |formerly smoked|145.8000030517578 |0.772845983505249  |High         |r  |1758131396804|
|3  |35.0|22.299999237060547|Other |United Kingdom|London      |0           |0            |never smoked  

In [12]:
output_counts = "s3a://mlops-dbz-sink/analytics/stroke_predictions_op_counts/"
output_latest = "s3a://mlops-dbz-sink/analytics/stroke_predictions_latest/"

In [13]:
op_counts.write.mode("overwrite").parquet(output_counts)
latest_stroke_predictions.write.mode("overwrite").parquet(output_latest)

In [14]:
print(f"Operation counts written to {output_counts}")
print(f"Latest customers snapshot written to {output_latest}")

Operation counts written to s3a://mlops-dbz-sink/analytics/stroke_predictions_op_counts/
Latest customers snapshot written to s3a://mlops-dbz-sink/analytics/stroke_predictions_latest/


In [15]:
latest_stroke_predictions.createOrReplaceTempView("stroke_predictions")

In [16]:
overall_distribution_by_risk_category = spark.sql("""
SELECT risk_category, 
       ROUND(100.0 * COUNT(*) / SUM(COUNT(*)) OVER(), 2) AS pct
FROM stroke_predictions
GROUP BY risk_category;
""")

In [17]:
overall_distribution_by_risk_category.show(5, truncate=False)

+-------------+-----+
|risk_category|pct  |
+-------------+-----+
|High         |27.91|
|Low          |38.72|
|Medium       |33.38|
+-------------+-----+



In [18]:
stroke_probability_bins = spark.sql("""
SELECT 
    CASE 
        WHEN probability < 0.2 THEN '0.0 - 0.2'
        WHEN probability < 0.4 THEN '0.2 - 0.4'
        WHEN probability < 0.6 THEN '0.4 - 0.6'
        WHEN probability < 0.8 THEN '0.6 - 0.8'
        ELSE '0.8 - 1.0'
    END AS prob_bucket,
    COUNT(*) AS total
FROM stroke_predictions
GROUP BY CASE 
            WHEN probability < 0.2 THEN '0.0 - 0.2'
            WHEN probability < 0.4 THEN '0.2 - 0.4'
            WHEN probability < 0.6 THEN '0.4 - 0.6'
            WHEN probability < 0.8 THEN '0.6 - 0.8'
            ELSE '0.8 - 1.0'
         END;
""")

In [19]:
stroke_probability_bins.show(5, truncate=False)

+-----------+-----+
|prob_bucket|total|
+-----------+-----+
|0.4 - 0.6  |651  |
|0.2 - 0.4  |985  |
|0.8 - 1.0  |723  |
|0.6 - 0.8  |1059 |
|0.0 - 0.2  |1190 |
+-----------+-----+



In [20]:
risk_by_gender = spark.sql("""
SELECT 
    gender,
    SUM(CASE WHEN risk_category = 'High' THEN 1 ELSE 0 END) AS high_risk,
    SUM(CASE WHEN risk_category = 'Medium' THEN 1 ELSE 0 END) AS medium_risk,
    SUM(CASE WHEN risk_category = 'Low' THEN 1 ELSE 0 END) AS low_risk
FROM stroke_predictions
WHERE gender!='Other'
GROUP BY gender
ORDER BY gender;
""")

In [21]:
risk_by_gender.show(5, truncate=False)

+------+---------+-----------+--------+
|gender|high_risk|medium_risk|low_risk|
+------+---------+-----------+--------+
|Female|633      |784        |883     |
|Male  |653      |754        |900     |
+------+---------+-----------+--------+



In [22]:
risk_by_age_group = spark.sql("""
SELECT 
    CASE 
        WHEN age < 30 THEN 'Under 30'
        WHEN age BETWEEN 30 AND 45 THEN '30-45'
        WHEN age BETWEEN 46 AND 60 THEN '46-60'
        ELSE '60+'
    END AS age_group,
    risk_category,
    COUNT(*) AS total
FROM stroke_predictions
GROUP BY 
    CASE 
        WHEN age < 30 THEN 'Under 30'
        WHEN age BETWEEN 30 AND 45 THEN '30-45'
        WHEN age BETWEEN 46 AND 60 THEN '46-60'
        ELSE '60+'
    END,
    risk_category;

""")

In [23]:
risk_by_age_group.show(5, truncate=False)

+---------+-------------+-----+
|age_group|risk_category|total|
+---------+-------------+-----+
|60+      |High         |1285 |
|46-60    |Low          |44   |
|46-60    |Medium       |951  |
|Under 30 |Low          |836  |
|30-45    |Low          |903  |
+---------+-------------+-----+
only showing top 5 rows



In [24]:
province_hotspots = spark.sql("""
SELECT 
    province,
    SUM(CASE WHEN risk_category = 'High' THEN 1 ELSE 0 END) AS high_risk,
    SUM(CASE WHEN risk_category = 'Medium' THEN 1 ELSE 0 END) AS medium_risk,
    SUM(CASE WHEN risk_category = 'Low' THEN 1 ELSE 0 END) AS low_risk
FROM stroke_predictions
WHERE country = 'South Africa'
GROUP BY province
ORDER BY province;
""")

In [25]:
province_hotspots.show(5, truncate=False)

+-------------+---------+-----------+--------+
|province     |high_risk|medium_risk|low_risk|
+-------------+---------+-----------+--------+
|Eastern Cape |147      |170        |222     |
|Free State   |142      |195        |180     |
|Gauteng      |144      |166        |203     |
|KwaZulu-Natal|141      |159        |172     |
|Limpopo      |126      |154        |212     |
+-------------+---------+-----------+--------+
only showing top 5 rows



In [26]:
hypertension_heart_disease_correlation=spark.sql("""
SELECT 
    hypertension,
    heart_disease,
    SUM(CASE WHEN risk_category = 'Low' THEN 1 ELSE 0 END)    AS low_risk,
    CAST(100.0 * SUM(CASE WHEN risk_category = 'Low' THEN 1 ELSE 0 END) 
         / COUNT(*) AS DECIMAL(5,2)) AS low_pct,
    
    SUM(CASE WHEN risk_category = 'Medium' THEN 1 ELSE 0 END) AS medium_risk,
    CAST(100.0 * SUM(CASE WHEN risk_category = 'Medium' THEN 1 ELSE 0 END) 
         / COUNT(*) AS DECIMAL(5,2)) AS medium_pct,
    
    SUM(CASE WHEN risk_category = 'High' THEN 1 ELSE 0 END)   AS high_risk,
    CAST(100.0 * SUM(CASE WHEN risk_category = 'High' THEN 1 ELSE 0 END) 
         / COUNT(*) AS DECIMAL(5,2)) AS high_pct,
    
    COUNT(*) AS total
FROM stroke_predictions
GROUP BY hypertension, heart_disease
ORDER BY hypertension, heart_disease;
""")

In [27]:
hypertension_heart_disease_correlation.show(5, truncate=False)

+------------+-------------+--------+-------+-----------+----------+---------+--------+-----+
|hypertension|heart_disease|low_risk|low_pct|medium_risk|medium_pct|high_risk|high_pct|total|
+------------+-------------+--------+-------+-----------+----------+---------+--------+-----+
|0           |0            |531     |48.14  |348        |31.55     |224      |20.31   |1103 |
|0           |1            |435     |37.40  |388        |33.36     |340      |29.23   |1163 |
|1           |0            |438     |37.47  |409        |34.99     |322      |27.54   |1169 |
|1           |1            |380     |32.40  |393        |33.50     |400      |34.10   |1173 |
+------------+-------------+--------+-------+-----------+----------+---------+--------+-----+



In [28]:
bmi_vs_risk=spark.sql("""
SELECT 
    CASE 
        WHEN bmi < 18.5 THEN 'Underweight'
        WHEN bmi < 25 THEN 'Healthy'
        WHEN bmi < 30 THEN 'Overweight'
        ELSE 'Obese'
    END AS bmi_category,
    risk_category,
    COUNT(*) AS total,
    CAST(
        100.0 * COUNT(*) 
        / SUM(COUNT(*)) OVER (
            PARTITION BY 
                CASE 
                    WHEN bmi < 18.5 THEN 'Underweight'
                    WHEN bmi < 25 THEN 'Healthy'
                    WHEN bmi < 30 THEN 'Overweight'
                    ELSE 'Obese'
                END
        ) 
        AS DECIMAL(5,2)
    ) AS pct
FROM stroke_predictions
GROUP BY 
    CASE 
        WHEN bmi < 18.5 THEN 'Underweight'
        WHEN bmi < 25 THEN 'Healthy'
        WHEN bmi < 30 THEN 'Overweight'
        ELSE 'Obese'
    END,
    risk_category
ORDER BY bmi_category, risk_category;

""")

In [29]:
bmi_vs_risk.show(5 , truncate=False)

+------------+-------------+-----+-----+
|bmi_category|risk_category|total|pct  |
+------------+-------------+-----+-----+
|Healthy     |High         |402  |26.19|
|Healthy     |Low          |635  |41.37|
|Healthy     |Medium       |498  |32.44|
|Obese       |High         |438  |30.44|
|Obese       |Low          |500  |34.75|
+------------+-------------+-----+-----+
only showing top 5 rows



In [30]:
glucose_level_risk_bands = spark.sql("""
SELECT 
    CASE 
        WHEN avg_glucose_level < 100 THEN 'Normal'
        WHEN avg_glucose_level < 126 THEN 'Prediabetic'
        ELSE 'Diabetic'
    END AS glucose_category,
    risk_category,
    COUNT(*) AS total,
    CAST(
        100.0 * COUNT(*) 
        / SUM(COUNT(*)) OVER (
            PARTITION BY 
                CASE 
                    WHEN avg_glucose_level < 100 THEN 'Normal'
                    WHEN avg_glucose_level < 126 THEN 'Prediabetic'
                    ELSE 'Diabetic'
                END
        )
        AS DECIMAL(5,2)
    ) AS pct
FROM stroke_predictions
GROUP BY 
    CASE 
        WHEN avg_glucose_level < 100 THEN 'Normal'
        WHEN avg_glucose_level < 126 THEN 'Prediabetic'
        ELSE 'Diabetic'
    END,
    risk_category
ORDER BY glucose_category, risk_category;  
""")

In [31]:
glucose_level_risk_bands.show(5, truncate=False)

+----------------+-------------+-----+-----+
|glucose_category|risk_category|total|pct  |
+----------------+-------------+-----+-----+
|Diabetic        |High         |478  |39.18|
|Diabetic        |Low          |343  |28.11|
|Diabetic        |Medium       |399  |32.70|
|Normal          |High         |201  |17.12|
|Normal          |Low          |601  |51.19|
+----------------+-------------+-----+-----+
only showing top 5 rows



In [32]:
extreme_bmi_or_glucose_with_low_risk_potential_misclassification_flag = spark.sql("""
SELECT id, age, bmi, avg_glucose_level, probability, risk_category
FROM stroke_predictions
WHERE (bmi > 35 OR avg_glucose_level > 150)
  AND risk_category = 'Low';
""")

In [33]:
extreme_bmi_or_glucose_with_low_risk_potential_misclassification_flag.show(5, truncate=False)

+---+----+------------------+------------------+-------------------+-------------+
|id |age |bmi               |avg_glucose_level |probability        |risk_category|
+---+----+------------------+------------------+-------------------+-------------+
|255|20.0|21.799999237060547|159.8000030517578 |0.12815499305725098|Low          |
|260|36.0|24.200000762939453|154.60000610351562|0.2779709994792938 |Low          |
|295|20.0|37.29999923706055 |110.0             |0.07099340111017227|Low          |
|529|19.0|35.5              |134.60000610351562|0.11188499629497528|Low          |
|552|30.0|37.0              |72.9000015258789  |0.10963000357151031|Low          |
+---+----+------------------+------------------+-------------------+-------------+
only showing top 5 rows



In [34]:
top_10_highest_probability_cases=spark.sql("""
SELECT id, age, gender, probability, risk_category
FROM stroke_predictions
ORDER BY probability DESC
LIMIT 10;  
""")

In [35]:
top_10_highest_probability_cases.show(10, truncate=False)

+----+----+------+------------------+-------------+
|id  |age |gender|probability       |risk_category|
+----+----+------+------------------+-------------+
|2567|85.0|Female|0.9199600219726562|High         |
|559 |85.0|Female|0.919743001461029 |High         |
|1181|85.0|Female|0.9187309741973877|High         |
|2788|84.0|Female|0.9177759885787964|High         |
|3826|85.0|Male  |0.9163039922714233|High         |
|4475|84.0|Female|0.9162330031394958|High         |
|3152|85.0|Male  |0.9137960076332092|High         |
|1197|85.0|Male  |0.9132140278816223|High         |
|53  |85.0|Female|0.9131709933280945|High         |
|1572|85.0|Male  |0.9124510288238525|High         |
+----+----+------+------------------+-------------+



In [48]:
# PostgreSQL connection configuration
postgres_host = "18.208.221.9"
postgres_port = "5432"
postgres_db = "analytics_db"
postgres_user = "analytics_user"
postgres_password = "analytics_pass" 

In [49]:
postgres_url = f"jdbc:postgresql://{postgres_host}:{postgres_port}/{postgres_db}"
postgres_properties = {
    "user": postgres_user,
    "password": postgres_password,
    "driver": "org.postgresql.Driver"
}

In [52]:
def write_spark_df_to_postgresql(spark_df, table_name, mode="replace"):
    """
    Write Spark DataFrame to PostgreSQL using psycopg2 and sqlalchemy
    Modes: 'replace', 'append'
    """
    try:
        # Convert Spark DataFrame to Pandas
        pandas_df = spark_df.toPandas()
        
        # Create SQLAlchemy engine
        from sqlalchemy import create_engine
        engine = create_engine(f'postgresql://{postgres_user}:{postgres_password}@{postgres_host}:{postgres_port}/{postgres_db}')
        
        # Write Pandas DataFrame to PostgreSQL
        pandas_df.to_sql(table_name, engine, if_exists=mode, index=False)
        
        print(f"Successfully wrote {pandas_df.shape[0]} rows to {table_name}")
    except Exception as e:
        print(f"Error writing to {table_name}: {str(e)}")
        raise

In [54]:
write_spark_df_to_postgresql(op_counts, "stroke_analytics_op_counts", mode="replace")

Successfully wrote 1 rows to stroke_analytics_op_counts


In [55]:
write_spark_df_to_postgresql(overall_distribution_by_risk_category, "stroke_analytics_risk_distribution", mode="replace")

Successfully wrote 3 rows to stroke_analytics_risk_distribution


In [56]:
write_spark_df_to_postgresql(stroke_probability_bins, "stroke_analytics_probability_bins", mode="replace")

Successfully wrote 5 rows to stroke_analytics_probability_bins


In [57]:
write_spark_df_to_postgresql(risk_by_gender, "stroke_analytics_risk_by_gender", mode="replace")

Successfully wrote 2 rows to stroke_analytics_risk_by_gender


In [58]:
write_spark_df_to_postgresql(risk_by_age_group, "stroke_analytics_risk_by_age_group", mode="replace")

Successfully wrote 9 rows to stroke_analytics_risk_by_age_group


In [59]:
write_spark_df_to_postgresql(province_hotspots, "stroke_analytics_province_hotspots", mode="replace")

Successfully wrote 9 rows to stroke_analytics_province_hotspots


In [60]:
write_spark_df_to_postgresql(hypertension_heart_disease_correlation, "stroke_analytics_hypertension_heart_correlation", mode="replace")

Successfully wrote 4 rows to stroke_analytics_hypertension_heart_correlation


In [61]:
write_spark_df_to_postgresql(bmi_vs_risk, "stroke_analytics_bmi_vs_risk", mode="replace")

Successfully wrote 12 rows to stroke_analytics_bmi_vs_risk


In [62]:
write_spark_df_to_postgresql(glucose_level_risk_bands, "stroke_analytics_glucose_risk_bands", mode="replace")

Successfully wrote 9 rows to stroke_analytics_glucose_risk_bands
