In [None]:
import os
import json
import traceback

from utils.main_function_helpers import *

os.environ['CUDA_VISIBLE_DEVICES'] = '0'


In [None]:
#######################################################
# Adjust the following  four parameters
#######################################################

# Assign an ID to the experiment
exp_nums = ['001']
# Pick which configuration file/s to run from the options folder
config_files = ['Fig2_Neigh2Neigh_N100_C128'] 
# Path to ImageNet train directory
path_to_ImageNet_train = '../../../../media/ssd1/ImageNet/ILSVRC/Data/CLS-LOC/'
# Run the best seed or all the seeds 
run_which_seeds = 'run_best_seed' # 'run_all_seeds' or 'run_best_seed'



########################################################

# Sanity checks
if len(config_files) != len(exp_nums):
    raise ValueError("Specify experiment ID for each experiment") 



#########################################################
# The parameters below are fixed 
training = True
testing = True



In [None]:
hps = []
for ee in range(len(exp_nums)):
    options_name = "options/{}.txt".format(config_files[ee])

    # Load hyperparameter options
    with open(options_name) as handle:
        hp = json.load(handle)
    hp['path_to_ImageNet_train'] = [path_to_ImageNet_train]
    hps.append(hp)
    


In [None]:
for ee in range(len(exp_nums)):
    
    hp = hps[ee]
    if run_which_seeds == 'run_best_seed':
        num_runs = hp['best_seed']
    else:
        num_runs = hp['all_seeds']
    
    
    for rr in num_runs:
        exp_name =  'E' + exp_nums[ee] + \
                    '_t' + str(hp['train_size'][0]) + \
                    '_l' + str(hp['num_pool_layers'][0]) + \
                    'c' + str(hp['chans'][0]) + \
                    '_bs' + str(hp['batch_size'][0]) +\
                    '_lr' + str(hp['lr'][0])[2:]
        if rr>0:
            exp_name = exp_name + '_run{}'.format(rr+1)
        if not os.path.isdir('../'+exp_name):
            os.mkdir('../'+exp_name)
        
        ########
        # Training
        ########
        try:
            if training:  
                print('\n{} - Training\n'.format(exp_name))
                args = get_args(hp,0,rr)
                args.output_dir = '../'+exp_name
                cli_main(args)
                print('\n{} - Training finished\n'.format(exp_name))
        except:
            with open("../"+exp_name+"/errors_train.txt", "a+") as text_file:
                error_str = traceback.format_exc()
                print(error_str, file=text_file)   
            print(error_str)
            
        ########
        # Testing
        ########
        try:
            if testing:
                print('\n{} - Testing\n'.format(exp_name))

                test_modes = ["val","test"]

                for test_mode in test_modes:
                    #for restore_mode in ["last","best"]:
                    for restore_mode in ["best"]:

                        args = get_args(hp,0,rr)
                        args.output_dir = '../'+exp_name
                        args.restore_mode = restore_mode
                        args.test_mode = test_mode
                        args.test_noise_std_min = args.noise_std
                        args.test_noise_std_max = args.noise_std
                        cli_main_test(args)
                        
                print('\n{} - Testing finished\n'.format(exp_name))
        except:
            with open("../"+exp_name+"/errors_test.txt", "a+") as text_file:
                error_str = traceback.format_exc()
                print(error_str, file=text_file)   
            print(error_str)               