# Project 01 - Color Compression

## Thông tin sinh viên

- Họ và tên: Văn Bá Đức Kiên
- MSSV: 22127218
- Lớp: 22CLC10

## Import các thư viện liên quan

In [218]:
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

## Helper functions

In [219]:
def read_img(img_path):
    '''
    Read image from img_path

    Parameters
    ----------
    img_path : str
        Path of image

    Returns
    -------
        Image (2D)
    '''

    # YOUR CODE HERE
    return np.asarray(Image.open(img_path).convert('RGB'))


def show_img(img_2d):
    '''
    Show image

    Parameters
    ----------
    img_2d : np.ndarray
        Image (2D)
    '''

    # YOUR CODE HERE
    plt.imshow(img_2d)



def save_img(img_2d, img_path):
    '''
    Save image to img_path

    Parameters
    ----------
    img_2d : np.ndarray
        Image (2D)
    img_path : str
        Path of image
    '''

    # YOUR CODE HERE
    Image.fromarray(img_2d, mode='RGB').save(img_path)


def convert_img_to_1d(img_2d):
    '''
    Convert 2D image to 1D image

    Parameters
    ----------
    img_2d : np.ndarray
        Image (2D)

    Returns
    -------
        Image (1D)
    '''

    # YOUR CODE HERE
    return np.reshape(img_2d, (-1, 3))


def kmeans(img_1d, k_clusters, max_iter, init_centroids='random'):
    '''
    K-Means algorithm

    Parameters
    ----------
    img_1d : np.ndarray with shape=(height * width, num_channels)
        Original (1D) image
    k_clusters : int
        Number of clusters
    max_iter : int
        Max iterator
    init_centroids : str, default='random'
        The method used to initialize the centroids for K-means clustering
        'random' --> Centroids are initialized with random values between 0 and 255 for each channel
        'in_pixels' --> A random pixel from the original image is selected as a centroid for each cluster

    Returns
    -------
    centroids : np.ndarray with shape=(k_clusters, num_channels)
        Stores the color centroids for each cluster
    labels : np.ndarray with shape=(height * width, )
        Stores the cluster label for each pixel in the image
    '''
    
    # YOUR CODE HERE

    # INIT CENTROIDS
    if (init_centroids == 'random'):
        centroids = np.random.randint(256, size=(k_clusters, 3)).astype('f4')
    elif (init_centroids == 'in_pixels'):
        row_set = np.array(np.unique(img_1d, axis=0)).astype('f4')
        centroids = row_set[np.random.choice(row_set.shape[0], k_clusters, replace=False)]
    
    for it in range (max_iter):
        print(f'\rNumber of interations: {it + 1}', end='')
        
        # ASIGN LABELS
        # calculate differences using numpy broadcast with 
        # imd_1d(1d_len, 3) - centroids(k_clusters, 1, 3) --> diff(k_clusters, 1d_len, 3)
        diff = np.array(img_1d - centroids.reshape((k_clusters, 1, 3)))
        dist = np.array(np.sqrt(np.sum(diff**2, axis=2)))
        labels = np.array(np.argmin(dist, axis=0))

        # UPDATE CENTROIDS
        cluster_ids, counts = np.unique(labels, return_counts=True)
        labels_cnt = dict(zip(cluster_ids, counts))
        
        # update the centroids, if any old centroid does not have any pixel in its group, replace it with a new random centroid
        new_centroids = np.random.randint(256, size=(k_clusters, 3)).astype('f4') # f4 = float32
        for i in labels_cnt:
            new_centroids[i] = np.sum(img_1d[labels == i], axis=0) / labels_cnt[i]

        # CHECK CONVERGENCE
        limit = max(1e-2, 1/max_iter)
        if (np.all(np.abs(new_centroids - centroids) < limit)):
            break
        centroids = new_centroids           
    print('\n')    
    return np.round(new_centroids), labels
        



def generate_2d_img(img_2d_shape, centroids, labels):
    '''
    Generate a 2D image based on K-means cluster centroids

    Parameters
    ----------
    img_2d_shape : tuple (height, width, 3)
        Shape of image
    centroids : np.ndarray with shape=(k_clusters, num_channels)
        Store color centroids
    labels : np.ndarray with shape=(height * width, )
        Store label for pixels (cluster's index on which the pixel belongs)

    Returns
    -------
        New image (2D)
    '''

    # YOUR CODE HERE
    return np.array([centroids[labels[i]] for i in range(len(labels))]).reshape(img_2d_shape).astype('uint8')


# Your additional functions here


## Your tests

In [220]:
# YOUR CODE HERE

## Main FUNCTION

In [221]:
# YOUR CODE HERE
def main():
    # Input and validate input image
    try:
        img_path = input("Enter the image's path: ")
        img_2d = read_img(img_path)
    except:
        print("Invalid image's path or invalid image!")
        return

    # Input and validate k_clusters and max_iter
    try: 
        k_clusters = int(input("Enter the number of colors in the compressed image: "))
        max_iter = int(input("Enter the max number of iterations for K-means: "))
        if (k_clusters <= 0 or max_iter <= 0):
            print("number of colors and max number of iterations must be positive integers!")
            return
    except:
        print("number of colors and max number of iterations must be positive integers!")
        return

    # Input and validate init_centroids
    init_centroids = input("Enter the centroids initialization type (random or in_pixels): ")
    if (init_centroids not in ["random", "in_pixels"]):
        print("Centroids initialization type must be either \"random\" or \"in_pixels\"!")
        return

    # Run the K-means algorithm
    print("Compressing, please wait...")
    img_1d = convert_img_to_1d(img_2d)
    centroids, labels = kmeans(img_1d, k_clusters, max_iter, init_centroids)
    compressed_2d_img = generate_2d_img(img_2d.shape, centroids, labels)

    print("Compressed image: ")
    show_img(compressed_2d_img)

    # Input path to save compressed image to
    saved_img_path = input("Enter the path to save the compressed image to (including file name and extension): ")
    try:
        save_img(compressed_2d_img, saved_img_path)
        print("Image sucessfully saved")
    except:
        print("Invalid folder or file or invalid filename extension of compressed image!")
    
    

In [None]:
# Call main function
if __name__ == '__main__':
    main()