In [23]:
import pandas as pd
import numpy as np
from collections import defaultdict, deque
from scipy.optimize import linear_sum_assignment

In [24]:
ts_df = pd.read_csv("./data/tracks.txt", sep="\t")
ts_df = ts_df.loc[ts_df["t"] <= 242]
cell_names = ts_df["name"].unique()
valid_cell_names = []
for name in cell_names:
    time_points = ts_df.loc[ts_df['name'] == name]["t"].values
    if len(time_points == 1) and time_points[0] == 242:
        continue
    valid_cell_names.append(name)


In [25]:

# name mapping
def map_names(did):
    """ Re-map cells to use their 'name' given their 'did'. Only applies to a
        few select cells where the tracker uses their 'name' instead of 'did'.
    """
    if   did == "P4a": return "Z3"
    elif did == "P4p": return "Z2"
    elif did == "P0a": return "AB"
    else: return did

untracked_nodes = ["AB", "P0", "P1"]

first_internal_layer = ["ABa", "ABp", "EMS", "P2"]

In [26]:
import json

def load_json(file_path):
    """
    Load a JSON file and return its content as a Python object.
    
    :param file_path: Path to the JSON file.
    :return: Parsed JSON content as a Python object.
    """
    with open(file_path, 'r', encoding='utf-8') as file:
        return json.load(file)
    
lineage_data = load_json('./data/cell_lineage.json')


In [27]:
class Node:
    def __init__(self, key, coordinate):
        self.key = key
        self.coordinate = coordinate
        self.children = []
        self.optimized = False
        self.parent = None
    

In [28]:
terminal_nodes = []
internal_front = []
internal_pool = []
coordinate_map = {}
depth_group = defaultdict(list)
def dfs(node, terminal_nodes, depth):
    children = node.get("children", [])
    lookup_name = map_names(node["did"])
    if lookup_name in valid_cell_names or lookup_name in untracked_nodes:
        if lookup_name in valid_cell_names:
            coordinate = ts_df.loc[ts_df['name'] == lookup_name].values[-1][1:4]*0.1625
        else:
            coordinate = np.array([0, 0, 0])
        tree_node = Node(lookup_name, coordinate)
        coordinate_map[lookup_name] = coordinate
        for child in children:
            child_node = dfs(child, terminal_nodes, depth+1)
            if child_node:
                child_node.parent = tree_node
                tree_node.children.append(child_node)
        if len(children) == 0:
            terminal_nodes.append(tree_node)
            tree_node.optimized = True
            depth_group[depth].append(tree_node)
            return tree_node
        if lookup_name in first_internal_layer:
            tree_node.optimized = True
            internal_front.append(tree_node)
        elif lookup_name not in untracked_nodes:
            internal_pool.append(tree_node.key)
        depth_group[depth].append(tree_node)
        return tree_node
    else:
        return None

root = dfs(lineage_data, terminal_nodes, 0)
terminal_nodes_names = [node.key for node in terminal_nodes]

In [43]:
# pruning "fake" terminal nodes
terminal_nodes = []
internal_front = []
internal_pool = []
coordinate_map = {}
fake_terminal = []
depth_group = defaultdict(list)
def dfs(node, terminal_nodes, depth):
    children = node.get("children", [])
    lookup_name = map_names(node["did"])
    if lookup_name in valid_cell_names or lookup_name in untracked_nodes:
        if lookup_name in valid_cell_names:
            coordinate = ts_df.loc[ts_df['name'] == lookup_name].values[-1][1:4]*0.1625
        else:
            coordinate = np.array([0, 0, 0])
        tree_node = Node(lookup_name, coordinate)
        coordinate_map[lookup_name] = coordinate
        for child in children:
            child_node = dfs(child, terminal_nodes, depth+1)
            if child_node:
                child_node.parent = tree_node
                tree_node.children.append(child_node)
        if len(children) == 0:
            terminal_nodes.append(tree_node)
            tree_node.optimized = True
            depth_group[depth].append(tree_node)
            return tree_node
        # "fake" terminal nodes:
        if len(tree_node.children) == 0:
            fake_terminal.append(tree_node)
        if lookup_name in first_internal_layer:
            tree_node.optimized = True
            internal_front.append(tree_node)
        elif lookup_name not in untracked_nodes:
            internal_pool.append(tree_node.key)
        depth_group[depth].append(tree_node)
        return tree_node
    else:
        return None

root = dfs(lineage_data, terminal_nodes, 0)
terminal_nodes_names = [node.key for node in terminal_nodes]

