In [1]:
import tensorflow as tf

In [2]:
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))
print(tf.test.gpu_device_name())

Num GPUs Available:  1
/device:GPU:0


2022-07-13 23:11:07.365324: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:975] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-07-13 23:11:07.370469: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:975] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-07-13 23:11:07.371625: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:975] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-07-13 23:11:07.373181: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags

In [3]:
import numpy as np
import pandas as pd
import xarray as xr
import h5py

import matplotlib.pyplot as plt

from datetime import datetime
from sklearn.model_selection import train_test_split
from tensorflow.keras import layers, losses
from tensorflow.keras.layers import Input, Lambda, LeakyReLU, Add, Dense, Activation, Flatten, Conv2D, Conv2DTranspose, MaxPooling2D
from tensorflow.keras.models import Model
from tensorflow.keras.initializers import glorot_uniform, constant, TruncatedNormal

%matplotlib inline

In [4]:
with h5py.File('processed_data/np_gan_standard.h5', 'r') as hf:
    data_lr = hf['np_lr'][:]
    data_lr_mean = hf['np_lr_mean'][:]
    data_lr_stddev = hf['np_lr_stddev'][:]
    data_hr = hf['np_hr'][:]
    data_hr_mean = hf['np_hr_mean'][:]
    data_hr_stddev = hf['np_hr_stddev'][:]

In [5]:
print(data_lr.shape)
print(data_lr_mean)
print(data_lr_stddev)
print('\n')
print(data_hr.shape)
print(data_hr_mean)
print(data_hr_stddev)  

(8520, 96, 96, 2)
[ 0.7051484 -1.0147774]
[3.1869051 2.8827915]


(8520, 192, 192, 2)
[ 0.701198  -1.0068085]
[3.149407  2.8781955]


In [6]:
#First split data into train+validation and test set
x_train, x_test, y_train, y_test = train_test_split(data_lr, data_hr, test_size=0.2, random_state=42)

#Next split training again into train and validation
x_train, x_val, y_train, y_val = train_test_split(x_train, y_train, test_size=0.25, random_state=42)

print(x_train.shape)
print(x_val.shape)
print(x_test.shape)

print(y_train.shape)
print(y_val.shape)
print(y_test.shape)

print(np.max(x_train), np.max(x_val), np.max(x_test), np.min(x_train), np.min(x_val), np.min(x_test))
print(np.max(y_train), np.max(y_val), np.max(y_test), np.min(y_train), np.min(y_val), np.min(y_test))

(5112, 96, 96, 2)
(1704, 96, 96, 2)
(1704, 96, 96, 2)
(5112, 192, 192, 2)
(1704, 192, 192, 2)
(1704, 192, 192, 2)
6.890482 5.989441 5.7024393 -5.4777956 -5.0524907 -5.5057893
7.035497 5.8600492 6.0442076 -5.315754 -5.2317853 -5.2996855


In [7]:
def generator(input_shape = (96, 96, 2), nf = 64, r = 2):
    """
    Arguments:
    input_shape -- shape of the images of the dataset, H*W*C
    nf -- integer, the number of filters in all convT layer before super-resolution step
    r -- integer, resolution ratio between output and input

    Returns:
    model -- a Model() instance in Keras
    """
    
    C0 = input_shape[2]
    # Define the input as a tensor with shape input_shape
    X_input = Input(input_shape)

    # Define kernel size and stride used
    k, stride = 3, 1
    
    # Shall we use a mirror padding and finally cutoff the edge, like the paper does? FIXME
    X = Conv2DTranspose(filters=nf, kernel_size=(k, k), strides=(stride, stride), padding='same')(X_input)
    # Shall we use relu, or leaky_relu? FIXME
    X = Activation('relu')(X)

    skip_connection = X
    
    for i in range(16):
        X_shortcut = X
        
        X = Conv2DTranspose(filters=nf, kernel_size=(k, k), strides=(stride, stride), padding='same')(X)
        X = Activation('relu')(X)
        X = Conv2DTranspose(filters=nf, kernel_size=(k, k), strides=(stride, stride), padding='same')(X)
        X = Add()([X, X_shortcut])
        # Are we missing a relu activation here, if we follow the resnet paper? FIXME
    
    X = Conv2DTranspose(filters=nf, kernel_size=(k, k), strides=(stride, stride), padding='same')(X)
    X = Add()([X, skip_connection])
    
    # Start to perform sr
    nf_sr = (r**2) * nf
    X = Conv2DTranspose(filters=nf_sr, kernel_size=(k, k), strides=(stride, stride), padding='same')(X)
    
    sub_layer = Lambda(lambda x:tf.nn.depth_to_space(x,r))
    X = sub_layer(X)
    X = Activation('relu')(X)
    
    X = Conv2DTranspose(filters=C0, kernel_size=(k, k), strides=(stride, stride), padding='same')(X)
    
    model = Model(inputs = X_input, outputs = X)
    
    return model

