In [None]:
# IMPLEMENDTED BY 20185183 CHUNGGI LEE

import numpy as np

import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
%matplotlib inline

from random import *

from copy import deepcopy

# Class for Data Generation
class data_gen:
    # Select Two Constraints
    def __init__(self, balance, spherical):
        self._balance = balance
        self._spherical = spherical
    
    # Make Dataset with Two Constraints 
    def pos_gen(self, start, end, interval):
        x, y = np.mgrid[start:end:interval, start:end:interval]
        pos = np.dstack((x, y))
        return pos
    
    # Generate Multivariate Gaussian Dstribution
    def generate(self):
        # Balance or Not
        pos1, pos2 = [], []
        if self._balance:
            pos1, pos2 = (10, 10), (10, 10)
        else:
            pos1, pos2 = (10, 10), (20, 20)
        
        # Sphrical or Not
        rv = None
        if self._spherical:
            rv1 = np.random.multivariate_normal(mean=[0, 0], cov=[[1, 0],[0, 1]], size=pos1)
            rv2 = np.random.multivariate_normal(mean=[3, 3], cov=[[1, 0],[0, 1]], size=pos2)
        else:
            rv1 = np.random.multivariate_normal(mean=[0, 0], cov=[[1, -0.5],[-0.5, 1]], size=pos1)
            rv2 = np.random.multivariate_normal(mean=[3, 3], cov=[[1, 0.5],[0.5, 1]], size=pos2)
        
        # Set Color
        color1, color2 = np.array([0 for i in range(pos1[0] * pos1[1])]), np.array([1 for i in range(pos2[0] * pos2[1])])
        color = np.concatenate((color1, color2))
        return np.concatenate((np.reshape(rv1, (-1,2)), np.reshape(rv2, (-1,2))), axis=0), color

# Function For Data Plot
def plot(data, color):
    color_list = np.array(["#1f77b4", "#ff7f0e"])
    plt.scatter(data[:,0], data[:,1], color=color_list[color])
    plt.show() 

In [None]:
# "Balanced" and "Spherical"
data, color = data_gen(True, True).generate()
plot(data, color)

In [None]:
# "Imbalanced" and "Spherical"
data, color = data_gen(False, True).generate()
plot(data, color)

In [None]:
# "Balanced" and "Non-spherical"
data, color = data_gen(True, False).generate()
plot(data, color)

In [None]:
# "Imbalanced" and "Non-spherical"
data, color = data_gen(False, False).generate()
plot(data, color)

In [None]:
from scipy.spatial import distance
from matplotlib.animation import FuncAnimation
from matplotlib.colors import to_rgb

