In [1]:
import numpy as np
import glob
import matplotlib.pyplot as plt
import cv2
from skimage import measure, segmentation
from vis_utils import load_volume, VolumeVisualizer, ColorMapVisualizer
from scipy.ndimage import zoom
from skimage.morphology import skeletonize, skeletonize_3d

from skimage import filters, morphology

from scipy.ndimage.filters import convolve, correlate
from scipy import signal

from skimage.filters import frangi, sato

from PIL import Image
import pickle

In [2]:
source_dir = './data/'
files = list(sorted(glob.glob(source_dir + '/*/*.raw')))
list(enumerate(files))

[(0, './data/P01/P01_60um_1612x623x1108.raw'),
 (1, './data/P02/P02_60um_1387x778x1149.raw'),
 (2, './data/P03/P03_60um_1473x1163x1148.raw'),
 (3, './data/P04/P04_60um_1273x466x1045.raw'),
 (4, './data/P05/P05_60um_1454x817x1102.raw'),
 (5, './data/P06/P06_60um_1425x564x1028.raw'),
 (6, './data/P07/P7_60um_1216x692x926.raw'),
 (7, './data/P08/P08_60um_1728x927x1149.raw'),
 (8, './data/P09/P09_60um_1359x456x1040.raw'),
 (9, './data/P10/P10_60um_1339x537x1035.raw'),
 (10, './data/P11/P11_60um_1735x595x1150.raw'),
 (11, './data/P12/P12_60um_1333x443x864.raw'),
 (12, './data/P13/P13_60um_1132x488x877.raw')]

In [13]:
%%time
threshold = 70

mask = load_volume(files[11], scale=0.5)
mask = mask > threshold
print(mask.shape)

(432, 222, 666)
CPU times: user 1.43 s, sys: 1.19 s, total: 2.62 s
Wall time: 2.88 s


In [4]:
# VolumeVisualizer(mask, binary=True).visualize()

## utility functions

In [25]:
def spherical_kernel(outer_radius, thickness=1, filled=True):    
    outer_sphere = morphology.ball(radius=outer_radius)
    if filled:
        return outer_sphere
    
    inner_radius = outer_radius - thickness
    inner_sphere = morphology.ball(radius=inner_radius)
    
    begin = outer_radius - inner_radius
    end = begin + inner_sphere.shape[0]
    outer_sphere[begin:end, begin:end, begin:end] -= inner_sphere
    return outer_sphere

def convolve_with_ball(img, ball_radius, dtype=np.uint16, normalize=True):
    kernel = spherical_kernel(ball_radius, filled=True)
    convolved = signal.convolve(img.astype(dtype), kernel.astype(dtype), mode='same')
    
    if not normalize:
        return convolved
    
    return (convolved / kernel.sum()).astype(np.float16)

def get_arterial_regions(conv_img, lower_hyst_fraction, upper_hyst_fraction):
    lower_hyst_value = lower_hyst_fraction * conv_img.max()
    upper_hyst_value = upper_hyst_fraction * conv_img.max()
    return filters.apply_hysteresis_threshold(conv_img, lower_hyst_value, upper_hyst_value)

def reconstruct_from_skeleton(skeleton, ball_radius):    
    mask = np.zeros(skeleton.shape, dtype=np.uint8)
    mask = np.pad(mask, ball_radius)
    
    kernel = spherical_kernel(ball_radius, filled=True)
    central_points = np.argwhere(skeleton == 1)
    
    for central_point in central_points:
        start_corner = tuple(central_point)
        end_corner = tuple(central_point + 2*ball_radius + 1)
        
        start1, start2, start3 = start_corner
        end1, end2, end3 = end_corner
        
        mask_slice = mask[start1:end1, start2:end2, start3:end3]
        mask_slice[:] = np.logical_or(mask_slice, kernel)
                
    return mask[ball_radius:-ball_radius, ball_radius:-ball_radius, ball_radius:-ball_radius]

# high level functions

def get_tree_core(tree_mask, kernel_radius, max_fraction):
    convolved_mask = convolve_with_ball(tree_mask, kernel_radius)
    core_voxels = convolved_mask > max_fraction * convolved_mask.max()
    core_skeleton = skeletonize_3d(core_voxels.astype(np.uint8))
    core_reconstruction = reconstruct_from_skeleton(core_skeleton, kernel_radius)
    
    return core_reconstruction

