# Project 1

##### Import Numpy

In [None]:
import numpy as np

##### Initialize centroids

Initialize centroids in two different ways
1. Choose *k* colors
2. Choose *k* pixel as centroids

In [None]:
def centroids_init(img, k, type):
    centroids = []
    if type == 'random':
        centroids = np.random.randint(0, 256, size=(k, img.shape[1]), dtype=np.uint8)
    elif type == 'in_pixels':
        centroids = img[np.random.choice(img.shape[0], size=k, replace=False)]
    return centroids

##### Label pixels with centroids

For each pixel in the image, we calculate the Manhattan distance to each centroid, then label the pixel with the nearest centroid

In [None]:
def label_pixels(img, centroids):
    # return np.argmin(np.linalg.norm(img[:, np.newaxis] - centroids, axis=2), axis=1)
    # using Manhattan distance instead
    return np.argmin(np.sum(np.abs(img[:, np.newaxis] - centroids), axis=2), axis=1)

##### Recalculate centroids after labeling

Recalculate centroids based on image with labeled pixels

In [None]:
def recalculate_centroids(img, labels, k):
    new_centroids = np.zeros((k, 3))
    for i in range(k):
        # for each pixels in the image labeled as cluster i
        pixels = img[labels == i]

        # prevent case when there are no pixels in the cluster
        if pixels.shape[0]:
            new_centroids[i] = np.mean(pixels, axis=0)

    # return np.array([np.mean(img[labels == i], axis=0) for i in range(k)])
    return new_centroids

##### K-Means algorithm

In [None]:
def kmeans(img, k_clusters, max_iter, init_centroids='random'):
    # initialize centroids
    centroids = centroids_init(img, k_clusters, init_centroids)
    labels = np.full(img.shape[0], -1)

    for _ in range(max_iter):
        # label each pixel with each centroids
        labels = label_pixels(img, centroids)

        # recalculate the centroids
        new_centroids = recalculate_centroids(img, labels, k_clusters)
        
        # check if the color is "good" enough
        if np.allclose(centroids, new_centroids, rtol=8e-3):
            break

        centroids = new_centroids

    return labels, centroids

#### Executing K-Means

##### Preprocessing image

In [None]:
from PIL import Image

filepath = './images/phong_canh.png'

image = Image.open(filepath)
img = np.array(image, dtype=int)

h, w, c = img.shape
image_reshape = img.reshape(h*w, c)

##### Some parameters for K-Means algorithm

In [None]:
k = 7
it = 1000
init = 'random'

##### Runs K-Means

In [None]:
labels, centroids = kmeans(img=image_reshape, k_clusters=k, max_iter=it, init_centroids=init)

##### Postprocessing image

In [None]:
import matplotlib.pyplot as plt

#reconstruct image array from centroids and labels
compressed_img = centroids[labels].reshape((h, w, c)).astype(np.uint8)
#construct an image from the image array
compressed_image = Image.fromarray(compressed_img)

##### Show image before and after

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(20, 10))

axes[0].imshow(image)
axes[0].axis('off')
axes[0].set_title('Before')

axes[1].imshow(compressed_image)
axes[1].axis('off')
axes[1].set_title('After')

plt.tight_layout()
plt.show()

##### Export image

In [None]:
import os
filename = os.path.splitext(os.path.basename(filepath))[0]
output_file = filename + '_k' + str(k) + '.png'

# export image to file
compressed_image.save(output_file)