# EM Algorithm for image segmentation 

In [None]:
import numpy as np

## EM Algorithm
#### Performance 
For better performance, the following principles are followed:
1. In order to avoid high cpu usage, loops in python are avoided
   (especially loops over the size of the dataset).
   Numpy opetations (in C) over whole arrays are used instead.
2. In order to avoid high memory and cpu usage,
   array allocations are avoided and intermediate and final results
   are stored in preallocated buffers.

### Mathematics

In [None]:
def spherical_gaussian(data_m_d, mean_c_d, std_c,
                       buffer_m_d, buffer_c,
                       out_m_c):
    k = len(mean_c_d)
    # math: buffer_c = 2*std
    np.multiply(std_c, 2, out=buffer_c)
    for c in xrange(k):
        # math: buffer_m_d = -(x-m_c)^2
        np.subtract(data_m_d, mean_c_d[c], out=buffer_m_d)
        np.square(buffer_m_d, out=buffer_m_d)
        np.multiply(buffer_m_d, -1, out=buffer_m_d)
        # math: buffer_m_d = exp{ [-(x-m_c)^2] / 2*std }
        np.divide(buffer_m_d, buffer_c[c], out=buffer_m_d)
        np.exp(buffer_m_d, out=buffer_m_d)
        #  math: apply product to exp{ [-(x-m_c)^2] / 2*std }
        np.prod(buffer_m_d, axis=1, out=out_m_c.T[c])
        
    # math: buffer_c = sqrt{2*pi*std}
    np.multiply(buffer_c, np.pi, buffer_c)
    np.sqrt(buffer_c, out=buffer_c)
    # divide all by sqrt{2*pi*std} ^ k
    # iterative divisions to avoid +Inf for big k
    for c in xrange(k):
        np.divide(out_m_c, buffer_c, out=out_m_c)

    
