In [1]:
import pyspark
from pyspark.sql import SparkSession
import pyspark.sql.functions as F

import pandas as pd
import matplotlib.pyplot as plt


In [2]:
# Start Spark session 
spark = SparkSession.builder.appName("GeoHealthAI").getOrCreate()

25/07/08 11:38:26 WARN Utils: Your hostname, Remis-MacBook-Pro.local resolves to a loopback address: 127.0.0.1; using 10.1.10.180 instead (on interface en0)
25/07/08 11:38:26 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/07/08 11:38:27 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [3]:
# Define file paths
atlas_path = "/Users/remihendershott/Desktop/Food Access Research Atlas.csv"
places_path = "/Users/remihendershott/Desktop/PLACES__Census_Tract_Data__GIS_Friendly_Format___2020_release_20250623.csv"

# Read each file & infer schema
atlas_df = spark.read.csv(atlas_path, header=True, inferSchema=True)
places_df = spark.read.csv(places_path, header=True, inferSchema=True)

# Show basic info
atlas_df.show(5)
places_df.show(5)

# Check schema
atlas_df.printSchema()
places_df.printSchema()


25/07/08 11:38:31 WARN SparkStringUtils: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.


+-----------+-------+--------------+-----+-------+-------+-----------------+--------+--------+-----------------+--------------------+-----------------+------------------+--------+---------------+-----------+------------------+--------+-----------+--------+-------------+---------+----------+----------+------------------+---------+----------+---------+----------+-----------+----------+---------+--------------+----------+---------------+----------+---------------+-------------+------------------+-----------+----------------+-----------+----------------+-----------+----------------+-----------+----------------+----------+---------------+-------------+------------------+----------+---------------+----------+---------------+----------+---------------+------+-----------+-------+------------+-------+------------+----------+---------------+--------+-------------+--------+-------------+--------+-------------+--------+-------------+-------+------------+----------+---------------+-------+---------

In [4]:
# Filter to AZ
atlas_az_df = atlas_df.filter(atlas_df.State == "Arizona")

# Check AZ rows
print("Rows for Arizona:", atlas_az_df.count())

# Preview
atlas_az_df.show(5)

# Filter to AZ
places_az_df = places_df.filter(places_df.StateAbbr == "AZ")

# Check number of AZ rows
print("Rows for Arizona:", places_az_df.count())

# Preview
places_az_df.show(5)


Rows for Arizona: 1520
+-----------+-------+-------------+-----+-------+-------+-----------------+--------+--------+-----------------+--------------------+-----------------+------------------+--------+---------------+-----------+------------------+--------+-----------+--------+-------------+---------+----------+----------+------------------+---------+----------+---------+----------+-----------+----------+---------+--------------+----------+---------------+----------+---------------+-------------+------------------+-----------+----------------+-----------+----------------+-----------+----------------+-----------+----------------+----------+---------------+-------------+------------------+----------+---------------+----------+---------------+----------+---------------+------+-----------+-------+------------+-------+------------+----------+---------------+--------+-------------+--------+-------------+--------+-------------+--------+-------------+-------+------------+----------+-----------

In [5]:
# EDA for atlas 
atlas_az_df.describe().show()

# Distinct counties
atlas_az_df.select("County").distinct().show(20)

25/07/08 11:38:34 WARN DAGScheduler: Broadcasting large task binary with size 1395.6 KiB
                                                                                

+-------+-------------------+-------+-------------+------------------+------------------+------------------+--------------------+-----------------+-----------------+-------------------+--------------------+-------------------+-------------------+------------------+-------------------+-----------------+------------------+-------------------+-------------------+-------------------+-------------------+-------------------+-------------------+-------------------+------------------+------------------+-----------------+------------------+-----------------+-----------------+-----------------+------------------+------------------+-----------------+------------------+-----------------+------------------+------------------+------------------+------------------+------------------+------------------+------------------+-----------------+------------------+-----------------+-------------------+-----------------+------------------+------------------+------------------+-----------------+---------------

In [6]:
# More atlas EDA 
# How many urban vs rural tracts?
atlas_az_df.groupBy("Urban").count().show()

