In [1]:
import networkx as nx
import numpy as np
import pandas as pd
from visualize_lp_solution import load_tiff_frames
from ctc_fluo_metrics import filter_to_migration_sol
import napari
from napari_graph import DirectedGraph
from napari.layers import Graph

In [2]:
def assign_track_id(sol):
    """Assign unique integer track ID to each node. 

    Nodes that have more than one incoming edge, or more than
    two children get assigned track ID -1.

    Args:
        sol (nx.DiGraph): directed solution graph
    """
    roots = [node for node in sol.nodes if sol.in_degree(node) == 0]
    nx.set_node_attributes(sol, -1, 'track-id')
    track_id = 1
    for root in roots:
        for edge_key in nx.dfs_edges(sol, root):
            source, dest = edge_key[0], edge_key[1]
            source_out = sol.out_degree(source)
            # true root
            if sol.in_degree(source) == 0:
                sol.nodes[source]['track-id'] = track_id
            # merge into dest or triple split from source
            elif sol.in_degree(dest) > 1 or source_out > 2:
                sol.nodes[source]['track-id'] = -1
                sol.nodes[dest]['track-id'] = -1
                continue
            # double parent_split
            elif source_out == 2:
                track_id += 1
            sol.nodes[dest]['track-id'] = track_id
        track_id += 1

def mask_by_id(nodes, seg):
    masks = np.zeros_like(seg)
    max_id = nodes['track-id'].max()
    for i in range(1, max_id+1):
        track_nodes = nodes[nodes['track-id'] == i]
        for row in track_nodes.itertuples():
            t = row.t
            orig_label = row.label
            mask = seg[t] == orig_label
            masks[t][mask] = row._11 + 1
    
    # colour weird vertices with 1
    unassigned = nodes[nodes['track-id'] == -1]
    for row in unassigned.itertuples():
        t = row.t
        orig_label = row.label
        mask = seg[t] == orig_label
        masks[t][mask] = 1

    return masks

def get_point_colour(sol, merges, bad_parents):
    merges = set(merges)
    bad_parents = set(bad_parents)
    colours = ['white' for _ in range(sol.number_of_nodes())]
    for node in merges:
        parents = [edge[0] for edge in sol.in_edges(node)]
        children = [edge[1] for edge in sol.out_edges(node)]

        # colour the parents orange
        for parent in parents:
            colours[parent] = 'orange'
        # colour the merge node red
        colours[node] = 'red'
        # colour the children yellow
        for child in children:
            colours[child] = 'yellow'

    for node in bad_parents:
        children = [edge[1] for edge in sol.out_edges(node)]
        # colour children pink
        for child in children:
            colours[child] =  'pink'
        # colour parent purple
        colours[node] = 'purple'
    return colours

In [3]:
def get_colour_vs_of_interest(graph, vs, pred_colour, v_colour, succ_colour, orig_colour=None):
    if orig_colour is None:
        colours = ['white' for _ in range(list(graph.nodes)[-1]+1)]
    else:
        colours = orig_colour
    for node in vs:
        parents = [edge[0] for edge in graph.in_edges(node)]
        children = [edge[1] for edge in graph.out_edges(node)]
        # colour the parents orange
        for parent in parents:
            colours[parent] = pred_colour
        # colour the merge node red
        colours[node] = v_colour
        # colour the children yellow
        for child in children:
            colours[child] = succ_colour
    return colours

In [4]:
def store_colour_vs_of_interest(graph, vs, pred_colour, v_colour, succ_colour, orig_colour=False):
    if not orig_colour:
        nx.set_node_attributes(graph, 'white', 'color')
    for node in vs:
        parents = [edge[0] for edge in graph.in_edges(node)]
        children = [edge[1] for edge in graph.out_edges(node)]
        # colour the parents orange
        for parent in parents:
            graph.nodes[parent]['color'] = pred_colour
        # colour the merge node red
        graph.nodes[node]['color'] = v_colour
        # colour the children yellow
        for child in children:
            graph.nodes[child]['color'] = succ_colour

## Load images, segmentation and Gold Standard Ground Truth

In [6]:
seg = load_tiff_frames('/home/draga/PhD/data/cell_tracking_challenge/ST_Segmentations/Fluo-N2DL-HeLa/01_ST/SEG/')
data = load_tiff_frames('/home/draga/PhD/data/cell_tracking_challenge/ST_Segmentations/Fluo-N2DL-HeLa/01/')
truth = load_tiff_frames('/home/draga/PhD/data/cell_tracking_challenge/ST_Segmentations/Fluo-N2DL-HeLa/01_GT/TRA/')

## Load unchanged model solution

In [8]:
sol = nx.read_graphml('/home/draga/PhD/data/cell_tracking_challenge/ST_Segmentations/Fluo-N2DL-HeLa/01_RES_IC/full_sol.graphml', node_type=int)
assign_track_id(sol)
node_df = pd.DataFrame.from_dict(sol.nodes, orient='index')
sol_mask = mask_by_id(node_df, seg)

merges = [node for node in sol.nodes if sol.in_degree(node) > 1]
bad_parents = [node for node in sol.nodes if sol.out_degree(node) > 2]

merge_edges = [edge for node in merges for edge in sol.in_edges(node)]
merge_edges.extend([edge for node in merges for edge in sol.out_edges(node)])

### Colour original merge vertices, their predecessors, and successors grey