In [8]:
gen_model = generator(input_shape = (96, 96, 2))
print(gen_model.summary())

2022-07-13 23:11:09.868396: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:975] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-07-13 23:11:09.869588: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:975] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-07-13 23:11:09.870679: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:975] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-07-13 23:11:09.871793: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:975] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-07-13 23:11:09.872868: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:975] successful NUMA node read from S

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 96, 96, 2)]  0           []                               
                                                                                                  
 conv2d_transpose (Conv2DTransp  (None, 96, 96, 64)  1216        ['input_1[0][0]']                
 ose)                                                                                             
                                                                                                  
 activation (Activation)        (None, 96, 96, 64)   0           ['conv2d_transpose[0][0]']       
                                                                                                  
 conv2d_transpose_1 (Conv2DTran  (None, 96, 96, 64)  36928       ['activation[0][0]']         

                                                                                                  
 add_6 (Add)                    (None, 96, 96, 64)   0           ['conv2d_transpose_14[0][0]',    
                                                                  'add_5[0][0]']                  
                                                                                                  
 conv2d_transpose_15 (Conv2DTra  (None, 96, 96, 64)  36928       ['add_6[0][0]']                  
 nspose)                                                                                          
                                                                                                  
 activation_8 (Activation)      (None, 96, 96, 64)   0           ['conv2d_transpose_15[0][0]']    
                                                                                                  
 conv2d_transpose_16 (Conv2DTra  (None, 96, 96, 64)  36928       ['activation_8[0][0]']           
 nspose)  

                                                                                                  
 activation_15 (Activation)     (None, 96, 96, 64)   0           ['conv2d_transpose_29[0][0]']    
                                                                                                  
 conv2d_transpose_30 (Conv2DTra  (None, 96, 96, 64)  36928       ['activation_15[0][0]']          
 nspose)                                                                                          
                                                                                                  
 add_14 (Add)                   (None, 96, 96, 64)   0           ['conv2d_transpose_30[0][0]',    
                                                                  'add_13[0][0]']                 
                                                                                                  
 conv2d_transpose_31 (Conv2DTra  (None, 96, 96, 64)  36928       ['add_14[0][0]']                 
 nspose)  

In [9]:
def discriminator(input_shape = (192, 192, 2)):
    """
    Arguments:
    input_shape -- shape of the images of the dataset, H*W*C

    Returns:
    model -- a Model() instance in Keras
    """
    
    C0 = input_shape[2]
    # Define the input as a tensor with shape input_shape
    X_input = Input(input_shape)
    
    #conv1
    X = Conv2D(filters=32, kernel_size=(3,3), strides=(1,1), padding="same")(X_input)
    X = LeakyReLU(alpha=0.2)(X)
    
    #conv2
    X = Conv2D(filters=32, kernel_size=(3,3), strides=(2,2), padding="same")(X)
    X = LeakyReLU(alpha=0.2)(X)
    
    #conv3
    X = Conv2D(filters=64, kernel_size=(3,3), strides=(1,1), padding="same")(X)
    X = LeakyReLU(alpha=0.2)(X)
    
    #conv4
    X = Conv2D(filters=64, kernel_size=(3,3), strides=(2,2), padding="same")(X)
    X = LeakyReLU(alpha=0.2)(X)
    
    #conv5
    X = Conv2D(filters=128, kernel_size=(3,3), strides=(1,1), padding="same")(X)
    X = LeakyReLU(alpha=0.2)(X)
    
    #conv6
    X = Conv2D(filters=128, kernel_size=(3,3), strides=(2,2), padding="same")(X)
    X = LeakyReLU(alpha=0.2)(X)
    
    #conv7
    X = Conv2D(filters=256, kernel_size=(3,3), strides=(1,1), padding="same")(X)
    X = LeakyReLU(alpha=0.2)(X)
    
    #conv8
    X = Conv2D(filters=256, kernel_size=(3,3), strides=(2,2), padding="same")(X)
    X = LeakyReLU(alpha=0.2)(X)
    
    X = Flatten()(X)
    
    #first fully-connect
    k_init = TruncatedNormal(stddev=0.02)
    X = Dense(units=1024, kernel_initializer=k_init)(X)
    X = LeakyReLU(alpha=0.2)(X)
    
    #second fully-connect, no activation FIXME
    X = Dense(units=1, kernel_initializer=k_init)(X)
    
    model = Model(inputs = X_input, outputs = X)
    return model

In [10]:
disc_model = discriminator(input_shape = (192, 192, 2))
print(disc_model.summary())

Model: "model_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_2 (InputLayer)        [(None, 192, 192, 2)]     0         
                                                                 
 conv2d (Conv2D)             (None, 192, 192, 32)      608       
                                                                 
 leaky_re_lu (LeakyReLU)     (None, 192, 192, 32)      0         
                                                                 
 conv2d_1 (Conv2D)           (None, 96, 96, 32)        9248      
                                                                 
 leaky_re_lu_1 (LeakyReLU)   (None, 96, 96, 32)        0         
                                                                 
 conv2d_2 (Conv2D)           (None, 96, 96, 64)        18496     
                                                                 
 leaky_re_lu_2 (LeakyReLU)   (None, 96, 96, 64)        0   

In [11]:
def compute_losses(x_HR, x_SR, d_HR, d_SR, alpha_advers=0.001, isGAN=False):

    content_loss = tf.reduce_mean((x_HR - x_SR)**2, axis=[1, 2, 3])

    if isGAN:
        g_advers_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=d_SR, labels=tf.ones_like(d_SR))

        d_advers_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=tf.concat([d_HR, d_SR], axis=0),
                                                                labels=tf.concat([tf.ones_like(d_HR), tf.zeros_like(d_SR)], axis=0))

        advers_perf = [tf.reduce_mean(tf.cast(tf.sigmoid(d_HR) > 0.5, tf.float32)), # % true positive
                       tf.reduce_mean(tf.cast(tf.sigmoid(d_SR) < 0.5, tf.float32)), # % true negative
                       tf.reduce_mean(tf.cast(tf.sigmoid(d_SR) > 0.5, tf.float32)), # % false positive
                       tf.reduce_mean(tf.cast(tf.sigmoid(d_HR) < 0.5, tf.float32))] # % false negative

        g_loss = tf.reduce_mean(content_loss) + alpha_advers*tf.reduce_mean(g_advers_loss)
        d_loss = tf.reduce_mean(d_advers_loss)

        return g_loss, d_loss, advers_perf, tf.reduce_mean(content_loss), tf.reduce_mean(g_advers_loss)
    else:
        return tf.reduce_mean(content_loss)

