In [None]:
!wget http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip

In [None]:
!unzip DIV2K_train_HR.zip

In [1]:
from model import DRN
from data import data_load,data_decode,data_prepare
from data import data_patch,data_augment,data_normalize
import tensorflow as tf

In [None]:
data_path = 'DIV2K_train_HR'
input_size = 64
channel = 3
scale = 4
dual = True
input_shape = (input_size,input_size,channel)

dataset = tf.data.Dataset.from_tensor_slices(data_load(data_path))
dataset = dataset.map(data_decode,tf.data.experimental.AUTOTUNE)
dataset = dataset.map(lambda x: data_prepare(x,scale,input_shape),tf.data.experimental.AUTOTUNE)
dataset = dataset.cache()
dataset = dataset.map(lambda x,y: data_patch(x,y,scale,input_shape),tf.data.experimental.AUTOTUNE)
dataset = dataset.map(data_augment,tf.data.experimental.AUTOTUNE)
dataset = dataset.map(data_normalize,tf.data.experimental.AUTOTUNE)
dataset = dataset.batch(16).prefetch(tf.data.experimental.AUTOTUNE)

model = DRN(input_shape=input_shape,model='DRN-S',scale=scale,dual=dual)

In [None]:
from tensorflow.keras.callbacks import ModelCheckpoint, LearningRateScheduler
from tensorflow.keras.optimizers import Adam,SGD
import os
import math

def CosineAnnealingScheduler(T_max=30,lr_max=0.001,lr_min=0.00009,Pi=tf.constant(math.pi)):
    def scheduler(epoch, lr):
        lr = lr_min + (lr_max - lr_min) * 0.5*(1 + tf.math.cos(Pi * epoch / T_max))
        return lr
    return scheduler
    
def loss(y_true, y_pred):
    loss = tf.math.reduce_mean(tf.keras.losses.MAE(y_true,y_pred))
    return loss
def dual_loss(y_true, y_pred):
    lr, sr2lr = tf.split(y_pred, 2, axis=-1)
    loss = tf.math.reduce_mean(tf.keras.losses.MAE(lr,sr2lr))
    return 0.1*loss

model_path = "./models/"
model_name = "weights-{epoch:03d}-{loss:.4f}.h5"
if not os.path.exists(model_path):
    os.mkdir(model_path)
checkpoint = ModelCheckpoint(os.path.join(model_path, model_name),save_freq=2,save_best_only=False,save_weights_only=True)
lrscheduler = LearningRateScheduler(CosineAnnealingScheduler())
opt = Adam(1e-3)
if dual:
    model.compile(loss=[loss]+[dual_loss for i in range(int(math.log(scale,2)))], optimizer=opt)
else:
    model.compile(loss='mean_absolute_error', optimizer=opt)

In [None]:
model.fit(dataset,epochs=400,callbacks=[checkpoint,lrscheduler])