In [None]:
import cv2 as cv
import matplotlib.pyplot as plt
import numpy as np
from skimage import data, filters
plt.rcParams['figure.figsize'] = [10, 10]

In [2]:
# Function to read image saved on google drive
def readImage(number, jpg = False):
    if jpg == True:
        path = "images/" + str(number) + ".jpg"
    else:
        path = "images/" + str(number) + ".tif"
    img = cv.imread(path)
    return img[:,:,::-1]

# Function to display image using matplotlib
def displayImage(a, title1 = "Img"):
    plt.imshow(a, cmap='gray', vmin=0, vmax=255), plt.title(title1)
    plt.show()

In [3]:
# Area-based cleaning of image
def clean(img, areaThreshold = 25):
    # Find all connected components
    output = cv.connectedComponentsWithStats(np.uint8(img), 4, cv.CV_32S)
    num_labels = output[0]
    labels = output[1]
    stats = output[2]
    centroids = output[3]

    # output image
    cleaned = np.copy(np.uint8(img))
    # if area of any connected component is less than area threshold, remove that component
    for i in range(cleaned.shape[0]):
        for j in range(cleaned.shape[1]):
            if stats[labels[i][j], cv.CC_STAT_AREA] < areaThreshold:
                cleaned[i][j] = 0
    return cleaned

In [4]:
def removeDisk(img, upper, lower):
    for i in range(upper):
        for j in range(img.shape[1]):
            img[i][j] = 0
    for i in range(lower, img.shape[0]):
        for j in range(img.shape[1]):
            img[i][j] = 0
    output = cv.connectedComponentsWithStats(np.uint8(img), 4, cv.CV_32S)
    labels = output[1]
    cent = int(img.shape[0]/2)
    diskLabel = 0
    while img[cent][diskLabel] == 0:
        diskLabel += 1
    diskLabel = labels[cent][diskLabel]
    for i in range(img.shape[0]):
        for j in range(img.shape[1]):
            if labels[i][j] == diskLabel:
                img[i][j] = 0
    return img

In [5]:
# Extract green channel from image
def extractChannel(img, channel = 'g'):
    # RGB instead of BGR bcooz we flipped it earlier while reading image
    r, g, b = cv.split(img)
    if channel == 'g':
        return g
    if channel == 'r':
        return r
    return b

In [6]:
def diskKernel(n):
    return cv.getStructuringElement(cv.MORPH_ELLIPSE,(2*n - 1 , 2*n - 1))

In [7]:
def saveImage(path, img, cmap='gray'):
    plt.imsave(path, img, cmap=cmap)

### Part 1: Thresholding

In [None]:
img = readImage(1)
g = extractChannel(img, 'g')

th2 = cv.adaptiveThreshold(g,255,cv.ADAPTIVE_THRESH_MEAN_C,\
            cv.THRESH_BINARY_INV,11,2)
displayImage(th2)
saveImage('part1_thresh.jpg', th2)
arClean = clean(th2, 50)
saveImage('part1_area_50.jpg', arClean)
displayImage(arClean)
med = cv.medianBlur(arClean, 3)
displayImage(med)
saveImage('part1_med.jpg', med)
disk = removeDisk(med, 20, img.shape[0] - 20)
displayImage(disk)
saveImage('part1_removedisk.jpg', disk)
arClean = clean(disk, 60)
saveImage('part1_final.jpg', arClean)
displayImage(arClean)

### Part 2: Region Growing

In [None]:
def union(img1, img2):
    if img1.shape[0] != img2.shape[0] or img1.shape[1] != img2.shape[1]:
        return img1
    for i in range(img2.shape[0]):
        for j in range(img2.shape[1]):
            if img2[i][j] == 255:
                img1[i][j] = 255
    return img1

# output = img1 - img2
def subtract(img1, img2):
    if img1.shape[0] != img2.shape[0] or img1.shape[1] != img2.shape[1]:
        return img1
    for i in range(img2.shape[0]):
        for j in range(img2.shape[1]):
            if img2[i][j] == 255 and img1[i][j] == 255:
                img1[i][j] = 0
    return img1

def get8n(x, y, shape):
    out = []
    maxx = shape[1]-1
    maxy = shape[0]-1

    #top left
    outx = min(max(x-1,0),maxx)
    outy = min(max(y-1,0),maxy)
    out.append((outx,outy))
    #top center
    outx = x
    outy = min(max(y-1,0),maxy)
    out.append((outx,outy))
    #top right
    outx = min(max(x+1,0),maxx)
    outy = min(max(y-1,0),maxy)
    out.append((outx,outy))
    #left
    outx = min(max(x-1,0),maxx)
    outy = y
    out.append((outx,outy))
    #right
    outx = min(max(x+1,0),maxx)
    outy = y
    out.append((outx,outy))
    #bottom left
    outx = min(max(x-1,0),maxx)
    outy = min(max(y+1,0),maxy)
    out.append((outx,outy))
    #bottom center
    outx = x
    outy = min(max(y+1,0),maxy)
    out.append((outx,outy))
    #bottom right
    outx = min(max(x+1,0),maxx)
    outy = min(max(y+1,0),maxy)
    out.append((outx,outy))

    return out

