In [1]:
import sys
sys.path.append('../')

In [2]:
import zarr
import numpy as np
import networkx as nx
from skimage.measure import regionprops
#from segmentation import segment_stats
from linajea_cost_test import get_merge_graph_from_array
from linajea_cost_test import get_merge_graph
from funlib.math import decode64, encode64
import networkx as nx
from candidate_graph import iterate_tree

In [3]:
z = zarr.open('anno_alice_T2030_tiny_2.zarr','r')
gt_image = z['gt_trackimage'][:]
raw = z['Raw'][:]
fragments = z['Fragments'][:]

In [4]:
# annotation label is 10000, 10001, 10004, 10005, 10007, 10010, 10013, 10016, 10017 pick them out
# annotaion is made by merge frgments by napari tool
#gt_mask = gt_image >=10000
#gt_image = gt_mask * gt_image

In [5]:
# create annotation graph by gt_trackimage

def create_annotation_graph(gt_image):
    '''
    gt_image: ndarray image by (t,z,y,x) each cell has an annotated label. for exmaple 10000, 10001, .... 

    return
    anno_graph: nx.Digrph() which create a graph with edges.
                these edges connected node are generated by encode cell.
                e.g label n will be encode to u at time 0
                                    encode to v at time 1
                then the edge will be (v, u, source = v, target = u)
    
    '''
    time = gt_image.shape[0]
    anno_graph = nx.DiGraph()
    for label in np.unique(gt_image):
        if label == 0:
            continue   
        for t in range(time):
            mask = gt_image[t]==label
            regions = regionprops(mask.astype(int))
            for props in regions:
                z0, y0, x0 = props.centroid
                u = encode64((t, int(z0), int(y0), int(x0), int(props.area)),bits=[9,12,12,12,19])
            mask = gt_image[t]==label
            regions = regionprops(mask.astype(int))
            for props in regions:
                z0, y0, x0 = props.centroid
                v = encode64((t, int(z0), int(y0), int(x0), int(props.area)),bits=[9,12,12,12,19])
            anno_graph.add_edge(v, u, source = v, target = u)
    return anno_graph


In [6]:
# get the id list of one time frame
from segmentation import segment_stats
id, positions, _ = segment_stats(gt_image[0],0)

t = 0
merge_tree = z['Merge_tree/Merge/'+str(t)]
scores = z['Merge_tree/Scoring/'+str(t)]
merge_tree = get_merge_graph_from_array(merge_tree,scores)

# count the number of how many annotation id in merge tree at time t
def count_anno_id(id,merge_tree):
    '''
    merge_tree: nx.Digrph merge_tree create by waterz 
    id: ndarray (n,) 1d array store all ids in GT image
    correct_id_num: int count the number of how many annotation node(id) in merge tree
    '''
    mt_id = list(merge_tree.nodes)
    correct_id_num = len(set(id)&set(mt_id))
    return correct_id_num 

correct_id_num = count_anno_id(id,merge_tree)


# get the sovler output

In [7]:
# read .gexf file which created by run_candidate_graph.ipynb
graph = nx.read_graphml('test_solver.graphml')
selected_key = 'selected'

'''
NOTE:
the Digraph nodes will convert to string after nx.read_graphml
the can didate_graph.node should be 
{19140315621363712: {'t': 0, 'z': 8, 'y': 12, 'x': 2, 'score': 0.04924517869949341, 'parent': 53339542484161024, 'id': 19140315621363712}
the graph will convert 19140315621363712 to '19140315621363712'
{'19140315621363712': {'t': 0, 'z': 8, 'y': 12, 'x': 2, 'score': 0.04924517869949341, 'parent': 53339542484161024, 'id': 19140315621363712}
'''
print(graph.edges(data=True))