In [12]:
gen_model.load_weights("gan_v1/ckp/")

<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f2cc00f7b20>

In [13]:
print(compute_losses(y_test, gen_model.predict(x_test), None, None, 0.001, False))

2022-07-13 23:11:22.121148: I tensorflow/stream_executor/cuda/cuda_dnn.cc:384] Loaded cuDNN version 8100


tf.Tensor(0.011780102, shape=(), dtype=float32)


In [14]:
adam = tf.keras.optimizers.Adam(learning_rate=1e-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08)
gen_model.compile(optimizer=adam, loss=losses.MeanSquaredError())

gen_model.evaluate(x_test, y_test)



0.01178010180592537

In [15]:
from time import time

train_loss = []
val_loss = []

def pretrain(epochs=20, batch_size=128):
    '''
        This method trains the generator without using a disctiminator/adversarial training. 
        output:  generator model
    '''

    train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(batch_size)
    batch_count = tf.data.experimental.cardinality(train_dataset)
     
    gen_model = generator(input_shape = (96, 96, 2)) 
    adam = tf.keras.optimizers.Adam(learning_rate=1e-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08)

    # Start training
    print('Training network ...')
    for epoch in range(1, epochs+1):
        print('Epoch: %d' %(epoch))
        start_time = time()
        epoch_loss, N = 0, 0
        
        for batch_idx, (batch_LR, batch_HR) in enumerate(train_dataset):
            N_batch = batch_LR.shape[0]
            
            with tf.GradientTape() as gen_tape:
                batch_SR = gen_model(batch_LR, training=True)
                gen_loss = compute_losses(batch_HR, batch_SR, None, None, alpha_advers=0.001, isGAN=False)
            
            grad_of_gen = gen_tape.gradient(gen_loss, gen_model.trainable_variables)
            adam.apply_gradients(zip(grad_of_gen, gen_model.trainable_variables))

            epoch_loss += gen_loss * N_batch
            N += N_batch

        epoch_loss = epoch_loss / N       
        
        val_SR = gen_model.predict(x_val, verbose=0)
        gen_val_loss = compute_losses(y_val, val_SR, None, None, alpha_advers=0.001, isGAN=False)
        
        print('Epoch generator training loss = %.6f, val loss = %6f' %(epoch_loss, gen_val_loss))
        print('Epoch took %.2f seconds\n' %(time() - start_time), flush=True)
        train_loss.append(epoch_loss)
        val_loss.append(gen_val_loss)

    print('Done.')

    return gen_model

