In [1]:
import os
import sys

os.environ["SPARK_HOME"] = "/home/talentum/spark" 
os.environ["PYLIB"] = os.environ["SPARK_HOME"] + "/python/lib"
os.environ["PYSPARK_PYTHON"] = "/usr/bin/python3.6" 
os.environ["PYSPARK_DRIVER_PYTHON"] = "/usr/bin/python3" 
sys.path.insert(0, os.environ["PYLIB"] + "/py4j-0.10.7-src.zip") 
sys.path.insert(0, os.environ["PYLIB"] + "/pyspark.zip")


In [2]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, when, lit, round, rand, sum as spark_sum

In [6]:
# spark = SparkSession.builder \
#     .appName("HeartDiseasePreprocessingPySpark") \
#     .config("spark.driver.memory", "8g") \
#     .config("spark.jars.packages", "mysql:mysql-connector-java:8.0.28") \
#     .getOrCreate()

# print("Spark Session initialized with 8g driver memory and MySQL connector.")

spark = SparkSession.builder \
    .appName("HeartDiseasePreprocessingPySpark") \
    .config("spark.jars.packages", "mysql:mysql-connector-java:8.0.28") \
    .getOrCreate()
  


In [10]:
hdfs_raw_data_path = "Dataset"
hdfs_output_csv_path = "Cleaned_dataset"


In [11]:
#     Applies common cleaning and transformation steps to the heart disease dataset.
#     Assumes all relevant columns have already been renamed to their standardized forms.
#     This function focuses on recoding, filtering invalid values, dropping nulls, and duplicates.

