In [26]:
import nrrd
import numpy as np
import os
from scipy.ndimage import gaussian_filter
from skimage.morphology import *
from matplotlib import pyplot as plt
from ipywidgets import interact, IntSlider
from scipy.ndimage import binary_dilation, binary_erosion, binary_closing
from helper import *

In [27]:
# Load the data
current_directory = os.getcwd()
m1l, _ = nrrd.read(current_directory + '/data/manual_1_label.nrrd')
m1r, _ = nrrd.read(current_directory + '/data/manual_1_raw.nrrd')
m2l, _ = nrrd.read(current_directory + '/data/manual_2_label.nrrd')
m2r, _ = nrrd.read(current_directory + '/data/manual_2_raw.nrrd')

In [28]:
print(m1r.shape)

(256, 256, 256)


In [29]:
# threshold mask using gaussian filter m1r
smoothed_block = gaussian_filter(m1r, sigma=2)
threshold_mask = smoothed_block > np.mean(smoothed_block)
obj_size = 20000
threshold_mask = remove_small_holes(remove_small_objects(threshold_mask, obj_size), obj_size)
mask_ratio = np.sum(threshold_mask) / threshold_mask.size
print(threshold_mask.shape, np.mean(smoothed_block), mask_ratio)

(256, 256, 256) 31809.679174363613 0.48309850692749023


In [30]:
# slider display result of threshold mask
# Assuming 'res' is your 3D array with int64 data type
def plot_slice(slice_index, axis=0):
    plt.figure(figsize=(8, 6))
    if axis == 1:
        plt.imshow(m1r[:,slice_index,:] * threshold_mask [:,slice_index,:], cmap='gray')
    elif axis == 2:
        plt.imshow(m1r[:,:,slice_index] * threshold_mask [:,:,slice_index], cmap='gray')
    else:
        plt.imshow(m1r[slice_index,:,:] * threshold_mask [slice_index,:,:], cmap='gray')
    plt.colorbar()
    plt.title(f'Slice {slice_index}')
    plt.show()

# Create a slider to browse through slices
interact(plot_slice, slice_index=IntSlider(min=0, max=threshold_mask.shape[0]-1, step=1, value=0), axis=IntSlider(min=0, max=2, step=1, value=0))


