In [2]:
from pyspark.sql import SparkSession

# Créer une session Spark
spark = SparkSession.builder \
    .appName("Mushroom Classification") \
    .getOrCreate()


In [5]:
# Remplace le chemin par celui du fichier CSV
data = spark.read.csv("mushrooms.csv", header=True, inferSchema=True)
data.show(5)  # Afficher les 5 premières lignes pour vérifier le chargement


+------+----------+------------+----------+--------+-----+----------------+-------------+----------+-----------+------------+-----------+-------------------------+-------------------------+-----------------------+-----------------------+----------+-----------+------------+----------+------------------+-----------+-------+
|class |cap-shape |cap-surface |cap-color |bruises |odor |gill-attachment |gill-spacing |gill-size |gill-color |stalk-shape |stalk-root |stalk-surface-above-ring |stalk-surface-below-ring |stalk-color-above-ring |stalk-color-below-ring |veil-type |veil-color |ring-number |ring-type |spore-print-color |population |habitat|
+------+----------+------------+----------+--------+-----+----------------+-------------+----------+-----------+------------+-----------+-------------------------+-------------------------+-----------------------+-----------------------+----------+-----------+------------+----------+------------------+-----------+-------+
|p     |x         |s        

In [8]:
from pyspark.ml.feature import StringIndexer
# Retire l'espace de fin pour chaque nom de colonne
for col_name in data.columns:
    data = data.withColumnRenamed(col_name, col_name.strip())

# Encode la colonne 'class' en numérique pour créer la colonne 'label'
indexer = StringIndexer(inputCol="class", outputCol="label")
data = indexer.fit(data).transform(data)
data.select("class", "label").show(5)  # Vérifie l'encodage


+------+-----+
| class|label|
+------+-----+
|p     |  1.0|
|e     |  0.0|
|e     |  0.0|
|p     |  1.0|
|e     |  0.0|
+------+-----+
only showing top 5 rows



In [11]:
from pyspark.ml.feature import StringIndexer

# Liste des colonnes catégorielles à encoder (toutes sauf "label" et "features")
categorical_columns = [col for col in data.columns if col not in ["class", "label"]]

# Appliquer StringIndexer sur chaque colonne catégorielle
for col_name in categorical_columns:
    indexer = StringIndexer(inputCol=col_name, outputCol=f"{col_name}_index")
    data = indexer.fit(data).transform(data)

# Vérifie l’encodage
data.select([f"{col}_index" for col in categorical_columns] + ["label"]).show(5)


+---------------+-----------------+---------------+-------------+----------+---------------------+------------------+---------------+----------------+-----------------+----------------+------------------------------+------------------------------+----------------------------+----------------------------+---------------+----------------+-----------------+---------------+-----------------------+----------------+-------------+-----+
|cap-shape_index|cap-surface_index|cap-color_index|bruises_index|odor_index|gill-attachment_index|gill-spacing_index|gill-size_index|gill-color_index|stalk-shape_index|stalk-root_index|stalk-surface-above-ring_index|stalk-surface-below-ring_index|stalk-color-above-ring_index|stalk-color-below-ring_index|veil-type_index|veil-color_index|ring-number_index|ring-type_index|spore-print-color_index|population_index|habitat_index|label|
+---------------+-----------------+---------------+-------------+----------+---------------------+------------------+---------------

In [12]:
from pyspark.ml.feature import VectorAssembler

# Utiliser les colonnes encodées dans VectorAssembler
assembler = VectorAssembler(inputCols=[f"{col}_index" for col in categorical_columns], outputCol="features")
data = assembler.transform(data)
data.select("features", "label").show(5)  # Vérifie l'assemblage des caractéristiques