# Tracts per county
atlas_az_df.groupBy("County").count().orderBy("count", ascending=False).show(20)

# Basic stats
atlas_az_df.describe(["PovertyRate", "MedianFamilyIncome"]).show()

+-----+-----+
|Urban|count|
+-----+-----+
|    1| 1292|
|    0|  228|
+-----+-----+

+-----------------+-----+
|           County|count|
+-----------------+-----+
|  Maricopa County|  913|
|      Pima County|  241|
|     Pinal County|   75|
|      Yuma County|   53|
|    Mohave County|   43|
|   Yavapai County|   42|
|   Cochise County|   32|
|    Navajo County|   31|
|  Coconino County|   28|
|    Apache County|   16|
|      Gila County|   16|
|Santa Cruz County|   10|
|    Graham County|    9|
|    La Paz County|    8|
|  Greenlee County|    3|
+-----------------+-----+

+-------+-----------------+------------------+
|summary|      PovertyRate|MedianFamilyIncome|
+-------+-----------------+------------------+
|  count|             1520|              1520|
|   mean|16.67473684210528| 69963.00597213006|
| stddev|12.63485960943205|32213.432015642065|
|    min|                0|            100179|
|    max|              9.9|              NULL|
+-------+-----------------+-----------------

In [7]:
# How many tracts are low income?
atlas_az_df.groupBy("LowIncomeTracts").count().show()

# Who meets food desert criteria?
for col in ["LILATracts_1And10", "LILATracts_halfAnd10", "LILATracts_1And20", "LILATracts_Vehicle"]:
    atlas_az_df.groupBy(col).count().show()

atlas_az_df.describe(["Pop2010", "OHU2010"]).show()

# Compare pop for flagged tracts
atlas_az_df.filter(atlas_az_df.LILATracts_1And10 == 1).describe(["Pop2010"]).show()

# Check share of pop low-income within 1 mile urban
atlas_az_df.describe(["lapop1share", "lalowi1share"]).show()

# Nulls in key columns
atlas_az_df.select([F.count(F.when(F.col(c).isNull(), c)).alias(c) for c in ["CensusTract", "PovertyRate", "MedianFamilyIncome"]]).show()

# Duplicates on CensusTract
atlas_az_df.groupBy("CensusTract").count().filter("count > 1").show()

+---------------+-----+
|LowIncomeTracts|count|
+---------------+-----+
|              1|  651|
|              0|  869|
+---------------+-----+

+-----------------+-----+
|LILATracts_1And10|count|
+-----------------+-----+
|                1|  257|
|                0| 1263|
+-----------------+-----+

+--------------------+-----+
|LILATracts_halfAnd10|count|
+--------------------+-----+
|                   1|  546|
|                   0|  974|
+--------------------+-----+

+-----------------+-----+
|LILATracts_1And20|count|
+-----------------+-----+
|                1|  226|
|                0| 1294|
+-----------------+-----+

+------------------+-----+
|LILATracts_Vehicle|count|
+------------------+-----+
|                 1|  237|
|                 0| 1283|
+------------------+-----+

+-------+------------------+------------------+
|summary|           Pop2010|           OHU2010|
+-------+------------------+------------------+
|  count|              1520|              1520|
|   mean| 4

In [8]:
# EDA for places 
# Basics 
places_az_df.columns

places_az_df.describe().show()

places_az_df.describe(["DIABETES_CrudePrev", "OBESITY_CrudePrev"]).show()

for col in ["DIABETES_CrudePrev", "OBESITY_CrudePrev"]:
    nulls = places_az_df.filter(F.col(col).isNull()).count()
    print(f"Missing in {col}: {nulls}")

# How many unique tracts?
places_az_df.select("TractFIPS").distinct().count()

# How many health records per tract?
places_az_df.groupBy("TractFIPS").count().orderBy("count", ascending=False).show(10)

# Distinct counties
places_az_df.select("CountyName").distinct().show(20)

# Tracts per county
places_az_df.groupBy("CountyName").count().orderBy("count", ascending=False).show(20)

