In [2]:
! nvidia-smi

Sun Sep  6 15:06:45 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 450.66       Driver Version: 418.67       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   48C    P0    29W / 250W |      0MiB / 16280MiB |      0%      Default |
|                               |                      |                 ERR! |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [None]:
from google.colab import drive
drive.mount('/content/drive')
% cd 'drive/My Drive'

In [None]:
import tensorflow as tf 
import numpy as np
from tensorflow.keras.layers import *
from matplotlib import pyplot as plt
import os
import shutil
import pickle 
from tqdm.notebook import tqdm
from zipfile import ZipFile
from glob import glob
import time


# Configs

In [None]:
#@title Config
test_name = "RMSprop" #@param ["Adam", "RMSprop", "Adagrad"]
batch_size =   32#@param {type:"number"}
hr_height =  256 #@param {type:"number"}
hr_width =  256 #@param {type:"number"}
scale = 4 #@param {type:"number"}
RESET_CHECKPOINTS = False #@param {type:"boolean"}
training_path= '/content/Data/' #@param {type:"string"}

#@title training parameters
iterations = 151 #@param {type:"number"}
evaluation_interval =  5 #@param {type:"number"}
interval= 10       #@param {type:"number"}
gen_lr = 1e-4 #@param {type:"number"}
disc_lr = 1e-4 #@param {type:"number"}
K = 2 #@param {type:"number"}

ad_loss_weight = 1e-3 #@param {type:"number"}
gf = 32 #@param {type:"number"}
gk =  4#@param {type:"number"}

#@title image parameters
channels  = 1 #@param {type:"number"}
scale = 4 #@param {type:"number"}
total_images = 100000 #@param {type:"number"}


train_len = total_images*3//4 
val_len = total_images//4

hr_shape = (hr_height , hr_width , channels)
lr_height = hr_height // scale
lr_width  = hr_width // scale
lr_shape  = (lr_height , lr_width , channels )
total = train_len//batch_size +1

 

In [None]:

#########Create directory for samples and checkpoints and models ######### 
dir_path= './samples/%s'% test_name
os.makedirs(dir_path, exist_ok=True)
dir_path= './checkpoints/%s'% test_name
os.makedirs(dir_path, exist_ok=True)
dir_path= './models/%s'% test_name
os.makedirs(dir_path, exist_ok=True)

# Data preprocessing

In [None]:
ex_discriptoin = { 
    'hr': tf.io.FixedLenFeature((), tf.string),
    'lr': tf.io.FixedLenFeature((), tf.string)
}

def parse_example(ex): 
  example = tf.io.parse_single_example(ex, ex_discriptoin)
  hr = example['hr']
  lr = example['lr']
  hr = tf.io.decode_jpeg(hr, channels=1)
  lr = tf.io.decode_jpeg(lr, channels=1)
  return hr, lr

def normalize(hr, lr): 
  lr = lr/255
  hr = hr/255 
  hr = 2*hr-1
  return hr, lr

record_files = glob('/content/drive/My Drive/tfrecords_data/*')
record_ds = tf.data.TFRecordDataset(record_files, num_parallel_reads=tf.data.experimental.AUTOTUNE)
val_ds = record_ds.take(val_len)
val_ds = val_ds.map(parse_example, num_parallel_calls=tf.data.experimental.AUTOTUNE).map(normalize, num_parallel_calls=tf.data.experimental.AUTOTUNE)
train_ds = record_ds.skip(val_len).take(train_len)
train_ds = train_ds.map(parse_example, num_parallel_calls=tf.data.experimental.AUTOTUNE).map(normalize, num_parallel_calls=tf.data.experimental.AUTOTUNE)

In [None]:
def take_elements(n): 
    hr= []
    lr = []
    for i, j in train_ds.shuffle(1000).take(n):
        hr.append(i)
        lr.append(j)
    return tf.convert_to_tensor(hr), tf.convert_to_tensor(lr)

# Sample Images


In [None]:
def compute_metrics(hr, fake):
   
    # hr and fake must be in range [0,1]
    
    psnr = tf.image.psnr(hr, fake, 1)
    ssim = tf.image.ssim(hr, fake, 1) 
    
    return psnr, ssim

