# Designed to run in Colab

In [None]:
# This program is designed to run in Colab (yes, we are skipping the req file)
# The code is about as unstructured and inefficient as it gets, but hey, it works...... umm...probably !!  :)
# For further details, please refer our paper 'Transform Domain Pyramidal Dilated Convolution Networks For Restoration of Under Display Camera Images'

In [None]:
!nvidia-smi

In [None]:
!pip3 install tensorlayer

# ImportPackages

In [None]:
# import packages
import os
import glob
import time
import math
import random
import imageio
import numpy as np
import tensorlayer as tl
import tensorflow as tf
from tensorflow.python.keras.layers import Conv2D, Input, Conv2DTranspose, Concatenate, Add, Lambda
from tensorflow.python.keras.models import Model
from tensorflow.keras.utils import plot_model
import matplotlib.pyplot as plt
AUTOTUNE = tf.data.experimental.AUTOTUNE

# Set Training Config

In [None]:
# Set Defaults
class Configuration:
  png_data = '/content/drive/My Drive/udc/DATASET/POLED/train/'   # png_data (training) should contain 2 folders : 1. gt and 2. input. The gt & input images should be placed in these folders with matching names
  base_dir = '/content/drive/My Drive/udc/DWT_POLED/'  # Base directory for the program.
  save_image_dir = base_dir+'images'
  save_model_dir = base_dir+'models'
  log_dir = base_dir+'log'
  steps_total=None
  batch_shape = (10,512,512,3)
  weight_mse_loss = 1
  weight_char_loss = 1
  progress_freq = 10        # Interval at which training progress is displayed. Choose a high value if running on colab to prevent the system from getting stuck
  display_freq = 50         # Interval at which sample predictions are generated in base_dir/images/epoch
  plot_training_freq = 20   # Interval at which training plots are generated in base_dir/log
  display_samples = 5       # Number of sample predictions for at display interval. This value will also depend on batch_shape and total training examples, so display_samples is not always correct

config = Configuration()
tf.io.gfile.makedirs(config.save_model_dir)
tf.io.gfile.makedirs(config.save_image_dir)
tf.io.gfile.makedirs(config.log_dir)

# Set Testing Config

In [None]:
class TestConfiguration:
  png_data = '/content/drive/My Drive/udc/DATASET/POLED/validation/'  # png_data (validation )should contain 2 folders : 1. gt and 2. input. The gt & input images should be placed in these folders with matching names
  batch_shape = (1,1024,None,3)  # for RLQ-TOD 2020 udc dataset image height is fixed at 1024
  file_count = None
Tconfig = TestConfiguration()

# Load Dataset

In [None]:
train_gt = sorted(tl.files.load_file_list(path=config.png_data+'gt/', regx='.*.png', printable=False))
train_gt = tl.vis.read_images(train_gt, path=config.png_data+'gt/', n_threads=64)
train_input = sorted(tl.files.load_file_list(path=config.png_data+'input/', regx='.*.png', printable=False))
train_input = tl.vis.read_images(train_input, path=config.png_data+'input/', n_threads=64)
test_gt = sorted(tl.files.load_file_list(path=Tconfig.png_data+'gt/', regx='.*.png', printable=False))
test_gt = tl.vis.read_images(test_gt, path=Tconfig.png_data+'gt/', n_threads=10)
test_input = sorted(tl.files.load_file_list(path=Tconfig.png_data+'input/', regx='.*.png', printable=False))
test_input = tl.vis.read_images(test_input, path=Tconfig.png_data+'input/', n_threads=10)

# Create Pipeline

