In [1]:
import os
import h5py
import tensorflow as tf
from tensorflow import keras
import datetime

In [2]:
def build_3d_unet(input_shape=(128,128,32,4)):
  inputs= keras.layers.Input(shape=input_shape)

  ## convolutional layers
  conv_layer1 = keras.layers.Conv3D(filters=32, kernel_size=(3, 3, 3), activation='relu',padding="same")(inputs)
  conv_layer2 = keras.layers.Conv3D(filters=64, kernel_size=(3, 3, 3), activation='relu',padding="same")(conv_layer1)

  ## add max pooling to obtain the most imformatic features
  pooling_layer1 = keras.layers.MaxPool3D(pool_size=(2, 2, 2),padding="same")(conv_layer2)

  conv_layer3 = keras.layers.Conv3D(filters=128, kernel_size=(3, 3, 3), activation='relu',padding="same")(pooling_layer1)
  pooling_layer2 = keras.layers.MaxPool3D(pool_size=(2, 2, 2),padding="same")(conv_layer3)
  conv_layer4 = keras.layers.Conv3D(filters=256, kernel_size=(3, 3, 3), activation='relu',padding="same")(pooling_layer2)
  upsamp_layer1=keras.layers.UpSampling3D(size=(2, 2, 2))(conv_layer4)
  conv_layer5 = keras.layers.Conv3D(filters=128, kernel_size=(3, 3, 3), activation='relu',padding="same")(upsamp_layer1)
  upsamp_layer2=keras.layers.UpSampling3D(size=(2, 2, 2))(conv_layer5)
  conv_layer6 = keras.layers.Conv3D(filters=64, kernel_size=(3, 3, 3), activation='relu',padding="same")(upsamp_layer2)
  conv_layer7 = keras.layers.Conv3D(filters=32, kernel_size=(3, 3, 3), activation='relu',padding="same")(conv_layer6)
  #cropped_output = keras.layers.Cropping3D(cropping=((0, 0), (0, 0), (0, 1)))(conv_layer7)  # Remove 1 voxel from depth
  ## now we perform up sampling
  outputs= keras.layers.Conv3D(filters=1, kernel_size=(1,1,1), activation='sigmoid')(conv_layer7)
  model = keras.models.Model(inputs=[inputs], outputs=[outputs])
  return model

In [3]:
def load_data_train(h5_file_path, batch_size):
    with h5py.File(h5_file_path, 'r') as hf:
        X = hf['X']
        Y = hf['Y']
        num_samples = X.shape[0]
        train_samples=int(num_samples*0.8)

        while True:  
            for i in range(0, train_samples, batch_size):
                X_batch = X[i:i + batch_size]
                Y_batch = Y[i:i + batch_size]
                yield X_batch, Y_batch

In [4]:
log_dir = os.path.join("logs", datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
tensorboard_callback = keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)
early_stopping = keras.callbacks.EarlyStopping(
    monitor='val_loss', 
    patience=4, 
    verbose=1, 
    restore_best_weights=True
)
model_checkpoint = keras.callbacks.ModelCheckpoint(
    'model-unet.best.keras',  
    monitor='val_loss', 
    verbose=1, 
    save_best_only=True, 
    mode='min'
)
reduce_lr = keras.callbacks.ReduceLROnPlateau(
    monitor='val_loss',  
    factor=0.1,         
    patience=3,
    min_lr=1e-6        
)
def loss_history_callback(logs=None):
    print(f"Epoch {8 + 1}:")
    print(f"  Binary Crossentropy Loss: {logs['output_layer_name_loss']}")
    print(f"  Unified Focal Loss: {logs['custom_loss_loss']}")
    print(f"  Combined Loss: {logs['loss']}")
callbacks = [
    early_stopping,
    model_checkpoint,
    tensorboard_callback,
    reduce_lr
]

In [5]:
def dice_coefficient(y_true, y_pred, smooth=1e-6):
    # Flatten the tensors
    y_true=tf.cast(y_true,tf.float32)
    y_pred=tf.cast(y_pred,tf.float32)
    y_true_f = tf.reshape(y_true, [-1])
    y_pred_f = tf.reshape(y_pred, [-1])
    
    # Calculate intersection
    intersection = tf.reduce_sum(y_true_f * y_pred_f)
    
    # Calculate Dice coefficient
    dice = (2. * intersection + smooth) / (tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f) + smooth)
    return dice

In [6]:
def iou_metric(y_true, y_pred, smooth=1e-6):
    # Flatten the tensors
    y_true=tf.cast(y_true,tf.float32)
    y_pred=tf.cast(y_pred,tf.float32)
    y_true_f = tf.reshape(y_true, [-1])
    y_pred_f = tf.reshape(y_pred, [-1])
    
    # Calculate intersection and union
    intersection = tf.reduce_sum(y_true_f * y_pred_f)
    union = tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f) - intersection
    
    # Calculate IoU
    iou = (intersection + smooth) / (union + smooth)
    return iou

In [7]:
def dice_loss(y_true, y_pred, smooth=1e-6):
    intersection = tf.reduce_sum(y_true * y_pred)
    dice = (2. * intersection + smooth) / (tf.reduce_sum(y_true) + tf.reduce_sum(y_pred) + smooth)
    return 1 - dice

def combined_loss(y_true, y_pred):
    return tf.keras.losses.BinaryCrossentropy()(y_true, y_pred) + dice_loss(y_true, y_pred)