In [None]:
import imageio as iio
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
import numpy as np
from pathlib import Path
import segment
from scipy import ndimage as ndi
from skimage import color, feature, filters, morphology, measure, segmentation, util
from stl import mesh
import sys
import yaml
%load_ext autoreload
%autoreload 2

## Load in Images

In [None]:
# Read YAML input file
yaml_file = Path(r'C:\Users\gusb\Research\PSAAP\segmentflow-input-gus.yml')
stream = open(yaml_file, 'r')
UI = yaml.load(stream,Loader=yaml.FullLoader)   # User Input
stream.close()
# Process User Input
ui_ct_img_dir           = UI['Files']['CT Scan Dir']
ui_stl_dir_location     = UI['Files']['STL Dir']
ui_output_filename_base = UI['Files']['STL Prefix']
ui_stl_overwrite        = UI['Files']['Overwrite Existing STL Files']
ui_single_particle_iso  = UI['Files']['Particle ID']
ui_suppress_save_msg    = UI['Files']['Suppress Save Messages']
ui_file_suffix          = UI['Load']['File Suffix']
ui_slice_crop           = UI['Load']['Slice Crop']
ui_row_crop             = UI['Load']['Row Crop']             
ui_col_crop             = UI['Load']['Col Crop']
ui_use_median_filter    = UI['Preprocess']['Apply Median Filter']
ui_rescale_range        = UI['Preprocess']['Rescale Intensity Range']
ui_n_otsu_classes       = UI['Binarize']['Number of Otsu Classes']
ui_n_selected_classes   = UI['Binarize']['Number of Classes to Select']
ui_use_int_dist_map     = UI['Segment']['Use Integer Distance Map']
ui_min_peak_distance    = UI['Segment']['Min Peak Distance']
ui_exclude_borders      = UI['Segment']['Exclude Border Particles']
ui_erode_particles      = UI['STL']['Erode Particles']    
ui_voxel_step_size      = UI['STL']['Marching Cubes Voxel Step Size']    
ui_spatial_res          = UI['STL']['Pixel-to-Length Ratio']
ui_show_segment_fig     = UI['Plot']['Show Segmentation Figure']
ui_n_imgs               = UI['Plot']['Number of Images']
ui_plot_maxima          = UI['Plot']['Plot Maxima']
ui_show_label_fig       = UI['Plot']['Show Particle Labels Figure']
ui_label_idx            = UI['Plot']['Particle Label Image Index']
ui_show_stl_fig         = UI['Plot']['Show Random STL Figure']
# Load images
print('Loading images...')
imgs = segment.load_images(
    ui_ct_img_dir,
    slice_crop=ui_slice_crop,
    row_crop=ui_row_crop,
    col_crop=ui_col_crop,
    convert_to_float=True,
    file_suffix=ui_file_suffix
)
print('--> Images loaded as 3D array: ', imgs.shape)
print('--> Size of array (GB): ', imgs.nbytes / 1E9)
# Plot images
fig, axes = segment.plot_imgs(imgs, n_imgs=4)
plt.show()

## Add preprocessing step

### Median filter followed by intensity rescale

In [None]:
print('Preprocessing images...')
imgs_pre = segment.preprocess(
    imgs, median_filter=True, rescale_intensity_range=ui_rescale_range
)
print('--> Preprocessing complete')
print('--> Size of array (GB): ', imgs_pre.nbytes / 1E9)
# Plot preprocessed images
fig, axes = segment.plot_imgs(imgs_pre, n_imgs=4)
plt.show()

In [None]:
print('Binarizing images...')
# Borders must be excluded after seg. to avoid exclusion of connected regions
imgs_binarized, thresh_vals = segment.binarize_multiotsu(
    imgs_pre, n_otsu_classes=2, exclude_borders=False, 
)
print('--> Binarization complete')
print('--> Size of array (GB): ', imgs_binarized.nbytes / 1E9)
# Plot binarized images
fig, axes = segment.plot_imgs(imgs_binarized, n_imgs=4)
plt.show()

## Segment the Images

In [None]:
print('Segmenting images...')
segment_dict = segment.watershed_segment(
    imgs_binarized, min_peak_distance=ui_min_peak_distance, 
    use_int_dist_map=ui_use_int_dist_map, return_dict=True
)
print('--> Segmentation complete')

In [None]:
# sys.getsizeof() doesn't represent nested objects; need to add manually
print('--> Size of segmentation results (GB):')
dict_size = sys.getsizeof(segment_dict)
for key, val in segment_dict.items():
    print(f'----> {key}: {sys.getsizeof(val) / 1E9}')

In [None]:
fig, axes = segment.plot_segment_steps(imgs, imgs_pre, imgs_binarized, segment_dict)
plt.show()

In [None]:
# How Many Particles Were Segmented?
n_particles = np.max(segment_dict['integer-labels'])
n_particles_digits = len(str(n_particles))
print('--> Total number of particles segmented: ' + str(n_particles))

### Exclude border particles and iterate through regions

In [None]:
if ui_exclude_borders:
    # How Many Particles Were Segmented?
    n_particles = np.max(segment_dict['integer-labels'])
    n_particles_digits = len(str(n_particles))
    print('--> Number of particles before border exclusion: ', str(n_particles))
    print()
    print('Excluding border particles...')
    segment_dict['integer-labels'] = segmentation.clear_border(
        segment_dict['integer-labels']
    )
regions = measure.regionprops(segment_dict['integer-labels'])
n_particles_noborder = len(regions)
print('--> Number of particles: ', str(n_particles_noborder))
fig, axes = segment.plot_segment_steps(imgs, imgs_pre, imgs_binarized, segment_dict)
plt.show()

In [None]:
print()
print('Generating surface meshes...')
n_saved = segment.save_regions_as_stl_files(
    imgs,
    regions,
    ui_stl_dir_location,
    ui_output_filename_base,
    n_particles_digits,
    suppress_save_msg=ui_suppress_save_msg,
    slice_crop=ui_slice_crop,
    row_crop=ui_row_crop,
    col_crop=ui_col_crop,
    spatial_res=ui_spatial_res,
    voxel_step_size=ui_voxel_step_size,
    erode_particles=ui_erode_particles,
    stl_overwrite=ui_stl_overwrite,
    return_n_saved=True,
)
print(f'--> {n_saved} STL file(s) written!')