In [1]:
import os
import datetime

import numpy as np
import matplotlib.pyplot as plt

from PIL import Image

from keras.layers import *
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Model

from tensorflow.keras.optimizers import Adam
from tensorflow.keras.applications.vgg19 import VGG19

In [2]:
"""
Fonction de chargement des images dans un dossier
"""
def load_dataset(path, image_shape):
    
    list_im = []
    ims = []
    
    #On recupere la liste des fichiers contenu dans le dossier
    for im in os.listdir(path):
        list_im.append(os.path.join(path,im))
    
    #On charge les images, on les resize et on les transforme en numpy array
    for i in list_im:
        ims.append(np.array(Image.open(i).resize(image_shape)))
    
    #On retourne la liste d images de type numpy array
    return np.array(ims)

In [3]:
"""
Fonction de normalisation des donnees : les pixels des images seront compris entre 0 et 1
"""
def normalisation(dataset):
    return (dataset.astype(np.float32) - 127.5) / 127.5 

In [4]:
"""
Fonction de prediction de nouvelles images et affichage des résultats
"""
def prediction_et_resultat_plot(x_test_hr, x_test_lr, generateur, nb_images, it, mode):
    
    #On choisit aléatoirement nb_images parmi la base de test
    indice_images = np.random.randint(0, x_test_hr.shape[0], nb_images)
    lr_images = x_test_lr[indice_images]
    hr_images = x_test_hr[indice_images]
    
    #Selon le mode, on charge les poids puis on prédit
    if mode == "train":
        images_generes = generator.predict(lr_images)
    elif mode == "inference":
        generateur.load_weights("./output2/gen_model_final.h5")
        images_generes = generateur.predict(lr_images)
    
    #Denormalisation des images
    lr_images = 0.5 * lr_images + 0.5
    hr_images = 0.5 * hr_images + 0.5
    images_generes = 0.5 * images_generes + 0.5
    
    #Pour chaque, on cree un plot avec l image basse resolution / l image generee / l image haute resolution
    for i in range(lr_images.shape[0]):
        
        plt.figure(figsize=(20, 40))
        plt.subplot(1,3,1)
        plt.imshow(lr_images[i])
        plt.axis('off')
        plt.title("image basse résolution")
        
        plt.subplot(1,3,2)
        plt.imshow(images_generes[i])
        plt.axis('off')
        plt.title("image générée")
        
        plt.subplot(1,3,3)
        plt.imshow(hr_images[i])
        plt.axis('off')
        plt.title("image haute résolution")
        
        plt.savefig('./output2/result_image_%d.png' % i)
        plt.close()  

    return   

In [5]:
"""
Fonction de chargement du modele VGG19
"""
def creation_vgg(hr_shape):
    vgg = VGG19(include_top = False ,  input_shape = hr_shape , weights="imagenet")
    features = vgg.get_layer(index = 9).output
    model = Model(inputs=[vgg.inputs], outputs=[features])
    return model

In [6]:
"""
Fonction de creation du discriminateur (architecture provenant de l'article de recherche du SRGAN(voir readme))
"""
def creation_discriminateur(hr_shape):

    def discri_block(inp, filters, strides = 1, bn = True):
        db = Conv2D(filters = filters, kernel_size = 3, strides = strides, padding='same')(inp)
        if bn:
            db = BatchNormalization(momentum = 0.8)(db)
        db = LeakyReLU(alpha = 0.2)(db)
        return db

    inp = Input(shape = hr_shape)

    d = discri_block(inp, 64, 1, bn=False)
    d = discri_block(d, 64, 2, True)
    d = discri_block(d, 128, 1, True)
    d = discri_block(d, 128, 2, True)
    d = discri_block(d, 256, 1, True)
    d = discri_block(d, 256, 2, True)
    d = discri_block(d, 512, 1, True)
    d = discri_block(d, 512, 2, True)

    d = Dense(1024)(d)
    d = LeakyReLU(alpha = 0.2)(d)
    d_final = Dense(1, activation = 'sigmoid')(d)

    return Model(inp, d_final)

