# Segment each fibre for 3D/4D slice with trained UnetID
- Author: Rui Guo (KU Leuven), rui.guo1@kuleuven.be
- Date: July 13 2022

## Import packages

In [None]:
import imlib as im
import pylib as py
import torchlib as tl
import torchsegnet as tn
import os
import imageio
import numpy as np
import matplotlib.pyplot as plt
import time

## Path for output and pretained model

In [None]:
output_dir = './output/T700-T-17/3D&4DData_Group/'
trainedNet_dir = './output/T700-T-17/model/visionary-disco-21/'

## Load data

In [None]:
# If the data are saved as h5 files
# The value of file_id starts from 0
dataset_folder = './data/T700-T-17/3D&4DData_Group/4Ddemo/h5/'
orig_3D4DData, new_fileName_list = im.load_fromH5PY(dataset_folder, 
                                                 file_id=None, 
                                                 dtype=np.uint8) 

In [None]:
# If the data are saved as images
# dataset_folder_img = './data/T700-T-17/3D&4DData_Group/4Ddemo/images/'
# orig_3D4DData, ids_3D4Ddata = im.read_3D4Ddata(dataset_folder_img)

## Visulize data
- **set_id:** Specify which sets of data you want to see
- **visulize_plane_range:** Specify which plane and plane size of the data you want to see
- **overlay:** Set True if you want to see the overlay results

In [None]:
data_list = [orig_3D4DData]
set_id = new_fileName_list[0]
slice_range, height_range, width_range = [0], [0, 200],[0, 200]
# slice_range, height_range, width_range = [0, 200], [0],[0, 200]
# slice_range, height_range, width_range = [0, 200], [0, 200], [0]

visulize_plane_range = [slice_range, height_range, width_range]
overlay = False
show_model=None

In [None]:
im.visulize_3D4Ddata(data_list, set_id, visulize_plane_range, overlay=False, show_model=None)

## Automatically segment fibre

In [None]:
dataset            = orig_3D4DData
net_var            = 'UnetID'
output_dir         = output_dir
trainedNet_dir     = trainedNet_dir
dataset_name       = 'Default' # Default is segm_results_3D4D
checkpoint_id      = 'last_id' # or checkpoint_id=200
crop_slice_shape   = (64,64,1)
constant_value     = 0 # This value is mainly used to consider the edge effects
out_threshold      = 0.5
save_format        = ['H5', 'png']
postproc_param     =  {"method": "open", "kernel": "matlab_kernel(7)", 
                       "iteration": 1, 'save_postproc_results':True} # Remove small artifacts
load_checkpoint    = False
load_trainingmodel = True

In [None]:
start_t = time.time()
tn.segm_3D4D(dataset=dataset, 
            net_var=net_var, 
            output_dir=output_dir,
            trainedNet_dir=trainedNet_dir,
            dataset_name=dataset_name,
            checkpoint_id='last_id', 
            crop_slice_shape=(64,64,1),
            constant_value=constant_value,
            out_threshold=0.5,
            save_format=save_format,
            load_checkpoint=load_checkpoint,
            load_trainingmodel=load_trainingmodel,
            **postproc_param)
end_t = time.time()
print('cost: ', end_t-start_t)