In [53]:
import numpy as np
import glob
import matplotlib.pyplot as plt
import cv2
from skimage import measure, segmentation, feature
from vis_utils import load_volume, VolumeVisualizer, ColorMapVisualizer
from scipy.ndimage import zoom
from skimage.morphology import skeletonize, skeletonize_3d, binary_dilation

from skimage import filters, morphology

from scipy.ndimage.filters import convolve, correlate
from scipy import signal

from skimage.filters import frangi, sato
from skimage.draw import line_nd

from PIL import Image
import pickle

In [4]:
source_dir = './data/P12/'
skeleton = np.load(source_dir + 'skeleton.npy')
thiccness_map = np.load(source_dir + 'thiccness-map.npy')

# !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
skeleton = np.pad(skeleton, 1)
thiccness_map = np.pad(thiccness_map, 1)

In [45]:
def visualize_addition(base, base_with_addition):
    base = (base.copy() > 0).astype(np.uint8)
    addition = (base_with_addition > 0).astype(np.uint8)
    addition[base == 1] = 0
    ColorMapVisualizer(base + addition * 4).visualize()
    
def visualize_lsd(lsd_mask):
    ColorMapVisualizer(lsd_mask.astype(np.uint8)).visualize()
    
def visualize_gradient(lsd_mask):
    ColorMapVisualizer(lsd_mask.astype(np.uint8)).visualize(gradient=True)
    
def visualize_mask_bin(mask):
    VolumeVisualizer((mask > 0).astype(np.uint8), binary=True).visualize()
    
def visualize_mask_non_bin(mask):
    VolumeVisualizer((mask > 0).astype(np.uint8) * 255, binary=False).visualize()
    
def visualize_skeleton(mask, visualize_mask=True, visualize_both_versions=False):
    skeleton = skeletonize((mask > 0).astype(np.uint8))
    if not visualize_mask or visualize_both_versions:
        VolumeVisualizer(skeleton, binary=True).visualize()
    if visualize_mask or visualize_both_versions:
        skeleton = skeleton.astype(np.uint8) * 4
        mask = (mask > 0).astype(np.uint8) * 3
        mask[skeleton != 0] = 0
        ColorMapVisualizer(skeleton + mask).visualize()

def visualize_ultimate(lsd, base_mask):
    visualize_lsd(lsd)
    visualize_mask_non_bin(lsd)
    visualize_addition(base_mask, lsd)
    visualize_skeleton(lsd, visualize_mask=True)

In [5]:
def trim_skeleton(skeleton):   
    padded_skeleton = np.pad(skeleton, 1)
    new_skeleton = np.zeros(padded_skeleton.shape)
    queue = [tuple(np.argwhere(padded_skeleton)[0])]
    new_skeleton[queue[0]] == -1;
    
    while(len(queue) > 0):
        x, y, z = queue.pop(0)
        
        for dx in [-1, 0, 1]:
            for dy in [-1, 0, 1]:
                for dz in [-1, 0, 1]:
                    if dx == dy == dz == 0:
                        continue
                    
                    neighbour_x = x + dx
                    neighbour_y = y + dy
                    neighbour_z = z + dz
                    if padded_skeleton[neighbour_x, neighbour_y, neighbour_z] == 0:
                        continue
                        
                    if new_skeleton[neighbour_x, neighbour_y, neighbour_z] == 0:
                        queue.append((neighbour_x, neighbour_y, neighbour_z))
                        new_skeleton[neighbour_x, neighbour_y, neighbour_z] = 2;
                        new_skeleton[x, y, z] = 1
                        
    return (new_skeleton[1:-1, 1:-1, 1:-1] == 1).astype(np.uint8)

def mark_bifurcation_regions(skeleton):
    
    padded_skeleton = np.pad(skeleton, 1)
    bifurcations_map = np.zeros(padded_skeleton.shape)
    
    for skeleton_voxel in np.argwhere(padded_skeleton > 0):
        x, y, z = tuple(skeleton_voxel)
        kernel_radius = 1
        kernel = np.ones((3, 3, 3))
        kernel[1, 1, 1] = 0
        
        skeleton_slice = padded_skeleton[
            x-kernel_radius:x+kernel_radius + 1,
            y-kernel_radius:y+kernel_radius + 1,
            z-kernel_radius:z+kernel_radius + 1
        ]
        
        intersections = (skeleton_slice > 0) * kernel
        bifurcations_map[x, y, z] = np.sum(intersections)
        
    return (bifurcations_map[1:-1, 1:-1, 1:-1] > 2).astype(np.uint8)

def mark_leaves(skeleton):
    trimmed = trim_skeleton(skeleton)
    leaves = skeleton - trimmed
    return leaves

def mark_nodes(skeleton):
    bifurcation_map = mark_bifurcation_regions(skeleton)
    leaves_map = mark_leaves(skeleton)
    return bifurcation_map + leaves_map

In [6]:
class Node:
    def __init__(self, coords, **kwargs):
        self.coords = coords
        self.data = kwargs
        self.edges = []
        
    def add_edge(self, edge):
        self.edges.append(edge)
        
    def __repr__(self):
        return f'Node {str(self.coords)}'
        
class Edge:
    def __init__(self, node_a, node_b, **kwargs):
        self.node_a = node_a
        self.node_b = node_b
        self.data = kwargs
        
    def __repr__(self):
        return f'Edge {self.node_a.coords} -> {self.node_b.coords}'

