In [None]:
import os
import datetime

import numpy as np
import matplotlib.pyplot as plt

from PIL import Image

from keras.layers import *
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Model

from tensorflow.keras.optimizers import Adam
from tensorflow.keras.applications.vgg19 import VGG19

In [None]:
"""
Fonction de chargement des images dans un dossier
"""
def load_dataset(path, image_shape):
    
    list_im = []
    ims = []
    
    #On recupere la liste des fichiers contenu dans le dossier
    for im in os.listdir(path):
        list_im.append(os.path.join(path,im))
    
    #On charge les images, on les resize et on les transforme en numpy array
    for i in list_im:
        ims.append(np.array(Image.open(i).resize(image_shape)))
    
    #On retourne la liste d images de type numpy array
    return np.array(ims)

In [None]:
"""
Fonction de normalisation des donnees : les pixels des images seront compris entre 0 et 1
"""
def normalisation(dataset):
    return (dataset.astype(np.float32) - 127.5) / 127.5 

In [None]:
"""
Fonction de prediction de nouvelles images et affichage des résultats
"""
def prediction_et_resultat_plot(x_test_hr, x_test_lr, generateur, nb_images, it, mode):
    
    #On choisit aléatoirement nb_images parmi la base de test
    indice_images = np.random.randint(0, x_test_hr.shape[0], nb_images)
    lr_images = x_test_lr[indice_images]
    hr_images = x_test_hr[indice_images]
    
    #Selon le mode, on charge les poids puis on prédit
    if mode == "train":
        images_generes = generator.predict(lr_images)
    elif mode == "inference":
        generateur.load_weights("./output2/gen_model_final.h5")
        images_generes = generateur.predict(lr_images)
    
    #Denormalisation des images
    lr_images = 0.5 * lr_images + 0.5
    hr_images = 0.5 * hr_images + 0.5
    images_generes = 0.5 * images_generes + 0.5
    
    #Pour chaque, on cree un plot avec l image basse resolution / l image generee / l image haute resolution
    for i in range(lr_images.shape[0]):
        
        plt.figure(figsize=(20, 40))
        plt.subplot(1,3,1)
        plt.imshow(lr_images[i])
        plt.axis('off')
        plt.title("image basse résolution")
        
        plt.subplot(1,3,2)
        plt.imshow(images_generes[i])
        plt.axis('off')
        plt.title("image générée")
        
        plt.subplot(1,3,3)
        plt.imshow(hr_images[i])
        plt.axis('off')
        plt.title("image haute résolution")
        
        plt.savefig('./output2/result_image_%d.png' % i)
        plt.close()  

    return   

In [None]:
"""
Fonction de chargement du modele VGG19
"""
def creation_vgg(hr_shape):
    vgg = VGG19(include_top = False ,  input_shape = hr_shape , weights="imagenet")
    features = vgg.get_layer(index = 9).output
    model = Model(inputs=[vgg.inputs], outputs=[features])
    return model

In [None]:
"""
Fonction de creation du discriminateur (architecture provenant de l'article de recherche du SRGAN(voir readme))
"""
def creation_discriminateur(hr_shape):

    def discri_block(inp, filters, strides = 1, bn = True):
        db = Conv2D(filters = filters, kernel_size = 3, strides = strides, padding='same')(inp)
        if bn:
            db = BatchNormalization(momentum = 0.8)(db)
        db = LeakyReLU(alpha = 0.2)(db)
        return db

    inp = Input(shape = hr_shape)

    d = discri_block(inp, 64, 1, bn=False)
    d = discri_block(d, 64, 2, True)
    d = discri_block(d, 128, 1, True)
    d = discri_block(d, 128, 2, True)
    d = discri_block(d, 256, 1, True)
    d = discri_block(d, 256, 2, True)
    d = discri_block(d, 512, 1, True)
    d = discri_block(d, 512, 2, True)

    d = Dense(1024)(d)
    d = LeakyReLU(alpha = 0.2)(d)
    d_final = Dense(1, activation = 'sigmoid')(d)

    return Model(inp, d_final)

In [None]:
"""
Fonction de creation du generateur (architecture provenant de l'article de recherche du SRGAN(voir readme))
"""
def creation_generateur(lr_shape):
    
    def residual_block(inp):
    
        model_rb = Conv2D(filters = 64, kernel_size = 3, strides = 1, padding = "same")(inp)
        model_rb = BatchNormalization(momentum = 0.8)(model_rb)
        model_rb = PReLU(alpha_initializer='zeros')(model_rb)
        model_rb = Conv2D(filters = 64, kernel_size = 3, strides = 1, padding = "same")(model_rb)
        model_rb = BatchNormalization(momentum = 0.8)(model_rb)
        model_rb = add([inp, model_rb])
    
        return model_rb

    def deconvolution(inp):
    
        model_dc = Conv2D(filters = 256, kernel_size = 3, strides = 1, padding = "same")(inp)
        model_dc = UpSampling2D(size = 2)(model_dc)
        model_dc = LeakyReLU(alpha = 0.2)(model_dc)
    
        return model_dc

    inp = Input(shape = lr_shape)

    model_g = Conv2D(filters = 64, kernel_size=9, strides=1, padding='same')(inp)
    model_g = PReLU(alpha_initializer='zeros')(model_g)
    
    sauv_out = model_g

    for i in range(16):
        model_g = residual_block(model_g)

    model_g = Conv2D(filters = 64, kernel_size = 3, strides = 1, padding='same')(model_g)
    model_g = BatchNormalization(momentum=0.8)(model_g)
    model_g = add([sauv_out, model_g])

    for i in range(2):
        model_g = deconvolution(model_g)

    model_g = Conv2D(filters = 3, kernel_size = 9, strides = 1, padding = 'same', activation = 'tanh')(model_g)

    return Model(inp, model_g)

