<a href="https://colab.research.google.com/github/SESHG14/ESRGAN_STUDENT_2023/blob/main/ESRGAN_Student_1_2023.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Training a Student Network using Knowledge Distillation

- Using a reduced ESRGAN model -> 8 residual layers instead of 16.
- Same structure as Teacher model.

In [None]:
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
from keras.models import Sequential
from keras import layers, Model
from sklearn.model_selection import train_test_split
from IPython.display import clear_output

import numpy as np
from keras import Model
from keras.layers import Conv2D, PReLU,BatchNormalization, Flatten
from keras.layers import UpSampling2D, LeakyReLU, Dense, Input, add
from tqdm import tqdm

# Student Model Structure Functions:

In [None]:
def res_block(ip):

    temp = ip

    # 1 - initial
    res_model = Conv2D(64, (3,3), padding = "same")(ip)
    res_model = LeakyReLU(alpha=0.2)(res_model)
    res_model = add([temp,res_model])

    # 2
    temp = res_model
    res_model = Conv2D(64, (3,3), padding = "same")(res_model)
    res_model = LeakyReLU(alpha=0.2)(res_model)
    res_model = add([temp,res_model])
    # 3
    temp = res_model
    res_model = Conv2D(64, (3,3), padding = "same")(res_model)
    res_model = LeakyReLU(alpha=0.2)(res_model)
    res_model = add([temp,res_model])
    # 4
    temp = res_model
    res_model = Conv2D(64, (3,3), padding = "same")(res_model)
    res_model = LeakyReLU(alpha=0.2)(res_model)
    res_model = add([temp,res_model])
    # Final Convolution
    res_model = Conv2D(64, (3,3), padding = "same")(res_model)

    #res_model = Conv2D(64, (3,3), padding = "same")(res_model)
    res_model = BatchNormalization(momentum = 0.5)(res_model)

    return add([ip,res_model])

def res_in_res_block(ip):

    #1
    compound_res = res_block(ip)
    #2
    compound_res = res_block(compound_res)
    #3
    compound_res = res_block(compound_res)

    return add([ip,compound_res])


def upscale_block(ip):

    up_model = Conv2D(256, (3,3), padding="same")(ip)
    up_model = UpSampling2D( size = 2 )(up_model)
    #up_model = PReLU(shared_axes=[1,2])(up_model)
    up_model = LeakyReLU(alpha=0.2)(up_model)

    return up_model

# Proposed ESRGAN Generator
def create_gen(gen_ip, num_res_block):
    layers = Conv2D(64, (9,9), padding="same")(gen_ip)
    #layers = PReLU(shared_axes=[1,2])(layers)
    layers = LeakyReLU(alpha=0.2)(layers)

    temp = layers

    for i in range(num_res_block):
        layers = res_in_res_block(layers)

    layers = Conv2D(64, (3,3), padding="same")(layers)
    layers = BatchNormalization(momentum=0.5)(layers)
    layers = add([layers,temp])

    layers = upscale_block(layers)
    layers = upscale_block(layers)

    #layers = Conv2D(64, (3,3), padding="same")(layers)
    op = Conv2D(3, (9,9), padding="same")(layers)

    return Model(inputs=gen_ip, outputs=op)


#Descriminator block that will be used to construct the discriminator
def discriminator_block(ip, filters, strides=1, bn=True):

    disc_model = Conv2D(filters, (3,3), strides = strides, padding="same")(ip)

    if bn:
        disc_model = BatchNormalization( momentum=0.8 )(disc_model)

    disc_model = LeakyReLU( alpha=0.2 )(disc_model)

    return disc_model

