In [1]:
# Import required libraries
import os
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, isnull, when, count
from pyspark.ml.classification import GBTClassifier
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.evaluation import BinaryClassificationEvaluator
from pyspark.sql.types import IntegerType, FloatType, DoubleType
import mlflow
import pandas as pd

In [2]:
# Configure SparkSession
spark = SparkSession.builder \
    .appName("Spark_MLlib_GBT_Label_Separation") \
    .master("local[*]") \
    .config("spark.driver.memory", "12g") \
    .config("spark.executor.memory", "6g") \
    .getOrCreate()

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/11/26 19:22:55 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [3]:
# Configure MLflow for local tracking
mlflow.set_tracking_uri("file:./mlruns")  # Use local directory for tracking
experiment_name = "MLlib_GBTClassifier"

# Create the experiment if it doesn't exist
if not mlflow.get_experiment_by_name(experiment_name):
    mlflow.create_experiment(experiment_name)
mlflow.set_experiment(experiment_name)

# Configure paths for saving models and logs
MODEL_SAVE_DIR = "./models"
LOG_DIR = "./logs"

os.makedirs(MODEL_SAVE_DIR, exist_ok=True)
os.makedirs(LOG_DIR, exist_ok=True)

In [4]:
# Load Reference Table
REFERENCE_TABLE_PATH = "../EDA/Results/reference_table_real.csv"
reference_table_missing_values_real = spark.read.csv(REFERENCE_TABLE_PATH, header=True)

# Define feature names excluding "DataType"
feature_names = [
    row["Tag"] for row in reference_table_missing_values_real.filter(col("Value Type").isin(["Continuous", "Categorical"])).collect()
    if row["Tag"] != "DataType"
]
target_name = "target"

In [5]:
# Read datasets into PySpark DataFrames
data_dir = "../Cleaning & Preparation/Train Test (Scaled) Data"
train_data_spark = spark.read.parquet(os.path.join(data_dir, "scaled_train_data.parquet"))
test_data_spark = spark.read.parquet(os.path.join(data_dir, "scaled_test_data.parquet"))
validation_data_spark = spark.read.parquet(os.path.join(data_dir, "scaled_validation_data.parquet"))

In [6]:
# Function to replace -1 labels with 0
def replace_invalid_labels(df, label_name):
    print(f"Replacing -1 with 0 in column {label_name}...")
    df = df.withColumn(label_name, when(col(label_name) == -1, 0).otherwise(col(label_name)))
    return df

# Replace all -1 with 0 in 'target' column right after loading
train_data_spark = replace_invalid_labels(train_data_spark, target_name)
test_data_spark = replace_invalid_labels(test_data_spark, target_name)
validation_data_spark = replace_invalid_labels(validation_data_spark, target_name)

Replacing -1 with 0 in column target...
Replacing -1 with 0 in column target...
Replacing -1 with 0 in column target...


In [7]:
# Define the list of labels to process
labels_to_process = [1, 2, 3, 4, 5, 6, 7, 8, 9]

In [8]:
# Separate data by label into PySpark DataFrames
label_dataframes = {}
for label in labels_to_process:
    train_df_label = train_data_spark.filter(col("label") == label)
    validation_df_label = validation_data_spark.filter(col("label") == label)
    test_df_label = test_data_spark.filter(col("label") == label)

    label_dataframes[label] = {
        "train": train_df_label,
        "validation": validation_df_label,
        "test": test_df_label,
    }

In [9]:
# Function to verify data integrity
def verify_data(df, label):
    # Check for null values
    null_counts = df.select([count(when(isnull(c), c)).alias(c) for c in df.columns]).collect()[0].asDict()
    if any(null_counts.values()):
        print(f"Null values detected in label {label}: {null_counts}")
        raise ValueError(f"Null values detected in label {label}. Investigate the data.")

    # Verify column types
    for col_name in feature_names:
        col_type = [f.dataType for f in df.schema.fields if f.name == col_name][0]
        if not isinstance(col_type, (IntegerType, FloatType, DoubleType)):
            print(f"Column {col_name} in label {label} is of type {col_type}, expected numeric.")
            raise ValueError(f"Column {col_name} in label {label} is not numeric.")

    print(f"Label {label}: No nulls, all columns have valid types.")

# Separate the execution of the verification from the data separation
for label in labels_to_process:
    train_df_label = label_dataframes[label]['train']
    validation_df_label = label_dataframes[label]['validation']
    test_df_label = label_dataframes[label]['test']

    verify_data(train_df_label, f"{label}_train")
    verify_data(validation_df_label, f"{label}_validation")
    verify_data(test_df_label, f"{label}_test")

24/11/26 19:23:02 WARN SparkStringUtils: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.
                                                                                

Label 1_train: No nulls, all columns have valid types.


                                                                                

Label 1_validation: No nulls, all columns have valid types.


                                                                                

Label 1_test: No nulls, all columns have valid types.


                                                                                

Label 2_train: No nulls, all columns have valid types.


                                                                                

Label 2_validation: No nulls, all columns have valid types.


                                                                                

Label 2_test: No nulls, all columns have valid types.


                                                                                

