In [1]:
from pyspark import SparkFiles

from pyspark.sql import SparkSession
from pyspark.ml import Pipeline
from pyspark.ml.feature import StringIndexer, VectorAssembler, OneHotEncoder
from pyspark.ml.classification import DecisionTreeClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder

spark = SparkSession.builder.appName("Decision Tree Model").getOrCreate()

24/09/27 17:20:29 WARN Utils: Your hostname, AI-CJB-LAP-459 resolves to a loopback address: 127.0.1.1; using 192.168.1.164 instead (on interface wlp0s20f3)
24/09/27 17:20:29 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/09/27 17:20:29 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
24/09/27 17:20:30 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.


In [2]:
url = "https://raw.githubusercontent.com/selva86/datasets/master/Iris.csv"
spark.sparkContext.addFile(url)

df = spark.read.csv(SparkFiles.get("Iris.csv"), header=True, inferSchema=True)
df.show(5)

+---+-------------+------------+-------------+------------+-----------+
| Id|SepalLengthCm|SepalWidthCm|PetalLengthCm|PetalWidthCm|    Species|
+---+-------------+------------+-------------+------------+-----------+
|  1|          5.1|         3.5|          1.4|         0.2|Iris-setosa|
|  2|          4.9|         3.0|          1.4|         0.2|Iris-setosa|
|  3|          4.7|         3.2|          1.3|         0.2|Iris-setosa|
|  4|          4.6|         3.1|          1.5|         0.2|Iris-setosa|
|  5|          5.0|         3.6|          1.4|         0.2|Iris-setosa|
+---+-------------+------------+-------------+------------+-----------+
only showing top 5 rows



In [3]:
# Convert the categorical labels in the 'Species' column to numerical values
label_indexer = StringIndexer(inputCol="Species", outputCol="label")
data = label_indexer.fit(df).transform(df)

# Assemble the feature columns into a single vector column
assembler = VectorAssembler(inputCols=["SepalLengthCm", "SepalWidthCm", "PetalLengthCm", "PetalWidthCm"], outputCol="features")
data = assembler.transform(data)

# Split data into training and testing sets
train_data, test_data = data.randomSplit([0.8, 0.2], seed=42)

In [4]:
dt_classifier = DecisionTreeClassifier(labelCol="label", featuresCol="features")

model = dt_classifier.fit(train_data)

In [5]:
predictions = model.transform(test_data)

evaluator = MulticlassClassificationEvaluator(labelCol="label", predictionCol="prediction", metricName="accuracy")
accuracy = evaluator.evaluate(predictions)

print(f"Test Accuracy: {accuracy:.2f}")

Test Accuracy: 0.92


In [6]:
feature_importance = model.featureImportances.toArray()

# Show feature importance
for i, column in enumerate(assembler.getInputCols()):
    print(f"Feature '{column}': {feature_importance[i]:.2f}")

Feature 'SepalLengthCm': 0.00
Feature 'SepalWidthCm': 0.02
Feature 'PetalLengthCm': 0.53
Feature 'PetalWidthCm': 0.45


In [7]:
print(model.toDebugString)

DecisionTreeClassificationModel: uid=DecisionTreeClassifier_8e6b3ae511ab, depth=4, numNodes=13, numClasses=3, numFeatures=4
  If (feature 2 <= 2.45)
   Predict: 0.0
  Else (feature 2 > 2.45)
   If (feature 3 <= 1.65)
    If (feature 2 <= 4.95)
     Predict: 1.0
    Else (feature 2 > 4.95)
     If (feature 3 <= 1.55)
      Predict: 2.0
     Else (feature 3 > 1.55)
      Predict: 1.0
   Else (feature 3 > 1.65)
    If (feature 2 <= 4.85)
     If (feature 1 <= 3.05)
      Predict: 2.0
     Else (feature 1 > 3.05)
      Predict: 1.0
    Else (feature 2 > 4.85)
     Predict: 2.0



In [8]:
# Save the model
model.save("Dtree_model")

# Load the model
from pyspark.ml.classification import DecisionTreeClassificationModel
loaded_model = DecisionTreeClassificationModel.load("Dtree_model")