def sample_image(epoch, test_name=test_name, save_sample=True): 
        r, c = 2, 2
        imgs_hr, imgs_lr = take_elements(2)

        fake_hr = G(imgs_lr)  # output in range [-1, 1]
        
        ######################calculate metrics###########
        fake_hr = 0.5*fake_hr +0.5  # range 0, 1
        imgs_hr = 0.5*imgs_hr +0.5  # range 0, 1
        
        psnr, ssim =compute_metrics(fake_hr, imgs_hr)
        
        print('PSNR= ', np.mean(psnr))
        print('SSIM= ', np.mean(ssim))
        
        # Save generated images and the high resolution originals
        titles = ['Generated', 'Original']
        fig, axs = plt.subplots(r, c)
        figSize = fig.get_size_inches()*4
        fig.set_size_inches(figSize)
        cnt = 0
        for row in range(r):
            for col, image in enumerate([fake_hr, imgs_hr]):
                axs[row, col].imshow(image[row,:,:,0] , cmap='gray')
                axs[row, col].set_title(titles[col])
                axs[row, col].axis('off')
            cnt += 1

        if save_sample:     
          fig.savefig("samples/%s/%d.png" % (test_name, epoch))
          # Save low resolution images for comparison
          plt.imsave(arr= fake_hr[0,:,:,0],fname=('samples/%s/%d_generated.jpg' % (test_name, epoch)), cmap='gray' )
          plt.imsave(arr= imgs_hr[0,:,:,0],fname=('samples/%s/%d_real.jpg' % (test_name, epoch) ), cmap='gray' )
        else: 
          plt.show()
        plt.close()
        
        

# **Models**

## Build Generator

In [None]:
def pixel_shuffle(scale): 
  return lambda x: tf.nn.depth_to_space(x, scale)


def upsample(x_in, num_filters): 
  lyr = Conv2D(num_filters, 4, padding='same')(x_in)
  lyr = Lambda(pixel_shuffle(scale=2))(lyr)
  return LeakyReLU()(lyr)

def RRDB(lyr, name): 
  layer1 = Conv2D(gf*2, 3, strides=1, padding='same', name=f'a_{name}')(lyr)
  layer1 = BatchNormalization(epsilon=1e-5, beta_initializer='glorot_normal', gamma_initializer='glorot_normal')(layer1)
  layer1 = LeakyReLU()(layer1)
  tmp = Add()([lyr, layer1])
  layer2 =  Conv2D(gf*2, 3, strides=1, padding='same', name=f'b_{name}')(tmp)
  layer2 = BatchNormalization(epsilon=1e-5, beta_initializer='glorot_normal', gamma_initializer='glorot_normal')(layer2)
  layer2 = LeakyReLU()(layer2)
  tmp = Add()([lyr, layer1, layer2])
  layer3 =  Conv2D(gf*2, 3, strides=1, padding='same', name=f'c_{name}')(tmp)
  layer3 = BatchNormalization(epsilon=1e-5, beta_initializer='glorot_normal', gamma_initializer='glorot_normal')(layer3)
  layer3 = LeakyReLU()(layer3)
  tmp = Add()([lyr, layer1, layer2, layer3])
  layer4 =  Conv2D(gf*2, 3, strides=1, padding='same', name=f'd_{name}')(tmp)
  layer4 = BatchNormalization(epsilon=1e-5, beta_initializer='glorot_normal', gamma_initializer='glorot_normal')(layer4)
  layer4 = LeakyReLU()(layer4)
  return Add()([lyr, layer1, layer2, layer3, layer4])
def res_block(pre_layer):
    lyr = Conv2D(gf, (3,3) , padding='same', strides=1)(pre_layer)
    lyr = BatchNormalization(epsilon=1e-5, beta_initializer='glorot_normal', gamma_initializer='glorot_normal')(lyr)
    lyr = LeakyReLU()(lyr)

    lyr = Conv2D(gf, (3,3) , padding='same', strides=1)(lyr)
    lyr = BatchNormalization(epsilon=1e-5, beta_initializer='glorot_normal', gamma_initializer='glorot_normal')(lyr)
    lyr = LeakyReLU()(lyr)


    lyr = Conv2D(gf, (3,3) , padding='same', strides=1)(lyr)
    lyr = BatchNormalization(epsilon=1e-5, beta_initializer='glorot_normal', gamma_initializer='glorot_normal')(lyr)
    lyr = LeakyReLU()(lyr)
    return Add()([lyr, pre_layer])

  