In [22]:
"""
Fonction de creation du generateur (architecture provenant de l'article de recherche du SRGAN(voir readme))
"""
def creation_generateur(lr_shape):
    
    def residual_block(inp):
    
        model_rb = Conv2D(filters = 64, kernel_size = 3, strides = 1, padding = "same")(inp)
        model_rb = BatchNormalization(momentum = 0.8)(model_rb)
        model_rb = PReLU(alpha_initializer='zeros')(model_rb)
        model_rb = Conv2D(filters = 64, kernel_size = 3, strides = 1, padding = "same")(model_rb)
        model_rb = BatchNormalization(momentum = 0.8)(model_rb)
        model_rb = add([inp, model_rb])
    
        return model_rb

    def deconvolution(inp):
    
        model_dc = Conv2D(filters = 256, kernel_size = 3, strides = 1, padding = "same")(inp)
        model_dc = UpSampling2D(size = 2)(model_dc)
        model_dc = LeakyReLU(alpha = 0.2)(model_dc)
    
        return model_dc

    inp = Input(shape = lr_shape)

    model_g = Conv2D(filters = 64, kernel_size=9, strides=1, padding='same')(inp)
    model_g = PReLU(alpha_initializer='zeros')(model_g)
    
    sauv_out = model_g

    for i in range(16):
        model_g = residual_block(model_g)

    model_g = Conv2D(filters = 64, kernel_size = 3, strides = 1, padding='same')(model_g)
    model_g = BatchNormalization(momentum=0.8)(model_g)
    model_g = add([sauv_out, model_g])

    for i in range(2):
        model_g = deconvolution(model_g)

    model_g = Conv2D(filters = 3, kernel_size = 9, strides = 1, padding = 'same', activation = 'tanh')(model_g)

    return Model(inp, model_g)

In [8]:
"""
Fonction de creation du SRGAN
"""
def creation_SRGAN(hr_shape, lr_shape):
    
    lr_images = Input(lr_shape)
    hr_images = Input(hr_shape)
    
    generated_hr = generateur(lr_images)
    generated_feature_map = vgg(generated_hr)
    
    #On entraine pas le discriminateur ici (on le fait avant)
    discriminateur.trainable = False
    
    return Model([lr_images, hr_images], [discriminateur(generated_hr), generated_feature_map])

In [9]:
"""
Fonction d entrainement du modele
"""
def train(generateur, discriminateur, srgan, vgg, x_train_hr, x_train_lr, epochs, batch_size):
        
        shape_output_discrinateur = (16, 16, 1)
        start_time = datetime.datetime.now()

        for epoch in range(epochs + 1):

            #On choisit aléatoirement nb_images parmi la base d entrainement
            indice_images = np.random.randint(0, x_test_hr.shape[0], batch_size)
            lr_images = x_test_lr[indice_images]
            hr_images = x_test_hr[indice_images]

            generated_images = generateur.predict(lr_images)
            
            #L output shape du discriminateur est : (batch_size,16,16,1)
            #on associe la classe 1 aux vrais images et 0 aux images generees
            target_1 = np.ones((batch_size,) + shape_output_discrinateur)
            target_0 = np.zeros((batch_size,) + shape_output_discrinateur)
            
            #Entrainement du discriminateur et recuperation des erreurs
            d_loss_vrai_im = discriminateur.train_on_batch(hr_images, target_1)
            d_loss_gen_im = discriminateur.train_on_batch(generated_images, target_0)
            
            #Moyenne des erreurs
            #d_loss = 0.5 * np.add(d_loss_vrai_im, d_loss_gen_im)

            #On choisit aléatoirement nb_images parmi la base d entrainement
            indice_images = np.random.randint(0, x_test_hr.shape[0], batch_size)
            lr_images = x_test_lr[indice_images]
            hr_images = x_test_hr[indice_images]

            target_1 = np.ones((batch_size,) + shape_output_discrinateur)
            
            #On récupères les features maps des images huates résolutions (du modèle VGG19) 
            #pour les comparer aux features maps des images generees
            
            feature_map_hr_images = vgg.predict(hr_images)
            
            #Entrainement du SRGAN avec recuperation de l'erreur
            g_loss = srgan.train_on_batch([lr_images, hr_images], [target_1, feature_map_hr_images])

            #Suivi du temps d'apprentissage
            time = datetime.datetime.now() - start_time
            #Affichage des epochs et du temps
            print("epoch : %d -- time :  %s" % (epoch, time))
                     
            #Affichage des erreurs
            #print("Loss HR , Loss LR, Loss GAN")
            #print(d_loss, g_loss)
            
            #Affichage des images generees et sauvegarde des poids des neurones des differents reseaux toutes les 1k images
            if (epoch % 1000 == 0) and (epoch > 0):
                prediction_et_resultat_plot(x_test_hr, x_test_lr, generator, 2, epoch, "train")
                generateur.save_weights('./output2/gen_model_%d.h5' % epoch)
                #discriminator.save_weights('./output2/dis_model_%d.h5' % epoch)
                #combined.save_weights('./output2/srgan_model_%d.h5' % epoch)