def region_growing(img, seed):
    outimg = np.zeros_like(img)
    l = []
    l.append((seed[0], seed[1]))
    processed = []
    removed = 0
    while len(l) > 0:
        pix = l[0]
        outimg[pix[0], pix[1]] = 255
        for coord in get8n(pix[0], pix[1], img.shape):
            if img[coord[0], coord[1]] != 0:
                outimg[coord[0], coord[1]] = 255
                if not coord in processed:
                    l.append(coord)
                processed.append(coord)
        l.pop(0)
    return outimg

def region_removing(img, seed, areaThreshold = 2000):
    outimg = np.zeros_like(img)
    l = []
    l.append((seed[0], seed[1]))
    processed = []
    removed = 0
    while len(l) > 0 and removed < areaThreshold:
        pix = l[0]
        outimg[pix[0], pix[1]] = 255
        for coord in get8n(pix[0], pix[1], img.shape):
            if img[coord[0], coord[1]] != 0:
                outimg[coord[0], coord[1]] = 255
                removed += 1
                if removed == areaThreshold:
                    break
                if not coord in processed:
                    l.append(coord)
                processed.append(coord)
        l.pop(0)
    return outimg

def on_mouse(event, x, y, flags, params):
    if event == cv.EVENT_LBUTTONDOWN:
        print('Seed: ' + str(x) + ', ' + str(y), thres[y,x])
        clicks.append((y,x))

n = 1
img = readImage(n)
g = extractChannel(img, 'g')
thres = cv.adaptiveThreshold(g,255,cv.ADAPTIVE_THRESH_MEAN_C,\
            cv.THRESH_BINARY_INV,11,2)
thres = clean(thres)
currOutput = np.zeros_like(thres)

while True:
    clicks = []
    cv.namedWindow('Input')
    cv.setMouseCallback('Input', on_mouse, 0, )
    cv.imshow('Input', thres)
    cv.imshow('Current output', currOutput)
    cv.waitKey()
    if len(clicks) > 0:
        seed = clicks[-1]
        out = region_growing(thres, seed)
        currOutput = union(currOutput, out)
        cv.imshow('Region Growing', out)
        cv.waitKey()
        cv.destroyAllWindows()
    else:
        cv.destroyAllWindows()
        break

print(currOutput.shape)
while True:
    clicks = []
    cv.namedWindow('Input')
    cv.setMouseCallback('Input', on_mouse, 0, )
    cv.imshow('Input', currOutput)
    cv.waitKey()
    if len(clicks) > 0:
        areaThres = int(input('Input area: '))
        seed = clicks[-1]
        temp = region_removing(currOutput, seed, areaThres)
        currOutput = subtract(currOutput, temp)
        cv.imshow('Region Removed', temp)
        cv.waitKey()
        cv.destroyAllWindows()
    else:
        cv.destroyAllWindows()
        break

displayImage(currOutput, 'Final image')

plt.imsave('output_part2/output_part2_' + str(n) + '.jpg', currOutput, cmap='gray')

### Part 3: Image matting

#### Create Trimap

In [6]:
def diskKernel(size):
    return cv.getStructuringElement(cv.MORPH_ELLIPSE, (size, size))

def readImage(number, jpg = True):
    if jpg == True:
        path = "output_part2/output_part2_" + str(number) + ".jpg"
    else:
        path = "output_part2/output_part2_" + str(number) + ".tif"
    img = cv.imread(path, cv.IMREAD_GRAYSCALE)
    return img
# Function to display image using matplotlib
def displayImage(a, title1 = "Img"):
    plt.imshow(a, cmap='gray', vmin=0, vmax=255), plt.title(title1)
    plt.show()

def find_trimap(currOutput , number, kernelSize):
    trimap = np.zeros(currOutput.shape , dtype = np.uint8)
    trimap[(currOutput > 20)] = 255
    trimap = clean(trimap, 15)
    dilat = cv.dilate(currOutput,diskKernel(kernelSize),iterations = 4)
    trimap[((dilat == 255)&(currOutput <= 20))] = 127
    path = 'trimap/kernel_' + str(kernelSize) + '/' + str(number)+'.jpg'
    plt.imsave(path, trimap, cmap='gray')
    return trimap

In [7]:
for i in range(1, 21):
    for s in [3, 5, 7, 9, 11]:
        img = readImage(i)
        find_trimap(img, i, s)

#### Image Matting

In [18]:
from matting import alpha_matting, load_image, save_image, estimate_foreground_background, stack_images

for i in range(1, 21):
    for s in [3, 5, 7, 9, 11]:
        image = load_image("input/input" + str(i) + ".jpg", "RGB")
        trimap = load_image("trimap/kernel_" + str(s) + "/" + str(i) +".jpg" , "GRAY")

        alpha = alpha_matting(image, trimap, method="cf", preconditioner="vcycle", print_info=True)

        foreground, background = estimate_foreground_background(image, alpha, print_info=True)

        cutout = stack_images(foreground, alpha)

        save_image("output_part3/kernel_" + str(s) + "/" + str(i) + ".png", cutout)