+--------------------+-----+
|            features|label|
+--------------------+-----+
|(22,[1,3,4,7,8,9,...|  1.0|
|(22,[1,2,3,4,8,9,...|  0.0|
|(22,[0,1,2,3,4,8,...|  0.0|
|(22,[2,3,4,7,8,9,...|  1.0|
|(22,[1,2,6,8,10,1...|  0.0|
+--------------------+-----+
only showing top 5 rows



In [13]:
train_data, test_data = data.randomSplit([0.8, 0.2], seed=42)


In [14]:
from pyspark.ml.classification import LogisticRegression

# Initialiser et entraîner le modèle de régression logistique
lr = LogisticRegression(featuresCol="features", labelCol="label")
model = lr.fit(train_data)


In [16]:
# Prédictions sur les données de test
predictions = model.transform(test_data)


In [21]:
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

# Précision
evaluator_precision = MulticlassClassificationEvaluator(labelCol="label", predictionCol="prediction", metricName="precisionByLabel")
precision = evaluator_precision.evaluate(predictions)

# Rappel
evaluator_recall = MulticlassClassificationEvaluator(labelCol="label", predictionCol="prediction", metricName="recallByLabel")
recall = evaluator_recall.evaluate(predictions)

# F1-score
evaluator_f1 = MulticlassClassificationEvaluator(labelCol="label", predictionCol="prediction", metricName="f1")
f1_score = evaluator_f1.evaluate(predictions)





In [22]:
# Tableau de contingence pour observer FP et VP
predictions.groupBy("label", "prediction").count().show()


+-----+----------+-----+
|label|prediction|count|
+-----+----------+-----+
|  1.0|       1.0|  748|
|  0.0|       1.0|    6|
|  1.0|       0.0|   15|
|  0.0|       0.0|  782|
+-----+----------+-----+



In [24]:
from pyspark.ml.evaluation import BinaryClassificationEvaluator

# AUC pour la courbe ROC
evaluator_roc = BinaryClassificationEvaluator(labelCol="label", rawPredictionCol="rawPrediction", metricName="areaUnderROC")
auc = evaluator_roc.evaluate(predictions)

from pyspark.ml.evaluation import BinaryClassificationEvaluator

# AUC pour la courbe ROC
evaluator_roc = BinaryClassificationEvaluator(labelCol="label", rawPredictionCol="rawPrediction", metricName="areaUnderROC")
auc = evaluator_roc.evaluate(predictions)

print(f"Précision : {precision:.2f}")
print(f"Rappel : {recall:.2f}")
print(f"F1-score : {f1_score:.2f}")
print(f"AUC de la courbe ROC : {auc:.2f}")


Précision : 0.98
Rappel : 0.99
F1-score : 0.99
AUC de la courbe ROC : 1.00


In [25]:
# Tableau de contingence pour observer FP et VP
predictions.groupBy("label", "prediction").count().show()


+-----+----------+-----+
|label|prediction|count|
+-----+----------+-----+
|  1.0|       1.0|  748|
|  0.0|       1.0|    6|
|  1.0|       0.0|   15|
|  0.0|       0.0|  782|
+-----+----------+-----+



In [28]:
model.save("mushroom_classification_model")

In [27]:
!pip install Flask


Collecting Flask
  Downloading flask-3.0.3-py3-none-any.whl.metadata (3.2 kB)
Collecting Werkzeug>=3.0.0 (from Flask)
  Downloading werkzeug-3.0.6-py3-none-any.whl.metadata (3.7 kB)
Collecting itsdangerous>=2.1.2 (from Flask)
  Downloading itsdangerous-2.2.0-py3-none-any.whl.metadata (1.9 kB)
Downloading flask-3.0.3-py3-none-any.whl (101 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m101.7/101.7 kB[0m [31m2.3 MB/s[0m eta [36m0:00:00[0m [36m0:00:01[0m
[?25hDownloading itsdangerous-2.2.0-py3-none-any.whl (16 kB)
Downloading werkzeug-3.0.6-py3-none-any.whl (227 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m228.0/228.0 kB[0m [31m5.7 MB/s[0m eta [36m0:00:00[0m00:01[0m
[?25hInstalling collected packages: Werkzeug, itsdangerous, Flask
Successfully installed Flask-3.0.3 Werkzeug-3.0.6 itsdangerous-2.2.0


In [29]:
# Importer les bibliothèques
from flask import Flask, request, jsonify
from pyspark.ml.classification import LogisticRegressionModel
from pyspark.ml.linalg import Vectors
import json
from werkzeug.serving import make_server
import threading

# Charger le modèle PySpark
model = LogisticRegressionModel.load("mushroom_classification_model")

# Initialiser l’application Flask
app = Flask(__name__)

# Créer le point d’entrée pour les prédictions
@app.route('/predict', methods=['POST'])
def predict():
    try:
        # Recevoir les données JSON du champignon à classifier
        data = request.json
        # Convertir les caractéristiques en vecteur
        features = Vectors.dense(data["features"])
        
        # Faire la prédiction
        prediction = model.predict(features)
        
        # Retourner la prédiction
        return jsonify({"prediction": int(prediction)})
    except Exception as e:
        return jsonify({"error": str(e)})

# Lancer le serveur Flask dans un thread séparé pour le maintenir actif dans Jupyter
class FlaskThread(threading.Thread):
    def __init__(self, app):
        threading.Thread.__init__(self)
        self.server = make_server('0.0.0.0', 5000, app)
        self.ctx = app.app_context()
        self.ctx.push()

    def run(self):
        print("API démarrée sur http://localhost:5000")
        self.server.serve_forever()

    def shutdown(self):
        self.server.shutdown()

# Créer et démarrer le thread Flask
flask_thread = FlaskThread(app)
flask_thread.start()


API démarrée sur http://localhost:5000


In [32]:
import requests

# Exemple de caractéristiques (remplace par des valeurs réelles du champignon)
features = [0.0, 1.0, 3.0, 1.0, 4.0, 0.0, 0.0, 1.0, 7.0, 1.0, 2.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 2.0, 4.0]

response = requests.post("http://localhost:5000/predict", json={"features": features})
print(response.json())


INFO:werkzeug:127.0.0.1 - - [30/Oct/2024 09:36:04] "POST /predict HTTP/1.1" 200 -


{'prediction': 1}


In [None]:
"""
0 : Champignon comestible.
1 : Champignon toxique.
"""