forked from yang-song/score_sde_pytorch
-
Notifications
You must be signed in to change notification settings - Fork 4
/
main.py
34 lines (27 loc) · 1.28 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from absl import app
from absl import flags
from ml_collections.config_flags import config_flags
import run_lib
FLAGS = flags.FLAGS
config_flags.DEFINE_config_file(
"config", None, "Training configuration.", lock_config=False)
flags.DEFINE_string("checkpoint_path", None, "Checkpoint directory.")
flags.DEFINE_string("data_path", None, "Checkpoint directory.")
flags.DEFINE_string("log_path", "./", "Checkpoint directory.")
flags.DEFINE_enum("mode", "train", ["train", "test", "multi_scale_test", "compute_dataset_statistics", 'evaluation_pipeline'], "Running mode: train or test")
flags.DEFINE_string("eval_folder", "eval",
"The folder name for storing evaluation results")
flags.mark_flags_as_required(["config", "mode", "log_path"])
def main(argv):
if FLAGS.mode == 'train':
run_lib.train(FLAGS.config, FLAGS.log_path, FLAGS.checkpoint_path)
elif FLAGS.mode == 'test':
run_lib.test(FLAGS.config, FLAGS.log_path, FLAGS.checkpoint_path)
elif FLAGS.mode == 'multi_scale_test':
run_lib.multi_scale_test(FLAGS.config, FLAGS.log_path)
elif FLAGS.mode == 'compute_dataset_statistics':
run_lib.compute_data_stats(FLAGS.config)
elif FLAGS.mode == 'evaluation_pipeline':
run_lib.evaluation_pipeline(FLAGS.config)
if __name__ == "__main__":
app.run(main)