# Create and Visualize an UprightTransform

In [8]:
from collections import namedtuple, defaultdict
from timeit import default_timer as timer

import numpy as np
import matplotlib.pyplot as plt
from matplotlib import collections as mc
from scipy.spatial import voronoi_plot_2d

import allensdk.core.json_utilities as ju
from allensdk.internal.pipeline_modules.cell_types.morphology import upright_transform as ut
from neuron_morphology.transforms.upright_transform import *

transforms = defaultdict(list)

%matplotlib qt
#%matplotlib inline 

In [9]:
production = "/Volumes/programs/celltypes/production/"
input_files = [
    production + "mousecelltypes/prod2164/specimen_922613228/922613228_input_cell_alignment.json",
    production + "mousecelltypes/prod2413/specimen_962261526/962261526_input_cell_alignment.json",
    production + "mousecelltypes/prod1556/specimen_825269133/825269133_input_cell_alignment.json",
    production + "humancelltypes/prod392/specimen_889779159/889779159_input_cell_alignment.json"
    ]
# input_files = input_files[2:3]

In [10]:
Input = namedtuple('Input', 'soma_coords pia_coords wm_coords soma_center soma_coords_str pia_coords_str wm_coords_str')
def read_file(filename):
    jin = ju.read(filename)

    soma_coords_str = jin['primary']['Soma']['path']
    pia_coords_str = jin['primary']['Pia']['path']
    wm_coords_str = jin['primary']['White Matter']['path']

    soma_coords = convert_coords_str(soma_coords_str)
    pia_coords = convert_coords_str(pia_coords_str)
    wm_coords = convert_coords_str(wm_coords_str)

    soma_center = np.asarray([soma_coords['x'].mean(), soma_coords['y'].mean()])
    return Input(soma_coords, pia_coords, wm_coords, soma_center, soma_coords_str, pia_coords_str, wm_coords_str)

In [11]:
def _add_ax(plot_func, *args, ax=None, **kwargs):
    if ax is None:
        fig = plt.figure()
        ax = fig.gca()
    return plot_func(*args, ax=ax, **kwargs)

In [12]:
def plot_pia_wm(data, ax=None):
    if not ax:
        _, ax = plt.subplots()

    ax.plot(data.pia_coords['x'], data.pia_coords['y'])
    ax.plot(data.wm_coords['x'], data.wm_coords['y'])
    return ax

In [13]:
def plot_inputs(input_data, ax=None):
    if not ax:
        _, ax = plt.subplots()
    
    soma_coords = input_data.soma_coords
    pia_coords = input_data.pia_coords
    wm_coords = input_data.wm_coords
    soma_center = input_data.soma_center
    
    
    ax.plot(soma_coords['x'], soma_coords['y'])
    ax.plot(soma_center[0], soma_center[1])
    plot_pia_wm(input_data, ax)
    ax.legend(['soma', 'soma_center', 'pia', 'wm'])
    return ax

### Grab all inputs and plot

In [14]:
n = len(input_files)
fig, axes = plt.subplots(1, n, squeeze=False)

inputs = []
for i, input_file in enumerate(input_files):
    inputs.append(read_file(input_file))
    plot_inputs(inputs[i], axes[0, i])
    axes[0, i].axis('scaled')

### Smoothing

In [15]:
def gauss_smooth(coords, n_interp=100, sigma=10):
    from scipy.ndimage import gaussian_filter1d
    x = coords['x']
    y = coords['y']
    t = np.linspace(0, 1, len(x))
    t2 = np.linspace(0, 1, n_interp)

    x2 = np.interp(t2, t, x)
    y2 = np.interp(t2, t, y)

    gauss_coords = {}
    gauss_coords['x'] = gaussian_filter1d(x2, sigma)
    gauss_coords['y'] = gaussian_filter1d(y2, sigma)
    return gauss_coords

In [16]:
def smooth(coords, method):
    smoothers = {'gauss': gauss_smooth}
    smoothed_coords = smoothers[method](coords)
    return smoothed_coords

In [17]:
def plot_smoothed_coords(input_data, smooth_input, ax):
    if not ax:
        _, ax = plt.subplots()
    
    plot_pia_wm(input_data, ax)
    plot_pia_wm(smooth_input, ax)

    return ax

In [18]:
fig, axes = plt.subplots(1, n, squeeze=False)

SmoothInputs = namedtuple('SmoothInputs', 'pia_coords, wm_coords')
smooth_inputs = []
for i, input_data in enumerate(inputs):
    smooth_input = SmoothInputs(smooth(input_data.pia_coords, 'gauss'),
                                smooth(input_data.wm_coords, 'gauss'))
    smooth_inputs.append(smooth_input)
    plot_smoothed_coords(input_data, smooth_input, axes[0, i])
    axes[0, i].axis('scaled')


### Perpendicular Pia
Plot perpendicular lines to pia and the projection of the soma center on to the pia

In [19]:
def perpendicular_line(xlist, ylist, scale=1000):
    dx = (xlist[1] - xlist[0])
    dy = (ylist[1] - ylist[0])
    cx = xlist[0] + dx / 2
    cy = ylist[0] + dy / 2
    
    scale /= np.sqrt(dx ** 2 + dy ** 2)
    x_seg = [cx, cx - dy * scale]
    y_seg = [cy, cy + dx * scale]
    
    return x_seg, y_seg

In [20]:
def plot_perp_pia(data, soma_center, ax=None, plot_every=3):
    if not ax:
        _, ax = plt.subplots()
    
    plot_pia_wm(data, ax)
    for i in range(len(data.pia_coords['x']) - 1):
        if i % plot_every == 0:    
            x_seg, y_seg = perpendicular_line(data.pia_coords['x'][i:i+2],data.pia_coords['y'][i:i+2])
            ax.plot(x_seg, y_seg, 'k')

    return ax

