In [116]:
#from identification import get_starting_points, get_vessel_graph, get_data_graph, identify_segments, get_segment_mask
from vis_utils import load_volume, VolumeVisualizer, ColorMapVisualizer

In [141]:
from vis_utils import load_volume, VolumeVisualizer, ColorMapVisualizer

In [158]:
import numpy as np
from matplotlib import pyplot as plt
import time

from skimage.morphology import skeletonize_3d, binary_dilation, convex_hull_image, ball, area_closing
from skimage import measure, segmentation, feature, filters, morphology
from scipy.ndimage import distance_transform_edt
from scipy.signal import fftconvolve
from scipy.ndimage import zoom
from mpl_toolkits.mplot3d import Axes3D
import pickle

In [159]:
# """This module contains functions necessary for the identification of vessel branches."""
# pylint: disable=import-error


from queue import SimpleQueue

import numpy as np
import networkx as nx

from scipy.ndimage import distance_transform_edt
from scipy.spatial.distance import euclidean
from skimage.draw import circle_perimeter, disk
from skimage.feature import peak_local_max


PROJECTION_ANGLES = {
    (45, 0): ["5", "6", "13", "9", "12a"],
    (-45, 0): ["5", "6", "13", "9", "12a"],
    (45, 20): ["5", "15"],
    (45, 20): ["5", "13"],
    (30, -30): ["5", "6", "13", "9", "12a"],
    (-30, -30): ["5", "13", "12a", "12b"],
    (-30, 30): ["5", "6", "9", "10", "10"],
}


def get_starting_points(img):
    """Get starting points for centerline detection algorithm. The starting point is selected
    as the maximum value of the euclidean distance transform (EDT) image of the segmented vessel.

    Args:
        img (np.array): EDT image of the segmented vessel.

    Returns:
        tuple: Starting point for centerline detection algorithm.
    """
    # initialize helper image
    helper = np.zeros_like(img)
    # find local maxima on EDT image.
    starting_points = peak_local_max(img)
    # get coordinates of local maxima
    helper[tuple(starting_points.T)] = 1
    starting_points = list(zip(*np.nonzero(helper)))
    # sort starting points by vessel diameter.
    starting_points.sort(key=lambda x: img[x])
    return starting_points


def get_sphere_points(center, radius, im_size):
    size = np.min([np.max(center) + 2*radius -1, im_size[0]-1, im_size[1]-1, im_size[2]-1])
    x, y, z = np.meshgrid(np.arange(size), np.arange(size), np.arange(size))
    d = np.sqrt((x-center[0])**2 + (y-center[1])**2 + (z-center[2])**2)
    surface_points = np.argwhere(np.isclose(d, radius))
    return [tuple(point) for point in surface_points]

def sphere_perimeter(coords, radius, shape):
    # Create arrays of indices for each dimension
    x_indices = np.arange(shape[0])
    y_indices = np.arange(shape[1])
    z_indices = np.arange(shape[2])
    
    # Create a 3D grid of indices for each dimension
    x_grid, y_grid, z_grid = np.meshgrid(x_indices, y_indices, z_indices, indexing='ij')
    
    # Calculate the distance of each point from the center
    distances = np.sqrt((x_grid - coords[0])**2 + (y_grid - coords[1])**2 + (z_grid - coords[2])**2)
    
    # Find the indices of points that are on the sphere surface
    sphere_indices = np.where(np.isclose(distances, radius))
    
    return sphere_indices

def sphere_full(coords, radius, shape):
    # Create arrays of indices for each dimension
    x_indices = np.arange(shape[0]-1)
    y_indices = np.arange(shape[1]-1)
    z_indices = np.arange(shape[2]-1)
    
    # Create a 3D grid of indices for each dimension
    x_grid, y_grid, z_grid = np.meshgrid(x_indices, y_indices, z_indices, indexing='ij')
    
    # Calculate the distance of each point from the center
    distances = np.sqrt((x_grid - coords[0])**2 + (y_grid - coords[1])**2 + (z_grid - coords[2])**2)
    
    sphere_indices = np.where(distances <= radius)
    
    # Create a binary mask with True values at the indices of points on the sphere surface
    mask = np.zeros(shape, dtype=bool)
    mask[sphere_indices] = True
    
    return mask

def is_inside_image(p, shape):
    return p[0] > 0 and p[1] > 0 and p[2] > 0 and p[0] > shape[0] and p[1] > shape[1] and p[2] > shape[2]

def get_points_of_interest(point, radius, visited, edt_img):
    """Get list of points that can be on the vessel centerline. The method utilizes local maxima
    to find the centerline on the basis of the EDT image of the vessel. The returned points are
    sorted by the EDT image value, with the pixels on the largest vessel at the beginning of the
    list.

    Args:
        point (tuple(int, int)): Point on centerline.
        radius (int): Radius of search circle.
        visited (np.array(bool)): Boolean array of visited image pixels.

    Returns:
        list: A list of points that can be on the vessel centerline.
    """
    # get points on circle circumference
    #circle_points = circle_perimeter(*point, radius=radius, method="andres")
    circle_points = sphere_perimeter(point, radius, edt_img.shape)
    #print("Circle points: ", len(circle_points), edt_img.shape)
    # create image with edt values of point on circle circumference
    helper = np.zeros_like(edt_img)
    helper[circle_points] = edt_img[circle_points]
    # find local maxima
    coords = peak_local_max(
        helper, min_distance=radius // 3, threshold_abs=2, threshold_rel=0.5
    )
    #print("Local peaks on surface: ", len(coords))
    # find coordinates of local maxima
    helper[circle_points] = 0
    helper[tuple(coords.T)] = 1
    helper[visited] = 0
    points_of_interest = list(zip(*np.nonzero(helper)))
    points_of_interest = list(filter(lambda x: is_inside_image(x, edt_img.shape), points_of_interest))
    # sort points by vessel diameter
    points_of_interest.sort(key=lambda x: edt_img[x], reverse=True)
    return points_of_interest


