# 3D U-net training and testing

This script is taken from https://keras.io/examples/vision/oxford_pets_image_segmentation/
We are doing image segmentation using a 3D U-net.
This script contains both the training part and the test part on different data sets. 

#### Import packages

In [None]:
### Import packages ############################
import os
import numpy as np
import PIL
from PIL import ImageOps
from PIL import Image
import random
import matplotlib.pyplot as plt

from tensorflow import keras
from tensorflow.keras.preprocessing.image import load_img
from tensorflow.keras import layers
from IPython.display import Image, display
from tensorflow.keras.preprocessing.image import load_img

import tensorflow as tf

#### Functions

In [None]:
class RootObject(keras.utils.Sequence):
    """Helper to iterate over the data (as Numpy arrays)."""

    def __init__(self, batch_size, img_size, input_img_paths, target_img_paths):
        self.batch_size = batch_size
        self.img_size = img_size
        self.input_img_paths = input_img_paths
        self.target_img_paths = target_img_paths

    def __len__(self):
        return len(self.target_img_paths) // self.batch_size

    def __getitem__(self, idx):
        """Returns tuple (input, target) correspond to batch #idx."""
        i = idx * self.batch_size
        batch_input_img_paths = self.input_img_paths[i : i + self.batch_size]
        batch_target_img_paths = self.target_img_paths[i : i + self.batch_size]
        x = np.zeros((self.batch_size,) + self.img_size + (3,), dtype="float32")
        for j, path in enumerate(batch_input_img_paths):
            img = load_img(path, target_size=self.img_size)
            x[j] = img
        y = np.zeros((self.batch_size,) + self.img_size + (1,), dtype="uint8")
        for j, path in enumerate(batch_target_img_paths):
            img = load_img(path, target_size=self.img_size, color_mode="grayscale")
            y[j] = np.expand_dims(img, 2)
            # Ground truth labels are 1, 2, 3. Subtract one to make them 0, 1, 2:
            y[j] -= 1
        return x, y

In [None]:
# Model Building 
# This is a U-net like architecture with a few differences : 
# like batch normalization, Separable convolution and dropout

def get_model(img_size, num_classes):
    inputs = keras.Input(shape=img_size + (3,))

    ### [First half of the network: downsampling inputs] ###

    # Entry block
    x = layers.Conv2D(32, 3, strides=2, padding="same")(inputs)
    x = layers.BatchNormalization()(x)
    x = layers.Activation("relu")(x)

    previous_block_activation = x  # Set aside residual

    # Blocks 1, 2, 3 are identical apart from the feature depth.
    for filters in [64, 128, 256]:
        x = layers.Activation("relu")(x)
        x = layers.SeparableConv2D(filters, 3, padding="same")(x)
        x = layers.BatchNormalization()(x)

        x = layers.Activation("relu")(x)
        x = layers.SeparableConv2D(filters, 3, padding="same")(x)
        x = layers.BatchNormalization()(x)

        x = layers.MaxPooling2D(3, strides=2, padding="same")(x)

        # Project residual
        residual = layers.Conv2D(filters, 1, strides=2, padding="same")(
            previous_block_activation
        )
        x = layers.add([x, residual])  # Add back residual
        previous_block_activation = x  # Set aside next residual

    ### [Second half of the network: upsampling inputs] ###

    for filters in [256, 128, 64, 32]:
        x = layers.Activation("relu")(x)
        x = layers.Conv2DTranspose(filters, 3, padding="same")(x)
        x = layers.BatchNormalization()(x)

        x = layers.Activation("relu")(x)
        x = layers.Conv2DTranspose(filters, 3, padding="same")(x)
        x = layers.BatchNormalization()(x)

        x = layers.UpSampling2D(2)(x)

        # Project residual
        residual = layers.UpSampling2D(2)(previous_block_activation)
        residual = layers.Conv2D(filters, 1, padding="same")(residual)
        x = layers.add([x, residual])  # Add back residual
        previous_block_activation = x  # Set aside next residual

    # Add a per-pixel classification layer
    outputs = layers.Conv2D(num_classes, 3, activation="sigmoid", padding="same")(x)

    # Define the model
    model = keras.Model(inputs, outputs)
    return model

In [None]:
def display_mask(i):
    """Quick utility to display a model's prediction."""
    mask = np.argmax(val_preds[i], axis=-1)
    mask = np.expand_dims(mask, axis=-1)
    img = PIL.ImageOps.autocontrast(keras.preprocessing.image.array_to_img(mask))
    display(img)
    
def EraseFile(repertoire):
    """Erase the precedent pictures in the folder."""
    import os
    files=os.listdir(repertoire)
    for i in range(0,len(files)):
        os.remove(repertoire+'/'+files[i])

def save_mask_pred(save_path, i):
    """Save the prediction"""
    mask = np.argmax(val_preds[i], axis=-1)
    mask = np.expand_dims(mask, axis=-1)
    img = PIL.ImageOps.autocontrast(keras.preprocessing.image.array_to_img(mask))
    path = save_path + 'pred_mask_' + str(i) +'.png'
    img.save(path)

#### Script 

In [None]:
#This part should be modified according to the dataset.

# Black roots label with white background
input_dir = "00.Datasets/modified/blackroots/Photo/"
target_dir = "00.Datasets/modified/blackroots/Masque/"

