In [None]:
import imageio.v3 as iio
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
import numpy as np
import open3d as o3d
from pathlib import Path
import pandas as pd
import segment
from scipy import ndimage as ndi
from skimage import color, feature, filters, morphology, measure, segmentation, util
from stl import mesh
import sys
import yaml
%load_ext autoreload
%autoreload 2

## Load in Images

In [None]:
# Read YAML input file
yaml_file = Path(r'C:\Users\gusb\Research\PSAAP\segmentflow-input-gus.yml')
stream = open(yaml_file, 'r')
UI = yaml.load(stream,Loader=yaml.FullLoader)   # User Input
stream.close()
# Process User Input
ui_ct_img_dir           = UI['Files']['CT Scan Dir']
ui_stl_dir_location     = UI['Files']['STL Dir']
ui_output_filename_base = UI['Files']['STL Prefix']
ui_stl_overwrite        = UI['Files']['Overwrite Existing STL Files']
ui_single_particle_iso  = UI['Files']['Particle ID']
ui_suppress_save_msg    = UI['Files']['Suppress Save Messages']
ui_file_suffix          = UI['Load']['File Suffix']
ui_slice_crop           = UI['Load']['Slice Crop']
ui_row_crop             = UI['Load']['Row Crop']             
ui_col_crop             = UI['Load']['Col Crop']
ui_use_median_filter    = UI['Preprocess']['Apply Median Filter']
ui_rescale_range        = UI['Preprocess']['Rescale Intensity Range']
ui_n_otsu_classes       = UI['Binarize']['Number of Otsu Classes']
ui_n_selected_classes   = UI['Binarize']['Number of Classes to Select']
ui_use_int_dist_map     = UI['Segment']['Use Integer Distance Map']
ui_min_peak_distance    = UI['Segment']['Min Peak Distance']
ui_exclude_borders      = UI['Segment']['Exclude Border Particles']
ui_n_erosions           = UI['STL']['Number of Pre-Surface Meshing Erosions']    
ui_median_filter        = UI['STL']['Smooth Voxels with Median Filtering']
ui_spatial_res          = UI['STL']['Pixel-to-Length Ratio']
ui_voxel_step_size      = UI['STL']['Marching Cubes Voxel Step Size']    
ui_mesh_smooth_n_iters  = UI['STL']['Number of Smoothing Iterations']  
ui_mesh_simplify_n_tris = UI['STL']['Target number of Triangles/Faces']
ui_mesh_simplify_factor = UI['STL']['Simplification factor Per Iteration']
ui_show_segment_fig     = UI['Plot']['Show Segmentation Figure']
ui_n_imgs               = UI['Plot']['Number of Images']
ui_plot_maxima          = UI['Plot']['Plot Maxima']
ui_show_label_fig       = UI['Plot']['Show Particle Labels Figure']
ui_label_idx            = UI['Plot']['Particle Label Image Index']
ui_show_stl_fig         = UI['Plot']['Show Random STL Figure']
# Load images
print('Loading images...')
imgs = segment.load_images(
    ui_ct_img_dir,
    slice_crop=ui_slice_crop,
    row_crop=ui_row_crop,
    col_crop=ui_col_crop,
    convert_to_float=True,
    file_suffix=ui_file_suffix
)
print('--> Images loaded as 3D array: ', imgs.shape)
print('--> Size of array (GB): ', imgs.nbytes / 1E9)
# Plot images
fig, axes = segment.plot_imgs(imgs, n_imgs=4)
plt.show()

## Add preprocessing step

### Median filter followed by intensity rescale

In [None]:
print('Preprocessing images...')
imgs_pre = segment.preprocess(
    imgs, median_filter=True, rescale_intensity_range=ui_rescale_range
)
print('--> Preprocessing complete')
print('--> Size of array (GB): ', imgs_pre.nbytes / 1E9)
# Plot preprocessed images
fig, axes = segment.plot_imgs(imgs_pre, n_imgs=4)
plt.show()

In [None]:
print('Binarizing images...')
# Borders must be excluded after seg. to avoid exclusion of connected regions
imgs_binarized, thresh_vals = segment.binarize_multiotsu(
    imgs_pre, 
    n_otsu_classes=ui_n_otsu_classes, 
    n_selected_thresholds=ui_n_selected_classes, 
    exclude_borders=False, 
)
print('--> Binarization complete')
print('--> Size of array (GB): ', imgs_binarized.nbytes / 1E9)
# Plot binarized images
fig, axes = segment.plot_imgs(imgs_binarized, n_imgs=4)
plt.show()