def get_disk_3d(center, radius):
    size = np.max(center) + 2*radius
    x, y, z = np.meshgrid(np.arange(size), np.arange(size), np.arange(size))
    d = np.sqrt((x-center[0])**2 + (y-center[1])**2 + (z-center[2])**2)
    surface_points = np.argwhere(d <= radius)
    return [tuple(point) for point in surface_points]

def generate_sphere(radius):
    large = ball(radius)
    small = ball(np.max(1, radius-1))
    padded_small = np.pad(small, 1, mode='constant')
    sphere = large - padded_small
    return sphere

def get_vessel_graph(seg_img):
    """Produce a graph which describes the vessel structure.

    Args:
        seg_img (np.array): Segmentation of vessel image.

    Returns:
        nx.Graph: Directed graph describing the vessel structure.
    """
    # perform euclidean distance transform on base image
    edt_img = distance_transform_edt(seg_img)
    print("EDT transform completed")
    # get starting points
    starting_point_stack = get_starting_points(edt_img)
    print("Finding starting points completed")

    # create vessel tree data structure
    vessel_graph = nx.DiGraph()
    # create array for holding information on visited pixels
    visited = np.zeros_like(edt_img).astype(bool)
    i = 0

    while starting_point_stack:
        i+=1
        # select new starting point and check if it was not visited already
        starting_point = starting_point_stack.pop()
        if visited[starting_point]:
            continue

        vessel_graph.add_node(starting_point)
        points_to_examine = [starting_point]
        if i % 10 == 0:
            print("Step i", i, len(starting_point_stack))
        j = 0
        while points_to_examine:
            j += 1
            if j % 100 == 0:
                print("Step: ", i, "-", j, " Points to examine: ", len(points_to_examine))
            # get point from queue
            point = points_to_examine.pop()
            # calculate vessel radius
            radius = int(edt_img[point])
            # if radius is too small, don't go further into this vessel
            # also limited by min_distance having to be >= 1 in peak_local_max
            if radius <= 1:
                continue
            # get points that can be on centreline
            points_of_interest = get_points_of_interest(point, radius, visited, edt_img)
                
            # eliminate points that are too close to each other
            for poi_1 in points_of_interest:
                for poi_2 in points_of_interest:
                    if poi_1 != poi_2 and euclidean(poi_1, poi_2) < np.maximum(edt_img[poi_1], edt_img[poi_2]):
                        if edt_img[poi_1] > edt_img[poi_2]:
                            points_of_interest.pop(points_of_interest.index(poi_1))
                        else:
                            points_of_interest.pop(points_of_interest.index(poi_2))

            # add points of interest to examination list and to vessel graph

            for poi in points_of_interest:
                # avoid node duplication
                if not (poi in points_to_examine or vessel_graph.has_node(poi)):
                    points_to_examine.append(poi)
                    vessel_graph.add_edge(point, poi)
                    
            # remove potential centreline pixels next to analyzed point to prevent going backwards
            disk_points = sphere_full(point, radius, visited.shape)

            visited[disk_points] = True

    return vessel_graph.to_undirected()


def remove_graph_components(graph):
    """Remove small (less than 4 nodes) graph components.

    Args:
        graph (nx.Graph): Undirected graph object from NetworkX library.

    Returns:
        nx.Graph: Graph without small components.
    """
    # get components
    components = list(nx.connected_components(graph))
    # remove small components
    for component in components:
        if len(component) < 2:
            for node in component:
                graph.remove_node(node)
    return graph


# pylint: disable=too-many-boolean-expressions,invalid-name
def connect_endings(graph, edt_img, multiplier=2.):
    """Fix discontinuities in vessel graph.

    Args:
        graph (nx.Graph): Graph object from NetworkX library.
        edt_img (np.array): Numpy array containing euclidean distance transformed image of vessel.
        multiplier (int, optional): Multiplier for maximum distance between nodes to connect.
        Defaults to 2.

    Returns:
        nx.Graph: Graph without discontinuities.
    """
    # find potential discontinuities
    endings = [node for node, degree in graph.degree() if degree == 1]
    # for every ending run BFS connection search
    while endings:
        # get point to find connection for
        start = endings.pop()
        # calculate search area
        r = edt_img[start] * multiplier
        search_area = int(r * r * np.pi)
        # setup BFS
        points = SimpleQueue()
        points.put(start)
        visited = []
        # run BFS on a restricted area
        while not points.empty() and search_area:
            search_area -= 1
            # get point
            x, y = points.get()
            visited.append((x, y))

            # check if point is a node and is a valid connection
            if graph.has_node((x, y)) and not nx.has_path(graph, start, (x, y)):
                graph.add_edge(start, (x, y))
                # this is to prevent accidentally creating bifurcations
                if (x, y) in endings:
                    endings.pop(endings.index((x, y)))
                break

            # add point to search if it is in segmentation mask and it is not visited
            for dx in range(-1, 2):
                for dy in range(-1, 2):
                    new_point = (x + dx, y + dy)
                    if (
                        x + dx >= 0
                        and x + dx < edt_img.shape[0]
                        and y + dy >= 0
                        and y + dy < edt_img.shape[1]
                        and new_point not in visited
                        and edt_img[x + dx, y + dy] > 0
                    ):
                        visited.append(new_point)
                        points.put(new_point)

    return graph