+-------+---------+---------+----------+------------------+--------------------+------------------+------------------+-----------------+-------------------+-------------------+------------------+---------------+-----------------+----------------+-----------------+---------------+-----------------+----------------+------------------+-----------------+------------------+------------------+-----------------+-------------+-----------------+-----------------+--------------------+--------------------+----------------------+----------------------+-----------------+--------------+------------------+---------------+------------------+---------------+------------------+------------------+------------------+----------------+------------------+------------------+------------------+------------------+------------------+----------------+------------------+-------------+------------------+------------------+------------------+---------------+------------------+-----------------+------------------+---

In [9]:
# List all columns that contain 'CANCER'
places_az_df.select([col for col in places_az_df.columns if "CANCER" in col]).show(5)


+----------------+----------------+
|CANCER_CrudePrev|CANCER_Crude95CI|
+----------------+----------------+
|             6.6|    ( 6.4,  6.8)|
|             6.7|    ( 6.6,  6.9)|
|             6.2|    ( 6.0,  6.4)|
|             5.7|    ( 5.6,  5.8)|
|             5.9|    ( 5.7,  6.1)|
+----------------+----------------+
only showing top 5 rows



In [10]:
# EDA for cancer info
places_az_df.describe(["CANCER_CrudePrev"]).show()

places_az_df.select("TractFIPS", "CountyName", "CANCER_CrudePrev").orderBy(F.desc("CANCER_CrudePrev")).show(10)

+-------+-----------------+
|summary| CANCER_CrudePrev|
+-------+-----------------+
|  count|             1516|
|   mean|6.719459102902372|
| stddev|2.839655066263512|
|    min|              0.9|
|    max|             18.7|
+-------+-----------------+

+----------+----------+----------------+
| TractFIPS|CountyName|CANCER_CrudePrev|
+----------+----------+----------------+
|4013040507|  Maricopa|            18.7|
|4013040513|  Maricopa|            18.4|
|4013040512|  Maricopa|            18.3|
|4013071503|  Maricopa|            18.2|
|4013422618|  Maricopa|            18.2|
|4013040522|  Maricopa|            17.9|
|4013040506|  Maricopa|            17.7|
|4013040514|  Maricopa|            17.6|
|4019004328|      Pima|            17.5|
|4013071504|  Maricopa|            17.2|
+----------+----------+----------------+
only showing top 10 rows



25/07/08 11:38:39 WARN GarbageCollectionMetrics: To enable non-built-in garbage collector(s) List(G1 Concurrent GC), users should configure it(them) to spark.eventLog.gcMetrics.youngGenerationGarbageCollectors or spark.eventLog.gcMetrics.oldGenerationGarbageCollectors


In [11]:
places_az_df.groupBy("CountyName").agg(
    F.mean("CANCER_CrudePrev").alias("Mean_Cancer_Rate"),
    F.stddev("CANCER_CrudePrev").alias("StdDev_Cancer_Rate")
).orderBy(F.desc("Mean_Cancer_Rate")).show(10)

+----------+------------------+------------------+
|CountyName|  Mean_Cancer_Rate|StdDev_Cancer_Rate|
+----------+------------------+------------------+
|    La Paz|10.075000000000001| 3.105180003984126|
|   Yavapai| 9.507142857142856|1.5097072656228971|
|    Mohave| 8.930232558139531|1.3516365536639068|
|      Gila| 8.900000000000002|1.8504053609952607|
|   Cochise|          7.359375| 2.023707771680805|
|      Pima| 7.212863070539423| 2.762992131489376|
|     Pinal|7.2093333333333325|2.8117116712996237|
|    Navajo| 6.874193548387097|1.3616087891254434|
|      Yuma| 6.845283018867924|3.2162188763620603|
|    Apache|           6.61875| 0.798514245333169|
+----------+------------------+------------------+
only showing top 10 rows



In [12]:
# Correlations testing (initial)
places_az_df.select("CANCER_CrudePrev", "CSMOKING_CrudePrev", "OBESITY_CrudePrev").describe().show()

places_az_df.corr("CANCER_CrudePrev", "CSMOKING_CrudePrev")
places_az_df.corr("CANCER_CrudePrev", "OBESITY_CrudePrev")

