# Segment each fibre for 2D slice with trained UnetID
- Author: Rui Guo (KU Leuven), rui.guo1@kuleuven.be
- Date: July 11 2022

## Import packages

In [None]:
import imlib as im
import pylib as py
import torchlib as tl
import torchsegnet as tn
import os
import imageio
import numpy as np
import time
import matplotlib.pyplot as plt
import cv2 as cv

## Path for output and pretained model

In [None]:
# The path for the output results
output_dir = './output/T700-T-17/2DData/'
# The path for the trained net model 
trainedNet_dir = './output/T700-T-17/model/visionary-disco-21/'

## Load the images

In [None]:
dataset_folder = './data/T700-T-17/2DData/'
dataset_name = '10N_slice_00009.tif'

In [None]:
dataset_file = os.path.join(dataset_folder + dataset_name)
origData = np.array([imageio.imread(dataset_file)])

In [None]:
plt.figure()
print(origData.shape)
print(np.max(origData))
print(np.min(origData))
print(np.mean(origData))
print(np.std(origData))

plt.imshow(origData[0,0:512,0:256], cmap='gray')
plt.show()

## Automatically segment fibre

In [None]:
dataset            = origData[0]
net_var            = 'UnetID'
output_dir         = output_dir
trainedNet_dir     = trainedNet_dir
dataset_name       = dataset_name # Default is segm_results_2D
checkpoint_id      = 'last_id' # or checkpoint_id=200
crop_input_shape   = (64,64,1)
constant_value     = 0 # This value is mainly used to consider the edge effects
out_threshold      = 0.5
pro_process        = True
postproc_param     =  {"method": "open", "kernel": "matlab_kernel(7)", 
                       "iteration": 1, 'save_postproc_results':True} # Remove small artifacts
load_checkpoint    = False
load_trainingmodel = True

In [None]:
start_t = time.time()
segm_img = tn.segm_2D(dataset=dataset, 
                      net_var=net_var, 
                      output_dir=output_dir,
                      trainedNet_dir=trainedNet_dir,
                      dataset_name=dataset_name,
                      checkpoint_id='last_id', 
                      crop_input_shape=(64,64,1),
                      constant_value=constant_value,
                      out_threshold=0.5,
                      load_checkpoint=load_checkpoint,
                      load_trainingmodel=load_trainingmodel,
                      **postproc_param)
end_t = time.time()
print('cost: ', end_t-start_t)

## Visulize the results
The whole visulize area is the whole image

In [None]:
img_size = origData[0].shape
print(origData[0].shape)
whole_visulize_area = [[0, img_size[0]], [0, img_size[1]]]

In [None]:
# Specify the area you want to show
crop_visulize_area = [[600, 800], [120, 350]] 

In [None]:
show_model = 'inner fibre'
overlay_img = im.overlay(origData[0], segm_img, visulize_area=crop_visulize_area, show_model=show_model)