In [4]:
import torch

def kmeans(points, x_prime):
    n = points.size(0)
    k = n // x_prime
    
    # Initialize centroids randomly
    centroids = points[torch.randperm(n)[:k]]
    
    old_centroids = centroids.clone()  # Initialize old_centroids
    
    while True:
        # Assign each point to the nearest centroid
        distances = torch.cdist(points, centroids)
        _, assignments = torch.min(distances, dim=1)
        
        # Check if any cluster is too large or too small
        cluster_sizes = torch.bincount(assignments, minlength=k)
        if (cluster_sizes != x_prime).any():
            # Re-assign points to different clusters
            for i in range(k):
                while cluster_sizes[i] > x_prime:
                    # Find the point farthest from the centroid and re-assign it
                    farthest = torch.argmin(distances[:, i])
                    assignments[farthest] = -1  # Mark point as unassigned
                    cluster_sizes[i] -= 1
                    
                while cluster_sizes[i] < x_prime:
                    # Find the point closest to the centroid and re-assign it
                    closest = torch.argmin(distances[:, i])
                    if assignments[closest] == -1:  # Skip points that are already unassigned
                        continue
                    assignments[closest] = i
                    cluster_sizes[i] += 1
        
        # Calculate new centroids
        centroids = torch.stack([points[assignments == i].mean(dim=0) for i in range(k)])
        
        # Check if the algorithm has converged
        if (centroids == old_centroids).all():
            break
        
        old_centroids = centroids
    
    # Return cluster assignments
    return assignments


In [5]:
points = torch.randn((512*3, 3))

x_prime = 128

assignments = kmeans(points, x_prime)
print(assignments.shape, assignments)

torch.Size([1536]) tensor([11,  2,  1,  ..., 11, 11, 11])


In [6]:
# count the number of occurance of each value in assignemnts
for values in torch.unique(assignments):
    print(values, (assignments == values).sum())

tensor(-1) tensor(5)
tensor(0) tensor(140)
tensor(1) tensor(131)
tensor(2) tensor(137)
tensor(3) tensor(121)
tensor(4) tensor(96)
tensor(5) tensor(122)
tensor(6) tensor(102)
tensor(7) tensor(122)
tensor(8) tensor(100)
tensor(9) tensor(135)
tensor(10) tensor(109)
tensor(11) tensor(216)


In [7]:
import torch
from sklearn.cluster import KMeans
from scipy.spatial.distance import cdist

# generate N 3D points
N = 100
points = torch.randn(N, 3)

# set the number of clusters
k = 5

# initialize k-means algorithm
kmeans = KMeans(n_clusters=k)

# fit the algorithm to the data
kmeans.fit(points)

# get the cluster labels for each point
labels = kmeans.labels_

# calculate the number of points in each group
counts = torch.tensor([torch.sum(torch.tensor(labels) == i) for i in range(k)])

# adjust the cluster assignments
while not torch.all(counts == N//k):
    # find the groups with more than N/k points
    overfull_groups = torch.where(counts > N//k)[0]
    if overfull_groups.nelement() > 0:
        # iterate over each overfull group
        for i in overfull_groups:
            group_points = points[labels == i]
            centroid = kmeans.cluster_centers_[i]
            distances = cdist(group_points, centroid.reshape(1,-1))
            farthest_point_idx = torch.argmax(torch.tensor(distances))
            farthest_point = group_points[farthest_point_idx]
            # find the closest underfull group
            underfull_groups = torch.where(counts < N//k)[0]
            distances = cdist(kmeans.cluster_centers_[underfull_groups], farthest_point.reshape(1,-1))
            closest_group_idx = torch.argmin(torch.tensor(distances))
            closest_group = underfull_groups[closest_group_idx]
            # reassign the farthest point to the closest underfull group
            labels[torch.where(labels == i)[0][farthest_point_idx]] = closest_group
            counts[i] -= 1
            counts[closest_group] += 1

    # find the groups with fewer than N/k points
    underfull_groups = torch.where(counts < N//k)[0]
    if underfull_groups.nelement() > 0:
        # iterate over each underfull group
        for i in underfull_groups:
            group_points = points[labels == i]
            centroid = kmeans.cluster_centers_[i]
            # find the closest unassigned point
            distances = cdist(points, centroid.reshape(1,-1))
            unassigned_points = torch.where(labels == -1)[0]
            unassigned_distances = distances[unassigned_points, i]
            closest_point_idx = torch.argmin(unassigned_distances)
            closest_point = points[unassigned_points[closest_point_idx]]
            # assign the closest un




ValueError: XA must be a 2-dimensional array.

In [2]:
import torch

k = 5
matrix = torch.randn(k, 3)

# Reshape the matrix to [k, 1, 3]
reshaped_matrix = matrix.view(k, 1, 3)

# Repeat each row along dimension 1, each row is repeated 10 times
inflated_matrix = reshaped_matrix.repeat(1, 10, 1)

# Reshape the inflated matrix back to [10*k, 3]
inflated_matrix = inflated_matrix.view(10*k, 3)

print(inflated_matrix)

tensor([[-0.2455,  1.5602, -0.3678],
        [-0.2455,  1.5602, -0.3678],
        [-0.2455,  1.5602, -0.3678],
        [-0.2455,  1.5602, -0.3678],
        [-0.2455,  1.5602, -0.3678],
        [-0.2455,  1.5602, -0.3678],
        [-0.2455,  1.5602, -0.3678],
        [-0.2455,  1.5602, -0.3678],
        [-0.2455,  1.5602, -0.3678],
        [-0.2455,  1.5602, -0.3678],
        [ 0.7741, -0.7501,  0.8914],
        [ 0.7741, -0.7501,  0.8914],
        [ 0.7741, -0.7501,  0.8914],
        [ 0.7741, -0.7501,  0.8914],
        [ 0.7741, -0.7501,  0.8914],
        [ 0.7741, -0.7501,  0.8914],
        [ 0.7741, -0.7501,  0.8914],
        [ 0.7741, -0.7501,  0.8914],
        [ 0.7741, -0.7501,  0.8914],
        [ 0.7741, -0.7501,  0.8914],
        [-1.4721,  0.4531, -0.3946],
        [-1.4721,  0.4531, -0.3946],
        [-1.4721,  0.4531, -0.3946],
        [-1.4721,  0.4531, -0.3946],
        [-1.4721,  0.4531, -0.3946],
        [-1.4721,  0.4531, -0.3946],
        [-1.4721,  0.4531, -0.3946],
 

In [5]:
points = torch.empty((0,3))
points2 = torch.randn((10,3))

points_all = torch.cat([points, points2], dim=0)
print(points_all.shape)


torch.Size([10, 3])
