In [8]:
# Cell 1: Initialization and Spark Session Setup
import os
import sys

# Set environment variables for Spark (adjust paths if necessary for your environment)
os.environ["SPARK_HOME"] = "/home/talentum/spark" # Adjust as per your Spark installation
os.environ["PYLIB"] = os.environ["SPARK_HOME"] + "/python/lib"
os.environ["PYSPARK_PYTHON"] = "/usr/bin/python3.6" # Ensure this matches your Python version
os.environ["PYSPARK_DRIVER_PYTHON"] = "/usr/bin/python3" # Ensure this matches your Python version
sys.path.insert(0, os.environ["PYLIB"] + "/py4j-0.10.7-src.zip") # Adjust py4j version if needed
sys.path.insert(0, os.environ["PYLIB"] + "/pyspark.zip")

# Ensure MySQL connector package is included for Spark submit.
os.environ['PYSPARK_SUBMIT_ARGS'] = '--packages mysql:mysql-connector-java:8.0.28 pyspark-shell'


from pyspark.sql import SparkSession
from pyspark.sql.functions import col, when, lit, round, rand, sum as spark_sum

# Initialize Spark Session with increased driver memory and MySQL connector
spark = SparkSession.builder \
    .appName("HeartDiseasePreprocessingPySpark") \
    .config("spark.driver.memory", "8g") \
    .config("spark.jars.packages", "mysql:mysql-connector-java:8.0.28") \
    .getOrCreate()

print("Spark Session initialized with 8g driver memory and MySQL connector.")