def connect_endings_mst(graph, edt_image, multiplier=10):
    nodes = list(graph.nodes())
    n = len(nodes)
    
    G = nx.Graph()
    G.add_nodes_from(nodes)

    for i in range(n):
        for j in range(i+1, n):
            d = euclidean(nodes[i], nodes[j]) 
            d_thresh = np.maximum(edt_image[nodes[i]], edt_image[nodes[j]]) * multiplier
            if d <= d_thresh:
                G.add_edge(nodes[i], nodes[j], weight=d)
    return nx.minimum_spanning_tree(G)


def sorted_nodes_by_distance(graph, node):
    distances = {}
    for n in list(graph.nodes):
        if n == node:
            continue
        dist = euclidean(node, n)
        distances[n] = dist
    sorted_nodes = sorted(distances, key=distances.get)
    return sorted_nodes
    
# pylint: disable=too-many-boolean-expressions,invalid-name
def connect_endings_3d(graph, edt_img, multiplier=3.):
    """Fix discontinuities in vessel graph.

    Args:
        graph (nx.Graph): Graph object from NetworkX library.
        edt_img (np.array): Numpy array containing euclidean distance transformed image of vessel.
        multiplier (int, optional): Multiplier for maximum distance between nodes to connect.
        Defaults to 2.

    Returns:
        nx.Graph: Graph without discontinuities.
    """
    # find potential discontinuities
    endings = [node for node, degree in graph.degree() if degree <= 1]
    root = np.unravel_index(np.argmax(edt_img), edt_img.shape)
    
    if root not in graph.nodes:
        graph.add_node(root)
        
    print("Endings: ", len(endings))
    # for every ending run BFS connection search
    while endings:
        # get point to find connection for
        start = endings.pop()
        # calculate search area
        r = edt_img[start] * multiplier
        search_area = int(4/3 * r * r * r * np.pi)
        # setup BFS
        points = SimpleQueue()
        points.put(start)
        visited = np.zeros_like(edt_img).astype(bool)

        # run BFS on a restricted area
        i = 0

        to_visit = sorted_nodes_by_distance(graph, start)
        print("Endings: ", len(endings))

        while i < len(to_visit):
            if nx.has_path(graph, root, start) or euclidean(start, to_visit[i]) > edt_img[start] * multiplier:
                break
            curr = to_visit[i]
            i += 1
            if not nx.has_path(graph, start, curr) and curr != start:
                graph.add_edge(start, curr)
                if curr in endings:
                    endings.pop(endings.index(curr))
                break

    return graph


def get_node_data(node, prev_node, next_node, edt_img):
    """Get data for node in vessel graph.

    Args:
        node (tuple): Current node for which data is extracted.
        prev_node (tuple): Previous node in graph traversal.
        next_node (tuple): Next node in graph traversal.
        edt_img (np.array): Euclidean distance transformed image of segmented vessel.

    Returns:
        dict: Dictionary containing node data.
    """
    data = {}
    # get vessel diameter
    data["vessel_diameter"] = float(edt_img[node])
    # get vessel diameter gradient
    data["vessel_diameter_grad"] = float(edt_img[node] - edt_img[prev_node])
    # get vessel fragment length
    v1 = np.subtract(prev_node, node).astype(np.float32)
    data["vessel_length"] = float(np.linalg.norm(v1))
    # get angle between nodes
    v1 /= np.linalg.norm(v1)
    v2 = np.subtract(next_node, node).astype(np.float32)
    v2 /= np.linalg.norm(v2)
    data["angle"] = float(np.arccos(np.clip(np.dot(v1, v2), -1.0, 1.0)))
    return data


def parametrize_graph(vessel_graph, edt_img):
    """Retrieve general segment data from detailed vessel graph.

    Args:
        vessel_graph (nx.Graph): Graph describing detailed vessel structure.
        edt_img (np.array): Numpy array containing euclidean distance transformed segmentation of
        vessel.

    Returns:
        nx.Digraph: Directed, parametrized graph containing vessel segment data in nodes.
    """
    # get nodes and sort them by vessel diameter
    # this list can be used in the future to include disconnected graph components
    nodes = [node for node, degree in vessel_graph.degree() if degree == 1]
    nodes.sort(key=lambda n: edt_img[n])
    # select starting point
    start = nodes.pop()
    # initiate graph
    data_graph = nx.Graph()
    # this variable is for identifying nodes in parametrized graph
    data_graph.add_node(start)
    # initiate paths to explore from this point
    paths_to_explore = [(start, neighbor) for neighbor in vessel_graph.neighbors(start)]
    visited = [start]
    # check every path for bifurcations or endings
    while paths_to_explore:
        # get relevant data
        start, node = paths_to_explore.pop()
        prev = start
        # flag node as visited
        visited.append(node)
        # create object to gather data from vessel segment
        data = {"nodes": {start: get_node_data(start, start, node, edt_img)}}
        # measure length of vessel segment
        total_len = 0
        vessel_diameters = []
        # traverse segment while no bifurcations or endings were detected
        while vessel_graph.degree(node) == 2:
            # get next node
            neighbors = list(vessel_graph.neighbors(node))
            neighbors.pop(neighbors.index(prev))
            next_node = neighbors[0]
            # gather data on next node
            data["nodes"][node] = get_node_data(node, prev, next_node, edt_img)
            total_len += data["nodes"][node]["vessel_length"]
            vessel_diameters.append(data["nodes"][node]["vessel_diameter"])
            # go ahead
            prev = node
            node = next_node
            visited.append(node)
        data["nodes"][node] = get_node_data(node, prev, node, edt_img)
        total_len += data["nodes"][node]["vessel_length"]
        vessel_diameters.append(data["nodes"][node]["vessel_diameter"])
        # this happens after ending/bifurcation was detected
        data["segment_length"] = total_len
        data["average_vessel_diameter"] = sum(vessel_diameters) / len(vessel_diameters)
        # add node with collected data
        data_graph.add_node(node)
        data_graph.add_edge(start, node, **data)
        # add paths to explore if there are new ones
        paths_to_explore.extend(
            [
                (node, neighbor)
                for neighbor in vessel_graph.neighbors(node)
                if neighbor not in visited
            ]
        )
    return data_graph