In [16]:
gen_model3 = pretrain(5, 128)

Training network ...
Epoch: 1
Epoch generator training loss = 0.550237, val loss = 0.098597
Epoch took 40.95 seconds

Epoch: 2
Epoch generator training loss = 0.078973, val loss = 0.066691
Epoch took 36.89 seconds

Epoch: 3
Epoch generator training loss = 0.058697, val loss = 0.052967
Epoch took 36.83 seconds

Epoch: 4
Epoch generator training loss = 0.048061, val loss = 0.044717
Epoch took 36.59 seconds

Epoch: 5
Epoch generator training loss = 0.041491, val loss = 0.039424
Epoch took 36.50 seconds

Done.


In [17]:
gen_model3.compile(optimizer=adam, loss=losses.MeanSquaredError())
gen_model3.evaluate(x_test, y_test)



0.040661320090293884

In [27]:
from time import time

def train(gen_model, disc_model, epochs=20, batch_size=128, alpha_advers=0.001):
    '''
        This method trains the generator and disctiminator adversarially
        Notice the two model should be argument of this function
        output: generator model and discriminator model
    '''

    train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(batch_size)
    batch_count = tf.data.experimental.cardinality(train_dataset)
    
    adam_slow = tf.keras.optimizers.Adam(learning_rate=2e-5, beta_1=0.9, beta_2=0.999, epsilon=1e-08)
    adam_fast = tf.keras.optimizers.Adam(learning_rate=1e-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08)

    # Start training
    print('Training network ...')
    for epoch in range(1, epochs+1):
        print('Epoch: %d' %(epoch))
        start_time = time()
        epoch_g_loss, epoch_d_loss, N = 0, 0, 0
        
        for batch_idx, (batch_LR, batch_HR) in enumerate(train_dataset):
            N_batch = batch_LR.shape[0]
            
            with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
                batch_SR = gen_model(batch_LR, training=True)
                d_HR = disc_model(batch_HR, training=True)
                d_SR = disc_model(batch_SR, training=True)
                g_loss, d_loss, advers_perf, content_loss, g_advers_loss \
                        = compute_losses(batch_HR, batch_SR, d_HR, d_SR, alpha_advers=0.001, isGAN=True)
            
            grad_of_gen = gen_tape.gradient(g_loss, gen_model.trainable_variables)
            adam_fast.apply_gradients(zip(grad_of_gen, gen_model.trainable_variables))
            
            grad_of_disc = disc_tape.gradient(d_loss, disc_model.trainable_variables)
            adam_slow.apply_gradients(zip(grad_of_disc, disc_model.trainable_variables))

            epoch_g_loss += g_loss * N_batch
            epoch_d_loss += d_loss * N_batch
            N += N_batch


        epoch_g_loss = epoch_g_loss / N       
        epoch_d_loss = epoch_d_loss / N       
        
        val_SR = gen_model.predict(x_val, verbose=0)
        val_d_HR = disc_model.predict(y_val, verbose=0)
        val_d_SR = disc_model.predict(val_SR, verbose=0)
        val_g_loss, val_d_loss, val_advers_perf, val_content_loss, val_g_advers_loss \
                    = compute_losses(y_val, val_SR, val_d_HR, val_d_SR, alpha_advers=0.001, isGAN=True)

        print('Epoch generator loss = %.6f, discriminator loss = %.6f' %(epoch_g_loss, epoch_d_loss))
        print('Epoch val: g_loss = %.6f, d_loss = %.6f, content_loss = %.6f, advers_loss = %.6f' \
              %(val_g_loss, val_d_loss, val_content_loss, val_g_advers_loss))
        print('Epoch took %.2f seconds\n' %(time() - start_time), flush=True)

    print('Done.')

    return gen_model, disc_model

