In [None]:
import imageio.v3 as iio
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
import numpy as np
from pathlib import Path
import segment
from scipy import ndimage as ndi
from skimage import color, feature, filters, morphology, measure, segmentation, util
from stl import mesh
import sys
import yaml
%load_ext autoreload
%autoreload 2

## Load in Images

In [None]:
# Read YAML input file
yaml_file = Path(r'C:\Users\gusb\Research\PSAAP\segmentflow-input-gus.yml')
stream = open(yaml_file, 'r')
UI = yaml.load(stream,Loader=yaml.FullLoader)   # User Input
stream.close()
# Process User Input
ui_ct_img_dir           = UI['Files']['CT Scan Dir']
ui_stl_dir_location     = UI['Files']['STL Dir']
ui_output_filename_base = UI['Files']['STL Prefix']
ui_stl_overwrite        = UI['Files']['Overwrite Existing STL Files']
ui_single_particle_iso  = UI['Files']['Particle ID']
ui_suppress_save_msg    = UI['Files']['Suppress Save Messages']
ui_file_suffix          = UI['Load']['File Suffix']
ui_slice_crop           = UI['Load']['Slice Crop']
ui_row_crop             = UI['Load']['Row Crop']             
ui_col_crop             = UI['Load']['Col Crop']
ui_use_median_filter    = UI['Preprocess']['Apply Median Filter']
ui_rescale_range        = UI['Preprocess']['Rescale Intensity Range']
ui_n_otsu_classes       = UI['Binarize']['Number of Otsu Classes']
ui_n_selected_classes   = UI['Binarize']['Number of Classes to Select']
ui_use_int_dist_map     = UI['Segment']['Use Integer Distance Map']
ui_min_peak_distance    = UI['Segment']['Min Peak Distance']
ui_exclude_borders      = UI['Segment']['Exclude Border Particles']
ui_erode_particles      = UI['STL']['Erode Particles']    
ui_voxel_step_size      = UI['STL']['Marching Cubes Voxel Step Size']    
ui_spatial_res          = UI['STL']['Pixel-to-Length Ratio']
ui_show_segment_fig     = UI['Plot']['Show Segmentation Figure']
ui_n_imgs               = UI['Plot']['Number of Images']
ui_plot_maxima          = UI['Plot']['Plot Maxima']
ui_show_label_fig       = UI['Plot']['Show Particle Labels Figure']
ui_label_idx            = UI['Plot']['Particle Label Image Index']
ui_show_stl_fig         = UI['Plot']['Show Random STL Figure']
# Load images
print('Loading images...')
imgs = segment.load_images(
    ui_ct_img_dir,
    slice_crop=ui_slice_crop,
    row_crop=ui_row_crop,
    col_crop=ui_col_crop,
    convert_to_float=True,
    file_suffix=ui_file_suffix
)
print('--> Images loaded as 3D array: ', imgs.shape)
print('--> Size of array (GB): ', imgs.nbytes / 1E9)
# Plot images
fig, axes = segment.plot_imgs(imgs, n_imgs=4)
plt.show()

## Add preprocessing step

### Median filter followed by intensity rescale

In [None]:
print('Preprocessing images...')
imgs_pre = segment.preprocess(
    imgs, median_filter=True, rescale_intensity_range=ui_rescale_range
)
print('--> Preprocessing complete')
print('--> Size of array (GB): ', imgs_pre.nbytes / 1E9)
# Plot preprocessed images
fig, axes = segment.plot_imgs(imgs_pre, n_imgs=4)
plt.show()

In [None]:
print('Binarizing images...')
# Borders must be excluded after seg. to avoid exclusion of connected regions
imgs_binarized, thresh_vals = segment.binarize_multiotsu(
    imgs_pre, 
    n_otsu_classes=ui_n_otsu_classes, 
    n_selected_thresholds=ui_n_selected_classes, 
    exclude_borders=False, 
)
print('--> Binarization complete')
print('--> Size of array (GB): ', imgs_binarized.nbytes / 1E9)
# Plot binarized images
fig, axes = segment.plot_imgs(imgs_binarized, n_imgs=4)
plt.show()

## Segment the Images

In [None]:
print('Segmenting images...')
segment_dict = segment.watershed_segment(
    imgs_binarized, min_peak_distance=ui_min_peak_distance, 
    use_int_dist_map=ui_use_int_dist_map, return_dict=True
)
print('--> Segmentation complete')

In [None]:
# sys.getsizeof() doesn't represent nested objects; need to add manually
print('--> Size of segmentation results (GB):')
dict_size = sys.getsizeof(segment_dict)
for key, val in segment_dict.items():
    print(f'----> {key}: {sys.getsizeof(val) / 1E9}')

In [None]:
fig, axes = segment.plot_segment_steps(imgs, imgs_pre, imgs_binarized, segment_dict)
plt.show()