In [12]:
# Preprocessing Function
def preprocess_heart_data(df):
    
    final_target_col = "HeartDiseaseorAttack"

    # --- Recode the target variable (HeartDiseaseorAttack) ---
    if final_target_col in df.columns:
        df = df.withColumn(final_target_col, \
                            when(col(final_target_col) == 2, lit(0)) \
                            .otherwise(col(final_target_col)))
        df = df.filter(~col(final_target_col).isin([7, 9])) # Filter out invalid codes (e.g., Don't know/Refused)
    else:
        print(f"Warning: '{final_target_col}' not found in DataFrame. Target processing might be incomplete.")

    # --- Recode other columns based on BRFSS interpretation and filter invalid codes ---
    if "HighBP" in df.columns:
        df = df.withColumn('HighBP', when(col('HighBP') == 1, 1).when(col('HighBP') == 2, 0).otherwise(col('HighBP'))).filter(col('HighBP') != 9)
    if "HighChol" in df.columns:
        df = df.withColumn('HighChol', when(col('HighChol') == 1, 1).when(col('HighChol') == 2, 0).otherwise(col('HighChol'))).filter(~col('HighChol').isin([7, 9]))
    if "CholCheck" in df.columns:
        df = df.withColumn('CholCheck', when(col('CholCheck') == 1, 1).when(col('CholCheck').isin([2, 3]), 0).otherwise(col('CholCheck'))).filter(col('CholCheck') != 9)

    if "BMI" in df.columns:
        df = df.withColumn('BMI', round(col('BMI') / 100, 2))

    if "Smoker" in df.columns:
        df = df.withColumn('Smoker', when(col('Smoker') == 1, 1).when(col('Smoker') == 2, 0).otherwise(col('Smoker'))).filter(~col('Smoker').isin([7, 9]))
    if "Stroke" in df.columns:
        df = df.withColumn('Stroke', when(col('Stroke') == 1, 1).when(col('Stroke') == 2, 0).otherwise(col('Stroke'))).filter(~col('Stroke').isin([7, 9]))
    if "Diabetes" in df.columns:
        df = df.withColumn('Diabetes',
                            when(col('Diabetes') == 1, 2)  # Yes
                            .when(col('Diabetes') == 2, 0)  # No
                            .when(col('Diabetes') == 3, 2)  # Pre-diabetes (grouped with Yes)
                            .when(col('Diabetes') == 4, 1)  # Gestational
                            .otherwise(col('Diabetes'))).filter(~col('Diabetes').isin([7, 9]))
    if "PhysActivity" in df.columns:
        df = df.withColumn('PhysActivity', when(col('PhysActivity') == 1, 1).when(col('PhysActivity') == 2, 0).otherwise(col('PhysActivity'))).filter(col('PhysActivity') != 9)
    if "Fruits" in df.columns:
        df = df.withColumn('Fruits', when(col('Fruits') == 1, 1).when(col('Fruits') == 2, 0).otherwise(col('Fruits'))).filter(col('Fruits') != 9)
    if "Veggies" in df.columns:
        df = df.withColumn('Veggies', when(col('Veggies') == 1, 1).when(col('Veggies') == 2, 0).otherwise(col('Veggies'))).filter(col('Veggies') != 9)
    if "HvyAlcoholConsump" in df.columns:
        df = df.withColumn('HvyAlcoholConsump', when(col('HvyAlcoholConsump') == 1, 1).when(col('HvyAlcoholConsump') == 2, 0).otherwise(col('HvyAlcoholConsump'))).filter(col('HvyAlcoholConsump') != 9)
    if "AnyHealthcare" in df.columns:
        df = df.withColumn('AnyHealthcare', when(col('AnyHealthcare') == 1, 1).when(col('AnyHealthcare') == 2, 0).otherwise(col('AnyHealthcare'))).filter(~col('AnyHealthcare').isin([7, 9]))
    if "NoDocbcCost" in df.columns:
        df = df.withColumn('NoDocbcCost', when(col('NoDocbcCost') == 1, 1).when(col('NoDocbcCost') == 2, 0).otherwise(col('NoDocbcCost'))).filter(~col('NoDocbcCost').isin([7, 9]))

    if "GenHlth" in df.columns:
        df = df.filter(~col('GenHlth').isin([7, 9]))
    if "MentHlth" in df.columns:
        df = df.withColumn('MentHlth', when(col('MentHlth') == 88, 0).otherwise(col('MentHlth'))).filter(~col('MentHlth').isin([77, 99]))
    if "PhysHlth" in df.columns:
        df = df.withColumn('PhysHlth', when(col('PhysHlth') == 88, 0).otherwise(col('PhysHlth'))).filter(~col('PhysHlth').isin([77, 99]))
    if "DiffWalk" in df.columns:
        df = df.withColumn('DiffWalk', when(col('DiffWalk') == 1, 1).when(col('DiffWalk') == 2, 0).otherwise(col('DiffWalk'))).filter(~col('DiffWalk').isin([7, 9]))
    if "Sex" in df.columns:
        df = df.withColumn('Sex', when(col('Sex') == 1, 1).when(col('Sex') == 2, 0).otherwise(col('Sex'))) # Recode Female to 0, Male to 1
    if "Age" in df.columns:
        df = df.filter(col('Age') != 14)
    if "Education" in df.columns:
        df = df.filter(col('Education') != 9)
    if "Income" in df.columns:
        df = df.filter(~col('Income').isin([77, 99]))

    # Define the final set of common columns expected after preprocessing
    final_common_columns = [
        final_target_col, "HighBP", "HighChol", "CholCheck", "BMI",
        "Smoker", "Stroke", "Diabetes", "PhysActivity", "Fruits", "Veggies",
        "HvyAlcoholConsump", "AnyHealthcare", "NoDocbcCost", "GenHlth",
        "MentHlth", "PhysHlth", "DiffWalk", "Sex", "Age", "Education", "Income"
    ]

    # Select only the desired columns that are present in the DataFrame
    selected_cols = [c for c in final_common_columns if c in df.columns]
    df = df.select(*selected_cols)

    # Drop rows with any remaining null values
    df = df.na.drop()

    # Remove duplicates
    df = df.dropDuplicates()

    return df


