In [1]:
import numpy as np
import glob
import matplotlib.pyplot as plt
import cv2
from skimage import measure, segmentation, feature
from vis_utils import load_volume, VolumeVisualizer, ColorMapVisualizer
from scipy.ndimage import zoom
from skimage.morphology import skeletonize, skeletonize_3d, binary_dilation
from skimage import filters, morphology
from scipy.ndimage.filters import convolve, correlate
from scipy import signal
from skimage.filters import frangi, sato
from skimage.draw import line_nd
from PIL import Image
import pickle
from queue import PriorityQueue

In [2]:
TREE_NAME = 'P12'

## Loading skeleton and thiccness_map

In [3]:
source_dir = './data/'
skeleton = np.load(source_dir + TREE_NAME + '/skeleton.npy')
thiccness_map = np.load(source_dir + TREE_NAME + '/thiccness-map.npy')

## Utility visualisation functions

In [4]:
def visualize_addition(base, base_with_addition):
    base = (base.copy() > 0).astype(np.uint8)
    addition = (base_with_addition > 0).astype(np.uint8)
    addition[base == 1] = 0
    ColorMapVisualizer(base + addition * 4).visualize()
    
def visualize_lsd(lsd_mask):
    ColorMapVisualizer(lsd_mask.astype(np.uint8)).visualize()
    
def visualize_gradient(lsd_mask):
    ColorMapVisualizer(lsd_mask.astype(np.uint8)).visualize(gradient=True)
    
def visualize_mask_bin(mask):
    VolumeVisualizer((mask > 0).astype(np.uint8), binary=True).visualize()
    
def visualize_mask_non_bin(mask):
    VolumeVisualizer((mask > 0).astype(np.uint8) * 255, binary=False).visualize()
    
def visualize_skeleton(mask, visualize_mask=True, visualize_both_versions=False):
    skeleton = skeletonize((mask > 0).astype(np.uint8))
    if not visualize_mask or visualize_both_versions:
        VolumeVisualizer(skeleton, binary=True).visualize()
    if visualize_mask or visualize_both_versions:
        skeleton = skeleton.astype(np.uint8) * 4
        mask = (mask > 0).astype(np.uint8) * 3
        mask[skeleton != 0] = 0
        ColorMapVisualizer(skeleton + mask).visualize()

def visualize_ultimate(lsd, base_mask):
    visualize_lsd(lsd)
    visualize_mask_non_bin(lsd)
    visualize_addition(base_mask, lsd)
    visualize_skeleton(lsd, visualize_mask=True)

## Resolving nodes mask

### Resolving leaves mask

In [5]:
def trim_skeleton(skeleton):   
    new_skeleton = np.zeros(skeleton.shape)
    skeleton_voxels = np.argwhere(skeleton)
    
    for voxel in skeleton_voxels:
        x, y, z = tuple(voxel)
        neighbours_count = 0
        
        for dx in [-1, 0, 1]:
            for dy in [-1, 0, 1]:
                for dz in [-1, 0, 1]:
                    if dx == dy == dz == 0:
                        continue
                    
                    neighbour_x = x + dx
                    neighbour_y = y + dy
                    neighbour_z = z + dz
                    if skeleton[neighbour_x, neighbour_y, neighbour_z] > 0:
                        neighbours_count += 1
                        
        if neighbours_count > 1:
            new_skeleton[x, y, z] = 1
                        
    return new_skeleton.astype(np.uint8)


def mark_leaves(skeleton):
    trimmed = trim_skeleton(skeleton)
    leaves = skeleton - trimmed
    return leaves

In [6]:
leaves_mask = mark_leaves(skeleton)

### resolving bifurcations mask

