In [6]:
import matplotlib.pyplot as plt
import os
from os.path import join
import numpy as np
from PIL import Image
import matplotlib.image as mpimg
from skimage.color import rgb2gray
from skimage.color import label2rgb
from skimage.filters import gaussian
from sklearn.cluster import KMeans

plt.close('all')
clear = lambda: os.system('clear')
clear()

np.random.seed(110)

colors = [[1,0,0],[0,1,0],[0,0,1],[0,0.5,0.5],[0.5,0,0.5]]

imgNames = ['water_coins','jump','tiger']#{'balloons', 'mountains', 'nature', 'ocean', 'polarlights'};
segmentCounts = [2,3,4,5]
images_last = [[] for i in range(len(imgNames))]
k = 0

In [None]:
for imgName in imgNames:
    for SegCount in segmentCounts:
              
        img = mpimg.imread('./Input/'+imgName+'.png')
        print('Using Matplotlib Image Library: Image is of datatype ',img.dtype, 'and size ',img.shape) # Image is of type float 

        img = Image.open('./Input/'+imgName+'.png')        
        img = np.array(img)
        print('Using Pillow (Python Image Library): Image is of datatype ',img.dtype,'and size ', img.shape)# Image is of type uint8
        
        nSegments = SegCount
        width, height, ncolours = img.shape
        nPixels = width * height
        maxIterations = 20;
        nColors = 3;
        
        outputPath = join(''.join(['Output/',str(SegCount), '_segments/', imgName , '/'])); 
        if not(os.path.exists(outputPath)):
            os.makedirs(outputPath)
      
        plt.imsave(outputPath+'0.png', img)
        
        pixels = img
        pixels = np.reshape(pixels, (nPixels, nColors, 1))
    
        pi = 1/nSegments*(np.ones((nSegments, 1),dtype='float'))
        increment = np.random.normal(0,.0001,1)
        for seg_ctr in range(len(pi)):
            if(seg_ctr%2==1):
                pi[seg_ctr] = pi[seg_ctr] + increment
            else:
                pi[seg_ctr] = pi[seg_ctr] - increment
        
        mu = 1/nSegments*(np.ones((nSegments, nColors), dtype='float'))
        for seg_ctr in range(nSegments):
            if(seg_ctr%2==1):
                increment = np.random.normal(0,.0001,1)
            for col_ctr in range(nColors):
                 if(seg_ctr%2==1):
                    mu[seg_ctr,col_ctr] = np.mean(pixels[:,col_ctr]) + increment
                 else:
                    mu[seg_ctr,col_ctr] = np.mean(pixels[:,col_ctr]) - increment;
       
        mu_last_iter = mu;
        pi_last_iter = pi;
        
        for iteration in range(maxIterations):
            print(''.join(['Image: ',imgName,' nSegments: ',str(nSegments),' iteration: ',str(iteration+1), ' E-step']))
            Ws = np.ones((nPixels,nSegments),dtype='float') 
            for pix_ctr in range(nPixels):
                logAjVec = np.zeros((nSegments,1),dtype='float')
                for seg_ctr in range(nSegments):
                    x_minus_mu_T  = np.transpose(pixels[pix_ctr,:]-(mu[seg_ctr,:])[np.newaxis].T)
                    x_minus_mu    = ((pixels[pix_ctr,:]-(mu[seg_ctr,:])[np.newaxis].T))
                    logAjVec[seg_ctr] = np.log(pi[seg_ctr]) - .5*(np.dot(x_minus_mu_T,x_minus_mu))
                logAmax = max(logAjVec.tolist()) 
                
                thirdTerm = 0;
                for seg_ctr in range(nSegments):
                    thirdTerm = thirdTerm + np.exp(logAjVec[seg_ctr]-logAmax)
                
                for seg_ctr in range(nSegments):
                    logY = logAjVec[seg_ctr] - logAmax - np.log(thirdTerm)
                    Ws[pix_ctr][seg_ctr] = np.exp(logY)
            
            print(''.join(['Image: ',imgName,' nSegments: ',str(nSegments),' iteration: ',str(iteration+1), ' M-step: Mixture coefficients']))
            
            mu = np.zeros((nSegments,nColors),dtype='float') # mean color for each segment
            pi = np.zeros((nSegments,1),dtype='float') #mixture coefficients

            
            for seg_ctr in range(nSegments):

                denominatorSum = 0;
#                 for pix_ctr in range(nPixels):
#                     mu[seg_ctr] = mu[seg_ctr] + (pixels[pix_ctr,:]* Ws[pix_ctr][seg_ctr]).reshape(nColors)
#                     """Update RGB color vector of mu[seg_ctr] as current mu[seg_ctr] + pixels[pix_ctr,:] times Ws[pix_ctr,seg_ctr] -- 5 points"""
#                     denominatorSum = denominatorSum + Ws[pix_ctr][seg_ctr]
                
                denominatorSum = sum(Ws[:,seg_ctr])
                mu[seg_ctr] = mu[seg_ctr] + np.dot(np.transpose(np.squeeze(pixels)), Ws[:, seg_ctr])
                    
                mu[seg_ctr,:] =  mu[seg_ctr,:]/ denominatorSum;
                pi[seg_ctr] = denominatorSum / nPixels;
                    
            print(np.transpose(pi))

            muDiffSq = np.sum(np.multiply((mu - mu_last_iter),(mu - mu_last_iter)))
            piDiffSq = np.sum(np.multiply((pi - pi_last_iter),(pi - pi_last_iter)))

            if (muDiffSq < .0000001 and piDiffSq < .0000001): #sign of convergence
                print('Convergence Criteria Met at Iteration: ',iteration, '-- Exiting code')
                break;
            
            mu_last_iter = mu;
            pi_last_iter = pi;
            
            
            segpixels = np.array(pixels)
            cluster = 0
            for pix_ctr in range(nPixels):
                cluster = np.where(Ws[pix_ctr,:] == max(Ws[pix_ctr,:]))
                vec     = np.squeeze(np.transpose(mu[cluster,:])) 
                segpixels[pix_ctr,:] =  vec.reshape(vec.shape[0],1)
                
            segpixels = np.reshape(segpixels,(img.shape[0],img.shape[1],nColors))
            segpixels = np.array(Image.fromarray(segpixels).convert('L'))
            kmeans = KMeans(nSegments).fit(segpixels.reshape(nPixels,1))
            seglabels = kmeans.labels_.reshape(segpixels.shape)
            
            seglabels = np.clip(seglabels, a_min = 0, a_max = nSegments)
            seglabels = label2rgb(seglabels, segpixels, colors)
            seglabels = gaussian(seglabels, sigma = 2, multichannel = True)
            mpimg.imsave(''.join([outputPath,str(iteration+1),'.png']),seglabels)
        images_last[k].append(seglabels)    
    k += 1
for i in range(len(images_last)):
    for img in images_last[i]:
        plt.imshow(img)
        plt.show()