+-------+-----------------+------------------+------------------+
|summary| CANCER_CrudePrev|CSMOKING_CrudePrev| OBESITY_CrudePrev|
+-------+-----------------+------------------+------------------+
|  count|             1516|              1516|              1516|
|   mean|6.719459102902372|17.002770448548798|30.731596306068653|
| stddev|2.839655066263512| 5.576327166046027| 5.144132261524036|
|    min|              0.9|               6.3|              20.3|
|    max|             18.7|              46.7|              51.7|
+-------+-----------------+------------------+------------------+



-0.3692211474776091

In [13]:
print("Cancer vs Smoking:", places_az_df.stat.corr("CANCER_CrudePrev", "CSMOKING_CrudePrev"))
print("Cancer vs Obesity:", places_az_df.stat.corr("CANCER_CrudePrev", "OBESITY_CrudePrev"))
print("Cancer vs Poverty:", places_az_df.stat.corr("CANCER_CrudePrev", "ACCESS2_CrudePrev"))

places_az_df.select("TractFIPS", "CountyName", "CANCER_CrudePrev") \
    .orderBy(F.desc("CANCER_CrudePrev")) \
    .show(10)


Cancer vs Smoking: -0.42326719878951924
Cancer vs Obesity: -0.3692211474776091
Cancer vs Poverty: -0.5062624798821899
+----------+----------+----------------+
| TractFIPS|CountyName|CANCER_CrudePrev|
+----------+----------+----------------+
|4013040507|  Maricopa|            18.7|
|4013040513|  Maricopa|            18.4|
|4013040512|  Maricopa|            18.3|
|4013071503|  Maricopa|            18.2|
|4013422618|  Maricopa|            18.2|
|4013040522|  Maricopa|            17.9|
|4013040506|  Maricopa|            17.7|
|4013040514|  Maricopa|            17.6|
|4019004328|      Pima|            17.5|
|4013071504|  Maricopa|            17.2|
+----------+----------+----------------+
only showing top 10 rows



In [14]:
# Merge & some cleaning for the atlas and places data frames

from pyspark.sql.functions import lpad, col

# Standardize both df's by converting FIPs to string and ensure values are 11 digits 
places_az_df = places_az_df.withColumn("TractFIPS_str", lpad(col("TractFIPS").cast("string"), 11, "0"))
atlas_az_df = atlas_az_df.withColumn("CensusTract", lpad(col("CensusTract").cast("string"), 11, "0"))

# Rename for merging
places_az_df = places_az_df.withColumnRenamed("TractFIPS_str", "CensusTract")

# Merge
joined_df = places_az_df.join(atlas_az_df, on="CensusTract", how="inner")

joined_df.printSchema()
joined_df.show(5, truncate=False)

root
 |-- CensusTract: string (nullable = true)
 |-- StateAbbr: string (nullable = true)
 |-- StateDesc: string (nullable = true)
 |-- CountyName: string (nullable = true)
 |-- CountyFIPS: integer (nullable = true)
 |-- TractFIPS: long (nullable = true)
 |-- TotalPopulation: integer (nullable = true)
 |-- ACCESS2_CrudePrev: double (nullable = true)
 |-- ACCESS2_Crude95CI: string (nullable = true)
 |-- ARTHRITIS_CrudePrev: double (nullable = true)
 |-- ARTHRITIS_Crude95CI: string (nullable = true)
 |-- BINGE_CrudePrev: double (nullable = true)
 |-- BINGE_Crude95CI: string (nullable = true)
 |-- BPHIGH_CrudePrev: double (nullable = true)
 |-- BPHIGH_Crude95CI: string (nullable = true)
 |-- BPMED_CrudePrev: double (nullable = true)
 |-- BPMED_Crude95CI: string (nullable = true)
 |-- CANCER_CrudePrev: double (nullable = true)
 |-- CANCER_Crude95CI: string (nullable = true)
 |-- CASTHMA_CrudePrev: double (nullable = true)
 |-- CASTHMA_Crude95CI: string (nullable = true)
 |-- CERVICAL_CrudeP