In [13]:
print(f"Reading 2013 data from HDFS: {hdfs_raw_data_path}/2013.csv")
df_2013_raw = spark.read.csv(f"{hdfs_raw_data_path}/2013.csv", header=True, inferSchema=True)


Reading 2013 data from HDFS: Dataset/2013.csv


In [14]:
# Define ALL columns to select and rename for 2013
selected_cols_2013_initial = [
    'CVDCRHD4', '_RFHYPE5', 'TOLDHI2', '_CHOLCHK', '_BMI5', 'SMOKE100', 'CVDSTRK3', 'DIABETE3',
    '_TOTINDA', '_FRTLT1', '_VEGLT1', '_RFDRHV4', 'HLTHPLN1', 'MEDCOST', 'GENHLTH',
    'MENTHLTH', 'PHYSHLTH', 'DIFFWALK', 'SEX', '_AGEG5YR', 'EDUCA', 'INCOME2'
]

In [15]:

df_2013_renamed = df_2013_raw.select(*selected_cols_2013_initial) \
    .withColumnRenamed('CVDCRHD4', 'HeartDiseaseorAttack') \
    .withColumnRenamed('_RFHYPE5', 'HighBP') \
    .withColumnRenamed('TOLDHI2', 'HighChol') \
    .withColumnRenamed('_CHOLCHK', 'CholCheck') \
    .withColumnRenamed('_BMI5', 'BMI') \
    .withColumnRenamed('SMOKE100', 'Smoker') \
    .withColumnRenamed('CVDSTRK3', 'Stroke') \
    .withColumnRenamed('DIABETE3', 'Diabetes') \
    .withColumnRenamed('_TOTINDA', 'PhysActivity') \
    .withColumnRenamed('_FRTLT1', 'Fruits') \
    .withColumnRenamed('_VEGLT1', 'Veggies') \
    .withColumnRenamed('_RFDRHV4', 'HvyAlcoholConsump') \
    .withColumnRenamed('HLTHPLN1', 'AnyHealthcare') \
    .withColumnRenamed('MEDCOST', 'NoDocbcCost') \
    .withColumnRenamed('GENHLTH', 'GenHlth') \
    .withColumnRenamed('MENTHLTH', 'MentHlth') \
    .withColumnRenamed('PHYSHLTH', 'PhysHlth') \
    .withColumnRenamed('DIFFWALK', 'DiffWalk') \
    .withColumnRenamed('SEX', 'Sex') \
    .withColumnRenamed('_AGEG5YR', 'Age') \
    .withColumnRenamed('EDUCA', 'Education') \
    .withColumnRenamed('INCOME2', 'Income')


In [16]:
df_2013_processed = preprocess_heart_data(df_2013_renamed)
print(f"2013 Processed Row Count: {df_2013_processed.count()}")


2013 Processed Row Count: 292441


In [17]:
print(f"Reading 2015 data from HDFS: {hdfs_raw_data_path}/2015.csv")
df_2015_raw = spark.read.csv(f"{hdfs_raw_data_path}/2015.csv", header=True, inferSchema=True)


Reading 2015 data from HDFS: Dataset/2015.csv


In [18]:
# Define ALL columns to select and rename for 2015
selected_cols_2015_initial = [
    '_MICHD', '_RFHYPE5', 'TOLDHI2', '_CHOLCHK', '_BMI5', 'SMOKE100', 'CVDSTRK3', 'DIABETE3',
    '_TOTINDA', '_FRTLT1', '_VEGLT1', '_RFDRHV5', 'HLTHPLN1', 'MEDCOST', 'GENHLTH',
    'MENTHLTH', 'PHYSHLTH', 'DIFFWALK', 'SEX', '_AGEG5YR', 'EDUCA', 'INCOME2'
]


