In [2]:
from pyspark.sql import SparkSession

# Créer une session Spark
spark = SparkSession.builder.appName("Mushroom Classification").getOrCreate()

# Charger le fichier CSV
df = spark.read.csv("mushrooms.csv", header=True, inferSchema=True)
df.show()

+-----+---------+-----------+---------+-------+----+---------------+------------+---------+----------+-----------+----------+------------------------+------------------------+----------------------+----------------------+---------+----------+-----------+---------+-----------------+----------+-------+
|class|cap-shape|cap-surface|cap-color|bruises|odor|gill-attachment|gill-spacing|gill-size|gill-color|stalk-shape|stalk-root|stalk-surface-above-ring|stalk-surface-below-ring|stalk-color-above-ring|stalk-color-below-ring|veil-type|veil-color|ring-number|ring-type|spore-print-color|population|habitat|
+-----+---------+-----------+---------+-------+----+---------------+------------+---------+----------+-----------+----------+------------------------+------------------------+----------------------+----------------------+---------+----------+-----------+---------+-----------------+----------+-------+
|    p|        x|          s|        n|      t|   p|              f|           c|        n|   

In [3]:
from pyspark.ml.feature import StringIndexer

# Indexer la colonne de classe (venimeux ou non)
indexer = StringIndexer(inputCol="class", outputCol="label")
df = indexer.fit(df).transform(df)

# Indexer les autres colonnes catégorielles
categorical_cols = df.columns[1:]  # toutes les colonnes sauf 'class' et 'label'
for col in categorical_cols:
    indexer = StringIndexer(inputCol=col, outputCol=col + "_index")
    df = indexer.fit(df).transform(df)

# Créer un nouveau DataFrame avec les colonnes indexées
indexed_columns = [col + "_index" for col in categorical_cols] + ["label"]
df_indexed = df.select(indexed_columns)
df_indexed.show()

+---------------+-----------------+---------------+-------------+----------+---------------------+------------------+---------------+----------------+-----------------+----------------+------------------------------+------------------------------+----------------------------+----------------------------+---------------+----------------+-----------------+---------------+-----------------------+----------------+-------------+-----------+-----+
|cap-shape_index|cap-surface_index|cap-color_index|bruises_index|odor_index|gill-attachment_index|gill-spacing_index|gill-size_index|gill-color_index|stalk-shape_index|stalk-root_index|stalk-surface-above-ring_index|stalk-surface-below-ring_index|stalk-color-above-ring_index|stalk-color-below-ring_index|veil-type_index|veil-color_index|ring-number_index|ring-type_index|spore-print-color_index|population_index|habitat_index|label_index|label|
+---------------+-----------------+---------------+-------------+----------+---------------------+----------

In [4]:
from pyspark.ml.feature import VectorAssembler

assembler = VectorAssembler(inputCols=indexed_columns[:-1], outputCol="features")
data = assembler.transform(df_indexed)

# Séparer les données en ensembles d'entraînement et de test
train_data, test_data = data.randomSplit([0.8, 0.2], seed=42)

In [12]:
from pyspark.ml.classification import RandomForestClassifier

rf = RandomForestClassifier(labelCol="label", featuresCol="features")
rf_model = rf.fit(train_data)

In [10]:
predictions = rf_model.transform(test_data)
predictions.select("features", "label", "prediction").show()

+--------------------+-----+----------+
|            features|label|prediction|
+--------------------+-----+----------+
|(23,[7,8,9,10,11,...|  1.0|       1.0|
|(23,[6,7,8,9,11,1...|  0.0|       0.0|
|(23,[4,7,10,18,22...|  1.0|       1.0|
|(23,[4,7,10,14,18...|  1.0|       1.0|
|(23,[4,7,10,13,14...|  1.0|       1.0|
|(23,[4,7,10,12,14...|  1.0|       1.0|
|(23,[4,7,10,12,13...|  1.0|       1.0|
|(23,[4,7,10,11,14...|  1.0|       1.0|
|(23,[4,7,10,11,12...|  1.0|       1.0|
|(23,[4,7,10,11,12...|  1.0|       1.0|
|(23,[4,7,10,11,12...|  1.0|       1.0|
|(23,[4,7,10,11,12...|  1.0|       1.0|
|(23,[4,7,10,11,12...|  1.0|       1.0|
|(23,[4,7,10,11,12...|  1.0|       1.0|
|(23,[4,7,10,13,18...|  1.0|       1.0|
|(23,[4,7,10,12,18...|  1.0|       1.0|
|(23,[4,7,10,12,13...|  1.0|       1.0|
|(23,[4,7,10,11,12...|  1.0|       1.0|
|(23,[4,7,10,13,18...|  1.0|       1.0|
|(23,[4,7,10,12,18...|  1.0|       1.0|
+--------------------+-----+----------+
only showing top 20 rows



In [8]:
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

evaluator = MulticlassClassificationEvaluator(labelCol="label", predictionCol="prediction", metricName="accuracy")
accuracy = evaluator.evaluate(predictions)
print(f"Accuracy: {accuracy}")

Accuracy: 1.0
