In [1]:
%load_ext autoreload

from pathlib import Path
import csv
from pprint import pprint
import collections
import itertools
from copy import deepcopy
from multiprocessing import Pool

from tqdm.notebook import tqdm, trange

import numpy as np

from scipy import ndimage as ndi

import networkx as nx
from dataclasses import dataclass

import skimage
import skimage.feature
import skimage.filters
import skimage.morphology
import skimage.segmentation as seg
import skimage.measure

import cv2 as cv

from bokeh.plotting import output_notebook, figure, show
from bokeh.layouts import column, row, layout, gridplot
from bokeh.models import ColumnDataSource, Slider

import matplotlib.pyplot as plt

%matplotlib inline

import fish

output_notebook()

In [2]:
THIS_DIR = Path.cwd()
ROOT_DIR = THIS_DIR.parent
DATA_DIR = ROOT_DIR / 'data'
OUT_DIR = THIS_DIR / "out" / "graphs"

In [3]:
@dataclass(frozen = True)
class Point:
    index: int
    frame: int
    x: float
    y: float
    area: float
    perimeter: float
        
    def distance_to(self, p):
        return np.sqrt(((self.x - p.x) ** 2) + ((self.y - p.y) ** 2))

In [4]:
points = []
with (THIS_DIR / 'out' / 'edges' / "D1-1__lower=25_upper=200_smoothing=3__centroids.csv").open(newline='') as f:
    spamreader = csv.reader(f, delimiter=',')
    for idx, (frame, x, y, area, perimeter) in enumerate(spamreader):
        points.append(Point(idx, int(frame), float(x), float(y), float(area), float(perimeter)))

In [5]:
print(len(points))
pprint(points[:10])

