In [1]:
import numpy as np
import glob
import matplotlib.pyplot as plt
import cv2
from skimage import measure, segmentation
from vis_utils import load_volume, VolumeVisualizer, ColorMapVisualizer
from scipy.ndimage import zoom
from skimage.morphology import skeletonize, skeletonize_3d

from skimage import filters, morphology

from scipy.ndimage.filters import convolve, correlate
from scipy import signal

from skimage.filters import frangi, sato

from PIL import Image
import pickle

In [2]:
source_dir = './data/'
files = list(sorted(glob.glob(source_dir + '/*.raw')))
list(enumerate(files))

[(0, './data/P01_60um_1612x623x1108.raw'),
 (1, './data/P12_60um_1333x443x864.raw')]

In [3]:
%%time
mask = load_volume(files[0], scale=0.5)
mask.shape

CPU times: user 2.86 s, sys: 1.39 s, total: 4.25 s
Wall time: 4.93 s


(554, 312, 806)

In [4]:
threshold = 20
mask = mask > threshold
# VolumeVisualizer(mask, binary=True).visualize()

## Utility functions

### Visualisation functions

In [2]:
def visualize_addition(base, base_with_addition):
    base = (base.copy() > 0).astype(np.uint8)
    addition = (base_with_addition > 0).astype(np.uint8)
    addition[base == 1] = 0
    ColorMapVisualizer(base + addition * 2).visualize()
    
def visualize_lsd(lsd_mask):
    ColorMapVisualizer(lsd_mask.astype(np.uint8)).visualize()
    
def visualize_mask_bin(mask):
    VolumeVisualizer((mask > 0).astype(np.uint8), binary=True).visualize()
    
def visualize_mask_non_bin(mask):
    VolumeVisualizer((mask > 0).astype(np.uint8) * 255, binary=False).visualize()
    
def visualize_skeleton(mask, visualize_mask=True, visualize_both_versions=False):
    skeleton = skeletonize_3d((mask > 0).astype(np.uint8))
    if not visualize_mask or visualize_both_versions:
        VolumeVisualizer(skeleton, binary=True).visualize()
    if visualize_mask or visualize_both_versions:
        skeleton = skeleton.astype(np.uint8) * 4
        mask = (mask > 0).astype(np.uint8) * 3
        mask[skeleton != 0] = 0
        ColorMapVisualizer(skeleton + mask).visualize()

def visualize_ultimate(lsd, base_mask):
    visualize_lsd(lsd)
    visualize_mask_non_bin(lsd)
    visualize_addition(base_mask, lsd)
    visualize_skeleton(lsd, visualize_mask=True)

In [3]:
def spherical_kernel(outer_radius, thickness=1, filled=True):    
    outer_sphere = morphology.ball(radius=outer_radius)
    if filled:
        return outer_sphere
    
    inner_radius = outer_radius - thickness
    inner_sphere = morphology.ball(radius=inner_radius)
    
    begin = outer_radius - inner_radius
    end = begin + inner_sphere.shape[0]
    outer_sphere[begin:end, begin:end, begin:end] -= inner_sphere
    return outer_sphere

def convolve_with_ball(img, ball_radius, dtype=np.uint16, normalize=True):
    kernel = spherical_kernel(ball_radius, filled=True)
    convolved = signal.convolve(img.astype(dtype), kernel.astype(dtype), mode='same')
    
    if not normalize:
        return convolved
    
    return (convolved / kernel.sum()).astype(np.float16)


def get_main_regions(binary_mask, min_size=10_000, connectivity=3):
    labeled = measure.label(binary_mask, connectivity=connectivity)
    region_props = measure.regionprops(labeled)
    
    main_regions_masks = []
    
    for props in region_props:
        if props.area >= min_size:
            main_regions_masks.append((props.filled_image, props.bbox))
            
    return main_regions_masks

In [4]:
def annihilate_jemiolas(mask, kernel_sizes=[10, 9, 8], fill_threshold=0.5, iters=1, conv_dtype=np.uint16):

    kernel_sizes_maps = []
    mask = mask.astype(np.uint8)
    
    for i in range(iters):
        kernel_size_map = np.zeros(mask.shape, dtype=np.uint8)
        best_fill_percentage = np.zeros(mask.shape, dtype=np.float16)

        for kernel_size in kernel_sizes:
            fill_percentage = convolve_with_ball(mask, kernel_size, dtype=conv_dtype, normalize=True)
            better_fill_indices = fill_percentage > best_fill_percentage
            kernel_size_map[better_fill_indices] = kernel_size
            best_fill_percentage[better_fill_indices] = fill_percentage[better_fill_indices]
            print(f'Iteration {i + 1} kernel {kernel_size} done')

        kernel_size_map *= best_fill_percentage > fill_threshold
        kernel_sizes_maps.append(kernel_size_map)
        mask = np.minimum(kernel_size_map + mask, 1).astype(np.uint8)        
        print(f'Iteration {i + 1} ended successfully')

    return kernel_sizes_maps

