In [1]:
from vis_utils import load_volume, VolumeVisualizer, ColorMapVisualizer
from identification import get_vessel_graph, remove_graph_components, parametrize_graph, clean_data_graph

In [2]:
import numpy as np
import networkx as nx
import time

from queue import SimpleQueue
from skimage import measure, morphology 
from scipy.ndimage import distance_transform_edt, zoom
from scipy.signal import fftconvolve
from scipy.spatial.distance import euclidean

In [3]:
visualize_steps = True 

### Visualisation utilities

In [4]:
def visualize_mask_bin(mask):
    VolumeVisualizer((mask > 0).astype(np.uint8), binary=True).visualize()
    
def visualize_mask_non_bin(mask):
    VolumeVisualizer((mask > 0).astype(np.uint8) * 255, binary=False).visualize()
    
def draw_graph_on_model(binary_model, graph):
    mask = np.zeros(binary_model.shape, dtype=np.uint8)
    mask[binary_model] = 30
    

    for edge in graph.edges:
        node1, node2 = edge
        x1, y1, z1 = node1
        x2, y2, z2 = node2
        line_x = np.linspace(x1, x2, num=300, endpoint=True, dtype=np.int32)
        line_y = np.linspace(y1, y2, num=300, endpoint=True, dtype=np.int32)
        line_z = np.linspace(z1, z2, num=300, endpoint=True, dtype=np.int32)
        for i in range(len(line_x)):
            mask[line_x[i], line_y[i], line_z[i]] = 255

    for node in graph.nodes:
        x, y, z = node
        mask[x, y, z] = 40
    return mask

### Read data

In [5]:
volume = np.fromfile('../data/P13/data.raw', dtype=np.uint8)

In [6]:
volume = volume.reshape(877, 488, 1132)
volume = volume[200:700, :, 100:650]
val = 20

volume[:, 200:, -35:] = val
volume[-150:, 200:, -100:-35] = val
volume[:, -200:-150, -100:-35] = val
volume[:, :210, -50:] = val
volume[:150, :150, -85:-50] = val

volume = volume[-170:-40,200:480,30:260]
volume[-90:, :, -110:] = val
volume[:, -80:, -92:] = val

In [7]:
volume.shape

(130, 280, 230)

In [8]:
if visualize_steps:
    VolumeVisualizer(volume, binary=False).visualize() 

In [9]:
mask = volume > 32

In [10]:
def get_main_regions(binary_mask, min_size=10_000, connectivity=3):
    labeled = measure.label(binary_mask, connectivity=connectivity)
    region_props = measure.regionprops(labeled)
    
    main_regions = np.zeros(binary_mask.shape)
    bounding_boxes = []
    for props in region_props:
        if props.area >= min_size:
            bounding_boxes.append(props.bbox)
            main_regions = np.logical_or(main_regions, labeled==props.label)
            
    lower_bounds = np.min(bounding_boxes, axis=0)[:3]
    upper_bounds = np.max(bounding_boxes, axis=0)[3:]

    return main_regions[
        lower_bounds[0]:upper_bounds[0],
        lower_bounds[1]:upper_bounds[1],
        lower_bounds[2]:upper_bounds[2],
    ], bounding_boxes

In [11]:
main_regions, bounding_boxes = get_main_regions(mask, min_size=25_000)

In [12]:
print('Number of main regions:', len(bounding_boxes))

Number of main regions: 1


In [13]:
if visualize_steps:
    visualize_mask_bin(main_regions)

In [14]:
def spherical_kernel(outer_radius, thickness=1, filled=True):    
    outer_sphere = morphology.ball(radius=outer_radius)
    if filled:
        return outer_sphere
    
    thickness = min(thickness, outer_radius)
    
    inner_radius = outer_radius - thickness
    inner_sphere = morphology.ball(radius=inner_radius)
    
    begin = outer_radius - inner_radius
    end = begin + inner_sphere.shape[0]
    outer_sphere[begin:end, begin:end, begin:end] -= inner_sphere
    return outer_sphere

def convolve_with_ball(img, ball_radius, dtype=np.uint16, normalize=True, fft=True):
    kernel = spherical_kernel(ball_radius, filled=True)
    if fft:
        convolved = fftconvolve(img.astype(dtype), kernel.astype(dtype), mode='same')
    else:
        convolved = signal.convolve(img.astype(dtype), kernel.astype(dtype), mode='same')
    
    if not normalize:
        return convolved
    
    return (convolved / kernel.sum()).astype(np.float16)

def calculate_reconstruction(mask, kernel_sizes=[10, 9, 8, 7], fill_threshold=0.5, iters=1, conv_dtype=np.uint16, fft=True):
    kernel_sizes_maps = []
    mask = mask.astype(np.uint8)
    
    for i in range(iters):
        kernel_size_map = np.zeros(mask.shape, dtype=np.uint8)

        for kernel_size in kernel_sizes:
            fill_percentage = convolve_with_ball(mask, kernel_size, dtype=conv_dtype, normalize=True, fft=fft)
            
            above_threshold_fill_indices = fill_percentage > fill_threshold
            kernel_size_map[above_threshold_fill_indices] = kernel_size + 1

            mask[above_threshold_fill_indices] = 1
            
            print(f'Iteration {i + 1} kernel {kernel_size} done')

        kernel_sizes_maps.append(kernel_size_map)
        print(f'Iteration {i + 1} ended successfully')

    return kernel_sizes_maps

