-
Notifications
You must be signed in to change notification settings - Fork 17
/
main.py
97 lines (84 loc) · 3.59 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from __future__ import print_function
from models.Dense3D import Dense3D
import torch
import toml
from training import Trainer
from testengine import Validator
from validation import Validator as Validator2
import torch.nn as nn
import os
from models.net2d import densenet161, resnet152, resnet152_plus, resnet152_R
import warnings
warnings.filterwarnings("ignore")
print("Loading options...")
# with open('options_slice.toml', 'r') as optionsFile:
with open('options_lip.toml', 'r') as optionsFile:
options = toml.loads(optionsFile.read())
if (options["general"]["usecudnnbenchmark"] and options["general"]["usecudnn"]):
print("Running cudnn benchmark...")
torch.backends.cudnn.benchmark = True
os.environ['CUDA_VISIBLE_DEVICES'] = options["general"]['gpuid']
torch.manual_seed(options["general"]['random_seed'])
# Create the model.
if options['general']['use_3d']:
model = Dense3D(options) ##TODO:1
elif options['general']['use_slice']:
if options['general']['use_plus']:
model = resnet152_plus(options['general']['class_num'], asinput=options['general']['plus_as_input'],
USE_25D=options['general']['use25d'])
else:
model = resnet152(options['general']['class_num'],
USE_25D=options['general']['use25d']) # vgg19_bn(2)#squeezenet1_1(2)
if 'R' in options['general'].keys():
model = resnet152_R(options['general']['class_num'])
else:
model = densenet161(2)
if (options["general"]["loadpretrainedmodel"]):
# remove paralle module
pretrained_dict = torch.load(options["general"]["pretrainedmodelpath"])
# load only exists weights
model_dict = model.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if
k in model_dict.keys() and v.size() == model_dict[k].size()}
print('matched keys:', len(pretrained_dict))
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
# Move the model to the GPU.
# criterion = model.loss()
if (options["general"]["usecudnn"]):
torch.cuda.manual_seed(options["general"]['random_seed'])
torch.cuda.manual_seed_all(options["general"]['random_seed'])
if (options["training"]["train"]):
trainer = Trainer(options, model)
if (options["validation"]["validate"]):
if options['general']['mod'] == 'slice':
validator = Validator2(options, 'validation', model, savenpy=options["validation"]["saves"],
)
else:
validator = Validator(options, 'validation', model, savenpy=options["validation"]["saves"],
) # TODO:change mod
if (options['test']['test']):
tester = Validator(options, 'test')
for epoch in range(options["training"]["startepoch"], options["training"]["epochs"]):
if (options["training"]["train"]):
trainer(epoch)
if (options["validation"]["validate"]) and epoch % 1 == 0:
result, re_all = validator()
# trainer.ScheduleLR(result.min())
print(options['training']['save_prefix'])
print('-' * 21)
print('All acc:' + str(re_all))
print('{:<10}|{:>10}'.format('Cls #', 'Accuracy'))
for i in range(len(result)):
print('{:<10}|{:>10}'.format(i, result[i]))
print('-' * 21)
if (options['test']['test']):
result, re_all = tester(model)
print('-' * 21)
print('All acc:' + str(re_all))
print('{:<10}|{:>10}'.format('Cls #', 'Accuracy'))
for i in range(2):
print('{:<10}|{:>10}'.format(i, result[i]))
print('-' * 21)
Trainer.writer.close()
print(options['training']['save_prefix'])