# __Severity prediction__

## üìä Analyse et Pr√©paration des Donn√©es d'Accidents de la Route (USA)

Ce notebook a pour objectif de pr√©parer, nettoyer et transformer un jeu de donn√©es d'accidents de la route aux √âtats-Unis afin de le rendre exploitable pour des t√¢ches de visualisation ou de mod√©lisation (pr√©diction de la gravit√© par exemple).

Nous utilisons **Apache Spark** pour g√©rer efficacement de grands volumes de donn√©es, avec des √©tapes de :
- Chargement
- Nettoyage (valeurs manquantes)
- Transformation des types
- Feature engineering
- Sauvegarde du dataset pr√™t √† l‚Äôemploi


In [1]:
import pandas as pd
import matplotlib.pyplot as plt
from pyspark.sql import SparkSession

## üì• Chargement du jeu de donn√©es

Nous commen√ßons par charger le jeu de donn√©es brut au format CSV, puis nous affichons un aper√ßu des colonnes et du sch√©ma.

In [2]:
spark = SparkSession.builder \
    .master("local[*]") \
    .appName("accidents") \
    .config("spark.executor.memory", "6g") \
    .config("spark.driver.memory", "6g") \
    .config("spark.driver.maxResultSize", "2g") \
    .config("spark.memory.fraction", "0.6") \
    .config("spark.memory.storageFraction", "0.3") \
    .getOrCreate()
spark.sparkContext.setLogLevel("ERROR")

data = spark.read.csv(f"file:///home/clement/2026/S8/Projet_commun/data/US_Accidents_March23.csv", header=True, inferSchema=True)
data.show(5)

25/06/19 23:17:43 WARN Utils: Your hostname, clement-HVY-WXX9 resolves to a loopback address: 127.0.1.1; using 192.168.1.134 instead (on interface wlp1s0)
25/06/19 23:17:43 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/06/19 23:17:44 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
                                                                                

+---+-------+--------+-------------------+-------------------+-----------------+------------------+-------+-------+------------+--------------------+--------------------+------------+----------+-----+----------+-------+----------+------------+-------------------+--------------+-------------+-----------+------------+--------------+--------------+---------------+-----------------+-----------------+-------+-----+--------+--------+--------+-------+-------+----------+-------+-----+---------------+--------------+------------+--------------+--------------+-----------------+---------------------+
| ID| Source|Severity|         Start_Time|           End_Time|        Start_Lat|         Start_Lng|End_Lat|End_Lng|Distance(mi)|         Description|              Street|        City|    County|State|   Zipcode|Country|  Timezone|Airport_Code|  Weather_Timestamp|Temperature(F)|Wind_Chill(F)|Humidity(%)|Pressure(in)|Visibility(mi)|Wind_Direction|Wind_Speed(mph)|Precipitation(in)|Weather_Condition|Ameni

In [3]:
data.printSchema()

root
 |-- ID: string (nullable = true)
 |-- Source: string (nullable = true)
 |-- Severity: integer (nullable = true)
 |-- Start_Time: timestamp (nullable = true)
 |-- End_Time: timestamp (nullable = true)
 |-- Start_Lat: double (nullable = true)
 |-- Start_Lng: double (nullable = true)
 |-- End_Lat: double (nullable = true)
 |-- End_Lng: double (nullable = true)
 |-- Distance(mi): double (nullable = true)
 |-- Description: string (nullable = true)
 |-- Street: string (nullable = true)
 |-- City: string (nullable = true)
 |-- County: string (nullable = true)
 |-- State: string (nullable = true)
 |-- Zipcode: string (nullable = true)
 |-- Country: string (nullable = true)
 |-- Timezone: string (nullable = true)
 |-- Airport_Code: string (nullable = true)
 |-- Weather_Timestamp: timestamp (nullable = true)
 |-- Temperature(F): double (nullable = true)
 |-- Wind_Chill(F): double (nullable = true)
 |-- Humidity(%): double (nullable = true)
 |-- Pressure(in): double (nullable = true)
 |-- V

## üßπ Nettoyage des donn√©es

Certaines colonnes contiennent de nombreuses valeurs manquantes. Nous allons :
- Compter le nombre de valeurs manquantes par colonne
- Supprimer celles avec plus de 100 000 valeurs manquantes
- Remplacer les valeurs manquantes restantes dans les colonnes cat√©gorielles par la modalit√© la plus fr√©quente (mode)

In [4]:
print("Nombre de lignes :", data.count())



Nombre de lignes : 7728394


                                                                                

In [5]:
from pyspark.sql.functions import col, sum

print("Nombre de lignes avec des valeurs manquantes :")
data.select([sum(col(c).isNull().cast("int")).alias(c) for c in data.columns]).show()

Nombre de lignes avec des valeurs manquantes :




+---+------+--------+----------+--------+---------+---------+-------+-------+------------+-----------+------+----+------+-----+-------+-------+--------+------------+-----------------+--------------+-------------+-----------+------------+--------------+--------------+---------------+-----------------+-----------------+-------+----+--------+--------+--------+-------+-------+----------+-------+----+---------------+--------------+------------+--------------+--------------+-----------------+---------------------+
| ID|Source|Severity|Start_Time|End_Time|Start_Lat|Start_Lng|End_Lat|End_Lng|Distance(mi)|Description|Street|City|County|State|Zipcode|Country|Timezone|Airport_Code|Weather_Timestamp|Temperature(F)|Wind_Chill(F)|Humidity(%)|Pressure(in)|Visibility(mi)|Wind_Direction|Wind_Speed(mph)|Precipitation(in)|Weather_Condition|Amenity|Bump|Crossing|Give_Way|Junction|No_Exit|Railway|Roundabout|Station|Stop|Traffic_Calming|Traffic_Signal|Turning_Loop|Sunrise_Sunset|Civil_Twilight|Nautical_Twil

                                                                                

In [6]:
from pyspark.sql.functions import col, sum, when

missing_counts = data.select([
    sum(when(col(c).isNull(), 1).otherwise(0)).alias(c) for c in data.columns
])

missing_dict = missing_counts.collect()[0].asDict()
columns_with_missing_gt_100k = [col_name for col_name, count in missing_dict.items() if count > 100000]

print("Colonnes avec plus de 100 000 valeurs manquantes :")
print(columns_with_missing_gt_100k)




Colonnes avec plus de 100 000 valeurs manquantes :
['End_Lat', 'End_Lng', 'Weather_Timestamp', 'Temperature(F)', 'Wind_Chill(F)', 'Humidity(%)', 'Pressure(in)', 'Visibility(mi)', 'Wind_Direction', 'Wind_Speed(mph)', 'Precipitation(in)', 'Weather_Condition']


                                                                                

In [7]:
data =  data.drop(*columns_with_missing_gt_100k)
data.show(5)

+---+-------+--------+-------------------+-------------------+-----------------+------------------+------------+--------------------+--------------------+------------+----------+-----+----------+-------+----------+------------+-------+-----+--------+--------+--------+-------+-------+----------+-------+-----+---------------+--------------+------------+--------------+--------------+-----------------+---------------------+
| ID| Source|Severity|         Start_Time|           End_Time|        Start_Lat|         Start_Lng|Distance(mi)|         Description|              Street|        City|    County|State|   Zipcode|Country|  Timezone|Airport_Code|Amenity| Bump|Crossing|Give_Way|Junction|No_Exit|Railway|Roundabout|Station| Stop|Traffic_Calming|Traffic_Signal|Turning_Loop|Sunrise_Sunset|Civil_Twilight|Nautical_Twilight|Astronomical_Twilight|
+---+-------+--------+-------------------+-------------------+-----------------+------------------+------------+--------------------+-------------------

