In [1]:
import numpy as np
import glob
import matplotlib.pyplot as plt
import cv2
from skimage import measure, segmentation, feature
from vis_utils import load_volume, VolumeVisualizer, ColorMapVisualizer
from scipy.ndimage import zoom
from skimage.morphology import skeletonize, skeletonize_3d, binary_dilation

from skimage import filters, morphology

from scipy.ndimage.filters import convolve, correlate
from scipy import signal

from skimage.filters import frangi, sato
from skimage.draw import line_nd

from PIL import Image
import pickle

In [2]:
source_dir = './data/P12/'
skeleton = np.load(source_dir + 'skeleton.npy')
thiccness_map = np.load(source_dir + 'thiccness-map.npy')

# !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
skeleton = np.pad(skeleton, 1)
thiccness_map = np.pad(thiccness_map, 1)

In [3]:
def visualize_addition(base, base_with_addition):
    base = (base.copy() > 0).astype(np.uint8)
    addition = (base_with_addition > 0).astype(np.uint8)
    addition[base == 1] = 0
    ColorMapVisualizer(base + addition * 4).visualize()
    
def visualize_lsd(lsd_mask):
    ColorMapVisualizer(lsd_mask.astype(np.uint8)).visualize()
    
def visualize_gradient(lsd_mask):
    ColorMapVisualizer(lsd_mask.astype(np.uint8)).visualize(gradient=True)
    
def visualize_mask_bin(mask):
    VolumeVisualizer((mask > 0).astype(np.uint8), binary=True).visualize()
    
def visualize_mask_non_bin(mask):
    VolumeVisualizer((mask > 0).astype(np.uint8) * 255, binary=False).visualize()
    
def visualize_skeleton(mask, visualize_mask=True, visualize_both_versions=False):
    skeleton = skeletonize((mask > 0).astype(np.uint8))
    if not visualize_mask or visualize_both_versions:
        VolumeVisualizer(skeleton, binary=True).visualize()
    if visualize_mask or visualize_both_versions:
        skeleton = skeleton.astype(np.uint8) * 4
        mask = (mask > 0).astype(np.uint8) * 3
        mask[skeleton != 0] = 0
        ColorMapVisualizer(skeleton + mask).visualize()

def visualize_ultimate(lsd, base_mask):
    visualize_lsd(lsd)
    visualize_mask_non_bin(lsd)
    visualize_addition(base_mask, lsd)
    visualize_skeleton(lsd, visualize_mask=True)

In [4]:
def trim_skeleton(skeleton):   
    # TODO: fix wczesniejszych skeletonow
    padded_skeleton = np.pad(skeleton, 1)
    new_skeleton = np.zeros(padded_skeleton.shape)
    queue = [tuple(np.argwhere(padded_skeleton)[0])]
    
    while(len(queue) > 0):
        x, y, z = queue.pop(0)
        
        for dx in [-1, 0, 1]:
            for dy in [-1, 0, 1]:
                for dz in [-1, 0, 1]:
                    if dx == dy == dz == 0:
                        continue
                    
                    neighbour_x = x + dx
                    neighbour_y = y + dy
                    neighbour_z = z + dz
                    if padded_skeleton[neighbour_x, neighbour_y, neighbour_z] == 0:
                        continue
                        
                    if new_skeleton[neighbour_x, neighbour_y, neighbour_z] == 0:
                        queue.append((neighbour_x, neighbour_y, neighbour_z))
                        new_skeleton[neighbour_x, neighbour_y, neighbour_z] += 1;
                        new_skeleton[x, y, z] += 1
                        
    return (new_skeleton[1:-1, 1:-1, 1:-1] > 1).astype(np.uint8)



