-
Notifications
You must be signed in to change notification settings - Fork 3
/
main.py
89 lines (70 loc) · 3.46 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import argparse
import yaml
from numpy.distutils.fcompiler import str2bool
from tensorboardX import SummaryWriter
from datasets import *
from solver import Solver
def main(args):
with open("./config.yml", 'r') as stream:
try:
config = yaml.safe_load(stream)
except yaml.YAMLError as exc:
print(exc)
train_logger = SummaryWriter(log_dir=os.path.join(config['log'], 'train'), comment='training')
config['train_prm_resume'] = args.train_prm_resume
config['train_filling_resume'] = args.train_filling_resume
solver = Solver(config)
train_trans = image_transform(**config['train_transform'])
config['train_dataset'].update({'transform': train_trans,
'target_transform': None,
'categories': config['class_names']})
if config.get('val_dataset') is not None:
config['val_dataset'].update({'transform': train_trans,
'target_transform': None,
'categories': config['class_names']})
if args.train_prm:
config['train_dataset'].update({'train_type': 'prm'})
dataset = train_dataset(**config['train_dataset'])
config['data_loaders']['dataset'] = dataset
data_loader = get_dataloader(**config['data_loaders'])
if config.get('val_dataset') is not None:
config['val_dataset'].update({'train_type': 'prm'})
dataset = train_dataset(**config['val_dataset'])
config['data_loaders']['dataset'] = dataset
val_data_loader = get_dataloader(**config['data_loaders'])
else:
val_data_loader = None
solver.train_prm(data_loader, train_logger, val_data_loader)
print('train prm over')
if args.train_filling:
proposals_trans = proposals_transform(**config['train_transform'])
config['train_dataset'].update({
'train_type': 'filling',
'target_transform': proposals_trans,
})
dataset = train_dataset(**config['train_dataset'])
config['data_loaders']['dataset'] = dataset
data_loader = get_dataloader(**config['data_loaders'])
solver.train_filling(data_loader, train_logger)
print('train filling over')
if args.run_demo:
test_trans = image_transform(**config['test_transform'])
config['test_dataset'].update({'image_size': config['test_transform']['image_size'],
'transform': test_trans})
dataset = test_dataset(**config['test_dataset'])
config['test_data_loaders']['dataset'] = dataset
data_loader = get_dataloader(**config['test_data_loaders'])
solver.inference(data_loader)
print('predict over')
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--train_filling', type=str2bool, default=False, help='set train filling mode up')
parser.add_argument('--train_prm', type=str2bool, default=False,
help='set train prm mode up')
parser.add_argument('--train_filling_resume', type=str2bool, default=False,
help='train filling with latest weight')
parser.add_argument('--train_prm_resume', type=str2bool, default=False,
help='train prm with latest weight')
parser.add_argument('--run_demo', '-I', type=str2bool, default=False, help='run demo')
args = parser.parse_args()
main(args)