def choose_root_node(data_graph):
    """Select root node for data graph.

    Args:
        data_graph (nx.Graph): Undirected vessel data graph.

    Returns:
        tuple: coordinates of root node.
    """
    # select edge with biggest average vessel diameter
    edges = [
        k
        for k, _ in sorted(
            nx.get_edge_attributes(data_graph, "average_vessel_diameter").items(),
            key=lambda item: item[1],
        )
    ]
    n1, n2 = edges[-1]
    # select root by node degree
    if data_graph.degree(n1) == 1 and data_graph.degree(n2) != 1:
        return n1
    if data_graph.degree(n2) == 1 and data_graph.degree(n1) != 1:
        return n2
    # select root by biggest vessel diameter
    n1_diameter = data_graph[n1][n2]["nodes"][n1]["vessel_diameter"]
    n2_diameter = data_graph[n1][n2]["nodes"][n2]["vessel_diameter"]
    return n1 if n1_diameter > n2_diameter else n2


def clean_data_graph(data_graph, min_segment_length=20):
    """Clean data graph by removing small segments, nodes of degree 2 and selecting a new root.

    Args:
        data_graph (nx.Graph): Undirected vessel data graph.
        min_segment_length (int, optional): Minimum segment length permissible in graph.
        Defaults to 20[px].

    Returns:
        nx.Graph: Cleaned vessel data graph.
    """
    # remove small segments
    for edge, segment_length in nx.get_edge_attributes(
        data_graph, "segment_length"
    ).items():
        if segment_length < min_segment_length:
            data_graph.remove_edge(*edge)
    # remove isolated nodes (degree == 0)
    data_graph.remove_nodes_from(list(nx.isolates(data_graph)))
    # merge segments for nodes of degree 2
    while [n for n, degree in data_graph.degree() if degree == 2]:
        for node, degree in list(data_graph.degree()):
            if degree != 2:
                continue
            n1, n2 = list(data_graph.neighbors(node))
            new_edge_data = {}
            new_edge_data["nodes"] = (
                data_graph[node][n1]["nodes"] | data_graph[node][n2]["nodes"]
            )
            new_edge_data["segment_length"] = sum(
                node_data["vessel_length"]
                for node_data in new_edge_data["nodes"].values()
            )
            new_edge_data["average_vessel_diameter"] = sum(
                node_data["vessel_diameter"]
                for node_data in new_edge_data["nodes"].values()
            ) / len(new_edge_data["nodes"])
            data_graph.remove_node(node)
            data_graph.add_edge(n1, n2, **new_edge_data)
    # select new root
    root_node = choose_root_node(data_graph)
    nx.set_node_attributes(data_graph, False, "root")
    data_graph.nodes[root_node]["root"] = True
    return data_graph


def identify_segments(data_graph, primary_angle, secondary_angle):
    """Identify vessel segments in vessel data graph.

    Args:
        data_graph (nx.Graph): Undirected vessel data graph.
        primary_angle (int): Positioner primary angle.
        secondary_angle (int): Positioner secondary angle.

    Raises:
        AttributeError: if angles are outside of the expected bounds (at least 20 degrees off the
        expected values defined in PROJECTION_ANGLES).

    Returns:
        nx.Graph: vessel data graph with labeled vessel segments.
    """
    for angles, vessel_labels in PROJECTION_ANGLES.items():
        primary, secondary = angles
        if (
            abs(primary_angle - primary) <= 20
            and abs(secondary_angle - secondary) <= 20
        ):
            labels = vessel_labels.copy()
            break
    else:
        raise AttributeError("Image has incorrect projection angles.")

    for k, v in data_graph.nodes(data="root"):
        if v:
            root_node = k
            break

    current_label = ""
    for n1, n2 in nx.bfs_edges(data_graph, root_node):
        if labels:
            current_label = labels.pop(0)
        data_graph[n1][n2]["vessel_label"] = current_label

    return data_graph


def get_segment_mask(data_graph, seg_img):
    """Generate representation of vessel segments in the form of segmentation masks.

    Args:
        data_graph (nx.Graph): NetworkX graph containing general vessel information.
        seg_img (np.array): Numpy array containing segmented vessel image.

    Returns:
        tuple(np.array, dict): Numpy array containing segment mask and label dict.
    """
    masks = []
    labels = {}
    for idx, nodes in enumerate(data_graph.edges.data("nodes")):
        try:
            labels[idx + 1] = data_graph[nodes[0]][nodes[1]]["vessel_label"]
        except KeyError:
            labels[idx + 1] = "UNKN"
            # WARN print(idx,nodes[0],nodes[1]," no label")
        # calculate distances from nodes in segment
        mask = np.ones_like(seg_img)
        mask[nodes[:-1]] = 0
        idxs = tuple(np.array(list(nodes[-1].keys())).T)
        mask[idxs] = 0
        mask = distance_transform_edt(mask)
        # append to summary mask
        masks.append(mask)
    masks = np.stack(masks)
    # select label closest to point
    result = np.argmin(masks, axis=0) + 1
    result[seg_img == 0] = 0
    result = result.astype(np.uint8)
    return result, labels


def get_data_graph(seg_img):
    """Function collecting steps for creating a data graph.

    Args:
        seg_img (np.array): Numpy array containing segmented vessel image.

    Returns:
        nx.Graph: Vessel data graph.
    """
    edt_img = distance_transform_edt(seg_img)
    vessel_graph = get_vessel_graph(seg_img)
    print(nx.info(vessel_graph))
    vessel_graph = remove_graph_components(vessel_graph)
    print(nx.info(vessel_graph))
    vessel_graph = connect_endings_3d(vessel_graph, edt_img)
    print(nx.info(vessel_graph))
    data_graph = parametrize_graph(vessel_graph, edt_img)
    data_graph = clean_data_graph(data_graph)
    return data_graph