# Cell 2: Define Preprocessing Function
def preprocess_heart_data(df):
    """
    Applies common cleaning and transformation steps to the heart disease dataset.
    Assumes all relevant columns have already been renamed to their standardized forms.
    This function focuses on recoding, filtering invalid values, dropping nulls, and duplicates.
    """
    final_target_col = "HeartDiseaseorAttack"

    # --- Recode the target variable (Heart Disease/Attack) ---
    if final_target_col in df.columns:
        df = df.withColumn(final_target_col, \
                            when(col(final_target_col) == 2, lit(0)) \
                            .otherwise(col(final_target_col)))
        df = df.filter(~col(final_target_col).isin([7, 9])) # Filter out invalid codes (e.g., Don't know/Refused)
    else:
        # This case should ideally not happen if renaming in Cells 3 & 4 is correct
        print(f"Warning: '{final_target_col}' not found in DataFrame. Target processing might be incomplete.")

    # --- Recode other columns based on BRFSS interpretation and filter invalid codes ---
    # HighBP: 1=Yes, 2=No -> 0=No, 1=Yes (filter 9=Don't know/Refused)
    if "HighBP" in df.columns:
        df = df.withColumn('HighBP', when(col('HighBP') == 1, 1).when(col('HighBP') == 2, 0).otherwise(col('HighBP'))).filter(col('HighBP') != 9)
    # HighChol: 1=Yes, 2=No -> 0=No, 1=Yes (filter 7=Don't know, 9=Refused)
    if "HighChol" in df.columns:
        df = df.withColumn('HighChol', when(col('HighChol') == 1, 1).when(col('HighChol') == 2, 0).otherwise(col('HighChol'))).filter(~col('HighChol').isin([7, 9]))
    # CholCheck: 1=Yes, 2=No, 3=Never -> 0=No, 1=Yes (filter 9=Don't know/Refused)
    if "CholCheck" in df.columns:
        df = df.withColumn('CholCheck', when(col('CholCheck') == 1, 1).when(col('CholCheck').isin([2, 3]), 0).otherwise(col('CholCheck'))).filter(col('CholCheck') != 9)
    
    # BMI: Already numeric, just round
    if "BMI" in df.columns:
        df = df.withColumn('BMI', round(col('BMI'), 0)) # Assuming _BMI5 / 100 already happened during rename/initial processing

    # Smoker: 1=Yes, 2=No -> 0=No, 1=Yes (filter 7, 9)
    if "Smoker" in df.columns:
        df = df.withColumn('Smoker', when(col('Smoker') == 1, 1).when(col('Smoker') == 2, 0).otherwise(col('Smoker'))).filter(~col('Smoker').isin([7, 9]))
    # Stroke: 1=Yes, 2=No -> 0=No, 1=Yes (filter 7, 9)
    if "Stroke" in df.columns:
        df = df.withColumn('Stroke', when(col('Stroke') == 1, 1).when(col('Stroke') == 2, 0).otherwise(col('Stroke'))).filter(~col('Stroke').isin([7, 9]))
    # Diabetes: 1=Yes, 2=No, 3=Pre-diabetes, 4=Gestational -> 0=No, 1=Gestational, 2=Yes/Pre-diabetes (filter 7, 9)
    if "Diabetes" in df.columns:
        df = df.withColumn('Diabetes',
                            when(col('Diabetes') == 1, 2)  # Yes
                            .when(col('Diabetes') == 2, 0)  # No
                            .when(col('Diabetes') == 3, 2)  # Pre-diabetes (grouped with Yes)
                            .when(col('Diabetes') == 4, 1)  # Gestational
                            .otherwise(col('Diabetes'))).filter(~col('Diabetes').isin([7, 9]))
    # PhysActivity: 1=Yes, 2=No -> 0=No, 1=Yes (filter 9)
    if "PhysActivity" in df.columns:
        df = df.withColumn('PhysActivity', when(col('PhysActivity') == 1, 1).when(col('PhysActivity') == 2, 0).otherwise(col('PhysActivity'))).filter(col('PhysActivity') != 9)
    # Fruits: 1=Yes, 2=No -> 0=No, 1=Yes (filter 9)
    if "Fruits" in df.columns:
        df = df.withColumn('Fruits', when(col('Fruits') == 1, 1).when(col('Fruits') == 2, 0).otherwise(col('Fruits'))).filter(col('Fruits') != 9)
    # Veggies: 1=Yes, 2=No -> 0=No, 1=Yes (filter 9)
    if "Veggies" in df.columns:
        df = df.withColumn('Veggies', when(col('Veggies') == 1, 1).when(col('Veggies') == 2, 0).otherwise(col('Veggies'))).filter(col('Veggies') != 9)
    # HvyAlcoholConsump: 1=Yes, 2=No -> 0=No, 1=Yes (filter 9)
    if "HvyAlcoholConsump" in df.columns:
        df = df.withColumn('HvyAlcoholConsump', when(col('HvyAlcoholConsump') == 1, 1).when(col('HvyAlcoholConsump') == 2, 0).otherwise(col('HvyAlcoholConsump'))).filter(col('HvyAlcoholConsump') != 9)
    # AnyHealthcare: 1=Yes, 2=No -> 0=No, 1=Yes (filter 7, 9)
    if "AnyHealthcare" in df.columns:
        df = df.withColumn('AnyHealthcare', when(col('AnyHealthcare') == 1, 1).when(col('AnyHealthcare') == 2, 0).otherwise(col('AnyHealthcare'))).filter(~col('AnyHealthcare').isin([7, 9]))
    # NoDocbcCost: 1=Yes, 2=No -> 0=No, 1=Yes (filter 7, 9)
    if "NoDocbcCost" in df.columns:
        df = df.withColumn('NoDocbcCost', when(col('NoDocbcCost') == 1, 1).when(col('NoDocbcCost') == 2, 0).otherwise(col('NoDocbcCost'))).filter(~col('NoDocbcCost').isin([7, 9]))
    
    # GenHlth: 1-5 scale (filter 7, 9)
    if "GenHlth" in df.columns:
        df = df.filter(~col('GenHlth').isin([7, 9]))
    # MentHlth: 1-30 days, 88=None -> 0=None, else days (filter 77, 99)
    if "MentHlth" in df.columns:
        df = df.withColumn('MentHlth', when(col('MentHlth') == 88, 0).otherwise(col('MentHlth'))).filter(~col('MentHlth').isin([77, 99]))
    # PhysHlth: 1-30 days, 88=None -> 0=None, else days (filter 77, 99)
    if "PhysHlth" in df.columns:
        df = df.withColumn('PhysHlth', when(col('PhysHlth') == 88, 0).otherwise(col('PhysHlth'))).filter(~col('PhysHlth').isin([77, 99]))
    # DiffWalk: 1=Yes, 2=No -> 0=No, 1=Yes (filter 7, 9)
    if "DiffWalk" in df.columns:
        df = df.withColumn('DiffWalk', when(col('DiffWalk') == 1, 1).when(col('DiffWalk') == 2, 0).otherwise(col('DiffWalk'))).filter(~col('DiffWalk').isin([7, 9]))
    # Sex: 1=Male, 2=Female -> 0=Female, 1=Male
    if "Sex" in df.columns:
        df = df.withColumn('Sex', when(col('Sex') == 1, 1).when(col('Sex') == 2, 0).otherwise(col('Sex'))) # Recode Female to 0, Male to 1
    # Age: 1-13 (age groups), filter 14 (Don't know/Refused)
    if "Age" in df.columns:
        df = df.filter(col('Age') != 14)
    # Education: 1-6 (filter 9)
    if "Education" in df.columns:
        df = df.filter(col('Education') != 9)
    # Income: 1-8 (filter 77, 99)
    if "Income" in df.columns:
        df = df.filter(~col('Income').isin([77, 99]))

    # Define the final set of common columns expected after preprocessing
    final_common_columns = [
        final_target_col, "HighBP", "HighChol", "CholCheck", "BMI",
        "Smoker", "Stroke", "Diabetes", "PhysActivity", "Fruits", "Veggies",
        "HvyAlcoholConsump", "AnyHealthcare", "NoDocbcCost", "GenHlth",
        "MentHlth", "PhysHlth", "DiffWalk", "Sex", "Age", "Education", "Income"
    ]

    # Select only the desired columns that are present in the DataFrame
    # This ensures consistency in column order and presence before union
    selected_cols = [c for c in final_common_columns if c in df.columns]
    df = df.select(*selected_cols)

    # Drop rows with any remaining null values across selected columns
    df = df.na.drop()

    # Remove duplicates
    df = df.dropDuplicates()

    return df

