In [1]:
import numpy as np
import matplotlib.pyplot as plt
import glob
import cv2
import pickle
from skimage import measure, segmentation, feature
from vis_utils import load_volume, VolumeVisualizer, ColorMapVisualizer
from scipy.ndimage import zoom
from skimage.morphology import skeletonize, skeletonize_3d, binary_dilation
from skimage import filters, morphology
from scipy.ndimage.filters import convolve, correlate
from scipy import signal
from skimage.filters import frangi, sato
from PIL import Image

In [2]:
TREE_NAME = 'P11'

In [3]:
source_dir = './data/'
files = {path.split('/')[2]: path for path in sorted(glob.glob(source_dir + '*/*.raw'))}
files

{'P01': './data/P01/P01_60um_1612x623x1108.raw',
 'P04': './data/P04/P04_60um_1273x466x1045.raw',
 'P11': './data/P11/P11_60um_1735x595x1150.raw',
 'P12': './data/P12/P12_60um_1333x443x864.raw'}

In [4]:
scales = {
    'P04': 0.5,
    'P11': 0.5,
    'P12': 0.5,
}

volume = load_volume(files[TREE_NAME], scale=scales[TREE_NAME])
volume.shape
# VolumeVisualizer(volume, binary=False).visualize()

(575, 298, 868)

In [5]:
thresholds = {
    'P04': 30,
    'P11': 100, # 40 to get whole tree 
    'P12': 70,
}

mask = volume > thresholds[TREE_NAME]
volume = None
# VolumeVisualizer(mask, binary=True).visualize()

## Utility functions

### Visualisation functions

In [6]:
def visualize_addition(base, base_with_addition):
    base = (base.copy() > 0).astype(np.uint8)
    addition = (base_with_addition > 0).astype(np.uint8)
    addition[base == 1] = 0
    ColorMapVisualizer(base + addition * 4).visualize()
    
def visualize_lsd(lsd_mask):
    ColorMapVisualizer(lsd_mask.astype(np.uint8)).visualize()
    
def visualize_gradient(lsd_mask):
    ColorMapVisualizer(lsd_mask.astype(np.uint8)).visualize(gradient=True)
    
def visualize_mask_bin(mask):
    VolumeVisualizer((mask > 0).astype(np.uint8), binary=True).visualize()
    
def visualize_mask_non_bin(mask):
    VolumeVisualizer((mask > 0).astype(np.uint8) * 255, binary=False).visualize()
    
def visualize_skeleton(mask, visualize_mask=True, visualize_both_versions=False):
    skeleton = skeletonize((mask > 0).astype(np.uint8))
    if not visualize_mask or visualize_both_versions:
        VolumeVisualizer(skeleton, binary=True).visualize()
    if visualize_mask or visualize_both_versions:
        skeleton = skeleton.astype(np.uint8) * 4
        mask = (mask > 0).astype(np.uint8) * 3
        mask[skeleton != 0] = 0
        ColorMapVisualizer(skeleton + mask).visualize()

def visualize_ultimate(lsd, base_mask):
    visualize_lsd(lsd)
    visualize_mask_non_bin(lsd)
    visualize_addition(base_mask, lsd)
    visualize_skeleton(lsd, visualize_mask=True)

### Convolution utils

In [7]:
def spherical_kernel(outer_radius, thickness=1, filled=True):    
    outer_sphere = morphology.ball(radius=outer_radius)
    if filled:
        return outer_sphere
    
    thickness = min(thickness, outer_radius)
    
    inner_radius = outer_radius - thickness
    inner_sphere = morphology.ball(radius=inner_radius)
    
    begin = outer_radius - inner_radius
    end = begin + inner_sphere.shape[0]
    outer_sphere[begin:end, begin:end, begin:end] -= inner_sphere
    return outer_sphere

def convolve_with_ball(img, ball_radius, dtype=np.uint16, normalize=True):
    kernel = spherical_kernel(ball_radius, filled=True)
    convolved = signal.convolve(img.astype(dtype), kernel.astype(dtype), mode='same')
    
    if not normalize:
        return convolved
    
    return (convolved / kernel.sum()).astype(np.float16)

## Main region extraction

