<a href="https://colab.research.google.com/github/Tianananana/active-contour/blob/main/Active_Contour.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import cv2
import numpy as np
import matplotlib.pyplot as plt 
from torch.autograd.functional import jacobian
from torchvision import transforms

In [2]:
def bilinear_interpolation(coord, img):
  """
  Bilinear interpolation.
  :param coord: coord of points
  :param img: 2d array of smoothed img
  """
  x = coord[:, 0]
  y = coord[:, 1]
  
  x1, x2 = torch.floor(x).type(torch.long), torch.ceil(x).type(torch.long)
  y1, y2 = torch.floor(y).type(torch.long), torch.ceil(y).type(torch.long)

  # 1st condition x1!=x2 and y1!=y2
  q11, q12, q21, q22 = img[x1, y1], img[x2, y1], img[x1, y2], img[x2, y2]
  interp_px = (q11 * (x2 - x) * (y2 - y) +
               q21 * (x - x1) * (y2 - y) +
               q12 * (x2 - x) * (y - y1) +
               q22 * (x - x1) * (y - y1)
               )
  
  # get the index of x1==x2 , y1==y2 respectively
  index_x = x1==x2
  index_y = y1==y2

  # 2ed condition x1==x2 and y1==y2
  interp_px[index_x & index_y]  = img[x1, y1][index_x & index_y]

  # 3rd condition x1!=x2 and y1==y2
  interp_px[~index_x & index_y]  = ((x2-x)*q11 + (x-x1)*q12)[~index_x & index_y]

  # 4th condition x1==x2 and y1!=y2
  interp_px[index_x & ~index_y]  = ((y2-y)*q11 + (y-y1)*q21)[index_x & ~index_y]
     
  return interp_px