In [7]:
def mark_bifurcation_regions(skeleton):
    padded_skeleton = np.pad(skeleton, 1)
    bifurcations_map = np.zeros(padded_skeleton.shape)
    
    for skeleton_voxel in np.argwhere(padded_skeleton > 0):
        x, y, z = tuple(skeleton_voxel)
        kernel_radius = 1
        kernel = np.ones((3, 3, 3))
        kernel[1, 1, 1] = 0
        
        skeleton_slice = padded_skeleton[
            x-kernel_radius:x+kernel_radius + 1,
            y-kernel_radius:y+kernel_radius + 1,
            z-kernel_radius:z+kernel_radius + 1
        ]
        
        intersections = (skeleton_slice > 0) * kernel
        bifurcations_map[x, y, z] = np.sum(intersections)
        
    return (bifurcations_map[1:-1, 1:-1, 1:-1] > 2).astype(np.uint8)


def mark_nodes(skeleton):
    bifurcation_map = mark_bifurcation_regions(skeleton)
    leaves_map = mark_leaves(skeleton)
    return bifurcation_map + leaves_map

In [8]:
bifurcations_mask = mark_bifurcation_regions(skeleton)
nodes_mask = ((bifurcations_mask + leaves_mask) > 0).astype(np.uint8)

## Constructing graph

In [9]:
class Node:
    def __init__(self, coords):
        self.coords = coords
        self.edges = []
        self.data = {}
            
    def add_edge(self, edge):
        self.edges.append(edge)
        
    def get_neighbours(self):
        return [e.node_a if e.node_a.coords != self.coords else e.node_b for e in self.edges]
    
    def copy_without_edges(self):
        copied_node = Node(self.coords)
        copied_node.data = self.data
        return copied_node
    
    def __setitem__(self, key, value):
        self.data[key] = value
    
    def __getitem__(self, key):
        return self.data[key]
    
    def __hash__(self):
        return hash(self.coords)
    
    def __repr__(self):
        return f'Node {str(self.coords)}'
        
        
class Edge:
    def __init__(self, node_a, node_b):
        self.node_a = node_a
        self.node_b = node_b
        self.data = {}
        
    def __setitem__(self, key, value):
        self.data[key] = value
    
    def __getitem__(self, key):
        return self.data[key]
    
    def __repr__(self):
        return f'Edge {self.node_a.coords} -> {self.node_b.coords}'

In [10]:
def construct_graph(skeleton, nodes_mask, thiccness_map):
    nodes_labels = measure.label(nodes_mask)
    nodes_props = measure.regionprops(nodes_labels)
    print('nodes found (regions on nodes mask):', nodes_labels.max())
    voxel_to_node = dict()
    
    for props in nodes_props:
        if props.label < 1:
            continue
            
        node = Node(tuple(props.coords[0]))
        node['voxels'] = props.coords
        node['thiccness'] = thiccness_map[tuple(props.coords[0])]
        
        for c in props.coords:
            voxel_to_node[tuple(c)] = node
            
    edges_mask = skeleton - nodes_mask
    edges_labels = measure.label(edges_mask > 0)
    print('edges found:', edges_labels.max())
    
    visited = np.zeros(skeleton.shape, dtype=np.bool)
    
    def find_touching_nodes(source_voxel):
        touching_nodes = set()
        queue = [source_voxel]
        while len(queue) > 0:
            x, y, z = queue.pop(0)
            
            for dx in [-1, 0, 1]:
                for dy in [-1, 0, 1]:
                    for dz in [-1, 0, 1]:
                        if dx == dy == dz == 0:
                            continue

                        neighbour_x = x + dx
                        neighbour_y = y + dy
                        neighbour_z = z + dz
                        if visited[neighbour_x, neighbour_y, neighbour_z]:
                            continue
                            
                        potential_node = voxel_to_node.get((neighbour_x, neighbour_y, neighbour_z))
                        if potential_node is not None:
                            touching_nodes.add(potential_node)

                        if edges_mask[neighbour_x, neighbour_y, neighbour_z] == 1:
                            queue.append((neighbour_x, neighbour_y, neighbour_z))
                            visited[neighbour_x, neighbour_y, neighbour_z] = True
        return list(touching_nodes)
       
        
    edges_props = measure.regionprops(edges_labels)
    
    edges = []
    bad_edges = []
    for props in edges_props:
        edge_voxel = props.coords[0]
        touching_nodes = find_touching_nodes(edge_voxel)
        if len(touching_nodes) != 2:
            print(f'bad edge found! touching nodes count: {len(touching_nodes)}')
            bad_edges.append(props.coords)
            continue
            
        edge = Edge(touching_nodes[0], touching_nodes[1])
        edge['voxels'] = props.coords
        edges.append(edge)
    return edges, bad_edges

