In [None]:
import os
from os import path
import numpy as np
import tensorflow as tf
import glob
import logging
import time
import socket
import random
from datetime import datetime
from data_gener.data_generator import process
from models.utils import ResultLogger,hps_logger,get_optimizer
from utils.mylogger import add_logging_level
from models.noise_flow_model import NoiseFlow
from utils.ArgParser import arg_parser
from utils.sidd_utils import init_params,restore_last_model
from models.train_sample import sample,train

In [None]:
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

In [None]:
# train_index = [4, 11, 13, 17, 18, 20, 22, 23, 25, 27, 28, 29, 30, 34, 35, 39, 40, 42, 43, 44, 45, 47, 81, 86, 88,
#                      90, 101, 102, 104, 105, 110, 111, 115, 116, 125, 126, 127, 129, 132, 135,
#                      138, 140, 175, 177, 178, 179, 180, 181, 185, 186, 189, 192, 193, 194, 196, 197]
# test_index = [54, 55, 57, 59, 60, 62, 63, 66, 150, 151, 152, 154, 155, 159, 160, 161, 163, 164, 165, 166, 198,
#                      199]
# train_index = [154, 146, 160, 173, 189, 135, 106, 164, 102, 168, 126, 190, 181,186, 165, 151, 161, 140, 149,  29,  55,  59,   1,   6,  45,   5,
#                 51,   4,  22,   3, 101, 122,  88,  94, 167, 129, 150,   8, 163, 7, 105, 180,  86,  63, 147, 159, 188, 172,  48, 155,  18,  34,
#                 28,  33, 132, 197, 114,  66, 198, 107, 144, 137, 166, 157,  64]
# test_index = [111, 185, 130,  99, 156, 136,  32,  17,  25,  98, 125, 110, 134, 138, 184, 142,  60,  54, 192, 152, 191]
train_index = [4, 11, 13, 17, 18, 20, 22, 23, 25, 27, 28, 29, 30, 34, 35, 39, 40, 42, 43, 44, 45, 47, 81, 86, 88,
                     90, 101, 102, 104, 105, 110, 111, 115, 116, 125, 126, 127, 129, 132, 135,
                     138, 140, 175, 177, 178, 179, 180, 181, 185, 186, 189, 192, 193, 194, 196, 197]
test_index = [54, 55, 57, 59, 60, 62, 63, 66, 150, 151, 152, 154, 155, 159, 160, 161, 163, 164, 165, 166, 198,
                     199]
data_path = '/home/lupin/SIDD_Small_Raw_Only/Data'
patch_height = 32
decay_steps = 1000
decay_rate = 0.9
batch_size = 128

In [None]:
train_list = process(data_path,train_index,patch_height)
test_list = process(data_path,test_index,patch_height)
train_step = int(np.ceil(len(train_list)/batch_size))
test_step = int(np.ceil(len(test_list)/batch_size))
random.shuffle(train_list)
random.shuffle(test_list)

In [None]:
hps = arg_parser()
hps.continue_training = False
total_time = time.time()
host = socket.gethostname()
tf.compat.v1.set_random_seed(hps.seed)
np.random.seed(hps.seed)
logdir = os.path.abspath(os.path.join('experiments', hps.problem, hps.logdir)) + '/'
if not os.path.exists(logdir):
    os.makedirs(logdir, exist_ok=True)
hps.logdirname = hps.logdir
hps.logdir = logdir
# set up a custom logger
add_logging_level('TRACE', 100)
logging.getLogger(__name__).setLevel("TRACE")
logging.basicConfig(level=logging.TRACE)

x_shape = [None, patch_height, patch_height, 4]
hps.x_shape = x_shape
hps.n_dims = np.prod(x_shape[1:])

input_shape = x_shape

# Build noise flow graph
logging.trace('Building NoiseFlow...')
tf.compat.v1.disable_eager_execution()
is_training = tf.compat.v1.placeholder(tf.bool, name='is_training')
x = tf.compat.v1.placeholder(tf.float32, x_shape, name='noise_image')
y = tf.compat.v1.placeholder(tf.float32, x_shape, name='clean_image')
nlf0 = tf.compat.v1.placeholder(tf.float32, [None], name='nlf0')
nlf1 = tf.compat.v1.placeholder(tf.float32, [None], name='nlf1')
iso = tf.compat.v1.placeholder(tf.float32, [None], name='iso')
shutter = tf.compat.v1.placeholder(tf.float32, [None], name='shutter')
lr = tf.compat.v1.placeholder(tf.float32, None, name='learning_rate')

# initialization of signal, gain, and camera parameters
if hps.sidd_cond == 'mix':
    init_params(hps)

# NoiseFlow model
nf = NoiseFlow(input_shape[1:], is_training, hps)
loss_val, sd_z = nf.loss(x, y, nlf0=nlf0, nlf1=nlf1, iso=iso, shutter = shutter)

