In [1]:
import nrrd
import numpy as np
import os
import matplotlib.pyplot as plt
from ipywidgets import interact, IntSlider
from skimage.segmentation import mark_boundaries
import cv2
from scipy import ndimage as ndi
from helper import *
import graph_tool.all as gt
import plotly.graph_objects as go
import time

current_directory = os.getcwd()
label_path = f"{current_directory}/data/label/manual_1_label.nrrd"
raw_data_path = f"{current_directory}/data/raw/manual_1_raw.nrrd"

mask_data, mask_header = nrrd.read(label_path)
raw_data, raw_header = nrrd.read(raw_data_path)

In [2]:
#rotate the data for best cut angle and calculate reduced data representations
axis = 1 #0 for x, 1 for y and 2 for z
k = 1 # number of times to rotate the array
mask_data = np.rot90(mask_data, k, axes=(axis, (axis+1)%3))
raw_data = np.rot90(raw_data, k, axes=(axis, (axis+1)%3))

reduced_mask_data = coarsen_image(mask_data, 3)
reduced_raw_data = coarsen_image(raw_data, 3)
#[0] is full res, each further index is 2x lower res, 256, 128, 64, 32

In [3]:
#create reduced graph
res_index = 2
rotated_directed_graph, r_src, r_tgt, rotated_weights, x_pos, y_pos, z_pos = create_masked_directed_energy_graph_from_mask(reduced_mask_data[res_index])

64 64 64
Time taken to add vertices to coord_to_vertex: 1.1938362121582031
Time taken to add edges to graph: 2.0253918170928955
Time taken to add source and sink nodes: 0.08239316940307617


In [4]:
def calculate_seam_iter(directed_graph, src, tgt, weights, test_size, x_pos, y_pos, z_pos):
    # Compute the residual capactiy of the edges
    stime = time.time()
    res = gt.boykov_kolmogorov_max_flow(directed_graph, src, tgt, weights) #time complexity: edges * vertices^2 * abs(min cut) 
    print("Time taken to calculate max flow boykov:", time.time()-stime)
    # sptime = time.time()
    # res = gt.push_relabel_max_flow(directed_graph, src, tgt, weights) #time complexity: vertices^3 <- scales better for our graph
    # print("Time taken to calculate max flow push relabel:", time.time()-sptime)
    stime = time.time()
    #use the residual graph to get the max flow
    flow = sum(weights[e] - res[e] for e in tgt.in_edges())
    # print("The maximum flow from source to sink is:", flow)
    print("Time taken to calculate max flow:", time.time()-stime, "flow:", flow)
    stime = time.time()
    # Determine the minimum cut partition
    part = gt.min_st_cut(directed_graph, src, weights, res)
    # print("The number of vertices in the partiton:", sum(part))
    print("Time taken to calculate max flow and min cut:", time.time()-stime)
    # Find the boundary vertices
    stime = time.time()
    boundary_vertices = find_boundary_vertices(np.array(directed_graph.get_edges()), part)
    # print("Number of boundary vertices:", len(boundary_vertices))
    print("Time taken to calculate seam:", time.time()-stime)

    shape = (test_size, test_size, test_size)

    # Convert the boundary vertices to a 3D array
    stime = time.time()
    boundary_array = boundary_vertices_to_array_masked(boundary_vertices, shape, 'x', x_pos, y_pos, z_pos)
    # print("Boundary points marked:", np.sum(boundary_array))
    print("Time taken to convert boundary vertices to array:", time.time()-stime)

    return boundary_array, flow

In [5]:
boundary_array, flow = calculate_seam_iter(rotated_directed_graph, r_src, r_tgt, rotated_weights, reduced_mask_data[res_index].shape[0], x_pos, y_pos, z_pos)
b_arr_up = upscale_and_dilate_3d(boundary_array, upscale_factor=4, dilation_amount=1)
print(b_arr_up.shape)

Time taken to calculate max flow boykov: 4.599825143814087
Time taken to calculate max flow: 0.041953086853027344 flow: 8848
Time taken to calculate max flow and min cut: 0.0896310806274414
edges that cross the partition: 1540352
boundary vertices: 6523
Time taken to calculate seam: 0.01758098602294922
Time taken to convert boundary vertices to array: 0.026566267013549805
(256, 256, 256)


In [6]:
# copying and masking the high res data for masked graph creation
res_index = 0
masked_array = reduced_mask_data[res_index].copy()
print(masked_array.shape, b_arr_up.shape)
masked_array[b_arr_up == 0] = -1

masked_graph, masked_src, masked_sink, masked_weights, x_pos, y_pos, z_pos = create_masked_directed_energy_graph_from_mask(masked_array)

(256, 256, 256) (256, 256, 256)
256 256 256
Time taken to add vertices to coord_to_vertex: 2.06321382522583
Time taken to add edges to graph: 3.194391965866089
Time taken to add source and sink nodes: 1.9652209281921387


