In [34]:
import nrrd
import numpy as np
import os
import matplotlib.pyplot as plt
from ipywidgets import interact, IntSlider
from skimage.segmentation import mark_boundaries
import cv2
import argparse
from numba import jit
from scipy import ndimage as ndi

mask_path = "/Users/jamesdarby/Documents/VesuviusScroll/GP/EmbedSegScrolls/manually_labelled_cubes/test/instances/layers_1.nrrd"
raw_data_path = "/Users/jamesdarby/Documents/VesuviusScroll/GP/EmbedSegScrolls/manually_labelled_cubes/test/volumes/slices_1.nrrd"

mask_data, mask_header = nrrd.read(mask_path)
raw_data, raw_header = nrrd.read(raw_data_path)

In [36]:
#helper functions: TODO move to separate file
import numpy as np
from skimage.color import gray2rgb, label2rgb
from skimage.segmentation import find_boundaries
from skimage.util import img_as_float
from skimage.morphology import dilation, square
import random

def mark_boundaries_color(image, label_img, color=None, outline_color=None, mode='outer', background_label=0, dilation_size=1):
    """Return image with boundaries between labeled regions highlighted with consistent colors derived from labels.

    Parameters:
    - image: Input image.
    - label_img: Image with labeled regions.
    - color: Ignored in this version.
    - outline_color: If specified, use this color for the outline. Otherwise, use the same as boundary.
    - mode: Choose 'inner', 'outer', or 'thick' to define boundary type.
    - background_label: Label to be treated as the background.
    - dilation_size: Size of the dilation square for the boundaries.

    Returns:
    - Image with boundaries highlighted.
    """
    # Ensure input image is in float and has three channels
    float_dtype = np.float32  # Use float32 for efficiency
    marked = img_as_float(image, force_copy=True).astype(float_dtype, copy=False)
    if marked.ndim == 2:
        marked = gray2rgb(marked)

    # Create a color map normalized by the number of unique labels
    unique_labels = np.unique(label_img)
    color_map = plt.get_cmap('nipy_spectral')  # You can change 'nipy_spectral' to any other colormap

    # Find boundaries and apply colors
    boundaries = find_boundaries(label_img, mode=mode, background=background_label)
    for label in unique_labels:
        if label == background_label:
            continue
        # Normalize label value to the range of the colormap
        normalized_color = color_map(label / np.max(unique_labels))[:3]  # Get RGB values only
        label_boundaries = find_boundaries(label_img == label, mode=mode)
        label_boundaries = dilation(label_boundaries, square(dilation_size))
        marked[label_boundaries] = normalized_color
        if outline_color is not None:
            outlines = dilation(label_boundaries, square(dilation_size + 1))
            marked[outlines] = outline_color
        else:
            marked[label_boundaries] = normalized_color

    return marked

def mark_boundaries_multicolor(image, label_img, color=None, outline_color=None, mode='outer', background_label=0, dilation_size=1):
    """Return image with boundaries between labeled regions highlighted with consistent colors.

    Parameters are the same as in the original function but color is ignored if provided.
    """
    # Ensure input image is in float and has three channels
    float_dtype = np.float32  # Use float32 for efficiency
    marked = img_as_float(image, force_copy=True).astype(float_dtype, copy=False)
    if marked.ndim == 2:
        marked = gray2rgb(marked)

    # Generate consistent colors for each unique label in label_img
    unique_labels = np.unique(label_img)
    color_map = {label: consistent_color(label) for label in unique_labels if label != background_label}

    # Find boundaries and apply colors
    boundaries = find_boundaries(label_img, mode=mode, background=background_label)
    for label, color in color_map.items():
        label_boundaries = find_boundaries(label_img == label, mode=mode)
        label_boundaries = dilation(label_boundaries, square(dilation_size))
        if outline_color is not None:
            outlines = dilation(label_boundaries, square(dilation_size))
            marked[outlines] = outline_color
        marked[label_boundaries] = color

    return marked


In [37]:
def plot_slice(slice_index, axis=0):
    plt.figure(figsize=(8, 6))
    if axis == 1:
        plt.imshow(mark_boundaries_color(raw_data[:,slice_index,:], mask_data[:,slice_index,:]))
    elif axis == 2:
        plt.imshow(mark_boundaries_color(raw_data[:,:,slice_index], mask_data[:,:,slice_index]))
    else:
        plt.imshow(mark_boundaries_color(raw_data[slice_index,:,:], mask_data[slice_index,:,:]))
    plt.colorbar()
    plt.title(f'Slice {slice_index}')
    plt.show()

