### Libraries

In [1]:
import os
import cv2
import numpy as np

from sklearn import cluster
from skimage import segmentation
from skimage import color 
from skimage.future import graph

import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
%matplotlib inline

### General Functions

In [2]:
def load_imgs(directory, filelist):
    imgs = []
    for file in filelist:
        f = os.path.join(directory, file)
        img = cv2.imread(f)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        imgs.append(img)
    return imgs

In [3]:
def filt(img, *filters):
    for filt in filters:
        img = filt(img)
    return img

In [4]:
def multiplot(img_array, figkw={}):
    fig = plt.figure(**figkw)
    
    for row, imgs in enumerate(img_array, start=1):
        for col, img in enumerate(imgs, start=1):
            plt.subplot(len(img_array), len(imgs), col)
            plt.imshow(img)
            if len(img.shape)==2:
                plt.imshow(img, cmap='gray')
            plt.tight_layout()
            plt.axis('off')
        plt.show()

In [5]:
def gridplot(img_array, img_sz=(2, 2), title='', savepath=None, show=False):
    rows, cols = len(img_array), len(img_array[0])
    gs = GridSpec(rows, cols, wspace=0.0, hspace=0.0)
    
    fig = plt.figure(figsize=(img_sz[0]*cols, img_sz[1]*rows))
    for row, imgs in enumerate(img_array):
        for col, img in enumerate(imgs):
            plt.subplot(gs[row, col])
            plt.imshow(img)
            if len(img.shape)==2:
                plt.imshow(img, cmap='gray')
            plt.axis('off')
    plt.suptitle(title)
    if savepath:
        assert savepath.endswith('.png'), 'savepath should include extention. ex) "images/img00.png"'
        dirname = os.path.dirname(savepath)
        os.makedirs(name=dirname, exist_ok=True)
        fig.savefig(savepath)
    if show:
        plt.show()
    else:
        plt.close(fig)

### Filters

In [6]:
def kmeans(*args, **kwargs):
    def func(img, *args, **kwargs):
        x, y, c = img.shape
        img = img.reshape(x*y, c)
        
        K = cluster.KMeans(**kwargs)
        K.fit(img)
        centers, labels = K.cluster_centers_, K.labels_
        img = centers[labels].reshape(x, y, c).astype(np.uint8)
        return img
    return func

def smooth(*args, **kwargs):
    def func(img, *args, **kwargs):
        img = cv2.fastNlMeansDenoisingColored(img, *args, **kwargs)
        return img
    return func

def graphcut(segmap=False, n_segments=20, mode='similarity'):
    def func(img, segmap=segmap, n_segments=n_segments, mode=mode):
        seg = segmentation.slic(img, n_segments=n_segments)

        g = graph.rag_mean_color(img, seg, mode=mode)
        seg2 = graph.cut_normalized(seg, g)
        img = color.label2rgb(seg2, img, kind='avg')
        if segmap:
            return seg2
        return img
    return func

def grabcut(segmap=False, rect_shape=(0.2, 0.2, 0.8, 0.8)):
    def func(img, segmap=segmap, rect_shape=rect_shape):
        w, h = img.shape[:2]
        mask = np.zeros((w, h),np.uint8)
        bgdModel = np.zeros((1,65),np.float64)
        fgdModel = np.zeros((1,65),np.float64)
        w1, h1, w2, h2 = rect_shape
        rect = (int(w1*w), int(h1*h), int(w2*w), int(h2*h))
        cv2.grabCut(img,mask,rect,bgdModel,fgdModel,1,cv2.GC_INIT_WITH_RECT)
        mask2 = np.where((mask==2)|(mask==0),0,1).astype('uint8')
        img = img*mask2[:,:,np.newaxis]
        if segmap:
            return mask2
        return img
    return func

def medianblur(ksize=5):
    def func(img, ksize=ksize):
        img = cv2.medianBlur(img, ksize=ksize)
        return img
    return func

### load data

In [7]:
# directory = os.path.join(os.path.expanduser("~"),"Desktop",'eyes_imgs')
directory = 'eyes_imgs'
filelist = os.listdir(directory)
imgs = load_imgs(directory, filelist)

### playground

##### trial_2

In [25]:
rows = []
for img in imgs[:16]:
    im0 = img
    im1 = filt(im0, smooth(templateWindowSize=40))
    im2 = filt(im1, graphcut(n_segments=300))
    im3 = filt(im2, grabcut(segmap=False, rect_shape=(0.1, 0.1, 0.9, 0.9)))
#     im4 = filt(im2, grabcut(segmap=False, rect_shape=(0.1, 0.1, 0.9, 0.9)))
#     im5 = filt(im3, medianblur(ksize=7))
    
    cols = [im0, im1, im2, im3]
    rows.append(cols)
img_array = rows

In [None]:
gridplot(img_array=img_array, title='', savepath='seg_imgs/seg01.png', show=True)

##### trial_1

In [10]:
def smooth_graphcut_grabcut_medianblur(templateWindowSize=40, n_segments=300, rect_shape=(0.1, 0.1, 0.9, 0.9), ksize=11):
    rows = []
    for img in imgs:
        im0 = img
        im1 = filt(im0, smooth(templateWindowSize=templateWindowSize))
        im2 = filt(im1, graphcut(n_segments=n_segments))
        im3 = filt(im2, grabcut(segmap=True, rect_shape=rect_shape))
        im4 = filt(im3, medianblur(ksize=11))
        cols = [im0, im1, im2, im3, im4]
        rows.append(cols)
    img_array = rows
    return img_array

In [11]:
for templateWindowSize in [20, 40, 60, 80]:
    for n_segment in [200, 300, 400]:
        for ksize in [3, 5, 9, 13]:
            img_array = smooth_graphcut_grabcut_medianblur(
                templateWindowSize=templateWindowSize, 
                n_segments=300, rect_shape=(0.1, 0.1, 0.9, 0.9), ksize=ksize)
            title = 'smooth_graphcut_grabcut_medianblur'
            savepath = os.path.join(title, f'templateWindowSize({templateWindowSize})_n_segment({n_segment})_ksize({ksize})'+'.png')
            gridplot(img_array=img_array, title=title, savepath=savepath)

##### trial_2

In [None]:
def by_ksize(templateWindowSize=40, n_segments=300, rect_shape=(0.1, 0.1, 0.9, 0.9), ksize=11):
    rows = []
    for img in imgs:
        im0 = img
        im1 = filt(im0, smooth(templateWindowSize=templateWindowSize))
        im2 = filt(im1, graphcut(n_segments=n_segments))
        im3 = filt(im2, grabcut(segmap=True, rect_shape=rect_shape))
        im4 = filt(im3, medianblur(ksize=ksize))
        cols = [im0, im1, im2, im3, im4]
        rows.append(cols)
    img_array = rows
    return img_array

In [None]:
for ksize in [3, 5, 9, 13]:
    img_array = smooth_graphcut_grabcut_medianblur(
        templateWindowSize=templateWindowSize, 
        n_segments=300, rect_shape=(0.1, 0.1, 0.9, 0.9), ksize=ksize)
    title = 'smooth_graphcut_grabcut_medianblur'
    savepath = os.path.join(title, f'templateWindowSize({templateWindowSize})_n_segment({n_segment})_ksize({ksize})'+'.png')
    gridplot(img_array=img_array, title=title, savepath=savepath)