In [1]:
import os
from datetime import datetime
from ctc_timings import get_im_centers, get_graph
from visualize_lp_solution import load_tiff_frames
import networkx as nx
import igraph
import numpy as np
import pandas as pd

In [2]:
DATA_ROOT = '/home/draga/PhD/data/cell_tracking_challenge/'
OUT_ROOT = '/home/draga/PhD/code/experiments/ctc/'
DS_NAME = 'Fluo-N2DL-HeLa/'
SEQ = '01_ST'
MIGRATION_ONLY = False
GT_PATH = os.path.join("/home/draga/PhD/data/cell_tracking_challenge/", DS_NAME, '01_GT/TRA/')

## Build and Solve Initial Model

In [3]:
im_dir = os.path.join(DATA_ROOT, DS_NAME, SEQ, 'TRA/' if SEQ.endswith('GT') else 'SEG/')
model_root = os.path.join(OUT_ROOT, DS_NAME, SEQ, 'models/')
sol_root = os.path.join(OUT_ROOT, DS_NAME, SEQ, 'output/')
os.makedirs(model_root, exist_ok=True)
os.makedirs(sol_root, exist_ok=True)

current_datetime = datetime.now().strftime("%d%b%y_%H%M")
out_path = os.path.join(OUT_ROOT, DS_NAME, SEQ, f'runtimes.csv')
model_path = os.path.join(model_root, f'{current_datetime}.lp')
sol_path = os.path.join(sol_root, f'{current_datetime}.sol')

In [4]:
coords, min_t, max_t, corners = get_im_centers(im_dir)
graph, build_time = get_graph(coords, min_t, max_t, corners)

Building kD trees: 100%|██████████| 92/92 [00:01<00:00, 58.88it/s]


Computing appearance/exit costs


Making appearance/exit edges: 100%|██████████| 8602/8602 [00:00<00:00, 71669.30it/s]
Making migration & division edges:  48%|████▊     | 44/91 [00:03<00:05,  8.11it/s]

Building 44


Making migration & division edges: 100%|██████████| 91/91 [00:11<00:00,  8.03it/s]

Build duration:  13.20469331741333





In [5]:
m, flow = graph._to_gurobi_model()
m.optimize()

Set parameter Username
Academic license - for non-commercial use only - expires 2023-05-12
Gurobi Optimizer version 9.5.2 build v9.5.2rc0 (linux64)
Thread count: 4 physical cores, 8 logical processors, using up to 8 threads
Optimize a model with 25809 rows, 110332 columns and 424120 nonzeros
Model fingerprint: 0xfa47d5f2
Coefficient statistics:
  Matrix range     [1e+00, 1e+00]
  Objective range  [2e-02, 4e+02]
  Bounds range     [1e+00, 1e+00]
  RHS range        [1e+00, 1e+00]

Concurrent LP optimizer: dual simplex and barrier
Showing barrier log only...

Presolve removed 139 rows and 2 columns
Presolve time: 0.62s
Presolved: 25670 rows, 110330 columns, 422622 nonzeros

Ordering time: 0.07s

Barrier statistics:
 AA' NZ     : 3.038e+05
 Factor NZ  : 1.262e+06 (roughly 70 MB of memory)
 Factor Ops : 1.681e+08 (less than 1 second per iteration)
 Threads    : 3

Barrier performed 0 iterations in 1.01 seconds (0.29 work units)
Barrier solve interrupted - model solved by another algorithm



In [6]:
def store_solution_on_graph(opt_model, graph):
    sol_vars = opt_model.getVars()
    v_info = [v.VarName.lstrip('flow[').rstrip(']').split(',') + [v.X] for v in sol_vars]
    v_dict = {int(eid): {
        'var_name': var_name,
        'src_id': int(src_id),
        'target_id': int(target_id),
        'flow': float(flow)
    } for eid, var_name, src_id, src_label, target_id, target_label, flow in v_info if float(flow) > 0}

    # store the correct flow on each graph edge
    graph._g.es['flow'] = 0
    graph._g.es.select(list(v_dict.keys()))['flow'] = [v_dict[eid]['flow'] for eid in v_dict.keys()]


In [7]:
# v long step can we avoid or make faster
store_solution_on_graph(m, graph)
graph.save_flow_info(coords)


Summing flow: 100%|██████████| 8958/8958 [00:41<00:00, 216.42it/s]


## Oracle

The first step in creating an oracle is finding the correct "context" i.e. groups of vertices and edges in the ground truth that correspond to a given problem vertex `v` in the solution.

This is a reincarnation of the graph matching problem for benchmarking against ground truth solutions. 

Requirements given a sol vertex `v`:

- find associated vertices in ground truth
  - could be many especially for a split
- don't pick up any unassociated vertices (because these might be other vertices in our own graph)

Given that merge vertices have typically been an instance of undersegmentation, we pick all vertices in GT whose bounding boxes overlap with `v` in the solution.

**NOTE** the proportion of overlap could be its own parameter in an interactive system

**Also Note**: We can't use the exact same matching as in the metrics computation... can we?
- Because of majority overlap requirement, even if a computed vertex overlaps "a good chunk" of a ground truth vertex, it won't be matched - we'd like to find all of these overlapped vertices - but what if two computed vertices overlap the same ground truth vertex? Why don't we just look for nearby vertices in the false negatives or non splits of the existing match?

