In [1]:
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as tf
import kornia
import numpy as np

## 1 Canny edge detection

### Q1.1

In [2]:
def gauss_derivative_filtering(image:torch.Tensor, length:int=3, sigma:float=1.) -> tuple:
    # flip the x filter. conv2d function is using cross-correlation
    x0 = torch.linspace(sigma, -sigma, length)
    x = x0.expand(length, -1)
    y = x0.reshape(length,1).expand(-1, length)
    
    # x direction Gaussian filter [B=1,C=1,H=length,W=length]
    gexp = torch.exp(-(torch.square(x) + torch.square(y)) / (2*sigma**2))
    gauss = gexp / (2*torch.pi*sigma**2)
    dx = (-(x / sigma**2) * gauss).view(1,1,length,length)
    dy = (-(y / sigma**2) * gauss).view(1,1,length,length)

    
    # replicate edge padding image with half filter length
    p = length//2
    pad_img = tf.pad(image, [p,p,p,p], mode='replicate')
    
    Ix = tf.conv2d(pad_img, dx)
    Iy = tf.conv2d(pad_img, dy)
    
    return Ix, Iy

In [3]:
def imshow_row2(img1:torch.Tensor, img2:torch.Tensor, title1:str, title2:str, cm:str=None) -> None:
    # helper func to show 2 image tensors in 1 row
    _, (g0, g1) = plt.subplots(1,2, figsize=(12,6))
    g0.imshow(img1.squeeze(), cmap=cm)
    g0.set_title(title1)
    g1.imshow(img2.squeeze(), cmap=cm)
    g1.set_title(title2)

In [4]:
def magnitude_rl(magnitude:torch.Tensor, direction:torch.Tensor) -> tuple:
    # helper to return the "left" and "right" comparison targets of magnitude, based on direction INT value
    # MUCH faster than my original magnitude_edge(), see below
    r0_up = torch.roll(magnitude, -1, dims=2)
    r0_up[..., -1, :] = r0_up[..., -2, :]

    r1_lu = torch.roll(magnitude, (-1, -1), dims=(2,3))
    r1_lu[..., -1, :] = r1_lu[..., -2, :]
    r1_lu[..., -1] = r1_lu[..., -2]

    r2_l = torch.roll(magnitude, -1, dims=3)
    r2_l[..., -1] = r2_l[..., -2]

    r3_ld = torch.roll(magnitude, (1, -1), dims=(2,3))
    r3_ld[..., 0, :] = r3_ld[..., 1, :]
    r3_ld[..., -1] = r3_ld[..., -2]

    r0 = torch.where(direction==0, r0_up, 0)
    r1 = torch.where(direction==1, r1_lu, 0)
    r2 = torch.where(direction==2, r2_l, 0)
    r3 = torch.where(direction==3, r3_ld, 0)

    l0_down = torch.roll(magnitude, 1, dims=2)
    l0_down[..., 0, :] = l0_down[..., 1, :]

    l1_rd = torch.roll(magnitude, (1, 1), dims=(2,3))
    l1_rd[..., 0, :] = l1_rd[..., 1, :]
    l1_rd[..., 0] = l1_rd[..., 1]

    l2_r = torch.roll(magnitude, 1, dims=3)
    l2_r[..., 0] = l2_r[..., 1]

    l3_ru = torch.roll(magnitude, (-1, 1), dims=(2,3))
    l3_ru[..., -1, :] = l3_ru[..., -2, :]
    l3_ru[..., 0] = l3_ru[..., 1]

    l0 = torch.where(direction==0, l0_down, 0)
    l1 = torch.where(direction==1, l1_rd, 0)
    l2 = torch.where(direction==2, l2_r, 0)
    l3 = torch.where(direction==3, l3_ru, 0)

    return (r0 + r1 + r2 + r3, l0 + l1 + l2 + l3)

