In [34]:
import os
import sys
import numpy as np
from typing import Callable

# Add a path to a directory with distance_measures module
parent_dir = os.path.abspath(os.path.join(os.getcwd(), '..'))
sys.path.insert(0, parent_dir)

import distance_measures as measures


In [35]:
class KMeans:
    """
    Implementation of k-means clustering algorithm.

    Parameters:
        k (int): Number of centroids (clusters).
        dist_meas (Callable): Dissimilarity measure function (e.g., Euclidean distance).
    """

    def __init__(
        self,
        k: int,
        dist_meas: Callable[[np.ndarray, np.ndarray], float]
    ) -> None:
        """
        Initialize the KMeans clustering algorithm.

        Args:
            k (int): Number of centroids.
            dist_meas (Callable): Dissimilarity measure function.
        """
        self.k: int = k
        self.dist_meas: Callable[[np.ndarray, np.ndarray], float] = dist_meas

    def _generate_centroids(
        self,
        X: np.ndarray,
        m: int
    ) -> np.ndarray:
        """
        Generate k centroids by computing the mean of m randomly drawn data points (with replacement).

        Args:
            X (np.ndarray): Input data of shape (n_samples, n_features).
            m (int): Number of data points to average for each centroid.

        Returns:
            np.ndarray: Array of shape (k, n_features) containing the centroids.
        """
        n, p = X.shape  # n: number of samples, p: number of features

        # Initialize centroids array
        centroids: np.ndarray = np.zeros((self.k, p), dtype=np.float32)

        for centr_id in range(self.k):
            # Randomly select a starting index for a slice of m samples
            slice_start: int = np.random.randint(0, n - m)
            # Draw m vectors from X
            vectors: np.ndarray = X[slice_start: slice_start + m]
            # Compute the mean vector to serve as a centroid
            centroid: np.ndarray = np.mean(vectors, axis=0)
            centroids[centr_id] = centroid

        return centroids