<a href="https://colab.research.google.com/github/svgladysh/Super-Resolution-GAN-Experiments/blob/master/super_resolution_gan_on_celebfaces_attributes.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Super-Resolution using Generative Adversarial Network

### an experiment on "CelebFaces Attributes" dataset


******************************


![](https://github.com/svgladysh/Super-Resolution-GAN-Experiments/raw/master/80.jpeg)

************************



I have also experimented with Super-Resolution GAN on some other datasets and posted my preliminary results in these Notebooks / Kernels:

*****************************

#### - Super-Resolution using GAN on Simpsons dataset

https://www.kaggle.com/svgladysh/super-resolution-gan-on-simpsons-springfield/

*****************************

#### - Super-Resolution using GAN on Labeled Faces in the Wild dataset

https://www.kaggle.com/svgladysh/super-resolution-gan-on-labeled-faces-in-the-wild

******************************

# References:

*************


[1] A Deep Journey into Super-resolution: A Survey

https://arxiv.org/pdf/1904.07523.pdf


******************


[2] Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial
Network 

https://arxiv.org/pdf/1609.04802.pdf

************


[3] Lectures from Yandex on Super-Resolution GAN:  


https://www.youtube.com/watch?v=tk84ia1K8-E

https://www.youtube.com/watch?v=2t05gq13xy0


**************

[4] Lectures from Moscow Institute of Physics and Technology at Deep Learning School:

https://www.dlschool.org/



***********************

[5] Generative Adversarial Networks Projects by Kailash Ahirwar

https://www.amazon.com/Generative-Adversarial-Networks-Projects-next-generation/dp/1789136679

https://github.com/PacktPublishing/Generative-Adversarial-Networks-Projects

****************************

[6] Generative Adversarial Nets 

https://arxiv.org/pdf/1406.2661.pdf

*******************************************************

[7] Perceptual Losses for Real-Time Style Transfer and Super-Resolution

https://arxiv.org/pdf/1603.08155.pdf

*******************************************************

[8]  "CelebFaces Attributes" dataset

http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html

*********************


[9]  Deep Residual Learning for Image Recognition 

https://arxiv.org/pdf/1512.03385.pdf


********************

[10] Very Deep Convolutional Networks for Large-Scale Image Recognition

https://arxiv.org/abs/1409.1556

************************

[11] ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks 

https://arxiv.org/pdf/1809.00219.pdf

*******************

[12] The relativistic discriminator: a key element missing from standard GAN

https://arxiv.org/pdf/1807.00734.pdf

********************



## Super-Resolution: Problem Definition


Super-resolution is the process of recovering a high-resolution (HR) image from a low-resolution (LR) image. 

A recovered HR image then is referred as a super-resolution image or SR image. 

Super-resolution is still considered **a challenging research problem in computer vision**.


##### Challenges:

* Super-Resolution = ill-posed inverse problem

Instead of a single unique solution, there exist multiple solutions for the same low-resolution image. To constrain the solution-space, reliable prior information is typically required. 

* Up-scaling factor = increase in complexity

The complexity of the problem increases as the up-scaling factor increases. At higher factors, the recovery of missing scene details becomes even more complex, and consequently it often leads to reproduction of wrong information. 

* Assessment of the quality of output

Assessment of the quality of output is not straightforward and loosely correlate to human perception.


##### Quantitative metrics

* PSNR = Peak Signal-to-Noise Ratio
PSNR is the ratio between maximum possible power of signal and power of corrupting noise 

* SSIM = Structural Similarity Index
SSIM measures the perceptual difference between two similar images


**********************************************

### Single Image Super-Resolution

* Single Image Super-Resolution (SISR) - generating an up-scaled image from a single source image

* Multuple Image Super-Resolution - generating an up-scaled image from multiple source images


In this Notebook I am focusing only on Single Image Super-Resolution

## Super-Resolution: Methods Classification

Super-resolution methods can be categorized into the following taxonomy according to the authors of the survey [1] based on their features: 



***********************************


![](https://github.com/svgladysh/Super-Resolution-GAN-Experiments/raw/master/SR_Taxonomy.png)



*******************************************

Picture from [1]

In this Notebook / Kernel I will be focusing only on Super-Resolution methods **based on GAN models (see the bars in the right side of the diagram above).



## GAN 

Generative Adversarial Networks (GAN) [6] is a Deep Neural Networks architecture based on a game-theoretic approach, where two components of the model, namely a generator and discriminator, try to compete with each other. 

The Generator is trying to fool the Discriminator by creating faked images. Whereas the Discriminator is trying not to be fooled and learns how to detect faked ones better. In this way the Generator learns to generate better more realistic images.



*******************************

![](https://github.com/svgladysh/Super-Resolution-GAN-Experiments/raw/master/11.png)


Picture from https://mc.ai/review-gan

***************************



### GAN Applied to the problem of Super-Resolution: 

The Generator creates SR images that a Discriminator cannot distinguish as a real HR image or an artificially super-resolved output. In this manner, HR images with better perceptual quality are generated.



********************************

## Loss Functions


**Perceptual Loss** is a weighted sum of the content loss and adversarial loss

***************************

![](https://github.com/svgladysh/Super-Resolution-GAN-Experiments/raw/master/0.png)



****************************

**Content Loss** can be of two types:

         
**Pixel-wise MSE** loss mean squared error between each pixel in real image and a pixel in generated image

**************************

![](https://github.com/svgladysh/Super-Resolution-GAN-Experiments/raw/master/1.png)

***************************

**VGG loss** is the Euclidean distance between the feature maps of the generated image and the real image


***********************

![](https://github.com/svgladysh/Super-Resolution-GAN-Experiments/raw/master/2.png)

***********************


**Adversarial Loss** is calculated based on probabilities provided by Discriminator

*************************

![](https://github.com/svgladysh/Super-Resolution-GAN-Experiments/raw/master/3.png)

**********************

**Discriminator** is trained to solve maximization:

************************

![](https://github.com/svgladysh/Super-Resolution-GAN-Experiments/raw/master/4.png)

***********************

**Generator** is trained to solve minimization:

************************

![](https://github.com/svgladysh/Super-Resolution-GAN-Experiments/raw/master/5.png)

****************************


## SRGAN


The authors of SRGAN [2] proposed to use an adversarial objective function that promotes super-resolved (SR) outputs that lie close to the manifold of natural images.

The main highlight of their work is a multi-task loss formulation that consists of three main parts: 
(1) a MSE loss that encodes pixel-wise similarity, 
(2) a perceptual similarity metric in terms of a distance metric defined over high-level image representation (e.g., deep network features), and 
(3) an adversarial loss that balances a min-max game between a generator and a discriminator (standard GAN objective [6]). 


*********************************************


#### SRGAN Conceptual architecture 

![](https://github.com/svgladysh/Super-Resolution-GAN-Experiments/raw/master/12.jpeg)

HR - High Resolution image

LR - Low Resolution image

SR - Super Resolution image

Generator - estimates for a LR its corresponding HR which is a SR

Discriminator - is trained to distinguish SR and real images

*****************************

***********************

#### SRGAN Neural network architecture

 

 

![](https://github.com/svgladysh/Super-Resolution-GAN-Experiments/raw/master/13.jpeg)



Picture from [2]

*******************************

Letters and numbers in the diagram above indicate the following architectural parameters for each convolutional layer in SRGAN:

* kernel size (k)  
* number of feature maps (n) 
* stride (s) 


*****************************

## ESRGAN

Enhanced Super-Resolution Generative Adversarial Networks (ESRGAN) [11] builds upon SRGAN [2]. ESRGAN main aim is to improve the overall perceptual quality for Super-Resolution. 

ESRGAN core novelties in comparison with SRGAN:

* removed Batch-Norm layers, which proved to increase performance and reduce computational complexity 

![](https://github.com/svgladysh/Super-Resolution-GAN-Experiments/raw/master/bn.png)



Picture from [11]

******************************

* introduced Residual in Residual Dense Block, because more layers and connections always boost performance

![](https://github.com/svgladysh/Super-Resolution-GAN-Experiments/raw/master/rrdb.png)



Picture from [11]

************************************

* discriminator based on the Relativistic GAN [12],  which tries to predict the probability that a real image is relatively more realistic than a fake one


![](https://github.com/svgladysh/Super-Resolution-GAN-Experiments/raw/master/relaGAN.png)



Picture from [11]

**********************************************



***********************************************

# Implementation: initial steps

The SRGAN implementation code is based on [5] with several changes made in order to make it work with CelebFaces Attributes (CelebA) Dataset in Kaggle environment, + some extra features being added to play and experiment with, and hyper-parameters being tailored a bit to the current dataset. 


### Import libraries 

* Начальные шаги по реализации модели. Импортируем библиотеки

In [2]:
import tensorflow as tf
from keras import Input
from keras.applications import VGG19, InceptionResNetV2
from keras.callbacks import TensorBoard
from keras.layers import BatchNormalization, Activation, LeakyReLU, Add, Dense
from keras.layers.convolutional import Conv2D, UpSampling2D
from keras.models import Model
from keras.optimizers import Adam

import glob
import time
import os
import cv2
import base64
import imageio
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import random

from imageio import imread
from skimage.transform import resize as imresize
from copy import deepcopy
from tqdm import tqdm
from pprint import pprint
from PIL import Image
from sklearn.model_selection import train_test_split

Using TensorFlow backend.


 ### Define some  hyper-parameters

* Зададим некоторые гиперпараметры

In [0]:
# для начала укажем количество эпох равным 100, затем конечно нужно будет продолжить ещё
epochs = 2

# размер батча выберем равным 8, так как больший размер уже не влезал в оперативную память на Kaggle
batch_size = 8

# укажем размер изображения с низким разрешением (LR) 
low_resolution_shape = (64, 64, 3)

# укажем размер изображения с высоким разрешением (HR) 
high_resolution_shape = (256, 256, 3)

# для простоты выберем в качестве оптимизатора Adam
common_optimizer = Adam(0.0002, 0.5)

In [4]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [0]:
data_dir = "/content/drive/My Drive/Colab Notebooks/00.NeedToCheck/img_align_celeba/*.*"

#### CelebA dataset contains images (photos) of people provided with attribute labels  [8]

* Датасет CelebA содержит изображения лиц людей с размеченными атрибутами  [8]



# Generator

* Сеть генератор

Создадим в качестве "строительного блока" - residual block 

см. архитектура ResNet в статье [9]    https://arxiv.org/pdf/1512.03385.pdf

In [0]:
def residual_block(x):

    filters = [64, 64]
    #filters = [128, 128]
    kernel_size = 3
    strides = 1
    padding = "same"
    momentum = 0.8
    activation = "relu"

    res = Conv2D(filters=filters[0], kernel_size=kernel_size, strides=strides, padding=padding)(x)
    res = Activation(activation=activation)(res)
    res = BatchNormalization(momentum=momentum)(res)

    res = Conv2D(filters=filters[1], kernel_size=kernel_size, strides=strides, padding=padding)(res)
    res = BatchNormalization(momentum=momentum)(res)

    res = Add()([res, x])
    return res

In [0]:
def build_generator():
    
    # используем в генераторе 16 residual блоков
    residual_blocks = 16
    momentum = 0.8
    
    # размерность соответствует LR - Low Resolution
    input_shape = (64, 64, 3)
    
    # input-слой для сети генератора
    input_layer = Input(shape=input_shape)
    
    # pre-residual block: свёрточный слой перед residual блоками 
    gen1 = Conv2D(filters=64, kernel_size=9, strides=1, padding='same', activation='relu')(input_layer)
    
    # добавляем 16 residual блоков
    res = residual_block(gen1)
    for i in range(residual_blocks - 1):
        res = residual_block(res)
    
    # post-residual block: свёрточный слой и batch-norm слой после residual блоков
    gen2 = Conv2D(filters=64, kernel_size=3, strides=1, padding='same')(res)
    gen2 = BatchNormalization(momentum=momentum)(gen2)
    
    # суммируем выходы из pre-residual block(gen1) и the post-residual block(gen2)
    gen3 = Add()([gen2, gen1])
    
    # слой UpSampling: обучаемся повышать размерность
    gen4 = UpSampling2D(size=2)(gen3)
    gen4 = Conv2D(filters=256, kernel_size=3, strides=1, padding='same')(gen4)
    gen4 = Activation('relu')(gen4)
    
    # слой UpSampling: обучаемся повышать размерность
    gen5 = UpSampling2D(size=2)(gen4)
    gen5 = Conv2D(filters=256, kernel_size=3, strides=1, padding='same')(gen5)
    gen5 = Activation('relu')(gen5)
    
    # слой convolution на выходе
    gen6 = Conv2D(filters=3, kernel_size=9, strides=1, padding='same')(gen5)
    output = Activation('tanh')(gen6)
    
    # модель 
    model = Model(inputs=[input_layer], outputs=[output], name='generator')
    return model

# Discriminator

* Сеть дискриминатор

In [0]:
def build_discriminator():
    
    # зададим гипер-параметры
    leakyrelu_alpha = 0.2
    momentum = 0.8
    
    # размерность соответствует HR - High Resolution
    input_shape = (256, 256, 3)
    
    # input-слой для сети дискриминатора
    input_layer = Input(shape=input_shape)
    
    # 8 свёрточных слоёв с батч-нормализациями  
    dis1 = Conv2D(filters=64, kernel_size=3, strides=1, padding='same')(input_layer)
    dis1 = LeakyReLU(alpha=leakyrelu_alpha)(dis1)

    dis2 = Conv2D(filters=64, kernel_size=3, strides=2, padding='same')(dis1)
    dis2 = LeakyReLU(alpha=leakyrelu_alpha)(dis2)
    dis2 = BatchNormalization(momentum=momentum)(dis2)

    dis3 = Conv2D(filters=128, kernel_size=3, strides=1, padding='same')(dis2)
    dis3 = LeakyReLU(alpha=leakyrelu_alpha)(dis3)
    dis3 = BatchNormalization(momentum=momentum)(dis3)

    dis4 = Conv2D(filters=128, kernel_size=3, strides=2, padding='same')(dis3)
    dis4 = LeakyReLU(alpha=leakyrelu_alpha)(dis4)
    dis4 = BatchNormalization(momentum=0.8)(dis4)

    dis5 = Conv2D(256, kernel_size=3, strides=1, padding='same')(dis4)
    dis5 = LeakyReLU(alpha=leakyrelu_alpha)(dis5)
    dis5 = BatchNormalization(momentum=momentum)(dis5)

    dis6 = Conv2D(filters=256, kernel_size=3, strides=2, padding='same')(dis5)
    dis6 = LeakyReLU(alpha=leakyrelu_alpha)(dis6)
    dis6 = BatchNormalization(momentum=momentum)(dis6)

    dis7 = Conv2D(filters=512, kernel_size=3, strides=1, padding='same')(dis6)
    dis7 = LeakyReLU(alpha=leakyrelu_alpha)(dis7)
    dis7 = BatchNormalization(momentum=momentum)(dis7)

    dis8 = Conv2D(filters=512, kernel_size=3, strides=2, padding='same')(dis7)
    dis8 = LeakyReLU(alpha=leakyrelu_alpha)(dis8)
    dis8 = BatchNormalization(momentum=momentum)(dis8)
    
    # полносвязный слой 
    dis9 = Dense(units=1024)(dis8)
    dis9 = LeakyReLU(alpha=0.2)(dis9)
    
    # последний полносвязный слой на выходе - для классификации 
    output = Dense(units=1, activation='sigmoid')(dis9)
    
    
    model = Model(inputs=[input_layer], outputs=[output], name='discriminator')
    return model

# Pre-trained VGG19 

Pre-trained VGG19 will be used for feature extraction from real images and generated images

* Предобученная сеть VGG19 для извлечения признаков из реальных и сгенерированных изображений

In [0]:
def build_vgg():
    
    # размерность соответствует HR - High Resolution
    input_shape = (256, 256, 3)
    
    # загружаем VGG19 предобученную на датасете 'Imagenet'
    vgg = VGG19(weights="imagenet")
    
    # возьмём выход с 9-го слоя
    vgg.outputs = [vgg.layers[9].output]
    
    # зададим входной слой
    input_layer = Input(shape=input_shape)
    
    # извлечём признаки 
    features = vgg(input_layer)
    
    # модель
    model = Model(inputs=[input_layer], outputs=[features])
    return model

Let's consider pre-trained InceptionResNetV2, ResNet50, ResNet152V2 as possible alternatives, which could be compared with VGG19

* Рассмотрим предобученные InceptionResNetV2, ResNet50, ResNet152V2 в качестве возможных альтернатив, с которыми мы сможем сравнивать VGG19

In [10]:
"""
def build_InceptionResNetV2():
    input_shape = (256, 256, 3)
    
    resnetV2 = InceptionResNetV2(weights='imagenet')
    resnetV2.outputs = [resnetV2.layers[-1].output]
    
    input_layer = Input(shape=input_shape)
    
    features = resnetV2(input_layer)
    
    model = Model(inputs=[input_layer], outputs=[features])
    return model
"""

"\ndef build_InceptionResNetV2():\n    input_shape = (256, 256, 3)\n    \n    resnetV2 = InceptionResNetV2(weights='imagenet')\n    resnetV2.outputs = [resnetV2.layers[-1].output]\n    \n    input_layer = Input(shape=input_shape)\n    \n    features = resnetV2(input_layer)\n    \n    model = Model(inputs=[input_layer], outputs=[features])\n    return model\n"

In [11]:
"""
def build_ResNet50():
    input_shape = (256, 256, 3)
    
    resent50 = ResNet50(weights='imagenet', input_shape=input_shape)
    resnet50.outputs = [resent50.layers[-1].output]
    
    input_layer = Input(shape=input_shape)
    
    features = resnet50(input_layer)
    
    model = Model(inputs=[input_layer], outputs=[features])
    return model
"""

"\ndef build_ResNet50():\n    input_shape = (256, 256, 3)\n    \n    resent50 = ResNet50(weights='imagenet', input_shape=input_shape)\n    resnet50.outputs = [resent50.layers[-1].output]\n    \n    input_layer = Input(shape=input_shape)\n    \n    features = resnet50(input_layer)\n    \n    model = Model(inputs=[input_layer], outputs=[features])\n    return model\n"

In [12]:
"""
def build_ResNet152V2():
    input_shape = (256, 256, 3)
    
    resent152 = InceptionResNetV2(weights='imagenet', input_shape=input_shape)
    resnet152.outputs = [resent152.layers[-1].output]
    
    input_layer = Input(shape=input_shape)
    
    features = resnet152(input_layer)
    
    model = Model(inputs=[input_layer], outputs=[features])
    return model
"""

"\ndef build_ResNet152V2():\n    input_shape = (256, 256, 3)\n    \n    resent152 = InceptionResNetV2(weights='imagenet', input_shape=input_shape)\n    resnet152.outputs = [resent152.layers[-1].output]\n    \n    input_layer = Input(shape=input_shape)\n    \n    features = resnet152(input_layer)\n    \n    model = Model(inputs=[input_layer], outputs=[features])\n    return model\n"

## Sampling images

* Реализуем функцию для сэмплирования изображений

In [0]:
def sample_images(data_dir, batch_size, high_resolution_shape, low_resolution_shape):
    
    # создадим список всех изображений, находящихся внутри каталога data_dir
    all_images = glob.glob(data_dir)
    
    # выберем случайный батч с изображениями
    images_batch = np.random.choice(all_images, size=batch_size)

    low_resolution_images = []
    high_resolution_images = []

    for img in images_batch:
        # получим numpy ndarray текущего изображения
        img1 = imread(img, as_gray=False, pilmode='RGB')
        img1 = img1.astype(np.float32)
        
        # изменим размеры
        img1_high_resolution = imresize(img1, high_resolution_shape)
        img1_low_resolution = imresize(img1, low_resolution_shape)
        
        # применим аугментацию: random horizontal flip
        if np.random.random() < 0.5:
            img1_high_resolution = np.fliplr(img1_high_resolution)
            img1_low_resolution = np.fliplr(img1_low_resolution)

        high_resolution_images.append(img1_high_resolution)
        low_resolution_images.append(img1_low_resolution)
    
    # конвертируем списки в numpy ndarrays
    return np.array(high_resolution_images), np.array(low_resolution_images)

## Saving images

* Реализуем функцию для сохранения изображений

In [0]:
def save_images(low_resolution_image, original_image, generated_image, path):

    # сохраним low-resolution, high-resolution(original) и generated high-resolution изображения в одной картинке

    fig = plt.figure()
    
    ax = fig.add_subplot(1, 3, 1)
    ax.imshow(original_image)
    ax.axis("off")
    ax.set_title("ORIGINAL")
    
    ax = fig.add_subplot(1, 3, 2)
    ax.imshow(low_resolution_image)
    ax.axis("off")
    ax.set_title("LOW_RESOLUTION")

    ax = fig.add_subplot(1, 3, 3)
    ax.imshow(generated_image)
    ax.axis("off")
    ax.set_title("GENERATED")

    plt.savefig(path)


## VGG19 compilation

* Скомпилируем предобученную сеть VGG19

In [15]:
vgg = build_vgg()
vgg.trainable = False
vgg.compile(loss='mse', optimizer=common_optimizer, metrics=['accuracy'])
vgg.summary()












Model: "model_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_2 (InputLayer)         (None, 256, 256, 3)       0         
_________________________________________________________________
vgg19 (Model)                multiple                  143667240 
Total params: 143,667,240
Trainable params: 0
Non-trainable params: 143,667,240
_________________________________________________________________


In [16]:
"""
vgg = build_InceptionResNetV2()
vgg.trainable = False
vgg.compile(loss='mse', optimizer=common_optimizer, metrics=['accuracy'])
"""

"\nvgg = build_InceptionResNetV2()\nvgg.trainable = False\nvgg.compile(loss='mse', optimizer=common_optimizer, metrics=['accuracy'])\n"

In [17]:
"""
resnet50 = build_ResNet50()
resnet50.trainable = False
resnet50.compile(loss='mse', optimizer=common_optimizer, metrics=['accuracy'])
"""

"\nresnet50 = build_ResNet50()\nresnet50.trainable = False\nresnet50.compile(loss='mse', optimizer=common_optimizer, metrics=['accuracy'])\n"

In [18]:
"""
resnet152 = build_ResNet152V2()
resnet152.trainable = False
resnet152.compile(loss='mse', optimizer=common_optimizer, metrics=['accuracy'])
"""

"\nresnet152 = build_ResNet152V2()\nresnet152.trainable = False\nresnet152.compile(loss='mse', optimizer=common_optimizer, metrics=['accuracy'])\n"

## Discriminator compilation

* Скомпилируем сеть дискриминатор

In [19]:

discriminator = build_discriminator()
discriminator.trainable = True
discriminator.compile(loss='mse', optimizer=common_optimizer, metrics=['accuracy'])





In [20]:
"""
discriminator = build_discriminator()
discriminator.trainable = False
discriminator.compile(loss='mse', optimizer=common_optimizer, metrics=['accuracy'])
"""

"\ndiscriminator = build_discriminator()\ndiscriminator.trainable = False\ndiscriminator.compile(loss='mse', optimizer=common_optimizer, metrics=['accuracy'])\n"

## Generator build


* Сделаем build генератора


In [21]:
generator = build_generator()




## Adversarial model compilation

* Скомпилируем созтязательную модель, которая будет включать в себя генератор, дискриминатор и предобученную сеть VGG19

In [0]:
def build_adversarial_model(generator, discriminator, vgg):
    
    # входной слой для high-resolution изображений
    input_high_resolution = Input(shape=high_resolution_shape)

    # входной слой для low-resolution изображений
    input_low_resolution = Input(shape=low_resolution_shape)

    # сгенерируем high-resolution изображения из low-resolution изображений
    generated_high_resolution_images = generator(input_low_resolution)

    # извлечём feature maps из generated images
    features = vgg(generated_high_resolution_images)
    
    # Сделаем здесь внутри GAN дискриминатор необучаемым, потому что 
    # в состязательной сети нам не нужно обучать дискриминатор в то время, когда обучается генератор.
    discriminator.trainable = False
    discriminator.compile(loss='mse', optimizer=common_optimizer, metrics=['accuracy'])

    # дискриминатор даст нам оценку вероятностей generated high-resolution изображений
    probs = discriminator(generated_high_resolution_images)

    # создадим и скомпилируем сотязательную модель
    adversarial_model = Model([input_low_resolution, input_high_resolution], [probs, features])
    adversarial_model.compile(loss=['binary_crossentropy', 'mse'], loss_weights=[1e-3, 1], optimizer=common_optimizer)
    return adversarial_model

In [23]:
adversarial_model = build_adversarial_model(generator, discriminator, vgg)



Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where


# Training loop on CelebA dataset

* Запустим цикл обучения на датасете CelebA



Будем обучать SRGAN в 2 этапа: 
* сначала на первом этапе обучаем дискриминатор
* затем на втором этапе обучаем состязательную сеть, внутри которой у нас обучается генератор, но заморожен дискриминатор
* ... да, вот так по-хитрому! 

In [0]:
for epoch in range(epochs):
    # обучаем дискриминатор
    d_history = []
    g_history = []
    print("Epoch:{}".format(epoch))
    
    # сэмплируем батч с изображениями
    high_resolution_images, low_resolution_images = sample_images(data_dir=data_dir, batch_size=batch_size,
                                                                          low_resolution_shape=low_resolution_shape,
                                                                          high_resolution_shape=high_resolution_shape)
    
    # нормализуем изображения
    high_resolution_images = high_resolution_images / 127.5 - 1.
    low_resolution_images = low_resolution_images / 127.5 - 1.
    
    # сгенерируем high-resolution изображения из low-resolution изображений
    generated_high_resolution_images = generator.predict(low_resolution_images)
    
    # сгенерируем батч настоящих и поддельных меток
    real_labels = np.ones((batch_size, 16, 16, 1))
    fake_labels = np.zeros((batch_size, 16, 16, 1))
    
    # обучим дискриминатор на настоящих и поддельных изображениях
    d_loss_real = discriminator.train_on_batch(high_resolution_images, real_labels)
    d_loss_real =  np.mean(d_loss_real)
    d_loss_fake = discriminator.train_on_batch(generated_high_resolution_images, fake_labels)
    d_loss_fake =  np.mean(d_loss_fake)
    # посчитаем общий loss дискриминатора как среднее арифметическое потерь на настоящих и на поддельных метках
    d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
    d_history.append(d_loss)
    print("D_loss:", d_loss)
    
    
    # обучаем генератор
    
    # сэмплируем батч с изображениями
    high_resolution_images, low_resolution_images = sample_images(data_dir=data_dir, batch_size=batch_size,
                                                                    low_resolution_shape=low_resolution_shape,
                                                                    high_resolution_shape=high_resolution_shape)
    
    #  нормализуем изображения
    high_resolution_images = high_resolution_images / 127.5 - 1.
    low_resolution_images = low_resolution_images / 127.5 - 1.
    
    # извлечём feature maps для настоящих high-resolution изображений
    image_features = vgg.predict(high_resolution_images)
    
    # обучим генератор
    g_loss = adversarial_model.train_on_batch([low_resolution_images, high_resolution_images],
                                             [real_labels, image_features])
    g_history.append( 0.5 * (g_loss[1]) )
    print( "G_loss:", 0.5 * (g_loss[1]) )
    
    # сохраним и выведем сэмплы изображений после каждых 10 эпох
    if epoch % 20 == 0:
        high_resolution_images, low_resolution_images = sample_images(data_dir=data_dir, batch_size=batch_size,
                                                                        low_resolution_shape=low_resolution_shape,
                                                                        high_resolution_shape=high_resolution_shape)
        
        # нормализуем изображения
        high_resolution_images = high_resolution_images / 127.5 - 1.
        low_resolution_images = low_resolution_images / 127.5 - 1.

        generated_images = generator.predict_on_batch(low_resolution_images)

        for index, img in enumerate(generated_images):
            save_images(low_resolution_images[index], high_resolution_images[index], img,
                        path="img_{}_{}".format(epoch, index))


Epoch:0

D_loss: 0.38818633556365967



#### Промежуточная валидация результатов

Посмотрим глазами на эти тройки изображений:
ORIGINAL -- LOW RESOLUTION -- GENERATED

Провалидируем результат, для того чтобы понять, нужно ли продолжать обучение ещё некоторое количество эпох.
Если сгенерились хорошие картинки, то можно прекращать.

In [0]:
for epoch in range(epochs):
    # обучаем дискриминатор
    d_history = []
    g_history = []
    print("Epoch:{}".format(epoch))
    
    # сэмплируем батч с изображениями
    high_resolution_images, low_resolution_images = sample_images(data_dir=data_dir, batch_size=batch_size,
                                                                          low_resolution_shape=low_resolution_shape,
                                                                          high_resolution_shape=high_resolution_shape)
    
    # нормализуем изображения
    high_resolution_images = high_resolution_images / 127.5 - 1.
    low_resolution_images = low_resolution_images / 127.5 - 1.
    
    # сгенерируем high-resolution изображения из low-resolution изображений
    generated_high_resolution_images = generator.predict(low_resolution_images)
    
    # сгенерируем батч настоящих и поддельных меток
    real_labels = np.ones((batch_size, 16, 16, 1))
    fake_labels = np.zeros((batch_size, 16, 16, 1))
    
    # обучим дискриминатор на настоящих и поддельных изображениях
    d_loss_real = discriminator.train_on_batch(high_resolution_images, real_labels)
    d_loss_real =  np.mean(d_loss_real)
    d_loss_fake = discriminator.train_on_batch(generated_high_resolution_images, fake_labels)
    d_loss_fake =  np.mean(d_loss_fake)
    # посчитаем общий loss дискриминатора как среднее арифметическое потерь на настоящих и на поддельных метках
    d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
    d_history.append(d_loss)
    print("D_loss:", d_loss)
    
    
    # обучаем генератор
    
    # сэмплируем батч с изображениями
    high_resolution_images, low_resolution_images = sample_images(data_dir=data_dir, batch_size=batch_size,
                                                                    low_resolution_shape=low_resolution_shape,
                                                                    high_resolution_shape=high_resolution_shape)
    
    #  нормализуем изображения
    high_resolution_images = high_resolution_images / 127.5 - 1.
    low_resolution_images = low_resolution_images / 127.5 - 1.
    
    # извлечём feature maps для настоящих high-resolution изображений
    image_features = vgg.predict(high_resolution_images)
    
    # обучим генератор
    g_loss = adversarial_model.train_on_batch([low_resolution_images, high_resolution_images],
                                             [real_labels, image_features])
    g_history.append( 0.5 * (g_loss[1]) )
    print( "G_loss:", 0.5 * (g_loss[1]) )
    
    # сохраним и выведем сэмплы изображений после каждых 10 эпох
    if epoch % 20 == 0:
        high_resolution_images, low_resolution_images = sample_images(data_dir=data_dir, batch_size=batch_size,
                                                                        low_resolution_shape=low_resolution_shape,
                                                                        high_resolution_shape=high_resolution_shape)
        
        # нормализуем изображения
        high_resolution_images = high_resolution_images / 127.5 - 1.
        low_resolution_images = low_resolution_images / 127.5 - 1.

        generated_images = generator.predict_on_batch(low_resolution_images)

        for index, img in enumerate(generated_images):
            save_images(low_resolution_images[index], high_resolution_images[index], img,
                        path="/kaggle/working/img_{}_{}".format(epoch, index))

## Save models weights

* Сохраним веса моделей

In [0]:
generator.save_weights("/kaggle/working/generator.h5")
discriminator.save_weights("/kaggle/working/discriminator.h5")

# Eval mode

* Режим предсказания

In [0]:
#discriminator = build_discriminator()
#generator = build_generator()

generator.load_weights("/kaggle/working/generator.h5")
discriminator.load_weights("/kaggle/working/discriminator.h5")

high_resolution_images, low_resolution_images = sample_images(data_dir=data_dir, batch_size=10,
                                                                      low_resolution_shape=low_resolution_shape,
                                                                      high_resolution_shape=high_resolution_shape)

high_resolution_images = high_resolution_images / 127.5 - 1.
low_resolution_images = low_resolution_images / 127.5 - 1.

generated_images = generator.predict_on_batch(low_resolution_images)

## Save images

* Сохраняем изображения

In [0]:
for index, img in enumerate(generated_images):
    save_images(low_resolution_images[index], high_resolution_images[index], img,
                path="/kaggle/working/gen_{}".format(index))