# Cell 3: Load 2013 CSV Data, Rename Columns, and Preprocess
print("Loading and preprocessing 2013 data...")
df_2013_raw = spark.read.csv('Dataset/2013.csv', header=True, inferSchema=True)

# Define ALL columns to select and rename for 2013 BEFORE calling preprocess_heart_data
# Importantly, CVDCRHD4 is now selected and renamed to HeartDiseaseorAttack here.
selected_cols_2013_initial = [
    'CVDCRHD4', '_RFHYPE5', 'TOLDHI2', '_CHOLCHK', '_BMI5', 'SMOKE100', 'CVDSTRK3', 'DIABETE3',
    '_TOTINDA', '_FRTLT1', '_VEGLT1', '_RFDRHV4', 'HLTHPLN1', 'MEDCOST', 'GENHLTH',
    'MENTHLTH', 'PHYSHLTH', 'DIFFWALK', 'SEX', '_AGEG5YR', 'EDUCA', 'INCOME2'
]

df_2013_renamed = df_2013_raw.select(*selected_cols_2013_initial) \
    .withColumnRenamed('CVDCRHD4', 'HeartDiseaseorAttack') \
    .withColumnRenamed('_RFHYPE5', 'HighBP') \
    .withColumnRenamed('TOLDHI2', 'HighChol') \
    .withColumnRenamed('_CHOLCHK', 'CholCheck') \
    .withColumnRenamed('_BMI5', 'BMI') \
    .withColumnRenamed('SMOKE100', 'Smoker') \
    .withColumnRenamed('CVDSTRK3', 'Stroke') \
    .withColumnRenamed('DIABETE3', 'Diabetes') \
    .withColumnRenamed('_TOTINDA', 'PhysActivity') \
    .withColumnRenamed('_FRTLT1', 'Fruits') \
    .withColumnRenamed('_VEGLT1', 'Veggies') \
    .withColumnRenamed('_RFDRHV4', 'HvyAlcoholConsump') \
    .withColumnRenamed('HLTHPLN1', 'AnyHealthcare') \
    .withColumnRenamed('MEDCOST', 'NoDocbcCost') \
    .withColumnRenamed('GENHLTH', 'GenHlth') \
    .withColumnRenamed('MENTHLTH', 'MentHlth') \
    .withColumnRenamed('PHYSHLTH', 'PhysHlth') \
    .withColumnRenamed('DIFFWALK', 'DiffWalk') \
    .withColumnRenamed('SEX', 'Sex') \
    .withColumnRenamed('_AGEG5YR', 'Age') \
    .withColumnRenamed('EDUCA', 'Education') \
    .withColumnRenamed('INCOME2', 'Income')

