## Training
### Import library

In [1]:
import os
import numpy as np
import cv2
from glob import glob
from tqdm import tqdm
from sklearn.utils import shuffle
import tensorflow as tf
from tensorflow.keras.callbacks import ModelCheckpoint, CSVLogger, ReduceLROnPlateau, EarlyStopping
from tensorflow.keras.optimizers import Adam, SGD
from sklearn.model_selection import train_test_split
from unetr_2d import build_unetr_2d
from metrics import dice_loss, dice_coef
from patchify import patchify
import matplotlib.pyplot as plt

In [2]:
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
print("Num GPUs Available: ", len(tf.config.experimental.list_physical_devices('GPU')))

Num GPUs Available:  1


In [3]:
np.random.seed(42)
tf.random.set_seed(42)

### Configuration

In [4]:
# UNETR  Configration
cf = {}
cf["image_size"] = 256
cf["num_classes"] = 2
cf["num_channels"] = 3
cf["num_layers"] = 12
cf["hidden_dim"] = 128
cf["mlp_dim"] = 32
cf["num_heads"] = 6
cf["dropout_rate"] = 0.1
cf["patch_size"] = 16
cf["num_patches"] = (cf["image_size"]**2) // (cf["patch_size"]**2)
cf["flat_patches_shape"] = (
    cf["num_patches"],
    cf["patch_size"] * cf["patch_size"] * cf["num_channels"]
)

### Input pipeline

In [5]:
def create_dir(path):
    if not os.path.exists(path):
        os.makedirs(path)

def load_dataset(path, split=0.1):
    # Loading the images and masks
    X = sorted(glob(os.path.join(path, "images", "*.png")))
    Y = sorted(glob(os.path.join(path, "masks", "*.png")))

    # Spliting the data into training and testing
    split_size = int(len(X) * split)

    train_x, valid_x = train_test_split(X, test_size=split_size, random_state=42)
    train_y, valid_y = train_test_split(Y, test_size=split_size, random_state=42)

    train_x, test_x = train_test_split(train_x, test_size=split_size, random_state=42)
    train_y, test_y = train_test_split(train_y, test_size=split_size, random_state=42)

    return (train_x, train_y), (valid_x, valid_y), (test_x, test_y)

def read_image(path):
    path = path.decode()
    image = cv2.imread(path, cv2.IMREAD_COLOR)
    image = cv2.resize(image, (cf["image_size"], cf["image_size"]))
    image = image / 255.0

    # Processing to patches
    patch_shape = (cf["patch_size"], cf["patch_size"], cf["num_channels"])
    patches = patchify(image, patch_shape, cf["patch_size"])
    patches = np.reshape(patches, cf["flat_patches_shape"])
    patches = patches.astype(np.float32)

    return patches

def read_mask(path):
    path = path.decode()
    mask = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
    mask = cv2.resize(mask, (cf["image_size"], cf["image_size"]))
    mask = mask / 255.0
    mask = mask.astype(np.float32)
    mask = np.expand_dims(mask, axis=-1)
    return mask

def tf_parse(x, y):
    def _parse(x, y):
        x = read_image(x)
        y = read_mask(y)
        return x, y

    x, y = tf.numpy_function(_parse, [x, y], [tf.float32, tf.float32])
    x.set_shape(cf["flat_patches_shape"])
    y.set_shape([cf["image_size"], cf["image_size"], 1])
    return x, y

def tf_dataset(X, Y, batch=2):
    ds = tf.data.Dataset.from_tensor_slices((X, Y))
    ds = ds.map(tf_parse).batch(batch).prefetch(10)
    return ds

In [8]:
# Directory for storing files
create_dir("files")

# Hyperparameters
batch_size = 8
lr = 0.1
num_epochs = 1
model_path = os.path.join("files", "model.h5")
csv_path = os.path.join("files", "log.csv")

# Dataset
dataset_path = "../data/MSD"
(train_x, train_y), (valid_x, valid_y), (test_x, test_y) = load_dataset(dataset_path)

print(f"Train: \t{len(train_x)} - {len(train_y)}")
print(f"Valid: \t{len(valid_x)} - {len(valid_y)}")
print(f"Test: \t{len(test_x)} - {len(test_y)}")

train_dataset = tf_dataset(train_x, train_y, batch=batch_size)
valid_dataset = tf_dataset(valid_x, valid_y, batch=batch_size)

Train: 	2452 - 2452
Valid: 	306 - 306
Test: 	306 - 306


### Build Model

In [12]:
smooth = 1e-15
def dice_coef(y_true, y_pred):
    y_true = tf.keras.layers.Flatten()(y_true)
    y_pred = tf.keras.layers.Flatten()(y_pred)
    intersection = tf.reduce_sum(y_true * y_pred)
    return (2. * intersection + smooth) / (tf.reduce_sum(y_true) + tf.reduce_sum(y_pred) + smooth)