In [8]:
def get_main_regions(binary_mask, min_size=10_000, connectivity=3):
    labeled = measure.label(binary_mask, connectivity=connectivity)
    region_props = measure.regionprops(labeled)
    
    main_regions_masks = []
    regions_labels = []
    bounding_boxes = []
    
    for props in region_props:
        if props.area >= min_size:
            main_regions_masks.append(props.filled_image)
            regions_labels.append(props.label)
            bounding_boxes.append(props.bbox)
            
    return main_regions_masks, regions_labels, bounding_boxes

#TODO add merging

In [9]:
main_regions = get_main_regions(mask)
print('number of main regions:', len(main_regions[0])) #TODO merge main regions if more than 1
mask_main = main_regions[0][0].astype(np.uint8)

VolumeVisualizer(mask_main, binary=True).visualize()
# VolumeVisualizer(skeletonize_3d(mask_main.astype(np.uint8)), binary=True).visualize()

number of main regions: 1


## Reconstruction by filling holes

In [10]:
def annihilate_jemiolas(mask, kernel_sizes=[10, 9, 8], fill_threshold=0.5, iters=1, conv_dtype=np.uint16):
    kernel_sizes_maps = []
    mask = mask.astype(np.uint8)
    
    for i in range(iters):
        kernel_size_map = np.zeros(mask.shape, dtype=np.uint8)

        for kernel_size in kernel_sizes:
            fill_percentage = convolve_with_ball(mask, kernel_size, dtype=conv_dtype, normalize=True)
            
            above_threshold_fill_indices = fill_percentage > fill_threshold
            kernel_size_map[above_threshold_fill_indices] = kernel_size + 1

            mask[above_threshold_fill_indices] = 1
            
            print(f'Iteration {i + 1} kernel {kernel_size} done')

        kernel_sizes_maps.append(kernel_size_map)
        print(f'Iteration {i + 1} ended successfully')

    return kernel_sizes_maps

In [11]:
# lsd_trees = annihilate_jemiolas(mask_main, kernel_sizes=range(0, 13), iters=3)
# np.save(source_dir + TREE_NAME + '/reconstruction', lsd_trees)
# ===============================================================================
lsd_trees = np.load(source_dir + TREE_NAME + '/reconstruction.npy')

In [12]:
reconstruction = (lsd_trees[-1] > 0).astype(np.uint8)
lsd_trees = None
# visualize_mask_bin(reconstruction)
visualize_skeleton(reconstruction)

## Onionization

In [13]:
def onionize(mask, kernel_sizes=[10, 9, 8], fill_threshold=0.8, conv_dtype=np.uint16):

    mask = mask.astype(np.uint8)
    
    kernel_size_map = np.zeros(mask.shape, dtype=np.uint8)

    for kernel_size in sorted(kernel_sizes):
        fill_percentage = convolve_with_ball(mask, kernel_size, dtype=conv_dtype, normalize=True)
        above_threshold_fill_indices = fill_percentage >= fill_threshold
        kernel_size_map[above_threshold_fill_indices] = kernel_size + 1
        print(f'Kernel {kernel_size} done')

    return kernel_size_map

In [14]:
%%time
# onion = onionize(reconstruction, kernel_sizes=range(13), fill_threshold=1)
# np.save(source_dir + TREE_NAME + '/onionization', onion)
# ===============================================================================
onion = np.load(source_dir + TREE_NAME + '/onionization.npy')

Kernel 0 done
Kernel 1 done
Kernel 2 done
Kernel 3 done
Kernel 4 done
Kernel 5 done
Kernel 6 done
Kernel 7 done
Kernel 8 done
Kernel 9 done
Kernel 10 done
Kernel 11 done
Kernel 12 done


FileNotFoundError: [Errno 2] No such file or directory: './data/P11/onionization'

In [17]:
visualize_lsd(onion)

## Skeleton fixing

In [15]:
skeleton = skeletonize_3d(reconstruction)

### dumb stretching

In [11]:
def iters_wrapper(func): 
    def inner(data, *args, iters=1, **kwargs): 
        result = func(data, *args, **kwargs)
        
        for i in range(iters - 1):
            result = func(result, *args, **kwargs)
            
        return result
    return inner 