In [5]:
def annihilate_jemiolas_faster(mask, kernel_sizes=[10, 9, 8], fill_threshold=0.5, iters=1, conv_dtype=np.uint16):

    kernel_sizes_maps = []
    mask = mask.astype(np.uint8)
    
    for i in range(iters):
        kernel_size_map = np.zeros(mask.shape, dtype=np.uint8)

        for kernel_size in kernel_sizes:
            fill_percentage = convolve_with_ball(mask, kernel_size, dtype=conv_dtype, normalize=True)
            
            above_threshold_fill_indices = fill_percentage > fill_threshold
            kernel_size_map[above_threshold_fill_indices] = kernel_size

            mask[above_threshold_fill_indices] = 1
            
            print(f'Iteration {i + 1} kernel {kernel_size} done')

        kernel_sizes_maps.append(kernel_size_map)
        print(f'Iteration {i + 1} ended successfully')

    return kernel_sizes_maps

## Main region extraction

In [9]:
mask_main = get_main_regions(mask)[0][0].astype(np.uint8)
# VolumeVisualizer(mask_main, binary=True).visualize()
# VolumeVisualizer(skeletonize_3d(mask_main.astype(np.uint8)), binary=True).visualize()

# Fun

In [17]:
def central_annihilation(mask, kernel_sizes=[10, 9, 8], fill_threshold=0.8, conv_dtype=np.uint16):

#     kernel_sizes_maps = []
    mask = mask.astype(np.uint8)
    
#     for i in range(iters):
    kernel_size_map = np.zeros(mask.shape, dtype=np.uint8)
    best_fill_percentage = np.zeros(mask.shape, dtype=np.float16)

    for kernel_size in sorted(kernel_sizes):
        fill_percentage = convolve_with_ball(mask, kernel_size, dtype=conv_dtype, normalize=True)
        
        above_threshold_fill_indices = fill_percentage >= fill_threshold
        
        kernel_size_map[above_threshold_fill_indices] = kernel_size if kernel_size != 0 else 30
        
        best_fill_percentage[above_threshold_fill_indices] = fill_percentage[above_threshold_fill_indices]
        
        print(f'Kernel {kernel_size} done')

#     kernel_size_map *= best_fill_percentage > fill_threshold

    return kernel_size_map, best_fill_percentage

In [7]:
with open('./filled_trees/P01/lsd_1st_to_6tf_faster.trees', 'rb') as f:
    lsd_trees = pickle.load(f)

best_filled_tree = (lsd_trees[-1] > 0).astype(np.uint8)

In [8]:
lsd_trees = []

In [18]:
kernel_size_map, best_fill_percentage = central_annihilation(best_filled_tree, kernel_sizes=range(0, 3), fill_threshold=1)

Kernel 0 done
Kernel 1 done
Kernel 2 done


In [19]:
visualize_lsd(kernel_size_map)

In [13]:
visualize_addition(kernel_size_map, best_filled_tree)

In [14]:
spherical_kernel(0, filled=True)

array([[[1]]], dtype=uint8)

In [12]:
with open('./filled_trees/P01/lsd_1st_to_6tf_faster.trees', 'rb') as f:
    lsd_trees = pickle.load(f)

In [16]:
visualize_mask_non_bin(mask_main)

In [13]:
best_filled_tree = lsd_trees[-1]

In [18]:
visualize_mask_non_bin(best_filled_tree)

In [19]:
visualize_addition(mask_main, best_filled_tree)

In [14]:
visualize_skeleton(best_filled_tree)

In [10]:
lsd_trees = annihilate_jemiolas_faster(mask_main, list(range(1, 13)), iters=6)