# #White roots label with black background
#input_dir = "00.Datasets/modified/whiteroots/Photo/"
#target_dir = "00.Datasets/modified/whiteroots/Masque/"

# #Beige roots label with white background
#input_dir = "00.Datasets/modified/cremeroots/Photo/"
#target_dir = "00.Datasets/modified/cremeroots/Masque/"

In [None]:
# First parameters to define

img_size = (720, 720) # input image size for U-net
num_classes = 2 # desired number of classes for the output mask
batch_size = 12 #Can be modified should not be more than the number of images in the differents data sets (train, val, test)

In [None]:
# Importation of the dataset 
# "sorted" is important : it ensure that the photo and the label match

# Photos
input_img_paths = sorted(
    [
        os.path.join(input_dir, fname)
        for fname in os.listdir(input_dir)
        if fname.endswith(".png")
    ]
)

# Labels
target_img_paths = sorted(
    [
        os.path.join(target_dir, fname)
        for fname in os.listdir(target_dir)
        if fname.endswith(".png") and not fname.startswith(".")
    ]
)

In [None]:
# Free up RAM in case the model definition cells were run multiple times
keras.backend.clear_session()

# Build model
model = get_model(img_size, num_classes)
model.summary()

In [None]:
val_samples = 12 # number of validation images

# Shuffle to prevent exploading gradient issue and over-fitting
random.Random(60).shuffle(input_img_paths)
random.Random(60).shuffle(target_img_paths)

# Split our img paths into a training and a validation set
train_input_img_paths = input_img_paths[:-val_samples]
train_target_img_paths = target_img_paths[:-val_samples]
val_input_img_paths = input_img_paths[-val_samples:]
val_target_img_paths = target_img_paths[-val_samples:]

# Instantiate data Sequences for each split
train_gen = RootObject(batch_size, img_size, train_input_img_paths, train_target_img_paths)

# #Check
#print(train_gen)

val_gen = RootObject(batch_size, img_size, val_input_img_paths, val_target_img_paths)


In [None]:
# Different parameters can be used here
# We changed : optimizer, loss function and number of epochs 

opti=keras.optimizers.RMSprop(
    learning_rate=0.00001,
    rho=0.9,
    momentum=0.0,
    epsilon=1e-07,
    centered=False,
    name="RMSprop")

#opti2=tf.keras.optimizers.SGD(learning_rate=0.000001, momentum=0.0, nesterov=False, name='SGD')
#opti3= tf.keras.optimizers.Adam(learning_rate=0.00001, beta_1=0.9, beta_2=0.999, epsilon=1e-07, amsgrad=False, name='Adam')

model.compile(optimizer=opti,
              loss="binary_crossentropy",
              #loss="sparse_categorical_crossentropy",
              metrics='Accuracy')

callbacks = [keras.callbacks.ModelCheckpoint("root_segmentation.h5", save_best_only=True)]

# Train the model, doing validation at the end of each epoch.
epochs = 30

history = model.fit(train_gen,
                    epochs=epochs,
                    validation_data=val_gen,
                    callbacks=callbacks)

In [None]:
### Loss and accuracy evolution during training
fig, (ax_loss, ax_acc) = plt.subplots(1, 2, figsize=(15,5))
ax_loss.plot(history.epoch, history.history["loss"], label="Train loss")
ax_acc.plot(history.epoch, history.history["accuracy"], label="Train accuracy")
plt.show()

In [None]:
# Generate predictions for images in the validation set
val_preds = model.predict(val_gen)

# Display results for validation image
i = 11

# Display input image
display(Image(filename=val_input_img_paths[i]))

# Display ground-truth target mask
img = PIL.ImageOps.autocontrast(load_img(val_target_img_paths[i]))
display(img)

# Display mask predicted by our model
display_mask(i)

In [None]:
# Saving the results - the predictions
# path need to be changed

path_pred_black = '03.3D_U-net/blackroots_pred/pred/'
#path_pred_white = '03.3D_U-net/whiteroots_pred/pred/'
#path_pred_creme = '03.3D_U-net/cremeroots_pred/pred/'

EraseFile(path_pred_black)
#EraseFile(path_pred_white)
#EraseFile('path_pred_creme')

for i in range (0, len(val_preds)): 
        save_mask_pred(path_pred_black, i)

# As we don't know the composition of the validation set we need to save the labels too
# path need to be change

EraseFile('03.3D_U-net/blackroots_pred/label/')
#EraseFile('03.3D_U-net/whiteroots_pred/label/')
#EraseFile('03.3D_U-net/cremeroots_pred/label/')

for name in val_target_img_paths : 
    img = Image.open(name)
    name_parts = name.split('/')
    # in name_parts[2], we have the value "DSC_5644.png"
    #image_single_name = name_parts[2].split('.')
    # #in image_single_name[0], we have the value "DSC_5644"
    path = '03.3D_U-net/blackroots_pred/label/'+ 'label_' + name_parts[2]
    #path = '03.3D_U-net/whiteroots_pred/label/'+ 'label_' + name_parts[2]
    #path = '03.3D_U-net/cremeroots_pred/label/'+ 'label_' + name_parts[2]
    img.save(path, 'png')