In [1]:
# coding: utf-8
from __future__ import print_function
import os, time, random
import tensorflow as tf
from PIL import Image
import numpy as np
from utils import *
from model import *
from glob import glob
import argparse

tf.compat.v1.disable_eager_execution()



parser = argparse.ArgumentParser(description='')
parser.add_argument('--batch_size', dest='batch_size', type=int, default=10, help='number of samples in one batch')
parser.add_argument('--patch_size', dest='patch_size', type=int, default=48, help='patch size')
parser.add_argument('--train_data_dir', dest='train_data_dir', default='./LOLdataset/our485', help='directory for training inputs')
parser.add_argument('--train_result_dir', dest='train_result_dir', default='./decom_net_train_result/', help='directory for decomnet training results')

# args = parser.parse_args()
args, unknown = parser.parse_known_args()

batch_size = args.batch_size
patch_size = args.patch_size

sess = tf.compat.v1.Session()

input_low = tf.compat.v1.placeholder(tf.compat.v1.float32, [None, None, None, 3], name='input_low')
input_high = tf.compat.v1.placeholder(tf.compat.v1.float32, [None, None, None, 3], name='input_high')

[R_low, I_low] = DecomNet(input_low)
[R_high, I_high] = DecomNet(input_high)

I_low_3 = tf.compat.v1.concat([I_low, I_low, I_low], axis=3)
I_high_3 = tf.compat.v1.concat([I_high, I_high, I_high], axis=3)

#network output
output_R_low = R_low
output_R_high = R_high
output_I_low = I_low_3
output_I_high = I_high_3

# define loss

def mutual_i_loss(input_I_low, input_I_high):
    low_gradient_x = gradient(input_I_low, "x")
    high_gradient_x = gradient(input_I_high, "x")
    x_loss = (low_gradient_x + high_gradient_x)* tf.compat.v1.exp(-10*(low_gradient_x+high_gradient_x))
    low_gradient_y = gradient(input_I_low, "y")
    high_gradient_y = gradient(input_I_high, "y")
    y_loss = (low_gradient_y + high_gradient_y) * tf.compat.v1.exp(-10*(low_gradient_y+high_gradient_y))
    mutual_loss = tf.compat.v1.reduce_mean( x_loss + y_loss) 
    return mutual_loss

def mutual_i_input_loss(input_I_low, input_im):
    input_gray = tf.compat.v1.image.rgb_to_grayscale(input_im)
    low_gradient_x = gradient(input_I_low, "x")
    input_gradient_x = gradient(input_gray, "x")
    x_loss = tf.compat.v1.abs(tf.compat.v1.div(low_gradient_x, tf.compat.v1.maximum(input_gradient_x, 0.01)))
    low_gradient_y = gradient(input_I_low, "y")
    input_gradient_y = gradient(input_gray, "y")
    y_loss = tf.compat.v1.abs(tf.compat.v1.div(low_gradient_y, tf.compat.v1.maximum(input_gradient_y, 0.01)))
    mut_loss = tf.compat.v1.reduce_mean(x_loss + y_loss) 
    return mut_loss

recon_loss_low = tf.compat.v1.reduce_mean(tf.compat.v1.abs(R_low * I_low_3 -  input_low))
recon_loss_high = tf.compat.v1.reduce_mean(tf.compat.v1.abs(R_high * I_high_3 - input_high))

equal_R_loss = tf.compat.v1.reduce_mean(tf.compat.v1.abs(R_low - R_high))

i_mutual_loss = mutual_i_loss(I_low, I_high)

i_input_mutual_loss_high = mutual_i_input_loss(I_high, input_high)
i_input_mutual_loss_low = mutual_i_input_loss(I_low, input_low)

loss_Decom = 1*recon_loss_high + 1*recon_loss_low \
               + 0.009 * equal_R_loss + 0.2*i_mutual_loss \
             + 0.15* i_input_mutual_loss_high + 0.15* i_input_mutual_loss_low

###
lr = tf.compat.v1.placeholder(tf.compat.v1.float32, name='learning_rate')

optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=lr, name='AdamOptimizer')
var_Decom = [var for var in tf.compat.v1.trainable_variables() if 'DecomNet' in var.name]
train_op_Decom = optimizer.minimize(loss_Decom, var_list = var_Decom)
sess.run(tf.compat.v1.global_variables_initializer())

saver_Decom = tf.compat.v1.train.Saver(var_list = var_Decom)
print("[*] Initialize model successfully...")

#load data
###train_data
train_low_data = []
train_high_data = []
train_low_data_names = glob(args.train_data_dir + '/low/*.png') 
train_low_data_names.sort()
train_high_data_names = glob(args.train_data_dir + '/high/*.png') 
train_high_data_names.sort()
assert len(train_low_data_names) == len(train_high_data_names)
print('[*] Number of training data: %d' % len(train_low_data_names))
for idx in range(len(train_low_data_names)):
    low_im = load_images(train_low_data_names[idx])
    train_low_data.append(low_im)
    high_im = load_images(train_high_data_names[idx])
    train_high_data.append(high_im)
