In [1]:
import nrrd
import numpy as np
import os
import matplotlib.pyplot as plt
from ipywidgets import interact, IntSlider
from skimage.segmentation import mark_boundaries
import cv2
from scipy import ndimage as ndi
from helper import *
import graph_tool.all as gt
import plotly.graph_objects as go
import time
from datetime import datetime


current_directory = os.getcwd()
filename = "manual_2"
label_path = f"{current_directory}/data/label/{filename}_label.nrrd"
raw_data_path = f"{current_directory}/data/raw/{filename}_raw.nrrd"

mask_data, mask_header = nrrd.read(label_path)
raw_data, raw_header = nrrd.read(raw_data_path)

In [2]:
#rotate the data for best cut angle and calculate reduced data representations
#NOTE: dont rotate manual label 2
rotate = False
if rotate:
    axis = 1 #0 for x, 1 for y and 2 for z
    k = 1 # number of times to rotate the array
    mask_data = np.rot90(mask_data, k, axes=(axis, (axis+1)%3))
    raw_data = np.rot90(raw_data, k, axes=(axis, (axis+1)%3))

reduced_mask_data = coarsen_image(mask_data, 3)
reduced_raw_data = coarsen_image(raw_data, 3)
#[0] is full res, each further index is 2x lower res, 256, 128, 64, 32

In [9]:
#visualise data to confirm alignment
res_i = 0
def plot_slice(slice_index, axis=0):
    plt.figure(figsize=(8, 6))
    if axis == 1:
        plt.imshow(mark_boundaries_color(reduced_raw_data[res_i][:,slice_index,:], reduced_mask_data[res_i][:,slice_index,:]))
    elif axis == 2:
        plt.imshow(mark_boundaries_color(reduced_raw_data[res_i][:,:,slice_index], reduced_mask_data[res_i][:,:,slice_index]))
    else:
        plt.imshow(mark_boundaries_color(reduced_raw_data[res_i][slice_index,:,:], reduced_mask_data[res_i][slice_index,:,:]))
    plt.colorbar()
    plt.title(f'Slice {slice_index}')
    plt.show()

# Create a slider to browse through slices
interact(plot_slice, slice_index=IntSlider(min=0, max=reduced_raw_data[0].shape[0]-1, step=1, value=0), axis=IntSlider(min=0, max=2, step=1, value=0))

