In [1]:
import numpy as np
import glob
import matplotlib.pyplot as plt
import cv2
from skimage import measure, segmentation, feature
from vis_utils import load_volume, VolumeVisualizer, ColorMapVisualizer
from scipy.ndimage import zoom
from skimage.morphology import skeletonize, skeletonize_3d

from skimage import filters, morphology

from scipy.ndimage.filters import convolve, correlate
from scipy import signal

from skimage.filters import frangi, sato

from PIL import Image
import pickle

In [5]:
source_dir = './data/*'
files = list(sorted(glob.glob(source_dir + '.raw')))
list(enumerate(files))

[(0, './data/P01_60um_1612x623x1108.raw'),
 (1, './data/P12_60um_1333x443x864.raw')]

In [7]:
%%time
mask = load_volume(files[1], scale=0.5)
mask.shape

CPU times: user 1.33 s, sys: 399 ms, total: 1.73 s
Wall time: 2.52 s


(432, 222, 666)

In [9]:
threshold = 70
mask = mask > threshold
VolumeVisualizer(mask, binary=True).visualize()

## Utility functions

### Visualisation functions

In [10]:
def visualize_addition(base, base_with_addition):
    base = (base.copy() > 0).astype(np.uint8)
    addition = (base_with_addition > 0).astype(np.uint8)
    addition[base == 1] = 0
    ColorMapVisualizer(base + addition * 2).visualize()
    
def visualize_lsd(lsd_mask):
    ColorMapVisualizer(lsd_mask.astype(np.uint8)).visualize()
    
def visualize_gradient(lsd_mask):
    ColorMapVisualizer(lsd_mask.astype(np.uint8)).visualize(gradient=True)
    
def visualize_mask_bin(mask):
    VolumeVisualizer((mask > 0).astype(np.uint8), binary=True).visualize()
    
def visualize_mask_non_bin(mask):
    VolumeVisualizer((mask > 0).astype(np.uint8) * 255, binary=False).visualize()
    
def visualize_skeleton(mask, visualize_mask=True, visualize_both_versions=False):
    skeleton = skeletonize((mask > 0).astype(np.uint8))
    if not visualize_mask or visualize_both_versions:
        VolumeVisualizer(skeleton, binary=True).visualize()
    if visualize_mask or visualize_both_versions:
        skeleton = skeleton.astype(np.uint8) * 4
        mask = (mask > 0).astype(np.uint8) * 3
        mask[skeleton != 0] = 0
        ColorMapVisualizer(skeleton + mask).visualize()

def visualize_ultimate(lsd, base_mask):
    visualize_lsd(lsd)
    visualize_mask_non_bin(lsd)
    visualize_addition(base_mask, lsd)
    visualize_skeleton(lsd, visualize_mask=True)

In [11]:
def load_lsd_trees(filename):
    with open(filename, 'rb') as f:
        lsd_trees = pickle.load(f)
    return lsd_trees

def save_lsd_trees(lsd_trees, filename):
    with open(filename, 'wb') as f:
        pickle.dump(lsd_trees, f)

In [12]:
def spherical_kernel(outer_radius, thickness=1, filled=True):    
    outer_sphere = morphology.ball(radius=outer_radius)
    if filled:
        return outer_sphere
    
    inner_radius = outer_radius - thickness
    inner_sphere = morphology.ball(radius=inner_radius)
    
    begin = outer_radius - inner_radius
    end = begin + inner_sphere.shape[0]
    outer_sphere[begin:end, begin:end, begin:end] -= inner_sphere
    return outer_sphere

def convolve_with_ball(img, ball_radius, dtype=np.uint16, normalize=True):
    kernel = spherical_kernel(ball_radius, filled=True)
    convolved = signal.convolve(img.astype(dtype), kernel.astype(dtype), mode='same')
    
    if not normalize:
        return convolved
    
    return (convolved / kernel.sum()).astype(np.float16)


def get_main_regions(binary_mask, min_size=10_000, connectivity=3):
    labeled = measure.label(binary_mask, connectivity=connectivity)
    region_props = measure.regionprops(labeled)
    
    main_regions_masks = []
    regions_labels = []
    bounding_boxes = []
    
    for props in region_props:
        if props.area >= min_size:
            main_regions_masks.append(props.filled_image)
            regions_labels.append(props.label)
            bounding_boxes.append(props.bbox)
            
    return main_regions_masks, regions_labels, bounding_boxes

