# Image Segmentation Techniques

## Importing the required libraries

In [None]:
import cv2

import numpy as np
import pandas as pd
from math import log10, sqrt, floor, ceil

from matplotlib import pyplot as plt
from matplotlib import axes as axes
from google.colab.patches import cv2_imshow

!pip install PyMaxflow
from maxflow.fastmin import aexpansion_grid,abswap_grid

##Thresholding functions

In [200]:
# Function to implement basic global thresholding
def global_thresholding(image, t, diff_thres):
  # Getting the histogram of the image
  hist = [0 for _ in range(int(np.max(image))+1)]
  for i in range(image.shape[0]):
    for j in range(image.shape[1]):
      intensity = floor(image[i,j])
      hist[intensity] +=1
  # Initializing values
  segmented=[]
  thresholds=[]
  t_old = 0
  t_new = t
  # Finding segmented image using threshold
  while True:
    new_image = np.zeros(image.shape)
    for i in range(image.shape[0]):
      for j in range(image.shape[1]):
        new_image[i,j] = 1 if image[i,j]>t_new else 0
    segmented.append(new_image)
    thresholds.append(t_new)
    # Checking if the difference between consecutive thresholds is less that the pre-defined difference
    if (abs(t_old-t_new)<=diff_thres):
      break
    # Finding new threshold
    m1=0
    total = 0
    for i in range(0, t_new+1):
      m1+=i*hist[i]
      total+=hist[i]
    m1/=total
    m2=0
    total=0
    for i in range(t_new+1, len(hist)):
      m2+=i*hist[i]
      total+=hist[i]
    m2/=total
    t_old = t_new
    t_new = floor((m1+m2)/2)

  return segmented, thresholds, hist

In [206]:
def otsu_thresholding(image):
  # Getting the normalized histogram of the image
  hist = [0 for _ in range(int(np.max(image))+1)]
  for i in range(image.shape[0]):
    for j in range(image.shape[1]):
      intensity = floor(image[i,j])
      hist[intensity] +=1
  total = image.shape[0]*image.shape[1]
  for i in range(len(hist)):
    hist[i]/=total
  # Getting the cumulative histogram of the image
  c_hist = [0 for _ in range(int(np.max(image))+1)]
  c_hist[0] = hist[0]
  for i in range(1, len(hist)):
    c_hist[i] = c_hist[i-1] + hist[i]
  # Getting cumulative means
  c_mean = [0 for _ in range(int(np.max(image))+1)]
  c_mean[0] = hist[0]
  for i in range(1, len(hist)):
    c_mean[i] = c_mean[i-1] + (i+1)*hist[i]
  # Global mean
  mean_global = c_mean[-1]
  # Computing between class variance and using that to compute threshold
  sigma_b = [0 for _ in range(int(np.max(image))+1)]
  for i in range(len(hist)):
    if c_hist[i]==0 or c_hist[i]==1:
      sigma_b[i]=0
    else:
      sigma_b[i] = ((mean_global*c_hist[i]-c_mean[i])**2)/(c_hist[i]*(1-c_hist[i]))
  sigma_best = np.max(sigma_b)
  k = 0
  num = 0
  for i in range(len(hist)):
    if sigma_b[i]==sigma_best:
      k+=i
      num+=1
  k = int(k/num)
  # Finding the new segmented image
  new_image = np.zeros(image.shape)
  for i in range(image.shape[0]):
    for j in range(image.shape[1]):
      new_image[i,j] = 1 if image[i,j]>k else 0
  # Calculating global variance and separability
  global_var = 0
  for i in range(len(hist)):
    global_var+=hist[i]*(i-mean_global)**2
  separability = sigma_b[k]/global_var
  return new_image, k, hist, separability

