In [2]:
import numpy as np
import maxflow
from sklearn.cluster import KMeans
from sklearn.mixture import GaussianMixture
import matplotlib.pyplot as plt
import cv2

#### Calculate $\beta$
 - Beta is the smoothness term for the boundary between bg and fg
 - When beta is 1, the boundary is perfectly smooth
 - beta > 0 is preferred to relax constraints on hard boundary b/w pixels
 
 $$
     \beta = \frac{1}{(2*\text{E}(\sqrt{||pixel_{i} - pixel{j}||})}
 $$
 
 Here, $\text{E}$ is the average over all pairs of neighbouring pixels in the image. which are $4wh -3h -3w +2$ in number.

In [3]:
class GrabCut:
    def __init__(self, img, gamma, n_components, max_iter, bgRect, mask, n_iters):
        #  x_start, x_end, y_start, y_end is what we should get after unpacking bgRect
        self.img = img.astype(np.float32)
        self.height, self.width =  list(img.shape)[:2]
        self.k = n_components
        self.gamma = gamma
        self.n_iters = n_iters
        self.bgRect = bgRect
        self.mask = mask
        self.max_iters = max_iter

        # define flags
        self.BG = 0	# confirmed bg
        self.FG = 1	# confirmed fg
        self.PR_BG = 2	# probable bg
        self.PR_FG = 3	# probable fg
        
        # params
        self.beta = 0
        
        self.set_beta() # sets beta
        self.init_graph() # initializes the graph    
        self.set_graph_weights() # sets the weights of the graph
        self.init_trimap(mask) # gets the trimap
        self.add_terminal_edges() # adds the terminal edges


    def set_beta(self):
        beta = 0
        # calculates average over an image sample for (z_m-z_n)^2 for 
        self._left_diff = self.img[:, 1:] - self.img[:, :-1] # Left-difference
        self._upleft_diff = self.img[1:, 1:] - self.img[:-1, :-1] # Up-Left difference
        self._up_diff = self.img[1:, :] - self.img[:-1, :] # Up-difference
        self._upright_diff = self.img[1:, :-1] - self.img[:-1, 1:] # Up-Right difference
        # beta is as described in the paper
        beta = (self._left_diff*self._left_diff).sum() + (self._upleft_diff*self._upleft_diff).sum() \
            + (self._up_diff*self._up_diff).sum() + (self._upright_diff*self._upright_diff).sum() # According to the formula
        self.beta = 1/(2*beta/(4*self.width*self.height - 3*self.width - 3*self.height + 2))
        # 4*self.width*self.height - 3*self.width - 3*self.height + 2 is the number of pairs of neighbouring pixels in the image

    def init_graph(self):
        self.graph = maxflow.Graph[float]()
        self.pixels = self.graph.add_grid_nodes((self.height, self.width))
        

    def set_graph_weights(self):
        for j in range(self.height):
            for i in range(self.width):
                curr_pixel = self.pixels[j, i] # src
                # find neighbors
                # set the weights of the edges which connect these neighboring pixels 

                # if has a neighbour above it
                if i > 0:
                    top_pixel = self.pixels[j-1, i]
                    smoothness = np.exp(-self.beta*(np.linalg.norm(self.img[j, i] - self.img[j-1, i])**2))
                    weight = self.gamma * smoothness
                    self.graph.add_edge(curr_pixel, top_pixel, weight, weight)
                
                # left
                if j > 0:
                    left_pixel = self.pixels[j, i-1]
                    smoothness = np.exp(-self.beta*(np.linalg.norm(self.img[j, i] - self.img[j, i-1])**2))
                    weight = self.gamma * smoothness

                    self.graph.add_edge(curr_pixel, left_pixel, weight, weight)

                # top left
                if i > 0 and j > 0:
                    topleft_pixel = self.pixels[j-1, i-1]
                    smoothness = np.exp(-self.beta*(np.linalg.norm(self.img[j, i] - self.img[j-1, i-1])**2))
                    weight = self.gamma * (1/np.sqrt(2)) * smoothness
                    self.graph.add_edge(curr_pixel, topleft_pixel, weight, weight)
                
                # top right
                if j > 0 and i < self.width-1:
                    topright_pixel = self.pixels[j-1, i+1]
                    smoothness = np.exp(-self.beta*(np.linalg.norm(self.img[j, i] - self.img[j-1, i+1])**2))
                    weight = self.gamma * (1/np.sqrt(2)) * smoothness
                    self.graph.add_edge(curr_pixel, topright_pixel, weight, weight)                

    def init_trimap(self, mask):
        self.trimap = np.ones((self.height,self.width))  * self.BG # initially all pixels are bg
        x_start, x_end, y_start, y_end = self.bgRect
        self.trimap[y_start:y_end, x_start:x_end] = self.PR_FG # set the bgRect to bg
        self.trimap[np.where(mask == 1)] = self.FG # add sure foreground pixels to the trimap

                # find neighbors
                # set the weights of the edges which connect these neighboring pixels 

                # if has a neighbour above it
                if i > 0:
                    top_pixel = self.pixels[j-1, i]
                    smoothness = np.exp(-self.beta*(np.linalg.norm(self.img[j, i] - self.img[j-1, i])**2))
                    weight = self.gamma * smoothness
                    self.graph.add_edge(curr_pixel, top_pixel, weight, weight)
                
                # left
                if j > 0:
                    left_pixel = self.pixels[j, i-1]
                    smoothness = np.exp(-self.beta*(np.linalg.norm(self.img[j, i] - self.img[j, i-1])**2))
                    weight = self.gamma * smoothness

                    self.graph.add_edge(curr_pixel, left_pixel, weight, weight)

                # top left
                if i > 0 and j > 0:
                    topleft_pixel = self.pixels[j-1, i-1]
                    smoothness = np.exp(-self.beta*(np.linalg.norm(self.img[j, i] - self.img[j-1, i-1])**2))
                    weight = self.gamma * (1/np.sqrt(2)) * smoothness
                    self.graph.add_edge(curr_pixel, topleft_pixel, weight, weight)
                
                # top right
                if j > 0 and i < self.width-1:
                    topright_pixel = self.pixels[j-1, i+1]
                    smoothness = np.exp(-self.beta*(np.linalg.norm(self.img[j, i] - self.img[j-1, i+1])**2))
                    weight = self.gamma * (1/np.sqrt(2)) * smoothness
                    self.graph.add_edge(curr_pixel, topright_pixel, weight, weight)                

    def init_trimap(self, mask):
        self.trimap = np.ones((self.height,self.width))  * self.BG # initially all pixels are bg
        x_start, x_end, y_start, y_end = self.bgRect
        self.trimap[y_start:y_end, x_start:x_end] = self.PR_FG # set the bgRect to bg
        self.trimap[np.where(mask == 1)] = self.FG # add sure foreground pixels to the trimap

    def add_terminal_edges(self):
        x,y = np.where(self.trimap == self.FG)
        for i,j in list(zip(x,y)):
            edge = self.pixels[i, j]
            self.graph.add_tedge(edge, np.inf, 0) # add an edge from source node of the flownetwork graph to the sureshot foreground pixel 
            # the edge has infinite weight in forward direction and 0 weight in the reverse directions (so that it's never cut in the mincut)
            # and its 0 in the reverse because the source is "source" and its not taking back the water (or fluid whatever)
        x,y = np.where(self.trimap == self.BG)
        for i,j in list(zip(x,y)):
            edge = self.pixels[i, j]
            self.graph.add_tedge(edge, 0, np.inf) 

    def Iterative_step(self):
        for iter in range(self.n_iters):
            bg_indices = np.where(np.logical_or(self.trimap == self.BG, self.trimap == self.PR_BG))
            fg_indices = np.where(np.logical_or(self.trimap == self.FG, self.trimap == self.PR_FG))

            self.bg_set = self.img[bg_indices]
            self.fg_set = self.img[fg_indices]

            self.init_k_means()
            BG_KM = self.kmb.fit(self.bg_set) # calculate initial means for the bg set using k means
            FG_KM = self.kmf.fit(self.fg_set) # calculate initial means for the fg set using k means

            self.init_gmm(BG_KM, FG_KM)
            self.bg_gmm.fit(self.bg_set, BG_KM.labels_) # use kmeans labels to fit the gmm
            self.fg_gmm.fit(self.fg_set, FG_KM.labels_) 

            BG_GMM = self.bg_gmm.predict(self.bg_set) # use the kmeans fitted gmm to predict gmm labels for the bg set
            FG_GMM = self.fg_gmm.predict(self.fg_set) # use the kmeans fitted gmm to predict gmm labels for the fg set

            self.bg_gmm.fit(self.bg_set, BG_GMM) # use the gmm predicted labels to fit on the gmm again
            self.fg_gmm.fit(self.fg_set, FG_GMM)

            # We've now fit the gmm and gotten the covariances etc. which helps to approximate distributions
            #  of the fg and bg pixels. This is the same as the histogram plots for gray images. E = U + V 
            # Here U = sigma D where D = -log(p(z)) - log pi, where pi are the gmm weights. p is gaussian pdf
            # D(alpha, k, theta, z) is required for each pixel, alpha = 0  for bg and 1 for fg and we calculate that separately
            # D = - log pi +0.5 * log det(sigma) + 0.5 * (z - mu)^T sigma^-1 (z - mu) for bg gmm alpha = 0 and for fg gmm alpha = 1

            self.get_cov_det()
            D_bg = -np.log(self.bg_gmm.weights_ / np.sqrt(self.bg_det)) # -log pi + 0.5 * log det(sigma)
            D_fg = -np.log(self.fg_gmm.weights_ / np.sqrt(self.fg_det)) # -log pi + 0.5 * log det(sigma)

            self.get_cov_inv()


            r_ind,c_ind = np.where(np.logical_or(self.trimap == self.PR_BG, self.trimap == self.PR_FG))
            edge_weights_bg = np.empty(shape = (self.img.shape[0],self.img.shape[1]))
            edge_weights_fg = np.empty(shape = (self.img.shape[0],self.img.shape[1]))
            for k in range(len(r_ind)): # for each probable fg or bg pixel
                node = self.img[r_ind[k], c_ind[k]] # get pixel
                D_BG = 0
                D_FG = 0
                for i in range(self.k): # for each gmm component
                    # calculate the D for the bg and fg gmm components
                    bg_u = self.bg_gmm.means_[i] # bg mean
                    fg_u = self.fg_gmm.means_[i] # fg mean
                    D_BG += D_bg[i] + 0.5*(node - bg_u).reshape(1,3).dot(self.bg_cov_inv[i]).dot((node - bg_u).reshape(3,1))[0][0]
                    D_FG += D_fg[i] + 0.5*(node - fg_u).reshape(1,3).dot(self.fg_cov_inv[i]).dot((node - fg_u).reshape(3,1))[0][0]

                edge_weights_bg[r_ind[k], c_ind[k]] = D_BG
                edge_weights_fg[r_ind[k], c_ind[k]] = D_FG

            # Now we have the edge weights for the fg and bg gmm components. We need to add these to the graph
            self.graph.add_tedge(self.pixels[r_ind[k], c_ind[k]], edge_weights_fg[r_ind[k], c_ind[k]], edge_weights_bg[r_ind[k], c_ind[k]])

            
        self.graph.maxflow() # run maxflow on the graph

        self.refine(self, r_ind, c_ind, edge_weights_bg, edge_weights_fg) # refine the trimap

    def init_k_means(self):
        self.kmf = KMeans(self.k, self.max_iters )
        self.kmd = KMeans(self.k, self.max_iters )


    def init_gmm(self, labels):
        self.bg_gmm = GaussianMixture(n_components=self.k)
        self.fg_gmm = GaussianMixture(n_components=self.k)

    def get_cov_det(self):
        self.bg_cov_det = np.linalg.det(self.bg_gmm.covariances_)
        self.fg_cov_det = np.linalg.det(self.fg_gmm.covariances_)

    def get_cov_inv(self):
        self.bg_cov_inv = np.linalg.inv(self.bg_gmm.covariances_)
        self.fg_cov_inv = np.linalg.inv(self.fg_gmm.covariances_)

    def refine(self, r_ind, c_ind, edge_weights_bg, edge_weights_fg):
        for i in range(len(r_ind)):
            node = self.pixels[r_ind[i], c_ind[i]]
            self.graph.add_tedge(node, -edge_weights_fg[r_ind[i], c_ind[i]], -edge_weights_bg[r_ind[i], c_ind[i]])

            if self.graph.get_segment(node) == 0:
                self.trimap[r_ind[i], c_ind[i]] = self.PR_FG

            else:
                self.trimap[r_ind[i], c_ind[i]] = self.PR_BG
    
    

In [None]:
img = cv2.imread("lantern.jpg")
print(img.shape)
mask = np.zeros((img.shape[:2]), dtype=np.uint8)
gc = GrabCut(img, 50, 5, 10, (365, 869, 40, 924,), mask, 1)

In [None]:
print(gc.img.shape)

In [None]:
img[np.where(np.logical_or(gc.trimap == gc.BG, gc.trimap == gc.PR_BG))] = [0, 0, 0]


In [None]:
cv2.imshow('ss', img)
while True:
    k = cv2.waitKey(0) & 0xFF
    print(k)
    if k == 27:
        cv2.destroyAllWindows()
        break