In [10]:
class ActiveContourImg:
  def __init__(self, img, coord, alpha=0.1, epsilon=0.01, lr=0.01):
    # Assume input img and coord is tensor when instantiating ActiveConourImg
    self.img = img
    k = 7
    s = 5

    # Gaussian blur via cv2 
    # self.smooth = cv2.GaussianBlur(self.img.detach().numpy(), (k, k), s)
    # self.smooth = torch.from_numpy(self.smooth).type(torch.FloatTensor)

    # Gaussian blur via torchvision transforms
    data_transforms = transforms.Compose([transforms.GaussianBlur(k, sigma=(s))])
    self.smooth = data_transforms(self.img)
    self.smooth = torch.clamp(self.img, min=1e-10, max=255)

    # Get image gradients
    self.dx_img = 1/8 * cv2.Sobel(self.smooth.detach().numpy(), cv2.CV_64F, 1, 0, ksize=3)
    self.dx_img = torch.from_numpy(self.dx_img).type(torch.FloatTensor)
    
    self.dy_img = 1/8 * cv2.Sobel(self.smooth.detach().numpy(), cv2.CV_64F, 0, 1, ksize=3)
    self.dy_img = torch.from_numpy(self.dy_img).type(torch.FloatTensor)


    self.no_pts = coord.shape[0]


    # Calculate l_0:
    self.pa = coord[0, :] # start pt
    self.pb = coord[-1, :] # end pt
    l_total = torch.norm(self.pa - self.pb)
    self.l_0 = (l_total/(self.no_pts-1)).item()

    # Parameters
    self.epsilon = epsilon
    self.alpha = alpha
    self.lr = lr
    self.P_reg = (self.img[self.pa.type(torch.IntTensor)[0],self.pa.type(torch.IntTensor)[1]] + \
                  self.img[self.pb.type(torch.IntTensor)[0],self.pb.type(torch.IntTensor)[1]])/2 # eqn (6) of paper

    # To be initialised. dtype: torch.
    self.l = None
    self.g = None
    self.j_l = None
    self.j_g = None
    self.loss = None
    

  def update_l(self, coord):
    """
    Updates value of l. Shape: (no_pts - 1, 1)
    """
    return torch.norm(coord[1:, :] - coord[:-1, :], dim = -1)

   
  def update_g(self,interp_px):
    """
    Updates value of g: Shape: (no_pts - 1, 1)
    """
    return (interp_px[:-1] + self.epsilon)**(-1) + (interp_px[1:] + self.epsilon)**(-1)



  def update_j_l(self):
    """
    Set new value of j_l and j_g. Shape: (no_pts - 1, no_pts, 2).
    j_l calculated via pytorch's jacobian method.
    """
    return jacobian(self.update_l, self.coord)


  def update_j_g(self, interp_px, interp_px_dx, interp_px_dy):
    """
    Set new value of j_g. Shape: (no_pts - 1, no_pts, 2).
    j_g calculated from dx_img and dy_img.
    """
    vx = - torch.diag(interp_px_dx/(interp_px + self.epsilon)**2) # k = i-1
    vy = - torch.diag(interp_px_dy/(interp_px + self.epsilon)**2)
    
    jx_g = vx[:-1,:]+vx[1:,:]
    jy_g = vy[:-1,:]+vy[1:,:]

    return torch.stack((jx_g, jy_g), axis = -1)


  def dL1dr(self):
    """
    Compute first term of loss function. Returns tensor of shape (no_pts, 2)
    """
    grad_x = np.matmul(self.j_g[:, :, 0].T, self.l) + np.matmul(self.j_l[:, :, 0].T, self.g)
    grad_y = np.matmul(self.j_g[:, :, 1].T, self.l) + np.matmul(self.j_l[:, :, 1].T, self.g)
    return torch.stack((grad_x, grad_y), axis = -1)

  def dL2dr(self):
    """
    Compute second term of loss function. Returns tensor of shape (no_pts-2, 2)
    # Exclude out 1st and n-th point
    """  
    coord = self.coord[1:-1,:]
    coord.requires_grad = True
    li_loss = torch.sum(((torch.norm(coord[1:, :] - coord[:-1, :], dim=1)) - self.l_0)**2)+\
        ((torch.norm(coord[0, :] - self.pa)) - self.l_0)**2+\
        ((torch.norm(self.pb - coord[-1, :])) - self.l_0)**2

    li_loss.backward()
    return coord.grad


  def grad_descent(self):
    return self.dL1dr()[1:-1, :] + self.alpha*self.dL2dr()


  def update_coord(self):
    updated_coord = self.coord[1:-1, :]  - self.lr*self.grad_descent()
    self.coord[1:-1] = updated_coord
   

  def loss_calculation(self, interp_px):
    loss_1= torch.sum(self.l/(interp_px[:-1] + self.epsilon)+self.l/(interp_px[1:] + self.epsilon))*self.P_reg/2
    loss_2 = torch.sum((self.l - self.l_0)**2)
    return torch.sum(loss_1) + self.alpha*torch.sum(loss_2)
    


  def train_one_round(self):
    interp_px = bilinear_interpolation(self.coord,self.img)
    interp_px_dx = bilinear_interpolation(self.coord,self.dx_img)
    interp_px_dy = bilinear_interpolation(self.coord,self.dy_img)
    self.l = self.update_l(self.coord)
    self.g = self.update_g(interp_px)
    self.j_l = self.update_j_l()
    self.j_g = self.update_j_g(interp_px,interp_px_dx,interp_px_dy)
    self.loss = self.loss_calculation(interp_px)
    self.update_coord()


In [None]:
def main(img, coord, rounds=10, alpha=0.1, epsilon = 0.01, lr=0.01):
  image = ActiveContourImg(img, coord)

  loss_all = []
  
  for i in range(rounds):
    print("round:{}==============".format(i))
    image.train_one_round()
    print("loss:", image.loss)
    loss_all.append(image.loss)
  print("Final coord: {} ============= All loss: {}".format(image.curr_coord, np.array(loss_all)))


if __name__ == "__main__":
  main(img, coord)

In [6]:
# TEST #
## Initialise coord and img ##
coord = torch.arange(start=30, end=70, step=2).reshape(10, 2).type(torch.FloatTensor)
img = torch.zeros((100, 100)).type(torch.FloatTensor)
for i in range(100):
  img[i, :] = torch.arange(100)
coord += 0.5

print("Coord Shape: {}".format(coord.shape))
print("Image Shape: {}".format(img.shape))
print("Coord Type: {}".format(type(coord)))
print("Image Type: {}".format(type(img)))

Coord Shape: torch.Size([10, 2])
Image Shape: torch.Size([100, 100])
Coord Type: <class 'torch.Tensor'>
Image Type: <class 'torch.Tensor'>


In [8]:
bilinear_interpolation(coord, img)

tensor([32.5000, 36.5000, 40.5000, 44.5000, 48.5000, 52.5000, 56.5000, 60.5000,
        64.5000, 68.5000])