In [8]:
# load GT graph
def get_gt_graph(gt_path):
    coords, min_t, max_t, corners = get_im_centers(GT_PATH)
    srcs = []
    dests = []
    is_parent = []
    for label_val in range(coords['label'].min(), coords['label'].max()):
        gt_points = coords[coords.label == label_val].sort_values(by='t')
        track_edges = [(gt_points.index.values[i], gt_points.index.values[i+1]) for i in range(0, len(gt_points)-1)]
        if len(track_edges):
            sources, targets = zip(*track_edges)
            srcs.extend(sources)
            dests.extend(targets)
            is_parent.extend([0 for _ in range(len(sources))])

    man_track = pd.read_csv(os.path.join(gt_path, 'man_track.txt'), sep=' ', header=None)
    man_track.columns = ['current', 'start_t', 'end_t', 'parent']
    child_tracks = man_track[man_track.parent != 0]
    for index, row in child_tracks.iterrows():
        parent_id = row['parent']
        parent_end_t = man_track[man_track.current == parent_id]['end_t'].values[0]
        parent_coords = coords[(coords.label == parent_id)][coords.t == parent_end_t]
        child_coords = coords[(coords.label == row['current']) & (coords.t == row['start_t'])]
        srcs.append(parent_coords.index.values[0])
        dests.append(child_coords.index.values[0])
        is_parent.append(1)

    edges = pd.DataFrame({
        'sources': srcs,
        'dests': dests,
        'is_parent': is_parent
    })    
    graph = igraph.Graph.DataFrame(edges, directed=True, vertices=coords, use_vids=True)
    return graph, coords

gt_graph, gt_coords = get_gt_graph(GT_PATH)

  return ufunc.reduce(obj, axis, dtype, out, **passkwargs)
  parent_coords = coords[(coords.label == parent_id)][coords.t == parent_end_t]


In [9]:
# load gt_ims
sol_ims = load_tiff_frames(im_dir)
gt_ims = load_tiff_frames(GT_PATH)
merge_rows = coords[coords['in-mig'] > 1]

In [10]:
# matching vertices with bounding box overlap
def get_gt_match_vertices(coords, gt_coords, sol_ims, gt_ims, v_id, label_key='label'):
    from traccuracy.matchers._compute_overlap import get_labels_with_overlap

    # get mask of problem blob
    problem_info = coords.loc[[v_id], [label_key, 't']]
    problem_label = problem_info[label_key].values[0]
    problem_t = problem_info['t'].values[0]
    if (ct := len(problem_info)) > 1:
        raise ValueError(f"Solution label {problem_label} appears {ct} times in frame {problem_t}.")
    mask = sol_ims[problem_t] == problem_label
    gt_frame = gt_ims[problem_t]
    gt_ov_labels, _ = get_labels_with_overlap(gt_frame, mask)
    gt_v_ids = []
    for label in gt_ov_labels:
        row = gt_coords[(gt_coords.label == label) & (gt_coords.t==problem_t)]
        if (ct := len(row)) > 1:
            raise ValueError(f"GT label {label} appears {ct} times in frame {problem_t}.")
        vid = row.index.values[0]
        gt_v_ids.append(vid)
    return gt_v_ids
    


If a merge on vertex `v` is a result of undersegmentation, we expect to be able to find two (or potentially more) GT vertices with bounding boxes overlapping `v` - the one matching `v` itself, and the one matching the unidentified cell. 

Below we check the overlapping GT vertices for each merge vertex

In [11]:
count_no_match = 0
count_single_match = 0
count_multi_match = 0
for i, _ in merge_rows.iterrows():
    gt_matched = get_gt_match_vertices(coords, gt_coords, sol_ims, gt_ims, i)
    print(f'Matching GT vertices for {i}: {gt_matched}')
    ct = len(gt_matched)
    if ct == 0:
        count_no_match += 1
    elif ct == 1:
        count_single_match += 1
    else:
        count_multi_match += 1
print(f"Unmatched: {count_no_match}\nMatched: {count_single_match}\nMulti-matched: {count_multi_match}")

Matching GT vertices for 467: [467, 468]
Matching GT vertices for 576: [577]
Matching GT vertices for 619: [621]
Matching GT vertices for 2314: [2321]
Matching GT vertices for 2514: [2522]
Matching GT vertices for 2585: [2594]
Matching GT vertices for 2667: [2677]
Matching GT vertices for 2757: [2769]
Matching GT vertices for 3628: [3646]
Matching GT vertices for 3872: [3892]
Matching GT vertices for 4075: [4097]
Matching GT vertices for 4208: [4234]
Matching GT vertices for 4669: [4700]
Matching GT vertices for 4784: [4816]
Matching GT vertices for 4788: [4820]
Matching GT vertices for 4901: [4934]
Matching GT vertices for 4906: [4939]
Matching GT vertices for 5021: [5054]
Matching GT vertices for 5113: [5146]
Matching GT vertices for 5232: [5266]
Matching GT vertices for 5353: [5388]
Matching GT vertices for 6720: [6756]
Matching GT vertices for 6769: [6805]
Matching GT vertices for 6847: [6883]
Matching GT vertices for 6976: [7012]
Matching GT vertices for 6980: [7016]
Unmatched: 0