In [15]:
# Basic EDA on joined set
print(f"Total Rows: {joined_df.count()}")
print(f"Total Columns: {len(joined_df.columns)}")

# Missing/NULL values?
key_cols = ["CANCER_CrudePrev", "PovertyRate", "LILATracts_1And10", "TEETHLOST_CrudePrev"]
for col_name in key_cols:
    nulls = joined_df.filter(col(col_name).isNull()).count()
    print(f"Missing values in {col_name}: {nulls}")

# Summary stats
joined_df.select("CANCER_CrudePrev", "PovertyRate", "TEETHLOST_CrudePrev").describe().show()

Total Rows: 1516
Total Columns: 210
Missing values in CANCER_CrudePrev: 0
Missing values in PovertyRate: 0
Missing values in LILATracts_1And10: 0
Missing values in TEETHLOST_CrudePrev: 1
+-------+-----------------+-----------------+-------------------+
|summary| CANCER_CrudePrev|      PovertyRate|TEETHLOST_CrudePrev|
+-------+-----------------+-----------------+-------------------+
|  count|             1516|             1516|               1515|
|   mean|6.719459102902372|16.65277044854883| 14.291683168316835|
| stddev|2.839655066263512|12.44697528075318|  7.581534579742508|
|    min|              0.9|                0|                4.5|
|    max|             18.7|              9.9|               49.3|
+-------+-----------------+-----------------+-------------------+



In [16]:
# Correlation for Teeth Lost and Cancer
joined_df.stat.corr("CANCER_CrudePrev", "TEETHLOST_CrudePrev")


-0.2533784937087673

In [17]:
from pyspark.sql.functions import col, regexp_replace

# Clean income/poverty fields 
clean_df = joined_df.withColumn("PovertyRate", regexp_replace("PovertyRate", ",", "").cast("double"))
clean_df = clean_df.withColumn("MedianFamilyIncome", regexp_replace("MedianFamilyIncome", ",", "").cast("double"))

# Select relevant columns and drop nulls
analysis_df = clean_df.select(
    "OBESITY_CrudePrev",
    "DIABETES_CrudePrev",
    "CANCER_CrudePrev",
    "LILATracts_1And10",
    "LILATracts_halfAnd10",
    "LILATracts_1And20",
    "LILATracts_Vehicle",
    "LowIncomeTracts",
    "LATracts1",
    "LATracts10",
    "LATracts20",
    "PovertyRate",
    "MedianFamilyIncome",
    "CountyName"
).na.drop()

In [18]:
for col_name in [
    "LILATracts_1And10", "LILATracts_halfAnd10", "LILATracts_1And20",
    "LILATracts_Vehicle", "LowIncomeTracts", "LATracts1", "LATracts10",
    "LATracts20", "PovertyRate", "MedianFamilyIncome"
]:
    corr_obesity = analysis_df.stat.corr("OBESITY_CrudePrev", col_name)
    corr_diabetes = analysis_df.stat.corr("DIABETES_CrudePrev", col_name)
    corr_cancer = analysis_df.stat.corr("CANCER_CrudePrev", col_name)
    print(f"Correlation with {col_name}:\n  Obesity: {corr_obesity:.3f} | Diabetes: {corr_diabetes:.3f} | Cancer: {corr_cancer:.3f}")

Correlation with LILATracts_1And10:
  Obesity: 0.331 | Diabetes: 0.435 | Cancer: 0.047
Correlation with LILATracts_halfAnd10:
  Obesity: 0.539 | Diabetes: 0.472 | Cancer: -0.157
Correlation with LILATracts_1And20:
  Obesity: 0.309 | Diabetes: 0.384 | Cancer: 0.031
Correlation with LILATracts_Vehicle:
  Obesity: 0.363 | Diabetes: 0.373 | Cancer: -0.085
Correlation with LowIncomeTracts:
  Obesity: 0.657 | Diabetes: 0.550 | Cancer: -0.199
Correlation with LATracts1:
  Obesity: -0.146 | Diabetes: -0.069 | Cancer: 0.136
