# Train the Unet_ID to segment each fibre
- Author: Rui Guo (KU Leuven), rui.guo1@kuleuven.be
- Date: Jan 2023

## Import packages

In [None]:
import fibresegt as fs
import numpy as np
import matplotlib.pyplot as plt

## Specify the output folder you want to save

In [None]:
output_dir = './output/demo/'

## Load the images
Get the images from the data folder.  
Notice: only one image is required during the training

In [None]:
# These two variable needs to be changed according to your purpose
dataset_folder = './data/demo/training_data/grayscale_slice/'
dataset_name = 'slice_00000_H300-250_550_W150-250_400.tif'

In [None]:
dataset_file = fs.join(dataset_folder, dataset_name)
origData = np.array(fs.imread(dataset_file))
fs.data_info(origData)
plt.figure()
plt.imshow(origData, cmap='gray')
plt.show()

## Create or load masks

In [None]:
label_folder = './data/demo/training_data/mask_slice/'
label_name = 'Masks_'+dataset_name[:-4]+'.png'

In [None]:
label_file = fs.join(label_folder + label_name)

In [None]:
data_info = dict(dataset_file=dataset_file, 
                 label_file=label_file)

### (A) Manual annotation
If you don't have the label for your dataset, you need to annotate it first.

In [None]:
# fs.annotate(origData, label_file)

### (B) Load the masks
If you have the label for your dataset already, you don't need to annotate again, just import it

In [None]:
labelInnerFibre = np.array(fs.imread(label_file))
if labelInnerFibre.ndim > 2:
    labelInnerFibre = labelInnerFibre[:,:,0]
fs.data_info(labelInnerFibre)
plt.figure()
plt.imshow(labelInnerFibre, cmap='gray')
plt.show()

## Sample  
- **1. Choose the image size**  
This is used to set the image size for samples
- **2. Choose the stride_step for training and validation**   
This is used to set how to sample the images. If the step size is equal to the image size, then there will not be overlapping areas.
- **3. Choose the itera_size**  
This is used to extend the fibre edge

In [None]:
image_size          = (64, 64)
train_stride_step   = 8
itera_enlarge_size  = 2
fig_path            = fs.join(label_folder, 'crop_data')
save_fig            = False
sample_info         = dict(image_size = image_size, 
                           train_stride_step=train_stride_step, 
                           itera_enlarge_size=itera_enlarge_size) 

In [None]:
Data = fs.generate_training_samples(data=[origData, labelInnerFibre], 
                                    stride_step=train_stride_step, 
                                    data_shape=image_size, 
                                    itera_enlarge_size=itera_enlarge_size,
                                    fig_path=fig_path,
                                    save_fig=save_fig,
                                    show_img=True)

## Train the model

In [None]:
Data               = Data
image_size         = (64, 64)
val_percent        = 0.0 # The probability to split the data as training and testing
epochs             = 200
batch_size         = 16
learning_rate      = 0.001
net_var            = 'UnetID'
data_aug           = {'brightness':0.3, 'contrast':0.3,
                      'GaussianBlur_kernel':5, 'GaussianBlur_sigma': (0.7, 1.3)}
# data_aug           = None
preprocess_info    = dict(data_info=data_info,
                          sample_info=sample_info)

In [None]:
fs.apis.train_net(Data=Data, image_size=image_size, output_dir=output_dir, 
                  val_percent=val_percent, epochs=epochs, batch_size=batch_size,
                  learning_rate=learning_rate, net_var=net_var, 
                  data_aug=data_aug,
                  preprocess_info=preprocess_info)