The goal of this project is to generate five key recommendations for optimizing contractor-based nursing staff allocations for nursing homes. These recommendations will be based primarily on the PBJ Daily Nurse Staffing dataset and supported by additional datasets that provide context on penalties, provider information, citations, quality measures, and more. 


| **Term Name**        | **Definition**                                                                 |
|----------------------|-------------------------------------------------------------------------------|
| PROVNUM              | Medicare provider number                                                      |
| PROVNAME             | Provider name                                                                 |
| CITY                 | Provider City                                                                 |
| STATE                | Postal abbreviation for State                                                 |
| COUNTY_NAME          | Name of Provider County, unique within state                                  |
| COUNTY_FIPS          | FIPS Code for Provider County, unique within state                            |
| CY_Qtr               | Calendar Quarter (yyyyQq, e.g. 2018Q4)                                        |
| WorkDate             | Day for Reported Hours (yyyymmdd)                                              |
| MDScensus            | Resident Census from MDS                                                      |
| Hrs_RNDON            | Total Hours for RN Director of Nursing                                         |
| Hrs_RNDON_emp        | Employee Hours for RN Director of Nursing                                      |
| Hrs_RNDON_ctr        | Contract Hours for RN Director of Nursing                                      |
| Hrs_RNadmin          | Hours for RN with administrative duties                                       |
| Hrs_RNadmin_emp      | Employee Hours for RN with administrative duties                              |
| Hrs_RNadmin_ctr      | Contract Hours for RN with administrative duties                              |
| Hrs_RN               | Total Hours for RN                                                            |
| Hrs_RN_emp           | Employee Hours for RN                                                         |
| Hrs_RN_ctr           | Contract Hours for RN                                                         |
| Hrs_LPNadmin         | Total Hours for LPN w/ admin duties                                            |
| Hrs_LPNadmin_emp     | Employee Hours for LPN w/ admin duties                                         |
| Hrs_LPNadmin_ctr     | Contract Hours for LPN w/ admin duties                                         |
| Hrs_LPN              | Total Hours for LPN                                                           |
| Hrs_LPN_emp          | Employee Hours for LPN                                                        |
| Hrs_LPN_ctr          | Contract Hours for LPN                                                        |
| Hrs_CNA              | Total Hours for CNA                                                           |
| Hrs_CNA_emp          | Employee Hours for CNA                                                        |
| Hrs_CNA_ctr          | Contract Hours for CNA                                                        |
| Hrs_NAtrain          | Total Hours for Nurse aide in training                                         |
| Hrs_NAtrain_emp      | Employee Hours for Nurse aide in training                                      |
| Hrs_NAtrain_ctr      | Contract Hours for Nurse aide in training                                      |
| Hrs_MedAide          | Total Hours for Med Aide/Technician                                            |
| Hrs_MedAide_emp      | Employee Hours for Med Aide/Technician                                         |
| Hrs_MedAide_ctr      | Contract Hours for Med Aide/Technician 

In [0]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col
import pandas as pd
from pyspark.sql import functions as F

In [0]:
spark = SparkSession.builder.appName("NurseStaffing").getOrCreate()

file_path = "/FileStore/tables/PBJ_Daily_Nurse_Staffing_Q1_2024.csv"  # update this with the actual file path
nurse_staffing_df = spark.read.csv(file_path, header=True, inferSchema=True, encoding="UTF-8" )

nurse_staffing_df.printSchema()


root
 |-- PROVNUM: string (nullable = true)
 |-- PROVNAME: string (nullable = true)
 |-- CITY: string (nullable = true)
 |-- STATE: string (nullable = true)
 |-- COUNTY_NAME: string (nullable = true)
 |-- COUNTY_FIPS: integer (nullable = true)
 |-- CY_Qtr: string (nullable = true)
 |-- WorkDate: integer (nullable = true)
 |-- MDScensus: integer (nullable = true)
 |-- Hrs_RNDON: double (nullable = true)
 |-- Hrs_RNDON_emp: double (nullable = true)
 |-- Hrs_RNDON_ctr: double (nullable = true)
 |-- Hrs_RNadmin: double (nullable = true)
 |-- Hrs_RNadmin_emp: double (nullable = true)
 |-- Hrs_RNadmin_ctr: double (nullable = true)
 |-- Hrs_RN: double (nullable = true)
 |-- Hrs_RN_emp: double (nullable = true)
 |-- Hrs_RN_ctr: double (nullable = true)
 |-- Hrs_LPNadmin: double (nullable = true)
 |-- Hrs_LPNadmin_emp: double (nullable = true)
 |-- Hrs_LPNadmin_ctr: double (nullable = true)
 |-- Hrs_LPN: double (nullable = true)
 |-- Hrs_LPN_emp: double (nullable = true)
 |-- Hrs_LPN_ctr: doubl