In [24]:
#This is old cell...Use as a baseline...Don't run it again!
gen_model_copy = gen_model
gen_model_gan, disc_model_gan = train(gen_model_copy, disc_model, epochs=20, batch_size=128, alpha_advers=0.001)

Training network ...
Epoch: 1
Epoch generator loss = 0.013467, discriminator loss = 0.550541
Epoch val: g_loss = 0.013649, d_loss = 0.457951, content_loss = 0.012198, advers_loss = 1.451408
Epoch took 71.27 seconds

Epoch: 2
Epoch generator loss = 0.011170, discriminator loss = 0.392340
Epoch val: g_loss = 0.013552, d_loss = 0.312669, content_loss = 0.012297, advers_loss = 1.254653
Epoch took 71.60 seconds

Epoch: 3
Epoch generator loss = 0.011610, discriminator loss = 0.248692
Epoch val: g_loss = 0.014092, d_loss = 0.238705, content_loss = 0.012744, advers_loss = 1.347879
Epoch took 71.25 seconds

Epoch: 4
Epoch generator loss = 0.012828, discriminator loss = 0.195272
Epoch val: g_loss = 0.015999, d_loss = 0.171700, content_loss = 0.013317, advers_loss = 2.681944
Epoch took 71.15 seconds

Epoch: 5
Epoch generator loss = 0.013769, discriminator loss = 0.190774
Epoch val: g_loss = 0.017635, d_loss = 0.212673, content_loss = 0.014325, advers_loss = 3.310584
Epoch took 71.14 seconds

Epoc

In [18]:
#This is another old cell...Use as a baseline...Don't run it again!
gen_model_copy = gen_model
gen_model_gan, disc_model_gan = train(gen_model_copy, disc_model, epochs=20, batch_size=128, alpha_advers=0.001)

Training network ...
Epoch: 1


2022-07-13 23:13:13.058785: I tensorflow/stream_executor/cuda/cuda_blas.cc:1786] TensorFloat-32 will be used for the matrix multiplication. This will only be logged once.


Epoch generator loss = 0.013959, discriminator loss = 0.690694
Epoch val: g_loss = 0.013202, d_loss = 0.687033, content_loss = 0.012513, advers_loss = 0.688950
Epoch took 77.75 seconds

