In [None]:
import os
import pandas as pd
import tensorflow as tf
import numpy as np
from tensorflow.keras import Input, Model, initializers, regularizers, models
from tensorflow.keras.layers import Layer, Conv2D, Conv2DTranspose, GlobalAveragePooling2D, AveragePooling2D, MaxPool2D, UpSampling2D,\
                                    BatchNormalization, Activation, Flatten, Dense, Input,\
                                    Add, Multiply, Concatenate, concatenate, Softmax
import pathlib
from datetime import datetime
from matplotlib import pyplot as plt
from tensorflow.python.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau, EarlyStopping
from tensorflow.python.keras.models import load_model
from tensorflow.python.keras.utils.vis_utils import plot_model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.activations import softmax
import time

In [None]:
"""
permet de déterminer si l'on doit entrainer le modèle de 0 même si ses sauvegardes sont présentes
ou si l'on doit charger le modèle et entrainer depuis son dernier point de sauvegarde s'il existe déjà
"""
train_from_zero = True

"""
permet de choisir le dataset. Si sa valeur est true alors le dataset contenu dans data sera utilisé
sinon celui contenu dans large_data
"""
use_small_dataset = True

if use_small_dataset:
    images_path = "/kaggle/input/flickr8k/Images"
    captions_path = "/kaggle/input/flickr8k/captions.txt"

img_size = (256, 256)


batch_size = 16
# utile pour le chargement du dataset. shuffle_size <= batch_size
shuffle_size = 16

# permet de définir la fréquence (en itérations) de sauvegarde du modèle.
save_freq = 20
epochs = 130

# permet de définir le nombre de pages sur lequel on veut afficher les images présentant les données originales et
# prédites
nb_pages = 5
# permet de définir le nombre d'images par page
nb_imgs_displayed = 5

In [None]:
def preprocess_data():
    """
    renvoie 3 datasets correspondant à train/test/val
    :return:
    """
    # Load the dataset
    images = retrieve_images(captions_path)

    # Split the dataset into training and validation sets
    train_data, test_data, valid_data = train_test_val_split(images)
    print("Number of training samples: ", len(train_data))
    print("Number of test samples: ", len(test_data))
    print("Number of validation samples: ", len(valid_data))

    # Pass the list of images and the list of corresponding captions
    train_dataset = make_dataset(train_data)
    valid_dataset = make_dataset(valid_data)

    test_dataset = (tf.convert_to_tensor([create_noisy_image(img) for img in test_data]), tf.convert_to_tensor([create_clean_image(img) for img in test_data]))

    return train_dataset, valid_dataset, test_dataset

In [None]:
def retrieve_images(filename):
    """
    Reads the FlickR Dataset and returns a list with all the images inside
    :param filename:
    :return: list of image names
    """

    images = set()

    if use_small_dataset:
        with open(filename) as caption_file:
            caption_data = caption_file.readlines()

            for line in caption_data:
                line = line.rstrip("\n")
                # Image name and captions are separated using a comma
                img_name = line.split(",")[0]

                img_name = os.path.join(images_path, img_name.strip())

                if img_name.endswith("jpg"):
                    images.add(img_name)
    else:
        df = pd.read_csv(filename, sep="|")

        for index, row in df.iterrows():
            img_name = row["image_name"]
            img_name = os.path.join(images_path, img_name.strip())

            if img_name.endswith("jpg"):
                images.add(img_name)

    print("Nb Images: {}".format(len(images)))
    return list(images)

In [None]:
def read_image(img_path, size=img_size):
    img = tf.io.read_file(img_path)
    img = tf.image.decode_jpeg(img, channels=3)
    img = tf.image.resize(img, size)
    img = tf.image.convert_image_dtype(img, tf.float32)
    return img

In [None]:
def create_noise(array):
    """
        Adds random noise to each image in the supplied array.
    """

    noise_factor = 75
    noisy_array = array + noise_factor * tf.random.normal(tf.shape(array), mean=0.0, stddev=1.0, seed=1)

    return tf.clip_by_value(noisy_array, 0, 255.0)


def create_noisy_image(img_path, size=img_size):
    return create_noise(read_image(img_path, size)) / 255

def create_clean_image(img_path, size=img_size):
    return read_image(img_path, size) /255