[('65231937549108737', '950218593931776', {'source': 65231937549108737, 'target': 950218593931776, 'overlap': 0.008579088471849867, 'selected': False}), ('65231937549108737', '31666141059356672', {'source': 65231937549108737, 'target': 31666141059356672, 'overlap': 0.005843681519357195, 'selected': False}), ('65231937549108737', '7599936094863872', {'source': 65231937549108737, 'target': 7599936094863872, 'overlap': 0.0058309037900874635, 'selected': False}), ('65231937549108737', '21321901339050496', {'source': 65231937549108737, 'target': 21321901339050496, 'overlap': 0.004901960784313725, 'selected': False}), ('65231937549108737', '66181984340151808', {'source': 65231937549108737, 'target': 66181984340151808, 'overlap': 0.0032232070910556, 'selected': False}), ('65231937549108737', '633344526123008', {'source': 65231937549108737, 'target': 633344526123008, 'overlap': 0.0016051364365971107, 'selected': False}), ('65231937549108737', '6368414345986048', {'source': 65231937549108737, '

In [8]:
def get_select_graph(graph,key='selected'):
    '''
    get a sub-graph form graph

    graph: nx.Digraph() candidate graph with some edges has attribute 'select' = True

    return
    select_graph: nx.Digraph()  sub-graph only have edges from candidate graph with 'select' is True
  
    '''
    select_graph = nx.DiGraph()
    for u,v,data in graph.edges(data=True):
        if data[key]:
            select_graph.add_edge(v, u, source = v, target = u)
            select_graph.add_node(u, **graph.nodes[u])
            select_graph.add_node(v, **graph.nodes[v])
    return select_graph

select_graph = get_select_graph(graph, key=selected_key)


def pick_select_edge(graph):
    '''
    graph: nx.Digraph() candidate graph with edge has attribute 'select' = True

    return
    edges: list of all edges in graph
    select_edges: list of edges with attribute 'select' == True
 
    '''
    edges = []
    selected_edges = []
    for u,v,data in graph.edges(data=True):
        edges.append((u,v))
        if data['selected']:
            selected_edges.append((u,v))
    return edges, selected_edges




def find_devison_node(select_graph):
    ''' 
    find the devision node
    select_graph: nx.Digraph()  sub-graph only have edges from graph with 'select' is True

    return
    devision_nodes: list of nodes wihch are split
    '''
    devision_nodes = []
    for node in select_graph.nodes:
        if select_graph.out_degree(node) == 2:
            devision_nodes.append(node)
    return devision_nodes
    

In [9]:
def select_nodes_t(select_graph,t):
    '''
    pick nodes at time t in select_graph

    select_graph: nx.Digraph()
    t: int 

    return
    nodes_t: a string list wich contains the nodes at time t 

    '''
    nodes_t = []
    for node in select_graph.nodes:
        print(node)
        cor = decode64(int(node),dims = 5,bits=[9,12,12,12,19])
        if cor[0]== t:
            nodes_t.append(node)
    return nodes_t
# pick out the start node

node_start = select_nodes_t(select_graph, 0)
print('the ids of nodes in t0 are:', node_start)

7599936094863872
65020831316575745
633344526123008
211192140268033
2603798155493376
3377854341447681
1020424099994112
23045961301361665
5453646393247232
6298243138849281
53234143984617984
13264542656102401
68996673946256896
107664238748635137
9148168692303872
50947151206812673
4538973003186688
21568054458586625
13722119915438592
2709428579075585
1161256111177728
1548275593248769
1477820981182464
25016311941698049
29132694457684992
1407452192970753
9112975713503232
1196380341148161
25860668156548096
16466303365678081
25332893939143680
24242238542321153
1372199103496704
879617892155393
950218593931776
1759287323919361
23433006572111362
44473166604534786
23432946482418689
18190406319605250
422246873038850
26142014240199170
6157436920529410
122969440607276546
27373656246064642
40180698964824578
48836088885676546
633550625833474
4749984723376130
457396870715906
422298364412418
70368744177666
the ids of nodes in t0 are: ['7599936094863872', '633344526123008', '2603798155493376', '10204240999

In [10]:
def find_path(graph, start, end):
    '''
    find a path between star and end in graph

    graph: nx.Digrph
    start: string  the start point (node id)
    end: string  the end point (node id)

    '''
    queue = [(start, [start])]
    while queue:
        (node, path) = queue.pop(0)
        for neighbor in graph.successors(node):
            if neighbor == end:
                yield path + [end]
            else:
                queue.append((neighbor, path + [neighbor]))
                
def find_all_paths(graph, start):
    '''
    find all path with a start point (node id)

    graph: nx.Digrph
    start: string  the start point  
    '''
    paths = []
    for node in graph.nodes():
        paths.extend(list(find_path(graph, start, node)))
    return [path for i, path in enumerate(paths) if not any(path == subpath[:len(path)] for subpath in paths[:i] + paths[i+1:])]

def find_repeated_nodes(paths):
    '''
    old one. not use any more
    path =[['19316374903985664', '28569847583478785', '1055642833915906'], ['19316374903985664', '28569847583478785', '140917883278850']]
    find_repeated_nodes(path) gives
    {'19316374903985664', '28569847583478785'}

    
    once you have all path from a node e.g.'19316374903985664'
    this function give a common nodes from two different paths.

    '''
    
    node_counts = {}
    for path in paths:
        for node in path:
            node_counts[node] = node_counts.get(node, 0) + 1
    repeated_nodes = set([node for node, count in node_counts.items() if count > 1])
    return repeated_nodes

def pick_nodes_once(lst,select_graph):
    '''

    lst : a list of paths of one node (from find_all_path)
        e.g node 1 has paths [[1,2,4],[1,2,3]]
    select_graph : nx.Digraph

    return:
    create lists. onece the node split it will creat a new list

    e.g for path =[['19316374903985664', '28569847583478785', '1055642833915906'], ['19316374903985664', '28569847583478785', '140917883278850']]
    pick_nodes_once(path,select_graph)
    it gives 
    [['19316374903985664', '28569847583478785'],
    ['1055642833915906'],
    ['140917883278850']]
    '''
    output = []
    nodes = set()
    for item in lst:
        new_item = []
        for node in item:
            if node not in nodes:
                nodes.add(node)
                new_item.append(node)
                if select_graph.out_degree(node) == 2:
                    output.append(new_item)
                    new_item = []                     
        if new_item:
            output.append(new_item)
    return output

In [11]:
# create a labels list that prevent napari error  (napari can only read labels under int32)
labels = []
for node in select_graph.nodes:
    labels.append(node)

In [12]:
track_data = []

l=0
num_node = 0
cell_mask = np.zeros(fragments.shape)
track_graph = {}
#track_graph = nx.DiGraph()# graph for arboretum
for start in node_start:    
    paths = find_all_paths(select_graph, start)
    pick = pick_nodes_once(paths,select_graph)

    for path in pick:
        label = labels.index(path[0]) # choose first node id as label
        for node in path:
            cor = decode64(int(node),dims=5,bits=[9,12,12,12,19])
            track_data.append([label,cor[0],cor[1],cor[2],cor[3]])
            if select_graph.out_degree(node) == 2:
                for neighbor in select_graph.out_edges(node):
                        track_graph[labels.index(neighbor[1])] = [label]
            t = cor[0]
            
            ids = z['Fragment_stats/id/'+str(t)]
            merge_tree = z['Merge_tree/Merge/'+str(t)]
            scores = z['Merge_tree/Scoring/'+str(t)]
            merge_tree = get_merge_graph_from_array(merge_tree,scores)

            if merge_tree.has_node(int(node)):
                sub_tree = iterate_tree(merge_tree,int(node))
                for a in sub_tree:
                    if a in list(ids):
                        # a is the leave node id
                        # get the label of this leave node
                        label_a = list(ids).index(a)+1
                        cell_mask[t][fragments[t] == label_a] = label

        

    if len(paths)>1:
        l=l+1
    #print('******************')
        print('the paths are',paths)
        #print('******************')
        for path in pick:
            print('sub-set of paths:',path)
        print('******************')



In [13]:
track_data

[[0, 0, 1, 26, 13],
 [0, 1, 3, 13, 13],
 [0, 2, 12, 1, 11],
 [2, 0, 0, 28, 3],
 [2, 1, 1, 4, 10],
 [2, 2, 13, 16, 0],
 [4, 0, 0, 1, 18],
 [4, 1, 0, 1, 18],
 [4, 2, 1, 3, 20],
 [6, 0, 13, 0, 9],
 [6, 1, 2, 7, 23],
 [6, 2, 3, 6, 25],
 [8, 0, 3, 0, 8],
 [8, 1, 5, 8, 28],
 [8, 2, 5, 18, 21],
 [10, 0, 5, 17, 22],
 [10, 1, 0, 9, 4],
 [10, 2, 7, 13, 7],
 [12, 0, 3, 10, 14],
 [12, 1, 8, 13, 7],
 [12, 2, 2, 12, 14],
 [14, 0, 4, 10, 27],
 [14, 1, 6, 16, 21],
 [14, 2, 1, 0, 10],
 [16, 0, 1, 12, 22],
 [16, 1, 11, 4, 4],
 [18, 0, 1, 25, 25],
 [18, 1, 3, 0, 27],
 [18, 2, 0, 23, 4],
 [20, 0, 0, 16, 20],
 [20, 1, 0, 6, 19],
 [22, 0, 0, 21, 9],
 [22, 1, 9, 23, 26],
 [22, 2, 9, 5, 17],
 [24, 0, 10, 4, 4],
 [24, 1, 12, 0, 9],
 [26, 0, 4, 2, 26],
 [26, 1, 13, 10, 13],
 [26, 2, 11, 6, 25],
 [28, 0, 10, 25, 18],
 [28, 1, 2, 23, 2],
 [28, 2, 1, 0, 27],
 [30, 0, 10, 3, 17],
 [30, 1, 11, 7, 24],
 [32, 0, 1, 1, 1],
 [32, 1, 0, 0, 1],
 [32, 2, 0, 0, 0],
 [34, 0, 9, 14, 28],
 [34, 1, 2, 0, 8],
 [34, 2, 9, 4, 3]]

In [14]:
track_graph

{}

In [18]:
import napari
viewer = napari.Viewer()
viewer.add_image(raw[3], name='Raw')
viewer.add_labels(gt_image, name='GT')
viewer.add_labels(fragments)
tracks = napari.layers.Tracks(track_data,   
                              graph=track_graph, 
                              name="track_test")
viewer.add_labels(cell_mask.astype(int))
viewer.add_layer(tracks)
viewer.window.add_plugin_dock_widget(plugin_name="napari-arboretum", 
                                     widget_name="Arboretum")



(<napari._qt.widgets.qt_viewer_dock_widget.QtViewerDockWidget at 0x2bfee96c0>,
 <napari_arboretum.plugin.Arboretum at 0x2d1b82290>)



In [14]:
# example for track

'''
               /  5  
           /  4 - 6
      / 2 / - 3 - 7
    1/- 8 - - 9 - 10

'''

g = nx.DiGraph()
g.add_edge(2,1)
g.add_edge(8,1)

g.add_edge(3,2)
g.add_edge(4,2)

g.add_edge(5,4)
g.add_edge(6,4)

g.add_edge(7,3)

g.add_edge(9,8)
g.add_edge(10,9)

node_start=(1)
lst = [[1,2,4,5],[1,2,4,6],[1,2,3,7],[1,8,9,10]]
track_graph = {}
output = []
nodes = set()
for item in lst:
    new_item = []
    for node in item:
        if node not in nodes:
            nodes.add(node)
            new_item.append(node)
            if g.in_degree(node) == 2:
                for neighbor in g.in_edges(node):
                    track_graph[neighbor[0]] = neighbor[1]
                output.append(new_item)
                new_item = []
                      
    if new_item:
        output.append(new_item)

for i in output:
    print(i)
print(track_graph)

[1]
[2]
[4]
[5]
[6]
[3, 7]
[8, 9, 10]
{2: 1, 8: 1, 3: 2, 4: 2, 5: 4, 6: 4}