In [15]:
scaled_mask = zoom(main_regions, zoom=0.7, order=0)

In [40]:
scaled_mask.shape

(91, 102, 161)

In [16]:
if visualize_steps:
    visualize_mask_bin(scaled_mask)

In [17]:
%%time
s_recos = calculate_reconstruction(scaled_mask, 
                                   kernel_sizes=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], 
                                   iters=3)

Iteration 1 kernel 1 done
Iteration 1 kernel 2 done
Iteration 1 kernel 3 done
Iteration 1 kernel 4 done
Iteration 1 kernel 5 done
Iteration 1 kernel 6 done
Iteration 1 kernel 7 done
Iteration 1 kernel 8 done
Iteration 1 kernel 9 done
Iteration 1 kernel 10 done
Iteration 1 kernel 11 done
Iteration 1 kernel 12 done
Iteration 1 ended successfully
Iteration 2 kernel 1 done
Iteration 2 kernel 2 done
Iteration 2 kernel 3 done
Iteration 2 kernel 4 done
Iteration 2 kernel 5 done
Iteration 2 kernel 6 done
Iteration 2 kernel 7 done
Iteration 2 kernel 8 done
Iteration 2 kernel 9 done
Iteration 2 kernel 10 done
Iteration 2 kernel 11 done
Iteration 2 kernel 12 done
Iteration 2 ended successfully
Iteration 3 kernel 1 done
Iteration 3 kernel 2 done
Iteration 3 kernel 3 done
Iteration 3 kernel 4 done
Iteration 3 kernel 5 done
Iteration 3 kernel 6 done
Iteration 3 kernel 7 done
Iteration 3 kernel 8 done
Iteration 3 kernel 9 done
Iteration 3 kernel 10 done
Iteration 3 kernel 11 done
Iteration 3 kernel 1

In [18]:
bin_reco = s_recos[-1] > 0

In [19]:
if visualize_steps:
    visualize_mask_bin(bin_reco)

In [20]:
%%time
edt_img = distance_transform_edt(bin_reco)

CPU times: total: 203 ms
Wall time: 205 ms


In [21]:
if visualize_steps:
    ColorMapVisualizer(edt_img.astype(np.uint8)).visualize()

In [22]:
%%time
vessel_graph = get_vessel_graph(bin_reco, 3)

Starting points:  3025
i: 0 j: 0 1
i: 1 j: 0 1
i: 1 j: 10 10
i: 1 j: 20 13
i: 1 j: 30 13
i: 1 j: 40 8
i: 1 j: 50 8
i: 1 j: 60 5
i: 1 j: 70 8
i: 1 j: 80 9
i: 1 j: 90 8
i: 2 j: 0 1
i: 3 j: 0 1
i: 4 j: 0 1
i: 5 j: 0 1
i: 6 j: 0 1
i: 7 j: 0 1
i: 8 j: 0 1
i: 9 j: 0 1
i: 10 j: 0 1
i: 11 j: 0 1
i: 12 j: 0 1
i: 13 j: 0 1
i: 14 j: 0 1
i: 15 j: 0 1
i: 16 j: 0 1
i: 17 j: 0 1
i: 18 j: 0 1
i: 19 j: 0 1
i: 20 j: 0 1
i: 21 j: 0 1
i: 22 j: 0 1
i: 23 j: 0 1
i: 24 j: 0 1
i: 25 j: 0 1
i: 26 j: 0 1
i: 27 j: 0 1
i: 28 j: 0 1
i: 29 j: 0 1
i: 30 j: 0 1
i: 31 j: 0 1
i: 32 j: 0 1
i: 33 j: 0 1
i: 34 j: 0 1
i: 35 j: 0 1
i: 36 j: 0 1
i: 37 j: 0 1
i: 38 j: 0 1
i: 39 j: 0 1
i: 40 j: 0 1
i: 41 j: 0 1
i: 42 j: 0 1
i: 43 j: 0 1
i: 44 j: 0 1
i: 45 j: 0 1
i: 46 j: 0 1
i: 47 j: 0 1
i: 48 j: 0 1
i: 49 j: 0 1
i: 50 j: 0 1
i: 51 j: 0 1
i: 52 j: 0 1
i: 53 j: 0 1
i: 54 j: 0 1
i: 55 j: 0 1
i: 56 j: 0 1
i: 57 j: 0 1
i: 58 j: 0 1
i: 59 j: 0 1
i: 60 j: 0 1
i: 61 j: 0 1
i: 62 j: 0 1
i: 63 j: 0 1
i: 64 j: 0 1
i: 65 j: 0 1
i: 66 j: 

In [23]:
print('Number of nodes', len(vessel_graph.nodes))
print('Number of edges', len(vessel_graph.edges))
print('Average degree', sum(dict(vessel_graph.degree).values()) / len(vessel_graph.nodes))

