In [1]:
import numpy as np
import glob
import matplotlib.pyplot as plt
import cv2
from skimage import measure, segmentation, feature
from vis_utils import load_volume, VolumeVisualizer, ColorMapVisualizer
from scipy.ndimage import zoom
from skimage.morphology import skeletonize, skeletonize_3d, binary_dilation

from skimage import filters, morphology

from scipy.ndimage.filters import convolve, correlate
from scipy import signal

from skimage.filters import frangi, sato

from PIL import Image
import pickle

In [2]:
source_dir = './data/*/'
files = list(sorted(glob.glob(source_dir + '*.raw')))
list(enumerate(files))

[(0, './data/P01/P01_60um_1612x623x1108.raw'),
 (1, './data/P02/P02_60um_1387x778x1149.raw'),
 (2, './data/P03/P03_60um_1473x1163x1148.raw'),
 (3, './data/P04/P04_60um_1273x466x1045.raw'),
 (4, './data/P05/P05_60um_1454x817x1102.raw'),
 (5, './data/P06/P06_60um_1425x564x1028.raw'),
 (6, './data/P07/P7_60um_1216x692x926.raw'),
 (7, './data/P08/P08_60um_1728x927x1149.raw'),
 (8, './data/P09/P09_60um_1359x456x1040.raw'),
 (9, './data/P10/P10_60um_1339x537x1035.raw'),
 (10, './data/P11/P11_60um_1735x595x1150.raw'),
 (11, './data/P12/P12_60um_1333x443x864.raw'),
 (12, './data/P13/P13_60um_1132x488x877.raw')]

In [7]:
%%time
mask = load_volume(files[1], scale=0.5)
mask.shape

CPU times: user 1.33 s, sys: 399 ms, total: 1.73 s
Wall time: 2.52 s


(432, 222, 666)

In [9]:
threshold = 70
mask = mask > threshold
VolumeVisualizer(mask, binary=True).visualize()

## Utility functions

### Visualisation functions

In [3]:
def visualize_addition(base, base_with_addition):
    base = (base.copy() > 0).astype(np.uint8)
    addition = (base_with_addition > 0).astype(np.uint8)
    addition[base == 1] = 0
    ColorMapVisualizer(base + addition * 4).visualize()
    
def visualize_lsd(lsd_mask):
    ColorMapVisualizer(lsd_mask.astype(np.uint8)).visualize()
    
def visualize_gradient(lsd_mask):
    ColorMapVisualizer(lsd_mask.astype(np.uint8)).visualize(gradient=True)
    
def visualize_mask_bin(mask):
    VolumeVisualizer((mask > 0).astype(np.uint8), binary=True).visualize()
    
def visualize_mask_non_bin(mask):
    VolumeVisualizer((mask > 0).astype(np.uint8) * 255, binary=False).visualize()
    
def visualize_skeleton(mask, visualize_mask=True, visualize_both_versions=False):
    skeleton = skeletonize((mask > 0).astype(np.uint8))
    if not visualize_mask or visualize_both_versions:
        VolumeVisualizer(skeleton, binary=True).visualize()
    if visualize_mask or visualize_both_versions:
        skeleton = skeleton.astype(np.uint8) * 4
        mask = (mask > 0).astype(np.uint8) * 3
        mask[skeleton != 0] = 0
        ColorMapVisualizer(skeleton + mask).visualize()

def visualize_ultimate(lsd, base_mask):
    visualize_lsd(lsd)
    visualize_mask_non_bin(lsd)
    visualize_addition(base_mask, lsd)
    visualize_skeleton(lsd, visualize_mask=True)

In [4]:
def load_lsd_trees(filename):
    with open(filename, 'rb') as f:
        lsd_trees = pickle.load(f)
    return lsd_trees

def save_lsd_trees(lsd_trees, filename):
    with open(filename, 'wb') as f:
        pickle.dump(lsd_trees, f)

In [20]:
def spherical_kernel(outer_radius, thickness=1, filled=True):    
    outer_sphere = morphology.ball(radius=outer_radius)
    if filled:
        return outer_sphere
    
    thickness = min(thickness, outer_radius)
    
    inner_radius = outer_radius - thickness
    inner_sphere = morphology.ball(radius=inner_radius)
    
    begin = outer_radius - inner_radius
    end = begin + inner_sphere.shape[0]
    outer_sphere[begin:end, begin:end, begin:end] -= inner_sphere
    return outer_sphere

def convolve_with_ball(img, ball_radius, dtype=np.uint16, normalize=True):
    kernel = spherical_kernel(ball_radius, filled=True)
    convolved = signal.convolve(img.astype(dtype), kernel.astype(dtype), mode='same')
    
    if not normalize:
        return convolved
    
    return (convolved / kernel.sum()).astype(np.float16)


def get_main_regions(binary_mask, min_size=10_000, connectivity=3):
    labeled = measure.label(binary_mask, connectivity=connectivity)
    region_props = measure.regionprops(labeled)
    
    main_regions_masks = []
    regions_labels = []
    bounding_boxes = []
    
    for props in region_props:
        if props.area >= min_size:
            main_regions_masks.append(props.filled_image)
            regions_labels.append(props.label)
            bounding_boxes.append(props.bbox)
            
    return main_regions_masks, regions_labels, bounding_boxes

