<a href="https://colab.research.google.com/github/DejanGjer/SRGAN/blob/main/SRGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!nvidia-smi

Wed Jul  7 12:11:17 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 465.27       Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla K80           Off  | 00000000:00:04.0 Off |                    0 |
| N/A   35C    P8    28W / 149W |      0MiB / 11441MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
zip_path = "/content/drive/My Drive/Colab Notebooks/SeminarskiB-NN/CelebaHQ_org.zip"
!cp "{zip_path}" .
!unzip -q "CelebaHQ_org.zip"
!rm "CelebaHQ_org.zip"

In [None]:
import os,shutil
from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials
import datetime
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import keras
import matplotlib.pyplot as plt
import numpy as np

import cv2
from tqdm.notebook import tqdm

from tensorflow.keras.models import Model
from tensorflow.keras.layers import BatchNormalization, LeakyReLU, Conv2D, Dense, \
                         Flatten, Add, PReLU, Conv2DTranspose, Lambda, UpSampling2D                    
from keras.optimizers import Adam
from tensorflow.keras.applications import VGG19
from keras.callbacks import ReduceLROnPlateau


# Load the TensorBoard notebook extension
%load_ext tensorboard

print("Setup complete!")

Setup complete!


In [None]:

class LoadDataset:
    def __init__(self, relative_path, validation_split = None, image_size_large = 256, image_size_small = 64, batch_size = 16, shuffle = True, interpolation = "bilinear"):
        print(relative_path)
        self.path = os.path.join("CelebaHQ", relative_path)
        self.image_size_large = image_size_large
        self.image_size_small = image_size_small
        self.batch_size = batch_size
        self.validation_split = validation_split
        self.interpolation = interpolation
        self.shuffle = shuffle
        self.seed = 1337

        self.hr_images_train = None
        self.hr_images_valid = None
        self.lr_images_train = None
        self.lr_images_valid = None

        self.hr_images_train = self.generate(self.image_size_large, "training" if self.validation_split != None else None)
        self.lr_images_train = self.generate(self.image_size_small, "training" if self.validation_split != None else None)
        self.hr_images_train = self.hr_images_train.prefetch(buffer_size=self.batch_size)
        self.lr_images_train = self.lr_images_train.prefetch(buffer_size=self.batch_size)

        if self.validation_split != None:
            self.hr_images_valid = self.generate(self.image_size_large, "validation" if self.validation_split != None else None)
            self.lr_images_valid = self.generate(self.image_size_small, "validation" if self.validation_split != None else None)
            self.hr_images_valid = self.hr_images_valid.prefetch(buffer_size=self.batch_size)
            self.lr_images_valid = self.lr_images_valid.prefetch(buffer_size=self.batch_size)

        
    def generate(self, image_size, subset):
        return tf.keras.preprocessing.image_dataset_from_directory(
            self.path,
            labels="inferred",
            label_mode=None,
            class_names=None,
            color_mode="rgb",
            batch_size=self.batch_size,
            image_size=(image_size, image_size),
            shuffle=self.shuffle,
            seed=self.seed,
            validation_split=self.validation_split,
            subset=subset,
            interpolation=self.interpolation,
            follow_links=False,
        )


In [None]:
train_set_size = 28000
valid_set_size = 1900
test_set_size = 100
batch_size = 32
low_reso_shape = (64,64,3)
high_reso_shape = (256,256,3)

train_valid_dataset = LoadDataset("train_valid_set", validation_split=valid_set_size/(train_set_size + valid_set_size), batch_size=batch_size)
test_dataset = LoadDataset("test_set",batch_size=batch_size, shuffle=False)

train_valid_set
Found 29900 files belonging to 1 classes.
Using 28000 files for training.
Found 29900 files belonging to 1 classes.
Using 28000 files for training.
Found 29900 files belonging to 1 classes.
Using 1900 files for validation.
Found 29900 files belonging to 1 classes.
Using 1900 files for validation.
test_set
Found 100 files belonging to 1 classes.
Found 100 files belonging to 1 classes.


In [None]:
import os,shutil
import cv2
import tensorflow as tf
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt

from tensorflow.keras.models import Model
from tensorflow.keras.layers import BatchNormalization, LeakyReLU, Conv2D, Dense, \
                         Flatten, Add, PReLU, Conv2DTranspose, Lambda, UpSampling2D                    