In [0]:
top_10_df = nurse_staffing_df.limit(10)

top_10_pandas_df = top_10_df.toPandas()

#display(top_10_pandas_df)

In [0]:
# Summary statistics to understand the distribution of numerical columns
summary_df = nurse_staffing_df.describe()

summary_pandas_df = summary_df.toPandas()

pd.set_option('display.float_format', '{:.2f}'.format)  # Rounds floats to 2 decimal places

#display(summary_pandas_df)

In [0]:
missing_values = nurse_staffing_df.select([col(column).isNull().alias(column) for column in nurse_staffing_df.columns]).groupBy().sum()
missing_values.show()

++
||
++
||
++



In [0]:
## Get basic column statistics
summary_df = nurse_staffing_df.describe()

summary_pandas_df = summary_df.toPandas()

pd.set_option('display.float_format', '{:.2f}'.format) 

columns_to_display = ['summary'] + list(summary_pandas_df.columns[9:])

#display(summary_pandas_df[columns_to_display])

In [0]:
aggregated_pandas_df = None
aggregated_df = None

To better analyze the data and make reccomendations, I will be implementing a number of metrics by aggregating the data by different
categories, like State, City, Provider Number, and more. Some of these metrics include:

<strong style="font-size:26px;">1. Contractor Utilization Rate:</strong> an index that measures a nursing home's reliance on contractors  
<p style="font-size:20px;">
An index that measures a nursing home's reliance on contractors in relation to other workers at the facility<br>
<strong>-Contractor Utilization Rate</strong> = <sup>Total Contractor Hours (RN, LPN, CNA, NAtrn, MedAide)</sup> / <sub>Total Hours (RN, LPN, CNA, NAtrn, MedAide)</sub>
</p>


Where: <br>
- <strong>Total Contractor Hours</strong> refers to the total number of hours worked by contractors across all staff types (RN, LPN, CNA, NAtrn, MedAide).<br>
- <strong>Total Hours</strong> includes the total hours worked by both employees and contractors across the same staff types.<br>

<strong style="font-size:26px;">2. Hours per Resident per Day, Contractor and Employee performance:</strong> <br>
    <p style="font-size:20px;">
An index that measures how much time contractors and employees spend on each resident on average. An effective metric for comparing efficiency in care. <br>
<strong>-Hours per Resident per Day(Contractors)</strong> = <sup>Total Contractor Hours(RN, LPN, CNA, NAtrn, MedAide)</sup> / <sub>Resident Census(MDScensus)</sub><br>
<strong>-Hours per Resident per Day(Employees)</strong> = <sup>Total Employee Hours(RN, LPN, CNA, NAtrn, MedAide)</sup> / <sub>Resident Census(MDScensus)</sub><br>
</p>


Where: <br>
- <strong>Total Contractor Hours</strong> refers to the total number of hours worked by contractors across all staff types (RN, LPN, CNA, NAtrn, MedAide).<br>
- <strong>Total Employee Hours</strong> refers to the total number of hours worked by contractors across all staff types (RN, LPN, CNA, NAtrn, MedAide).<br>
- <strong>Resident Census (MDScensus)</strong> refers to the number of residents present or receiving care in a nursing home facility on a given day.<br>