Label 3_train: No nulls, all columns have valid types.


                                                                                

Label 3_validation: No nulls, all columns have valid types.


                                                                                

Label 3_test: No nulls, all columns have valid types.


                                                                                

Label 4_train: No nulls, all columns have valid types.


                                                                                

Label 4_validation: No nulls, all columns have valid types.


                                                                                

Label 4_test: No nulls, all columns have valid types.


                                                                                

Label 5_train: No nulls, all columns have valid types.


                                                                                

Label 5_validation: No nulls, all columns have valid types.


                                                                                

Label 5_test: No nulls, all columns have valid types.
Label 6_train: No nulls, all columns have valid types.
Label 6_validation: No nulls, all columns have valid types.
Label 6_test: No nulls, all columns have valid types.


                                                                                

Label 7_train: No nulls, all columns have valid types.


                                                                                

Label 7_validation: No nulls, all columns have valid types.


                                                                                

Label 7_test: No nulls, all columns have valid types.


                                                                                

Label 8_train: No nulls, all columns have valid types.


                                                                                

Label 8_validation: No nulls, all columns have valid types.


                                                                                

Label 8_test: No nulls, all columns have valid types.


                                                                                

Label 9_train: No nulls, all columns have valid types.


                                                                                

Label 9_validation: No nulls, all columns have valid types.




Label 9_test: No nulls, all columns have valid types.


                                                                                

In [10]:
# Train a GBT model for each label
for label in labels_to_process:
    print(f"\nProcessing label: {label}")

    train_df = label_dataframes[label]["train"]
    validation_df = label_dataframes[label]["validation"]
    test_df = label_dataframes[label]["test"]

    # Check if there's sufficient data
    if train_df.count() == 0 or validation_df.count() == 0 or test_df.count() == 0:
        print(f"Insufficient data for label {label}. Skipping this label.")
        continue

    # Assemble features into a feature vector
    assembler = VectorAssembler(inputCols=feature_names, outputCol="features")
    train_df = assembler.transform(train_df)
    validation_df = assembler.transform(validation_df)
    test_df = assembler.transform(test_df)

    # Adjust model training section to include hyper-parameters with default values
    maxDepth = 5  # Maximum depth of the tree (default=5). Controls overfitting.
    maxBins = 32  # Maximum number of bins used for splitting features (default=32).
    maxIter = 20  # Number of iterations (trees) (default=20).
    stepSize = 0.1  # Learning rate (default=0.1). Controls the rate at which the model learns.
    subsamplingRate = 1.0  # Fraction of data to use for training each tree (default=1.0).
    lossType = "logistic"  # Loss function to minimize (default='logistic').
    minInstancesPerNode = 1  # Minimum instances per node (default=1).
    minInfoGain = 0.0  # Minimum information gain for a split (default=0.0).

    # Define the GBTClassifier with hyperparameters
    gbt = GBTClassifier(featuresCol="features", labelCol=target_name,
                        maxDepth=maxDepth,
                        maxBins=maxBins,
                        maxIter=maxIter,
                        stepSize=stepSize,
                        subsamplingRate=subsamplingRate,
                        lossType=lossType,
                        minInstancesPerNode=minInstancesPerNode,
                        minInfoGain=minInfoGain)

    # Start an MLflow run
    with mlflow.start_run(run_name=f"GBT_Label_{label}"):
        # Train the model
        gbt_model = gbt.fit(train_df)

        # Predict on the validation set
        validation_predictions = gbt_model.transform(validation_df)

        # Evaluate the model
        evaluator = BinaryClassificationEvaluator(labelCol=target_name, metricName="areaUnderPR")
        aupr = evaluator.evaluate(validation_predictions)
        print(f"AUC for label {label}: {aupr}")

        # Include classification report and confusion matrix as evaluation results
        def compute_confusion_matrix(predictions, label_col, prediction_col):
            tp = predictions.filter((col(label_col) == 1) & (col(prediction_col) == 1)).count()
            tn = predictions.filter((col(label_col) == 0) & (col(prediction_col) == 0)).count()
            fp = predictions.filter((col(label_col) == 0) & (col(prediction_col) == 1)).count()
            fn = predictions.filter((col(label_col) == 1) & (col(prediction_col) == 0)).count()
            return {'tp': tp, 'tn': tn, 'fp': fp, 'fn': fn}

        confusion_matrix_val = compute_confusion_matrix(validation_predictions, target_name, 'prediction')

        # Compute key metrics
        tp = confusion_matrix_val['tp']
        tn = confusion_matrix_val['tn']
        fp = confusion_matrix_val['fp']
        fn = confusion_matrix_val['fn']
        precision = tp / (tp + fp) if (tp + fp) > 0 else 0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0
        f1_score = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0

        print(f"Precision for label {label}: {precision}")
        print(f"Recall for label {label}: {recall}")
        print(f"F1 Score for label {label}: {f1_score}")
        print(f"Confusion Matrix: {confusion_matrix_val}")

        # Capture the feature importance for each model
        importances = gbt_model.featureImportances
        feature_importance_list = []
        for idx, imp in enumerate(importances):
            feature_importance_list.append((feature_names[idx], imp))

        # Convert to DataFrame
        feature_importance_df = pd.DataFrame(feature_importance_list, columns=["feature", "importance"])

        # Save feature importances to a file
        feature_importance_path = os.path.join(LOG_DIR, f"feature_importance_label_{label}.csv")
        feature_importance_df.to_csv(feature_importance_path, index=False)
        mlflow.log_artifact(feature_importance_path, artifact_path="feature_importance")

        # Save the model
        model_path = os.path.join(MODEL_SAVE_DIR, f"gbt_model_label_{label}")
        gbt_model.write().overwrite().save(model_path)
        mlflow.log_artifact(model_path, artifact_path="models")

        # Log parameters and metrics to MLflow
        mlflow.log_param("label", label)
        mlflow.log_param("maxDepth", maxDepth)
        mlflow.log_param("maxBins", maxBins)
        mlflow.log_param("maxIter", maxIter)
        mlflow.log_param("stepSize", stepSize)
        mlflow.log_param("subsamplingRate", subsamplingRate)
        mlflow.log_param("lossType", lossType)
        mlflow.log_param("minInstancesPerNode", minInstancesPerNode)
        mlflow.log_param("minInfoGain", minInfoGain)

        mlflow.log_metric("areaUnderPR", aupr)
        mlflow.log_metric("precision", precision)
        mlflow.log_metric("recall", recall)
        mlflow.log_metric("f1_score", f1_score)
        mlflow.log_metric("true_positives", tp)
        mlflow.log_metric("true_negatives", tn)
        mlflow.log_metric("false_positives", fp)
        mlflow.log_metric("false_negatives", fn)

        # End the MLflow run
        mlflow.end_run()