In [11]:
%%time
edges, bad_edges = construct_graph(skeleton, nodes_mask, thiccness_map)
print("Number of bad edged found:", len(bad_edges))

nodes found (regions on nodes mask): 2290
edges found: 2339
bad edge found! touching nodes count: 1
bad edge found! touching nodes count: 1
bad edge found! touching nodes count: 1
bad edge found! touching nodes count: 1
bad edge found! touching nodes count: 1
bad edge found! touching nodes count: 1
bad edge found! touching nodes count: 1
bad edge found! touching nodes count: 1
bad edge found! touching nodes count: 1
bad edge found! touching nodes count: 1
bad edge found! touching nodes count: 1
bad edge found! touching nodes count: 1
bad edge found! touching nodes count: 1
bad edge found! touching nodes count: 1
Number of bad edged found: 14
CPU times: user 5.72 s, sys: 908 ms, total: 6.63 s
Wall time: 6.63 s


## Cleaning DAG

### finding root node

In [12]:
def convert_to_nodes_list(edges):
    nodes = {}
    for e in edges:
        nodes[e.node_a] = e.node_a.copy_without_edges()
        nodes[e.node_b] = e.node_b.copy_without_edges()
        
    for e in edges:
        new_edge = Edge(nodes[e.node_a], nodes[e.node_b])
        new_edge.data = e.data
        nodes[e.node_a].add_edge(new_edge)
        nodes[e.node_b].add_edge(new_edge)
        
    return list(nodes.values())

def find_tree_root_candidates(nodes, root_degree, thiccness_tolerance):
    proper_degree_nodes = [node for node in nodes if len(node.edges) == root_degree]
    root_thickness = max(map(lambda node: node['thiccness'], proper_degree_nodes))
    root_candidates = [node for node in proper_degree_nodes if 
                       node['thiccness'] >= root_thickness - thiccness_tolerance]
    return root_candidates

def visualize_root(root, skeleton, mark_radius=2):
    visualisation = skeleton.copy().astype(np.uint8)
    for v in root['voxels']:
        x, y, z = tuple(v)
        visualisation[x - mark_radius: x + mark_radius, 
                      y - mark_radius: y + mark_radius, 
                      z - mark_radius: z + mark_radius] = 4
    visualize_lsd(visualisation)

In [13]:
%%time

roots_degrees = {
    'P01': 1,
    'P05': 1,
    'P12': 1,
}

root_thiccness_tolerance = {
    'P01': 0,
    'P05': 4,
    'P12': 0,
}

nodes = convert_to_nodes_list(edges)
root_candidates = find_tree_root_candidates(nodes, roots_degrees.get(TREE_NAME, 1), 
                                            root_thiccness_tolerance.get(TREE_NAME, 0))
print(f'found {len(root_candidates)} root candidate(s)')

candidates_indices = {
    'P01': 0,
    'P05': 6,
    'P12': 0,
}

root = root_candidates[candidates_indices.get(TREE_NAME, 0)]
visualize_root(root, skeleton) # verify whether the proper node was selected

found 1 root candidate(s)
CPU times: user 6 s, sys: 3.54 s, total: 9.55 s
Wall time: 56.1 s


### removing cycles (obtaining DAG)

