In [16]:
import nrrd
import numpy as np
import os
import matplotlib.pyplot as plt
from ipywidgets import interact, IntSlider
from skimage.segmentation import mark_boundaries
import cv2
from scipy import ndimage as ndi
from helper import *
import graph_tool.all as gt
import plotly.graph_objects as go
import time

current_directory = os.getcwd()
label_path = f"{current_directory}/data/label/manual_1_label.nrrd"
raw_data_path = f"{current_directory}/data/raw/manual_1_raw.nrrd"

mask_data, mask_header = nrrd.read(label_path)
raw_data, raw_header = nrrd.read(raw_data_path)

In [17]:
#create reduced data

reduction_factor = 2
reduced_mask_data = reduce_array(mask_data, (reduction_factor,reduction_factor,reduction_factor), method='mean')
reduced_raw_data = reduce_array(raw_data, (reduction_factor,reduction_factor,reduction_factor), method='mean')
print(mask_data.shape, reduced_mask_data.shape)

(256, 256, 256) (128, 128, 128)


In [18]:
#create origianl and rotated graphs

# original_directed_graph, wp = create_directed_energy_graph_from_mask(reduced_mask_data)
# original_directed_graph, original_source, original_sink, original_weights = add_directed_source_sink(original_directed_graph, reduced_mask_data.shape[0], reduced_mask_data.shape[1], reduced_mask_data.shape[2])
# o_src, o_tgt = original_directed_graph.vertex(original_source), original_directed_graph.vertex(original_sink)

axis = 1 #0 for x, 1 for y and 2 for z
k = 1 # number of times to rotate the array
# due to the direction and symmetry, only a single rotation 
# on y and z along with the original is needed to get the best cut
reduced_mask_data_t = np.rot90(reduced_mask_data, k=k,  axes=(axis, (axis+1)%3))
reduced_raw_data_t = np.rot90(reduced_raw_data, k=k, axes=(axis, (axis+1)%3))

rotated_directed_graph, wp = create_directed_energy_graph_from_mask(reduced_mask_data_t)
rotated_directed_graph, rotated_source, rotated_sink, rotated_weights = add_directed_source_sink(rotated_directed_graph, reduced_mask_data_t.shape[0], reduced_mask_data_t.shape[1], reduced_mask_data_t.shape[2])
r_src, r_tgt = rotated_directed_graph.vertex(rotated_source), rotated_directed_graph.vertex(rotated_sink)

128 128 128


In [19]:
def calculate_seam_iter(directed_graph, src, tgt, weights, test_size):
    # Compute the residual capactiy of the edges
    # stime = time.time()
    # res = gt.boykov_kolmogorov_max_flow(directed_graph, src, tgt, weights) #time complexity: edges * vertices^2 * abs(min cut) 
    # print("Time taken to calculate max flow boykov:", time.time()-stime)
    sptime = time.time()
    res = gt.push_relabel_max_flow(directed_graph, src, tgt, weights) #time complexity: vertices^3 <- scales better for our graph
    print("Time taken to calculate max flow push relabel:", time.time()-sptime)
    stime = time.time()
    #use the residual graph to get the max flow
    flow = sum(weights[e] - res[e] for e in tgt.in_edges())
    # print("The maximum flow from source to sink is:", flow)
    print("Time taken to calculate max flow:", time.time()-stime, "flow:", flow)
    stime = time.time()
    # Determine the minimum cut partition
    part = gt.min_st_cut(directed_graph, src, weights, res)
    # print("The number of vertices in the partiton:", sum(part))
    print("Time taken to calculate max flow and min cut:", time.time()-stime)
    # Find the boundary vertices
    stime = time.time()
    boundary_vertices = find_boundary_vertices(np.array(directed_graph.get_edges()), part)
    # print("Number of boundary vertices:", len(boundary_vertices))
    print("Time taken to calculate seam:", time.time()-stime)

    shape = (test_size, test_size, test_size)

    # Convert the boundary vertices to a 3D array
    stime = time.time()
    boundary_array = boundary_vertices_to_array(boundary_vertices, shape, 'x')
    # print("Boundary points marked:", np.sum(boundary_array))
    print("Time taken to convert boundary vertices to array:", time.time()-stime)

    stime = time.time()
    graph_seam_indices = seam_voxels_to_graph_indices(boundary_array)
    # print("Number of graph indices in the seam:", len(graph_seam_indices))
    print("Time taken to convert seam voxels to graph indices:", time.time()-stime)
    swtime = time.time()
    weights = update_outgoing_edges_p(directed_graph, graph_seam_indices, 1e8) #this is a bottleneck
    print("Time taken to update weights outside fxn:", time.time()-swtime)

    return boundary_array, weights, flow

In [20]:
original_boundary_array = np.zeros(reduced_mask_data.shape)
rotated_boundary_array = np.zeros(reduced_mask_data_t.shape)