In [44]:
# fake terminal pruning with BFS
fake_q = deque(fake_terminal)
while len(fake_q) > 0:
    cur_node = fake_q.popleft()
    if (cur_node.key in first_internal_layer): print('hmmm....')
    internal_pool.remove(cur_node.key)
    cur_parent = cur_node.parent
    cur_parent.children.remove(cur_node)
    if len(cur_parent.children) < 1:
        fake_q.append(cur_parent)

381? check if it is still a tree, at most two children, true terminal nodes, exactly one parent node

In [45]:
cost_q = deque()
for node in internal_front:
    cost_q.append((node, 0))
level_cost = defaultdict(int)
lineage_cost = 0
while len(cost_q) > 0:
    cur_node, cur_level = cost_q.popleft()
    for child in cur_node.children:
        lineage_cost += np.linalg.norm(cur_node.coordinate - child.coordinate)
        level_cost[cur_level] += np.linalg.norm(cur_node.coordinate - child.coordinate)
        cost_q.append((child, cur_level+1))
print(lineage_cost)

2767.4784896724996


In [46]:
# strictly optimize layer by layer
for depth in range(3, 11):
    cur_level_nodes = []
    cur_level_nodes_names = []
    for node in depth_group[depth]:
        if node.key in internal_pool:
            cur_level_nodes.append(node)
            cur_level_nodes_names.append(node.key)
    parents = [node.parent.key for node in cur_level_nodes]
    cost_mat = np.zeros((len(cur_level_nodes_names), len(parents)))
    for i in range(len(cur_level_nodes_names)):
        for j in range(len(parents)):
            cost_mat[i][j] = np.linalg.norm(coordinate_map[cur_level_nodes_names[i]] - coordinate_map[parents[j]])
    row_indices, col_indices = linear_sum_assignment(cost_mat)
    for row, col in zip(row_indices, col_indices):
        cur_node = cur_level_nodes[col]
        new_node_name = cur_level_nodes_names[row]
        new_node_coordinate = coordinate_map[new_node_name]
        cur_node.key = new_node_name
        cur_node.coordinate = new_node_coordinate
        cur_node.optimized = True

terminal_parents = [node.parent.key for node in terminal_nodes]
cost_mat = np.zeros((len(terminal_nodes_names), len(terminal_parents)))
for i in range(len(terminal_nodes_names)):
    for j in range(len(terminal_parents)):
        cost_mat[i][j] = np.linalg.norm(coordinate_map[terminal_nodes_names[i]] - coordinate_map[terminal_parents[j]])
row_indices, col_indices = linear_sum_assignment(cost_mat)
for row, col in zip(row_indices, col_indices):
    cur_node = terminal_nodes[col]
    new_node_name = terminal_nodes_names[row]
    new_node_coordinate = coordinate_map[new_node_name]
    cur_node.key = new_node_name
    cur_node.coordinate = new_node_coordinate
    cur_node.optimized = True

In [47]:
cost_q = deque(internal_front)
lineage_cost = 0
while len(cost_q) > 0:
    cur_node = cost_q.popleft()
    for child in cur_node.children:
        lineage_cost += np.linalg.norm(cur_node.coordinate - child.coordinate)
        cost_q.append(child)
print(lineage_cost)

2498.1560431395287


In [None]:
optimize_q = internal_front.copy()
cur_cost = 0
terminal_parents = []
while len(internal_pool) > 0:
    parents = []
    next_front = []
    for node in optimize_q:
        for child in node.children:
            if child.optimized: 
                terminal_parents.append(node)
                continue
            parents.append(node.key)
            next_front.append(child)
    cost_mat = np.zeros((len(internal_pool), len(parents)))
    for i in range(len(internal_pool)):
        for j in range(len(parents)):
            cost_mat[i][j] = np.linalg.norm(coordinate_map[internal_pool[i]] - coordinate_map[parents[j]])
    row_indices, col_indices = linear_sum_assignment(cost_mat)
    for row, col in zip(row_indices, col_indices):
        cur_node = next_front[col]
        new_node_name = internal_pool[row]
        new_node_coordinate = coordinate_map[new_node_name]
        cur_node.key = new_node_name
        cur_node.coordinate = new_node_coordinate
        cur_node.optimized = True

    for index in reversed(row_indices):
        del internal_pool[index]
    optimize_q = next_front
    cur_cost += cost_mat[row_indices, col_indices].sum()
    print(cost_mat[row_indices, col_indices].sum())


