## Extracting Dihedral Angles at Triple Junctions from Segmentation Stacks

### Prep

In [None]:
### Imports
import scipy.interpolate as scipolate
import itertools
import numpy as np
import matplotlib.pyplot as plt
import scipy.spatial as sps

from skimage import io
from ipywidgets import interact

from mpl_toolkits.mplot3d import Axes3D

In [None]:
### Load input

im = io.imread('../../../ForceInferenceProject/Data/Generated/three_intersecting_spheres_aniso.tif')
print(im.dtype, im.shape)

In [None]:
### Show input segmentations

@interact(z=(0, im.shape[0]-1, 1))
def show_stack(z=im.shape[0]//2):
    plt.figure(figsize=(8,8))
    plt.imshow(im[z], cmap='gray')
    plt.show()

### Identifying Object Outlines

In [None]:
### Identify outlines by comparing shifted images

# Pad the image by 1 voxel on all sides
im_pad = np.pad(im, 1, mode='reflect')

# Get possible shifts in all directions
shifts = itertools.product([0,1], repeat=3)

# Check and accumulate differences in shifts
outlines = np.zeros_like(im, dtype=np.bool)
for shift in shifts:
    zs0, ys0, xs0 = [slice(1, None) if s else slice(None) for s in shift]
    zs1, ys1, xs1 = [slice(None,-1) if s else slice(None) for s in shift]
    comparison = im_pad[zs0, ys0, xs0] != im_pad[zs1, ys1, xs1]
    outlines  += comparison[:im.shape[0],  :im.shape[1],  :im.shape[2]]
    outlines  += comparison[-im.shape[0]:, -im.shape[1]:, -im.shape[2]:]  # Symmetry
    
# Re-annotate the cell identities
outlines_id = outlines * im

# Report
print(outlines.dtype, outlines.shape)
print(outlines_id.dtype, outlines_id.shape)

In [None]:
### Show identified outlines

@interact(z=(0, im.shape[0]-1, 1))
def show_stack(z=im.shape[0]//2):
    plt.figure(figsize=(8,8))
    plt.imshow(outlines_id[z], cmap='gray')
    plt.show()

### Identifying Triple Edges and Triple Junctions

In [None]:
### Find coordinates of all voxels involved in triple junctions

# Get coordinates of all cell outline voxels
OCs = np.array(np.where(outlines)).T  # OCs <- "Outline Coordinates"

# Go through outline coordinates (oc) and find TEs
TEs = []  # TEs <- "Triple Edes"
for OC in OCs:
    selection = im_pad[OC[0]+1:OC[0]+3, OC[1]+1:OC[1]+3, OC[2]+1:OC[2]+3]
    if len(set(selection.flatten())) == 3:
        TEs.append(OC+0.5)
        
# Convert TE list to numpy array
TEs = np.array(TEs)

In [None]:
### Build a dict of TJs structured as: {tuple(cell1_ID, cell2_ID, cell3_ID) : array(TEs, Dimensions)}

# Prepare defaultdict
from collections import defaultdict
TJs = defaultdict(lambda : [[],[],[]])  # TJs <- "Triple Junctions"

# Go through TEs, create IDs, assign coordinates to IDs
for TE in TEs:
    selection = im_pad[np.int(TE[0])+1:np.int(TE[0])+3, 
                       np.int(TE[1])+1:np.int(TE[1])+3, 
                       np.int(TE[2])+1:np.int(TE[2])+3]
    TJ_ID = tuple(sorted(set(selection.flatten())))
    TJs[TJ_ID][0].append(TE[0])
    TJs[TJ_ID][1].append(TE[1])
    TJs[TJ_ID][2].append(TE[2])

# Convert TJ lists to numpy arrays
for key in TJs.keys():
    TJs[key] = np.array(TJs[key]).T

In [None]:
### Show identified TJs on image stack

@interact(z=(0, im.shape[0]-1, 1))
def show_stack(z=im.shape[0]//2):
    
    # Prep and plot image
    plt.figure(figsize=(8,8))
    plt.imshow(outlines_id[z], cmap='gray')
    
    # For each TJ...
    for TJ_num,TJ_ID in enumerate(TJs.keys()):
        
        # Get the TJs TEs in the selected z plane
        TEs_in_plane = TJs[TJ_ID][np.floor(TJs[TJ_ID][:,0])==z]
        
        # Plot the points (note the curios way of setting the color!)
        plt.scatter(TEs_in_plane[:, 2], TEs_in_plane[:, 1],
                    c=[TJ_num for _ in range(TEs_in_plane.shape[0])], 
                    cmap='hsv', vmin=0, vmax=len(TJs), s=20)
        
    # Finish
    plt.show()

In [None]:
### Show identified TJs as 3D scatter

# Prepare the plot
fig = plt.figure(figsize=(12,12))
ax  = fig.add_subplot(111, projection='3d')

# Plot each TJ in a different color
for TJ_num,TJ_ID in enumerate(TJs.keys()):
    ax.scatter(TJs[TJ_ID][:,2], TJs[TJ_ID][:,1], TJs[TJ_ID][:,0],
               c=[TJ_num for _ in range(TJs[TJ_ID].shape[0])], 
               cmap='hsv', vmin=0, vmax=len(TJs), s=10)

## Also show cell outlines [takes several seconds to render!]
#ax.scatter([c[2] for c in OCs],
#           [c[1] for c in OCs],
#           [c[0] for c in OCs],
#           c='gray', alpha=0.05, linewidth=0, s=5)

# Axis limits
ax.set_xlim([0,400])
ax.set_ylim([0,400])
ax.set_zlim([0,400])  # NOTE: RESULT IS SQUASHED IF RESOLUTION IS NOT CONSIDERED!

# Finish
plt.show()

### Spline Fitting

*Turns out spline fitting requires the input points to be roughly in order along the spline, which isn't guaranteed in our case. Ordering the points happens to be far harder problem than one might imagine (it's a variation of traveling salesman) but luckily it can be solved quite well with a Breadth-First Search (BFS). This solution is partially inspired by Imanol Luengo's answer to [this SO question](https://stackoverflow.com/questions/37742358/sorting-points-to-form-a-continuous-line).*

In [None]:
### Get an example TJ

TJ = TJs[list(TJs.keys())[0]]

In [None]:
### Get all pairwise distances

import scipy.spatial.distance as ds
dists = ds.squareform(ds.pdist(TJ))

In [None]:
### Plot distance matrix

plt.figure(figsize=(8,8))
plt.imshow(dists)
plt.title("distances")
plt.xlabel('TEs')
plt.ylabel('TEs')
plt.show()

In [None]:
### Generate graph based on N nearest neighbors

# Get nearest neighbors
from sklearn.neighbors import NearestNeighbors  # TODO: Replace with scipy kdTRee
N_neighbors = 10  # Note: the outcome seems quite robust to this number; still think about how to set this?
NN_graph = NearestNeighbors(10).fit(dists).kneighbors_graph()

# Create the graph
import networkx as nx
G = nx.from_scipy_sparse_matrix(NN_graph)
nx.is_connected(G)  # TODO: Use this as a sanity check!

In [None]:
### Find BFS tree from all sources

all_paths = []
all_dsums = []
for n in G.nodes():
    
    # Get BFS tree
    path = list(nx.bfs_tree(G, n))
    
    # Get sum of all distances within tree
    dsum = 0.0
    for p in range(len(path)-1):
        dsum += dists[path[p], path[p+1]]
    
    # Keep results
    all_paths.append(path)
    all_dsums.append(dsum)

# Select the best solution
best_path = all_paths[np.argmin(all_dsums)]
sorted_TJ = TJ[best_path]

In [None]:
### Visualize the sorted result

# Get sorted pairwise distances
sorted_dists = ds.squareform(ds.pdist(sorted_TJ))

# Plot pairwise distances
fig, ax = plt.subplots(1, 2, figsize=(13,6.5))
ax[0].imshow(sorted_dists)
ax[0].set_title("sorted distances")
ax[0].set_xlabel('TEs'); ax[0].set_ylabel('TEs')

# Plot path/tree on scatter
ax[1].scatter(TJ[:,1], TJ[:,0], s=20, alpha=0.5)
ax[1].plot(sorted_TJ[:,1], sorted_TJ[:,0], c='r', alpha=0.75)
ax[1].set_title("path scatter")
ax[1].set_xlabel('y'); ax[1].set_ylabel('z')

# Done
plt.tight_layout()
plt.show()

In [None]:
def sortPoints(TJ):
    #euclidian distances
    dists = ds.squareform(ds.pdist(TJ))
    
    #sciKit Learn
    N_neighbors = 10  # Note: the outcome seems quite robust to this number; still think about how to set this?
    NN_graph = NearestNeighbors(N_neighbors).fit(dists).kneighbors_graph()
    
    #networkx : implement the "nx.is_connected(G)" method to check if everything is connencted
    G = nx.from_scipy_sparse_matrix(NN_graph)
    
    all_paths = []
    all_dsums = []
    for n in G.nodes():

        # Get BFS tree
        path = list(nx.bfs_tree(G, n))

        # Get sum of all distances within tree
        dsum = 0.0
        for p in range(len(path)-1):
            dsum += dists[path[p], path[p+1]]

        # Keep results
        all_paths.append(path)
        all_dsums.append(dsum)

    # Select the best solution
    best_path = all_paths[np.argmin(all_dsums)]
    sorted_TJ = TJ[best_path]
    return sorted_TJ

In [None]:
def zsortPoints(interfaces,TE):
    #euclidian distances
    dists = ds.squareform(ds.pdist(interfaces))
    
    #sciKit Learn
    N_neighbors = 10  # Note: the outcome seems quite robust to this number; still think about how to set this?
    NN_graph = NearestNeighbors(N_neighbors).fit(dists).kneighbors_graph()
    
    #networkx : implement the "nx.is_connected(G)" method to check if everything is connencted
    G = nx.from_scipy_sparse_matrix(NN_graph)
    
    # Get BFS tree
    path = list(nx.bfs_tree(G, TE))
    sorted_IC = interfaces[path]
    return sorted_IC

In [None]:
tckDict={}
fpDict={}
ierDict={}
msgDict={}
splevDict = {}
for TJ_ID in TJs.keys():
    
    #sorts all of the points for each TJ
    TJs[TJ_ID] = sortPoints(TJs[TJ_ID])
    
    #tck has 2 values: tck[0] is tck and tck[1] is u
    #gets the knots for the spline
    tck, fp, ier, msg = scipolate.splprep(x = [TJs[TJ_ID][:,2],TJs[TJ_ID][:,1],
                                            TJs[TJ_ID][:,0]], k = 3, full_output=1)
    #could adjust s to fit the spline more smoothly to the curve
    
    tckDict[TJ_ID] = tck
    fpDict[TJ_ID] = fp
    ierDict[TJ_ID] = ier
    msgDict[TJ_ID] = msg
    
    #evaluates the knots produced and puts them in xyz positions
    splevDict[TJ_ID] = scipolate.splev(x = np.linspace(0.0,1.0,100), tck = tckDict[TJ_ID][0])

In [None]:
# Prepare defaultdict
from itertools import combinations
ICs = defaultdict(lambda : [[],[],[]])  # ICs <- "Interface Cordinates"

# Go through OCs, create IDs, assign coordinates to IDs
for OC in OCs:
    selection = im_pad[np.int(OC[0])+1:np.int(OC[0])+3, 
                       np.int(OC[1])+1:np.int(OC[1])+3, 
                       np.int(OC[2])+1:np.int(OC[2])+3]
    IC_ID = tuple(sorted(set(selection.flatten())))
    if len(IC_ID) == 2:
        ICs[IC_ID][0].append(OC[0])
        ICs[IC_ID][1].append(OC[1])
        ICs[IC_ID][2].append(OC[2])
    elif len(IC_ID) > 2:
        comb = combinations(IC_ID,2)
        for ID in comb:
            ICs[ID][0].append(OC[0])
            ICs[ID][1].append(OC[1])
            ICs[ID][2].append(OC[2])

# Convert IC lists to numpy arrays
for key in ICs.keys():
    ICs[key] = np.array(ICs[key]).T

In [None]:
interfaceByZ = []
for z in range(im.shape[0]):
    tempList = []
    for OC in OCs:
        if OC[0] == z:
            selection = im_pad[np.int(OC[0])+1:np.int(OC[0])+3, 
                       np.int(OC[1])+1:np.int(OC[1])+3, 
                       np.int(OC[2])+1:np.int(OC[2])+3]
            simple = tuple(sorted(set(selection.flatten())))
            if len(simple) == 2:
                tempList.append(simple)
            else:
                for temp in combinations((simple),2):
                    tempList.append(temp)
    interfaceByZ.append(list(set(tempList)))

In [None]:
TEpointsByZ = {}
for z in range(im.shape[0]):
    TEpointsByZ[z] = []
    for TJ_num,TJ_ID in enumerate(TJs.keys()):
        TEs_in_plane = TJs[TJ_ID][np.floor(TJs[TJ_ID][:,0])==z]
        while(len(TEs_in_plane) > 0):
            tree = sps.KDTree(TEs_in_plane)
            idx = tree.query_ball_point(TEs_in_plane[0],15.0)
            
            TE_pointx = np.mean(TEs_in_plane[idx][:,2])
            TE_pointy = np.mean(TEs_in_plane[idx][:,1])

            TEpointsByZ[z].append([z,TE_pointy,TE_pointx])
            TEs_in_plane = np.delete(TEs_in_plane,idx,0)
            
for z in TEpointsByZ:
    TEpointsByZ[z] = np.array(TEpointsByZ[z])
        

In [None]:
@interact(z=(0, im.shape[0]-1, 1))
def show_stack(z=im.shape[0]//2):
    
    # Prep and plot image
    plt.figure(figsize=(8,8))
    plt.imshow(outlines_id[z], cmap='gray')
    
    for IC_num,IC_ID in enumerate(allPlaneSplev[z].keys()):
        
        # Plot the points (note the curios way of setting the color!)
        plt.scatter(TEpointsByZ[z][:,2], TEpointsByZ[z][:,1], cmap='hsv', s=30)
    
    
    # Finish
    plt.show()

In [None]:
surroundingCords = {}
for z in range(im.shape[0]):
    Dicty = {}
    OCs_in_plane = OCs[OCs[:,0]==z]
    if(len(OCs_in_plane)>0):
        tree = sps.KDTree(OCs_in_plane)
        for point in TEpointsByZ[z]:
            idx = tree.query_ball_point(point,10.0)
            Dicty[tuple(point)] = OCs_in_plane[idx]
    surroundingCords[z] = Dicty

In [None]:
@interact(z=(0, im.shape[0]-1, 1))
def show_stack(z=im.shape[0]//2):
    
    # Prep and plot image
    plt.figure(figsize=(8,8))
    plt.imshow(outlines_id[z], cmap='gray')
    
    for IC_num,IC_ID in enumerate(allPlaneSplev[z].keys()):
        
        # Plot the points (note the curios way of setting the color!)
        plt.scatter(TEpointsByZ[z][:,2], TEpointsByZ[z][:,1], cmap='hsv', s=2)
        for point in TEpointsByZ[z]:
            plt.scatter(surroundingCords[z][tuple(point)][:,2], surroundingCords[z][tuple(point)][:,1],
                        cmap='hsv', s=2)
    
    
    # Finish
    plt.show()

### TODO

- Brodland approach
    - Fit splines to outlines in the image plane
    - Identify angles in image plane for each TE *[sort of needed in both]*
    - Fit splines to TJs *[needed in both]*
    - Find normal plane to TJ-spline at each TE *[needed in both]*
    - Project image plane angles onto normal plane
    
    
- Better approach?
    - Fit splines to TJs *[needed in both]*
    - Find normal plane to TJ-spline at each TE *[needed in both]*
    - Identify angles in normal plane for each TE *[sort of needed in both]*