In [None]:
from pyspark.sql import SparkSession
from pyspark.ml.clustering import KMeans
from pyspark.ml.feature import VectorAssembler
import pandas as pd
from sklearn.datasets import make_blobs
import matplotlib.pyplot as plt
import seaborn as sns

#Start Spark session
spark = SparkSession.builder.appName("BlobsClustering").getOrCreate()

#Generate synthetic dataset using make_blobs()
X, y_true = make_blobs(n_samples=300, centers=4, cluster_std=1.0, random_state=42)
df_pd = pd.DataFrame(X, columns=["x", "y"])
df_pd["label"] = y_true # true labels (for visualization only)

#Convert to Spark DataFrame
df = spark.createDataFrame(df_pd[["x", "y"]])

#Assemble features into vector
vec_assembler = VectorAssembler(inputCols=["x", "y"], outputCol="features")
df_features = vec_assembler.transform(df)

#Train KMeans model
kmeans = KMeans(k=4, seed=1, featuresCol="features", predictionCol="cluster")
model = kmeans.fit(df_features)

#Predict clusters
predictions = model.transform(df_features)

#Convert back to Pandas for visualization
preds_pd = predictions.select("x", "y", "cluster").toPandas()

#Plot the clustered points
plt.figure(figsize=(8,6))
sns.scatterplot(data=preds_pd, x="x", y="y", hue="cluster", palette="Set2", s=60)
plt.title("KMeans Clustering (k=4) on Synthetic Blob Data")
plt.xlabel("x")
plt.ylabel("y")
plt.legend(title="Cluster")
plt.grid(True)
plt.tight_layout()
plt.show()

# Optional: Plot original labels for comparison
plt.figure(figsize=(8,6))
sns.scatterplot(x=X[:,0], y=X[:,1], hue=y_true, palette="Set1", s=60)
plt.title("Original Labels (for reference only)")
plt.xlabel("x")
plt.ylabel("y")
plt.legend(title="True Label")
plt.grid(True)
plt.tight_layout()
plt.show()

In [None]:
#Elbow plot to find optimal K-value
errors = []
for k in range(2, 10):
  km = KMeans(k=k, seed=1, featuresCol="features")
  model_k = km.fit(df_features)
  errors.append(model_k.summary.trainingCost)

plt.figure(figsize=(8,5))
plt.plot(range(2, 10), errors, marker='o')
plt.title("Elbow Method for Optimal k")
plt.xlabel("Number of Clusters (k)")
plt.ylabel("Within Set Sum of Squared Errors (WSSSE)")
plt.grid(True)
plt.show()