In [1]:
import numpy as np
import glob
import matplotlib.pyplot as plt
import cv2
import pickle
from skimage import measure, segmentation, feature
from vis_utils import load_volume, VolumeVisualizer, ColorMapVisualizer
from scipy.ndimage import zoom
from skimage.morphology import skeletonize, skeletonize_3d, binary_dilation
from skimage import filters, morphology
from scipy.ndimage.filters import convolve, correlate
from scipy import signal
from skimage.filters import frangi, sato
from PIL import Image

In [2]:
TREE_NAME = 'P12'

## Loading specimen volume

In [3]:
source_dir = './data/'
files = {path.split('/')[2]: path for path in sorted(glob.glob(source_dir + '*/*.raw'))}
files

{'P01': './data/P01/P01_60um_1612x623x1108.raw',
 'P04': './data/P04/P04_60um_1273x466x1045.raw',
 'P11': './data/P11/P11_60um_1735x595x1150.raw',
 'P12': './data/P12/P12_60um_1333x443x864.raw'}

In [4]:
%%time

scales = {
    'P04': 0.5,
    'P11': 0.5,
    'P12': 0.5,
}

volume = load_volume(files[TREE_NAME], scale=scales.get(TREE_NAME, 0.5))
print(volume.shape)
VolumeVisualizer(volume, binary=False).visualize()

(432, 222, 666)
CPU times: user 4.1 s, sys: 960 ms, total: 5.06 s
Wall time: 13.1 s


## Utility visualisation functions

In [5]:
def visualize_addition(base, base_with_addition):
    base = (base.copy() > 0).astype(np.uint8)
    addition = (base_with_addition > 0).astype(np.uint8)
    addition[base == 1] = 0
    ColorMapVisualizer(base + addition * 4).visualize()
    
def visualize_lsd(lsd_mask):
    ColorMapVisualizer(lsd_mask.astype(np.uint8)).visualize()
    
def visualize_gradient(lsd_mask):
    ColorMapVisualizer(lsd_mask.astype(np.uint8)).visualize(gradient=True)
    
def visualize_mask_bin(mask):
    VolumeVisualizer((mask > 0).astype(np.uint8), binary=True).visualize()
    
def visualize_mask_non_bin(mask):
    VolumeVisualizer((mask > 0).astype(np.uint8) * 255, binary=False).visualize()
    
def visualize_skeleton(mask, visualize_mask=True, visualize_both_versions=False):
    skeleton = skeletonize((mask > 0).astype(np.uint8))
    if not visualize_mask or visualize_both_versions:
        VolumeVisualizer(skeleton, binary=True).visualize()
    if visualize_mask or visualize_both_versions:
        skeleton = skeleton.astype(np.uint8) * 4
        mask = (mask > 0).astype(np.uint8) * 3
        mask[skeleton != 0] = 0
        ColorMapVisualizer(skeleton + mask).visualize()

def visualize_ultimate(lsd, base_mask):
    visualize_lsd(lsd)
    visualize_mask_non_bin(lsd)
    visualize_addition(base_mask, lsd)
    visualize_skeleton(lsd, visualize_mask=True)

## Thresholding volume to get binary mask

In [6]:
thresholds = {
    'P04': 30,
    'P11': 100, # 40 to get whole tree 
    'P12': 70,
}

mask = volume > thresholds[TREE_NAME]
volume = None
visualize_mask_bin(mask)

## Extracting main region 

In [7]:
def get_main_regions(binary_mask, min_size=10_000, connectivity=3):
    labeled = measure.label(binary_mask, connectivity=connectivity)
    region_props = measure.regionprops(labeled)
    
    main_regions_masks = []
    regions_labels = []
    bounding_boxes = []
    
    for props in region_props:
        if props.area >= min_size:
            main_regions_masks.append(props.filled_image)
            regions_labels.append(props.label)
            bounding_boxes.append(props.bbox)
            
    return main_regions_masks, regions_labels, bounding_boxes

In [8]:
main_region_min_size = {
    'P04': 10_000,
    'P11': 10_000,
    'P12': 10_000,
}

main_regions, _, _ = get_main_regions(mask, min_size=main_region_min_size.get(TREE_NAME, 10_000))
print('number of main regions:', len(main_regions), f'{"PANIC!!!" if len(main_regions) > 1 else ""}')
mask_main = main_regions[0].astype(np.uint8)
mask = None
main_regions = None
visualize_mask_bin(mask_main)