In [7]:
#plot masked array
print(masked_graph.num_vertices(), masked_graph.num_edges(), masked_src, masked_sink)
res_i = 0
def plot_slice(slice_index, axis=0):
    plt.figure(figsize=(8, 6))
    if axis == 1:
        plt.imshow(masked_array[:,slice_index,:])
    elif axis == 2:
        plt.imshow(masked_array[:,:,slice_index])
    else:
        plt.imshow(masked_array[slice_index,:,:])
    plt.colorbar()
    plt.title(f'Slice {slice_index}')
    plt.show()

# Create a slider to browse through slices
interact(plot_slice, slice_index=IntSlider(min=0, max=masked_array.shape[0]-1, step=1, value=0), axis=IntSlider(min=0, max=2, step=1, value=0))

463739 2461814 463737 463738


interactive(children=(IntSlider(value=0, description='slice_index', max=255), IntSlider(value=0, description='…

<function __main__.plot_slice(slice_index, axis=0)>

In [8]:
#plot label ontop of raw data
res_i = 0
def plot_slice(slice_index, axis=0):
    plt.figure(figsize=(8, 6))
    if axis == 1:
        plt.imshow(mark_boundaries_color(b_arr_up[:,slice_index,:], reduced_mask_data[res_i][:,slice_index,:]))
    elif axis == 2:
        plt.imshow(mark_boundaries_color(b_arr_up[:,:,slice_index], reduced_mask_data[res_i][:,:,slice_index]))
    else:
        plt.imshow(mark_boundaries_color(b_arr_up[slice_index,:,:], reduced_mask_data[res_i][slice_index,:,:]))
    plt.colorbar()
    plt.title(f'Slice {slice_index}')
    plt.show()

# Create a slider to browse through slices
interact(plot_slice, slice_index=IntSlider(min=0, max=b_arr_up.shape[0]-1, step=1, value=0), axis=IntSlider(min=0, max=2, step=1, value=0))

interactive(children=(IntSlider(value=0, description='slice_index', max=255), IntSlider(value=0, description='…

<function __main__.plot_slice(slice_index, axis=0)>

In [9]:
boundary_array, flow = calculate_seam_iter(masked_graph, masked_src, masked_sink, masked_weights, masked_array.shape[0], x_pos, y_pos, z_pos)

Time taken to calculate max flow boykov: 2.7150728702545166
Time taken to calculate max flow: 0.6388070583343506 flow: 119257
Time taken to calculate max flow and min cut: 5.117893695831299
edges that cross the partition: 2461814
boundary vertices: 105451
Time taken to calculate seam: 0.06766295433044434
Time taken to convert boundary vertices to array: 0.44994616508483887


In [10]:
print("flow:", flow)
print("boundary_array:", np.sum(boundary_array))

flow: 119257
boundary_array: 65536


In [11]:
#plot full res boundary array
def plot_slice(slice_index, axis=0):
    plt.figure(figsize=(8, 6))
    if axis == 1:
        plt.imshow(boundary_array[:,slice_index,:])
    elif axis == 2:
        plt.imshow(boundary_array[:,:,slice_index])
    else:
        plt.imshow(boundary_array[slice_index,:,:])
    plt.colorbar()
    plt.title(f'Slice {slice_index}')
    plt.show()

# Create a slider to browse through slices
interact(plot_slice, slice_index=IntSlider(min=0, max=boundary_array.shape[0]-1, step=1, value=0), axis=IntSlider(min=0, max=2, step=1, value=0))

interactive(children=(IntSlider(value=0, description='slice_index', max=255), IntSlider(value=0, description='…

<function __main__.plot_slice(slice_index, axis=0)>

In [12]:
#plot seam ontop of raw data
res_i = 0
def plot_slice(slice_index, axis=0):
    plt.figure(figsize=(8, 6))
    if axis == 1:
        plt.imshow(mark_boundaries_color(boundary_array[:,slice_index,:], reduced_mask_data[res_i][:,slice_index,:]))
    elif axis == 2:
        plt.imshow(mark_boundaries_color(boundary_array[:,:,slice_index], reduced_mask_data[res_i][:,:,slice_index]))
    else:
        plt.imshow(mark_boundaries_color(boundary_array[slice_index,:,:], reduced_mask_data[res_i][slice_index,:,:]))
    plt.colorbar()
    plt.title(f'Slice {slice_index}')
    plt.show()

# Create a slider to browse through slices
interact(plot_slice, slice_index=IntSlider(min=0, max=b_arr_up.shape[0]-1, step=1, value=0), axis=IntSlider(min=0, max=2, step=1, value=0))

interactive(children=(IntSlider(value=0, description='slice_index', max=255), IntSlider(value=0, description='…

<function __main__.plot_slice(slice_index, axis=0)>

In [13]:
def multi_res_seam_iter(res_index, mask_array_data, b_arr_up):
    masked_array = mask_array_data[res_index].copy()
    masked_array[b_arr_up == 0] = -1
    stime= time.time()
    directed_graph, src, tgt, weights, x_pos, y_pos, z_pos = create_masked_directed_energy_graph_from_mask(masked_array)
    print(f"Time taken to create graph {res_index}:", time.time()-stime)
    boundary_array, flow = calculate_seam_iter(directed_graph, src, tgt, weights, masked_array.shape[0], x_pos, y_pos, z_pos)
    return boundary_array

def multi_res_seam_calculation(mask_array_data, res_index=3, upscale_factor=2, dilation_amount=1):
    stime= time.time()
    directed_graph, src, tgt, weights, x_pos, y_pos, z_pos = create_masked_directed_energy_graph_from_mask(mask_array_data[res_index])
    print(f"Time taken to create graph {res_index}:", time.time()-stime)
    boundary_array, flow = calculate_seam_iter(directed_graph, src, tgt, weights, mask_array_data[res_index].shape[0], x_pos, y_pos, z_pos)
    b_arr_up = upscale_and_dilate_3d(boundary_array, upscale_factor=2, dilation_amount=dilation_amount)
    for i in range(res_index-1, -1, -1):
        boundary_array = multi_res_seam_iter(i, mask_array_data, b_arr_up)
        b_arr_up = upscale_and_dilate_3d(boundary_array, upscale_factor=upscale_factor, dilation_amount=dilation_amount)
    return boundary_array



In [18]:
res_index = 2
mask_array_data = reduced_mask_data
num_seams_to_remove = 1

for i in range(num_seams_to_remove):
    boundary_array = multi_res_seam_calculation(mask_array_data, res_index, dilation_amount=0)
    #TODO: remove seam from mask_array_data

64 64 64
Time taken to add vertices to coord_to_vertex: 1.2046380043029785
Time taken to add edges to graph: 2.0474579334259033
Time taken to add source and sink nodes: 0.08918094635009766
Time taken to create graph 2: 3.3590247631073
Time taken to calculate max flow boykov: 4.213392019271851
Time taken to calculate max flow: 0.04012274742126465 flow: 8848
Time taken to calculate max flow and min cut: 0.09046030044555664
edges that cross the partition: 1540352
boundary vertices: 6523
Time taken to calculate seam: 0.021227121353149414
Time taken to convert boundary vertices to array: 0.026913881301879883
128 128 128
Time taken to add vertices to coord_to_vertex: 0.13570833206176758
Time taken to add edges to graph: 0.09845685958862305
Time taken to add source and sink nodes: 0.3834991455078125
Time taken to create graph 1: 0.620053768157959
Time taken to calculate max flow boykov: 0.14374279975891113
Time taken to calculate max flow: 0.15477871894836426 flow: 547000015151
Time taken to 

In [19]:
#plot boundary array, aka seam to remove
res_i = 0
def plot_slice(slice_index, axis=0):
    plt.figure(figsize=(8, 6))
    if axis == 1:
        plt.imshow(boundary_array[:,slice_index,:])
    elif axis == 2:
        plt.imshow(boundary_array[:,:,slice_index])
    else:
        plt.imshow(boundary_array[slice_index,:,:])
    plt.colorbar()
    plt.title(f'Slice {slice_index}')
    plt.show()

# Create a slider to browse through slices
interact(plot_slice, slice_index=IntSlider(min=0, max=boundary_array.shape[0]-1, step=1, value=0), axis=IntSlider(min=0, max=2, step=1, value=0))

interactive(children=(IntSlider(value=0, description='slice_index', max=255), IntSlider(value=0, description='…

<function __main__.plot_slice(slice_index, axis=0)>

In [20]:
#plot seam ontop of mask array
res_i = 0
def plot_slice(slice_index, axis=0):
    plt.figure(figsize=(8, 6))
    if axis == 1:
        plt.imshow(mark_boundaries_color(boundary_array[:,slice_index,:], reduced_mask_data[res_i][:,slice_index,:]))
    elif axis == 2:
        plt.imshow(mark_boundaries_color(boundary_array[:,:,slice_index], reduced_mask_data[res_i][:,:,slice_index]))
    else:
        plt.imshow(mark_boundaries_color(boundary_array[slice_index,:,:], reduced_mask_data[res_i][slice_index,:,:]))
    plt.colorbar()
    plt.title(f'Slice {slice_index}')
    plt.show()

# Create a slider to browse through slices
interact(plot_slice, slice_index=IntSlider(min=0, max=boundary_array.shape[0]-1, step=1, value=0), axis=IntSlider(min=0, max=2, step=1, value=0))

interactive(children=(IntSlider(value=0, description='slice_index', max=255), IntSlider(value=0, description='…

<function __main__.plot_slice(slice_index, axis=0)>