#Descriminartor, as described in the original paper
def create_disc(disc_ip):

    df = 64

    d1 = discriminator_block(disc_ip, df, bn=False)
    d2 = discriminator_block(d1, df, strides=2)
    d3 = discriminator_block(d2, df*2)
    d4 = discriminator_block(d3, df*2, strides=2)
    d5 = discriminator_block(d4, df*4)
    d6 = discriminator_block(d5, df*4, strides=2)
    d7 = discriminator_block(d6, df*8)
    d8 = discriminator_block(d7, df*8, strides=2)

    d8_5 = Flatten()(d8)
    d9 = Dense(df*16)(d8_5)
    d10 = LeakyReLU(alpha=0.2)(d9)
    validity = Dense(1, activation='sigmoid')(d10)

    return Model(disc_ip, validity)

#VGG19 for the feature map obtained by the j-th convolution (after activation)
#before the i-th maxpooling layer within the VGG19 network.(as described in the paper)
#Build a pre-trained VGG19 model that outputs image features extracted at the
# third block of the model
# VGG architecture: https://github.com/keras-team/keras/blob/master/keras/applications/vgg19.py
from keras.applications import VGG19

def build_vgg(hr_shape):

    vgg = VGG19(weights="imagenet",include_top=False, input_shape=hr_shape)

    return Model(inputs=vgg.inputs, outputs=vgg.layers[10].output)

#Combined model
def create_comb(gen_model, disc_model, vgg, lr_ip, hr_ip):
    gen_img = gen_model(lr_ip)

    gen_features = vgg(gen_img)

    disc_model.trainable = False
    validity = disc_model(gen_img)

    return Model(inputs=[lr_ip, hr_ip], outputs=[validity, gen_features])

# LR SCHEDULERS for Gen and Disc

In [None]:
import tensorflow as tf

def MultiStepLR(initial_learning_rate, lr_steps, lr_rate, name='MultiStepLR'):
    """Multi-steps learning rate scheduler."""
    lr_steps_value = [initial_learning_rate]
    for _ in range(len(lr_steps)):
        lr_steps_value.append(lr_steps_value[-1] * lr_rate)
    return tf.keras.optimizers.schedules.PiecewiseConstantDecay(
        boundaries=lr_steps, values=lr_steps_value)


def CosineAnnealingLR_Restart(initial_learning_rate, t_period, lr_min):
    """Cosine annealing learning rate scheduler with restart."""
    return tf.keras.experimental.CosineDecayRestarts(
        initial_learning_rate=initial_learning_rate,
        first_decay_steps=t_period, t_mul=1.0, m_mul=1.0,
        alpha=lr_min / initial_learning_rate)



    # pretrain PSNR lr scheduler
lr_scheduler_D = MultiStepLR(1e-4, [50000, 100000, 200000, 300000], 0.5)

    # ESRGAN lr scheduler
lr_scheduler_G = MultiStepLR(1e-4, [50000, 100000, 200000, 300000], 0.5)

    # Cosine Annealing lr scheduler
    # lr_scheduler = CosineAnnealingLR_Restart(2e-4, 250000, 1e-7)

    ##############################
    # Draw figure
    ##############################
N_iter = 1000000
step_list = list(range(0, N_iter, 1000))
lr_list = []
for i in step_list:
     current_lr = lr_scheduler_G(i).numpy()
     lr_list.append(current_lr)

import matplotlib as mpl
from matplotlib import pyplot as plt
import matplotlib.ticker as mtick
mpl.style.use('default')
import seaborn
seaborn.set(style='whitegrid')
seaborn.set_context('paper')

plt.figure(1)
plt.subplot(111)
plt.ticklabel_format(style='sci', axis='x', scilimits=(0, 0))
plt.title('Title', fontsize=16, color='k')
plt.plot(step_list, lr_list, linewidth=1.5, label='learning rate scheme')
legend = plt.legend(loc='upper right', shadow=False)
ax = plt.gca()
labels = ax.get_xticks().tolist()
for k, v in enumerate(labels):
    labels[k] = str(int(v / 1000)) + 'K'