In [14]:
def annihilate_jemiolas_faster(mask, kernel_sizes=[10, 9, 8], fill_threshold=0.5, iters=1, conv_dtype=np.uint16):

    kernel_sizes_maps = []
    mask = mask.astype(np.uint8)
    
    for i in range(iters):
        kernel_size_map = np.zeros(mask.shape, dtype=np.uint8)

        for kernel_size in kernel_sizes:
            fill_percentage = convolve_with_ball(mask, kernel_size, dtype=conv_dtype, normalize=True)
            
            above_threshold_fill_indices = fill_percentage > fill_threshold
            kernel_size_map[above_threshold_fill_indices] = kernel_size + 1

            mask[above_threshold_fill_indices] = 1
            
            print(f'Iteration {i + 1} kernel {kernel_size} done')

        kernel_sizes_maps.append(kernel_size_map)
        print(f'Iteration {i + 1} ended successfully')

    return kernel_sizes_maps

def onionize(mask, kernel_sizes=[10, 9, 8], fill_threshold=0.8, conv_dtype=np.uint16):

    mask = mask.astype(np.uint8)
    
    kernel_size_map = np.zeros(mask.shape, dtype=np.uint8)

    for kernel_size in sorted(kernel_sizes):
        fill_percentage = convolve_with_ball(mask, kernel_size, dtype=conv_dtype, normalize=True)
        above_threshold_fill_indices = fill_percentage >= fill_threshold
        kernel_size_map[above_threshold_fill_indices] = kernel_size + 1
        print(f'Kernel {kernel_size} done')

    return kernel_size_map

## Main region extraction

In [15]:
mask_main = get_main_regions(mask)[0][0].astype(np.uint8)
# VolumeVisualizer(mask_main, binary=True).visualize()
# VolumeVisualizer(skeletonize_3d(mask_main.astype(np.uint8)), binary=True).visualize()

In [17]:
lsd_trees = annihilate_jemiolas_faster(mask_main, kernel_sizes=range(0, 13), iters=3)

Iteration 1 kernel 0 done
Iteration 1 kernel 1 done
Iteration 1 kernel 2 done
Iteration 1 kernel 3 done
Iteration 1 kernel 4 done
Iteration 1 kernel 5 done
Iteration 1 kernel 6 done
Iteration 1 kernel 7 done
Iteration 1 kernel 8 done
Iteration 1 kernel 9 done
Iteration 1 kernel 10 done
Iteration 1 kernel 11 done
Iteration 1 kernel 12 done
Iteration 1 ended successfully
Iteration 2 kernel 0 done
Iteration 2 kernel 1 done
Iteration 2 kernel 2 done
Iteration 2 kernel 3 done
Iteration 2 kernel 4 done
Iteration 2 kernel 5 done
Iteration 2 kernel 6 done
Iteration 2 kernel 7 done
Iteration 2 kernel 8 done
Iteration 2 kernel 9 done
Iteration 2 kernel 10 done
Iteration 2 kernel 11 done
Iteration 2 kernel 12 done
Iteration 2 ended successfully
Iteration 3 kernel 0 done
Iteration 3 kernel 1 done
Iteration 3 kernel 2 done
Iteration 3 kernel 3 done
Iteration 3 kernel 4 done
Iteration 3 kernel 5 done
Iteration 3 kernel 6 done
Iteration 3 kernel 7 done
Iteration 3 kernel 8 done
Iteration 3 kernel 9 d

In [18]:
visualize_skeleton(lsd_trees[-1])

In [19]:
# save_lsd_trees(lsd_trees, './filled_trees/P12/trees_3_iters.trees')

In [None]:
# lsd_trees = load_lsd_trees('./filled_trees/P12/trees_3_iters.trees')

In [20]:
reconstruction = (lsd_trees[-1] > 0).astype(np.uint8)
lsd_trees = []
# visualize_mask_bin(reconstruction)

In [22]:
skeleton = skeletonize_3d(reconstruction)

In [23]:
%%time
onion = onionize(reconstruction, kernel_sizes=range(12), fill_threshold=0.8)
# onion = np.load('./data/P12/processed/onion.npy')

Kernel 0 done
Kernel 1 done
Kernel 2 done
Kernel 3 done
Kernel 4 done
Kernel 5 done
Kernel 6 done
Kernel 7 done
Kernel 8 done
Kernel 9 done
Kernel 10 done
Kernel 11 done
CPU times: user 1min 5s, sys: 15.3 s, total: 1min 20s
Wall time: 1min 21s


In [24]:
visualize_lsd(onion)

In [25]:
# np.save('./onions/P12/onion', onion)