In [12]:
@iters_wrapper
def stretch_skeleton(skeleton, kernel_size_map):
    max_radius = int(kernel_size_map.max())
    padded_skeleton = np.pad(skeleton, max_radius)
    padded_kernel_map = np.pad(kernel_size_map, max_radius)
    
    skeleton_voxels = np.argwhere(padded_skeleton)
    kernels = [spherical_kernel(radius) for radius in range(max_radius)]
    
    new_skeleton = np.zeros(padded_skeleton.shape)
    
    for voxel_coords in skeleton_voxels:
        x, y, z = tuple(voxel_coords)
        kernel_radius = padded_kernel_map[x, y, z] - 1
        kernel = kernels[kernel_radius]
        
        kernel_x, kernel_y, kernel_z = tuple(voxel_coords - kernel_radius)
        kernel_diameter = 2 * kernel_radius + 1
        kernel_map_slice = padded_kernel_map[
            kernel_x:kernel_x + kernel_diameter,
            kernel_y:kernel_y + kernel_diameter,
            kernel_z:kernel_z + kernel_diameter
        ]
        
        neighbours = kernel_map_slice * kernel
        
        if neighbours.max() == neighbours[kernel_radius, kernel_radius, kernel_radius]:
            target_voxel = (x, y, z)
            
        else:
            local_max_coords = np.argwhere(neighbours == neighbours.max())[0]
            dx, dy, dz = tuple(local_max_coords - kernel_radius)
            target_voxel = (x + dx, y + dy, z + dz)
        
        new_skeleton[target_voxel] = 1
        
    return new_skeleton[max_radius:-max_radius, max_radius:-max_radius, max_radius:-max_radius]

### eating leaves

In [13]:
@iters_wrapper
def trim_skeleton(skeleton):   
    padded_skeleton = np.pad(skeleton, 1)
    new_skeleton = np.zeros(padded_skeleton.shape)
    queue = [tuple(np.argwhere(padded_skeleton)[0])]
    new_skeleton[queue[0]] == -1;
    
    while(len(queue) > 0):
        x, y, z = queue.pop(0)
        
        for dx in [-1, 0, 1]:
            for dy in [-1, 0, 1]:
                for dz in [-1, 0, 1]:
                    if dx == dy == dz == 0:
                        continue
                    
                    neighbour_x = x + dx
                    neighbour_y = y + dy
                    neighbour_z = z + dz
                    if padded_skeleton[neighbour_x, neighbour_y, neighbour_z] == 0:
                        continue
                        
                    if new_skeleton[neighbour_x, neighbour_y, neighbour_z] == 0:
                        queue.append((neighbour_x, neighbour_y, neighbour_z))
                        new_skeleton[neighbour_x, neighbour_y, neighbour_z] = 2;
                        new_skeleton[x, y, z] = 1
                        
    return (new_skeleton[1:-1, 1:-1, 1:-1] == 1).astype(np.uint8)

## thiccness map

In [14]:
def propagate_thiccness(skeleton, kernel_size_map):
    padded_skeleton = np.pad(skeleton, 1)
    padded_kernels_map = np.pad(kernel_size_map, 1)
    
    thiccness_map = np.zeros(padded_kernels_map.shape)
    thiccness_map[padded_skeleton > 0] = padded_kernels_map[padded_skeleton > 0]
    
    queue = list([tuple(coords) for coords in np.argwhere(padded_skeleton)])
    while(len(queue) > 0):
        x, y, z = queue.pop(0)
        thiccness = thiccness_map[x, y, z]
        
        for dx in [-1, 0, 1]:
            for dy in [-1, 0, 1]:
                for dz in [-1, 0, 1]:
                    neighbour_x = x + dx
                    neighbour_y = y + dy
                    neighbour_z = z + dz
                    if thiccness_map[neighbour_x, neighbour_y, neighbour_z] > 0:
                        continue
                        
                    if padded_kernels_map[neighbour_x, neighbour_y, neighbour_z] == 0:
                        continue
                        
                    thiccness_map[neighbour_x, neighbour_y, neighbour_z] = thiccness
                    queue.append((neighbour_x, neighbour_y, neighbour_z))
                        
    return thiccness_map[1:-1, 1:-1, 1:-1]

In [16]:
%%time
trimmed_skeleton = trim_skeleton(skeleton, iters=8)

CPU times: user 25.7 s, sys: 1.78 s, total: 27.4 s
Wall time: 27.5 s


In [17]:
ends = (skeleton - trimmed_skeleton) * (onion < 5)
trimmed_full = trimmed_skeleton + ends

labels = measure.label(trimmed_full)
trimmed_ultimate = (labels == 1).astype(np.uint8)

In [70]:
visualize_addition(trimmed_ultimate, skeleton)

In [18]:
%%time
thicc_map = propagate_thiccness(trimmed_ultimate, onion)
visualize_gradient(thicc_map)

