In [25]:
%matplotlib qt
import pharynx_io
import image_processing
import numpy as np
import matplotlib.pyplot as plt
import skimage.measure
import seaborn as sns
from skimage import segmentation
import xarray as xr
from scipy.interpolate import UnivariateSpline, interp1d, make_lsq_spline, make_interp_spline
from matplotlib.widgets import Slider, Button, RadioButtons

In [10]:
img_path = "/Users/sean/code/wormAnalysis/data/paired_ratio_movement_data_sean/2017_02_22-HD233_SAY47/2017_02_22-HD233_SAY47.tif"
strain_map_path = "/Users/sean/code/wormAnalysis/data/paired_ratio_movement_data_sean/2017_02_22-HD233_SAY47/indexer.csv"

strains = pharynx_io.load_strain_map(strain_map_path)

raw_imgs = pharynx_io.load_images(img_path, "TL/470_1/410_1/470_2/410_2", strains)

In [12]:
seg_imgs_ = raw_imgs > 2000

fl_imgs = raw_imgs[:,1,:,:]

(rot_fl, rot_seg) = image_processing.center_and_rotate_pharynxes(fl_imgs, seg_imgs_[:,1,:,:])

In [13]:
def measure_profile(fl, mid, xs):
    ys = mid(xs)
    zs = fl[np.int_(xs),np.int_(ys)]
    zs = []
    for x,y in zip(np.int_(xs),np.int_(ys)):
        zs.append(fl[y, x].data)
    return zs

In [14]:
mids = image_processing.calculate_midlines(rot_seg)

In [24]:
wvls = raw_imgs.wavelength
midlines = {}
for wvl in wvls.data:
    midlines[wvl] = image_processing.calculate_midlines()

TL
470_1
410_1
470_2
410_2


In [17]:
for k in mmap.values():
    print(k)

0
1
2


In [63]:
prof_xs = np.linspace(40, 120, 100)
profs = np.asarray([measure_profile(rot_fl[i], mids[i], prof_xs) for i in range(rot_fl.shape[0])])

fig,(ax,ax1) = plt.subplots(2,1)
plt.subplots_adjust(left=0.25, bottom=0.25)

i = 0

axanimal = plt.axes([0.25, 0.1, 0.65, 0.03])

xs = np.linspace(40, 120)
im=ax.matshow(rot_fl[i])
mid,=ax.plot(xs, mids[i](xs), color='r', lw=1)


prof,= ax1.plot(prof_xs, profs[i,:])

def update(val):
    mid.set_ydata(mids[int(val)](xs))
    prof.set_ydata(profs[int(val),:])
    im.set_data(rot_fl[int(val)])
    fig.canvas.draw_idle()

a = np.arange(0,len(strains)-1)
s_animal = Slider(axanimal, 'Animal', min(a), max(a), valinit=min(a), valstep=1)

s_animal.on_changed(update)

0

In [11]:
from scipy.interpolate import make_interp_spline
from matplotlib.lines import Line2D

i=0
fig, ax1 = plt.subplots(figsize=(10,5))

xs = np.linspace(40, 120)
xs_subsampled = np.linspace(40, 120, 5)
ax1.imshow(rot_fl[i])
# ax1.plot(xs, mids[i](xs), label="original spline")
ax1.scatter(xs_subsampled, mids[i](
    xs_subsampled), s=100, zorder=2, edgecolors='black', color='white')

b = make_interp_spline(xs_subsampled, mids[i](xs_subsampled))
ax1.plot(xs, b(xs), label='b-spline', zorder=1, color='orange')

ax1.legend()

<matplotlib.legend.Legend at 0x13738ab00>

In [3]:
def dist(x, y):
    """
    Return the distance between two points.
    """
    d = x - y
    return np.sqrt(np.dot(d, d))


def dist_point_to_segment(p, s0, s1):
    """
    Get the distance of a point to a segment.

      *p*, *s0*, *s1* are *xy* sequences

    This algorithm from
    http://geomalgorithms.com/a02-_lines.html
    """
    p = np.asarray(p, float)
    s0 = np.asarray(s0, float)
    s1 = np.asarray(s1, float)
    v = s1 - s0
    w = p - s0

    c1 = np.dot(w, v)
    if c1 <= 0:
        return dist(p, s0)

    c2 = np.dot(v, v)
    if c2 <= c1:
        return dist(p, s1)

    b = c1 / c2
    pb = s0 + b * v
    return dist(p, pb)

In [176]:
from matplotlib.lines import Line2D
from matplotlib.artist import Artist