def build_G():  

    input_layer = Input(shape= (None, None, channels)) 
    
    tmp = Conv2D(gf*2, (3,3) , padding='same', strides=1)(input_layer)
    # extracting basic details:
      ####  Branch A 
    branch_A =  Conv2D(gf, (3,3) , padding='same', strides=1)(input_layer)
    branch_A = res_block(branch_A) 
    branch_A = res_block(branch_A)
    branch_A = res_block(branch_A)
      #### Branch B
    branch_B =  Conv2D(gf, (3,3) , padding='same', strides=1)(input_layer)
    branch_B = res_block(branch_B) 
    branch_B = res_block(branch_B)
    branch_B = res_block(branch_B)

    layer = Concatenate()([branch_A, branch_B])
    layer = RRDB(layer, '1')
    layer = RRDB(layer, '2')

    layer = Add()([tmp, layer])
    layer = upsample(layer, gf*4)
    layer = upsample(layer, gf*4)
    layer = Conv2D(1, 9, padding='same', activation='tanh')(layer)
    return tf.keras.models.Model(input_layer, layer)




## Build Discriminator

In [None]:
def build_D():
    # seems to have vanishing gradients , try to replace relu with LeakyRelu
    input_layer= Input(shape=hr_shape)
    
    layer = Conv2D(64, 3, strides=1)(input_layer)
    layer = LeakyReLU(0.2)(layer)
    
    layer = Conv2D(64, 3, strides=2)(layer)
    layer = BatchNormalization()(layer)
    layer = LeakyReLU(0.2)(layer)
    
    layer = Conv2D(64*2, 3, strides=1)(layer)
    layer = BatchNormalization()(layer)
    layer = LeakyReLU(0.2)(layer)
    
    layer = Conv2D(64*2, 3, strides=2)(layer)
    layer = BatchNormalization()(layer)
    layer = LeakyReLU(0.2)(layer)
    
    layer = Conv2D(64*4, 3, strides=1)(layer)
    layer = BatchNormalization()(layer)
    layer = LeakyReLU(0.2)(layer)
    
    layer = Conv2D(64*4, 3, strides=2)(layer)
    layer = BatchNormalization()(layer)
    layer = LeakyReLU(0.2)(layer)
    
    layer = Conv2D(64*8, 3, strides=1)(layer)
    layer = BatchNormalization()(layer)
    layer = LeakyReLU(0.2)(layer)
    
    layer = Conv2D(64*8, 3, strides=2)(layer)
    layer = BatchNormalization()(layer)
    layer = LeakyReLU(0.2)(layer)
    
    layer= Flatten()(layer)
    layer = Dense(1024)(layer)
    layer = LeakyReLU(0.2)(layer)
    output_layer = Dense(1)(layer)
    return tf.keras.models.Model(input_layer, output_layer) 


# Checkpoints config


In [None]:
def save_checkpoint(step_num, psnr_list, ssim_list, test_name=test_name):
    G.save('./checkpoints/{}/generator.h5'.format(test_name))
    D.save('./checkpoints/{}/discriminator.h5'.format(test_name))
    saving_file= open('./checkpoints/{}/steps_psnr_ssim.pickle'.format(test_name) , 'wb')
    pickle.dump((step_num, psnr_list, ssim_list) ,saving_file)
    saving_file.close()

    G.save(f'./checkpoints/{test_name}/generator_old.h5')
    D.save(f'./checkpoints/{test_name}/discriminator_old.h5')
    print('#####################checkpoint saved#########################')

def load_checkpoint(test_name=test_name):
    try:
      g=tf.keras.models.load_model(f'./checkpoints/{test_name}/generator.h5')
    except OSError:
      g=tf.keras.models.load_model(f'./checkpoints/{test_name}/generator_old.h5')
    try:
      d=tf.keras.models.load_model('./checkpoints/{}/discriminator.h5'.format(test_name))
    except OSError:
      d=tf.keras.models.load_model('./checkpoints/{}/discriminator_old.h5'.format(test_name))
      
    saving_file= open('./checkpoints/{}/steps_psnr_ssim.pickle'.format(test_name) , 'rb')
    step, psnr_list, ssim_list=pickle.load(saving_file)
    saving_file.close()
    return  step, psnr_list, ssim_list , g, d
   

def reset_checkpoint(test_name=test_name): 
  
  dir_path= './samples/%s'% test_name
  if os.path.exists(dir_path): 
    shutil.rmtree(dir_path)
    os.mkdir(dir_path)
  dir_path= './checkpoints/%s'% test_name
  if os.path.exists(dir_path): 
    shutil.rmtree(dir_path)
    os.mkdir(dir_path)
  dir_path= './models/%s'% test_name
  if os.path.exists(dir_path): 
    shutil.rmtree(dir_path)
    os.mkdir(dir_path)

  
  save_checkpoint(0, [], [], test_name)
  print('reset checkpoint')
  


