# Spark K-MeansClustering Demo


Run the colab Demo :  [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1GuLkMWvxZQTlfVEPcsv-OwZ4rgFJ2MiT?usp=sharing)



## Setup Spark in Colab

In [1]:
# Install Java, Spark 3.3.2 and py4j
!apt-get install openjdk-11-jdk -qq > /dev/null
!wget -q https://archive.apache.org/dist/spark/spark-3.3.2/spark-3.3.2-bin-hadoop3.tgz
!tar xf spark-3.3.2-bin-hadoop3.tgz
!pip install -q findspark


In [2]:
# Set environment variables
import os
os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-11-openjdk-amd64"
os.environ["SPARK_HOME"] = "/content/spark-3.3.2-bin-hadoop3"


In [3]:
#  Initialize Spark
import findspark
findspark.init()

from pyspark.sql import SparkSession

spark = SparkSession.builder \
    .appName("Colab Spark MLlib Setup") \
    .getOrCreate()

print("Spark Session started successfully!")


Spark Session started successfully!



# Load the Iris Dataset

In [4]:
# Download and load the Iris dataset
!wget -q https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data -O iris.csv


In [5]:
# Load with labels
df = spark.read.csv("iris.csv", inferSchema=True, header=False)
df = df.toDF("sepal_length", "sepal_width", "petal_length", "petal_width", "label")
df.show(5)


+------------+-----------+------------+-----------+-----------+
|sepal_length|sepal_width|petal_length|petal_width|      label|
+------------+-----------+------------+-----------+-----------+
|         5.1|        3.5|         1.4|        0.2|Iris-setosa|
|         4.9|        3.0|         1.4|        0.2|Iris-setosa|
|         4.7|        3.2|         1.3|        0.2|Iris-setosa|
|         4.6|        3.1|         1.5|        0.2|Iris-setosa|
|         5.0|        3.6|         1.4|        0.2|Iris-setosa|
+------------+-----------+------------+-----------+-----------+
only showing top 5 rows



In [6]:
# Assemble features
from pyspark.ml.feature import VectorAssembler

assembler = VectorAssembler(
    inputCols=["sepal_length", "sepal_width", "petal_length", "petal_width"],
    outputCol="features"
)
df = assembler.transform(df)

#Train Model

In [7]:
# Train KMeans model
from pyspark.ml.clustering import KMeans

kmeans = KMeans(featuresCol="features", k=3, seed=1)
model = kmeans.fit(df)

print(" Model trained. Cluster centers:")
for center in model.clusterCenters():
    print(center)


 Model trained. Cluster centers:
[5.9016129  2.7483871  4.39354839 1.43387097]
[5.006 3.418 1.464 0.244]
[6.85       3.07368421 5.74210526 2.07105263]


In [8]:
# Make predictions
predictions = model.transform(df)
predictions.select("features", "label", "prediction").show(10)


+-----------------+-----------+----------+
|         features|      label|prediction|
+-----------------+-----------+----------+
|[5.1,3.5,1.4,0.2]|Iris-setosa|         1|
|[4.9,3.0,1.4,0.2]|Iris-setosa|         1|
|[4.7,3.2,1.3,0.2]|Iris-setosa|         1|
|[4.6,3.1,1.5,0.2]|Iris-setosa|         1|
|[5.0,3.6,1.4,0.2]|Iris-setosa|         1|
|[5.4,3.9,1.7,0.4]|Iris-setosa|         1|
|[4.6,3.4,1.4,0.3]|Iris-setosa|         1|
|[5.0,3.4,1.5,0.2]|Iris-setosa|         1|
|[4.4,2.9,1.4,0.2]|Iris-setosa|         1|
|[4.9,3.1,1.5,0.1]|Iris-setosa|         1|
+-----------------+-----------+----------+
only showing top 10 rows



# Evaluate

In [9]:
# Evaluate Clustering
from pyspark.ml.evaluation import ClusteringEvaluator

# Silhouette Score
evaluator = ClusteringEvaluator(featuresCol="features", predictionCol="prediction", metricName="silhouette")
silhouette = evaluator.evaluate(predictions)

# WSSSE (Sum of Squared Errors)
wssse = model.summary.trainingCost

print("\nK-Means Evaluation Metrics:")
print(f"Silhouette Score : {silhouette:.2f}")
print(f"WSSSE            : {wssse:.2f}")



K-Means Evaluation Metrics:
Silhouette Score : 0.74
WSSSE            : 78.94


In [10]:
# Cluster Sizes
print("\n Cluster Sizes:")
predictions.groupBy("prediction").count().show()



 Cluster Sizes:
+----------+-----+
|prediction|count|
+----------+-----+
|         1|   50|
|         2|   38|
|         0|   62|
+----------+-----+



In [11]:
# Compare with real labels
print("\n Cluster Prediction vs Real Label:")
predictions.groupBy("label", "prediction").count().orderBy("label").show(20)



 Cluster Prediction vs Real Label:
+---------------+----------+-----+
|          label|prediction|count|
+---------------+----------+-----+
|    Iris-setosa|         1|   50|
|Iris-versicolor|         0|   48|
|Iris-versicolor|         2|    2|
| Iris-virginica|         0|   14|
| Iris-virginica|         2|   36|
+---------------+----------+-----+