22.94056466762293
53.88515892774103
100.90815688828023
153.67046134169698
303.0360285968159
471.59221458006994
801.0344045949911


In [None]:
terminal_parents = [node.parent.key for node in terminal_nodes]
cost_mat = np.zeros((len(terminal_nodes_names), len(terminal_parents)))
for i in range(len(terminal_nodes_names)):
    for j in range(len(terminal_parents)):
        cost_mat[i][j] = np.linalg.norm(coordinate_map[terminal_nodes_names[i]] - coordinate_map[terminal_parents[j]])
row_indices, col_indices = linear_sum_assignment(cost_mat)
for row, col in zip(row_indices, col_indices):
    cur_node = terminal_nodes[col]
    new_node_name = terminal_nodes_names[row]
    new_node_coordinate = coordinate_map[new_node_name]
    cur_node.key = new_node_name
    cur_node.coordinate = new_node_coordinate
    cur_node.optimized = True

In [61]:
cost_q = deque(internal_front)
lineage_cost = 0
while len(cost_q) > 0:
    cur_node = cost_q.popleft()
    for child in cur_node.children:
        lineage_cost += np.linalg.norm(cur_node.coordinate - child.coordinate)
        cost_q.append(child)
print(lineage_cost)

2682.920752885195


In [40]:
len(terminal_nodes)

474

In [31]:
sorted(depth_group.keys(), reverse=True)[:-3]

[10, 9, 8, 7, 6, 5, 4, 3]

In [32]:
len(internal_pool)

692

In [70]:
cur_cost = 0
for depth in sorted(depth_group.keys(), reverse=True)[:-3]:
    optimize_q = depth_group[depth]
    parents_dict = {}
    next_front = []
    parents = []
    for node in optimize_q:
        if node.parent.optimized: continue
        if node.parent.key in parents_dict:
            parents_dict[node.parent.key].append(node.key)
        else:
            parents_dict[node.parent.key] = [node.key]
            parents.append(node.parent.key)
            next_front.append(node.parent)

    cost_mat = np.zeros((len(internal_pool), len(parents)))
    for i in range(len(internal_pool)):
        for j in range(len(parents)):
            for child in parents_dict[parents[j]]:
                cost_mat[i][j] += np.linalg.norm(coordinate_map[internal_pool[i]] - coordinate_map[child])
    row_indices, col_indices = linear_sum_assignment(cost_mat)
    for row, col in zip(row_indices, col_indices):
        cur_node = next_front[col]
        new_node_name = internal_pool[row]
        new_node_coordinate = coordinate_map[new_node_name]
        cur_node.key = new_node_name
        cur_node.coordinate = new_node_coordinate

    for index in reversed(row_indices):
        del internal_pool[index]
    optimize_q = next_front
    cur_cost += cost_mat[row_indices, col_indices].sum()
    print(len(optimize_q))

90
159
103
60
30
16
8
0


In [10]:
optimize_q = terminal_nodes.copy()
cur_cost = 0
while len(optimize_q) > 0:
    parents_dict = {}
    next_front = []
    parents = []
    for node in optimize_q:
        if node.parent.optimized: continue
        if node.parent.key in parents_dict:
            parents_dict[node.parent.key].append(node.key)
        else:
            parents_dict[node.parent.key] = [node.key]
            parents.append(node.parent.key)
            next_front.append(node.parent)

    cost_mat = np.zeros((len(internal_pool), len(parents)))
    for i in range(len(internal_pool)):
        for j in range(len(parents)):
            for child in parents_dict[parents[j]]:
                cost_mat[i][j] += np.linalg.norm(coordinate_map[internal_pool[i]] - coordinate_map[child])
    row_indices, col_indices = linear_sum_assignment(cost_mat)
    for row, col in zip(row_indices, col_indices):
        cur_node = next_front[col]
        new_node_name = internal_pool[row]
        new_node_coordinate = coordinate_map[new_node_name]
        cur_node.key = new_node_name
        cur_node.coordinate = new_node_coordinate

    for index in reversed(row_indices):
        del internal_pool[index]
    optimize_q = next_front
    cur_cost += cost_mat[row_indices, col_indices].sum()
    print(len(optimize_q))


155
109
72
48
25
14
5
0


In [11]:
cost_q = deque(internal_front)
lineage_cost = 0
while len(cost_q) > 0:
    cur_node = cost_q.popleft()
    for child in cur_node.children:
        lineage_cost += np.linalg.norm(cur_node.coordinate - child.coordinate)
        cost_q.append(child)
print(lineage_cost)

2818.391986433554
