In [56]:
import h5py
import numpy as np
import nrrd

def read_hdf5_file(file_path):
    # Open the HDF5 file
    with h5py.File(file_path, 'r') as hdf5_file:
        # Print all dataset names
        print("Datasets in the file:")
        for name in hdf5_file:
            print(name)
        
        # Read datasets into NumPy arrays
        data = {}
        for name in hdf5_file:
            data[name] = np.array(hdf5_file[name])
            print(f"{name} shape: {data[name].shape}")

    return data

In [57]:
pred_path = '/home/james/Documents/VS/pytorch-3dunet-instanceSeg/data/Vesuvius/test/dataset/layers_1_predictions.h5'
# pred_path = '/home/james/Documents/VS/pytorch-3dunet-instanceSeg/data/Vesuvius/val/dataset/layers_1.h5'
data = read_hdf5_file(pred_path)
d2 = read_hdf5_file('/home/james/Documents/VS/pytorch-3dunet-instanceSeg/data/Vesuvius/train/dataset/layers_1.h5')
img_path = '/home/james/Documents/VS/pytorch-3dunet-instanceSeg/data/Vesuvius/test/raw/layers_1.nrrd'
img, _ = nrrd.read(img_path)

pred = data['predictions']
gt = d2['label']
print(pred.shape, img.shape, np.sum(pred), np.sum(gt))

Datasets in the file:
predictions
predictions shape: (256, 256, 256)
Datasets in the file:
label
raw
label shape: (256, 256, 256)
raw shape: (256, 256, 256)
(256, 256, 256) (256, 256, 256) 18763119 19231572


In [58]:
import numpy as np
from skimage.color import gray2rgb, label2rgb
from skimage.segmentation import find_boundaries, mark_boundaries
from skimage.util import img_as_float
from skimage.morphology import dilation, square
import random

def mark_boundaries_color(image, label_img, color=None, outline_color=None, mode='outer', background_label=0, dilation_size=1):
    """Return image with boundaries between labeled regions highlighted with consistent colors derived from labels.

    Parameters:
    - image: Input image.
    - label_img: Image with labeled regions.
    - color: Ignored in this version.
    - outline_color: If specified, use this color for the outline. Otherwise, use the same as boundary.
    - mode: Choose 'inner', 'outer', or 'thick' to define boundary type.
    - background_label: Label to be treated as the background.
    - dilation_size: Size of the dilation square for the boundaries.

    Returns:
    - Image with boundaries highlighted.
    """
    # Ensure input image is in float and has three channels
    float_dtype = np.float32  # Use float32 for efficiency
    marked = img_as_float(image, force_copy=True).astype(float_dtype, copy=False)
    if marked.ndim == 2:
        marked = gray2rgb(marked)

    # Create a color map normalized by the number of unique labels
    unique_labels = np.unique(label_img)
    color_map = plt.get_cmap('nipy_spectral')  # You can change 'nipy_spectral' to any other colormap

    # Find boundaries and apply colors
    boundaries = find_boundaries(label_img, mode=mode, background=background_label)
    for label in unique_labels:
        if label == background_label:
            continue
        # Normalize label value to the range of the colormap
        normalized_color = color_map(label / np.max(unique_labels))[:3]  # Get RGB values only
        label_boundaries = find_boundaries(label_img == label, mode=mode)
        label_boundaries = dilation(label_boundaries, square(dilation_size))
        marked[label_boundaries] = normalized_color
        if outline_color is not None:
            outlines = dilation(label_boundaries, square(dilation_size + 1))
            marked[outlines] = outline_color
        else:
            marked[label_boundaries] = normalized_color

    return marked


def consistent_color(label):
    """Generate a consistent color for a given label using a hash function."""
    random.seed(hash(label))
    return [random.random() for _ in range(3)]

def mark_boundaries_multicolor(image, label_img, color=None, outline_color=None, mode='outer', background_label=0, dilation_size=1):
    """Return image with boundaries between labeled regions highlighted with consistent colors.

    Parameters are the same as in the original function but color is ignored if provided.
    """
    # Ensure input image is in float and has three channels
    float_dtype = np.float32  # Use float32 for efficiency
    marked = img_as_float(image, force_copy=True).astype(float_dtype, copy=False)
    if marked.ndim == 2:
        marked = gray2rgb(marked)

    # Generate consistent colors for each unique label in label_img
    unique_labels = np.unique(label_img)
    color_map = {label: consistent_color(label) for label in unique_labels if label != background_label}

    # Find boundaries and apply colors
    boundaries = find_boundaries(label_img, mode=mode, background=background_label)
    for label, color in color_map.items():
        label_boundaries = find_boundaries(label_img == label, mode=mode)
        label_boundaries = dilation(label_boundaries, square(dilation_size))
        if outline_color is not None:
            outlines = dilation(label_boundaries, square(dilation_size))
            marked[outlines] = outline_color
        marked[label_boundaries] = color

    return marked

def plot_segmentation_results(test_slice, segmentation):
    fig, axes = plt.subplots(1, 2, figsize=(10, 5))

    # Show marked boundary image
    axes[0].imshow(mark_boundaries(test_slice, np.array(segmentation)))
    axes[0].set_title("Marked Boundary")

    # Show unmarked boundary image
    axes[1].imshow(test_slice, cmap='gray')
    axes[1].set_title("Unmarked Boundary")

    plt.show()