In [9]:
store_colour_vs_of_interest(sol, set(merges + bad_parents), 'silver', 'black', 'gray')

## Load opinionated oracle solution

In [15]:
fixed_sol = nx.read_graphml('/home/draga/PhD/data/cell_tracking_challenge/ST_Segmentations/Fluo-N2DL-HeLa/01_RES_IC/oracle_introduced_mig_near_parent.graphml', node_type=int)
filter_to_migration_sol(fixed_sol)
fixed_node_df = pd.DataFrame.from_dict(fixed_sol.nodes, orient='index')

still_merges = [node for node in fixed_sol.nodes if fixed_sol.in_degree(node) > 1]
still_bad_parents = [node for node in fixed_sol.nodes if fixed_sol.out_degree(node) > 2]

In [16]:
introduced

Unnamed: 0.1,Unnamed: 0,merge_id,new_id,t,y,x,new_label
0,0,467,8606,9,505.0,907.0,390
1,1,576,8607,11,505.0,907.0,391
2,2,2314,8608,37,435.0,177.0,392
3,3,2514,8609,39,509.0,906.0,393
4,4,2585,8610,40,506.0,911.0,394
5,5,2667,8611,41,507.0,905.0,395
6,6,2757,8612,42,505.0,908.0,396
7,7,3628,8613,51,448.0,184.0,397
8,8,3872,8614,53,448.0,184.0,398
9,9,4075,8615,55,448.0,184.0,399


In [17]:
# introduced_vertices
introduced = pd.read_csv('/home/draga/PhD/data/cell_tracking_challenge/ST_Segmentations/Fluo-N2DL-HeLa/01_RES_IC/oracle_introduced.csv')
introduced_vs = list(introduced['new_id'])
store_colour_vs_of_interest(fixed_sol, introduced_vs, 'springgreen', 'darkgreen', 'limegreen')
store_colour_vs_of_interest(fixed_sol, set(still_merges + still_bad_parents), 'coral', 'red', 'maroon', True)

## Load vertex introduction only oracle solution

In [18]:
vertex_sol = nx.read_graphml('/home/draga/PhD/data/cell_tracking_challenge/ST_Segmentations/Fluo-N2DL-HeLa/01_RES_IC/oracle_introduced_near_parent_no_edges.graphml', node_type=int)
filter_to_migration_sol(vertex_sol)
vertices_node_df = pd.DataFrame.from_dict(vertex_sol.nodes, orient='index')

vertex_sol_merges = [node for node in vertex_sol.nodes if vertex_sol.in_degree(node) > 1]
vertex_sol_bad_parents = [node for node in vertex_sol.nodes if vertex_sol.out_degree(node) > 2]

In [19]:
store_colour_vs_of_interest(vertex_sol, introduced_vs, 'springgreen', 'darkgreen', 'limegreen')
store_colour_vs_of_interest(vertex_sol, set(vertex_sol_merges + vertex_sol_bad_parents), 'coral', 'red', 'maroon', True)

In [None]:
vertex_sol.out_edges(3056)

OutEdgeDataView([(3056, 3155)])

## Load into napari

### Original solution

In [20]:

coords_df = node_df[['t', 'y', 'x']]
merge_graph = DirectedGraph(edges=merge_edges, coords=coords_df)
layer = Graph(
    merge_graph, 
    out_of_slice_display=True,
    ndim=3, 
    # scale=(50, 1, 1), 
    size=5, 
    properties=node_df,
    face_color=list(nx.get_node_attributes(sol, "color").values()),
)

full_graph = DirectedGraph(edges=list(sol.edges.keys()), coords=coords_df)
full_layer = Graph(
    full_graph, 
    out_of_slice_display=True,
    ndim=3, 
    # scale=(50, 1, 1), 
    size=5, 
    properties=node_df,
    face_color=list(nx.get_node_attributes(sol, "color").values()),
)

### Fixed edges solution

In [21]:
fixed_coords = fixed_node_df[['t', 'y', 'x']]
fixed_edge_graph = DirectedGraph(edges=list(fixed_sol.edges.keys()), coords=fixed_coords)
fixed_layer = Graph(
    fixed_edge_graph, 
    out_of_slice_display=True,
    # ndim=3, 
    # scale=(50, 1, 1), 
    size=5, 
    properties=fixed_node_df,
    face_color=list(nx.get_node_attributes(fixed_sol, "color").values()),
)

### Just vertices solution

In [22]:
vertices_coords = vertices_node_df[['t', 'y', 'x']]
vertices_graph = DirectedGraph(edges=list(vertex_sol.edges.keys()), coords=vertices_coords)
vertices_layer = Graph(
    vertices_graph, 
    out_of_slice_display=True,
    # ndim=3, 
    # scale=(50, 1, 1), 
    size=5, 
    properties=vertices_node_df,
    face_color=list(nx.get_node_attributes(vertex_sol, "color").values()),
)

In [23]:
viewer = napari.Viewer()
viewer.add_image(data)
viewer.add_labels(
    seg,
    # scale=(50, 1, 1), 
)
viewer.add_labels(
    sol_mask,
    name='Solution'
)
viewer.add_labels(truth)
viewer.add_layer(layer)
viewer.add_layer(full_layer)
viewer.add_layer(fixed_layer)
viewer.add_layer(vertices_layer)

  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)


<Graph layer 'vertices_graph' at 0x7f26bba56190>