In [1]:
import numpy as np
import glob
import matplotlib.pyplot as plt
import cv2
from skimage import measure, segmentation
from vis_utils import load_volume, VolumeVisualizer, ColorMapVisualizer
from scipy.ndimage import zoom
from skimage.morphology import skeletonize, skeletonize_3d

from skimage import filters, morphology

from scipy.ndimage.filters import convolve, correlate
from scipy import signal

from skimage.filters import frangi, sato

from PIL import Image
import pickle

In [2]:
source_dir = './data/'
files = list(sorted(glob.glob(source_dir + '/*.raw')))
list(enumerate(files))

[(0, './data/P01_60um_1612x623x1108.raw'),
 (1, './data/P12_60um_1333x443x864.raw')]

In [3]:
%%time
mask = load_volume(files[1], scale=0.5)
mask.shape

CPU times: user 1.28 s, sys: 439 ms, total: 1.72 s
Wall time: 2.53 s


(432, 222, 666)

In [4]:
threshold = 70
mask = mask > threshold
# VolumeVisualizer(mask, binary=True).visualize()

## Utility functions

### Visualisation functions

In [5]:
def visualize_addition(base, base_with_addition):
    base = (base.copy() > 0).astype(np.uint8)
    addition = (base_with_addition > 0).astype(np.uint8)
    addition[base == 1] = 0
    ColorMapVisualizer(base + addition * 2).visualize()
    
def visualize_lsd(lsd_mask):
    ColorMapVisualizer(lsd_mask.astype(np.uint8)).visualize()
    
def visualize_mask_bin(mask):
    VolumeVisualizer((mask > 0).astype(np.uint8), binary=True).visualize()
    
def visualize_mask_non_bin(mask):
    VolumeVisualizer((mask > 0).astype(np.uint8) * 255, binary=False).visualize()
    
def visualize_skeleton(mask, visualize_mask=True, visualize_both_versions=False):
    skeleton = skeletonize_3d((mask > 0).astype(np.uint8))
    if not visualize_mask or visualize_both_versions:
        VolumeVisualizer(skeleton, binary=True).visualize()
    if visualize_mask or visualize_both_versions:
        skeleton = skeleton.astype(np.uint8) * 4
        mask = (mask > 0).astype(np.uint8) * 3
        mask[skeleton != 0] = 0
        ColorMapVisualizer(skeleton + mask).visualize()

def visualize_ultimate(lsd, base_mask):
    visualize_lsd(lsd)
    visualize_mask_non_bin(lsd)
    visualize_addition(base_mask, lsd)
    visualize_skeleton(lsd, visualize_mask=True)

In [6]:
def load_lsd_trees(filename):
    with open(filename, 'rb') as f:
        lsd_trees = pickle.load(f)
    return lsd_trees

def save_lsd_trees(lsd_trees, filename):
    with open(filename, 'wb') as f:
        pickle.dump(lsd_trees, f)

In [7]:
def spherical_kernel(outer_radius, thickness=1, filled=True):    
    outer_sphere = morphology.ball(radius=outer_radius)
    if filled:
        return outer_sphere
    
    inner_radius = outer_radius - thickness
    inner_sphere = morphology.ball(radius=inner_radius)
    
    begin = outer_radius - inner_radius
    end = begin + inner_sphere.shape[0]
    outer_sphere[begin:end, begin:end, begin:end] -= inner_sphere
    return outer_sphere

def convolve_with_ball(img, ball_radius, dtype=np.uint16, normalize=True):
    kernel = spherical_kernel(ball_radius, filled=True)
    convolved = signal.convolve(img.astype(dtype), kernel.astype(dtype), mode='same')
    
    if not normalize:
        return convolved
    
    return (convolved / kernel.sum()).astype(np.float16)


def get_main_regions(binary_mask, min_size=10_000, connectivity=3):
    labeled = measure.label(binary_mask, connectivity=connectivity)
    region_props = measure.regionprops(labeled)
    
    main_regions_masks = []
    
    for props in region_props:
        if props.area >= min_size:
            main_regions_masks.append((props.filled_image, props.bbox))
            
    return main_regions_masks