def mark_bifurcation_regions(skeleton):
    
    padded_skeleton = np.pad(skeleton, 1)
    bifurcations_map = np.zeros(padded_skeleton.shape)
    
    for skeleton_voxel in np.argwhere(padded_skeleton > 0):
        x, y, z = tuple(skeleton_voxel)
        kernel_radius = 1
        kernel = np.ones((3, 3, 3))
        kernel[1, 1, 1] = 0
        
        skeleton_slice = padded_skeleton[
            x-kernel_radius:x+kernel_radius + 1,
            y-kernel_radius:y+kernel_radius + 1,
            z-kernel_radius:z+kernel_radius + 1
        ]
        
        intersections = (skeleton_slice > 0) * kernel
        bifurcations_map[x, y, z] = np.sum(intersections)
        
    return (bifurcations_map[1:-1, 1:-1, 1:-1] > 2).astype(np.uint8)

def mark_leaves(skeleton):
    trimmed = trim_skeleton(skeleton)
    leaves = skeleton - trimmed
    return leaves

def mark_nodes(skeleton):
    bifurcation_map = mark_bifurcation_regions(skeleton)
    leaves_map = mark_leaves(skeleton)
    return bifurcation_map + leaves_map

In [23]:
class Node:
    def __init__(self, coords, **kwargs):
        self.coords = coords
        self.data = kwargs
        self.edges = []
        
    def __hash__(self):
        return hash(self.coords)
        
    def add_edge(self, edge):
        self.edges.append(edge)
        
    def __repr__(self):
        return f'Node {str(self.coords)}'
    
    def get_neighbours(self):
        return [e.node_a if e.node_a != self else e.node_b for e in self.edges]
        
class Edge:
    def __init__(self, node_a, node_b, **kwargs):
        self.node_a = node_a
        self.node_b = node_b
        self.data = kwargs
        
    def __repr__(self):
        return f'Edge {self.node_a.coords} -> {self.node_b.coords}'

In [24]:
def construct_graph(skeleton, thiccness_map, root_min_thiccness):
    
    bifurcations_map = mark_bifurcation_regions(skeleton)
    leaves_map = mark_leaves(skeleton)
    nodes_map = bifurcations_map + leaves_map
    
    # 1 node object for each pixel of bifurcation
    labels = measure.label(nodes_map)
#     regionprops = measure.regionprops(labels)
    
#     for props in regionprops:
#         if props.label < 1:
#             continue
            
    print('regions found:', labels.max())
        
        
    nodes_per_label = dict()
    for i in range(1, np.max(labels) + 1):
        nodes_per_label[i] = Node((-1, -1, -1))
        
    nodes_dict = dict()
    for node_pixel_coords in np.argwhere(nodes_map):
        node_object = nodes_per_label[labels[tuple(node_pixel_coords)]]
        nodes_dict[tuple(node_pixel_coords)] = node_object
        if node_object.coords == (-1, -1, -1): # TODO nie rob tego tutaj, tylko potem ustal na srodek ciezkosci
            node_object.coords = tuple(node_pixel_coords)
            node_object.thiccness = thiccness_map[tuple(node_pixel_coords)]
        
    root_voxels = np.argwhere(leaves_map * (thiccness_map > root_min_thiccness))
    if root_voxels.shape[0] > 1:
        raise Exception("multiple potential roots found, consider thiccening the root")
        
    root_coords = tuple(root_voxels[0])
    queue = [(root_coords, root_coords)]
    
    edges = []
    
    visited = np.zeros(skeleton.shape, dtype=np.bool)
    visited[root_coords] = True;
    
    while(len(queue) > 0):
        coords, last_node_coords = queue.pop(0)
        x, y, z = coords
        
        if (nodes_map[x, y, z] == 1) and (nodes_dict[coords] != nodes_dict[last_node_coords]):
            last_node = nodes_dict[last_node_coords]
            current_node = nodes_dict[coords]
            edges.append(Edge(last_node, current_node))
            last_node_coords = coords
        
        for dx in [-1, 0, 1]:
            for dy in [-1, 0, 1]:
                for dz in [-1, 0, 1]:
                    if dx == dy == dz == 0:
                        continue
                    
                    neighbour_x = x + dx
                    neighbour_y = y + dy
                    neighbour_z = z + dz
                    if visited[neighbour_x, neighbour_y, neighbour_z]:
                        continue
                        
                    if skeleton[neighbour_x, neighbour_y, neighbour_z] == 1:
                        queue.append(((neighbour_x, neighbour_y, neighbour_z), last_node_coords))
                        visited[neighbour_x, neighbour_y, neighbour_z] = True
                        
    return edges