In [19]:
df_2015_renamed = df_2015_raw.select(*selected_cols_2015_initial) \
    .withColumnRenamed('_MICHD', 'HeartDiseaseorAttack') \
    .withColumnRenamed('_RFHYPE5', 'HighBP') \
    .withColumnRenamed('TOLDHI2', 'HighChol') \
    .withColumnRenamed('_CHOLCHK', 'CholCheck') \
    .withColumnRenamed('_BMI5', 'BMI') \
    .withColumnRenamed('SMOKE100', 'Smoker') \
    .withColumnRenamed('CVDSTRK3', 'Stroke') \
    .withColumnRenamed('DIABETE3', 'Diabetes') \
    .withColumnRenamed('_TOTINDA', 'PhysActivity') \
    .withColumnRenamed('_FRTLT1', 'Fruits') \
    .withColumnRenamed('_VEGLT1', 'Veggies') \
    .withColumnRenamed('_RFDRHV5', 'HvyAlcoholConsump') \
    .withColumnRenamed('HLTHPLN1', 'AnyHealthcare') \
    .withColumnRenamed('MEDCOST', 'NoDocbcCost') \
    .withColumnRenamed('GENHLTH', 'GenHlth') \
    .withColumnRenamed('MENTHLTH', 'MentHlth') \
    .withColumnRenamed('PHYSHLTH', 'PhysHlth') \
    .withColumnRenamed('DIFFWALK', 'DiffWalk') \
    .withColumnRenamed('SEX', 'Sex') \
    .withColumnRenamed('_AGEG5YR', 'Age') \
    .withColumnRenamed('EDUCA', 'Education') \
    .withColumnRenamed('INCOME2', 'Income')



In [20]:
df_2015_processed = preprocess_heart_data(df_2015_renamed)
print(f"2015 Processed Row Count: {df_2015_processed.count()}")



2015 Processed Row Count: 249137


In [21]:
print("Uniting processed dataframes...")
merged_df = df_2013_processed.unionByName(df_2015_processed)
print(f"Cleaned DataFrames merged. Total rows: {merged_df.count()}")
print("Schema of merged DataFrame:")
merged_df.printSchema()


Uniting processed dataframes...
Cleaned DataFrames merged. Total rows: 541578
Schema of merged DataFrame:
root
 |-- HeartDiseaseorAttack: double (nullable = true)
 |-- HighBP: double (nullable = true)
 |-- HighChol: double (nullable = true)
 |-- CholCheck: double (nullable = true)
 |-- BMI: double (nullable = true)
 |-- Smoker: double (nullable = true)
 |-- Stroke: double (nullable = true)
 |-- Diabetes: double (nullable = true)
 |-- PhysActivity: double (nullable = true)
 |-- Fruits: double (nullable = true)
 |-- Veggies: double (nullable = true)
 |-- HvyAlcoholConsump: double (nullable = true)
 |-- AnyHealthcare: double (nullable = true)
 |-- NoDocbcCost: double (nullable = true)
 |-- GenHlth: double (nullable = true)
 |-- MentHlth: double (nullable = true)
 |-- PhysHlth: double (nullable = true)
 |-- DiffWalk: double (nullable = true)
 |-- Sex: double (nullable = true)
 |-- Age: double (nullable = true)
 |-- Education: double (nullable = true)
 |-- Income: double (nullable = true)



In [23]:
print("Performing class balancing...")
counts_spark = merged_df.groupBy('HeartDiseaseorAttack').count().collect()
majority_count = 0
minority_count = 0
for row in counts_spark:
    if row['HeartDiseaseorAttack'] == 0:
        majority_count = row['count']
    elif row['HeartDiseaseorAttack'] == 1:
        minority_count = row['count']

print(f"Original class counts: HeartDiseaseorAttack=0: {majority_count}, HeartDiseaseorAttack=1: {minority_count}")

df_majority = merged_df.filter(col('HeartDiseaseorAttack') == 0)
df_minority = merged_df.filter(col('HeartDiseaseorAttack') == 1)