def dice_loss(y_true, y_pred):
    return 1.0 - dice_coef(y_true, y_pred)

In [7]:
# Model
model = build_unetr_2d(cf)
model.compile(
    loss=dice_loss, 
    optimizer=SGD(lr), 
    metrics=[dice_coef, "acc"]
)
print(model.summary())

Model: "UNETR_2D"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 256, 768)]   0           []                               
                                                                                                  
 dense (Dense)                  (None, 256, 128)     98432       ['input_1[0][0]']                
                                                                                                  
 tf.__operators__.add (TFOpLamb  (None, 256, 128)    0           ['dense[0][0]']                  
 da)                                                                                              
                                                                                                  
 layer_normalization (LayerNorm  (None, 256, 128)    256         ['tf.__operators__.add[0][

### Training Loop

In [9]:
callbacks = [
    ModelCheckpoint(model_path, verbose=1, save_best_only=True),
    ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=5, min_lr=1e-7, verbose=1),
    CSVLogger(csv_path),
    EarlyStopping(monitor='val_loss', patience=20, restore_best_weights=False)
]

model.fit(
    train_dataset,
    epochs=num_epochs,
    validation_data=valid_dataset,
    callbacks=callbacks
)

Epoch 1: val_loss improved from inf to 0.95544, saving model to files\model.h5


<keras.callbacks.History at 0x2885488dfd0>

## Testing

In [10]:
# Directory for storing results
create_dir("results")

In [14]:
# Load the model 
model_path = os.path.join("files", "model.h5")
model = tf.keras.models.load_model(
    model_path, 
    custom_objects={
        'dice_loss': dice_loss,
        'dice_coef': dice_coef
    }
)

In [16]:
for index, (x, y) in enumerate(tqdm(zip(test_x, test_y), total=10)):
    
    print(f"Iteration: {index}")
    
    # Stopping condition
    if (index==10): break

    # Read the image
    image = cv2.imread(x, cv2.IMREAD_COLOR)
    image = cv2.resize(image, (cf["image_size"], cf["image_size"]))
    x = image / 255.0
    print(x.shape)

    patch_shape = (cf["patch_size"], cf["patch_size"], cf["num_channels"])
    patches = patchify(x, patch_shape, cf["patch_size"])
    patches = np.reshape(patches, cf["flat_patches_shape"])
    patches = patches.astype(np.float32)
    patches = np.expand_dims(patches, axis=0)
    
    # Read Mask
    mask = cv2.imread(y, cv2.IMREAD_GRAYSCALE)
    mask = cv2.resize(mask, (cf["image_size"], cf["image_size"]))
    mask = mask / 255.0
    mask = np.expand_dims(mask, axis=-1)
    mask = np.concatenate([mask, mask, mask], axis=-1)
    print(mask.shape)

    # Prediction
    pred = model.predict(patches, verbose=0)[0]
    pred = np.concatenate([pred, pred, pred], axis=-1)
    print(pred.shape)

    # Save final mask
    cat_images = np.concatenate([image, mask*255, pred*255], axis=1)
    save_image_path = f"results/test_{index}.png"
    cv2.imwrite(save_image_path, cat_images)

  0%|          | 0/10 [00:00<?, ?it/s]

Iteration: 0
(256, 256, 3)
(256, 256, 3)


 20%|██        | 2/10 [00:00<00:01,  5.36it/s]

(256, 256, 3)
Iteration: 1
(256, 256, 3)
(256, 256, 3)
(256, 256, 3)
Iteration: 2
(256, 256, 3)
(256, 256, 3)


 40%|████      | 4/10 [00:00<00:00,  7.45it/s]

(256, 256, 3)
Iteration: 3
(256, 256, 3)
(256, 256, 3)
(256, 256, 3)
Iteration: 4
(256, 256, 3)
(256, 256, 3)


 60%|██████    | 6/10 [00:00<00:00,  8.14it/s]

(256, 256, 3)
Iteration: 5
(256, 256, 3)
(256, 256, 3)
(256, 256, 3)
Iteration: 6
(256, 256, 3)
(256, 256, 3)


 80%|████████  | 8/10 [00:01<00:00,  8.82it/s]

(256, 256, 3)
Iteration: 7
(256, 256, 3)
(256, 256, 3)
(256, 256, 3)
Iteration: 8
(256, 256, 3)
(256, 256, 3)


100%|██████████| 10/10 [00:01<00:00,  8.09it/s]

(256, 256, 3)
Iteration: 9
(256, 256, 3)
(256, 256, 3)
(256, 256, 3)
Iteration: 10



