In [None]:
"""
ORIGINALLY kaka_processing.py IN THE VIRTUAL DRIVE
This is code from Arianna Salili-James, who was a PhD student of Stephen Marsland's.
It aims to segment the kaka's whole head, potentially including the nozzle since that is included.
To run:
    Modify the dirty hack at the bottom of the script
Next steps:
(1) verify it a bit more 
(2) detect and remove the nozzle properly -- colour? shape?
(3) extract either the contour or some points, to be discussed
"""

import numpy as np
import skimage.io as io
from skimage.filters import threshold_otsu
from skimage import measure
import matplotlib.pyplot as plt
import cv2
from copy import deepcopy
from collections import Counter
import matplotlib.colors as colors
import matplotlib.cm as cmx
import re
import matplotlib
matplotlib.use('Agg')

def round_pixel_colours(image, base=5):
    # Aim: round pixel colours to the nearest multiple of your choosing (default 5).

    new_img = np.zeros(np.shape(image), dtype="uint8")
    for i, pix in enumerate(image):
        new_img[i] = np.array(base * np.round(pix / base), dtype="uint8")

    return new_img

def recolour(scalarMap,image_colours,image):
  I = deepcopy(image)
  for i,c_orig in enumerate(image_colours):
    c = list(np.array(scalarMap.to_rgba(i)[:3])*255)
    c = list(np.int_(c))
    orig_col = np.int_(re.findall("\d+",c_orig))
    I[np.where(I==orig_col)[:2]] = c
  return I

def segment(img_orig,img_orig_g,f,plot=False):
    """### Segment main object"""

    # Otsu and Marching Squares (standard)
    thresh = threshold_otsu(img_orig_g)
    binary = img_orig_g > thresh
    contours = measure.find_contours(binary, 0.8)
    k = np.argmax([len(c[:,0]) for c in contours])
    X = contours[k][:,1]
    Y = contours[k][:,0]

    a = np.array([np.array([X,Y]).T], dtype=np.int32)
    img_filled = deepcopy(img_orig)
    img_filled = cv2.fillPoly(img_filled, a, [255,0,0])

    if plot:
        fig,ax = plt.subplots(1,2,figsize=(10,4))
        ax[0].imshow(img_orig)
        ax[0].plot(X,Y)
        ax[1].imshow(img_filled)
        for i in range(0,2):
            ax[i].axis('off')

    """### Colour patches"""
    # Round pixel colours:
    img_round = round_pixel_colours(img_orig,base=10)
    img_bg_rm = deepcopy(img_round)
    # Remove background colours:
    img_bg_rm[np.where(img_filled!=[255,0,0])[:2]] = [255,255,255]
    
    patch_size = 20

    img_patches = deepcopy(img_bg_rm)

    a,b,_ = np.shape(img_patches)

    inds_a = list(np.arange(0,a,patch_size))+[a]
    inds_b = list(np.arange(0,b,patch_size))+[b]

    all_new_cols = []

    for i,a1 in enumerate(inds_a[:-1]):
        a2 = inds_a[i+1]
        for j,b1 in enumerate(inds_b[:-1]):
            b2 = inds_b[j+1]
            img_patch = img_bg_rm[a1:a2,b1:b2]
            k = (a2-a1)*(b2-b1)
            cols = img_patch.reshape(k,3)
            cols_round = [str(p) for p in cols if list(p)!=[255,255,255]]
            if len(cols_round)>0:
                count_ = Counter(cols_round)
                vals = list(count_.values())
                m = np.int_(re.findall("\d+", np.array(list(count_.keys()))[np.argmax(vals)]))
                img_patches[a1:a2,b1:b2] = list(m)
                all_new_cols.append(str(m))
        
    img_patches[np.where(img_bg_rm==[255,255,255])[:2]] = [255,255,255]

    # Get new colour scale to differentiate patches better:
    # Note that the colour -> new_colour mapping is a bit random at the moment.
    # i.e., a new yellow patch doesn't necessarily mean the original
    # (mode) patch colour is closer to the new green patch, than a new pink one etc.
    # [This is just a simple test - could definitely do with some editing.]
    
    patch_colours = np.unique(all_new_cols)
    N = len(patch_colours)

    jet = cm = plt.get_cmap('hsv') 
    cNorm  = colors.Normalize(vmin=0, vmax=N)
    scalarMap = cmx.ScalarMappable(norm=cNorm, cmap=jet)

    img_patches_coloured = recolour(scalarMap,patch_colours,img_patches)

    if plot:
        fig,ax = plt.subplots(1,3,figsize=(17,6))
        ax[0].imshow(img_orig)
        ax[1].imshow(img_patches)
        ax[2].imshow(img_patches_coloured)
        for i in range(0,3):
            ax[i].axis('off')
    else:
        plt.imshow(img_patches)
        pl.axis('off')

    plt.savefig(f[:-4]+'_out.jpg')
    plt.close()

# in loop, load, read, save

from os import listdir
from os.path import isfile, join

onlyfiles = [f for f in listdir('.') if isfile(join('.', f))]
for f in onlyfiles:
    if f.startswith("GH017274"):
        img_orig = io.imread(f)
        img_orig_g = io.imread(f,as_gray=True)
        segment(img_orig,img_orig_g,f)