In [None]:
# This cell will change drastically in a future update
def get_train_data():
  def hq_train():
      for img in train_gt:
          yield img
  def lq_train():
    for img in test_gt:
      yield img
  def random_crop(lq_img, hq_img, hq_crop_size=config.batch_shape[1], scale=1):
      lq_crop_size = hq_crop_size // scale
      lq_img_shape = tf.shape(lq_img)[:2]
      lq_w = tf.random.uniform(shape=(), maxval=lq_img_shape[1] - lq_crop_size + 1, dtype=tf.int32)
      lq_h = tf.random.uniform(shape=(), maxval=lq_img_shape[0] - lq_crop_size + 1, dtype=tf.int32)

      hq_w = lq_w * scale
      hq_h = lq_h * scale

      lq_img_cropped = lq_img[lq_h:lq_h + lq_crop_size, lq_w:lq_w + lq_crop_size]
      hq_img_cropped = hq_img[hq_h:hq_h + hq_crop_size, hq_w:hq_w + hq_crop_size]
      lq_img_cropped = tf.cast(lq_img_cropped,dtype=tf.float32)/255.
      hq_img_cropped = tf.cast(hq_img_cropped,dtype=tf.float32)/255.

      return lq_img_cropped, hq_img_cropped

  train_lq = tf.data.Dataset.from_generator(lq_train, output_types=(tf.uint8))
  train_hq = tf.data.Dataset.from_generator(hq_train, output_types=(tf.uint8))
  train_ds = tf.data.Dataset.zip((train_lq,train_hq))
  train_ds = train_ds.shuffle(buffer_size=len(train_gt),reshuffle_each_iteration=True)
  train_ds = train_ds.map(lambda lq,hq: random_crop(lq, hq, scale=1), num_parallel_calls=AUTOTUNE)
  
  train_ds = train_ds.repeat()
  train_ds = train_ds.batch(batch_size=config.batch_shape[0])
  train_ds = train_ds.prefetch(buffer_size=AUTOTUNE)
  return train_ds

In [None]:
# This cell will change drastically in a future update
def get_test_data():
  # dataset API and augmentation
  def hq_test():
      for img in test_gt:
          yield img
  def _map_fn_hq(img):
      hq_patch = img/255.
      return hq_patch
  def lq_test():
    for img in test_input:
      yield img
  def _map_fn_lq(img):
    lq_patch = img/255.
    return lq_patch
  test_lq = tf.data.Dataset.from_generator(lq_test, output_types=(tf.float32))
  test_lq = test_lq.map(_map_fn_lq, num_parallel_calls=AUTOTUNE)
  test_hq = tf.data.Dataset.from_generator(hq_test, output_types=(tf.float32))
  test_hq = test_hq.map(_map_fn_hq, num_parallel_calls=AUTOTUNE)
  test_ds = tf.data.Dataset.zip((test_lq,test_hq))
  test_ds = test_ds.shuffle(buffer_size=len(test_gt),reshuffle_each_iteration=True)
  test_ds = test_ds.prefetch(buffer_size=AUTOTUNE)
  test_ds = test_ds.batch(batch_size=Tconfig.batch_shape[0])
  return test_ds

# Execute Pipeline

In [None]:
load_time = time.time()
train_ds = get_train_data()
trainval_ds = get_test_data()
load_time = time.time()-load_time
if(Tconfig.batch_shape[0]<config.display_samples):
  config.display_samples = Tconfig.batch_shape[0]
config.steps_total = math.ceil(len(train_gt)/config.batch_shape[0])
Tconfig.file_count = len(test_input)
print('tf dataset creation time = {}'.format(load_time))
print('Steps per epoch = {}'.format(config.steps_total))
print('Test file count = {}'.format(Tconfig.file_count))

# Model Definition

In [None]:
def depth_to_space(scale):
    return lambda x: tf.nn.depth_to_space(x, scale)

In [None]:
def space_to_depth(block_size):
    return lambda x: tf.nn.space_to_depth(x, block_size)

In [None]:
def conv_relu(x, filters, kernel, use_bias = True, dilation_rate=1):
	if dilation_rate == 0:
		y = tf.keras.layers.Conv2D(filters,1,padding='same',use_bias=use_bias,
			activation='relu')(x)
	else:
		y = tf.keras.layers.Conv2D(filters,kernel,padding='same',use_bias=use_bias,
			dilation_rate=dilation_rate,
			activation='relu')(x)
	return y
def conv(x, filters, kernel, use_bias=True, dilation_rate=1):
	y = tf.keras.layers.Conv2D(filters,kernel,padding='same',use_bias=use_bias,
		dilation_rate=dilation_rate)(x)
	return y

In [None]:
class DWT(tf.keras.layers.Layer):
	def __init__(self, **kwargs):
		super(DWT, self).__init__(**kwargs)

	def call(self, inputs, **kwargs):
		x01 = inputs[:,0::2,:,:] / 4.0
		x02 = inputs[:,1::2,:,:] / 4.0
		x1 = x01[:,:,0::2,:]
		x2 = x01[:,:,1::2,:]
		x3 = x02[:,:,0::2,:]
		x4 = x02[:,:,1::2,:]
		y1 = x1+x2+x3+x4
		y2 = x1-x2+x3-x4
		y3 = x1+x2-x3-x4
		y4 = x1-x2-x3+x4
		y = tf.keras.backend.concatenate([y1,y2,y3,y4],axis=-1)
		return y

	def compute_output_shape(self, input_shape):
		c = input_shape[-1]*4
		if(input_shape[1] != None and input_shape[2] != None):
			return (input_shape[0], input_shape[1] >> 1, input_shape[2] >> 1, c)
		else:
			return (None, None, None, c)

