<a href="https://colab.research.google.com/github/alexandrufalk/tensorflow/blob/Master/SRGAN_.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
import tensorflow as tf
from tensorflow.keras import layers, models, optimizers
from tensorflow.keras.applications import VGG19
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import numpy as np
import matplotlib.pyplot as plt
import os
from PIL import Image


#Define the Generator
-residual blocks and upsampling layers to convert low-resolution images to high-resolution images.

-A residual block helps in training deeper networks by allowing gradients to flow through skip connections.

In [4]:
def residual_block(input_tensor, filters=64, kernel_size=3):
    x = layers.Conv2D(filters, kernel_size, padding='same')(input_tensor)
    x = layers.BatchNormalization()(x)
    x = layers.PReLU(shared_axes=[1, 2])(x)
    x = layers.Conv2D(filters, kernel_size, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Add()([x, input_tensor])
    return x

#Upsampling Block
Uses PixelShuffle (sub-pixel convolution) to upscale the image.


In [5]:
def upsample_block(input_tensor, filters=256, kernel_size=3, scale=2):
    x = layers.Conv2D(filters, kernel_size, padding='same')(input_tensor)
    x = layers.Lambda(lambda x: tf.nn.depth_to_space(x, scale))(x)
    x = layers.PReLU(shared_axes=[1, 2])(x)
    return x

#Generator Model

In [6]:
def build_generator(hr_shape):
    """
    Args:
        hr_shape: tuple, high-resolution image shape, e.g., (None, None, 3)
    Returns:
        Keras Model
    """
    inputs = layers.Input(shape=(None, None, 3))

    # Initial Conv layer
    x = layers.Conv2D(64, 9, padding='same')(inputs)
    x = layers.PReLU(shared_axes=[1, 2])(x)
    residual = x

    # 16 Residual blocks
    for _ in range(16):
        residual = residual_block(residual)

    # Conv layer after residual blocks
    x = layers.Conv2D(64, 3, padding='same')(residual)
    x = layers.BatchNormalization()(x)
    x = layers.Add()([x, inputs])

    # Upsampling blocks
    x = upsample_block(x)
    x = upsample_block(x)

    # Output layer
    outputs = layers.Conv2D(3, 9, padding='same', activation='tanh')(x)

    model = models.Model(inputs, outputs, name='Generator')
    return model

#Define the Discriminator
The discriminator is a CNN that classifies images as real or fake.

In [7]:
def build_discriminator(hr_shape):
    """
    Args:
        hr_shape: tuple, high-resolution image shape, e.g., (None, None, 3)
    Returns:
        Keras Model
    """
    inputs = layers.Input(shape=hr_shape)

    x = layers.Conv2D(64, 3, strides=1, padding='same')(inputs)
    x = layers.LeakyReLU(alpha=0.2)(x)

    # 15 Convolutional blocks
    filters = 64
    for i in range(1, 16):
        if i % 2 == 0:
            strides = 2
            filters *= 2
        else:
            strides = 1
        x = layers.Conv2D(filters, 3, strides=strides, padding='same')(x)
        x = layers.BatchNormalization()(x)
        x = layers.LeakyReLU(alpha=0.2)(x)

    x = layers.Flatten()(x)
    x = layers.Dense(1024)(x)
    x = layers.LeakyReLU(alpha=0.2)(x)

    outputs = layers.Dense(1, activation='sigmoid')(x)

    model = models.Model(inputs, outputs, name='Discriminator')
    return model

#Define the SRGAN Model

Combines the generator and discriminator. The discriminator is used to compute the adversarial loss for the generator.

In [8]:
def build_srgan(generator, discriminator, vgg):
    """
    Args:
        generator: Keras Model, generator model
        discriminator: Keras Model, discriminator model
        vgg: Keras Model, pre-trained VGG model for perceptual loss
    Returns:
        Keras Model
    """
    discriminator.trainable = False
    sr = generator.output
    vgg_features = vgg(sr)
    validity = discriminator(sr)

    model = models.Model(generator.input, [validity, vgg_features])
    return model


#Define Loss Functions
Mean Squared Error between the generated image and the ground truth high-resolution image.

In [9]:
def content_loss(y_true, y_pred):
    return tf.reduce_mean(tf.square(y_true - y_pred))

In [11]:
#Adversarial Loss - Binary cross-entropy loss to train the generator to fool the discriminator.
adversarial_loss = tf.keras.losses.BinaryCrossentropy()

#Perceptual Loss
Uses feature maps from a pre-trained VGG19 network to compute the loss.

In [12]:
# Load pre-trained VGG19 model + higher level layers
def build_vgg():
    vgg = VGG19(weights='imagenet', include_top=False, input_shape=(None, None, 3))
    vgg.trainable = False
    # Select the output of 'block5_conv4' for perceptual loss
    output = vgg.get_layer('block5_conv4').output
    model = models.Model(vgg.input, output)
    return model

vgg = build_vgg()

def perceptual_loss(y_true, y_pred):
    y_true_features = vgg(y_true)
    y_pred_features = vgg(y_pred)
    return tf.reduce_mean(tf.square(y_true_features - y_pred_features))


Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/vgg19/vgg19_weights_tf_dim_ordering_tf_kernels_notop.h5
[1m80134624/80134624[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 0us/step
