In [63]:
# Loading python package
from pyspark.sql import SparkSession
from pyspark.sql.functions import when, col
from pyspark.ml.feature import VectorAssembler, StringIndexer, OneHotEncoder
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.evaluation import BinaryClassificationEvaluator, RegressionEvaluator, MulticlassClassificationEvaluator
from pyspark.ml import Pipeline

In [3]:
# Initialize a Spark session
spark = SparkSession.builder.appName("MushroomClassification").getOrCreate()

In [4]:
# Load your mushroom dataset into a DataFrame (replace 'data_file' with your dataset path)
data = spark.read.csv("/content/mushrooms.csv", header=True, inferSchema=True)

In [5]:
data.show(5)

+-----+---------+-----------+---------+-------+----+---------------+------------+---------+----------+-----------+----------+------------------------+------------------------+----------------------+----------------------+---------+----------+-----------+---------+-----------------+----------+-------+
|class|cap-shape|cap-surface|cap-color|bruises|odor|gill-attachment|gill-spacing|gill-size|gill-color|stalk-shape|stalk-root|stalk-surface-above-ring|stalk-surface-below-ring|stalk-color-above-ring|stalk-color-below-ring|veil-type|veil-color|ring-number|ring-type|spore-print-color|population|habitat|
+-----+---------+-----------+---------+-------+----+---------------+------------+---------+----------+-----------+----------+------------------------+------------------------+----------------------+----------------------+---------+----------+-----------+---------+-----------------+----------+-------+
|    p|        x|          s|        n|      t|   p|              f|           c|        n|   

Dataset Attribute Information:
*   classes: edible=e, poisonous=p
*   cap-shape: bell=b,conical=c,convex=x,flat=f, knobbed=k,sunken=s
*   cap-surface: fibrous=f,grooves=g,scaly=y,smooth=s
*   cap-color: brown=n,buff=b,cinnamon=c,gray=g,green=r,pink=p,purple=u,*   red=e,white=w,yellow=y
*   bruises: bruises=t,no=f
*   odor: almond=a,anise=l,creosote=c,fishy=y,foul=f,musty=m,none=n, pungent=p,spicy=s
*   gill-attachment: attached=a,descending=d,free=f,notched=n
*   gill-spacing: close=c,crowded=w,distant=d
*   gill-size: broad=b,narrow=n
*   gill-color: black=k,brown=n,buff=b,chocolate=h,gray=g, green=r,orange=o,pink=p,purple=u,red=e,white=w,yellow=y
*   stalk-shape: enlarging=e,tapering=t
*   stalk-root: bulbous=b,club=c,cup=u,equal=e,rhizomorphs=z,rooted=r,missing=?
*   stalk-surface-above-ring: fibrous=f,scaly=y,silky=k,smooth=s
*   stalk-surface-below-ring: fibrous=f,scaly=y,silky=k,smooth=s
*   stalk-color-above-ring: brown=n,buff=b,cinnamon=c,gray=g,orange=o,pink=p,red=e,white=w,yellow=y
*   stalk-color-below-ring: brown=n,buff=b,cinnamon=c,gray=g,orange=o,pink=p,red=e,white=w,yellow=y
*   veil-type: partial=p,universal=u
*   veil-color: brown=n,orange=o,white=w,yellow=y
*   ring-number: none=n,one=o,two=t
*   ring-type: cobwebby=c,evanescent=e,flaring=f,large=l,none=n,pendant=p,sheathing=s,zone=z
*   spore-print-color: black=k,brown=n,buff=b,chocolate=h,green=r,orange=o,purple=u,white=w,yellow=y
*   population: abundant=a,clustered=c,numerous=n,scattered=s,several=v,solitary=y
*   habitat: grasses=g,leaves=l,meadows=m,paths=p,urban=u,waste=w,woods=d

#  Dataset Preparation

In [6]:
# Rename the "class" column to "edible?"
df_renamed = data.withColumnRenamed("class", "edible?")

# Replace 'e' with 1 and 'p' with 0 in the 'class' column
df_updated = df_renamed.withColumn("edible?", when(col("edible?") == 'e', 1).otherwise(0))