In [None]:
class IWT(tf.keras.layers.Layer):
	def __init__(self, **kwargs):
		super(IWT, self).__init__(**kwargs)

	def build(self, input_shape):
		c = input_shape[-1]
		out_c = c >> 2
		kernel = np.zeros((1,1,c,c),dtype=np.float32)
		for i in range(0,c,4):
			idx = i >> 2
			kernel[0,0,idx::out_c,idx]         = [1, 1, 1, 1]
			kernel[0,0,idx::out_c,idx+out_c]   = [1,-1, 1,-1]
			kernel[0,0,idx::out_c,idx+out_c*2] = [1, 1,-1,-1]
			kernel[0,0,idx::out_c,idx+out_c*3] = [1,-1,-1, 1]
		self.kernel = tf.keras.backend.variable(value = kernel, dtype='float32')

	def call(self, inputs, **kwargs):
		y = tf.keras.backend.conv2d(inputs, self.kernel, padding='same')
		y = tf.nn.depth_to_space(y,2)
		return y

	def compute_output_shape(self, input_shape):
		c = input_shape[-1]>>2
		if(input_shape[1] != None and input_shape[2] != None):
			return (input_shape[0], input_shape[1] << 1, input_shape[2] << 1, c)
		else:
			return (None, None, None, c)

In [None]:
def pyramid(x, filters, Pyramid_Cells):
	def pyramid_cell(x, filters, dilation_rates):
		for i in range(len(dilation_rates)):
			dilation_rate = dilation_rates[i]
			if i==0:
				t = conv_relu(x,filters,3,dilation_rate=dilation_rate)
				_t = tf.keras.layers.Concatenate(axis=-1)([x,t])
			else:
				t = conv_relu(_t,filters,3,dilation_rate=dilation_rate)
				_t = tf.keras.layers.Concatenate(axis=-1)([_t,t])
		return _t
	concat_list = []
	t = conv_relu(x,filters*2,3)
	for i in range(len(Pyramid_Cells)):
		if i == 0:
			_t = pyramid_cell(t,filters,Pyramid_Cells[i])
		else:
			_t = pyramid_cell(_t,filters,Pyramid_Cells[i])
		_t = conv_relu(_t,filters,1)
		concat_list.append(_t)		
	if len(concat_list) == 1:
		return _t
	else:
		y = tf.keras.layers.Concatenate(axis=-1)(concat_list)
		return y

In [None]:
def encoder(x,nFilters, nPyramids, Pyramid_Cells, nPyramidFilters,type):
	def pyramid(x,nFilters,Pyramid_Cells,nPyramidFilters):
		_t = pyramid(x, nPyramidFilters, Pyramid_Cells)		
		y = conv(_t, nFilters, 3)
		y = tf.keras.layers.Lambda(lambda x:x*0.1)(y)
		y = tf.keras.layers.Add()([x,y])
		return y	
	if(type=='wavelet'):
		x = DWT()(x)
	t = Lambda(space_to_depth(block_size=2))(x)
	t = conv_relu(t,nFilters,5)
	t = conv_relu(t,nFilters,3)
	t = pyramid(t,nFilters,Pyramid_Cells,nPyramidFilters)
	t = tf.keras.layers.Conv2D(nFilters*2,5,padding='same',strides = (2,2),use_bias=True)(t)
	t = pyramid(t,nFilters*2,Pyramid_Cells,nPyramidFilters*2)
	t = tf.keras.layers.Conv2D(nFilters*4,5,padding='same',strides = (2,2),use_bias=True)(t)
	t = pyramid(t,nFilters*4,Pyramid_Cells,nPyramidFilters*4)
	return t