In [0]:
nurse_staffing_df = nurse_staffing_df.withColumn(
    "Total_Nursing_Hours", 
    F.col("Hrs_RN") + F.col("Hrs_LPN") + F.col("Hrs_CNA") + F.col("Hrs_NAtrn") + F.col("Hrs_MedAide")
)
nurse_staffing_df = nurse_staffing_df.withColumn(
    "Total_contracting_Hours", 
    F.col("Hrs_RN_ctr") + F.col("Hrs_LPN_ctr") + F.col("Hrs_CNA_ctr") + F.col("Hrs_NAtrn_ctr") + F.col("Hrs_MedAide_ctr")
)
nurse_staffing_df = nurse_staffing_df.withColumn(
    "Total_employee_Hours", 
    F.col("Hrs_RN_emp") + F.col("Hrs_LPN_emp") + F.col("Hrs_CNA_emp") + F.col("Hrs_NAtrn_emp") + F.col("Hrs_MedAide_emp")
)
nurse_staffing_df = nurse_staffing_df.withColumn(
    'contractor_reliance_ratio', 
    F.col('Total_contracting_Hours') / F.col("Total_Nursing_Hours")
)

nurse_staffing_df = nurse_staffing_df.fillna({'contractor_reliance_ratio': 0})



In [0]:
# percentile based bins
quantiles = nurse_staffing_df.approxQuantile("contractor_reliance_ratio", [0.33, 0.66], 0.01)
low_quantile = quantiles[0]
high_quantile = quantiles[1]

df_agg_home = nurse_staffing_df.groupBy("PROVNUM").agg(
    F.sum("Total_contracting_Hours").alias("total_contractor_hours"),
    F.sum("Total_employee_Hours").alias("total_employee_hours"),
    F.sum("MDScensus").alias("total_resident_census"),
    F.sum("Total_Nursing_Hours").alias("Total_nursing_Hours"),
    F.countDistinct("WorkDate").alias("total_days") 
)
df_agg_home = df_agg_home.withColumn(
    'contractor_reliance_ratio', 
    F.col('total_contractor_hours') / F.col("Total_nursing_Hours")
)
df_agg_home = df_agg_home.withColumn(
    'Reliance on Contracting',
    F.when(F.col('contractor_reliance_ratio') <= low_quantile, 'Low Reliance')
    .when((F.col('contractor_reliance_ratio') > low_quantile) & (F.col('contractor_reliance_ratio') <= high_quantile), 'Medium Reliance')
    .otherwise('High Reliance')
)

category_counts = df_agg_home.groupBy("PROVNUM", "Reliance on Contracting").agg(F.count("*").alias("count"))


#category_counts.display()

In [0]:
df_agg_home.display()

In [0]:
import plotly.express as px
quantiles = nurse_staffing_df.approxQuantile("contractor_reliance_ratio", [0.33, 0.66], 0.01)
low_quantile = quantiles[0]
high_quantile = quantiles[1]

state_view = nurse_staffing_df.groupBy("STATE").agg(
    F.sum("Total_contracting_Hours").alias("total_contractor_hours"),
    F.sum("Total_employee_Hours").alias("total_employee_hours"),
    F.sum("MDScensus").alias("total_resident_census"),
    F.sum("Total_Nursing_Hours").alias("Total_nursing_Hours"),
    F.countDistinct("WorkDate").alias("total_days") 
)
state_view = state_view.withColumn(
    'contractor_reliance_ratio', 
    F.col('total_contractor_hours') / F.col("Total_nursing_Hours")
)
quantiles = state_view.approxQuantile("contractor_reliance_ratio", [0.33, 0.66], 0.01)
low_quantile = quantiles[0]
high_quantile = quantiles[1]

state_view = state_view.withColumn(
    'Reliance on Contracting',
    F.when(F.col('contractor_reliance_ratio') <= low_quantile, 'Low Reliance')
    .when((F.col('contractor_reliance_ratio') > low_quantile) & (F.col('contractor_reliance_ratio') <= high_quantile), 'Medium Reliance')
    .otherwise('High Reliance')
)

state_vals = state_view.select("STATE").collect()
crr_vals = state_view.select("contractor_reliance_ratio").collect()
data = {
    'State' : [row["STATE"] for row in state_vals],
    'Contractor Reliance' : [row["contractor_reliance_ratio"] for row in crr_vals]
}

df = pd.DataFrame(data)


fig = px.choropleth(df, 
                    locations='State', 
                    locationmode="USA-states", 
                    color='Contractor Reliance', 
                    color_continuous_scale="Viridis", 
                    scope="usa",
                    title="Contractor Reliance Ratio by State")