# Batch sum to avoid very big with very small additions
def batch_sum_axis0(arr_m_k, buffer_k, out_k):
    m = len(arr_m_k)
    out_k.fill(0)
    for i in xrange(1 + m // 1000):
        start = 1000*i
        end = min(1000*(i+1), m)
        arr_m_k[start:end, :].sum(axis=0, out=buffer_k)
        np.add(out_k, buffer_k, out=out_k)
    
    
def reconstruction_error(current_m_d, initial_m_d, buffer_m_d):
    np.subtract(current_m_d, initial_m_d, out=buffer_m_d)
    np.square(buffer_m_d, out=buffer_m_d)
    return buffer_m_d.mean()

### EM Algorithm Implementation
#### Cluster Initialization
The **means** of the clusters are initialized as a *linspace* around the total mean of the data.<br>
The *bounds of the linspace* are set for each dimension with the following rules:
> `Lower bound = mean_d - 0.2*k*std_d` <br>
> `Upper bound = mean_d + 0.2*k*std_d` <br>
>   where `mean_d`, `std_d` are the mean and std of the dimension (independently)
> * if the lower (upper) bound differs from the mean more than 2.5*std it is adjusted to:
>   `mean-2.5*std` (`mean+2.5*std`) <br>
> * The lower bound cannot be lesser than the total minimum of the dimension <br>
> * The upper bound cannot be greater than the total maximum of the dimension

The **std** of all clusters is set to 1 <br>
The **pi** of all clusters is set to 1/k

#### Algorithm termination
The algoritmh terminates when one the folowing occurs:
* Maximun number of iterations is reached
* All convergence criteria are met


**Convergence criteria** take into account the difference that was noticed on the 
cluster means and stds. Specifically:
1. The mean absolute diffenrence between the current and previous cluster means
   (over all dimensions) is lesser than `d / sqrt(k)` and
2. The mean absolute diffenrence between the current and previous cluster sdts
   is lesser than `d / sqrt(k)`
   
#### Numerical stability
Measures taken for numerical stability include: 
* Summation over the gamas is done in batches to avoid additions
  between very large and very small values
* In the case a datum is away from all cluster means, and thus all gaussian probality 
  densities are aproximated with zero, the datum is assigned equally over all clusters
  (avoiding devision by zero)
* Divisions of the form `a/b^k` is performed via `k` successive divisons with `b`
  to avoid errors when `k` is big
* Minimum values are defined for extreme cases such as when a cluster has zero 
  density over all the data

#### Logging and visualizing intermediate results
The EM class (which encapsulates the algorithm) allow for a custom function
to be called after cluster initialization and each iteration. This enables the
ability logging and visualizing intermediate stages of the algorithm application.

In [None]:
min_con = 0.01
min_gama = 1e-10
epsilon = 0.001
min_gama_sum = 1e-4

class EM:
    def __init__(self, data, k):
        self.data = data
        self.k = k
        self.d = d = len(data[0])
        self.m = m = len(data)
        self.gama = np.empty((m, k))  # also known as r_ic
        self.cluster_means = np.empty((k, d))
        self.cluster_std = np.empty(k)
        self.cluster_p = np.empty(k)
        # keep old to compare for convergence test
        self.old_cluster_means = np.empty((k, d))
        self.old_cluster_std = np.empty(k)
        
        # allocate buffers for intermediate results
        self.buffer_m_d = np.empty((m, d))
        self.buffer_k_d = np.empty((k, d))
        self.buffer_c = np.empty(k)
        self.buffer_c2 = np.empty(k)
        self.buffer_d = np.empty(d)
        self.buffer_m = np.empty(m)
        self.buffer_m_int = np.empty(m, dtype=int)
        
    def expectation(self):
        spherical_gaussian(self.data, self.cluster_means, self.cluster_std,
                          self.buffer_m_d, self.buffer_c, out_m_c=self.gama)
        # mupltiply by pi
        np.multiply(self.gama, self.cluster_p, out=self.gama)
        # now normalize gama to have sum=1 for each case
        self.gama.sum(axis=1, out=self.buffer_m)
        ## but first! take care of the cases where all clusters have 0 proba
        zero_sum_cases = self.buffer_m==0  # that's a boolean array
        self.buffer_c.fill(1/(self.k+0.0))
        self.gama[zero_sum_cases] = self.buffer_c
        self.buffer_m[zero_sum_cases] = 1  # no need to normalize those
        np.divide(self.gama, self.buffer_m[None].T, out=self.gama)
        
    def maximization(self):
        # sum by batches to avoid arithmetic errors
        batch_sum_axis0(self.gama, self.buffer_c2, out_k=self.buffer_c)
        self.buffer_c[self.buffer_c==0] = min_gama_sum
        self.gama.T.dot(self.data, out=self.cluster_means)
        np.divide(self.cluster_means, self.buffer_c[None].T, out=self.cluster_means)
        
        np.divide(self.buffer_c, self.m, out=self.cluster_p)
        assert np.abs(self.cluster_p.sum() - 1.0) < epsilon, "sum of pi is != 1.0: %r" % self.cluster_p.sum()
        for c in xrange(self.k):
            cov = np.zeros(self.d)
            np.subtract(self.data, self.cluster_means[c], out=self.buffer_m_d)
            np.square(self.buffer_m_d, out=self.buffer_m_d)
            np.multiply(self.buffer_m_d, self.gama.T[c][None].T, out=self.buffer_m_d)
            self.buffer_m_d.sum(axis=1, out=self.buffer_m_d.T[0])
            np.divide(self.buffer_m_d.T[0], self.buffer_c[c], out=self.buffer_m_d.T[0])
            self.cluster_std[c] = max(np.sqrt(self.buffer_m_d.T[0].sum()), min_con)  
            
    def init_clusters(self):
        std = np.std(self.data, axis=0)
        mean = np.mean(self.data, axis=0)
        min_case = self.data.min(axis=0)
        max_case = self.data.max(axis=0)
        for d in xrange(self.d):
            bound = min(0.2*self.k*std[d], 2.5*std[d])
            lower_bound = -bound
            upper_bound = bound
            if lower_bound+mean[d] < min_case[d]:
                lower_bound = min_case[d]-mean[d]
            if upper_bound+mean[d] > max_case[d]:
                upper_bound = max_case[d]-mean[d]
            self.cluster_means.T[d] = np.linspace(lower_bound, upper_bound, num=self.k)
        np.add(self.cluster_means, mean, out=self.cluster_means)
        self.cluster_std[:] = 1
        self.cluster_p[:] = 1/(self.k+0.0)
    
    def converged(self):
        np.subtract(self.cluster_means, self.old_cluster_means, out=self.buffer_k_d)
        np.abs(self.buffer_k_d, out=self.buffer_k_d)
        np.subtract(self.cluster_std, self.old_cluster_std, out=self.buffer_c)
        np.abs(self.buffer_c, out=self.buffer_c)
        if self.buffer_k_d.mean() > self.d/np.sqrt(self.k):
            return False
        if self.buffer_c.mean() > self.d/np.sqrt(self.k):
            return False
        return True
    
    def em(self, repetations=3, logger=None):
        self.init_clusters()
        self.expectation()
        if logger is not None: logger(self.cluster_means, self.cluster_std, self.cluster_p, self.gama, 0)
        for repet in xrange(repetations):
            # swap buffers
            temp = self.cluster_means
            self.cluster_means = self.old_cluster_means
            self.old_cluster_means = temp
            temp = self.cluster_std
            self.cluster_std = self.old_cluster_std
            self.old_cluster_std = temp
            
            self.maximization()
            self.expectation()
            if logger is not None: logger(self.cluster_means, self.cluster_std, self.cluster_p, self.gama, repet+1)
            if self.converged(): break
                
    def segment(self, out):
        data = self.data
        np.argmax(self.gama, axis=1, out=self.buffer_m_int)
        max_like = self.buffer_m_int
        for i in xrange(len(data)):
            out[i] = self.cluster_means[max_like[i]]

    def error(self, segmented):
        return reconstruction_error(self.data, segmented, buffer_m_d=self.buffer_m_d)

## Image Application

### Implement ImageLogger to log the results of the EM algorithm on the image
This class logs and visualizes the intermediate and final results of the EM algorithm
over an image.

In [None]:
import pandas as pd

def to_image(segmented, size):
    segmented = segmented.reshape((size[0], size[1], 3))
    segmented = segmented.astype(np.uint8)
    return Image.fromarray(segmented)

def log_form_mean(mean, small=False):
    return str(np.round(mean, 0).astype(int)) if small else str(np.round(mean, 1))

class ImageLogger:
    def __init__(self, em, size, display_image=True, display_error=True,  display_clusters=True):
        self.em = em
        self.results = pd.DataFrame(columns=['Clusters', 'Iteration', 'Error'])
        self.size = size
        self.out = np.empty((len(em.data), 3))
        self.display_image = display_image
        self.display_error = display_error
        self.display_clusters = display_clusters
        self.display = display_image or display_error or display_clusters
    
    def __call__(self, means, std, p, _, repet):
        self.em.segment(self.out)
        error = self.em.error(self.out)
        self.results.loc[len(self.results)] = [self.em.k, repet, error]
        if not self.display: return
        head = 'Iteration '+str(repet) if repet > 0 else 'Initial assignment'
        output = ('<div style="margin-top:1rem">'
            +'<div style="margin-bottom:0.7rem;color:#264747;font-size:1.8rem">'+head+'</div>')
        if self.display_clusters:
            std = np.round(std, 2)
            for c in xrange(len(means)):
                pc = str(np.round(p[c]*100, 3))+'%'
                mc = log_form_mean(means[c])
                mc_small = log_form_mean(means[c], True)
                wrap_style = "display:inline-block;text-align:left;margin-right:1.5rem;margin-bottom:0.5rem;line-height: 1.3;color:darkslategray"
                style = "background:rgb("+str(means[c][0])+","+str(means[c][1])+","+str(means[c][2])+");display:inline-block;width:7rem;height:2em;border:1px solid darkgray"
                output = output+'<div style="'+wrap_style+'"><div style="'+style+'" title="'+mc+'"></div><br>&nbsp;'+mc_small+'<br> &nbsp;std: '+str(std[c])+"<br/> &nbsp;p: &nbsp;&nbsp; "+pc+"</div>"
        if self.display_error:
            output = output+"<div style='margin-top:0.1rem'><strong>Error: "+str(error)+"</strong></div>"
        display(HTML(output+"</div>"))
        if self.display_image:
            display(self.get_image())
        
    def get_image(self):
        return to_image(self.out, self.size)

### Load Image

In [None]:
from PIL import Image
from IPython.display import display
from IPython.display import HTML

image = Image.open("data/traino.jpg")
image

### Now,  test the algorithm with K=32
We will run the algorithm and display the progression of the error and of
the progression color clusters for each iteration.

Colors correspond to the cluster means.

(The resulting image is not displayed here to reduce the size of this 
file however it will be available at "results/k-32.png" upon completion of this notebook)

In [None]:
data = np.asarray(image)
print 'Original Input Size:', len(data), "X", len(data[0])

data_flat = data.astype(float).reshape((len(data)*len(data[0]), 3))
print 'Flattened Size:', len(data_flat)

em = EM(data_flat, 32)
logger = ImageLogger(em, (len(data), len(data[0])), display_image=False)
em.em(20, logger)

display(logger.results)
#display(logger.get_image())

### Test and obtain results for K=2, 4, 8, 16, 24, 32
NOTE: Maximum iterations are set to 20

In [None]:
import pandas as pd
import os, errno

# make directory for results
try: os.makedirs("results")
except OSError as e:
     if e.errno != errno.EEXIST: raise

results = pd.DataFrame(columns=['Clusters', 'Iterations', 'Error'])
KS_TO_TEST = [2, 4, 8, 16, 24, 32]
IMAGE_COMPARISSON_HEIGHT = 1050
image_comparisson_width = int(np.ceil((float(IMAGE_COMPARISSON_HEIGHT) / ((len(KS_TO_TEST)+1) // 2)) * 2 * (float(image.width) / image.height)))
image_comparisson = Image.new("RGB", (image_comparisson_width, IMAGE_COMPARISSON_HEIGHT))

y = 0
for i, k in enumerate(KS_TO_TEST):
    em = EM(data_flat, k)
    logger = ImageLogger(em, (len(data), len(data[0])),
         display_image=False, display_clusters=False, display_error=False)
    em.em(20, logger)
    logger.results.to_csv("results/k-"+str(k)+".csv",sep = ',')
    image_k = logger.get_image()
    image_k.save("results/k-"+str(k)+".png")
    logger.results.rename(columns={'Iteration': 'Iterations'}, inplace=True)
    last_row = logger.results.iloc[-1]
    results.loc[len(results)] = last_row
    image_k.thumbnail((image_comparisson_width//2, IMAGE_COMPARISSON_HEIGHT//(len(KS_TO_TEST) // 2)), Image.ANTIALIAS)
    w, h = image_k.size
    x = 0 if i%2==0 else w
    image_comparisson.paste(image_k, (x, y, x+w, y+h))
    if i%2==1:
        y = y+h
    print "k =", k, "done"
    

results.to_csv("results/all-k.csv",sep = ',')
image_comparisson.save("results/all-k.png")
display(results)
display(image_comparisson)

In [None]:
import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline
plt.figure(0, (8, 6))
plt.plot(results['Clusters'], results['Error'])
plt.xlabel('Clusters')
plt.ylabel('Error')
plt.title('Reconstruction error for different k')

## Weaknesses

### Known logical imlementation errors
The algorithms sometimes gives higher error on iterations that happen after it has converged.
The magnitude of the difference is very small to be noticed by naked eye

### Limits
The algorithm, on the application of image segmentation behaves well for values of k at most around 180.
For k=256, for example, it merges most of the clusters into a single cluster with the same mean and std

### Possible improvements
The implemetnation can be tweaked to allow the application of the algorithm
over an SRS of the total dataset which will dramatically improve performance.


On the final stage, in should just perform an expectation step over the total dataset
to obtain the gamas and then use them to segment the initial data.