In [None]:
def train_test_val_split(caption_data, test_frac=0.1, val_frac=0.1, shuffle=True):
    """

    :param caption_data:
    :param test_frac:
    :param val_frac:
    :param shuffle:
    :return: train_data, test_data, val_data
    """

    # 1. Get the list of all image names
    all_images = caption_data

    # 2. Shuffle if necessary
    if shuffle:
        np.random.shuffle(all_images)

    # 3. Split
    test_size = int(len(caption_data) * test_frac)
    val_size = int(len(caption_data) * val_frac)

    test_data = all_images[:test_size]
    val_data = all_images[test_size:test_size + val_size]
    train_data = all_images[test_size + val_size:]

    return train_data, test_data, val_data

In [None]:
def make_dataset(images):
    x = tf.data.Dataset.from_tensor_slices(images).map(
        create_noisy_image, num_parallel_calls=tf.data.experimental.AUTOTUNE
    )
    y = tf.data.Dataset.from_tensor_slices(images).map(
        create_clean_image, num_parallel_calls=tf.data.experimental.AUTOTUNE
    )
    dataset = tf.data.Dataset.zip((x, y))
    dataset = dataset.batch(batch_size).shuffle(shuffle_size).prefetch(tf.data.experimental.AUTOTUNE)

    return dataset