fig.show()
# category_counts = state_view.groupBy("STATE", "reliance_category").agg(F.count("*").alias("count"))
# category_counts.display()

In [0]:
state_view_sorted = state_view.orderBy(F.col("contractor_reliance_ratio").desc()).limit(10)
state_view_sorted.select("STATE", "contractor_reliance_ratio").display()

STATE,contractor_reliance_ratio
VT,0.3226349250671943
ME,0.186067706104239
NH,0.1773764769971579
PA,0.1719216624272976
NJ,0.1633561007562138
ND,0.1413256539804502
DE,0.1403628173861172
OR,0.1321779104794137
NY,0.1311192164619611
AK,0.1296771993661699


Databricks visualization. Run in Databricks to view.

In [0]:
df_agg_home = df_agg_home.withColumn(
    "HPRD", 
    F.col("Total_nursing_Hours") / F.col("total_resident_census")
)

df_agg_home = df_agg_home.withColumn(
    "Risk of Short Staffing",
    F.when(col("HPRD") < 3.5, "High-Risk")  
    .when((col("HPRD") >= 3.5) & (col("HPRD") < 4.0), "Moderate-Risk")  
    .otherwise("Low-Risk")
)
df_agg_home_market = df_agg_home.groupBy("Risk of Short Staffing", "Reliance on Contracting").agg(F.count("*").alias("# of Nursing Homes"))

df_agg_home_market = df_agg_home_market.withColumn(
    "Market Share (%)",
    (F.col("# of Nursing Homes") / df_agg_home.count())*100
)
#print(df_agg_home.count())
df_agg_home_market.select("Risk of Short Staffing", "Reliance on Contracting", "# of Nursing Homes", "Market Share (%)").sort("Risk of Short Staffing", "Reliance on Contracting").display()


Risk of Short Staffing,Reliance on Contracting,# of Nursing Homes,Market Share (%)
High-Risk,High Reliance,3125,21.366060440311774
High-Risk,Low Reliance,3773,25.796526733214822
High-Risk,Medium Reliance,1736,11.869273895801996
Low-Risk,High Reliance,1185,8.102010118966223
Low-Risk,Low Reliance,1101,7.527690414330644
Low-Risk,Medium Reliance,573,3.917680842335567
Moderate-Risk,High Reliance,1201,8.211404348420622
Moderate-Risk,Low Reliance,1311,8.963489675919595
Moderate-Risk,Medium Reliance,621,4.245863530698755


Databricks visualization. Run in Databricks to view.

In [0]:
file_path = "/FileStore/tables/NH_Penalties_Aug2024.csv"  # update this with the actual file path
nursing_home_penalties = spark.read.csv(file_path, header=True, inferSchema=True, encoding="UTF-8" )
nursing_home_penalties = nursing_home_penalties.withColumn("Penalty_Date", F.to_date(F.col("Penalty Date"), "yyyy-MM-dd"))


nursing_home_penalties = nursing_home_penalties.filter(
    (F.col("Penalty_Date") >= F.lit("2024-01-01")) & (F.col("Penalty_Date") <= F.lit("2024-03-31"))
)
nursing_home_penalties.printSchema()
#nursing_home_penalties.display()


root
 |-- CMS Certification Number (CCN): string (nullable = true)
 |-- Provider Name: string (nullable = true)
 |-- Provider Address: string (nullable = true)
 |-- City/Town: string (nullable = true)
 |-- State: string (nullable = true)
 |-- ZIP Code: integer (nullable = true)
 |-- Penalty Date: date (nullable = true)
 |-- Penalty Type: string (nullable = true)
 |-- Fine Amount: integer (nullable = true)
 |-- Payment Denial Start Date: date (nullable = true)
 |-- Payment Denial Length in Days: integer (nullable = true)
 |-- Location: string (nullable = true)
 |-- Processing Date: date (nullable = true)
 |-- Penalty_Date: date (nullable = true)



In [0]:

agg_penalties = nursing_home_penalties.groupBy("CMS Certification Number (CCN)").agg(
    F.count(F.when(F.col("Penalty Type") == "Fine", True)).alias("num_fines"),

    F.count(F.when(F.col("Penalty Type") == "Payment Denial", True)).alias("num_payment_denials"),

    F.sum(F.when(F.col("Penalty Type") == "Fine", F.col("Fine Amount")).otherwise(0)).alias("total_fine_amount"),

    F.sum(F.when(F.col("Penalty Type") == "Payment Denial", 
                 F.col("Payment Denial Length in Days")
                ).otherwise(0)).alias("total_payment_denial_length")
)
agg_penalties_filtered = agg_penalties.select(
    col("CMS Certification Number (CCN)").alias("PROVNUM"),
    "num_fines", 
    "num_payment_denials",
    "total_fine_amount", 
    "total_payment_denial_length"
    )
    
df_agg_home = df_agg_home.join(agg_penalties_filtered, on = "PROVNUM", how = "Left")
df_agg_home.printSchema()


root
 |-- PROVNUM: string (nullable = true)
 |-- total_contractor_hours: double (nullable = true)
 |-- total_employee_hours: double (nullable = true)
 |-- total_resident_census: long (nullable = true)
 |-- Total_nursing_Hours: double (nullable = true)
 |-- total_days: long (nullable = false)
 |-- contractor_reliance_ratio: double (nullable = true)
 |-- Reliance on Contracting: string (nullable = false)
 |-- HPRD: double (nullable = true)
 |-- Risk of Short Staffing: string (nullable = false)
 |-- num_fines: long (nullable = true)
 |-- num_payment_denials: long (nullable = true)
 |-- total_fine_amount: long (nullable = true)
 |-- total_payment_denial_length: long (nullable = true)



In [0]:
df_agg_home = df_agg_home.fillna(0, subset = ["num_fines", "num_payment_denials","total_fine_amount", "total_payment_denial_length", "contractor_reliance_ratio"])
#df_agg_home.display()

In [0]:

file_path = "/FileStore/tables/NH_HealthCitations_Aug2024.csv"  
nursing_home_ssc = spark.read.csv(file_path, header=True, inferSchema=True, encoding="UTF-8" )
nursing_home_ssc = nursing_home_ssc.withColumn("Survey Date", F.to_date(F.col("Survey Date"), "yyyy-MM-dd"))

print(df_agg_home.select("PROVNUM").distinct().count())
nursing_home_ssc = nursing_home_ssc.filter(
    (F.col("Survey Date") >= F.lit("2024-01-01")) & (F.col("Survey Date") <= F.lit("2024-03-31"))
)
nursing_home_ssc = nursing_home_ssc.select(
    col("CMS Certification Number (CCN)").alias("PROVNUM"),
    "Scope Severity Code"
)

from pyspark.sql.window import Window
from pyspark.sql import functions as F

window = Window.partitionBy("PROVNUM").orderBy(F.col("Scope Severity Code").desc())

df_with_rank = nursing_home_ssc.withColumn("rank", F.row_number().over(window))

nursing_home_ssc = df_with_rank.filter(F.col("rank") == 1)

#nursing_home_ssc.select("PROVNUM", "Scope Severity Code").display()


14626


In [0]:

df_agg_home = df_agg_home.join(nursing_home_ssc, on = "PROVNUM", how = "Left")
df_agg_home.printSchema()
#df_agg_home.display()


root
 |-- PROVNUM: string (nullable = true)
 |-- total_contractor_hours: double (nullable = true)
 |-- total_employee_hours: double (nullable = true)
 |-- total_resident_census: long (nullable = true)
 |-- Total_nursing_Hours: double (nullable = true)
 |-- total_days: long (nullable = false)
 |-- contractor_reliance_ratio: double (nullable = false)
 |-- Reliance on Contracting: string (nullable = false)
 |-- HPRD: double (nullable = true)
 |-- Risk of Short Staffing: string (nullable = false)
 |-- num_fines: long (nullable = true)
 |-- num_payment_denials: long (nullable = true)
 |-- total_fine_amount: long (nullable = true)
 |-- total_payment_denial_length: long (nullable = true)
 |-- Scope Severity Code: string (nullable = true)
 |-- rank: integer (nullable = true)



In [0]:
null_count = df_agg_home.filter(F.col("Scope Severity Code").isNull()).count()
total = df_agg_home.count()
df_agg_home.printSchema()
print(f"Number of null values in 'Scope Severity Code: {null_count}, Total Number of Values: {total}")