### Visualisation utilities

In [160]:
def visualize_addition(base, base_with_addition):
    base = (base.copy() > 0).astype(np.uint8)
    addition = (base_with_addition > 0).astype(np.uint8)
    addition[base == 1] = 0
    ColorMapVisualizer(base + addition * 4).visualize()
    
def visualize_lsd(lsd_mask):
    ColorMapVisualizer(lsd_mask.astype(np.uint8)).visualize()
    
def visualize_gradient(lsd_mask):
    ColorMapVisualizer(lsd_mask.astype(np.uint8)).visualize(gradient=True)
    
def visualize_mask_bin(mask):
    VolumeVisualizer((mask > 0).astype(np.uint8), binary=True).visualize()
    
def visualize_mask_non_bin(mask):
    VolumeVisualizer((mask > 0).astype(np.uint8) * 255, binary=False).visualize()
    
def visualize_skeleton(mask, visualize_mask=True, visualize_both_versions=False):
    skeleton = skeletonize_3d((mask > 0).astype(np.uint8))
    if not visualize_mask or visualize_both_versions:
        VolumeVisualizer(skeleton, binary=True).visualize()
    if visualize_mask or visualize_both_versions:
        skeleton = skeleton.astype(np.uint8) * 4
        mask = (mask > 0).astype(np.uint8) * 3
        mask[skeleton != 0] = 0
        ColorMapVisualizer(skeleton + mask).visualize()

def visualize_ultimate(lsd, base_mask):
    visualize_lsd(lsd)
    visualize_mask_non_bin(lsd)
    visualize_addition(base_mask, lsd)
    visualize_skeleton(lsd, visualize_mask=True)
    
def plot_3d_graph(G):
    pos = nx.spring_layout(G, dim=3)
    fig = plt.figure(figsize=(10,10))
    ax = fig.add_subplot(111, projection='3d')
    for node in G.nodes():
        (x,y,z) = pos[node]
        ax.scatter(x, y, z, c='b', marker='o')
        ax.text(x, y, z, node, fontsize=10)
    for edge in G.edges():
        (u,v) = edge
        (x1,y1,z1) = pos[u]
        (x2,y2,z2) = pos[v]
        ax.plot([x1,x2], [y1,y2], [z1,z2], 'k-', alpha=0.5)
    ax.set_xlabel('X Label')
    ax.set_ylabel('Y Label')
    ax.set_zlabel('Z Label')
    plt.show()
    
def draw_graph_on_model(binary_model, graph):
    mask = np.zeros(binary_model.shape, dtype=np.uint8)
    mask[binary_model] = 30
    

    for edge in graph.edges:
        node1, node2 = edge
        x1, y1, z1 = node1
        x2, y2, z2 = node2
        line_x = np.linspace(x1, x2, num=300, endpoint=True, dtype=np.int32)
        line_y = np.linspace(y1, y2, num=300, endpoint=True, dtype=np.int32)
        line_z = np.linspace(z1, z2, num=300, endpoint=True, dtype=np.int32)
        for i in range(len(line_x)):
            mask[line_x[i], line_y[i], line_z[i]] = 255

    for node in graph.nodes:
        x, y, z = node
        mask[x, y, z] = 40
#         for dx in (-1, 2):
#             for dy in (-1, 2):
#                 for dz in (-1, 2):
#                     mask[x+dx, y+dy, z+dz] = 150

    return mask


### Read data

In [161]:
volume = np.fromfile('../data/P13/data.raw', dtype=np.uint8)

In [162]:
volume = volume.reshape(877, 488, 1132)
volume = volume[200:700, :, 100:650]
val = 20

volume[:, 200:, -35:] = val
volume[-150:, 200:, -100:-35] = val
volume[:, -200:-150, -100:-35] = val
volume[:, :210, -50:] = val
volume[:150, :150, -85:-50] = val

volume = volume[-170:-40,200:480,30:260]
volume[-90:, :, -110:] = val
volume[:, -80:, -92:] = val

In [163]:
VolumeVisualizer(volume, binary=False).visualize() 

In [164]:
mask = volume > 32

In [165]:
def get_main_regions(binary_mask, min_size=10_000, connectivity=3):
    labeled = measure.label(binary_mask, connectivity=connectivity)
    region_props = measure.regionprops(labeled)
    
    main_regions = np.zeros(binary_mask.shape)
    bounding_boxes = []
    for props in region_props:
        if props.area >= min_size:
            bounding_boxes.append(props.bbox)
            main_regions = np.logical_or(main_regions, labeled==props.label)
            
    lower_bounds = np.min(bounding_boxes, axis=0)[:3]
    upper_bounds = np.max(bounding_boxes, axis=0)[3:]

    return main_regions[
        lower_bounds[0]:upper_bounds[0],
        lower_bounds[1]:upper_bounds[1],
        lower_bounds[2]:upper_bounds[2],
    ], bounding_boxes

In [166]:
main_regions, bounding_boxes = get_main_regions(mask, min_size=25_000)
print('number of main regions:', len(bounding_boxes))
mask_main = main_regions
# mask = None
# visualize_mask_non_bin(mask_main)

number of main regions: 1


In [41]:
visualize_mask_bin(mask_main)

In [167]:
#ColorMapVisualizer(img_edt.astype(np.uint8)).visualize()

In [168]:
#img_edt = distance_transform_edt(mask_main)

