In [None]:
import os
import tensorflow as tf
from tools.data_loading import load_data, read_image, read_mask
from unet import UnetBuilder
from metrics import dice_coef, dice_loss, iou
from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau, CSVLogger, EarlyStopping

In [None]:
# HYPERPARAMS
IMG_HEIGHT, IMG_WIDTH = 256, 256
BATCH_SIZE = 32
N_CLASSES = 16  # 15 labels of abdominal organs + background label
DATASET_PATH = "./jpg_data"
LEARNING_RATE = 1e-4
EPOCHS = 100


In [None]:
def tf_parse(x, y):
    def _parse(x, y):
        x = read_image(x)
        y = read_mask(y)
        return x, y

    x, y = tf.numpy_function(_parse, [x, y], [tf.float64, tf.float64])
    x.set_shape([IMG_HEIGHT, IMG_WIDTH, 3])
    y.set_shape([IMG_HEIGHT, IMG_WIDTH, 1])
    return x, y


In [None]:
def tf_dataset(x, y, batch_size=8):
    dataset = tf.data.Dataset.from_tensor_slices((x, y))
    dataset = dataset.map(tf_parse, num_parallel_calls=tf.data.AUTOTUNE)
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(tf.data.AUTOTUNE)
    return dataset

In [None]:
(train_x, train_y), (valid_x, valid_y) = load_data(DATASET_PATH)

In [None]:
train_dataset = tf_dataset(train_x, train_y, batch_size=BATCH_SIZE)
valid_dataset = tf_dataset(valid_x, valid_y, batch_size=BATCH_SIZE)

In [None]:
input_shape = (IMG_HEIGHT, IMG_WIDTH, 3)
model = UnetBuilder.build_unet(input_shape=input_shape, n_classes=N_CLASSES)

In [None]:
import tensorflow

metrics = [dice_coef, iou]

optimizer = tensorflow.keras.optimizers.Adam(learning_rate=LEARNING_RATE)
model.compile(loss=dice_loss, optimizer=optimizer, metrics=metrics)

In [None]:
result_folder = "./files"
csv_path = os.path.join(result_folder, "csv_log")
model_path = os.path.join(result_folder, "model.keras")

# create folder to store final model and its results 
if not os.path.exists(result_folder):
    os.mkdir(result_folder)
    os.mkdir(csv_path)
    csv_path = os.path.join(csv_path, "csv_log.csv")

callbacks = [
    ModelCheckpoint(filepath=model_path, verbose=1, save_best_only=True),
    ReduceLROnPlateau(monitor="val_loss", factor=0.1, patience=4),
    CSVLogger(filename=csv_path),
    EarlyStopping(monitor="val_loss", patience=20, restore_best_weights=False)
]

In [None]:
model.fit(train_dataset, validation_data=valid_dataset, epochs=EPOCHS, callbacks=callbacks)