root
 |-- PROVNUM: string (nullable = true)
 |-- total_contractor_hours: double (nullable = true)
 |-- total_employee_hours: double (nullable = true)
 |-- total_resident_census: long (nullable = true)
 |-- Total_nursing_Hours: double (nullable = true)
 |-- total_days: long (nullable = false)
 |-- contractor_reliance_ratio: double (nullable = false)
 |-- Reliance on Contracting: string (nullable = false)
 |-- HPRD: double (nullable = true)
 |-- Risk of Short Staffing: string (nullable = false)
 |-- num_fines: long (nullable = true)
 |-- num_payment_denials: long (nullable = true)
 |-- total_fine_amount: long (nullable = true)
 |-- total_payment_denial_length: long (nullable = true)
 |-- Scope Severity Code: string (nullable = true)
 |-- rank: integer (nullable = true)

Number of null values in 'Scope Severity Code: 9418, Total Number of Values: 14626


In [0]:
df_agg_home = df_agg_home.withColumn(
    "Severity Points",
    F.when(F.col("Scope Severity Code") == "J", 50)
     .when(F.col("Scope Severity Code") == "K", 100)
     .when(F.col("Scope Severity Code") == "L", 150)
     .when(F.col("Scope Severity Code") == "G", 10)
     .when(F.col("Scope Severity Code") == "H", 20)
     .when(F.col("Scope Severity Code") == "I", 30)
     .when(F.col("Scope Severity Code") == "D", 2)
     .when(F.col("Scope Severity Code") == "E", 4)
     .when(F.col("Scope Severity Code") == "F", 6)
     .when(F.col("Scope Severity Code") == "A", 0)
     .when(F.col("Scope Severity Code") == "B", 0)
     .when(F.col("Scope Severity Code") == "C", 0)
     .otherwise(None)
)
df_sev = df_agg_home.groupBy("Severity Points").count()

df_sev.orderBy("Severity Points").show()

+---------------+-----+
|Severity Points|count|
+---------------+-----+
|           null| 9418|
|              0|   26|
|              2| 1700|
|              4| 1355|
|              6|  863|
|             10|  709|
|             20|   27|
|             50|  367|
|            100|  107|
|            150|   54|
+---------------+-----+



In [0]:
df_agg_home.printSchema()

root
 |-- PROVNUM: string (nullable = true)
 |-- total_contractor_hours: double (nullable = true)
 |-- total_employee_hours: double (nullable = true)
 |-- total_resident_census: long (nullable = true)
 |-- Total_nursing_Hours: double (nullable = true)
 |-- total_days: long (nullable = false)
 |-- contractor_reliance_ratio: double (nullable = false)
 |-- Reliance on Contracting: string (nullable = false)
 |-- HPRD: double (nullable = true)
 |-- Risk of Short Staffing: string (nullable = false)
 |-- num_fines: long (nullable = true)
 |-- num_payment_denials: long (nullable = true)
 |-- total_fine_amount: long (nullable = true)
 |-- total_payment_denial_length: long (nullable = true)
 |-- Scope Severity Code: string (nullable = true)
 |-- rank: integer (nullable = true)
 |-- Severity Points: integer (nullable = true)



In [0]:
from pyspark.ml.feature import VectorAssembler
from pyspark.sql import functions as F
from pyspark.ml.clustering import KMeans
from pyspark.ml.feature import StandardScaler
assembler = VectorAssembler(inputCols=[
    "total_contractor_hours", "contractor_reliance_ratio", 
    "HPRD", "num_fines", "total_fine_amount", 
    "num_payment_denials", "total_payment_denial_length"], 
    outputCol="features")


df_features = assembler.transform(df_agg_home)

kmeans = KMeans(featuresCol="features", k=12, seed=1)

scaler = StandardScaler(inputCol="features", outputCol="scaled_features", withMean=True, withStd=True)
df_scaled = scaler.fit(df_features).transform(df_features)

model = kmeans.fit(df_scaled)

df_clustered = model.transform(df_scaled)

df_clustered.select("PROVNUM", "features", "prediction").show(5)