from keras.optimizers import Adam
from keras.callbacks import ReduceLROnPlateau

class SRResNet:
    def __init__(self):
        self.upscale_factor = 4
        self.model = self.build((64, 64, 3))

    def SubpixelConv2D(self,scale):
        return Lambda(lambda x: tf.nn.depth_to_space(x, scale))


    def res_block(self, input_layer):
        x = Conv2D(filters=64, kernel_size=3, strides=1, padding="same")(input_layer)
        x = BatchNormalization(momentum=0.8)(x)
        x = PReLU()(x)
        x = Conv2D(filters=64, kernel_size=3, strides=1, padding="same")(x)
        x = BatchNormalization(momentum=0.8)(x)
        return Add()([input_layer, x])

    def upsample_block(self, input_layer):
        x = Conv2D(filters=256, kernel_size=3, strides=1, padding="same")(input_layer)
        x = self.SubpixelConv2D(2)(x)
        return PReLU()(x)

    def build(self, input_shape, res_blocks = 8):
        inputs = keras.Input(shape=input_shape)
        x = tf.keras.layers.experimental.preprocessing.Rescaling(1.0 / 255)(inputs)

        x = Conv2D(filters=64, kernel_size=9, padding="same")(x)
        x = PReLU()(x)
        output1 = x

        for _ in range(res_blocks):
            x = self.res_block(x)

        x = Conv2D(filters=64, kernel_size=3, strides=1, padding="same")(x)
        output2 = BatchNormalization(momentum=0.8)(x)
        x = Add()([output1, output2])

        for _ in range(self.upscale_factor // 2):
            x = self.upsample_block(x)

        outputs = Conv2D(filters=3, kernel_size=9, strides=1, padding="same", activation="tanh")(x)
        
        model = Model(inputs=inputs, outputs=outputs, name="SRResNet")
        model.compile(
          optimizer=keras.optimizers.Adam(learning_rate=0.0002,beta_1=0.9,beta_2=0.999),
          loss="mse",
          loss_weights=1
        )
        
        return model


class SRValidator:
  def __init__(self):
        self.upscale_factor = 4
        self.model = self.build((256,256,3))
  
  def disc_block(self, input, n_filters, batch_norm):
    x = Conv2D(filters = n_filters, kernel_size = 3, padding = 'same')(input)
    if batch_norm:
        x = BatchNormalization(momentum=0.8)(x)
    x = LeakyReLU(alpha=0.2)(x)
    x = Conv2D(filters = n_filters, kernel_size = 3,
                strides=2, padding = 'same')(x)
    x = BatchNormalization(momentum=0.8)(x)
    x = LeakyReLU(alpha=0.2)(x)
    return x

  def build(self, input_shape, blocks_num=4):
    inputs = keras.Input(shape=input_shape)
    x = inputs
    for i in range(blocks_num):
      x = self.disc_block(x, (2 ** i) * 64, False if i == 0 else True)
    x = Flatten()(x)
    x = Dense(1024)(x)
    x = LeakyReLU(alpha=0.2)(x)
    outputs = Dense(1, activation='sigmoid')(x)
    model = Model(inputs=inputs, outputs=outputs, name="SRValidator")
    model.compile(
        optimizer=keras.optimizers.Adam(learning_rate=0.0002,beta_1=0.9,beta_2=0.999),
        loss="binary_crossentropy",
        loss_weights=1,
        metrics=['accuracy']
    )

    return model
    

class SRGAN:
  def __init__(self, low_reso_shape, generator=None, validator=None):
    self.upscale_factor = 4
    if generator != None:
      self.generator = generator
    else:
      self.generator = SRResNet().model
    if validator != None:
      self.validator = validator
    else:
      self.validator = SRValidator().model
    self.model = self.build(low_reso_shape)

  def build(self, low_reso_shape):
    lr_images = keras.Input(shape=low_reso_shape)
    hr_gen_images = self.generator(lr_images)

    self.validator.trainable = False
    gen_valid = self.validator(hr_gen_images)

    model = Model(inputs=lr_images, outputs=[hr_gen_images, gen_valid], name="SRGAN")
    model.compile(
        optimizer=keras.optimizers.Adam(learning_rate=0.0002,beta_1=0.9,beta_2=0.999),
        loss=["mse", "binary_crossentropy"],
        loss_weights=[1, 1e-3],
    )
    return model


In [None]:
srgan = SRGAN(low_reso_shape)
srgan.model.summary()

Model: "SRGAN"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_3 (InputLayer)         [(None, 64, 64, 3)]       0         
_________________________________________________________________
SRResNet (Functional)        (None, 256, 256, 3)       8560899   
_________________________________________________________________
SRValidator (Functional)     (None, 1)                 138912577 
Total params: 147,473,476
Trainable params: 8,558,723
Non-trainable params: 138,914,753
_________________________________________________________________


In [None]:
srgan.generator.summary()

Model: "SRResNet"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 64, 64, 3)]  0                                            
__________________________________________________________________________________________________
rescaling (Rescaling)           (None, 64, 64, 3)    0           input_1[0][0]                    
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 64, 64, 64)   15616       rescaling[0][0]                  
__________________________________________________________________________________________________
p_re_lu (PReLU)                 (None, 64, 64, 64)   262144      conv2d[0][0]                     
___________________________________________________________________________________________

