Skip to content

Commit

Permalink
Remove requirement of having the type of model as a command line argu…
Browse files Browse the repository at this point in the history
…ment (it will simply be a hyperparameter) since we currently need to be able to run the model script like 'python MY_MODEL_SCRIPT.py' with no arguments
  • Loading branch information
andrew-weisman committed Aug 23, 2021
1 parent cf20f61 commit 721b0f1
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
10 changes: 5 additions & 5 deletions options.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os, sys
import argparse

def parse_training_arguments():
def parse_training_arguments(model_name):
parser = argparse.ArgumentParser()
required_args = parser.add_argument_group('required arguments')
parser.add_argument('-d', '--depth', dest="depth", type=str,
Expand All @@ -13,8 +13,8 @@ def parse_training_arguments():
parser.add_argument('-k', '--kappa', type=str, dest="kappa",
help='How fast to linearly ramp up KL loss',
metavar="kappa", default='1.')
required_args.add_argument("-m", "--model_name", dest="model_name",
help="model name: adage | vae", metavar="model_name")
parser.add_argument("-m", "--model_name", dest="model_name",
help="model name: adage | vae", metavar="model_name", default=model_name)
parser.add_argument('-N', '--noise', type=float, dest="noise",
help='How much Gaussian noise to add during training',
metavar="noise", default=0.05)
Expand Down Expand Up @@ -87,7 +87,7 @@ def parse_prediction_arguments():
args = parser.parse_args()
return args

def parse_command_line_arguments(task):
def parse_command_line_arguments(task, model_name):
if task == "predict":
opt = parse_prediction_arguments()
checkpoint_combined = 'checkpoints/' + ".".join(["tybalt",
Expand All @@ -96,7 +96,7 @@ def parse_command_line_arguments(task):

elif task == "train":
# task == "train"
opt = parse_training_arguments()
opt = parse_training_arguments(model_name)
if not os.path.isdir(os.path.join(os.getcwd(), 'checkpoints')):
os.mkdir(os.path.join(os.getcwd(), 'checkpoints'))
if not opt.model_name:
Expand Down
2 changes: 1 addition & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def save_training_performance(pd, args):

if __name__ == '__main__':
opt, checkpoint_combined, checkpoint_encoder, checkpoint_decoder = \
parse_command_line_arguments("train")
parse_command_line_arguments("train", candle_params['tybalt_model_name'])

opt, rnaseq_df, train_df, test_df = get_data(opt)

Expand Down

0 comments on commit 721b0f1

Please sign in to comment.