<a href="https://colab.research.google.com/github/Biribbissolo/Prova/blob/main/Train_Script.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install -U -q segmentation-models
!pip install tensorflow==2.9.3
!pip install h5py==2.10.0
!pip install plotly==5.3.1

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import numpy as np
import os

import nibabel as nib
import tensorflow as tf

import matplotlib.pyplot as plt
from matplotlib import pyplot as plt
from tqdm import tqdm

import skimage
from skimage.io import imread, imshow, imsave
from skimage.transform import resize

from tensorflow import keras
from keras.callbacks import ModelCheckpoint
from keras.callbacks import CSVLogger
from keras.callbacks import EarlyStopping
from keras.utils.np_utils import to_categorical
from keras.preprocessing.image import ImageDataGenerator
from keras.models import load_model
from keras import metrics
from keras.callbacks import ReduceLROnPlateau

from segmentation_models import Unet
import math
from math import floor

import random
from random import seed
from random import random

Useful functions

In [None]:
def dispenser(dataset_path, Total_Slices):     # Function to establish how many slices to sample from each
                                                                      # subject: the greater the slices of the subject, the more slices will be selected.
  dataset_path = dataset_path + '/'
  directory = os.listdir(dataset_path)
  directory.sort()

  num_subj=len(directory)
  All_Slices = np.zeros(num_subj)

  i = 0
  for n in range (0,num_subj):
      case_path = dataset_path +'/'+ directory[n]
      volume = nib.load(os.path.join(case_path, "imaging.nii.gz"))
      n_slice,height,width = volume.shape
      All_Slices[i] = n_slice
      i = i + 1

  brick = round((min(All_Slices)/100) * 5)
  num_bricks = round(Total_Slices / brick)
  aux_sum = sum(All_Slices)
  Slice_Map = np.round((All_Slices / aux_sum) * num_bricks) * brick

  return Slice_Map


def visualizer(segm, IMG_HEIGHT, IMG_WIDTH):     #Function useful for displaying segmentations

   segm = resize(segm,(IMG_HEIGHT,IMG_WIDTH,4), mode='constant', preserve_range=True)

   back = segm[:,:,0]
   kid = segm[:,:,1]
   tum = segm[:,:,2]
   cys = segm[:,:,3]

   back = (back == 1)
   kid = (kid == 1)
   tum = (tum == 1)
   cys = (cys == 1)

   all_segments = np.zeros((IMG_HEIGHT, IMG_WIDTH, 3))

   all_segments[back] = (1,0,0)
   all_segments[kid] = (0,1,0)
   all_segments[tum] = (0,0,1)
   all_segments[cys] = (1,1,0)

   return all_segments


def Make_Dataset(dataset_path, Total_Slices, view_dataset):     # Function that actually defines the dataset.

    print('Beginning definition of Dataset useful for training...')
    print('\n')

    directory = os.listdir(dataset_path)
    directory.sort()

    num_subj=len(directory)                              # Get the number of subjects

    Slice_Map = dispenser(dataset_path, Total_Slices)
    Total_Slices = int(sum(Slice_Map))              # Get the actual number of total slices

    IMG_HEIGHT= 512
    IMG_WIDTH= 512
    IMG_CHANNELS = 3
    NUM_CLASSES = 4

    X_train=np.zeros([Total_Slices, IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS], dtype=np.uint8) # Default training image volume
    Y_train=np.zeros([Total_Slices, IMG_HEIGHT, IMG_WIDTH, NUM_CLASSES], dtype=np.float32) # Default training segmentation volume

    k=0
    checkpoint = 0

    for n in range (0,num_subj):
        case_path = dataset_path +'/'+ directory[n]
        volume = nib.load(os.path.join(case_path, "imaging.nii.gz"))
        segmentation = nib.load(os.path.join(case_path, "segmentation.nii.gz"))

        # Selection of some slices
        n_slice,height,width = volume.shape  # Subject dimensions

        volume_new = volume.slicer[0 : n_slice]
        segmentation_new = segmentation.slicer[0 : n_slice]

        # Conversion to ndarray
        Vol_train_supporto= volume_new.get_fdata().astype(np.int16)
        Mask_train_supporto = segmentation_new.get_fdata().astype(np.uint8)
        slices_per_subject = int(Slice_Map[k])

        k=k+1

        # Take n random and unique slices from the subject
        Vol_train=np.zeros((slices_per_subject,IMG_HEIGHT,IMG_WIDTH))
        Mask_train=np.zeros((slices_per_subject,IMG_HEIGHT,IMG_WIDTH))
        fette_randomiche=np.zeros(slices_per_subject)
        support = np.random.permutation(n_slice)
        for j in range (0,slices_per_subject):
          fette_randomiche[j]=support[j]

          if height!=IMG_HEIGHT or width!=IMG_WIDTH:

            Vol_train[j,:,:] = resize(Vol_train_supporto[int(abs(fette_randomiche[j]-1)),:,:], (IMG_HEIGHT,IMG_WIDTH), mode='constant', preserve_range=True)
            support_1 = to_categorical(Mask_train_supporto[int(abs(fette_randomiche[j]-1)),:,:], num_classes=NUM_CLASSES, dtype='float32')
            support_2 = resize(support_1, (IMG_HEIGHT,IMG_WIDTH,4), mode='constant', preserve_range=True)
            Mask_train[j,:,:] = np.argmax(support_2,axis=2).astype(np.uint8)

          else:
            Vol_train[j,:,:] = Vol_train_supporto[int(abs(fette_randomiche[j]-1)),:,:]
            Mask_train[j,:,:] = Mask_train_supporto[int(abs(fette_randomiche[j]-1)),:,:]


        vol_train_2 = np.zeros([slices_per_subject, IMG_HEIGHT, IMG_WIDTH,IMG_CHANNELS], dtype=np.uint8)
        mask_train_2=np.zeros([slices_per_subject, IMG_HEIGHT, IMG_WIDTH,NUM_CLASSES], dtype=np.float32)
        for i, id_ in tqdm(enumerate(Vol_train+1), total=len(Vol_train)):

          mask_train_2[i] = to_categorical(Mask_train[i], num_classes=NUM_CLASSES, dtype='float32')
          vol_train_2[i]=resize(Vol_train[i,:,:], (IMG_HEIGHT,IMG_WIDTH,1), mode='constant', preserve_range=True)

        X_train[checkpoint:(checkpoint + slices_per_subject),:,:,: ]=vol_train_2     # Final definition X_train
        Y_train[checkpoint:(checkpoint + slices_per_subject),:,:,: ]=mask_train_2    # Final definition Y_train

        checkpoint = checkpoint + slices_per_subject

        print(volume.shape)

    if view_dataset == 'on':
      for n in range(0,X_train.shape[0]):
        fig = plt.figure(figsize=(20, 20))
        fig = plt.figure()
        ax1 = fig.add_subplot(121)
        ax1.imshow(X_train[n,:,:,:]), ax1.set_title('TAC Image')
        ax2= fig.add_subplot(122)
        ax2.imshow(visualizer(Y_train[n,:,:,:], IMG_HEIGHT, IMG_WIDTH)), ax2.set_title('Manual Segmentation')

    print('The definition phase of the Dataset for training has ended.')
    print('\n')

    return X_train, Y_train