In [15]:
def remove_dag_cycles(root):
    counter = 0
    
    new_root = root.copy_without_edges()
    coords_to_old_parents = {}
    coords_to_new_node = { new_root.coords: new_root }
    
    queue = PriorityQueue()
    for node in root.get_neighbours():
        coords_to_old_parents[node.coords] = [root]
        queue.put(((-node['thiccness'], counter), node))
        counter += 1
        
    while not queue.empty():
        _, node = queue.get()
        
        if coords_to_new_node.get(node.coords) is not None:
            continue
        
        parent_candidates = coords_to_old_parents[node.coords]
        proper_parent_thiccness = min([p['thiccness'] for p in parent_candidates])
        proper_parent = [p for p in parent_candidates if p['thiccness'] == proper_parent_thiccness][0]
        edge_from_parent = [e for e in proper_parent.edges if e.node_a == node or e.node_b == node][0]
        
        new_node = node.copy_without_edges()
        new_parent = coords_to_new_node[proper_parent.coords]
        new_edge = Edge(new_parent, new_node)
        new_edge.data = edge_from_parent.data
        new_parent.add_edge(new_edge)
        
        coords_to_new_node[new_node.coords] = new_node
        
        for neighbour in node.get_neighbours():
            parents = coords_to_old_parents.get(neighbour.coords, [])
            coords_to_old_parents[neighbour.coords] = parents + [node]
            queue.put(((-neighbour['thiccness'], counter), neighbour))
            counter += 1
            
    return new_root

In [16]:
%%time

clean_root = remove_dag_cycles(root)

CPU times: user 383 ms, sys: 360 µs, total: 383 ms
Wall time: 386 ms


### removing redundant nodes and edges

In [17]:
def merge_edges(a, b, node_a, node_b):
        new_edge = Edge(node_a, node_b)
        new_edge.data = a.data
        new_edge['voxels'] = np.concatenate([a['voxels'], b.node_a['voxels'], b['voxels']])
        return new_edge


def remove_dag_redundant_nodes(root):
    new_root = root.copy_without_edges()
    for edge in root.edges:
        new_neighbour = remove_dag_redundant_nodes(edge.node_b)
        
        if len(new_neighbour.edges) == 1:
            merged_edge = merge_edges(edge, new_neighbour.edges[0], new_root, new_neighbour.edges[0].node_b)
            new_root.add_edge(merged_edge)
            
        else:
            new_edge = Edge(new_root, new_neighbour)
            new_edge.data = edge.data
            new_root.add_edge(new_edge)
            
    return new_root

In [18]:
%%time

clean_root = remove_dag_redundant_nodes(clean_root)

CPU times: user 44.4 ms, sys: 6.41 ms, total: 50.8 ms
Wall time: 48.8 ms


### Obtaining clean nodes and edges

In [19]:
def get_nodes_with_dfs(root):
    nodes = [root]
    for e in root.edges:
        if e.node_a != root:
            print(e)
        
        nodes += get_nodes_with_dfs(e.node_b)
        
    return nodes


def get_edges_with_dfs(root):
    edges = []
    for e in root.edges:
        edges += [e]
        edges += get_edges_with_dfs(e.node_b)
        
    return edges

In [20]:
clean_nodes = get_nodes_with_dfs(clean_root)
clean_edges = get_edges_with_dfs(clean_root)

print(f'# of nodes: {len(clean_nodes)}, # of edges: {len(clean_edges)}')

# of nodes: 5783, # of edges: 5782


## Populating graph with basic metadata

### reordering edges voxels

In [21]:
def reorder_edges_voxels(edge):
    node_voxels = [tuple(voxel) for voxel in edge.node_a['voxels']]
    edge_voxels = [tuple(voxel) for voxel in edge['voxels']]
    all_voxels = node_voxels + edge_voxels
    
    queue = [node_voxels[0]]
    sorted_voxels = [node_voxels[0]]
    
    while len(queue) != 0:
        x, y, z = queue.pop(0)
        
        for dx in [-1, 0, 1]:
            for dy in [-1, 0, 1]:
                for dz in [-1, 0, 1]:
                    if dx == dy == dz == 0:
                        continue
                    
                    neighbour = (x + dx, y + dy, z + dz)
                    
                    if (neighbour in sorted_voxels) or (neighbour not in all_voxels):
                        continue
                    
                    sorted_voxels.append(neighbour)
                    queue.append(neighbour)
                    
    sorted_edge_voxels = [voxel for voxel in sorted_voxels if voxel not in node_voxels]
    edge['voxels'] = sorted_edge_voxels
    