In [None]:
# How Many Particles Were Segmented?
n_particles = np.max(segment_dict['integer-labels'])
n_particles_digits = len(str(n_particles))
print('--> Total number of particles segmented: ' + str(n_particles))

### Exclude border particles

In [None]:
if ui_exclude_borders:
    # How Many Particles Were Segmented?
    n_particles = np.max(segment_dict['integer-labels'])
    n_particles_digits = len(str(n_particles))
    print('--> Number of particles before border exclusion: ', str(n_particles))
    print()
    print('Excluding border particles...')
    segment_dict['integer-labels'] = segmentation.clear_border(
        segment_dict['integer-labels']
    )
regions = measure.regionprops(segment_dict['integer-labels'])
n_particles_noborder = len(regions)
print('--> Number of particles: ', str(n_particles_noborder))
fig, axes = segment.plot_segment_steps(imgs, imgs_pre, imgs_binarized, segment_dict)
plt.show()

### Iterate through regions

In [None]:
import open3d as o3d

def check_properties(mesh):
    n_triangles = len(mesh.triangles)
    edge_manifold = mesh.is_edge_manifold(allow_boundary_edges=True)
    edge_manifold_boundary = mesh.is_edge_manifold(allow_boundary_edges=False)
    vertex_manifold = mesh.is_vertex_manifold()
    self_intersecting = mesh.is_self_intersecting()
    watertight = mesh.is_watertight()
    orientable = mesh.is_orientable()
    print(f"  n_triangles:            {n_triangles}")
    print(f"  edge_manifold:          {edge_manifold}")
    print(f"  edge_manifold_boundary: {edge_manifold_boundary}")
    print(f"  vertex_manifold:        {vertex_manifold}")
    print(f"  self_intersecting:      {self_intersecting}")
    print(f"  watertight:             {watertight}")
    print(f"  orientable:             {orientable}")
    print()

def repair_mesh(mesh):
    print('Repairing mesh...')
    mesh.remove_degenerate_triangles()
    mesh.remove_duplicated_triangles()
    mesh.remove_duplicated_vertices()
    mesh.remove_non_manifold_edges()
    return mesh

def smooth_mesh(mesh, n_iter=10):
    print('Smoothing mesh...')
    smoothed = mesh.filter_smooth_simple(number_of_iterations=n_iter)
    return smoothed 

def simplify_mesh(mesh, n_tris=500):
    print('Simplifying mesh...')
    simplified = mesh.simplify_quadric_decimation(n_tris)
    if not simplified.is_watertight():
        simplified, n_tris = simplify_mesh(mesh, n_tris + 1)
    return simplified, n_tris

def postprocess_mesh(
        stl_save_path, smooth_iter=10, simplify_n_tris=250, 
        repair_mesh=True, return_props=True
):
    stl_save_path = str(stl_save_path)
    stl_mesh = o3d.io.read_triangle_mesh(stl_save_path)
    if smooth_iter is not None:
        stl_mesh = stl_mesh.filter_smooth_simple(number_of_iterations=smooth_iter)
    if simplify_n_tris is not None:
        stl_mesh = stl_mesh.simplify_quadric_decimation(simplify_n_tris)
    if repair_mesh:
        stl_mesh.remove_degenerate_triangles()
        stl_mesh.remove_duplicated_triangles()
        stl_mesh.remove_duplicated_vertices()
        stl_mesh.remove_non_manifold_edges()
    o3d.io.write_triangle_mesh(
        stl_save_path, stl_mesh, 
        # Currently unsupported to save STLs in ASCII format
        # write_ascii=True
    )
    if return_props:
        mesh_props = {}
        mesh_props['n_triangles'] = len(stl_mesh.triangles)
        mesh_props['watertight'] = stl_mesh.is_watertight()
        mesh_props['self_intersecting'] = stl_mesh.is_self_intersecting()
        mesh_props['orientable'] = stl_mesh.is_orientable()
        mesh_props['edge_manifold'] = stl_mesh.is_edge_manifold(allow_boundary_edges=True)
        mesh_props['edge_manifold_boundary'] = stl_mesh.is_edge_manifold(allow_boundary_edges=False)
        mesh_props['vertex_manifold'] = stl_mesh.is_vertex_manifold()
        return mesh_props