In [None]:
srgan.validator.summary()

Model: "SRValidator"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_2 (InputLayer)         [(None, 256, 256, 3)]     0         
_________________________________________________________________
conv2d_21 (Conv2D)           (None, 256, 256, 64)      1792      
_________________________________________________________________
leaky_re_lu (LeakyReLU)      (None, 256, 256, 64)      0         
_________________________________________________________________
conv2d_22 (Conv2D)           (None, 128, 128, 64)      36928     
_________________________________________________________________
batch_normalization_17 (Batc (None, 128, 128, 64)      256       
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 128, 128, 64)      0         
_________________________________________________________________
conv2d_23 (Conv2D)           (None, 128, 128, 128)     

In [None]:
%tensorboard --logdir "/content/drive/My Drive/Colab Notebooks/SeminarskiB-NN/logs/srgan/mse_content_loss"

In [None]:
generator = keras.models.load_model("/content/drive/My Drive/Colab Notebooks/SeminarskiB-NN/logs/srresnet/20210517-182937/model_checkpoints/epoch_29")



In [None]:
def train_srgan_equal(dataset, srgan, epochs):
  train_dataset = tf.data.Dataset.zip((dataset.lr_images_train, dataset.hr_images_train.map(lambda x: x / 127.5 - 1.0)))
  valid_dataset = tf.data.Dataset.zip((dataset.lr_images_valid, dataset.hr_images_valid.map(lambda x: x / 127.5 - 1.0)))
  
  tensorboard_images_num = 2

  current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
  save_dir = "/content/drive/My Drive/Colab Notebooks/SeminarskiB-NN/logs/srgan/mse_content_loss/" + current_time
  train_log_dir = save_dir + "/tensorboard_train"
  save_model_dir = save_dir + "/model_checkpoints"
  charts_dir = save_dir + "/charts"
  !mkdir "$save_dir"
  !mkdir "$save_model_dir"
  !mkdir "$charts_dir"
  train_summary_writer = tf.summary.create_file_writer(train_log_dir) 

  total_iterations = train_set_size // dataset.batch_size
  step = 0

  for epoch in range(epochs):
    for i, (lr_batch_train, hr_batch_train) in tqdm(enumerate(train_dataset), total=total_iterations, desc=f"Epoch: {epoch}", unit="batches"):
      step += 1
      real = tf.ones([batch_size])
      fake = tf.zeros([batch_size])

      #Discriminator training
      generated_images = srgan.generator.predict_on_batch(lr_batch_train)
      srgan.validator.trainable = True
      srgan.validator.train_on_batch(hr_batch_train,y=real)
      srgan.validator.train_on_batch(generated_images,y=fake)
      srgan.validator.trainable = False

      if step % 20 == 0:
        real_loss, real_acc = srgan.validator.test_on_batch(hr_batch_train,y=real)
        fake_loss, fake_acc = srgan.validator.test_on_batch(generated_images,y=fake)
        disc_loss = (real_loss + fake_loss) / 2
        disc_acc = (real_acc + fake_acc) / 2
        curr_disc_acc = disc_acc

        with train_summary_writer.as_default():
          tf.summary.scalar('real_loss',real_loss,step=step)
          tf.summary.scalar('real_acc',real_acc,step=step)
          tf.summary.scalar('fake_loss',fake_loss,step=step)
          tf.summary.scalar('fake_acc',fake_acc,step=step)
          tf.summary.scalar('disc_loss',disc_loss,step=step)
          tf.summary.scalar('disc_acc',disc_acc,step=step)

      #Generator training
      srgan.model.train_on_batch(lr_batch_train, y=[hr_batch_train, real])

      if step % 20 == 0:
        gan_loss, content_loss, adverserial_loss = srgan.model.test_on_batch(lr_batch_train, y=[hr_batch_train, real])
        generated_images = srgan.generator.predict_on_batch(lr_batch_train)

        real_loss, real_acc = srgan.validator.test_on_batch(hr_batch_train,y=real)
        fake_loss, fake_acc = srgan.validator.test_on_batch(generated_images,y=fake)
        disc_loss = (real_loss + fake_loss) / 2
        disc_acc = (real_acc + fake_acc) / 2
        curr_disc_acc = disc_acc
            
        with train_summary_writer.as_default():
          tf.summary.scalar('content_loss',content_loss,step=step)
          tf.summary.scalar('adverserial_loss',adverserial_loss,step=step)
          tf.summary.scalar('gan_loss',gan_loss,step=step)

      if step % 100 == 0:
         with train_summary_writer.as_default():
            tf.summary.image('train_input_images', lr_batch_train.numpy().astype("uint8")[0:tensorboard_images_num], step=step)
            tf.summary.image('train_genarated_images',((1 + generator.predict_on_batch(lr_batch_train)) * 127.5).astype("uint8")[0:tensorboard_images_num] ,step=step)
            tf.summary.image('train_real_images',((1 + hr_batch_train.numpy()) * 127.5).astype("uint8")[0:tensorboard_images_num] ,step=step)

    valid_batch = valid_dataset.take(1)
    for (lr_batch_valid, hr_batch_valid) in valid_batch:
      with train_summary_writer.as_default():
        tf.summary.image('valid_input_images', lr_batch_valid.numpy().astype("uint8")[0:tensorboard_images_num], step=step)
        tf.summary.image('valid_genarated_images',((1 + generator.predict_on_batch(lr_batch_valid)) * 127.5).astype("uint8")[0:tensorboard_images_num] ,step=step)
        tf.summary.image('valid_real_images',((1 + hr_batch_valid.numpy()) * 127.5).astype("uint8")[0:tensorboard_images_num] ,step=step)

    if (epoch + 1) % 4 == 0:
      srgan.validator.save(save_model_dir + f"/epoch_{epoch+1}/discriminator")
      srgan.generator.save(save_model_dir + f"/epoch_{epoch+1}/generator")