def fix_edges_voxels(root):
    edges = get_edges_with_dfs(root)
    for edge in edges:
        reorder_edges_voxels(edge)

In [22]:
%%time
fix_edges_voxels(clean_root)
print(clean_root['voxels'])
print(clean_root.edges[0]['voxels'])

[[531  77 333]]
[(530, 77, 333), (529, 77, 333), (528, 77, 333), (527, 78, 334), (526, 77, 334), (525, 77, 334), (524, 77, 334), (523, 77, 334), (522, 77, 334), (521, 77, 334), (520, 78, 334), (519, 78, 335), (518, 78, 335), (517, 78, 335), (516, 79, 335), (515, 79, 335), (514, 78, 336), (513, 79, 336), (512, 78, 337), (511, 78, 338), (510, 79, 337), (509, 78, 338), (508, 79, 338), (507, 78, 339), (506, 79, 339), (505, 79, 339), (504, 79, 339), (503, 79, 339), (502, 79, 340), (501, 80, 340), (500, 80, 341), (499, 80, 342), (498, 80, 342), (497, 80, 342), (496, 80, 343), (495, 81, 343), (494, 82, 343)]
CPU times: user 5.83 s, sys: 5.18 ms, total: 5.84 s
Wall time: 5.85 s


### edges and nodes thiccness

In [23]:
def fix_nodes_thiccness(root, thiccness_map):
    nodes = get_nodes_with_dfs(root)
    for node in nodes:
        thiccness_list = [thiccness_map[tuple(coords)] for coords in node['voxels']]
        node['thiccness'] = np.mean(thiccness_list)
    

def add_edges_thiccness(root, thiccness_map):
    edges = get_edges_with_dfs(root)
    for edge in edges:
        thiccness_list = [thiccness_map[tuple(coords)] for coords in edge['voxels']]
        edge['thiccness_list'] = np.array(thiccness_list)
        edge['mean_thiccness'] = np.mean(thiccness_list)

In [24]:
%%time
fix_nodes_thiccness(clean_root, thiccness_map)
add_edges_thiccness(clean_root, thiccness_map)
clean_root.edges[0]['thiccness_list']

CPU times: user 341 ms, sys: 12.1 ms, total: 353 ms
Wall time: 354 ms


array([19., 19., 19., 19., 19., 19., 19., 19., 19., 19., 19., 19., 19.,
       19., 19., 19., 19., 19., 19., 19., 19., 19., 19., 19., 19., 19.,
       19., 19., 19., 19., 19., 19., 19., 19., 19., 19., 19.])

### centroids and edges lengths

In [25]:
def set_nodes_centroids(root):
    nodes = get_nodes_with_dfs(root)
    for node in nodes:
        node['centroid'] = np.mean(node['voxels'], axis=0)

        
def calculate_edge_length(edge, chunk_length=1):
    voxels = np.array(edge['voxels'])
    needed_nans = (chunk_length - (len(voxels) % chunk_length)) % chunk_length
    voxels = np.concatenate([voxels, np.full((needed_nans, 3), np.nan)])
    
    chunked_voxels = voxels.reshape(-1, chunk_length, 3)
    
    edge_centroids = np.nanmean(chunked_voxels, axis=1)
    starting_centroid = edge.node_a['centroid']
    ending_centroid = edge.node_b['centroid']
    
    centroids = np.concatenate([
        starting_centroid[np.newaxis, ...],
        edge_centroids,
        ending_centroid[np.newaxis, ...]
    ])
    
    squared_diffs = np.diff(centroids, axis=0) ** 2
    squared_lengths = np.sum(squared_diffs, axis=1)
    lengths = np.sqrt(squared_lengths)
    total_length = np.sum(lengths)
    
    return total_length
    

def set_edges_length(root, chunk_length=1):
    edges = get_edges_with_dfs(root)
    for edge in edges:
        edge['length'] = calculate_edge_length(edge, chunk_length)

