In [None]:
import cv2
import glob
import matplotlib.pyplot as plt
import numpy as np
from cellpose import models
from cellpose import utils
from cellpose import plot
from scipy.ndimage import generic_filter
import itertools
from tqdm import tqdm
from scipy.optimize import minimize, leastsq
import GPy
import GPyOpt
import scipy

## Synthetic Data Generation

We define some generating points. For some point $\alpha$, we give it:
- $q_{\alpha}$ its (x,y) coordinates
- $z_{\alpha}$ its height above the (x,y) plane

The distance from some $r$ to $\alpha$ is the simple Euclidean distance, i.e. 
 $$d^2_{\alpha}(r) = p_{\alpha}(|r - q_{\alpha}|^2 + |z_{\alpha}|^2)$$

In [None]:
class Point:
    def __init__(self, x, y, z, p):
        self.x = x
        self.y = y
        self.z = z
        
        self.q = np.array([x,y])
        self.p = p
    
    def distance(self, r):
        return np.sqrt(self.p *(np.linalg.norm(r - self.q, axis=-1)**2 + np.linalg.norm(self.z)**2))
    
    def numpy(self):
        return np.array([self.x, self.y, self.z])
    
class Points:
    def __init__(self, num_points = 121, max_x = 1000, max_y = 800, max_z = 0, min_p = 0.001, max_p = 0.005):
        self.num_points = num_points
        self.N = np.sqrt(num_points).astype(int)
        
        self.max_x = max_x
        self.max_y = max_y
        self.max_z = max_z
        self.min_p = min_p
        self.max_p = max_p
        
        self.grid = self.initialize_grid()
        self.points = self.get_random_points(num_points)
    
    def initialize_grid(self):
        grid = []
        for i in range(self.max_x):
            for j in range(self.max_y):
                grid.append([i,j])
        return np.array(grid)

    def sort(self, CMs):
        """
        sort the points according to the norm of the center of their spanned cells.
        Uses the precomputed distance transform
        IMPORTANT: the center != generating point location
        """
        norms = np.array([np.linalg.norm(x) for x in CMs])

        sorted_norm_args = norms.argsort()
        old_points = self.points.copy()
        for j in range(len(sorted_norm_args)):
            self.points[j] = old_points[sorted_norm_args[j]] 
        
    def get_random_points(self, num_points):
        x_points = np.linspace(20, self.max_x-20, self.N).astype(int)
        y_points = np.linspace(20, self.max_y-20, self.N).astype(int)
                
        points = []
        for i in range(self.N):
            for j in range(self.N):
                x = int(x_points[i] + np.random.normal(0, 10))
                y = int(y_points[j] + np.random.normal(0, 10))
                z = int(np.random.uniform(0, self.max_z))
                p = np.random.uniform(self.min_p, self.max_p)
                
                x = np.clip(x, 0, self.max_x-1)
                y = np.clip(y, 0, self.max_y-1)
                
                points.append(Point(x,y,z,p))
        return points
    
    def image(self, show=False):
        img = np.ones((self.max_y, self.max_x)).astype('float32')
        for point in self.points:
            img = cv2.circle(img, (point.x, point.y), self.N, (0, 0, 0), -1)
        if show:
            plt.figure(figsize=(15, 10))
            plt.imshow(img, cmap='gray')
            plt.show()
        return img
    
    def numpy(self):
        img = np.zeros((self.max_x, self.max_y, self.max_z))
        for point in self.points:
            img[point.x, point.y, point.z] = 1
        return img

In [None]:
class DistanceTransform():
    def __init__(self, points):
        self.grid = points.grid
        self.points = points.points
        self.max_x = points.max_x
        self.max_y = points.max_y
        
        self.distances = [] # self.distances[i] contains the distance matrix for points[i] on the grid
        
        for i, point in enumerate(self.points):
            self.distances.append(point.distance(self.grid))
        
        self.distances = np.array(self.distances)
        self.transform = None
        self.CMs = np.array([[0, 0] for _ in range(len(self.points))])
        self.points_in_cell = [0 for _ in range(len(self.points))]
        
    def normalize(self, x):
        max_ = np.max(x)
        min_ = np.min(x)
        return (x - min_)/(max_ - min_)
    
    def compute_transform(self):
        """
        return image with the intensity of r = (x,y) corresponding to min_{i} p_i d_i^2(r) where i corresponds 
        to the index of the point in points. Also computes the center of mass for each cell. 
        """
        transform = np.zeros((self.max_x, self.max_y)).astype('float32')
        for i, r in enumerate(self.grid):
            closest_cell = np.argmin(self.distances[:, i])
            transform[r[0], r[1]] = self.distances[:, i][closest_cell]
            # update center of mass of corresponding cell
            self.CMs[closest_cell] += r
            self.points_in_cell[closest_cell] += 1
        
        self.CMs = [self.CMs[i]/self.points_in_cell[i] for i in range(len(self.points))]
        self.transform = self.normalize(transform).T
    
    def visualize_transform(self):
        res_rgb = cv2.cvtColor(self.transform, cv2.COLOR_GRAY2RGB)
        for point in self.points:
            res_rgb = cv2.circle(res_rgb, (point.x, point.y), 3, (0, 0, 1), -1)

        plt.figure(figsize=(15, 10))
        plt.imshow(res_rgb)
        plt.show()