# Show the updated DataFrame
df_updated.show()

+-------+---------+-----------+---------+-------+----+---------------+------------+---------+----------+-----------+----------+------------------------+------------------------+----------------------+----------------------+---------+----------+-----------+---------+-----------------+----------+-------+
|edible?|cap-shape|cap-surface|cap-color|bruises|odor|gill-attachment|gill-spacing|gill-size|gill-color|stalk-shape|stalk-root|stalk-surface-above-ring|stalk-surface-below-ring|stalk-color-above-ring|stalk-color-below-ring|veil-type|veil-color|ring-number|ring-type|spore-print-color|population|habitat|
+-------+---------+-----------+---------+-------+----+---------------+------------+---------+----------+-----------+----------+------------------------+------------------------+----------------------+----------------------+---------+----------+-----------+---------+-----------------+----------+-------+
|      0|        x|          s|        n|      t|   p|              f|           c|     

### Checking if there are features columns with only one values within the dataset

In [17]:
# cap-shape column
cap_shape_df = df_updated.groupBy("cap-shape").count()
cap_shape_df.show()

# cap-surface column
cap_surface_df = df_updated.groupBy("cap-surface").count()
cap_surface_df.show()

# cap-color column
cap_color_df = df_updated.groupBy("cap-color").count()
cap_color_df.show()

# bruises column
bruises_df = df_updated.groupBy("bruises").count()
bruises_df.show()

# odor column
odor_df = df_updated.groupBy("odor").count()
odor_df.show()

# gill-attachment column
gill_attachment_df = df_updated.groupBy("gill-attachment").count()
gill_attachment_df.show()

# gill-spacing column
gill_spacing_df = df_updated.groupBy("gill-spacing").count()
gill_spacing_df.show()

# gill-size column
gill_size_df = df_updated.groupBy("gill-size").count()
gill_size_df.show()

# gill-color column
gill_color_df = df_updated.groupBy("gill-color").count()
gill_color_df.show()

# stalk-shape column
stalk_shape_df = df_updated.groupBy("stalk-shape").count()
stalk_shape_df.show()

# stalk-root column
stalk_root_df = df_updated.groupBy("stalk-root").count()
stalk_root_df.show()

# stalk-surface-above-ring column
stalk_surface_above_ring_df = df_updated.groupBy("stalk-surface-above-ring").count()
stalk_surface_above_ring_df.show()

# stalk-surface-below-ring column
stalk_surface_below_ring_df = df_updated.groupBy("stalk-surface-below-ring").count()
stalk_surface_below_ring_df.show()

# stalk-color-above-ring
stalk_color_above_ring_df = df_updated.groupBy("stalk-color-above-ring").count()
stalk_color_above_ring_df.show()

# stalk-color-below-ring column
stalk_color_below_ring_df = df_updated.groupBy("stalk-color-below-ring").count()
stalk_color_below_ring_df.show()

# veil-type column
veil_type_df = df_updated.groupBy("veil-type").count()
veil_type_df.show()

# veil-color column
veil_color_df = df_updated.groupBy("veil-color").count()
veil_color_df.show()

# ring-number column
ring_number_df = df_updated.groupBy("ring-number").count()
ring_number_df.show()

# ring-type column
ring_type_df = df_updated.groupBy("ring-type").count()
ring_type_df.show()

# spore-print-color column
spore_print_color_df = df_updated.groupBy("spore-print-color").count()
spore_print_color_df.show()

# population column
population_df = df_updated.groupBy("population").count()
population_df.show()

# habitat column
habitat_df = df_updated.groupBy("habitat").count()
habitat_df.show()

+---------+-----+
|cap-shape|count|
+---------+-----+
|        x| 3656|
|        f| 3152|
|        k|  828|
|        c|    4|
|        b|  452|
|        s|   32|
+---------+-----+

+-----------+-----+
|cap-surface|count|
+-----------+-----+
|          g|    4|
|          f| 2320|
|          y| 3244|
|          s| 2556|
+-----------+-----+

