In [1]:
import numpy as np
import glob
import matplotlib.pyplot as plt
import cv2
import pickle
from skimage import measure, segmentation, feature
from vis_utils import load_volume, VolumeVisualizer, ColorMapVisualizer
from scipy.ndimage import zoom
from skimage.morphology import skeletonize, skeletonize_3d, binary_dilation
from skimage import filters, morphology
from scipy.ndimage.filters import convolve, correlate
from scipy import signal
from skimage.filters import frangi, sato
from PIL import Image

In [2]:
TREE_NAME = 'P12'

## Utility visualisation functions

In [3]:
def visualize_addition(base, base_with_addition):
    base = (base.copy() > 0).astype(np.uint8)
    addition = (base_with_addition > 0).astype(np.uint8)
    addition[base == 1] = 0
    ColorMapVisualizer(base + addition * 4).visualize()
    
def visualize_lsd(lsd_mask):
    ColorMapVisualizer(lsd_mask.astype(np.uint8)).visualize()
    
def visualize_gradient(lsd_mask):
    ColorMapVisualizer(lsd_mask.astype(np.uint8)).visualize(gradient=True)
    
def visualize_mask_bin(mask):
    VolumeVisualizer((mask > 0).astype(np.uint8), binary=True).visualize()
    
def visualize_mask_non_bin(mask):
    VolumeVisualizer((mask > 0).astype(np.uint8) * 255, binary=False).visualize()
    
def visualize_skeleton(mask, visualize_mask=True, visualize_both_versions=False):
    skeleton = skeletonize((mask > 0).astype(np.uint8))
    if not visualize_mask or visualize_both_versions:
        VolumeVisualizer(skeleton, binary=True).visualize()
    if visualize_mask or visualize_both_versions:
        skeleton = skeleton.astype(np.uint8) * 4
        mask = (mask > 0).astype(np.uint8) * 3
        mask[skeleton != 0] = 0
        ColorMapVisualizer(skeleton + mask).visualize()

def visualize_ultimate(lsd, base_mask):
    visualize_lsd(lsd)
    visualize_mask_non_bin(lsd)
    visualize_addition(base_mask, lsd)
    visualize_skeleton(lsd, visualize_mask=True)

## Loading specimen reconstruction

In [4]:
source_dir = './data/'
reconstruction = np.load(source_dir + TREE_NAME + '/reconstruction.npy')
visualize_skeleton(reconstruction)

## Obtaining and trimming skeleton

In [5]:
%%time
skeleton = skeletonize_3d(reconstruction)

CPU times: user 3.31 s, sys: 28.4 ms, total: 3.34 s
Wall time: 3.35 s


In [6]:
def iters_wrapper(func): 
    def inner(data, *args, iters=1, **kwargs): 
        result = func(data, *args, **kwargs)
        print('iteration 1 done')
        for i in range(iters - 1):
            result = func(result, *args, **kwargs)  
            print(f'iteration {i + 2} done')
        return result
    return inner 

@iters_wrapper
def trim_skeleton(skeleton):   
    new_skeleton = np.zeros(skeleton.shape)
    skeleton_voxels = np.argwhere(skeleton)
    
    for voxel in skeleton_voxels:
        x, y, z = tuple(voxel)
        neighbours_count = 0
        
        for dx in [-1, 0, 1]:
            for dy in [-1, 0, 1]:
                for dz in [-1, 0, 1]:
                    if dx == dy == dz == 0:
                        continue
                    
                    neighbour_x = x + dx
                    neighbour_y = y + dy
                    neighbour_z = z + dz
                    if skeleton[neighbour_x, neighbour_y, neighbour_z] > 0:
                        neighbours_count += 1
                        
        if neighbours_count > 1:
            new_skeleton[x, y, z] = 1
                        
    return new_skeleton.astype(np.uint8)

In [7]:
%%time
iterations = {
    'P01': 5,
    'P05': 25,
    'P12': 8,
}

trimmed_skeleton = trim_skeleton(skeleton, iters=iterations.get(TREE_NAME, 8))
visualize_addition(trimmed_skeleton, reconstruction)
visualize_addition(trimmed_skeleton, skeleton)

iteration 1 done
iteration 2 done
iteration 3 done
iteration 4 done
iteration 5 done
iteration 6 done
iteration 7 done
iteration 8 done
CPU times: user 29.1 s, sys: 2.19 s, total: 31.3 s
Wall time: 1min


## Propagating thiccness

In [8]:
def spherical_kernel(outer_radius, thickness=1, filled=True):    
    outer_sphere = morphology.ball(radius=outer_radius)
    if filled:
        return outer_sphere
    
    thickness = min(thickness, outer_radius)
    
    inner_radius = outer_radius - thickness
    inner_sphere = morphology.ball(radius=inner_radius)
    
    begin = outer_radius - inner_radius
    end = begin + inner_sphere.shape[0]
    outer_sphere[begin:end, begin:end, begin:end] -= inner_sphere
    return outer_sphere

