In [1]:
import numpy as np
import glob
import matplotlib.pyplot as plt
import cv2
from skimage import measure, segmentation
from vis_utils import load_volume, VolumeVisualizer
from scipy.ndimage import zoom
from skimage.morphology import skeletonize, skeletonize_3d
from skimage import filters, morphology

from scipy.ndimage.filters import convolve, correlate
from scipy import signal

In [2]:
source_dir = './data/'
files = list(sorted(glob.glob(source_dir + '/*/*.raw')))
list(enumerate(files))

[(0, './data/P01/P01_60um_1612x623x1108.raw'),
 (1, './data/P02/P02_60um_1387x778x1149.raw'),
 (2, './data/P03/P03_60um_1473x1163x1148.raw'),
 (3, './data/P04/P04_60um_1273x466x1045.raw'),
 (4, './data/P05/P05_60um_1454x817x1102.raw'),
 (5, './data/P06/P06_60um_1425x564x1028.raw'),
 (6, './data/P07/P7_60um_1216x692x926.raw'),
 (7, './data/P08/P08_60um_1728x927x1149.raw'),
 (8, './data/P09/P09_60um_1359x456x1040.raw'),
 (9, './data/P10/P10_60um_1339x537x1035.raw'),
 (10, './data/P11/P11_60um_1735x595x1150.raw'),
 (11, './data/P12/P12_60um_1333x443x864.raw'),
 (12, './data/P13/P13_60um_1132x488x877.raw'),
 (13, './data/P14/P14_60um_1927x746x1124.raw'),
 (14, './data/P15/P15_60um_1318x640x1059.raw'),
 (15, './data/P16/P16_60um_1558x687x1084.raw'),
 (16, './data/P17/P17_60um_1573x555x968.raw'),
 (17, './data/P320/320_60um_1739x553x960.raw'),
 (18, './data/P333/333_60um_1762x989x1095.raw'),
 (19, './data/P73/73_60um_1729x854x1143.raw')]

In [3]:
%%time
volume = load_volume(files[11], scale=0.5)
visualizer = VolumeVisualizer(volume, binary=False).visualize()

CPU times: user 4.62 s, sys: 833 ms, total: 5.45 s
Wall time: 21.4 s


## simple threshold segmentation

In [4]:
threshold = 70
mask_raw = volume > threshold

VolumeVisualizer(mask_raw).visualize()

In [7]:
def get_main_regions(binary_mask, min_size=10_000, connectivity=3):
    labeled = measure.label(binary_mask, connectivity=connectivity)
    region_props = measure.regionprops(labeled)
    
    main_regions_masks = []
    
    for props in region_props:
        if props.area >= min_size:
            main_regions_masks.append((props.filled_image, props.bbox))
            
    return main_regions_masks

def merge_masks(masks, img_shape):
    result_mask = np.zeros(img_shape, dtype=np.uint8)
    for mask, bbox in masks:
        min1, min2, min3, max1, max2, max3 = bbox
        result_mask[min1:max1, min2:max2, min3:max3] += mask.astype(np.uint8)
        
    return result_mask

In [9]:
%%time
main_regions_masks = get_main_regions(mask_raw, min_size=5_000, connectivity=1)

CPU times: user 7.78 s, sys: 1.36 s, total: 9.14 s
Wall time: 9.99 s


In [11]:
mask = main_regions_masks[0][0]
VolumeVisualizer(mask).visualize()

## utility functions

In [59]:
def spherical_kernel(outer_radius, thickness=1, filled=True):
   
    inner_radius = outer_radius - thickness
    
    outer_sphere = morphology.ball(radius=outer_radius)
    if filled:
        return outer_sphere
    
    inner_sphere = morphology.ball(radius=inner_radius)
    
    begin = outer_radius - inner_radius
    end = begin + inner_sphere.shape[0]
    
    outer_sphere[begin:end, begin:end, begin:end] -= inner_sphere
    return outer_sphere

def convolve_with_ball(mask, ball_radius, dtype=np.uint16):
    kernel = spherical_kernel(ball_radius, filled=True)
    return signal.convolve(mask.astype(dtype), kernel.astype(dtype), mode='same')

def get_arterial_regions(conv_img, lower_hist_fraction, upper_hist_fraction):
    upper_hist_value = upper_hist_fraction * conv_img.max()
    lower_hist_value = lower_hist_fraction * conv_img.max()
    return filters.apply_hysteresis_threshold(conv_img, lower_hist_value, upper_hist_value)

def reconstruct_from_skeleton(skeleton, ball_radius):
    
    print(skeleton.shape)
    
    mask = np.zeros(skeleton.shape, dtype=np.uint8)
    mask = np.pad(mask, ball_radius)
    
    kernel = spherical_kernel(ball_radius, filled=True)
    central_points = np.argwhere(skeleton == 1)
    
    for central_point in central_points:
        start_corner = tuple(central_point)
        end_corner = tuple(central_point + 2*ball_radius + 1)
        
        start1, start2, start3 = start_corner
        end1, end2, end3 = end_corner
        
        mask_slice = mask[start1:end1, start2:end2, start3:end3]
        mask_slice[:] = np.logical_or(mask_slice, kernel)
        
    print(mask.shape)
        
    return mask[ball_radius:-ball_radius, ball_radius:-ball_radius, ball_radius:-ball_radius]

# high level functions

def get_tree_core(tree_mask, kernel_radius, max_fraction):
    convolved_mask = convolve_with_ball(tree_mask, kernel_radius)
    core_voxels = convolved_mask > max_fraction * convolved_mask.max()
    core_skeleton = skeletonize_3d(core_voxels.astype(np.uint8))
    core_reconstruction = reconstruct_from_skeleton(core_skeleton, kernel_radius)
    
    return core_reconstruction

def expand_tree_reconstruction(tree_mask, reconstruction, kernel_radius, max_fraction):
    convolved_mask = convolve_with_ball(tree_mask, kernel_radius)
    
    convolved_max = convolved_mask.max()
    threshold_value = int(max_fraction * convolved_max)
    
    # set current reconstruction to infinity
    convolved_mask += reconstruction * (convolved_max + 2)
    
#     return convolved_mask
    
    return filters.apply_hysteresis_threshold(convolved_mask, threshold_value, convolved_max + 1)

In [40]:
%%time
core_rec = get_tree_core(mask, 20, 0.95)

(411, 196, 645)
(451, 236, 685)
CPU times: user 8.48 s, sys: 3.18 s, total: 11.7 s
Wall time: 12.4 s


In [41]:
VolumeVisualizer(np.logical_or(core_rec, mask)).visualize()

In [64]:
%%time
rec = expand_tree_reconstruction(mask, core_rec, kernel_radius=10, max_fraction=0.2)

CPU times: user 7.08 s, sys: 3.98 s, total: 11.1 s
Wall time: 12.4 s


In [65]:
VolumeVisualizer(rec).visualize()