# Training an SRCNN model

In [1]:
import tensorflow as tf
from glob import glob

# Import data loader
from data_loader import MultipleDataLoader

# Import model
from SRCNN import SRCNN

from training_helpers import clearMSE_metric, compute_loss

from supreshelper import *

### The input to our model is composed of the 4 most visible LR images combined with the median LR hence the number of channel is 5

In [2]:
lr_channels = 5
data_dir = "DataTFRecords/"
DataLoader = MultipleDataLoader(data_dir)

### We define the tf data object with proper properties for training

In [3]:
batch_size = 4

# List tfrecords files
train_files = glob(data_dir +  "train/*/*/multiple.tfrecords")
    
# Create a tf dataset
train_dataset = tf.data.TFRecordDataset(train_files)

# Map each file to the parsing funciton, enabling data augmentation
train_dataset = train_dataset.map(lambda x: DataLoader.parse_multiple_fixed(x, augment=True, num_lrs = lr_channels), num_parallel_calls=tf.data.experimental.AUTOTUNE)

# reshuffle_each_iteration works only when combined with repeat
train_dataset = train_dataset.repeat()
train_dataset = train_dataset.shuffle(len(train_files))

# Set the batch size
train_dataset = train_dataset.batch(batch_size)

### Building the SRCNN model

In [4]:
model = SRCNN(channel_dim=lr_channels, include_batch_norm = False).model

### Define initial learning rate and optimizer

In [5]:
optimizer = tf.keras.optimizers.Adam(0.0001)

### Define call back to save model

In [6]:
filepath = "Model/SRCNN/{epoch:02d}.hdf5"
os.makedirs("Model/SRCNN/", exist_ok=True)
checkpoint = tf.keras.callbacks.ModelCheckpoint(filepath, monitor='train_loss', verbose=0, save_best_only=False, save_weights_only=False, mode='auto', save_freq='epoch')

callbacks_list = [checkpoint]

### Compile and train model

In [7]:
model.compile(optimizer, compute_loss, metrics=[clearMSE_metric])

In [8]:
model.fit_generator(train_dataset, steps_per_epoch=len(train_files) / batch_size, epochs=30, use_multiprocessing=True, callbacks=callbacks_list)

Epoch 1/30
Epoch 2/30
Epoch 3/30
Epoch 4/30
Epoch 5/30
Epoch 6/30
Epoch 7/30
Epoch 8/30
Epoch 9/30
Epoch 10/30
Epoch 11/30
Epoch 12/30
Epoch 13/30
Epoch 14/30
Epoch 15/30
Epoch 16/30
Epoch 17/30
Epoch 18/30
Epoch 19/30
Epoch 20/30
Epoch 21/30
Epoch 22/30
Epoch 23/30
Epoch 24/30
Epoch 25/30
Epoch 26/30
Epoch 27/30
Epoch 28/30
Epoch 29/30
Epoch 30/30


<tensorflow.python.keras.callbacks.History at 0x7f6030046f98>

### Compute score on train set

In [9]:
train_records = glob(data_dir +  "train/*/*/multiple.tfrecords")
train_scenes = glob(data_dir +  "train/*/*/")
    
train_score = tf.data.TFRecordDataset(train_records)
train_score = train_score.map(lambda x: DataLoader.parse_multiple_fixed(x, augment=False, num_lrs = lr_channels))
# reshuffle_each_iteration works only for the repeat operation
train_score = train_score.batch(1)

In [11]:
scores = []
i = 0
for lrs, hr in train_score:
    sr = model(lrs)
    scores.append(score_image_fast(sr[0][:,:,0].numpy(), train_scenes[i]))
    i += 1  

In [12]:
print(np.mean(scores))

1.0306669966571653