# Class for Hard and Soft KMeans
class KMeans:
    def __init__(self, k, data, iter_num = None, is_hard = True, beta=None, is_image = False):
        """
        _k : # of clusters
        _is_image : Whether the Input is an Image or Not
        _res : Responsibilities
        _prev_center : Store Previous Center
        _data : Data
        _iter_num : Set the Iteration Number, If _iter_num is None, then it will not terminate until satifying condtion.
        _is_hard : Set Hard or Soft KMeans Mode
        _beta : Set the beta value for Soft KMeans
        _mx, _mn : Get Max and Min Value from Data
        """
        self._k = k
        self._is_image = is_image
        self._res = None
        self._prev_center = None
        self._data = data
        self._iter_num = iter_num
        self._is_hard = is_hard
        self._beta = beta
        self._mx, self._mn = np.max(data), np.min(data)
        
        # Reshape for image
        if is_image:
            self._img_shape = data.shape
            self._data = np.reshape(self._data, (-1, 3))
            
        # Initailize the Center Randomly
        self._center = self.init_center(self._data)
    
    def init_center(self, data):
        """
        Initailize the Center Randomly
        """
        #Set Previous Center using Dummy Value
        self._prev_center = np.array([99999 for i in range(0,data.shape[-1])])
        # Make Centers using Uniform Distribution
        centers = np.array([[uniform(self._mn,self._mx) for i in range((self._data.shape[-1]))] for i in range(0, self._k)])
        return centers

    def plot(self):
        """
        Plot the Intemediate and Final Result
        """
        d1, center = self._data, self._center
        
        # Set the 20 Colors
        colors = np.array(["#1f77b4", "#ff7f0e", "#2ca02c", "#d62728",  
                      "#9467bd", "#8c564b", "#e377c2", "#7f7f7f",  
                      "#bcbd22", "#17becf", "#aec7e8", "#ffbb78", 
                      "#98df8a", "#ff9896", "#c5b0d5", "#c49c94", 
                      "#f7b6d2", "#c7c7c7", "#dbdb8d", "#9edae5"])
        
        # For Initial Iteration (There is no Responsibility)
        if self._res is None and self._is_hard:
            self._res = np.array([0 for i in range(0, len(self._data))])
        elif self._res is None and not self._is_hard:
            self._res = np.array([[0.0 for j in range(0, self._k)] for i in range(0, len(self._data))])
        
        # No Image --> Scatterplot, Image --> Image
        if not self._is_image:
            # Hard : Discrete Color, Soft : Continuous Color
            if self._is_hard:
                plt.scatter(d1[:,0], d1[:,1], color=colors[self._res])
                plt.scatter(center[:,0], center[:,1], marker="*", s=200, color="red")
                plt.show() 
            else:
                colors = np.array([list(to_rgb(i)) for i in colors[:self._k]])
                color_list = self._res.dot(colors)
                
                plt.scatter(d1[:,0], d1[:,1], c=color_list)
                plt.scatter(center[:,0], center[:,1], marker="*", s=200, color="red")
                plt.show() 
        else:
            res = self._res
            
            int_center = np.uint8([[int(j) for j in i] for i in center])
            
            # Hard : Discrete Color, Soft : Continuous Color
            if self._is_hard:
                res = np.reshape(self._res, (self._img_shape[:-1]))
                plt.imshow(cv2.cvtColor(int_center[res], cv2.COLOR_BGR2RGB))
                plt.show()
            else:
                #res = np.reshape(self._res, (self._img_shape))
                color_list = np.reshape(self._res.dot(int_center).astype(np.uint8), (self._img_shape))
                plt.imshow(cv2.cvtColor(color_list, cv2.COLOR_BGR2RGB))
                plt.show()
    
    def estep(self):
        """
        Assign all data points to the cluster for which find argmin value
        Compute Responsibilities
        """
        distance = None
        for n, j in enumerate(self._center):
            dis, res = None, None
            # Hard and Soft / Compute ||Data - Center||^2
            if self._is_hard:
                dis = np.sum(np.square(self._data - j), axis=-1)
            else:
                dis = np.exp(-self._beta * np.linalg.norm(self._data - j, 2, axis=-1))
            
            # Concate all result
            if n == 0:
                distance = dis
            else:
                distance = np.dstack((distance, dis))
        
        # Hard : Argmin, Soft : Normalization
        if self._is_hard:
            res = np.reshape(np.argmin(distance, axis=-1), -1)
        else:
            distance = distance[0]
            res = distance / np.sum(distance, axis=-1, keepdims=True)
        self._res = res
        return res
    
    def mstep(self, response):
        """
        Compute new sample means for every cluster
        Update Each Cluster
        """
        for i in range(0, self._k):
            data_sum, size = None, None
            # Compute new cluster and Update
            if self._is_hard:
                data_sum = np.sum(self._data[response==i], axis=0) 
                size = response[response==i].size
            else:
                data_sum = np.dot(response[:,i], (self._data))
                size = np.sum(response[:,i])
                
            avg = np.true_divide(data_sum, size)
            
            if np.isnan(avg).any():
                self._center[i] = np.array([uniform(self._mn,self._mx) for i in range((self._data.shape[-1]))])
            else:
                self._center[i] = np.array(avg)
        self._center = np.array(self._center)
        
    def _train(self):
        """
        Train Hard and Soft KMeans using EM-algorithm
        """
        self._prev_center = np.array(list(self._center))
        # E-Step
        response = self.estep()
         
        # M-step
        self.mstep(response)
    
    def cost_func(self):
        """
        Compute Difference between Current and Previous Centers
        """
        return not ((self._prev_center-np.array(self._center)) < 1e-5).all()
    
    def train(self, plot_num=1):
        """
        Decide Whether Stop or Not and Plot Number
        """
        self.plot()
        num_cost = 0
        if self._iter_num == None:
            flag = self.cost_func() 
            while(flag):
                if num_cost > 1000:
                    break
                flag = self.cost_func() 
                self._train()
                if num_cost % plot_num == 0:
                    self.plot()
                num_cost += 1
        else:
            for i in range(0, self._iter_num):
                self._train()
                if i % plot_num == 0:
                    self.plot()
                prev_center = self._center

In [None]:
import cv2
img = cv2.imread('images/samoyed.jpg', cv2.IMREAD_COLOR)

In [None]:
import cv2
img = cv2.imread('images/sea.jpg', cv2.IMREAD_COLOR)

In [None]:
# Soft KMeans, Image, While Iteration
kmeans = KMeans(k=10, data=img, iter_num=None, is_hard=False, beta=1, is_image=True)
kmeans.train()
kmeans.plot()

In [None]:
# Hard KMeans, Image, While Iteration
kmeans = KMeans(k=15, data=img, iter_num=None, is_hard=True, beta=None, is_image=True)
kmeans.train()
kmeans.plot()

In [None]:
# Soft KMeans, Not Image, While Iteration
kmeans = KMeans(k=10, data=data, iter_num = None, is_hard = False, beta = 5, is_image=False)
kmeans.train()
kmeans.plot()

In [None]:
# Hard KMeans, Not Image, Range Iteration
kmeans = KMeans(k=2, data=data, iter_num=10, is_hard = True, beta = None, is_image=False)
kmeans.train()
kmeans.plot()

In [None]:
# Hard KMeans, Not Image, While Iteration
kmeans = KMeans(k=10, data=data, iter_num = None, is_hard = True, beta = None, is_image=False)
kmeans.train()
kmeans.plot()