# iPython Notebook for Training a New Model 

## This iPython notebook details code for training. 

Note: To train these models will require a GPU and cannot be done with a CPU. In addition, it is highly recommended that the user creates a new python file using this notebook as a guide. Training using a Jupyter Notebook is not recommended 

## Please make sure the following files are in the directory in which you are training from: 
1. training.py 
2. unet3d.py
3. cunet3d.py
4. unetplusplus3d.py
5. testing3D.py

## Please label your data as follows: 

### Training Data 
1. Training Data - scar_imgs_train.npy
2. Training Segmentations - scar_segs_train.npy

### Validation Data 
1. Validation Data - scar_imgs_val.npy
2. Validation Segmentations - scar_imgs_val.npy

For data normalization, it is recommended that you use the mean intensity of your training data. Please save this 
number as: scar_shift.npy


#### Training and Validation Data Dimensions: 

Please ensure your training data is of the format: (N_studies,256,256,16). Unfortunately, the only compatible size is 256x256x16. This will be corrected in future releases. 

Please ensure your segmentation data is of the format: (N_studies,256,256,16,3). The channels are 0 - Background Voxels, 1 - Myocardial Voxels, 2 - Scar Voxels 

In [1]:
# Required Libraries 
import numpy as np 
import os 

# Keras Libraries 
from tensorflow.keras.utils import to_categorical 
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import tensorflow as tf


# My libraries 
from unet3d import get_unet3D_multi 
from cunet3d import get_Cunet3D_mulit
from unetplusplus3d import get_unetpp3D_multi

from training import training_main 
from testing3D import testing_parallel

2023-12-08 08:27:28.468896: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.10.1


In [2]:
# DEFINE PATH WHERE DATA IS HERE: 
dpath = # Input path here 

print('*'*100)
print('LOADING DATA')

shift = np.load(dpath+'scar_mean.npy')

imgs_train = np.load(dpath+'scar_imgs_train.npy')
imgs_val = np.load(dpath+'scar_imgs_test.npy')
segs_train = np.load(dpath+'scar_segs_train.npy')
segs_val = np.load(dpath+'scar_segs_test.npy')

img_rows = imgs_train.shape[1]
img_cols = imgs_train.shape[2]
slices = imgs_train.shape[3]
print('*'*100)



**************************************************


In [None]:
# DEFINE PATH WHERE YOU WOULD LIKE TO SAVE LEARNING CURVES AND WEIGHTS 

path1 = # Input path for saving here 

In [None]:
# HYPERPARAMETER SELECTION 

# Model Parameters 
classes = 3 
layers = 4 # INPUT NUMBER OF LAYERS. AT THIS TIME, THIS ONLY APPLIES TO THE U-NET MODEL 
min_convs = 4 # THIS SETS THE MINIMUM NUMBER OF CONVOLUTIONS IN THE BOTTOM LAYER 
kernel = (3,3,3) # DO NOT CHANGE 

''' 
For loss functions, you have the following options: 
1. wcce_kld (Weighted Adaptive Categorical Entropy with KL Divergence) - THIS WAS USED IN THE MANUSCRIPT 
2. dice_loss (Dice Coefficient Loss)
3. dice_gen_loss (Generalized Dice Coefficient Loss)
4. weighted_categorical_crossentropy (Weighted Adaptive Categorical Cross Entropy)
5. categorical_crossentropy (Adaptive Categorical Cross Entropy. DEFAULT if none of the other options are input) 
'''
lossfun = 'wcce_kld'

# Hyperparameters 
learning_rate = 1e-5 # LEARNING RATE 
decay = 0.6 # DECAY RATE FOR ADAPTIVE CROSS ENTROPY 
batches = 4 # BATCH SIZE 
epochs_total = 30 # TOTAL NUMBER OF EPOCHS 
epochs_batch = 5 # BATCHES PER EPOCH 
weights_init = (1,14,48) # WEIGHTS IF USING WEIGHTED CROSS ENTROPY 

strategy = tf.distribute.MirroredStrategy()

train_loss = np.zeros((epochs_total,))
train_dice = np.zeros((epochs_total,))
train_myo = np.zeros((epochs_total,))

val_loss = np.zeros((epochs_total,))
val_dice = np.zeros((epochs_total,))
val_myo = np.zeros((epochs_total,))



In [None]:
# USER MODEL TYPE SELECTION AND MODEL DEFINITION 

model_type = # User can input the models. The selection is: U-Net, Cascaded U-Net, U-Net++. Please input the name of the model of choice here. 

if model_type == "U-Net": 
    with strategy.scope():
        model = get_unet3D_multi(img_rows, img_cols, slices, classes, layers, min_convs, kernel, learning_rate, lossfun, weights_init)
    model.summary()
