In [75]:
import os
import nrrd
import numpy as np
from scipy.ndimage import binary_dilation

In [76]:
import numpy as np
from skimage.color import gray2rgb, label2rgb
from skimage.segmentation import find_boundaries
from skimage.util import img_as_float
from skimage.morphology import dilation, square
import random

def mark_boundaries_color(image, label_img, color=None, outline_color=None, mode='outer', background_label=0, dilation_size=1):
    """Return image with boundaries between labeled regions highlighted with consistent colors derived from labels.

    Parameters:
    - image: Input image.
    - label_img: Image with labeled regions.
    - color: Ignored in this version.
    - outline_color: If specified, use this color for the outline. Otherwise, use the same as boundary.
    - mode: Choose 'inner', 'outer', or 'thick' to define boundary type.
    - background_label: Label to be treated as the background.
    - dilation_size: Size of the dilation square for the boundaries.

    Returns:
    - Image with boundaries highlighted.
    """
    # Ensure input image is in float and has three channels
    float_dtype = np.float32  # Use float32 for efficiency
    marked = img_as_float(image, force_copy=True).astype(float_dtype, copy=False)
    if marked.ndim == 2:
        marked = gray2rgb(marked)

    # Create a color map normalized by the number of unique labels
    unique_labels = np.unique(label_img)
    color_map = plt.get_cmap('nipy_spectral')  # You can change 'nipy_spectral' to any other colormap

    # Find boundaries and apply colors
    boundaries = find_boundaries(label_img, mode=mode, background=background_label)
    for label in unique_labels:
        if label == background_label:
            continue
        # Normalize label value to the range of the colormap
        normalized_color = color_map(label / np.max(unique_labels))[:3]  # Get RGB values only
        label_boundaries = find_boundaries(label_img == label, mode=mode)
        label_boundaries = dilation(label_boundaries, square(dilation_size))
        marked[label_boundaries] = normalized_color
        if outline_color is not None:
            outlines = dilation(label_boundaries, square(dilation_size + 1))
            marked[outlines] = outline_color
        else:
            marked[label_boundaries] = normalized_color

    return marked


def consistent_color(label):
    """Generate a consistent color for a given label using a hash function."""
    random.seed(hash(label))
    return [random.random() for _ in range(3)]

def mark_boundaries_multicolor(image, label_img, color=None, outline_color=None, mode='outer', background_label=0, dilation_size=1):
    """Return image with boundaries between labeled regions highlighted with consistent colors.

    Parameters are the same as in the original function but color is ignored if provided.
    """
    # Ensure input image is in float and has three channels
    float_dtype = np.float32  # Use float32 for efficiency
    marked = img_as_float(image, force_copy=True).astype(float_dtype, copy=False)
    if marked.ndim == 2:
        marked = gray2rgb(marked)

    # Generate consistent colors for each unique label in label_img
    unique_labels = np.unique(label_img)
    color_map = {label: consistent_color(label) for label in unique_labels if label != background_label}

    # Find boundaries and apply colors
    boundaries = find_boundaries(label_img, mode=mode, background=background_label)
    for label, color in color_map.items():
        label_boundaries = find_boundaries(label_img == label, mode=mode)
        label_boundaries = dilation(label_boundaries, square(dilation_size))
        if outline_color is not None:
            outlines = dilation(label_boundaries, square(dilation_size))
            marked[outlines] = outline_color
        marked[label_boundaries] = color

    return marked

def plot_segmentation_results(test_slice, segmentation):
    fig, axes = plt.subplots(1, 2, figsize=(10, 5))

    # Show marked boundary image
    axes[0].imshow(mark_boundaries(test_slice, np.array(segmentation)))
    axes[0].set_title("Marked Boundary")

    # Show unmarked boundary image
    axes[1].imshow(test_slice, cmap='gray')
    axes[1].set_title("Unmarked Boundary")

    plt.show()

In [77]:
import os
import nrrd
import numpy as np
from scipy.ndimage import binary_dilation, binary_closing