In [26]:
def correct_skeleton(skeleton, kernel_size_map):
    max_radius = int(kernel_size_map.max())
    padded_skeleton = np.pad(skeleton, max_radius)
    padded_kernel_map = np.pad(kernel_size_map, max_radius)
    
    skeleton_voxels = np.argwhere(padded_skeleton)
    kernels = [spherical_kernel(radius) for radius in range(max_radius)]
    
    new_skeleton = np.zeros(padded_skeleton.shape)
    
    for voxel_coords in skeleton_voxels:
        x, y, z = tuple(voxel_coords)
        kernel_radius = padded_kernel_map[x, y, z] - 1
        kernel = kernels[kernel_radius]
        
        kernel_x, kernel_y, kernel_z = tuple(voxel_coords - kernel_radius)
        kernel_diameter = 2 * kernel_radius + 1
        kernel_map_slice = padded_kernel_map[
            kernel_x:kernel_x + kernel_diameter,
            kernel_y:kernel_y + kernel_diameter,
            kernel_z:kernel_z + kernel_diameter
        ]
        
        neighbours = kernel_map_slice * kernel
        
        if neighbours.max() == neighbours[kernel_radius, kernel_radius, kernel_radius]:
            target_voxel = (x, y, z)
            
        else:
            local_max_coords = np.argwhere(neighbours == neighbours.max())[0]
            dx, dy, dz = tuple(local_max_coords - kernel_radius)
            target_voxel = (x + dx, y + dy, z + dz)
        
        new_skeleton[target_voxel] = 1
        
    return new_skeleton[max_radius:-max_radius, max_radius:-max_radius, max_radius:-max_radius]


def propagate_thiccness(skeleton, kernel_size_map):
    padded_skeleton = np.pad(skeleton, 1)
    padded_kernels_map = np.pad(kernel_size_map, 1)
    
    thiccness_map = np.zeros(padded_kernels_map.shape)
    thiccness_map[padded_skeleton > 0] = padded_kernels_map[padded_skeleton > 0]
    
    queue = list([tuple(coords) for coords in np.argwhere(padded_skeleton)])
    while(len(queue) > 0):
        x, y, z = queue.pop(0)
        thiccness = thiccness_map[x, y, z]
        
        for dx in [-1, 0, 1]:
            for dy in [-1, 0, 1]:
                for dz in [-1, 0, 1]:
                    neighbour_x = x + dx
                    neighbour_y = y + dy
                    neighbour_z = z + dz
                    if thiccness_map[neighbour_x, neighbour_y, neighbour_z] > 0:
                        continue
                        
                    if padded_kernels_map[neighbour_x, neighbour_y, neighbour_z] == 0:
                        continue
                        
                    thiccness_map[neighbour_x, neighbour_y, neighbour_z] = thiccness
                    queue.append((neighbour_x, neighbour_y, neighbour_z))
                        
    return thiccness_map[1:-1, 1:-1, 1:-1]

In [27]:
%%time

corrected_skeleton = skeleton
for i in range(5):
    corrected_skeleton = correct_skeleton(corrected_skeleton, onion)

CPU times: user 7.64 s, sys: 902 ms, total: 8.55 s
Wall time: 8.57 s


In [28]:
%%time
thicc_map = propagate_thiccness(corrected_skeleton, onion)

CPU times: user 46.7 s, sys: 385 ms, total: 47 s
Wall time: 47.3 s


In [29]:
visualize_gradient(thicc_map)

# Trim it like a pro

In [39]:
def trim_skeleton(skeleton, thicc_map, threshold=0.9):
    max_radius = int(thicc_map.max())
    padded_skeleton = np.pad(skeleton, max_radius)
    padded_thicc_map = np.pad(thicc_map, max_radius)
    
    skeleton_voxels = np.argwhere(padded_skeleton)
    kernels = [spherical_kernel(radius) for radius in range(max_radius)] # from 0 to max_radius - 1
    
    trimmed_skeleton = np.zeros(padded_skeleton.shape)
    
    for voxel_coords in skeleton_voxels:
        x, y, z = tuple(voxel_coords)
        kernel_radius = padded_thicc_map[x, y, z] - 1
        kernel = kernels[kernel_radius]
        
        kernel_x, kernel_y, kernel_z = tuple(voxel_coords - kernel_radius)
        kernel_diameter = 2 * kernel_radius + 1
        kernel_map_slice = padded_thicc_map[
            kernel_x:kernel_x + kernel_diameter,
            kernel_y:kernel_y + kernel_diameter,
            kernel_z:kernel_z + kernel_diameter
        ]
        
        fill = np.sum(kernel * (kernel_map_slice > 0).astype(np.uint8)) / np.sum(kernel)
                
        if fill >= threshold:
            trimmed_skeleton[x, y, z] = 1
        
    return trimmed_skeleton[max_radius:-max_radius, max_radius:-max_radius, max_radius:-max_radius]


In [77]:
trimmed_skeleton = trim_skeleton(skeleton, thicc_map.astype(np.uint8), threshold=0.75)

