In [2]:
from google.colab import drive
drive.mount("/content/gdrive")

Mounted at /content/gdrive


In [3]:
import os
import numpy as np
import matplotlib.pyplot as plt
import cv2 as cv

In [12]:
root = "/content/gdrive/MyDrive/COMP448/"
dataset = os.path.join(root, "Dataset")
os.chdir(root)

In [15]:
### we must call these functions when cwd is the directory where we put them
def read_image(imName):
    im = cv.imread(imName)
    return np.flip(im, 2)

def write_image(im, imName):
    im = cv.imwrite(imName, cv.cvtColor(im, cv.COLOR_RGB2BGR))

In [16]:
def plot_image(im, w, h, cmap = "gray", title = False):
    fig = plt.figure(figsize = (w, h))
    plt.imshow(im, cmap = cmap, interpolation = "none")
    plt.xticks([])
    plt.yticks([])
    if title:
        plt.title(title)

In [17]:
def uint8_normalize(im):
    M = np.amax(im)
    m = np.amin(im)
    normalized = (255/(M - m))*(im - m)
    return normalized.round().astype(int)

In [18]:
def plot_segmentation(im, mask, w, h, show_type = "w"):
    imc = np.copy(im)
    if show_type == "w":
        fig = plt.figure(figsize = (w, h))
        imc[:,:,0][mask] = 255
        imc[:,:,1][mask] = 255
        imc[:,:,2][mask] = 255
        plt.imshow(imc, interpolation = "none")
        plt.xticks([])
        plt.yticks([])

In [19]:
def eliminate_circular_components(mask, th = 0.3):
    labelnum, labels, stats, centroids = cv.connectedComponentsWithStats((1*mask).astype("uint8"))
    circularity = []
    sqrtpi = np.sqrt(np.pi)
    for c in range(1,labelnum):
        contours, hierarchy = cv.findContours((labels == c).astype("uint8"), cv.RETR_EXTERNAL, cv.CHAIN_APPROX_NONE)
        perimeter = cv.arcLength(contours[0], closed = True)
        if perimeter == 0:
            perimeter = 2*sqrtpi
        circularity.append((4*np.pi*stats[c, 4])/(perimeter**2))
    circularity = np.array(circularity)
    circulars = 1 + np.nonzero(circularity >= th)[0]
    labels[np.isin(labels, circulars)] = 0
    return labels > 0

In [20]:
def fill_regions(im, mask):
    imc = np.copy(im)
    zeroed = np.copy(im)
    zeroed[mask] = [0,0,0]

    sums = cv.boxFilter(zeroed, cv.CV_32F, (11,11), normalize = False)
    num_values = cv.boxFilter(1*(zeroed > 0), cv.CV_32F, (11,11), normalize = False)


    if np.count_nonzero(num_values[:,:,0] == 0) > 0:
        troubled_indices = np.transpose(np.where(num_values[:,:,0] == 0))
        m = zeroed.shape[0]
        n = zeroed.shape[1]
        for i in range(troubled_indices.shape[0]):
            ksize = 6
            r, c = troubled_indices[i,0], troubled_indices[i, 1]
            while True:
                mr, Mr = max(r - ksize, 0), min(r + ksize + 1, m)
                mc, Mc = max(c - ksize, 0), min(c + ksize + 1, n)
                count = np.count_nonzero(zeroed[mr:Mr, mc:Mc,:]) / 3
                if count > 0:
                    zeroed[r, c] = np.sum(zeroed[mr:Mr, mc:Mc,:], (0,1)) / count
                    break
                else:
                    ksize += 1
        mask[troubled_indices[:,0], troubled_indices[:,1]] = 0
        num_values[troubled_indices[:,0], troubled_indices[:,1]] = 1

    avg = sums/num_values
    imc[mask] = avg[mask]
    return imc

In [21]:
def remove_hairs(im, size = 5, r = 2, plots = False):
    ksize = 2*size + 1

    in_filter = np.zeros((ksize, ksize))
    for i in range(ksize):
        for j in range(ksize):
            if (i - size)**2 + (j - size)**2 <= r**2:
                in_filter[i, j] = 1
    out_filter = 1 - in_filter

    in_filter = in_filter / np.sum(in_filter)
    out_filter = out_filter / np.sum(out_filter)

    #gray = im[:,:,2]
    gray = cv.cvtColor(im, cv.COLOR_RGB2GRAY)
    smooth_gray = cv.GaussianBlur(gray, ksize = (3,3), sigmaX = 1)

    in_filt = cv.filter2D(smooth_gray, cv.CV_32F, in_filter)
    out_filt = cv.filter2D(smooth_gray, cv.CV_32F, out_filter)
    diff = out_filt - in_filt

    diff = uint8_normalize(diff)


    if plots:
        plot_image(diff, 7, 7, title = "Step 1")

    mean = np.mean(diff)
    std = np.std(diff)
    mask = ((diff - mean) > std)

    if plots:
        plot_image(mask.astype("uint8"), 7, 7, title = "Step 2")

    mask2 = eliminate_circular_components(mask, 0.2)

    if plots:
        plot_image(mask2.astype("uint8"), 7, 7, title = "Step 3")

    removed = fill_regions(im, mask2)

    if plots:
        plot_image(removed, 7, 7, title = "Result")

    return removed