elif model_type == "Cascaded U-Net": 
    with strategy.scope():
        model = get_Cunet3D_mulit(img_rows, img_cols, slices, classes, layers, min_convs, kernel, learning_rate, lossfun, weights_init)
    model.summary()
    path1 = "weights_cunet.h5": 
elif model_type == "U-Net++": 
    with strategy.scope():
        model = get_unetpp3D_multi(img_rows, img_cols, slices, classes, layers, min_convs, kernel, learning_rate, lossfun, weights_init)
    model.summary()
else: 
    print('Please enter one of the appropriate options: U-Net, Cascaded U-Net, or U-Net++')


In [None]:
# TRAINING 

savepath1 = path1[:-1]
savepath1 = savepath1+"_batch1/"
cache = training_main(imgs_train,segs_train,model,savepath1, batches, epochs_batch, imgs_val, segs_val)
num_epoch_batches = int(np.round(epochs_total/epochs_batch))

train_loss[0:epochs_batch] = np.array(cache.history['loss'])
train_dice[0:epochs_batch] = np.array(cache.history['dice_coef'])
train_myo[0:epochs_batch] = np.array(cache.history['myo_dice'])

val_loss[0:epochs_batch] = np.array(cache.history['val_loss'])
val_dice[0:epochs_batch] = np.array(cache.history['val_dice_coef'])
val_myo[0:epochs_batch] = np.array(cache.history['val_myo_dice'])

for ii in range(num_epoch_batches - 1): 
    weights1 = (weights_init[1]-1)*np.exp(-ii*decay) + 1 
    weights2= (weights_init[2]-1)*np.exp(-ii*decay) + 1
    
    weights = (weights_init[0],weights1,weights2)
    
    if model_type == "U-Net": 
        with strategy.scope():
            model = get_unet3D_multi(img_rows, img_cols, slices, classes, layers, min_convs, kernel, learning_rate, lossfun, weights_init)
        model.summary()
    elif model_type == "Cascaded U-Net": 
        with strategy.scope():
            model = get_Cunet3D_mulit(img_rows, img_cols, slices, classes, layers, min_convs, kernel, learning_rate, lossfun, weights_init)
        model.summary()
        path1 = "weights_cunet.h5": 
    elif model_type == "U-Net++": 
        with strategy.scope():
            model = get_unetpp3D_multi(img_rows, img_cols, slices, classes, layers, min_convs, kernel, learning_rate, lossfun, weights_init)

    old_model_weight_path = savepath1 + 'final_weights.h5'
    model.load_weights(old_model_weight_path)
    
    ind1 = epochs_batch*(ii+1)
    ind2 = epochs_batch*(ii+2)
    
    print('-'*100)
    print('Training Epochs'+str(ind1)+"-"+str(ind2))
    print('-'*100)
    pathii = path1[:-1]
    pathii = pathii+"_batch"+str(ii+2)+"/"
    cache = training_main(imgs_train,segs_train,model,pathii, batches, epochs_batch, imgs_val, segs_val)
    
    
    train_loss[ind1:ind2] = np.array(cache.history['loss'])
    train_dice[ind1:ind2] = np.array(cache.history['dice_coef'])
    train_myo[ind1:ind2] = np.array(cache.history['myo_dice'])

    val_loss[ind1:ind2] = np.array(cache.history['val_loss'])
    val_dice[ind1:ind2] = np.array(cache.history['val_dice_coef'])
    val_myo[ind1:ind2] = np.array(cache.history['val_myo_dice'])
    

np.save(path1+'train_loss.npy',train_loss)
np.save(path1+'train_dice.npy',train_dice)
np.save(path1+'train_myo.npy',train_myo)

np.save(path1+'val_loss.npy',val_loss)
np.save(path1+'val_dice.npy',val_dice)
np.save(path1+'val_myo.npy',val_myo)

In [None]:
# SAVING THE RESULTS OF THE VALIDATION DATA SET 

savepath = pathii+'validation/'

# Create the path 
if not os.path.exists(savepath):
    os.mkdir(savepath)

gpus = 1 

# Testing model on validation set only 
preds, contours = testing_parallel(savepath, model, imgs_val, segs_val, gpus, classes, shift)

In [None]:
# PLOTTING REPRESENTATIVE SEGMENTATION MASKS 

m = # INPUT WHICH STUDY NUMBER YOU WOULD LIKE TO PLOT. IF YOU ONLY HAVE 1 STUDY, then m = 1 

for ii in range(16): 
    plt.subplot(4,4,ii+1)
    plt.imshow(preds[m,:,:,:])
    

In [None]:
# PLOTTING REPRESENTATIVE SEGMENTATION MASKS 

m = # INPUT WHICH STUDY NUMBER YOU WOULD LIKE TO PLOT. IF YOU ONLY HAVE 1 STUDY, then m = 1 

for ii in range(16): 
    plt.subplot(4,4,ii+1)
    plt.imshow(contours[m,:,:,:,:])
    