In [21]:
iters = 32
for i in range(0,iters):
    # bau, original_weights = calculate_seam_iter(original_directed_graph, o_src, o_tgt, original_weights, reduced_mask_data.shape[0])
    rbau, rotated_weights, flow = calculate_seam_iter(rotated_directed_graph, r_src, r_tgt, rotated_weights, reduced_mask_data_t.shape[0])

    # original_boundary_array += bau

    overlap = np.sum(np.logical_and(rotated_boundary_array, rbau))
    print("Sum of overlap:", overlap)
    rotated_boundary_array += rbau

    print("Iteration:", i)

Time taken to calculate max flow push relabel: 22.853574991226196
Time taken to calculate max flow: 0.1616818904876709 flow: 24331
Time taken to calculate max flow and min cut: 0.8152980804443359
edges that cross the partition: 12452352
boundary vertices: 27181
Time taken to calculate seam: 0.17088723182678223
Time taken to convert boundary vertices to array: 1.5631158351898193
Time taken to convert seam voxels to graph indices: 0.014582157135009766
Time taken to update weights:  11.62621784210205
Time taken to update weights outside fxn: 11.629797220230103
Sum of overlap: 0
Iteration: 0
Time taken to calculate max flow push relabel: 30.2525851726532
Time taken to calculate max flow: 0.16325998306274414 flow: 25564
Time taken to calculate max flow and min cut: 0.8815698623657227
edges that cross the partition: 12452352
boundary vertices: 26587
Time taken to calculate seam: 0.17076897621154785
Time taken to convert boundary vertices to array: 1.5747621059417725
Time taken to convert sea

In [22]:
# Create the new boundary array initialized to zeros
boundary_array = np.zeros_like(original_boundary_array)
boundary_array[original_boundary_array > 0] = 1
boundary_array[rotated_boundary_array > 0] = 2
if reduced_mask_data_t.shape[0] <= 64:

    # Plot the 3d boundary array with plotly
    showAll = False
    # Define the grid dimensions
    x_dim, y_dim, z_dim = boundary_array.shape
    x, y, z = np.indices((x_dim, y_dim, z_dim))

    # Flatten the data
    x, y, z = x.flatten(), y.flatten(), z.flatten()
    if showAll:
        # vis_data = np.copy(test_data)
        vis_data = np.where(reduced_mask_data_t != 0, 1, 0)
        vis_data[boundary_array == 1] = 2
    else:
        vis_data = boundary_array

    values = vis_data.flatten()

    # Calculate the linear index using C-style row-major order
    index = lambda i, j, k: i * y_dim * z_dim + j * z_dim + k
    indices = [index(i, j, k) for i, j, k in zip(x, y, z)]

    # Create a scatter plot
    fig = go.Figure(data=[go.Scatter3d(
        x=x,
        y=y,
        z=z,
        mode='markers',
        marker=dict(
            size=1,
            color=values,
            # colorscale='Viridis',
            colorbar=dict(title='Color Scale'),
            # opacity=0.8
            colorscale=[[0, 'rgba(0, 0, 0, 0)'], [0.5, 'blue'], [1.0, 'red']],  # Blue for 1, transparent for 0
            # opacity=[1 if value == 1 else 0 for value in values]  # Fully opaque for 1, fully transparent for 0
        )
    )])

    # Update layout to remove background grids and keep axes fixed
    fig.update_layout(scene=dict(
        xaxis=dict(showbackground=False, showgrid=False, showline=True, zeroline=True),
        yaxis=dict(showbackground=False, showgrid=False, showline=True, zeroline=True),
        zaxis=dict(showbackground=False, showgrid=False, showline=True, zeroline=True)
    ))

    fig.show()


In [23]:
# boundary_array = np.where(boundary_array != 0, 1, 0)
boundary_array = rotated_boundary_array
boundary_array = np.where(boundary_array != 0, 1, 0)
def plot_slice(slice_index, axis=0):
    plt.figure(figsize=(8, 6))
    if axis == 1:
        plt.imshow(mark_boundaries_color(reduced_raw_data_t[:,slice_index,:], boundary_array[:,slice_index,:]))
    elif axis == 2:
        plt.imshow(mark_boundaries_color(reduced_raw_data_t[:,:,slice_index], boundary_array[:,:,slice_index]))
    else:
        plt.imshow(mark_boundaries_color(reduced_raw_data_t[slice_index,:,:], boundary_array[slice_index,:,:]))
    plt.colorbar()
    plt.title(f'Slice {slice_index}')
    plt.show()
 
print(reduced_raw_data_t.max(), reduced_raw_data_t.min())
# Create a slider to browse through slices
interact(plot_slice, slice_index=IntSlider(min=0, max=reduced_raw_data.shape[0]-1, step=1, value=0), axis=IntSlider(min=0, max=2, step=1, value=0))

65535 4421