In [None]:
stl_dir_location = ui_stl_dir_location
output_filename_base = ui_output_filename_base
suppress_save_msg=ui_suppress_save_msg
slice_crop=ui_slice_crop
row_crop=ui_row_crop
col_crop=ui_col_crop
spatial_res=ui_spatial_res
voxel_step_size=ui_voxel_step_size
allow_degenerate_tris=False
erode_particles=ui_erode_particles
stl_overwrite=False
n_saved = 0
n_not_saved = 0
bbox_dict = {
    'min_slice' : [],
    'max_slice' : [],
    'min_row' : [],
    'max_row' : [],
    'min_col' : [],
    'max_col' : [],
}
n_erosions = 1
regions_subset = [region for i, region in enumerate(regions) if i < 100]
for region in regions_subset:
    # Determine if STL can be saved
    # if stl_save_path.exists() and not stl_overwrite:
    #     raise ValueError(f'STL already exists: {stl_save_path}')
    # elif stl_save_path.exists() and stl_overwrite:
    #     stl_save_path.unlink()
    # If STL can be saved, continue with process
    n_voxels = region.area  # 3D area is actually volume (N voxels)
    # Get bounding slice, row, and column
    min_slice, min_row, min_col, max_slice, max_row, max_col = region.bbox
    # Continue with process if particle has at least 2 voxels in each dim
    if (
        max_slice - min_slice >= 2 
        and max_row - min_row >= 2 
        and max_col - min_col >= 2
    ):
        # Isolate Individual Particles
        imgs_particle = region.image
        # Create array of zeros with a voxel of padding around region
        imgs_particle_padded = np.pad(imgs_particle, 1)
        # Insert region inside padding
        imgs_particle_padded[1:-1, 1:-1, 1:-1] = imgs_particle
        if n_erosions is not None:
            for _ in range(n_erosions):
                imgs_particle_padded = morphology.binary_erosion(
                    imgs_particle_padded
                )
            fn_suffix = f'-{n_erosions}_erosions'
        else:
            # Set fn_suffix in case there are no erosions
            fn_suffix = ''
        # Create save path
        fn = (
            f'{output_filename_base}'
            f'{str(region.label).zfill(n_particles_digits)}'
            f'{fn_suffix}.stl'
        )
        stl_save_path = Path(stl_dir_location) / fn
        # stl_save_path = Path(r'C:\Users\gusb\Research\PSAAP\STL-files') / fn
        # Do Surface Meshing - Marching Cubes
        if imgs_particle_padded.max() != 0:
            verts, faces, normals, values = measure.marching_cubes(
                imgs_particle_padded, step_size=voxel_step_size,
                allow_degenerate=allow_degenerate_tris
            )
            # Convert vertices (verts) and faces to numpy-stl format for saving:
            vertice_count = faces.shape[0]
            stl_mesh = mesh.Mesh(
                np.zeros(vertice_count, dtype=mesh.Mesh.dtype),
                remove_empty_areas=False
            )
            for i, face in enumerate(faces):
                for j in range(3):
                    stl_mesh.vectors[i][j] = verts[face[j], :]
            # Calculate offsets for STL coordinates
            if col_crop is not None:
                x_offset = col_crop[0]
            else: 
                x_offset = 0
            if row_crop is not None:
                y_offset = row_crop[0]
            else: 
                y_offset = 0
            if slice_crop is not None:
                z_offset = slice_crop[0]
            else:
                z_offset = 0
            # Add offset related to particle location. Subtracted by one to 
            # account for voxel padding on front end of each dimension.
            x_offset += min_col - 1
            y_offset += min_row - 1
            z_offset += min_slice - 1
            # Apply offsets to (x, y, z) coordinates of mesh
            stl_mesh.x += x_offset
            stl_mesh.y += y_offset
            stl_mesh.z += z_offset
            # stl_mesh.vectors are the position vectors. Multiplying by the 
            # spatial resolution of the scan makes these vectors physical.
            stl_mesh.vectors *= spatial_res
            # Save STL only if mesh is closed
            if stl_mesh.is_closed():
                stl_mesh.save(stl_save_path)
                props = postprocess_mesh(
                    stl_save_path, smooth_iter=10, simplify_n_tris=250, 
                    repair_mesh=True, return_props=True
                )
                mesh_props = {'particleID' : region.label}
                mesh_props.update(props)
                n_saved += 1
                bbox_dict['min_slice'].append(min_slice)
                bbox_dict['max_slice'].append(max_slice)
                bbox_dict['min_row'].append(min_row)
                bbox_dict['max_row'].append(max_row)
                bbox_dict['min_col'].append(min_col)
                bbox_dict['max_col'].append(max_col)
                if not suppress_save_msg:
                    print(f'STL saved: {stl_save_path}')
                    print(mesh_props)
            else:
                n_not_saved += 1
                print(
                    f'Particle {region.label} not saved: surface not '
                    'closed.'
                )
        else:
            n_not_saved += 1
            print(
                f'Surface mesh not created for particle {region.label}: '
                'Array empty.' 
            )

In [None]:
fig, ax = segment.plot_mesh_3D(stl_save_path)
ax.view_init(elev=None, azim=30)
plt.show()

In [None]:
fig, ax = segment.plot_stl()
ax.view_init(elev=None, azim=30)
plt.show()

In [None]:
fig, axes = segment.plot_imgs(
    imgs_particle_padded, 
    n_imgs=imgs_particle_padded.shape[0], 
    imgs_per_row=5
)
plt.show()