Correlation with LATracts10:
  Obesity: 0.169 | Diabetes: 0.339 | Cancer: 0.100
Correlation with LATracts20:
  Obesity: 0.160 | Diabetes: 0.307 | Cancer: 0.052
Correlation with PovertyRate:
  Obesity: 0.737 | Diabetes: 0.551 | Cancer: -0.350
Correlation with MedianFamilyIncome:
  Obesity: -0.681 | Diabetes: -0.584 | Cancer: 0.152


In [None]:
# Group agg to compare some key factors with chronic diseases
binary_flags = [
    "LILATracts_1And10",
    "LILATracts_halfAnd10",
    "LILATracts_1And20",
    "LILATracts_Vehicle",
    "LowIncomeTracts",
    "LATracts1",
    "LATracts10",
    "LATracts20"
]

# Group by each flag, compute avg rates
for flag in binary_flags:
    print(f"\n=== Average Cancer, Obesity, & Diabetes Rates by {flag} ===")
    analysis_df.groupBy(flag).agg(
        {"OBESITY_CrudePrev": "avg", "DIABETES_CrudePrev": "avg", "CANCER_CrudePrev": "avg"}
    ).show()


=== Average Cancer, Obesity, & Diabetes Rates by LILATracts_1And10 ===
+-----------------+-----------------------+----------------------+---------------------+
|LILATracts_1And10|avg(DIABETES_CrudePrev)|avg(OBESITY_CrudePrev)|avg(CANCER_CrudePrev)|
+-----------------+-----------------------+----------------------+---------------------+
|                1|     13.309765625000004|     34.45664062499999|    7.026953124999998|
|                0|      9.287370103916865|     29.93860911270983|    6.673780975219824|
+-----------------+-----------------------+----------------------+---------------------+


=== Average Cancer, Obesity, & Diabetes Rates by LILATracts_halfAnd10 ===
+--------------------+-----------------------+----------------------+---------------------+
|LILATracts_halfAnd10|avg(DIABETES_CrudePrev)|avg(OBESITY_CrudePrev)|avg(CANCER_CrudePrev)|
+--------------------+-----------------------+----------------------+---------------------+
|                   1|     12.151376146789

In [20]:
# Create new col: FoodDesertScore as sum of binary flags above
food_desert_flags = [
    "LILATracts_1And10", "LILATracts_halfAnd10", "LILATracts_1And20",
    "LILATracts_Vehicle", "LowIncomeTracts", "LATracts1", "LATracts10", "LATracts20"
]

analysis_df = analysis_df.withColumn("FoodDesertScore", sum([col(f) for f in food_desert_flags]))

# Group by FoodDesertScore & compute avg's
analysis_df.groupBy("FoodDesertScore").agg(
    {"OBESITY_CrudePrev": "avg", "DIABETES_CrudePrev": "avg", "CANCER_CrudePrev": "avg"}
).orderBy("FoodDesertScore").show()

+---------------+-----------------------+----------------------+---------------------+
|FoodDesertScore|avg(DIABETES_CrudePrev)|avg(OBESITY_CrudePrev)|avg(CANCER_CrudePrev)|
+---------------+-----------------------+----------------------+---------------------+
|              0|      8.128699551569502|    27.756502242152447|    6.982286995515691|
|              1|      9.040606060606065|    29.112323232323224|    7.152121212121213|
|              2|     11.573873873873877|     34.48648648648647|   5.6545045045045015|
|              3|     10.779545454545453|     34.17272727272727|     4.99090909090909|
|              4|     13.622727272727273|                 33.35|    7.800000000000001|
|              5|     12.548598130841114|    34.026168224299056|     7.01121495327103|
|              6|     12.568478260869565|     34.25652173913044|    6.676086956521738|
|              7|     17.388571428571428|    36.994285714285716|   7.5114285714285725|
+---------------+-----------------------+--

In [21]:
# Regression with Spark MLlib set up 
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.regression import LinearRegression
from pyspark.ml import Pipeline