As we can see, we have matched every vertex (meaning we can be confident the merge vertices are real vertices in the ground truth graph), but only one of the merge vertices has an additional overlapping vertex i.e. was undersegmented by this overlap measure.

So, we need to be a bit less strict with our matching criteria, and match GT vertices within a given tolerance of the merge vertex. Keeping in mind we are trying not to make false associations for unrelated vertices in our solution, we can limit this approximate match to only **GT vertices within a given distance of `v` that have no other matching vertices in the solution graph** - these are false negative vertices which, if added to the solution, cannot possibly be a mis-association of an existing vertex.

**NOTE** We should probably **only** be doing this. If multiple computed vertices overlap the same gt vertex, how would we one identify that and two decide among them - of course in this check we dismiss a reference vertex if they have **any** overlap, so I think we're actually still running into this problem.

In [12]:
# vertices in gt frame near to v that don't have matching vertices in solution
def get_gt_unmatched_vertices(coords, gt_coords, sol_ims, gt_ims, v_id, dist):
    from scipy.spatial import KDTree
    from traccuracy.matchers._compute_overlap import get_labels_with_overlap
    import numpy as np

    problem_row = coords.loc[[v_id]]
    problem_t = problem_row['t'].values[0]
    cols = ['y', 'x']
    if 'z' in coords.columns:
        cols = ['z', 'y', 'x']
    problem_coords = tuple(problem_row[cols].values[0])
    
    # build kdt from gt frame
    gt_frame_coords = gt_coords[gt_coords['t'] == problem_t][cols]
    coord_indices, *coord_tuples = zip(*list(gt_frame_coords.itertuples(name=None)))
    coord_tuples = np.asarray(list(zip(*coord_tuples)))
    coord_indices = np.asarray(coord_indices)

    # get nearby vertices
    gt_tree = KDTree(coord_tuples)
    potential_unmatched = coord_indices[gt_tree.query_ball_point(problem_coords, dist, return_sorted=True)]
    
    unmatched = []
    problem_frame = sol_ims[problem_t]
    # check if they don't overlap with any solution vertices i.e. they are a fn
    for v in potential_unmatched:
        v_label = gt_coords.loc[[v], ['label']].values[0]
        mask = gt_ims[problem_t] == v_label
        _, sol_overlaps = get_labels_with_overlap(mask, problem_frame)
        if not len(sol_overlaps):
            unmatched.append(v)
    return unmatched

In [13]:
def get_gt_unmatched_vertices_near_parent(coords, gt_coords, sol_ims, gt_ims, v_id, v_parents, dist, label_key='label'):
    from scipy.spatial import KDTree
    from traccuracy.matchers._compute_overlap import get_labels_with_overlap
    import numpy as np

    problem_row = coords.loc[[v_id]]
    problem_t = problem_row['t'].values[0]
    cols = ['y', 'x']
    if 'z' in coords.columns:
        cols = ['z', 'y', 'x']
    parent_rows = coords.loc[v_parents]
    parent_coords = parent_rows[cols].values
    
    # build kdt from gt frame
    gt_frame_coords = gt_coords[gt_coords['t'] == problem_t][cols]
    coord_indices, *coord_tuples = zip(*list(gt_frame_coords.itertuples(name=None)))
    coord_tuples = np.asarray(list(zip(*coord_tuples)))
    coord_indices = np.asarray(coord_indices)

    # get nearby vertices close to both parents of v

    gt_tree = KDTree(coord_tuples)
    nearby = [n_index for n_list in gt_tree.query_ball_point(parent_coords, dist, return_sorted=True) for n_index in n_list]
    potential_unmatched = coord_indices[nearby]
    unmatched = []
    problem_frame = sol_ims[problem_t]
    # check if they don't overlap with any solution vertices i.e. they are a fn
    for v in potential_unmatched:
        v_label = gt_coords.loc[[v], ['label']].values[0]
        mask = gt_ims[problem_t] == v_label
        _, sol_overlaps = get_labels_with_overlap(mask, problem_frame)
        if not len(sol_overlaps) and v not in unmatched:
            unmatched.append(v)
    return unmatched

Below, we check whether unmatched GT vertices exist for the merge vertices above for different distance measures.

In [14]:
# counts_none = []
# counts_one = []
# counts_multi = []
# dists = [10, 20, 25, 30, 35, 40, 50, 60, 70]
# for dist in dists:
#     count_no_unmatched = 0
#     count_one_unmatched = 0
#     count_multi_unmatched = 0
#     for i, _ in merge_rows.iterrows():
#         parent_ids = [v for v in graph._g.neighbors(i, mode='in') if graph._g.es[graph._g.get_eid(v, i)]['flow'] > 0]
#         unmatched_gt = get_gt_unmatched_vertices_near_parent(coords, gt_coords, sol_ims, gt_ims, i, parent_ids, dist)
#         ct = len(unmatched_gt)
#         if ct == 0:
#             count_no_unmatched += 1
#         elif ct == 1:
#             count_one_unmatched += 1
#         else:
#             count_multi_unmatched += 1
#             # print(f"At distance {dist}, unmatched GT near {i}: {unmatched_gt}")
#     counts_none.append(count_no_unmatched)
#     counts_one.append(count_one_unmatched)
#     counts_multi.append(count_multi_unmatched)
# for i, dist in enumerate(dists):
#     print(f"Distance: {dist}\nNone: {counts_none[i]}, One: {counts_one[i]}, Multi: {counts_multi[i]}")


