In [3]:
# Written by W.T. Chung

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import tensorflow as tf
import matplotlib.pyplot as plt
import os


In [4]:
tf.keras.utils.set_random_seed(812)

# This will make TensorFlow ops as deterministic as possible, but it will
# affect the overall performance, so it's not enabled by default.
# `enable_op_determinism()` is introduced in TensorFlow 2.9.
tf.config.experimental.enable_op_determinism()

input_path = './dataset/'

In [5]:
#taken from keras tutorial: https://keras.io/examples/vision/edsr/
from tensorflow.keras import layers
from tensorflow.keras.models import Model


# Residual Block

def ResBlock(inputs):
    x = layers.Conv2D(64, 3, padding="same", activation="relu")(inputs)
    x = layers.Conv2D(64, 3, padding="same")(x)
    x = layers.Add()([inputs, x])
    return x


# Upsampling Block
def Upsampling(inputs, factor=2, **kwargs):
    x = layers.Conv2D(64 * (factor ** 2), 3, padding="same", **kwargs)(inputs)
    x = tf.nn.depth_to_space(x, block_size=factor)
    x = layers.Conv2D(64 * (factor ** 2), 3, padding="same", **kwargs)(x)
    x = tf.nn.depth_to_space(x, block_size=factor)
    x = layers.Conv2D(64 * (factor ** 2), 3, padding="same", **kwargs)(x)
    x = tf.nn.depth_to_space(x, block_size=factor)
    return x


def make_model(num_filters=64, num_of_residual_blocks=16):
    # Flexible Inputs to input_layer
    input_layer = layers.Input(shape=(None, None, 1))
    # Scaling Pixel Values
    x = x_new = layers.Conv2D(num_filters, 3, padding="same")(input_layer)

    # 16 residual blocks
    for _ in range(num_of_residual_blocks):
        x_new = ResBlock(x_new)

    x_new = layers.Conv2D(num_filters, 3, padding="same")(x_new)
    x = layers.Add()([x, x_new])

    x = Upsampling(x)
    output_layer = layers.Conv2D(1, 3, padding="same")(x)

    return Model(input_layer, output_layer)

model = make_model()
model.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_1 (InputLayer)        [(None, None, None, 1)]      0         []                            
                                                                                                  
 conv2d (Conv2D)             (None, None, None, 64)       640       ['input_1[0][0]']             
                                                                                                  
 conv2d_1 (Conv2D)           (None, None, None, 64)       36928     ['conv2d[0][0]']              
                                                                                                  
 conv2d_2 (Conv2D)           (None, None, None, 64)       36928     ['conv2d_1[0][0]']            
                                                                                              

In [6]:
#get filenames
def getFiles(mode = "train",scalar = 'YOH'):
    HR_path = input_path + "HR/"+scalar+'/' + mode
    files = os.listdir(HR_path)
    train_files = []
    for idx,file in enumerate(files):
        if file.startswith('YOH'):
            train_files.append(file)
    for idx,file in enumerate(train_files):
        train_files[idx] = file[3:]
    return train_files

In [7]:
train_files = getFiles()
val_files =  getFiles('val')
test_files =  getFiles('test')

In [8]:
my_mean = 0.003057
my_std = 0.002693
def getXY(idx,filenames,mode = "train",scalar = 'YOH'):
    LR_path = input_path + "LR/"+scalar+'/' + mode
    HR_path = input_path + "HR/"+scalar+'/' + mode

    X = (np.fromfile(LR_path + "/"+scalar + filenames[idx], dtype="<f4").reshape(16,16,1) - my_mean)/my_std
    Y = (np.fromfile(HR_path + "/"+scalar + filenames[idx], dtype="<f4").reshape(128,128,1)- my_mean)/my_std
    X = tf.convert_to_tensor(X,dtype=tf.float32)
    Y = tf.convert_to_tensor(Y,dtype=tf.float32)
    return [X,Y]

def getTrainXY(idx):
    return getXY(idx,train_files,mode = "train")

def getValXY(idx):
    return getXY(idx,val_files,mode = "val")
def getTestXY(idx):
    return getXY(idx,val_files,mode = "test")


In [9]:
X0,Y0 = getXY(0,train_files,mode = "train")
print("Feature shapes:")
print( X0.shape)
print("Label shapes:")
print( Y0.shape)

nx_in, ny_in, nc_in = X0.shape
nx_out, ny_out, nc_out = Y0.shape

Feature shapes:
(16, 16, 1)
Label shapes:
(128, 128, 1)


In [10]:
def load_train_dataset_wrapper(file_idx):
    return tf.py_function(getTrainXY, inp=[file_idx], Tout=[tf.float32,tf.float32])

def load_val_dataset_wrapper(file_idx):
    return tf.py_function(getValXY, inp=[file_idx], Tout=[tf.float32,tf.float32])

batch_size = 32 #divisible by number of gpus
nfile = len(train_files)
train_ds = tf.data.Dataset.from_tensor_slices(range(nfile))
train_ds = (train_ds
    .shuffle(nfile)
    .map(load_train_dataset_wrapper, num_parallel_calls=tf.data.AUTOTUNE)
#     .map(scale_wrapper, num_parallel_calls=tf.data.AUTOTUNE)
    .batch(batch_size)
    .prefetch(tf.data.AUTOTUNE)
)

nfile = len(val_files)
val_ds = tf.data.Dataset.from_tensor_slices(range(nfile))
val_ds = (val_ds
    .map(load_val_dataset_wrapper, num_parallel_calls=tf.data.AUTOTUNE)
#     .map(scale_wrapper, num_parallel_calls=tf.data.AUTOTUNE)
    .batch(batch_size)
    .prefetch(tf.data.AUTOTUNE)
)

