# Tutorials for training Models

In [1]:
import os
import tensorflow as tf
%load_ext autoreload
%autoreload 2
%matplotlib inline

## Preparing for Data

- Data preparing in CNN-based super-resolution research is always based on one assumption: Lr image is the `bicubic` downsampled version of Hr image. However we support many kind of degradations such as `gaussian` kernel, `bilinear` kernel, additive noise and so on. 

- We usually crop images into patches for convenience of training, but reconstruct the whole image when testing (when full convolution net). Based on the fact that the order of downsampling and cropping operation doesn't matter, I suggest to crop the image and save patches to `tfrecord` file first, then use `map` method to downsample each hr-patch.

In [2]:
train_dir = "./Image/set14" # Arbitrary
valid_dir = "./Image/set5"
AUTOTUNE = tf.data.experimental.AUTOTUNE
SCALE = 3

In [6]:
from src.write2tfrec import write_dst_tfrec, load_tfrecord

cache_dir = "./cache"
os.makedirs(cache_dir, exist_ok=True)

Saving patch into tfrecord file. IF saved, we can use directly.

In [7]:

if not os.path.isfile("./cache/set14_train_48x48.tfrec"):
    write_dst_tfrec(train_dir, 10, 48, "./cache/set14_train_48x48.tfrec")
    
if not os.path.isfile("./cache/set5_valid_48x48.tfrec"):
    write_dst_tfrec(valid_dir, 10, 48, "./cache/set5_valid_48x48.tfrec")





12it [00:00, 37.25it/s]




5it [00:00, 51.69it/s]


Input should be hr-patch, and output should be data pair (inputs, labels).
The degradation function used here is pre-defined in `preprocess.py` file, one can 
customize if needed.

In [8]:
from src.preprocess import degrade_image

def preprocess(hr):
    lr, hr = degrade_image(hr, SCALE, method=2, restore_shape=False)
    return lr, hr


Load tfrecord file and map the preprocess function. 
`repeat()` makes the dataset repeat infinitely.

In [9]:
trdst = load_tfrecord(48, "./cache/set14_train_48x48.tfrec").map(preprocess,AUTOTUNE).repeat()
valdst = load_tfrecord(48, "./cache/set5_valid_48x48.tfrec").map(preprocess, AUTOTUNE)


## Train a pre-defined model

In [10]:
from src.model import EDSR_baseline

Here we train EDSR-baseline model for example. 

In [11]:
model = EDSR_baseline(scale=SCALE, model_name="EDSR_Baseline",
                      channel=3).create_model(load_weights=False,
                                              weights_path=None)


In [12]:
model.fit(trdst,
          valdst,
          nb_epochs=2,
          steps_per_epoch=20,
          batch_size=16,
          use_wn=False)

Training model : EDSR_Baseline_X3


W0619 16:30:54.574010 12428 tf_logging.py:161] Model failed to serialize as JSON. Ignoring... 


Epoch 1/2
Epoch 00001: saving model to ./weights/EDSR_Baseline_X3.h5
Epoch 2/2
Epoch 00002: saving model to ./weights/EDSR_Baseline_X3.h5


<src.model.EDSR.EDSR_baseline at 0x1798ea0fe10>