In [1]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, when, isnull, expr
from pyspark.ml.feature import VectorAssembler, StandardScaler
from pyspark.ml.clustering import KMeans
from pyspark.sql.functions import when
from pyspark.ml.stat import Correlation
from pyspark.ml.linalg import Vectors
from pyspark.ml.feature import StringIndexer, VectorAssembler
from pyspark.ml import Pipeline
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.sql import functions as F


# Utilisation de la base de données Hive, et charge les données de la table

In [2]:
spark = SparkSession.builder \
    .appName("Categorie") \
    .config("spark.hadoop.hive.metastore.uris", "thrift://hive-metastore:9083") \
    .enableHiveSupport() \
    .getOrCreate()


In [3]:
spark.sql("USE concessionnaire")
client_immat_df = spark.sql("SELECT * FROM client_vehicules_clusters_cleaned")

# Afficher un aperçu des données

In [4]:
client_immat_df.show(5)
client_immat_df.printSchema()


+---------------+---------+--------+--------+-------+----+---------------+-------------------+---------------+--------+
|immatriculation|puissance|longueur|nbPortes|   prix| age|nbenfantacharge|situationfamilliale|deuxiemevoiture|category|
+---------------+---------+--------+--------+-------+----+---------------+-------------------+---------------+--------+
|        0 NQ 98|     55.0|     1.0|     3.0| 8540.0|50.0|            0.0|                1.0|              1|       1|
|        0 PW 98|     75.0|     1.0|     5.0|13750.0|58.0|            0.0|                0.0|              0|       0|
|        1 AK 80|     75.0|     1.0|     5.0|18310.0|35.0|            0.0|                0.0|              0|       7|
|     1000 KN 59|     75.0|     1.0|     5.0|18310.0|24.0|            0.0|                0.0|              0|       7|
|     1001 CT 14|    150.0|     3.0|     5.0|27020.0|68.0|            2.0|                1.0|              0|       5|
+---------------+---------+--------+----

In [5]:
missing_counts = client_immat_df.select([isnull(col(c)).alias(c) for c in client_immat_df.columns]).groupby().sum().show()

++
||
++
||
++



# Préparation des données

In [6]:
training_data = client_immat_df.drop("immatriculation")

In [7]:
from pyspark.ml.feature import VectorAssembler

assembler = VectorAssembler(
    inputCols=[
        "age", "nbenfantacharge", "situationfamilliale", "deuxiemevoiture"
    ], 
    outputCol="features"
)

assembled_data = assembler.transform(training_data)

In [8]:
train_data, test_data = assembled_data.randomSplit([0.8, 0.2], seed=42)

In [9]:
from pyspark.ml.classification import RandomForestClassifier

# Initialiser le modèle
rf = RandomForestClassifier(
    featuresCol="features", 
    labelCol="category", 
    numTrees=100, 
    maxDepth=10, 
    seed=42
)

# Entraîner le modèle
rf_model = rf.fit(train_data)


In [10]:
predictions = rf_model.transform(test_data)

In [11]:
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

evaluator = MulticlassClassificationEvaluator(
    labelCol="category", 
    predictionCol="prediction", 
    metricName="accuracy"
)

accuracy = evaluator.evaluate(predictions)
print(f"Précision du modèle : {accuracy}")

Précision du modèle : 0.9894144720020455