On supprimons les colonnes "Description", "ID", "Source" qui ne nous serons pas utiles dans la suite.

In [8]:
data = data.drop("Description", "ID", "Source")
data.show(5)

+--------+-------------------+-------------------+-----------------+------------------+------------+--------------------+------------+----------+-----+----------+-------+----------+------------+-------+-----+--------+--------+--------+-------+-------+----------+-------+-----+---------------+--------------+------------+--------------+--------------+-----------------+---------------------+
|Severity|         Start_Time|           End_Time|        Start_Lat|         Start_Lng|Distance(mi)|              Street|        City|    County|State|   Zipcode|Country|  Timezone|Airport_Code|Amenity| Bump|Crossing|Give_Way|Junction|No_Exit|Railway|Roundabout|Station| Stop|Traffic_Calming|Traffic_Signal|Turning_Loop|Sunrise_Sunset|Civil_Twilight|Nautical_Twilight|Astronomical_Twilight|
+--------+-------------------+-------------------+-----------------+------------------+------------+--------------------+------------+----------+-----+----------+-------+----------+------------+-------+-----+--------+-

In [9]:
print("Nombre de lignes avec des valeurs manquantes :")
data.select([sum(col(c).isNull().cast("int")).alias(c) for c in data.columns]).show()

Nombre de lignes avec des valeurs manquantes :




+--------+----------+--------+---------+---------+------------+------+----+------+-----+-------+-------+--------+------------+-------+----+--------+--------+--------+-------+-------+----------+-------+----+---------------+--------------+------------+--------------+--------------+-----------------+---------------------+
|Severity|Start_Time|End_Time|Start_Lat|Start_Lng|Distance(mi)|Street|City|County|State|Zipcode|Country|Timezone|Airport_Code|Amenity|Bump|Crossing|Give_Way|Junction|No_Exit|Railway|Roundabout|Station|Stop|Traffic_Calming|Traffic_Signal|Turning_Loop|Sunrise_Sunset|Civil_Twilight|Nautical_Twilight|Astronomical_Twilight|
+--------+----------+--------+---------+---------+------------+------+----+------+-----+-------+-------+--------+------------+-------+----+--------+--------+--------+-------+-------+----------+-------+----+---------------+--------------+------------+--------------+--------------+-----------------+---------------------+
|       0|         0|       0|       

                                                                                

In [10]:
from pyspark.sql.functions import col, count, when, desc

features_cat = ["Street", "City", "Zipcode", "Timezone", "Airport_Code", "Sunrise_Sunset", "Civil_Twilight", "Nautical_Twilight", "Astronomical_Twilight"]

for feature in features_cat:
    mode_value = data.groupBy(feature).count().orderBy(desc("count")).first()[0]
    data = data.fillna({feature: mode_value})

print("Nombre de lignes avec des valeurs manquantes :")
data.select([sum(col(c).isNull().cast("int")).alias(c) for c in data.columns]).show()

                                                                                

Nombre de lignes avec des valeurs manquantes :




+--------+----------+--------+---------+---------+------------+------+----+------+-----+-------+-------+--------+------------+-------+----+--------+--------+--------+-------+-------+----------+-------+----+---------------+--------------+------------+--------------+--------------+-----------------+---------------------+
|Severity|Start_Time|End_Time|Start_Lat|Start_Lng|Distance(mi)|Street|City|County|State|Zipcode|Country|Timezone|Airport_Code|Amenity|Bump|Crossing|Give_Way|Junction|No_Exit|Railway|Roundabout|Station|Stop|Traffic_Calming|Traffic_Signal|Turning_Loop|Sunrise_Sunset|Civil_Twilight|Nautical_Twilight|Astronomical_Twilight|
+--------+----------+--------+---------+---------+------------+------+----+------+-----+-------+-------+--------+------------+-------+----+--------+--------+--------+-------+-------+----------+-------+----+---------------+--------------+------------+--------------+--------------+-----------------+---------------------+
|       0|         0|       0|       

                                                                                

## üîß Conversion des types et traitement des bool√©ens / cat√©gorielles

Nous proc√©dons √† :
- La conversion des colonnes bool√©ennes en 0/1
- L'encodage des colonnes binaires et cat√©gorielles avec `StringIndexer`

In [11]:
data.printSchema()

root
 |-- Severity: integer (nullable = true)
 |-- Start_Time: timestamp (nullable = true)
 |-- End_Time: timestamp (nullable = true)
 |-- Start_Lat: double (nullable = true)
 |-- Start_Lng: double (nullable = true)
 |-- Distance(mi): double (nullable = true)
 |-- Street: string (nullable = false)
 |-- City: string (nullable = false)
 |-- County: string (nullable = true)
 |-- State: string (nullable = true)
 |-- Zipcode: string (nullable = false)
 |-- Country: string (nullable = true)
 |-- Timezone: string (nullable = false)
 |-- Airport_Code: string (nullable = false)
 |-- Amenity: boolean (nullable = true)
 |-- Bump: boolean (nullable = true)
 |-- Crossing: boolean (nullable = true)
 |-- Give_Way: boolean (nullable = true)
 |-- Junction: boolean (nullable = true)
 |-- No_Exit: boolean (nullable = true)
 |-- Railway: boolean (nullable = true)
 |-- Roundabout: boolean (nullable = true)
 |-- Station: boolean (nullable = true)
 |-- Stop: boolean (nullable = true)
 |-- Traffic_Calming: bo

In [12]:
boolean_features = ["Amenity", "Bump", "Crossing", "Give_Way", "Junction", "No_Exit", "Railway", "Roundabout", "Station", "Stop", "Traffic_Calming", "Traffic_Signal", "Turning_Loop"]
for feature in boolean_features:
    data = data.withColumn(feature, col(feature).cast("int"))

data.show(5)

+--------+-------------------+-------------------+-----------------+------------------+------------+--------------------+------------+----------+-----+----------+-------+----------+------------+-------+----+--------+--------+--------+-------+-------+----------+-------+----+---------------+--------------+------------+--------------+--------------+-----------------+---------------------+
|Severity|         Start_Time|           End_Time|        Start_Lat|         Start_Lng|Distance(mi)|              Street|        City|    County|State|   Zipcode|Country|  Timezone|Airport_Code|Amenity|Bump|Crossing|Give_Way|Junction|No_Exit|Railway|Roundabout|Station|Stop|Traffic_Calming|Traffic_Signal|Turning_Loop|Sunrise_Sunset|Civil_Twilight|Nautical_Twilight|Astronomical_Twilight|
+--------+-------------------+-------------------+-----------------+------------------+------------+--------------------+------------+----------+-----+----------+-------+----------+------------+-------+----+--------+------

In [13]:
from pyspark.ml.feature import StringIndexer

binary_features = ["Sunrise_Sunset", "Civil_Twilight", "Nautical_Twilight", "Astronomical_Twilight"]
for feature in binary_features:
    indexer = StringIndexer(inputCol=feature, outputCol=feature + "_index")
    data = indexer.fit(data).transform(data)
    data = data.drop(feature)
data.show(5)

                                                                                