## Cell Segmentation 

In [None]:
class Segmenter:
    def __init__(self):
        self.model = models.Cellpose(model_type='cyto')
        self.images = []
        self.N = None
        self.masks = []
        self.outlines = []
        self.vertices = []
        self.cells = []
        self.borders = []
        self.adjacent_cells = []
        self.barrycenters = []
    
    def closest_nonzero(self, img, pt):
        """
        returns the closest non-zero value to some point in the image. 
        Does so by iteratively constructing a spiral around the point.
        """
        steps_before_rotating = 0
        sign = -1
        current_point = pt

        while True:
            for j in range(2):
                for k in range(steps_before_rotating):
                    if j == 0:
                        # step in the y direction
                        current_point = current_point + sign * np.array([0, 1])
                    elif j == 1:
                        # step in the x direction
                        current_point = current_point + sign * np.array([1, 0])
                    try:
                        if img[current_point[0], current_point[1]] != 0 and current_point[0] < img.shape[0] and current_point[1] < img.shape[1] and current_point[0] >= 0 and current_point[1] >= 0:
                            return img[current_point[0], current_point[1]]
                    except:
                        # case where the spiral goes out of bounds (when zeroed points are near the edges of the image)
                        pass

            sign *= -1
            steps_before_rotating += 1
    
    
    def finetune_masks(self):
        """
        fine tune the masks by filling in points where they are zero
        """
        
        def swap_cell_ids(img, i, j):
            """
            auxiliary function -- takes two cell ids in the mask and swaps them in the mask
            """
            print("swapping", i, "and", j)
            for x in range(img.shape[0]):
                for y in range(img.shape[1]):
                    if img[x][y] == i: 
                        img[x][y] = j
                    elif img[x][y] == j:
                        img[x][y] = i
            return img
        
        for i in range(self.N):            
            # find where the mask image is zero
            zeros = np.array(np.where(self.masks[i] == 0)).T
            
            new_img = self.masks[i].copy()
            for point in zeros:
                new_img[point[0], point[1]] = self.closest_nonzero(self.masks[i], point)
            
            # fix the mask ordering --> order according to the distance of the cells' barrycenters from the origin
            self.masks[i] = new_img.copy() 
                        
            cells = np.unique(new_img)
            center_norms = []
            for cell in cells:
                center = np.mean(np.array(np.where(new_img == cell)).T, axis=0)
                norm = np.linalg.norm(center)
                center_norms.append(norm)
 
            sorted_norm_args = np.array(center_norms).argsort()
            for j in range(len(sorted_norm_args)):
                self.masks[i][np.where(new_img == sorted_norm_args[j] + 1)] = j + 1              
    
    def compute_barrycenters(self, i=0):
        # initialize
        if len(self.barrycenters) == 0:
            self.barrycenters = [None for i in range(self.N)]
        if self.barrycenters[i] == None:
            self.barrycenters[i] = {alpha:None for alpha in self.cells[i]}
            
        # find the barrycenter for each cell
        for alpha in self.cells[i]:
            self.barrycenters[i][alpha] = np.mean(np.array(np.where(self.masks[i] == alpha)).T, axis=0)
    
    def compute_vertices(self):
        """
        find the vertices according to the masks (i.e. intersections of three colors). The self.vertices[i]
        contains the indices of the vertices in the mask corresponding to image i. 
        In addition, it computes the outline
        """
        
        # define 3x3 kernel that finds the number of different colors in the neighborhood of a point
        kernel = lambda neighborhood : len(set(neighborhood))
        
        for i in range(self.N):
            # convolve that kernel with the masks
            res = generic_filter(self.masks[i], kernel, (3, 3))
            
            self.outlines.append(res)

            # vertices are the the points where the value is 3
            indices = np.array(np.where(res >= 3)).T
            self.vertices.append(indices)
    
    def find_cells(self):
        for i in range(self.N):
            # colors in the mask are unique --> each corresponds to a cell
            self.cells.append(np.unique(self.masks[i]))
      
    def get_border(self, alpha, beta, view_border = False, i = 0):
        """
        return the points in the border between cells alpha and beta
        """
        border = self.masks[i].copy()
                
        # zero out everything except the two cells
        border[np.where((border != alpha) & (border != beta))] = 0
   
        # to find the border, count the values in each point's neighborhood (removing 1 if the value is 0)
        kernel = lambda x : len(set(x.flatten())) - list(x.flatten()).count(0)
        
        # only do the convolution over the area taken up by the two masks for performance reasons
        # so we create some bounding box over which we'll convolve
        non_zero = np.transpose(np.where(border != 0))
        max_x = np.max(non_zero.T[0]) ; max_y = np.max(non_zero.T[1])
        min_x = np.min(non_zero.T[0]) ; min_y = np.min(non_zero.T[1])
        border = border[min_x:max_x, min_y:max_y]
        
        border = generic_filter(border, kernel, (3, 3))
                
        # identify elements in the border (i.e. those that are 2)
        border_points = np.transpose(np.where(border == 2))
        for k in range(len(border_points)):
            border_points[k][0] += min_x
            border_points[k][1] += min_y
        
        # if not initialized, initialize the borders to be None
        if len(self.borders) == 0:
            for j in range(self.N):
                self.borders.append({c1 : {c2 : None for c2 in self.cells[j]} for c1 in self.cells[j]})
        
        if len(border_points) != 0:
            self.borders[i][alpha][beta] = border_points
            self.borders[i][beta][alpha] = border_points
    
        if view_border:
            img = self.masks[i].copy()
            for point in self.borders[i][alpha][beta]:
                img = cv2.circle(img, (point[1], point[0]), 1, (255, 255, 255), -1)

            plt.figure(figsize=(15, 10))
            plt.imshow(img)
            plt.show()

    
    def edge(self, alpha, beta, i=0):
        """
        get the stored edge between cells alpha and beta in an image
        """
        try:
            return self.borders[i][alpha][beta]
        except:
            print("edges does not exist")
    
    def edges(self, i=0):
        """
        get all of the edges for a given image
        """
        return self.borders[i]
    
    def pairs(self, i=0):
        """
        returns the stored adjacent cells for the given image
        """
        pairs_l = []
        for (a,b) in self.adjacent_cells[i]:
            if self.borders[i][a][b] is None:
                print("empty border for", a,b)
            if self.borders[i][a][b] is not None:
                pairs_l.append((a,b))
        return pairs_l
    
    def get_edge_cells(self, i=0):
        """
        returns the cells on the edge of the image
        """
        cells = set()
        mask = self.masks[i]
        for i in range(mask.shape[0]):
            cells.add(mask[i][0])
            cells.add(mask[i][mask.shape[1]-1])
        for j in range(mask.shape[1]):
            cells.add(mask[0][j])
            cells.add(mask[mask.shape[0]-1][j])
        return cells
    
    def neighbors(self, v, i=0):
        """
        
        returns the mask values in the neighborhood of vertex v, where v is an index
        
        """
        neighs = {(v[0], v[1] + 1), (v[0] + 1, v[1] + 1), (v[0] + 1, v[1]), (v[0] + 1, v[1] - 1),
                (v[0], v[1] - 1), (v[0] - 1, v[1] - 1), (v[0] - 1, v[1]), (v[0] - 1, v[1] + 1)}

        if v[0] == 0:
            neighs.discard((v[0] - 1, v[1]))
            neighs.discard((v[0] - 1, v[1] - 1))
            neighs.discard((v[0] - 1, v[1] + 1))

        if v[1] == 0:
            neighs.discard((v[0], v[1] - 1))
            neighs.discard((v[0] - 1, v[1] - 1))
            neighs.discard((v[0] + 1, v[1] - 1))

        if v[0] == self.masks[i].shape[0]-1:
            neighs.discard((v[0] + 1, v[1]))
            neighs.discard((v[0] + 1, v[1] - 1))
            neighs.discard((v[0] + 1, v[1] + 1))

        if v[1] == self.masks[i].shape[1]-1:
            neighs.discard((v[0], v[1] + 1))
            neighs.discard((v[0] - 1, v[1] + 1))
            neighs.discard((v[0] + 1, v[1] + 1))

        return [self.masks[i][x] for x in neighs]
    
    def get_adjacent_cells(self, i=0):
        if len(self.adjacent_cells) == 0:
            self.adjacent_cells = [[] for _ in range(self.N)]
        
        # get all pairs of adjacent colors
        all_pairs = []
        for v in self.vertices[i]:
            neighboring_colors = self.neighbors(v)
            # get unique colors in the neighborhood
            unique = set(neighboring_colors)
            # pairs of colors in the neighborhood of v
            pairs = list(itertools.combinations(unique, 2))
            all_pairs += pairs

        for p in set(all_pairs):
            if (p[0], p[1]) not in self.adjacent_cells[i] and (p[1], p[0]) not in self.adjacent_cells[i]:
                self.adjacent_cells[i].append((p[0], p[1]))
    
    def find_edges(self, i=0):
        self.get_adjacent_cells()
        for (alpha, beta) in tqdm(self.adjacent_cells[i]):
            self.get_border(alpha, beta, i=i) 
            
    def segment(self, images, diameter=None):
        """
        
        Main function -- segments the image into cells and identifies the edges 
        
        """
        
        if type(images) != list: images = [images]
        self.images = images
        self.N = len(images)
        
        print("Evaluating the neural network")
        masks, flows, styles, diams = self.model.eval(images, diameter=diameter, flow_threshold=None, channels=[0,0])
    
        # original masks
        self.masks = masks
                
        print("Fixing the masks")
        # finetune the mask
        self.finetune_masks()
                
        print("Computing the vertices")
        # compute the vertices and outline
        self.compute_vertices()
        
        print("Identifying the cells")
        # compute the cells 
        self.find_cells()
        
        print("Finding the borders between cells")
        # find all edges between cells
        self.find_edges()
        
        print("Finding the cell barycenters")
        # find alll barycenters
        self.compute_barrycenters()
    
    def visualize(self, name='outlines', specific_cell = None, show_vertices = True, i = 0, overlay=False, return_img=False):
        """
        visualize the masks on the ith image
        """
        
        if name == 'masks':
            segmented = self.masks[i].copy()
        else:
            segmented = self.outlines[i].copy()
        
        if specific_cell != None:
            segmented[segmented != specific_cell] = 0
        
        if overlay:
            image = cv2.cvtColor(self.images[i], cv2.COLOR_GRAY2RGB)
            #img = plot.mask_overlay(image, segmented)
            for (alpha, beta) in self.pairs(i):
                for point in self.borders[i][alpha][beta]:
                    img = cv2.circle(image, (point[1], point[0]), 1, (1, 0, 0), -1)
        else:
            img = segmented.copy()
        
        if show_vertices:
            for point in self.vertices[i]:
                img = cv2.circle(img, (point[1], point[0]), 2, (0, 0, 1), -1)
        
        plt.figure(figsize=(15, 10))
        plt.imshow(img)
        plt.show()
        
        if return_img:
            return img