In [25]:
def construct_graph_xd(skeleton, thiccness_map, root_min_thiccness):
    
    bifurcations_map = mark_bifurcation_regions(skeleton)
    leaves_map = mark_leaves(skeleton)
    nodes_map = ((bifurcations_map + leaves_map) > 0).astype(np.uint8)
    
    nodes_labels = measure.label(nodes_map)
    nodes_props = measure.regionprops(nodes_labels)
    print('regions found:', nodes_labels.max())
    
    voxel_to_node = dict()
    
    for props in nodes_props:
        if props.label < 1:
            continue
            
        node = Node(coords=tuple(props.coords[0]))
        node.voxels = props.coords
        node.label = props.label
        node.thiccness = thiccness_map[tuple(props.coords[0])]
        
        for c in props.coords:
            voxel_to_node[tuple(c)] = node
            
    edges_mask = skeleton - nodes_map
    edges_labels = measure.label(edges_mask > 0)
    print('edges found:', edges_labels.max())
    
    visited = np.zeros(skeleton.shape, dtype=np.bool)
    
    def find_touching_nodes(source_voxel):
        touching_nodes = set()
        queue = [source_voxel]
        
        while len(queue) > 0:
            x, y, z = queue.pop(0)
            
            for dx in [-1, 0, 1]:
                for dy in [-1, 0, 1]:
                    for dz in [-1, 0, 1]:
                        if dx == dy == dz == 0:
                            continue

                        neighbour_x = x + dx
                        neighbour_y = y + dy
                        neighbour_z = z + dz
                        if visited[neighbour_x, neighbour_y, neighbour_z]:
                            continue
                            
                        potential_node = voxel_to_node.get((neighbour_x, neighbour_y, neighbour_z))
                        if potential_node is not None:
                            touching_nodes.add(potential_node)

                        if edges_mask[neighbour_x, neighbour_y, neighbour_z] == 1:
                            queue.append((neighbour_x, neighbour_y, neighbour_z))
                            visited[neighbour_x, neighbour_y, neighbour_z] = True
                            
        return list(touching_nodes)
       
        
    edges_props = measure.regionprops(edges_labels)
    
    edges = []
    for props in edges_props:
        edge_voxel = props.coords[0]
        touching_nodes = find_touching_nodes(edge_voxel)
        if len(touching_nodes) != 2:
            print(f'something went wrong! touching nodes count: {len(touching_nodes)}')
            continue
            
        edge = Edge(touching_nodes[0], touching_nodes[1])
        edge.voxels = props.coords
        edges.append(edge)
        
    return edges

In [26]:
%%time
edges = construct_graph_xd(skeleton, thiccness_map, root_min_thiccness=6)

regions found: 2300
edges found: 2338
something went wrong! touching nodes count: 1
something went wrong! touching nodes count: 1
something went wrong! touching nodes count: 1
something went wrong! touching nodes count: 1
something went wrong! touching nodes count: 1
something went wrong! touching nodes count: 1
something went wrong! touching nodes count: 1
CPU times: user 8.18 s, sys: 913 ms, total: 9.1 s
Wall time: 9.21 s


In [27]:
edges_dict = {}
for e in edges:
    def universal_hash(node_a, node_b):
        hash_a = hash(node_a)
        hash_b = hash(node_b)
        if hash_a < hash_b:
            return hash((hash_a, hash_b))
        
        return hash((hash_b, hash_a))
        
    current_val = edges_dict.get(universal_hash(e.node_a, e.node_b), 0)
    edges_dict[universal_hash(e.node_a, e.node_b)] = current_val + 1