In [8]:
def annihilate_jemiolas(mask, kernel_sizes=[10, 9, 8], fill_threshold=0.5, iters=1, conv_dtype=np.uint16):

    kernel_sizes_maps = []
    mask = mask.astype(np.uint8)
    
    for i in range(iters):
        kernel_size_map = np.zeros(mask.shape, dtype=np.uint8)
        best_fill_percentage = np.zeros(mask.shape, dtype=np.float16)

        for kernel_size in kernel_sizes:
            fill_percentage = convolve_with_ball(mask, kernel_size, dtype=conv_dtype, normalize=True)
            better_fill_indices = fill_percentage > best_fill_percentage
            kernel_size_map[better_fill_indices] = kernel_size
            best_fill_percentage[better_fill_indices] = fill_percentage[better_fill_indices]
            print(f'Iteration {i + 1} kernel {kernel_size} done')

        kernel_size_map *= best_fill_percentage > fill_threshold
        kernel_sizes_maps.append(kernel_size_map)
        mask = np.minimum(kernel_size_map + mask, 1).astype(np.uint8)        
        print(f'Iteration {i + 1} ended successfully')

    return kernel_sizes_maps

In [9]:
def annihilate_jemiolas_faster(mask, kernel_sizes=[10, 9, 8], fill_threshold=0.5, iters=1, conv_dtype=np.uint16):

    kernel_sizes_maps = []
    mask = mask.astype(np.uint8)
    
    for i in range(iters):
        kernel_size_map = np.zeros(mask.shape, dtype=np.uint8)

        for kernel_size in kernel_sizes:
            fill_percentage = convolve_with_ball(mask, kernel_size, dtype=conv_dtype, normalize=True)
            
            above_threshold_fill_indices = fill_percentage > fill_threshold
            kernel_size_map[above_threshold_fill_indices] = kernel_size

            mask[above_threshold_fill_indices] = 1
            
            print(f'Iteration {i + 1} kernel {kernel_size} done')

        kernel_sizes_maps.append(kernel_size_map)
        print(f'Iteration {i + 1} ended successfully')

    return kernel_sizes_maps

## Main region extraction

In [10]:
mask_main = get_main_regions(mask)[0][0].astype(np.uint8)
VolumeVisualizer(mask_main, binary=True).visualize()
# VolumeVisualizer(skeletonize_3d(mask_main.astype(np.uint8)), binary=True).visualize()

# Fun

In [11]:
def central_annihilation(mask, kernel_sizes=[10, 9, 8], fill_threshold=0.8, conv_dtype=np.uint16):

    mask = mask.astype(np.uint8)
    
    kernel_size_map = np.zeros(mask.shape, dtype=np.uint8)
    best_fill_percentage = np.zeros(mask.shape, dtype=np.float16)

    for kernel_size in sorted(kernel_sizes):
        fill_percentage = convolve_with_ball(mask, kernel_size, dtype=conv_dtype, normalize=True)
        above_threshold_fill_indices = fill_percentage >= fill_threshold
        kernel_size_map[above_threshold_fill_indices] = kernel_size if kernel_size != 0 else 30
        best_fill_percentage[above_threshold_fill_indices] = fill_percentage[above_threshold_fill_indices]
        print(f'Kernel {kernel_size} done')

    return kernel_size_map, best_fill_percentage

In [12]:
lsd_trees = load_lsd_trees('filled_trees/P01/lsd_1st_to_6tf_faster.trees')
print(len(lsd_trees))
best_lsd_tree = lsd_trees[-1]
lsd_trees = []
visualize_lsd(best_lsd_tree)

6


In [None]:
kernel_size_map, best_fill_percentage = central_annihilation(best_lsd_tree > 0, range(20), fill_threshold=1)

Kernel 0 done
Kernel 1 done
Kernel 2 done
Kernel 3 done
Kernel 4 done
Kernel 5 done
Kernel 6 done
Kernel 7 done
Kernel 8 done
Kernel 9 done
Kernel 10 done
Kernel 11 done
Kernel 12 done
Kernel 13 done
Kernel 14 done
Kernel 15 done
