In [5]:
from pyspark.sql import SparkSession
from pyspark.ml.feature import StringIndexer, VectorAssembler
from pyspark.ml.classification import DecisionTreeClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
import urllib.request

# 1. Start Spark Session
spark = SparkSession.builder.appName("IrisClassification").getOrCreate()

# 2. Load dataset
data_url = "https://raw.githubusercontent.com/uiuc-cse/data-fa14/gh-pages/data/iris.csv"
local_path = "iris.csv"
urllib.request.urlretrieve(data_url, local_path)
iris_df = spark.read.csv(local_path, header=True, inferSchema=True)

# 3. Encode label column
indexer = StringIndexer(inputCol="species", outputCol="label")
iris_df = indexer.fit(iris_df).transform(iris_df)

# 4. Assemble features
assembler = VectorAssembler(
    inputCols=["sepal_length", "sepal_width", "petal_length", "petal_width"],
    outputCol="features"
)
iris_df = assembler.transform(iris_df)

# 5. Split data
train_df, test_df = iris_df.randomSplit([0.8, 0.2], seed=42)

# 6. Train Decision Tree model
dt = DecisionTreeClassifier(featuresCol="features", labelCol="label")
model = dt.fit(train_df)

# 7. Make predictions
predictions = model.transform(test_df)

# 8. Evaluate the model
evaluator = MulticlassClassificationEvaluator(labelCol="label", predictionCol="prediction", metricName="accuracy")
accuracy = evaluator.evaluate(predictions)

print(f"Test Accuracy: {accuracy:.2f}")

# Optional: Show a few predictions
predictions.select("features", "label", "prediction").show(50)

# Stop Spark session
spark.stop()


Test Accuracy: 1.00
+-----------------+-----+----------+
|         features|label|prediction|
+-----------------+-----+----------+
|[4.4,3.0,1.3,0.2]|  0.0|       0.0|
|[4.6,3.2,1.4,0.2]|  0.0|       0.0|
|[4.6,3.6,1.0,0.2]|  0.0|       0.0|
|[4.8,3.1,1.6,0.2]|  0.0|       0.0|
|[4.9,3.1,1.5,0.1]|  0.0|       0.0|
|[5.0,2.3,3.3,1.0]|  1.0|       1.0|
|[5.0,3.5,1.3,0.3]|  0.0|       0.0|
|[5.1,3.5,1.4,0.2]|  0.0|       0.0|
|[5.3,3.7,1.5,0.2]|  0.0|       0.0|
|[5.4,3.0,4.5,1.5]|  1.0|       1.0|
|[5.4,3.4,1.5,0.4]|  0.0|       0.0|
|[5.4,3.7,1.5,0.2]|  0.0|       0.0|
|[5.4,3.9,1.7,0.4]|  0.0|       0.0|
|[5.5,2.5,4.0,1.3]|  1.0|       1.0|
|[5.6,2.9,3.6,1.3]|  1.0|       1.0|
|[5.7,2.9,4.2,1.3]|  1.0|       1.0|
|[5.8,2.7,5.1,1.9]|  2.0|       2.0|
|[6.3,2.5,4.9,1.5]|  1.0|       1.0|
|[6.4,3.1,5.5,1.8]|  2.0|       2.0|
|[6.5,3.0,5.2,2.0]|  2.0|       2.0|
|[6.5,3.0,5.5,1.8]|  2.0|       2.0|
|[6.5,3.0,5.8,2.2]|  2.0|       2.0|
|[6.7,3.3,5.7,2.5]|  2.0|       2.0|
|[6.8,3.0,5.5,2.1]