In [None]:
class Convolutional_block(Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.conv_1 = Conv2D(filters=64, kernel_size=(3,3), strides=1, padding='same')
        self.conv_2 = Conv2D(filters=64, kernel_size=(3,3), strides=1, padding='same')
        self.conv_3 = Conv2D(filters=64, kernel_size=(3,3), strides=1, padding='same')
        self.conv_4 = Conv2D(filters=64, kernel_size=(3,3), strides=1, padding='same')

    def call(self, X):
        X_1 = self.conv_1(X)
        X_1 = Activation('relu')(X_1)

        X_2 = self.conv_2(X_1)
        X_2 = Activation('relu')(X_2)

        X_3 = self.conv_3(X_2)
        X_3 = Activation('relu')(X_3)

        X_4 = self.conv_4(X_3)
        X_4 = Activation('relu')(X_4)
        
        #print('---conv block=',X_4.shape)
        
        return X_4
    
class Channel_attention(Layer):
    def __init__(self, C=64, **kwargs):
        super().__init__(**kwargs)
        self.C=C
        self.gap = GlobalAveragePooling2D()
        self.dense_middle = Dense(units=2, activation='relu')
        self.dense_sigmoid = Dense(units=self.C, activation='sigmoid')
    
    def get_config(self):
        config = super().get_config().copy()
        config.update({
            'C': self.C
        })
        return config

    def call(self, X):
        v = self.gap(X)
        #print("ca_ after gap =",v.shape)
        fc1 = self.dense_middle(v)
        #print("ca_ after fc1 =",fc1.shape)
        mu = self.dense_sigmoid(fc1)
        #print("ca_ after fc2 =",mu.shape)

        U_out = Multiply()([X, mu])
        
        #print('---channel attention block=',U_out.shape)

        return U_out
    
class Avg_pool_Unet_Upsample_msfe(Layer):
    def __init__(self, avg_pool_size, upsample_rate, **kwargs):
        super().__init__(**kwargs)
        self.avg_pool_size=avg_pool_size
        self.upsample_rate=upsample_rate
        # ---initialization for Avg pooling---
        self.avg_pool = AveragePooling2D(pool_size=avg_pool_size, padding='same')

        # --- initialization for Unet---
        self.deconv_lst = []
        filter=512
        for i in range(4):
            self.deconv_lst.append(Conv2DTranspose(filters=filter/2, kernel_size=[3, 3], strides=2, padding='same'))
            filter/=2

        self.conv_32_down_lst = []
        for i in range(2):
            self.conv_32_down_lst.append(Conv2D(filters=64, kernel_size=[3, 3], activation='relu', padding='same', kernel_regularizer=regularizers.l2(l2=0.001)))

        self.conv_64_down_lst = []
        for i in range(2):
            self.conv_64_down_lst.append(Conv2D(filters=128, kernel_size=[3, 3], activation='relu', padding='same', kernel_regularizer=regularizers.l2(l2=0.001)))

        self.conv_128_down_lst = []
        for i in range(2):
            self.conv_128_down_lst.append(Conv2D(filters=256, kernel_size=[3, 3], activation='relu', padding='same', kernel_regularizer=regularizers.l2(l2=0.001)))

        self.conv_256_down_lst = []
        for i in range(2):
            self.conv_256_down_lst.append(Conv2D(filters=512, kernel_size=[3, 3], activation='relu', padding='same', kernel_regularizer=regularizers.l2(l2=0.001)))

        self.conv_512_down_lst = []
        for i in range(2):
            self.conv_512_down_lst.append(Conv2D(filters=1024, kernel_size=[3, 3], activation='relu', padding='same', kernel_regularizer=regularizers.l2(l2=0.001)))


        self.conv_32_up_lst = []
        for i in range(2):
            self.conv_32_up_lst.append(Conv2D(filters=64, kernel_size=[3, 3], activation='relu', padding='same', kernel_regularizer=regularizers.l2(l2=0.001)))

        self.conv_64_up_lst = []
        for i in range(2):
            self.conv_64_up_lst.append(Conv2D(filters=128, kernel_size=[3, 3], activation='relu', padding='same', kernel_regularizer=regularizers.l2(l2=0.001)))

        self.conv_128_up_lst = []
        for i in range(2):
            self.conv_128_up_lst.append(Conv2D(filters=256, kernel_size=[3, 3], activation='relu', padding='same', kernel_regularizer=regularizers.l2(l2=0.001)))

        self.conv_256_up_lst = []
        for i in range(2):
            self.conv_256_up_lst.append(Conv2D(filters=512, kernel_size=[3, 3], activation='relu', padding='same', kernel_regularizer=regularizers.l2(l2=0.001)))


        self.conv_3 = Conv2D(filters=3, kernel_size=[1, 1])

        self.pooling1_unet = MaxPool2D(pool_size=[2, 2], padding='same')
        self.pooling2_unet = MaxPool2D(pool_size=[2, 2], padding='same')
        self.pooling3_unet = MaxPool2D(pool_size=[2, 2], padding='same')
        self.pooling4_unet = MaxPool2D(pool_size=[2, 2], padding='same')

        # ---initialization for Upsampling---
        self.upsample = UpSampling2D(upsample_rate, interpolation='bilinear')
    
    def get_config(self):
        config = super().get_config().copy()
        config.update({
            'avg_pool_size': self.avg_pool_size,
            'upsample_rate':self.upsample_rate
        })
        return config

    def upsample_and_concat(self, x1, x2, i):
        deconv = self.deconv_lst[i](x1)
        deconv_output = Concatenate()([deconv, x2])
        return deconv_output

    def unet(self, input):
        # ---Unet downsampling---
        conv1 = input
        for c_32 in self.conv_32_down_lst:
            conv1 = c_32(conv1)
        pool1 = self.pooling1_unet(conv1)

        conv2 = pool1
        for c_64 in self.conv_64_down_lst:
            conv2 = c_64(conv2)
        pool2 = self.pooling2_unet(conv2)

        conv3 = pool2
        for c_128 in self.conv_128_down_lst:
            conv3 = c_128(conv3)
        pool3 = self.pooling3_unet(conv3)

        conv4 = pool3
        for c_256 in self.conv_256_down_lst:
            conv4 = c_256(conv4)
        pool4 = self.pooling4_unet(conv4)

        conv5 = pool4
        for c_512 in self.conv_512_down_lst:
            conv5 = c_512(conv5)

        # ---Unet upsampling---
        up6 = self.upsample_and_concat(conv5, conv4, 0)
        conv6 = up6
        for c_256 in self.conv_256_up_lst:
            conv6 = c_256(conv6)

        up7 = self.upsample_and_concat(conv6, conv3, 1)
        conv7 = up7
        for c_128 in self.conv_128_up_lst:
            conv7 = c_128(conv7)

        up8 = self.upsample_and_concat(conv7, conv2, 2)
        conv8 = up8
        for c_64 in self.conv_64_up_lst:
            conv8 = c_64(conv8)

        up9 = self.upsample_and_concat(conv8, conv1, 3)
        conv9 = up9
        for c_32 in self.conv_32_up_lst:
            conv9 = c_32(conv9)

        conv10 = self.conv_3(conv9)
        return conv10

    def call(self, X):
        avg_pool = self.avg_pool(X)
        #print("ap =",avg_pool.shape)
        unet = self.unet(avg_pool)
        upsample = self.upsample(unet)
        return upsample


class Multi_scale_feature_extraction(Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.msfe_16 = Avg_pool_Unet_Upsample_msfe(avg_pool_size=16, upsample_rate=16)
        self.msfe_8 = Avg_pool_Unet_Upsample_msfe(avg_pool_size=8, upsample_rate=8)
        #self.msfe_4 = Avg_pool_Unet_Upsample_msfe(avg_pool_size=4, upsample_rate=4)
        self.msfe_2 = Avg_pool_Unet_Upsample_msfe(avg_pool_size=2, upsample_rate=2)
        #self.msfe_1 = Avg_pool_Unet_Upsample_msfe(avg_pool_size=1, upsample_rate=1)

    def call(self, X):
        up_sample_16 = self.msfe_16(X)
        up_sample_8 = self.msfe_8(X)
        #up_sample_4 = self.msfe_4(X) #if I add it I should add it in the concatenate
        up_sample_2 = self.msfe_2(X)
        #up_sample_1 = self.msfe_1(X)
        msfe_out = Concatenate()([X, up_sample_16, up_sample_8, up_sample_2]) #, up_sample_1])

        #print('---Multi scale feature extraction block=',msfe_out.shape)
        return msfe_out
    
class Kernel_selecting_module(Layer):
    def __init__(self, C=21, **kwargs):
        super().__init__(**kwargs)
        self.C = C
        self.c_3 = Conv2D(filters=self.C, kernel_size=(3,3), strides=1, padding='same', kernel_regularizer=regularizers.l2(l2=0.001))
        self.c_5 = Conv2D(filters=self.C, kernel_size=(5,5), strides=1, padding='same', kernel_regularizer=regularizers.l2(l2=0.001))
        self.c_7 = Conv2D(filters=self.C, kernel_size=(7,7), strides=1, padding='same', kernel_regularizer=regularizers.l2(l2=0.001))
        self.gap = GlobalAveragePooling2D()
        self.dense_two = Dense(units=2, activation='relu')
        self.dense_c1 = Dense(units=self.C)
        self.dense_c2 = Dense(units=self.C)
        self.dense_c3 = Dense(units=self.C)
    
    def get_config(self):
        config = super().get_config().copy()
        config.update({
            'C': self.C
        })
        return config

    def call(self, X):
        X_1 = self.c_3(X)
        X_2 = self.c_5(X)
        X_3 = self.c_7(X)

        X_dash = Add()([X_1, X_2, X_3])

        v_gap = self.gap(X_dash)
        v_gap = tf.reshape(v_gap, [-1, 1, 1, self.C])
        fc1 = self.dense_two(v_gap)

        alpha = self.dense_c1(fc1)
        beta = self.dense_c2(fc1)
        gamma = self.dense_c3(fc1)

        before_softmax = concatenate([alpha, beta, gamma], 1)
        # print(before_softmax.shape)
        after_softmax = softmax(before_softmax, axis=1)
        a1 = after_softmax[:, 0, :, :]
        # print(a1)
        a1 = tf.reshape(a1, [-1, 1, 1, self.C])
        # print(a1)
        a2 = after_softmax[:, 1, :, :]
        a2 = tf.reshape(a2, [-1, 1, 1, self.C])
        a3 = after_softmax[:, 2, :, :]
        a3 = tf.reshape(a3, [-1, 1, 1, self.C])

        select_1 = Multiply()([X_1, a1])
        select_2 = Multiply()([X_2, a2])
        select_3 = Multiply()([X_3, a3])

        out = Add()([select_1, select_2, select_3])

        return out

In [None]:
def get_autoencoder():
    inputs = Input(shape=(*img_size, 3), name="dirty_image")
    conv_block = Convolutional_block()(inputs)
    ca_block = Channel_attention()(conv_block)
    ca_block = Conv2D(filters=3, kernel_size=(3,3), strides=1, padding='same')(ca_block)
    ca_block = Concatenate()([inputs, ca_block])

    msfe_block = Multi_scale_feature_extraction()(ca_block)

    ksm = Kernel_selecting_module()(msfe_block)
    ksm = Conv2D(filters=3, kernel_size=(3,3), strides=1, padding='same', name="clean_image")(ksm)
    model = Model(inputs=[inputs], outputs=[ksm])
    return model

model = get_autoencoder()
model.summary()

In [None]:
def PSNR(y_true, y_pred):
    return tf.image.psnr(y_true, y_pred, 1)

def SSIM(y_true, y_pred):
    return 1 - tf.reduce_mean(tf.image.ssim(y_true, y_pred, 1))

In [None]:
def main():
    train_dataset, valid_dataset, test_dataset = preprocess_data()

    if use_small_dataset:
        filepath = "model"
        partial_filepath = "partial_model"
    else:
        filepath = "model_large"
        partial_filepath = "partial_model_large"

    if not train_from_zero and os.path.isdir("/kaggle/working/" + filepath):
        autoencoder = load_model(filepath)
    else:
        checkpoint = ModelCheckpoint(partial_filepath, save_best_only=True)
        reduce = ReduceLROnPlateau(monitor='loss', factor=0.2, patience=5, min_lr=1e-5, verbose=1)
        early_stopping = EarlyStopping(patience=10, restore_best_weights=True)

        # we check if the model has any checkpoint
        if not train_from_zero and os.path.isdir("/kaggle/working/" + partial_filepath):
            path = partial_filepath
            autoencoder = load_model(partial_filepath)

            nb_epochs = epochs
        else:
            path = filepath
            autoencoder = get_autoencoder()
            autoencoder.compile(optimizer="adam", loss="binary_crossentropy")
            autoencoder.summary()
            plot_model(
                autoencoder, to_file='model_architecture.png', show_shapes=True, show_dtype=False,
                show_layer_names=True, rankdir='TB', expand_nested=True, dpi=200
            )
            nb_epochs = epochs

        then = time.time()

        history = autoencoder.fit(
            train_dataset,
            epochs=epochs,
            validation_data=valid_dataset,
            callbacks=[early_stopping, checkpoint, reduce]
        )
        print(f"Training finished for {nb_epochs} epochs after {time.time() - then} seconds")
        autoencoder.save(path)

        dirty_input, clean_output = test_dataset
        then = time.time()
        prediction = autoencoder.predict(dirty_input)
        print(f"Inference finished after {time.time() - then} seconds")

        show_history(history)

        then = time.time()
        eval = autoencoder.evaluate(dirty_input, clean_output)
        print(f"Evaluation finished after {time.time() - then} seconds")
        #print(f"The result of evaluation is: Loss: {eval[0]}, MSE: {eval[1]}")
        print(f"The result of evaluation is: Loss: {eval}")
        
        display(dirty_input, clean_output, prediction)

In [None]:
def show_history(history):

    # retrieve all the indicators of the evolution of the training
    # loss,...

    indicators = list(history.history.keys())

    # we only need the first half because the second is just the validation metrics

    indicators = indicators[:len(indicators)//2]

    for focus in indicators:
        # Defining Figure
        f = plt.figure(figsize=(10, 7))
        f.add_subplot()

        # Adding Subplot
        plt.plot(history.epoch, history.history[focus], label=focus)
        plt.plot(history.epoch, history.history[f'val_{focus}'], label=f"val_{focus}")

        plt.title(f"{focus} Curve", fontsize=18)
        plt.xlabel("Epochs", fontsize=15)
        plt.ylabel(focus, fontsize=15)
        plt.grid(alpha=0.3)
        plt.legend()
        plt.savefig(f"{focus}_curve.png")
        # plt.show()
        plt.close()


def display(noisy_img, clean_img, predicted_img):
    pathlib.Path('imgs').mkdir(parents=True, exist_ok=True)

    clean_img = tf.cast(tf.multiply(clean_img, 255.0), dtype=tf.int32)
    noisy_img = tf.cast(tf.multiply(noisy_img, 255.0), dtype=tf.int32)
    predicted_img = tf.cast(tf.multiply(predicted_img, 255.0), dtype=tf.int32)

    indices = np.random.randint(len(clean_img), size=nb_imgs_displayed*nb_pages)
    indices = indices.reshape((nb_pages, nb_imgs_displayed))

    row_titles = ["original\n", "dirty\n", "predicted\n"]

    for page in range(nb_pages):
        print("Plot")
        index = 0
        noisy = []
        clean = []
        predicted = []
        for i in indices[page]:
            clean.append(clean_img[i])
            noisy.append(noisy_img[i])
            predicted.append(predicted_img[i])

        images = [clean, noisy, predicted]
        fig = plt.figure(figsize=(10, 6))
        # create 3x1 subfigs
        subfigs = fig.subfigures(nrows=3, ncols=1)

        for row, subfig in enumerate(subfigs):
            subfig.suptitle(row_titles[row-1])
            col_index = index
            # create 1x3 subplots per subfig
            axs = subfig.subplots(nrows=1, ncols=nb_imgs_displayed)
            for col, ax in enumerate(axs):
                ax.imshow(images[row-1][col_index])
                col_index += 1
                ax.get_xaxis().set_visible(False)
                ax.get_yaxis().set_visible(False)

        index += nb_imgs_displayed

        plt.savefig(f"imgs/{int(round(datetime.now().timestamp()))}.png", dpi=200)
        plt.show()
        plt.close()

In [None]:
main()