## CAP Tiling

In [None]:
class VMSI():
    
    def __init__(self, cell_pairs, edges, num_cells, cells, barrycenters, edge_cells, width=500, height=500):
        self.cell_pairs = cell_pairs
        self.edges = edges
        self.num_edges = len(edges)
        self.num_cells = num_cells
        self.width = width
        self.height = height
        self.barrycenters = barrycenters
        self.cells = cells
        self.tension = {alpha: {beta: None for beta in self.cells} for alpha in self.cells}
        self.edge_cells = edge_cells
        
        # remove edges if they don't have at least 3 points
        for (alpha, beta) in self.cell_pairs:
            if len(self.edges[alpha][beta]) < 3: 
                self.cell_pairs.remove((alpha, beta))
                self.edges[alpha][beta] = None
        
        # init vertices
        self.vertices = {alpha: {beta: None for beta in self.cells} for alpha in self.cells}
        for (alpha, beta) in self.cell_pairs:
            self.vertices[alpha][beta] = self.get_vertices(alpha, beta)
        
        # init tangents at vertices
        self.tangents = {alpha: {beta: None for beta in self.cells} for alpha in self.cells}
        for (alpha, beta) in self.cell_pairs:
            self.tangents[alpha][beta] = self.get_tangents(alpha, beta)
    
    def transform(self, q, z, p):
        """
        
        transform from points to CAP tiling
        via Equations 6 and 7
        
        """
        
        center = {alpha : {beta: None for beta in self.cells} for alpha in self.cells}
        radius = {alpha : {beta: None for beta in self.cells} for alpha in self.cells}
        
        for (alpha, beta) in self.cell_pairs:
            center[alpha][beta] = (p[beta-1]*q[beta-1] - p[alpha-1]*q[alpha-1]) / (p[beta-1] - p[alpha-1])
            radius[alpha][beta] = np.sqrt(((p[alpha-1]*p[beta-1]) * (np.linalg.norm(q[alpha-1] - q[beta-1])**2))/(p[alpha-1] - p[beta-1])**2 \
                                          - (p[alpha-1] * (z[alpha-1]**2) - p[beta-1] * (z[beta-1]**2))/(p[alpha-1] - p[beta-1]))
        return center, radius
        
    
    def energy(self, theta):
        """
        
        theta is a flattened num_cells x 4 dimensional vector
        
        """
                
        # retrieve the 'encoded' information
        q, z, p = self.extract_values(theta)
        
        # transform
        center, radius = self.transform(q, z, p) 
        
        # compute the loss
        energy = 0
        for (alpha, beta) in self.cell_pairs:
            edge = self.edges[alpha][beta]
            for pixel in edge:
                energy += (np.linalg.norm(pixel - center[alpha][beta]) - radius[alpha][beta])**2
        
        return energy/(2*self.num_edges)
        
    def is_square(self, i):
        x = i // 2
        visited = set([x])
        while x * x != i:
            x = (x + (i // x)) // 2
            if x in visited:
                return False
            visited.add(x)
        return True
        
    def extract_values(self, theta):
        N = len(theta)//4
        x = theta[:N]
        y = theta[N:2*N]
        z = theta[2*N:3*N]
        p = theta[3*N:4*N]
        q = np.array([x,y]).T
        return q, z, p
    
    def initialize_points(self):
        x = [] ; y = [] ; z = [] ; p = []
        
        for alpha in self.cells:
            center = self.barrycenters[alpha]
    
            x.append(center[0]) ; y.append(center[1]) ; z.append(0)
            p.append(np.random.uniform(0.001, 0.005))
                
        return np.array(x + y + z + p)
    
    def get_vertices(self, alpha, beta):
        """
        
        return the two vertices between any two cells
        
        """
        
        edge = self.edges[alpha][beta]
        
        # find the edge endpoints
        edge_mean = np.mean(edge, axis=0)
        start_ind = np.argmax(np.array([np.linalg.norm(x - edge_mean) for x in edge]))
        end_ind = np.argmax(np.array([np.linalg.norm(x - edge[start_ind]) for x in edge]))
        start = np.array(edge[start_ind]) ; end = np.array(edge[end_ind])
                    
        return [start, end]
    
    def get_tangents(self, alpha, beta):
        """
        finds the normalized tangents at the two vertices in the edge between alpha and beta
        """
        v1, v2 = self.vertices[alpha][beta]
        
        # find closest points in the edge to each of the vertices and find line that goes through them
        edge = [e for e in self.edges[alpha][beta] if list(e) not in [list(v1), list(v2)]]

        v1_b = np.mean(np.array([edge[i] for i in np.array([np.linalg.norm(x - v1) for x in edge]).argsort()[:5]]), axis=0)
        v2_b = np.mean(np.array([edge[i] for i in np.array([np.linalg.norm(x - v2) for x in edge]).argsort()[:5]]), axis=0)

        # compute tangents
        t1 = (v1_b - v1) / np.linalg.norm(v1_b - v1)
        t2 = (v2_b - v2) / np.linalg.norm(v2_b - v2)
        
        return [t1, t2]
        
    
    def initialize(self):
        """
        
        Minimization to determine initial (p, q) -- we want the vector pointing from the center each of the 
        vertices to be orthogonal to the tangents. Minimization of C1 with some constraints to avoid trivial 
        solutions
        
        """
        
        # define the shorthand t_i where i in {1, 2} denotes the vertex in question between cells alpha and beta. It is 
        # the vector from the center of cell alpha to vertex i in between alpha and beta
        t = lambda alpha, beta, q, p, i : (p[alpha-1] - p[beta-1])*self.vertices[alpha][beta][i] - (p[alpha-1]*q[alpha-1] - p[beta-1]*q[beta-1])
        
        
        def extract(theta):
            # get p and q from theta used in the following minimization
            N = len(theta)//3
            x = theta[:N] ; y = theta[N:2*N] ; q = np.array([x,y]).T
            p = theta[2*N:3*N]
            return q, p
        
        
        # add constraints
        def ratio_constraint(theta):
            """
            The solution is constrained so that the ratio of the average
            magnitude of t to the average pressure differential
            equals the averaged measured radius of curvature
            in the image. 
            
            The average measured radius of curvature is found via least squares
            """
            
            # find the ratio
            q, p = extract(theta)
            avg_t_norm = 0
            avg_p_differential = 0
            for (alpha, beta) in self.cell_pairs:
                avg_p_differential += np.abs(p[alpha-1] - p[beta-1])
                for i in range(2):
                    avg_t_norm += np.linalg.norm(t(alpha, beta, q, p, i))
            
            ratio = avg_t_norm / 2*avg_p_differential
            # R_avg is precomputed outside of this function 
            return ratio - R_avg
        
        # Try to make the tangent and the ti orthogonal by minimizing this function
        def E_initial(theta):            
            q, p = extract(theta)
            E = 0
            
            E1 = abs(ratio_constraint(theta))
            E += 1000000*E1
            
            if len([True for i in range(len(p)) if p[i] < 0]):
                return np.inf
            
            for (alpha, beta) in self.cell_pairs:
                # for both vertices at the extremities
                for i in range(2):
                    E += (t(alpha, beta, q, p, i) @ self.tangents[alpha][beta][i])**2
            return E

        def fit_circle(edge):
            """
            fitting the circle involves finding a point st the distance from that point to every
            point in the egde is the same. Error to minimize is thus the variance of the distances from the center
            """ 
            def error(x_c, y_c):
                center = np.array([x_c, y_c])
                return np.std([np.linalg.norm(x - center) for x in edge])**2
            
            # optimize
            center0 = np.mean(edge, axis=0)
            x_c0, y_c0 = center0[0], center0[1] 
            center = leastsq(error, x_c0, y_c0)[0]
            R = np.mean([np.linalg.norm(x - center) for x in edge])
            return R
        
        R_avg = np.mean([fit_circle(self.edges[alpha][beta]) for (alpha, beta) in self.cell_pairs])
        
        theta0_with_z = list(self.initialize_points()) ; N = len(theta0_with_z)//4
        theta0 = np.array(theta0_with_z[:2*N] + theta0_with_z[3*N:4*N])  # only keep q and p
        
        bounds = []
        N = len(theta0)//3
        for i in range(len(theta0)):
            if i <= N: # x
                bounds.append((max(0, theta0[i]-10), min(theta0[i]+10, self.width)))
            elif i <= 2*N: # y 
                bounds.append((max(0, theta0[i]-10), min(theta0[i]+10, self.height)))
            else: # p
                bounds.append((0.001, 0.005))

        # minimize to find the optimal q and p
        print("Finding the initialization of q and p")
        q, p = extract(theta0)
        optimal = minimize(E_initial, theta0, options={'disp':True}, bounds=bounds).x
        print(E_initial(optimal))

        q, p =  extract(optimal)
        x = q.T[0] ; y = q.T[1]
        
        # find z. For NOW, set to be 0 -- later, we will implement the actual equation
        z = [0 for _ in range(len(p))]

        return np.array(list(x) + list(y) + list(z) + list(p))
        
    def load_constraints(self, theta0):
        """
        come up with the constraints and bounds
        """
        constraints = []
        bounds = []
            
        N = len(theta0)//4
        
        for i in range(len(theta0)):
            if i <= N: # x
                bounds.append((max(0, theta0[i]-10), min(theta0[i]+10, self.width)))
            elif i <= 2*N: # y 
                bounds.append((max(0, theta0[i]-10), min(theta0[i]+10, self.height)))
            elif i <= 3*N: # z
                bounds.append((0, 0.5))
            else: # p
                bounds.append((0.001, 0.005))
                
        # non-imaginary constraint, i.e. force things inside of the root in Equation 7 to not be negative
        def imaginary_constraint(alpha, beta):
            def constraint(theta):
                N = len(theta)//4
                x = theta[:N]
                y = theta[N:2*N]
                z = theta[2*N:3*N]
                p = theta[3*N:4*N]
                q = np.array([x, y]).T

                return ((p[alpha-1]*p[beta-1]) * np.linalg.norm(q[alpha-1] - q[beta-1])**2)/(p[alpha-1] - p[beta-1])**2 \
                        - (p[alpha-1] * (z[alpha-1]**2) - p[beta-1] * (z[beta-1]**2))/(p[alpha-1] - p[beta-1])
            
            return constraint
        
        for (alpha, beta) in self.cell_pairs:
            constraints.append({'type': 'ineq', 'fun': imaginary_constraint(alpha, beta)})
        
        return bounds, constraints
    
    def fit(self):
        
        """
        
        Perform the minimization of equation 5 with respect to the variables (q, z, p)
        
        """
        
        # initialize the vector
        theta0 = self.initialize()
        #theta0 = self.initialize_points()
                
        # get the bounds and constraints
        bounds, constraints = self.load_constraints(theta0)
        
        print("Main minimization")
        # minimize
        print(theta0)
        optimal = minimize(self.energy, theta0, options={'disp':True}, bounds=bounds, constraints=constraints).x
        
        # extract the values from the optimal vector
        q, z, p = self.extract_values(optimal)
        
        # compute the tensions
        self.get_tensions(q, z, p)
        
        return q, z, p
    
    def get_tensions(self, q, z, p):
        
        """
        
        applies the Young-Laplace law to obtain the tensions at every edge
        
        """
        
        center, radius = self.transform(q, z, p) 
        
        for (alpha, beta) in self.cell_pairs:
            if alpha not in self.edge_cells or beta not in self.edge_cells: 
                self.tension[alpha][beta] = np.abs((p[alpha-1] - p[beta-1]) * radius[alpha][beta])
                self.tension[beta][alpha] = self.tension[alpha][beta]
        
    def get_normalized_tensions(self):
        """
        
        normalize the tensions between 0 and 1. Used for plotting in the CAP tiling
        
        """
        tensions_normalized = {alpha: {beta: None for beta in self.cells} for alpha in self.cells}
        min_T = np.inf ; max_T = -np.inf
        
        for (alpha, beta) in self.cell_pairs:
            if alpha not in self.edge_cells or beta not in self.edge_cells:
                if self.tension[alpha][beta] > max_T: max_T = self.tension[alpha][beta]
                elif self.tension[alpha][beta] < min_T: min_T = self.tension[alpha][beta]
        for (alpha, beta) in self.cell_pairs:
            if alpha not in self.edge_cells or beta not in self.edge_cells:
                tensions_normalized[alpha][beta] = (self.tension[alpha][beta] - min_T) / (max_T - min_T)
                tensions_normalized[beta][alpha] = -1 * tensions_normalized[alpha][beta]
        return tensions_normalized
            
    def CAP(self, img, q, z, p):
        
        """
        
        Takes the generating points determined by the minimization, 
        finds the corresponding (center, radius) pairs for each edge, 
        and plots the resuling cirles between the first and last elements in each edge
        
        """
        img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
        center, radius = self.transform(q, z, p) 
        
        # normalize the tensions (to see the differences when plotting the colors)
        tensions_normalized = self.get_normalized_tensions()

        for (alpha, beta) in self.cell_pairs: 
            if alpha not in self.edge_cells or beta not in self.edge_cells:
                # get the tension in the edge
                T = tensions_normalized[alpha][beta]

                # get the corresponding CAP center and radius
                rho, R = center[alpha][beta], radius[alpha][beta]

                # define the set of points in the circle using polar coordinates
                circle = [rho + R*np.array([np.cos(theta), np.sin(theta)]) for theta in np.linspace(0, 2*np.pi, num=70000)]

                # find the edge endpoints
                start, end = self.vertices[alpha][beta]

                # find the points in the set that are closest to the edge extremities
                cap_s = np.argmin(np.linalg.norm(circle - start, axis=1))
                cap_e = np.argmin(np.linalg.norm(circle - end, axis=1))
                arcs = [circle[cap_e:] + circle[:cap_s+1], circle[cap_s:cap_e+1], circle[cap_s:] + circle[:cap_e+1], circle[cap_e:cap_s+1]]
                lengths = [len(arc) if len(arc) != 0 else np.inf for arc in arcs]
                arc = arcs[np.argmin(np.array(lengths))]

                try:
                    # plot the continuous path between those two points
                    for point in arc:
                        # color will be determined by the tension
                        img = cv2.circle(img, (int(point[1]), int(point[0])), 2, (1, (1-T), (1-T)), -1)

                    # plot the arc endpoints
                    img = cv2.circle(img, (int(arc[0][1]), int(arc[0][0])), 5, (0, 0, 0), -1)
                    img = cv2.circle(img, (int(arc[-1][1]), int(arc[-1][0])), 5, (0, 0, 0), -1)
                except:
                    print(T)

                # plot the generating points
                #img = cv2.circle(img, (int(q[alpha-1][1]), int(q[alpha-1][0])), 3, (1, 1, 1), -1)
                #img = cv2.circle(img, (int(q[beta-1][1]), int(q[beta-1][0])), 3, (1, 1, 1), -1)
        
        return img

In [None]:
def get_actual(seg, dtr):
    actual_model = VMSI(cell_pairs = seg.pairs(), edges = seg.edges(), num_cells = len(seg.cells[0]), 
             cells = seg.cells[0], edge_cells = seg.get_edge_cells(), barrycenters = seg.barrycenters[0], height=256, width=256)
    q, z, p = actual_model.extract_values(model.initialize_points())
    actual_model.get_tensions(q, z, p)
    
    # get the actual q, z, and p by finding the closest points that were used to generate the image
    q_actual = []
    z_actual = []
    p_actual = []

    generating_q = np.array([point.q[::-1] for point in generating_points.points])
    generating_z = np.array([point.z for point in generating_points.points])
    generating_p = np.array([point.p for point in generating_points.points])

    for i in range(len(q)):
        closest_index = np.argmin([np.linalg.norm(x - q[i]) for x in generating_q])
        q_actual.append(generating_q[closest_index])
        z_actual.append(generating_z[closest_index])
        p_actual.append(generating_p[closest_index])

    q_actual = np.array(q_actual)
    z_actual = np.array(z_actual)
    p_actual = np.array(p_actual)
    
    img = actual_model.CAP(dtr.transform.copy(), q_actual, z_actual, p_actual)
   
    plt.figure(figsize=(15, 10))
    plt.imshow(img)
    plt.show()
    
    return actual_model


def evaluate(model, seg=seg, dtr=dtr):
    actual_model = get_actual(seg, dtr)
    
    predicted = model.tension
    actual = actual_model.tension

    x_points = []
    y_points = []
    for (alpha, beta) in model.cell_pairs:
        if alpha not in model.edge_cells or beta not in model.edge_cells:
            try:
                pred = abs(predicted[alpha][beta])
                actual = abs(actual[alpha][beta])
                
                x_points.append(pred)
                y_points.append(actual)
            except:
                pass

    print(scipy.stats.pearsonr(x_points, y_points))
    plt.figure(figsize=(5,5))
    plt.scatter(x_points, y_points)
    return x_points, y_points

## Example with Synthetic Data

In [None]:
config = {
          'max_x' : 256,
          'max_y' : 256,
          'min_p' : 0.001,
          'max_p' : 0.004,
          'num_points' : 50
         }

generating_points = Points(**config)

dtr = DistanceTransform(generating_points)
dtr.compute_transform()
generating_points.sort(dtr.CMs)
dtr.visualize_transform()

In [None]:
seg = Segmenter()
seg.segment(dtr.transform.copy(), diameter=None)
seg.visualize('outlines', overlay=True)

In [None]:
model = VMSI(cell_pairs = seg.pairs(), edges = seg.edges(), num_cells = len(seg.cells[0]), 
             cells = seg.cells[0], barrycenters = seg.barrycenters[0], 
             edge_cells = seg.get_edge_cells(), height=256, width=256)

q, z, p = model.fit()

plt.figure(figsize=(15, 10))
plt.imshow(model.CAP(dtr.transform.copy(), q, z, p))
plt.show()

evaluate(model)

## Example with Real Data

In [None]:
# real image
image = cv2.resize(cv2.imread('test.png', 0), (256, 256))
plt.imshow(image)
plt.show()

seg = Segmenter()
seg.segment(image, diameter=None)
seg_img = seg.visualize('outlines', overlay=True, return_img=True)

model = VMSI(cell_pairs = seg.pairs(), edges = seg.edges(), num_cells = len(seg.cells[0]), 
             cells = seg.cells[0], barrycenters = seg.barrycenters[0], 
             edge_cells = seg.get_edge_cells(), height=256, width=256)

q, z, p = model.fit()

plt.figure(figsize=(15, 10))
plt.imshow(model.CAP(np.float32(image.copy()), q, z, p))
plt.show()