### Load checkpoints


In [None]:
if RESET_CHECKPOINTS: 
  G = build_G()
  D = build_D()
  reset_checkpoint()

last_epoch, psnr_list, ssim_list, G, D= load_checkpoint(test_name)
print(last_epoch, psnr_list, ssim_list)


# Losses and optimizers

In [None]:
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)
mse_loss_fn = tf.keras.losses.MeanSquaredError()
def disc_loss(real_logits, fake_logits): 
    real_labels = tf.zeros_like(real_logits) + 0.05*tf.random.normal(real_logits.shape,mean=0)
    fake_labels = tf.ones_like(fake_logits) + 0.05*tf.random.normal(fake_logits.shape,mean=0)
    real_loss = cross_entropy(real_labels, real_logits)
    fake_loss = cross_entropy(fake_labels, fake_logits)
    return 0.5*(real_loss+fake_loss)
   

In [None]:
def gen_loss(fake_logits): 
    # fake logits is the discriminator's decision  about the images came from the generator
    return cross_entropy(tf.zeros_like(fake_logits), fake_logits)

## Optimizer


In [None]:
if test_name == 'Adam': 
  disc_opt = tf.keras.optimizers.Adam(disc_lr)
  gen_opt = tf.keras.optimizers.Adam(gen_lr)

elif test_name == 'Adagrad': 
  disc_opt = tf.keras.optimizers.Adagrad(disc_lr)
  gen_opt = tf.keras.optimizers.Adagrad(gen_lr)

elif test_name=='RMSprop':
  disc_opt = tf.keras.optimizers.RMSprop(disc_lr)
  gen_opt = tf.keras.optimizers.RMSprop(gen_lr)

else: 
  print("Error optimizer is not set")

# Training Function

In [None]:
def train_gen(hr, lr):
    with tf.GradientTape(persistent=True) as tape:
        fake= G(lr, training=True)
        fake_logits = D(fake, training=False)
        content_loss = mse_loss_fn(hr, fake)
        adv_loss = gen_loss(fake_logits)
        loss = content_loss + ad_loss_weight*adv_loss
    
    gen_grads= tape.gradient(loss, G.trainable_variables)
    gen_opt.apply_gradients(zip(gen_grads, G.trainable_variables))
    adv_grads = tape.gradient(adv_loss, G.trainable_variables)

    return {
        'content_loss': tf.math.reduce_mean(content_loss), 
        'adv_loss': tf.math.reduce_mean(adv_loss), 
        'deep_grads': np.mean(gen_grads[-5]),
        'shallow_grads': np.mean(gen_grads[4]),
        'adv_deep_grads': np.mean(adv_grads[-3]),
        'adv_shallow_grads': np.mean(adv_grads[0])
    }
        
    
    
def train_disc(hr, lr):
    with tf.GradientTape() as tape:
        fake= G(lr, training=False)
        real_logits = D(hr, training=True)
        fake_logits = D(fake, training=True)
        loss = disc_loss(real_logits, fake_logits)
    
    disc_grads = tape.gradient(loss, D.trainable_variables)
    disc_opt.apply_gradients(zip(disc_grads, D.trainable_variables))
    
    return {
        'loss': tf.math.reduce_mean(loss),
        'deep_grads': np.mean(disc_grads[-1]), 
        'shallow_grads': np.mean(disc_grads[0])
        
    }


# Evaluate Function