In [None]:
def decoder(x,nFilters, nPyramids, Pyramid_Cells, nPyramidFilters,type):
  def pyramid(x,nFilters,Pyramid_Cells,nPyramidFilters):
    _t = pyramid(x, nPyramidFilters, Pyramid_Cells)		
    y = conv(_t, nFilters, 3)
    y = tf.keras.layers.Lambda(lambda x:x*0.1)(y)
    y = tf.keras.layers.Add()([x,y])
    return y
  t = pyramid(x,nFilters,Pyramid_Cells,nPyramidFilters)
  t = tf.keras.layers.Conv2DTranspose(nFilters/2,4,strides=(2,2),padding='same',use_bias=True)(t)
  t = pyramid(t,nFilters/2,Pyramid_Cells,nPyramidFilters/2)
  t = tf.keras.layers.Conv2DTranspose(nFilters/4,4,strides=(2,2),padding='same',use_bias=True)(t)
  t = pyramid(t,nFilters/4,Pyramid_Cells,nPyramidFilters/4)
  t = Lambda(depth_to_space(scale=2))(t)
  if(type=='wavelet'):
    out = conv(t,3*4,3)
    out = IWT()(out)
  else:
    out = conv(t,3,3) 
  return out

In [None]:
def get_model(nPyramids,Pyramid_Cells,nFilters,nFilters_deco,nPyramidFilters,nPyramidFilters_deco):
  x = tf.keras.layers.Input(shape=(None, None, 3))
  w = encoder(x,nFilters, nPyramids, Pyramid_Cells, nPyramidFilters,type='wavelet')
  out = decoder(w,nFilters_deco, nPyramids, Pyramid_Cells, nPyramidFilters_deco,type='wavelet')
  return tf.keras.Model(x,out,name="generator")

# Training Losses

In [None]:
# Charbonnier Loss
@tf.function(input_signature=(tf.TensorSpec(shape=[None,config.batch_shape[1],config.batch_shape[2],config.batch_shape[3]], dtype=tf.float32),tf.TensorSpec(shape=[None,config.batch_shape[1],config.batch_shape[2],config.batch_shape[3]], dtype=tf.float32)))
def charbonnier(y_true, y_pred):
    epsilon = 1e-3
    error = y_true - y_pred
    p = tf.keras.backend.sqrt(tf.keras.backend.square(error) + tf.keras.backend.square(epsilon))
    return tf.keras.backend.mean(p)

In [None]:
@tf.function(input_signature=(tf.TensorSpec(shape=[None,config.batch_shape[1],config.batch_shape[2],config.batch_shape[3]], dtype=tf.float32),tf.TensorSpec(shape=[None,config.batch_shape[1],config.batch_shape[2],config.batch_shape[3]], dtype=tf.float32)))
def calculate_loss(target_batch,gen_out):
  mse_loss = config.weight_mse_loss*tf.reduce_mean(tf.keras.losses.MeanSquaredError()(target_batch,gen_out))
  # char_loss = config.weight_char_loss*charbonnier(target_batch, gen_out
  return mse_loss # + char_loss

# Train & Test Step

In [None]:
@tf.function(input_signature=(tf.TensorSpec(shape=[None,config.batch_shape[1],config.batch_shape[2],config.batch_shape[3]], dtype=tf.float32),tf.TensorSpec(shape=[None,config.batch_shape[1],config.batch_shape[2],config.batch_shape[3]], dtype=tf.float32)))
def train_step(input_batch,target_batch):
  global optimizer, model

  with tf.GradientTape() as tape:
    gen_out = model(input_batch,training=True)
    mse_loss = calculate_loss(target_batch,gen_out)
  del input_batch,gen_out,target_batch
  gradients = tape.gradient(mse_loss, model.trainable_weights)
  optimizer.apply_gradients(zip(gradients, model.trainable_weights))
  return mse_loss

@tf.function(input_signature=(tf.TensorSpec(shape=[None,None,None,Tconfig.batch_shape[3]], dtype=tf.float32),tf.TensorSpec(shape=[None,None,None,Tconfig.batch_shape[3]], dtype=tf.float32)))
def test_step(input_batch,target_batch):
  global model
  gen_out = model(input_batch,training=False)
  return gen_out,tf.reduce_sum(tf.image.psnr(tf.image.convert_image_dtype(gen_out, tf.dtypes.uint8, saturate=True),tf.image.convert_image_dtype(target_batch, tf.dtypes.uint8, saturate=True),max_val=255))

# Save Model & Prediction