Epoch: 2
Epoch generator loss = 0.011227, discriminator loss = 0.678161
Epoch val: g_loss = 0.012624, d_loss = 0.666862, content_loss = 0.011940, advers_loss = 0.684485
Epoch took 71.49 seconds

Epoch: 3
Epoch generator loss = 0.010499, discriminator loss = 0.655292
Epoch val: g_loss = 0.012619, d_loss = 0.643332, content_loss = 0.011914, advers_loss = 0.705594
Epoch took 71.12 seconds

Epoch: 4
Epoch generator loss = 0.010140, discriminator loss = 0.634736
Epoch val: g_loss = 0.012720, d_loss = 0.621345, content_loss = 0.011954, advers_loss = 0.766147
Epoch took 70.96 seconds

Epoch: 5
Epoch generator loss = 0.010020, discriminator loss = 0.615354
Epoch val: g_loss = 0.012659, d_loss = 0.614914, content_loss = 0.011993, advers_loss = 0.665825
Epoch took 71.05 seconds

Epoch: 6
Epoch generator loss = 0.

In [19]:
#This is another old cell which continues the last cell...Use as a baseline...Don't run it again!
gen_model_gan, disc_model_gan = train(gen_model_gan, disc_model_gan, epochs=100, batch_size=128, alpha_advers=0.001)

Training network ...
Epoch: 1
Epoch generator loss = 0.017364, discriminator loss = 0.575698
Epoch val: g_loss = 0.014610, d_loss = 0.483025, content_loss = 0.013723, advers_loss = 0.886674
Epoch took 71.34 seconds

Epoch: 2
Epoch generator loss = 0.013059, discriminator loss = 0.493132
Epoch val: g_loss = 0.013709, d_loss = 0.500946, content_loss = 0.012655, advers_loss = 1.054259
Epoch took 71.64 seconds

Epoch: 3
Epoch generator loss = 0.012047, discriminator loss = 0.507576
Epoch val: g_loss = 0.013299, d_loss = 0.516748, content_loss = 0.012349, advers_loss = 0.949914
Epoch took 71.17 seconds

Epoch: 4
Epoch generator loss = 0.011527, discriminator loss = 0.522405
Epoch val: g_loss = 0.013220, d_loss = 0.534289, content_loss = 0.012265, advers_loss = 0.955288
Epoch took 70.93 seconds

Epoch: 5
Epoch generator loss = 0.011245, discriminator loss = 0.527741
Epoch val: g_loss = 0.013343, d_loss = 0.548862, content_loss = 0.012327, advers_loss = 1.015619
Epoch took 71.00 seconds

Epoc

Epoch: 43
Epoch generator loss = 0.010957, discriminator loss = 0.473624
Epoch val: g_loss = 0.013990, d_loss = 0.474965, content_loss = 0.012781, advers_loss = 1.208763
Epoch took 71.21 seconds

Epoch: 44
Epoch generator loss = 0.011019, discriminator loss = 0.465066
Epoch val: g_loss = 0.014281, d_loss = 0.476811, content_loss = 0.012904, advers_loss = 1.377354
Epoch took 71.23 seconds

Epoch: 45
Epoch generator loss = 0.011107, discriminator loss = 0.461701
Epoch val: g_loss = 0.014512, d_loss = 0.482131, content_loss = 0.013198, advers_loss = 1.313444
Epoch took 71.49 seconds

Epoch: 46
Epoch generator loss = 0.011182, discriminator loss = 0.463264
Epoch val: g_loss = 0.014340, d_loss = 0.498312, content_loss = 0.012887, advers_loss = 1.453358
Epoch took 71.28 seconds

Epoch: 47
Epoch generator loss = 0.011015, discriminator loss = 0.471588
Epoch val: g_loss = 0.014259, d_loss = 0.483507, content_loss = 0.012950, advers_loss = 1.308823
Epoch took 71.31 seconds

Epoch: 48
Epoch gene

Epoch: 85
Epoch generator loss = 0.011403, discriminator loss = 0.514446
Epoch val: g_loss = 0.014015, d_loss = 0.528680, content_loss = 0.013201, advers_loss = 0.813569
Epoch took 71.26 seconds