Iteration 1 kernel 1 done
Iteration 1 kernel 2 done
Iteration 1 kernel 3 done
Iteration 1 kernel 4 done
Iteration 1 kernel 5 done
Iteration 1 kernel 6 done
Iteration 1 kernel 7 done
Iteration 1 kernel 8 done
Iteration 1 kernel 9 done
Iteration 1 kernel 10 done
Iteration 1 kernel 11 done
Iteration 1 kernel 12 done
Iteration 1 ended successfully
Iteration 2 kernel 1 done
Iteration 2 kernel 2 done
Iteration 2 kernel 3 done
Iteration 2 kernel 4 done
Iteration 2 kernel 5 done
Iteration 2 kernel 6 done
Iteration 2 kernel 7 done
Iteration 2 kernel 8 done
Iteration 2 kernel 9 done
Iteration 2 kernel 10 done
Iteration 2 kernel 11 done
Iteration 2 kernel 12 done
Iteration 2 ended successfully
Iteration 3 kernel 1 done
Iteration 3 kernel 2 done
Iteration 3 kernel 3 done
Iteration 3 kernel 4 done
Iteration 3 kernel 5 done
Iteration 3 kernel 6 done
Iteration 3 kernel 7 done
Iteration 3 kernel 8 done
Iteration 3 kernel 9 done
Iteration 3 kernel 10 done
Iteration 3 kernel 11 done
Iteration 3 kernel 1

In [9]:
lsd_trees = annihilate_jemiolas(mask_main, list(range(1, 13)), iters=10)

Iteration 1 kernel 1 done
Iteration 1 kernel 2 done
Iteration 1 kernel 3 done
Iteration 1 kernel 4 done
Iteration 1 kernel 5 done
Iteration 1 kernel 6 done
Iteration 1 kernel 7 done
Iteration 1 kernel 8 done
Iteration 1 kernel 9 done
Iteration 1 kernel 10 done
Iteration 1 kernel 11 done
Iteration 1 kernel 12 done
Iteration 1 ended successfully
Iteration 2 kernel 1 done
Iteration 2 kernel 2 done
Iteration 2 kernel 3 done
Iteration 2 kernel 4 done
Iteration 2 kernel 5 done
Iteration 2 kernel 6 done
Iteration 2 kernel 7 done
Iteration 2 kernel 8 done
Iteration 2 kernel 9 done
Iteration 2 kernel 10 done
Iteration 2 kernel 11 done
Iteration 2 kernel 12 done
Iteration 2 ended successfully
Iteration 3 kernel 1 done
Iteration 3 kernel 2 done
Iteration 3 kernel 3 done
Iteration 3 kernel 4 done
Iteration 3 kernel 5 done
Iteration 3 kernel 6 done
Iteration 3 kernel 7 done
Iteration 3 kernel 8 done
Iteration 3 kernel 9 done
Iteration 3 kernel 10 done
Iteration 3 kernel 11 done
Iteration 3 kernel 1

In [10]:
len(lsd_trees)

10

In [12]:
pickle.dump(lsd_trees, open( "lsd_1st_to_6tf_faster.trees", "wb" ) )

In [11]:
visualize_ultimate(lsd_trees[-1], mask_main)

In [16]:
with open('lsd_1st_to_6tf_faster', 'rb') as f:
    lsd_trees = pickle.load(f)

In [10]:
latest_lsd_tree = lsd_trees[-1]

In [11]:
lsd_trees = []

In [12]:
new_lsd_trees = annihilate_jemiolas(latest_lsd_tree, list(range(1, 13)), iters=8)

Iteration 1 kernel 1 done
Iteration 1 kernel 2 done
Iteration 1 kernel 3 done
Iteration 1 kernel 4 done
Iteration 1 kernel 5 done
Iteration 1 kernel 6 done
Iteration 1 kernel 7 done
Iteration 1 kernel 8 done
Iteration 1 kernel 9 done
Iteration 1 kernel 10 done
Iteration 1 kernel 11 done
Iteration 1 kernel 12 done
Iteration 1 ended successfully
Iteration 2 kernel 1 done
Iteration 2 kernel 2 done
Iteration 2 kernel 3 done
Iteration 2 kernel 4 done
Iteration 2 kernel 5 done
Iteration 2 kernel 6 done
Iteration 2 kernel 7 done
Iteration 2 kernel 8 done
Iteration 2 kernel 9 done
Iteration 2 kernel 10 done
Iteration 2 kernel 11 done
Iteration 2 kernel 12 done
Iteration 2 ended successfully
Iteration 3 kernel 1 done
Iteration 3 kernel 2 done
Iteration 3 kernel 3 done
Iteration 3 kernel 4 done
Iteration 3 kernel 5 done
Iteration 3 kernel 6 done
Iteration 3 kernel 7 done
Iteration 3 kernel 8 done
Iteration 3 kernel 9 done
Iteration 3 kernel 10 done
Iteration 3 kernel 11 done
Iteration 3 kernel 1

In [13]:
visualize_ultimate(new_lsd_trees[-1], mask_main)

In [15]:
# pickle.dump(new_lsd_trees, open( "lsd_11th_to_18th.trees", "wb" ) )

In [13]:
print(len(lsd_trees), len(new_lsd_trees))

NameError: name 'new_lsd_trees' is not defined