In [21]:
fig, axes = plt.subplots(1, n, squeeze=False)

for i, input_data in enumerate(inputs):
    plot_perp_pia(smooth_inputs[i], input_data.soma_center, axes[0, i])
    axes[0, i].axis('scaled')


In [22]:
fig, axes = plt.subplots(1, n,squeeze=False)

for i, input_data in enumerate(inputs):
    proj = project_to_polyline(smooth_inputs[i].pia_coords, input_data.soma_center)
    plot_pia_wm(smooth_inputs[i], axes[0, i])
    axes[0, i].plot([proj[0], input_data.soma_center[0]], [proj[1], input_data.soma_center[1]])
    axes[0, i].axis('scaled')
    
    theta = np.arctan2((proj[1] - input_data.soma_center[1]), (proj[0] - input_data.soma_center[0]))
    theta = np.pi / 2 - theta
    rot = rotation_from_angle(theta, axis=2)
    translation = np.append(-input_data.soma_center, 0)
    affine = affine_from_transform_translation(translation=translation,
                                               transform=rot,
                                               translate_first=True)
    transforms['perpendicular_to_pia'].append(UprightTransform(affine))

### Ray trace minimum distance

In [23]:
fig, axes = plt.subplots(1, n,squeeze=False)

for i, input_data in enumerate(inputs):
    start = timer()
    px, py, wx, wy= ut.calculate_shortest(input_data.soma_center[0], input_data.soma_center[1],
                                          input_data.pia_coords_str, input_data.wm_coords_str)
    end = timer()
    print(end - start)
    theta = np.pi /2 - np.arctan2((py - wy), (px - wx))
    rot = rotation_from_angle(theta, axis=2)
    translation = np.append(-input_data.soma_center, 0)
    affine = affine_from_transform_translation(translation=translation,
                                               transform=rot,
                                               translate_first=True)
    plot_pia_wm(input_data, axes[0, i])
    axes[0, i].plot([px, wx],[py, wy])
    
    transforms['ray_trace_min_dist'].append(UprightTransform(affine))
    

35.96802252402995
16.31279934104532
47.13320497702807
98.00265596504323


### Voronoi
The voronoi diagram creates lines of equal distance between points. Ignoring the lines that fall outside of the region leaves just the centerlines between the pia and soma. Finding the projection of the soma on to this midline finds the minimum distance between the wm and pia through the soma.

In [24]:
def plot_voronoi(data, input_data, ax=None):
    if not ax:
        _, ax = plt.subplots()
    
    soma_coords = input_data.soma_coords
    soma = input_data.soma_center
    pia_coords = data.pia_coords
    wm_coords = data.wm_coords    
    
    T, v_diagram, mid_line, min_proj, theta = UprightTransform.from_coords(soma_coords, pia_coords, wm_coords, n_interp=3)
    #voronoi_plot_2d(v_diagram, ax=ax)

    lc = mc.LineCollection(mid_line)
    ax.plot(soma_coords['x'], soma_coords['y'])
    ax.plot(soma[0], soma[1])
    ax.plot(pia_coords['x'], pia_coords['y'])
    ax.plot(wm_coords['x'], wm_coords['y'])
    ax.add_collection(lc)
    ax.arrow(soma[0], soma[1], min_proj[0] - soma[0], min_proj[1] - soma[1])
    ax.legend(['soma', 'soma_center', 'pia', 'wm', 'midline','soma_projection'])
    
    return T, ax

### Using the Transform
transform_coords is a helper function that creates a 3xN array (setting z to zero) from a dict of coords ({'x': \[x_values\], 'y':\[y_values\]}. It then applies the transform T to them.

In [25]:
def transform_coords(coords, T):
    n_coords = coords['x'].shape[0]
    array = np.hstack((coords['x'].reshape((n_coords, 1)),
                       coords['y'].reshape((n_coords, 1)),
                       np.zeros((n_coords, 1))))
    return T.transform(array.T).T

In [26]:
def plot_transformed_input(input_data, transform, ax=None):
    soma_coords = input_data.soma_coords
    pia_coords = input_data.pia_coords
    wm_coords = input_data.wm_coords
    
    soma_coords_t = transform_coords(soma_coords, transform)
    ax.plot(soma_coords_t[:, 0], soma_coords_t[:, 1])
    pia_coords_t = transform_coords(pia_coords, transform)
    ax.plot(pia_coords_t[:, 0], pia_coords_t[:, 1])
    wm_coords_t = transform_coords(wm_coords, transform)
    ax.plot(wm_coords_t[:, 0], wm_coords_t[:, 1])
    ax.legend(['soma', 'pia', 'wm'])
    
    return ax

In [27]:
def print_transform(T):
    theta = np.arccos(transform_list[j].affine[0, 0])

    print('Theta (rad): ', theta)
    print('Theta (deg): ', np.degrees(theta))

    tvr = T.to_dict()
    for key, value in tvr.items():
        print(key, value)

In [28]:
fig, axes = plt.subplots(len(transforms), n, squeeze=False)

for i, (transform_name, transform_list) in enumerate(transforms.items()):
    for j, input_data in enumerate(inputs):
        plot_transformed_input(input_data, transform_list[j], axes[i, j])
        theta = np.degrees(np.arccos(transform_list[j].affine[0, 0]))
        axes[i, j].set_title(transform_name + f' {theta:.2f}')
        axes[i, j].axis('scaled')

