diff --git a/train_noise_flow.py b/train_noise_flow.py index 6c5d24a..f72586e 100644 --- a/train_noise_flow.py +++ b/train_noise_flow.py @@ -1,10 +1,10 @@ #!/usr/bin/env python +import logging import os import queue import socket import sys import time -import logging from datetime import datetime from os import path from threading import Thread @@ -19,10 +19,10 @@ from mylogger import add_logging_level from sidd.ArgParser import arg_parser from sidd.Initialization import initialize_data_stats_queues_baselines_histograms -from sidd.sidd_utils import sidd_filenames_que_inst, save_visual_minibatch, load_visual_minibatch, restore_last_model, \ - divide_parts, save_minibatch, load_minibatch, \ - divide_array_parts, calc_train_test_stats, print_train_test_stats, restore_epoch_model, sample_sidd_tf, \ - load_minibatches_que, calc_kldiv_mb, kl_div_3_data +from sidd.data_loader import check_download_sidd +from sidd.sidd_utils import sidd_filenames_que_inst, restore_last_model, \ + divide_parts, calc_train_test_stats, print_train_test_stats, sample_sidd_tf, \ + calc_kldiv_mb, kl_div_3_data os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' @@ -215,6 +215,10 @@ def init_params(hps1): def main(hps): + + # Download SIDD_Medium_Raw? + check_download_sidd() + total_time = time.time() host = socket.gethostname() tf.set_random_seed(hps.seed)