In [None]:
import numpy as np
from matplotlib import pyplot as plt
import PIL
# import timeit

In [None]:
def get_image_info(image):
    width = image.shape[0]
    height = image.shape[1]
    num_channel = image.shape[-1] if image.ndim != 2 else 1 
    return width, height, num_channel

In [None]:
def rand_k_centroids(matrix, k_clusters, random_type=1):
    if random_type == 'random':   
        rand_centroids = np.random.randint(0, 256, size=(k_clusters, matrix.shape[1]), dtype=np.int64)
    elif random_type == 'in_pixels':
        rand_centroids = np.random.default_rng()
        rand_centroids = rand_centroids.choice(np.unique(matrix), (k_clusters, matrix.shape[1]), replace='False') 
    
    return rand_centroids


In [None]:
def update_centroids(labels, centroids, image_1d, k_clusters):
    num_channel = image_1d.shape[1]
    new_centroids = np.empty((k_clusters, num_channel))
    for k in range(k_clusters):
        if np.any(labels == k):
            new_centroids[k] = image_1d[labels == k].mean(axis=0)
        else:
            new_centroids[k] = centroids[k]
    return new_centroids


In [None]:

def kmeans(image_1d, k_clusters, max_iter=100, random_type='random'):
    
    # Random k_clusters centroids
    centroids = rand_k_centroids(image_1d, k_clusters, random_type)
    
    iteration = 0
    while(iteration < max_iter):
        # Using Manhattan to calculate euclidean distance between 2 pixels.
        distance = np.sum(np.abs(image_1d[:, np.newaxis] - centroids), axis=-1)
        
        labels = np.argmin(distance, axis=-1)

        new_centroids = update_centroids(labels, centroids, image_1d, k_clusters)
        
        # if np.allclose(centroids, new_centroids, rtol=0.5): # For faster
        if np.all(centroids == new_centroids): # More accurately
            break
    
        iteration += 1
        centroids = new_centroids
       
    return centroids, labels

In [None]:
def visuallization(image_1d):
    wcss = []
    for k in range(1, 11):
        centroids, labels = kmeans(image_1d, k, 1000, 'in_pixels')

        wcss.append(0)
        for i in range(k):
            labels_centroids = centroids[i]
            labels_data = image_1d[labels == i]
            wcss[k-1] += np.sum(np.linalg.norm(labels_data - labels_centroids, axis=1)**2)
    
    # Plot WCSS versus K
    plt.plot(range(1, 11), wcss)
    plt.xlabel('Number of clusters (K)')
    plt.ylabel('WCSS')
    plt.title('Elbow Method')
    plt.show()

In [None]:
def show_image(*img, captions):
    fig, axes = plt.subplots(1, 2, figsize=(8, 8))

    for i, ax in enumerate(axes.flat):
        ax.imshow(img[i])
        ax.axis('off')
        ax.set_title(captions[i], fontsize=12)


In [None]:
def main():

    # Input picture name
    img_name = input('Enter image name: ')
    format = input('Output format (png, pdf): ')


    # Open and convert to numpy.array
    image = PIL.Image.open(img_name)
    image = np.array(image)

    # Get image info
    width, height, num_color_channel = get_image_info(image)
    image_1d = np.reshape(image, (height * width, num_color_channel))

    # Setting k_cluster
    k_clusters = 7
    
    # Call KMeans
    centroids, labels = kmeans(image_1d, k_clusters, 100, 'in_pixels')

    # Using to visualize Elbow method for optimize K_clusters
    # visualization(image)

    # Calculate average time of 10 loops
    # in_pixel_time = timeit.timeit(lambda: kmeans(image, k_clusters, 100, 'in_pixels'), number=10)
    # print("Average time for 'in_pixel': ", in_pixel_time / 10)
    # random_time = timeit.timeit(lambda: kmeans(image, k_clusters, 100, 'random'), number=10)
    # print("Average time for 'random': ", random_time / 10)

    # Decompress to original dimensions of picture
    compress_img = np.take_along_axis(centroids, labels[:, np.newaxis], axis=0)
    compress_img = np.reshape(compress_img, (width, height, num_color_channel)).astype(np.uint8)

    show_image(image, compress_img, captions=['Origin', 'Color compression'])
    
    # Save picture with format
    save_img = PIL.Image.fromarray(compress_img)
    img_name = img_name[:img_name.find('.') + 1] + format
    save_img.save(img_name)

if __name__ == "__main__":
    main()