def convolve_with_ball(img, ball_radius, dtype=np.uint16, normalize=True):
    kernel = spherical_kernel(ball_radius, filled=True)
    convolved = signal.convolve(img.astype(dtype), kernel.astype(dtype), mode='same')
    
    if not normalize:
        return convolved
    
    return (convolved / kernel.sum()).astype(np.float16)

def onionize(mask, kernel_sizes=[10, 9, 8], fill_threshold=0.8, conv_dtype=np.uint16):
    mask = mask.astype(np.uint8)
    kernel_size_map = np.zeros(mask.shape, dtype=np.uint8)

    for kernel_size in sorted(kernel_sizes):
        fill_percentage = convolve_with_ball(mask, kernel_size, dtype=conv_dtype, normalize=True)
        above_threshold_fill_indices = fill_percentage >= fill_threshold
        kernel_size_map[above_threshold_fill_indices] = kernel_size + 1
        print(f'Kernel {kernel_size} done')

    return kernel_size_map

In [9]:
%%time
kernel_sizes = {
    'P01': range(19),
    'P05': range(21),
    'P12': range(12),
}

fill_thresholds = {
    'P01': 0.8,
    'P05': 0.8,
    'P12': 0.8,
}

onion = onionize(reconstruction, 
                 kernel_sizes=kernel_sizes.get(TREE_NAME, range(12)), 
                 fill_threshold=fill_thresholds.get(TREE_NAME, 0.8))
visualize_lsd(onion)

Kernel 0 done
Kernel 1 done
Kernel 2 done
Kernel 3 done
Kernel 4 done
Kernel 5 done
Kernel 6 done
Kernel 7 done
Kernel 8 done
Kernel 9 done
Kernel 10 done
Kernel 11 done
CPU times: user 1min, sys: 13 s, total: 1min 13s
Wall time: 1min 19s


In [10]:
def propagate_thiccness(skeleton, kernel_size_map):   
    thiccness_map = np.zeros(kernel_size_map.shape)
    thiccness_map[skeleton > 0] = kernel_size_map[skeleton > 0]
    
    queue = list([tuple(coords) for coords in np.argwhere(skeleton)])
    while(len(queue) > 0):
        x, y, z = queue.pop(0)
        thiccness = thiccness_map[x, y, z]
        
        for dx in [-1, 0, 1]:
            for dy in [-1, 0, 1]:
                for dz in [-1, 0, 1]:
                    neighbour_x = x + dx
                    neighbour_y = y + dy
                    neighbour_z = z + dz
                    if thiccness_map[neighbour_x, neighbour_y, neighbour_z] > 0:
                        continue
                        
                    if kernel_size_map[neighbour_x, neighbour_y, neighbour_z] == 0:
                        continue
                        
                    thiccness_map[neighbour_x, neighbour_y, neighbour_z] = thiccness
                    queue.append((neighbour_x, neighbour_y, neighbour_z))
                        
    return thiccness_map

In [11]:
%%time
thiccness_map = propagate_thiccness(trimmed_skeleton, onion)
visualize_gradient(thiccness_map)

CPU times: user 41.2 s, sys: 565 ms, total: 41.8 s
Wall time: 45.2 s


## Adding leaves to the skeleton

In [12]:
def get_largest_region(binary_mask, connectivity=3):
    labeled = measure.label(binary_mask, connectivity=connectivity)
    region_props = measure.regionprops(labeled)
    
    max_region = 0
    max_region_label = 0
    
    for props in region_props:
        if props.area > max_region:
            max_region = props.area
            max_region_label = props.label

    return (labeled == max_region_label).astype(np.uint8)

def make_ends_meet(skeleton, trimmed_skeleton, thiccness_map, ends_max_thiccness):
    ends = (skeleton - trimmed_skeleton) * (thiccness_map <= ends_max_thiccness)
    ends = ends.astype(np.uint8)
    
    trimmed_with_ends = (trimmed_skeleton > 0).astype(np.uint8) + ends
    return get_largest_region(trimmed_with_ends, connectivity=3)

In [13]:
ends_max_thiccnesses = {
    'P01': 9,
    'P05': 12,
    'P12': 4,
}

full_skeleton = make_ends_meet(skeleton, trimmed_skeleton, thiccness_map, 
                               ends_max_thiccness=ends_max_thiccnesses.get(TREE_NAME, 4))
visualize_addition(full_skeleton, skeleton) # full_skeleton is green (1)

## Saving skeleton and thiccness map

In [14]:
np.save(source_dir + TREE_NAME + '/skeleton', full_skeleton)
np.save(source_dir + TREE_NAME + '/thiccness-map', thiccness_map)