CPU times: user 1min 25s, sys: 1.27 s, total: 1min 26s
Wall time: 1min 46s


## Bifurcation detection

In [19]:
def mark_bifurcation_regions(skeleton, thiccness_map):
    
    max_kernel_radius = int(thiccness_map.max())
    kernels = [spherical_kernel(radius, filled=False, thickness=2) for radius in range(max_kernel_radius)]
    
    padded_skeleton = np.pad(skeleton, max_kernel_radius)
    padded_thiccness_map = np.pad(thiccness_map, max_kernel_radius)
    bifurcations_map = np.zeros(padded_skeleton.shape)
    
    for skeleton_voxel in np.argwhere(padded_skeleton > 0):
        x, y, z = tuple(skeleton_voxel)
        kernel_radius = padded_thiccness_map[x, y, z] - 1
        kernel = kernels[kernel_radius]
        
        skeleton_slice = padded_skeleton[
            x-kernel_radius:x+kernel_radius + 1,
            y-kernel_radius:y+kernel_radius + 1,
            z-kernel_radius:z+kernel_radius + 1
        ]
        
        intersections = (skeleton_slice > 0) * kernel
        labelled_intersections = measure.label(intersections)
        
        bifurcations_map[x, y, z] = np.max(labelled_intersections)
#         if np.sum((map_slice * kernel)) > 2:
#             bifurcations_map[x, y, z] = 1
            
    return bifurcations_map[
        max_kernel_radius:-max_kernel_radius,
        max_kernel_radius:-max_kernel_radius,
        max_kernel_radius:-max_kernel_radius
    ]

def mark_bifurcation_regions2(skeleton):
    
    padded_skeleton = np.pad(skeleton, 1)
    bifurcations_map = np.zeros(padded_skeleton.shape)
    
    for skeleton_voxel in np.argwhere(padded_skeleton > 0):
        x, y, z = tuple(skeleton_voxel)
        kernel_radius = 1
        kernel = np.ones((3, 3, 3))
        kernel[1, 1, 1] = 0
        
        skeleton_slice = padded_skeleton[
            x-kernel_radius:x+kernel_radius + 1,
            y-kernel_radius:y+kernel_radius + 1,
            z-kernel_radius:z+kernel_radius + 1
        ]
        
        intersections = (skeleton_slice > 0) * kernel
#         labelled_intersections = measure.label(intersections)
        
        bifurcations_map[x, y, z] = np.sum(intersections)
#         if np.sum((map_slice * kernel)) > 2:
#             bifurcations_map[x, y, z] = 1
            
    return bifurcations_map[
        1:-1,
        1:-1,
        1:-1
    ]

def print_kernels(mask, thiccness_map):
    
    max_kernel_radius = int(thiccness_map.max())
    kernels = [spherical_kernel(radius) for radius in range(max_kernel_radius)]
    
    padded_mask = np.pad(mask, max_kernel_radius)
    padded_thiccness_map = np.pad(thiccness_map, max_kernel_radius)
    kernels_image = np.zeros(padded_mask.shape)
    
    for voxel in np.argwhere(padded_mask > 0):
        x, y, z = tuple(voxel)
        kernel_radius = int(padded_thiccness_map[x, y, z] - 1)
        kernel = kernels[kernel_radius]
        
        mask_slice = kernels_image[
            x-kernel_radius:x+kernel_radius + 1,
            y-kernel_radius:y+kernel_radius + 1,
            z-kernel_radius:z+kernel_radius + 1
        ]
        
        mask_slice[:] = np.logical_or(mask_slice, kernel)
            
    return kernels_image[
        max_kernel_radius:-max_kernel_radius,
        max_kernel_radius:-max_kernel_radius,
        max_kernel_radius:-max_kernel_radius
    ]

In [20]:
%%time
bifurcation_map = mark_bifurcation_regions2(trimmed_ultimate)

CPU times: user 1.03 s, sys: 105 ms, total: 1.14 s
Wall time: 1.16 s


In [148]:
visualize_lsd(reconstruction + (bifurcation_map > 2).astype(np.uint8) * 4)

In [21]:
bif_mask = (bifurcation_map > 2).astype(np.uint8)
bif_kernels_image = print_kernels(bif_mask, thicc_map)
visualize_mask_bin(bif_kernels_image)

In [22]:
visualize_addition(bif_kernels_image, reconstruction)