In [22]:
# Regression on Obesity 
df_obesity = VectorAssembler(
    inputCols=["FoodDesertScore", "PovertyRate", "MedianFamilyIncome"],
    outputCol="features"
).transform(analysis_df).withColumnRenamed("OBESITY_CrudePrev", "label")

lr_obesity = LinearRegression(featuresCol="features", labelCol="label")
model_obesity = lr_obesity.fit(df_obesity)

print("=== OBESITY MODEL ===")
print(f"Coefficients: {model_obesity.coefficients}")
print(f"Intercept: {model_obesity.intercept}")
print(f"R^2: {model_obesity.summary.r2}")
print(f"RMSE: {model_obesity.summary.rootMeanSquaredError}")

25/07/08 11:38:50 WARN Instrumentation: [91369542] regParam is zero, which might cause numerical instability and overfitting.
25/07/08 11:38:51 WARN InstanceBuilder: Failed to load implementation from:dev.ludovic.netlib.blas.JNIBLAS
25/07/08 11:38:51 WARN InstanceBuilder: Failed to load implementation from:dev.ludovic.netlib.blas.VectorBLAS
25/07/08 11:38:51 WARN InstanceBuilder: Failed to load implementation from:dev.ludovic.netlib.lapack.JNILAPACK


=== OBESITY MODEL ===
Coefficients: [0.16018736945196185,0.19992849065188942,-4.90089508252292e-05]
Intercept: 30.514911386366105
R^2: 0.5969764965032665
RMSE: 3.253972197928731


In [23]:
# Regression on Diabetes
df_diabetes = VectorAssembler(
    inputCols=["FoodDesertScore", "PovertyRate", "MedianFamilyIncome"],
    outputCol="features"
).transform(analysis_df).withColumnRenamed("DIABETES_CrudePrev", "label")

lr_diabetes = LinearRegression(featuresCol="features", labelCol="label")
model_diabetes = lr_diabetes.fit(df_diabetes)

print("=== DIABETES MODEL ===")
print(f"Coefficients: {model_diabetes.coefficients}")
print(f"Intercept: {model_diabetes.intercept}")
print(f"R^2: {model_diabetes.summary.r2}")
print(f"RMSE: {model_diabetes.summary.rootMeanSquaredError}")

25/07/08 11:38:51 WARN Instrumentation: [bf118b83] regParam is zero, which might cause numerical instability and overfitting.


=== DIABETES MODEL ===
Coefficients: [0.491940751267148,0.046223395636333035,-3.5440337785561044e-05]
Intercept: 10.821455372814905
R^2: 0.4295764762284737
RMSE: 2.624508566572328


In [24]:
# Regression on Cancer
df_cancer = VectorAssembler(
    inputCols=["FoodDesertScore", "PovertyRate", "MedianFamilyIncome"],
    outputCol="features"
).transform(analysis_df).withColumnRenamed("CANCER_CrudePrev", "label")

lr_cancer = LinearRegression(featuresCol="features", labelCol="label")
model_cancer = lr_cancer.fit(df_cancer)

print("=== CANCER MODEL ===")
print(f"Coefficients: {model_cancer.coefficients}")
print(f"Intercept: {model_cancer.intercept}")
print(f"R^2: {model_cancer.summary.r2}")
print(f"RMSE: {model_cancer.summary.rootMeanSquaredError}")

25/07/08 11:38:52 WARN Instrumentation: [89f3642c] regParam is zero, which might cause numerical instability and overfitting.


=== CANCER MODEL ===
Coefficients: [0.29842071625028505,-0.12905686295319685,-1.2707837383349185e-05]
Intercept: 9.260230890976013
R^2: 0.16810544114767267
RMSE: 2.5883948232145615


In [None]:
# Rank by FoodDesertScore col
ranked_df = analysis_df.orderBy(col("FoodDesertScore").desc())

# Add in location col for ranking
ranked_df.select(
    "CountyName",   
    "FoodDesertScore",
    "OBESITY_CrudePrev",
    "DIABETES_CrudePrev",
    "CANCER_CrudePrev"
).show(10, truncate=False)