In [None]:
dataset_path = '/content/drive/your_dataset_path'

print('Definition of the Trainin set')
print('\n')

X_train, Y_train = Make_Dataset(dataset_path, 8000, 'off') # the final number of training images is set to 8000.

print('Definition of the Validation set')
print('\n')

X_val, Y_val = Make_Dataset(dataset_path, 1200, 'off') # the final number of validation images is set to 8000.

In [None]:
# Defining the characteristics of the data that are generated from the available images
image_datagen = ImageDataGenerator(rotation_range = 23,
                                   width_shift_range = 0.21,
                                   height_shift_range = 0.21,
                                   horizontal_flip = True,
                                   vertical_flip = False,
                                   zoom_range = 0.28,
                                   fill_mode = 'nearest')

#Data augmentation
val_datagen = ImageDataGenerator()

#Generator
seed = 1
def XYaugmentGenerator(X1, y, seed, batch_size):
    genX1 = image_datagen.flow(X1, y, batch_size=batch_size, seed=seed)
    genX2 = image_datagen.flow(y, X1, batch_size=batch_size, seed=seed)
    while True:
        X1i = genX1.next()
        X2i = genX2.next()
        yield X1i[0], X2i[0]

# Run the following block if the training is from scratch: round one.

In [None]:
NUM_CLASSES = 4

BACKBONE = 'efficientnetb5' # <-------- chosen architecture

model = Unet(backbone_name=BACKBONE,
            input_shape=(512,512,3),
            encoder_weights='imagenet',
            encoder_freeze=False,
            decoder_block_type='transpose',
            classes= NUM_CLASSES,
            decoder_filters=(512, 256, 128, 64, 32),
            decoder_use_batchnorm=True,
            activation='sigmoid')

# Optimization algorithm definition and loss function

m = [ metrics.TruePositives(name='tp') ]
model.compile('Nadam', loss='mse', metrics=m)

# Run the following block if the training is resumed: rounds following the first.

In [None]:
model=load_model('/content/drive/output_training_path/Epoch_15-Val_Loss0.00022.h5',custom_objects=None, compile=False)

# Optimization algorithm definition and loss function

m = [ metrics.TruePositives(name='tp') ]

opt =tf.keras.optimizers.Nadam(learning_rate=1e-03)                    #suggested values for cascade trainings: --> 1e-03 --> 3.16227766e-04 --> 1e-04.

model.compile(optimizer=opt, loss='mse', metrics=m)

Let's get the training going!

In [None]:
path = '/content/drive/output_training_path'
batch_size = 16
n_epochs = 100     # Reasonably large choice

# CVSLogger definition
csv_logger = CSVLogger('./log.out', append=True, separator=';')

#Earlystopping definition
earlystopping = EarlyStopping(monitor = 'val_tp',verbose = 1, min_delta = 0.01, patience = 65, mode = 'max')

#Learning rate definition
reduce_LR=ReduceLROnPlateau(monitor='val_tp', factor=0.316227766, patience=20, verbose=1, mode='max', min_delta=0.0001, cooldown=0, min_lr=0)

#Checkpoint rate definition
checkpoint = ModelCheckpoint(path + '/Epoch_{epoch:02d}-Val_Loss{val_loss:.5f}.h5', monitor='val_tp', mode='max', save_best_only=True, verbose=1)

callbacks_list = [csv_logger, reduce_LR, earlystopping, checkpoint]

# Train model
results = model.fit_generator(XYaugmentGenerator(X_train,Y_train,seed, batch_size),
                              steps_per_epoch = np.ceil(float(len(X_train))/float(batch_size)),
                              validation_data = val_datagen.flow(X_val,Y_val,batch_size),
                              validation_steps = np.ceil(float(len(X_val))/float(batch_size)),
                              shuffle = True,
                              epochs = n_epochs,
                              callbacks = callbacks_list)

This round is done!