## Segment a sparse 3D image with a single material component  

The goal of this notebook is to develop a 3D segmentation algorithm that improves segmentation where features are detected.

**Data:** AM parts from Xuan Zhang. 

In [1]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import os
import h5py
import sys
from tomo_encoders import Patches
from tomo_encoders import DataFile
import tensorflow as tf
import time
from tomo_encoders.tasks import SparseSegmenter
from tomo_encoders.misc_utils.feature_maps_vis import view_midplanes

In [2]:
GPU_mem_limit = 42.0
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        tf.config.experimental.set_virtual_device_configuration(gpus[0], [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=GPU_mem_limit*1000.0)])
    except RuntimeError as e:
        print(e)        

In [3]:
data_path = '/data02/MyArchive/AM_part_Xuan' #ensure this path matches where your data is located.

test_binning = 2
# load vols here and quick look
dict_scrops = {'mli_L206_HT_650_L3' : (slice(100,-100, test_binning), \
                                    slice(None,None, test_binning), \
                                    slice(None,None, test_binning)), \
            'AM316_L205_fs_tomo_L5' : (slice(50,-50, test_binning), \
                                       slice(None,None, test_binning), \
                                       slice(None,None, test_binning))}

In [4]:
# create datasets input for train method
datasets = {}
for filename, s_crops in dict_scrops.items():
    ct_fpath = os.path.join(data_path, 'data', \
                            filename + '_rec_1x1_uint16.hdf5')
    seg_fpath = os.path.join(data_path, 'seg_data', \
                             filename, filename + '_GT.hdf5')
    
    datasets.update({filename : {'fpath_X' : ct_fpath, \
                                 'fpath_Y' : seg_fpath, \
                                 'data_tag_X' : 'data', \
                                 'data_tag_Y' : 'SEG', \
                                 's_crops' : s_crops}})

In [5]:
# syx = slice(600,-600,None)
# # view_midplanes(X[:,syx,syx])
# X = X[:,syx,syx]
# Y = Y[:,syx,syx]

## Train U-net for segmentation  

In [6]:
# Feature Extraction stuff
model_path = '/data02/MyArchive/aisteer_3Dencoders/models/AM_part_segmenter'
descriptor_tag = 'tmp'#'test_noblanks_pt2cutoff_nostd'

model_size = (64,64,64)
model_params = {"n_filters" : [32, 64],\
                "n_blocks" : 2,\
                "activation" : 'lrelu',\
                "batch_norm" : True,\
                "isconcat" : [True, True],\
                "pool_size" : [2,4],\
                "stdinput" : False}

training_params = {"sampling_method" : "random", \
                   "batch_size" : 24, \
                   "n_epochs" : 30,\
                   "random_rotate" : True, \
                   "add_noise" : 0.05, \
                   "max_stride" : 4, \
                   "cutoff" : 0.2}
fe = SparseSegmenter(model_initialization = 'define-new', \
                         model_size = model_size, \
                         descriptor_tag = descriptor_tag, \
                         **model_params)

############# ii = 1
############# ii = 0


In [7]:
# fe.models["segmenter"].summary()

In [8]:
# for ii in range(len(fe.models['segmenter'].layers)):
#     lshape = str(fe.models['segmenter'].layers[ii].output_shape)
#     lname = str(fe.models['segmenter'].layers[ii].name)
#     print(lshape + "    ::    "  + lname) 

In [9]:
Xs, Ys = fe.load_datasets(datasets)

loading data...
copy to gpu time per 1 size chunk: 1.91 ms
processing time per 1 size chunk: 0.41 ms
copy from gpu time per 1 size chunk: 3.54 ms
total time:  2.6554770469665527
done
Shape X (451, 2100, 2100), shape Y (451, 2100, 2100)
loading data...
copy to gpu time per 1 size chunk: 2.86 ms
processing time per 1 size chunk: 0.27 ms
copy from gpu time per 1 size chunk: 5.33 ms
total time:  3.40153169631958
done
Shape X (400, 2600, 2600), shape Y (400, 2600, 2600)


In [None]:
fe.train(Xs, Ys, training_params["batch_size"], \
         training_params["sampling_method"], \
         training_params["n_epochs"], \
         max_stride = training_params["max_stride"], \
         random_rotate = training_params["random_rotate"], \
         add_noise = training_params["add_noise"], \
         cutoff = training_params["cutoff"])
fe.save_models(model_path)

Epoch 1/30
Epoch 2/30
Epoch 3/30
Epoch 4/30
Epoch 5/30
Epoch 6/30

## Test it  

In [None]:
bin_size = (64,64,64)
descriptor_tag = 'test_noblanks_pt2cutoff_nostd'
model_names = {"segmenter" : "segmenter_Unet_%s"%descriptor_tag}
model_path = '/data02/MyArchive/aisteer_3Dencoders/models/AM_part_segmenter'

In [None]:
## Need to write the stitch function
max_stride = 8

In [None]:
fe = SparseSegmenter(model_initialization = 'load-model', \
                     model_names = model_names, model_path = model_path)

In [None]:
# X = fe._normalize_volume(X)

In [None]:
patches = Patches(Xs[0].shape, initialize_by = "grid", \
                  patch_size = fe.model_size, stride = max_stride)

In [None]:
x = patches.extract(Xs[0], fe.model_size).astype(np.float32)
y_pred = fe.models["segmenter"].predict(x[...,np.newaxis])
y_pred = y_pred[...,0]
# y_pred = np.round(y_pred).astype(np.uint8)

In [None]:
ii = 25
view_midplanes(vol = y_pred[ii])
view_midplanes(vol = x[ii])

In [None]:
y_pred[ii].mean()

In [None]:
fe.models["segmenter"].summary()

In [None]:
for ii in range(len(fe.models['segmenter'].layers)):
    lshape = str(fe.models['segmenter'].layers[ii].output_shape)
    lname = str(fe.models['segmenter'].layers[ii].name)
    print(lshape + "    ::    "  + lname) 