def get_largest_region(binary_mask, connectivity=3):
    labeled = measure.label(binary_mask, connectivity=connectivity)
    region_props = measure.regionprops(labeled)
    
    max_region = 0
    max_region_label = 0
    
    for props in region_props:
        if props.area > max_region:
            max_region = props.area
            max_region_label = props.label

    return (labeled == max_region_label).astype(np.uint8)

In [6]:
def annihilate_jemiolas_faster(mask, kernel_sizes=[10, 9, 8], fill_threshold=0.5, iters=1, conv_dtype=np.uint16):

    kernel_sizes_maps = []
    mask = mask.astype(np.uint8)
    
    for i in range(iters):
        kernel_size_map = np.zeros(mask.shape, dtype=np.uint8)

        for kernel_size in kernel_sizes:
            fill_percentage = convolve_with_ball(mask, kernel_size, dtype=conv_dtype, normalize=True)
            
            above_threshold_fill_indices = fill_percentage > fill_threshold
            kernel_size_map[above_threshold_fill_indices] = kernel_size + 1

            mask[above_threshold_fill_indices] = 1
            
            print(f'Iteration {i + 1} kernel {kernel_size} done')

        kernel_sizes_maps.append(kernel_size_map)
        print(f'Iteration {i + 1} ended successfully')

    return kernel_sizes_maps

def onionize(mask, kernel_sizes=[10, 9, 8], fill_threshold=0.8, conv_dtype=np.uint16):

    mask = mask.astype(np.uint8)
    
    kernel_size_map = np.zeros(mask.shape, dtype=np.uint8)

    for kernel_size in sorted(kernel_sizes):
        fill_percentage = convolve_with_ball(mask, kernel_size, dtype=conv_dtype, normalize=True)
        above_threshold_fill_indices = fill_percentage >= fill_threshold
        kernel_size_map[above_threshold_fill_indices] = kernel_size + 1
        print(f'Kernel {kernel_size} done')

    return kernel_size_map

## Main region extraction

In [15]:
mask_main = get_main_regions(mask)[0][0].astype(np.uint8)
# VolumeVisualizer(mask_main, binary=True).visualize()
# VolumeVisualizer(skeletonize_3d(mask_main.astype(np.uint8)), binary=True).visualize()

In [7]:
# lsd_trees = annihilate_jemiolas_faster(mask_main, kernel_sizes=range(0, 13), iters=3)
lsd_trees = load_lsd_trees('./data/P12/reconstructions')

In [16]:
visualize_skeleton(lsd_trees[-1])

In [15]:
reconstruction = (lsd_trees[-1] > 0).astype(np.uint8)
# visualize_mask_bin(reconstruction)

In [16]:
skeleton = skeletonize_3d(reconstruction)

In [8]:
%%time
#onion = onionize(reconstruction, kernel_sizes=range(12), fill_threshold=1)
onion = np.load('./data/P12/onion.npy')

CPU times: user 0 ns, sys: 220 ms, total: 220 ms
Wall time: 454 ms


In [9]:
visualize_lsd(onion)

In [10]:
# np.save('./data/P12/onion-fully-filled', onion)

## Skeleton fixing

In [11]:
def iters_wrapper(func): 
    def inner(data, *args, iters=1, **kwargs): 
        result = func(data, *args, **kwargs)
        
        for i in range(iters - 1):
            result = func(result, *args, **kwargs)
            
        return result
    return inner 

### eating leaves

In [12]:
@iters_wrapper
def trim_skeleton(skeleton):   
    padded_skeleton = np.pad(skeleton, 1)
    new_skeleton = np.zeros(padded_skeleton.shape)
    queue = [tuple(np.argwhere(padded_skeleton)[0])]
    new_skeleton[queue[0]] == -1;
    
    while(len(queue) > 0):
        x, y, z = queue.pop(0)
        
        for dx in [-1, 0, 1]:
            for dy in [-1, 0, 1]:
                for dz in [-1, 0, 1]:
                    if dx == dy == dz == 0:
                        continue
                    
                    neighbour_x = x + dx
                    neighbour_y = y + dy
                    neighbour_z = z + dz
                    if padded_skeleton[neighbour_x, neighbour_y, neighbour_z] == 0:
                        continue
                        
                    if new_skeleton[neighbour_x, neighbour_y, neighbour_z] == 0:
                        queue.append((neighbour_x, neighbour_y, neighbour_z))
                        new_skeleton[neighbour_x, neighbour_y, neighbour_z] = 2;
                        new_skeleton[x, y, z] = 1
                        
    return (new_skeleton[1:-1, 1:-1, 1:-1] == 1).astype(np.uint8)

## thiccness map

