# Prerequirements

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import PIL
import cv2
import glob
import imgaug.augmenters as iaa
import imgaug as ia
import datetime

import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow import keras
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Conv2DTranspose, BatchNormalization, Layer, ReLU, Dropout, concatenate
from tensorflow.keras.callbacks import EarlyStopping
from keras import backend as K
from tensorboard.plugins.hparams import api as hp

import warnings
warnings.filterwarnings("ignore")

In [None]:
physical_devices = tf.config.list_physical_devices('GPU')
print(physical_devices)
if len(physical_devices) > 0:
    tf.config.experimental.set_memory_growth(physical_devices[0], True)

In [None]:
def imshow(img: np.array):
    if img.shape[0] * 2 > img.shape[1]:
        fig = plt.figure(figsize=(7, 7))
    else:
        fig = plt.figure(figsize=(20, 20))
    plt.axis('off')
    plt.imshow(img)

In [None]:
! mkdir ./tensorboard/
! mkdir ./tensorboard/autoencoder_skip/
! mkdir ./tensorboard/autoencoder_skip/fit/
! mkdir ./tensorboard/autoencoder_skip/hparam_tuning/
! mkdir ./models/
! mkdir ./models/autoencoder_skip/

# Preparing dataset

In [None]:
# ds = []
# for idx, file in enumerate(glob.glob("/kaggle/input/celeba-dataset/img_align_celeba/img_align_celeba/*")):
#     ds.append(np.array(PIL.Image.open(file).resize((112, 96))))
#     if idx >= 10001: break
# ds = np.array(ds)
# ds.shape

In [None]:
# idxs = np.random.choice(len(ds), 5)
# imshow(np.concatenate(ds[idxs], 1))

In [None]:
# def create_line_mask(img):
#     mask = np.full(img.shape, 255, np.uint8)
#     for _ in range(np.random.randint(6, 10)):
#         x1, x2 = np.random.randint(1, img.shape[1]), np.random.randint(1, img.shape[1])
#         y1, y2 = np.random.randint(1, img.shape[0]), np.random.randint(1, img.shape[0])
#         thickness = np.random.randint(4, 6)
#         cv2.line(mask, (x1, y1), (x2, y2), (1, 1, 1), thickness)

#     masked_image = cv2.bitwise_and(img, mask)

#     return masked_image

In [None]:
# idxs = np.random.choice(len(ds), 5)
# masked = np.array(list(map(create_line_mask, ds[idxs])))
# imshow(np.concatenate(masked, 1))

In [None]:
# for idx, sample in enumerate(ds):
#     PIL.Image.fromarray(sample).save(f'./data/samples/{idx}.png')
#     PIL.Image.fromarray(create_line_mask(sample)).save(f'./data/samples_line_masked/{idx}.png')
#     PIL.Image.fromarray(create_line_mask(sample)).save(f'./data/samples_square_masked/{idx}.png')

# Custom generator

In [None]:
class DataGenerator(keras.utils.Sequence):
    def __init__(self, X, Y, batch_size=128, dim=(112, 96), n_channels=3): 
        self.X = X
        self.Y = Y
        self.batch_size = batch_size
        self.dim = dim
        self.n_channels = n_channels

        self.on_epoch_end()
        assert(len(self.X) == len(self.Y) or len(self.X) > 0)

        
    def __len__(self):
        return int(np.floor(len(self.X) / self.batch_size))


    def __getitem__(self, index):
        indexes = self.indexes[index * self.batch_size : (index+1) * self.batch_size]
        return self.__data_generation(indexes)

    
    def on_epoch_end(self):
        self.indexes = np.arange(len(self.X))
    

    def __data_generation(self, idxs):
        X_batch = np.empty((self.batch_size, self.dim[0], self.dim[1], self.n_channels))
        Y_batch = np.empty((self.batch_size, self.dim[0], self.dim[1], self.n_channels))

        for i, idx in enumerate(idxs):
            image = np.array(PIL.Image.open(self.X[idx]))
            label = np.array(PIL.Image.open(self.Y[idx]))
            if np.random.randint(0, 100) < 20:
                X_batch[i,] = self.augment(image / 255)
                Y_batch[i,] = self.augment(label / 255)
            else:
                X_batch[i,] = image / 255
                Y_batch[i,] = label / 255
            # X_batch[i,] = image / 255
            # Y_batch[i,] = label / 255

        return X_batch, Y_batch

    
    def augment(self, img):
        seq = iaa.Sequential([
            iaa.Sometimes(0.1, iaa.CropAndPad(
            percent=(-0.05, 0.1),
            pad_mode=ia.ALL,
            pad_cval=(0, 255)))
        ])
        return seq(images=img)