In [70]:
visualize_lsd(trimmed_skeleton)

In [75]:
visualize_addition(trimmed_skeleton, skeleton)

In [78]:
visualize_addition(trimmed_skeleton, corrected_skeleton)

In [73]:
regions = measure.label(trimmed_skeleton)
np.max(regions)

1113

In [74]:
visualize_addition(trimmed_skeleton, reconstruction)

In [92]:
def eat_leaves(skeleton, iters=1):   
    new_skeleton = skeleton
    
    for i in range(iters):
        kernel = np.ones((3,3,3))
        convolved = signal.convolve(skeleton.astype(np.uint8), kernel.astype(np.uint8), mode='same')
        new_skeleton = new_skeleton * (convolved > 3).astype(np.uint8)
        print(f'iteration {i + 1} done')
      
    return new_skeleton

In [93]:
eated_out_skeleton = eat_leaves(skeleton, iters=5)

iteration 1 done
iteration 2 done
iteration 3 done
iteration 4 done
iteration 5 done


In [95]:
visualize_mask_bin(eated_out_skeleton)

In [97]:
visualize_addition(eated_out_skeleton, skeleton)

In [120]:
def eat_leaves_smarter_iter(skeleton):   
    padded_skeleton = np.pad(skeleton, 1)
    new_skeleton = np.zeros(padded_skeleton.shape)
    queue = [tuple(np.argwhere(padded_skeleton)[0])]
    new_skeleton[queue[0]] == -1;
    
    while(len(queue) > 0):
        x, y, z = queue.pop(0)
        
        for dx in [-1, 0, 1]:
            for dy in [-1, 0, 1]:
                for dz in [-1, 0, 1]:
                    if dx == dy == dz == 0:
                        continue
                    
                    neighbour_x = x + dx
                    neighbour_y = y + dy
                    neighbour_z = z + dz
                    if padded_skeleton[neighbour_x, neighbour_y, neighbour_z] == 0:
                        continue
                        
                    if new_skeleton[neighbour_x, neighbour_y, neighbour_z] == 0:
                        queue.append((neighbour_x, neighbour_y, neighbour_z))
                        new_skeleton[neighbour_x, neighbour_y, neighbour_z] = 2;
                        new_skeleton[x, y, z] = 1
                        
    return new_skeleton[1:-1, 1:-1, 1:-1]

def eat_leaves_smarter(skeleton, iters=1):
    new_skeleton = skeleton.copy()
    for i in range(iters):
        new_skeleton = eat_leaves_smarter_iter(new_skeleton)
        new_skeleton = (new_skeleton == 1).astype(np.uint8)
    return new_skeleton

In [128]:
%%time
smart_skeleton = eat_leaves_smarter(skeleton, iters=10)

CPU times: user 33.4 s, sys: 1.31 s, total: 34.7 s
Wall time: 34.8 s


In [129]:
visualize_addition(smart_skeleton, skeleton)

In [138]:
xd = smart_skeleton == 1
regions = measure.label(xd)
np.max(regions)

1

In [179]:
real_ends = ((smart_skeleton - skeleton) * (thicc_map < 4)) > 0

In [180]:
the_smartest_skeleton = smart_skeleton + real_ends

In [181]:
regions = measure.label(the_smartest_skeleton)
np.max(regions)

9

In [175]:
visualize_lsd(the_smartest_skeleton)

In [182]:
visualize_lsd(regions)

# dethiccation

In [39]:
def dethiccate(thiccness_map, fill_threshold):
    kernels_radii = np.unique(thiccness_map)[1:] - 1.
    
    thin_tree = np.zeros(thiccness_map.shape)
    for radius in kernels_radii:
        mask = convolve_with_ball((thiccness_map > 0).astype(np.uint8), radius, normalize=True) >= fill_threshold
        thin_tree[thiccness_map == radius+1] = \
            thiccness_map[thiccness_map == radius+1] * mask[thiccness_map == radius+1]
        
        print(f'kernel {radius} done')
        
    return thin_tree

In [40]:
%%time
thin_tree = dethiccate(thicc_map, fill_threshold=0.7)

kernel 0.0 done
kernel 1.0 done
kernel 2.0 done
kernel 3.0 done
kernel 4.0 done
kernel 5.0 done
kernel 6.0 done
kernel 7.0 done
kernel 8.0 done
kernel 9.0 done
kernel 10.0 done
kernel 11.0 done
CPU times: user 1min 40s, sys: 21.4 s, total: 2min 2s
Wall time: 2min 9s


In [41]:
ColorMapVisualizer(thin_tree.astype(np.uint8)).visualize(gradient=True)

In [43]:
visualize_addition(thin_tree, thicc_map)