In [28]:
np_counters = np.array(list(edges_dict.values()))
np.sum(np_counters == 2)

4

## graph visualisation

In [29]:
def draw_graph(edges, shape='auto', lsd=False):
    if shape == 'auto':
        max_x, max_y, max_z = (0, 0, 0)
        for i, edge in enumerate(edges):
            max_x, max_y, max_z = tuple(np.maximum([max_x, max_y, max_z], edge.node_a.coords))
            max_x, max_y, max_z = tuple(np.maximum([max_x, max_y, max_z], edge.node_b.coords))
        shape = (max_x + 1, max_y + 1, max_z + 1)

    graph = np.zeros(shape)

    for i, edge in enumerate(edges):
        graph[line_nd(edge.node_a.coords, edge.node_b.coords)] = i
        
    return graph if lsd else graph > 0

In [30]:
graph = draw_graph(edges, shape = skeleton.shape)

In [13]:
visualize_mask_bin(graph)

In [31]:
def get_main_regions(binary_mask, min_size=10_000, connectivity=3):
    labeled = measure.label(binary_mask, connectivity=connectivity)
    region_props = measure.regionprops(labeled)
    
    main_regions_masks = []
    regions_labels = []
    bounding_boxes = []
    
    for props in region_props:
        if props.area >= min_size:
            main_regions_masks.append(props.filled_image)
            regions_labels.append(props.label)
            bounding_boxes.append(props.bbox)
            
    return main_regions_masks, regions_labels, bounding_boxes

volume = np.pad(load_volume('./data/P12/P12_60um_1333x443x864.raw', scale=0.5), 1)
volume = volume > 70
main_regions = get_main_regions(volume)
mask_main = np.pad(main_regions[0][0].astype(np.uint8), 1)

In [32]:
visualize_addition(graph, mask_main)

## graph cleanup

In [35]:
def remove_redundant_edges(edges):
    def universal_hash(node_a, node_b):
        hash_a = hash(node_a)
        hash_b = hash(node_b)
        if hash_a < hash_b:
            return hash((hash_a, hash_b))

        return hash((hash_b, hash_a))

    edges_dict = {}
    
    for e in edges:
        edge_hash = universal_hash(e.node_a, e.node_b)
        edges_dict[edge_hash] = e
        
    return list(edges_dict.values())
        
        
def convert_to_nodes_list(edges):
    nodes = {}
    for e in edges:
        nodes[e.node_a] = e.node_a
        nodes[e.node_b] = e.node_b
        
    for e in edges:
        nodes[e.node_a].add_edge(e)
        nodes[e.node_b].add_edge(e)
        
    return nodes
        

def find_nearest_proper_node(source, edge):
    neighbours = len(source.edges)
    if neighbours != 2:
        return source
    
    new_edge = source.edges[0] if source.edges[0] != edge else source.edges[1]
    new_node = new_edge.node_a if new_edge.node_a != source else new_edge.node_b
    return find_nearest_proper_node(new_node, new_edge)


def remove_redundant_nodes(root, parent=None):
    
    new_root = root.copy()
    new_root.neighbours = []
    
    queue = [root]
    while len(queue) > 0:
        node = queue.pop(0)
    
        for e in node.edges:
            neighbour = e.node_a if e.node_a != node else e.node_b
            proper_neighbour = find_nearest_proper_node(neighbour, e)

            if proper_neighbour in node.neighbours:
                continue
                
            proper_node = proper_neighbour.copy()
            proper_node.neighbours = [node]
            node.neighbours.append(proper_node)
            
    return new_root
        

In [34]:
nodes = convert_to_nodes_list(edges)

In [36]:
root = None
for n in nodes:
    if len(n.edges) == 1:
        root = n
        break
        
new_root = remove_redundant_nodes(root)

AttributeError: 'Node' object has no attribute 'copy'

In [None]:
2324