df_2013_processed = preprocess_heart_data(df_2013_renamed)

output_path_2013 = "Dataset/tmp/heart_disease_2013_processed.parquet"
df_2013_processed.write.mode("overwrite").parquet(output_path_2013)
print(f"Processed 2013 data saved to {output_path_2013}")
print(f"2013 Processed Row Count: {df_2013_processed.count()}")


# Cell 4: Load 2015 CSV Data, Rename Columns, and Preprocess
print("Loading and preprocessing 2015 data...")
df_2015_raw = spark.read.csv('Dataset/2015.csv', header=True, inferSchema=True)

# Define ALL columns to select and rename for 2015 BEFORE calling preprocess_heart_data
# Importantly, _MICHD is now selected and renamed to HeartDiseaseorAttack here.
selected_cols_2015_initial = [
    '_MICHD', '_RFHYPE5', 'TOLDHI2', '_CHOLCHK', '_BMI5', 'SMOKE100', 'CVDSTRK3', 'DIABETE3',
    '_TOTINDA', '_FRTLT1', '_VEGLT1', '_RFDRHV5', 'HLTHPLN1', 'MEDCOST', 'GENHLTH',
    'MENTHLTH', 'PHYSHLTH', 'DIFFWALK', 'SEX', '_AGEG5YR', 'EDUCA', 'INCOME2'
]

df_2015_renamed = df_2015_raw.select(*selected_cols_2015_initial) \
    .withColumnRenamed('_MICHD', 'HeartDiseaseorAttack') \
    .withColumnRenamed('_RFHYPE5', 'HighBP') \
    .withColumnRenamed('TOLDHI2', 'HighChol') \
    .withColumnRenamed('_CHOLCHK', 'CholCheck') \
    .withColumnRenamed('_BMI5', 'BMI') \
    .withColumnRenamed('SMOKE100', 'Smoker') \
    .withColumnRenamed('CVDSTRK3', 'Stroke') \
    .withColumnRenamed('DIABETE3', 'Diabetes') \
    .withColumnRenamed('_TOTINDA', 'PhysActivity') \
    .withColumnRenamed('_FRTLT1', 'Fruits') \
    .withColumnRenamed('_VEGLT1', 'Veggies') \
    .withColumnRenamed('_RFDRHV5', 'HvyAlcoholConsump') \
    .withColumnRenamed('HLTHPLN1', 'AnyHealthcare') \
    .withColumnRenamed('MEDCOST', 'NoDocbcCost') \
    .withColumnRenamed('GENHLTH', 'GenHlth') \
    .withColumnRenamed('MENTHLTH', 'MentHlth') \
    .withColumnRenamed('PHYSHLTH', 'PhysHlth') \
    .withColumnRenamed('DIFFWALK', 'DiffWalk') \
    .withColumnRenamed('SEX', 'Sex') \
    .withColumnRenamed('_AGEG5YR', 'Age') \
    .withColumnRenamed('EDUCA', 'Education') \
    .withColumnRenamed('INCOME2', 'Income')

df_2015_processed = preprocess_heart_data(df_2015_renamed)

output_path_2015 = "Dataset/tmp/heart_disease_2015_processed.parquet"
df_2015_processed.write.mode("overwrite").parquet(output_path_2015)
print(f"Processed 2015 data saved to {output_path_2015}")
print(f"2015 Processed Row Count: {df_2015_processed.count()}")