+--------+-------------------+-------------------+-----------------+------------------+------------+--------------------+------------+----------+-----+----------+-------+----------+------------+-------+----+--------+--------+--------+-------+-------+----------+-------+----+---------------+--------------+------------+--------------------+--------------------+-----------------------+---------------------------+
|Severity|         Start_Time|           End_Time|        Start_Lat|         Start_Lng|Distance(mi)|              Street|        City|    County|State|   Zipcode|Country|  Timezone|Airport_Code|Amenity|Bump|Crossing|Give_Way|Junction|No_Exit|Railway|Roundabout|Station|Stop|Traffic_Calming|Traffic_Signal|Turning_Loop|Sunrise_Sunset_index|Civil_Twilight_index|Nautical_Twilight_index|Astronomical_Twilight_index|
+--------+-------------------+-------------------+-----------------+------------------+------------+--------------------+------------+----------+-----+----------+-------+----

## üïí Feature Engineering temporel

Les colonnes de dates sont d√©compos√©es en sous-composantes (ann√©e, mois, jour, heure, etc.) afin de permettre une meilleure mod√©lisation et visualisation des comportements temporels.
Une nouvelle colonne `duration` est aussi ajout√©e pour quantifier la dur√©e d'un accident en minutes.


In [14]:
from pyspark.sql.functions import year, month, dayofmonth, hour, minute, second

data = data.withColumn("duration", (col("End_Time").cast("long") - col("Start_Time").cast("long")) / 60)  # duration in minutes

data = data.withColumn("Start_Year", year(col("Start_Time"))) \
        .withColumn("Start_Month", month(col("Start_Time"))) \
        .withColumn("Start_Day", dayofmonth(col("Start_Time"))) \
        .withColumn("Start_Hour", hour(col("Start_Time"))) \
        .withColumn("Start_Minute", minute(col("Start_Time"))) \
        .withColumn("Start_Second", second(col("Start_Time"))) \
        .drop("Start_Time")
data = data.withColumn("End_Year", year(col("End_Time"))) \
        .withColumn("End_Month", month(col("End_Time"))) \
        .withColumn("End_Day", dayofmonth(col("End_Time"))) \
        .withColumn("End_Hour", hour(col("End_Time"))) \
        .withColumn("End_Minute", minute(col("End_Time"))) \
        .withColumn("End_Second", second(col("End_Time"))) \
        .drop("End_Time")
data.show(5) 

+--------+-----------------+------------------+------------+--------------------+------------+----------+-----+----------+-------+----------+------------+-------+----+--------+--------+--------+-------+-------+----------+-------+----+---------------+--------------+------------+--------------------+--------------------+-----------------------+---------------------------+--------+----------+-----------+---------+----------+------------+------------+--------+---------+-------+--------+----------+----------+
|Severity|        Start_Lat|         Start_Lng|Distance(mi)|              Street|        City|    County|State|   Zipcode|Country|  Timezone|Airport_Code|Amenity|Bump|Crossing|Give_Way|Junction|No_Exit|Railway|Roundabout|Station|Stop|Traffic_Calming|Traffic_Signal|Turning_Loop|Sunrise_Sunset_index|Civil_Twilight_index|Nautical_Twilight_index|Astronomical_Twilight_index|duration|Start_Year|Start_Month|Start_Day|Start_Hour|Start_Minute|Start_Second|End_Year|End_Month|End_Day|End_Hour|End

## üó∫ Nettoyage des colonnes g√©ographiques

Certaines colonnes comme `Street`, `City`, `County`, `Zipcode`, etc., sont supprim√©es pour :
- R√©duire la cardinalit√©
- Limiter le bruit

In [15]:
# Drop Street, City, County, Country, Zipcode, Airport_Code, State

data = data.drop("Street", "City", "County", "Country", "Zipcode", "Airport_Code")
data.show(5) 

+--------+-----------------+------------------+------------+-----+----------+-------+----+--------+--------+--------+-------+-------+----------+-------+----+---------------+--------------+------------+--------------------+--------------------+-----------------------+---------------------------+--------+----------+-----------+---------+----------+------------+------------+--------+---------+-------+--------+----------+----------+
|Severity|        Start_Lat|         Start_Lng|Distance(mi)|State|  Timezone|Amenity|Bump|Crossing|Give_Way|Junction|No_Exit|Railway|Roundabout|Station|Stop|Traffic_Calming|Traffic_Signal|Turning_Loop|Sunrise_Sunset_index|Civil_Twilight_index|Nautical_Twilight_index|Astronomical_Twilight_index|duration|Start_Year|Start_Month|Start_Day|Start_Hour|Start_Minute|Start_Second|End_Year|End_Month|End_Day|End_Hour|End_Minute|End_Second|
+--------+-----------------+------------------+------------+-----+----------+-------+----+--------+--------+--------+-------+-------+-

In [16]:
features = ["Timezone", "State"]

for feature in features:
    indexer = StringIndexer(inputCol=feature, outputCol=feature + "_index")
    data = indexer.fit(data).transform(data)
    data = data.drop(feature)
data.show(5)



+--------+-----------------+------------------+------------+-------+----+--------+--------+--------+-------+-------+----------+-------+----+---------------+--------------+------------+--------------------+--------------------+-----------------------+---------------------------+--------+----------+-----------+---------+----------+------------+------------+--------+---------+-------+--------+----------+----------+--------------+-----------+
|Severity|        Start_Lat|         Start_Lng|Distance(mi)|Amenity|Bump|Crossing|Give_Way|Junction|No_Exit|Railway|Roundabout|Station|Stop|Traffic_Calming|Traffic_Signal|Turning_Loop|Sunrise_Sunset_index|Civil_Twilight_index|Nautical_Twilight_index|Astronomical_Twilight_index|duration|Start_Year|Start_Month|Start_Day|Start_Hour|Start_Minute|Start_Second|End_Year|End_Month|End_Day|End_Hour|End_Minute|End_Second|Timezone_index|State_index|
+--------+-----------------+------------------+------------+-------+----+--------+--------+--------+-------+------

                                                                                

## üß† Cr√©ation de variables d√©riv√©es

Deux nouvelles colonnes sont cr√©√©es :
- `is_weekend` : indique si l'accident a eu lieu un samedi ou dimanche
- `is_night` : indique si l'accident a eu lieu avant 6h du matin

In [17]:
data = data.withColumn("is_weekend", (col("Start_Day") >= 6).cast("int")) \
            .withColumn("is_night", (col("Start_Hour") < 6).cast("int")) \
            
data.show(5)

+--------+-----------------+------------------+------------+-------+----+--------+--------+--------+-------+-------+----------+-------+----+---------------+--------------+------------+--------------------+--------------------+-----------------------+---------------------------+--------+----------+-----------+---------+----------+------------+------------+--------+---------+-------+--------+----------+----------+--------------+-----------+----------+--------+
|Severity|        Start_Lat|         Start_Lng|Distance(mi)|Amenity|Bump|Crossing|Give_Way|Junction|No_Exit|Railway|Roundabout|Station|Stop|Traffic_Calming|Traffic_Signal|Turning_Loop|Sunrise_Sunset_index|Civil_Twilight_index|Nautical_Twilight_index|Astronomical_Twilight_index|duration|Start_Year|Start_Month|Start_Day|Start_Hour|Start_Minute|Start_Second|End_Year|End_Month|End_Day|End_Hour|End_Minute|End_Second|Timezone_index|State_index|is_weekend|is_night|
+--------+-----------------+------------------+------------+-------+----+-

## Pr√©paration du train set et du test set

In [18]:
from pyspark.ml.feature import VectorAssembler

feature_cols = ["Start_Lat", "Start_Lng", "Distance(mi)", "Amenity", "Bump", "Crossing", "Give_Way",
                "Junction", "No_Exit", "Railway", "Roundabout", "Station", "Stop", "Traffic_Calming", 
                "Traffic_Signal", "Turning_Loop","Sunrise_Sunset_index", "Civil_Twilight_index", 
                "Nautical_Twilight_index", "Astronomical_Twilight_index", "duration", "Start_Year", 
                "Start_Month", "Start_Day", "Start_Hour", "Start_Minute", "Start_Second", "End_Year", 
                "End_Month", "End_Day", "End_Hour", "End_Minute", "End_Second", "Timezone_index", 
                "State_index", "is_weekend", "is_night"]