In [10]:
#image haute resolution (hr) de taille 256*256
#image basse resolution (lr) de taille 64*64
#Facteur de 4 entre les images haute et basse résolution
image_shape1 = (256,256)
image_shape2 = (64,64)

#Chargement / Resize / Normalisation des donnees
x_train_hr = normalisation(load_dataset("./div2k/DIV2K_train_HR", image_shape1))
x_train_lr = normalisation(load_dataset("./div2k/DIV2K_train_LR_bicubic/X2", image_shape2))
x_test_hr = normalisation(load_dataset("./div2k/DIV2K_valid_HR", image_shape1))
x_test_lr = normalisation(load_dataset("./div2k/DIV2K_valid_LR_bicubic/X2", image_shape2))

In [11]:
mode = "train"
#mode = "inference"

In [23]:
lr_shape = (64, 64, 3)
hr_shape = (256, 256, 3)

optimizer = Adam(0.0002, 0.5)

#Chargement du VGG et compilation avec la loss MSE
#Ce reseau ne doit pas être entraine !
vgg = creation_vgg(hr_shape)
vgg.trainable = False
vgg.compile(loss = 'mse',optimizer = optimizer,metrics = ['accuracy'])

#Creation du discriminateur et compilation
discriminateur = creation_discriminateur(hr_shape)
discriminateur.compile(loss = 'mse',optimizer = optimizer,metrics=['accuracy'])

#Creation du generateur
generateur = creation_generateur(lr_shape)

#Creation du SRGAN final
srgan = creation_SRGAN(hr_shape,lr_shape)
srgan.compile(loss=['binary_crossentropy','mse'], loss_weights = [1e-3,1], optimizer = optimizer)

if (mode == "train"):
    #Lancement de l'entrainement
    train(generateur, discriminateur, srgan, vgg, x_train_hr, x_train_lr, epochs = 20001, batch_size = 16)
elif (mode == "inference"):
    #Prediction 
    prediction_et_resultat_plot(x_test_hr, x_test_lr, generateur, 2, 3, mode)

epoch : 0 -- time :  0:00:25.414697
epoch : 1 -- time :  0:00:28.523527
epoch : 2 -- time :  0:00:31.654955
epoch : 3 -- time :  0:00:34.826622
epoch : 4 -- time :  0:00:37.967155
epoch : 5 -- time :  0:00:41.139949
epoch : 6 -- time :  0:00:44.268165
epoch : 7 -- time :  0:00:47.411929
epoch : 8 -- time :  0:00:50.546468
epoch : 9 -- time :  0:00:53.631740
epoch : 10 -- time :  0:00:56.775452
epoch : 11 -- time :  0:00:59.896094
epoch : 12 -- time :  0:01:03.089120
epoch : 13 -- time :  0:01:06.226826
epoch : 14 -- time :  0:01:09.377295
epoch : 15 -- time :  0:01:12.565969
epoch : 16 -- time :  0:01:15.734150
epoch : 17 -- time :  0:01:18.863649
epoch : 18 -- time :  0:01:22.006377
epoch : 19 -- time :  0:01:25.158026
epoch : 20 -- time :  0:01:28.305491
epoch : 21 -- time :  0:01:31.434977
epoch : 22 -- time :  0:01:34.617291
epoch : 23 -- time :  0:01:37.793029
epoch : 24 -- time :  0:01:40.968794
epoch : 25 -- time :  0:01:44.149196
epoch : 26 -- time :  0:01:47.302553
epoch : 27 

epoch : 219 -- time :  0:12:04.217902
epoch : 220 -- time :  0:12:07.419369
epoch : 221 -- time :  0:12:10.608438
epoch : 222 -- time :  0:12:13.784084
epoch : 223 -- time :  0:12:17.235940
epoch : 224 -- time :  0:12:20.441316
epoch : 225 -- time :  0:12:23.624076
epoch : 226 -- time :  0:12:26.857921
epoch : 227 -- time :  0:12:30.078505
epoch : 228 -- time :  0:12:33.270205
epoch : 229 -- time :  0:12:36.444464
epoch : 230 -- time :  0:12:39.637776
epoch : 231 -- time :  0:12:42.852479
epoch : 232 -- time :  0:12:46.034611
epoch : 233 -- time :  0:12:49.268474
epoch : 234 -- time :  0:12:53.088554
epoch : 235 -- time :  0:12:56.271553
epoch : 236 -- time :  0:12:59.445715
epoch : 237 -- time :  0:13:02.633590
epoch : 238 -- time :  0:13:05.834299
epoch : 239 -- time :  0:13:09.045706
epoch : 240 -- time :  0:13:12.261124
epoch : 241 -- time :  0:13:15.493520
epoch : 242 -- time :  0:13:18.681317
epoch : 243 -- time :  0:13:21.888740
epoch : 244 -- time :  0:13:25.092873
epoch : 245 