+---------+-----+
|cap-color|count|
+---------+-----+
|        g| 1840|
|        n| 2284|
|        e| 1500|
|        p|  144|
|        y| 1072|
|        w| 1040|
|        c|   44|
|        u|   16|
|        b|  168|
|        r|   16|
+---------+-----+

+-------+-----+
|bruises|count|
+-------+-----+
|      f| 4748|
|      t| 3376|
+-------+-----+

+----+-----+
|odor|count|
+----+-----+
|   l|  400|
|   m|   36|
|   f| 2160|
|   n| 3528|
|   p|  256|
|   y|  576|
|   c|  192|
|   a|  400|
|   s|  576|
+----+-----+

+---------------+-----+
|gill-attachment|count|
+---------------+-----+
|              f| 7914|
|              a|  210|
+---------------+-

In [18]:
# We can see from the previous cell that all the values in the column veil-type are the same.
# This mean that we can remove it from the dataset without impacting the results.
cleaned_df = df_updated.drop("veil-type")
cleaned_df.show()

+-------+---------+-----------+---------+-------+----+---------------+------------+---------+----------+-----------+----------+------------------------+------------------------+----------------------+----------------------+----------+-----------+---------+-----------------+----------+-------+
|edible?|cap-shape|cap-surface|cap-color|bruises|odor|gill-attachment|gill-spacing|gill-size|gill-color|stalk-shape|stalk-root|stalk-surface-above-ring|stalk-surface-below-ring|stalk-color-above-ring|stalk-color-below-ring|veil-color|ring-number|ring-type|spore-print-color|population|habitat|
+-------+---------+-----------+---------+-------+----+---------------+------------+---------+----------+-----------+----------+------------------------+------------------------+----------------------+----------------------+----------+-----------+---------+-----------------+----------+-------+
|      0|        x|          s|        n|      t|   p|              f|           c|        n|         k|          e|  

# Data preprocessing

In [53]:
# Define a list of categorical columns
categorical_cols = [
    "cap-shape",
    "cap-surface",
    "cap-color",
    "bruises",
    "odor",
    "gill-attachment",
    "gill-spacing",
    "gill-size",
    "gill-color",
    "stalk-shape",
    "stalk-root",
    "stalk-surface-above-ring",
    "stalk-surface-below-ring",
    "stalk-color-above-ring",
    "stalk-color-below-ring",
    "veil-color",
    "ring-number",
    "ring-type",
    "spore-print-color",
    "population",
    "habitat"
]

# Initialize StringIndexer transformers
indexers = [StringIndexer(inputCol=col, outputCol=col + "_index") for col in categorical_cols]

# Fit and transform the DataFrame using the StringIndexers
indexer_models = [indexer.fit(cleaned_df) for indexer in indexers]
indexed_df = cleaned_df
for indexer_model in indexer_models:
    indexed_df = indexer_model.transform(indexed_df)

# Show the transformed DataFrame
indexed_df.show()

+-------+---------+-----------+---------+-------+----+---------------+------------+---------+----------+-----------+----------+------------------------+------------------------+----------------------+----------------------+----------+-----------+---------+-----------------+----------+-------+---------------+-----------------+---------------+-------------+----------+---------------------+------------------+---------------+----------------+-----------------+----------------+------------------------------+------------------------------+----------------------------+----------------------------+----------------+-----------------+---------------+-----------------------+----------------+-------------+
|edible?|cap-shape|cap-surface|cap-color|bruises|odor|gill-attachment|gill-spacing|gill-size|gill-color|stalk-shape|stalk-root|stalk-surface-above-ring|stalk-surface-below-ring|stalk-color-above-ring|stalk-color-below-ring|veil-color|ring-number|ring-type|spore-print-color|population|habitat|cap

In [54]:
# List of indexed categorical columns
indexed_cols = [col + "_index" for col in categorical_cols]

# Assemble the indexed categorical columns into a feature vector
assembler = VectorAssembler(inputCols=indexed_cols, outputCol="features")
assembled_df = assembler.transform(indexed_df)

# Model Training & Execution