interactive(children=(IntSlider(value=0, description='slice_index', max=255), IntSlider(value=0, description='…

<function __main__.plot_slice(slice_index, axis=0)>

In [31]:
def process_nrrd_file(nrrd_path, erosion_iterations=1, dilation_iterations=1):
    print(f"Processing: {nrrd_path}")
    data, header = nrrd.read(nrrd_path)
    unique_values = np.unique(data[data > 0])  # Ignore background & masking

    # Create an empty array for the result
    result = np.zeros_like(data, dtype=np.uint8)
    weights = np.ones_like(data, dtype=np.uint8)
    border_overlap = np.zeros_like(data, dtype=np.uint8)

    for value in unique_values:
        structure_mask = data == value
        
        # Pad the structure mask to prevent erosion at the edges
        padded_structure = np.pad(structure_mask, pad_width=erosion_iterations, mode='constant', constant_values=value)
        
        # Erode the padded structure
        eroded_padded_structure = binary_erosion(padded_structure, iterations=erosion_iterations)
        
        # Remove the padding after erosion
        eroded_structure = eroded_padded_structure[
            erosion_iterations:-erosion_iterations,
            erosion_iterations:-erosion_iterations,
            erosion_iterations:-erosion_iterations
        ]
        
        # Ensure the eroded structure is within bounds
        if eroded_structure.shape != structure_mask.shape:
            print(f"WARNING: Erosion caused shape mismatch for value {value}")
            eroded_structure = np.zeros_like(structure_mask)
        
        if dilation_iterations > 0:
            # Dilate the original structure
            dilated_structure = binary_dilation(structure_mask, iterations=dilation_iterations)
        else:
            dilated_structure = structure_mask
        
        border = np.logical_xor(dilated_structure, eroded_structure)
        # Label assignments
        result[border] = 1  # Border class
        result[eroded_structure] = 2  # Foreground class (overwrites the border if applicable)
        
        # Weight assignments
        weights[border] = 2 # border weight
        weights[eroded_structure] = 1

        # Update border overlap array
        border_overlap += border.astype(np.uint8)

    # Assign a weight of 10 to overlapping border regions
    weights[border_overlap > 1] = 1
    weights[border_overlap <= 1] = 0

    return result, weights

In [32]:
m1_label, m1_border = process_nrrd_file(current_directory + '/data/manual_1_label.nrrd', erosion_iterations=3, dilation_iterations=3)

Processing: /Users/jamesdarby/Documents/VesuviusScroll/GP/Vesuvius_3D_datasets/data/manual_1_label.nrrd


In [33]:
# Display the resulting border
def plot_slice(slice_index, axis=0):
    plt.figure(figsize=(8, 6))
    if axis == 1:
        plt.imshow(m1_border[:,slice_index,:], cmap='gray')
    elif axis == 2:
        plt.imshow(m1_border[:,:,slice_index], cmap='gray')
    else:
        plt.imshow(m1_border[slice_index,:,:], cmap='gray')
    plt.colorbar()
    plt.title(f'Slice {slice_index}')
    plt.show()

# Create a slider to browse through slices
interact(plot_slice, slice_index=IntSlider(min=0, max=m1_border.shape[0]-1, step=1, value=0), axis=IntSlider(min=0, max=2, step=1, value=0))


interactive(children=(IntSlider(value=0, description='slice_index', max=255), IntSlider(value=0, description='…

<function __main__.plot_slice(slice_index, axis=0)>

In [34]:
threshold_mask[m1_border == 1] = 0
split_threshold_mask = remove_small_holes(remove_small_objects(threshold_mask, obj_size), obj_size)

In [35]:
# Display the resulting split basic mask
def plot_slice(slice_index, axis=0):
    plt.figure(figsize=(8, 6))
    if axis == 1:
        plt.imshow(split_threshold_mask[:,slice_index,:], cmap='gray')
    elif axis == 2:
        plt.imshow(split_threshold_mask[:,:,slice_index], cmap='gray')
    else:
        plt.imshow(split_threshold_mask[slice_index,:,:], cmap='gray')
    plt.colorbar()
    plt.title(f'Slice {slice_index}')
    plt.show()

# Create a slider to browse through slices
interact(plot_slice, slice_index=IntSlider(min=0, max=split_threshold_mask.shape[0]-1, step=1, value=0), axis=IntSlider(min=0, max=2, step=1, value=0))


interactive(children=(IntSlider(value=0, description='slice_index', max=255), IntSlider(value=0, description='…

<function __main__.plot_slice(slice_index, axis=0)>

In [36]:
instance_seg = label_foreground_structures(split_threshold_mask, foreground_value=1)

Number of connected foreground structures before filtering: 11
Number of connected foreground structures after filtering: 11


In [37]:
#plot tri-class recovered instance labels
def plot_slice(slice_index, axis=0):
    plt.figure(figsize=(8, 6))
    if axis == 1:
        plt.imshow(instance_seg[:,slice_index,:])
    elif axis == 2:
        plt.imshow(instance_seg[:,:,slice_index])
    else:
        plt.imshow(instance_seg[slice_index,:,:])
    plt.colorbar()
    plt.title(f'Slice {slice_index}')
    plt.show()

interact(plot_slice, slice_index=IntSlider(min=0, max=instance_seg.shape[0]-1, step=1, value=0), axis=IntSlider(min=0, max=2, step=1, value=0))

interactive(children=(IntSlider(value=0, description='slice_index', max=255), IntSlider(value=0, description='…

<function __main__.plot_slice(slice_index, axis=0)>

In [38]:
#plot tri-class recovered instance labels
def plot_slice(slice_index, axis=0):
    plt.figure(figsize=(8, 6))
    if axis == 1:
        plt.imshow(mark_boundaries_color(m1r[:,slice_index,:], instance_seg[:,slice_index,:]))
    elif axis == 2:
        plt.imshow(mark_boundaries_color(m1r[:,:,slice_index], instance_seg[:,:,slice_index]))
    else:
        plt.imshow(mark_boundaries_color(m1r[slice_index,:,:], instance_seg[slice_index,:,:]))
    plt.colorbar()
    plt.title(f'Slice {slice_index}')
    plt.show()

interact(plot_slice, slice_index=IntSlider(min=0, max=instance_seg.shape[0]-1, step=1, value=0), axis=IntSlider(min=0, max=2, step=1, value=0))

interactive(children=(IntSlider(value=0, description='slice_index', max=255), IntSlider(value=0, description='…

<function __main__.plot_slice(slice_index, axis=0)>