<a href="https://colab.research.google.com/github/Vakurin/SuperResolutionGAN-TF2.0/blob/master/PreGenerator_Training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# PreGenerator-Training Pipeline
---

In [0]:
# Make default version tf2.0 in colab
%tensorflow_version 2.x
%matplotlib inline

!pip install tensorlayer
import tensorlayer as tl

import tensorflow as tf
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.metrics import Mean

import time
import os
import matplotlib.pyplot as plt

# For Colab 
import zipfile
from google.colab import drive

## Load Data And Generator

In [2]:
!git clone https://github.com/Vakurin/SuperResolutionGAN-TF2.0.git
%cd SuperResolutionGAN-TF2.0/

from models import model_generator
from dataset import get_train_data

generator = model_generator()
generator_optimizer = tf.keras.optimizers.Adam(1e-4)

Cloning into 'SuperResolutionGAN-TF2.0'...
remote: Enumerating objects: 20, done.[K
remote: Counting objects:   5% (1/20)[Kremote: Counting objects:  10% (2/20)[Kremote: Counting objects:  15% (3/20)[Kremote: Counting objects:  20% (4/20)[Kremote: Counting objects:  25% (5/20)[Kremote: Counting objects:  30% (6/20)[Kremote: Counting objects:  35% (7/20)[Kremote: Counting objects:  40% (8/20)[Kremote: Counting objects:  45% (9/20)[Kremote: Counting objects:  50% (10/20)[Kremote: Counting objects:  55% (11/20)[Kremote: Counting objects:  60% (12/20)[Kremote: Counting objects:  65% (13/20)[Kremote: Counting objects:  70% (14/20)[Kremote: Counting objects:  75% (15/20)[Kremote: Counting objects:  80% (16/20)[Kremote: Counting objects:  85% (17/20)[Kremote: Counting objects:  90% (18/20)[Kremote: Counting objects:  95% (19/20)[Kremote: Counting objects: 100% (20/20)[Kremote: Counting objects: 100% (20/20), done.[K
remote: Compressing objects:   6% (

## Global Var

In [0]:
PATH_FOR_SAVE_MODELS = '/content/drive/My Drive/SRGAN/pre-train/'
EPOCHS  = 2500

PATH_FOLDER_WITH_IMAGES  = '/tmp/Dataset/DIV2K_train_HR/'
IMG_HR_SIZE = 256
IMG_LR_SIZE = 64
BATCH_SIZE = 4

## Colab Settings

In [4]:
drive.mount('/content/drive')

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly

Enter your authorization code:
··········
Mounted at /content/drive


In [5]:
zip_ref = zipfile.ZipFile("/content/drive/My Drive/SRGAN/DIV2K_train_HR.zip", 'r')
zip_ref.extractall("/tmp/Dataset")
zip_ref.close()

print('total training test images:', len(os.listdir(PATH_FOLDER_WITH_IMAGES)))

total training test images: 800


## Get Dataset

In [0]:
train_ds = get_train_data(IMG_HR_SIZE, IMG_LR_SIZE, BATCH_SIZE, PATH_FOLDER_WITH_IMAGES)

## Training Checkpoints

In [0]:
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")

checkpoint = tf.train.Checkpoint(step=tf.Variable(1),
                                  psnr=tf.Variable(-1.0),
                                  optimizer=generator_optimizer,
                                  model=generator)

checkpoint_manager = tf.train.CheckpointManager(checkpoint=checkpoint,
                                                    directory=checkpoint_dir,
                                                    max_to_keep=3)

### Training Setup

In [0]:
def generator_pre_train(train_ds, epochs, save_path='./pre-train-model.h5'):
  
  loss_mean = Mean()
  n_step_epoch = round(epochs // BATCH_SIZE)

  for epoch in range(n_step_epoch):
    now = time.time()
    for step, (lr, hr) in enumerate(train_ds):
                
        loss, generated_image = train_step(lr, hr)
        loss_mean(loss)

        if step % 1000 == 0:
            loss_value = loss_mean.result()
            loss_mean.reset_states()

            print(f'{epoch * BATCH_SIZE}/{epochs}: loss = {loss_value.numpy():.3f},  ({time.time() - now} s)')
            checkpoint_manager.save()
        
    # Save Image
    if (epoch != 0) and (epoch * BATCH_SIZE % 10 == 0):
          print('____SAVE_____')
          tl.vis.save_images(generated_image.numpy(), [2, 4], os.path.join(save_path, 'pre-train_g_{}.png'.format(epoch)))
          generator.save(os.path.join(save_path, 'pre-generator-{}.h5'.format(epoch * BATCH_SIZE)))
      

    

@tf.function
def train_step(lr, hr):
    with tf.GradientTape() as g_tape:
        generated_image = generator(lr, training=True)
        # Loss MSE
        loss_value = tf.keras.losses.MeanSquaredError()(hr, generated_image)

    gradients = g_tape.gradient(loss_value, generator.trainable_variables)
    generator_optimizer.apply_gradients(zip(gradients, generator.trainable_variables))

    return loss_value, generated_image

## Training

In [13]:
generator_pre_train(train_ds, EPOCHS, save_path=PATH_FOR_SAVE_MODELS)

0/2500: loss = 0.019,  (1.1226136684417725 s)
4/2500: loss = 0.021,  (1.0921790599822998 s)
8/2500: loss = 0.021,  (1.0913665294647217 s)
12/2500: loss = 0.021,  (1.0995068550109863 s)
16/2500: loss = 0.020,  (1.0974464416503906 s)
20/2500: loss = 0.020,  (1.0884041786193848 s)
____SAVE_____
24/2500: loss = 0.020,  (1.0831165313720703 s)
28/2500: loss = 0.020,  (1.1057724952697754 s)
32/2500: loss = 0.020,  (1.044996738433838 s)
36/2500: loss = 0.019,  (1.0598022937774658 s)
40/2500: loss = 0.019,  (1.046924114227295 s)
____SAVE_____
44/2500: loss = 0.019,  (1.0613272190093994 s)
48/2500: loss = 0.022,  (1.0685710906982422 s)
52/2500: loss = 0.020,  (1.071953535079956 s)
56/2500: loss = 0.019,  (1.0297062397003174 s)


KeyboardInterrupt: ignored