In [59]:
unique_values = np.unique(pred)
print(unique_values)

[0 1 2]


In [60]:
#show prediction
from ipywidgets import interact, IntSlider
import matplotlib.pyplot as plt
def plot_slice(slice_index, axis=0):
    plt.figure(figsize=(8, 6))
    if axis == 1:
        plt.imshow(pred[:,slice_index,:])
    elif axis == 2:
        plt.imshow(pred[:,:,slice_index])
    else:
        plt.imshow(pred[slice_index,:,:])
    plt.colorbar()
    plt.title(f'Slice {slice_index}')
    plt.show()

interact(plot_slice, slice_index=IntSlider(min=0, max=pred.shape[0]-1, step=1, value=0), axis=IntSlider(min=0, max=2, step=1, value=0))

interactive(children=(IntSlider(value=0, description='slice_index', max=255), IntSlider(value=0, description='…

<function __main__.plot_slice(slice_index, axis=0)>

In [66]:
#show ground truth label
from ipywidgets import interact, IntSlider
import matplotlib.pyplot as plt
def plot_slice(slice_index, axis=0):
    plt.figure(figsize=(8, 6))
    if axis == 1:
        plt.imshow(gt[:,slice_index,:])
    elif axis == 2:
        plt.imshow(gt[:,:,slice_index])
    else:
        plt.imshow(gt[slice_index,:,:])
    plt.colorbar()
    plt.title(f'Slice {slice_index}')
    plt.show()

interact(plot_slice, slice_index=IntSlider(min=0, max=pred.shape[0]-1, step=1, value=0), axis=IntSlider(min=0, max=2, step=1, value=0))

interactive(children=(IntSlider(value=0, description='slice_index', max=255), IntSlider(value=0, description='…

<function __main__.plot_slice(slice_index, axis=0)>

In [67]:
#show prediction overlaid on raw data
from ipywidgets import interact, IntSlider
import matplotlib.pyplot as plt
def plot_slice(slice_index, axis=0):
    plt.figure(figsize=(8, 6))
    if axis == 1:
        plt.imshow(mark_boundaries_color(img[:,slice_index,:], pred[:,slice_index,:]))
    elif axis == 2:
        plt.imshow(mark_boundaries_color(img[:,:,slice_index], pred[:,:,slice_index]))
    else:
        plt.imshow(mark_boundaries_color(img[slice_index,:,:], pred[slice_index,:,:]))
    plt.colorbar()
    plt.title(f'Slice {slice_index}')
    plt.show()

interact(plot_slice, slice_index=IntSlider(min=0, max=pred.shape[0]-1, step=1, value=0), axis=IntSlider(min=0, max=2, step=1, value=0))

interactive(children=(IntSlider(value=0, description='slice_index', max=255), IntSlider(value=0, description='…

<function __main__.plot_slice(slice_index, axis=0)>

In [68]:
import scipy.ndimage
import scipy.ndimage
def label_foreground_structures(input_array, min_size=100000): #use aggressize min_size to remove small structures, then decent overlap stride to get ones that are erronously cut off on the edges
    # Find connected components in the foreground (value 2)
    foreground = (input_array == 2)
    
    # Label connected components
    labeled_array, num_features = scipy.ndimage.label(foreground)
    
    # Measure the size of each connected component
    component_sizes = np.bincount(labeled_array.ravel())
    
    # Create a mask for components larger than the minimum size
    large_components = component_sizes >= min_size
    
    # Ensure background is not considered a component
    large_components[0] = False
    
    # Create a filtered array to hold only large components
    filtered_array = labeled_array.copy()
    
    # Set small components to 0 (background)
    filtered_array[~large_components[labeled_array]] = 0
    
    print(f"Number of connected foreground structures before filtering: {num_features}")
    print(f"Number of connected foreground structures after filtering: {np.max(filtered_array)}")
    
    return filtered_array

In [69]:
labeled_arr = label_foreground_structures(pred)

Number of connected foreground structures before filtering: 32
Number of connected foreground structures after filtering: 2


In [70]:
#show post processing instance segmentation overlaid on raw data
def plot_slice(slice_index, axis=0):
    plt.figure(figsize=(8, 6))
    if axis == 1:
        plt.imshow(mark_boundaries_color(img[:,slice_index,:], labeled_arr[:,slice_index,:]))
    elif axis == 2:
        plt.imshow(mark_boundaries_color(img[:,:,slice_index], labeled_arr[:,:,slice_index]))
    else:
        plt.imshow(mark_boundaries_color(img[slice_index,:,:], labeled_arr[slice_index,:,:]))
    plt.colorbar()
    plt.title(f'Slice {slice_index}')
    plt.show()

interact(plot_slice, slice_index=IntSlider(min=0, max=pred.shape[0]-1, step=1, value=0), axis=IntSlider(min=0, max=2, step=1, value=0))

interactive(children=(IntSlider(value=0, description='slice_index', max=255), IntSlider(value=0, description='…

<function __main__.plot_slice(slice_index, axis=0)>