diff --git a/options.py b/options.py index 58a78a1..124d949 100644 --- a/options.py +++ b/options.py @@ -1,7 +1,7 @@ import os, sys import argparse -def parse_training_arguments(): +def parse_training_arguments(model_name): parser = argparse.ArgumentParser() required_args = parser.add_argument_group('required arguments') parser.add_argument('-d', '--depth', dest="depth", type=str, @@ -13,8 +13,8 @@ def parse_training_arguments(): parser.add_argument('-k', '--kappa', type=str, dest="kappa", help='How fast to linearly ramp up KL loss', metavar="kappa", default='1.') - required_args.add_argument("-m", "--model_name", dest="model_name", - help="model name: adage | vae", metavar="model_name") + parser.add_argument("-m", "--model_name", dest="model_name", + help="model name: adage | vae", metavar="model_name", default=model_name) parser.add_argument('-N', '--noise', type=float, dest="noise", help='How much Gaussian noise to add during training', metavar="noise", default=0.05) @@ -87,7 +87,7 @@ def parse_prediction_arguments(): args = parser.parse_args() return args -def parse_command_line_arguments(task): +def parse_command_line_arguments(task, model_name): if task == "predict": opt = parse_prediction_arguments() checkpoint_combined = 'checkpoints/' + ".".join(["tybalt", @@ -96,7 +96,7 @@ def parse_command_line_arguments(task): elif task == "train": # task == "train" - opt = parse_training_arguments() + opt = parse_training_arguments(model_name) if not os.path.isdir(os.path.join(os.getcwd(), 'checkpoints')): os.mkdir(os.path.join(os.getcwd(), 'checkpoints')) if not opt.model_name: diff --git a/train.py b/train.py index 5ebc217..58c3954 100755 --- a/train.py +++ b/train.py @@ -47,7 +47,7 @@ def save_training_performance(pd, args): if __name__ == '__main__': opt, checkpoint_combined, checkpoint_encoder, checkpoint_decoder = \ - parse_command_line_arguments("train") + parse_command_line_arguments("train", candle_params['tybalt_model_name']) opt, rnaseq_df, train_df, test_df = get_data(opt)