# **Stroke Prediction Pipeline**

In [None]:
#Import Libaries
from pyspark.sql import SparkSession, DataFrame
from pyspark.sql import functions as F
from pyspark.ml import Pipeline
from pyspark.ml.feature import Imputer, StringIndexer, OneHotEncoder, VectorAssembler, StandardScaler
from pyspark.ml.classification import LogisticRegression, LinearSVC, GBTClassifier, MultilayerPerceptronClassifier
from pyspark.ml.evaluation import BinaryClassificationEvaluator
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
import time
import logging


In [None]:
#Data information
label_column = "stroke"
csv_input_path = "hdfs://master:9000/user/sat3812/stroke_project/data/healthcare-dataset-stroke-data.csv"

#Reduce output logging, making easier to read
logger = logging.getLogger("py4j")
logger.setLevel(logging.ERROR)

## **Helper Functions**

In [None]:
#Cleaning functions
# Define null tokens
null_token_list = [t.lower() for t in ["na", "n/a", "null", "none", "n.a.", "na.", "", "N/A"]]

#Function to convert null tokens to actual nulls
def convert_to_null(column_name: str):
    """
    Convert common text-based null tokens to actual null for the given column.
    """
    return F.when(F.lower(F.trim(F.col(column_name))).isin(null_token_list), None).otherwise(F.col(column_name))


#Function to clean column names
def clean_column_names(dataframe: DataFrame):
    """
    Standardize column names: trim, lowercase, replace spaces/hyphens with underscores, remove parentheses.
    """
    for original_name in dataframe.columns:
        standardized_name = (
            original_name.strip()
                         .lower()
                         .replace(" ", "_")
                         .replace("-", "_")
                         .replace("(", "")
                         .replace(")", "")
        )
        if standardized_name != original_name:
            dataframe = dataframe.withColumnRenamed(original_name, standardized_name)
    return dataframe


#Function to normalize numeric columns 
def normalize_numeric(dataframe: DataFrame):
    return (
        dataframe.withColumn("bmi", convert_to_null("bmi").cast("double"))
                 .withColumn("avg_glucose_level", F.col("avg_glucose_level").cast("double"))
                 .withColumn("age", F.col("age").cast("double"))
    )