## Segment the Images

In [None]:
print('Segmenting images...')
segment_dict = segment.watershed_segment(
    imgs_binarized, min_peak_distance=ui_min_peak_distance, 
    use_int_dist_map=ui_use_int_dist_map, return_dict=True
)
print('--> Segmentation complete')
# How Many Particles Were Segmented?
n_particles = np.max(segment_dict['integer-labels'])
n_particles_digits = len(str(n_particles))
print('--> Total number of particles segmented: ' + str(n_particles))
# Plot segmentation results
fig, axes = segment.plot_segment_steps(imgs, imgs_pre, imgs_binarized, segment_dict)
plt.show()

### Exclude border particles

In [None]:
if ui_exclude_borders:
    # How Many Particles Were Segmented?
    n_particles = np.max(segment_dict['integer-labels'])
    n_particles_digits = len(str(n_particles))
    print('--> Number of particles before border exclusion: ', str(n_particles))
    print()
    print('Excluding border particles...')
    segment_dict['integer-labels'] = segmentation.clear_border(
        segment_dict['integer-labels']
    )
regions = measure.regionprops(segment_dict['integer-labels'])
n_particles_noborder = len(regions)
print('--> Number of particles: ', str(n_particles_noborder))
fig, axes = segment.plot_segment_steps(imgs, imgs_pre, imgs_binarized, segment_dict)
plt.show()

## Surface meshing

### Define functions

In [None]:
def create_surface_mesh(
        imgs, slice_crop, row_crop, col_crop,
        min_slice, min_row, min_col, spatial_res=1, voxel_step_size=1
):
    verts, faces, normals, values = measure.marching_cubes(
        imgs, step_size=voxel_step_size,
        allow_degenerate=False
    )
    # Convert vertices (verts) and faces to numpy-stl format for saving:
    vertice_count = faces.shape[0]
    stl_mesh = mesh.Mesh(
        np.zeros(vertice_count, dtype=mesh.Mesh.dtype),
        remove_empty_areas=False
    )
    for i, face in enumerate(faces):
        for j in range(3):
            stl_mesh.vectors[i][j] = verts[face[j], :]
    # Calculate offsets for STL coordinates
    if col_crop is not None:
        x_offset = col_crop[0]
    else: 
        x_offset = 0
    if row_crop is not None:
        y_offset = row_crop[0]
    else: 
        y_offset = 0
    if slice_crop is not None:
        z_offset = slice_crop[0]
    else:
        z_offset = 0
    # Add offset related to particle location. Subtracted by one to 
    # account for voxel padding on front end of each dimension.
    x_offset += min_col - 1
    y_offset += min_row - 1
    z_offset += min_slice - 1
    # Apply offsets to (x, y, z) coordinates of mesh
    stl_mesh.x += x_offset
    stl_mesh.y += y_offset
    stl_mesh.z += z_offset
    # stl_mesh.vectors are the position vectors. Multiplying by the 
    # spatial resolution of the scan makes these vectors physical.
    stl_mesh.vectors *= spatial_res
    return stl_mesh, verts, faces, normals, values

### Iterate through regions