In [169]:
def spherical_kernel(outer_radius, thickness=1, filled=True):    
    outer_sphere = morphology.ball(radius=outer_radius)
    if filled:
        return outer_sphere
    
    thickness = min(thickness, outer_radius)
    
    inner_radius = outer_radius - thickness
    inner_sphere = morphology.ball(radius=inner_radius)
    
    begin = outer_radius - inner_radius
    end = begin + inner_sphere.shape[0]
    outer_sphere[begin:end, begin:end, begin:end] -= inner_sphere
    return outer_sphere

def convolve_with_ball(img, ball_radius, dtype=np.uint16, normalize=True, fft=True):
    kernel = spherical_kernel(ball_radius, filled=True)
    if fft:
        convolved = fftconvolve(img.astype(dtype), kernel.astype(dtype), mode='same')
    else:
        convolved = signal.convolve(img.astype(dtype), kernel.astype(dtype), mode='same')
    
    if not normalize:
        return convolved
    
    return (convolved / kernel.sum()).astype(np.float16)

def calculate_reconstruction(mask, kernel_sizes=[10, 9, 8, 7], fill_threshold=0.5, iters=1, conv_dtype=np.uint16, fft=True):
    kernel_sizes_maps = []
    mask = mask.astype(np.uint8)
    
    for i in range(iters):
        kernel_size_map = np.zeros(mask.shape, dtype=np.uint8)

        for kernel_size in kernel_sizes:
            fill_percentage = convolve_with_ball(mask, kernel_size, dtype=conv_dtype, normalize=True, fft=fft)
            
            above_threshold_fill_indices = fill_percentage > fill_threshold
            kernel_size_map[above_threshold_fill_indices] = kernel_size + 1

            mask[above_threshold_fill_indices] = 1
            
            print(f'Iteration {i + 1} kernel {kernel_size} done')

        kernel_sizes_maps.append(kernel_size_map)
        print(f'Iteration {i + 1} ended successfully')

    return kernel_sizes_maps

In [170]:
scaled_mask = zoom(mask_main, zoom=0.7, order=2)

In [171]:
s_recos = calculate_reconstruction(scaled_mask, 
                                   kernel_sizes=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], 
                                   iters=3)

Iteration 1 kernel 1 done
Iteration 1 kernel 2 done
Iteration 1 kernel 3 done
Iteration 1 kernel 4 done
Iteration 1 kernel 5 done
Iteration 1 kernel 6 done
Iteration 1 kernel 7 done
Iteration 1 kernel 8 done
Iteration 1 kernel 9 done
Iteration 1 kernel 10 done
Iteration 1 kernel 11 done
Iteration 1 kernel 12 done
Iteration 1 ended successfully
Iteration 2 kernel 1 done
Iteration 2 kernel 2 done
Iteration 2 kernel 3 done
Iteration 2 kernel 4 done
Iteration 2 kernel 5 done
Iteration 2 kernel 6 done
Iteration 2 kernel 7 done
Iteration 2 kernel 8 done
Iteration 2 kernel 9 done
Iteration 2 kernel 10 done
Iteration 2 kernel 11 done
Iteration 2 kernel 12 done
Iteration 2 ended successfully
Iteration 3 kernel 1 done
Iteration 3 kernel 2 done
Iteration 3 kernel 3 done
Iteration 3 kernel 4 done
Iteration 3 kernel 5 done
Iteration 3 kernel 6 done
Iteration 3 kernel 7 done
Iteration 3 kernel 8 done
Iteration 3 kernel 9 done
Iteration 3 kernel 10 done
Iteration 3 kernel 11 done
Iteration 3 kernel 1

In [172]:
bin_reco = s_recos[-1] > 0

In [173]:
%%time
edt_img = distance_transform_edt(bin_reco)

CPU times: total: 203 ms
Wall time: 200 ms


In [174]:
ColorMapVisualizer(edt_img.astype(np.uint8)).visualize()

In [155]:
%%time
vessel_graph = get_vessel_graph(bin_reco)
print(nx.info(vessel_graph))

EDT transform completed
Finding starting points completed
Step i 40 2753
Step i 60 2733
Step i 80 2713
Step i 140 2653
Step i 150 2643
Step i 160 2633
Step i 170 2623
Step i 190 2603
Step i 200 2593
Step i 230 2563
Step i 250 2543
Step i 270 2523
Step i 290 2503
Step i 320 2473
Step i 350 2443
Step i 370 2423
Step i 380 2413
Step i 390 2403
Step i 400 2393
Step i 410 2383
Step i 420 2373
Step i 430 2363
Step i 440 2353
Step i 450 2343
Step i 460 2333
Step i 470 2323
Step i 480 2313
Step i 490 2303
Step i 500 2293
Step i 510 2283
Step i 520 2273
Step i 530 2263
Step i 540 2253
Step i 550 2243
Step i 560 2233
Step i 570 2223
Step i 580 2213
Step i 590 2203
Step i 600 2193
Step i 610 2183
Step i 620 2173
Step i 630 2163
Step i 640 2153
Step i 650 2143
Step i 660 2133
Step i 670 2123
Step i 680 2113
Step i 690 2103
Step i 700 2093
Step i 710 2083
Step i 720 2073
Step i 730 2063
Step i 740 2053
Step i 750 2043
Step i 760 2033
Step i 770 2023
Step i 780 2013
Step i 790 2003
Step i 800 1993
S




In [97]:
model_with_graph = draw_graph_on_model(bin_reco, vessel_graph)
ColorMapVisualizer(model_with_graph.astype(np.uint8)).visualize()

In [134]:
%%time
vessel_graph_rm = remove_graph_components(vessel_graph_cn)
print(nx.info(vessel_graph_rm))

Graph with 1526 nodes and 1525 edges
CPU times: total: 0 ns
Wall time: 3.02 ms





In [135]:
model_with_graph = draw_graph_on_model(bin_reco, vessel_graph_rm)
ColorMapVisualizer(model_with_graph.astype(np.uint8)).visualize()