In [None]:
#Handling class imbalance with oversampling
def replicate_minority_class_with_oversampling(dataframe: DataFrame, target_column: str = label_column):
    """
    Oversample the minority class in a binary classification DataFrame by replication.
    """
    # Count records per class label
    class_count_by_label = {
        int(row[target_column]): int(row["count"])
        for row in dataframe.groupBy(target_column).count().collect()
    }

    # Identify minority and majority class labels
    minority_label_value = 1
    majority_label_value = 0
    minority_count = class_count_by_label.get(minority_label_value, 1)
    majority_count = class_count_by_label.get(majority_label_value, 1)

    # If classes are already balanced return original DataFrame
    if minority_count >= majority_count:
        return dataframe

    # Separate minority and majority class DataFrames
    minority_dataframe = dataframe.filter(F.col(target_column) == minority_label_value)
    majority_dataframe = dataframe.filter(F.col(target_column) == majority_label_value)

    # Calculate replication factor
    replication_factor = max(1, majority_count // max(1, minority_count))

    # Create balanced DataFrame by replicating minority class
    balanced_dataframe = majority_dataframe
    for _ in range(replication_factor):
        balanced_dataframe = balanced_dataframe.union(minority_dataframe)

    #Print oversampling details
    print(
        f"Oversampled minority class from {minority_count} "
        f"to approximately {minority_count * (replication_factor + 1)}; "
        f"majority was {majority_count}."
    )

    # Shuffle the resulting balanced DataFrame
    return balanced_dataframe.randomSplit([1.0], seed=42)[0]



## Preprocessing 

In [None]:
def preprocessing():
    """
    Preprocessing:
      - Impute numeric columns (median)
      - Index + OneHotEncode categorical string columns
      - Assemble all features into a single vector
      - Standardize features
    """
    #Define columns
    numeric_feature_columns = ["age", "avg_glucose_level", "bmi"]
    categorical_string_columns = ["gender", "ever_married", "work_type", "residence_type", "smoking_status"]
    binary_numeric_columns = ["hypertension", "heart_disease"]

    #Imputer for numeric columns
    numeric_imputer = Imputer(
        inputCols=numeric_feature_columns,
        outputCols=numeric_feature_columns
    )

    #StringIndexers and OneHotEncoders for categorical columns
    string_indexers = [
        StringIndexer(inputCol=col_name, outputCol=f"{col_name}_index", handleInvalid="keep")
        for col_name in categorical_string_columns
    ]

    one_hot_encoders = [
        OneHotEncoder(inputCol=f"{col_name}_index", outputCol=f"{col_name}_onehot")
        for col_name in categorical_string_columns
    ]

    #Create single feature vector
    assembled_inputs = (
        numeric_feature_columns
        + binary_numeric_columns
        + [f"{col_name}_onehot" for col_name in categorical_string_columns]
    )

    #Combine all features into a single vector
    feature_assembler = VectorAssembler(
        inputCols=assembled_inputs,
        outputCol="unscaled_features"
    )

    #Standardize features
    feature_scaler = StandardScaler(
        inputCol="unscaled_features",
        outputCol="features",
        withMean=False,
        withStd=True
    )
    
    #Return preprocessing stages
    return [numeric_imputer] + string_indexers + one_hot_encoders + [feature_assembler, feature_scaler]


#Function to figure out feature vector length
def infer_feature_vector_length(fitted_preprocessing_pipeline, dataframe: DataFrame, features_column: str = "features"):
    sample_row = fitted_preprocessing_pipeline.transform(dataframe.limit(1)).select(features_column).collect()[0]
    return len(sample_row[features_column])

### Evaluator Function

In [None]:
#Function to evaluate model performance
def evaluate_models(predictions: DataFrame, label_column_name: str = label_column):
    """
    Compute metrics:
      - accuracy, AUC (ROC), precision, recall (sensitivity), specificity
      - confusion matrix: TP, TN, FP, FN
    Returns a dictionary.
    """
    # Find AUC
    auc_evaluator = BinaryClassificationEvaluator(labelCol=label_column_name, metricName="areaUnderROC")
    area_under_roc = auc_evaluator.evaluate(predictions)

    # Calculate accuracy
    total_records = predictions.count()
    correct_predictions = predictions.filter(F.col("prediction") == F.col(label_column_name)).count()
    accuracy = (correct_predictions / total_records) if total_records else 0.0

    # Confusion matrix 
    true_positive = predictions.filter((F.col(label_column_name) == 1) & (F.col("prediction") == 1)).count()
    false_positive = predictions.filter((F.col(label_column_name) == 0) & (F.col("prediction") == 1)).count()
    false_negative = predictions.filter((F.col(label_column_name) == 1) & (F.col("prediction") == 0)).count()
    true_negative = predictions.filter((F.col(label_column_name) == 0) & (F.col("prediction") == 0)).count()

    # Calculate precision, recall, specificity
    precision = true_positive / (true_positive + false_positive) if (true_positive + false_positive) else 0.0
    recall = true_positive / (true_positive + false_negative) if (true_positive + false_negative) else 0.0
    specificity = true_negative / (true_negative + false_positive) if (true_negative + false_positive) else 0.0

    # Return metrics as a dictionary
    return {
        "accuracy": accuracy,
        "auc": area_under_roc,
        "precision": precision,
        "recall": recall,
        "specificity": specificity,
        "true_positive": true_positive,
        "true_negative": true_negative,
        "false_positive": false_positive,
        "false_negative": false_negative,
        "total_records": total_records,
    }

### Models

In [None]:
# Model Training 
def models():
    spark_session = SparkSession.builder.appName("sat5165_stroke_pipeline").getOrCreate()
    spark_session.sparkContext.setLogLevel("ERROR")
    print(f"Loading: {csv_input_path}")
    
    input_dataframe = spark_session.read.csv(csv_input_path, header=True, inferSchema=True)

    input_dataframe = clean_column_names(input_dataframe)
    input_dataframe = normalize_numeric(input_dataframe).dropna(subset=[label_column]).withColumn(label_column, F.col(label_column).cast("int"))

    # Train/test split
    training_raw_dataframe, testing_dataframe = input_dataframe.randomSplit([0.8, 0.2], seed=42)
    training_dataframe = replicate_minority_class_with_oversampling(training_raw_dataframe, label_column)

    # Preprocessing pipeline
    preprocessing_stages = preprocessing()

    # Fit once to get MLP input size
    fitted_preprocessor = Pipeline(stages=preprocessing_stages).fit(training_dataframe)
    input_dim = infer_feature_vector_length(fitted_preprocessor, training_dataframe)
    print(f"[info] MLP input_dim = {input_dim}")

    # List to hold (name, estimator, grid)
    model_specs = []

    # Logistic Regression
    logistic_regression = LogisticRegression(labelCol=label_column, featuresCol="features")
    lr_grid = ParamGridBuilder().addGrid(logistic_regression.regParam, [0.01, 0.1]).build()
    model_specs.append(("Logistic Regression", logistic_regression, lr_grid))

    # Linear SVM
    linear_svm = LinearSVC(labelCol=label_column, featuresCol="features", maxIter=100, regParam=0.1)
    svm_grid = ParamGridBuilder().addGrid(linear_svm.regParam, [0.01, 0.1, 0.5]).build()
    model_specs.append(("Linear SVM", linear_svm, svm_grid))

    # Gradient-Boosted Trees
    gbt = GBTClassifier(labelCol=label_column, featuresCol="features", maxIter=50, maxDepth=5, stepSize=0.1, seed=42)
    gbt_grid = ParamGridBuilder().addGrid(gbt.maxDepth, [3, 5]).addGrid(gbt.stepSize, [0.05, 0.1]).build()
    model_specs.append(("Gradient Boosted Trees", gbt, gbt_grid))

    # Multilayer Perceptron (MLP)
    mlp = MultilayerPerceptronClassifier(
        labelCol=label_column,
        featuresCol="features",
        layers=[input_dim, 64, 32, 2],
        maxIter=100,
        stepSize=0.05,
        blockSize=128,
        seed=42
    )
    mlp_grid = ParamGridBuilder().addGrid(mlp.maxIter, [100]).build()
    model_specs.append(("Multilayer Perceptron", mlp, mlp_grid))

    # Evaluator (AUC-ROC)
    auc_evaluator = BinaryClassificationEvaluator(labelCol=label_column, metricName="areaUnderROC")

    summary = []
    for name, estimator, param_grid in model_specs:
        pipeline_with_estimator = Pipeline(stages=preprocessing_stages + [estimator])
        cross_validator = CrossValidator(
            estimator=pipeline_with_estimator,
            estimatorParamMaps=param_grid,
            evaluator=auc_evaluator,
            numFolds=3,
            parallelism=2,
            seed=42
        )
        print(f"\nTraining {name}...")
        start_time = time.time()
        best_model = cross_validator.fit(training_dataframe)
        train_seconds = time.time() - start_time
        print(f"Total time: {train_seconds:.2f} seconds")

        predictions = best_model.transform(testing_dataframe)
        metrics = evaluate_models(predictions, label_column)
        metrics["model"] = name
        metrics["train_seconds"] = train_seconds
        summary.append(metrics)

        print(f"\n{name} Performance Results")
        for k in ["accuracy", "auc", "precision", "recall", "specificity"]:
            print(f"  {k}: {metrics[k]:.3f}")
        print(
            f"  TP: {metrics['true_positive']}  TN: {metrics['true_negative']}  "
            f"FP: {metrics['false_positive']}  FN: {metrics['false_negative']}"
        )

    # Summary table
    print("MODEL COMPARISON SUMMARY")
    print("-" * 86)
    print(f"{'Model':<28} {'Acc':>8} {'AUC':>8} {'Prec':>8} {'Recall':>8} {'Spec':>8} {'Time(s)':>9}")
    print("-" * 86)
    for r in summary:
        print(f"{r['model']:<28} {r['accuracy']:>8.3f} {r['auc']:>8.3f} {r['precision']:>8.3f} {r['recall']:>8.3f} {r['specificity']:>8.3f} {r['train__seconds']:>9.2f}")

    spark_session.stop()
    print("[Stroke Pipeline Complete.]")

    spark_session.stop()
    print("[Stroke Pipeline Complete.]")


In [None]:
#Run the program
if __name__ == "__main__":
    print("[main] starting models()")
    models()