Based on the quick exploration above it looks like for this dataset, a distance of 40px captures many close-by unmatched vertices without finding multiple for a given vertex - which is more likely to be an unrelated vertex. Given the two functions above, let's see how many merge vertices have >1 associated GT vertices.

In [15]:
# count_no_match = 0
# count_single_match = 0
# count_multi_match = 0
# for i, _ in merge_rows.iterrows():
#     gt_matched = get_gt_match_vertices(coords, gt_coords, sol_ims, gt_ims, i)
#     # print(f'Matching GT vertices for {i}: {gt_matched}')
#     if len(gt_matched) == 1:
#         parent_ids = [v for v in graph._g.neighbors(i, mode='in') if graph._g.es[graph._g.get_eid(v, i)]['flow'] > 0]
#         gt_unmatched = get_gt_unmatched_vertices_near_parent(coords, gt_coords, sol_ims, gt_ims, i, parent_ids, 50)
#     else:
#         gt_unmatched = []
#     # print(f'Unmatched GT vertices for {i}: {gt_unmatched}')
#     ct = len(gt_matched) + len(gt_unmatched)
#     if ct == 0:
#         count_no_match += 1
#     elif ct == 1:
#         count_single_match += 1
#     else:
#         count_multi_match += 1
# print(f"Unmatched: {count_no_match}\nMatched: {count_single_match}\nMulti-matched: {count_multi_match}")

## Introducing Vertices

As we can see, after using the more relaxed matching measure when only one overlapping GT vertex is found, we find two associated vertices for half of the merge vertices, and a single vertex for the rest. Now we need to decide what to do with them.

Merge vertices **must** split on the next frame and this occurs in two (currently observed) ways in the dataset:

- A cell is undersegmented for a single frame, and the merge divides into its two constituents on the next frame
- A cell is undersegmented for multiple frames, and the extra flow in the merge vertex is shunted to a cell that divides in the next frame
  - which **should** be division flow

For now, we make minimal changes by simply introducing the additional vertices and fixing their incoming and outgoing edges. Introduced vertex `v'` is matched to `v`s furthest parent. Outgoing edges are slightly more complex.

When neither of `v`s current children are merge vertices, we connect `v` to its closest child, and `v'` to the other child. When a next vertex **is** a merge vertex, it means frame `t+1` does not contain a reasonable split for `v`, and the flow was sent elsewhere to cope. As a result, we terminate `v`s predecessor `u` with the longest edge to `v` i.e. `u` flows to target - this "divests" it of any additional flow. We also do this with `v`s furthest away parent if there is no additional vertex introduced - implying there is no reasonable split available for `v` at time `t`.

In [16]:
last_label = 0
last_index = 0
v_info = None
oracle = {}
for i, _ in merge_rows.iterrows():
    gt_matched = get_gt_match_vertices(coords, gt_coords, sol_ims, gt_ims, i)
    parent_ids = [v for v in graph._g.neighbors(i, mode='in') if graph._g.es[graph._g.get_eid(v, i)]['flow'] > 0]
    gt_unmatched = get_gt_unmatched_vertices_near_parent(coords, gt_coords, sol_ims, gt_ims, i, parent_ids, 50)
    problem_v = coords.loc[[i]]
    problem_coords = tuple(problem_v[['y', 'x']].values[0])

    # we couldn't find a match for this vertex at all, we should just delete it
    if not len(gt_matched) and not len(gt_unmatched):
        decision = 'delete'
    # we've only found one vertex nearby, it's v itself
    elif len(gt_matched) + len(gt_unmatched) == 1:
        decision = 'terminate'
    # more than one "true" vertex overlaps v, a vertex should be introduced
    elif len(gt_matched) > 1:
        # closest match is `v`, second closest gets introduced
        distances_to_v = [np.linalg.norm(
                            np.asarray(problem_coords) - np.asarray(gt_coords.loc[[v], ['y', 'x']].values[0])
                        ) for v in gt_matched]
        second_closest = gt_matched[np.argsort(distances_to_v)[1]]
        v_info = gt_coords.loc[second_closest]
        decision = 'introduce'
    # we didn't find >1 overlap, but we've found an unmatched GT vertex nearby
    elif len(gt_unmatched):
        # we just take the closest
        v_id = gt_unmatched[0]
        v_info = gt_coords.loc[v_id]
        decision = 'introduce'

    if v_info is not None:
        if last_label == 0:
            next_label = coords['label'].max() + 1
            # hypervertices...
            new_index = max(coords.index.values) + 5
        else:
            next_label = last_label + 1
            new_index = last_index + 1

        last_label = next_label
        last_index = new_index

    oracle[i] = {
        'decision': decision,
        'v_info': None if v_info is None else (new_index, list(v_info[['t', 'y', 'x']]) + [next_label]),
        'parent': None
    }
    v_info = None

## Introducing just vertices - no edge fixing

In [17]:
introduce_vertices = dict(filter(lambda item: item[1]['decision'] == 'introduce', oracle.items()))

