In [None]:
import os
from glob import glob
import numpy as np
import tensorflow as tf
from sklearn.model_selection import train_test_split
from keras.callbacks import ModelCheckpoint, ReduceLROnPlateau, CSVLogger, EarlyStopping
from tools.data_loading import get_colormap, read_image, read_mask
from unet import UnetBuilder
from metrics import dice_coef, dice_loss, iou

In [None]:
# HYPERPARAMS
IMG_HEIGHT, IMG_WIDTH = 256, 256
INPUT_SHAPE = (IMG_HEIGHT, IMG_WIDTH, 3)
BATCH_SIZE = 32
N_CLASSES = 16  # 15 labels of abdominal organs + background label
LEARNING_RATE = 1e-4
EPOCHS = 75

DATASET_PATH = "./data"
MODEL_PATH = os.path.join("files", "model.h5")
CSV_PATH = os.path.join("files", "data.csv")
np.random.seed(42)
tf.random.set_seed(42)

In [None]:
def create_dir(path):
    if not os.path.exists(path):
        os.makedirs(path)


create_dir("files")

In [None]:
def load_dataset(path, split=0.25, log_feedback=False):
    images = sorted(glob(os.path.join(path, "train", "img", "*")))[:5000]  # take 5k images
    masks = sorted(glob(os.path.join(path, "train", "msk", "*")))[:5000]  # take 5k masks
    split_size = int(split * len(images))
    train_x, valid_x = train_test_split(images, test_size=split_size, random_state=42)
    train_y, valid_y = train_test_split(masks, test_size=split_size, random_state=42)
    if log_feedback:
        print(f"NUMBER OF PAIRS:\nTraining: {len(train_x)}/{len(train_y)}\nValidation: {len(valid_x)}/{len(valid_y)}")
    return (train_x, train_y), (valid_x, valid_y)


(train_x, train_y), (valid_x, valid_y) = load_dataset(path=DATASET_PATH, log_feedback=True)

In [None]:
CLASSES, COLORMAP = get_colormap("./organ_labels.json")


def preprocess(x, y):
    def f(x, y):
        x = x.decode()
        y = y.decode()

        x = read_image(x)
        y = read_mask(y, COLORMAP)

        return x, y

    image, mask = tf.numpy_function(f, [x, y], [tf.float32, tf.uint8])
    image.set_shape([IMG_HEIGHT, IMG_WIDTH, 3])
    mask.set_shape([IMG_HEIGHT, IMG_WIDTH, N_CLASSES])
    return image, mask

In [None]:
def tf_dataset(x, y, batch_size=8):
    dataset = tf.data.Dataset.from_tensor_slices((x, y))
    dataset = dataset.shuffle(buffer_size=5000)
    dataset = dataset.map(preprocess)
    dataset = dataset.batch(batch_size=batch_size)
    dataset = dataset.prefetch(2)
    return dataset

In [None]:
train_dataset = tf_dataset(train_x, train_y, batch_size=BATCH_SIZE)
valid_dataset = tf_dataset(valid_x, valid_y, batch_size=BATCH_SIZE)

In [None]:
model = UnetBuilder.build_unet(input_shape=INPUT_SHAPE, n_classes=N_CLASSES)

In [None]:
import tensorflow

metrics = [dice_coef, iou]

optimizer = tensorflow.keras.optimizers.Adam(learning_rate=LEARNING_RATE)
model.compile(loss="categorical_crossentropy", optimizer=optimizer, metrics=metrics)

In [None]:
callbacks = [
    ModelCheckpoint(filepath=MODEL_PATH, verbose=1, save_best_only=True),
    ReduceLROnPlateau(monitor="val_loss", factor=0.1, patience=4),
    CSVLogger(filename=CSV_PATH, append=True),
    EarlyStopping(monitor="val_loss", patience=12, restore_best_weights=False)
]

In [None]:
model.fit(train_dataset, validation_data=valid_dataset, epochs=EPOCHS, callbacks=callbacks)