In [5]:
def magnitude_edge(magnitude:torch.Tensor, dire_pent:torch.Tensor) -> torch.Tensor:
    # alternative double loop way to get magnitude edges
    # slow compared to magnitude_rl()
    choices = {
        0: (lambda i, j: ((i+1, j  ), (i-1, j  ))),
        1: (lambda i, j: ((i+1, j+1), (i-1, j-1))),
        2: (lambda i, j: ((i  , j+1), (i  , j-1))),
        3: (lambda i, j: ((i-1, j+1), (i+1, j-1))),
    }
    pad_mag = tf.pad(magnitude, [1,1,1,1], mode='replicate').squeeze()
    
    # perform non-maximum supression on each pixel
    edge = torch.zeros_like(magnitude.squeeze())
    h, w = image.shape[-2:]
    for i in range(h):
        for j in range(w):
            pr, pl = choices[dire_pent[i, j].item()](i+1, j+1)
            center = pad_mag[i+1, j+1]
            if center > pad_mag[pr] and center >= pad_mag[pl]:
                edge[i, j] = center
    return edge.view(1,1,h,w)

In [6]:
def MyCanny(image:torch.Tensor, sigma:float) -> torch.Tensor:
    # 1. Filter image with x, y derivatives of Gaussian filter
    Ix, Iy = gauss_derivative_filtering(image, 3, sigma)
    
    # example x, y gradient with kornia sobel filter, and compare with native gradient filter
    example = kornia.filters.spatial_gradient(image, mode='sobel')
    eg_x, eg_y = torch.split(example, 1, dim=2)
    
#     # show temporal results
#     imshow_row2(Ix, Iy, "x gradient of image", "y gradient of image", "gray")
#     imshow_row2(eg_x, eg_y, "kornia x gradient of image", "kornia y gradient of image", "gray")
    
    # 2. Find magnitude & direction of image gradient
    magnitude = torch.sqrt(torch.square(Ix) + torch.square(Iy))
    direction = torch.atan(Iy / Ix).nan_to_num_()
    imshow_row2(magnitude.squeeze(), direction.squeeze(), "magnitude", "direction", "gray")
    
    # 3. Perform non-maximum suppression
    # prepare direction matrix into positive and multitude of pi/4 for neighbor choices
    dire_pent = torch.div(direction.squeeze() + torch.pi*5/8, torch.pi/4, rounding_mode='floor').squeeze().int()
    # truncate the same comparisons of top-down and down-top
    dire_pent %= 4
    
    right, left = magnitude_rl(magnitude, dire_pent)
    NMS_bool = torch.logical_and(magnitude > right, magnitude >= left)
    NMS = torch.where(NMS_bool, magnitude, 0)
#     plt.imshow(NMS.squeeze(), cmap='gray')
    # return Non Maximum Suppression edges with dimension 1*1*H*W
    return NMS

### Q1.2

In [7]:
def window_8(pad:torch.Tensor) -> tuple:
    # return 8x frames clock-wise with 1*1*H-2*W-2 of given pad
    p0 = pad[..., :-2, 1:-1]
    p1 = pad[..., :-2, 2:]
    p2 = pad[..., 1:-1, 2:]
    p3 = pad[..., 2:, 2:]
    p4 = pad[..., 2:, 1:-1]
    p5 = pad[..., 2:, :-2]
    p6 = pad[..., 1:-1, :-2]
    p7 = pad[..., :-2, :-2]
    return (p0, p1, p2, p3, p4, p5, p6, p7)

In [8]:
def MyCannyFull(image:torch.Tensor, sigma:float, high_max_ratio:float) -> torch.Tensor:
    edge = MyCanny(image, sigma)