###eval_data
eval_low_data = []
eval_high_data = []
eval_low_data_name = glob('./LOLdataset/eval15/low/*.png')
eval_low_data_name.sort()
eval_high_data_name = glob('./LOLdataset/eval15/high/*.png*')
eval_high_data_name.sort()
for idx in range(len(eval_low_data_name)):
    eval_low_im = load_images(eval_low_data_name[idx])
    eval_low_data.append(eval_low_im)
    eval_high_im = load_images(eval_high_data_name[idx])
    eval_high_data.append(eval_high_im)


epoch = 500
learning_rate = 0.0001

sample_dir = args.train_result_dir
if not os.path.isdir(sample_dir):
    os.makedirs(sample_dir)

eval_every_epoch = 500
train_phase = 'decomposition'
numBatch = len(train_low_data) // int(batch_size)
train_op = train_op_Decom
train_loss = loss_Decom
saver = saver_Decom

checkpoint_dir = './checkpoint/decom_net_retrain/'
if not os.path.isdir(checkpoint_dir):
    os.makedirs(checkpoint_dir)
ckpt=tf.compat.v1.train.get_checkpoint_state(checkpoint_dir)
if ckpt:
    print('loaded '+ckpt.model_checkpoint_path)
    saver.restore(sess,ckpt.model_checkpoint_path)
else:
    print('No decomnet pretrained model!')

start_step = 0
start_epoch = 0
iter_num = 0

print("[*] Start training for phase %s, with start epoch %d start iter %d : " % (train_phase, start_epoch, iter_num))
start_time = time.time()
image_id = 0
for epoch in range(start_epoch, epoch):
    for batch_id in range(start_step, numBatch):
        batch_input_low = np.zeros((batch_size, patch_size, patch_size, 3), dtype="float32")
        batch_input_high = np.zeros((batch_size, patch_size, patch_size, 3), dtype="float32")
        for patch_id in range(batch_size):
            h, w, _ = train_low_data[image_id].shape
            x = random.randint(0, h - patch_size)
            y = random.randint(0, w - patch_size)
            rand_mode = random.randint(0, 7)
            batch_input_low[patch_id, :, :, :] = data_augmentation(train_low_data[image_id][x : x+patch_size, y : y+patch_size, :], rand_mode)
            batch_input_high[patch_id, :, :, :] = data_augmentation(train_high_data[image_id][x : x+patch_size, y : y+patch_size, :], rand_mode)
            image_id = (image_id + 1) % len(train_low_data)
            if image_id == 0:
                tmp = list(zip(train_low_data, train_high_data))
                random.shuffle(tmp)
                train_low_data, train_high_data  = zip(*tmp)

        _, loss = sess.run([train_op, train_loss], feed_dict={input_low: batch_input_low, \
                                                              input_high: batch_input_high, \
                                                              lr: learning_rate})
        print("%s Epoch: [%2d] [%4d/%4d] time: %4.4f, loss: %.6f" \
              % (train_phase, epoch + 1, batch_id + 1, numBatch, time.time() - start_time, loss))
        iter_num += 1
    if (epoch + 1) % eval_every_epoch == 0:
        print("[*] Evaluating for phase %s / epoch %d..." % (train_phase, epoch + 1))
        for idx in range(len(eval_low_data)):
            input_low_eval = np.expand_dims(eval_low_data[idx], axis=0)
            result_1, result_2 = sess.run([output_R_low, output_I_low], feed_dict={input_low: input_low_eval})
            save_images(os.path.join(sample_dir, 'low_%d_%d.png' % ( idx + 1, epoch + 1)), result_1, result_2)
        for idx in range(len(eval_low_data)):
            input_low_eval = np.expand_dims(eval_high_data[idx], axis=0)
            result_11, result_22 = sess.run([output_R_high, output_I_high], feed_dict={input_high: input_low_eval})
            save_images(os.path.join(sample_dir, 'high_%d_%d.png' % ( idx + 1, epoch + 1)), result_11, result_22)
         
    saver.save(sess, checkpoint_dir + 'model.ckpt')

print("[*] Finish training for phase %s." % train_phase)





Instructions for updating:
Deprecated in favor of operator or tf.math.divide.
[*] Initialize model successfully...
[*] Number of training data: 481
loaded ./checkpoint/decom_net_retrain/model.ckpt
INFO:tensorflow:Restoring parameters from ./checkpoint/decom_net_retrain/model.ckpt
[*] Start training for phase decomposition, with start epoch 0 start iter 0 : 
decomposition Epoch: [ 1] [   1/  48] time: 3.1189, loss: 0.170772
decomposition Epoch: [ 1] [   2/  48] time: 3.7925, loss: 0.080093
decomposition Epoch: [ 1] [   3/  48] time: 4.4590, loss: 0.073999
decomposition Epoch: [ 1] [   4/  48] time: 5.1206, loss: 0.103606
decomposition Epoch: [ 1] [   5/  48] time: 5.7772, loss: 0.118739
decomposition Epoch: [ 1] [   6/  48] time: 6.4548, loss: 0.093089
decomposition Epoch: [ 1] [   7/  48] time: 7.1564, loss: 0.122145
decomposition Epoch: [ 1] [   8/  48] time: 7.8789, loss: 0.080379
decomposition Epoch: [ 1] [   9/  48] time: 8.5905, loss: 0.131505
decomposition Epoch: [ 1] [  10/  48]

## 