In [None]:
def train_test_split(X, Y, train_size=0.8):   
    train_split = int(train_size * len(X))
    
    X_train = X[:train_split]
    Y_train = Y[:train_split]
    
    X_test = X[train_split:]
    Y_test = Y[train_split:]
    
    return X_train, X_test, Y_train, Y_test

In [None]:
X = sorted(glob.glob("/kaggle/input/cv-project3/data/samples_masked/*.png"))[:1000]
Y = sorted(glob.glob("/kaggle/input/cv-project3/data/samples/*.png"))[:1000]

X_train, X_test, Y_train, Y_test = train_test_split(X, Y)

In [None]:
train_gen = DataGenerator(X_train, Y_train)
test_gen = DataGenerator(X_test, Y_test)

# Metrics

In [None]:
def dice_coef(y_true, y_pred):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2. * intersection) / (K.sum(y_true_f + y_pred_f))

In [None]:
def jaccard_distance(y_true, y_pred, smooth=100):
    intersection = K.sum(K.abs(y_true * y_pred), axis=-1)
    sum_ = K.sum(K.abs(y_true) + K.abs(y_pred), axis=-1)
    jac = (intersection + smooth) / (sum_ - intersection + smooth)
    return (1 - jac) * smooth

In [None]:
cosine_similarity = tf.keras.metrics.CosineSimilarity(axis=1)

# Autoencoder

In [None]:
# class Autoencoder(keras.Model):
#     def __init__(self):
#         super(Autoencoder, self).__init__()


#     def __ConvBlock(self, out, kernel_size, prev_layer):
#         cnn = Conv2D(out, kernel_size, padding="same")(prev_layer)
#         cnn = BatchNormalization()(cnn)
#         cnn = ReLU()(cnn)
#         return cnn

#     def __EncodeBlock(self, out, kernel_size, prev_layer, dr_rate=0.1):
#         conv = self.__ConvBlock(out, kernel_size, prev_layer)
#         conv = self.__ConvBlock(out, kernel_size, conv)
#         conv = self.__ConvBlock(out, kernel_size, conv)
#         conv = MaxPooling2D((2, 2))(conv)
#         conv = Dropout(dr_rate)(conv)
#         return conv


#     def __DecodeBlock(self, out, kernel_size, prev_layer):
#         up = Conv2DTranspose(out, kernel_size, strides=(2, 2), padding="same")(prev_layer)
#         up = BatchNormalization()(up)
#         up = ReLU()(up)
#         return up


#     def model(self, input_shape=(112, 96, 3), dr_rate=0.1, kernel_size=(3, 3)):
#         inputs = keras.layers.Input(input_shape)

#         conv1 = self.__EncodeBlock(32, kernel_size, inputs, dr_rate) 
#         conv2 = self.__EncodeBlock(64, kernel_size, conv1, dr_rate)
#         conv3 = self.__EncodeBlock(128, kernel_size, conv2, dr_rate) 
#         conv4 = self.__EncodeBlock(256, kernel_size, conv3, dr_rate) 

#         deconv1 = self.__DecodeBlock(256, kernel_size, conv4)
#         deconv2 = self.__DecodeBlock(128, kernel_size, deconv1)
#         deconv3 = self.__DecodeBlock(64, kernel_size, deconv2)
#         deconv4 = self.__DecodeBlock(32, kernel_size, deconv3)

#         outputs = keras.layers.Conv2D(3, (3, 3), activation='sigmoid', padding='same')(deconv4)

#         return keras.models.Model(inputs=[inputs], outputs=[outputs])

