In [1]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

In [None]:


!pip install tensorflow keras

In [None]:
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow import keras

import glob
import os

from keras import Input
from keras.applications import VGG19
from keras.callbacks import TensorBoard
from keras.layers import BatchNormalization, Activation, LeakyReLU, Add, Dense
from keras.layers import Conv2D, UpSampling2D
from keras.models import Model
from keras.optimizers import Adam


import random
from numpy import asarray
from itertools import repeat

import imageio
from imageio import imread
from PIL import Image
from skimage.transform import resize as imresize
import matplotlib.pyplot as plt

import warnings
warnings.filterwarnings("ignore")

print("Tensorflow version " + tf.__version__)
# print("Keras version " + tf.keras.__version__)

In [None]:
# data path
TRAIN_PATH = '/content/drive/MyDrive/normal/'
VAL_PATH = '/content/drive/MyDrive/'
TEST_PATH = '/content/drive/MyDrive/'
data_path = TRAIN_PATH

epochs = 5

# batch size equals to 8 (due to RAM limits)
batch_size = 12

# define the shape of low resolution image (LR)
low_resolution_shape = (32, 32, 1)

# define the shape of high resolution image (HR)
high_resolution_shape = (224, 224, 1)

# optimizer for discriminator, generator
common_optimizer = Adam(0.0002, 0.5)

# use seed for reproducible results
SEED = 2020
tf.random.set_seed(SEED)

## III. Data

Load data, process data, EDA

In [None]:
def get_train_images(data_path):
    # CLASSES = ['CNV', 'DME', 'DRUSEN', 'NORMAL']
    CLASSES = [ 'NORMAL']
    image_list = []

    for class_type in CLASSES:
        image_list.extend(glob.glob(data_path + class_type + '/*'))

    return image_list
get_train_images(data_path)

In [None]:
def find_img_dims(image_list):

    min_size = []
    max_size = []

    for i in range(len(image_list)):
        im = Image.open(image_list[i])
        min_size.append(min(im.size))
        max_size.append(max(im.size))

    return min(min_size), max(max_size)



In [None]:
from PIL import Image

def resize_and_find_img_dims(image_list, target_size=(224, 224)):
    min_size = []
    max_size = []

    for i in range(len(image_list)):
        im = Image.open(image_list[i])
        im = im.resize(target_size)
        min_size.append(min(im.size))
        max_size.append(max(im.size))

    return min(min_size), max(max_size)

# Example usage
image_list = get_train_images(data_path)
min_size, max_size = resize_and_find_img_dims(image_list)
print('The min and max image dims are {} and {} respectively.'.format(min_size, max_size))
# print()


# image_list = get_train_images(data_path)
# min_size, max_size = find_img_dims(image_list)
# print('The min and max image dims are {} and {} respectively.'
#       .format(min_size, max_size))

## IV. Utility functions

Quantitative metrics for image quality  
Loss functions  
Plots  
Image processing: sampling and saving images

### IV A. Metrics

#### 1. PSNR - Peak Signal-to-Noise ratio


PSNR is the ratio between maximum possible power of signal and power of corrupting noise (Wikipedia).


$${ PSNR = 10  \log_{10}  \left( {MAX_I^2 \over MSE} \right) }$$

$ MAX_I $  -  maximum possible power of a signal of image I  
$ MSE $  -  mean squared error pixel by pixel

In [None]:
def compute_psnr(original_image, generated_image):

    original_image = tf.convert_to_tensor(original_image, dtype=tf.float32)
    generated_image = tf.convert_to_tensor(generated_image, dtype=tf.float32)
    psnr = tf.image.psnr(original_image, generated_image, max_val=1.0)

    return tf.math.reduce_mean(psnr, axis=None, keepdims=False, name=None)

In [None]:
def plot_psnr(psnr):

    psnr_means = psnr['psnr_quality']
    plt.figure(figsize=(10,8))
    plt.plot(psnr_means)
    plt.xlabel('Epochs')
    plt.ylabel('PSNR')
    plt.title('PSNR')