In [58]:
# Split data into training and testing sets
train_data, test_data = assembled_df.randomSplit([0.8, 0.2], seed=42)

In [59]:
# Initializing the classifier
classifier = LogisticRegression(labelCol="edible?", featuresCol="features")

In [60]:
# Train the classifier
model = classifier.fit(train_data)

In [61]:
# Make predictions on the test data
predictions = model.transform(test_data)

In [64]:
# Evaluate the model using RegressionEvaluator
regression_evaluator = RegressionEvaluator(labelCol="edible?", predictionCol="prediction", metricName="mae")

# Calculate Mean Absolute Error
mae = regression_evaluator.evaluate(predictions)
print(f"MAE: {mae}")

# Calculate Mean Squared Error
mse = regression_evaluator.setMetricName("mse").evaluate(predictions)
print(f"MSE: {mse}")

# Calculate Root Mean Squared Error
rmse = regression_evaluator.setMetricName("rmse").evaluate(predictions)
print(f"RMSE: {rmse}")

# To calculate R-squared (R²) and Adjusted R-squared
n = test_data.count()  # Number of data points in the test set
k = len(indexed_df.columns) - 1  # Number of predictors (excluding the label column)
r2 = 1 - (mse / (n - k - 1))
print(f"R²: {r2}")

adjusted_r2 = 1 - ((1 - r2) * (n - 1) / (n - k - 1))
print(f"Adjusted R²: {adjusted_r2}")

# Calculate ROC AUC
evaluator = BinaryClassificationEvaluator(labelCol="edible?", rawPredictionCol="rawPrediction", metricName="areaUnderROC")
roc_auc = evaluator.evaluate(predictions)
print(f"ROC AUC: {roc_auc}")

MAE: 0.0070921985815602835
MSE: 0.0070921985815602835
RMSE: 0.08421519210665189
R²: 0.9999952969505427
Adjusted R²: 0.999995165963754
ROC AUC: 0.9981113801452784


In [66]:
# Make predictions on the test data
predictions = model.transform(test_data)

# Select relevant columns for printing
selected_columns = ["edible?", "prediction", "probability", "features"]

# Reminder of edible? values meaning
print("If edible? = 1 the mushroom is edible otherwise it is poisonous.")

# Show the results for each prediction
for row in predictions.select(selected_columns).collect():
    edible_label = row["edible?"]
    prediction = row["prediction"]
    probability = row["probability"]
    features = row["features"]

    print(f"Actual edible?: {edible_label}, Predicted edible?: {prediction}")
    print(f"Probability: {probability}")
    print(f"Features: {features}\n")

[1;30;43mLe flux de sortie a été tronqué et ne contient que les 5000 dernières lignes.[0m
Actual edible?: 0, Predicted edible?: 0.0
Probability: [0.9999999072322201,9.2767779902303e-08]
Features: (21,[0,1,2,4,7,10,13,14,17,20],[2.0,1.0,2.0,3.0,1.0,1.0,1.0,1.0,1.0,3.0])

Actual edible?: 0, Predicted edible?: 0.0
Probability: [0.99999983524252,1.6475747999233903e-07]
Features: (21,[0,1,2,4,7,10,14,17,20],[2.0,1.0,2.0,3.0,1.0,1.0,1.0,1.0,2.0])

Actual edible?: 0, Predicted edible?: 0.0
Probability: [0.9999989832434013,1.0167565986929361e-06]
Features: (21,[0,1,2,4,7,10,17],[2.0,1.0,2.0,3.0,1.0,1.0,1.0])

Actual edible?: 0, Predicted edible?: 0.0
Probability: [0.9918034647451828,0.008196535254817228]
Features: (21,[0,1,4,7,10,11,12,17],[2.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0])

Actual edible?: 0, Predicted edible?: 0.0
Probability: [0.9893158726642252,0.01068412733577484]
Features: (21,[0,1,4,7,10,11,17],[2.0,1.0,1.0,1.0,1.0,1.0,1.0])

Actual edible?: 0, Predicted edible?: 0.0
Probability: [0.9

In [None]:
# Stop the Spark session
spark.stop()