In [None]:
def evaluate():
    adv_loss= 0
    d_loss= 0
    content_loss = 0
    psnr = []
    ssim =[] 
    step = 0
    for batch_hr, batch_lr in tqdm(val_ds.batch(batch_size), total= val_len//batch_size, unit=' batch',desc=f'Evaluating: '):
        fake= G(batch_lr, training=False)
        fake_logits = D(fake, training=False)
   
        c_loss = mse_loss_fn(batch_hr, fake)
        c_loss = tf.math.reduce_mean(c_loss)
        
        a_loss = gen_loss(fake_logits)
        t_loss = c_loss + ad_loss_weight*a_loss

        real_logits = D(batch_hr, training=False)
        d_loss = disc_loss(real_logits, fake_logits)
        
        adv_loss += tf.math.reduce_mean(a_loss)
        d_loss += tf.math.reduce_mean(d_loss)
        content_loss += tf.math.reduce_mean(c_loss)
        
        step +=1
      
           
        fake= fake*0.5+0.5
        batch_hr= batch_hr*0.5+0.5
        p, s= compute_metrics(batch_hr, fake)
        p = tf.math.reduce_mean(p)
        s = tf.math.reduce_mean(s)

        psnr.append(p)
        ssim.append(s)
          
        
    print("*-"*25)
    mean_psnr = np.mean(psnr)
    mean_ssim = np.mean(ssim)
    adv_loss_mean = adv_loss/total
    disc_loss_mean = d_loss/total
    content_loss_mean = content_loss/total
    print("*-"*25)
    print("disc loss = {:.3f} adv loss={:.3f}  content_loss = {:.3f}\nOverall PSNR: {:.3f}, SSIM: {:.3f}".format(disc_loss_mean,  adv_loss_mean,content_loss_mean, mean_psnr, mean_ssim))    
    print("*-"*25)


    #############saving checkpoint#################
    

# **Training Loops**

In [None]:
for epoch in range(last_epoch, iterations):
    adv_loss= 0
    d_loss= 0
    content_loss = 0
    psnr = []
    ssim =[] 
    adv_sh_grads = 0 
    adv_deep_grads = 0 
    gen_sh_grads = 0 
    gen_deep_grads = 0 
    disc_sh_grads = 0 
    disc_deep_grads = 0 
    step = 0
    k = K
    checkout_interval = total//10
    for batch_hr, batch_lr in tqdm(train_ds.batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE), total= total, unit=' batch',desc=f'Epoch {epoch}'):
        
        d = train_disc(batch_hr, batch_lr)
        if k == K:
          g = train_gen(batch_hr, batch_lr)
          k = 0
        adv_loss += g['adv_loss']
        d_loss += d['loss']
        content_loss += (g['content_loss'])
        adv_sh_grads += g['adv_shallow_grads']
        adv_deep_grads += g['adv_deep_grads']
        gen_sh_grads += g['shallow_grads']
        gen_deep_grads += g['deep_grads']
 
        disc_sh_grads += d['shallow_grads']
        disc_deep_grads += d['deep_grads']
        k+=1
        step +=1
        if step%(checkout_interval)== 0 : 
            fake= G(batch_lr, training=False)
            fake= fake*0.5+0.5
            batch_hr= batch_hr*0.5+0.5
            p, s= compute_metrics(batch_hr, fake)
            psnr.append(p)
            ssim.append(s)
            print(f'\n adv_loss: {adv_loss/step:.4f} content_loss: {content_loss/step:.4f}  disc_loss: {d_loss/step:.4f} \n psnr: {np.mean(p):.4f}, ssim: {np.mean(s):.4f}')
            # sample_image(epoch, test_name, save_sample=False)
            # save_checkpoint(epoch, psnr_list, ssim_list, test_name)
        
    print("*-"*25)
    print('GRADIENTS:')
    
    print('Gen gradients: shallow: {}   deep:{}'.format(gen_sh_grads/step, gen_sh_grads/step))
    print('disc gradients: shallow: {}   deep:{}'.format(disc_sh_grads/step, disc_sh_grads/step))
    print('adv gradients: shallow: {}   deep:{}'.format(adv_sh_grads/step, adv_sh_grads/step))
    mean_psnr = np.mean(psnr)
    mean_ssim = np.mean(ssim)
    adv_loss_mean = adv_loss/total
    disc_loss_mean = d_loss/total
    content_loss_mean = content_loss/total
    print("*-"*25)
    print("disc loss = {:.3f} adv loss={:.3f}  content_loss = {:.3f}\nOverall PSNR: {:.3f}, SSIM: {:.3f}".format(disc_loss_mean,  adv_loss_mean,content_loss_mean, mean_psnr, mean_ssim))    
    print("*-"*25)

    #############saving checkpoint#################
    
    psnr_list.append(mean_psnr)
    ssim_list.append(mean_ssim) 
    save_checkpoint(epoch+1, psnr_list, ssim_list, test_name)
    if epoch%interval == 0 :
        G.save('./models/{}/G_{}.h5'.format(test_name, epoch))
        D.save('./models/{}/D_{}.h5'.format(test_name, epoch))

    if epoch%evaluation_interval==0:
      evaluate()
    sample_image(epoch, test_name, save_sample=True)