interactive(children=(IntSlider(value=0, description='slice_index', max=255), IntSlider(value=0, description='…

<function __main__.plot_slice(slice_index, axis=0)>

In [7]:
# Helper functions that are annoying to move to a seperate file
def calculate_seam_iter(directed_graph, src, tgt, weights, test_size, x_pos, y_pos, z_pos):
    # Compute the residual capactiy of the edges
    # stime = time.time()
    res = gt.boykov_kolmogorov_max_flow(directed_graph, src, tgt, weights) #time complexity: edges * vertices^2 * abs(min cut) 
    # print("Time taken to calculate max flow boykov:", time.time()-stime)
    # sptime = time.time()
    # res = gt.push_relabel_max_flow(directed_graph, src, tgt, weights) #time complexity: vertices^3 <- scales better for our graph
    # print("Time taken to calculate max flow push relabel:", time.time()-sptime)
    # stime = time.time()
    #use the residual graph to get the max flow
    flow = sum(weights[e] - res[e] for e in tgt.in_edges())
    # print("The maximum flow from source to sink is:", flow)
    # print("Time taken to calculate max flow:", time.time()-stime, "flow:", flow)
    # stime = time.time()
    # Determine the minimum cut partition
    part = gt.min_st_cut(directed_graph, src, weights, res)
    # print("The number of vertices in the partiton:", sum(part))
    # print("Time taken to calculate max flow and min cut:", time.time()-stime)
    # Find the boundary vertices
    # stime = time.time()
    boundary_vertices = find_boundary_vertices(np.array(directed_graph.get_edges()), part)
    # print("Number of boundary vertices:", len(boundary_vertices))
    # print("Time taken to calculate seam:", time.time()-stime)

    shape = (test_size, test_size, test_size)

    # Convert the boundary vertices to a 3D array
    # stime = time.time()
    boundary_array = boundary_vertices_to_array_masked(boundary_vertices, shape, 'x', x_pos, y_pos, z_pos)
    # print("Boundary points marked:", np.sum(boundary_array))
    # print("Time taken to convert boundary vertices to array:", time.time()-stime)
    

    return boundary_array, flow

def multi_res_seam_iter(res_index, mask_array_data, b_arr_up):
    masked_array = mask_array_data[res_index].copy().astype(np.int16)
    masked_array[b_arr_up == 0] = -1
    stime= time.time()
    directed_graph, src, tgt, weights, x_pos, y_pos, z_pos = create_masked_directed_energy_graph_from_mask(masked_array)
    # print(f"Time taken to create graph {res_index}:", time.time()-stime)
    boundary_array, flow = calculate_seam_iter(directed_graph, src, tgt, weights, masked_array.shape[0], x_pos, y_pos, z_pos)
    return boundary_array

def multi_res_seam_calculation(mask_array_data, res_index=3, upscale_factor=2, dilation_amount=1):
    directed_graph, src, tgt, weights, x_pos, y_pos, z_pos = create_masked_directed_energy_graph_from_mask(mask_array_data[res_index])
    boundary_array, flow = calculate_seam_iter(directed_graph, src, tgt, weights, mask_array_data[res_index].shape[0], x_pos, y_pos, z_pos)
    b_arr_up = upscale_and_dilate_3d(boundary_array, upscale_factor=2, dilation_amount=dilation_amount)
    for i in range(res_index-1, -1, -1):
        boundary_array = multi_res_seam_iter(i, mask_array_data, b_arr_up)
        if i != 0:
            b_arr_up = upscale_and_dilate_3d(boundary_array, upscale_factor=upscale_factor, dilation_amount=dilation_amount)
    return boundary_array

In [8]:
res_index = 3 #res index to start at, higher values are lower res, and thus should be faster
mask_array_data = reduced_mask_data
raw_array_data = reduced_raw_data[0]
num_seams_to_remove = 76 #76 to hit the 180 vram limit -> mixed precision could increase this
boundary_arrays = []

for i in range(num_seams_to_remove):
    stime = time.time()
    boundary_array = multi_res_seam_calculation(mask_array_data, res_index, dilation_amount=1)
    mask_array_data, raw_array_data = remove_voxels(mask_array_data[0], raw_array_data, boundary_array)
    mask_array_data = coarsen_image(mask_array_data, res_index)
    boundary_arrays.append(boundary_array)
    print(f"Time taken to calculate and remove seam {i}:", time.time()-stime)

Time taken to calculate and remove seam 0: 20.732311010360718
Time taken to calculate and remove seam 1: 20.351593017578125
Time taken to calculate and remove seam 2: 20.32573890686035
Time taken to calculate and remove seam 3: 20.478810787200928
Time taken to calculate and remove seam 4: 21.023407697677612
Time taken to calculate and remove seam 5: 20.909087896347046
Time taken to calculate and remove seam 6: 20.654766082763672


KeyboardInterrupt: 

In [87]:
#save results
# Create output directory if it doesn't exist
output_dir = os.path.join(os.getcwd(), 'output/densified_cubes')
os.makedirs(output_dir, exist_ok=True)

# Save mask_array_data[0] as NRRD with a timestamp
mask_nrrd_path = os.path.join(output_dir, f'{filename}_{num_seams_to_remove}_densified_label.nrrd')
nrrd.write(mask_nrrd_path, mask_array_data[0])
print(f"Saved mask_array_data[0] to {mask_nrrd_path}")

# Save raw_array_data as NRRD with a timestamp
raw_nrrd_path = os.path.join(output_dir, f'{filename}_{num_seams_to_remove}_densified_data.nrrd')
nrrd.write(raw_nrrd_path, raw_array_data)
print(f"Saved raw_array_data to {raw_nrrd_path}")

Saved mask_array_data[0] to /Users/jamesdarby/Documents/VesuviusScroll/GP/3D_sheet_carving/output/densified_cubes/densified_label_20240606_214137.nrrd
Saved raw_array_data to /Users/jamesdarby/Documents/VesuviusScroll/GP/3D_sheet_carving/output/densified_cubes/densified_data_20240606_214137.nrrd


In [6]:
#plot results
res_i = 0
def plot_slice(slice_index, axis=0):
    plt.figure(figsize=(8, 6))
    if axis == 1:
        plt.imshow(mark_boundaries_color(raw_array_data[:,slice_index,:], mask_array_data[res_i][:,slice_index,:]))
    elif axis == 2:
        plt.imshow(mark_boundaries_color(raw_array_data[:,:,slice_index], mask_array_data[res_i][:,:,slice_index]))
    else:
        plt.imshow(mark_boundaries_color(raw_array_data[slice_index,:,:], mask_array_data[res_i][slice_index,:,:]))
    plt.colorbar()
    plt.title(f'Slice {slice_index}')
    plt.show()

# Create a slider to browse through slices
interact(plot_slice, slice_index=IntSlider(min=0, max=raw_array_data.shape[0]-1, step=1, value=0), axis=IntSlider(min=0, max=2, step=1, value=0))

interactive(children=(IntSlider(value=0, description='slice_index', max=255), IntSlider(value=0, description='…

<function __main__.plot_slice(slice_index, axis=0)>