# CDACS Model Experiment - Training

## Import necessary libraries

In [None]:
import tensorflow as tf

physical_devices = tf.config.list_physical_devices('GPU') 
for device in physical_devices:
    tf.config.experimental.set_memory_growth(device, True)
physical_devices

In [None]:
import os
from IPython.display import clear_output
import matplotlib.pyplot as plt

In [None]:
from module.model_utils import *
from module.metrics import *
from module.dataset_utils import BasicDatasetProcess

## Import datasets

In [None]:
import datasets.camelyon16

dataset_wrappers_he = BasicDatasetProcess.get_dataset_wrapper_from_dataset('camelyon16', 'HE_CR')

dataset_wrappers_he

## Preprocessing

### Define batch preprocessing parameters

In [None]:
# training proportion for seperate train data into trainging|validation
train_proportion = 0.7

# patch size for WSI before input into model
patch_size = 1000

# input size for model
# e.g. patch_size=1000, input_size = 1024
# A WSI for 10000x10000 pixels size will patched into 100 patched with 1000x1000 size
# multiple patches for 1000x1000 size will be resized into 1024x1024
input_size = 1024
# number of patches for random patch during training
num_patches = 100


# if WSI too big for training, try turning off
cache = False
prefetch = False


# number of patches input into model at the same time
BATCH_SIZE = 10
# buffer size for shuffling
BUFFER_SIZE = 100

### Preprocessing dataset using Color Deconvolution(CD) algorithm and Adaptive Color Segmentation(ACS) a.k.a. Color Region(CR) algorithm in the batch-processed manner

In [None]:
he_train_pre, he_val_pre = dataset_wrappers_he['train'].random_split(train_proportion)
he_train = he_train_pre.process().cd_normalize()
he_val = he_val_pre.process().cd_normalize()

train_dataset = he_train
val_dataset = he_val
train_dataset, val_dataset

### Prepare training dataset with randomly patching subimages for preventing artifacts in Image Segmentation training process

In [None]:
print(train_dataset)
train_images = train_dataset \
    .unpack_datapoint() \
    .assert_callback(lambda ds: ds.cache() if cache else ds) \
    .random_patches(num_patches, patch_size) \
    .resize_image(input_size)
print(train_images)

### Prepare validation dataset with ordered patching subimages for complete validation

In [None]:
print(val_dataset)
val_images = val_dataset \
    .unpack_datapoint() \
    .extract_large_patches(patch_size) \
    .resize_image(input_size) \
    .assert_callback(lambda ds: ds.cache() if cache else ds)
print(val_images)

### Define training parameters

In [None]:
TRAIN_LENGTH = train_images.dataset_size
VALIDATION_LENGTH = val_images.dataset_size

STEPS_PER_EPOCH = TRAIN_LENGTH // BATCH_SIZE

### Visually verify prepared dataset

In [None]:
for image, mask in he_val.unpack_datapoint().processed_dataset.shuffle(10).take(10):
    sample_image, sample_mask = image, mask

    if Dataset.get_ratio(sample_mask) >= 1e-2:
        display([sample_image, sample_mask])
        break

### Data augmentations using tf built-in functions to further preventing artifacts from image-segmentation-training process

In [None]:
def augment(input_image, input_mask):
    # flipping random horizontal or vertical
    if tf.random.uniform(()) > 0.5:
        input_image = tf.image.flip_left_right(input_image)
        input_mask = tf.image.flip_left_right(input_mask)
    if tf.random.uniform(()) > 0.5:
        input_image = tf.image.flip_up_down(input_image)
        input_mask = tf.image.flip_up_down(input_mask)

    return input_image, input_mask

In [None]:
print(train_images)
train_batches = (
    train_images
    .processed_dataset
    .shuffle(BUFFER_SIZE)
    .batch(BATCH_SIZE)
    .repeat()
    .map(augment)
)
if prefetch:
    train_batches = train_batches.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
print(train_batches)

print(val_images)
val_batches = (
    val_images
    .processed_dataset
    .batch(BATCH_SIZE)
)
if prefetch:
    val_batches = val_batches.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
print(val_batches)

## Setup UNet + MobileNetV2 hybrid model for 1024 input size

In [None]:
from tensorflow.keras.optimizers import Adam
# optimizer = Adam(1e-4)
optimizer = 'adam'

In [None]:
model_obj = MobileNetV2_1024_Model(
    output_channels=2,
    input_channels=1,
    input_size=input_size,
)
model = model_obj.model

model.compile(optimizer=optimizer,
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=[
                  'accuracy',
                  jacard_coef,
                  dice_coef,
              ])

### Resume previous training (optional)
- Set cell below to `code` type
- Uncomment `initial_epoch` parameter in model.fit to resume actual epoch count

### Define serious of callback functions that will be used during training

In [None]:
def large_prediction(image, mask, patch_size=patch_size):
    pred_mask = model_obj.easy_predict_single(image, patch_size=patch_size)
    return image, mask, pred_mask

In [None]:
n_image, n_mask, pred_mask = large_prediction(sample_image, sample_mask)
display([sample_image, sample_mask, pred_mask])

In [None]:
class DisplayCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        clear_output(wait=True)
        
        img, mask, pred_mask = large_prediction(sample_image, sample_mask)
        fig = display([img, mask, pred_mask], show=False)
        fig.savefig(os.path.join(output_folder, f'pred_sample_image_epoch_{epoch:04d}.png'))
        
        plt.show()
        print ('\nSample Prediction after epoch {}\n'.format(epoch+1))

In [None]:
output_folder = 'checkpoints'
logdir        = "logs/"
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=logdir)

best_checkpoint_filepath = output_folder+"/model_epoch_{epoch:04d}_val_dict_{val_dice_coef:.5f}.hdf5"
model_best_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=best_checkpoint_filepath,
    save_weights_only=False,
    monitor='val_dice_coef',
    mode='max',
    save_best_only=True,
)

checkpoint_filepath = output_folder+"/model_epoch_{epoch:04d}.hdf5"
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath,
    verbose=0,
    save_best_only=False,
    save_weights_only=False,
    mode="auto",
    save_freq='epoch',
    period=10,
)

os.makedirs(output_folder, exist_ok = True)

## Training

In [None]:
EPOCHS = 500
VALIDATION_STEPS = VALIDATION_LENGTH//BATCH_SIZE

model_history = model.fit(train_batches, epochs=EPOCHS,
                          steps_per_epoch=STEPS_PER_EPOCH,
                          validation_data=val_batches,
                          # initial_epoch=initial_epoch,
                          callbacks=[
                              DisplayCallback(),
                              tensorboard_callback,
                              model_best_checkpoint_callback,
                              model_checkpoint_callback,
                          ]
                        )