-
Notifications
You must be signed in to change notification settings - Fork 41
/
Copy pathtest.py
28 lines (21 loc) · 868 Bytes
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
# Copyright (c) Alibaba, Inc. and its affiliates.
import argparse
import os
import sys
import warnings
parent_folder = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(parent_folder)
from adaseq.commands.test import test_model # noqa: E402 isort:skip
warnings.filterwarnings('ignore')
def main(args):
"""test a model from args"""
test_model(args.work_dir, args.device, args.checkpoint_path)
if __name__ == '__main__':
parser = argparse.ArgumentParser('test with a model checkpoint')
parser.add_argument(
'-w', '--work_dir', required=True, help='directory to load config and checkpoint'
)
parser.add_argument('-d', '--device', default='gpu', help='device name')
parser.add_argument('-ckpt', '--checkpoint_path', default=None, help='model checkpoint')
args = parser.parse_args()
main(args)