In [11]:
#create directories for checkpoints and logs
log_dir = "./logs"
checkpoint_dir = "./ckpt"
if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)
if not os.path.exists(log_dir):
    os.makedirs(log_dir)

# Either restore the latest model, or create a fresh one if there is no checkpoint available.
def make_or_restore_model(alpha=1e-3,checkpoint_dir=checkpoint_dir,ckpt=None):
    model =  make_model()
    #set up save/load checkpoints
    if ckpt:
      latest_checkpoint = checkpoint_dir + "/ckpt-" +str(ckpt)
      print("Restoring from", latest_checkpoint)
      nckpt = int(latest_checkpoint.split('-')[-1])
      restored_model =  tf.keras.models.load_model(latest_checkpoint)
      model.set_weights(restored_model.get_weights())
    else:
      checkpoints = [checkpoint_dir + "/" + name for name in os.listdir(checkpoint_dir)]
      nckpt = 0
      if checkpoints:
          latest_checkpoint = max(checkpoints, key=os.path.getctime)
          print("Restoring from", latest_checkpoint)
          nckpt = int(latest_checkpoint.split('-')[-1])
          restored_model =  tf.keras.models.load_model(latest_checkpoint)
          model.set_weights(restored_model.get_weights())
      else:
          print("Creating a new model")

    #compile model with optimizerloss functions,metrics
    #NOTE: Although we're only using MSE as the public metric here, we will be also evaluating with SSIM-based metrics to search for best models so you may want to monitor other metrics
    model.compile(optimizer = tf.keras.optimizers.Adam(learning_rate=alpha),
        loss=tf.keras.losses.MeanSquaredError(reduction='sum_over_batch_size'),
           metrics =[tf.keras.metrics.MeanAbsoluteError()]) #add other metrics here

    return [model,nckpt]

In [14]:
def run_training(train_dataset=train_ds,val_dataset=val_ds,epochs=100,lr=1e-3,ckpt=None,checkpoint_dir=checkpoint_dir,log_dir=log_dir):
    # Create a MirroredStrategy for multi-gpu
    strategy = tf.distribute.MirroredStrategy()
    print('Number of GPUs: {}'.format(strategy.num_replicas_in_sync))

    # For multi-gpu: Open a strategy scope and create/restore the model
    with strategy.scope():
        [model,nckpt] = make_or_restore_model(alpha=lr,ckpt=ckpt,checkpoint_dir=checkpoint_dir)
    #set up checkpoints and logs
    callbacks = [
        # This callback saves a SavedModel every epoch
        tf.keras.callbacks.ModelCheckpoint(
            filepath=checkpoint_dir + "/ckpt-{epoch}", save_freq="epoch"
        ),
        # This callback logs every epoch
        tf.keras.callbacks.CSVLogger(log_dir + "/model_history_log.csv", append=True)
    ]

    #train
    model.fit(
        train_dataset,
        epochs=epochs,
        initial_epoch=nckpt,
        callbacks=callbacks,
        verbose=1,
        validation_data = val_dataset
    )

    #save at the end of epoch
    model.save('./final_model')

    
def finetune(train_ds=train_ds,val_dataset=val_ds,epochs=150,lr=1e-5,ckpt=None,checkpoint_dir=checkpoint_dir,log_dir=log_dir):
    run_training(train_dataset=train_ds,val_dataset=val_ds,epochs=epochs,lr=lr,ckpt=ckpt,checkpoint_dir=checkpoint_dir,log_dir=log_dir)


In [17]:
finetune(epochs=193,lr=1e-4)

Number of GPUs: 1
Restoring from /content/drive/My Drive/pci_invited/ckptPretrainedV2/ckpt-92
Epoch 93/193
Epoch 94/193
Epoch 95/193
Epoch 96/193
Epoch 97/193
Epoch 98/193
Epoch 99/193
Epoch 100/193
Epoch 101/193
Epoch 102/193
Epoch 103/193
Epoch 104/193
Epoch 105/193
Epoch 106/193
Epoch 107/193
Epoch 108/193
Epoch 109/193
Epoch 110/193
Epoch 111/193
Epoch 112/193
Epoch 113/193
Epoch 114/193
Epoch 115/193
Epoch 116/193
Epoch 117/193
Epoch 118/193
Epoch 119/193
Epoch 120/193
Epoch 121/193
Epoch 122/193
Epoch 123/193
Epoch 124/193
Epoch 125/193
Epoch 126/193
Epoch 127/193
Epoch 128/193
Epoch 129/193
Epoch 130/193
Epoch 131/193
Epoch 132/193
Epoch 133/193
Epoch 134/193
Epoch 135/193
Epoch 136/193
Epoch 137/193
Epoch 138/193
Epoch 139/193
Epoch 140/193
Epoch 141/193
Epoch 142/193
Epoch 143/193
Epoch 144/193
Epoch 145/193
Epoch 146/193
Epoch 147/193
Epoch 148/193
Epoch 149/193
Epoch 150/193
Epoch 151/193
Epoch 152/193
Epoch 153/193
Epoch 154/193
Epoch 155/193
Epoch 156/193
Epoch 157/193
Epo

In [None]:
#if not exist mkdir
if not os.path.exists('./outputScratch'):
    os.makedirs('./outputScratch/'+checkpoint_dir)
    os.makedirs('./outputScratch/'+log_dir)

#train from scratch
run_training(epochs=250,ckpt=1,lr=1e-4,checkpoint_dir='./outputScratch/'+checkpoint_dir,log_dir='./outputScratch/'+log_dir)