assembler = VectorAssembler(inputCols=feature_cols, outputCol="features")
data = assembler.transform(data)

In [19]:
train_df, test_df = data.randomSplit([0.8, 0.2], seed=42)

print("Train set size:", train_df.count())
print("Test set size:", test_df.count())

                                                                                

Train set size: 6184251




Test set size: 1544143


                                                                                

## Random Forest

In [None]:
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

# Entra√Ænement du mod√®le Random Forest
rf = RandomForestClassifier(labelCol="Severity", featuresCol="features", numTrees=50, maxDepth=5, maxBins=64, seed=42)
rf_model = rf.fit(train_df)
rf_pred = rf_model.transform(test_df)
rf_pred.select("features", "Severity", "prediction", "probability").show(5)

# √âvaluation avec diff√©rentes m√©triques
# Accuracy
rf_evaluator_accuracy = MulticlassClassificationEvaluator(labelCol="Severity", predictionCol="prediction", metricName="accuracy")
rf_accuracy = rf_evaluator_accuracy.evaluate(rf_pred)
print("Test set accuracy with Random Forest:", rf_accuracy)

# Precision (macro-average)
rf_evaluator_precision = MulticlassClassificationEvaluator(labelCol="Severity", predictionCol="prediction", metricName="weightedPrecision")
rf_precision = rf_evaluator_precision.evaluate(rf_pred)
print("Test set precision with Random Forest:", rf_precision)

# Recall (macro-average)
rf_evaluator_recall = MulticlassClassificationEvaluator(labelCol="Severity", predictionCol="prediction", metricName="weightedRecall")
rf_recall = rf_evaluator_recall.evaluate(rf_pred)
print("Test set recall with Random Forest:", rf_recall)

# F1-Score (macro-average)
rf_evaluator_f1 = MulticlassClassificationEvaluator(labelCol="Severity", predictionCol="prediction", metricName="f1")
rf_f1 = rf_evaluator_f1.evaluate(rf_pred)
print("Test set F1-score with Random Forest:", rf_f1)

# R√©sum√© des m√©triques
print("\n=== R√©sum√© des m√©triques Random Forest ===")
print(f"Accuracy:  {rf_accuracy:.4f}")
print(f"Precision: {rf_precision:.4f}")
print(f"Recall:    {rf_recall:.4f}")
print(f"F1-Score:  {rf_f1:.4f}")

from pyspark.sql.functions import col

# Affichage de la distribution des pr√©dictions par classe
print("\n=== Distribution des pr√©dictions ===")
rf_pred.groupBy("Severity", "prediction").count().orderBy("Severity", "prediction").show()

# Pour obtenir des m√©triques par classe individuellement
print("\n=== M√©triques par classe ===")
classes = rf_pred.select("Severity").distinct().collect()
for row in classes:
    class_label = row["Severity"]
    print(f"\nClasse {class_label}:")
    
    # Cr√©er des pr√©dictions binaires pour cette classe
    binary_pred = rf_pred.withColumn("binary_label", 
                                   (col("Severity") == class_label).cast("double")) \
                        .withColumn("binary_prediction", 
                                   (col("prediction") == class_label).cast("double"))
    
    # Calculer les m√©triques pour cette classe
    tp = binary_pred.filter((col("binary_label") == 1.0) & (col("binary_prediction") == 1.0)).count()
    fp = binary_pred.filter((col("binary_label") == 0.0) & (col("binary_prediction") == 1.0)).count()
    tn = binary_pred.filter((col("binary_label") == 0.0) & (col("binary_prediction") == 0.0)).count()
    fn = binary_pred.filter((col("binary_label") == 1.0) & (col("binary_prediction") == 0.0)).count()
    
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0
    f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
    
    print(f"  Precision: {precision:.4f}")
    print(f"  Recall:    {recall:.4f}")
    print(f"  F1-Score:  {f1:.4f}")
    print(f"  Support:   {tp + fn}")

                                                                                