In [26]:
%%time
set_nodes_centroids(clean_root)
set_edges_length(clean_root, 2)
print(clean_root.edges[0]['length'])

42.32046023354219
CPU times: user 907 ms, sys: 11.4 ms, total: 918 ms
Wall time: 910 ms


## Creating DAG object

In [27]:
class DAG:
    def __init__(self, root, volume_shape):
        self.root = root
        self.nodes = get_nodes_with_dfs(root)
        self.edges = get_edges_with_dfs(root)
        self.volume_shape = volume_shape
        self.data = {}
    
    def __setitem__(self, key, value):
        self.data[key] = value
    
    def __getitem__(self, key):
        return self.data[key]
        

def save_dag(dag, filename):
    with open(filename, 'wb') as output:
        pickle.dump(dag, output)
        

def load_dag(filename):
    with open(filename, 'rb') as input_:
        dag = pickle.load(input_)
        return dag

In [28]:
dag = DAG(clean_root, thiccness_map.shape)

## DAG visualization

In [29]:
def spherical_kernel(outer_radius, thickness=1, filled=True):    
    outer_sphere = morphology.ball(radius=outer_radius)
    if filled:
        return outer_sphere
    
    thickness = min(thickness, outer_radius)
    
    inner_radius = outer_radius - thickness
    inner_sphere = morphology.ball(radius=inner_radius)
    
    begin = outer_radius - inner_radius
    end = begin + inner_sphere.shape[0]
    outer_sphere[begin:end, begin:end, begin:end] -= inner_sphere
    return outer_sphere


def print_kernels(image, nodes, value):
    image = image.copy()
    max_kernel_radius = int(max([node['thiccness'] for node in nodes]))
    kernels = [spherical_kernel(radius) for radius in range(max_kernel_radius + 1)]
    
    padded_image = np.pad(image, max_kernel_radius)
    kernels_image = np.zeros(padded_image.shape)
    
    for node in nodes:
        x, y, z = (coord + max_kernel_radius for coord in node.coords)
        kernel_radius = int(node['thiccness'])
        kernel = kernels[kernel_radius]
        
        mask_slice = kernels_image[
            x-kernel_radius:x+kernel_radius + 1,
            y-kernel_radius:y+kernel_radius + 1,
            z-kernel_radius:z+kernel_radius + 1
        ]
        
        mask_slice[:] = np.logical_or(mask_slice, kernel)
            
    kernels_image = kernels_image[
        max_kernel_radius:-max_kernel_radius,
        max_kernel_radius:-max_kernel_radius,
        max_kernel_radius:-max_kernel_radius
    ]
    
    image[kernels_image == 1] = value
    return image


def draw_nodes(image, nodes, value=2):
    nodes_image = print_kernels(image, nodes, value)
    return nodes_image

    
def draw_edges(image, edges, value='mean_thiccness', interpolate=True):
    image = image.copy()

    for i, edge in enumerate(edges):
        if type(value) == str:
            fill_value = edge[value]
        else:
            fill_value = value
        
        if interpolate:
            image[line_nd(edge.node_a.coords, edge.node_b.coords)] = fill_value
        else:
            for v in edge['voxels']:
                image[tuple(v)] = fill_value
        
    return image

def draw_central_line(image, dag):
    image_with_edges = draw_edges(image, dag.edges, value=1, interpolate=False)
    for n in dag.nodes:
        for v in n['voxels']:
            image_with_edges[tuple(v)] = 1
        
    return image_with_edges

In [30]:
visualization = np.zeros(skeleton.shape)
visualization = draw_nodes(visualization, dag.nodes, 25)
visualization = draw_edges(visualization, dag.edges, value='mean_thiccness')
visualize_gradient(visualization)

In [31]:
visualization = np.zeros(skeleton.shape)
visualization = draw_edges(visualization, dag.edges, value='length')
visualize_gradient(visualization)

In [32]:
central_line = draw_central_line(np.zeros(skeleton.shape), dag)
visualize_addition(central_line, skeleton)

## Saving dag

In [33]:
save_dag(dag, source_dir + TREE_NAME + '/dag.pkl')