## Parameters tuning on small part of dataset

In [None]:
# HP_DROPOUT = hp.HParam("dropout", hp.Discrete([0.1, 0.2, 0.3, 0.5]))
# HP_OPTIMIZER = hp.HParam('optimizer', hp.Discrete(['adam', 'sgd', 'adagrad']))

# METRIC = 'mean_absolute_error'

# with tf.summary.create_file_writer('./tensorboard/autoencoder/hparam_tuning').as_default():
#     hp.hparams_config(
#         hparams=[HP_DROPOUT, HP_OPTIMIZER],
#         metrics=[hp.Metric(METRIC, display_name='mean_absolute_error')]
#     )

In [None]:
# def train_test_model(hparams):
#     model = Autoencoder().model(input_shape=(112, 96, 3), dr_rate=hparams[HP_DROPOUT])
#     model.compile(
#         optimizer=hparams[HP_OPTIMIZER],
#         loss='mean_absolute_error',
#     )

#     history = model.fit(
#         train_gen, 
#         validation_data = test_gen, 
#         epochs=10, 
#         steps_per_epoch = len(train_gen), 
#         validation_steps = len(test_gen),
#         use_multiprocessing = True,
#     )
    
#     loss = model.evaluate(test_gen[0][0], test_gen[0][1])
#     return loss

In [None]:
# def run(run_dir, hparams):
#     with tf.summary.create_file_writer(run_dir).as_default():
#         hp.hparams(hparams)
#         loss = train_test_model(hparams)
#         tf.summary.scalar(METRIC, loss, step=1)

In [None]:
# session_num = 0

# for dr_rate in HP_DROPOUT.domain.values:
#     for optimizer in HP_OPTIMIZER.domain.values:
#         hparams = {
#             HP_DROPOUT: dr_rate,
#             HP_OPTIMIZER: optimizer,
#         }
#         run_name = "run-%d" % session_num
#         print('--- Starting trial: %s' % run_name)
#         print({h.name: hparams[h] for h in hparams})
#         run('logs/hparam_tuning/' + run_name, hparams)
#         session_num += 1

## Model training

In [None]:
# model = Autoencoder().model(input_shape=(112, 96, 3), dr_rate=0.2)
# model.compile(optimizer='adam', loss='mean_absolute_error', metrics=[dice_coef, jaccard_distance, cosine_similarity])
# keras.utils.plot_model(model, show_shapes=True, to_file='./autoencoder.png')

# early_stopping = tf.keras.callbacks.EarlyStopping(
#     monitor='val_loss', 
#     patience=10, 
#     min_delta=0.001, 
#     restore_best_weights=True
# )

In [None]:
# log_dir = "./tensorboard/autoencoder/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
# tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)

In [None]:
# history = model.fit(
#     train_gen, 
#     validation_data = test_gen, 
#     epochs=500, 
#     steps_per_epoch = len(train_gen), 
#     validation_steps = len(test_gen),
#     use_multiprocessing = True,
#     callbacks=[tensorboard_callback]
# )

In [None]:
# model.save("./models/autoencoder/")

# Autoencoder with skip-connections