+-------+--------------------+----------+
|PROVNUM|            features|prediction|
+-------+--------------------+----------+
| 015171|(7,[2],[4.2829391...|        10|
| 015371|(7,[2],[3.6745227...|        10|
| 015425|(7,[2],[3.9094496...|        10|
| 035059|(7,[0,1,2],[2535....|        10|
| 055253|(7,[0,1,2],[479.5...|        10|
+-------+--------------------+----------+
only showing top 5 rows



In [0]:

df_non_null = df_clustered.filter(F.col("Scope Severity Code").isNotNull())
df_non_null.select("PROVNUM", "Scope Severity Code", "prediction").show(10)


+-------+-------------------+----------+
|PROVNUM|Scope Severity Code|prediction|
+-------+-------------------+----------+
| 055253|                  G|        10|
| 055074|                  D|        10|
| 055304|                  F|        10|
| 055356|                  G|        10|
| 035166|                  D|        10|
| 045373|                  E|        10|
| 056214|                  E|        10|
| 105119|                  D|        10|
| 055563|                  D|        10|
| 056258|                  D|        10|
+-------+-------------------+----------+
only showing top 10 rows



In [0]:
from pyspark.sql import functions as F
from pyspark.sql.window import Window

# Group by cluster (prediction) and Scope Severity Code to count occurrences
cluster_code_counts = df_non_null.groupBy("prediction", "Scope Severity Code").count()

# Rank the Scope Severity Codes within each cluster by their frequency
window = Window.partitionBy("prediction").orderBy(F.desc("count"))

# Add a ranking column and keep the most common Scope Severity Code per cluster
df_cluster_modes = cluster_code_counts.withColumn("rank", F.row_number().over(window)).filter(F.col("rank") == 1)

# Show the most common Scope Severity Code for each cluster
df_cluster_modes.select("prediction", "Scope Severity Code", "count").show()


+----------+-------------------+-----+
|prediction|Scope Severity Code|count|
+----------+-------------------+-----+
|         0|                  D|  294|
|         1|                  J|    8|
|         2|                  G|   68|
|         3|                  K|    3|
|         4|                  G|  177|
|         5|                  J|   11|
|         6|                  L|    1|
|         7|                  G|   84|
|         8|                  E|    8|
|         9|                  D|   76|
|        10|                  D| 1294|
|        11|                  J|   22|
+----------+-------------------+-----+



In [0]:
# Calculate the total number of rows in each cluster
total_in_cluster = df_non_null.groupBy("prediction").count().withColumnRenamed("count", "total_count")

# Join with the mode data to get the counts of the most common Scope Severity Code in each cluster
cluster_purity = df_cluster_modes.join(total_in_cluster, on="prediction")

# Calculate the purity (percentage of rows in each cluster that have the most common Scope Severity Code)
cluster_purity = cluster_purity.withColumn("purity", F.col("count") / F.col("total_count"))

# Show cluster purity
cluster_purity.select("prediction", "Scope Severity Code", "purity").show()


+----------+-------------------+-------------------+
|prediction|Scope Severity Code|             purity|
+----------+-------------------+-------------------+
|         0|                  D|0.41761363636363635|
|         1|                  J|0.27586206896551724|
|         2|                  G| 0.4533333333333333|
|         3|                  K|                0.5|
|         4|                  G| 0.4338235294117647|
|         5|                  J| 0.2619047619047619|
|         6|                  L|                1.0|
|         7|                  G|                0.4|
|         8|                  E|0.38095238095238093|
|         9|                  D| 0.4175824175824176|
|        10|                  D| 0.3812610489098409|
|        11|                  J|0.36065573770491804|
+----------+-------------------+-------------------+



In [0]:
from pyspark.sql.window import Window
from pyspark.sql import functions as F


mode_value = df_agg_home.filter(F.col("Scope Severity Code").isNotNull()) \
    .groupBy("Scope Severity Code") \
    .count() \
    .orderBy(F.desc("count")) \
    .first()["Scope Severity Code"]
print(f"Mode value of 'Scope Severity Code': {mode_value}")
counts = df_agg_home.groupBy("Scope Severity Code").count().orderBy(F.desc("count"))

df_filled = df_agg_home.fillna({"Scope Severity Code": mode_value})

#df_filled.display()
counts.display()

Mode value of 'Scope Severity Code': D


Scope Severity Code,count
,9418
D,1700
E,1355
F,863
G,709
J,367
K,107
L,54
H,27
B,18