number of main regions: 1 


## Filling holes in the mask to get rid of mistletoes in skeleton

In [9]:
def spherical_kernel(outer_radius, thickness=1, filled=True):    
    outer_sphere = morphology.ball(radius=outer_radius)
    if filled:
        return outer_sphere
    
    thickness = min(thickness, outer_radius)
    
    inner_radius = outer_radius - thickness
    inner_sphere = morphology.ball(radius=inner_radius)
    
    begin = outer_radius - inner_radius
    end = begin + inner_sphere.shape[0]
    outer_sphere[begin:end, begin:end, begin:end] -= inner_sphere
    return outer_sphere

def convolve_with_ball(img, ball_radius, dtype=np.uint16, normalize=True):
    kernel = spherical_kernel(ball_radius, filled=True)
    convolved = signal.convolve(img.astype(dtype), kernel.astype(dtype), mode='same')
    
    if not normalize:
        return convolved
    
    return (convolved / kernel.sum()).astype(np.float16)

def annihilate_jemiolas(mask, kernel_sizes=[10, 9, 8], fill_threshold=0.5, iters=1, conv_dtype=np.uint16):
    kernel_sizes_maps = []
    mask = mask.astype(np.uint8)
    
    for i in range(iters):
        kernel_size_map = np.zeros(mask.shape, dtype=np.uint8)

        for kernel_size in kernel_sizes:
            fill_percentage = convolve_with_ball(mask, kernel_size, dtype=conv_dtype, normalize=True)
            
            above_threshold_fill_indices = fill_percentage > fill_threshold
            kernel_size_map[above_threshold_fill_indices] = kernel_size + 1

            mask[above_threshold_fill_indices] = 1
            
            print(f'Iteration {i + 1} kernel {kernel_size} done')

        kernel_sizes_maps.append(kernel_size_map)
        print(f'Iteration {i + 1} ended successfully')

    return kernel_sizes_maps

In [10]:
kernel_sizes = {
    'P04': range(0, 13),
    'P11': range(0, 13),
    'P12': range(0, 13),
}

number_of_iterations = {
    'P04': 3,
    'P11': 3,
    'P12': 3,
}

lsd_trees = annihilate_jemiolas(mask_main, 
                                kernel_sizes=kernel_sizes.get(TREE_NAME, range(0, 13)), 
                                iters=number_of_iterations.get(TREE_NAME, 3))

Iteration 1 kernel 0 done
Iteration 1 kernel 1 done
Iteration 1 kernel 2 done
Iteration 1 kernel 3 done
Iteration 1 kernel 4 done
Iteration 1 kernel 5 done
Iteration 1 kernel 6 done
Iteration 1 kernel 7 done
Iteration 1 kernel 8 done
Iteration 1 kernel 9 done
Iteration 1 kernel 10 done
Iteration 1 kernel 11 done
Iteration 1 kernel 12 done
Iteration 1 ended successfully
Iteration 2 kernel 0 done
Iteration 2 kernel 1 done
Iteration 2 kernel 2 done
Iteration 2 kernel 3 done
Iteration 2 kernel 4 done
Iteration 2 kernel 5 done
Iteration 2 kernel 6 done
Iteration 2 kernel 7 done
Iteration 2 kernel 8 done
Iteration 2 kernel 9 done
Iteration 2 kernel 10 done
Iteration 2 kernel 11 done
Iteration 2 kernel 12 done
Iteration 2 ended successfully
Iteration 3 kernel 0 done
Iteration 3 kernel 1 done
Iteration 3 kernel 2 done
Iteration 3 kernel 3 done
Iteration 3 kernel 4 done
Iteration 3 kernel 5 done
Iteration 3 kernel 6 done
Iteration 3 kernel 7 done
Iteration 3 kernel 8 done
Iteration 3 kernel 9 d

## Verifying obtained reconstruction 

In [11]:
# if the reconstruction looks bad try using results of previous iterations
# if the skelecon still has mistletoes try increasing number of iterations
reconstruction = (lsd_trees[-1] > 0).astype(np.uint8)
visualize_mask_non_bin(reconstruction) # check for holes
visualize_skeleton(reconstruction) # check for mistletoes
visualize_addition(mask_main, reconstruction) # check for anomalies

## Saving the reconstruction

In [12]:
reconstruction = np.pad(reconstruction, 1) # padding reconstruction to avoid padding later
np.save(source_dir + TREE_NAME + '/reconstruction', reconstruction)