In [None]:
class Autoencoder_skip(keras.Model):
    def __init__(self):
        super(Autoencoder_skip, self).__init__()


    def __ConvBlock(self, out, kernel_size, prev_layer):
        cnn = Conv2D(out, kernel_size, padding="same")(prev_layer)
        cnn = BatchNormalization()(cnn)
        cnn = ReLU()(cnn)
        return cnn

    def __EncodeBlock(self, out, kernel_size, prev_layer, dr_rate=0.1):
        conv = self.__ConvBlock(out, kernel_size, prev_layer)
        conv = self.__ConvBlock(out, kernel_size, conv)
        conv = self.__ConvBlock(out, kernel_size, conv)
        pool = MaxPooling2D((2, 2))(conv)
        pool = Dropout(dr_rate)(pool)
        return conv, pool


    def __DecodeBlock(self, out, conv_out, kernel_size, prev_layer, skip_con):
        conv = self.__ConvBlock(conv_out, kernel_size, prev_layer)
        up = Conv2DTranspose(out, kernel_size, strides=(2, 2), padding="same")(conv)
        up = keras.layers.concatenate([up, skip_con], axis=3)
        up = BatchNormalization()(up)
        up = ReLU()(up)
        return up


    def model(self, input_shape=(112, 96, 3), dr_rate=0.1, kernel_size=(3, 3)):
        inputs = keras.layers.Input(input_shape)

        conv1, pool1 = self.__EncodeBlock(32, kernel_size, inputs, dr_rate)
        conv2, pool2 = self.__EncodeBlock(64, kernel_size, pool1, dr_rate)
        conv3, pool3 = self.__EncodeBlock(128, kernel_size, pool2, dr_rate)
        conv4, pool4 = self.__EncodeBlock(256, kernel_size, pool3, dr_rate) 

        deconv1 = self.__DecodeBlock(256, 512, kernel_size, pool4, conv4)
        deconv2 = self.__DecodeBlock(128, 256, kernel_size, deconv1, conv3)
        deconv3 = self.__DecodeBlock(64, 128, kernel_size, deconv2, conv2)
        deconv4 = self.__DecodeBlock(32, 64, kernel_size, deconv3, conv1)

        outputs = keras.layers.Conv2D(3, (3, 3), activation='sigmoid', padding='same')(deconv4)

        return keras.models.Model(inputs=[inputs], outputs=[outputs])

## Parameters tuning

In [None]:
HP_DROPOUT = hp.HParam("dropout", hp.Discrete([0.1, 0.2, 0.3, 0.5]))
HP_OPTIMIZER = hp.HParam('optimizer', hp.Discrete(['adam', 'sgd', 'adagrad']))

METRIC = 'mean_absolute_error'

with tf.summary.create_file_writer('./tensorboard/autoencoder_skip/hparam_tuning/').as_default():
    hp.hparams_config(
        hparams=[HP_DROPOUT, HP_OPTIMIZER],
        metrics=[hp.Metric(METRIC, display_name='mean_absolute_error')]
    )

In [None]:
def train_test_model(hparams):
    model = Autoencoder_skip().model(input_shape=(112, 96, 3), dr_rate=hparams[HP_DROPOUT])
    model.compile(
        optimizer=hparams[HP_OPTIMIZER],
        loss='mean_absolute_error',
    )

    history = model.fit(
        train_gen, 
        validation_data = test_gen, 
        epochs=10, 
        steps_per_epoch = len(train_gen), 
        validation_steps = len(test_gen),
        use_multiprocessing = True,
    )
    
    loss = model.evaluate(test_gen[0][0], test_gen[0][1])
    return loss

In [None]:
def run(run_dir, hparams):
    with tf.summary.create_file_writer(run_dir).as_default():
        hp.hparams(hparams)
        loss = train_test_model(hparams)
        tf.summary.scalar(METRIC, loss, step=1)

In [None]:
session_num = 0

for dr_rate in HP_DROPOUT.domain.values:
    for optimizer in HP_OPTIMIZER.domain.values:
        hparams = {
            HP_DROPOUT: dr_rate,
            HP_OPTIMIZER: optimizer,
        }
        run_name = "run-%d" % session_num
        print('--- Starting trial: %s' % run_name)
        print({h.name: hparams[h] for h in hparams})
        run('logs/hparam_tuning/' + run_name, hparams)
        session_num += 1

## Model training

In [None]:
# model = Autoencoder_skip().model(input_shape=(112, 96, 3), dr_rate=0.2)
# model.compile(optimizer='adam', loss='mean_absolute_error', metrics=[dice_coef, jaccard_distance, cosine_similarity])
# # keras.utils.plot_model(model, show_shapes=True)

In [None]:
# history = model.fit(
#     train_gen, 
#     validation_data = test_gen, 
#     epochs=100, 
#     steps_per_epoch = len(train_gen), 
#     validation_steps = len(test_gen),
#     use_multiprocessing = True,
# #     callbacks=[tensorboard_callback]
# )

In [None]:
# imshow(model.predict(np.array([test_gen[0][0][3]])).reshape(112, 96, 3))

In [None]:
# imshow(test_gen[0][0][3])