In [149]:
pickle.dump(vessel_graph_rm, open('removed-small.pickle', 'wb'))

In [156]:
%%time
for r in [3, 5, 7, 10, 15, 19]:
    vessel_graph_cn = connect_endings_3d(vessel_graph, edt_img, r)
    print(nx.info(vessel_graph_cn))

Endings:  2571
Endings:  2570
Endings:  2568
Endings:  2566
Endings:  2564
Endings:  2562
Endings:  2560
Endings:  2558
Endings:  2556
Endings:  2554
Endings:  2553
Endings:  2552
Endings:  2550
Endings:  2548
Endings:  2546
Endings:  2544
Endings:  2542
Endings:  2541
Endings:  2540
Endings:  2538
Endings:  2536
Endings:  2534
Endings:  2533
Endings:  2531
Endings:  2529
Endings:  2527
Endings:  2525
Endings:  2523
Endings:  2521
Endings:  2519
Endings:  2518
Endings:  2516
Endings:  2514
Endings:  2512
Endings:  2511
Endings:  2509
Endings:  2507
Endings:  2505
Endings:  2504
Endings:  2502
Endings:  2500
Endings:  2498
Endings:  2496
Endings:  2494
Endings:  2492
Endings:  2490
Endings:  2488
Endings:  2486
Endings:  2484
Endings:  2482
Endings:  2480
Endings:  2478
Endings:  2476
Endings:  2474
Endings:  2472
Endings:  2470
Endings:  2469
Endings:  2468
Endings:  2466
Endings:  2464
Endings:  2463
Endings:  2461
Endings:  2459
Endings:  2458
Endings:  2456
Endings:  2454
Endings:  

Endings:  1599
Endings:  1597
Endings:  1595
Endings:  1594
Endings:  1593
Endings:  1591
Endings:  1589
Endings:  1587
Endings:  1585
Endings:  1584
Endings:  1582
Endings:  1580
Endings:  1579
Endings:  1577
Endings:  1575
Endings:  1573
Endings:  1571
Endings:  1569
Endings:  1567
Endings:  1565
Endings:  1563
Endings:  1561
Endings:  1559
Endings:  1557
Endings:  1556
Endings:  1554
Endings:  1552
Endings:  1550
Endings:  1548
Endings:  1546
Endings:  1544
Endings:  1542
Endings:  1540
Endings:  1538
Endings:  1536
Endings:  1534
Endings:  1532
Endings:  1531
Endings:  1529
Endings:  1527
Endings:  1525
Endings:  1524
Endings:  1522
Endings:  1520
Endings:  1518
Endings:  1516
Endings:  1514
Endings:  1512
Endings:  1510
Endings:  1508
Endings:  1506
Endings:  1504
Endings:  1502
Endings:  1500
Endings:  1498
Endings:  1496
Endings:  1494
Endings:  1492
Endings:  1490
Endings:  1488
Endings:  1487
Endings:  1485
Endings:  1483
Endings:  1481
Endings:  1480
Endings:  1478
Endings:  

Endings:  597
Endings:  595
Endings:  594
Endings:  592
Endings:  591
Endings:  590
Endings:  588
Endings:  586
Endings:  584
Endings:  582
Endings:  580
Endings:  578
Endings:  576
Endings:  574
Endings:  573
Endings:  572
Endings:  570
Endings:  568
Endings:  566
Endings:  564
Endings:  563
Endings:  561
Endings:  559
Endings:  557
Endings:  555
Endings:  553
Endings:  552
Endings:  550
Endings:  548
Endings:  546
Endings:  544
Endings:  542
Endings:  540
Endings:  539
Endings:  537
Endings:  535
Endings:  534
Endings:  532
Endings:  530
Endings:  528
Endings:  526
Endings:  525
Endings:  524
Endings:  523
Endings:  521
Endings:  519
Endings:  517
Endings:  515
Endings:  513
Endings:  512
Endings:  510
Endings:  509
Endings:  508
Endings:  506
Endings:  505
Endings:  503
Endings:  502
Endings:  500
Endings:  498
Endings:  496
Endings:  494
Endings:  492
Endings:  491
Endings:  489
Endings:  487
Endings:  485
Endings:  483
Endings:  482
Endings:  481
Endings:  479
Endings:  477
Ending




Endings:  2222
Endings:  2221
Endings:  2219
Endings:  2218
Endings:  2217
Endings:  2215
Endings:  2213
Endings:  2212
Endings:  2211
Endings:  2209
Endings:  2207
Endings:  2206
Endings:  2205
Endings:  2203
Endings:  2201
Endings:  2199
Endings:  2198
Endings:  2197
Endings:  2195
Endings:  2193
Endings:  2191
Endings:  2189
Endings:  2187
Endings:  2185
Endings:  2183
Endings:  2181
Endings:  2179
Endings:  2178
Endings:  2176
Endings:  2174
Endings:  2172
Endings:  2170
Endings:  2168
Endings:  2166
Endings:  2165
Endings:  2164
Endings:  2162
Endings:  2160
Endings:  2158
Endings:  2156
Endings:  2154
Endings:  2152
Endings:  2150
Endings:  2148
Endings:  2146
Endings:  2144
Endings:  2143
Endings:  2142
Endings:  2140
Endings:  2139
Endings:  2138
Endings:  2136
Endings:  2134
Endings:  2133
Endings:  2132
Endings:  2130
Endings:  2129
Endings:  2128
Endings:  2127
Endings:  2126
Endings:  2125
Endings:  2123
Endings:  2122
Endings:  2121
Endings:  2120
Endings:  2119
Endings:  