In [None]:
def save_model(traintest_psnr):
  global ckpt, chkpt_manager_best, chkpt_manager_latest
  chkpt_manager_latest.save(checkpoint_number=1)
  model.save_weights(config.save_model_dir+'/latest_model.hdf5')
  if(traintest_psnr >= ckpt.max_psnr):
    ckpt.max_psnr.assign(traintest_psnr)
    chkpt_manager_best.save(checkpoint_number=1)
    model.save_weights(config.save_model_dir+'/best_model.hdf5')
  return

gap_img = tf.constant(np.zeros([Tconfig.batch_shape[1],20,3], dtype = np.float32))
def save_sample_predictions(epoch,input_batch,gen_out,target_batch):
  if(epoch % config.display_freq == 0):
    tf.io.gfile.makedirs(config.save_image_dir+'/{}'.format(epoch))
    for i in range(config.display_samples):
      # disp_img = tf.concat([input_batch[i], gap_img, gen_out[i], gap_img, target_batch[i]], axis=1) # if a gap is needed b/w images, uncomment this and comment next ine. Note that the image height must be set in test config for this to work.
      disp_img = tf.concat([input_batch[i], gen_out[i], target_batch[i]], axis=1)
      disp_img = tf.image.convert_image_dtype(disp_img, tf.dtypes.uint8, saturate=True).numpy()
      imageio.imwrite(config.save_image_dir+'/{}/{}.png'.format(epoch,i),disp_img)
  return

# Instantiate Variables

In [None]:
Pyramid_Cells = ((3,2,1,1,1,1),)
nPyramids = 1
schedule=None
optimizer=None
model=None
ckpt=None
chkpt_manager_latest=None
chkpt_manager_best=None

In [None]:
def instantiate_training_variables():
  global schedule,optimizer,model,ckpt,chkpt_manager_latest,chkpt_manager_best
  schedule = tf.keras.optimizers.schedules.PiecewiseConstantDecay(boundaries=[3000*config.steps_total,5000*config.steps_total], values=[1e-4,0.5e-4,0.2e-4])
  optimizer = tf.keras.optimizers.Adam(learning_rate=schedule)
  model = get_model(nPyramids,Pyramid_Cells,nFilters=16,nFilters_deco=64,nPyramidFilters=16,nPyramidFilters_deco=64)
  ckpt = tf.train.Checkpoint(epoch = tf.Variable(0), max_psnr = tf.Variable(0.0), optimizer=optimizer)
  chkpt_manager_latest = tf.train.CheckpointManager(ckpt, config.log_dir+'/latest_chkpt', max_to_keep=1,checkpoint_name='ckpt')
  chkpt_manager_best = tf.train.CheckpointManager(ckpt, config.log_dir+'/best_chkpt', max_to_keep=1,checkpoint_name='ckpt')
  return

# Load Checkpoint

In [None]:
def restore_checkpoint(training_mode):
  global model,ckpt,chkpt_manager_best,chkpt_manager_latest
  if(training_mode=='best'):
      ckpt.restore(chkpt_manager_best.latest_checkpoint)
      model.load_weights(config.save_model_dir+'/best_model.hdf5')
      print("Checkpoint restored from epoch {}".format(ckpt.epoch.numpy()))
  elif(training_mode=='latest'):
    if(tf.io.gfile.exists(config.log_dir+'/latest_chkpt')):
      ckpt.restore(chkpt_manager_latest.latest_checkpoint)
      model.load_weights(config.save_model_dir+'/latest_model.hdf5')
      print("Checkpoint restored from epoch {}".format(ckpt.epoch.numpy()))
    else:
      print("Created new checkpoint")
  return

# Create/Update Log File

In [None]:
def create_log_files():

  global ckpt

  def trim_log_file(fileName, epoch):
    with open(fileName,'r') as f:
      records = [line.rstrip() for line in f]
      while(True):
        if(int(records[-1].split(',')[0])>epoch):
          records.remove(records[-1])
        else:
          break

    with open(fileName+'','w') as f:
      f.write('\n'.join(records)+'\n')

  if(ckpt.epoch==0):
    log_file = open(config.log_dir+"/log_file_loss.csv","w")
    log_file.write('{},{},{},{}\n'.format('Epoch','Step','step time','MSE loss'))
    log_file.close()

    log_file = open(config.log_dir+"/log_file_accuracy.csv","w")
    log_file.write('{},{}\n'.format('Epoch','traintest_psnr'))
    log_file.close()

  else:

    if(not os.path.exists(config.log_dir+"/log_file_loss.csv")):
      log_file = open(config.log_dir+"/log_file_loss.csv","w")
      log_file.write('{},{},{},{}\n'.format('Epoch','Step','step time','MSE loss'))
      log_file.close()
    else:
      trim_log_file(config.log_dir+"/log_file_loss.csv",ckpt.epoch.numpy())

    if(not os.path.exists(config.log_dir+"/log_file_accuracy.csv")):
      log_file = open(config.log_dir+"/log_file_accuracy.csv","w")
      log_file.write('{},{}\n'.format('Epoch','traintest_psnr'))
      log_file.close()
    else:
      trim_log_file(config.log_dir+"/log_file_accuracy.csv",ckpt.epoch.numpy())

