In [None]:
import kornia
import torch

def MySeamCarving(input_img, width, height):
    
    # convert the input image to a PyTorch tensor and adjust its shape
    input_img = torch.from_numpy(input_img).unsqueeze(0).unsqueeze(0).float()

    # determine the number of vertical and horizontal seams to remove
    shape = input_img.shape
    vertical_seams = shape[3] - width
    horizontal_seams = shape[2] - height

    # remove vertical seams
    input_img = CarvingHelper(input_img, vertical_seams)

    # transpose the image to perform horizontal seam removal
    input_img = torch.transpose(input_img, 2, 3)

    # remove horizontal seams
    input_img = CarvingHelper(input_img, horizontal_seams)

    # transpose the image back to its original orientation
    input_img = torch.transpose(input_img, 2, 3)

     # assign the processed image to the output variable
    output_img = input_img

    return output_img

In [None]:
def CarvingHelper(input_img, remove_seams):
    if remove_seams <= 0:
        return input_img
    else:
        for _ in range(remove_seams):
            # (a.) compute the energy image
            spatial_gradient = kornia.filters.SpatialGradient(mode='sobel')(input_img)
            spatial_gradient = spatial_gradient.squeeze(0)
            gradient_magnitude = torch.abs(spatial_gradient[:, 0]) + torch.abs(spatial_gradient[:, 1])
            E = gradient_magnitude.sum(dim=0)

            height = E.shape[0]
            width = E.shape[1]

            # (b.) create a scoring matrix, M
            M = torch.zeros_like(E)

            # (c.) set the values of the first row of M to match E
            M[0] = E[0]

            # (d.) compute the cumulative minimum energy map using the formula
            seam = torch.zeros_like(M, dtype=int)
            for row in range(1, height):
                for col in range(0, width):
                    if col == 0:
                        neighbour = torch.argmin(M[row - 1, col:col + 2]).item()
                        seam[row, col] = neighbour + col
                        E_val = M[row - 1, neighbour + col]
                    else:
                        neighbour = torch.argmin(M[row - 1, col - 1:col + 2]).item()
                        seam[row, col] = neighbour + col - 1
                        E_val = M[row - 1, neighbour + col - 1]
                    M[row, col] = M[row, col] + E_val

            # (e.) find the minimum value in the bottom row of the scoring matrix
            min_idx_val = torch.argmin(M[-1]).item()

            # (f.) trace back up the seam
            for row in reversed(range(height)):
                seam[row, min_idx_val] = 0
                min_idx_val = seam[row, min_idx_val]

            # (g.) remove the seam from the image
            seam_mask = torch.ones(input_img.shape, dtype=torch.bool)
            for row in reversed(range(height)):
                seam_mask[:, :, row, min_idx_val] = False
                min_idx_val = seam[row, min_idx_val]
            output_img = input_img[seam_mask].reshape((1, 1, height, width - 1))

            # (h.) repeat until required seam is reached
            input_img = input_img[:, :, :, :-1]

        return output_img