In [None]:
srgan = SRGAN(low_reso_shape, generator=generator)
train_srgan_equal(train_valid_dataset, srgan, 12)

HBox(children=(FloatProgress(value=0.0, description='Epoch: 0', max=875.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='Epoch: 1', max=875.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='Epoch: 2', max=875.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='Epoch: 3', max=875.0, style=ProgressStyle(description_wid…


INFO:tensorflow:Assets written to: /content/drive/My Drive/Colab Notebooks/SeminarskiB-NN/logs/srgan/mse_content_loss/20210706-154155/model_checkpoints/epoch_4/discriminator/assets




INFO:tensorflow:Assets written to: /content/drive/My Drive/Colab Notebooks/SeminarskiB-NN/logs/srgan/mse_content_loss/20210706-154155/model_checkpoints/epoch_4/generator/assets


HBox(children=(FloatProgress(value=0.0, description='Epoch: 4', max=875.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='Epoch: 5', max=875.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='Epoch: 6', max=875.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='Epoch: 7', max=875.0, style=ProgressStyle(description_wid…


INFO:tensorflow:Assets written to: /content/drive/My Drive/Colab Notebooks/SeminarskiB-NN/logs/srgan/mse_content_loss/20210706-154155/model_checkpoints/epoch_8/discriminator/assets




INFO:tensorflow:Assets written to: /content/drive/My Drive/Colab Notebooks/SeminarskiB-NN/logs/srgan/mse_content_loss/20210706-154155/model_checkpoints/epoch_8/generator/assets


HBox(children=(FloatProgress(value=0.0, description='Epoch: 8', max=875.0, style=ProgressStyle(description_wid…