In [1]:
import os.path as osp
import os

import numpy as np
from sklearn.cluster import KMeans
from skimage.io import imread
from scipy.spatial.distance import cdist
import matplotlib.pyplot as plt

In [2]:
def train_kmeans_models(images, min_k, max_k):
    assert min_k <= max_k
    
    if type(images) is not list:
        images = [images]
    
    features = [image.reshape((image.shape[0] * image.shape[1], image.shape[2])) for image in images]
    data = {}
    for i, X in enumerate(features):
        print("Training models on image {}".format(i))
        data[i] = { 'models' : [], 'labels' : [] }
        for k in range(min_k, max_k + 1):
            print("Traing on K = " + str(k))
            model = KMeans(n_clusters=k)
            labels = model.fit_predict(X).reshape(images[i].shape[:2])
            data[i]['models'].append(model)
            data[i]['labels'].append(labels)
            
    return data

In [1]:
def plot_maps(images, data, min_k, max_k,
              colors = [[0,0,0], [255,255,255], [255,0,0],
                        [0,255,0],[0,0,255],[255,255,0],
                        [0,255,255],[255,0,255]]
             ):
    
    assert min_k <= max_k
    
    if type(images) is not list:
        images = [images] 
    
    sample_aspect = (images[0].shape[0] * 3) / (float(images[0].shape[1]) * MAX_K - MIN_K + 1)

    fig, ax = plt.subplots(nrows=len(data), ncols=MAX_K - MIN_K + 2, figsize=(15, int(15 * sample_aspect)))

    for i, x in enumerate(ax):
        x[0].imshow(images[i], aspect='auto')
        x[0].set_axis_off()

    ax[0][0].set_title("Original")
    for iax in range(1,MAX_K - MIN_K + 2):
        ax[0][iax].set_title("K = "+ str(iax + MIN_K - 1))

    for i,x in enumerate(ax):
        for j,k in enumerate(range(MIN_K, MAX_K + 1)):  
            pred_image = np.zeros(images[i].shape, dtype=int)
            labels = data[i]['labels'][j]
            for l in range(k):
                pred_image[labels == l] = colors[l]
            x[j+1].imshow(pred_image, aspect='auto')
            x[j+1].set_axis_off()

    plt.tight_layout()
    plt.show()

In [1]:
def segment(image, bg_val=(255,255,255), return_filter=False):
    m = KMeans(n_clusters = 3)
    flattened_dims = (image.shape[0] * image.shape[1], image.shape[2])
    labels = m.fit_predict(image.reshape(flattened_dims)).reshape(image.shape[:2])
    
    # find the cluster closest to the color white
    bg_px_cluster = np.argmin(
                        np.sum(
                            np.subtract(bg_val, m.cluster_centers_),
                            axis=1
                        )
                    )
    
    # convert background pixels to black
    image[labels == bg_px_cluster, :] = [0,0,0]
    if return_filter:
        bin_filt = labels.copy()
        bin_filt[labels == bg_px_cluster] = 0
        bin_filt[labels != bg_px_cluster] = 1
        return image, bin_filt
    return image