Epoch: 86
Epoch generator loss = 0.011351, discriminator loss = 0.512347
Epoch val: g_loss = 0.014078, d_loss = 0.512271, content_loss = 0.013144, advers_loss = 0.933782
Epoch took 71.28 seconds

Epoch: 87
Epoch generator loss = 0.011288, discriminator loss = 0.517924
Epoch val: g_loss = 0.014093, d_loss = 0.515251, content_loss = 0.013174, advers_loss = 0.919740
Epoch took 71.27 seconds

Epoch: 88
Epoch generator loss = 0.011291, discriminator loss = 0.517844
Epoch val: g_loss = 0.014116, d_loss = 0.514136, content_loss = 0.013201, advers_loss = 0.915710
Epoch took 71.25 seconds

Epoch: 89
Epoch generator loss = 0.011325, discriminator loss = 0.516837
Epoch val: g_loss = 0.014176, d_loss = 0.511293, content_loss = 0.013250, advers_loss = 0.926291
Epoch took 71.25 seconds

Epoch: 90
Epoch gene

In [23]:
gen_model_gan.save('gan/v2/gen_model_gan_v2.h5')
disc_model_gan.save('gan/v2/disc_model_gan_v2.h5')



In [25]:
gen_model_reopen = tf.keras.models.load_model('gan/v2/gen_model_gan_v2.h5')
disc_model_reopen = tf.keras.models.load_model('gan/v2/disc_model_gan_v2.h5')



In [28]:
gen_model_reopen, disc_model_reopen = train(gen_model_reopen, disc_model_reopen, epochs=1, batch_size=128, alpha_advers=0.001)

Training network ...
Epoch: 1
Epoch generator loss = 0.016659, discriminator loss = 0.535702
Epoch val: g_loss = 0.014914, d_loss = 0.380175, content_loss = 0.013659, advers_loss = 1.255487
Epoch took 71.95 seconds

Done.


In [29]:
def generator_loss(x_HR, x_SR, d_SR, alpha_advers=0.001):
    
    content_loss = tf.reduce_mean((x_HR - x_SR)**2)
    g_advers_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_SR, labels=tf.ones_like(d_SR)))
    g_loss = content_loss + alpha_advers * g_advers_loss
    
    return g_loss, content_loss, g_advers_loss

def discriminator_loss(d_HR, d_SR):
    
    return tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=tf.concat([d_HR, d_SR], axis=0),
                                                                  labels=tf.concat([tf.ones_like(d_HR), tf.zeros_like(d_SR)], axis=0)))

In [52]:
from time import time

@tf.function
def train_step(generator, discriminator, generator_optimizer, discriminator_optimizer, batch_LR, batch_HR, alpha_advers=0.001):
    
    g_count = 0
    d_loss = tf.constant(0.0)
    d_HR = discriminator(batch_HR, training=False)
    while(d_loss < tf.constant(0.45) and g_count < 20):
        g_count += 1
        with tf.GradientTape() as gen_tape:
            batch_SR = generator(batch_LR, training=True)
            d_SR = discriminator(batch_SR, training=False)
            g_loss, content_loss, g_advers_loss = generator_loss(batch_HR, batch_SR, d_SR, alpha_advers=alpha_advers)

        grad_of_gen = gen_tape.gradient(g_loss, generator.trainable_variables)
        generator_optimizer.apply_gradients(zip(grad_of_gen, generator.trainable_variables))
        d_loss = discriminator_loss(d_HR, d_SR)

    d_count = 0
    d_loss = tf.constant(100.0)
    while(d_loss > tf.constant(0.65) and d_count < 20):
        d_count += 1
        with tf.GradientTape() as disc_tape:
            batch_SR = generator(batch_LR, training=False)
            d_HR = discriminator(batch_HR, training=True)
            d_SR = discriminator(batch_SR, training=True)
            d_loss = discriminator_loss(d_HR, d_SR)

        grad_of_disc = disc_tape.gradient(d_loss, discriminator.trainable_variables)
        discriminator_optimizer.apply_gradients(zip(grad_of_disc, discriminator.trainable_variables))
        
    return g_loss, d_loss, g_count, d_count

    
