# Using K-means for Image segmentation 

I suggest that you create your own script to read pictures from a directory
and convert them. This will both make the reviewing process faster, but also give you a chance to see the overall picture of how the 50 tests cases do all in all more easily.



## Importing the libraries we will need

In [1]:
import numpy as np
from imageio import imread, imwrite
from skimage.transform import rescale
from skimage.color import rgb2lab, lab2rgb # Remove if not needed
from sklearn.metrics.pairwise import euclidean_distances

from datetime import datetime

In [3]:
img_name = "IMG_NAME"
img_extension = "IMG_EXT"
image_raw = imread(f"{img_name}.{img_extension}")
image_cilab_raw = rgb2lab(image_raw[:, :, :3])

## Preproccesing functions

In [4]:
def compare(a, b):
    return sum((a-b)**2)

def comparision_matrix(M):
    width = M.shape[0]
    height = M.shape[1]

    # result will be a widthxheight matrix,
    # where the 4 values in each cells are the results of the compare-function with the cell:
    #   [left,right,up,down]
    result = np.zeros((width, height, 4))

    # Row comparision
    for i, j in zip(range(0, width), range(1, width)):
        for k in range(0, height):
            tmp = compare(M[i, k], M[j, k])
            result[i, k, 1] = tmp
            result[j, k, 0] = tmp
    
    # Column comparison
    for l in range(0, width):
        for m, n in zip(range(0, height), range(1, height)):
            tmp = compare(M[l, m], M[l, n])
            result[l, m, 3] = tmp
            result[l, n, 2] = tmp

    return result

def compress_matrix(M):
    compressed_M = np.zeros((M.shape[0], M.shape[1], 1))
    for i in (range(M.shape[0])):
        for j in range(M.shape[1]):
            compressed_M[i, j] = np.average(M[i, j])

    return compressed_M

## Clustering functions

In [5]:
def initial_centroids(points, k, dim):

    if k*dim <= len(points):
        print("Wait...")
        Mean = points.mean(0)
        return np.array([Mean + np.random.randn(dim) for _ in range(k)])

    centroids = points.copy()
    np.random.shuffle(centroids)
    return [centroids[i*dim:dim*(i+1)] for i in range(k)]

def closest_centroid(points, centroids):
    # returns an array containing the index to the nearest centroid for each point
    distances = np.sqrt(((points - centroids[:, np.newaxis])**2).sum(axis=2))
    return np.argmin(distances, axis=0)

def find_centers(centers, points, assignments):
    points = points.flatten()
    for i in range(len(centers)):
        cluster = points[assignments == i]
        if len(cluster) != 0:
            centers[i] = cluster.mean(0)
    return centers

## Postprocessing functions

In [None]:
# Change values in this for different colors.
# The number corresponds to the nth cluster, so changing it will change
# that cluster's color
def color_picker(number):
    if number == 0:
        return [255,255,255]
    elif number == 1:
        return [0,0,0]
    elif number == 2:
        return [0,0,255]
    elif number == 3:
        return [255,0,0]
    elif number == 4:
        return [0,255,0]
    elif number == 5:
        return [255,255,0]
    elif number == 6:
        return [50,50,255]
    elif number == 7:
        return [240,70,24]
    elif number == 8:
        return [12,129,33]
    elif number == 9:
        return [68,0,35]

def create_image_from_cluster(assignments, M, shape):
    assignments = assignments.reshape(shape) 
    for i in (range(M.shape[0])):
        for j in range(M.shape[1]):
            M[i,j] = color_picker(assignments[i,j])
    return M

### Running the clustering algorithm

In [None]:

k = 3 # Number of clusters: Can be changed
dim = 1 # Dimension of clusters: Will require manual tinkering if changed

# !! NOTE THAT IN YOUR TESTING k SHOULD BE EQUAL TO 3 AND dim EQUAL TO 1 ¡¡

M = comparision_matrix(image_cilab_raw) 
M = compress_matrix(M)
centroids = initial_centroids(M.flatten(), k, dim)

while(True):
    close = closest_centroid(M.reshape(-1, 1), centroids)
    prev_centroids = centroids.copy()
    centroids = find_centers(centroids, M, close)
    if sum(centroids-prev_centroids) == 0:
        break

### Create visual representation of clusters

In [None]:
new_img = create_image_from_cluster(close, image_raw[:,:,:3], (M.shape[0],M.shape[1]))

## Example of displaying image in editor

In [None]:
plt.imshow(new_img)

## Example of how to save an image

In [None]:
dir_path = "./" # Make this the full path

imwrite(f"{dir_path}{img_name}-{datetime.now()}.png", new_img) # We save as png because all we have are the colors in [r, g, b]