In [44]:
from pyspark.sql import Row, SparkSession

In [45]:
spark = SparkSession.builder.getOrCreate()

In [46]:
# Assuming dataRDD is your RDD containing data points
dataRDD = spark.read.csv("data.csv", header=True, inferSchema=True).rdd

# Step 1: Randomly initialize K centroids
k = 3  # specify the number of clusters
centroids = dataRDD.takeSample(withReplacement=False, num=k, seed=42)

print("Initial centroids:")
for index, centroid in enumerate(centroids):
    print(f"Centroid {index}: {centroid}")

Initial centroids:
Centroid 0: Row(x1=99.3715860581051, x2=31.88346837545848)
Centroid 1: Row(x1=84.20068926901318, x2=13.86568097708431)
Centroid 2: Row(x1=87.44382714322302, x2=-23.348612125353185)


In [47]:
# Step 2: Define a function to assign each data point to the nearest centroid
def assign_to_centroid(point, centroids):
    best_centroid = None
    best_distance = float("inf")
    for centroid in centroids:
        distance = sum((a - b) ** 2 for a, b in zip(point, centroid))
        if distance < best_distance:
            best_distance = distance
            best_centroid = centroid

    return best_centroid, point

In [48]:
# Step 3: Assign each data point to the nearest centroid
assigned_data = dataRDD.map(lambda point: assign_to_centroid(point, centroids))

# Step 4: Calculate new centroids by averaging points assigned to each centroid
new_centroids = (
    assigned_data.groupByKey()
    .mapValues(lambda points: [sum(p) / len(points) for p in zip(*points)])
    .values()
    .collect()
)

new_centroids = [Row(x1=centroid[0], x2=centroid[1]) for centroid in new_centroids]

In [49]:
# Step 5: Iterate until convergence or a specified number of iterations
max_iterations = 10
for iteration in range(max_iterations):
    assigned_data = dataRDD.map(lambda point: assign_to_centroid(point, centroids))
    new_centroids = (
        assigned_data.groupByKey()
        .mapValues(lambda points: [sum(p) / len(points) for p in zip(*points)])
        .values()
        .collect()
    )
    new_centroids = [Row(x1=centroid[0], x2=centroid[1]) for centroid in new_centroids]

In [50]:
final_centroids = new_centroids
print("Final centroids:")
for index, centroid in enumerate(final_centroids):
    print(f"Centroid {index}: {centroid}")

Final centroids:
Centroid 0: Row(x1=85.58130353160162, x2=-23.416604265260847)
Centroid 1: Row(x1=4.785748415740877, x2=13.679017999658925)
Centroid 2: Row(x1=103.05401794144838, x2=43.26637266357961)
