In [7]:
from pyspark.sql import SparkSession
from pyspark.ml.tuning import CrossValidatorModel
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

# 1. Start Spark session
spark = SparkSession.builder \
    .appName("Traffic Status Evaluation") \
    .getOrCreate()

# 2. Load model
cv_model = CrossValidatorModel.load("traffic_status_rf_model")
rf_model = cv_model.bestModel.stages[-1]  # Get the final RandomForest model

# 3. Load predictions
predictions = spark.read.parquet("traffic_status_predictions.parquet")

# 4. Evaluate metrics
metrics = {
    "Accuracy": "accuracy",
    "F1 Score": "f1",
    "Precision": "weightedPrecision",
    "Recall": "weightedRecall"
}

for name, metric in metrics.items():
    evaluator = MulticlassClassificationEvaluator(
        labelCol="label", predictionCol="prediction", metricName=metric
    )
    score = evaluator.evaluate(predictions)
    print(f"{name}: {score:.4f}")

# 5. Extract full feature importances
# Get feature names from pipeline
assembler = cv_model.bestModel.stages[4]  # Index of VectorAssembler
feature_names = assembler.getInputCols()

# Get importances from RandomForest model
importances = rf_model.featureImportances.toArray()

# Print feature importances
print("\n📊 Feature Importances:")
for name, score in zip(feature_names, importances):
    print(f"{name}: {score:.4f}")


Accuracy: 0.4846
F1 Score: 0.3852
Precision: 0.3451
Recall: 0.4846

📊 Feature Importances:
Month: 0.1284
Day: 0.0011
day_of_week: 0.0204
is_weekend: 0.0311
Mean_Temp_C: 0.1097
Total_Precip_mm: 0.0027
camera_vec: 0.0000
road_vec: 0.0000


In [8]:
cv_model.bestModel.stages

[StringIndexerModel: uid=StringIndexer_a80b13bc95c1, handleInvalid=keep,
 OneHotEncoderModel: uid=OneHotEncoder_2959f3a12b91, dropLast=true, handleInvalid=error,
 StringIndexerModel: uid=StringIndexer_2c51e5ab93cc, handleInvalid=keep,
 OneHotEncoderModel: uid=OneHotEncoder_f279c407b5a0, dropLast=true, handleInvalid=error,
 VectorAssembler_894ea8ad1ed7,
 StringIndexerModel: uid=StringIndexer_4a02c2b442fd, handleInvalid=keep,
 RandomForestClassificationModel: uid=RandomForestClassifier_32899780674b, numTrees=100, numClasses=4, numFeatures=641]

In [9]:
predictions.select("features", "label", "prediction").show(truncate=False)


+--------------------------------------------------------+-----+----------+
|features                                                |label|prediction|
+--------------------------------------------------------+-----+----------+
|(641,[0,1,2,4,5,235,551],[2.0,3.0,5.0,-6.8,4.2,1.0,1.0])|0.0  |0.0       |
|(641,[0,1,2,4,5,172,501],[2.0,3.0,5.0,-6.8,4.2,1.0,1.0])|1.0  |0.0       |
|(641,[0,1,2,4,5,321,624],[2.0,3.0,5.0,-6.8,4.2,1.0,1.0])|0.0  |0.0       |
|(641,[0,1,2,4,5,181,356],[2.0,3.0,5.0,-6.8,4.2,1.0,1.0])|0.0  |0.0       |
|(641,[0,1,2,4,5,100,444],[2.0,3.0,5.0,-6.8,4.2,1.0,1.0])|1.0  |0.0       |
|(641,[0,1,2,4,5,51,403],[2.0,3.0,5.0,-6.8,4.2,1.0,1.0]) |0.0  |0.0       |
|(641,[0,1,2,4,5,224,540],[2.0,3.0,5.0,-6.8,4.2,1.0,1.0])|0.0  |0.0       |
|(641,[0,1,2,4,5,225,542],[2.0,3.0,5.0,-6.8,4.2,1.0,1.0])|1.0  |0.0       |
|(641,[0,1,2,4,5,47,401],[2.0,3.0,5.0,-6.8,4.2,1.0,1.0]) |0.0  |0.0       |
|(641,[0,1,2,4,5,152,486],[2.0,3.0,5.0,-6.8,4.2,1.0,1.0])|0.0  |0.0       |
|(641,[0,1,2