# Imports

In [None]:
import tensorflow as tf
from tensorflow_examples.models.pix2pix import pix2pix
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import clear_output

import os
import datetime
import pickle

In [None]:
gpus = tf.config.list_physical_devices('GPU')
if gpus:
    # Restrict TensorFlow to only use the first GPU
    try:
        tf.config.set_visible_devices(gpus[0], 'GPU')
        logical_gpus = tf.config.list_logical_devices('GPU')
        print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPU")
    except RuntimeError as e:
        # Visible devices must be set before GPUs have been initialized
        print(e)

# Data

In [None]:
DATA_PATH = "./data/"
BATCH_SIZE = 8
IMAGE_HEIGHT = 224
IMAGE_WIDTH = 224
SEED = 123
N_CLASSES = 66

def getDataset(path):
    return tf.keras.utils.image_dataset_from_directory(
                path,
                labels=None,
                color_mode='rgb',
                batch_size=BATCH_SIZE,
                image_size=(IMAGE_HEIGHT, IMAGE_WIDTH),
                shuffle=True,
                seed=SEED,
                validation_split=None,
                interpolation='bilinear',
                crop_to_aspect_ratio=True,
            )

In [None]:
train_x = getDataset(DATA_PATH + "training/images")
train_y = getDataset(DATA_PATH + "training/instances")

In [None]:
test_x = getDataset(DATA_PATH + "validation/images")
test_y = getDataset(DATA_PATH + "validation/instances")

In [None]:
train_ds = tf.data.Dataset.zip((train_x, train_y))
test_ds = tf.data.Dataset.zip((test_x, test_y))

In [None]:
def displayExample(display_list):
    plt.figure(figsize=(15, 15))

    title = ['Input Image', 'True Mask', 'Predicted Mask']

    for i in range(len(display_list)):
        plt.subplot(1, len(display_list), i+1)
        plt.title(title[i])
        plt.imshow(tf.keras.utils.array_to_img(display_list[i]))
        plt.axis('off')
    plt.show()

In [None]:
for images, masks in train_ds.take(1):
    sample_image, sample_mask = images[0], masks[0]
    print(sample_mask.shape)
    displayExample([sample_image, sample_mask])

In [None]:
def normalize(input_image, input_mask):
    input_image = tf.cast(input_image, tf.float32) / 255.0 # [0.0, 1.0]
    input_mask = tf.cast(input_mask[:, :, :, 0], np.uint8) # [0, 65]
    input_mask = tf.one_hot(input_mask, N_CLASSES) # One hot each pixel
    return input_image, input_mask

In [None]:
train_ds = train_ds.map(normalize)
test_ds = test_ds.map(normalize)

# Model

In [None]:
base_model = tf.keras.applications.MobileNetV2(input_shape=[IMAGE_HEIGHT, IMAGE_WIDTH, 3], include_top=False)

# Use the activations of these layers
layer_names = [
    'block_1_expand_relu',   # 64x64
    'block_3_expand_relu',   # 32x32
    'block_6_expand_relu',   # 16x16
    'block_13_expand_relu',  # 8x8
    'block_16_project',      # 4x4
]
base_model_outputs = [base_model.get_layer(name).output for name in layer_names]

# Create the feature extraction model
down_stack = tf.keras.Model(inputs=base_model.input, outputs=base_model_outputs)

down_stack.trainable = False

In [None]:
up_stack = [
    pix2pix.upsample(512, 3),  # 4x4 -> 8x8
    pix2pix.upsample(256, 3),  # 8x8 -> 16x16
    pix2pix.upsample(128, 3),  # 16x16 -> 32x32
    pix2pix.upsample(64, 3),   # 32x32 -> 64x64
]

In [None]:
def unet_model(output_channels:int):
    inputs = tf.keras.layers.Input(shape=[IMAGE_HEIGHT, IMAGE_WIDTH, 3])

    # Downsampling through the model
    skips = down_stack(inputs)
    x = skips[-1]
    skips = reversed(skips[:-1])

    # Upsampling and establishing the skip connections
    for up, skip in zip(up_stack, skips):
        x = up(x)
        concat = tf.keras.layers.Concatenate()
        x = concat([x, skip])

    # This is the last layer of the model
    last = tf.keras.layers.Conv2DTranspose(
        filters=output_channels, kernel_size=3, strides=2,
        padding='same')  # 128x128 -> 256x256

    x = last(x)

    return tf.keras.Model(inputs=inputs, outputs=x)

In [None]:
model = unet_model(output_channels=N_CLASSES)

model.compile(optimizer='adam',
              loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
              metrics=['categorical_accuracy', tf.keras.metrics.Precision()])

In [None]:
model.summary()

In [None]:
tf.keras.utils.plot_model(model, show_shapes=True)

In [None]:
def create_mask(pred_mask):
    pred_mask = tf.math.argmax(pred_mask, axis=-1)
    pred_mask = pred_mask[..., tf.newaxis]
    if len(pred_mask.shape) > 3:
        pred_mask = pred_mask[0]
    return pred_mask

In [None]:
def show_predictions(dataset=None, num=1):
    for image, mask in dataset.take(num):
        pred_mask = create_mask(model.predict(image))
        true_mask = create_mask(mask[0])
        displayExample([image[0], true_mask, pred_mask])

In [None]:
show_predictions(train_ds)

In [None]:
class DisplayCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        clear_output(wait=True)
        show_predictions(train_ds)
        print ('\nSample Prediction after epoch {}\n'.format(epoch+1))

In [None]:
checkpoint_filepath = './models/model.h5'

model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath,
    monitor='val_loss',
    verbose=0,
    mode='max',
    save_freq="epoch",
    save_best_only=True)

In [None]:
early_stopping_callback = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=5)

## Training

In [None]:
EPOCHS = 2
STEPS_PER_EPOCH = 800 // BATCH_SIZE
VAL_SUBSPLITS = 5
VALIDATION_STEPS = 200//BATCH_SIZE//VAL_SUBSPLITS

model_history = model.fit(train_ds, 
                          epochs = EPOCHS, 
                          steps_per_epoch = STEPS_PER_EPOCH,
                          validation_steps = VALIDATION_STEPS, 
                          validation_data = test_ds,
                          callbacks = [DisplayCallback(), model_checkpoint_callback, early_stopping_callback])

with open('./histories/history.pickle', 'wb+') as file:
    pickle.dump(model_history, file)

In [None]:
def displayLearningCurves(history):
    plt.plot(history.history['loss'])
    plt.plot(history.history['val_loss'])
    plt.title('Loss curves')
    plt.ylabel('Loss')
    plt.xlabel('Epoch')
    plt.legend(['train', 'test'], loc = 'upper left')
    plt.show()
    
    plt.plot(history.history['categorical_accuracy'])
    plt.plot(history.history['val_categorical_accuracy'])
    plt.title('Accuracy curves')
    plt.ylabel('Acc')
    plt.xlabel('Epoch')
    plt.legend(['train', 'test'], loc = 'upper left')
    plt.show()

In [None]:
displayLearningCurves(model_history)

In [None]:
test = tf.keras.models.load_model(checkpoint_filepath)
test.summary()