In [1]:
%load_ext autoreload

from pathlib import Path
import csv
from pprint import pprint
import collections
import itertools

from tqdm import tqdm, trange

import numpy as np

from scipy import ndimage as ndi

import networkx as nx
from dataclasses import dataclass

import skimage
import skimage.feature
import skimage.filters
import skimage.morphology
import skimage.segmentation as seg
import skimage.measure

import cv2 as cv

from bokeh.plotting import output_notebook, figure, show
from bokeh.layouts import column, row, layout, gridplot
from bokeh.models import ColumnDataSource, Slider

import matplotlib.pyplot as plt

%matplotlib inline

import fish

output_notebook()

In [2]:
THIS_DIR = Path.cwd()
ROOT_DIR = THIS_DIR.parent
DATA_DIR = ROOT_DIR / 'data'

In [3]:
@dataclass(frozen = True)
class Point:
    index: int
    frame: int
    x: float
    y: float
    area: float
    perimeter: float
        
    def distance_to(self, p):
        return np.sqrt(((self.x - p.x) ** 2) + ((self.y - p.y) ** 2))

In [4]:
points = []
with (THIS_DIR / "D1-1__lower=25_upper=200_smoothing=3__centroids.csv").open(newline='') as f:
    spamreader = csv.reader(f, delimiter=',')
    for idx, (frame, x, y, area, perimeter) in enumerate(spamreader):
        points.append(Point(idx, int(frame), float(x), float(y), float(area), float(perimeter)))

In [5]:
print(len(points))
pprint(points[:10])

61069
[Point(index=0, frame=0, x=886.9304782298358, y=789.5497501784439, area=1167.5, perimeter=179.5807341337204),
 Point(index=1, frame=0, x=805.4095607235141, y=774.6976744186046, area=258.0, perimeter=84.42640614509583),
 Point(index=2, frame=0, x=404.2520391517129, y=725.5921696574225, area=613.0, perimeter=110.91168737411499),
 Point(index=3, frame=0, x=1004.6421052631579, y=711.0315789473684, area=47.5, perimeter=29.55634891986847),
 Point(index=4, frame=0, x=763.6103513006168, y=721.1182622687047, area=621.5, perimeter=107.15432798862457),
 Point(index=5, frame=0, x=810.66968053044, y=680.2194092827004, area=276.5, perimeter=71.01219260692596),
 Point(index=6, frame=0, x=833.7481481481481, y=655.5041394335511, area=382.5, perimeter=83.84061968326569),
 Point(index=7, frame=0, x=651.4521426738517, y=635.757505773672, area=649.5, perimeter=110.81118214130402),
 Point(index=8, frame=0, x=946.8632707774799, y=630.305183199285, area=373.0, perimeter=114.56854152679443),
 Point(index

In [23]:
def group_by_frame(points):
    groups = collections.defaultdict(list)
    
    for point in points:
        groups[point.frame].append(point)
        
    return dict(groups.items())

In [24]:
by_frame = group_by_frame(points)
last_frame = len(by_frame) - 1

In [25]:
def window(seq, n):
    "Returns a sliding window (of width n) over data from the iterable"
    "   s -> (s0,s1,...s[n-1]), (s1,s2,...,sn), ...                   "
    it = iter(seq)
    result = tuple(itertools.islice(it, n))
    if len(result) == n:
        yield result
    for elem in it:
        result = result[1:] + (elem,)
        yield result
        
for a, b in window(range(5), 2):
    print(a, b)

0 1
1 2
2 3
3 4


In [26]:
g = nx.DiGraph()

for (curr_index, curr_points), (next_index, next_points) in window(tqdm(list(sorted(by_frame.items(), key = lambda x: x[0]))), 2):
    edges = [(a, b, a.distance_to(b)) for a, b in itertools.product(curr_points, next_points)]
    g.add_weighted_edges_from(edges)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1716/1716 [00:17<00:00, 100.75it/s]


In [27]:
print(g)
print(len(g.nodes))
print(len(g.edges))


61069
2183132


In [28]:
starts = by_frame[0]
ends = by_frame[last_frame]

In [29]:
pprint(starts)
pprint(ends)

[Point(index=0, frame=0, x=886.9304782298358, y=789.5497501784439, area=1167.5, perimeter=179.5807341337204),
 Point(index=1, frame=0, x=805.4095607235141, y=774.6976744186046, area=258.0, perimeter=84.42640614509583),
 Point(index=2, frame=0, x=404.2520391517129, y=725.5921696574225, area=613.0, perimeter=110.91168737411499),
 Point(index=3, frame=0, x=1004.6421052631579, y=711.0315789473684, area=47.5, perimeter=29.55634891986847),
 Point(index=4, frame=0, x=763.6103513006168, y=721.1182622687047, area=621.5, perimeter=107.15432798862457),
 Point(index=5, frame=0, x=810.66968053044, y=680.2194092827004, area=276.5, perimeter=71.01219260692596),
 Point(index=6, frame=0, x=833.7481481481481, y=655.5041394335511, area=382.5, perimeter=83.84061968326569),
 Point(index=7, frame=0, x=651.4521426738517, y=635.757505773672, area=649.5, perimeter=110.81118214130402),
 Point(index=8, frame=0, x=946.8632707774799, y=630.305183199285, area=373.0, perimeter=114.56854152679443),
 Point(index=9, fr

In [None]:
shortest_paths = {}
for start in tqdm(starts):
    paths = nx.shortest_path(g, source = start, weight = 'weight')
    for end, path in paths.items():
        shortest_paths[start, end] = path










  0%|                                                                                                                                                                              | 0/41 [00:00<?, ?it/s][A[A[A[A[A[A[A[A[A








  2%|████                                                                                                                                                                  | 1/41 [00:09<06:02,  9.05s/it][A[A[A[A[A[A[A[A[A








  5%|████████                                                                                                                                                              | 2/41 [00:19<06:04,  9.35s/it][A[A[A[A[A[A[A[A[A








  7%|████████████▏                                                                                                                                                         | 3/41 [00:28<05:56,  9.38s/it][A[A[A[A[A[A[A[A[A








 10%|████████████████▏                 