+--------------------+--------+----------+--------------------+
|            features|Severity|prediction|         probability|
+--------------------+--------+----------+--------------------+
|(37,[0,1,5,14,20,...|       1|       2.0|[0.0,0.0101656887...|
|(37,[0,1,20,21,22...|       1|       2.0|[0.0,0.0069492012...|
|(37,[0,1,20,21,22...|       1|       2.0|[0.0,0.0069492012...|
|(37,[0,1,2,11,14,...|       1|       2.0|[0.0,0.0091299548...|
|(37,[0,1,5,14,20,...|       1|       2.0|[0.0,0.0101656887...|
+--------------------+--------+----------+--------------------+
only showing top 5 rows



                                                                                

Test set accuracy with Random Forest: 0.7972137295574309


                                                                                

Test set precision with Random Forest: 0.6355497305948685


                                                                                

Test set recall with Random Forest: 0.7972137295574309


                                                                                

Test set F1-score with Random Forest: 0.7072611566921142

=== R√©sum√© des m√©triques Random Forest ===
Accuracy:  0.7972
Precision: 0.6355
Recall:    0.7972
F1-Score:  0.7073

=== Distribution des pr√©dictions ===


                                                                                

+--------+----------+-------+
|Severity|prediction|  count|
+--------+----------+-------+
|       1|       2.0|  13598|
|       2|       2.0|1231012|
|       3|       2.0| 258740|
|       4|       2.0|  40793|
+--------+----------+-------+


=== M√©triques par classe ===


                                                                                


Classe 1:


                                                                                

  Precision: 0.0000
  Recall:    0.0000
  F1-Score:  0.0000
  Support:   13598

Classe 3:


                                                                                

  Precision: 0.0000
  Recall:    0.0000
  F1-Score:  0.0000
  Support:   258740

Classe 4:


                                                                                

  Precision: 0.0000
  Recall:    0.0000
  F1-Score:  0.0000
  Support:   40793

Classe 2:




  Precision: 0.7972
  Recall:    1.0000
  F1-Score:  0.8872
  Support:   1231012


                                                                                

## Logistic Regression & Multilayer Perceptron Classifier

In [27]:
from pyspark.ml.classification import LogisticRegression, MultilayerPerceptronClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.sql.functions import col

# ===== LOGISTIC REGRESSION =====
print("="*50)
print("LOGISTIC REGRESSION")
print("="*50)

# Entra√Ænement du mod√®le Logistic Regression
lr = LogisticRegression(labelCol="Severity", featuresCol="features", maxIter=100, regParam=0.01, elasticNetParam=0.8)
lr_model = lr.fit(train_df)
lr_pred = lr_model.transform(test_df)
lr_pred.select("features", "Severity", "prediction", "probability").show(5)

# √âvaluation Logistic Regression
lr_evaluator_accuracy = MulticlassClassificationEvaluator(labelCol="Severity", predictionCol="prediction", metricName="accuracy")
lr_accuracy = lr_evaluator_accuracy.evaluate(lr_pred)

lr_evaluator_precision = MulticlassClassificationEvaluator(labelCol="Severity", predictionCol="prediction", metricName="weightedPrecision")
lr_precision = lr_evaluator_precision.evaluate(lr_pred)

lr_evaluator_recall = MulticlassClassificationEvaluator(labelCol="Severity", predictionCol="prediction", metricName="weightedRecall")
lr_recall = lr_evaluator_recall.evaluate(lr_pred)

lr_evaluator_f1 = MulticlassClassificationEvaluator(labelCol="Severity", predictionCol="prediction", metricName="f1")
lr_f1 = lr_evaluator_f1.evaluate(lr_pred)

# R√©sum√© des m√©triques Logistic Regression
print("\n=== R√©sum√© des m√©triques Logistic Regression ===")
print(f"Accuracy:  {lr_accuracy:.4f}")
print(f"Precision: {lr_precision:.4f}")
print(f"Recall:    {lr_recall:.4f}")
print(f"F1-Score:  {lr_f1:.4f}")

# Distribution des pr√©dictions Logistic Regression
print("\n=== Distribution des pr√©dictions Logistic Regression ===")
lr_pred.groupBy("Severity", "prediction").count().orderBy("Severity", "prediction").show()

# ===== MULTILAYER PERCEPTRON =====
print("\n" + "="*50)
print("MULTILAYER PERCEPTRON")
print("="*50)

# Entra√Ænement du mod√®le MLP
layers = [len(feature_cols), 128, 64, 32, 5]  # Input layer, two hidden layers, output layer (5 classes)
mlp = MultilayerPerceptronClassifier(labelCol="Severity", featuresCol="features", maxIter=100, layers=layers, blockSize=128, seed=42)
mlp_model = mlp.fit(train_df)
mlp_pred = mlp_model.transform(test_df)
mlp_pred.select("features", "Severity", "prediction", "probability").show(5)

# √âvaluation MLP
mlp_evaluator_accuracy = MulticlassClassificationEvaluator(labelCol="Severity", predictionCol="prediction", metricName="accuracy")
mlp_accuracy = mlp_evaluator_accuracy.evaluate(mlp_pred)

mlp_evaluator_precision = MulticlassClassificationEvaluator(labelCol="Severity", predictionCol="prediction", metricName="weightedPrecision")
mlp_precision = mlp_evaluator_precision.evaluate(mlp_pred)

mlp_evaluator_recall = MulticlassClassificationEvaluator(labelCol="Severity", predictionCol="prediction", metricName="weightedRecall")
mlp_recall = mlp_evaluator_recall.evaluate(mlp_pred)

mlp_evaluator_f1 = MulticlassClassificationEvaluator(labelCol="Severity", predictionCol="prediction", metricName="f1")
mlp_f1 = mlp_evaluator_f1.evaluate(mlp_pred)

# R√©sum√© des m√©triques MLP
print("\n=== R√©sum√© des m√©triques MLP ===")
print(f"Accuracy:  {mlp_accuracy:.4f}")
print(f"Precision: {mlp_precision:.4f}")
print(f"Recall:    {mlp_recall:.4f}")
print(f"F1-Score:  {mlp_f1:.4f}")

# Distribution des pr√©dictions MLP
print("\n=== Distribution des pr√©dictions MLP ===")
mlp_pred.groupBy("Severity", "prediction").count().orderBy("Severity", "prediction").show()

# ===== COMPARAISON DES MOD√àLES =====
print("\n" + "="*60)
print("COMPARAISON DES TROIS MOD√àLES")
print("="*60)

# Tableau comparatif (supposons que rf_accuracy, rf_precision, rf_recall, rf_f1 sont disponibles)
print(f"{'Mod√®le':<20} {'Accuracy':<10} {'Precision':<10} {'Recall':<10} {'F1-Score':<10}")
print("-" * 60)
print(f"{'Random Forest':<20} {rf_accuracy:<10.4f} {rf_precision:<10.4f} {rf_recall:<10.4f} {rf_f1:<10.4f}")
print(f"{'Logistic Regression':<20} {lr_accuracy:<10.4f} {lr_precision:<10.4f} {lr_recall:<10.4f} {lr_f1:<10.4f}")
print(f"{'MLP Neural Net':<20} {mlp_accuracy:<10.4f} {mlp_precision:<10.4f} {mlp_recall:<10.4f} {mlp_f1:<10.4f}")

# Identification du meilleur mod√®le
models_performance = {
    'Random Forest': {'accuracy': rf_accuracy, 'f1': rf_f1},
    'Logistic Regression': {'accuracy': lr_accuracy, 'f1': lr_f1},
    'MLP Neural Net': {'accuracy': mlp_accuracy, 'f1': mlp_f1}
}

best_accuracy_model = max(models_performance.items(), key=lambda x: x[1]['accuracy'])
best_f1_model = max(models_performance.items(), key=lambda x: x[1]['f1'])

print(f"\n=== Meilleurs mod√®les ===")
print(f"Meilleure Accuracy: {best_accuracy_model[0]} ({best_accuracy_model[1]['accuracy']:.4f})")
print(f"Meilleur F1-Score:  {best_f1_model[0]} ({best_f1_model[1]['f1']:.4f})")

# ===== M√âTRIQUES PAR CLASSE POUR LOGISTIC REGRESSION =====
def calculate_class_metrics(predictions_df, model_name):
    print(f"\n=== M√©triques par classe - {model_name} ===")
    classes = predictions_df.select("Severity").distinct().collect()
    
    for row in classes:
        class_label = row["Severity"]
        print(f"\nClasse {class_label}:")
        
        # Cr√©er des pr√©dictions binaires pour cette classe
        binary_pred = predictions_df.withColumn("binary_label", 
                                       (col("Severity") == class_label).cast("double")) \
                            .withColumn("binary_prediction", 
                                       (col("prediction") == class_label).cast("double"))
        
        # Calculer les m√©triques pour cette classe
        tp = binary_pred.filter((col("binary_label") == 1.0) & (col("binary_prediction") == 1.0)).count()
        fp = binary_pred.filter((col("binary_label") == 0.0) & (col("binary_prediction") == 1.0)).count()
        tn = binary_pred.filter((col("binary_label") == 0.0) & (col("binary_prediction") == 0.0)).count()
        fn = binary_pred.filter((col("binary_label") == 1.0) & (col("binary_prediction") == 0.0)).count()
        
        precision = tp / (tp + fp) if (tp + fp) > 0 else 0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0
        f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
        
        print(f"  Precision: {precision:.4f}")
        print(f"  Recall:    {recall:.4f}")
        print(f"  F1-Score:  {f1:.4f}")
        print(f"  Support:   {tp + fn}")

# Calculer les m√©triques par classe pour chaque mod√®le
calculate_class_metrics(lr_pred, "Logistic Regression")
calculate_class_metrics(mlp_pred, "MLP Neural Network")

LOGISTIC REGRESSION


                                                                                

+--------------------+--------+----------+--------------------+
|            features|Severity|prediction|         probability|
+--------------------+--------+----------+--------------------+
|(37,[0,1,5,14,20,...|       1|       2.0|[1.60170429709093...|
|(37,[0,1,20,21,22...|       1|       2.0|[1.44556405771679...|
|(37,[0,1,20,21,22...|       1|       3.0|[1.17582062223475...|
|(37,[0,1,2,11,14,...|       1|       2.0|[1.76436926235271...|
|(37,[0,1,5,14,20,...|       1|       2.0|[1.63432594728259...|
+--------------------+--------+----------+--------------------+
only showing top 5 rows



                                                                                


=== R√©sum√© des m√©triques Logistic Regression ===
Accuracy:  0.7926
Precision: 0.7040
Recall:    0.7926
F1-Score:  0.7174

=== Distribution des pr√©dictions Logistic Regression ===


                                                                                

+--------+----------+-------+
|Severity|prediction|  count|
+--------+----------+-------+
|       1|       2.0|  13586|
|       1|       3.0|     12|
|       2|       2.0|1213079|
|       2|       3.0|  17879|
|       2|       4.0|     54|
|       3|       2.0| 247874|
|       3|       3.0|  10835|
|       3|       4.0|     31|
|       4|       2.0|  38640|
|       4|       3.0|   2128|
|       4|       4.0|     25|
+--------+----------+-------+


MULTILAYER PERCEPTRON


                                                                                

+--------------------+--------+----------+--------------------+
|            features|Severity|prediction|         probability|
+--------------------+--------+----------+--------------------+
|(37,[0,1,5,14,20,...|       1|       2.0|[8.91477607081901...|
|(37,[0,1,20,21,22...|       1|       2.0|[4.10347228082112...|
|(37,[0,1,20,21,22...|       1|       2.0|[4.10434281760650...|
|(37,[0,1,2,11,14,...|       1|       2.0|[4.12301152839504...|
|(37,[0,1,5,14,20,...|       1|       2.0|[4.43935322644836...|
+--------------------+--------+----------+--------------------+
only showing top 5 rows



                                                                                


=== R√©sum√© des m√©triques MLP ===
Accuracy:  0.7972
Precision: 0.6355
Recall:    0.7972
F1-Score:  0.7073

=== Distribution des pr√©dictions MLP ===


                                                                                

+--------+----------+-------+
|Severity|prediction|  count|
+--------+----------+-------+
|       1|       2.0|  13598|
|       2|       2.0|1231012|
|       3|       2.0| 258740|
|       4|       2.0|  40793|
+--------+----------+-------+


COMPARAISON DES TROIS MOD√àLES
Mod√®le               Accuracy   Precision  Recall     F1-Score  
------------------------------------------------------------
Random Forest        0.7972     0.6355     0.7972     0.7073    
Logistic Regression  0.7926     0.7040     0.7926     0.7174    
MLP Neural Net       0.7972     0.6355     0.7972     0.7073    

=== Meilleurs mod√®les ===
Meilleure Accuracy: Random Forest (0.7972)
Meilleur F1-Score:  Logistic Regression (0.7174)

=== M√©triques par classe - Logistic Regression ===


                                                                                


Classe 1:


                                                                                

  Precision: 0.0000
  Recall:    0.0000
  F1-Score:  0.0000
  Support:   13598

Classe 3:


                                                                                

  Precision: 0.3512
  Recall:    0.0419
  F1-Score:  0.0748
  Support:   258740

Classe 4:


                                                                                

  Precision: 0.2273
  Recall:    0.0006
  F1-Score:  0.0012
  Support:   40793

Classe 2:


                                                                                

  Precision: 0.8017
  Recall:    0.9854
  F1-Score:  0.8841
  Support:   1231012

=== M√©triques par classe - MLP Neural Network ===


                                                                                


Classe 1:


                                                                                

  Precision: 0.0000
  Recall:    0.0000
  F1-Score:  0.0000
  Support:   13598

Classe 3:


                                                                                

  Precision: 0.0000
  Recall:    0.0000
  F1-Score:  0.0000
  Support:   258740

Classe 4:


                                                                                

  Precision: 0.0000
  Recall:    0.0000
  F1-Score:  0.0000
  Support:   40793

Classe 2:




  Precision: 0.7972
  Recall:    1.0000
  F1-Score:  0.8872
  Support:   1231012


                                                                                

## Grid Search Logistic Regression

In [28]:
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator
from pyspark.ml import Pipeline
from pyspark.sql.functions import col
from pyspark.storagelevel import StorageLevel
import time

# 1. Pr√©paration des donn√©es
data = data.withColumn("label", col("Severity") - 1)

# Split en train / validation / test
train_df, temp_df = data.randomSplit([0.7, 0.3], seed=42)
val_df, test_df = temp_df.randomSplit([0.5, 0.5], seed=42)

print(f"Train set size: {train_df.count()}")
print(f"Validation set size: {val_df.count()}")
print(f"Test set size: {test_df.count()}")

# Cache les datasets (utilise MEMORY_AND_DISK pour √©viter l'OutOfMemoryError)
train_df.persist(StorageLevel.MEMORY_AND_DISK)
val_df.persist(StorageLevel.MEMORY_AND_DISK)
test_df.persist(StorageLevel.MEMORY_AND_DISK)

# 2. Logistic Regression Model
lr = LogisticRegression(
    featuresCol="features",
    labelCol="label",
    maxIter=100,
    elasticNetParam=0.0,  # 0 = L2 (Ridge), 1 = L1 (Lasso), between = ElasticNet
)

# 3. Evaluator
evaluator = MulticlassClassificationEvaluator(
    labelCol="label",
    predictionCol="prediction",
    metricName="accuracy"
)

# 4. Param Grid (modifiable selon complexit√©)
paramGrid = ParamGridBuilder() \
    .addGrid(lr.regParam, [0.0, 0.01, 0.1, 0.5, 1.0]) \
    .addGrid(lr.elasticNetParam, [0.0, 0.25, 0.5, 0.75, 1.0]) \
    .addGrid(lr.maxIter, [50, 100, 200]) \
    .build()

print(f"Total parameter combinations: {len(paramGrid)}")

# 5. CrossValidator
crossval = CrossValidator(
    estimator=lr,
    estimatorParamMaps=paramGrid,
    evaluator=evaluator,
    numFolds=3,
    seed=42,
    parallelism=2  # Ajuste selon ton CPU
)

print("Starting grid search...")
start_time = time.time()

# 6. Fit du mod√®le
cv_model = crossval.fit(train_df)

end_time = time.time()
print(f"Grid search completed in {(end_time - start_time)/60:.2f} minutes")

# 7. Meilleur mod√®le
best_model = cv_model.bestModel
print("\n=== BEST PARAMETERS ===")
print(f"Reg Param: {best_model.getRegParam()}")
print(f"ElasticNet Param: {best_model.getElasticNetParam()}")

# 8. √âvaluation sur validation set
val_predictions = best_model.transform(val_df)
val_accuracy = evaluator.evaluate(val_predictions)
print(f"\nValidation Accuracy: {val_accuracy:.4f}")

# Autres m√©triques
evaluator_f1 = MulticlassClassificationEvaluator(labelCol="label", predictionCol="prediction", metricName="f1")
evaluator_precision = MulticlassClassificationEvaluator(labelCol="label", predictionCol="prediction", metricName="weightedPrecision")
evaluator_recall = MulticlassClassificationEvaluator(labelCol="label", predictionCol="prediction", metricName="weightedRecall")

val_f1 = evaluator_f1.evaluate(val_predictions)
val_precision = evaluator_precision.evaluate(val_predictions)
val_recall = evaluator_recall.evaluate(val_predictions)

print(f"Validation F1 Score: {val_f1:.4f}")
print(f"Validation Precision: {val_precision:.4f}")
print(f"Validation Recall: {val_recall:.4f}")

# 9. √âvaluation finale sur le test set
test_predictions = best_model.transform(test_df)
test_accuracy = evaluator.evaluate(test_predictions)
test_f1 = evaluator_f1.evaluate(test_predictions)
test_precision = evaluator_precision.evaluate(test_predictions)
test_recall = evaluator_recall.evaluate(test_predictions)

print("\n=== FINAL TEST RESULTS ===")
print(f"Test Accuracy: {test_accuracy:.4f}")
print(f"Test F1 Score: {test_f1:.4f}")
print(f"Test Precision: {test_precision:.4f}")
print(f"Test Recall: {test_recall:.4f}")

print("\nGrid search complete!")

                                                                                

Train set size: 5410682


                                                                                

Validation set size: 1160026


                                                                                

Test set size: 1157686
Total parameter combinations: 75
Starting grid search...


25/06/19 08:39:01 ERROR StrongWolfeLineSearch: Encountered bad values in function evaluation. Decreasing step size to 0.5
25/06/19 08:39:02 ERROR StrongWolfeLineSearch: Encountered bad values in function evaluation. Decreasing step size to 0.25
25/06/19 08:39:03 ERROR StrongWolfeLineSearch: Encountered bad values in function evaluation. Decreasing step size to 0.5
25/06/19 08:39:04 ERROR StrongWolfeLineSearch: Encountered bad values in function evaluation. Decreasing step size to 0.25
25/06/19 08:39:04 ERROR StrongWolfeLineSearch: Encountered bad values in function evaluation. Decreasing step size to 0.125
25/06/19 08:39:05 ERROR StrongWolfeLineSearch: Encountered bad values in function evaluation. Decreasing step size to 0.0625
25/06/19 08:39:06 ERROR StrongWolfeLineSearch: Encountered bad values in function evaluation. Decreasing step size to 0.0703125
25/06/19 08:39:13 ERROR LBFGS: Failure! Resetting history: breeze.optimize.FirstOrderException: Line search zoom failed
25/06/19 08:3

Grid search completed in 70.23 minutes

=== BEST PARAMETERS ===
Reg Param: 0.1
ElasticNet Param: 0.25


                                                                                


Validation Accuracy: 0.7965


                                                                                

Validation F1 Score: 0.7063
Validation Precision: 0.6344
Validation Recall: 0.7965


                                                                                


=== FINAL TEST RESULTS ===
Test Accuracy: 0.7976
Test F1 Score: 0.7078
Test Precision: 0.6362
Test Recall: 0.7976

Grid search complete!


## Recherche d'hyperparam√®tres Random Forest

In [20]:
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

# Entra√Ænement du mod√®le Random Forest
rf = RandomForestClassifier(labelCol="Severity", featuresCol="features", numTrees=50, maxDepth=10, maxBins=64, seed=42)
rf_model = rf.fit(train_df)
rf_pred = rf_model.transform(test_df)
rf_pred.select("features", "Severity", "prediction", "probability").show(5)

# √âvaluation avec diff√©rentes m√©triques
# Accuracy
rf_evaluator_accuracy = MulticlassClassificationEvaluator(labelCol="Severity", predictionCol="prediction", metricName="accuracy")
rf_accuracy = rf_evaluator_accuracy.evaluate(rf_pred)
print("Test set accuracy with Random Forest:", rf_accuracy)

# Precision (macro-average)
rf_evaluator_precision = MulticlassClassificationEvaluator(labelCol="Severity", predictionCol="prediction", metricName="weightedPrecision")
rf_precision = rf_evaluator_precision.evaluate(rf_pred)
print("Test set precision with Random Forest:", rf_precision)

# Recall (macro-average)
rf_evaluator_recall = MulticlassClassificationEvaluator(labelCol="Severity", predictionCol="prediction", metricName="weightedRecall")
rf_recall = rf_evaluator_recall.evaluate(rf_pred)
print("Test set recall with Random Forest:", rf_recall)

# F1-Score (macro-average)
rf_evaluator_f1 = MulticlassClassificationEvaluator(labelCol="Severity", predictionCol="prediction", metricName="f1")
rf_f1 = rf_evaluator_f1.evaluate(rf_pred)
print("Test set F1-score with Random Forest:", rf_f1)

# R√©sum√© des m√©triques
print("\n=== R√©sum√© des m√©triques Random Forest ===")
print(f"Accuracy:  {rf_accuracy:.4f}")
print(f"Precision: {rf_precision:.4f}")
print(f"Recall:    {rf_recall:.4f}")
print(f"F1-Score:  {rf_f1:.4f}")

from pyspark.sql.functions import col

# Affichage de la distribution des pr√©dictions par classe
print("\n=== Distribution des pr√©dictions ===")
rf_pred.groupBy("Severity", "prediction").count().orderBy("Severity", "prediction").show()

# Pour obtenir des m√©triques par classe individuellement
print("\n=== M√©triques par classe ===")
classes = rf_pred.select("Severity").distinct().collect()
for row in classes:
    class_label = row["Severity"]
    print(f"\nClasse {class_label}:")
    
    # Cr√©er des pr√©dictions binaires pour cette classe
    binary_pred = rf_pred.withColumn("binary_label", 
                                   (col("Severity") == class_label).cast("double")) \
                        .withColumn("binary_prediction", 
                                   (col("prediction") == class_label).cast("double"))
    
    # Calculer les m√©triques pour cette classe
    tp = binary_pred.filter((col("binary_label") == 1.0) & (col("binary_prediction") == 1.0)).count()
    fp = binary_pred.filter((col("binary_label") == 0.0) & (col("binary_prediction") == 1.0)).count()
    tn = binary_pred.filter((col("binary_label") == 0.0) & (col("binary_prediction") == 0.0)).count()
    fn = binary_pred.filter((col("binary_label") == 1.0) & (col("binary_prediction") == 0.0)).count()
    
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0
    f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
    
    print(f"  Precision: {precision:.4f}")
    print(f"  Recall:    {recall:.4f}")
    print(f"  F1-Score:  {f1:.4f}")
    print(f"  Support:   {tp + fn}")

                                                                                

+--------------------+--------+----------+--------------------+
|            features|Severity|prediction|         probability|
+--------------------+--------+----------+--------------------+
|(37,[0,1,5,14,20,...|       1|       2.0|[0.0,0.0036939006...|
|(37,[0,1,20,21,22...|       1|       2.0|[0.0,0.0026071086...|
|(37,[0,1,20,21,22...|       1|       2.0|[0.0,0.0033414416...|
|(37,[0,1,2,11,14,...|       1|       2.0|[0.0,0.0044547536...|
|(37,[0,1,5,14,20,...|       1|       2.0|[0.0,0.0060030378...|
+--------------------+--------+----------+--------------------+
only showing top 5 rows



                                                                                

Test set accuracy with Random Forest: 0.8152075293544704


                                                                                

Test set precision with Random Forest: 0.8034139370473914


                                                                                

Test set recall with Random Forest: 0.8152075293544704


                                                                                

Test set F1-score with Random Forest: 0.7568436219036964

=== R√©sum√© des m√©triques Random Forest ===
Accuracy:  0.8152
Precision: 0.8034
Recall:    0.8152
F1-Score:  0.7568

=== Distribution des pr√©dictions ===


                                                                                

+--------+----------+-------+
|Severity|prediction|  count|
+--------+----------+-------+
|       1|       1.0|    209|
|       1|       2.0|  13159|
|       1|       3.0|    230|
|       2|       1.0|     11|
|       2|       2.0|1219455|
|       2|       3.0|  11418|
|       2|       4.0|    128|
|       3|       1.0|      6|
|       3|       2.0| 220102|
|       3|       3.0|  38460|
|       3|       4.0|    172|
|       4|       2.0|  38745|
|       4|       3.0|   1375|
|       4|       4.0|    673|
+--------+----------+-------+


=== M√©triques par classe ===


                                                                                


Classe 1:


                                                                                

  Precision: 0.9248
  Recall:    0.0154
  F1-Score:  0.0302
  Support:   13598

Classe 3:


                                                                                

  Precision: 0.7470
  Recall:    0.1486
  F1-Score:  0.2480
  Support:   258740

Classe 4:


                                                                                

  Precision: 0.6917
  Recall:    0.0165
  F1-Score:  0.0322
  Support:   40793

Classe 2:




  Precision: 0.8176
  Recall:    0.9906
  F1-Score:  0.8958
  Support:   1231012


                                                                                

In [21]:
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

# Entra√Ænement du mod√®le Random Forest
rf = RandomForestClassifier(labelCol="Severity", featuresCol="features", numTrees=50, maxDepth=10, maxBins=128, seed=42)
rf_model = rf.fit(train_df)
rf_pred = rf_model.transform(test_df)
rf_pred.select("features", "Severity", "prediction", "probability").show(5)

# √âvaluation avec diff√©rentes m√©triques
# Accuracy
rf_evaluator_accuracy = MulticlassClassificationEvaluator(labelCol="Severity", predictionCol="prediction", metricName="accuracy")
rf_accuracy = rf_evaluator_accuracy.evaluate(rf_pred)
print("Test set accuracy with Random Forest:", rf_accuracy)

# Precision (macro-average)
rf_evaluator_precision = MulticlassClassificationEvaluator(labelCol="Severity", predictionCol="prediction", metricName="weightedPrecision")
rf_precision = rf_evaluator_precision.evaluate(rf_pred)
print("Test set precision with Random Forest:", rf_precision)

# Recall (macro-average)
rf_evaluator_recall = MulticlassClassificationEvaluator(labelCol="Severity", predictionCol="prediction", metricName="weightedRecall")
rf_recall = rf_evaluator_recall.evaluate(rf_pred)
print("Test set recall with Random Forest:", rf_recall)

# F1-Score (macro-average)
rf_evaluator_f1 = MulticlassClassificationEvaluator(labelCol="Severity", predictionCol="prediction", metricName="f1")
rf_f1 = rf_evaluator_f1.evaluate(rf_pred)
print("Test set F1-score with Random Forest:", rf_f1)

# R√©sum√© des m√©triques
print("\n=== R√©sum√© des m√©triques Random Forest ===")
print(f"Accuracy:  {rf_accuracy:.4f}")
print(f"Precision: {rf_precision:.4f}")
print(f"Recall:    {rf_recall:.4f}")
print(f"F1-Score:  {rf_f1:.4f}")

from pyspark.sql.functions import col

# Affichage de la distribution des pr√©dictions par classe
print("\n=== Distribution des pr√©dictions ===")
rf_pred.groupBy("Severity", "prediction").count().orderBy("Severity", "prediction").show()

# Pour obtenir des m√©triques par classe individuellement
print("\n=== M√©triques par classe ===")
classes = rf_pred.select("Severity").distinct().collect()
for row in classes:
    class_label = row["Severity"]
    print(f"\nClasse {class_label}:")
    
    # Cr√©er des pr√©dictions binaires pour cette classe
    binary_pred = rf_pred.withColumn("binary_label", 
                                   (col("Severity") == class_label).cast("double")) \
                        .withColumn("binary_prediction", 
                                   (col("prediction") == class_label).cast("double"))
    
    # Calculer les m√©triques pour cette classe
    tp = binary_pred.filter((col("binary_label") == 1.0) & (col("binary_prediction") == 1.0)).count()
    fp = binary_pred.filter((col("binary_label") == 0.0) & (col("binary_prediction") == 1.0)).count()
    tn = binary_pred.filter((col("binary_label") == 0.0) & (col("binary_prediction") == 0.0)).count()
    fn = binary_pred.filter((col("binary_label") == 1.0) & (col("binary_prediction") == 0.0)).count()
    
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0
    f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
    
    print(f"  Precision: {precision:.4f}")
    print(f"  Recall:    {recall:.4f}")
    print(f"  F1-Score:  {f1:.4f}")
    print(f"  Support:   {tp + fn}")

                                                                                

+--------------------+--------+----------+--------------------+
|            features|Severity|prediction|         probability|
+--------------------+--------+----------+--------------------+
|(37,[0,1,5,14,20,...|       1|       2.0|[0.0,0.0044409988...|
|(37,[0,1,20,21,22...|       1|       2.0|[0.0,0.0024579079...|
|(37,[0,1,20,21,22...|       1|       2.0|[0.0,0.0034302661...|
|(37,[0,1,2,11,14,...|       1|       2.0|[0.0,0.0055644805...|
|(37,[0,1,5,14,20,...|       1|       2.0|[0.0,0.0044348316...|
+--------------------+--------+----------+--------------------+
only showing top 5 rows



                                                                                

Test set accuracy with Random Forest: 0.8161633993742807


                                                                                

Test set precision with Random Forest: 0.8029935660074855


                                                                                

Test set recall with Random Forest: 0.8161633993742807


                                                                                

Test set F1-score with Random Forest: 0.7602770588550031

=== R√©sum√© des m√©triques Random Forest ===
Accuracy:  0.8162
Precision: 0.8030
Recall:    0.8162
F1-Score:  0.7603

=== Distribution des pr√©dictions ===


                                                                                

+--------+----------+-------+
|Severity|prediction|  count|
+--------+----------+-------+
|       1|       1.0|    208|
|       1|       2.0|  13195|
|       1|       3.0|    195|
|       2|       1.0|      5|
|       2|       2.0|1217167|
|       2|       3.0|  13733|
|       2|       4.0|    107|
|       3|       1.0|      6|
|       3|       2.0| 216304|
|       3|       3.0|  42294|
|       3|       4.0|    136|
|       4|       1.0|      1|
|       4|       2.0|  38557|
|       4|       3.0|   1631|
|       4|       4.0|    604|
+--------+----------+-------+


=== M√©triques par classe ===


                                                                                


Classe 1:


                                                                                

  Precision: 0.9455
  Recall:    0.0153
  F1-Score:  0.0301
  Support:   13598

Classe 3:


                                                                                

  Precision: 0.7311
  Recall:    0.1635
  F1-Score:  0.2672
  Support:   258740

Classe 4:


                                                                                

  Precision: 0.7131
  Recall:    0.0148
  F1-Score:  0.0290
  Support:   40793

Classe 2:




  Precision: 0.8195
  Recall:    0.9888
  F1-Score:  0.8962
  Support:   1231012


                                                                                

# üìà R√©sultats des Mod√®les de Pr√©diction

Apr√®s avoir pr√©par√© les donn√©es, nous avons entra√Æn√© plusieurs mod√®les de machine learning pour pr√©dire la gravit√© des accidents (`Severity`). Voici un r√©sum√© des performances obtenues :

### üîç M√©triques d'√©valuation
- **Accuracy (Pr√©cision globale)** : mesure la proportion de bonnes pr√©dictions.
- **F1-Score** : √©quilibre entre pr√©cision et rappel pour les classes d√©s√©quilibr√©es.
- **Matrice de confusion** : permet d'analyser les erreurs par classe.

### üìä Mod√®les test√©s
- `Logistic Regression`
- `Random Forest`
- `Multilayer Perceptron Classifier`

### üèÜ Observations
- Le mod√®le **Random Forest** obtient g√©n√©ralement la meilleure pr√©cision et un bon compromis entre biais et variance.
- **Logistic Regression** est rapide mais montre des limites sur des donn√©es complexes ou d√©s√©quilibr√©es.
- Le mod√®le **Multilayer Perceptron Classifier** offre des r√©sultats comparables √† Random Forest mais avec un temps d'entra√Ænement plus √©lev√©.

### ‚ö†Ô∏è Remarques
- La classe majoritaire est souvent sur-repr√©sent√©e, ce qui peut biaiser les m√©triques.
- Il peut √™tre utile d'ajuster les poids de classe ou d'utiliser un **√©chantillonnage √©quilibr√©**.