ax.set_xticklabels(labels)
ax.yaxis.set_major_formatter(mtick.FormatStrFormatter('%.1e'))

ax.set_ylabel('Learning rate')
ax.set_xlabel('Iteration')
fig = plt.gcf()
plt.show()

# Loss Functions:

- Disc loss -> VGG19
- Perceptual Loss
- Relativistic Loss - Generator Loss
- Pixel-wise Loss


In [None]:
import tensorflow as tf
from tensorflow.keras.applications.vgg19 import preprocess_input, VGG19


def PixelLoss(criterion='l1'):
    """pixel loss"""
    if criterion == 'l1':
        return tf.keras.losses.MeanAbsoluteError()
    elif criterion == 'l2':
        return tf.keras.losses.MeanSquaredError()
    else:
        raise NotImplementedError(
            'Loss type {} is not recognized.'.format(criterion))


def ContentLoss(criterion='l1', output_layer=54, before_act=True):
    """content loss"""
    if criterion == 'l1':
        loss_func = tf.keras.losses.MeanAbsoluteError()
    elif criterion == 'l2':
        loss_func = tf.keras.losses.MeanSquaredError()
    else:
        raise NotImplementedError(
            'Loss type {} is not recognized.'.format(criterion))
    vgg = VGG19(input_shape=(None, None, 3), include_top=False)

    if output_layer == 22:  # Low level feature
        pick_layer = 5
    elif output_layer == 54:  # Hight level feature
        pick_layer = 20
    else:
        raise NotImplementedError(
            'VGG output layer {} is not recognized.'.format(criterion))

    if before_act:
        vgg.layers[pick_layer].activation = None

    fea_extrator = tf.keras.Model(vgg.input, vgg.layers[pick_layer].output)

    @tf.function
    def content_loss(hr, sr):
        # the input scale range is [0, 1] (vgg is [0, 255]).
        # 12.75 is rescale factor for vgg featuremaps.
        preprocess_sr = preprocess_input(sr * 255.) / 12.75
        preprocess_hr = preprocess_input(hr * 255.) / 12.75
        sr_features = fea_extrator(preprocess_sr)
        hr_features = fea_extrator(preprocess_hr)

        return loss_func(hr_features, sr_features)

    return content_loss


def DiscriminatorLoss(gan_type='ragan'):
    """discriminator loss"""
    cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=False)
    sigma = tf.sigmoid

    def discriminator_loss_ragan(hr, sr):
        return 0.5 * (
            cross_entropy(tf.ones_like(hr), sigma(hr - tf.reduce_mean(sr))) +
            cross_entropy(tf.zeros_like(sr), sigma(sr - tf.reduce_mean(hr))))

    def discriminator_loss(hr, sr):
        real_loss = cross_entropy(tf.ones_like(hr), sigma(hr))
        fake_loss = cross_entropy(tf.zeros_like(sr), sigma(sr))
        return real_loss + fake_loss

    if gan_type == 'ragan':
        return discriminator_loss_ragan
    elif gan_type == 'gan':
        return discriminator_loss
    else:
        raise NotImplementedError(
            'Discriminator loss type {} is not recognized.'.format(gan_type))


def GeneratorLoss(gan_type='ragan'):
    """generator loss"""
    cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=False)
    sigma = tf.sigmoid

    def generator_loss_ragan(hr, sr):
        return 0.5 * (
            cross_entropy(tf.ones_like(sr), sigma(sr - tf.reduce_mean(hr))) +
            cross_entropy(tf.zeros_like(hr), sigma(hr - tf.reduce_mean(sr))))

    def generator_loss(hr, sr):
        return cross_entropy(tf.ones_like(sr), sigma(sr))

    if gan_type == 'ragan':
        return generator_loss_ragan
    elif gan_type == 'gan':
        return generator_loss
    else:
        raise NotImplementedError(
            'Generator loss type {} is not recognized.'.format(gan_type))

