In [None]:
!wget http://press.liacs.nl/mirflickr/mirflickr25k.v3b/mirflickr25k.zip

In [None]:
!unzip *.zip && rm -rf *.zip

In [None]:
!nvidia-smi -L

In [None]:
import tensorflow as tf
import os
import numpy as np

class DataLoader(tf.keras.utils.Sequence):
    def __init__(self, paths, SEED_PATH,
                 batch_size):
        self.df = paths.copy()
        np.random.shuffle(self.df)
        self.batch_size = batch_size
        self.input_size = (128, 128, 3)
        self.PATH = SEED_PATH
        
        self.n = len(self.df)
        self.m = 0
        self.max = self.__len__()
    
    def __len__(self):
        return int(len(self.df) / float(self.batch_size))
    
    def __get_input(self, path, target_size):
        image = tf.keras.preprocessing.image.load_img(path)
        image_arr = tf.keras.preprocessing.image.img_to_array(image)
        image_arr = tf.image.resize(image_arr,(target_size[0], target_size[1])).numpy()
        image_arr = image_arr.astype("float32")
        image_arr = image_arr / 255.
        return image_arr
    
    def __get_data(self, batches):
        # Generates data containing batch_size samples
        X_batch = np.asarray([self.__get_input(x, self.input_size) for x in batches])
        return X_batch, X_batch

    def __getitem__(self, index):
        batches = self.df[index * self.batch_size:(index + 1) * self.batch_size]
        X_batches = [os.path.join(self.PATH, img) for img in batches]
        x, y = self.__get_data(X_batches)
        return x, y
    
    def __next__(self):
        if self.m >= self.max:
            self.m = 0
        X, Y = self.__getitem__(self.m)
        self.m += 1
        return X, Y

In [None]:
from tensorflow.keras.layers import *
import tensorflow as tf

In [None]:
import tensorflow as tf

input = tf.keras.layers.Input(shape=(128, 128, 3))

l1 = tf.keras.layers.Conv2D(filters=64, kernel_size=(3, 3), kernel_initializer="he_uniform", padding="same", kernel_regularizer = tf.keras.regularizers.L2(0.001))(input)
l2 = tf.keras.layers.Activation("relu")(l1)
l3 = tf.keras.layers.Conv2D(filters=64, kernel_size=(3, 3), kernel_initializer="he_uniform", padding="same", kernel_regularizer = tf.keras.regularizers.L2(0.001))(l2)
l4 = tf.keras.layers.Activation("relu")(l3)
l5 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(l4)

l6 = tf.keras.layers.Conv2D(filters=128, kernel_size=(3, 3), kernel_initializer="he_uniform", padding="same", kernel_regularizer = tf.keras.regularizers.L2(0.001))(l5)
l7 = tf.keras.layers.Activation("relu")(l6)
l8 = tf.keras.layers.Conv2D(filters=128, kernel_size=(3, 3), kernel_initializer="he_uniform", padding="same", kernel_regularizer = tf.keras.regularizers.L2(0.001))(l7)
l9 = tf.keras.layers.Activation("relu")(l8)
l10 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(l9)

l11 = tf.keras.layers.Conv2D(filters=256, kernel_size=(3, 3), kernel_initializer="he_uniform", padding="same", kernel_regularizer = tf.keras.regularizers.L2(0.001))(l10)
l12 = tf.keras.layers.Activation("relu")(l11)
l13 = tf.keras.layers.Conv2D(filters=256, kernel_size=(3, 3), kernel_initializer="he_uniform", padding="same", kernel_regularizer = tf.keras.regularizers.L2(0.001))(l12)
l14 = tf.keras.layers.Activation("relu")(l13)
l15 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(l14)

encoder_output = tf.keras.layers.Conv2D(filters=512, kernel_size=(3, 3), kernel_initializer="he_uniform", padding="same", kernel_regularizer = tf.keras.regularizers.L2(0.001))(l15)

l16 = tf.keras.layers.Conv2DTranspose(filters=256, kernel_size=(3, 3), kernel_initializer="he_uniform", padding="same", kernel_regularizer = tf.keras.regularizers.L2(0.001))(encoder_output)
l17 = tf.keras.layers.Activation("relu")(l16)
l18 = tf.keras.layers.Conv2DTranspose(filters=256, kernel_size=(3, 3), strides=2, kernel_initializer="he_uniform", padding="same", kernel_regularizer = tf.keras.regularizers.L2(0.001))(l17)
l19 = tf.keras.layers.Activation("relu")(l18)

# Residual Connection
l20 = tf.keras.layers.add([l19, l14])

l21 = tf.keras.layers.Conv2DTranspose(filters=128, kernel_size=(3, 3), kernel_initializer="he_uniform", padding="same", kernel_regularizer = tf.keras.regularizers.L2(0.001))(l20)
l22 = tf.keras.layers.Activation("relu")(l21)
l23 = tf.keras.layers.Conv2DTranspose(filters=128, kernel_size=(3, 3), strides=2, kernel_initializer="he_uniform", padding="same", kernel_regularizer = tf.keras.regularizers.L2(0.001))(l22)
l24 = tf.keras.layers.Activation("relu")(l23)

# Residual Connection
l25 = tf.keras.layers.add([l24, l9])

l26 = tf.keras.layers.Conv2DTranspose(filters=64, kernel_size=(3, 3), kernel_initializer="he_uniform", padding="same", kernel_regularizer = tf.keras.regularizers.L2(0.001))(l25)
l27 = tf.keras.layers.Activation("relu")(l26)
l28 = tf.keras.layers.Conv2DTranspose(filters=64, kernel_size=(3, 3), strides=2, kernel_initializer="he_uniform", padding="same", kernel_regularizer = tf.keras.regularizers.L2(0.001))(l27)
l29 = tf.keras.layers.Activation("relu")(l28)

# Residual Connection
l30 = tf.keras.layers.add([l29, l4])

decoder_output = tf.keras.layers.Conv2D(3, kernel_size=(3, 3), padding='same',activation='sigmoid',name='Decoding_Output')(l30)
model = tf.keras.models.Model(inputs = [input], outputs = [decoder_output])

In [None]:
model.summary()

In [None]:
opt = tf.keras.optimizers.Adam(learning_rate=1e-3)
model.compile(optimizer="adam", loss = "mse")

In [None]:
checkpoint = tf.keras.callbacks.ModelCheckpoint("model.h5")

tensorboard = tf.keras.callbacks.TensorBoard(log_dir="logs")
callbacks = [checkpoint, tensorboard]

In [None]:
import os
images = [image for image in os.listdir("./mirflickr") if image.endswith(".jpg")][:50000]

In [None]:
loader = DataLoader(
  images, 
  SEED_PATH="./mirflickr",
  batch_size=128
)

In [None]:
x, y = next(loader)

In [None]:
x.shape, y.shape

In [None]:
import matplotlib.pyplot as plt
plt.imshow(x[0])

In [None]:
plt.imshow(y[0])

In [None]:
history = model.fit(
    loader,
    epochs=15,
    verbose=1,
    callbacks=callbacks
)

In [None]:
plt.imshow(x[0])

In [None]:
import cv2
image = model.predict(tf.expand_dims(x[0], 0)).reshape((128, 128, 3))
plt.imshow(image)