interactive(children=(IntSlider(value=0, description='slice_index', max=127), IntSlider(value=0, description='…

<function __main__.plot_slice(slice_index, axis=0)>

In [24]:
def remove_voxels(data_array, boundary_array, num_voxels_to_remove, direction='x'):
    """
    Removes a specified number of voxels from each row in the given direction
    based on a boundary mask array.
    
    Parameters:
    data_array (np.ndarray): The original 3D array from which voxels are to be removed.
    boundary_array (np.ndarray): The boundary mask 3D array specifying the voxels to be removed.
    num_voxels_to_remove (int): Number of voxels to be removed from each row.
    direction (str): The direction in which voxels are to be removed ('x', 'y', 'z').
    
    Returns:
    np.ndarray: The resulting 3D array with the specified voxels removed.
    """
    assert data_array.shape == boundary_array.shape, "Data array and boundary array must have the same shape."
    assert direction in ['x', 'y', 'z'], "Direction must be 'x', 'y', or 'z'."
    
    if direction == 'x':
        new_shape = (data_array.shape[0], data_array.shape[1], data_array.shape[2] - num_voxels_to_remove)
        result_array = np.zeros(new_shape, dtype=data_array.dtype)
        for i in range(data_array.shape[0]):
            for j in range(data_array.shape[1]):
                data_row = data_array[i, j, :]
                boundary_row = boundary_array[i, j, :]
                keep_indices = np.where(boundary_row == 0)[0]
                keep_indices = keep_indices[:new_shape[2]]  # Only keep up to the new size
                result_array[i, j, :] = data_row[keep_indices][:new_shape[2]]
    
    elif direction == 'y':
        new_shape = (data_array.shape[0], data_array.shape[1] - num_voxels_to_remove, data_array.shape[2])
        result_array = np.zeros(new_shape, dtype=data_array.dtype)
        for i in range(data_array.shape[0]):
            for k in range(data_array.shape[2]):
                data_row = data_array[i, :, k]
                boundary_row = boundary_array[i, :, k]
                keep_indices = np.where(boundary_row == 0)[0]
                keep_indices = keep_indices[:new_shape[1]]  # Only keep up to the new size
                result_array[i, :, k] = data_row[keep_indices][:new_shape[1]]
    
    elif direction == 'z':
        new_shape = (data_array.shape[0] - num_voxels_to_remove, data_array.shape[1], data_array.shape[2])
        result_array = np.zeros(new_shape, dtype=data_array.dtype)
        for j in range(data_array.shape[1]):
            for k in range(data_array.shape[2]):
                data_row = data_array[:, j, k]
                boundary_row = boundary_array[:, j, k]
                keep_indices = np.where(boundary_row == 0)[0]
                keep_indices = keep_indices[:new_shape[0]]  # Only keep up to the new size
                result_array[:, j, k] = data_row[keep_indices][:new_shape[0]]
    
    return result_array

In [25]:
print(iters)
densified_data = remove_voxels(reduced_raw_data_t, boundary_array, iters, direction='x')
densified_label = remove_voxels(reduced_mask_data_t, boundary_array, iters, direction='x')
print(densified_data.shape, densified_label.shape)

32
(128, 128, 96) (128, 128, 96)


In [30]:
#save teh boundary array, densified data and densified label as nrrd files in /output
nrrd.write(f"{current_directory}/output/boundary_array.nrrd", boundary_array)
nrrd.write(f"{current_directory}/output/densified_data.nrrd", densified_data)
nrrd.write(f"{current_directory}/output/densified_label.nrrd", densified_label)

In [29]:
def plot_slice(slice_index, axis=0):
    plt.figure(figsize=(8, 6))
    if axis == 1:
        plt.imshow(densified_data[:,slice_index,:],cmap='gray')
    elif axis == 2:
        plt.imshow(densified_data[:,:,slice_index],cmap='gray')
    else:
        plt.imshow(densified_data[slice_index,:,:],cmap='gray')
    plt.colorbar()
    plt.title(f'Slice {slice_index}')
    plt.show()
 
print(reduced_raw_data_t.max(), reduced_raw_data_t.min())
# Create a slider to browse through slices
interact(plot_slice, slice_index=IntSlider(min=0, max=densified_data.shape[0]-1, step=1, value=0), axis=IntSlider(min=0, max=2, step=1, value=0))

65535 4421


interactive(children=(IntSlider(value=0, description='slice_index', max=127), IntSlider(value=0, description='…

<function __main__.plot_slice(slice_index, axis=0)>

In [28]:
def plot_slice(slice_index, axis=0):
    plt.figure(figsize=(8, 6))
    if axis == 1:
        plt.imshow(densified_label[:,slice_index,:])
    elif axis == 2:
        plt.imshow(densified_label[:,:,slice_index])
    else:
        plt.imshow(densified_label[slice_index,:,:])
    plt.colorbar()
    plt.title(f'Slice {slice_index}')
    plt.show()
 
print(reduced_raw_data_t.max(), reduced_raw_data_t.min())
# Create a slider to browse through slices
interact(plot_slice, slice_index=IntSlider(min=0, max=densified_data.shape[0]-1, step=1, value=0), axis=IntSlider(min=0, max=2, step=1, value=0))

65535 4421


interactive(children=(IntSlider(value=0, description='slice_index', max=127), IntSlider(value=0, description='…

<function __main__.plot_slice(slice_index, axis=0)>