In [3]:
import numpy as np
import math

In [2]:
def find_closest_centroid(data_point, centroids):
    """Find the index of the closest centroid to a given data point."""
    closest_index = None
    closest_distance = float('inf')
    for i, centroid in enumerate(centroids):
        # Assuming find_dist is your function
        distance = find_dist(data_point, centroid)
        if distance < closest_distance:
            closest_distance = distance
            closest_index = i
    return closest_index


def k_means(k, imgs, iterations):
    """Perform K-means clustering."""
    num_images = imgs.shape[0]
    num_dimensions = imgs.shape[1]

    # Randomly initialize centroids
    centroids = imgs[np.random.choice(num_images, k, replace=False)]

    for _ in range(iterations):
        # Assign each data point to the nearest centroid
        assignments = np.array(
            [find_closest_centroid(img, centroids) for img in imgs])

        # Update centroids
        for i in range(k):
            cluster_points = imgs[assignments == i]
            if len(cluster_points) > 0:
                centroids[i] = np.mean(cluster_points, axis=0)

    return centroids, assignments


def find_dist(x1, x2):
    d_c = math.sqrt((x1[0] - x2[0])**2 + (x1[1] - x2[1])
                    ** 2 + (x1[2] - x2[2])**2)
    d_s = math.sqrt((x1[3] - x2[3])**2 + (x1[4] - x2[4])**2)
    N = 100  # population
    S = math.sqrt(N/k)
    m = 1  # change value to properly weight distances
    D = math.sqrt(d_c ** 2 + (((d_s/S)**2) * m**2))
    return D