+----------+---------------+-----------------+------------------+----------------+
|CountyName|FoodDesertScore|OBESITY_CrudePrev|DIABETES_CrudePrev|CANCER_CrudePrev|
+----------+---------------+-----------------+------------------+----------------+
|Coconino  |7              |32.6             |15.2              |6.4             |
|Maricopa  |7              |33.9             |10.7              |6.2             |
|Coconino  |7              |35.0             |16.7              |6.3             |
|Apache    |7              |39.2             |20.4              |6.6             |
|Coconino  |7              |35.8             |17.9              |6.5             |
|Apache    |7              |36.9             |18.4              |6.7             |
|Coconino  |7              |34.3             |17.6              |7.3             |
|Apache    |7              |37.2             |18.6              |6.7             |
|Gila      |7              |31.9             |13.0              |9.8             |
|Apa

In [None]:
# Z-score setup 
from pyspark.sql.functions import mean, stddev
from pyspark.sql.functions import when

# Means & stdev
agg_stats = ranked_df.select(
    mean("OBESITY_CrudePrev"), stddev("OBESITY_CrudePrev"),
    mean("DIABETES_CrudePrev"), stddev("DIABETES_CrudePrev"),
    mean("CANCER_CrudePrev"), stddev("CANCER_CrudePrev"),
    mean("FoodDesertScore"), stddev("FoodDesertScore")
).collect()[0]

# Unpack found stats
mean_obesity, std_obesity = agg_stats[0], agg_stats[1]
mean_diabetes, std_diabetes = agg_stats[2], agg_stats[3]
mean_cancer, std_cancer = agg_stats[4], agg_stats[5]
mean_fds, std_fds = agg_stats[6], agg_stats[7]

# Z-score cols 
zscore_df = ranked_df.withColumn("Z_OBESITY", (col("OBESITY_CrudePrev") - mean_obesity) / std_obesity) \
                     .withColumn("Z_DIABETES", (col("DIABETES_CrudePrev") - mean_diabetes) / std_diabetes) \
                     .withColumn("Z_CANCER", (col("CANCER_CrudePrev") - mean_cancer) / std_cancer) \
                     .withColumn("Z_FoodDesertScore", (col("FoodDesertScore") - mean_fds) / std_fds)

# Flag tracts with higher burden value 
flagged_df = zscore_df.withColumn("Flag_High_Obesity", when(col("Z_OBESITY") > 2, 1).otherwise(0)) \
                      .withColumn("Flag_High_Diabetes", when(col("Z_DIABETES") > 2, 1).otherwise(0)) \
                      .withColumn("Flag_High_Cancer", when(col("Z_CANCER") > 2, 1).otherwise(0)) \
                      .withColumn("Flag_High_FoodDesert", when(col("Z_FoodDesertScore") > 2, 1).otherwise(0))

In [27]:
# Examine flagged data 
flagged_df.filter(
    (col("Flag_High_Obesity") == 1) |
    (col("Flag_High_Diabetes") == 1) |
    (col("Flag_High_Cancer") == 1) |
    (col("Flag_High_FoodDesert") == 1)
).select(
    "CountyName", "FoodDesertScore",
    "OBESITY_CrudePrev", "DIABETES_CrudePrev", "CANCER_CrudePrev",
    "Flag_High_Obesity", "Flag_High_Diabetes", "Flag_High_Cancer", "Flag_High_FoodDesert"
).show(truncate=False)

+----------+---------------+-----------------+------------------+----------------+-----------------+------------------+----------------+--------------------+
|CountyName|FoodDesertScore|OBESITY_CrudePrev|DIABETES_CrudePrev|CANCER_CrudePrev|Flag_High_Obesity|Flag_High_Diabetes|Flag_High_Cancer|Flag_High_FoodDesert|
+----------+---------------+-----------------+------------------+----------------+-----------------+------------------+----------------+--------------------+
|Coconino  |7              |34.3             |17.6              |7.3             |0                |1                 |0               |1                   |
|Apache    |7              |36.9             |18.4              |6.7             |0                |1                 |0               |1                   |
|Apache    |7              |38.0             |17.2              |5.7             |0                |1                 |0               |1                   |
|Apache    |7              |37.3             |17.8  