merge_vids = []
new_vids = []
new_t = []
y = []
x = []
new_label = []
for key, v_info in introduce_vertices.items():
    merge_vids.append(key)
    v_info = v_info['v_info']
    new_vids.append(v_info[0])
    info_list = v_info[1]
    new_t.append(int(info_list[0]))
    y.append(info_list[1])
    x.append(info_list[2])
    new_label.append(info_list[3])

In [18]:
oracle_intro_df = pd.DataFrame({
    'merge_id': merge_vids,
    'new_id': new_vids,
    't': new_t,
    'y': y,
    'x': x,
    'new_label': new_label
})
oracle_intro_df.head()

Unnamed: 0,merge_id,new_id,t,y,x,new_label
0,467,8606,9,505.0,907.0,390
1,576,8607,11,505.0,907.0,391
2,2314,8608,37,435.0,177.0,392
3,2514,8609,39,509.0,906.0,393
4,2585,8610,40,506.0,911.0,394


In [19]:
for e in graph._g.incident(2959, mode='in'):
    print(graph._g.es[e], graph._g.es[e].source, graph._g.es[e].target)
print(graph._g.vs[2959])

igraph.Edge(<igraph.Graph object at 0x7fcdb1a3f040>, 48022, {'cost': 120.29703974082108, 'var_name': 'e_43.78_44.318', 'label': '120.2', 'flow': 0}) 2803 2959
igraph.Edge(<igraph.Graph object at 0x7fcdb1a3f040>, 48037, {'cost': 167.97343896117354, 'var_name': 'e_43.79_44.318', 'label': '167.9', 'flow': 0}) 2804 2959
igraph.Edge(<igraph.Graph object at 0x7fcdb1a3f040>, 48047, {'cost': 130.07172618650856, 'var_name': 'e_43.88_44.318', 'label': '130.0', 'flow': 0}) 2805 2959
igraph.Edge(<igraph.Graph object at 0x7fcdb1a3f040>, 48467, {'cost': 218.18774836685924, 'var_name': 'e_43.228_44.318', 'label': '218.1', 'flow': 0}) 2847 2959
igraph.Edge(<igraph.Graph object at 0x7fcdb1a3f040>, 48591, {'cost': 53.96202604113287, 'var_name': 'e_43.307_44.318', 'label': '53.96', 'flow': 0}) 2860 2959
igraph.Edge(<igraph.Graph object at 0x7fcdb1a3f040>, 48599, {'cost': 43.6193398920349, 'var_name': 'e_43.308_44.318', 'label': '43.61', 'flow': 0}) 2861 2959
igraph.Edge(<igraph.Graph object at 0x7fcdb1a3

In [20]:
introduce_vertices = dict(filter(lambda item: item[1]['decision'] == 'introduce', oracle.items()))
introduce_oracle = {item['v_info'][0]: (int(item['v_info'][1][0]), item['v_info'][1][1:-1], item['v_info'][1][-1]) for item in introduce_vertices.values()}
graph.introduce_vertices(introduce_oracle)

HI


Rebuilding frames: 100%|██████████| 26/26 [00:02<00:00,  9.49it/s]


In [21]:
# graph._g.neighbors(graph._g.vs[8520], mode='in')
# print(graph._g.vs[-1])
# print(graph._g.vs[8621])
# graph._g.vs[-60]

## Introducing vertices and fixing edges

In [22]:
def introduce_vertex(merge_v, graph, oracle):
    oracle_info = oracle[merge_v]
    new_vid, info = oracle_info['v_info']
    t = int(info[0])
    coords = tuple(info[1:3])
    new_label = info[3]

    get_flow = lambda x, y: graph._g.es[graph._g.get_eid(x, y)]['flow']
    get_distance = lambda x, y: np.linalg.norm(np.asarray(x['coords'])) - np.linalg.norm(np.asarray(y['coords']))

    children = [graph._g.vs[v] for v in graph._g.neighbors(merge_v, 'out') if get_flow(merge_v, v) > 0]

    parents = [graph._g.vs[v] for v in graph._g.neighbors(merge_v, 'in') if get_flow(v, merge_v) > 0]
    # merge_v has been dealt with, it might have a new parent to assign
    if len(parents) < 2:
        if oracle[merge_v]['parent']:
            new_parent = oracle[merge_v]['parent']
        # if it doesn't that means there's a vertex being introduced, but we don't know who to parent it to?
        else:
            new_parent = None
            raise ValueError(f"Vertex {merge_v} only has parents {parents}. New vertex {new_vid}:{v_info} will have no parent connnection!")
    else:
        new_parent = parents[0].index if get_distance(parents[0], graph._g.vs[merge_v]) > get_distance(parents[1], graph._g.vs[merge_v]) else parents[1].index

    # add vertex
    graph.introduce_vertex(new_vid, t, coords, new_label)
    if graph._g.are_connected(new_parent, merge_v):
        # delete current edge (new_parent, merge_v)
        graph._g.delete_edges([(new_parent,merge_v)])
    # add new edge (new_parent, introduced_v)
    graph.add_edge(new_parent, new_vid, is_fixed=True)

    is_merge_child1 = children[0].index  in oracle
    is_merge_child2 = children[1].index  in oracle
    # find longest edge
    furthest_child = children[0].index if get_distance(graph._g.vs[merge_v], children[0]) > get_distance(graph._g.vs[merge_v], children[1]) else children[1].index
    # neither child a merge
    if not is_merge_child1 and not is_merge_child2:
        # delete current edge (merge_v, furthest_child)
        graph._g.delete_edges([(merge_v, furthest_child)])
        # add new edge (new_v, furthest_child)
        graph.add_edge(new_vid, furthest_child, is_fixed=True)
    # a single child is merge
    elif is_merge_child1 ^ is_merge_child2:
        merge_child, other_child = (children[0].index, children[1].index) if is_merge_child1 else (children[1].index, children[0].index)
        
        # delete current edge (merge_v, merge_child)
        graph._g.delete_edges([(merge_v, merge_child)])
        # fix edge (merge_v, other_child) - cost =0?
        graph._g.es[graph._g.get_eid(merge_v, other_child)]['cost'] = 0

        # this merge child will also be getting split, so we'll have a new vertex to parent to new_vid
        if oracle[merge_child]['decision'] == 'introduce':
            oracle[merge_child]['parent'] = new_vid
        # we'll have nowhere really to send this vertex, so send it to target
        else:
            # add edge (new_v, target)
            graph.add_edge(new_vid, graph.target.index, is_fixed=True)
    # both children merge
    else:
        # delete current edges (merge_v, both_children)
        graph._g.delete_edges([(merge_v, children[0].index)])
        graph._g.delete_edges([(merge_v, children[1].index)])

        # add edge new_v - target
        graph.add_edge(new_vid, graph.target.index, is_fixed=True)
        # add edge merge_v - target
        graph.add_edge(new_vid, graph.target.index, is_fixed=True)


Now that we've got functions to deal with the decisions made by the oracle, we can apply them on the graph. We save the full graph for visualization purposes, and then build and solve the optimization model again on the updated graph.

Because the graph is changing as we introduce/terminate vertices, it's important that our oracle can deal with these changes appropriately. In particular, by the time we get to a certain time frame and introduce a vertex, `v'`, it may no longer be a merge vertex, if its edges have been fixed as part of a previous introduction.

To deal with this we first:
- Introduce all vertices before terminating any. Introduced vertices are more likely to "fix" the local graph, and given our retrieval of them, we can be relatively sure they **need** to be there. Termination is a more fuzzy oracle decision, so we only make it if we really have to.
- When introducing a vertex `v'` at `t` whose associated merge vertex `v` is no longer a merge vertex, it means we must have introduced a vertex in the previous frame - this new vertex should be the parent of `v'`
    - We track this new parent in the oracle so that we can appropriately assign it

In [23]:
# introduce_vertices = dict(filter(lambda item: item[1]['decision'] == 'introduce', oracle.items()))
# for vid in introduce_vertices:
#     introduce_vertex(vid, graph, oracle)

In [24]:
# still_merged = []
# for merge_v in oracle.keys():
#     incoming_vs = graph._g.neighbors(merge_v, mode='in')
#     actual_incoming = []
#     for v in incoming_vs:
#         relevant_edge = graph._g.es[graph._g.get_eid(v, merge_v)]
#         if relevant_edge['flow'] > 0 or relevant_edge['cost'] == 0:
#             actual_incoming.append(v)
#     if len(actual_incoming) > 1:
#         print(f"Vertex {merge_v} is still a merge vertex. Incoming vertices: {actual_incoming}")
#         print(f"Oracle for {merge_v}: {oracle[merge_v]}")


In [25]:
# import networkx as nx
# import numpy as np
# full_path = "/home/draga/PhD/data/cell_tracking_challenge/Fluo-N2DL-HeLa/01_RES_IC/oracle_introduced_full.graphml"
# del(graph._g.vs['name'])
# del(graph._g.es['label'])
# for v in graph._g.vs:
#     v['y'] = v['coords'][0]
#     v['x'] = v['coords'][1]
#     for attr_name in graph._g.vertex_attributes():
#         if isinstance(v[attr_name], np.bool_):
#             v[attr_name] = int(v[attr_name])
#         elif v[attr_name] is None:
#             v[attr_name] = ''
# for e in graph._g.es:
#     for attr_name in graph._g.edge_attributes():
#         if e[attr_name] is None:
#             e[attr_name] = 0
# del(graph._g.vs['coords'])
# g_nx = graph._g.to_networkx()
# nx.write_graphml_lxml(g_nx, full_path)

In [26]:
# import networkx as nx
# mig_copy = graph._g.copy()
# mig_copy.delete_edges(lambda e: (e['flow'] == 0 and e['cost'] > 0))
# to_delete_ids = [v.index for v in mig_copy.vs if v['label'] in ['division', 'source', 'appearance', 'target']]
# mig_copy.delete_vertices(to_delete_ids)
# mig_nx = mig_copy.to_networkx()
# full_path = "/home/draga/PhD/data/cell_tracking_challenge/Fluo-N2DL-HeLa/01_RES_IC/oracle_introduced_mig.graphml"
# nx.write_graphml_lxml(mig_nx, full_path)

## Rebuild Model, Solve and Check

In [27]:
new_m, flow = graph._to_gurobi_model()
new_m.optimize()

Gurobi Optimizer version 9.5.2 build v9.5.2rc0 (linux64)
Thread count: 4 physical cores, 8 logical processors, using up to 8 threads
Optimize a model with 25860 rows, 110553 columns and 424970 nonzeros
Model fingerprint: 0x6a120820
Coefficient statistics:
  Matrix range     [1e+00, 1e+00]
  Objective range  [2e-02, 4e+02]
  Bounds range     [1e+00, 1e+00]
  RHS range        [1e+00, 1e+00]

Concurrent LP optimizer: dual simplex and barrier
Showing barrier log only...

Presolve removed 139 rows and 2 columns
Presolve time: 0.35s
Presolved: 25721 rows, 110551 columns, 423471 nonzeros

Ordering time: 0.04s

Barrier statistics:
 AA' NZ     : 3.044e+05
 Factor NZ  : 1.279e+06 (roughly 70 MB of memory)
 Factor Ops : 1.737e+08 (less than 1 second per iteration)
 Threads    : 3

Barrier performed 0 iterations in 0.54 seconds (0.27 work units)
Barrier solve interrupted - model solved by another algorithm


Solved with dual simplex
Solved in 9454 iterations and 0.54 seconds (0.33 work units)
Opti

In [28]:
def convert_sol_igraph_to_nx(graph):
    for v in graph._g.vs:
        v['y'] = v['coords'][0]
        v['x'] = v['coords'][1]
        for attr_name in graph._g.vertex_attributes():
            if isinstance(v[attr_name], np.bool_):
                v[attr_name] = int(v[attr_name])
            elif v[attr_name] is None:
                v[attr_name] = 0
    for e in graph._g.es:
        for attr_name in graph._g.edge_attributes():
            if e[attr_name] is None:
                e[attr_name] = 0
    del(graph._g.vs['coords'])
    del(graph._g.vs['name'])
    del(graph._g.vs['label'])

    del(graph._g.es['label'])
    nx_g = graph._g.to_networkx(create_using=nx.DiGraph)
    return nx_g

In [29]:
# save info on graph
store_solution_on_graph(new_m, graph)
nx_g = convert_sol_igraph_to_nx(graph)
oracle_node_df = pd.DataFrame.from_dict(nx_g.nodes, orient='index')

In [30]:
graph.save_flow_info(oracle_node_df)

Summing flow: 100%|██████████| 8964/8964 [01:15<00:00, 118.29it/s]


In [31]:
for e in graph._g.incident(2959, mode='in'):
    print(graph._g.es[e], graph._g.es[e].source, graph._g.es[e].target)
print(graph._g.vs[2959])

igraph.Edge(<igraph.Graph object at 0x7fcdb1a3f040>, 39310, {'cost': 120.29703974082108, 'var_name': 'e_43.78_44.318', 'flow': 0}) 2803 2959
igraph.Edge(<igraph.Graph object at 0x7fcdb1a3f040>, 39325, {'cost': 167.97343896117354, 'var_name': 'e_43.79_44.318', 'flow': 0}) 2804 2959
igraph.Edge(<igraph.Graph object at 0x7fcdb1a3f040>, 39335, {'cost': 130.07172618650856, 'var_name': 'e_43.88_44.318', 'flow': 0}) 2805 2959
igraph.Edge(<igraph.Graph object at 0x7fcdb1a3f040>, 39755, {'cost': 218.18774836685924, 'var_name': 'e_43.228_44.318', 'flow': 0}) 2847 2959
igraph.Edge(<igraph.Graph object at 0x7fcdb1a3f040>, 39879, {'cost': 53.96202604113287, 'var_name': 'e_43.307_44.318', 'flow': 0}) 2860 2959
igraph.Edge(<igraph.Graph object at 0x7fcdb1a3f040>, 39887, {'cost': 43.6193398920349, 'var_name': 'e_43.308_44.318', 'flow': 0}) 2861 2959
igraph.Edge(<igraph.Graph object at 0x7fcdb1a3f040>, 39898, {'cost': 62.6603460350025, 'var_name': 'e_43.317_44.318', 'flow': 0}) 2862 2959
igraph.Edge(<i

In [32]:
full_path = "/home/draga/PhD/data/cell_tracking_challenge/Fluo-N2DL-HeLa/01_RES_IC/oracle_introduced_near_parent_no_edges.graphml"
nx.write_graphml_lxml(nx_g, full_path)

In [33]:
new_merges = oracle_node_df[oracle_node_df['in-mig'] > 1]
term_v = dict(filter(lambda item: item[1]['decision'] == 'terminate', oracle.items()))
new_merges_no_term = new_merges.drop(list(term_v.keys()), errors='ignore')


In [34]:
new_merges_no_term

Unnamed: 0,pixel_value,t,is_source,is_target,is_appearance,is_division,y,x,_igraph_index,in-app,in-div,in-mig,out-mig,out-target
632,363,12,0,0,0,0,500.765372,891.763754,632,0.0,0.0,2.0,2.0,0
2687,363,41,0,0,0,0,492.375,891.059524,2687,0.0,0.0,2.0,2.0,0
2872,363,43,0,0,0,0,486.684524,892.791667,2872,0.0,0.0,2.0,2.0,0
4113,266,55,0,0,0,0,541.792321,935.307155,4113,0.0,0.0,2.0,2.0,0
4230,279,56,0,0,0,0,543.175439,971.371345,4230,0.0,0.0,2.0,2.0,0
4339,265,57,0,0,0,0,507.3792,961.9216,4339,0.0,0.0,1.5,1.5,0
4343,278,57,0,0,0,0,518.868056,1007.775463,4343,0.0,0.0,1.5,1.5,0
5533,325,67,0,0,0,0,340.281899,759.038576,5533,0.0,1.0,2.5,3.5,0
8612,396,42,False,False,False,False,505.0,908.0,8612,0.0,0.0,2.0,2.0,0


In [35]:
from ctc_fluo_metrics import introduce_gt_labels
new_sol_ims = sol_ims.copy()
introduce_gt_labels(nx_g, new_sol_ims, gt_ims)

count_no_match = 0
count_single_match = 0
count_multi_match = 0
for i, _ in new_merges_no_term.iterrows():
    gt_overlaps = get_gt_match_vertices(oracle_node_df, gt_coords, new_sol_ims, gt_ims, i, label_key='pixel_value')
    if len(gt_overlaps) == 1:
        parent_ids = [v for v in graph._g.neighbors(i, mode='in') if graph._g.es[graph._g.get_eid(v, i)]['flow'] > 0]
        gt_unmatched = get_gt_unmatched_vertices_near_parent(oracle_node_df, gt_coords, new_sol_ims, gt_ims, i, parent_ids, 50, label_key='pixel_value')
        gt_overlaps = []
    else:
        gt_unmatched = []
    # print(f'Unmatched GT vertices for {i}: {gt_unmatched}')
    ct = len(gt_overlaps) + len(gt_unmatched)
    if ct == 0:
        count_no_match += 1
        print(f'Merge vertex {i} does not have a nearby vertex to introduce')
    elif ct == 1:
        count_single_match += 1
        print(f'Nearby fn vertices for merge vertex {i}: {gt_overlaps + gt_unmatched}')
    else:
        count_multi_match += 1
        print(f'Nearby fn_vertices for merge vertex {i}: {gt_overlaps + gt_unmatched}')
print(f"Unmatched: {count_no_match}\nMatched: {count_single_match}\nMulti-matched: {count_multi_match}")

Nearby fn_vertices for merge vertex 632: [634, 635]
Nearby fn vertices for merge vertex 2687: [2699]
Nearby fn vertices for merge vertex 2872: [2887]
Nearby fn vertices for merge vertex 4113: [4163]
Merge vertex 4230 does not have a nearby vertex to introduce
Merge vertex 4339 does not have a nearby vertex to introduce
Merge vertex 4343 does not have a nearby vertex to introduce
Merge vertex 5533 does not have a nearby vertex to introduce
Merge vertex 8612 does not have a nearby vertex to introduce
Unmatched: 5
Matched: 3
Multi-matched: 1


In [36]:
nx_g = graph._g.to_networkx(create_using=nx.DiGraph)


In [37]:
full_path = "/home/draga/PhD/data/cell_tracking_challenge/Fluo-N2DL-HeLa/01_RES_IC/oracle_introduced_near_parent_no_edges.graphml"
nx.write_graphml_lxml(nx_g, full_path)

In [38]:
def terminate_merge_vertex(merge_v, graph):
    get_flow = lambda x, y: graph._g.es[graph._g.get_eid(x, y)]['flow']
    get_distance = lambda x, y: np.linalg.norm(np.asarray(x['coords'])) - np.linalg.norm(np.asarray(y['coords']))

    parents = [graph._g.vs[v] for v in graph._g.neighbors(merge_v, 'in') if get_flow(v, merge_v) > 0]
    children = [graph._g.vs[v] for v in graph._g.neighbors(merge_v, 'out') if get_flow(merge_v, v) > 0]

    furthest_child = children[0].index if get_distance(graph._g.vs[merge_v], children[0]) > get_distance(graph._g.vs[merge_v], children[1]) else children[1].index
    furthest_parent = parents[0].index if get_distance(parents[0], graph._g.vs[merge_v]) > get_distance(parents[1], graph._g.vs[merge_v]) else parents[1].index
    
    # Delete edge (furthest_parent, merge_v)
    graph._g.delete_edges([(furthest_parent, merge_v)])
    # Add edge from furthest parent to target
    graph.add_edge(furthest_parent, graph.target.index, is_fixed=True)
    # Delete (merge_v, furthest_child)
    graph._g.delete_edges([(merge_v, furthest_child)])

### TODO:

- ~~Finish introduce_vertex~~
- ~~Write terminate_merge_vertex~~
- [ ] How do we handle changes to interconnected merge vertices? I.e. making sure it's consistent
  - Do any of these actions risk enforcing a potentially unsatisfiable set of edges
- ~~Save edited graph after running introduce~~
- ~~Save edited graph after running terminate~~
- ~~Add code to rebuild and rerun solution~~
- [ ] Check edge labels with `None` in them?
- [ ] Update coord rows after introducing vertices so we can save flow info
- [ ] Visualize/metric/check new solution
  - [ ] Potentially fix metric
  - [ ] Check original solution fp/tp/fn edges vs. oracle solution fp/tp/fn edges vs. re-solved fp/tp/fn edges
- [ ] Make edge deletion fixed 0 flow
- [ ] Improve edge deletion/addition to do all at once

- [ ] Update live model and re-solve