Endings:  1332
Endings:  1330
Endings:  1328
Endings:  1326
Endings:  1324
Endings:  1323
Endings:  1321
Endings:  1319
Endings:  1318
Endings:  1316
Endings:  1314
Endings:  1312
Endings:  1311
Endings:  1310
Endings:  1309
Endings:  1307
Endings:  1305
Endings:  1304
Endings:  1302
Endings:  1301
Endings:  1299
Endings:  1297
Endings:  1295
Endings:  1293
Endings:  1291
Endings:  1289
Endings:  1287
Endings:  1286
Endings:  1284
Endings:  1282
Endings:  1280
Endings:  1278
Endings:  1277
Endings:  1275
Endings:  1274
Endings:  1272
Endings:  1270
Endings:  1268
Endings:  1266
Endings:  1264
Endings:  1262
Endings:  1261
Endings:  1259
Endings:  1257
Endings:  1255
Endings:  1254
Endings:  1252
Endings:  1250
Endings:  1249
Endings:  1248
Endings:  1247
Endings:  1246
Endings:  1244
Endings:  1243
Endings:  1241
Endings:  1239
Endings:  1237
Endings:  1235
Endings:  1233
Endings:  1232
Endings:  1230
Endings:  1228
Endings:  1226
Endings:  1224
Endings:  1222
Endings:  1220
Endings:  

Endings:  363
Endings:  362
Endings:  360
Endings:  359
Endings:  357
Endings:  355
Endings:  353
Endings:  351
Endings:  349
Endings:  347
Endings:  346
Endings:  345
Endings:  344
Endings:  342
Endings:  341
Endings:  340
Endings:  339
Endings:  337
Endings:  335
Endings:  334
Endings:  332
Endings:  330
Endings:  328
Endings:  326
Endings:  325
Endings:  324
Endings:  323
Endings:  322
Endings:  320
Endings:  318
Endings:  316
Endings:  314
Endings:  313
Endings:  311
Endings:  309
Endings:  307
Endings:  305
Endings:  303
Endings:  301
Endings:  299
Endings:  297
Endings:  295
Endings:  294
Endings:  293
Endings:  292
Endings:  290
Endings:  288
Endings:  286
Endings:  285
Endings:  283
Endings:  281
Endings:  279
Endings:  277
Endings:  276
Endings:  274
Endings:  273
Endings:  271
Endings:  269
Endings:  268
Endings:  266
Endings:  264
Endings:  262
Endings:  260
Endings:  258
Endings:  256
Endings:  254
Endings:  253
Endings:  252
Endings:  250
Endings:  248
Endings:  246
Ending

Endings:  194
Endings:  193
Endings:  192
Endings:  191
Endings:  190
Endings:  189
Endings:  188
Endings:  187
Endings:  186
Endings:  185
Endings:  184
Endings:  183
Endings:  182
Endings:  181
Endings:  180
Endings:  179
Endings:  178
Endings:  177
Endings:  176
Endings:  175
Endings:  174
Endings:  173
Endings:  172
Endings:  171
Endings:  170
Endings:  169
Endings:  168
Endings:  167
Endings:  166
Endings:  165
Endings:  164
Endings:  163
Endings:  162
Endings:  161
Endings:  160
Endings:  159
Endings:  158
Endings:  157
Endings:  156
Endings:  155
Endings:  154
Endings:  153
Endings:  152
Endings:  151
Endings:  150
Endings:  149
Endings:  148
Endings:  147
Endings:  146
Endings:  145
Endings:  144
Endings:  143
Endings:  142
Endings:  141
Endings:  140
Endings:  139
Endings:  138
Endings:  137
Endings:  136
Endings:  135
Endings:  134
Endings:  133
Endings:  132
Endings:  131
Endings:  130
Endings:  129
Endings:  128
Endings:  127
Endings:  126
Endings:  125
Endings:  124
Ending

Endings:  126
Endings:  125
Endings:  124
Endings:  123
Endings:  122
Endings:  121
Endings:  120
Endings:  119
Endings:  118
Endings:  117
Endings:  116
Endings:  115
Endings:  114
Endings:  113
Endings:  112
Endings:  111
Endings:  110
Endings:  109
Endings:  108
Endings:  107
Endings:  106
Endings:  105
Endings:  104
Endings:  103
Endings:  102
Endings:  101
Endings:  100
Endings:  99
Endings:  98
Endings:  97
Endings:  96
Endings:  95
Endings:  94
Endings:  93
Endings:  92
Endings:  91
Endings:  90
Endings:  89
Endings:  88
Endings:  87
Endings:  86
Endings:  85
Endings:  84
Endings:  83
Endings:  82
Endings:  81
Endings:  80
Endings:  79
Endings:  78
Endings:  77
Endings:  76
Endings:  75
Endings:  74
Endings:  73
Endings:  72
Endings:  71
Endings:  70
Endings:  69
Endings:  68
Endings:  67
Endings:  66
Endings:  65
Endings:  64
Endings:  63
Endings:  62
Endings:  61
Endings:  60
Endings:  59
Endings:  58
Endings:  57
Endings:  56
Endings:  55
Endings:  54
Endings:  53
Endings:  5

In [157]:
model_with_graph = draw_graph_on_model(bin_reco, vessel_graph_cn)
ColorMapVisualizer(model_with_graph.astype(np.uint8)).visualize()

In [137]:
%%time
data_graph = parametrize_graph(vessel_graph_cn, edt_img)

CPU times: total: 109 ms
Wall time: 111 ms


  v1 /= np.linalg.norm(v1)
  v2 /= np.linalg.norm(v2)


In [138]:
%%time
data_graph_cl = clean_data_graph(data_graph)

CPU times: total: 0 ns
Wall time: 2.01 ms


In [139]:
model_with_graph = draw_graph_on_model(bin_reco, data_graph_cl)

In [140]:
ColorMapVisualizer(model_with_graph.astype(np.uint8)).visualize()