#     plt.imshow(edge.squeeze(), cmap='gray')
    
    # For this function, instead of giving a high threshold
    # high threshold is (0 < high_max_ratio <= 1) of max intensity
    if high_max_ratio <= 0.:
        print(f'high_max_ratio {high_max_ratio} should be (0, 1]')
        return image
    
    # 4. Hysteresis thresholding
    maximum = torch.max(edge)
    high = maximum * high_max_ratio
    hyst = torch.where(edge >= high, 1., 0.)
    pad_hyst = tf.pad(hyst, [1,1,1,1], mode='constant', value=0.)
    p0, p1, p2, p3, p4, p5, p6, p7 = window_8(pad_hyst) 
    
    new_count = torch.sum(hyst == 1.)
    count = 0
    diff = new_count
    while diff:
        print(diff)
        count = new_count
        b0 = torch.logical_and(p0 == 1, edge > 0)
        b1 = torch.logical_and(p1 == 1, edge > 0)
        b2 = torch.logical_and(p2 == 1, edge > 0)
        b3 = torch.logical_and(p3 == 1, edge > 0)
        b4 = torch.logical_and(p4 == 1, edge > 0)
        b5 = torch.logical_and(p5 == 1, edge > 0)
        b6 = torch.logical_and(p6 == 1, edge > 0)
        b7 = torch.logical_and(p7 == 1, edge > 0)
        b01 = torch.logical_or(b0, b1)
        b23 = torch.logical_or(b2, b3)
        b45 = torch.logical_or(b4, b5)
        b67 = torch.logical_or(b6, b7)
        b03 = torch.logical_or(b01, b23)
        b47 = torch.logical_or(b45, b67)
        candidate = torch.logical_or(b03, b47)
        hyst = torch.where(candidate, 1., hyst)
        new_count = torch.sum(hyst == 1.)
        diff =  new_count - count
    
    return hyst

## 2 Seam Carving

In [9]:
def MySeamCarving(image_origin:torch.Tensor, target_h:int, target_w:int) -> np.ndarray:
    h, w = image_origin.shape[-2:]
    print(h, target_h, w, target_w)
    if h < target_h or w < target_h:
        print('Target image dimension is not supported or needed to be processed by this function')
        return image_origin
    
    image = image_origin.detach().clone()
    transposed = target_h < h
    while h > target_h or w > target_w:
        if target_w == w and target_h < h:
            image = torch.transpose(image, dim0=2, dim1=3)
            h, w = w, h
            target_h, target_w = target_w, target_h
        
        # 1. Get spacial gradient for each channel of given image, get Ix, Iy of 1*C*H*W
        grad = kornia.filters.spatial_gradient(image, mode="sobel")
        Ix, Iy = torch.split(grad, 1, dim=2)

        # 2. Compute Energy or sum of magnitude of gradient on all channels, get Energy of H*W
        magnitude = torch.sqrt(torch.square(Ix) + torch.square(Iy))
        energy = torch.sum(magnitude, dim=1).squeeze()

        # 3. DP score board and find lowest energy vertical seam
        score = torch.zeros_like(energy)
        score[0, :] = energy[0, :]
        for i in range(1, h):
            # DP on score board
            last = score[i-1, :]
            left = torch.roll(last, 1)
            left[0] = last[0]
            right = torch.roll(last, -1)
            right[-1] = last[-1]
            score[i, :] = energy[i, :] + torch.minimum(torch.minimum(last, left), right)
        seam = torch.zeros(h, dtype=torch.int64)
        seam[-1] = torch.argmin(score[-1, :])
        
        # 4. Backtrack for the vertical seam and remove it
        for i in range(h-2, -1, -1):
            # backtrack for seam indices
            m = seam[i+1]
            l, r = max(m-1, 0), min(m+1, w-1)
            three = [score[i, l], score[i, m], score[i, r]]
            seam[i] = m + three.index(min(three)) - 1
            seam[i] = min(max(seam[i], 0), w-1)
        
        # 5. Carve image on all channels at seam
        image = CarvingHelper(image, seam)
        h, w = image.shape[-2:]
    
    if transposed:
        print("transposed!")
        image = torch.transpose(image, dim0=2, dim1=3)
    return np.transpose(image.squeeze().numpy(), (1,2,0))

In [10]:
def CarvingHelper(image:torch.Tensor, seam:torch.Tensor) -> torch.Tensor:
    # Remove a vertical seam from image at position seam of each line
    h, w = image.shape[-2:]
    for i in range(h):
        image[..., i, :] = torch.cat((image[..., i, :seam[i]], torch.roll(image[..., i, seam[i]:], -1, dims=2)), 2)
    return image[..., :w-1]