# Cell 5: Load Saved Data and Union
print("Loading processed data from Parquet and uniting...")
df_2013_reloaded = spark.read.parquet(output_path_2013)
df_2015_reloaded = spark.read.parquet(output_path_2015)

merged_df = df_2013_reloaded.unionByName(df_2015_reloaded)
print(f"Cleaned DataFrames merged. Total rows: {merged_df.count()}")
print("Schema of merged DataFrame:")
merged_df.printSchema()


# Cell 6: Perform Class Balancing on the Merged DataFrame (Downsampling)
print("Performing class balancing...")
counts_spark = merged_df.groupBy('HeartDiseaseorAttack').count().collect()
majority_count = 0
minority_count = 0
for row in counts_spark:
    if row['HeartDiseaseorAttack'] == 0:
        majority_count = row['count']
    elif row['HeartDiseaseorAttack'] == 1:
        minority_count = row['count']

print(f"Original class counts: HeartDiseaseorAttack=0: {majority_count}, HeartDiseaseorAttack=1: {minority_count}")

df_majority = merged_df.filter(col('HeartDiseaseorAttack') == 0)
df_minority = merged_df.filter(col('HeartDiseaseorAttack') == 1)

sampling_fraction = minority_count / majority_count if majority_count > 0 else 0

df_majority_downsampled = df_majority.sample(False, sampling_fraction, seed=42)

balanced_df = df_minority.union(df_majority_downsampled)

balanced_df = balanced_df.orderBy(rand(seed=42))

print("Class sizes after balancing:")
balanced_df.groupBy('HeartDiseaseorAttack').count().show()
print(f"Balanced DataFrame row count: {balanced_df.count()}")


# Cell 7: Dump Cleaned Dataset to MySQL
print("Attempting to dump cleaned dataset to MySQL...")
jdbcHostname = "127.0.0.1"
jdbcPort = 3306
jdbcDatabase = "test"
jdbcUsername = "bigdata"
jdbcPassword = "Bigdata@123"
jdbcTableName = "heart_disease_data_balanced"

jdbcUrl = f"jdbc:mysql://{jdbcHostname}:{jdbcPort}/{jdbcDatabase}?useSSL=false&allowPublicKeyRetrieval=true"
jdbcDriver = "com.mysql.cj.jdbc.Driver"

connectionProperties = {
  "user" : jdbcUsername,
  "password" : jdbcPassword,
  "driver" : jdbcDriver
}

try:
    print(f"Writing data to MySQL table: {jdbcTableName} at {jdbcUrl.split('?')[0]}...")
    balanced_df.write \
        .format("jdbc") \
        .option("url", jdbcUrl) \
        .option("dbtable", jdbcTableName) \
        .mode("overwrite") \
        .options(**connectionProperties) \
        .save()
    print(f"Successfully dumped cleaned data to MySQL table: {jdbcTableName}")
except Exception as e:
    print(f"Error writing to MySQL: {e}")

# Cell 8: Stop Spark Session
spark.stop()
print("Spark Session stopped.")

Spark Session initialized with 8g driver memory and MySQL connector.
Loading and preprocessing 2013 data...
Processed 2013 data saved to Dataset/tmp/heart_disease_2013_processed.parquet
2013 Processed Row Count: 292441
Loading and preprocessing 2015 data...
Processed 2015 data saved to Dataset/tmp/heart_disease_2015_processed.parquet
2015 Processed Row Count: 249137
Loading processed data from Parquet and uniting...
Cleaned DataFrames merged. Total rows: 541578
Schema of merged DataFrame:
root
 |-- HeartDiseaseorAttack: double (nullable = true)
 |-- HighBP: double (nullable = true)
 |-- HighChol: double (nullable = true)
 |-- CholCheck: double (nullable = true)
 |-- BMI: double (nullable = true)
 |-- Smoker: double (nullable = true)
 |-- Stroke: double (nullable = true)
 |-- Diabetes: double (nullable = true)
 |-- PhysActivity: double (nullable = true)
 |-- Fruits: double (nullable = true)
 |-- Veggies: double (nullable = true)
 |-- HvyAlcoholConsump: double (nullable = true)
 |-- AnyH