In [None]:
def save_regions_as_stl_files(
    regions,
    stl_dir_location,
    output_filename_base,
    n_particles_digits,
    suppress_save_msg=True,
    slice_crop=None,
    row_crop=None,
    col_crop=None,
    stl_overwrite=False,
    spatial_res=1,
    n_erosions=None,
    median_filter_voxels=True,
    voxel_step_size=1,
    mesh_smooth_n_iters=None, 
    mesh_simplify_n_tris=None, 
    mesh_simplify_factor=None, 
):
    props_df = pd.DataFrame(columns=[
        'particleID',
        'meshed',
        'n_voxels',
        'n_triangles',
        'min_slice',
        'max_slice',
        'min_row',
        'max_row',
        'min_col',
        'max_col',
    ])
    for region in regions:
        # Create save path
        fn = (
            f'{output_filename_base}'
            f'_{str(region.label).zfill(n_particles_digits)}.stl'
        )
        stl_save_path = Path(stl_dir_location) / fn
        # If STL can be saved, continue with process
        if stl_save_path.exists() and not stl_overwrite:
            raise ValueError(f'STL already exists: {stl_save_path}')
        elif not Path(stl_dir_location).exists():
            # Make directory if it doesn't exist
            Path(stl_dir_location).mkdir(parents=True)
        # Get bounding slice, row, and column
        min_slice, min_row, min_col, max_slice, max_row, max_col = region.bbox
        # If particle has less than 2 voxels in each dim, do not mesh surface
        # (marching cubes limitation)
        props = {}
        props['particleID'] = region.label
        props['n_voxels']   = region.area
        props['centroid']   = region.centroid
        props['min_slice']  = min_slice
        props['max_slice']  = max_slice
        props['min_row']    = min_row
        props['max_row']    = max_row
        props['min_col']    = min_col
        props['max_col']    = max_col
        if (
            max_slice - min_slice <= 2 + 2*n_erosions
            and max_row - min_row <= 2 + 2*n_erosions
            and max_col - min_col <= 2 + 2*n_erosions
        ):
            props['meshed'] = False
            print(
                f'Surface mesh not created for particle {region.label}: '
                'Particle smaller than minimum width in at least one dimension.'
            )
        # Continue with process if particle has at least 2 voxels in each dim
        else:
            # Isolate Individual Particles
            imgs_particle = region.image
            imgs_particle_padded = np.pad(imgs_particle, 1)
            # Insert region inside padding
            imgs_particle_padded[1:-1, 1:-1, 1:-1] = imgs_particle
            if n_erosions is not None and n_erosions > 0:
                for _ in range(n_erosions):
                    imgs_particle_padded = morphology.binary_erosion(
                        imgs_particle_padded
                    )
                particle_labeled = measure.label(
                    imgs_particle_padded, connectivity=1
                )
                particle_regions = measure.regionprops(particle_labeled)
                if len(particle_regions) > 1:
                    # Sort particle regions by area with largest first
                    particle_regions = sorted(
                        particle_regions, key=lambda r: r.area, reverse=True
                    )
                    # Clear non-zero voxels from imgs_particle_padded
                    imgs_particle_padded = np.zeros_like(
                        imgs_particle_padded, dtype=np.uint8
                    )
                    # Add non-zero voxels back for voxels belonging to largest 
                    # particle present (particle_regions[0])
                    imgs_particle_padded[
                        particle_labeled == particle_regions[0].label
                    ] = 255  # (255 is max for 8-bit/np.uint8 image)
            if median_filter_voxels:
                # Median filter used to smooth particle in image/voxel form
                imgs_particle_padded = filters.median(imgs_particle_padded)
            # Perform marching cubes surface meshing when array has values > 0
            try:
                stl_mesh, vertices, faces, normals, vals = create_surface_mesh(
                    imgs_particle_padded, slice_crop, row_crop, col_crop, 
                    min_slice, min_row, min_col, spatial_res=spatial_res, 
                    voxel_step_size=voxel_step_size
                )
                stl_mesh.save(stl_save_path)
                stl_mesh, mesh_props = segment.postprocess_mesh(
                    stl_save_path, smooth_iter=mesh_smooth_n_iters, 
                    simplify_n_tris=mesh_simplify_n_tris, 
                    iterative_simplify_factor=mesh_simplify_factor, 
                    recursive_simplify=False, resave_mesh=True
                )
                props['meshed'] = True
                props = {**props, **mesh_props}
                if not suppress_save_msg:
                    print(f'STL saved: {stl_save_path}')
            except RuntimeError as error:
                props['meshed'] = False
                print(
                    f'Surface mesh not created for particle {region.label}:',
                    error
                )
        props_df = pd.concat(
            [props_df, pd.DataFrame.from_records([props])], ignore_index=True
        )
    csv_fn = (f'{output_filename_base}_properties.csv')
    csv_save_path = Path(stl_dir_location) / csv_fn
    props_df.to_csv(csv_save_path, index=False)
    # Count number of meshed particles
    n_saved = len(np.argwhere(props_df['meshed'].to_numpy()))
    print(f'{n_saved} STL file(s) saved: {stl_dir_location}')

In [None]:
save_regions_as_stl_files(
    regions,
    ui_stl_dir_location,
    ui_output_filename_base,
    n_particles_digits,
    suppress_save_msg=ui_suppress_save_msg,
    slice_crop=ui_slice_crop,
    row_crop=ui_row_crop,
    col_crop=ui_col_crop,
    stl_overwrite=ui_stl_overwrite,
    spatial_res=ui_spatial_res,
    n_erosions=ui_n_erosions,
    median_filter_voxels=ui_median_filter,
    voxel_step_size=ui_voxel_step_size,
    mesh_smooth_n_iters=ui_mesh_smooth_n_iters, 
    mesh_simplify_n_tris=ui_mesh_simplify_n_tris, 
    mesh_simplify_factor=ui_mesh_simplify_factor, 
)