print("\nProcessing complete.")


Processing label: 1


24/11/26 19:26:20 WARN InstanceBuilder: Failed to load implementation from:dev.ludovic.netlib.blas.JNIBLAS
                                                                                

AUC for label 1: 0.9984564236676148


                                                                                

Precision for label 1: 0.9973010093266896
Recall for label 1: 0.9557239057239058
F1 Score for label 1: 0.9760698990277908
Confusion Matrix: {'tp': 62447, 'tn': 171285, 'fp': 169, 'fn': 2893}

Processing label: 2


                                                                                

AUC for label 2: 0.9926040254803992


                                                                                

Precision for label 2: 0.9720844374985844
Recall for label 2: 0.9634755474739312
F1 Score for label 2: 0.9677608473840569
Confusion Matrix: {'tp': 85837, 'tn': 185445, 'fp': 2465, 'fn': 3254}

Processing label: 3


                                                                                

AUC for label 3: 0.9648958338147932


                                                                                

Precision for label 3: 0.8532301544225654
Recall for label 3: 0.9971027071854267
F1 Score for label 3: 0.9195730287010593
Confusion Matrix: {'tp': 567503, 'tn': 17580, 'fp': 97620, 'fn': 1649}

Processing label: 4


                                                                                

AUC for label 4: 0.7341298606762959


                                                                                

Precision for label 4: 0.6681818045540264
Recall for label 4: 0.998638631657802
F1 Score for label 4: 0.8006527259948634
Confusion Matrix: {'tp': 2451541, 'tn': 17368, 'fp': 1217432, 'fn': 3342}

Processing label: 5


                                                                                

AUC for label 5: 0.9983420039940153


                                                                                

Precision for label 5: 0.9670473482242505
Recall for label 5: 0.9862663406845624
F1 Score for label 5: 0.9765622949528885
Confusion Matrix: {'tp': 316268, 'tn': 107959, 'fp': 10777, 'fn': 4404}

Processing label: 6
AUC for label 6: 0.9121168816076213
Precision for label 6: 0.9818982387475538
Recall for label 6: 0.6110519104886588
F1 Score for label 6: 0.7533076850896124
Confusion Matrix: {'tp': 4014, 'tn': 70834, 'fp': 74, 'fn': 2555}

Processing label: 7


                                                                                

AUC for label 7: 0.9919171481531437


                                                                                

Precision for label 7: 0.9271423816197853
Recall for label 7: 0.9816704149652908
F1 Score for label 7: 0.9536275617230621
Confusion Matrix: {'tp': 6548313, 'tn': 679777, 'fp': 514586, 'fn': 122269}

Processing label: 8


                                                                                

AUC for label 8: 0.9977719983934743


                                                                                

Precision for label 8: 0.9733505042756959
Recall for label 8: 0.9740694497865907
F1 Score for label 8: 0.9737098443215431
Confusion Matrix: {'tp': 3560631, 'tn': 1056130, 'fp': 97487, 'fn': 94787}

Processing label: 9


                                                                                

AUC for label 9: 0.9763970731502984


                                                                                

Precision for label 9: 0.9158861331305223
Recall for label 9: 0.918088649887691
F1 Score for label 9: 0.9169860689500097
Confusion Matrix: {'tp': 179434, 'tn': 2423450, 'fp': 16479, 'fn': 16009}

Processing complete.