class MidlineSplineEditorPlot(object):
    
    show_vertices = True
    epsilon = 5 # max pixel distance to count as a vertex hit
    
    def __init__(self, fig, ax, x, y, spl_resolution=5, k=3, bc_type=None):
        """
        x = vector of x positions of midline
        y = vector of y positions of midline
        """
        self.canvas = fig.canvas
        self.ax = ax
        self.x = x
        self.y = y
        self.k = k
        self.bc_type = bc_type
        self.spl = self.mk_spl(x, y)
        self.spl_resolution = spl_resolution
        self.spl_xs = np.linspace(np.min(x), np.max(x), spl_resolution)
        
        self.pts = Line2D(self.x, self.y, zorder=1, color='white', ls='', marker='o', animated=True)
        self.line = Line2D(self.spl_xs, self.spl(self.spl_xs), color='orange', animated=True, zorder=2)
        
        self._ind = None # active vertex
        
        self.ax.add_artist(self.line)
        self.ax.add_artist(self.pts)
        
        self.canvas.mpl_connect('draw_event', self.draw_callback)
        self.canvas.mpl_connect('button_press_event', self.button_press_callback)
        self.canvas.mpl_connect('button_release_event', self.button_release_callback)
        self.canvas.mpl_connect('motion_notify_event', self.motion_notify_callback)


    def mk_spl(self, x, y):
#         return make_interp_spline(x, y, k=self.k, bc_type=self.bc_type)
        return UnivariateSpline(x, y, s=50, k=self.k)
        
    def draw_callback(self, event):
        self.background = self.canvas.copy_from_bbox(self.ax.bbox)
        self.ax.draw_artist(self.line)
        self.ax.draw_artist(self.pts)
        # do not need to blit here, this will fire before the screen is
        # updated

    def get_ind_under_point(self, event):
        # display coords
        d = np.hypot(self.x - event.xdata, self.y - event.ydata)
        indseq, = np.nonzero(d == d.min())
        ind = indseq[0]
        if d[ind] >= self.epsilon:
            ind = None
        return ind
    
    def button_press_callback(self, event):
        """Called when a mouse button is pressed"""
        if event.inaxes is None:
            return
        if event.button != 1:  # check that it is left mouse
            return
        self._ind = self.get_ind_under_point(event)
        
    def button_release_callback(self, event):
        """Called when a mouse button is released"""
        if event.button != 1:
            return
        self._ind = None
    
    def motion_notify_callback(self, event):
        if self._ind is None:
            return
        if event.inaxes is None:
            return
        mousex, mousey = event.xdata, event.ydata
        
        
        
        oldx = self.x
        oldy = self.y
        try:
            self.x[self._ind] = mousex
            self.y[self._ind] = mousey
            
            self.spl = self.mk_spl(self.x, self.y)
            self.spl_xs = np.linspace(np.min(self.x), np.max(self.x), self.spl_resolution)
            self.line.set_data(self.spl_xs, self.spl(self.spl_xs))
            self.pts.set_data(self.x, self.y)
        except ValueError:
            self.x = oldx
            self.y = oldy
        
        self.canvas.restore_region(self.background)
        self.ax.draw_artist(self.line)
        self.ax.draw_artist(self.pts)
        self.canvas.blit(self.ax.bbox)

In [175]:
xs_subsampled = np.linspace(40, 120, 6)
ys_subsampled = mids[i](xs_subsampled)

fig, ax = plt.subplots()
ax.imshow(rot_fl[i])
m = MidlineEditor(
    fig, ax, xs_subsampled, ys_subsampled, 
    spl_resolution=20, k=4, bc_type=None)
#     (
#         [(4, 0), (2, 0)], 
#         [(4, 0), (2, 0)]))

In [6]:
raw_imgs.where(raw_imgs.wavelength != 'TL', drop=True)

<xarray.DataArray (strain: 123, wavelength: 4, y: 130, x: 174)>
array([[[[220., ..., 224.],
         ...,
         [507., ..., 212.]],

        ...,

        [[242., ..., 229.],
         ...,
         [491., ..., 232.]]],


       ...,


       [[[215., ..., 212.],
         ...,
         [223., ..., 208.]],

        ...,

        [[222., ..., 227.],
         ...,
         [230., ..., 209.]]]])
Coordinates:
  * wavelength  (wavelength) <U5 '470_1' '410_1' '470_2' '410_2'
  * strain      (strain) <U5 'HD233' 'HD233' 'HD233' ... 'SAY47' 'SAY47' 'SAY47'
Dimensions without coordinates: y, x

In [7]:
raw_imgs

<xarray.DataArray (strain: 123, wavelength: 5, y: 130, x: 174)>
array([[[[1033, ...,  985],
         ...,
         [ 673, ...,  958]],

        ...,

        [[ 242, ...,  229],
         ...,
         [ 491, ...,  232]]],


       ...,


       [[[1048, ..., 1014],
         ...,
         [ 992, ...,  992]],

        ...,

        [[ 222, ...,  227],
         ...,
         [ 230, ...,  209]]]], dtype=uint16)
Coordinates:
  * wavelength  (wavelength) <U5 'TL' '470_1' '410_1' '470_2' '410_2'
  * strain      (strain) <U5 'HD233' 'HD233' 'HD233' ... 'SAY47' 'SAY47' 'SAY47'
Dimensions without coordinates: y, x