def expand_tree_reconstruction(tree_mask, reconstruction, kernel_radius, max_fraction):
    convolved_mask = convolve_with_ball(tree_mask, kernel_radius)
    
    kernel_vol = spherical_kernel(kernel_radius).sum()
    threshold_value = int(max_fraction * kernel_vol)
    
    # set current reconstruction to infinity
    convolved_mask_with_huge_core = convolved_mask + reconstruction * (kernel_vol + 2)
        
        
    expanded_rec = filters.apply_hysteresis_threshold(convolved_mask_with_huge_core, threshold_value, kernel_vol + 5)
    expansion = expanded_rec - reconstruction
    
    convolved_mask_with_huge_expansion = convolved_mask + expansion * (kernel_vol + 2)
    expanded_expansion = filters.apply_hysteresis_threshold(convolved_mask_with_huge_expansion, threshold_value, kernel_vol + 5)
    
    ee_skeleton = skeletonize_3d(expanded_expansion.astype(np.uint8))
    ee_reconstruction = reconstruct_from_skeleton(ee_skeleton, kernel_radius)
    
    return ee_reconstruction#, ee_skeleton

def get_main_regions(binary_mask, min_size=10_000, connectivity=3):
    labeled = measure.label(binary_mask, connectivity=connectivity)
    region_props = measure.regionprops(labeled)
    
    main_regions_masks = []
    
    for props in region_props:
        if props.area >= min_size:
            main_regions_masks.append((props.filled_image, props.bbox))
            
    return main_regions_masks

def fill_jemiola(shape, cache, fill_threshold=0.5):

    kernel_size_map = np.zeros(shape, dtype=np.uint8)
    best_fill_percentage_map = np.zeros(shape, dtype=np.float16)

    for key, conv_map in cache.items():
        better_fill = conv_map > best_fill_percentage_map
        kernel_size_map[better_fill] = key
        best_fill_percentage_map[better_fill] = conv_map[better_fill]

    kernel_size_map *= best_fill_percentage_map > fill_threshold
    
    return kernel_size_map.astype(np.uint8)

## mask main region

In [16]:
mask_main = get_main_regions(mask)[0][0]
# VolumeVisualizer(mask_main, binary=True).visualize()

## caching convolution results

In [7]:
with open('convolved_volume_cache', 'rb') as f:
    cache_iter1 = pickle.load(f)
    
with open('double_cache', 'rb') as f:
    cache_iter2 = pickle.load(f)

In [13]:
ultimate_mask = mask_main.copy().astype(np.uint8)
ultimate_mask += fill_jemiola(mask_main.shape, cache_iter1, fill_threshold=0.5)
ultimate_mask += fill_jemiola(mask_main.shape, cache_iter2, fill_threshold=0.5)

In [4]:
# VolumeVisualizer((ultimate_mask > 0).astype(np.uint8) * 255, binary=False).visualize()
# np.save('mask_after_2_iters', (ultimate_mask > 0))
mask_after_2_iters = np.load('mask_after_2_iters.npy')
VolumeVisualizer((mask_after_2_iters).astype(np.uint8) * 255, binary=False).visualize()

In [5]:
cache_iter3 = {}
kernel_sizes = range(1, 11)

for k in kernel_sizes:
    
    if k in cache_iter3.keys():
        print(f'skipping {k}')
        continue
    
    cache_iter3[k] = convolve_with_ball(mask_after_2_iters, k)
    print(k)

1
2
3
4
5
6
7
8
9
10


In [6]:
mask2 = mask_after_2_iters.copy()
mask3 = fill_jemiola(mask_after_2_iters.shape, cache_iter3, fill_threshold=0.5)

ColorMapVisualizer(mask3).visualize()

In [8]:
lol = mask2.copy()
lsd = mask3 > 0
lsd[lol == 1] = 0
VolumeVisualizer(lsd * 255, binary=False).visualize()

In [22]:
reconstruction = np.minimum(lsd.astype(np.uint8) + lol.astype(np.uint8), 1)
VolumeVisualizer(reconstruction, binary=True).visualize()

In [20]:
ColorMapVisualizer(mask_main.astype(np.uint8)*2 + reconstruction).visualize()

## igranie z ogniem

In [29]:
# %%time

rec = get_tree_core(reconstruction, 15, 0.95)
total_rec = rec.copy().astype(np.uint8)
new_rec = rec.copy()
print('core is nice')

for i, kernel_radius in enumerate([10, 9, 8]):#, 7, 6, 5, 4, 3, 2, 1]):
    new_rec = expand_tree_reconstruction(reconstruction, new_rec, kernel_radius=kernel_radius, max_fraction=0.5)
    rec = np.logical_or(rec, new_rec).astype(np.uint8)
    
    just_expansion = new_rec.copy()
    just_expansion[total_rec > 0] = 0
    total_rec += just_expansion * (i + 2)
    
    print('iter for', kernel_radius, 'ended successfully XD')

core is nice
iter for 10 ended successfully XD
iter for 9 ended successfully XD
iter for 8 ended successfully XD


In [32]:
ColorMapVisualizer(skeletonize_3d(reconstruction)).visualize()