In [201]:
def multiple_otsu(image):
  # Normalized histogram
  hist = [0 for _ in range(int(np.max(image))+1)]
  for i in range(image.shape[0]):
    for j in range(image.shape[1]):
      intensity = floor(image[i,j])
      hist[intensity] +=1
  total = image.shape[0]*image.shape[1]
  for i in range(len(hist)):
    hist[i]/=total
  # Cumulative histogram
  c_hist = [0 for _ in range(len(hist))]
  c_hist[0] = hist[0]
  for i in range(1, len(hist)):
    c_hist[i] = c_hist[i-1] + hist[i]
  # Cumulative means
  c_mean = [0 for _ in range(len(hist))]
  c_mean[0] = hist[0]
  for i in range(1, len(hist)):
    c_mean[i] = c_mean[i-1] + (i+1)*hist[i]
  # Global mean
  mean_global = c_mean[-1]
  # Between class variance
  sigma_b = np.zeros((len(hist), len(hist)))
  for k1 in range(1, len(hist)-1):
    for k2 in range(1, len(hist)-1):
      p1 = c_hist[k1]
      p2 = c_hist[k2] - c_hist[k1]
      p3 = c_hist[-1] - c_hist[k2]
      m1 = (c_mean[k1]/p1) if p1 else 0
      m2 = (c_mean[k2]-c_mean[k1])/p2 if p2 else 0
      m3 = (c_mean[-1]-c_mean[k2])/p3 if p3 else 0
      sigma_b[k1,k2] = p1*((m1-mean_global)**2) + p2*((m2-mean_global)**2) + p3*((m3-mean_global)**2)
  sigma_best = np.max(sigma_b)
  k1 = k2 = 0
  num = 0
  for i in range(len(hist)):
    for j in range(len(hist)):
      if sigma_b[i,j]==sigma_best:
        k1+=i
        k2+=j
        num+=1
  k1/=num
  k2/=num
  k1 = int(k1)
  k2 = int(k2)
  # Global variance and separability
  global_var = 0
  for i in range(len(hist)):
    global_var+=hist[i]*(i-mean_global)**2
  sep = sigma_b[k1,k2]/global_var
  # Segmented image
  segmented = np.zeros(image.shape)
  for i in range(image.shape[0]):
    for j in range(image.shape[1]):
      if image[i,j]>k2:
        segmented[i,j]=len(hist)
      elif image[i,j]>k1:
        segmented[i,j] = len(hist)/2
      else:
        segmented[i,j]=0
  # Plotting
  fig, ax = plt.subplots(1,3, figsize=(30,5))
  fig.tight_layout()
  ax[0].imshow(image, cmap='gray')
  ax[0].set_title("Original Image")
  ax[1].imshow(segmented, cmap='gray')
  name = "Segmented Image\nThreshold 1 = "+str(k1) + "\nThreshold 2 = "+str(k2) +"\nSeparability = "+str(sep)
  ax[1].set_title(name)
  ax[2].bar(range(len(hist)), hist, width=0.5)
  ax[2].vlines(k1, 0, max(hist), 'red')
  ax[2].vlines(k2, 0, max(hist), 'red')
  ax[2].set_title("Histogram of Smoothened Image\n(Threshold shown in Red)")
  plt.show(fig)

## Functions to plot the segmentation steps

In [202]:
# Function to plot the steps in a segmentation algorithm
def plot_all_steps(image, segmented, thresholds):
  n = len(segmented)
  rows = ceil((n+1)/4)
  fig, ax = plt.subplots(rows, 4, figsize=(20,rows*4));
  if rows>1:
    k=0
    for i in range(rows):
      for j in range(4):
        if k>=n:
          ax[i][j].set_axis_off()
        else:
          if i==j==0:
            ax[i][j].imshow(original, cmap="gray")
            ax[i][j].set_title("Original Image")
          else:
            ax[i][j].imshow(segmented[k], cmap="gray");
            name = "Threshold = "+ str(thresholds[k]);
            ax[i][j].set_title(name);
            k+=1
  else:
    ax[0].imshow(original, cmap="gray")
    ax[0].set_title("Original Image")
    for i in range(3):
      if i<n:
        ax[i+1].imshow(segmented[i], cmap="gray");
        name = "Threshold = "+ str(thresholds[i]);
        ax[i+1].set_title(name);
      else:
        ax[i+1].set_axis_off()

# Function to plot the segmented and original image for histogram based methods
def plot_histogram(image, segmented, hist, threshold, sep=0):
  fig, ax = plt.subplots(1,3, figsize=(30,5))
  fig.tight_layout()
  ax[0].imshow(image, cmap='gray')
  ax[0].set_title("Smoothened Image")
  ax[1].imshow(segmented, cmap='gray')
  if sep==0:
    name = "Segmented Image\nThreshold = "+str(threshold)
  else:
    name = "Segmented Image\nThreshold = "+str(threshold) +"\nSeparability = "+str(sep)
  ax[1].set_title(name)
  ax[2].bar(range(len(hist)), hist, width=0.5)
  ax[2].vlines(threshold, 0, max(hist), 'red')
  ax[2].set_title("Histogram of Smoothened Image\n(Threshold shown in Red)")
  plt.show(fig)

## Functions for Graph Cut segmentation techniques

In [397]:
# Function to get the data cost
def get_D(I, levels):
  D = np.square( I.reshape(I.shape+(1,)) - levels.reshape((1,1,-1)) )
  return D

# Function to get the interaction cost
def get_V(levels):
  V = np.abs(levels.reshape((-1,1)) - levels.reshape((1,-1)))
  return V