pixel_loss_fn = PixelLoss("l2")
fea_loss_fn = ContentLoss("l2")
gen_loss_fn = GeneratorLoss("ragan")
dis_loss_fn = DiscriminatorLoss("ragan")

Use this instead

In [None]:
def pixelwise_mse(y_true, y_pred):
  mean_squared_error = tf.reduce_mean(tf.reduce_mean(
      (y_true - y_pred)**2, axis=0))
  return mean_squared_error

def RelativisticAverageLoss(non_transformed_disc, type_="G"):
  """ Relativistic Average Loss based on RaGAN
      Args:
      non_transformed_disc: non activated discriminator Model
      type_: type of loss to Ra loss to produce.
             'G': Relativistic average loss for generator
             'D': Relativistic average loss for discriminator
  """
  loss = None

  def D_Ra(x, y):
    return non_transformed_disc(
        x) - tf.reduce_mean(non_transformed_disc(y))

  def loss_D(y_true, y_pred):
    """
      Relativistic Average Loss for Discriminator
      Args:
        y_true: Real Image
        y_pred: Generated Image
    """
    real_logits = D_Ra(y_true, y_pred)
    fake_logits = D_Ra(y_pred, y_true)
    real_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
        labels=tf.ones_like(real_logits), logits=real_logits))
    fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
        labels=tf.zeros_like(fake_logits), logits=fake_logits))
    return real_loss + fake_loss

  def loss_G(y_true, y_pred):
    """
     Relativistic Average Loss for Generator
     Args:
       y_true: Real Image
       y_pred: Generated Image
    """
    real_logits = D_Ra(y_true, y_pred)
    fake_logits = D_Ra(y_pred, y_true)
    real_loss = tf.nn.sigmoid_cross_entropy_with_logits(
        labels=tf.zeros_like(real_logits), logits=real_logits)
    fake_loss = tf.nn.sigmoid_cross_entropy_with_logits(
        labels=tf.ones_like(fake_logits), logits=fake_logits)
    return real_loss + fake_loss
  if type_ == "G":
    loss = loss_G
  elif type_ == "D":
    loss = loss_D
  return loss

# Implementation

1. Load dataset.
2. Compile the Student Model.
3. Load the Teacher model.
4. Initialize optimizers.
5. Execute adversarial training step.

1. Load dataset.

In [None]:
n=5000
lr_list = os.listdir("/content/drive/MyDrive/Image datasets/data/lr_images")[:n]
lr_list.sort

lr_images = []
for img in lr_list:
    img_lr = cv2.imread("/content/drive/MyDrive/Image datasets/data/lr_images/" + img)
    img_lr = cv2.cvtColor(img_lr, cv2.COLOR_BGR2RGB)
    lr_images.append(img_lr)


hr_list = os.listdir("/content/drive/MyDrive/Image datasets/data/hr_images")[:n]
hr_list.sort

hr_images = []
for img in hr_list:
    img_hr = cv2.imread("/content/drive/MyDrive/Image datasets/data/hr_images/" + img)
    img_hr = cv2.cvtColor(img_hr, cv2.COLOR_BGR2RGB)
    hr_images.append(img_hr)

lr_images = np.array(lr_images)
hr_images = np.array(hr_images)

Normalize input

In [None]:
#Scale values
lr_images = lr_images / 255.
hr_images = hr_images / 255.

#Split into train and test sets
lr_train, lr_test, hr_train, hr_test = train_test_split(lr_images, hr_images,
                                                      test_size=0.2, random_state=42)

2. Compiling the Student model:

In [None]:
hr_shape = (hr_train.shape[1], hr_train.shape[2], hr_train.shape[3])
lr_shape = (lr_train.shape[1], lr_train.shape[2], lr_train.shape[3])

lr_ip = Input(shape=lr_shape)
hr_ip = Input(shape=hr_shape)

