In [1]:
import os 
import sys

from tensorflow import keras
import tensorflow as tf
from PIL import Image
import numpy as np
from model import FSRCNN

# Allow local import from parent directory
# sys.path.insert(0, "..")
# from dataset import deserialize
# from common import normalize_y


# def preprocess(example): 
#     lr, hr = deserialize(example)
#     hr = normalize_y(hr)    
#     lr = normalize_y(lr)
#     return lr, hr

SIZE = 100
CHN  = 1
R = 4


def normalize_y(y_array, scale=219, offset=16):
    # for y-channel to [0, 1]
    y_array = tf.cast(y_array, dtype="float32")
    return (y_array - offset) / scale


def preprocess(example):
    image_feature_description = {"lr": tf.io.FixedLenFeature([], tf.string),
                                 "hr": tf.io.FixedLenFeature([], tf.string)}
      
    example = tf.io.parse_single_example(example, image_feature_description)
    lr = tf.io.decode_raw(example["lr"], out_type="uint8")
    hr = tf.io.decode_raw(example["hr"], out_type="uint8")
    shape = [SIZE, SIZE, CHN]
    hr = tf.reshape(hr, shape=shape)
    shape = [SIZE//R, SIZE//R, CHN]
    lr = tf.reshape(lr, shape=shape)
    hr = normalize_y(hr)    
    lr = normalize_y(lr)
    return lr, hr





In [2]:
# config and parameters 
IS_FSRCNN_S = False    # is FSRCNN_S or FSRCNN
RESUME = True   # Train from scratch or use previously traind weights 
n_tfrecords = 32
batch_size = 16
epochs = 300

# prep the dataset
train_dir = [f"../tfrecords/div2k_train{i}.tfrecords" for i in range(n_tfrecords)]
valid_dir = [f"../tfrecords/div2k_valid{i}.tfrecords" for i in range(n_tfrecords)]
train_dataset = tf.data.TFRecordDataset(train_dir).map(preprocess).batch(batch_size)
valid_dataset = tf.data.TFRecordDataset(valid_dir).map(preprocess).batch(batch_size)

# prep the model 
model = FSRCNN(d=32, s=5, m=1, r=4) if IS_FSRCNN_S else FSRCNN()
loss_fn = tf.keras.losses.MeanSquaredError()
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4)
model.compile(optimizer, loss=loss_fn)

name_model = "FSRCNN_S" if IS_FSRCNN_S else "FSRCNN"
log_dir = f"logs/{name_model}"       
checkpoint_path  = f"checkpoints/"+ name_model+"{epoch:03d}.ckpt"


if RESUME: 
    # load pre-trained model 
    checkpoint_dir = os.path.dirname(checkpoint_path)
    latest = tf.train.latest_checkpoint(checkpoint_dir)
    model.load_weights(latest)

# Training
callbacks = [
    tf.keras.callbacks.EarlyStopping(monitor = "val_loss", 
                                     patience = 3, 
                                     restore_best_weights = True),

    tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                                        verbose = 1,
                                        monitor="loss",
                                        save_freq="epoch"),

    tf.keras.callbacks.TensorBoard(log_dir=log_dir, update_freq="epoch")
]

history = model.fit(train_dataset, 
           initial_epoch=148, 
           epochs=epochs, 
           callbacks=callbacks, 
           validation_data=valid_dataset)


Epoch 149/300
Instructions for updating:
use `tf.profiler.experimental.stop` instead.
  13131/Unknown - 605s 46ms/step - loss: 0.0018
Epoch 00149: saving model to checkpoints\FSRCNN149.ckpt
Epoch 150/300
Epoch 00150: saving model to checkpoints\FSRCNN150.ckpt
Epoch 151/300
Epoch 00151: saving model to checkpoints\FSRCNN151.ckpt
Epoch 152/300
Epoch 00152: saving model to checkpoints\FSRCNN152.ckpt
Epoch 153/300
Epoch 00153: saving model to checkpoints\FSRCNN153.ckpt
Epoch 154/300
Epoch 00154: saving model to checkpoints\FSRCNN154.ckpt
Epoch 155/300
Epoch 00155: saving model to checkpoints\FSRCNN155.ckpt
Epoch 156/300
Epoch 00156: saving model to checkpoints\FSRCNN156.ckpt
Epoch 157/300
Epoch 00157: saving model to checkpoints\FSRCNN157.ckpt
Epoch 158/300
Epoch 00158: saving model to checkpoints\FSRCNN158.ckpt
Epoch 159/300
Epoch 00159: saving model to checkpoints\FSRCNN159.ckpt
Epoch 160/300
Epoch 00160: saving model to checkpoints\FSRCNN160.ckpt
Epoch 161/300
Epoch 00161: saving model 

In [11]:
def get_latest_model(): 
    model = FSRCNN(d=32, s=5, m=1, r=4) if IS_FSRCNN_S else FSRCNN()
    loss_fn = tf.keras.losses.MeanSquaredError()
    optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4)
    model.compile(optimizer, loss=loss_fn)

    name_model = "FSRCNN_S" if IS_FSRCNN_S else "FSRCNN"
    checkpoint_path  = f"checkpoints/"+ name_model+"{epoch:03d}.ckpt"
    checkpoint_dir = os.path.dirname(checkpoint_path)
    latest = tf.train.latest_checkpoint(checkpoint_dir)
    model.load_weights(latest)