# Train the Unet_ID to segment each fibre
- Author: Rui Guo (KU Leuven), rui.guo1@kuleuven.be
- Date: July 08 2022

## Import packages

In [None]:
import imageio
import numpy as np
import imlib as im
import pylib as py
import torchsegnet as tn
import matplotlib.pyplot as plt
from PIL import Image

## Specify the output folder you want to save

In [None]:
output_dir = './output/T700-T-17/'

## Load the images
Get the images from the data folder.  
Notice: only one image is required during the training

In [None]:
# These two variable needs to be changed according to your purpose
dataset_folder = './data/T700-T-17/Training_Data/grayscale_slice/'
dataset_name = 'crop_h500(500_1000)_w250(1150_1400)_T17_10N_slice_00000.png'

In [None]:
dataset_file = py.join(dataset_folder, dataset_name)
origData = np.array(imageio.imread(dataset_file))
print(origData.shape)
print(np.max(origData))
print(np.min(origData))
print((origData.dtype))

plt.figure()
plt.imshow(origData[0:512,0:256], cmap='gray')
plt.show()

In [None]:
# im = Image.fromarray(origData[0:500,0:256])
# im.save(dataset_folder + '/crop_slice_'+str(0).zfill(5)+'.tif')

## Create or load masks

In [None]:
label_folder = './data/T700-T-17/Training_Data/mask_slice/'
label_name = 'Masks_crop_h500(500_1000)_w250(1150_1400)_T17_10N_slice_00000.png'

In [None]:
label_file = py.join(label_folder + label_name)

### (A) Manual annotation
If you don't have the label for your dataset, you need to annotate it first.

In [None]:
# py.annotate(origData, label_file)

### (B) Load the masks
If you have the label for your dataset already, you don't need to annotate again, just import it

In [None]:
labelInnerFibre = np.array(imageio.imread(label_file))
if labelInnerFibre.ndim > 2:
    labelInnerFibre = labelInnerFibre[:,:,0]
print(labelInnerFibre.shape)
plt.figure()
plt.imshow(labelInnerFibre, cmap='gray')
plt.show()

## Sample  
- **1. Choose the image size**  
This is used to set the image size for samples
- **2. Choose the stride_step for training and validation**   
This is used to set how to sample the images. If the step size is equal to the image size, then there will not be overlapping areas.

In [None]:
image_size = (64, 64)
train_stride_step = 8

In [None]:
Data = py.sample(data=[origData, labelInnerFibre], stride_step=train_stride_step, data_shape=image_size, shrink_size=0, show_img=True)

## Train the model

In [None]:
val_percent        = 0.2 # The probability to split the data as training and testing
epochs             = 200
batch_size         = 16
learning_rate      = 0.001
net_var            = 'UnetID'
save_checkpoint    = True
save_trainingmodel = True

In [None]:
tn.train_net(Data, image_size, output_dir, val_percent, epochs, batch_size,
             learning_rate, net_var, save_checkpoint, save_trainingmodel, amp=False)