# Tutorials for training Models

In [1]:
import os
import tensorflow as tf
import glob
%load_ext autoreload
%autoreload 2
%matplotlib inline

## Preparing for Data

- Data preparing in CNN-based super-resolution research is always based on one assumption: Lr image is the `bicubic` downsampled version of Hr image. However we support many kind of degradations such as `gaussian` kernel, `bilinear` kernel, additive noise and so on. 

- We usually crop images into patches for convenience of training, but reconstruct the whole image when testing (when full convolution net). Based on the fact that the order of downsampling and cropping operation doesn't matter, I suggest to crop the image and save patches to `tfrecord` file first, then use `map` method to downsample each hr-patch.

In [2]:
train_dir = "./Image/set14" # Arbitrary
valid_dir = "./Image/set5"
AUTOTUNE = tf.data.experimental.AUTOTUNE
SCALE = 3

In [5]:
from src.write2tfrec import write_dst_tfrec, load_tfrecord

cache_dir = "./cache"
os.makedirs(cache_dir, exist_ok=True)

Saving patch into tfrecord file. IF saved, we can use directly.

In [6]:

if not os.path.isfile("./cache/set14_train_48x48.tfrec"):
    paths = list(glob.glob(os.path.join(train_dir, "*")))
    write_dst_tfrec(paths, 10, 48, "./cache/set14_train_48x48.tfrec")
    
if not os.path.isfile("./cache/set5_valid_48x48.tfrec"):
    paths = list(glob.glob(os.path.join(valid_dir, "*")))
    write_dst_tfrec(paths, 10, 48, "./cache/set5_valid_48x48.tfrec")





12it [00:00, 39.19it/s]




5it [00:00, 42.13it/s]


Input should be hr-patch, and output should be data pair (inputs, labels).
The degradation function used here is pre-defined in `preprocess.py` file, one can 
customize if needed.

In [7]:
from src.preprocess import degrade_image

def preprocess(hr):
    lr, hr = degrade_image(hr, SCALE, method=-1, restore_shape=False, kernel_sigma=0.5)
    return lr, hr


Load tfrecord file and map the preprocess function. 
`repeat()` makes the dataset repeat infinitely.

In [8]:
trdst = load_tfrecord(48, "./cache/set14_train_48x48.tfrec").map(preprocess).repeat()
valdst = load_tfrecord(48, "./cache/set5_valid_48x48.tfrec").map(preprocess)


## Train a pre-defined model

In [9]:
from src.model import EDSR_baseline

Here we train EDSR-baseline model for example. 

In [10]:
model = EDSR_baseline(scale=SCALE, model_name="EDSR_Baseline",
                      channel=3).create_model(load_weights=False,
                                              weights_path=None)


In [11]:
model.fit(trdst,
          valdst,
          nb_epochs=2,
          steps_per_epoch=20,
          batch_size=16,
          use_wn=False)

Training model : EDSR_Baseline_X3


W0621 21:28:34.268955 14980 training_utils.py:1353] Expected a shuffled dataset but input dataset `x` is not shuffled. Please invoke `shuffle()` on input dataset.
W0621 21:28:34.341790 14980 tf_logging.py:161] Model failed to serialize as JSON. Ignoring... 


Epoch 1/2
 1/20 [>.............................] - ETA: 2:33 - loss: 5.8379 - psnr_tf: -6.6579

W0621 21:28:42.690458 14980 callbacks.py:236] Method (on_train_batch_end) is slow compared to the batch update (0.202476). Check your callbacks.


 2/20 [==>...........................] - ETA: 1:15 - loss: 3.4289 - psnr_tf: -2.9330

W0621 21:28:42.760271 14980 callbacks.py:236] Method (on_train_batch_end) is slow compared to the batch update (0.181550). Check your callbacks.


Epoch 00001: saving model to ./weights/EDSR_Baseline_X3.h5
Epoch 2/2
Epoch 00002: saving model to ./weights/EDSR_Baseline_X3.h5


<src.model.EDSR.EDSR_baseline at 0x29b8b951550>