# save variable names and number of parameters
vs = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.GLOBAL_VARIABLES)
vars_files = os.path.join(hps.logdir, 'model_vars.txt')
with open(vars_files, 'w') as vf:
    vf.write(str(vs))
hps.num_params = int(np.sum([np.prod(v.get_shape().as_list())
                             for v in tf.compat.v1.trainable_variables()]))
logging.trace('number of parameters = %d' % hps.num_params)
hps_logger(hps.logdir + 'hps.txt', hps, nf.get_layer_names(), hps.num_params)

# create session
sess = tf.compat.v1.Session()

# create a saver.
saver = tf.compat.v1.train.Saver(max_to_keep=0)  # keep all models

# checkpoint directory
ckpt_dir = os.path.join(hps.logdir, 'ckpt')
ckpt_path = os.path.join(ckpt_dir, 'model.ckpt')
if not os.path.exists(ckpt_dir):
    os.makedirs(ckpt_dir, exist_ok=True)

# sampling temperature (default = 1.0)
if hps.temp is None:
    hps.temp = 1.0

# setup the output log
train_logger = test_logger = None
log_columns = ['epoch', 'NLL','sdz','train_time']
# NLL: negative log likelihood
# sdz: standard deviation of the base measure (sanity check)
train_logger = ResultLogger(hps.logdir + 'train.txt', log_columns, hps.continue_training)
sample_logger = ResultLogger(hps.logdir + 'sample.txt', ['sample_time','KLD_G', 'KLD_NLF', 'KLD_NF', 'KLD_R'],hps.continue_training)

tcurr = time.time()                      
# continue training?
start_epoch = 1
logging.trace('continue_training = ' + str(hps.continue_training))
if hps.continue_training:
    sess.run(tf.compat.v1.global_variables_initializer())
    saver.restore(sess, path.join(ckpt_dir,'model.ckpt.best'))
    last_epoch = restore_last_model(ckpt_dir, sess, saver)
    start_epoch = 1 + last_epoch
    # noinspection PyBroadException
    try:
        train_op = tf.compat.v1.get_collection('train_op')  # [0]
    except:
        logging.trace('could not restore optimizer state, preparing a new optimizer')
        train_op = get_optimizer(hps, lr, loss_val,decay_steps,decay_rate)
else:
    logging.trace('preparing optimizer')
    train_op = get_optimizer(hps, lr, loss_val,decay_steps,decay_rate)
    logging.trace('initializing variables')
    sess.run(tf.compat.v1.global_variables_initializer())
# Epochs
logging.trace('Starting training/testing/samplings.')
logging.trace('Logging to ' + hps.logdir)
kldiv_best = np.inf

In [None]:
for epoch in range(start_epoch, hps.epochs + 1):

    # Testing
    if (epoch < 10 or (epoch < 100 and epoch % 10 == 0) or epoch % hps.epochs_full_valid == 0.):
        kldiv3, t_sample = sample(sess,nf,test_list,x,y,iso,shutter,is_training,batch_size,test_step,hps,sample_logger,epoch)
        saver.save(sess, ckpt_path, global_step=epoch)
        # best model?
        if  kldiv3[2]< kldiv_best:
            kldiv_best = kldiv3[2]
            saver.save(sess, ckpt_path + '.best')
            is_best = 1
        else:
            is_best = 0
    t_curr = 0
    sd_z_tr,train_loss,t_train = train(sess,train_list,loss_val,sd_z,train_op,x,y,nlf0,nlf1,iso,shutter,lr,
                                       is_training,batch_size,train_step,hps,train_logger,epoch)
    t_curr = time.time()-tcurr
    tcurr = time.time()
    # Training loop
    
    # End training

    # print results of train/test/sample
    tr_l = train_loss
    if epoch < 10 or (epoch < 100 and epoch % 10 == 0) or \
            epoch % hps.epochs_full_valid == 0.:
        # E: epoch
        # tr, ts, tsm, tv: time of training, testing, sampling, visualization
        # T: total time
        # tL, sL, smL: loss of training, testing, sampling
        # SDr, SDs: std. dev. of base measure in training and testing
        # B: 1 if best model, 0 otherwise
        print('%s %s %s Epoch=%d train time=%.1f  sample time=%.1f T=%.1f '
              'loss=%5.1f SDr=%.1f B=%d' %
              (str(datetime.now())[11:16], host, hps.logdirname, epoch, t_train,t_sample, t_curr,
               tr_l, sd_z_tr, is_best),end='')
        if kldiv3 is not None:
            print(' ', end='')
            # marginal KL divergence of noise samples from: Gaussian, camera-NLF, and NoiseFlow, respectively
            print(','.join('{0:.3f}'.format(kk) for kk in kldiv3), end='')
        print('', flush=True)

total_time = time.time() - total_time
logging.trace('Total time = %f' % total_time)
with open(path.join(hps.logdir, 'total_time.txt'), 'w') as f:
    f.write('total_time (s) = %f' % total_time)
logging.trace("Finished!")