In [17]:
def propagate_thiccness(skeleton, kernel_size_map):
    padded_skeleton = np.pad(skeleton, 1)
    padded_kernels_map = np.pad(kernel_size_map, 1)
    
    thiccness_map = np.zeros(padded_kernels_map.shape)
    thiccness_map[padded_skeleton > 0] = padded_kernels_map[padded_skeleton > 0]
    
    queue = list([tuple(coords) for coords in np.argwhere(padded_skeleton)])
    while(len(queue) > 0):
        x, y, z = queue.pop(0)
        thiccness = thiccness_map[x, y, z]
        
        for dx in [-1, 0, 1]:
            for dy in [-1, 0, 1]:
                for dz in [-1, 0, 1]:
                    neighbour_x = x + dx
                    neighbour_y = y + dy
                    neighbour_z = z + dz
                    if thiccness_map[neighbour_x, neighbour_y, neighbour_z] > 0:
                        continue
                        
                    if padded_kernels_map[neighbour_x, neighbour_y, neighbour_z] == 0:
                        continue
                        
                    thiccness_map[neighbour_x, neighbour_y, neighbour_z] = thiccness
                    queue.append((neighbour_x, neighbour_y, neighbour_z))
                        
    return thiccness_map[1:-1, 1:-1, 1:-1]

In [18]:
%%time
trimmed_skeleton = trim_skeleton(skeleton, iters=8)

CPU times: user 27.6 s, sys: 2.22 s, total: 29.8 s
Wall time: 30.4 s


thiccness is calculated based on trimmed skeleton

In [23]:
%%time
thicc_map = propagate_thiccness(trimmed_skeleton, onion)
visualize_gradient(thicc_map)

CPU times: user 57.5 s, sys: 1.1 s, total: 58.6 s
Wall time: 1min 11s


In [33]:
def make_ends_meet(skeleton, trimmed_skeleton, thiccness_map, ends_max_thiccness):
    ends = (skeleton - trimmed_skeleton) * (thiccness_map <= ends_max_thiccness)
    ends = ends.astype(np.uint8)
    
    trimmed_with_ends = (trimmed_skeleton > 0).astype(np.uint8) + ends
    return get_largest_region(trimmed_with_ends, connectivity=3)

In [34]:
full_skeleton = make_ends_meet(skeleton, trimmed_skeleton, thicc_map, ends_max_thiccness=4)

In [35]:
visualize_addition(full_skeleton, skeleton)

In [32]:
visualize_lsd(thicc_map)

## Bifurcation detection

In [48]:
def mark_bifurcation_regions(skeleton):
    
    padded_skeleton = np.pad(skeleton, 1)
    bifurcations_map = np.zeros(padded_skeleton.shape)
    
    for skeleton_voxel in np.argwhere(padded_skeleton > 0):
        x, y, z = tuple(skeleton_voxel)
        kernel_radius = 1
        kernel = np.ones((3, 3, 3))
        kernel[1, 1, 1] = 0
        
        skeleton_slice = padded_skeleton[
            x-kernel_radius:x+kernel_radius + 1,
            y-kernel_radius:y+kernel_radius + 1,
            z-kernel_radius:z+kernel_radius + 1
        ]
        
        intersections = (skeleton_slice > 0) * kernel
        bifurcations_map[x, y, z] = np.sum(intersections)
        
    return (bifurcations_map[1:-1, 1:-1, 1:-1] > 2).astype(np.uint8)

def mark_leaves(skeleton):
    trimmed = trim_skeleton(skeleton, iters=1)
    leaves = skeleton - trimmed
    return leaves

def print_kernels(mask, thiccness_map):
    
    max_kernel_radius = int(thiccness_map.max())
    kernels = [spherical_kernel(radius) for radius in range(max_kernel_radius)]
    
    padded_mask = np.pad(mask, max_kernel_radius)
    padded_thiccness_map = np.pad(thiccness_map, max_kernel_radius)
    kernels_image = np.zeros(padded_mask.shape)
    
    for voxel in np.argwhere(padded_mask > 0):
        x, y, z = tuple(voxel)
        kernel_radius = int(padded_thiccness_map[x, y, z] - 1)
        kernel = kernels[kernel_radius]
        
        mask_slice = kernels_image[
            x-kernel_radius:x+kernel_radius + 1,
            y-kernel_radius:y+kernel_radius + 1,
            z-kernel_radius:z+kernel_radius + 1
        ]
        
        mask_slice[:] = np.logical_or(mask_slice, kernel)
            
    return kernels_image[
        max_kernel_radius:-max_kernel_radius,
        max_kernel_radius:-max_kernel_radius,
        max_kernel_radius:-max_kernel_radius
    ]

In [49]:
%%time
bifurcation_map = mark_bifurcation_regions(full_skeleton)
leaves_map = mark_leaves(full_skeleton)

nodes_map = bifurcation_map + leaves_map

CPU times: user 4.08 s, sys: 716 ms, total: 4.8 s
Wall time: 4.92 s


In [50]:
nodes_image = print_kernels(nodes_map, thicc_map)
# visualize_mask_bin(nodes_image)

In [52]:
visualize_addition(nodes_image, reconstruction)

In [54]:
np.save('./data/P12/thiccness-map', thicc_map)
np.save('./data/P12/skeleton', full_skeleton)