Number of nodes 2723
Number of edges 121
Average degree 0.08887256702166728


In [24]:
if visualize_steps:
    model_with_graph = draw_graph_on_model(bin_reco, vessel_graph)
    ColorMapVisualizer(model_with_graph.astype(np.uint8)).visualize()

In [25]:
%%time
vessel_graph_rm = remove_graph_components(vessel_graph)

CPU times: total: 15.6 ms
Wall time: 13.3 ms


In [26]:
print('Number of nodes', len(vessel_graph_rm.nodes))
print('Number of edges', len(vessel_graph_rm.edges))
print('Average degree', sum(dict(vessel_graph_rm.degree).values()) / len(vessel_graph_rm.nodes))

Number of nodes 140
Number of edges 121
Average degree 1.7285714285714286


In [27]:
if visualize_steps:
    model_with_graph = draw_graph_on_model(bin_reco, vessel_graph_rm)
    ColorMapVisualizer(model_with_graph.astype(np.uint8)).visualize()

In [28]:
def connect_endings_3d(graph, edt_img, multiplier=3.):
    """Fix discontinuities in vessel graph.

    Args:
        graph (nx.Graph): Graph object from NetworkX library.
        edt_img (np.array): Numpy array containing euclidean distance transformed image of vessel.
        multiplier (int, optional): Multiplier for maximum distance between nodes to connect.
        Defaults to 2.

    Returns:
        nx.Graph: Graph without discontinuities.
    """
    # find potential discontinuities
    endings = [node for node, degree in graph.degree() if degree == 1]
    root = np.unravel_index(np.argmax(edt_img), edt_img.shape)
    
    if root not in graph.nodes:
        graph.add_node(root)
        
    print("Endings: ", len(endings))

    while endings:
        start = endings.pop()
        r = edt_img[start] * multiplier
        search_area = int(4/3 * r * r * r * np.pi)

        visited = np.zeros_like(edt_img).astype(bool)

        to_visit = sorted_nodes_by_distance(graph, start)
        i = 0
        while i < len(to_visit):
            if nx.has_path(graph, root, start) or euclidean(start, to_visit[i]) > edt_img[start] * multiplier:
                break
            curr = to_visit[i]
            i += 1
            if not nx.has_path(graph, start, curr) and curr != start:
                graph.add_edge(start, curr)
                if curr in endings:
                    endings.pop(endings.index(curr))
                break
    return graph

def sorted_nodes_by_distance(graph, node):
    distances = {}
    for n in list(graph.nodes):
        if n == node:
            continue
        dist = euclidean(node, n)
        distances[n] = dist
    sorted_nodes = sorted(distances, key=distances.get)
    return sorted_nodes

In [29]:
%%time
for r in [3, 5, 10, 17]:
    vessel_graph_cn = connect_endings_3d(vessel_graph, edt_img, r)

Endings:  70
Endings:  50
Endings:  46
Endings:  44
CPU times: total: 1.33 s
Wall time: 1.32 s


In [30]:
print('Number of nodes', len(vessel_graph_cn.nodes))
print('Number of edges', len(vessel_graph_cn.edges))
print('Average degree', sum(dict(vessel_graph_cn.degree).values()) / len(vessel_graph_cn.nodes))

Number of nodes 141
Number of edges 140
Average degree 1.9858156028368794


In [31]:
if visualize_steps:
    model_with_graph = draw_graph_on_model(bin_reco, vessel_graph_cn)
    ColorMapVisualizer(model_with_graph.astype(np.uint8)).visualize()

In [32]:
visualize_steps = True

In [42]:
%%time
data_graph = parametrize_graph(vessel_graph_cn, edt_img)

CPU times: total: 31.2 ms
Wall time: 19 ms


In [43]:
print('Number of nodes', len(data_graph.nodes))
print('Number of edges', len(data_graph.edges))
print('Average degree', sum(dict(data_graph.degree).values()) / len(data_graph.nodes))

Number of nodes 77
Number of edges 76
Average degree 1.974025974025974


In [44]:
if visualize_steps:
    model_with_graph = draw_graph_on_model(bin_reco, data_graph)
    ColorMapVisualizer(model_with_graph.astype(np.uint8)).visualize()

In [36]:
%%time
data_graph_cl = clean_data_graph(data_graph)

CPU times: total: 0 ns
Wall time: 0 ns


In [37]:
if visualize_steps:
    model_with_graph = draw_graph_on_model(bin_reco, data_graph_cl)
    ColorMapVisualizer(model_with_graph.astype(np.uint8)).visualize()

In [38]:
print('Number of nodes', len(data_graph.nodes))
print('Number of edges', len(data_graph.edges))
print('Average degree', sum(dict(data_graph.degree).values()) / len(data_graph.nodes))

Number of nodes 24
Number of edges 13
Average degree 1.0833333333333333


In [45]:
model_with_graph = draw_graph_on_model(bin_reco, data_graph_cl)
ColorMapVisualizer(model_with_graph.astype(np.uint8)).visualize()