# Function to give the complement of an image
def complement(img):
  new_img = np.zeros(img.shape)
  max_val = np.max(img)
  for i in range(img.shape[0]):
    for j in range(img.shape[1]):
      if img[i,j]==0:
        new_img[i,j]=max_val
      elif img[i,j]==max_val:
        new_img[i,j]=0
      else:
        new_img[i,j] = img[i,j]
  return new_img

# Graph cut algorithms - alpha-beta swap and alpha expansion
def graphcut(img):
  I = img/np.max(img)
  levels = np.array([0.25, 0.75])
  D = get_D(I, levels)
  V = get_V(levels)
  ab_swap = complement (abswap_grid(D,V))
  a_expansion = complement(aexpansion_grid(D,V))
  # fig, ax = plt.subplots(1, 3, figsize = (30,30))
  # ax[0].imshow(img, cmap='gray')
  # ax[0].set_title("Original Image")
  # ax[1].imshow(ab_swap, cmap="gray")
  # ax[1].set_title("Segmentation Using Alpha-Beta Swap")
  # ax[2].imshow(a_expansion, cmap="gray")
  # ax[2].set_title("Segmentation Using Alpha Expansion")
  return ab_swap, a_expansion

## Metrics

In [378]:
def calculate_metrics(img, mask):
  tp = 0
  tn = 0
  fp = 0
  fn = 0
  for i in range(img.shape[0]):
    for j in range(img.shape[1]):
      if img[i,j]>0 and mask[i,j]>0:
        tp+=1
      elif img[i,j]>0 and mask[i,j]==0:
        fp+=1
      elif img[i,j]==0 and mask[i,j]>0:
        fn+=1
      elif img[i,j]==0 and mask[i,j]==0:
        tn+=1
  metrics = {}
  metrics["Accuracy"] = (tp+tn)/(tp+tn+fp+fn)
  metrics["Dice Coefficient"] = (2*tp)/(2*tp+fn+fp)
  metrics["Jaccard Index"] = (tp)/(tp+fn+fp)
  metrics["Sensitivity"] = (tp)/(tp+fn)
  metrics["Specificity"] = (tn)/(tn+fp)
  return metrics

## Final Code

In [None]:
original = cv2.imread("1.bmp",0);

# Change the initializations here
start = 125
diff = 1
ksize = 0

if ksize==0:
  smoothened = original
else:
  smoothened = cv2.blur(original, (ksize,ksize))
segmented, thresholds, hist = global_thresholding(smoothened, start, diff);

# To visualize segmentation algorithm iteration wise, uncomment the next two lines
# print ("Initialization = ", start, "\nDifference Threshold = ", diff)
# plot_all_steps(original, segmented, thresholds);

# To plot the final result of segmentation
plot_histogram(smoothened, segmented[-1], hist, thresholds[-1]);

In [None]:
img = cv2.imread("1.bmp",0)
mask = cv2.imread("1_anno.bmp", 0)

otsu, _, _, _ = otsu_thresholding(img)
ab_swap, a_expansion = graphcut(img)

otsu_metrics = calculate_metrics(otsu, mask)
ab_swap_metrics = calculate_metrics(ab_swap, mask)
a_exp_metrics = calculate_metrics(a_expansion, mask)

fig, ax = plt.subplots(2, 3, figsize = (20,10))
ax[0][0].imshow(img, cmap="gray");
ax[0][0].set_title("Original Image");
ax[0][1].imshow(mask, cmap="gray");
ax[0][1].set_title("Ground Truth");
ax[0][2].set_axis_off();
ax[1][0].imshow(otsu, cmap="gray");
ax[1][0].set_title("Image segmented using Otsu Thresholding");
ax[1][1].imshow(ab_swap, cmap="gray");
ax[1][1].set_title("Image segmented using alpha-beta swap");
ax[1][2].imshow(a_expansion, cmap="gray");
ax[1][2].set_title("Image segmented using alpha expansion");

data = [['Otsu Thresholding', otsu_metrics["Accuracy"], otsu_metrics["Dice Coefficient"], otsu_metrics["Jaccard Index"], otsu_metrics["Sensitivity"], otsu_metrics["Specificity"]],['Alpha Beta Swap', ab_swap_metrics["Accuracy"], ab_swap_metrics["Dice Coefficient"], ab_swap_metrics["Jaccard Index"], ab_swap_metrics["Sensitivity"], ab_swap_metrics["Specificity"]], ['Alpha Expansion', a_exp_metrics["Accuracy"], a_exp_metrics["Dice Coefficient"], a_exp_metrics["Jaccard Index"], a_exp_metrics["Sensitivity"], a_exp_metrics["Specificity"]]]
df = pd.DataFrame(data, columns = ['Method', 'Accuracy', 'Dice Coefficient', 'Jaccard Index', 'Sensitivity', 'Specificity'])
df.head()