#### 2. SSIM - Structural Similarity Index


SSIM measures the perceptual difference between two similar images [(see Wikipedia)](https://en.wikipedia.org/wiki/Structural_similarity).

$${ SSIM(x, y) = {(2 \mu_x \mu_y + c_1) (2 \sigma_{xy} + c_2) \over (\mu_x^2 + \mu_y^2 + c_1) ( \sigma_x^2 + \sigma_y^2 +c_2)}  }$$


$ \mu_x, \mu_y$       - average value for image $x, y$    
$ \sigma_x, \sigma_y$ - standard deviation for image $x, y$     
$ \sigma_{xy}$        - covariance  of $x$ and $y$      
$ c_1, c_2 $          - coefficients

In [None]:
def compute_ssim(original_image, generated_image):

    original_image = tf.convert_to_tensor(original_image, dtype=tf.float32)
    generated_image = tf.convert_to_tensor(generated_image, dtype=tf.float32)
    ssim = tf.image.ssim(original_image, generated_image, max_val=1.0, filter_size=11,
                          filter_sigma=1.5, k1=0.01, k2=0.03)

    return tf.math.reduce_mean(ssim, axis=None, keepdims=False, name=None)

In [None]:
def plot_ssim(ssim):

    ssim_means = ssim['ssim_quality']

    plt.figure(figsize=(10,8))
    plt.plot(ssim_means)
    plt.xlabel('Epochs')
    plt.ylabel('SSIM')
    plt.title('SSIM')

### IV B. Loss Functions

The most important contribution of the SRGAN paper was the use of a *perceptual loss* function.


**Perceptual Loss**  is a weighted sum of the *content loss* and *adversarial loss*.


$${ l^{SR} = l_X^{SR} + 10^{-3}l_{Gen}^{SR}}$$

$l^{SR}$ - perceptual loss   
$l_X^{SR}$ - content loss   
$l_{Gen}^{SR}$ - adversarial loss


****************************

**1. Content Loss**   
The SRGAN replaced the *MSE loss* with a *VGG loss*. Both losses are defined below:

         
**Pixel-wise MSE loss** is the mean squared error between each pixel in the original HR image and a the corresponding pixel in the generated SR image.


**VGG loss** is the euclidean distance between the feature maps of the generated SR image and the original HR  image. The feature maps are the activation layers of the pre-trained  VGG 19 network.

$${ l_{{VGG}/{i,j}}^{SR} = {1 \over {W_{i,j}H_{i,j}}} \sum\limits_{x=1}^{W_{i,j}} \sum\limits_{y=1}^{H_{i,j}}  ({\phi}_{i,j}(I^{HR})_{x,y} - {\phi}_{i,j} (G_{{\theta}_G} (I^{LR}))_{x,y})^2}$$


$ l_{{VGG}/{i,j}}^{SR} $  -  VGG loss    
$ {\phi}_{i,j} $  -   the feature map obtained by the j-th convolution (after activation) before the i-th maxpooling layer within the VGG19 network



**2. Adversarial Loss**  
This is calculated based on probabilities provided by Discriminator.

$${ l_{Gen}^{SR} = \sum\limits_{n=1}^{N} - \log{D_{{\theta}_D}} (G_{{\theta}_G} (I^{LR}))}$$

$ l_{Gen}^{SR} $  -  generative loss  
$ D $  -  discriminator function    
$ D_{{\theta}_D} $  -  discriminator function parametrized with $ {\theta}_D $   
$ {D_{{\theta}_D}} (G_{{\theta}_G} (I^{LR})) $   -  probability that the reconstructed image $
$ G_{{\theta}_G} (I^{LR}) $  is a natural HR image

#### Plot loss function

In [None]:
def plot_loss(losses):

    d_loss = losses['d_history']
    g_loss = losses['g_history']

    plt.figure(figsize=(10,8))
    plt.plot(d_loss, label="Discriminator loss")
    plt.plot(g_loss, label="Generator loss")

    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title("Loss")
    plt.legend()

### IV C. Sampling, saving images

In [None]:
import numpy as np
from skimage.io import imread
from skimage.transform import resize

def sample_images(image_list, batch_size, high_resolution_shape, low_resolution_shape):
    """
    Pre-process a batch of training images
    """
    # Randomly sample a batch of images from image_list
    images_batch = np.random.choice(image_list, size=batch_size)

    lr_images = []
    hr_images = []

    for img in images_batch:
        # Read the image with RGB mode
        img1 = imread(img, as_gray=True)
        img1 = img1.astype(np.float32)

        # Resize the image to high resolution and low resolution
        img1_high_resolution = resize(img1, output_shape=(224, 224, 1),anti_aliasing=True)
        img1_low_resolution = resize(img1, output_shape=(32, 32, 1),anti_aliasing=True)




        # Do a random horizontal flip
        if np.random.random() < 0.5:
            img1_high_resolution = np.fliplr(img1_high_resolution)
            img1_low_resolution = np.fliplr(img1_low_resolution)

        hr_images.append(img1_high_resolution)
        lr_images.append(img1_low_resolution)

    # Convert lists into numpy ndarrays
    hr_images = np.array(hr_images)[..., np.newaxis]
    lr_images = np.array(lr_images)[..., np.newaxis]
    return hr_images, lr_images


In [None]:
import matplotlib.pyplot as plt
import numpy as np

def save_images(original_image, lr_image, sr_image, path):
    """
    Save LR, HR (original) and generated SR
    images in one panel
    """

    fig, ax = plt.subplots(1, 3, figsize=(10, 6))

    images = [original_image, lr_image, sr_image]
    titles = ['HR', 'LR', 'SR - generated']

    for idx, img in enumerate(images):
        # Ensure image is in [0, 1] range
        img = np.clip((img + 1) / 2.0, 0, 1)  # Scale from [-1, 1] to [0, 1]
        img = img.squeeze()  # Remove single-dimensional entries from the shape of an array
        ax[idx].imshow(img, cmap='gray')
        ax[idx].axis("off")
        ax[idx].set_title(titles[idx])

    plt.tight_layout()
    plt.savefig(path)
    plt.close()


## V. SRGAN-VGG19

The SRGAN has the following code components:
 1. Generator network
 2.  Discriminator network
 3. Feature extractor using the VGG19 network
 4. Adversarial framework

### V 1. Generator

There are 16 residual blocks and 2 upsampling blocks. The generator follows the architecture outlined in [2].

In [None]:
def residual_block(x):
    filters = [64, 64]
    kernel_size = 3
    strides = 1
    padding = "same"
    momentum = 0.8
    activation = "relu"

    res = Conv2D(filters=filters[0], kernel_size=kernel_size, strides=strides, padding=padding)(x)
    res = Activation(activation=activation)(res)
    res = BatchNormalization(momentum=momentum)(res)

    res = Conv2D(filters=filters[1], kernel_size=kernel_size, strides=strides, padding=padding)(res)
    res = BatchNormalization(momentum=momentum)(res)

    res = Add()([res, x])

    return res


In [None]:
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, BatchNormalization, Add, UpSampling2D, Activation, Cropping2D
from tensorflow.keras.models import Model

def residual_block(x):
    filters = 64
    momentum = 0.8

    res = Conv2D(filters=filters, kernel_size=3, strides=1, padding='same')(x)
    res = BatchNormalization(momentum=momentum)(res)
    res = Activation('relu')(res)
    res = Conv2D(filters=filters, kernel_size=3, strides=1, padding='same')(res)
    res = BatchNormalization(momentum=momentum)(res)

    res = Add()([res, x])
    return res

def build_generator():
    residual_blocks = 16
    momentum = 0.8
    input_shape = (32, 32, 1)
    input_layer = Input(shape=input_shape)

    gen1 = Conv2D(filters=64, kernel_size=9, strides=1, padding='same', activation='relu')(input_layer)

    res = residual_block(gen1)
    for i in range(residual_blocks - 1):
        res = residual_block(res)

    gen2 = Conv2D(filters=64, kernel_size=3, strides=1, padding='same')(res)
    gen2 = BatchNormalization(momentum=momentum)(gen2)

    gen1_resize = Conv2D(filters=64, kernel_size=1, strides=1, padding='same')(gen1)
    gen3 = Add()([gen2, gen1_resize])

    gen4 = UpSampling2D(size=2)(gen3)
    gen4 = Conv2D(filters=256, kernel_size=3, strides=1, padding='same')(gen4)
    gen4 = Activation('relu')(gen4)

    gen5 = UpSampling2D(size=2)(gen4)
    gen5 = Conv2D(filters=256, kernel_size=3, strides=1, padding='same')(gen5)
    gen5 = Activation('relu')(gen5)

    gen6 = UpSampling2D(size=2)(gen5)
    gen6 = Conv2D(filters=256, kernel_size=3, strides=1, padding='same')(gen6)
    gen6 = Activation('relu')(gen6)

    gen7 = Cropping2D(cropping=((16, 16), (16, 16)))(gen6)

    gen8 = Conv2D(filters=1, kernel_size=9, strides=1, padding='same')(gen7)
    output = Activation('tanh')(gen8)

    model = Model(inputs=[input_layer], outputs=[output], name='generator')

    return model

generator = build_generator()
generator.summary()



### V 2. Discriminator

In [None]:
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, BatchNormalization, LeakyReLU, Dense, Flatten, Reshape
from tensorflow.keras.models import Model

def build_discriminator():
    # Define hyperparameters
    leakyrelu_alpha = 0.2
    momentum = 0.8


    # Input shape
    input_shape = (224, 224, 1)

    # Input layer for discriminator
    input_layer = Input(shape=input_shape)

    # Convolutional layers with reduced filters
    dis1 = Conv2D(filters=64, kernel_size=3, strides=2, padding='same')(input_layer)  # (112, 112, 64)
    dis1 = LeakyReLU(alpha=leakyrelu_alpha)(dis1)

    dis2 = Conv2D(filters=64, kernel_size=3, strides=2, padding='same')(dis1)  # (56, 56, 64)
    dis2 = LeakyReLU(alpha=leakyrelu_alpha)(dis2)
    dis2 = BatchNormalization(momentum=momentum)(dis2)

    dis3 = Conv2D(filters=128, kernel_size=3, strides=2, padding='same')(dis2)  # (28, 28, 128)
    dis3 = LeakyReLU(alpha=leakyrelu_alpha)(dis3)
    dis3 = BatchNormalization(momentum=momentum)(dis3)

    dis4 = Conv2D(filters=128, kernel_size=3, strides=2, padding='same')(dis3)  # (14, 14, 128)
    dis4 = LeakyReLU(alpha=leakyrelu_alpha)(dis4)
    dis4 = BatchNormalization(momentum=momentum)(dis4)

    dis5 = Conv2D(filters=256, kernel_size=3, strides=1, padding='same')(dis4)  # (14, 14, 256)
    dis5 = LeakyReLU(alpha=leakyrelu_alpha)(dis5)
    dis5 = BatchNormalization(momentum=momentum)(dis5)

    dis6 = Conv2D(filters=256, kernel_size=3, strides=1, padding='same')(dis5)  # (14, 14, 256)
    dis6 = LeakyReLU(alpha=leakyrelu_alpha)(dis6)
    dis6 = BatchNormalization(momentum=momentum)(dis6)

    dis7 = Conv2D(filters=512, kernel_size=3, strides=1, padding='same')(dis6)  # (14, 14, 512)
    dis7 = LeakyReLU(alpha=leakyrelu_alpha)(dis7)
    dis7 = BatchNormalization(momentum=momentum)(dis7)

    dis8 = Conv2D(filters=1, kernel_size=3, strides=1, padding='same')(dis7)  # (14, 14, 1)
    dis8 = LeakyReLU(alpha=leakyrelu_alpha)(dis8)
    dis8 = BatchNormalization(momentum=momentum)(dis8)


    output = Dense(units=1, activation='sigmoid')(dis8)
    model = Model(inputs=[input_layer], outputs=[output], name='discriminator')

    return model

discriminator = build_discriminator()
discriminator.trainable = True

discriminator.summary()


### V 3. Mobilenet Extractor

In [None]:
import tensorflow as tf
from tensorflow.keras.applications import MobileNet
from tensorflow.keras.layers import Input, Lambda, Conv2D, UpSampling2D
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam

def build_teacher_model():
    input_shape = (224, 224, 1)
    input_layer = Input(shape=input_shape)

    rgb_input = Lambda(lambda x: tf.image.grayscale_to_rgb(x))(input_layer)

    MobileNet_base = MobileNet(include_top=False, weights='imagenet', input_tensor=rgb_input, input_shape=(224, 224, 3))

    output = MobileNet_base.get_layer('conv_pw_11_relu').output  # Shape: (14, 14, 512)

    model = Model(inputs=input_layer, outputs=output)
    return model

def build_student_model():
    input_shape = (224, 224, 1)
    input_layer = Input(shape=input_shape)

    rgb_input = Lambda(lambda x: tf.image.grayscale_to_rgb(x))(input_layer)

    # Use a smaller version of MobileNet
    MobileNet_base = MobileNet(include_top=False, weights=None, input_tensor=rgb_input, input_shape=(224, 224, 3), alpha=0.5)

    # The output is (7, 7, 512), so we need to upsample
    x = MobileNet_base.output
    x = UpSampling2D(size=(2, 2))(x)  # Now (14, 14, 512)
    output = Conv2D(512, (3, 3), padding='same', activation='relu')(x)

    model = Model(inputs=input_layer, outputs=output)
    return model

teacher_model = build_teacher_model()
teacher_model.trainable = False

# Build the student model
student_model = build_student_model()

# Create a combined model for training
input_layer = Input(shape=(224, 224, 1))
teacher_preds = teacher_model(input_layer)
student_preds = student_model(input_layer)

combined_model = Model(inputs=input_layer, outputs=[student_preds, teacher_preds])


student_model.compile(optimizer=Adam(), loss='mse')
teacher_model.summary(),student_model.summary()

In [None]:
from tensorflow.keras.optimizers.legacy import Adam

# Define the common optimizer with legacy Adam
common_optimizer = Adam(0.0002, 0.5)



In [None]:
def build_adversarial_model(generator, discriminator, teacher_model, student_model):
    # Input layer for low-resolution images
    input_low_resolution = Input(shape=low_resolution_shape)

    # Generate high-resolution images from low-resolution images
    generated_high_resolution_images = generator(input_low_resolution)

    # Extract features using the teacher model (frozen)
    teacher_features = teacher_model(generated_high_resolution_images)

    # Extract features using the student model (trainable)
    student_features = student_model(generated_high_resolution_images)

    # Make discriminator non-trainable for the adversarial model
    discriminator.trainable = False

    # Discriminator will give us a probability estimation for the generated high-resolution images
    probs = discriminator(generated_high_resolution_images)

    # Create the adversarial model
    adversarial_model = Model(input_low_resolution, [probs, student_features, teacher_features])

    return adversarial_model

# Ensure the teacher model is not trainable
teacher_model.trainable = False

# Build the adversarial model
adversarial_model = build_adversarial_model(generator, discriminator, teacher_model, student_model)

# Compile the adversarial model
adversarial_model.compile(
    loss=['binary_crossentropy', 'mse', None],  # None for teacher output as it's not trained
    loss_weights=[1e-3, 1, 0],  # 0 weight for teacher output
    optimizer=common_optimizer
)
adversarial_model.summary()

In [None]:
# adversarial_model = build_adversarial_model(generator, discriminator, fe_model)
# adversarial_model.compile(loss=['binary_crossentropy', 'mean_squared_error'],
#                            optimizer=common_optimizer,
#                            loss_weights=[0.5, 0.5])
# adversarial_model.summary()

## VI. Training


In [None]:
# initialize

losses = {"d_history":[], "g_history":[]}
psnr = {'psnr_quality': []}
ssim = {'ssim_quality': []}

In [None]:
from tensorflow.keras.preprocessing.image import array_to_img, img_to_array
def resize_images(images, target_size):
    resized_images = []
    for img_array in images:
        img = array_to_img(img_array)
        img = img.resize(target_size)
        img = img_to_array(img)
        resized_images.append(img)
    return np.array(resized_images)

In [None]:

discriminator.compile(loss='mse', optimizer=common_optimizer, metrics=['accuracy'])

In [None]:
hr_images, lr_images = sample_images(image_list,
                                         batch_size=batch_size,
                                         low_resolution_shape=low_resolution_shape,
                                         high_resolution_shape=high_resolution_shape)

    # Normalize the images
hr_images = hr_images / 127.5 - 1.
lr_images = lr_images / 127.5 - 1.
if hr_images.shape[-1] == 1 and hr_images.shape[-2] == 1:
        hr_images_new = hr_images.reshape(hr_images.shape[:-2] + (1,))
if lr_images.shape[-1] == 1 and lr_images.shape[-2] == 1:
        lr_images_new = lr_images.reshape(lr_images.shape[:-2] + (1,))

In [None]:
len(hr_images)
len(image_list)
hr_images_new.shape,lr_images_new.shape

In [None]:

for epoch in range(100):
    image_list = get_train_images(data_path)

    """
    Train the discriminator network
    """
    hr_images, lr_images = sample_images(image_list,
                                         batch_size=batch_size,
                                         low_resolution_shape=low_resolution_shape,
                                         high_resolution_shape=high_resolution_shape)

    # Normalize the images
    hr_images = hr_images / 127.5 - 1.
    lr_images = lr_images / 127.5 - 1.

    # Generate high-resolution images from low-resolution images
    generated_high_resolution_images = generator.predict(lr_images)

    # Generate a batch of true and fake labels
    real_labels = np.ones((batch_size, 14, 14, 1))
    fake_labels = np.zeros((batch_size, 14, 14, 1))

    d_loss_real = discriminator.train_on_batch(hr_images, real_labels)
    d_loss_fake = discriminator.train_on_batch(generated_high_resolution_images, fake_labels)

    # Calculate total loss of discriminator as average loss on true and fake labels
    d_loss = 0.5 * (np.mean(d_loss_real) + np.mean(d_loss_fake))
    losses['d_history'].append(d_loss)

    """
    Train the generator network
    """
    # Extract features using the teacher model
    teacher_features = teacher_model.predict(hr_images)

    # Train the generator (via the adversarial model)
    g_loss = adversarial_model.train_on_batch(lr_images, [real_labels, teacher_features, teacher_features])
    losses['g_history'].append(g_loss[1])  # Assuming g_loss[1] is the relevant generator loss

    # Calculate PSNR and SSIM
    ps = compute_psnr(hr_images_new, generated_high_resolution_images)
    ss = compute_ssim(hr_images_new, generated_high_resolution_images)
    psnr['psnr_quality'].append(ps)
    ssim['ssim_quality'].append(ss)

    print(f"Epoch {epoch + 1}  PSNR {ps.numpy()}  SSIM {ss.numpy()}")

    """
    Save and print image samples
    """
    if epoch % 5 == 0:
        for index, img in enumerate(generated_high_resolution_images):
            if index < 1:  # Save only one image per 5 epochs for demonstration
                save_images(hr_images[index], lr_images[index], img,
                            path=f"/content/drive/MyDrive/mobilenet_distillation/images_{epoch}_{index}")