# Create a slider to browse through slices
interact(plot_slice, slice_index=IntSlider(min=0, max=mask_data.shape[0]-1, step=1, value=0), axis=IntSlider(min=0, max=2, step=1, value=0))

interactive(children=(IntSlider(value=0, description='slice_index', max=255), IntSlider(value=0, description='…

<function __main__.plot_slice(slice_index, axis=0)>

1. from left to right find all the first background pixels, then remove them, then the second etc until a full background pass cannot be made
2. Do this on all 3 axis.

In [18]:
import numpy as np

def deep_shrink_array(array, axis):
    """
    Iteratively shrinks a 3D array along a given axis by removing layers that have reached
    their first background voxel for each perpendicular position. The iteration stops when
    no further background voxels can be removed without affecting non-background voxels.

    :param array: 3D numpy array
    :param axis: Axis along which to shrink the array (0, 1, or 2)
    :return: The deeply shrunk 3D array
    """
    depth = array.shape[axis]
    current_index = np.zeros(array.shape[:axis] + array.shape[axis+1:], dtype=int)

    while True:
        change_occurred = False
        # Determine if the current indexed slice for each position is a background voxel
        for idx in np.ndindex(current_index.shape):
            full_idx = idx[:axis] + (current_index[idx],) + idx[axis:]
            if full_idx[axis] < depth and array[full_idx] == 0:
                current_index[idx] += 1
                change_occurred = True
            elif full_idx[axis] >= depth:
                # This position has already reached beyond the array bounds
                continue

        # Exit the loop if no more changes have occurred
        if not change_occurred:
            break

    # Construct a mask for removing slices
    max_removal_depth = np.min(current_index)
    if max_removal_depth > 0:
        if axis == 0:
            array = array[max_removal_depth:, :, :]
        elif axis == 1:
            array = array[:, max_removal_depth:, :]
        else:
            array = array[:, :, max_removal_depth:]

    return array

            

In [29]:
def shrink_one_layer(array, axis):
    """
    Shrinks a 3D array along a given axis by removing the first background voxel encountered
    for each position perpendicular to the specified axis, reducing the size by one along
    that axis.

    :param array: 3D numpy array
    :param axis: Axis along which to shrink the array (0, 1, or 2)
    :return: The shrunken 3D array
    """
    shape = list(array.shape)
    new_depth = shape[axis] - 1
    if new_depth <= 0:
        return np.empty((0, *shape[1:]), dtype=array.dtype)  # Avoid negative dimensions

    # Initialize a new array with one less in the specified axis
    new_shape = shape[:]
    new_shape[axis] = new_depth
    new_array = np.empty(new_shape, dtype=array.dtype)

    # Copy data into new array, skipping the first background voxel for each perpendicular position
    for idx in np.ndindex(tuple(shape[i] for i in range(len(shape)) if i != axis)):
        if axis == 0:
            line = array[:, idx[0], idx[1]]
        elif axis == 1:
            line = array[idx[0], :, idx[1]]
        else:
            line = array[idx[0], idx[1], :]

        # Find first background voxel
        first_background = np.argmax(line == 0)
        if line[first_background] == 0:  # Ensure it is actually background
            # Create a new line without the first background voxel
            new_line = np.concatenate([line[:first_background], line[first_background+1:]])
            if new_line.shape[0] < new_depth:  # If no background found, copy original
                new_line = line[:new_depth]
        else:
            new_line = line[:new_depth]  # If no background found, just truncate

        if axis == 0:
            new_array[:, idx[0], idx[1]] = new_line
        elif axis == 1:
            new_array[idx[0], :, idx[1]] = new_line
        else:
            new_array[idx[0], idx[1], :] = new_line

    return new_array

In [51]:
def dual_shrink_one_layer(data_array, control_array, axis):
    """
    Shrinks two 3D arrays along a given axis by removing the first background voxel (value == 0)
    encountered in the control_array for each position perpendicular to the specified axis,
    reducing the size by one along that axis for both arrays.

    :param data_array: 3D numpy array to be modified according to control_array
    :param control_array: 3D numpy array used to determine the voxel removal
    :param axis: Axis along which to shrink the arrays (0, 1, or 2)
    :return: A tuple of the two shrunken 3D arrays
    """
    shape = list(control_array.shape)
    new_depth = shape[axis] - 1
    if new_depth <= 0:
        return (np.empty((0, *shape[1:]), dtype=data_array.dtype), 
                np.empty((0, *shape[1:]), dtype=control_array.dtype))  # Avoid negative dimensions

    # Initialize new arrays with one less in the specified axis
    new_shape = shape[:]
    new_shape[axis] = new_depth
    new_data_array = np.empty(new_shape, dtype=data_array.dtype)
    new_control_array = np.empty(new_shape, dtype=control_array.dtype)

    # Copy data into new arrays, skipping the first background voxel for each perpendicular position
    for idx in np.ndindex(tuple(shape[i] for i in range(len(shape)) if i != axis)):
        if axis == 0:
            control_line = control_array[:, idx[0], idx[1]]
            data_line = data_array[:, idx[0], idx[1]]
        elif axis == 1:
            control_line = control_array[idx[0], :, idx[1]]
            data_line = data_array[idx[0], :, idx[1]]
        else:
            control_line = control_array[idx[0], idx[1], :]
            data_line = data_array[idx[0], idx[1], :]

        # Find first background voxel in the control line
        first_background = np.argmax(control_line == 0)
        if control_line[first_background] == 0:  # Ensure it is actually background
            # Create new lines without the first background voxel
            new_control_line = np.concatenate([control_line[:first_background], control_line[first_background+1:]])
            new_data_line = np.concatenate([data_line[:first_background], data_line[first_background+1:]])
            if new_control_line.shape[0] < new_depth:  # If no background found, copy original
                new_control_line = control_line[:new_depth]
                new_data_line = data_line[:new_depth]
        else:
            new_control_line = control_line[:new_depth]  # If no background found, just truncate
            new_data_line = data_line[:new_depth]

        if axis == 0:
            new_control_array[:, idx[0], idx[1]] = new_control_line
            new_data_array[:, idx[0], idx[1]] = new_data_line
        elif axis == 1:
            new_control_array[idx[0], :, idx[1]] = new_control_line
            new_data_array[idx[0], :, idx[1]] = new_data_line
        else:
            new_control_array[idx[0], idx[1], :] = new_control_line
            new_data_array[idx[0], idx[1], :] = new_data_line

    return new_data_array, new_control_array

In [67]:
rd, ts = dual_shrink_one_layer(raw_data, mask_data, 2)
print(ts.shape)

(256, 256, 255)


In [68]:
for i in range(100):
    # ts = shrink_one_layer(ts, 1)
    # ts = shrink_one_layer(ts, 2)
    rd, ts = dual_shrink_one_layer(rd, ts, 1)
    rd, ts = dual_shrink_one_layer(rd, ts, 2)

In [69]:
def plot_slice(slice_index, axis=0):
    plt.figure(figsize=(8, 6))
    if axis == 1:
        plt.imshow(rd[:,slice_index,:],cmap='gray')
    elif axis == 2:
        plt.imshow(rd[:,:,slice_index],cmap='gray')
    else:
        plt.imshow(rd[slice_index,:,:],cmap='gray')
    plt.colorbar()
    plt.title(f'Slice {slice_index}')
    plt.show()

# Create a slider to browse through slices
interact(plot_slice, slice_index=IntSlider(min=0, max=ts.shape[0]-1, step=1, value=0), axis=IntSlider(min=0, max=2, step=1, value=0))

interactive(children=(IntSlider(value=0, description='slice_index', max=255), IntSlider(value=0, description='…

<function __main__.plot_slice(slice_index, axis=0)>

In [70]:
def plot_slice(slice_index, axis=0):
    plt.figure(figsize=(8, 6))
    if axis == 1:
        plt.imshow(ts[:,slice_index,:])
    elif axis == 2:
        plt.imshow(ts[:,:,slice_index])
    else:
        plt.imshow(ts[slice_index,:,:])
    plt.colorbar()
    plt.title(f'Slice {slice_index}')
    plt.show()

# Create a slider to browse through slices
interact(plot_slice, slice_index=IntSlider(min=0, max=ts.shape[0]-1, step=1, value=0), axis=IntSlider(min=0, max=2, step=1, value=0))

interactive(children=(IntSlider(value=0, description='slice_index', max=255), IntSlider(value=0, description='…

<function __main__.plot_slice(slice_index, axis=0)>