#generator = create_gen(lr_ip, num_res_block = 16)
generator = create_gen(lr_ip, num_res_block = 8)
#generator.summary()

discriminator = create_disc(hr_ip)
discriminator.compile(loss="binary_crossentropy", optimizer= optimizer_D, metrics=['accuracy'])
#discriminator.summary()

vgg = build_vgg((128,128,3))
#print(vgg.summary())
vgg.trainable = False

gan_model = create_comb(generator, discriminator, vgg, lr_ip, hr_ip)

# Cimpilation Parameters:
# Loss: binary_crossentropy
# Content: feature map obtained by the j-th convolution (after activation)
# before the i-th maxpooling layer within the VGG19 network.
# MSE between the feature representations of a reconstructed image
# and the reference image.
gan_model.compile(loss=["binary_crossentropy", "mse"], loss_weights=[1e-3, 1], optimizer=optimizer_G)
gan_model.summary()

3. Load the Teacher:

In [None]:
from keras.models import load_model
from numpy.random import randint

Teacher = load_model('/content/drive/MyDrive/Image datasets/data/models/gen7_e_40_9.04.h5', compile=True)

4. Initialize optimizers.

In [None]:
optimizer_G = tf.keras.optimizers.Adam(learning_rate=lr_scheduler_G,
                                       beta_1=0.9,
                                       beta_2=0.999)
optimizer_G.build

optimizer_D = tf.keras.optimizers.Adam(learning_rate=lr_scheduler_D,
                                       beta_1=0.9,
                                       beta_2=0.999)
optimizer_D.build

Organizing Batches

In [None]:
batch_size = 1
train_lr_batches = []
train_hr_batches = []
for it in range(int(hr_train.shape[0] / batch_size)):
    start_idx = it * batch_size
    end_idx = start_idx + batch_size
    train_hr_batches.append(hr_train[start_idx:end_idx])
    train_lr_batches.append(lr_train[start_idx:end_idx])

# Combining the Pieces:

Model Training Step Function

In [None]:
loss_fn = tf.keras.losses.MeanSquaredError(reduction="none")
metric_fn = tf.keras.metrics.Mean()
student_psnr = tf.keras.metrics.Mean()
teacher_psnr = tf.keras.metrics.Mean()

ra_generator = utils.RelativisticAverageLoss(
        teacher_disc, type_="G")
ra_discriminator = utils.RelativisticAverageLoss(
        teacher_disc, type_="D")
perceptual_loss = utils.PerceptualLoss(
        weights="imagenet",
        input_shape=hr_size,
        loss_type="L2")
student_psnr = tf.keras.metrics.Mean()
teacher_psnr = tf.keras.metrics.Mean()

In [None]:
# define vars!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
alpha = !!float 1e-4
lambda = !!float 5e-2
teacher_disc = discriminator
generator_metric = tf.keras.metrics.Mean()
discriminator_metric = tf.keras.metrics.Mean()
dummy_optimizer = tf.optimizers.Adam()

def preprocess_input(image):
  image = image[..., ::-1]
  mean = -tf.constant([103.939, 116.779, 123.68])
  return tf.nn.bias_add(image, mean)