ResourceExhaustedError:  OOM when allocating tensor with shape[16,64,256,256] and type float on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc
	 [[node gradient_tape/model_15/model_12/block1_conv2/Conv2D/Conv2DBackpropInput
 (defined at /anaconda/envs/azureml_py38_tensorflow/lib/python3.8/site-packages/keras/optimizer_v2/optimizer_v2.py:464)
]]
Hint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info. This isn't available when running in Eager mode.
 [Op:__inference_train_function_25818]

Errors may have originated from an input operation.
Input Source operations connected to node gradient_tape/model_15/model_12/block1_conv2/Conv2D/Conv2DBackpropInput:
In[0] gradient_tape/model_15/model_12/block1_conv2/Conv2D/ShapeN:	
In[1] model_15/model_12/block1_conv2/Conv2D/ReadVariableOp (defined at /anaconda/envs/azureml_py38_tensorflow/lib/python3.8/site-packages/keras/layers/convolutional.py:231)	
In[2] gradient_tape/model_15/model_12/block1_conv2/ReluGrad:

Operation defined at: (most recent call last)
>>>   File "/anaconda/envs/azureml_py38_tensorflow/lib/python3.8/runpy.py", line 193, in _run_module_as_main
>>>     return _run_code(code, main_globals, None,
>>> 
>>>   File "/anaconda/envs/azureml_py38_tensorflow/lib/python3.8/runpy.py", line 86, in _run_code
>>>     exec(code, run_globals)
>>> 
>>>   File "/anaconda/envs/azureml_py38_tensorflow/lib/python3.8/site-packages/ipykernel_launcher.py", line 16, in <module>
>>>     app.launch_new_instance()
>>> 
>>>   File "/anaconda/envs/azureml_py38_tensorflow/lib/python3.8/site-packages/traitlets/config/application.py", line 845, in launch_instance
>>>     app.start()
>>> 
>>>   File "/anaconda/envs/azureml_py38_tensorflow/lib/python3.8/site-packages/ipykernel/kernelapp.py", line 612, in start
>>>     self.io_loop.start()
>>> 
>>>   File "/anaconda/envs/azureml_py38_tensorflow/lib/python3.8/site-packages/tornado/platform/asyncio.py", line 149, in start
>>>     self.asyncio_loop.run_forever()
>>> 
>>>   File "/anaconda/envs/azureml_py38_tensorflow/lib/python3.8/asyncio/base_events.py", line 567, in run_forever
>>>     self._run_once()
>>> 
>>>   File "/anaconda/envs/azureml_py38_tensorflow/lib/python3.8/asyncio/base_events.py", line 1855, in _run_once
>>>     handle._run()
>>> 
>>>   File "/anaconda/envs/azureml_py38_tensorflow/lib/python3.8/asyncio/events.py", line 81, in _run
>>>     self._context.run(self._callback, *self._args)
>>> 
>>>   File "/anaconda/envs/azureml_py38_tensorflow/lib/python3.8/site-packages/tornado/ioloop.py", line 690, in <lambda>
>>>     lambda f: self._run_callback(functools.partial(callback, future))
>>> 
>>>   File "/anaconda/envs/azureml_py38_tensorflow/lib/python3.8/site-packages/tornado/ioloop.py", line 743, in _run_callback
>>>     ret = callback()
>>> 
>>>   File "/anaconda/envs/azureml_py38_tensorflow/lib/python3.8/site-packages/tornado/gen.py", line 787, in inner
>>>     self.run()
>>> 
>>>   File "/anaconda/envs/azureml_py38_tensorflow/lib/python3.8/site-packages/tornado/gen.py", line 748, in run
>>>     yielded = self.gen.send(value)
>>> 
>>>   File "/anaconda/envs/azureml_py38_tensorflow/lib/python3.8/site-packages/ipykernel/kernelbase.py", line 365, in process_one
>>>     yield gen.maybe_future(dispatch(*args))
>>> 
>>>   File "/anaconda/envs/azureml_py38_tensorflow/lib/python3.8/site-packages/tornado/gen.py", line 209, in wrapper
>>>     yielded = next(result)
>>> 
>>>   File "/anaconda/envs/azureml_py38_tensorflow/lib/python3.8/site-packages/ipykernel/kernelbase.py", line 268, in dispatch_shell
>>>     yield gen.maybe_future(handler(stream, idents, msg))
>>> 
>>>   File "/anaconda/envs/azureml_py38_tensorflow/lib/python3.8/site-packages/tornado/gen.py", line 209, in wrapper
>>>     yielded = next(result)
>>> 
>>>   File "/anaconda/envs/azureml_py38_tensorflow/lib/python3.8/site-packages/ipykernel/kernelbase.py", line 543, in execute_request
>>>     self.do_execute(
>>> 
>>>   File "/anaconda/envs/azureml_py38_tensorflow/lib/python3.8/site-packages/tornado/gen.py", line 209, in wrapper
>>>     yielded = next(result)
>>> 
>>>   File "/anaconda/envs/azureml_py38_tensorflow/lib/python3.8/site-packages/ipykernel/ipkernel.py", line 306, in do_execute
>>>     res = shell.run_cell(code, store_history=store_history, silent=silent)
>>> 
>>>   File "/anaconda/envs/azureml_py38_tensorflow/lib/python3.8/site-packages/ipykernel/zmqshell.py", line 536, in run_cell
>>>     return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)
>>> 
>>>   File "/anaconda/envs/azureml_py38_tensorflow/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 2876, in run_cell
>>>     result = self._run_cell(
>>> 
>>>   File "/anaconda/envs/azureml_py38_tensorflow/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 2922, in _run_cell
>>>     return runner(coro)
>>> 
>>>   File "/anaconda/envs/azureml_py38_tensorflow/lib/python3.8/site-packages/IPython/core/async_helpers.py", line 68, in _pseudo_sync_runner
>>>     coro.send(None)
>>> 
>>>   File "/anaconda/envs/azureml_py38_tensorflow/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3145, in run_cell_async
>>>     has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
>>> 
>>>   File "/anaconda/envs/azureml_py38_tensorflow/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3337, in run_ast_nodes
>>>     if (await self.run_code(code, result,  async_=asy)):
>>> 
>>>   File "/anaconda/envs/azureml_py38_tensorflow/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3417, in run_code
>>>     exec(code_obj, self.user_global_ns, self.user_ns)
>>> 
>>>   File "<ipython-input-23-21ac72f3210c>", line 25, in <module>
>>>     train(generateur, discriminateur, srgan, vgg, x_train_hr, x_train_lr, epochs = 20001, batch_size = 16)
>>> 
>>>   File "<ipython-input-9-032021ca0777>", line 43, in train
>>>     g_loss = srgan.train_on_batch([lr_images, hr_images], [target_1, feature_map_hr_images])
>>> 
>>>   File "/anaconda/envs/azureml_py38_tensorflow/lib/python3.8/site-packages/keras/engine/training.py", line 1900, in train_on_batch
>>>     logs = self.train_function(iterator)
>>> 
>>>   File "/anaconda/envs/azureml_py38_tensorflow/lib/python3.8/site-packages/keras/engine/training.py", line 878, in train_function
>>>     return step_function(self, iterator)
>>> 
>>>   File "/anaconda/envs/azureml_py38_tensorflow/lib/python3.8/site-packages/keras/engine/training.py", line 867, in step_function
>>>     outputs = model.distribute_strategy.run(run_step, args=(data,))
>>> 
>>>   File "/anaconda/envs/azureml_py38_tensorflow/lib/python3.8/site-packages/keras/engine/training.py", line 860, in run_step
>>>     outputs = model.train_step(data)
>>> 
>>>   File "/anaconda/envs/azureml_py38_tensorflow/lib/python3.8/site-packages/keras/engine/training.py", line 816, in train_step
>>>     self.optimizer.minimize(loss, self.trainable_variables, tape=tape)
>>> 
>>>   File "/anaconda/envs/azureml_py38_tensorflow/lib/python3.8/site-packages/keras/optimizer_v2/optimizer_v2.py", line 530, in minimize
>>>     grads_and_vars = self._compute_gradients(
>>> 
>>>   File "/anaconda/envs/azureml_py38_tensorflow/lib/python3.8/site-packages/keras/optimizer_v2/optimizer_v2.py", line 583, in _compute_gradients
>>>     grads_and_vars = self._get_gradients(tape, loss, var_list, grad_loss)
>>> 
>>>   File "/anaconda/envs/azureml_py38_tensorflow/lib/python3.8/site-packages/keras/optimizer_v2/optimizer_v2.py", line 464, in _get_gradients
>>>     grads = tape.gradient(loss, var_list, grad_loss)
>>> 