61032
[Point(index=0, frame=0, x=886.6498503634032, y=789.5937010118283, area=1169.5, perimeter=189.37972366809845),
 Point(index=1, frame=0, x=805.7322936972059, y=774.5230669265757, area=256.5, perimeter=83.01219260692596),
 Point(index=2, frame=0, x=404.7261904761905, y=725.6124338624338, area=630.0, perimeter=112.91168737411499),
 Point(index=3, frame=0, x=1004.5797101449275, y=711.4589371980676, area=34.5, perimeter=23.899494767189026),
 Point(index=4, frame=0, x=763.5616883116883, y=721.2513528138528, area=616.0, perimeter=106.56854152679443),
 Point(index=5, frame=0, x=811.6480411046884, y=680.2312138728323, area=259.5, perimeter=75.01219260692596),
 Point(index=6, frame=0, x=833.8497232865047, y=655.3277990634311, area=391.5, perimeter=84.66904675960541),
 Point(index=7, frame=0, x=651.5867597620894, y=635.7879493147142, area=644.5, perimeter=109.9827550649643),
 Point(index=8, frame=0, x=946.7846790890269, y=630.432298136646, area=402.5, perimeter=114.81118214130402),
 Point(i

In [6]:
def make_point_array(points):
    arr = []
    for point in points:
        arr.append([point.frame, point.x, point.y])
    return np.array(arr)

point_array = make_point_array(points)
print(point_array.shape)
print(point_array[:10])

(61032, 3)
[[   0.          886.64985036  789.59370101]
 [   0.          805.7322937   774.52306693]
 [   0.          404.72619048  725.61243386]
 [   0.         1004.57971014  711.4589372 ]
 [   0.          763.56168831  721.25135281]
 [   0.          811.6480411   680.23121387]
 [   0.          833.84972329  655.32779906]
 [   0.          651.58675976  635.78794931]
 [   0.          946.78467909  630.43229814]
 [   0.          339.58005822  623.00946143]]


In [7]:
center = np.mean(point_array, axis = 0)[1:]
print(center)
center_point = Point(-1, -1, x = center[0], y = center[1], area = -1, perimeter = -1)
print(center_point)
dist_from_center = np.linalg.norm(point_array[:, 1:] - center, axis=-1)

[648.35334895 488.02237576]
Point(index=-1, frame=-1, x=648.3533489519203, y=488.02237575537663, area=-1, perimeter=-1)


In [8]:
def show_points(point_array):
    def _(doc):
        p = figure(match_aspect = True)
        p.x_range.range_padding = p.y_range.range_padding = 0

        source = ColumnDataSource(data=dict(
            x = point_array[:, 1],
            y = point_array[:, 2],
        ))
        
        now = ColumnDataSource(data=dict(
            x = point_array[point_array[:, 0] == 0, 1],
            y = point_array[point_array[:, 0] == 0, 2],
        ))

        slider = Slider(start=0, end=last_frame, value=0, step=1, title="Frame", callback_policy='throttle', callback_throttle=100)
        
        def update(attr, old, new):
            now.data = dict(
                x = point_array[point_array[:, 0] == new, 1],
                y = point_array[point_array[:, 0] == new, 2],
            )
            
        slider.on_change('value_throttled', update)

        p.circle(x = 'x', y = 'y', source = source, size=.1)
        p.circle(x = 'x', y = 'y', source = now, color='red')
        
        doc.add_root(column(p, slider))
        
    return show(_)

cutoff = 430
filtered_point_array = point_array[dist_from_center < cutoff]
print(filtered_point_array[filtered_point_array[:, 0] == 1, 0])
show_points(filtered_point_array)

[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]


In [9]:
def window(seq, n):
    "Returns a sliding window (of width n) over data from the iterable"
    "   s -> (s0,s1,...s[n-1]), (s1,s2,...,sn), ...                   "
    it = iter(seq)
    result = tuple(itertools.islice(it, n))
    if len(result) == n:
        yield result
    for elem in it:
        result = result[1:] + (elem,)
        yield result

In [10]:
def group_by_frame(pts):
    groups = collections.defaultdict(list)
    
    for p in pts:
        groups[p.frame].append(p)
        
    return dict(groups.items())

In [11]:
filtered_points = [p for p in points if p.distance_to(center_point) < cutoff]
print(len(points), len(filtered_points))
by_frame = group_by_frame(filtered_points)
last_frame = len(by_frame) - 1

61032 59892


In [12]:
g = nx.DiGraph()

for (curr_index, curr_points), (next_index, next_points) in window(tqdm(list(sorted(by_frame.items(), key = lambda x: x[0]))), 2):
    edges = (
        (a, b, a.distance_to(b))
        for a, b in itertools.product(curr_points, next_points)
    )
    edges = (
        (a, b, d) for a, b, d in edges
        if d < 50
    )
    g.add_weighted_edges_from(edges)

HBox(children=(FloatProgress(value=0.0, max=1716.0), HTML(value='')))




In [13]:
print(len(points))
print(len(filtered_points))
print(len(g.nodes))
print(len(g.edges))

61032
59892
59743
81181


In [14]:
starts = by_frame[0]
ends = by_frame[last_frame]

In [15]:
pprint(starts)
pprint(ends)

[Point(index=0, frame=0, x=886.6498503634032, y=789.5937010118283, area=1169.5, perimeter=189.37972366809845),
 Point(index=1, frame=0, x=805.7322936972059, y=774.5230669265757, area=256.5, perimeter=83.01219260692596),
 Point(index=2, frame=0, x=404.7261904761905, y=725.6124338624338, area=630.0, perimeter=112.91168737411499),
 Point(index=3, frame=0, x=1004.5797101449275, y=711.4589371980676, area=34.5, perimeter=23.899494767189026),
 Point(index=4, frame=0, x=763.5616883116883, y=721.2513528138528, area=616.0, perimeter=106.56854152679443),
 Point(index=5, frame=0, x=811.6480411046884, y=680.2312138728323, area=259.5, perimeter=75.01219260692596),
 Point(index=6, frame=0, x=833.8497232865047, y=655.3277990634311, area=391.5, perimeter=84.66904675960541),
 Point(index=7, frame=0, x=651.5867597620894, y=635.7879493147142, area=644.5, perimeter=109.9827550649643),
 Point(index=8, frame=0, x=946.7846790890269, y=630.432298136646, area=402.5, perimeter=114.81118214130402),
 Point(index=9

In [16]:
def get_path_length(graph, path):
    length = 0
    for a, b in window(path, 2):
        length += graph[a][b]['weight']
    return length

In [17]:
@dataclass
class Path:
    points: list
    length: int

    @property
    def coordinates(self):
        return np.array([[point.x, point.y] for point in self.points])

In [18]:
def find_paths(g, start):
    try:
        return nx.shortest_path(g, source = start, weight = 'weight')
    except Exception as e:
        print(e)


def shortest_paths_between(g, starts, ends):
    shortest_paths = {}
    
#     with Pool() as pool:
#         results = pool.starmap(find_paths, ((g, start) for start in starts))
    for paths in results:
        for end, path in paths.items():
            if end not in ends:
                continue
            shortest_paths[start, end] = Path(points = path, length = get_path_length(g, path))
    return shortest_paths

In [19]:
shortest_paths = shortest_paths_between(g, by_frame[0], by_frame[last_frame])
# shortest_paths = shortest_paths_between(g, by_frame[0], by_frame[100])

NameError: name 'results' is not defined

In [None]:
for (start, end), path in sorted(shortest_paths.items(), key = lambda kv: kv[1].length)[:10]:
    print(start.index, end.index, path.length)

In [None]:
paths = []
for (start, end), path in sorted(shortest_paths.items(), key = lambda kv: kv[1].length):
    paths.append(path.coordinates)
print(paths[0])

In [None]:
def show_paths(point_array, paths):
    def _(doc):
        p = figure(match_aspect = True)
        p.x_range.range_padding = p.y_range.range_padding = 0

        source = ColumnDataSource(data=dict(
            x = point_array[:, 1],
            y = point_array[:, 2],
        ))
        now = ColumnDataSource(data=dict(
            x = point_array[point_array[:, 0] == 0, 1],
            y = point_array[point_array[:, 0] == 0, 2],
        ))
        
        start = ColumnDataSource(data = dict(
            x = [paths[0][0, 0]],
            y = [paths[0][0, 1]],
        ))
        path = ColumnDataSource(data=dict(
            x = paths[0][:, 0],
            y = paths[0][:, 1],
        ))

        path_slider = Slider(start=0, end=len(paths), value=0, step=1, title="Path #", callback_policy='throttle', callback_throttle=100)
        
        def path_update(attr, old, new):
            path.data = dict(
                x = paths[new][:, 0],
                y = paths[new][:, 1],
            )
            start.data = dict(
                x = [paths[new][0, 0]],
                y = [paths[new][0, 1]],
            )
            
        path_slider.on_change('value_throttled', path_update)
        
        frame_slider = Slider(start=0, end=last_frame, value=0, step=1, title="Frame #", callback_policy='throttle', callback_throttle=100)
        
        def frame_update(attr, old, new):
            now.data = dict(
                x = point_array[point_array[:, 0] == new, 1],
                y = point_array[point_array[:, 0] == new, 2],
            )
            
        frame_slider.on_change('value_throttled', frame_update)

        p.circle(x = 'x', y = 'y', source = source, size=.1)
        p.circle(x = 'x', y = 'y', source = now, color='red')
        p.line(x = 'x', y = 'y', source = path, color='green', alpha=0.5)
        p.x(x = 'x', y = 'y', source = start, color='black', size = 20, line_width=5, line_alpha=0.5)
        
        doc.add_root(column(p, row(frame_slider, path_slider)))
        
    return show(_)

show_paths(filtered_point_array, paths)

In [None]:
frames = fish.read(DATA_DIR / "D1-1.hsv")

In [None]:
BLUE = (255, 0, 0)
GREEN = (0, 255, 0)
RED = (0, 0, 255)
YELLOW = (0, 255, 255)

LOOK = 50

LINE_OPTS = dict(
    isClosed = False,
    thickness = 1,
    lineType = cv.LINE_AA,
)

def original_with_paths(frames, paths):
    paths = [p.coordinates.astype(np.int0) for p in paths]
    
    for frame_index, frame in enumerate(frames):
        frame = cv.cvtColor(frame, cv.COLOR_GRAY2BGR)
        
        frame = cv.polylines(
            frame,
            [p[max(frame_index - LOOK, 0) : frame_index] for p in paths],
            color=RED,
            **LINE_OPTS,
        )
        frame = cv.polylines(
            frame,
            [p[frame_index : frame_index + LOOK] for p in paths],
            color=GREEN,
            **LINE_OPTS,
        )
        
        for idx, p in enumerate(paths):
            x, y = p[frame_index]
            frame = cv.putText(
                frame,
                str(idx),
                (x + 15, y),
                fontFace=cv.FONT_HERSHEY_DUPLEX,
                fontScale=1,
                color=YELLOW,
                thickness=1,
                lineType=cv.LINE_AA,
            )
        
        yield frame

In [None]:
# f = frames[100:].copy()
# fish.make_movie(
#     OUT_DIR / "test_one-path.mp4",
#     frames = original_with_paths(f, [paths[0]]),
#     num_frames = len(f),
#     fps = 5,
# )

In [None]:
def find_paths_basic(g, by_frame):
    possible_paths = shortest_paths_between(g, by_frame[0], by_frame[last_frame])
    paths = {}

    while len(possible_paths) > 0:
        (start, end), path = min(possible_paths.items(), key = lambda kv: kv[1].length)
        paths[start, end] = path
        possible_paths = {(s, e): p for (s, e), p in possible_paths.items() if s is not start and e is not end}
        
    return paths

In [None]:
real_paths = find_paths_basic(g, by_frame)
print(len(real_paths))

In [None]:
f = frames[100:].copy()
fish.make_movie(
    OUT_DIR / "test_all-paths.mp4",
    frames = original_with_paths(f, real_paths.values()),
    num_frames = len(f),
    fps = 5,
)

In [None]:
def find_paths_increase_weights_by_flux(g, starts, ends):
    g = deepcopy(g)
    starts = starts.copy()
    ends = ends.copy()
    
    possible_paths = shortest_paths_between(g, starts, ends)
    real_paths = {}

    while len(possible_paths) > 0:
        (start, end), path = min(possible_paths.items(), key = lambda kv: kv[1].length)
        print(start.index, end.index, path.length)
        real_paths[start, end] = path
        possible_paths = {(s, e): p for (s, e), p in possible_paths.items() if s is not start and e is not end}

In [None]:
find_paths_increase_weights_by_flux(g, by_frame[0], by_frame[last_frame])