# Plot Metrics

In [None]:
def plot_training_metrics():
  with open(config.log_dir+'/log_file_loss.csv') as f:
    rows = [line.rstrip().split(",") for line in f]
  mse_loss = []
  for row in rows[1:]:
    mse_loss.append(float(row[3]))  # mse loss
  plt.plot(mse_loss)
  plt.ylabel('loss per step')
  plt.xlabel('Step')
  plt.title('MSE loss')
  plt.grid()
  plt.savefig(config.log_dir+'/MSE_loss.png',dpi=500)
  plt.close()

  with open(config.log_dir+'/log_file_accuracy.csv') as f:
    rows = [line.rstrip().split(",") for line in f]
  accuracy = []
  for row in rows[1:]:
    accuracy.append(float(row[1]))  # PSNR
  plt.plot(accuracy)
  plt.ylabel('PSNR')
  plt.xlabel('Epoch')
  plt.title('Accuracy')
  plt.grid()
  plt.savefig(config.log_dir+'/psnr.png',dpi=500)
  plt.close()
  return

# Training Routine

In [None]:
def train(training_mode):
  global ckpt
  global model
  global chkpt_manager_latest

  step_time = 0.0
  print_template = 'Epoch: {}, step: {}, time: {:.3f}s, mse_loss: {:.5f}'

  instantiate_training_variables()
  restore_checkpoint(training_mode)
  create_log_files()
  plot_model(model, to_file = config.save_model_dir+'/g_model.png', show_shapes=True, show_layer_names=True)
    
  ckpt.epoch.assign_add(1)
  print('learning rate = {}'.format(ckpt.optimizer._decayed_lr(tf.float32)))
  log_file = open(config.log_dir+"/log_file_loss.csv","a")
  for step,(input_batch,target_batch) in enumerate(train_ds,start=1):
    epoch_step = step%config.steps_total
    step_time = time.time()  
    mse_loss = train_step(input_batch,target_batch)
    step_time = time.time()-step_time
    log_file.write('{},{},{},{}\n'.format(ckpt.epoch.numpy(),epoch_step,step_time,mse_loss))
    # print(print_template.format(ckpt.epoch.numpy(),epoch_step,step_time,mse_loss))  # uncomment for stepwise training progress. Be warned, this would result in browser crash after a lot of iterations if you are using jupyter/colab etc.

    if(epoch_step==0):
      log_file.close()
      traintest_psnr = 0.0
      for input_batch,target_batch in trainval_ds:
        gen_out,minibatch_psnr = test_step(input_batch,target_batch)
        traintest_psnr+=minibatch_psnr
      traintest_psnr=traintest_psnr/Tconfig.file_count
      save_model(traintest_psnr)
      save_sample_predictions(ckpt.epoch.numpy(),input_batch,gen_out,target_batch)
      log_file = open(config.log_dir+"/log_file_accuracy.csv","a")
      log_file.write('{},{}\n'.format(ckpt.epoch.numpy(),traintest_psnr))
      log_file.close()
      chkpt_manager_latest.save(checkpoint_number=1)
      model.save_weights(config.save_model_dir+'/latest_model.hdf5')
      print('Checkpoint saved')
      ckpt.epoch.assign_add(1)
      log_file = open(config.log_dir+"/log_file_loss.csv","a")
      if(ckpt.epoch.numpy()%config.plot_training_freq==0):
        plot_training_metrics()
      if(ckpt.epoch.numpy()%config.progress_freq==0):
        print('Epoch = {}, max_psnr = {}, current psnr = {}'.format(ckpt.epoch.numpy(), ckpt.max_psnr.numpy(), traintest_psnr))

In [None]:
train('latest') # 'latest' => resume from latest weights, 'best => resume from best psnr epoch. When training from scratch, always pass latest as argument here.