In [None]:
"""
Fonction de creation du SRGAN
"""
def creation_SRGAN(hr_shape, lr_shape):
    
    lr_images = Input(lr_shape)
    hr_images = Input(hr_shape)
    
    generated_hr = generateur(lr_images)
    generated_feature_map = vgg(generated_hr)
    
    #On entraine pas le discriminateur ici (on le fait avant)
    discriminateur.trainable = False
    
    return Model([lr_images, hr_images], [discriminateur(generated_hr), generated_feature_map])

In [None]:
"""
Fonction d entrainement du modele
"""
def train(generateur, discriminateur, srgan, vgg, x_train_hr, x_train_lr, epochs, batch_size):
        
        shape_output_discrinateur = (16, 16, 1)
        start_time = datetime.datetime.now()

        for epoch in range(epochs + 1):

            #On choisit aléatoirement nb_images parmi la base d entrainement
            indice_images = np.random.randint(0, x_test_hr.shape[0], batch_size)
            lr_images = x_test_lr[indice_images]
            hr_images = x_test_hr[indice_images]

            generated_images = generateur.predict(lr_images)
            
            #L output shape du discriminateur est : (batch_size,16,16,1)
            #on associe la classe 1 aux vrais images et 0 aux images generees
            target_1 = np.ones((batch_size,) + shape_output_discrinateur)
            target_0 = np.zeros((batch_size,) + shape_output_discrinateur)
            
            #Entrainement du discriminateur et recuperation des erreurs
            d_loss_vrai_im = discriminateur.train_on_batch(hr_images, target_1)
            d_loss_gen_im = discriminateur.train_on_batch(generated_images, target_0)
            
            #Moyenne des erreurs
            #d_loss = 0.5 * np.add(d_loss_vrai_im, d_loss_gen_im)

            #On choisit aléatoirement nb_images parmi la base d entrainement
            indice_images = np.random.randint(0, x_test_hr.shape[0], batch_size)
            lr_images = x_test_lr[indice_images]
            hr_images = x_test_hr[indice_images]

            target_1 = np.ones((batch_size,) + shape_output_discrinateur)
            
            #On récupères les features maps des images huates résolutions (du modèle VGG19) 
            #pour les comparer aux features maps des images generees
            
            feature_map_hr_images = vgg.predict(hr_images)
            
            #Entrainement du SRGAN avec recuperation de l'erreur
            g_loss = srgan.train_on_batch([lr_images, hr_images], [target_1, feature_map_hr_images])

            #Suivi du temps d'apprentissage
            time = datetime.datetime.now() - start_time
            #Affichage des epochs et du temps
            print("epoch : %d -- time :  %s" % (epoch, time))
                     
            #Affichage des erreurs
            #print("Loss HR , Loss LR, Loss GAN")
            #print(d_loss, g_loss)
            
            #Affichage des images generees et sauvegarde des poids des neurones des differents reseaux toutes les 1k images
            if (epoch % 1000 == 0) and (epoch > 0):
                prediction_et_resultat_plot(x_test_hr, x_test_lr, generator, 2, epoch, "train")
                generateur.save_weights('./output2/gen_model_%d.h5' % epoch)
                #discriminator.save_weights('./output2/dis_model_%d.h5' % epoch)
                #combined.save_weights('./output2/srgan_model_%d.h5' % epoch)

In [None]:
#image haute resolution (hr) de taille 256*256
#image basse resolution (lr) de taille 64*64
#Facteur de 4 entre les images haute et basse résolution
image_shape1 = (256,256)
image_shape2 = (64,64)

#Chargement / Resize / Normalisation des donnees
x_train_hr = normalisation(load_dataset("./div2k/DIV2K_train_HR", image_shape1))
x_train_lr = normalisation(load_dataset("./div2k/DIV2K_train_LR_bicubic/X2", image_shape2))
x_test_hr = normalisation(load_dataset("./div2k/DIV2K_valid_HR", image_shape1))
x_test_lr = normalisation(load_dataset("./div2k/DIV2K_valid_LR_bicubic/X2", image_shape2))

In [None]:
mode = "train"
#mode = "inference"

In [None]:
lr_shape = (64, 64, 3)
hr_shape = (256, 256, 3)

optimizer = Adam(0.0002, 0.5)

#Chargement du VGG et compilation avec la loss MSE
#Ce reseau ne doit pas être entraine !
vgg = creation_vgg(hr_shape)
vgg.trainable = False
vgg.compile(loss = 'mse',optimizer = optimizer,metrics = ['accuracy'])

#Creation du discriminateur et compilation
discriminateur = creation_discriminateur(hr_shape)
discriminateur.compile(loss = 'mse',optimizer = optimizer,metrics=['accuracy'])

#Creation du generateur
generateur = creation_generateur(lr_shape)

#Creation du SRGAN final
srgan = creation_SRGAN(hr_shape,lr_shape)
srgan.compile(loss=['binary_crossentropy','mse'], loss_weights = [1e-3,1], optimizer = optimizer)

if (mode == "train"):
    #Lancement de l'entrainement
    train(generateur, discriminateur, srgan, vgg, x_train_hr, x_train_lr, epochs = 20001, batch_size = 16)
elif (mode == "inference"):
    #Prediction 
    prediction_et_resultat_plot(x_test_hr, x_test_lr, generateur, 2, 3, mode)