def process_nrrd_files(root_dir):
    for subdir, _, files in os.walk(root_dir):
        if os.path.basename(subdir) == 'label':
            for file in files:
                if file.endswith('.nrrd'):
                    nrrd_path = os.path.join(subdir, file)
                    data, header = nrrd.read(nrrd_path)
                    
                    unique_values = np.unique(data)
                    unique_values = unique_values[unique_values != 0]  # Ignore background
                    
                    # Create an empty array for the result
                    result = np.zeros_like(data)
                    
                    for value in unique_values:
                        structure_mask = data == value
                        
                        # Fill small holes in the structure
                        closed_structure = structure_mask#binary_closing(structure_mask, structure=np.ones((3, 3, 3))) 
                        # for some reason doing the closing causes the dilation to dilate from the edges
                        # without it the dialtion keeps the structure at the edge intact which is more ideal
                        
                        # Dilate the closed structure
                        dilated_mask = binary_dilation(closed_structure, iterations=2)
                        
                        # Assign the border class and foreground class
                        border_class = dilated_mask & ~closed_structure
                        result[border_class] = 1  # Border class
                        result[~border_class & dilated_mask] = 2  # Foreground class
                    
                    # Save the processed array back to an .nrrd file
                    output_path = os.path.join(subdir, f"tri_class_{file}")
                    nrrd.write(output_path, result, header)
                    print(f"Processed and saved: {output_path}")

In [78]:
root_directory = '/home/james/Documents/VS/pytorch-3dunet-instanceSeg/data/Vesuvius/'
process_nrrd_files(root_directory)

Processed and saved: /home/james/Documents/VS/pytorch-3dunet-instanceSeg/data/Vesuvius/test/label/tri_class_3350_4000_8450_xyz_256_res1_s4_label.nrrd
Processed and saved: /home/james/Documents/VS/pytorch-3dunet-instanceSeg/data/Vesuvius/test/label/tri_class_layers_1.nrrd


In [79]:
from ipywidgets import interact, IntSlider
import matplotlib.pyplot as plt
pred = nrrd.read('/home/james/Documents/VS/pytorch-3dunet-instanceSeg/data/Vesuvius/test/label/tri_class_layers_1.nrrd')[0]
#show final clipped instance segmentation
def plot_slice(slice_index, axis=0):
    plt.figure(figsize=(8, 6))
    if axis == 1:
        plt.imshow(pred[:,slice_index,:])
    elif axis == 2:
        plt.imshow(pred[:,:,slice_index])
    else:
        plt.imshow(pred[slice_index,:,:])
    plt.colorbar()
    plt.title(f'Slice {slice_index}')
    plt.show()

interact(plot_slice, slice_index=IntSlider(min=0, max=pred.shape[0]-1, step=1, value=0), axis=IntSlider(min=0, max=2, step=1, value=0))

interactive(children=(IntSlider(value=0, description='slice_index', max=255), IntSlider(value=0, description='…

<function __main__.plot_slice(slice_index, axis=0)>

In [80]:
from ipywidgets import interact, IntSlider
import matplotlib.pyplot as plt
pred2 = nrrd.read('/home/james/Documents/VS/pytorch-3dunet-instanceSeg/data/Vesuvius/test/label/layers_1.nrrd')[0]
#show final clipped instance segmentation
def plot_slice(slice_index, axis=0):
    plt.figure(figsize=(8, 6))
    if axis == 1:
        plt.imshow(pred2[:,slice_index,:])
    elif axis == 2:
        plt.imshow(pred2[:,:,slice_index])
    else:
        plt.imshow(pred2[slice_index,:,:])
    plt.colorbar()
    plt.title(f'Slice {slice_index}')
    plt.show()

interact(plot_slice, slice_index=IntSlider(min=0, max=pred2.shape[0]-1, step=1, value=0), axis=IntSlider(min=0, max=2, step=1, value=0))

interactive(children=(IntSlider(value=0, description='slice_index', max=255), IntSlider(value=0, description='…

<function __main__.plot_slice(slice_index, axis=0)>