# SRGAN

Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network
[2017 CVPR] [SRGAN & SRResNet]: 
https://arxiv.org/pdf/1609.04802.pdf

**Abstract**


Despite the breakthroughs in accuracy and speed of single image super-resolution using faster and deeper convolutional neural networks, one central problem remains largely unsolved: how do we recover the finer texture details when we super-resolve at large upscaling factors? The behavior of optimization-based super-resolution methods is principally driven by the choice of the objective function. Recent work has largely focused on minimizing the mean squared reconstruction error. The resulting estimates have high peak signal-to-noise ratios, but they are often lacking high-frequency details and are perceptually unsatisfying in the sense that they fail to match the fidelity expected at the higher resolution. In this paper, we present SRGAN, a generative adversarial network (GAN) for image super-resolution (SR). To our knowledge, it is the first framework capable of inferring photo-realistic natural images for 4x upscaling factors. To achieve this, we propose a perceptual loss function which consists of an adversarial loss and a content loss. The adversarial loss pushes our solution to the natural image manifold using a discriminator network that is trained to differentiate between the super-resolved images and original photo-realistic images. In addition, we use a content loss motivated by perceptual similarity instead of similarity in pixel space. Our deep residual network is able to recover photo-realistic textures from heavily downsampled images on public benchmarks. An extensive mean-opinion-score (MOS) test shows hugely significant gains in perceptual quality using SRGAN. The MOS scores obtained with SRGAN are closer to those of the original high-resolution images than to those obtained with any state-of-the-art method.

**Librairies**

In [1]:
import glob
import shutil
import random
import cv2
import os
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf

from data import STFDOGS20580
from model.srgan import generator, discriminator
from train import SrganTrainer, SrganGeneratorTrainer

**Paramètres GPU**

In [2]:
gpus = tf.config.experimental.list_physical_devices('GPU')

for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)

print("Nombre de GPU disponible : ", len(gpus))

Nombre de GPU disponible :  1


**Paramètres divers**

In [3]:
#Dossier sauvegarde des poids
WEIGHTS_DIR = 'weights/srgan'
N_IMAGES = 100

**Fonctions diverses**

In [4]:
weights_file = lambda filename: os.path.join(WEIGHTS_DIR, filename)

**Téléchargement du dataset / Interpolation x4 / Séparation des données**
  
Par défaut, les images STFDOGS20580 sont stockées dans le dossier `.stfdogs20580` du répertoire racine du projet. Le dataset est téléchargé et préparé automatiquement.

STFDOGS20580 class

In [5]:
stfdogs20580_train = STFDOGS20580(subset='train', n_images=N_IMAGES)
stfdogs20580_valid = STFDOGS20580(subset='valid', n_images=N_IMAGES)

In [6]:
stfdogs20580_train.remove_data(images_archive=False, images_preprocessed=True, cache=True)

In [7]:
train_ds = stfdogs20580_train.dataset(batch_size=16, random_transform=True)
valid_ds = stfdogs20580_valid.dataset(batch_size=16, random_transform=True, repeat_count=1)


Déplacement des données dans dossier HR & shuffle & formatage

Downsampling des données dans dossier LR

Séparation des données HR (train,valid): 

Séparation des données LR (train,valid): 

Création du cache : .stfdogs20580/caches\STFDOGS20580_train_LR_bicubic.cache ...

Création du cache : .stfdogs20580/caches\STFDOGS20580_train_HR.cache ...

Création du cache : .stfdogs20580/caches\STFDOGS20580_valid_LR_bicubic.cache ...

Création du cache : .stfdogs20580/caches\STFDOGS20580_valid_HR.cache ...


**Modélisations**

**modèles pré-entrainés**

Nous allons pré-entrainé le génerateur avant sur la mse afin de fournir un générateir initialisé et améliorer la performance avec une méthode adverserial.

In [8]:
#Création du dossier pour les sauvegardes
os.makedirs(WEIGHTS_DIR, exist_ok=True)

**Generator pre-training**

In [None]:
pre_trainer = SrganGeneratorTrainer(model=generator(), checkpoint_dir=f'.ckpt/pre_generator')
pre_trainer.train(train_ds,
                  valid_ds.take(10),
                  steps=20000, 
                  evaluate_every=100, 
                  save_best_only=False)

pre_trainer.model.save_weights(weights_file('pre_generator.h5'))

**Generator fine-tuning (GAN)**

In [None]:
gan_generator = generator()
gan_generator.load_weights(weights_file('pre_generator.h5'))

gan_trainer = SrganTrainer(generator=gan_generator, discriminator=discriminator())
gan_trainer.train(train_ds,
                  steps=5000)

In [None]:
gan_trainer.generator.save_weights(weights_file('gan_generator.h5'))
gan_trainer.discriminator.save_weights(weights_file('gan_discriminator.h5'))

**Evaluation**

In [None]:
pre_generator = generator()
gan_generator = generator()

pre_generator.load_weights(weights_file('pre_generator.h5'))
gan_generator.load_weights(weights_file('gan_generator.h5'))

In [None]:
from model import resolve_single
from utils import load_image

def resolve_and_plot(lr_image_path):
    lr = load_image(lr_image_path)
    
    pre_sr = resolve_single(pre_generator, lr)
    gan_sr = resolve_single(gan_generator, lr)
    
    plt.figure(figsize=(20, 20))
    
    images = [lr, pre_sr, gan_sr]
    titles = ['LR', 'SR (PRE)', 'SR (GAN)']
    positions = [1, 3, 4]
    
    for i, (img, title, pos) in enumerate(zip(images, titles, positions)):
        plt.subplot(2, 2, pos)
        plt.imshow(img)
        plt.title(title)
        plt.xticks([])
        plt.yticks([])

In [None]:
resolve_and_plot('demo/dog-crop.jpg')