In [3]:
from pyspark.sql import SparkSession
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.mllib.evaluation import MulticlassMetrics

# Start Spark session
spark = SparkSession.builder.appName("EvaluateTrafficModel").getOrCreate()

# Load predictions from HDFS (CSV format)
predictions = spark.read.option("header", True).option("inferSchema", True).csv(
    "hdfs://localhost:9000/user/hdoop/toronto_traffic/output/final_predictions_csv"
)

# Print schema to confirm the structure
print("Schema of predictions:")
predictions.printSchema()

# Cast traffic_label and prediction columns to double
predictions = predictions.withColumn("traffic_label", predictions["traffic_label"].cast("double"))
predictions = predictions.withColumn("prediction", predictions["prediction"].cast("double"))

# Create an evaluator object
evaluator = MulticlassClassificationEvaluator(labelCol="traffic_label", predictionCol="prediction")

# Compute evaluation metrics
accuracy = evaluator.setMetricName("accuracy").evaluate(predictions)
f1 = evaluator.setMetricName("f1").evaluate(predictions)
precision = evaluator.setMetricName("weightedPrecision").evaluate(predictions)
recall = evaluator.setMetricName("weightedRecall").evaluate(predictions)

print("\nEvaluation Metrics:")
print(f"Accuracy:  {accuracy:.4f}")
print(f"F1 Score:  {f1:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall:    {recall:.4f}")

# Compute the confusion matrix using RDD-based MulticlassMetrics
pred_rdd = predictions.select("prediction", "traffic_label").rdd.map(lambda row: (row["prediction"], row["traffic_label"]))
metrics = MulticlassMetrics(pred_rdd)
conf_matrix = metrics.confusionMatrix().toArray()

print("\nConfusion Matrix:")
print(conf_matrix)

spark.stop()


Schema of predictions:
root
 |-- date: date (nullable = true)
 |-- traffic_label: double (nullable = true)
 |-- prediction: double (nullable = true)


Evaluation Metrics:
Accuracy:  0.6099
F1 Score:  0.6022
Precision: 0.6189
Recall:    0.6099


                                                                                


Confusion Matrix:
[[10223. 11506.]
 [ 5474. 16329.]]