In [41]:
def construct_graph(skeleton, thiccness_map, nodes_dict, root_min_thiccness):
    
    bifurcations_map = mark_bifurcation_regions(skeleton)
    leaves_map = mark_leaves(skeleton)
    nodes_map = bifurcations_map + leaves_map
    
    # 1 node object for each pixel of bifurcation
    labels = measure.label(nodes_map)
    nodes_per_label = dict()
    for i in range(1, np.max(labels) + 1):
        nodes_per_label[i] = Node((-1, -1, -1))
    nodes_dict = dict()
    for node_pixel_coords in np.argwhere(nodes_map):
        node_object = nodes_per_label[labels[tuple(node_pixel_coords)]]
        nodes_dict[tuple(node_pixel_coords)] = node_object
        if node_object.coords == (-1, -1, -1): # TODO nie rob tego tutaj, tylko potem ustal na srodek ciezkosci
            node_object.coords = tuple(node_pixel_coords)
            node_object.thiccness = thiccness_map[tuple(node_pixel_coords)]
        
    root_voxels = np.argwhere(leaves_map * (thiccness_map > root_min_thiccness))
    if root_voxels.shape[0] > 1:
        raise Exception("multiple potential roots found, consider thiccening the root")
        
    root_coords = tuple(root_voxels[0])
    queue = [(root_coords, root_coords)]
    
    edges = []
    
    visited = np.zeros(skeleton.shape, dtype=np.bool)
    visited[root_coords] = True;
    
    while(len(queue) > 0):
        coords, last_node_coords = queue.pop(0)
        x, y, z = coords
        
        if (nodes_map[x, y, z] == 1) and (nodes_dict[coords] != nodes_dict[last_node_coords]):
            last_node = nodes_dict[last_node_coords]
            current_node = nodes_dict[coords]
            edges.append(Edge(last_node, current_node))
            last_node_coords = coords
        
        for dx in [-1, 0, 1]:
            for dy in [-1, 0, 1]:
                for dz in [-1, 0, 1]:
                    if dx == dy == dz == 0:
                        continue
                    
                    neighbour_x = x + dx
                    neighbour_y = y + dy
                    neighbour_z = z + dz
                    if visited[neighbour_x, neighbour_y, neighbour_z]:
                        continue
                        
                    if skeleton[neighbour_x, neighbour_y, neighbour_z] == 1:
                        queue.append(((neighbour_x, neighbour_y, neighbour_z), last_node_coords))
                        visited[neighbour_x, neighbour_y, neighbour_z] = True
                        
    return edges

In [42]:
%%time
edges = construct_graph(skeleton, thiccness_map, nodes_dict, root_min_thiccness=6)

CPU times: user 10.3 s, sys: 616 ms, total: 10.9 s
Wall time: 10.9 s


In [43]:
edges

[Edge (395, 79, 279) -> (382, 90, 282),
 Edge (382, 90, 282) -> (316, 90, 353),
 Edge (382, 90, 282) -> (305, 122, 243),
 Edge (316, 90, 353) -> (320, 106, 364),
 Edge (320, 106, 364) -> (317, 110, 370),
 Edge (305, 122, 243) -> (295, 120, 228),
 Edge (316, 90, 353) -> (293, 94, 368),
 Edge (317, 110, 370) -> (316, 114, 367),
 Edge (317, 110, 370) -> (319, 117, 378),
 Edge (316, 114, 367) -> (317, 121, 369),
 Edge (293, 94, 368) -> (284, 102, 361),
 Edge (316, 114, 367) -> (307, 117, 362),
 Edge (319, 117, 378) -> (317, 123, 385),
 Edge (317, 121, 369) -> (326, 125, 362),
 Edge (319, 117, 378) -> (325, 131, 368),
 Edge (320, 106, 364) -> (348, 115, 343),
 Edge (307, 117, 362) -> (321, 116, 357),
 Edge (307, 117, 362) -> (294, 116, 360),
 Edge (307, 117, 362) -> (309, 117, 348),
 Edge (284, 102, 361) -> (273, 117, 363),
 Edge (348, 115, 343) -> (349, 111, 343),
 Edge (293, 94, 368) -> (268, 99, 391),
 Edge (309, 117, 348) -> (308, 119, 343),
 Edge (325, 131, 368) -> (318, 139, 364),
 Ed

In [44]:
len(edges)

2324

## graph visualisation

In [76]:
def draw_graph(edges, shape='auto', lsd=False):
    if shape == 'auto':
        max_x, max_y, max_z = (0, 0, 0)
        for i, edge in enumerate(edges):
            max_x, max_y, max_z = tuple(np.maximum([max_x, max_y, max_z], edge.node_a.coords))
            max_x, max_y, max_z = tuple(np.maximum([max_x, max_y, max_z], edge.node_b.coords))
        shape = (max_x + 1, max_y + 1, max_z + 1)

    graph = np.zeros(shape)

    for i, edge in enumerate(edges):
        graph[line_nd(edge.node_a.coords, edge.node_b.coords)] = i
        
    return graph if lsd else graph > 0

In [79]:
graph = draw_graph(edges, shape = skeleton.shape)

In [80]:
visualize_mask_bin(graph)

In [82]:
volume = np.pad(load_volume('./data/P12/P12_60um_1333x443x864.raw', scale=0.5), 1)
volume = volume > 70

In [86]:
graph = draw_graph(edges, shape = volume.shape)

In [88]:
visualize_addition(volume, graph)

In [84]:
volume.shape

(434, 224, 668)

In [85]:
skeleton.shape

(413, 199, 649)