sampling_fraction = minority_count / majority_count if majority_count > 0 else 0

df_majority_downsampled = df_majority.sample(False, sampling_fraction, seed=42)

balanced_df = df_minority.union(df_majority_downsampled)
balanced_df = balanced_df.orderBy(rand(seed=42))

print("Class sizes after balancing:")
balanced_df.groupBy('HeartDiseaseorAttack').count().show()
print(f"Balanced DataFrame row count: {balanced_df.count()}")




Performing class balancing...
Original class counts: HeartDiseaseorAttack=0: 498159, HeartDiseaseorAttack=1: 43419
Class sizes after balancing:
+--------------------+-----+
|HeartDiseaseorAttack|count|
+--------------------+-----+
|                 0.0|43462|
|                 1.0|43419|
+--------------------+-----+

Balanced DataFrame row count: 86881


In [24]:
# Dump Cleaned Dataset to MySQL (retained from previous versions, as requested by user's initial problem context)
print("Attempting to dump cleaned dataset to MySQL...")
jdbcHostname = "127.0.0.1"
jdbcPort = 3306
jdbcDatabase = "test"
jdbcUsername = "bigdata"
jdbcPassword = "Bigdata@123"
jdbcTableName = "heart_disease_data_balanced"

jdbcUrl = f"jdbc:mysql://{jdbcHostname}:{jdbcPort}/{jdbcDatabase}?useSSL=false&allowPublicKeyRetrieval=true"
jdbcDriver = "com.mysql.cj.jdbc.Driver"

connectionProperties = {
  "user" : jdbcUsername,
  "password" : jdbcPassword,
  "driver" : jdbcDriver
}

try:
    print(f"Writing data to MySQL table: {jdbcTableName} at {jdbcUrl.split('?')[0]}...")
    balanced_df.write \
        .format("jdbc") \
        .option("url", jdbcUrl) \
        .option("dbtable", jdbcTableName) \
        .mode("overwrite") \
        .options(**connectionProperties) \
        .save()
    print(f"Successfully dumped cleaned data to MySQL table: {jdbcTableName}")
except Exception as e:
    print(f"Error writing to MySQL: {e}")


Attempting to dump cleaned dataset to MySQL...
Writing data to MySQL table: heart_disease_data_balanced at jdbc:mysql://127.0.0.1:3306/test...
Successfully dumped cleaned data to MySQL table: heart_disease_data_balanced


In [25]:
# IMPORTANT: Save the balanced_df to HDFS as a single CSV file for the shell script to copy
print(f"Saving balanced data to HDFS as CSV at: {hdfs_output_csv_path}")
# Coalesce to 1 partition to get a single CSV file, useful for 'hdfs dfs -get'
balanced_df.coalesce(1).write.csv(hdfs_output_csv_path, header=True, mode="overwrite")
print("Balanced data saved to HDFS as CSV.")


Saving balanced data to HDFS as CSV at: Cleaned_dataset
Balanced data saved to HDFS as CSV.


In [26]:
# In preprocess_heart_data.py, after class balancing is done:

local_output_path = "file:////home/talentum/Heart_Disease_Project/cleaned_heart_data_csv/" # Note the three slashes for an absolute path

print(f"Saving balanced data directly to local path as CSV at: {local_output_path}")

# Coalesce to 1 partition to ensure a single CSV file is created in the output directory.
# 'header=True' includes column names, 'mode="overwrite"' handles existing files.
balanced_df.coalesce(1).write.csv(local_output_path, header=True, mode="overwrite")

print("Balanced data saved directly to local filesystem as CSV.")


# Stop Spark Session
spark.stop()
print("Spark Session stopped.")


Saving balanced data directly to local path as CSV at: file:////home/talentum/Heart_Disease_Project/cleaned_heart_data_csv/
Balanced data saved directly to local filesystem as CSV.
Spark Session stopped.