def step_fn(image_lr, image_hr):
  with tf.GradientTape() as gen_Tape, tf.GradientTape() as disc_tape:
    teacher_fake = Teacher.predict_on_batch(image_lr)
    teacher_fake = tf.clip_by_value(teacher_fake, 0, 255)

    student_fake = gan_model.train_on_batch(image_lr)
    student_fake = tf.clip_by_value(teacher_fake, 0, 255)

    image_hr = tf.clip_by_value(image_hr,0,255)

    psnr = tf.image.psnr(student_fake, image_hr, max_val=255.0)
    student_psnr = tf.reduce_mean(psnr)

    psnr = tf.image.psnr(teacher_fake, image_hr, max_val=255.0)
    teacher_psnr = tf.reduce_mean(psnr)

    mse_loss = pixelwise_mse(teacher_fake, student_fake)

    image_lr = preprocess_input(image_lr)
    image_hr = preprocess_input(image_hr)
    student_fake = preprocess_input(student_fake)
    teacher_fake = preprocess_input(teacher_fake)

    student_ra_loss = ra_generator(teacher_fake, student_fake)
    discriminator_loss = ra_discriminator(teacher_fake, student_fake)
    discriminator_metric(discriminator_loss)
        discriminator_loss = tf.reduce_mean(
            discriminator_loss) * (1.0 / batch_size)
    percep_loss = perceptual_loss(image_hr, student_fake)
    generator_loss = lambda_ * percep_loss + alpha * student_ra_loss + (1 - alpha) * mse_loss
    generator_metric(generator_loss)
    generator_loss = tf.reduce_mean(
            generator_loss) * (1.0 / batch_size)


    generator_gradient = gen_tape.gradient(
          generator_loss, student.trainable_variables)

    discriminator_gradient = disc_tape.gradient(
          discriminator_loss, teacher_disc.trainable_variables)

    generator_op = generator_optimizer.apply_gradients(
          zip(generator_gradient, student.trainable_variables))
    discriminator_op = discriminator_optimizer.apply_gradients(
          zip(discriminator_gradient, teacher_disc.trainable_variables))

    with tf.control_dependencies(
              [generator_op, discriminator_op]):
        return tf.cast(discriminator_optimizer.iterations, tf.float32)

# finish return and full loop

@tf.function
    def train_step(image_lr, image_hr):
      """
        In Graph Function to assign trainer function to
        replicate among worker nodes.
        Args:
          image_lr: Distributed batch of Low Resolution Images
          image_hr: Distributed batch of High Resolution Images
      """
      distributed_metric = step_fn(image_lr, image_hr)
      mean_metric = distributed_metric
      return mean_metric




In [None]:
epochs = 10
for e in range(epochs):

    fake_label = np.zeros((batch_size, 1)) # Assign a label of 0 to all fake (generated images)
    real_label = np.ones((batch_size,1)) # Assign a label of 1 to all real images.

    #gen and disc losses.
    g_losses = []
    d_losses = []


    for b in tqdm(range(len(train_hr_batches))):
        lr_imgs = train_lr_batches[b] #Fetch a batch of LR images for training
        hr_imgs = train_hr_batches[b] #Fetch a batch of HR images for training

        fake_imgs = generator.predict_on_batch(lr_imgs) #Fake images

        #training the discriminator on fake and real HR images.
        discriminator.trainable = True
        d_loss_gen = discriminator.train_on_batch(fake_imgs, fake_label)
        d_loss_real = discriminator.train_on_batch(hr_imgs, real_label)

        #training the generator by fixing discriminator as non-trainable
        discriminator.trainable = False

        metric = train_step(lr_imgs, hr_imgs)

        d_losses.append(discriminator_loss)
        g_losses.append(generator_loss)


        clear_output()

    #Convert the list of losses to an array
    g_losses = np.array(g_losses)
    d_losses = np.array(d_losses)

    #Calculate the average losses for generator and discriminator
    g_loss = np.sum(g_losses, axis=0) / len(g_losses)
    d_loss = np.sum(d_losses, axis=0) / len(d_losses)

    print("epoch:", e+1 ,"g_loss:", g_loss, "d_loss:", d_loss)
    #generator.save("/content/drive/MyDrive/Image datasets/data/models/Student/gen1_e_"+ str(e+1) + "_" + str("{:.2f}".format(round(g_loss,2))) +".h5")

    if (e+1) % 5 == 0:
        #Save the generator after every 5 epochs
        generator.save("/content/drive/MyDrive/Image datasets/data/models/Student/gen1_e_"+ str(e+1) + "_" + str("{:.2f}".format(round(g_loss,2))) +".h5")