def train(gen_model, disc_model, epochs=20, batch_size=128, alpha_advers=0.001):
    '''
        This method trains the generator and disctiminator adversarially
        Notice the two model should be argument of this function
        output: generator model and discriminator model
    '''

    train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(batch_size)
    batch_count = tf.data.experimental.cardinality(train_dataset)
    
    g_opt = tf.keras.optimizers.Adam(learning_rate=1e-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08)
    d_opt = tf.keras.optimizers.Adam(learning_rate=1e-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08)

    # Start training
    print('Training network ...')
    for epoch in range(1, epochs+1):
        print('Epoch: %d' %(epoch))
        start_time = time()
        epoch_g_loss, epoch_d_loss, N, g_count_tot, d_count_tot = 0, 0, 0, 0, 0
        
        for batch_idx, (batch_LR, batch_HR) in enumerate(train_dataset):
            N_batch = batch_LR.shape[0]
            g_loss, d_loss, g_count, d_count \
                = train_step(gen_model, disc_model, g_opt, d_opt, batch_LR, batch_HR, alpha_advers)
            
            epoch_g_loss += g_loss * N_batch
            epoch_d_loss += d_loss * N_batch
            N += N_batch
            g_count_tot += g_count
            d_count_tot += d_count

        epoch_g_loss = epoch_g_loss / N       
        epoch_d_loss = epoch_d_loss / N       
        
        val_SR = gen_model.predict(x_val, verbose=0)
        val_d_HR = disc_model.predict(y_val, verbose=0)
        val_d_SR = disc_model.predict(val_SR, verbose=0)
        
        val_g_loss, val_content_loss, val_advers_loss = generator_loss(y_val, val_SR, val_d_SR, alpha_advers)
        val_d_loss = discriminator_loss(val_d_HR, val_d_SR)
        
        print('Epoch generator loss = %.6f, discriminator loss = %.6f, g_count = %d, d_count = %d' %(epoch_g_loss, epoch_d_loss, g_count_tot, d_count_tot))
        print('Epoch val: g_loss = %.6f, d_loss = %.6f, content_loss = %.6f, advers_loss = %.6f' \
              %(val_g_loss, val_d_loss, val_content_loss, val_advers_loss))
        print('Epoch took %.2f seconds\n' %(time() - start_time), flush=True)

    print('Done.')

    return gen_model, disc_model

In [None]:
gen_model_reopen, disc_model_reopen = train(gen_model_reopen, disc_model_reopen, epochs=100, batch_size=128, alpha_advers=0.001)

Training network ...
Epoch: 1
Epoch generator loss = 0.013088, discriminator loss = 0.469322, g_count = 213, d_count = 40
Epoch val: g_loss = 0.013804, d_loss = 0.489812, content_loss = 0.013066, advers_loss = 0.738212
Epoch took 174.13 seconds

Epoch: 2
Epoch generator loss = 0.011306, discriminator loss = 0.462603, g_count = 245, d_count = 40
Epoch val: g_loss = 0.014872, d_loss = 0.460794, content_loss = 0.013433, advers_loss = 1.439021
Epoch took 187.67 seconds

Epoch: 3
Epoch generator loss = 0.011590, discriminator loss = 0.461018, g_count = 193, d_count = 40
Epoch val: g_loss = 0.014561, d_loss = 0.431687, content_loss = 0.013468, advers_loss = 1.092899
Epoch took 153.30 seconds

Epoch: 4
Epoch generator loss = 0.011671, discriminator loss = 0.458729, g_count = 200, d_count = 40
Epoch val: g_loss = 0.014865, d_loss = 0.456652, content_loss = 0.013603, advers_loss = 1.261903
Epoch took 157.56 seconds

Epoch: 5
Epoch generator loss = 0.011773, discriminator loss = 0.458742, g_coun